-
Notifications
You must be signed in to change notification settings - Fork 30
Description
I got a situation where I trained a model and saved its checkpoint files, then I need to restore the graph from the meta file and feed a new data iterator to keep training, so i find a issue talking about that, then i write some code to demo my situation.
Current behavior
When i use ParquetDataset to feed, i can't restore the meta file, and got the following error:
Traceback (most recent call last):
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
graph._c_graph, serialized, options) # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot add function '__inference_Dataset_flat_map__create_dataset_10' because a different function with the same name already exists.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test/io/restore_hb.py", line 223, in <module>
resume_training(another_train_dataset, another_test_dataset)
File "test/io/restore_hb.py", line 132, in resume_training
saver = tf.train.import_meta_graph('checkpoints_hb/fufu.meta')
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1697, in import_meta_graph
**kwargs)[0]
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1721, in _import_meta_graph_with_return_elements
**kwargs))
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/meta_graph.py", line 809, in import_scoped_meta_graph_with_return_elements
return_elements=return_elements)
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
producer_op_list=producer_op_list)
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 505, in _import_graph_def_internal
raise ValueError(str(e))
ValueError: Cannot add function '__inference_Dataset_flat_map__create_dataset_10' because a different function with the same name already exists.
I guess that error not belongs to a bug for HybridBackend, because i also try the TFRecordDataset and get a similar error:
Traceback (most recent call last):
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
graph._c_graph, serialized, options) # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot add function '__inference_Dataset_map__parse_function_55' because a different function with the same name already exists.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test/io/restore_pb.py", line 225, in <module>
restore_feed()
File "test/io/restore_pb.py", line 220, in restore_feed
resume_training(another_train_dataset, another_test_dataset)
File "test/io/restore_pb.py", line 155, in resume_training
saver = tf.train.import_meta_graph('checkpoints_pb/fufu.meta')
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1697, in import_meta_graph
**kwargs)[0]
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1721, in _import_meta_graph_with_return_elements
**kwargs))
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/meta_graph.py", line 809, in import_scoped_meta_graph_with_return_elements
return_elements=return_elements)
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
producer_op_list=producer_op_list)
File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 505, in _import_graph_def_internal
raise ValueError(str(e))
ValueError: Cannot add function '__inference_Dataset_map__parse_function_55' because a different function with the same name already exists.
But that process works for from_tensor_slices and CsvDataset, i'm just curious and want to know how to restore and feed a new dataset iterator.
Expected behavior
When i use ParquetDataset in traing, i can restore the checkpoint and feed a new ParquetDataset iterator
System information
- GPU model and memory: 16G for Tesla T4
- OS Platform: Ubuntu 18.04.5 LTS (Bionic Beaver)
- Docker version: 20.10.14
- GCC/CUDA/cuDNN version: gcc version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04),
- Python/conda version: Python 3.6.12
- TensorFlow/PyTorch version: tensorflow 1.15.5+deeprec2201
Code to reproduce
training and restore use ParquetDataset to feed that doesn't work
# Tensorflow 1.15
# https://github.com/tensorflow/tensorflow/issues/11679#
#
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from hybridbackend.tensorflow.data import DataFrame
from hybridbackend.tensorflow.data import ParquetDataset
from tensorflow.python.data.ops import dataset_ops
new_dtypes = {"test1": np.float32, "test2": np.float32}
train_df = pd.DataFrame(np.random.randint(0, 100, (5, 2)), columns=['test1', 'test2'])
train_df = train_df.astype(new_dtypes)
train_df.to_parquet('train.parquet')
test_df = pd.DataFrame(np.random.randint(0, 100, (2, 2)), columns=['test1', 'test2'])
test_df = test_df.astype(new_dtypes)
test_df.to_parquet('test.parquet')
def make_initializable_iterator(ds):
if hasattr(dataset_ops, 'make_initializable_iterator'):
return dataset_ops.make_initializable_iterator(ds)
return ds.make_initializable_iterator()
def make_one_shot_iterator(ds):
if hasattr(dataset_ops, 'make_one_shot_iterator'):
return dataset_ops.make_one_shot_iterator(ds)
return ds.make_one_shot_iterator()
def train(train_dataset, test_dataset):
"""
Create graph with an Dataset and Iterator and save the model.
There is some op that is applied to the data from the iterator.
"""
iterator_handle = tf.placeholder(tf.string, shape=[])
tf.add_to_collection('iterator_handle', iterator_handle)
iterator = tf.data.Iterator.from_string_handle(iterator_handle, dataset_ops.get_legacy_output_types(train_dataset),
dataset_ops.get_legacy_output_shapes(train_dataset),
dataset_ops.get_legacy_output_classes(train_dataset))
train_iter = make_initializable_iterator(train_dataset)
test_iter = make_initializable_iterator(test_dataset)
element = iterator.get_next()
v = tf.get_variable(name='v', initializer=tf.zeros(shape=(1, 2)))
# to use when saving summaries
global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
increament_global_step = tf.assign(global_step, global_step + 1)
global_step = global_step + 1
tf.add_to_collection('increament_global_step', increament_global_step)
some_op = tf.assign(v, v + tf.abs(element['test1']))
tf.add_to_collection('some_op', tf.reduce_sum(some_op))
tf.summary.scalar('v_sum', tf.reduce_sum(v))
tf.summary.scalar('some_op', tf.reduce_mean(some_op))
merged_summary = tf.summary.merge_all()
tf.add_to_collection('merged_summary', merged_summary)
writer = tf.summary.FileWriter('checkpoints_hb', graph=tf.get_default_graph())
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Run data iterator initialisation
sess.run(train_iter.initializer)
sess.run(test_iter.initializer)
# "Training"
print("Training")
while True:
try:
[op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
feed_dict={iterator_handle: train_handle})
writer.add_summary(summary_values, global_step=g_step)
print(op)
except tf.errors.OutOfRangeError:
break
# "Test evaluation"
print("Testing")
while True:
try:
print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
except tf.errors.OutOfRangeError:
break
saver.save(sess, 'checkpoints_hb/fufu')
def resume_training(train_dataset, test_dataset):
"""Restore the model from file and pass some new data through it
for further training """
with tf.Session() as sess:
saver = tf.train.import_meta_graph('checkpoints_hb/fufu.meta')
saver.restore(sess, 'checkpoints_hb/fufu')
iterator_handle = tf.get_collection('iterator_handle')[0]
some_op = tf.get_collection('some_op')[0]
increament_global_step = tf.get_collection('increament_global_step')[0]
merged_summary = tf.get_collection('merged_summary')[0]
writer = tf.summary.FileWriter('checkpoints_hb', graph=tf.get_default_graph())
# Make new iterators and handles
train_iter = make_initializable_iterator(train_dataset)
test_iter = make_initializable_iterator(test_dataset)
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Further training the model using new datasets (which may be different from original ones)
print("Resume training ...")
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Run data iterator initialisation
sess.run(train_iter.initializer)
sess.run(test_iter.initializer)
# "Training"
print("Training")
while True:
try:
[op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
feed_dict={iterator_handle: train_handle})
writer.add_summary(summary_values, global_step=g_step)
print(op)
except tf.errors.OutOfRangeError:
break
# "Test evaluation"
print("Testing")
while True:
try:
print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
except tf.errors.OutOfRangeError:
break
saver.save(sess, 'checkpoints_hb/fufu')
def train_feed():
# delete existing saved models and summary files
if os.path.exists('checkpoints_hb'):
shutil.rmtree('checkpoints_hb')
# train_dataset = tf.data.Dataset.from_tensor_slices(
# tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
train_dataset = ParquetDataset('train.parquet',
batch_size=1,
fields=[DataFrame.Field('test1', tf.float32),
DataFrame.Field('test2', tf.float32)])
test_dataset = ParquetDataset('test.parquet',
batch_size=1,
fields=[DataFrame.Field('test1', tf.float32),
DataFrame.Field('test2', tf.float32)])
# test_dataset = tf.data.Dataset.from_tensor_slices(
# tf.constant(np.random.randint(0, 100, (2, 2)), dtype=tf.float32))
train(train_dataset, test_dataset)
def restore_feed():
# Load and fine-tune the saved model using new data
another_train_dataset = ParquetDataset(
'train.parquet',
batch_size=1,
fields=[DataFrame.Field('test1', tf.float32),
DataFrame.Field('test2', tf.float32)])
another_test_dataset = ParquetDataset(
'test.parquet', batch_size=1, fields=[DataFrame.Field('test1', tf.float32),
DataFrame.Field('test2', tf.float32)])
resume_training(another_train_dataset, another_test_dataset)
if __name__ == '__main__':
train_feed()
restore_feed()It works neither for TFRecordDataset.
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from tensorflow.python.data.ops import dataset_ops
def make_one_shot_iterator(ds):
if hasattr(dataset_ops, 'make_one_shot_iterator'):
return dataset_ops.make_one_shot_iterator(ds)
return ds.make_one_shot_iterator()
def make_initializable_iterator(ds):
if hasattr(dataset_ops, 'make_initializable_iterator'):
return dataset_ops.make_initializable_iterator(ds)
return ds.make_initializable_iterator()
# Define features
feature_description = {
'test1': tf.io.FixedLenFeature([], dtype=tf.float32),
'test2': tf.io.FixedLenFeature([], dtype=tf.float32)
}
def _parse_function(example_proto):
return tf.io.parse_example(example_proto, feature_description)
def write_pb(df, file):
# Write TFrecord file
with tf.io.TFRecordWriter(file) as writer:
for index, row in df.iterrows():
print(row['test1'], row['test2'])
# Create the Example
example = tf.train.Example(features=tf.train.Features(
feature={
'test1': tf.train.Feature(float_list=tf.train.FloatList(value=[row['test1']])),
'test2': tf.train.Feature(float_list=tf.train.FloatList(value=[row['test2']]))
}))
writer.write(example.SerializeToString())
new_dtypes = {"test1": np.float32, "test2": np.float32}
train_df = pd.DataFrame(np.random.randint(0, 100, (5, 2)), columns=['test1', 'test2'])
train_df = train_df.astype(new_dtypes)
write_pb(train_df, 'train.tfrecord')
test_df = pd.DataFrame(np.random.randint(0, 100, (2, 2)), columns=['test1', 'test2'])
test_df = test_df.astype(new_dtypes)
write_pb(test_df, 'test.tfrecord')
def train(train_dataset, test_dataset):
"""
Create graph with an Dataset and Iterator and save the model.
There is some op that is applied to the data from the iterator.
"""
iterator_handle = tf.placeholder(tf.string, shape=[])
tf.add_to_collection('iterator_handle', iterator_handle)
iterator = tf.data.Iterator.from_string_handle(iterator_handle, dataset_ops.get_legacy_output_types(train_dataset),
dataset_ops.get_legacy_output_shapes(train_dataset),
dataset_ops.get_legacy_output_classes(train_dataset))
train_iter = make_initializable_iterator(train_dataset)
test_iter = make_initializable_iterator(test_dataset)
element = iterator.get_next()
v = tf.get_variable(name='v', initializer=tf.zeros(shape=(1, 2)))
# to use when saving summaries
global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
increament_global_step = tf.assign(global_step, global_step + 1)
global_step = global_step + 1
tf.add_to_collection('increament_global_step', increament_global_step)
some_op = tf.assign(v, v + tf.abs(element['test1']))
tf.add_to_collection('some_op', tf.reduce_sum(some_op))
tf.summary.scalar('v_sum', tf.reduce_sum(v))
tf.summary.scalar('some_op', tf.reduce_mean(some_op))
merged_summary = tf.summary.merge_all()
tf.add_to_collection('merged_summary', merged_summary)
writer = tf.summary.FileWriter('checkpoints_pb', graph=tf.get_default_graph())
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Run data iterator initialisation
sess.run(train_iter.initializer)
sess.run(test_iter.initializer)
# "Training"
print("Training")
while True:
try:
[op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
feed_dict={iterator_handle: train_handle})
writer.add_summary(summary_values, global_step=g_step)
print(op)
except tf.errors.OutOfRangeError:
break
# "Test evaluation"
print("Testing")
while True:
try:
print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
except tf.errors.OutOfRangeError:
break
saver.save(sess, 'checkpoints_pb/fufu')
def resume_training(train_dataset, test_dataset):
"""Restore the model from file and pass some new data through it
for further training """
with tf.Session() as sess:
saver = tf.train.import_meta_graph('checkpoints_pb/fufu.meta')
saver.restore(sess, 'checkpoints_pb/fufu')
iterator_handle = tf.get_collection('iterator_handle')[0]
some_op = tf.get_collection('some_op')[0]
increament_global_step = tf.get_collection('increament_global_step')[0]
merged_summary = tf.get_collection('merged_summary')[0]
writer = tf.summary.FileWriter('checkpoints_pb', graph=tf.get_default_graph())
# Make new iterators and handles
train_iter = make_initializable_iterator(train_dataset)
test_iter = make_initializable_iterator(test_dataset)
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Further training the model using new datasets (which may be different from original ones)
print("Resume training ...")
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Run data iterator initialisation
sess.run(train_iter.initializer)
sess.run(test_iter.initializer)
# "Training"
print("Training")
while True:
try:
[op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
feed_dict={iterator_handle: train_handle})
writer.add_summary(summary_values, global_step=g_step)
print(op)
except tf.errors.OutOfRangeError:
break
# "Test evaluation"
print("Testing")
while True:
try:
print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
except tf.errors.OutOfRangeError:
break
saver.save(sess, 'checkpoints_pb/fufu')
def train_feed():
# delete existing saved models and summary files
if os.path.exists('checkpoints_pb'):
shutil.rmtree('checkpoints_pb')
# train_dataset = tf.data.Dataset.from_tensor_slices(
# tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
train_dataset = tf.data.TFRecordDataset(['train.tfrecord']).batch(1).map(_parse_function)
test_dataset = tf.data.TFRecordDataset(['test.tfrecord']).batch(1).map(_parse_function)
train(train_dataset, test_dataset)
def restore_feed():
# Load and fine-tune the saved model using new data
another_train_dataset = tf.data.TFRecordDataset(['train.tfrecord']).batch(1).map(_parse_function)
another_test_dataset = tf.data.TFRecordDataset(['test.tfrecord']).batch(1).map(_parse_function)
resume_training(another_train_dataset, another_test_dataset)
if __name__ == '__main__':
train_feed()
restore_feed()But works for CsvDataset
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.ops import dataset_ops
new_dtypes = {"test1": np.float32, "test2": np.float32}
train_df = pd.DataFrame(np.random.randint(0, 100, (5, 2)), columns=['test1', 'test2'])
train_df = train_df.astype(new_dtypes)
train_df.to_csv('train.csv', index=False)
test_df = pd.DataFrame(np.random.randint(0, 100, (2, 2)), columns=['test1', 'test2'])
test_df = test_df.astype(new_dtypes)
test_df.to_csv('test.csv', index=False)
def make_initializable_iterator(ds):
if hasattr(dataset_ops, 'make_initializable_iterator'):
return dataset_ops.make_initializable_iterator(ds)
return ds.make_initializable_iterator()
def make_one_shot_iterator(ds):
if hasattr(dataset_ops, 'make_one_shot_iterator'):
return dataset_ops.make_one_shot_iterator(ds)
return ds.make_one_shot_iterator()
def train(train_dataset, test_dataset):
"""
Create graph with an Dataset and Iterator and save the model.
There is some op that is applied to the data from the iterator.
"""
iterator_handle = tf.placeholder(tf.string, shape=[])
tf.add_to_collection('iterator_handle', iterator_handle)
iterator = tf.data.Iterator.from_string_handle(iterator_handle, dataset_ops.get_legacy_output_types(train_dataset),
dataset_ops.get_legacy_output_shapes(train_dataset),
dataset_ops.get_legacy_output_classes(train_dataset))
train_iter = make_initializable_iterator(train_dataset)
test_iter = make_initializable_iterator(test_dataset)
element = iterator.get_next()
v = tf.get_variable(name='v', initializer=tf.zeros(shape=(1, 2)))
# to use when saving summaries
global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
increament_global_step = tf.assign(global_step, global_step + 1)
global_step = global_step + 1
tf.add_to_collection('increament_global_step', increament_global_step)
some_op = tf.assign(v, v + tf.abs(element))
tf.add_to_collection('some_op', tf.reduce_sum(some_op))
tf.summary.scalar('v_sum', tf.reduce_sum(v))
tf.summary.scalar('some_op', tf.reduce_mean(some_op))
merged_summary = tf.summary.merge_all()
tf.add_to_collection('merged_summary', merged_summary)
writer = tf.summary.FileWriter('checkpoints_csv', graph=tf.get_default_graph())
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Run data iterator initialisation
sess.run(train_iter.initializer)
sess.run(test_iter.initializer)
# "Training"
print("Training")
while True:
try:
[op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
feed_dict={iterator_handle: train_handle})
writer.add_summary(summary_values, global_step=g_step)
print(op)
except tf.errors.OutOfRangeError:
break
# "Test evaluation"
print("Testing")
while True:
try:
print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
except tf.errors.OutOfRangeError:
break
saver.save(sess, 'checkpoints_csv/fufu')
def resume_training(train_dataset, test_dataset):
"""Restore the model from file and pass some new data through it
for further training """
with tf.Session() as sess:
saver = tf.train.import_meta_graph('checkpoints_csv/fufu.meta')
saver.restore(sess, 'checkpoints_csv/fufu')
iterator_handle = tf.get_collection('iterator_handle')[0]
some_op = tf.get_collection('some_op')[0]
increament_global_step = tf.get_collection('increament_global_step')[0]
merged_summary = tf.get_collection('merged_summary')[0]
writer = tf.summary.FileWriter('checkpoints_csv', graph=tf.get_default_graph())
# Make new iterators and handles
train_iter = make_initializable_iterator(train_dataset)
test_iter = make_initializable_iterator(test_dataset)
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Further training the model using new datasets (which may be different from original ones)
print("Resume training ...")
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Run data iterator initialisation
sess.run(train_iter.initializer)
sess.run(test_iter.initializer)
# "Training"
print("Training")
while True:
try:
[op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
feed_dict={iterator_handle: train_handle})
writer.add_summary(summary_values, global_step=g_step)
print(op)
except tf.errors.OutOfRangeError:
break
# "Test evaluation"
print("Testing")
while True:
try:
print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
except tf.errors.OutOfRangeError:
break
saver.save(sess, 'checkpoints_csv/fufu')
def train_feed():
# delete existing saved models and summary files
if os.path.exists('checkpoints_csv'):
shutil.rmtree('checkpoints_csv')
# train_dataset = tf.data.Dataset.from_tensor_slices(
# tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
train_dataset = readers.CsvDataset("train.csv", record_defaults=[tf.float32, tf.float32], header=True)
test_dataset = readers.CsvDataset("test.csv", record_defaults=[tf.float32, tf.float32], header=True)
# test_dataset = tf.data.Dataset.from_tensor_slices(
# tf.constant(np.random.randint(0, 100, (2, 2)), dtype=tf.float32))
train(train_dataset, test_dataset)
def restore_feed():
# Load and fine-tune the saved model using new data
another_train_dataset = readers.CsvDataset("train.csv", record_defaults=[tf.float32, tf.float32], header=True)
another_test_dataset = readers.CsvDataset("test.csv", record_defaults=[tf.float32, tf.float32], header=True)
resume_training(another_train_dataset, another_test_dataset)
if __name__ == '__main__':
train_feed()
restore_feed()Willing to contribute
Yes