Created
November 19, 2025 18:22
-
-
Save surya501/784d0fa902de3bd213692f31c2f12d1a to your computer and use it in GitHub Desktop.
ResNet ONNX/PyTorch Embedding Generator - PEP 723 standalone script for image similarity search (1x7 grid with ResNet input visualization)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "onnxruntime>=1.18.0", | |
| # "numpy>=1.24.0", | |
| # "pillow>=10.0.0", | |
| # "h5py>=3.9.0", | |
| # "torch>=2.0.0", | |
| # "torchvision>=0.15.0", | |
| # "onnxscript>=0.1.0", | |
| # "onnx>=1.14.0", | |
| # ] | |
| # /// | |
| """ | |
| ╔════════════════════════════════════════════════════════════════════════════╗ | |
| ║ ResNet ONNX Embedding Generator with Similarity Search ║ | |
| ╚════════════════════════════════════════════════════════════════════════════╝ | |
| QUICK START: | |
| 1. Create folder structure in project root: | |
| mkdir db input output | |
| 2. Place database stamps in db/ folder | |
| 3. Place query stamps in input/ folder | |
| 4. Run script (uses ResNet50 by default): | |
| uv run scripts/generate_embeddings_onnx.py . | |
| 5. Or specify different model: | |
| uv run scripts/generate_embeddings_onnx.py . --model resnet18 | |
| uv run scripts/generate_embeddings_onnx.py . --model resnet50 | |
| uv run scripts/generate_embeddings_onnx.py . --model resnet152 | |
| 6. View result images in output/ folder | |
| 7. Or use pure PyTorch mode (no ONNX conversion needed): | |
| uv run scripts/generate_embeddings_onnx.py . --backend torch | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| BACKEND MODES: | |
| --backend onnx (default) - Fast ONNX Runtime inference, generates models automatically | |
| --backend torch (alternative) - Pure PyTorch inference on CPU, no model generation needed | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| WORKFLOW: | |
| 1. First run: Generate embeddings from "db/" and cache them (db_embeddings_{model}.h5) | |
| 2. Subsequent runs: Load cached "db" embeddings | |
| 3. Generate embeddings for all images in "input/" | |
| 4. For each input image: Find top 5 similar images from DB | |
| 5. Save visual results to "output/" (one image per input image) | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| FOLDER STRUCTURE: | |
| project_root/ | |
| ├── scripts/ | |
| │ └── generate_embeddings_onnx.py (this file) | |
| ├── db/ (database stamps - processed once) | |
| ├── input/ (query stamps - processed each run) | |
| ├── output/ (visual results auto-created) | |
| ├── .models/ (ONNX models auto-downloaded) | |
| └── db_embeddings_resnet50.h5 (cached DB embeddings for each model) | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| OUTPUT FORMAT: | |
| For each input image, creates a visual result image: | |
| output/{filename}_results.jpg | |
| Layout (1x7 grid): | |
| ┌─────────────────────────────────────────────────────────────────────────────┐ | |
| │ Input │ ResNet │ Match1 │ Match2 │ Match3 │ Match4 │ Match5 │ | |
| │ │ Input │ Sim │ Sim │ Sim │ Sim │ Sim │ | |
| │ │(224x224)│ 0.95 │ 0.92 │ 0.89 │ 0.87 │ 0.84 │ | |
| └─────────────────────────────────────────────────────────────────────────────┘ | |
| Shows original input, preprocessed 224x224 image fed to ResNet, and top 5 matches. | |
| All images same size with aspect ratio preserved (except ResNet input shows actual preprocessing). | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| SUPPORTED MODELS: | |
| - resnet18 (512-dim embeddings) | |
| - resnet50 (2048-dim embeddings) - DEFAULT | |
| - resnet152 (2048-dim embeddings) | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| REGENERATING DB EMBEDDINGS: | |
| Delete the cache file to force regeneration: | |
| - Windows: del db_embeddings_resnet50.h5 | |
| - Mac/Linux: rm db_embeddings_resnet50.h5 | |
| - Or delete manually in file explorer | |
| Then run the script again. New embeddings will be generated. | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| ONNX MODEL MANAGEMENT: | |
| The script automatically generates ONNX models from PyTorch on first run. | |
| They are stored in: .models/ | |
| To DELETE ONNX MODELS and regenerate: | |
| - Windows: rmdir /s .models | |
| - Mac/Linux: rm -rf .models | |
| - Or delete manually in file explorer | |
| Next run will automatically regenerate ONNX models (~2-5 min per model). | |
| To DELETE ONLY A SPECIFIC MODEL: | |
| - Windows: del .models\resnet50.onnx | |
| - Mac/Linux: rm .models/resnet50.onnx | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| SUPPORTED IMAGE FORMATS: | |
| .jpg, .jpeg, .png (case-insensitive) | |
| REQUIREMENTS: | |
| - Python 3.11+ | |
| - Dependencies auto-installed by uv | |
| - ~100MB disk per ONNX model (auto-downloaded) | |
| - CPU-only (works on all platforms) | |
| """ | |
| import sys | |
| import argparse | |
| from pathlib import Path | |
| from typing import List, Union, Dict, Tuple | |
| import urllib.request | |
| from urllib.error import URLError | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import onnxruntime as ort | |
| import h5py | |
| import torch | |
| import torch.onnx | |
| from torchvision import models | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Configuration | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| RESIZE = (224, 224) | |
| IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| # ONNX ResNet models | |
| # Models are automatically generated from PyTorch on first run | |
| MODELS = { | |
| "resnet18": { | |
| "torch_model": models.resnet18, | |
| "embedding_dim": 512, | |
| "file": "resnet18.onnx" | |
| }, | |
| "resnet50": { | |
| "torch_model": models.resnet50, | |
| "embedding_dim": 2048, | |
| "file": "resnet50.onnx" | |
| }, | |
| "resnet152": { | |
| "torch_model": models.resnet152, | |
| "embedding_dim": 2048, | |
| "file": "resnet152.onnx" | |
| } | |
| } | |
| DEFAULT_MODEL = "resnet50" | |
| # Local paths (relative to script) | |
| SCRIPT_DIR = Path(__file__).parent.parent | |
| MODELS_DIR = SCRIPT_DIR / ".models" | |
| # Image layout config | |
| THUMB_SIZE = 150 # Size of each thumbnail image | |
| THUMB_PADDING = 10 # Padding between images | |
| LABEL_HEIGHT = 40 # Space for similarity score label | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Model Conversion & Loading | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def convert_pytorch_to_onnx(model_name: str, model_path: Path) -> None: | |
| """ | |
| Convert PyTorch ResNet model to ONNX format. | |
| Args: | |
| model_name: Name of the model (resnet18, resnet50, resnet152) | |
| model_path: Path where to save the ONNX model | |
| """ | |
| if model_name not in MODELS: | |
| print(f"✗ Unknown model: {model_name}") | |
| sys.exit(1) | |
| print(f"⚙️ Generating ONNX model from PyTorch (first time only)...") | |
| print(f" Model: {model_name.upper()}") | |
| print(f" This may take 2-5 minutes...\n") | |
| model_config = MODELS[model_name] | |
| torch_model_fn = model_config["torch_model"] | |
| # Create models directory | |
| model_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Load pretrained PyTorch model | |
| print(f" Downloading PyTorch weights...") | |
| with torch.no_grad(): | |
| pytorch_model = torch_model_fn(weights="DEFAULT").eval() | |
| # Remove classification layer to get features | |
| # ResNet models have avgpool -> fc, we keep just avgpool | |
| pytorch_model = torch.nn.Sequential(*list(pytorch_model.children())[:-1]) | |
| # Create dummy input and export to ONNX | |
| dummy_input = torch.randn(1, 3, 224, 224) | |
| print(f" Exporting to ONNX...") | |
| torch.onnx.export( | |
| pytorch_model, | |
| dummy_input, | |
| str(model_path), | |
| input_names=["input"], | |
| output_names=["output"], | |
| opset_version=12, | |
| do_constant_folding=True, | |
| verbose=False | |
| ) | |
| print(f"✓ ONNX model saved: {model_path}\n") | |
| def get_or_create_model(model_name: str) -> Path: | |
| """Get ONNX model path, generating from PyTorch if needed.""" | |
| if model_name not in MODELS: | |
| print(f"✗ Unknown model: {model_name}") | |
| print(f"Available models: {', '.join(MODELS.keys())}") | |
| sys.exit(1) | |
| model_config = MODELS[model_name] | |
| model_path = MODELS_DIR / model_config["file"] | |
| if model_path.exists(): | |
| print(f"✓ ONNX model found: {model_path.name}") | |
| return model_path | |
| # Convert from PyTorch on first run | |
| convert_pytorch_to_onnx(model_name, model_path) | |
| return model_path | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # PyTorch Inference | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class TorchFeatureExtractor: | |
| """PyTorch-based ResNet feature extractor.""" | |
| def __init__(self, model_name: str, device: str = "cpu"): | |
| """ | |
| Initialize PyTorch model as feature extractor. | |
| Args: | |
| model_name: Name of the model (resnet18, resnet50, resnet152) | |
| device: Device to run on ("cpu" or "cuda") | |
| """ | |
| if model_name not in MODELS: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| self.model_name = model_name | |
| self.device = device | |
| self.model_fn = MODELS[model_name]["torch_model"] | |
| # Load pretrained model and remove classification layer | |
| print(f"Loading PyTorch {model_name.upper()} model...") | |
| with torch.no_grad(): | |
| pytorch_model = self.model_fn(weights="DEFAULT").to(device).eval() | |
| # Remove classification layer (keep only avgpool output) | |
| self.model = torch.nn.Sequential(*list(pytorch_model.children())[:-1]).to(device) | |
| print(f"✓ PyTorch model loaded on {device.upper()}\n") | |
| def extract_features(self, image_tensor: torch.Tensor) -> List[float]: | |
| """ | |
| Extract features from preprocessed image tensor. | |
| Args: | |
| image_tensor: Image tensor with shape (1, 3, 224, 224) | |
| Returns: | |
| Feature vector as Python list | |
| """ | |
| with torch.no_grad(): | |
| image_tensor = image_tensor.to(self.device) | |
| features = self.model(image_tensor) | |
| # Flatten the output (remove spatial dimensions) | |
| features = torch.nn.functional.adaptive_avg_pool2d(features, 1) | |
| features = features.flatten(1) | |
| return features.cpu().numpy().flatten().tolist() | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Image Preprocessing | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def preprocess_image(image: Union[Image.Image, str]) -> np.ndarray: | |
| """ | |
| Preprocess image for ONNX ResNet inference (returns numpy array). | |
| Args: | |
| image: PIL Image object or path to image file | |
| Returns: | |
| Preprocessed tensor with shape (1, 3, 224, 224) as numpy array | |
| """ | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert("RGB") | |
| # Resize to 224x224 | |
| image = image.resize(RESIZE) | |
| # Normalize: [0, 255] -> [0, 1] | |
| img = np.array(image, dtype=np.float32) / 255.0 | |
| # ImageNet normalization | |
| img = (img - IMAGENET_MEAN) / IMAGENET_STD | |
| # Convert HWC to CHW | |
| img = np.transpose(img, (2, 0, 1)) | |
| # Add batch dimension | |
| return np.expand_dims(img, 0) | |
| def preprocess_image_torch(image: Union[Image.Image, str]) -> torch.Tensor: | |
| """ | |
| Preprocess image for PyTorch ResNet inference (returns tensor). | |
| Args: | |
| image: PIL Image object or path to image file | |
| Returns: | |
| Preprocessed tensor with shape (1, 3, 224, 224) | |
| """ | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert("RGB") | |
| # Resize to 224x224 | |
| image = image.resize(RESIZE) | |
| # Normalize: [0, 255] -> [0, 1] | |
| img = np.array(image, dtype=np.float32) / 255.0 | |
| # ImageNet normalization | |
| img = (img - IMAGENET_MEAN) / IMAGENET_STD | |
| # Convert HWC to CHW | |
| img = np.transpose(img, (2, 0, 1)) | |
| # Add batch dimension and convert to tensor | |
| img = np.expand_dims(img, 0) | |
| return torch.from_numpy(img) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Embedding Generation | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def generate_embedding_onnx(session: ort.InferenceSession, image: Union[Image.Image, str]) -> List[float]: | |
| """ | |
| Generate embedding using ONNX ResNet model. | |
| Args: | |
| session: ONNX Runtime InferenceSession | |
| image: PIL Image object or path to image file | |
| Returns: | |
| Embedding vector as Python list of floats | |
| """ | |
| input_tensor = preprocess_image(image) | |
| input_name = session.get_inputs()[0].name | |
| output_name = session.get_outputs()[0].name | |
| outputs = session.run([output_name], {input_name: input_tensor}) | |
| return outputs[0].flatten().tolist() | |
| def generate_embedding_torch(extractor: TorchFeatureExtractor, image: Union[Image.Image, str]) -> List[float]: | |
| """ | |
| Generate embedding using PyTorch ResNet model. | |
| Args: | |
| extractor: TorchFeatureExtractor instance | |
| image: PIL Image object or path to image file | |
| Returns: | |
| Embedding vector as Python list of floats | |
| """ | |
| input_tensor = preprocess_image_torch(image) | |
| return extractor.extract_features(input_tensor) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # File & Folder Operations | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def collect_images(folder: Path) -> List[Path]: | |
| """Collect all supported image files from folder (non-recursive).""" | |
| images = [] | |
| for ext in ["*.jpg", "*.jpeg", "*.png", "*.JPG", "*.JPEG", "*.PNG"]: | |
| images.extend(folder.glob(ext)) | |
| return sorted(set(images)) | |
| def load_and_resize_image(img_path: Path, size: Tuple[int, int]) -> Image.Image: | |
| """Load image and resize to specified size.""" | |
| img = Image.open(img_path).convert("RGB") | |
| img.thumbnail(size, Image.Resampling.LANCZOS) | |
| # Create new image with exact size to avoid stretching | |
| new_img = Image.new("RGB", size, (255, 255, 255)) | |
| offset = ((size[0] - img.width) // 2, (size[1] - img.height) // 2) | |
| new_img.paste(img, offset) | |
| return new_img | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Database Embeddings (Generate & Load) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def get_db_cache_path(model_name: str) -> Path: | |
| """Get cache file path for a specific model.""" | |
| return SCRIPT_DIR / f"db_embeddings_{model_name}.h5" | |
| def generate_db_embeddings(inference_fn, db_folder: Path, model_name: str, embedding_dim: int) -> Dict[str, List[float]]: | |
| """ | |
| Generate and cache embeddings for all images in database folder. | |
| Args: | |
| inference_fn: Callable for generating embeddings (generate_embedding_onnx or generate_embedding_torch) | |
| db_folder: Path to database folder | |
| model_name: Name of the model (for cache naming) | |
| embedding_dim: Dimension of embeddings for this model | |
| Returns: | |
| Dictionary mapping filenames to embeddings | |
| """ | |
| print(f"\n📦 Generating DB embeddings from: {db_folder}") | |
| db_images = collect_images(db_folder) | |
| if not db_images: | |
| print("✗ No images found in db folder") | |
| return {} | |
| embeddings = {} | |
| for idx, img_path in enumerate(db_images, 1): | |
| try: | |
| embeddings[img_path.name] = inference_fn(img_path) | |
| if idx % 100 == 0 or idx == len(db_images): | |
| print(f" [{idx}/{len(db_images)}] Processed") | |
| except Exception as e: | |
| print(f" ✗ Failed {img_path.name}: {e}") | |
| # Save to HDF5 | |
| db_cache = get_db_cache_path(model_name) | |
| db_cache.parent.mkdir(parents=True, exist_ok=True) | |
| keys = np.array(list(embeddings.keys()), dtype='S') | |
| vectors = np.array(list(embeddings.values()), dtype=np.float32) | |
| with h5py.File(db_cache, "w") as f: | |
| f.create_dataset("keys", data=keys, maxshape=(None,), chunks=True) | |
| f.create_dataset("vectors", data=vectors, maxshape=(None, embedding_dim), chunks=True) | |
| print(f"✓ Saved {len(embeddings)} DB embeddings to: {db_cache}") | |
| print(f" (Delete this file to regenerate DB embeddings)") | |
| return embeddings | |
| def load_db_embeddings(model_name: str) -> Dict[str, np.ndarray]: | |
| """Load cached DB embeddings from HDF5 file.""" | |
| db_cache = get_db_cache_path(model_name) | |
| if not db_cache.exists(): | |
| return {} | |
| with h5py.File(db_cache, "r") as f: | |
| keys = f["keys"][:].astype(str) | |
| vectors = f["vectors"][:] | |
| return dict(zip(keys, vectors)) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Similarity Search | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: | |
| """Compute cosine similarity between two vectors.""" | |
| return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) | |
| def find_top_matches(input_embedding: List[float], db_embeddings: Dict[str, np.ndarray], top_k: int = 5) -> List[Dict]: | |
| """ | |
| Find top K most similar images from database. | |
| Args: | |
| input_embedding: Embedding vector | |
| db_embeddings: Dictionary of filename -> embedding vector | |
| top_k: Number of top matches to return (default: 5) | |
| Returns: | |
| List of dicts with rank, filename, and similarity score | |
| """ | |
| input_vec = np.array(input_embedding, dtype=np.float32) | |
| similarities = {} | |
| for db_name, db_vec in db_embeddings.items(): | |
| similarities[db_name] = cosine_similarity(input_vec, db_vec) | |
| sorted_matches = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| return [ | |
| {"rank": i + 1, "db_file": name, "similarity": float(score)} | |
| for i, (name, score) in enumerate(sorted_matches) | |
| ] | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Visualization | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def create_result_image(input_img_path: Path, matches: List[Dict], db_folder: Path, output_path: Path): | |
| """ | |
| Create visual result image in 1x7 grid (input + preprocessed + 5 matches). | |
| All images same size with aspect ratio preserved. | |
| Layout: | |
| ┌────────────────────────────────────────────────────────────────┐ | |
| │ Input │ ResNet Input │ M1 │ M2 │ M3 │ M4 │ M5 │ | |
| │ │ (224x224) │Sim │Sim │Sim │Sim │Sim │ | |
| └────────────────────────────────────────────────────────────────┘ | |
| """ | |
| # Load and prepare all images (input + preprocessed + 5 matches) | |
| all_images = [] | |
| # Load input image (original with aspect ratio preserved) | |
| input_img = load_and_resize_image(input_img_path, (THUMB_SIZE, THUMB_SIZE)) | |
| all_images.append((input_img, "Input")) | |
| # Load preprocessed image (what actually goes into ResNet - 224x224 stretched) | |
| preprocessed_img = Image.open(input_img_path).convert("RGB") | |
| preprocessed_img = preprocessed_img.resize(RESIZE) # Direct resize to 224x224 (may distort) | |
| # Scale up for display | |
| preprocessed_display = preprocessed_img.resize((THUMB_SIZE, THUMB_SIZE), Image.Resampling.NEAREST) | |
| all_images.append((preprocessed_display, "ResNet\nInput")) | |
| # Load match images | |
| for match in matches: | |
| match_path = db_folder / match["db_file"] | |
| if match_path.exists(): | |
| match_img = load_and_resize_image(match_path, (THUMB_SIZE, THUMB_SIZE)) | |
| else: | |
| # Create placeholder if image not found | |
| match_img = Image.new("RGB", (THUMB_SIZE, THUMB_SIZE), (200, 200, 200)) | |
| all_images.append((match_img, f"{match['similarity']:.3f}")) | |
| # Create result image: 7 columns x 1 row + label space | |
| total_width = (7 * THUMB_SIZE) + (8 * THUMB_PADDING) | |
| total_height = THUMB_SIZE + LABEL_HEIGHT + (2 * THUMB_PADDING) | |
| result = Image.new("RGB", (total_width, total_height), (255, 255, 255)) | |
| draw = ImageDraw.Draw(result) | |
| # Paste all images in a row | |
| for col_idx, (img, label) in enumerate(all_images): | |
| x = THUMB_PADDING + (col_idx * (THUMB_SIZE + THUMB_PADDING)) | |
| y = THUMB_PADDING | |
| # Paste image | |
| result.paste(img, (x, y)) | |
| # Draw label below image | |
| label_text = label if col_idx <= 1 else f"Sim:\n{label}" | |
| bbox = draw.textbbox((0, 0), label_text) | |
| label_width = bbox[2] - bbox[0] | |
| label_height = bbox[3] - bbox[1] | |
| label_x = x + (THUMB_SIZE - label_width) // 2 | |
| label_y = y + THUMB_SIZE + 5 | |
| # Draw semi-transparent background for text | |
| draw.rectangle( | |
| [(label_x - 5, label_y - 5), (label_x + label_width + 5, label_y + label_height + 5)], | |
| fill=(240, 240, 240), | |
| outline=(200, 200, 200) | |
| ) | |
| draw.text((label_x, label_y), label_text, fill=(0, 0, 0)) | |
| result.save(output_path, quality=95) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Input Processing & Output Generation | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def process_input_folder(inference_fn, input_folder: Path, output_folder: Path, db_folder: Path, db_embeddings: Dict): | |
| """ | |
| Process all input images and generate visual similarity results. | |
| Args: | |
| inference_fn: Callable for generating embeddings (generate_embedding_onnx or generate_embedding_torch) | |
| input_folder: Path to input folder with query images | |
| output_folder: Path to output folder (auto-created) | |
| db_folder: Path to database folder | |
| db_embeddings: Cached database embeddings | |
| """ | |
| input_images = collect_images(input_folder) | |
| if not input_images: | |
| print("✗ No images found in input folder") | |
| return | |
| output_folder.mkdir(parents=True, exist_ok=True) | |
| print(f"\n🔍 Processing {len(input_images)} input images...") | |
| for idx, img_path in enumerate(input_images, 1): | |
| try: | |
| embedding = inference_fn(img_path) | |
| matches = find_top_matches(embedding, db_embeddings, top_k=5) | |
| # Create visual result | |
| output_file = output_folder / f"{img_path.stem}_results.jpg" | |
| create_result_image(img_path, matches, db_folder, output_file) | |
| top_match = matches[0]["db_file"] if matches else "NO MATCHES" | |
| top_sim = matches[0]["similarity"] if matches else 0.0 | |
| print(f" [{idx}/{len(input_images)}] {img_path.name} → {top_match} (sim: {top_sim:.3f})") | |
| except Exception as e: | |
| print(f" ✗ Failed {img_path.name}: {e}") | |
| print(f"\n✓ Results saved to: {output_folder}") | |
| print(f" View JPG images to see top 5 matches for each input image") | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Testing | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def test_all_models() -> bool: | |
| """ | |
| Test all three models by converting and verifying they work. | |
| Returns: | |
| True if all models pass, False otherwise | |
| """ | |
| print("╔════════════════════════════════════════════════════════════════╗") | |
| print("║ Testing All ResNet Models (Convert & Verify) ║") | |
| print("╚════════════════════════════════════════════════════════════════╝\n") | |
| all_passed = True | |
| for model_name in ["resnet18", "resnet50", "resnet152"]: | |
| print(f"\n{'='*60}") | |
| print(f"Testing {model_name.upper()}") | |
| print(f"{'='*60}") | |
| try: | |
| # Get or create model | |
| model_path = get_or_create_model(model_name) | |
| # Verify model file exists and has reasonable size | |
| if not model_path.exists(): | |
| print(f"✗ Model file not found: {model_path}") | |
| all_passed = False | |
| continue | |
| file_size_mb = model_path.stat().st_size / (1024 * 1024) | |
| print(f"✓ Model file exists: {file_size_mb:.1f} MB") | |
| # Try to load with ONNX Runtime | |
| print(f" Loading with ONNX Runtime...") | |
| session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) | |
| print(f"✓ ONNX Runtime loaded successfully") | |
| # Test inference with dummy input | |
| print(f" Testing inference...") | |
| dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32) | |
| input_name = session.get_inputs()[0].name | |
| output_name = session.get_outputs()[0].name | |
| outputs = session.run([output_name], {input_name: dummy_input}) | |
| embedding = outputs[0].flatten() | |
| # Verify output shape | |
| expected_dim = MODELS[model_name]["embedding_dim"] | |
| if len(embedding) != expected_dim: | |
| print(f"✗ Wrong embedding dimension: {len(embedding)} (expected {expected_dim})") | |
| all_passed = False | |
| continue | |
| print(f"✓ Inference successful: {len(embedding)}-dim embedding") | |
| print(f"✓ {model_name.upper()} PASSED") | |
| except Exception as e: | |
| print(f"✗ {model_name.upper()} FAILED: {e}") | |
| all_passed = False | |
| print(f"\n{'='*60}") | |
| if all_passed: | |
| print("✓ ALL TESTS PASSED") | |
| else: | |
| print("✗ SOME TESTS FAILED") | |
| print(f"{'='*60}\n") | |
| return all_passed | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Main | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def main(): | |
| """Main entry point.""" | |
| parser = argparse.ArgumentParser( | |
| description="ResNet ONNX Embedding Generator with Similarity Search", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| uv run scripts/generate_embeddings_onnx.py . | |
| uv run scripts/generate_embeddings_onnx.py . --model resnet18 | |
| uv run scripts/generate_embeddings_onnx.py . --model resnet152 | |
| uv run scripts/generate_embeddings_onnx.py --test | |
| """ | |
| ) | |
| parser.add_argument("project_root", nargs="?", default=".", help="Project root directory (default: current directory)") | |
| parser.add_argument( | |
| "--model", | |
| choices=list(MODELS.keys()), | |
| default=DEFAULT_MODEL, | |
| help=f"Model to use (default: {DEFAULT_MODEL})" | |
| ) | |
| parser.add_argument( | |
| "--test", | |
| action="store_true", | |
| help="Test all models (convert from PyTorch and verify)" | |
| ) | |
| parser.add_argument( | |
| "--backend", | |
| choices=["onnx", "torch"], | |
| default="onnx", | |
| help="Backend to use: onnx (fast, auto-converts) or torch (pure PyTorch, no conversion needed)" | |
| ) | |
| args = parser.parse_args() | |
| # Handle test mode | |
| if args.test: | |
| success = test_all_models() | |
| sys.exit(0 if success else 1) | |
| project_root = Path(args.project_root).resolve() | |
| model_name = args.model | |
| db_folder = project_root / "db" | |
| input_folder = project_root / "input" | |
| output_folder = project_root / "output" | |
| print("╔════════════════════════════════════════════════════════════════╗") | |
| print("║ ResNet ONNX Embedding Generator with Similarity Search ║") | |
| print("╚════════════════════════════════════════════════════════════════╝\n") | |
| # Validate folders | |
| if not db_folder.exists(): | |
| print(f"✗ DB folder not found: {db_folder}") | |
| sys.exit(1) | |
| if not input_folder.exists(): | |
| print(f"✗ Input folder not found: {input_folder}") | |
| sys.exit(1) | |
| print(f"📂 Project root: {project_root}") | |
| print(f"🔧 Model: {model_name.upper()}") | |
| print(f"🔧 Backend: {args.backend.upper()}") | |
| print(f"📂 Database folder: {db_folder}") | |
| print(f"📂 Input folder: {input_folder}") | |
| print(f"📂 Output folder: {output_folder}\n") | |
| # Get model config | |
| model_config = MODELS[model_name] | |
| embedding_dim = model_config["embedding_dim"] | |
| # Setup backend | |
| if args.backend == "onnx": | |
| # Get or create ONNX model | |
| print(f"🔧 Preparing {model_name.upper()} model...") | |
| model_path = get_or_create_model(model_name) | |
| # Load ONNX session (CPU only) | |
| print("Loading ONNX Runtime...") | |
| session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) | |
| print("✓ ONNX Runtime loaded\n") | |
| # Create inference function | |
| inference_fn = lambda img_path: generate_embedding_onnx(session, img_path) | |
| else: # torch backend | |
| # Load PyTorch model | |
| extractor = TorchFeatureExtractor(model_name, device="cpu") | |
| # Create inference function | |
| inference_fn = lambda img_path: generate_embedding_torch(extractor, img_path) | |
| # Load or generate DB embeddings | |
| db_cache = get_db_cache_path(model_name) | |
| if db_cache.exists(): | |
| print(f"⚡ Loading cached DB embeddings: {db_cache}") | |
| db_embeddings = load_db_embeddings(model_name) | |
| print(f"✓ Loaded {len(db_embeddings)} cached embeddings\n") | |
| else: | |
| db_embeddings = generate_db_embeddings(inference_fn, db_folder, model_name, embedding_dim) | |
| # Process input folder | |
| if db_embeddings: | |
| process_input_folder(inference_fn, input_folder, output_folder, db_folder, db_embeddings) | |
| else: | |
| print("✗ No DB embeddings available. Cannot process input folder.") | |
| sys.exit(1) | |
| print("\n✓ Done!") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment