Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions generate_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Usage:
# From tensorflow/models/
# Create train data:
python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record
python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record --label_map_dir=training/object-detection.pbtxt.pbtxt

# Create test data:
python generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.record
python generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.record --label_map_dir=training/object-detection.pbtxt.pbtxt
"""
from __future__ import division
from __future__ import print_function
Expand All @@ -24,15 +24,33 @@
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('image_dir', '', 'Path to images')
flags.DEFINE_string('label_map_dir', '', 'Path to label map')
FLAGS = flags.FLAGS


# TO-DO replace this with label map
def class_text_to_int(row_label):
if row_label == 'raccoon':
return 1
else:
None
def read_label_map(label_map_path):
item_id = None
item_name = None
items = {}

with open(label_map_path, "r") as file:
for line in file:
line.replace(" ", "")
if line == "item{":
pass
elif line == "}":
pass
elif "id" in line:
item_id = int(line.split(":", 1)[1].strip())
elif "name" in line:
item_name = line.split(":", 1)[1].replace("'", "").strip()

if item_id is not None and item_name is not None:
items[item_name] = item_id
item_id = None
item_name = None

return items


def split(df, group):
Expand All @@ -41,7 +59,7 @@ def split(df, group):
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]


def create_tf_example(group, path):
def create_tf_example(group, path, label_map):
with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
Expand All @@ -63,7 +81,7 @@ def create_tf_example(group, path):
ymins.append(row['ymin'] / height)
ymaxs.append(row['ymax'] / height)
classes_text.append(row['class'].encode('utf8'))
classes.append(class_text_to_int(row['class']))
classes.append(label_map[row['class']])

tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
Expand All @@ -87,8 +105,9 @@ def main(_):
path = os.path.join(FLAGS.image_dir)
examples = pd.read_csv(FLAGS.csv_input)
grouped = split(examples, 'filename')
label_map = read_label_map(FLAGS.label_map_dir)
for group in grouped:
tf_example = create_tf_example(group, path)
tf_example = create_tf_example(group, path, label_map)
writer.write(tf_example.SerializeToString())

writer.close()
Expand Down
12 changes: 6 additions & 6 deletions test_generate_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_csv_to_tf_example_one_raccoon_per_file(self):

grouped = generate_tfrecord.split(raccoon_df, 'filename')
for group in grouped:
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1})
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [256])
self._assertProtoEqual(
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_csv_to_tf_example_multiple_raccoons_per_file(self):

grouped = generate_tfrecord.split(raccoon_df, 'filename')
for group in grouped:
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1})
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [256])
self._assertProtoEqual(
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_csv_to_tf_example_one_raccoons_multiple_files(self):
grouped = generate_tfrecord.split(raccoon_df, 'filename')
for group in grouped:
if group.filename == image_file_one:
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1})
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [256])
self._assertProtoEqual(
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_csv_to_tf_example_one_raccoons_multiple_files(self):
example.features.feature['image/object/class/label'].int64_list.value,
[1])
elif group.filename == image_file_two:
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1})
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [256])
self._assertProtoEqual(
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_csv_to_tf_example_multiple_raccoons_multiple_files(self):
grouped = generate_tfrecord.split(raccoon_df, 'filename')
for group in grouped:
if group.filename == image_file_one:
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1})
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [256])
self._assertProtoEqual(
Expand Down Expand Up @@ -239,7 +239,7 @@ def test_csv_to_tf_example_multiple_raccoons_multiple_files(self):
example.features.feature['image/object/class/label'].int64_list.value,
[1, 1])
elif group.filename == image_file_two:
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir())
example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1})
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [256])
self._assertProtoEqual(
Expand Down