Created
January 15, 2025 04:26
-
-
Save unLomTrois/87ac433448d78e1dcdacd723c08f15be to your computer and use it in GitHub Desktop.
Wrapper for onnxruntime deepghs/wd14_tagger_with_embeddings (CUDA)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import pandas as pd | |
| import onnxruntime as rt | |
| import torch | |
| import huggingface_hub | |
| from PIL import Image | |
| # ------------------------------------------------------------------- | |
| # Configuration Constants | |
| # ------------------------------------------------------------------- | |
| HUB_REPO = "deepghs/wd14_tagger_with_embeddings" | |
| MODEL_FILENAME = "model.onnx" | |
| LABEL_FILENAME = "tags_info.csv" | |
| MODEL_SUBFOLDER = "SmilingWolf/wd-swinv2-tagger-v3" # list of subfolders https://huggingface.co/deepghs/wd14_tagger_with_embeddings | |
| # ------------------------------------------------------------------- | |
| # Helper Functions | |
| # ------------------------------------------------------------------- | |
| def rcut_threshold( | |
| probs: np.ndarray, | |
| k: int | None = None, | |
| p: float | None = None | |
| ) -> float: | |
| """ | |
| Calculate an RCut threshold for a probability array. | |
| Exactly one of ``k`` or ``p`` must be specified: | |
| - k: Use top-k method. | |
| - p: Use probability threshold method. | |
| Parameters | |
| ---------- | |
| probs : np.ndarray | |
| 1D array of probabilities. | |
| k : int | None, optional | |
| Number of top items to retain. | |
| p : float | None, optional | |
| Probability threshold. | |
| Returns | |
| ------- | |
| float | |
| The threshold value derived from ``k`` or ``p``. | |
| Raises | |
| ------ | |
| ValueError | |
| If both ``k`` and ``p`` are missing or both are provided, | |
| or if values are out of valid ranges. | |
| """ | |
| if (k is None and p is None) or (k is not None and p is not None): | |
| raise ValueError("Exactly one of 'k' or 'p' must be provided.") | |
| # Sort descending | |
| sorted_probs = np.sort(probs)[::-1] | |
| n = len(sorted_probs) | |
| if k is not None: | |
| if not (1 <= k <= n): | |
| raise ValueError(f"'k' must be between 1 and {n}.") | |
| return sorted_probs[k - 1] | |
| # p is not None here | |
| if not (0 <= p <= 1): | |
| raise ValueError("'p' must be in [0, 1].") | |
| return p | |
| def load_labels(label_csv: str) -> tuple[list[str], list[int], list[int], list[int]]: | |
| """ | |
| Read, clean, and extract label indices from a CSV. | |
| Parameters | |
| ---------- | |
| label_csv : str | |
| Path to the 'tags_info.csv' file. | |
| Returns | |
| ------- | |
| tuple[list[str], list[int], list[int], list[int]] | |
| A tuple of: | |
| - list of label strings (underscores replaced with spaces) | |
| - indexes where 'category' == 9 | |
| - indexes where 'category' == 0 | |
| - indexes where 'category' == 4 | |
| """ | |
| df = pd.read_csv(label_csv) | |
| # Replace underscores in names with spaces | |
| tag_names = df["name"].str.replace("_", " ").tolist() | |
| rating_indexes = list(np.where(df["category"] == 9)[0]) | |
| general_indexes = list(np.where(df["category"] == 0)[0]) | |
| character_indexes = list(np.where(df["category"] == 4)[0]) | |
| return tag_names, rating_indexes, general_indexes, character_indexes | |
| def download_artifacts( | |
| repo_id: str, | |
| model_file: str, | |
| label_file: str, | |
| subfolder: str | None = None | |
| ) -> tuple[str, str]: | |
| """ | |
| Download a model file and label file from the Hugging Face Hub. | |
| Parameters | |
| ---------- | |
| repo_id : str | |
| The repository ID on Hugging Face. | |
| model_file : str | |
| Name of the model file (e.g., model.onnx). | |
| label_file : str | |
| Name of the label CSV file (e.g., tags_info.csv). | |
| subfolder : str | None, optional | |
| Subfolder name on HF Hub, if any. | |
| Returns | |
| ------- | |
| tuple[str, str] | |
| A tuple of: | |
| - label CSV path | |
| - model path | |
| """ | |
| label_csv_path = huggingface_hub.hf_hub_download( | |
| repo_id, | |
| label_file, | |
| subfolder=subfolder | |
| ) | |
| model_path = huggingface_hub.hf_hub_download( | |
| repo_id, | |
| model_file, | |
| subfolder=subfolder | |
| ) | |
| return label_csv_path, model_path | |
| # ------------------------------------------------------------------- | |
| # Main Class | |
| # ------------------------------------------------------------------- | |
| class WD14Tagger: | |
| """ | |
| A tagger class that uses an ONNX model to: | |
| - Load label info from a CSV | |
| - Prepare images (RGB -> BGR, pad, resize) | |
| - Predict embeddings | |
| - Predict tags using an RCut threshold | |
| No external code depends on this class yet, so the API is very open | |
| to changes and refactoring. | |
| """ | |
| def __init__( | |
| self, | |
| repo_id: str = HUB_REPO, | |
| model_file: str = MODEL_FILENAME, | |
| label_file: str = LABEL_FILENAME, | |
| subfolder: str = MODEL_SUBFOLDER | |
| ) -> None: | |
| """ | |
| Initialize the WD14Tagger by downloading artifacts and creating the ONNX session. | |
| Parameters | |
| ---------- | |
| repo_id : str, optional | |
| Hugging Face repository ID (default is the global constant HUB_REPO). | |
| model_file : str, optional | |
| Model filename on HF Hub (default is MODEL_FILENAME). | |
| label_file : str, optional | |
| Label CSV filename on HF Hub (default is LABEL_FILENAME). | |
| subfolder : str, optional | |
| Subfolder on HF Hub (default is MODEL_SUBFOLDER). | |
| """ | |
| label_csv, model_path = download_artifacts( | |
| repo_id=repo_id, | |
| model_file=model_file, | |
| label_file=label_file, | |
| subfolder=subfolder | |
| ) | |
| # Load labels | |
| ( | |
| self.tag_names, | |
| self.rating_indexes, | |
| self.general_indexes, | |
| self.character_indexes | |
| ) = load_labels(label_csv) | |
| # Create the inference session | |
| self.model: rt.InferenceSession = self.create_inference_session(model_path) | |
| # Store the input size needed by the model | |
| self.model_target_size: int = self._extract_model_target_size() | |
| print(f"Model loaded. Input size: {self.model_target_size}x{self.model_target_size}") | |
| @staticmethod | |
| def prepare_providers() -> list: | |
| """ | |
| Return a list of providers for ONNXRuntime to use, preferring | |
| CUDA if available, then falling back to CPU. | |
| Returns | |
| ------- | |
| list | |
| List of providers for ONNXRuntime. | |
| """ | |
| providers: list = [] | |
| available_providers = rt.get_available_providers() | |
| if "CUDAExecutionProvider" in available_providers: | |
| providers.append( | |
| ( | |
| "CUDAExecutionProvider", | |
| { | |
| "device_id": torch.cuda.current_device(), | |
| "user_compute_stream": str(torch.cuda.current_stream().cuda_stream) | |
| }, | |
| ) | |
| ) | |
| providers.append("CPUExecutionProvider") | |
| return providers | |
| @classmethod | |
| def create_inference_session(cls, model_path: str) -> rt.InferenceSession: | |
| """ | |
| Create a new inference session for the ONNX model. | |
| Parameters | |
| ---------- | |
| model_path : str | |
| File path to the ONNX model. | |
| Returns | |
| ------- | |
| onnxruntime.InferenceSession | |
| The ONNX inference session object. | |
| """ | |
| session_options = rt.SessionOptions() | |
| session_options.log_severity_level = 0 # reduce verbosity | |
| providers = cls.prepare_providers() | |
| return rt.InferenceSession( | |
| model_path, | |
| sess_options=session_options, | |
| providers=providers | |
| ) | |
| def _extract_model_target_size(self) -> int: | |
| """ | |
| Extract the target (input) size for images from the ONNX model. | |
| Assumes shape format: (batch, height, width, channels). | |
| Returns | |
| ------- | |
| int | |
| The height (and width) that the model expects for input images. | |
| """ | |
| input_meta = self.model.get_inputs()[0] | |
| # Typically shape is (None, height, width, 3) | |
| _, height, _, _ = input_meta.shape | |
| return height | |
| def prepare_image(self, pil_img: Image.Image) -> np.ndarray: | |
| """ | |
| Convert an image to a model-friendly format: | |
| 1. RGBA -> RGB | |
| 2. Square padding (white background) | |
| 3. Resize to model target size | |
| 4. Convert RGB -> BGR | |
| 5. Return as float32 np array with shape (1, H, W, 3) | |
| Parameters | |
| ---------- | |
| pil_img : Image.Image | |
| The input image. | |
| Returns | |
| ------- | |
| np.ndarray | |
| A 4D float32 array of shape (1, H, W, 3) in BGR format. | |
| """ | |
| # Convert to RGBA then composite onto a white background | |
| rgba_canvas = Image.new("RGBA", pil_img.size, (255, 255, 255)) | |
| rgba_canvas.alpha_composite(pil_img) | |
| rgb_img = rgba_canvas.convert("RGB") | |
| # Pad to a square | |
| max_dim = max(rgb_img.size) | |
| pad_left = (max_dim - rgb_img.size[0]) // 2 | |
| pad_top = (max_dim - rgb_img.size[1]) // 2 | |
| padded = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
| padded.paste(rgb_img, (pad_left, pad_top)) | |
| # Resize if necessary | |
| if max_dim != self.model_target_size: | |
| padded = padded.resize( | |
| (self.model_target_size, self.model_target_size), | |
| Image.BICUBIC | |
| ) | |
| # Convert RGB to BGR | |
| arr: np.ndarray = np.asarray(padded, dtype=np.float32)[:, :, ::-1] | |
| return np.expand_dims(arr, axis=0) | |
| def predict_probabilities_and_embeddings( | |
| self, | |
| pil_img: Image.Image | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Forward pass through the ONNX model. Returns: | |
| - probabilities (float32 array of shape [#tags]) | |
| - embeddings (float32 array of shape [embedding_dim]) | |
| The model is expected to have two outputs: | |
| [probs_output, embeddings_output]. | |
| Parameters | |
| ---------- | |
| pil_img : Image.Image | |
| The input image. | |
| Returns | |
| ------- | |
| tuple[np.ndarray, np.ndarray] | |
| A tuple of (probabilities_array, embeddings_array). | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("ONNX model session is not initialized.") | |
| # Prepare input | |
| inp = self.prepare_image(pil_img) | |
| input_name = self.model.get_inputs()[0].name | |
| probs_name, emb_name = [o.name for o in self.model.get_outputs()] | |
| # Model outputs = [ [probabilities], [embeddings] ] | |
| probs_out, emb_out = self.model.run([probs_name, emb_name], {input_name: inp}) | |
| return probs_out[0], emb_out[0] | |
| def predict_tags( | |
| self, | |
| pil_img: Image.Image, | |
| k: int | None = None, | |
| p: float | None = None | |
| ) -> list[tuple[str, float]]: | |
| """ | |
| Predict (tag_name, probability) pairs for each tag, applying an RCut threshold. | |
| Exactly one of `k` (top-k) or `p` (probability threshold) must be specified. | |
| The tags are sorted by descending probability. | |
| Parameters | |
| ---------- | |
| pil_img : Image.Image | |
| The input image. | |
| k : int | None, optional | |
| Use top-k threshold if set. | |
| p : float | None, optional | |
| Use probability threshold if set. | |
| Returns | |
| ------- | |
| list[tuple[str, float]] | |
| A list of (tag_name, probability) tuples for tags | |
| exceeding the threshold, sorted by descending probability. | |
| """ | |
| probs, _ = self.predict_probabilities_and_embeddings(pil_img) | |
| threshold = rcut_threshold(probs, k=k, p=p) | |
| # Filter tags that meet/exceed threshold | |
| filtered_tags: list[tuple[str, float]] = [ | |
| (self.tag_names[i], float(probs[i])) | |
| for i in range(len(self.tag_names)) | |
| if probs[i] >= threshold | |
| ] | |
| filtered_tags.sort(key=lambda x: x[1], reverse=True) | |
| return filtered_tags | |
| def predict_embeddings(self, pil_img: Image.Image) -> np.ndarray: | |
| """ | |
| Return just the embedding vector for the given image. | |
| Parameters | |
| ---------- | |
| pil_img : Image.Image | |
| The input image. | |
| Returns | |
| ------- | |
| np.ndarray | |
| The embedding vector extracted from the model. | |
| """ | |
| _, emb = self.predict_probabilities_and_embeddings(pil_img) | |
| return emb | |
| def predict_tags_and_embeddings( | |
| self, | |
| pil_img: Image.Image, | |
| k: int | None = None, | |
| p: float | None = None | |
| ) -> tuple[list[tuple[str, float]], np.ndarray]: | |
| """ | |
| Predict tags and embeddings in a single call. | |
| Parameters | |
| ---------- | |
| pil_img : Image.Image | |
| The input image. | |
| k : int | None, optional | |
| Use top-k threshold if set. | |
| p : float | None, optional | |
| Use probability threshold if set. | |
| Returns | |
| ------- | |
| tuple[list[tuple[str, float]], np.ndarray] | |
| A tuple of: | |
| - List of (tag_name, probability) for tags above the threshold | |
| - The embedding vector | |
| """ | |
| probs, emb = self.predict_probabilities_and_embeddings(pil_img) | |
| threshold = rcut_threshold(probs, k=k, p=p) | |
| filtered_tags: list[tuple[str, float]] = [ | |
| (self.tag_names[i], float(probs[i])) | |
| for i in range(len(self.tag_names)) | |
| if probs[i] >= threshold | |
| ] | |
| filtered_tags.sort(key=lambda x: x[1], reverse=True) | |
| return filtered_tags, emb | |
| predictor = WD14Tagger() | |
| # Check if using CUDA | |
| providers = predictor.model.get_providers() | |
| if "CUDAExecutionProvider" in providers: | |
| print("Using CUDAExecutionProvider") | |
| else: | |
| print("Not using CUDAExecutionProvider") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment