Created
August 11, 2025 06:02
-
-
Save daydreamt/70e8d1f3039fecf35b8342b83506c618 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import subprocess | |
| import time | |
| from dataclasses import dataclass | |
| from typing import List, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| import torch | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import os | |
| import re | |
| # ----------------------------------------------------------------------------- | |
| # Data containers | |
| # ----------------------------------------------------------------------------- | |
| @dataclass | |
| class ClassificationResult: | |
| """Top-1 prediction for a single sentence.""" | |
| label: str | |
| score: float | |
| @dataclass | |
| class TestResult: | |
| text: str | |
| candle: ClassificationResult | |
| hf: ClassificationResult | |
| tei: ClassificationResult | |
| candle_time_ms: float | |
| hf_time_ms: float | |
| tei_time_ms: float | |
| # ----------------------------------------------------------------------------- | |
| # Benchmark framework | |
| # ----------------------------------------------------------------------------- | |
| class DeBERTaBatchTester: | |
| """Run correctness & speed benchmarks for Candle / TEI / 🤗 in batches.""" | |
| def __init__( | |
| self, | |
| model_id: str, | |
| candle_binary: str, | |
| tei_url: str, | |
| num_samples: int, | |
| batch_size: int = 8, | |
| ) -> None: | |
| self.model_id = model_id | |
| self.candle_binary = candle_binary | |
| self.tei_url = tei_url | |
| self.num_samples = num_samples | |
| self.batch_size = batch_size | |
| # Hugging Face model | |
| self.hf_model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
| self.hf_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| self.hf_model.eval() | |
| ds = load_dataset("JasperLS/prompt-injections", split="train") | |
| self.sentences: List[str] = [x for x in [x[:(512 * 3)].strip() for x in ds["text"][:num_samples]] if len(x) >= 2] # truncate too | |
| # --------------------------------------------------------------------- | |
| # Candle (Rust) – multi-sentence via repeated --sentence flags | |
| # --------------------------------------------------------------------- | |
| def _candle_batch(self, sentences: List[str]) -> Tuple[List[ClassificationResult], float]: | |
| cmd = [ | |
| self.candle_binary, | |
| f"--model-id={self.model_id}", | |
| "--revision=main", | |
| "--task=text-classification",] + [f"--sentence={s}" for s in sentences] | |
| out = subprocess.check_output(cmd, text=True) | |
| # Extract Candle-reported time like: "Inferenced inputs in 7.700875ms" | |
| m = re.search(r"Inferenced inputs in\s*([0-9]*\.?[0-9]+)\s*ms", out) | |
| # Extract Candle-reported time like: "Tokenized and loaded inputs in 5.600125ms" | |
| m2 = re.search(r"Tokenized and loaded inputs in\s*([0-9]*\.?[0-9]+)\s*ms", out) | |
| if not m: | |
| raise RuntimeError("Candle inference time line not found in output") | |
| # print(f"{m}, candle time for {len(sentences)}") | |
| candle_time_ms = float(m.group(1)) + float(m2.group(1)) | |
| # Candle prints a pseudo-JSON list on the last line. | |
| last_line = out.strip().split("\n")[-1] | |
| items = last_line.strip("[]").split("},") | |
| results: List[ClassificationResult] = [] | |
| for item in items: | |
| parts = ( | |
| item.replace("TextClassificationItem {", "") | |
| .replace("}", "") | |
| .split(",") | |
| ) | |
| label = parts[0].split(":")[-1].strip().strip('"') | |
| score = float(parts[1].split(":")[-1]) | |
| results.append(ClassificationResult(label, score)) | |
| assert len(results) == len(sentences), "Candle output size mismatch." | |
| return results, candle_time_ms | |
| # --------------------------------------------------------------------- | |
| # Hugging Face – batched tensors | |
| # --------------------------------------------------------------------- | |
| def _hf_batch(self, sentences: List[str]) -> Tuple[List[ClassificationResult], float]: | |
| t0 = time.time() | |
| toks = self.hf_tokenizer( | |
| sentences, return_tensors="pt", padding=True, truncation=False | |
| ) | |
| with torch.no_grad(): | |
| probs = torch.nn.functional.softmax( | |
| self.hf_model(**toks).logits, dim=-1 | |
| ) | |
| elapsed = (time.time() - t0) * 1000 | |
| id2label = self.hf_model.config.id2label | |
| top1 = torch.topk(probs, 1, dim=-1) | |
| results = [ | |
| ClassificationResult(id2label[idx.item()], score.item()) | |
| for score, idx in zip(top1.values.squeeze(1), top1.indices.squeeze(1)) | |
| ] | |
| return results, elapsed | |
| # --------------------------------------------------------------------- | |
| # TEI – POST /predict with {"inputs": [[s1], [s2], ...]} | |
| # --------------------------------------------------------------------- | |
| def _tei_batch(self, sentences: List[str]) -> Tuple[List[ClassificationResult], float]: | |
| payload = {"inputs": [[s] for s in sentences]} | |
| t0 = time.time() | |
| r = requests.post(self.tei_url, json=payload, timeout=30) | |
| r.raise_for_status() | |
| elapsed = (time.time() - t0) * 1000 | |
| data = r.json() # outer list per sentence, inner list per class | |
| results: List[ClassificationResult] = [] | |
| for preds in data: | |
| best = max(preds, key=lambda x: x["score"]) | |
| results.append(ClassificationResult(best["label"], best["score"])) | |
| return results, elapsed | |
| # ------------------------------------------------------------------ | |
| # Main loop | |
| # ------------------------------------------------------------------ | |
| def run(self) -> List[TestResult]: | |
| tests: List[TestResult] = [] | |
| for i in tqdm(range(0, self.num_samples, self.batch_size), desc="Batches"): | |
| batch = self.sentences[i : i + self.batch_size] | |
| if not len(batch): | |
| continue | |
| # candle_res, candle_t = self._candle_batch(batch) | |
| hf_res, hf_t = self._hf_batch(batch) | |
| # TODO: | |
| tei_res, tei_t = self._tei_batch(batch) | |
| candle_res, candle_t = tei_res, tei_t | |
| # Distribute batch time equally – sufficient for averages. | |
| per_candle = candle_t / len(batch) | |
| per_hf = hf_t / len(batch) | |
| per_tei = tei_t / len(batch) | |
| for s, c, h, t in zip(batch, candle_res, hf_res, tei_res): | |
| tests.append( | |
| TestResult( | |
| text=s, | |
| candle=c, | |
| hf=h, | |
| tei=t, | |
| candle_time_ms=per_candle, | |
| hf_time_ms=per_hf, | |
| tei_time_ms=per_tei, | |
| ) | |
| ) | |
| return tests | |
| # -------------------------------------------------------------- | |
| # Aggregated metrics | |
| # -------------------------------------------------------------- | |
| @staticmethod | |
| def to_dataframe(results: List[TestResult]) -> pd.DataFrame: | |
| df = pd.DataFrame( | |
| { | |
| "text": [r.text for r in results], | |
| "candle_label": [r.candle.label for r in results], | |
| "hf_label": [r.hf.label for r in results], | |
| "tei_label": [r.tei.label for r in results], | |
| "candle_score": [r.candle.score for r in results], | |
| "hf_score": [r.hf.score for r in results], | |
| "tei_score": [r.tei.score for r in results], | |
| "candle_ms": [r.candle_time_ms for r in results], | |
| "hf_ms": [r.hf_time_ms for r in results], | |
| "tei_ms": [r.tei_time_ms for r in results], | |
| } | |
| ) | |
| # Pairwise absolute score differences | |
| df["diff_hf_tei"] = (df.hf_score - df.tei_score).abs() | |
| df["diff_hf_candle"] = (df.hf_score - df.candle_score).abs() | |
| df["diff_candle_tei"] = (df.candle_score - df.tei_score).abs() | |
| # Label agreement flag | |
| df["labels_match"] = ( | |
| (df.candle_label == df.hf_label) & (df.hf_label == df.tei_label) | |
| ) | |
| return df | |
| # ----------------------------------------------------------------------------- | |
| # CLI entry point | |
| # ----------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import argparse | |
| p = argparse.ArgumentParser( | |
| description="Benchmark Candle/TEI/HF DeBERTa with batching and similarity stats" | |
| ) | |
| p.add_argument("--model-id", default="meta-llama/Prompt-Guard-86M") | |
| p.add_argument( | |
| "--candle-binary", | |
| default=os.path.expanduser("~/niko/candle/target/release/examples/debertav2"), | |
| help="Path to Candle debertav2 binary", | |
| ) | |
| p.add_argument("--tei-url", default="http://127.0.0.1:8080/predict") | |
| p.add_argument("--num-samples", type=int, default=1000) | |
| p.add_argument("--batch-size", type=int, default=8) | |
| p.add_argument("--out-csv", default="test_results.csv") | |
| args = p.parse_args() | |
| tester = DeBERTaBatchTester( | |
| model_id=args.model_id, | |
| candle_binary=args.candle_binary, | |
| tei_url=args.tei_url, | |
| num_samples=args.num_samples, | |
| batch_size=args.batch_size, | |
| ) | |
| df = DeBERTaBatchTester.to_dataframe(tester.run()) | |
| print("\nAverage latency (ms):") | |
| print(df[["candle_ms", "hf_ms", "tei_ms"]].mean()) | |
| agree_rate = df.labels_match.mean() | |
| print(f"\nLabel agreement: {agree_rate*100:.2f}% ({df.labels_match.sum()}/{len(df)})") | |
| print("\nAbsolute score differences:") | |
| for col in ["diff_hf_tei", "diff_hf_candle", "diff_candle_tei"]: | |
| print( | |
| f"{col}: avg={df[col].mean():.6f}, max={df[col].max():.6f}" | |
| ) | |
| # Persist | |
| df.to_csv(args.out_csv, index=False) | |
| print(f"\nSaved detailed results to {args.out_csv}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment