Skip to content

Instantly share code, notes, and snippets.

@unLomTrois
Created January 15, 2025 04:26
Show Gist options
  • Select an option

  • Save unLomTrois/87ac433448d78e1dcdacd723c08f15be to your computer and use it in GitHub Desktop.

Select an option

Save unLomTrois/87ac433448d78e1dcdacd723c08f15be to your computer and use it in GitHub Desktop.
Wrapper for onnxruntime deepghs/wd14_tagger_with_embeddings (CUDA)
import numpy as np
import pandas as pd
import onnxruntime as rt
import torch
import huggingface_hub
from PIL import Image
# -------------------------------------------------------------------
# Configuration Constants
# -------------------------------------------------------------------
HUB_REPO = "deepghs/wd14_tagger_with_embeddings"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "tags_info.csv"
MODEL_SUBFOLDER = "SmilingWolf/wd-swinv2-tagger-v3" # list of subfolders https://huggingface.co/deepghs/wd14_tagger_with_embeddings
# -------------------------------------------------------------------
# Helper Functions
# -------------------------------------------------------------------
def rcut_threshold(
probs: np.ndarray,
k: int | None = None,
p: float | None = None
) -> float:
"""
Calculate an RCut threshold for a probability array.
Exactly one of ``k`` or ``p`` must be specified:
- k: Use top-k method.
- p: Use probability threshold method.
Parameters
----------
probs : np.ndarray
1D array of probabilities.
k : int | None, optional
Number of top items to retain.
p : float | None, optional
Probability threshold.
Returns
-------
float
The threshold value derived from ``k`` or ``p``.
Raises
------
ValueError
If both ``k`` and ``p`` are missing or both are provided,
or if values are out of valid ranges.
"""
if (k is None and p is None) or (k is not None and p is not None):
raise ValueError("Exactly one of 'k' or 'p' must be provided.")
# Sort descending
sorted_probs = np.sort(probs)[::-1]
n = len(sorted_probs)
if k is not None:
if not (1 <= k <= n):
raise ValueError(f"'k' must be between 1 and {n}.")
return sorted_probs[k - 1]
# p is not None here
if not (0 <= p <= 1):
raise ValueError("'p' must be in [0, 1].")
return p
def load_labels(label_csv: str) -> tuple[list[str], list[int], list[int], list[int]]:
"""
Read, clean, and extract label indices from a CSV.
Parameters
----------
label_csv : str
Path to the 'tags_info.csv' file.
Returns
-------
tuple[list[str], list[int], list[int], list[int]]
A tuple of:
- list of label strings (underscores replaced with spaces)
- indexes where 'category' == 9
- indexes where 'category' == 0
- indexes where 'category' == 4
"""
df = pd.read_csv(label_csv)
# Replace underscores in names with spaces
tag_names = df["name"].str.replace("_", " ").tolist()
rating_indexes = list(np.where(df["category"] == 9)[0])
general_indexes = list(np.where(df["category"] == 0)[0])
character_indexes = list(np.where(df["category"] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes
def download_artifacts(
repo_id: str,
model_file: str,
label_file: str,
subfolder: str | None = None
) -> tuple[str, str]:
"""
Download a model file and label file from the Hugging Face Hub.
Parameters
----------
repo_id : str
The repository ID on Hugging Face.
model_file : str
Name of the model file (e.g., model.onnx).
label_file : str
Name of the label CSV file (e.g., tags_info.csv).
subfolder : str | None, optional
Subfolder name on HF Hub, if any.
Returns
-------
tuple[str, str]
A tuple of:
- label CSV path
- model path
"""
label_csv_path = huggingface_hub.hf_hub_download(
repo_id,
label_file,
subfolder=subfolder
)
model_path = huggingface_hub.hf_hub_download(
repo_id,
model_file,
subfolder=subfolder
)
return label_csv_path, model_path
# -------------------------------------------------------------------
# Main Class
# -------------------------------------------------------------------
class WD14Tagger:
"""
A tagger class that uses an ONNX model to:
- Load label info from a CSV
- Prepare images (RGB -> BGR, pad, resize)
- Predict embeddings
- Predict tags using an RCut threshold
No external code depends on this class yet, so the API is very open
to changes and refactoring.
"""
def __init__(
self,
repo_id: str = HUB_REPO,
model_file: str = MODEL_FILENAME,
label_file: str = LABEL_FILENAME,
subfolder: str = MODEL_SUBFOLDER
) -> None:
"""
Initialize the WD14Tagger by downloading artifacts and creating the ONNX session.
Parameters
----------
repo_id : str, optional
Hugging Face repository ID (default is the global constant HUB_REPO).
model_file : str, optional
Model filename on HF Hub (default is MODEL_FILENAME).
label_file : str, optional
Label CSV filename on HF Hub (default is LABEL_FILENAME).
subfolder : str, optional
Subfolder on HF Hub (default is MODEL_SUBFOLDER).
"""
label_csv, model_path = download_artifacts(
repo_id=repo_id,
model_file=model_file,
label_file=label_file,
subfolder=subfolder
)
# Load labels
(
self.tag_names,
self.rating_indexes,
self.general_indexes,
self.character_indexes
) = load_labels(label_csv)
# Create the inference session
self.model: rt.InferenceSession = self.create_inference_session(model_path)
# Store the input size needed by the model
self.model_target_size: int = self._extract_model_target_size()
print(f"Model loaded. Input size: {self.model_target_size}x{self.model_target_size}")
@staticmethod
def prepare_providers() -> list:
"""
Return a list of providers for ONNXRuntime to use, preferring
CUDA if available, then falling back to CPU.
Returns
-------
list
List of providers for ONNXRuntime.
"""
providers: list = []
available_providers = rt.get_available_providers()
if "CUDAExecutionProvider" in available_providers:
providers.append(
(
"CUDAExecutionProvider",
{
"device_id": torch.cuda.current_device(),
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)
},
)
)
providers.append("CPUExecutionProvider")
return providers
@classmethod
def create_inference_session(cls, model_path: str) -> rt.InferenceSession:
"""
Create a new inference session for the ONNX model.
Parameters
----------
model_path : str
File path to the ONNX model.
Returns
-------
onnxruntime.InferenceSession
The ONNX inference session object.
"""
session_options = rt.SessionOptions()
session_options.log_severity_level = 0 # reduce verbosity
providers = cls.prepare_providers()
return rt.InferenceSession(
model_path,
sess_options=session_options,
providers=providers
)
def _extract_model_target_size(self) -> int:
"""
Extract the target (input) size for images from the ONNX model.
Assumes shape format: (batch, height, width, channels).
Returns
-------
int
The height (and width) that the model expects for input images.
"""
input_meta = self.model.get_inputs()[0]
# Typically shape is (None, height, width, 3)
_, height, _, _ = input_meta.shape
return height
def prepare_image(self, pil_img: Image.Image) -> np.ndarray:
"""
Convert an image to a model-friendly format:
1. RGBA -> RGB
2. Square padding (white background)
3. Resize to model target size
4. Convert RGB -> BGR
5. Return as float32 np array with shape (1, H, W, 3)
Parameters
----------
pil_img : Image.Image
The input image.
Returns
-------
np.ndarray
A 4D float32 array of shape (1, H, W, 3) in BGR format.
"""
# Convert to RGBA then composite onto a white background
rgba_canvas = Image.new("RGBA", pil_img.size, (255, 255, 255))
rgba_canvas.alpha_composite(pil_img)
rgb_img = rgba_canvas.convert("RGB")
# Pad to a square
max_dim = max(rgb_img.size)
pad_left = (max_dim - rgb_img.size[0]) // 2
pad_top = (max_dim - rgb_img.size[1]) // 2
padded = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded.paste(rgb_img, (pad_left, pad_top))
# Resize if necessary
if max_dim != self.model_target_size:
padded = padded.resize(
(self.model_target_size, self.model_target_size),
Image.BICUBIC
)
# Convert RGB to BGR
arr: np.ndarray = np.asarray(padded, dtype=np.float32)[:, :, ::-1]
return np.expand_dims(arr, axis=0)
def predict_probabilities_and_embeddings(
self,
pil_img: Image.Image
) -> tuple[np.ndarray, np.ndarray]:
"""
Forward pass through the ONNX model. Returns:
- probabilities (float32 array of shape [#tags])
- embeddings (float32 array of shape [embedding_dim])
The model is expected to have two outputs:
[probs_output, embeddings_output].
Parameters
----------
pil_img : Image.Image
The input image.
Returns
-------
tuple[np.ndarray, np.ndarray]
A tuple of (probabilities_array, embeddings_array).
"""
if self.model is None:
raise RuntimeError("ONNX model session is not initialized.")
# Prepare input
inp = self.prepare_image(pil_img)
input_name = self.model.get_inputs()[0].name
probs_name, emb_name = [o.name for o in self.model.get_outputs()]
# Model outputs = [ [probabilities], [embeddings] ]
probs_out, emb_out = self.model.run([probs_name, emb_name], {input_name: inp})
return probs_out[0], emb_out[0]
def predict_tags(
self,
pil_img: Image.Image,
k: int | None = None,
p: float | None = None
) -> list[tuple[str, float]]:
"""
Predict (tag_name, probability) pairs for each tag, applying an RCut threshold.
Exactly one of `k` (top-k) or `p` (probability threshold) must be specified.
The tags are sorted by descending probability.
Parameters
----------
pil_img : Image.Image
The input image.
k : int | None, optional
Use top-k threshold if set.
p : float | None, optional
Use probability threshold if set.
Returns
-------
list[tuple[str, float]]
A list of (tag_name, probability) tuples for tags
exceeding the threshold, sorted by descending probability.
"""
probs, _ = self.predict_probabilities_and_embeddings(pil_img)
threshold = rcut_threshold(probs, k=k, p=p)
# Filter tags that meet/exceed threshold
filtered_tags: list[tuple[str, float]] = [
(self.tag_names[i], float(probs[i]))
for i in range(len(self.tag_names))
if probs[i] >= threshold
]
filtered_tags.sort(key=lambda x: x[1], reverse=True)
return filtered_tags
def predict_embeddings(self, pil_img: Image.Image) -> np.ndarray:
"""
Return just the embedding vector for the given image.
Parameters
----------
pil_img : Image.Image
The input image.
Returns
-------
np.ndarray
The embedding vector extracted from the model.
"""
_, emb = self.predict_probabilities_and_embeddings(pil_img)
return emb
def predict_tags_and_embeddings(
self,
pil_img: Image.Image,
k: int | None = None,
p: float | None = None
) -> tuple[list[tuple[str, float]], np.ndarray]:
"""
Predict tags and embeddings in a single call.
Parameters
----------
pil_img : Image.Image
The input image.
k : int | None, optional
Use top-k threshold if set.
p : float | None, optional
Use probability threshold if set.
Returns
-------
tuple[list[tuple[str, float]], np.ndarray]
A tuple of:
- List of (tag_name, probability) for tags above the threshold
- The embedding vector
"""
probs, emb = self.predict_probabilities_and_embeddings(pil_img)
threshold = rcut_threshold(probs, k=k, p=p)
filtered_tags: list[tuple[str, float]] = [
(self.tag_names[i], float(probs[i]))
for i in range(len(self.tag_names))
if probs[i] >= threshold
]
filtered_tags.sort(key=lambda x: x[1], reverse=True)
return filtered_tags, emb
predictor = WD14Tagger()
# Check if using CUDA
providers = predictor.model.get_providers()
if "CUDAExecutionProvider" in providers:
print("Using CUDAExecutionProvider")
else:
print("Not using CUDAExecutionProvider")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment