From f03e6323e895044837a827a3b47814793d144913 Mon Sep 17 00:00:00 2001 From: JemisaR Date: Sun, 16 Jun 2019 19:15:02 +0300 Subject: [PATCH 1/2] Read classes from label map I grew tired of doing copy and paste from the label map to the script. There is additional command line parameter to provide the path to the label map. The label map will be read once in the script, parsed into a dictionary and passed to the create_tf_example method. The names are the keys in the dictionary and the IDs are the values. All tests passed and tfrecords have been created without issues. --- generate_tfrecord.py | 40 ++++++++++++++++++++++++++++----------- test_generate_tfrecord.py | 12 ++++++------ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/generate_tfrecord.py b/generate_tfrecord.py index 87fa5a98..fb082727 100644 --- a/generate_tfrecord.py +++ b/generate_tfrecord.py @@ -2,10 +2,10 @@ Usage: # From tensorflow/models/ # Create train data: - python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record + python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record --label_map_dir=training/object-detection.pbtxt.pbtxt # Create test data: - python generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.record + python generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.record --label_map_dir=training/object-detection.pbtxt.pbtxt """ from __future__ import division from __future__ import print_function @@ -24,15 +24,32 @@ flags.DEFINE_string('csv_input', '', 'Path to the CSV input') flags.DEFINE_string('output_path', '', 'Path to output TFRecord') flags.DEFINE_string('image_dir', '', 'Path to images') +flags.DEFINE_string('label_map_dir', '', 'Path to label map') FLAGS = flags.FLAGS -# TO-DO replace this with label map -def class_text_to_int(row_label): - if row_label == 'raccoon': - return 1 - else: - None +def read_label_map(label_map_path): + item_id = None + item_name = None + items = {} + with open(label_map_path, "r") as file: + for line in file: + line.strip() + if line == "item {": + pass + elif "id" in line: + item_id = int(line.split(":", 1)[1].strip()) + elif "name" in line: + item_name = line.split(":", 1)[1].replace("'", "").strip() + elif line == "}": + pass + + if item_id is not None and item_name is not None: + items[item_name] = item_id + item_id = None + item_name = None + + return items def split(df, group): @@ -41,7 +58,7 @@ def split(df, group): return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] -def create_tf_example(group, path): +def create_tf_example(group, path, label_map): with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) @@ -63,7 +80,7 @@ def create_tf_example(group, path): ymins.append(row['ymin'] / height) ymaxs.append(row['ymax'] / height) classes_text.append(row['class'].encode('utf8')) - classes.append(class_text_to_int(row['class'])) + classes.append(label_map[row['class']]) tf_example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), @@ -87,8 +104,9 @@ def main(_): path = os.path.join(FLAGS.image_dir) examples = pd.read_csv(FLAGS.csv_input) grouped = split(examples, 'filename') + label_map = read_label_map(FLAGS.label_map_dir) for group in grouped: - tf_example = create_tf_example(group, path) + tf_example = create_tf_example(group, path, label_map) writer.write(tf_example.SerializeToString()) writer.close() diff --git a/test_generate_tfrecord.py b/test_generate_tfrecord.py index a7eabb68..2445fb42 100644 --- a/test_generate_tfrecord.py +++ b/test_generate_tfrecord.py @@ -25,7 +25,7 @@ def test_csv_to_tf_example_one_raccoon_per_file(self): grouped = generate_tfrecord.split(raccoon_df, 'filename') for group in grouped: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -72,7 +72,7 @@ def test_csv_to_tf_example_multiple_raccoons_per_file(self): grouped = generate_tfrecord.split(raccoon_df, 'filename') for group in grouped: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -123,7 +123,7 @@ def test_csv_to_tf_example_one_raccoons_multiple_files(self): grouped = generate_tfrecord.split(raccoon_df, 'filename') for group in grouped: if group.filename == image_file_one: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -155,7 +155,7 @@ def test_csv_to_tf_example_one_raccoons_multiple_files(self): example.features.feature['image/object/class/label'].int64_list.value, [1]) elif group.filename == image_file_two: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -207,7 +207,7 @@ def test_csv_to_tf_example_multiple_raccoons_multiple_files(self): grouped = generate_tfrecord.split(raccoon_df, 'filename') for group in grouped: if group.filename == image_file_one: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -239,7 +239,7 @@ def test_csv_to_tf_example_multiple_raccoons_multiple_files(self): example.features.feature['image/object/class/label'].int64_list.value, [1, 1]) elif group.filename == image_file_two: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( From 1587934d094516684da91430ad1b40d9353f65b8 Mon Sep 17 00:00:00 2001 From: JemisaR Date: Tue, 18 Jun 2019 17:27:36 +0300 Subject: [PATCH 2/2] Improved reading label map to be more typo tolerant. --- generate_tfrecord.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/generate_tfrecord.py b/generate_tfrecord.py index fb082727..7f21075f 100644 --- a/generate_tfrecord.py +++ b/generate_tfrecord.py @@ -32,17 +32,18 @@ def read_label_map(label_map_path): item_id = None item_name = None items = {} + with open(label_map_path, "r") as file: for line in file: - line.strip() - if line == "item {": + line.replace(" ", "") + if line == "item{": + pass + elif line == "}": pass elif "id" in line: item_id = int(line.split(":", 1)[1].strip()) elif "name" in line: item_name = line.split(":", 1)[1].replace("'", "").strip() - elif line == "}": - pass if item_id is not None and item_name is not None: items[item_name] = item_id