Last active
July 30, 2025 12:51
-
-
Save cnmoro/3c66de4f92716e8cf044e550a23ee9d2 to your computer and use it in GitHub Desktop.
Gliclass ONNX Conversion and Inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ## Converter Imports | |
| from transformers import AutoTokenizer | |
| from gliclass import GLiClassModel | |
| import torch | |
| ## Quantizer Imports | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| ## Inference imports | |
| from tokenizers import Tokenizer | |
| import onnxruntime as ort | |
| import numpy as np | |
| import os | |
| class GLiClassOnnxConverter: | |
| def __init__(self, model_id, max_labels=100, output_directory="./onnx_model"): | |
| self.model_id = model_id | |
| self.model = GLiClassModel.from_pretrained(self.model_id) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| self.model.eval() | |
| self.output_directory = output_directory | |
| self.max_labels = max_labels | |
| self.dummy_labels = [f"label{i}" for i in range(max_labels)] | |
| def create_dynamic_dummy_inputs(self): | |
| """Create dummy inputs with MAX_LABELS, but dynamic axes will handle actual count.""" | |
| dummy_text = "Test text for ONNX export." | |
| input_text = dummy_text | |
| for label in self.dummy_labels: | |
| input_text += f"<<LABEL>>{label}" | |
| input_text += "<<SEP>>" | |
| inputs = self.tokenizer( | |
| input_text, | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| return { | |
| "input_ids": inputs["input_ids"], | |
| "attention_mask": inputs["attention_mask"] | |
| } | |
| def convert_and_export(self): | |
| dummy_inputs = self.create_dynamic_dummy_inputs() | |
| dynamic_axes = { | |
| "input_ids": {0: "batch_size", 1: "seq_len"}, | |
| "attention_mask": {0: "batch_size", 1: "seq_len"}, | |
| "logits": {0: "batch_size", 1: "num_labels"} # num_labels is dynamic | |
| } | |
| # Export with flexible label handling | |
| torch.onnx.export( | |
| self.model, | |
| tuple(dummy_inputs.values()), | |
| os.path.join(self.output_directory, "model.onnx"), | |
| export_params=True, | |
| opset_version=14, | |
| do_constant_folding=True, | |
| input_names=list(dummy_inputs.keys()), | |
| output_names=["logits"], | |
| dynamic_axes=dynamic_axes | |
| ) | |
| self.tokenizer.save_pretrained(self.output_directory) | |
| # Remove files that are not "model.onnx" or "tokenizer.json" | |
| for file in os.listdir(self.output_directory): | |
| if file not in ["model.onnx", "tokenizer.json"]: | |
| os.remove(os.path.join(self.output_directory, file)) | |
| print("ONNX model exported") | |
| class OnnxQuantizer: | |
| def __init__(self, onnx_path: str): | |
| self.onnx_path = onnx_path | |
| def quantize(self): | |
| quantize_dynamic( | |
| model_input=self.onnx_path, | |
| model_output=self.onnx_path.replace(".onnx", "_QUInt8.onnx"), | |
| weight_type=QuantType.QUInt8, | |
| per_channel=True | |
| ) | |
| class GLiClassOnnxInference: | |
| def __init__(self, onnx_path: str, tokenizer_json_path: str): | |
| self.onnx_runtime_session = ort.InferenceSession(onnx_path) | |
| self.tokenizer = Tokenizer.from_file(tokenizer_json_path) | |
| def encode(self, text: str, max_length: int = 512, pad: bool = True): | |
| encoded = self.tokenizer.encode(text) | |
| ids = encoded.ids | |
| mask = encoded.attention_mask | |
| if pad and len(ids) < max_length: | |
| pad_len = max_length - len(ids) | |
| ids += [self.tokenizer.token_to_id("[PAD]")] * pad_len | |
| mask += [0] * pad_len | |
| ids = ids[:max_length] | |
| mask = mask[:max_length] | |
| return np.array([ids], dtype=np.int64), np.array([mask], dtype=np.int64) | |
| def onnx_predict(self, text: str, labels: list[str]): | |
| full_text = "".join([f"<<LABEL>>{l}" for l in labels]) + "<<SEP>>" + text | |
| ids, mask = self.encode(full_text, max_length=512) | |
| ort_inputs = {"input_ids": ids, "attention_mask": mask} | |
| logits = self.onnx_runtime_session.run(None, ort_inputs)[0] | |
| probs = 1 / (1 + np.exp(-logits[0])) | |
| return [{"label": label, "score": float(prob)} for label, prob in zip(labels, probs)] | |
| if __name__ == "__main__": | |
| for model_id in [ | |
| "knowledgator/gliclass-edge-v3.0", | |
| "knowledgator/gliclass-base-v3.0", | |
| "knowledgator/gliclass-modern-large-v3.0", | |
| "knowledgator/gliclass-large-v3.0", | |
| "knowledgator/gliclass-x-base" | |
| ]: | |
| # Convert then quantize | |
| model_name = model_id.split("/")[-1] | |
| output_dir = os.path.join("onnx_models", model_name) | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| converter = GLiClassOnnxConverter(model_id, output_directory=output_dir) | |
| converter.convert_and_export() | |
| quantizer = OnnxQuantizer(onnx_path=os.path.join(output_dir, "model.onnx")) | |
| quantizer.quantize() | |
| # Test out inference | |
| print(f"Testing {model_name}...") | |
| inference_session = GLiClassOnnxInference( | |
| onnx_path=os.path.join(output_dir, "model_QUInt8.onnx"), | |
| tokenizer_json_path=os.path.join(output_dir, "tokenizer.json") | |
| ) | |
| results = inference_session.onnx_predict("One day I will see the world!", ["travel", "dreams", "sport", "science", "politics"]) | |
| for r in results: | |
| print(f"{r['label']} => {r['score']:.3f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment