File tree Expand file tree Collapse file tree 7 files changed +26
-10
lines changed Expand file tree Collapse file tree 7 files changed +26
-10
lines changed Original file line number Diff line number Diff line change 28
28
from tensorflow_datasets .core import file_adapters
29
29
from tensorflow_datasets .core import splits as splits_lib
30
30
from tensorflow_datasets .core .utils import type_utils
31
+ from tensorflow_datasets .core .utils .lazy_imports_utils import array_record_data_source
31
32
import tree
32
33
33
- from array_record .python import array_record_data_source
34
-
35
34
T = TypeVar ('T' )
36
35
37
36
_DEFAULT_ITERATION_STEP = 1000
@@ -55,9 +54,9 @@ class ArrayRecordDataSource(AbcSequence):
55
54
None
56
55
)
57
56
iteration_step : int = _DEFAULT_ITERATION_STEP
58
- data_source : array_record_data_source . ArrayRecordDataSource = (
59
- dataclasses . field ( init = False )
60
- )
57
+ # In order to lazy load array_record, we don't load
58
+ # `array_record_data_source.ArrayRecordDataSource` here.
59
+ data_source : Any = dataclasses . field ( init = False )
61
60
length : int = dataclasses .field (init = False )
62
61
63
62
def __post_init__ (self ):
Original file line number Diff line number Diff line change 25
25
from tensorflow_datasets .core import splits as splits_lib
26
26
from tensorflow_datasets .core .data_sources import array_record
27
27
from tensorflow_datasets .core .utils import shard_utils
28
-
29
- from array_record .python import array_record_data_source
28
+ from tensorflow_datasets .core .utils .lazy_imports_utils import array_record_data_source
30
29
31
30
32
31
_FILE_INSTRUCTIONS = [
Original file line number Diff line number Diff line change 59
59
if typing .TYPE_CHECKING :
60
60
import apache_beam as beam
61
61
from apache_beam .runners import runner
62
- from array_record .python import array_record_data_source
63
62
64
63
ListOrTreeOrElem = type_utils .ListOrTreeOrElem
65
64
Tree = type_utils .Tree
Original file line number Diff line number Diff line change 24
24
25
25
from etils import epath
26
26
from tensorflow_datasets .core .utils import type_utils
27
+ from tensorflow_datasets .core .utils .lazy_imports_utils import array_record_module
27
28
from tensorflow_datasets .core .utils .lazy_imports_utils import tensorflow as tf
28
29
29
- from array_record .python import array_record_module
30
-
31
30
ExamplePositions = List [Any ]
32
31
33
32
Original file line number Diff line number Diff line change 30
30
from tensorflow_datasets .core import splits as split_lib
31
31
from tensorflow_datasets .core import utils
32
32
from tensorflow_datasets .core .proto import dataset_info_pb2
33
+ from tensorflow_datasets .core .utils .lazy_imports_utils import array_record_module
33
34
34
35
from google .protobuf import json_format
35
36
Original file line number Diff line number Diff line change @@ -185,9 +185,26 @@ def tf_success_callback(**kwargs):
185
185
ensure_tf_version (kwargs ["module" ])
186
186
187
187
188
+ def array_record_error_callback (** kwargs ):
189
+ del kwargs
190
+ print ("\n \n ***************************************************************" )
191
+ print (
192
+ "Failed to import ArrayRecord. This probably means that you are running"
193
+ " on macOS or Windows. ArrayRecord currently does not work for your"
194
+ " infrastructure, because it uses Python bindings in C++. We are actively"
195
+ " working on this issue. Thanks for your understanding."
196
+ )
197
+ print ("***************************************************************\n \n " )
198
+
199
+
188
200
with lazy_imports (
189
201
error_callback = tf_error_callback , success_callback = tf_success_callback
190
202
):
191
203
import tensorflow as tf # pylint: disable=g-import-not-at-top,unused-import
192
204
205
+
206
+ with lazy_imports (error_callback = array_record_error_callback ):
207
+ from array_record .python import array_record_data_source # pylint: disable=g-import-not-at-top,unused-import
208
+ from array_record .python import array_record_module # pylint: disable=g-import-not-at-top,unused-import
209
+
193
210
tensorflow = tf
Original file line number Diff line number Diff line change @@ -60,6 +60,7 @@ def _generate_examples(self):
60
60
def test_import_tfds_without_loading_tf ():
61
61
with mock .patch .object (logging , 'log_first_n' ) as log_first_n :
62
62
assert 'tensorflow' not in sys .modules
63
+ assert 'array_record' not in sys .modules
63
64
64
65
data_dir = '/tmp/import_without_tf'
65
66
builder = DummyDataset (data_dir = data_dir )
@@ -77,3 +78,4 @@ def test_import_tfds_without_loading_tf():
77
78
# No warning concerning TensorFlow DTypes was dispatched while loading
78
79
assert not log_first_n .called
79
80
assert 'tensorflow' not in sys .modules
81
+ assert 'array_record' in sys .modules
You can’t perform that action at this time.
0 commit comments