Skip to content

Commit 869b31a

Browse files
Propose a method for handling datasets which doesn't explicitly require Tensorflow (#21758)
* wip: initial draft of backend agnostic dataset handling logic * localize the local imports * update unit tests to support torch dataset return value * update error msg * correct typo * formatting * update docstring * add docstrings to tensorflow and torch dataset handlers * eliminate extra arg original_dataset * ensure tests for torch dataset handler are run * ensure we return tensorflow as the default handler * Update keras/src/utils/dataset_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras/src/utils/dataset_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras/src/utils/dataset_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix formatting from commit in github UI * fix pydoc * initial pass at a purely function approach * fix pydocs * fix tests * fix docstring * fix docstring * ensure compatibility works for both tf <> torch and vice versa * ensure compatibility works for both tf <> torch and vice versa * simplify condition for identifying backend * fix imports, simplify dataset check --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent f78fd8c commit 869b31a

File tree

2 files changed

+307
-52
lines changed

2 files changed

+307
-52
lines changed

keras/src/utils/dataset_utils.py

Lines changed: 230 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,22 @@
66

77
import numpy as np
88

9+
from keras.src import backend
910
from keras.src import tree
1011
from keras.src.api_export import keras_export
1112
from keras.src.utils import file_utils
1213
from keras.src.utils import io_utils
1314
from keras.src.utils.module_utils import grain
14-
from keras.src.utils.module_utils import tensorflow as tf
1515

1616

1717
@keras_export("keras.utils.split_dataset")
1818
def split_dataset(
19-
dataset, left_size=None, right_size=None, shuffle=False, seed=None
19+
dataset,
20+
left_size=None,
21+
right_size=None,
22+
shuffle=False,
23+
seed=None,
24+
preferred_backend=None,
2025
):
2126
"""Splits a dataset into a left half and a right half (e.g. train / test).
2227
@@ -37,27 +42,86 @@ def split_dataset(
3742
Defaults to `None`.
3843
shuffle: Boolean, whether to shuffle the data before splitting it.
3944
seed: A random seed for shuffling.
45+
preferred_backend: String, specifying which backend
46+
(e.g.; "tensorflow", "torch") to use. If `None`, the
47+
backend is inferred from the type of `dataset` - if
48+
`dataset` is a `tf.data.Dataset`, "tensorflow" backend
49+
is used, if `dataset` is a `torch.utils.data.Dataset`,
50+
"torch" backend is used, and if `dataset` is a list/tuple/np.array
51+
the current Keras backend is used. Defaults to `None`.
4052
4153
Returns:
42-
A tuple of two `tf.data.Dataset` objects:
43-
the left and right splits.
44-
54+
A tuple of two dataset objects, the left and right splits. The exact
55+
type of the returned objects depends on the `preferred_backend`.
56+
For example, with a "tensorflow" backend,
57+
`tf.data.Dataset` objects are returned. With a "torch" backend,
58+
`torch.utils.data.Dataset` objects are returned.
4559
Example:
4660
4761
>>> data = np.random.random(size=(1000, 4))
4862
>>> left_ds, right_ds = keras.utils.split_dataset(data, left_size=0.8)
49-
>>> int(left_ds.cardinality())
50-
800
51-
>>> int(right_ds.cardinality())
52-
200
63+
>>> # For a tf.data.Dataset, you can use .cardinality()
64+
>>> # >>> int(left_ds.cardinality())
65+
>>> # 800
66+
>>> # For a torch.utils.data.Dataset, you can use len()
67+
>>> # >>> len(left_ds)
68+
>>> # 800
69+
"""
70+
preferred_backend = preferred_backend or _infer_preferred_backend(dataset)
71+
if preferred_backend != "torch":
72+
return _split_dataset_tf(
73+
dataset,
74+
left_size=left_size,
75+
right_size=right_size,
76+
shuffle=shuffle,
77+
seed=seed,
78+
)
79+
else:
80+
return _split_dataset_torch(
81+
dataset,
82+
left_size=left_size,
83+
right_size=right_size,
84+
shuffle=shuffle,
85+
seed=seed,
86+
)
87+
88+
89+
def _split_dataset_tf(
90+
dataset, left_size=None, right_size=None, shuffle=False, seed=None
91+
):
92+
"""Splits a dataset into a left half and a right half (e.g. train / test).
93+
94+
Args:
95+
dataset:
96+
A `tf.data.Dataset` object,
97+
or a list/tuple of arrays with the same length.
98+
left_size: If float (in the range `[0, 1]`), it signifies
99+
the fraction of the data to pack in the left dataset. If integer, it
100+
signifies the number of samples to pack in the left dataset. If
101+
`None`, defaults to the complement to `right_size`.
102+
Defaults to `None`.
103+
right_size: If float (in the range `[0, 1]`), it signifies
104+
the fraction of the data to pack in the right dataset.
105+
If integer, it signifies the number of samples to pack
106+
in the right dataset.
107+
If `None`, defaults to the complement to `left_size`.
108+
Defaults to `None`.
109+
shuffle: Boolean, whether to shuffle the data before splitting it.
110+
seed: A random seed for shuffling.
111+
112+
Returns:
113+
A tuple of two `tf.data.Dataset` objects:
114+
the left and right splits.
53115
"""
116+
from keras.src.utils.module_utils import tensorflow as tf
117+
54118
dataset_type_spec = _get_type_spec(dataset)
55119

