66
77import numpy as np
88
9+ from keras .src import backend
910from keras .src import tree
1011from keras .src .api_export import keras_export
1112from keras .src .utils import file_utils
1213from keras .src .utils import io_utils
1314from keras .src .utils .module_utils import grain
14- from keras .src .utils .module_utils import tensorflow as tf
1515
1616
1717@keras_export ("keras.utils.split_dataset" )
1818def split_dataset (
19- dataset , left_size = None , right_size = None , shuffle = False , seed = None
19+ dataset ,
20+ left_size = None ,
21+ right_size = None ,
22+ shuffle = False ,
23+ seed = None ,
24+ preferred_backend = None ,
2025):
2126 """Splits a dataset into a left half and a right half (e.g. train / test).
2227
@@ -37,27 +42,86 @@ def split_dataset(
3742 Defaults to `None`.
3843 shuffle: Boolean, whether to shuffle the data before splitting it.
3944 seed: A random seed for shuffling.
45+ preferred_backend: String, specifying which backend
46+ (e.g.; "tensorflow", "torch") to use. If `None`, the
47+ backend is inferred from the type of `dataset` - if
48+ `dataset` is a `tf.data.Dataset`, "tensorflow" backend
49+ is used, if `dataset` is a `torch.utils.data.Dataset`,
50+ "torch" backend is used, and if `dataset` is a list/tuple/np.array
51+ the current Keras backend is used. Defaults to `None`.
4052
4153 Returns:
42- A tuple of two `tf.data.Dataset` objects:
43- the left and right splits.
44-
54+ A tuple of two dataset objects, the left and right splits. The exact
55+ type of the returned objects depends on the `preferred_backend`.
56+ For example, with a "tensorflow" backend,
57+ `tf.data.Dataset` objects are returned. With a "torch" backend,
58+ `torch.utils.data.Dataset` objects are returned.
4559 Example:
4660
4761 >>> data = np.random.random(size=(1000, 4))
4862 >>> left_ds, right_ds = keras.utils.split_dataset(data, left_size=0.8)
49- >>> int(left_ds.cardinality())
50- 800
51- >>> int(right_ds.cardinality())
52- 200
63+ >>> # For a tf.data.Dataset, you can use .cardinality()
64+ >>> # >>> int(left_ds.cardinality())
65+ >>> # 800
66+ >>> # For a torch.utils.data.Dataset, you can use len()
67+ >>> # >>> len(left_ds)
68+ >>> # 800
69+ """
70+ preferred_backend = preferred_backend or _infer_preferred_backend (dataset )
71+ if preferred_backend != "torch" :
72+ return _split_dataset_tf (
73+ dataset ,
74+ left_size = left_size ,
75+ right_size = right_size ,
76+ shuffle = shuffle ,
77+ seed = seed ,
78+ )
79+ else :
80+ return _split_dataset_torch (
81+ dataset ,
82+ left_size = left_size ,
83+ right_size = right_size ,
84+ shuffle = shuffle ,
85+ seed = seed ,
86+ )
87+
88+
89+ def _split_dataset_tf (
90+ dataset , left_size = None , right_size = None , shuffle = False , seed = None
91+ ):
92+ """Splits a dataset into a left half and a right half (e.g. train / test).
93+
94+ Args:
95+ dataset:
96+ A `tf.data.Dataset` object,
97+ or a list/tuple of arrays with the same length.
98+ left_size: If float (in the range `[0, 1]`), it signifies
99+ the fraction of the data to pack in the left dataset. If integer, it
100+ signifies the number of samples to pack in the left dataset. If
101+ `None`, defaults to the complement to `right_size`.
102+ Defaults to `None`.
103+ right_size: If float (in the range `[0, 1]`), it signifies
104+ the fraction of the data to pack in the right dataset.
105+ If integer, it signifies the number of samples to pack
106+ in the right dataset.
107+ If `None`, defaults to the complement to `left_size`.
108+ Defaults to `None`.
109+ shuffle: Boolean, whether to shuffle the data before splitting it.
110+ seed: A random seed for shuffling.
111+
112+ Returns:
113+ A tuple of two `tf.data.Dataset` objects:
114+ the left and right splits.
53115 """
116+ from keras .src .utils .module_utils import tensorflow as tf
117+
54118 dataset_type_spec = _get_type_spec (dataset )
55119
56120 if dataset_type_spec is None :
57121 raise TypeError (
58122 "The `dataset` argument must be either"
59- "a `tf.data.Dataset`, a `torch.utils.data.Dataset` "
60- "object, or a list/tuple of arrays. "
123+ "a `tf.data.Dataset` object, or "
124+ "a list/tuple of arrays. "
61125 f"Received: dataset={ dataset } of type { type (dataset )} "
62126 )
63127
@@ -106,6 +170,103 @@ def split_dataset(
106170 return left_split , right_split
107171
108172
173+ def _split_dataset_torch (
174+ dataset , left_size = None , right_size = None , shuffle = False , seed = None
175+ ):
176+ """Splits a dataset into a left half and a right half (e.g. train / test).
177+
178+ Args:
179+ dataset:
180+ A `torch.utils.data.Dataset` object,
181+ or a list/tuple of arrays with the same length.
182+ left_size: If float (in the range `[0, 1]`), it signifies
183+ the fraction of the data to pack in the left dataset. If integer, it
184+ signifies the number of samples to pack in the left dataset. If
185+ `None`, defaults to the complement to `right_size`.
186+ Defaults to `None`.
187+ right_size: If float (in the range `[0, 1]`), it signifies
188+ the fraction of the data to pack in the right dataset.
189+ If integer, it signifies the number of samples to pack
190+ in the right dataset.
191+ If `None`, defaults to the complement to `left_size`.
192+ Defaults to `None`.
193+ shuffle: Boolean, whether to shuffle the data before splitting it.
194+ seed: A random seed for shuffling.
195+
196+ Returns:
197+ A tuple of two `torch.utils.data.Dataset` objects:
198+ the left and right splits.
199+ """
200+ import torch
201+ from torch .utils .data import TensorDataset
202+ from torch .utils .data import random_split
203+
204+ dataset_type_spec = _get_type_spec (dataset )
205+ if dataset_type_spec is None :
206+ raise TypeError (
207+ "The `dataset` argument must be a `torch.utils.data.Dataset`"
208+ " object, or a list/tuple of arrays."
209+ f" Received: dataset={ dataset } of type { type (dataset )} "
210+ )
211+
212+ if not isinstance (dataset , torch .utils .data .Dataset ):
213+ if dataset_type_spec is np .ndarray :
214+ dataset = TensorDataset (torch .from_numpy (dataset ))
215+ elif dataset_type_spec in (list , tuple ):
216+ tensors = [torch .from_numpy (x ) for x in dataset ]
217+ dataset = TensorDataset (* tensors )
218+ elif is_tf_dataset (dataset ):
219+ dataset_as_list = _convert_dataset_to_list (
220+ dataset , dataset_type_spec
221+ )
222+ tensors = [
223+ torch .from_numpy (np .array (sample ))
224+ for sample in zip (* dataset_as_list )
225+ ]
226+ dataset = TensorDataset (* tensors )
227+
228+ if right_size is None and left_size is None :
229+ raise ValueError (
230+ "At least one of the `left_size` or `right_size` "
231+ "must be specified. "
232+ "Received: left_size=None and right_size=None"
233+ )
234+
235+ # Calculate total length and rescale split sizes
236+ total_length = len (dataset )
237+ left_size , right_size = _rescale_dataset_split_sizes (
238+ left_size , right_size , total_length
239+ )
240+
241+ # Shuffle the dataset if required
242+ if shuffle :
243+ generator = torch .Generator ()
244+ if seed is not None :
245+ generator .manual_seed (seed )
246+ else :
247+ generator .seed ()
248+ else :
249+ generator = None
250+
251+ left_split , right_split = random_split (
252+ dataset , [left_size , right_size ], generator = generator
253+ )
254+
255+ return left_split , right_split
256+
257+
258+ def _infer_preferred_backend (dataset ):
259+ """Infer the backend from the dataset type."""
260+ if isinstance (dataset , (list , tuple , np .ndarray )):
261+ return backend .backend ()
262+ if is_tf_dataset (dataset ):
263+ return "tensorflow"
264+ elif is_torch_dataset (dataset ):
265+ return "torch"
266+ else :
267+ raise TypeError (f"Unsupported dataset type: { type (dataset )} " )
268+
269+
109270def _convert_dataset_to_list (
110271 dataset ,
111272 dataset_type_spec ,
@@ -208,7 +369,7 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
208369 )
209370
210371 return iter (zip (* dataset ))
211- elif dataset_type_spec is tf . data . Dataset :
372+ elif is_tf_dataset ( dataset ) :
212373 if is_batched (dataset ):
213374 dataset = dataset .unbatch ()
214375 return iter (dataset )
@@ -242,15 +403,20 @@ def _get_next_sample(
242403 Yields:
243404 data_sample: The next sample.
244405 """
406+ from keras .src .trainers .data_adapters .data_adapter_utils import (
407+ is_tensorflow_tensor ,
408+ )
245409 from keras .src .trainers .data_adapters .data_adapter_utils import (
246410 is_torch_tensor ,
247411 )
248412
249413 try :
250414 dataset_iterator = iter (dataset_iterator )
251415 first_sample = next (dataset_iterator )
252- if isinstance (first_sample , (tf .Tensor , np .ndarray )) or is_torch_tensor (
253- first_sample
416+ if (
417+ isinstance (first_sample , np .ndarray )
418+ or is_tensorflow_tensor (first_sample )
419+ or is_torch_tensor (first_sample )
254420 ):
255421 first_sample_shape = np .array (first_sample ).shape
256422 else :
@@ -291,23 +457,36 @@ def _get_next_sample(
291457 yield sample
292458
293459
294- def is_torch_dataset (dataset ):
295- if hasattr (dataset , "__class__" ):
296- for parent in dataset .__class__ .__mro__ :
297- if parent .__name__ == "Dataset" and str (
298- parent .__module__
299- ).startswith ("torch.utils.data" ):
300- return True
301- return False
460+ def is_tf_dataset (dataset ):
461+ return _mro_matches (
462+ dataset ,
463+ class_names = ("DatasetV2" , "Dataset" ),
464+ module_prefixes = (
465+ "tensorflow.python.data" , # TF classic
466+ "tensorflow.data" , # newer TF paths
467+ ),
468+ )
302469
303470
304471def is_grain_dataset (dataset ):
305- if hasattr (dataset , "__class__" ):
306- for parent in dataset .__class__ .__mro__ :
307- if parent .__name__ in (
308- "MapDataset" ,
309- "IterDataset" ,
310- ) and str (parent .__module__ ).startswith ("grain._src.python" ):
472+ return _mro_matches (
473+ dataset ,
474+ class_names = ("MapDataset" , "IterDataset" ),
475+ module_prefixes = ("grain._src.python" ,),
476+ )
477+
478+
479+ def is_torch_dataset (dataset ):
480+ return _mro_matches (dataset , ("Dataset" ,), ("torch.utils.data" ,))
481+
482+
483+ def _mro_matches (dataset , class_names , module_prefixes ):
484+ if not hasattr (dataset , "__class__" ):
485+ return False
486+ for parent in dataset .__class__ .__mro__ :
487+ if parent .__name__ in class_names :
488+ mod = str (parent .__module__ )
489+ if any (mod .startswith (pref ) for pref in module_prefixes ):
311490 return True
312491 return False
313492
@@ -441,8 +620,10 @@ def _restore_dataset_from_list(
441620 dataset_as_list , dataset_type_spec , original_dataset
442621):
443622 """Restore the dataset from the list of arrays."""
444- if dataset_type_spec in [tuple , list , tf .data .Dataset ] or is_torch_dataset (
445- original_dataset
623+ if (
624+ dataset_type_spec in [tuple , list ]
625+ or is_tf_dataset (original_dataset )
626+ or is_torch_dataset (original_dataset )
446627 ):
447628 # Save structure by taking the first element.
448629 element_spec = dataset_as_list [0 ]
@@ -483,7 +664,9 @@ def _get_type_spec(dataset):
483664 return list
484665 elif isinstance (dataset , np .ndarray ):
485666 return np .ndarray
486- elif isinstance (dataset , tf .data .Dataset ):
667+ elif is_tf_dataset (dataset ):
668+ from keras .src .utils .module_utils import tensorflow as tf
669+
487670 return tf .data .Dataset
488671 elif is_torch_dataset (dataset ):
489672 from torch .utils .data import Dataset as TorchDataset
@@ -543,6 +726,8 @@ def index_directory(
543726 order.
544727 """
545728 if file_utils .is_remote_path (directory ):
729+ from keras .src .utils .module_utils import tensorflow as tf
730+
546731 os_module = tf .io .gfile
547732 path_module = tf .io .gfile
548733 else :
@@ -647,7 +832,12 @@ def index_directory(
647832
648833
649834def iter_valid_files (directory , follow_links , formats ):
650- io_module = tf .io .gfile if file_utils .is_remote_path (directory ) else os
835+ if file_utils .is_remote_path (directory ):
836+ from keras .src .utils .module_utils import tensorflow as tf
837+
838+ io_module = tf .io .gfile
839+ else :
840+ io_module = os
651841
652842 if not follow_links :
653843 walk = io_module .walk (directory )
@@ -674,9 +864,12 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
674864 paths, and `labels` is a list of integer labels corresponding
675865 to these files.
676866 """
677- path_module = (
678- tf .io .gfile if file_utils .is_remote_path (directory ) else os .path
679- )
867+ if file_utils .is_remote_path (directory ):
868+ from keras .src .utils .module_utils import tensorflow as tf
869+
870+ path_module = tf .io .gfile
871+ else :
872+ path_module = os .path
680873
681874 dirname = os .path .basename (directory )
682875 valid_files = iter_valid_files (directory , follow_links , formats )
@@ -746,6 +939,8 @@ def labels_to_dataset_tf(labels, label_mode, num_classes):
746939 Returns:
747940 A `tf.data.Dataset` instance.
748941 """
942+ from keras .src .utils .module_utils import tensorflow as tf
943+
749944 label_ds = tf .data .Dataset .from_tensor_slices (labels )
750945 if label_mode == "binary" :
751946 label_ds = label_ds .map (
0 commit comments