Skip to content

Instantly share code, notes, and snippets.

@surya501
Created November 19, 2025 18:22
Show Gist options
  • Select an option

  • Save surya501/784d0fa902de3bd213692f31c2f12d1a to your computer and use it in GitHub Desktop.

Select an option

Save surya501/784d0fa902de3bd213692f31c2f12d1a to your computer and use it in GitHub Desktop.
ResNet ONNX/PyTorch Embedding Generator - PEP 723 standalone script for image similarity search (1x7 grid with ResNet input visualization)
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "onnxruntime>=1.18.0",
# "numpy>=1.24.0",
# "pillow>=10.0.0",
# "h5py>=3.9.0",
# "torch>=2.0.0",
# "torchvision>=0.15.0",
# "onnxscript>=0.1.0",
# "onnx>=1.14.0",
# ]
# ///
"""
╔════════════════════════════════════════════════════════════════════════════╗
║ ResNet ONNX Embedding Generator with Similarity Search ║
╚════════════════════════════════════════════════════════════════════════════╝
QUICK START:
1. Create folder structure in project root:
mkdir db input output
2. Place database stamps in db/ folder
3. Place query stamps in input/ folder
4. Run script (uses ResNet50 by default):
uv run scripts/generate_embeddings_onnx.py .
5. Or specify different model:
uv run scripts/generate_embeddings_onnx.py . --model resnet18
uv run scripts/generate_embeddings_onnx.py . --model resnet50
uv run scripts/generate_embeddings_onnx.py . --model resnet152
6. View result images in output/ folder
7. Or use pure PyTorch mode (no ONNX conversion needed):
uv run scripts/generate_embeddings_onnx.py . --backend torch
───────────────────────────────────────────────────────────────────────────────
BACKEND MODES:
--backend onnx (default) - Fast ONNX Runtime inference, generates models automatically
--backend torch (alternative) - Pure PyTorch inference on CPU, no model generation needed
───────────────────────────────────────────────────────────────────────────────
WORKFLOW:
1. First run: Generate embeddings from "db/" and cache them (db_embeddings_{model}.h5)
2. Subsequent runs: Load cached "db" embeddings
3. Generate embeddings for all images in "input/"
4. For each input image: Find top 5 similar images from DB
5. Save visual results to "output/" (one image per input image)
───────────────────────────────────────────────────────────────────────────────
FOLDER STRUCTURE:
project_root/
├── scripts/
│ └── generate_embeddings_onnx.py (this file)
├── db/ (database stamps - processed once)
├── input/ (query stamps - processed each run)
├── output/ (visual results auto-created)
├── .models/ (ONNX models auto-downloaded)
└── db_embeddings_resnet50.h5 (cached DB embeddings for each model)
───────────────────────────────────────────────────────────────────────────────
OUTPUT FORMAT:
For each input image, creates a visual result image:
output/{filename}_results.jpg
Layout (1x7 grid):
┌─────────────────────────────────────────────────────────────────────────────┐
│ Input │ ResNet │ Match1 │ Match2 │ Match3 │ Match4 │ Match5 │
│ │ Input │ Sim │ Sim │ Sim │ Sim │ Sim │
│ │(224x224)│ 0.95 │ 0.92 │ 0.89 │ 0.87 │ 0.84 │
└─────────────────────────────────────────────────────────────────────────────┘
Shows original input, preprocessed 224x224 image fed to ResNet, and top 5 matches.
All images same size with aspect ratio preserved (except ResNet input shows actual preprocessing).
───────────────────────────────────────────────────────────────────────────────
SUPPORTED MODELS:
- resnet18 (512-dim embeddings)
- resnet50 (2048-dim embeddings) - DEFAULT
- resnet152 (2048-dim embeddings)
───────────────────────────────────────────────────────────────────────────────
REGENERATING DB EMBEDDINGS:
Delete the cache file to force regeneration:
- Windows: del db_embeddings_resnet50.h5
- Mac/Linux: rm db_embeddings_resnet50.h5
- Or delete manually in file explorer
Then run the script again. New embeddings will be generated.
───────────────────────────────────────────────────────────────────────────────
ONNX MODEL MANAGEMENT:
The script automatically generates ONNX models from PyTorch on first run.
They are stored in: .models/
To DELETE ONNX MODELS and regenerate:
- Windows: rmdir /s .models
- Mac/Linux: rm -rf .models
- Or delete manually in file explorer
Next run will automatically regenerate ONNX models (~2-5 min per model).
To DELETE ONLY A SPECIFIC MODEL:
- Windows: del .models\resnet50.onnx
- Mac/Linux: rm .models/resnet50.onnx
───────────────────────────────────────────────────────────────────────────────
SUPPORTED IMAGE FORMATS:
.jpg, .jpeg, .png (case-insensitive)
REQUIREMENTS:
- Python 3.11+
- Dependencies auto-installed by uv
- ~100MB disk per ONNX model (auto-downloaded)
- CPU-only (works on all platforms)
"""
import sys
import argparse
from pathlib import Path
from typing import List, Union, Dict, Tuple
import urllib.request
from urllib.error import URLError
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import onnxruntime as ort
import h5py
import torch
import torch.onnx
from torchvision import models
# ═══════════════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════════════
RESIZE = (224, 224)
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# ONNX ResNet models
# Models are automatically generated from PyTorch on first run
MODELS = {
"resnet18": {
"torch_model": models.resnet18,
"embedding_dim": 512,
"file": "resnet18.onnx"
},
"resnet50": {
"torch_model": models.resnet50,
"embedding_dim": 2048,
"file": "resnet50.onnx"
},
"resnet152": {
"torch_model": models.resnet152,
"embedding_dim": 2048,
"file": "resnet152.onnx"
}
}
DEFAULT_MODEL = "resnet50"
# Local paths (relative to script)
SCRIPT_DIR = Path(__file__).parent.parent
MODELS_DIR = SCRIPT_DIR / ".models"
# Image layout config
THUMB_SIZE = 150 # Size of each thumbnail image
THUMB_PADDING = 10 # Padding between images
LABEL_HEIGHT = 40 # Space for similarity score label
# ═══════════════════════════════════════════════════════════════════════════
# Model Conversion & Loading
# ═══════════════════════════════════════════════════════════════════════════
def convert_pytorch_to_onnx(model_name: str, model_path: Path) -> None:
"""
Convert PyTorch ResNet model to ONNX format.
Args:
model_name: Name of the model (resnet18, resnet50, resnet152)
model_path: Path where to save the ONNX model
"""
if model_name not in MODELS:
print(f"✗ Unknown model: {model_name}")
sys.exit(1)
print(f"⚙️ Generating ONNX model from PyTorch (first time only)...")
print(f" Model: {model_name.upper()}")
print(f" This may take 2-5 minutes...\n")
model_config = MODELS[model_name]
torch_model_fn = model_config["torch_model"]
# Create models directory
model_path.parent.mkdir(parents=True, exist_ok=True)
# Load pretrained PyTorch model
print(f" Downloading PyTorch weights...")
with torch.no_grad():
pytorch_model = torch_model_fn(weights="DEFAULT").eval()
# Remove classification layer to get features
# ResNet models have avgpool -> fc, we keep just avgpool
pytorch_model = torch.nn.Sequential(*list(pytorch_model.children())[:-1])
# Create dummy input and export to ONNX
dummy_input = torch.randn(1, 3, 224, 224)
print(f" Exporting to ONNX...")
torch.onnx.export(
pytorch_model,
dummy_input,
str(model_path),
input_names=["input"],
output_names=["output"],
opset_version=12,
do_constant_folding=True,
verbose=False
)
print(f"✓ ONNX model saved: {model_path}\n")
def get_or_create_model(model_name: str) -> Path:
"""Get ONNX model path, generating from PyTorch if needed."""
if model_name not in MODELS:
print(f"✗ Unknown model: {model_name}")
print(f"Available models: {', '.join(MODELS.keys())}")
sys.exit(1)
model_config = MODELS[model_name]
model_path = MODELS_DIR / model_config["file"]
if model_path.exists():
print(f"✓ ONNX model found: {model_path.name}")
return model_path
# Convert from PyTorch on first run
convert_pytorch_to_onnx(model_name, model_path)
return model_path
# ═══════════════════════════════════════════════════════════════════════════
# PyTorch Inference
# ═══════════════════════════════════════════════════════════════════════════
class TorchFeatureExtractor:
"""PyTorch-based ResNet feature extractor."""
def __init__(self, model_name: str, device: str = "cpu"):
"""
Initialize PyTorch model as feature extractor.
Args:
model_name: Name of the model (resnet18, resnet50, resnet152)
device: Device to run on ("cpu" or "cuda")
"""
if model_name not in MODELS:
raise ValueError(f"Unknown model: {model_name}")
self.model_name = model_name
self.device = device
self.model_fn = MODELS[model_name]["torch_model"]
# Load pretrained model and remove classification layer
print(f"Loading PyTorch {model_name.upper()} model...")
with torch.no_grad():
pytorch_model = self.model_fn(weights="DEFAULT").to(device).eval()
# Remove classification layer (keep only avgpool output)
self.model = torch.nn.Sequential(*list(pytorch_model.children())[:-1]).to(device)
print(f"✓ PyTorch model loaded on {device.upper()}\n")
def extract_features(self, image_tensor: torch.Tensor) -> List[float]:
"""
Extract features from preprocessed image tensor.
Args:
image_tensor: Image tensor with shape (1, 3, 224, 224)
Returns:
Feature vector as Python list
"""
with torch.no_grad():
image_tensor = image_tensor.to(self.device)
features = self.model(image_tensor)
# Flatten the output (remove spatial dimensions)
features = torch.nn.functional.adaptive_avg_pool2d(features, 1)
features = features.flatten(1)
return features.cpu().numpy().flatten().tolist()
# ═══════════════════════════════════════════════════════════════════════════
# Image Preprocessing
# ═══════════════════════════════════════════════════════════════════════════
def preprocess_image(image: Union[Image.Image, str]) -> np.ndarray:
"""
Preprocess image for ONNX ResNet inference (returns numpy array).
Args:
image: PIL Image object or path to image file
Returns:
Preprocessed tensor with shape (1, 3, 224, 224) as numpy array
"""
if isinstance(image, (str, Path)):
image = Image.open(image).convert("RGB")
# Resize to 224x224
image = image.resize(RESIZE)
# Normalize: [0, 255] -> [0, 1]
img = np.array(image, dtype=np.float32) / 255.0
# ImageNet normalization
img = (img - IMAGENET_MEAN) / IMAGENET_STD
# Convert HWC to CHW
img = np.transpose(img, (2, 0, 1))
# Add batch dimension
return np.expand_dims(img, 0)
def preprocess_image_torch(image: Union[Image.Image, str]) -> torch.Tensor:
"""
Preprocess image for PyTorch ResNet inference (returns tensor).
Args:
image: PIL Image object or path to image file
Returns:
Preprocessed tensor with shape (1, 3, 224, 224)
"""
if isinstance(image, (str, Path)):
image = Image.open(image).convert("RGB")
# Resize to 224x224
image = image.resize(RESIZE)
# Normalize: [0, 255] -> [0, 1]
img = np.array(image, dtype=np.float32) / 255.0
# ImageNet normalization
img = (img - IMAGENET_MEAN) / IMAGENET_STD
# Convert HWC to CHW
img = np.transpose(img, (2, 0, 1))
# Add batch dimension and convert to tensor
img = np.expand_dims(img, 0)
return torch.from_numpy(img)
# ═══════════════════════════════════════════════════════════════════════════
# Embedding Generation
# ═══════════════════════════════════════════════════════════════════════════
def generate_embedding_onnx(session: ort.InferenceSession, image: Union[Image.Image, str]) -> List[float]:
"""
Generate embedding using ONNX ResNet model.
Args:
session: ONNX Runtime InferenceSession
image: PIL Image object or path to image file
Returns:
Embedding vector as Python list of floats
"""
input_tensor = preprocess_image(image)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
outputs = session.run([output_name], {input_name: input_tensor})
return outputs[0].flatten().tolist()
def generate_embedding_torch(extractor: TorchFeatureExtractor, image: Union[Image.Image, str]) -> List[float]:
"""
Generate embedding using PyTorch ResNet model.
Args:
extractor: TorchFeatureExtractor instance
image: PIL Image object or path to image file
Returns:
Embedding vector as Python list of floats
"""
input_tensor = preprocess_image_torch(image)
return extractor.extract_features(input_tensor)
# ═══════════════════════════════════════════════════════════════════════════
# File & Folder Operations
# ═══════════════════════════════════════════════════════════════════════════
def collect_images(folder: Path) -> List[Path]:
"""Collect all supported image files from folder (non-recursive)."""
images = []
for ext in ["*.jpg", "*.jpeg", "*.png", "*.JPG", "*.JPEG", "*.PNG"]:
images.extend(folder.glob(ext))
return sorted(set(images))
def load_and_resize_image(img_path: Path, size: Tuple[int, int]) -> Image.Image:
"""Load image and resize to specified size."""
img = Image.open(img_path).convert("RGB")
img.thumbnail(size, Image.Resampling.LANCZOS)
# Create new image with exact size to avoid stretching
new_img = Image.new("RGB", size, (255, 255, 255))
offset = ((size[0] - img.width) // 2, (size[1] - img.height) // 2)
new_img.paste(img, offset)
return new_img
# ═══════════════════════════════════════════════════════════════════════════
# Database Embeddings (Generate & Load)
# ═══════════════════════════════════════════════════════════════════════════
def get_db_cache_path(model_name: str) -> Path:
"""Get cache file path for a specific model."""
return SCRIPT_DIR / f"db_embeddings_{model_name}.h5"
def generate_db_embeddings(inference_fn, db_folder: Path, model_name: str, embedding_dim: int) -> Dict[str, List[float]]:
"""
Generate and cache embeddings for all images in database folder.
Args:
inference_fn: Callable for generating embeddings (generate_embedding_onnx or generate_embedding_torch)
db_folder: Path to database folder
model_name: Name of the model (for cache naming)
embedding_dim: Dimension of embeddings for this model
Returns:
Dictionary mapping filenames to embeddings
"""
print(f"\n📦 Generating DB embeddings from: {db_folder}")
db_images = collect_images(db_folder)
if not db_images:
print("✗ No images found in db folder")
return {}
embeddings = {}
for idx, img_path in enumerate(db_images, 1):
try:
embeddings[img_path.name] = inference_fn(img_path)
if idx % 100 == 0 or idx == len(db_images):
print(f" [{idx}/{len(db_images)}] Processed")
except Exception as e:
print(f" ✗ Failed {img_path.name}: {e}")
# Save to HDF5
db_cache = get_db_cache_path(model_name)
db_cache.parent.mkdir(parents=True, exist_ok=True)
keys = np.array(list(embeddings.keys()), dtype='S')
vectors = np.array(list(embeddings.values()), dtype=np.float32)
with h5py.File(db_cache, "w") as f:
f.create_dataset("keys", data=keys, maxshape=(None,), chunks=True)
f.create_dataset("vectors", data=vectors, maxshape=(None, embedding_dim), chunks=True)
print(f"✓ Saved {len(embeddings)} DB embeddings to: {db_cache}")
print(f" (Delete this file to regenerate DB embeddings)")
return embeddings
def load_db_embeddings(model_name: str) -> Dict[str, np.ndarray]:
"""Load cached DB embeddings from HDF5 file."""
db_cache = get_db_cache_path(model_name)
if not db_cache.exists():
return {}
with h5py.File(db_cache, "r") as f:
keys = f["keys"][:].astype(str)
vectors = f["vectors"][:]
return dict(zip(keys, vectors))
# ═══════════════════════════════════════════════════════════════════════════
# Similarity Search
# ═══════════════════════════════════════════════════════════════════════════
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def find_top_matches(input_embedding: List[float], db_embeddings: Dict[str, np.ndarray], top_k: int = 5) -> List[Dict]:
"""
Find top K most similar images from database.
Args:
input_embedding: Embedding vector
db_embeddings: Dictionary of filename -> embedding vector
top_k: Number of top matches to return (default: 5)
Returns:
List of dicts with rank, filename, and similarity score
"""
input_vec = np.array(input_embedding, dtype=np.float32)
similarities = {}
for db_name, db_vec in db_embeddings.items():
similarities[db_name] = cosine_similarity(input_vec, db_vec)
sorted_matches = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:top_k]
return [
{"rank": i + 1, "db_file": name, "similarity": float(score)}
for i, (name, score) in enumerate(sorted_matches)
]
# ═══════════════════════════════════════════════════════════════════════════
# Visualization
# ═══════════════════════════════════════════════════════════════════════════
def create_result_image(input_img_path: Path, matches: List[Dict], db_folder: Path, output_path: Path):
"""
Create visual result image in 1x7 grid (input + preprocessed + 5 matches).
All images same size with aspect ratio preserved.
Layout:
┌────────────────────────────────────────────────────────────────┐
│ Input │ ResNet Input │ M1 │ M2 │ M3 │ M4 │ M5 │
│ │ (224x224) │Sim │Sim │Sim │Sim │Sim │
└────────────────────────────────────────────────────────────────┘
"""
# Load and prepare all images (input + preprocessed + 5 matches)
all_images = []
# Load input image (original with aspect ratio preserved)
input_img = load_and_resize_image(input_img_path, (THUMB_SIZE, THUMB_SIZE))
all_images.append((input_img, "Input"))
# Load preprocessed image (what actually goes into ResNet - 224x224 stretched)
preprocessed_img = Image.open(input_img_path).convert("RGB")
preprocessed_img = preprocessed_img.resize(RESIZE) # Direct resize to 224x224 (may distort)
# Scale up for display
preprocessed_display = preprocessed_img.resize((THUMB_SIZE, THUMB_SIZE), Image.Resampling.NEAREST)
all_images.append((preprocessed_display, "ResNet\nInput"))
# Load match images
for match in matches:
match_path = db_folder / match["db_file"]
if match_path.exists():
match_img = load_and_resize_image(match_path, (THUMB_SIZE, THUMB_SIZE))
else:
# Create placeholder if image not found
match_img = Image.new("RGB", (THUMB_SIZE, THUMB_SIZE), (200, 200, 200))
all_images.append((match_img, f"{match['similarity']:.3f}"))
# Create result image: 7 columns x 1 row + label space
total_width = (7 * THUMB_SIZE) + (8 * THUMB_PADDING)
total_height = THUMB_SIZE + LABEL_HEIGHT + (2 * THUMB_PADDING)
result = Image.new("RGB", (total_width, total_height), (255, 255, 255))
draw = ImageDraw.Draw(result)
# Paste all images in a row
for col_idx, (img, label) in enumerate(all_images):
x = THUMB_PADDING + (col_idx * (THUMB_SIZE + THUMB_PADDING))
y = THUMB_PADDING
# Paste image
result.paste(img, (x, y))
# Draw label below image
label_text = label if col_idx <= 1 else f"Sim:\n{label}"
bbox = draw.textbbox((0, 0), label_text)
label_width = bbox[2] - bbox[0]
label_height = bbox[3] - bbox[1]
label_x = x + (THUMB_SIZE - label_width) // 2
label_y = y + THUMB_SIZE + 5
# Draw semi-transparent background for text
draw.rectangle(
[(label_x - 5, label_y - 5), (label_x + label_width + 5, label_y + label_height + 5)],
fill=(240, 240, 240),
outline=(200, 200, 200)
)
draw.text((label_x, label_y), label_text, fill=(0, 0, 0))
result.save(output_path, quality=95)
# ═══════════════════════════════════════════════════════════════════════════
# Input Processing & Output Generation
# ═══════════════════════════════════════════════════════════════════════════
def process_input_folder(inference_fn, input_folder: Path, output_folder: Path, db_folder: Path, db_embeddings: Dict):
"""
Process all input images and generate visual similarity results.
Args:
inference_fn: Callable for generating embeddings (generate_embedding_onnx or generate_embedding_torch)
input_folder: Path to input folder with query images
output_folder: Path to output folder (auto-created)
db_folder: Path to database folder
db_embeddings: Cached database embeddings
"""
input_images = collect_images(input_folder)
if not input_images:
print("✗ No images found in input folder")
return
output_folder.mkdir(parents=True, exist_ok=True)
print(f"\n🔍 Processing {len(input_images)} input images...")
for idx, img_path in enumerate(input_images, 1):
try:
embedding = inference_fn(img_path)
matches = find_top_matches(embedding, db_embeddings, top_k=5)
# Create visual result
output_file = output_folder / f"{img_path.stem}_results.jpg"
create_result_image(img_path, matches, db_folder, output_file)
top_match = matches[0]["db_file"] if matches else "NO MATCHES"
top_sim = matches[0]["similarity"] if matches else 0.0
print(f" [{idx}/{len(input_images)}] {img_path.name} → {top_match} (sim: {top_sim:.3f})")
except Exception as e:
print(f" ✗ Failed {img_path.name}: {e}")
print(f"\n✓ Results saved to: {output_folder}")
print(f" View JPG images to see top 5 matches for each input image")
# ═══════════════════════════════════════════════════════════════════════════
# Testing
# ═══════════════════════════════════════════════════════════════════════════
def test_all_models() -> bool:
"""
Test all three models by converting and verifying they work.
Returns:
True if all models pass, False otherwise
"""
print("╔════════════════════════════════════════════════════════════════╗")
print("║ Testing All ResNet Models (Convert & Verify) ║")
print("╚════════════════════════════════════════════════════════════════╝\n")
all_passed = True
for model_name in ["resnet18", "resnet50", "resnet152"]:
print(f"\n{'='*60}")
print(f"Testing {model_name.upper()}")
print(f"{'='*60}")
try:
# Get or create model
model_path = get_or_create_model(model_name)
# Verify model file exists and has reasonable size
if not model_path.exists():
print(f"✗ Model file not found: {model_path}")
all_passed = False
continue
file_size_mb = model_path.stat().st_size / (1024 * 1024)
print(f"✓ Model file exists: {file_size_mb:.1f} MB")
# Try to load with ONNX Runtime
print(f" Loading with ONNX Runtime...")
session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
print(f"✓ ONNX Runtime loaded successfully")
# Test inference with dummy input
print(f" Testing inference...")
dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
outputs = session.run([output_name], {input_name: dummy_input})
embedding = outputs[0].flatten()
# Verify output shape
expected_dim = MODELS[model_name]["embedding_dim"]
if len(embedding) != expected_dim:
print(f"✗ Wrong embedding dimension: {len(embedding)} (expected {expected_dim})")
all_passed = False
continue
print(f"✓ Inference successful: {len(embedding)}-dim embedding")
print(f"✓ {model_name.upper()} PASSED")
except Exception as e:
print(f"✗ {model_name.upper()} FAILED: {e}")
all_passed = False
print(f"\n{'='*60}")
if all_passed:
print("✓ ALL TESTS PASSED")
else:
print("✗ SOME TESTS FAILED")
print(f"{'='*60}\n")
return all_passed
# ═══════════════════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════════════════
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="ResNet ONNX Embedding Generator with Similarity Search",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
uv run scripts/generate_embeddings_onnx.py .
uv run scripts/generate_embeddings_onnx.py . --model resnet18
uv run scripts/generate_embeddings_onnx.py . --model resnet152
uv run scripts/generate_embeddings_onnx.py --test
"""
)
parser.add_argument("project_root", nargs="?", default=".", help="Project root directory (default: current directory)")
parser.add_argument(
"--model",
choices=list(MODELS.keys()),
default=DEFAULT_MODEL,
help=f"Model to use (default: {DEFAULT_MODEL})"
)
parser.add_argument(
"--test",
action="store_true",
help="Test all models (convert from PyTorch and verify)"
)
parser.add_argument(
"--backend",
choices=["onnx", "torch"],
default="onnx",
help="Backend to use: onnx (fast, auto-converts) or torch (pure PyTorch, no conversion needed)"
)
args = parser.parse_args()
# Handle test mode
if args.test:
success = test_all_models()
sys.exit(0 if success else 1)
project_root = Path(args.project_root).resolve()
model_name = args.model
db_folder = project_root / "db"
input_folder = project_root / "input"
output_folder = project_root / "output"
print("╔════════════════════════════════════════════════════════════════╗")
print("║ ResNet ONNX Embedding Generator with Similarity Search ║")
print("╚════════════════════════════════════════════════════════════════╝\n")
# Validate folders
if not db_folder.exists():
print(f"✗ DB folder not found: {db_folder}")
sys.exit(1)
if not input_folder.exists():
print(f"✗ Input folder not found: {input_folder}")
sys.exit(1)
print(f"📂 Project root: {project_root}")
print(f"🔧 Model: {model_name.upper()}")
print(f"🔧 Backend: {args.backend.upper()}")
print(f"📂 Database folder: {db_folder}")
print(f"📂 Input folder: {input_folder}")
print(f"📂 Output folder: {output_folder}\n")
# Get model config
model_config = MODELS[model_name]
embedding_dim = model_config["embedding_dim"]
# Setup backend
if args.backend == "onnx":
# Get or create ONNX model
print(f"🔧 Preparing {model_name.upper()} model...")
model_path = get_or_create_model(model_name)
# Load ONNX session (CPU only)
print("Loading ONNX Runtime...")
session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
print("✓ ONNX Runtime loaded\n")
# Create inference function
inference_fn = lambda img_path: generate_embedding_onnx(session, img_path)
else: # torch backend
# Load PyTorch model
extractor = TorchFeatureExtractor(model_name, device="cpu")
# Create inference function
inference_fn = lambda img_path: generate_embedding_torch(extractor, img_path)
# Load or generate DB embeddings
db_cache = get_db_cache_path(model_name)
if db_cache.exists():
print(f"⚡ Loading cached DB embeddings: {db_cache}")
db_embeddings = load_db_embeddings(model_name)
print(f"✓ Loaded {len(db_embeddings)} cached embeddings\n")
else:
db_embeddings = generate_db_embeddings(inference_fn, db_folder, model_name, embedding_dim)
# Process input folder
if db_embeddings:
process_input_folder(inference_fn, input_folder, output_folder, db_folder, db_embeddings)
else:
print("✗ No DB embeddings available. Cannot process input folder.")
sys.exit(1)
print("\n✓ Done!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment