diff --git a/generate_tfrecord.py b/generate_tfrecord.py index 87fa5a98..7f21075f 100644 --- a/generate_tfrecord.py +++ b/generate_tfrecord.py @@ -2,10 +2,10 @@ Usage: # From tensorflow/models/ # Create train data: - python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record + python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record --label_map_dir=training/object-detection.pbtxt.pbtxt # Create test data: - python generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.record + python generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.record --label_map_dir=training/object-detection.pbtxt.pbtxt """ from __future__ import division from __future__ import print_function @@ -24,15 +24,33 @@ flags.DEFINE_string('csv_input', '', 'Path to the CSV input') flags.DEFINE_string('output_path', '', 'Path to output TFRecord') flags.DEFINE_string('image_dir', '', 'Path to images') +flags.DEFINE_string('label_map_dir', '', 'Path to label map') FLAGS = flags.FLAGS -# TO-DO replace this with label map -def class_text_to_int(row_label): - if row_label == 'raccoon': - return 1 - else: - None +def read_label_map(label_map_path): + item_id = None + item_name = None + items = {} + + with open(label_map_path, "r") as file: + for line in file: + line.replace(" ", "") + if line == "item{": + pass + elif line == "}": + pass + elif "id" in line: + item_id = int(line.split(":", 1)[1].strip()) + elif "name" in line: + item_name = line.split(":", 1)[1].replace("'", "").strip() + + if item_id is not None and item_name is not None: + items[item_name] = item_id + item_id = None + item_name = None + + return items def split(df, group): @@ -41,7 +59,7 @@ def split(df, group): return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] -def create_tf_example(group, path): +def create_tf_example(group, path, label_map): with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) @@ -63,7 +81,7 @@ def create_tf_example(group, path): ymins.append(row['ymin'] / height) ymaxs.append(row['ymax'] / height) classes_text.append(row['class'].encode('utf8')) - classes.append(class_text_to_int(row['class'])) + classes.append(label_map[row['class']]) tf_example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), @@ -87,8 +105,9 @@ def main(_): path = os.path.join(FLAGS.image_dir) examples = pd.read_csv(FLAGS.csv_input) grouped = split(examples, 'filename') + label_map = read_label_map(FLAGS.label_map_dir) for group in grouped: - tf_example = create_tf_example(group, path) + tf_example = create_tf_example(group, path, label_map) writer.write(tf_example.SerializeToString()) writer.close() diff --git a/test_generate_tfrecord.py b/test_generate_tfrecord.py index a7eabb68..2445fb42 100644 --- a/test_generate_tfrecord.py +++ b/test_generate_tfrecord.py @@ -25,7 +25,7 @@ def test_csv_to_tf_example_one_raccoon_per_file(self): grouped = generate_tfrecord.split(raccoon_df, 'filename') for group in grouped: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -72,7 +72,7 @@ def test_csv_to_tf_example_multiple_raccoons_per_file(self): grouped = generate_tfrecord.split(raccoon_df, 'filename') for group in grouped: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -123,7 +123,7 @@ def test_csv_to_tf_example_one_raccoons_multiple_files(self): grouped = generate_tfrecord.split(raccoon_df, 'filename') for group in grouped: if group.filename == image_file_one: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -155,7 +155,7 @@ def test_csv_to_tf_example_one_raccoons_multiple_files(self): example.features.feature['image/object/class/label'].int64_list.value, [1]) elif group.filename == image_file_two: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -207,7 +207,7 @@ def test_csv_to_tf_example_multiple_raccoons_multiple_files(self): grouped = generate_tfrecord.split(raccoon_df, 'filename') for group in grouped: if group.filename == image_file_one: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual( @@ -239,7 +239,7 @@ def test_csv_to_tf_example_multiple_raccoons_multiple_files(self): example.features.feature['image/object/class/label'].int64_list.value, [1, 1]) elif group.filename == image_file_two: - example = generate_tfrecord.create_tf_example(group, self.get_temp_dir()) + example = generate_tfrecord.create_tf_example(group, self.get_temp_dir(), {"raccoon": 1}) self._assertProtoEqual( example.features.feature['image/height'].int64_list.value, [256]) self._assertProtoEqual(