56120
if dataset_type_spec is None:
57121
raise TypeError(
58122
"The `dataset` argument must be either"
59-
"a `tf.data.Dataset`, a `torch.utils.data.Dataset`"
60-
"object, or a list/tuple of arrays. "
123+
"a `tf.data.Dataset` object, or"
124+
"a list/tuple of arrays. "
61125
f"Received: dataset={dataset} of type {type(dataset)}"
62126
)
63127

@@ -106,6 +170,103 @@ def split_dataset(
106170
return left_split, right_split
107171

108172

173+
def _split_dataset_torch(
174+
dataset, left_size=None, right_size=None, shuffle=False, seed=None
175+
):
176+
"""Splits a dataset into a left half and a right half (e.g. train / test).
177+
178+
Args:
179+
dataset:
180+
A `torch.utils.data.Dataset` object,
181+
or a list/tuple of arrays with the same length.
182+
left_size: If float (in the range `[0, 1]`), it signifies
183+
the fraction of the data to pack in the left dataset. If integer, it
184+
signifies the number of samples to pack in the left dataset. If
185+
`None`, defaults to the complement to `right_size`.
186+
Defaults to `None`.
187+
right_size: If float (in the range `[0, 1]`), it signifies
188+
the fraction of the data to pack in the right dataset.
189+
If integer, it signifies the number of samples to pack
190+
in the right dataset.
191+
If `None`, defaults to the complement to `left_size`.
192+
Defaults to `None`.
193+
shuffle: Boolean, whether to shuffle the data before splitting it.
194+
seed: A random seed for shuffling.
195+
196+
Returns:
197+
A tuple of two `torch.utils.data.Dataset` objects:
198+
the left and right splits.
199+
"""
200+
import torch
201+
from torch.utils.data import TensorDataset
202+
from torch.utils.data import random_split
203+
204+
dataset_type_spec = _get_type_spec(dataset)
205+
if dataset_type_spec is None:
206+
raise TypeError(
207+
"The `dataset` argument must be a `torch.utils.data.Dataset`"
208+
" object, or a list/tuple of arrays."
209+
f" Received: dataset={dataset} of type {type(dataset)}"
210+
)
211+
212+
if not isinstance(dataset, torch.utils.data.Dataset):
213+
if dataset_type_spec is np.ndarray:
214+
dataset = TensorDataset(torch.from_numpy(dataset))
215+
elif dataset_type_spec in (list, tuple):
216+
tensors = [torch.from_numpy(x) for x in dataset]
217+
dataset = TensorDataset(*tensors)
218+
elif is_tf_dataset(dataset):
219+
dataset_as_list = _convert_dataset_to_list(
220+
dataset, dataset_type_spec
221+
)
222+
tensors = [
223+
torch.from_numpy(np.array(sample))
224+
for sample in zip(*dataset_as_list)
225+
]
226+
dataset = TensorDataset(*tensors)
227+
228+
if right_size is None and left_size is None:
229+
raise ValueError(
230+
"At least one of the `left_size` or `right_size` "
231+
"must be specified. "
232+
"Received: left_size=None and right_size=None"
233+
)
234+
235+
# Calculate total length and rescale split sizes
236+
total_length = len(dataset)
237+
left_size, right_size = _rescale_dataset_split_sizes(
238+
left_size, right_size, total_length
239+
)
240+
241+
# Shuffle the dataset if required
242+
if shuffle:
243+
generator = torch.Generator()
244+
if seed is not None:
245+
generator.manual_seed(seed)
246+
else:
247+
generator.seed()
248+
else:
249+
generator = None
250+
251+
left_split, right_split = random_split(
252+
dataset, [left_size, right_size], generator=generator
253+
)
254+
255+
return left_split, right_split
256+
257+
258+
def _infer_preferred_backend(dataset):
259+
"""Infer the backend from the dataset type."""
260+
if isinstance(dataset, (list, tuple, np.ndarray)):
261+
return backend.backend()
262+
if is_tf_dataset(dataset):
263+
return "tensorflow"
264+
elif is_torch_dataset(dataset):
265+
return "torch"
266+
else:
267+
raise TypeError(f"Unsupported dataset type: {type(dataset)}")
268+
269+
109270
def _convert_dataset_to_list(
110271
dataset,
111272
dataset_type_spec,
@@ -208,7 +369,7 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
208369
)
209370

210371
return iter(zip(*dataset))
211-
elif dataset_type_spec is tf.data.Dataset:
372+
elif is_tf_dataset(dataset):
212373
if is_batched(dataset):
213374
dataset = dataset.unbatch()
214375
return iter(dataset)
@@ -242,15 +403,20 @@ def _get_next_sample(
242403
Yields:
243404
data_sample: The next sample.
244405
"""
406+
from keras.src.trainers.data_adapters.data_adapter_utils import (
407+
is_tensorflow_tensor,
408+
)
245409
from keras.src.trainers.data_adapters.data_adapter_utils import (
246410
is_torch_tensor,
247411
)
248412

249413
try:
250414
dataset_iterator = iter(dataset_iterator)
251415
first_sample = next(dataset_iterator)
252-
if isinstance(first_sample, (tf.Tensor, np.ndarray)) or is_torch_tensor(
253-
first_sample
416+
if (
417+
isinstance(first_sample, np.ndarray)
418+
or is_tensorflow_tensor(first_sample)
419+
or is_torch_tensor(first_sample)
254420
):
255421
first_sample_shape = np.array(first_sample).shape
256422
else:
@@ -291,23 +457,36 @@ def _get_next_sample(
291457
yield sample
292458

293459

294-
def is_torch_dataset(dataset):
295-
if hasattr(dataset, "__class__"):
296-
for parent in dataset.__class__.__mro__:
297-
if parent.__name__ == "Dataset" and str(
298-
parent.__module__
299-
).startswith("torch.utils.data"):
300-
return True
301-
return False
460+
def is_tf_dataset(dataset):
461+
return _mro_matches(
462+
dataset,
463+
class_names=("DatasetV2", "Dataset"),
464+
module_prefixes=(
465+
"tensorflow.python.data", # TF classic
466+
"tensorflow.data", # newer TF paths
467+
),
468+
)
302469

303470

304471
def is_grain_dataset(dataset):
305-
if hasattr(dataset, "__class__"):
306-
for parent in dataset.__class__.__mro__:
307-
if parent.__name__ in (
308-
"MapDataset",
309-
"IterDataset",
310-
) and str(parent.__module__).startswith("grain._src.python"):
472+
return _mro_matches(
473+
dataset,
474+
class_names=("MapDataset", "IterDataset"),
475+
module_prefixes=("grain._src.python",),
476+
)
477+
478+
479+
def is_torch_dataset(dataset):
480+
return _mro_matches(dataset, ("Dataset",), ("torch.utils.data",))
481+
482+
483+
def _mro_matches(dataset, class_names, module_prefixes):
484+
if not hasattr(dataset, "__class__"):
485+
return False
486+
for parent in dataset.__class__.__mro__:
487+
if parent.__name__ in class_names:
488+
mod = str(parent.__module__)
489+
if any(mod.startswith(pref) for pref in module_prefixes):
311490
return True
312491
return False
313492

@@ -441,8 +620,10 @@ def _restore_dataset_from_list(
441620
dataset_as_list, dataset_type_spec, original_dataset
442621
):
443622
"""Restore the dataset from the list of arrays."""
444-
if dataset_type_spec in [tuple, list, tf.data.Dataset] or is_torch_dataset(
445-
original_dataset
623+
if (
624+
dataset_type_spec in [tuple, list]
625+
or is_tf_dataset(original_dataset)
626+
or is_torch_dataset(original_dataset)
446627
):
447628
# Save structure by taking the first element.
448629
element_spec = dataset_as_list[0]
@@ -483,7 +664,9 @@ def _get_type_spec(dataset):
483664
return list
484665
elif isinstance(dataset, np.ndarray):
485666
return np.ndarray
486-
elif isinstance(dataset, tf.data.Dataset):
667+
elif is_tf_dataset(dataset):
668+
from keras.src.utils.module_utils import tensorflow as tf
669+
487670
return tf.data.Dataset
488671
elif is_torch_dataset(dataset):
489672
from torch.utils.data import Dataset as TorchDataset
@@ -543,6 +726,8 @@ def index_directory(
543726
order.
544727
"""
545728
if file_utils.is_remote_path(directory):
729+
from keras.src.utils.module_utils import tensorflow as tf
730+
546731
os_module = tf.io.gfile
547732
path_module = tf.io.gfile
548733
else:
@@ -647,7 +832,12 @@ def index_directory(
647832

648833

649834
def iter_valid_files(directory, follow_links, formats):
650-
io_module = tf.io.gfile if file_utils.is_remote_path(directory) else os
835+
if file_utils.is_remote_path(directory):
836+
from keras.src.utils.module_utils import tensorflow as tf
837+
838+
io_module = tf.io.gfile
839+
else:
840+
io_module = os
651841

652842
if not follow_links:
653843
walk = io_module.walk(directory)
@@ -674,9 +864,12 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
674864
paths, and `labels` is a list of integer labels corresponding
675865
to these files.
676866
"""
677-
path_module = (
678-
tf.io.gfile if file_utils.is_remote_path(directory) else os.path
679-
)
867+
if file_utils.is_remote_path(directory):
868+
from keras.src.utils.module_utils import tensorflow as tf
869+
870+
path_module = tf.io.gfile
871+
else:
872+
path_module = os.path
680873

681874
dirname = os.path.basename(directory)
682875
valid_files = iter_valid_files(directory, follow_links, formats)
@@ -746,6 +939,8 @@ def labels_to_dataset_tf(labels, label_mode, num_classes):
746939
Returns:
747940
A `tf.data.Dataset` instance.
748941
"""
942+
from keras.src.utils.module_utils import tensorflow as tf
943+
749944
label_ds = tf.data.Dataset.from_tensor_slices(labels)
750945
if label_mode == "binary":
751946
label_ds = label_ds.map(

0 commit comments

Comments
 (0)