Skip to content

Instantly share code, notes, and snippets.

@cnmoro
Last active July 30, 2025 12:51
Show Gist options
  • Select an option

  • Save cnmoro/3c66de4f92716e8cf044e550a23ee9d2 to your computer and use it in GitHub Desktop.

Select an option

Save cnmoro/3c66de4f92716e8cf044e550a23ee9d2 to your computer and use it in GitHub Desktop.
Gliclass ONNX Conversion and Inference
## Converter Imports
from transformers import AutoTokenizer
from gliclass import GLiClassModel
import torch
## Quantizer Imports
from onnxruntime.quantization import quantize_dynamic, QuantType
## Inference imports
from tokenizers import Tokenizer
import onnxruntime as ort
import numpy as np
import os
class GLiClassOnnxConverter:
def __init__(self, model_id, max_labels=100, output_directory="./onnx_model"):
self.model_id = model_id
self.model = GLiClassModel.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.model.eval()
self.output_directory = output_directory
self.max_labels = max_labels
self.dummy_labels = [f"label{i}" for i in range(max_labels)]
def create_dynamic_dummy_inputs(self):
"""Create dummy inputs with MAX_LABELS, but dynamic axes will handle actual count."""
dummy_text = "Test text for ONNX export."
input_text = dummy_text
for label in self.dummy_labels:
input_text += f"<<LABEL>>{label}"
input_text += "<<SEP>>"
inputs = self.tokenizer(
input_text,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
return {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"]
}
def convert_and_export(self):
dummy_inputs = self.create_dynamic_dummy_inputs()
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"logits": {0: "batch_size", 1: "num_labels"} # num_labels is dynamic
}
# Export with flexible label handling
torch.onnx.export(
self.model,
tuple(dummy_inputs.values()),
os.path.join(self.output_directory, "model.onnx"),
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=["logits"],
dynamic_axes=dynamic_axes
)
self.tokenizer.save_pretrained(self.output_directory)
# Remove files that are not "model.onnx" or "tokenizer.json"
for file in os.listdir(self.output_directory):
if file not in ["model.onnx", "tokenizer.json"]:
os.remove(os.path.join(self.output_directory, file))
print("ONNX model exported")
class OnnxQuantizer:
def __init__(self, onnx_path: str):
self.onnx_path = onnx_path
def quantize(self):
quantize_dynamic(
model_input=self.onnx_path,
model_output=self.onnx_path.replace(".onnx", "_QUInt8.onnx"),
weight_type=QuantType.QUInt8,
per_channel=True
)
class GLiClassOnnxInference:
def __init__(self, onnx_path: str, tokenizer_json_path: str):
self.onnx_runtime_session = ort.InferenceSession(onnx_path)
self.tokenizer = Tokenizer.from_file(tokenizer_json_path)
def encode(self, text: str, max_length: int = 512, pad: bool = True):
encoded = self.tokenizer.encode(text)
ids = encoded.ids
mask = encoded.attention_mask
if pad and len(ids) < max_length:
pad_len = max_length - len(ids)
ids += [self.tokenizer.token_to_id("[PAD]")] * pad_len
mask += [0] * pad_len
ids = ids[:max_length]
mask = mask[:max_length]
return np.array([ids], dtype=np.int64), np.array([mask], dtype=np.int64)
def onnx_predict(self, text: str, labels: list[str]):
full_text = "".join([f"<<LABEL>>{l}" for l in labels]) + "<<SEP>>" + text
ids, mask = self.encode(full_text, max_length=512)
ort_inputs = {"input_ids": ids, "attention_mask": mask}
logits = self.onnx_runtime_session.run(None, ort_inputs)[0]
probs = 1 / (1 + np.exp(-logits[0]))
return [{"label": label, "score": float(prob)} for label, prob in zip(labels, probs)]
if __name__ == "__main__":
for model_id in [
"knowledgator/gliclass-edge-v3.0",
"knowledgator/gliclass-base-v3.0",
"knowledgator/gliclass-modern-large-v3.0",
"knowledgator/gliclass-large-v3.0",
"knowledgator/gliclass-x-base"
]:
# Convert then quantize
model_name = model_id.split("/")[-1]
output_dir = os.path.join("onnx_models", model_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
converter = GLiClassOnnxConverter(model_id, output_directory=output_dir)
converter.convert_and_export()
quantizer = OnnxQuantizer(onnx_path=os.path.join(output_dir, "model.onnx"))
quantizer.quantize()
# Test out inference
print(f"Testing {model_name}...")
inference_session = GLiClassOnnxInference(
onnx_path=os.path.join(output_dir, "model_QUInt8.onnx"),
tokenizer_json_path=os.path.join(output_dir, "tokenizer.json")
)
results = inference_session.onnx_predict("One day I will see the world!", ["travel", "dreams", "sport", "science", "politics"])
for r in results:
print(f"{r['label']} => {r['score']:.3f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment