-
-
Save 4PixelsDev/7400df5cd6f004c4d630c849660577d6 to your computer and use it in GitHub Desktop.
| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Writes metadata and label file to the image classifier models.""" | |
| """ | |
| Usage - Terminal command which execute the script accept 3 parameters: | |
| -- model_file - path to .tflite model without metadata | |
| -- label_file - path to .txt file with classes (1 class per row) | |
| -- export_directory - path to generated .tflite model with metadata | |
| """ | |
| # python ./metadata_writer_for_object_detection.py \ | |
| # --model_file=./model_without_metadata/final_model.tflite \ | |
| # --label_file=./labels.txt \ | |
| # --export_directory=./model_with_metadata | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os | |
| from absl import app | |
| from absl import flags | |
| import tensorflow as tf | |
| import flatbuffers | |
| # pylint: disable=g-direct-tensorflow-import | |
| from tflite_support import metadata_schema_py_generated as _metadata_fb | |
| from tflite_support import metadata as _metadata | |
| # pylint: enable=g-direct-tensorflow-import | |
| FLAGS = flags.FLAGS | |
| def define_flags(): | |
| flags.DEFINE_string("model_file", None, | |
| "Path and file name to the TFLite model file.") | |
| flags.DEFINE_string("label_file", None, "Path to the label file.") | |
| flags.DEFINE_string("export_directory", None, | |
| "Path to save the TFLite model files with metadata.") | |
| flags.mark_flag_as_required("model_file") | |
| flags.mark_flag_as_required("label_file") | |
| flags.mark_flag_as_required("export_directory") | |
| class ModelSpecificInfo(object): | |
| """Holds information that is specificly tied to an image classifier.""" | |
| def __init__(self, name, version, image_width, image_height, image_min, | |
| image_max, mean, std, num_classes): | |
| self.name = name | |
| self.version = version | |
| self.image_width = image_width | |
| self.image_height = image_height | |
| self.image_min = image_min | |
| self.image_max = image_max | |
| self.mean = mean | |
| self.std = std | |
| self.num_classes = num_classes | |
| _MODEL_INFO = { | |
| "final_model.tflite": | |
| ModelSpecificInfo( | |
| name="MobileNetV1 image classifier", | |
| version="v1", | |
| image_width=300, | |
| image_height=300, | |
| image_min=0, | |
| image_max=255, | |
| mean=[127.5], | |
| std=[127.5], | |
| num_classes=1) | |
| } | |
| class MetadataPopulatorForObjectDetection(object): | |
| """Populates the metadata for an image classifier.""" | |
| def __init__(self, model_file, model_info, label_file_path): | |
| self.model_file = model_file | |
| self.model_info = model_info | |
| self.label_file_path = label_file_path | |
| self.metadata_buf = None | |
| def populate(self): | |
| """Creates metadata and then populates it for an image classifier.""" | |
| self._create_metadata() | |
| self._populate_metadata() | |
| def _create_metadata(self): | |
| """Creates the metadata for an image classifier.""" | |
| # Creates model info. | |
| model_meta = _metadata_fb.ModelMetadataT() | |
| model_meta.name = self.model_info.name | |
| model_meta.description = ("Equipment.") | |
| model_meta.version = self.model_info.version | |
| model_meta.author = "TensorFlow" | |
| model_meta.license = ("Apache License. Version 2.0 " | |
| "http://www.apache.org/licenses/LICENSE-2.0.") | |
| # Creates input info. | |
| input_meta = _metadata_fb.TensorMetadataT() | |
| input_meta.name = "image" | |
| input_meta.description = ("The expected image is 300 x 300, with three channels " | |
| "(red, blue, and green) per pixel. Each value in the tensor is between" | |
| " 0 and 1.") | |
| input_meta.content = _metadata_fb.ContentT() | |
| input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT() | |
| input_meta.content.contentProperties.colorSpace = ( | |
| _metadata_fb.ColorSpaceType.RGB) | |
| input_meta.content.contentPropertiesType = ( | |
| _metadata_fb.ContentProperties.ImageProperties) | |
| input_normalization = _metadata_fb.ProcessUnitT() | |
| input_normalization.optionsType = ( | |
| _metadata_fb.ProcessUnitOptions.NormalizationOptions) | |
| input_normalization.options = _metadata_fb.NormalizationOptionsT() | |
| input_normalization.options.mean = self.model_info.mean | |
| input_normalization.options.std = self.model_info.std | |
| input_meta.processUnits = [input_normalization] | |
| input_stats = _metadata_fb.StatsT() | |
| input_stats.max = [self.model_info.image_max] | |
| input_stats.min = [self.model_info.image_min] | |
| input_meta.stats = input_stats | |
| # Creates output info. | |
| output_location_meta = _metadata_fb.TensorMetadataT() | |
| output_location_meta.name = "location" | |
| output_location_meta.description = "The locations of the detected boxes." | |
| output_location_meta.content = _metadata_fb.ContentT() | |
| output_location_meta.content.contentPropertiesType = (_metadata_fb.ContentProperties.BoundingBoxProperties) | |
| output_location_meta.content.contentProperties = (_metadata_fb.BoundingBoxPropertiesT()) | |
| output_location_meta.content.contentProperties.index = [1, 0, 3, 2] | |
| output_location_meta.content.contentProperties.type = (_metadata_fb.BoundingBoxType.BOUNDARIES) | |
| output_location_meta.content.contentProperties.coordinateType = (_metadata_fb.CoordinateType.RATIO) | |
| output_location_meta.content.range = _metadata_fb.ValueRangeT() | |
| output_location_meta.content.range.min = 2 | |
| output_location_meta.content.range.max = 2 | |
| output_class_meta = _metadata_fb.TensorMetadataT() | |
| output_class_meta.name = "category" | |
| output_class_meta.description = "The categories of the detected boxes." | |
| output_class_meta.content = _metadata_fb.ContentT() | |
| output_class_meta.content.contentPropertiesType = ( | |
| _metadata_fb.ContentProperties.FeatureProperties) | |
| output_class_meta.content.contentProperties = ( | |
| _metadata_fb.FeaturePropertiesT()) | |
| output_class_meta.content.range = _metadata_fb.ValueRangeT() | |
| output_class_meta.content.range.min = 2 | |
| output_class_meta.content.range.max = 2 | |
| label_file = _metadata_fb.AssociatedFileT() | |
| label_file.name = os.path.basename(self.label_file_path) | |
| label_file.description = "Label of objects that this model can recognize." | |
| label_file.type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS | |
| output_class_meta.associatedFiles = [label_file] | |
| output_score_meta = _metadata_fb.TensorMetadataT() | |
| output_score_meta.name = "score" | |
| output_score_meta.description = "The scores of the detected boxes." | |
| output_score_meta.content = _metadata_fb.ContentT() | |
| output_score_meta.content.contentPropertiesType = ( | |
| _metadata_fb.ContentProperties.FeatureProperties) | |
| output_score_meta.content.contentProperties = ( | |
| _metadata_fb.FeaturePropertiesT()) | |
| output_score_meta.content.range = _metadata_fb.ValueRangeT() | |
| output_score_meta.content.range.min = 2 | |
| output_score_meta.content.range.max = 2 | |
| output_number_meta = _metadata_fb.TensorMetadataT() | |
| output_number_meta.name = "number of detections" | |
| output_number_meta.description = "The number of the detected boxes." | |
| output_number_meta.content = _metadata_fb.ContentT() | |
| output_number_meta.content.contentPropertiesType = ( | |
| _metadata_fb.ContentProperties.FeatureProperties) | |
| output_number_meta.content.contentProperties = ( | |
| _metadata_fb.FeaturePropertiesT()) | |
| # Creates subgraph info. | |
| group = _metadata_fb.TensorGroupT() | |
| group.name = "detection result" | |
| group.tensorNames = [ output_location_meta.name, output_class_meta.name, output_score_meta.name ] | |
| subgraph = _metadata_fb.SubGraphMetadataT() | |
| subgraph.inputTensorMetadata = [input_meta] | |
| subgraph.outputTensorMetadata = [output_location_meta, output_class_meta, output_score_meta,output_number_meta] | |
| subgraph.outputTensorGroups = [group] | |
| model_meta.subgraphMetadata = [subgraph] | |
| b = flatbuffers.Builder(0) | |
| b.Finish( | |
| model_meta.Pack(b), | |
| _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) | |
| self.metadata_buf = b.Output() | |
| def _populate_metadata(self): | |
| """Populates metadata and label file to the model file.""" | |
| populator = _metadata.MetadataPopulator.with_model_file(self.model_file) | |
| populator.load_metadata_buffer(self.metadata_buf) | |
| populator.load_associated_files([self.label_file_path]) | |
| populator.populate() | |
| def main(_): | |
| model_file = FLAGS.model_file | |
| model_basename = os.path.basename(model_file) | |
| if model_basename not in _MODEL_INFO: | |
| raise ValueError( | |
| "The model info for, {0}, is not defined yet.".format(model_basename)) | |
| export_model_path = os.path.join(FLAGS.export_directory, model_basename) | |
| # Copies model_file to export_path. | |
| tf.io.gfile.copy(model_file, export_model_path, overwrite=True) | |
| # Generate the metadata objects and put them in the model file | |
| populator = MetadataPopulatorForObjectDetection( | |
| export_model_path, _MODEL_INFO.get(model_basename), FLAGS.label_file) | |
| populator.populate() | |
| # Validate the output model file by reading the metadata and produce | |
| # a json file with the metadata under the export path | |
| displayer = _metadata.MetadataDisplayer.with_model_file(export_model_path) | |
| export_json_file = os.path.join(FLAGS.export_directory, | |
| os.path.splitext(model_basename)[0] + ".json") | |
| json_file = displayer.get_metadata_json() | |
| with open(export_json_file, "w") as f: | |
| f.write(json_file) | |
| print("Finished populating metadata and associated file to the model:") | |
| print(model_file) | |
| print("The metadata json file has been saved to:") | |
| print(export_json_file) | |
| print("The associated file that has been been packed to the model is:") | |
| print(displayer.get_packed_associated_file_list()) | |
| if __name__ == "__main__": | |
| define_flags() | |
| app.run(main) |
Dear Igor,
I am trying to convert a working TF2 model into the TFLite format to be used in a mobile App. I have found your excellent code but I get one error message in line 210.
ValueError: The number of output tensors (8) should match the number of output tensor metadata (4)
I have been trying to solve it without success.
Any ideas on how to proceed?
Thanks so much@agql Where did you was your error after all?? I am facing the same issue....
@MikeMpapa
Hello, I am not sure about it, but i think it might be a problem of conversion (your trained model -> your trained model in tflite)
Before using this script, you have to be sure your model is converted in the right way.
I first use the code : export_tflite_graph_tf2.py that you can find here https://github.com/tensorflow/models/tree/master/research/object_detection
And then I convert this new model in tflite using this code :
` import tensorflow as tf
#Convert the model
saved_model_dir = " path to your saved_model_dir "
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
open("name_model.tflite", "wb").write(tflite_model) `
After that, there is no more issues with the number of output tensors
Hello sir, I tried running your script and I faced one error ValueError: "The model info for, {0}, is not defined yet.".format(model_basename)) could u please help me rectify my mistake. I have used ssd_mobilenet_v2_320*320 model.
Hii, try changing the above code in the script based on your model configuration, works for me....
Dear Igor,
I am trying to convert a working TF2 model into the TFLite format to be used in a mobile App. I have found your excellent code but I get one error message in line 210.
ValueError: The number of output tensors (8) should match the number of output tensor metadata (4)
I have been trying to solve it without success.
Any ideas on how to proceed?
Thanks so much@agql Where did you was your error after all?? I am facing the same issue....
@MikeMpapa Hello, I am not sure about it, but i think it might be a problem of conversion (your trained model -> your trained model in tflite) Before using this script, you have to be sure your model is converted in the right way.
I first use the code : export_tflite_graph_tf2.py that you can find here https://github.com/tensorflow/models/tree/master/research/object_detection
And then I convert this new model in tflite using this code :
` import tensorflow as tf
#Convert the model saved_model_dir = " path to your saved_model_dir "
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] tflite_model = converter.convert() open("name_model.tflite", "wb").write(tflite_model) `
After that, there is no more issues with the number of output tensors
works for me, thankss

@agql Where did you was your error after all?? I am facing the same issue....