diff --git a/input_pipeline.py b/input_pipeline.py index 5ef8310..036b5f6 100644 --- a/input_pipeline.py +++ b/input_pipeline.py @@ -52,7 +52,7 @@ def input_fn(): # same input file is sent to all workers. if isinstance(input_file, str) or len(input_file) == 1: options = tf.data.Options() - options.experimental_distribute.auto_shard = False + options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF d = d.with_options(options) return d