-
-
Save mnowatzky/0fb8b70ceea520a9525accdcdbfea324 to your computer and use it in GitHub Desktop.
| """ | |
| Usage: | |
| # From tensorflow/models/ | |
| # Create train data: | |
| python create_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.tfrecord | |
| # Create test data: | |
| python create_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.tfrecord | |
| """ | |
| from __future__ import division | |
| from __future__ import print_function | |
| from __future__ import absolute_import | |
| import os | |
| import io | |
| import pandas as pd | |
| import tensorflow as tf | |
| from PIL import Image | |
| import sys | |
| sys.path.append('../') | |
| from object_detection.utils import dataset_util | |
| from object_detection.utils import label_map_util | |
| from collections import namedtuple, OrderedDict | |
| flags = tf.compat.v1.flags | |
| flags.DEFINE_string('csv_input', '', 'data/train_labels.csv') | |
| flags.DEFINE_string('output_path', '', 'data/train.record') | |
| flags.DEFINE_string('image_dir', '', 'images/train') | |
| flags.DEFINE_string('label_map', '', 'data/label_map.pbtxt') | |
| FLAGS = flags.FLAGS | |
| label_dict = label_map_util.get_label_map_dict(FLAGS.label_map) | |
| def class_text_to_int(row_label): | |
| global label_dict | |
| return label_dict.get(row_label, 0) | |
| def split(df, group): | |
| data = namedtuple('data', ['filename', 'object']) | |
| gb = df.groupby(group) | |
| return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] | |
| def create_tf_example(group, path): | |
| with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: | |
| encoded_jpg = fid.read() | |
| encoded_jpg_io = io.BytesIO(encoded_jpg) | |
| image = Image.open(encoded_jpg_io) | |
| width, height = image.size | |
| filename = group.filename.encode('utf8') | |
| image_format = b'jpg' | |
| xmins = [] | |
| xmaxs = [] | |
| ymins = [] | |
| ymaxs = [] | |
| classes_text = [] | |
| classes = [] | |
| for index, row in group.object.iterrows(): | |
| xmins.append(row['xmin'] / width) | |
| xmaxs.append(row['xmax'] / width) | |
| ymins.append(row['ymin'] / height) | |
| ymaxs.append(row['ymax'] / height) | |
| classes_text.append(row['class'].encode('utf8')) | |
| classes.append(class_text_to_int(row['class'])) | |
| tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
| 'image/height': dataset_util.int64_feature(height), | |
| 'image/width': dataset_util.int64_feature(width), | |
| 'image/filename': dataset_util.bytes_feature(filename), | |
| 'image/source_id': dataset_util.bytes_feature(filename), | |
| 'image/encoded': dataset_util.bytes_feature(encoded_jpg), | |
| 'image/format': dataset_util.bytes_feature(image_format), | |
| 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | |
| 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | |
| 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | |
| 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | |
| 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
| 'image/object/class/label': dataset_util.int64_list_feature(classes), | |
| })) | |
| return tf_example | |
| def main(_): | |
| writer = tf.compat.v1.python_io.TFRecordWriter(FLAGS.output_path) | |
| path = os.path.join(FLAGS.image_dir) | |
| examples = pd.read_csv(FLAGS.csv_input) | |
| grouped = split(examples, 'filename') | |
| for group in grouped: | |
| tf_example = create_tf_example(group, path) | |
| writer.write(tf_example.SerializeToString()) | |
| writer.close() | |
| output_path = os.path.join(os.getcwd(), FLAGS.output_path) | |
| print('Successfully created the TFRecords: {}'.format(output_path)) | |
| if __name__ == '__main__': | |
| tf.compat.v1.app.run() |
Thank you for the tip! That was a spelling error and I have updated the gist.
Hi, thank you for the code!
I unfortunately got an error and I don't understand why :
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x92 in position 57: invalid start byte
I would really appreciate some help!
(Environment: win7-64 - Anaconda 3 - Tensorflow 2.4.1 - Python 3.8.5)
here's the whole error message :
(base) C:\Users\LABO_RF\stage\test_mac_2>python create_tfrecord.py --csv_input=d
ata\train_labels.csv --output_path=train.tfrecord
2021-03-12 11:12:44.005094: W tensorflow/stream_executor/platform/default/dso_lo
ader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64
_110.dll not found
2021-03-12 11:12:44.005094: I tensorflow/stream_executor/cuda/cudart_stub.cc:29]
Ignore above cudart dlerror if you do not have a GPU set up on your machine.
Traceback (most recent call last):
File "create_tfrecord.py", line 33, in
label_dict = label_map_util.get_label_map_dict(FLAGS.label_map)
File "C:\Users\LABO_RF\stage\object_detection\models\research\object_detection
\utils\label_map_util.py", line 201, in get_label_map_dict
label_map = load_labelmap(label_map_path_or_proto)
File "C:\Users\LABO_RF\stage\object_detection\models\research\object_detection
\utils\label_map_util.py", line 168, in load_labelmap
label_map_string = fid.read()
File "C:\Users\LABO_RF\anaconda3\lib\site-packages\tensorflow\python\lib\io\fi
le_io.py", line 117, in read
self._preread_check()
File "C:\Users\LABO_RF\anaconda3\lib\site-packages\tensorflow\python\lib\io\fi
le_io.py", line 79, in _preread_check
self._read_buf = _pywrap_file_io.BufferedInputStream(
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x92 in position 57: invalid
start byte
It looks like there is a problem with your label_map.pbtxt file. Are you sure that file is in the correct format?
If you got here from following @iKhushPatel 's article, don't forget to change "label_map.pbtext" to "label_map.pbtxt" else you will run into some errors.