-
-
Save mnowatzky/0fb8b70ceea520a9525accdcdbfea324 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Usage: | |
| # From tensorflow/models/ | |
| # Create train data: | |
| python create_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.tfrecord | |
| # Create test data: | |
| python create_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.tfrecord | |
| """ | |
| from __future__ import division | |
| from __future__ import print_function | |
| from __future__ import absolute_import | |
| import os | |
| import io | |
| import pandas as pd | |
| import tensorflow as tf | |
| from PIL import Image | |
| import sys | |
| sys.path.append('../') | |
| from object_detection.utils import dataset_util | |
| from object_detection.utils import label_map_util | |
| from collections import namedtuple, OrderedDict | |
| flags = tf.compat.v1.flags | |
| flags.DEFINE_string('csv_input', '', 'data/train_labels.csv') | |
| flags.DEFINE_string('output_path', '', 'data/train.record') | |
| flags.DEFINE_string('image_dir', '', 'images/train') | |
| flags.DEFINE_string('label_map', '', 'data/label_map.pbtxt') | |
| FLAGS = flags.FLAGS | |
| label_dict = label_map_util.get_label_map_dict(FLAGS.label_map) | |
| def class_text_to_int(row_label): | |
| global label_dict | |
| return label_dict.get(row_label, 0) | |
| def split(df, group): | |
| data = namedtuple('data', ['filename', 'object']) | |
| gb = df.groupby(group) | |
| return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] | |
| def create_tf_example(group, path): | |
| with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: | |
| encoded_jpg = fid.read() | |
| encoded_jpg_io = io.BytesIO(encoded_jpg) | |
| image = Image.open(encoded_jpg_io) | |
| width, height = image.size | |
| filename = group.filename.encode('utf8') | |
| image_format = b'jpg' | |
| xmins = [] | |
| xmaxs = [] | |
| ymins = [] | |
| ymaxs = [] | |
| classes_text = [] | |
| classes = [] | |
| for index, row in group.object.iterrows(): | |
| xmins.append(row['xmin'] / width) | |
| xmaxs.append(row['xmax'] / width) | |
| ymins.append(row['ymin'] / height) | |
| ymaxs.append(row['ymax'] / height) | |
| classes_text.append(row['class'].encode('utf8')) | |
| classes.append(class_text_to_int(row['class'])) | |
| tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
| 'image/height': dataset_util.int64_feature(height), | |
| 'image/width': dataset_util.int64_feature(width), | |
| 'image/filename': dataset_util.bytes_feature(filename), | |
| 'image/source_id': dataset_util.bytes_feature(filename), | |
| 'image/encoded': dataset_util.bytes_feature(encoded_jpg), | |
| 'image/format': dataset_util.bytes_feature(image_format), | |
| 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | |
| 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | |
| 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | |
| 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | |
| 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
| 'image/object/class/label': dataset_util.int64_list_feature(classes), | |
| })) | |
| return tf_example | |
| def main(_): | |
| writer = tf.compat.v1.python_io.TFRecordWriter(FLAGS.output_path) | |
| path = os.path.join(FLAGS.image_dir) | |
| examples = pd.read_csv(FLAGS.csv_input) | |
| grouped = split(examples, 'filename') | |
| for group in grouped: | |
| tf_example = create_tf_example(group, path) | |
| writer.write(tf_example.SerializeToString()) | |
| writer.close() | |
| output_path = os.path.join(os.getcwd(), FLAGS.output_path) | |
| print('Successfully created the TFRecords: {}'.format(output_path)) | |
| if __name__ == '__main__': | |
| tf.compat.v1.app.run() |
Author
It looks like there is a problem with your label_map.pbtxt file. Are you sure that file is in the correct format?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, thank you for the code!
I unfortunately got an error and I don't understand why :
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x92 in position 57: invalid start byte
I would really appreciate some help!
(Environment: win7-64 - Anaconda 3 - Tensorflow 2.4.1 - Python 3.8.5)
here's the whole error message :
(base) C:\Users\LABO_RF\stage\test_mac_2>python create_tfrecord.py --csv_input=d
ata\train_labels.csv --output_path=train.tfrecord
2021-03-12 11:12:44.005094: W tensorflow/stream_executor/platform/default/dso_lo
ader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64
_110.dll not found
2021-03-12 11:12:44.005094: I tensorflow/stream_executor/cuda/cudart_stub.cc:29]
Ignore above cudart dlerror if you do not have a GPU set up on your machine.
Traceback (most recent call last):
File "create_tfrecord.py", line 33, in
label_dict = label_map_util.get_label_map_dict(FLAGS.label_map)
File "C:\Users\LABO_RF\stage\object_detection\models\research\object_detection
\utils\label_map_util.py", line 201, in get_label_map_dict
label_map = load_labelmap(label_map_path_or_proto)
File "C:\Users\LABO_RF\stage\object_detection\models\research\object_detection
\utils\label_map_util.py", line 168, in load_labelmap
label_map_string = fid.read()
File "C:\Users\LABO_RF\anaconda3\lib\site-packages\tensorflow\python\lib\io\fi
le_io.py", line 117, in read
self._preread_check()
File "C:\Users\LABO_RF\anaconda3\lib\site-packages\tensorflow\python\lib\io\fi
le_io.py", line 79, in _preread_check
self._read_buf = _pywrap_file_io.BufferedInputStream(
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x92 in position 57: invalid
start byte