Skip to content

Instantly share code, notes, and snippets.

@unLomTrois
Last active March 24, 2025 20:57
Show Gist options
  • Select an option

  • Save unLomTrois/7cf20bf8b4bd0042397a300f1e4afa4f to your computer and use it in GitHub Desktop.

Select an option

Save unLomTrois/7cf20bf8b4bd0042397a300f1e4afa4f to your computer and use it in GitHub Desktop.
smilingtimm.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyO75meM027garACKhCh2OKi",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/unLomTrois/7cf20bf8b4bd0042397a300f1e4afa4f/smilingtimm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"from dataclasses import dataclass\n",
"from pathlib import Path\n",
"from typing import Optional\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import timm\n",
"import torch\n",
"from huggingface_hub import hf_hub_download\n",
"from huggingface_hub.utils import HfHubHTTPError\n",
"from PIL import Image\n",
"from simple_parsing import field, parse_known_args\n",
"from timm.data import create_transform, resolve_data_config\n",
"from torch import Tensor, nn\n",
"from torch.nn import functional as F\n",
"\n",
"torch_device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"MODEL_REPO_MAP = {\n",
" \"vit\": \"SmilingWolf/wd-vit-tagger-v3\",\n",
" \"swinv2\": \"SmilingWolf/wd-swinv2-tagger-v3\",\n",
" \"convnext\": \"SmilingWolf/wd-convnext-tagger-v3\",\n",
"}\n",
"\n",
"\n",
"def pil_ensure_rgb(image: Image.Image) -> Image.Image:\n",
" # convert to RGB/RGBA if not already (deals with palette images etc.)\n",
" if image.mode not in [\"RGB\", \"RGBA\"]:\n",
" image = image.convert(\"RGBA\") if \"transparency\" in image.info else image.convert(\"RGB\")\n",
" # convert RGBA to RGB with white background\n",
" if image.mode == \"RGBA\":\n",
" canvas = Image.new(\"RGBA\", image.size, (255, 255, 255))\n",
" canvas.alpha_composite(image)\n",
" image = canvas.convert(\"RGB\")\n",
" return image\n",
"\n",
"\n",
"def pil_pad_square(image: Image.Image) -> Image.Image:\n",
" w, h = image.size\n",
" # get the largest dimension so we can pad to a square\n",
" px = max(image.size)\n",
" # pad to square with white background\n",
" canvas = Image.new(\"RGB\", (px, px), (255, 255, 255))\n",
" canvas.paste(image, ((px - w) // 2, (px - h) // 2))\n",
" return canvas\n",
"\n",
"\n",
"@dataclass\n",
"class LabelData:\n",
" names: list[str]\n",
" rating: list[np.int64]\n",
" general: list[np.int64]\n",
" character: list[np.int64]\n",
"\n",
"\n",
"def load_labels_hf(\n",
" repo_id: str,\n",
" revision: Optional[str] = None,\n",
" token: Optional[str] = None,\n",
") -> LabelData:\n",
" try:\n",
" csv_path = hf_hub_download(\n",
" repo_id=repo_id, filename=\"selected_tags.csv\", revision=revision, token=token\n",
" )\n",
" csv_path = Path(csv_path).resolve()\n",
" except HfHubHTTPError as e:\n",
" raise FileNotFoundError(f\"selected_tags.csv failed to download from {repo_id}\") from e\n",
"\n",
" df: pd.DataFrame = pd.read_csv(csv_path, usecols=[\"name\", \"category\"])\n",
" tag_data = LabelData(\n",
" names=df[\"name\"].tolist(),\n",
" rating=list(np.where(df[\"category\"] == 9)[0]),\n",
" general=list(np.where(df[\"category\"] == 0)[0]),\n",
" character=list(np.where(df[\"category\"] == 4)[0]),\n",
" )\n",
"\n",
" return tag_data\n",
"\n",
"\n",
"def get_tags(\n",
" probs: Tensor,\n",
" labels: LabelData,\n",
" gen_threshold: float,\n",
" char_threshold: float,\n",
"):\n",
" # Convert indices+probs to labels\n",
" probs = list(zip(labels.names, probs.numpy()))\n",
"\n",
" # First 4 labels are actually ratings\n",
" rating_labels = dict([probs[i] for i in labels.rating])\n",
"\n",
" # General labels, pick any where prediction confidence > threshold\n",
" gen_labels = [probs[i] for i in labels.general]\n",
" gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])\n",
" gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))\n",
"\n",
" # Character labels, pick any where prediction confidence > threshold\n",
" char_labels = [probs[i] for i in labels.character]\n",
" char_labels = dict([x for x in char_labels if x[1] > char_threshold])\n",
" char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))\n",
"\n",
" # Combine general and character labels, sort by confidence\n",
" combined_names = [x for x in gen_labels]\n",
" combined_names.extend([x for x in char_labels])\n",
"\n",
" # Convert to a string suitable for use as a training caption\n",
" caption = \", \".join(combined_names)\n",
" taglist = caption.replace(\"_\", \" \").replace(\"(\", \"\\(\").replace(\")\", \"\\)\")\n",
"\n",
" return caption, taglist, rating_labels, char_labels, gen_labels\n",
"\n",
"\n",
"@dataclass\n",
"class ScriptOptions:\n",
" image_file: Path = field(positional=True)\n",
" model: str = field(default=\"swinv2\")\n",
" gen_threshold: float = field(default=0.35)\n",
" char_threshold: float = field(default=0.75)\n",
"\n",
"\n",
"def main(opts: ScriptOptions):\n",
" repo_id = MODEL_REPO_MAP.get(opts.model)\n",
" image_path = Path(opts.image_file).resolve()\n",
" if not image_path.is_file():\n",
" raise FileNotFoundError(f\"Image file not found: {image_path}\")\n",
"\n",
" print(f\"Loading model '{opts.model}' from '{repo_id}'...\")\n",
" model: nn.Module = timm.create_model(\"hf-hub:\" + repo_id).eval()\n",
" state_dict = timm.models.load_state_dict_from_hf(repo_id)\n",
" model.load_state_dict(state_dict)\n",
"\n",
" print(\"Loading tag list...\")\n",
" labels: LabelData = load_labels_hf(repo_id=repo_id)\n",
"\n",
" print(\"Creating data transform...\")\n",
" transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))\n",
"\n",
" print(\"Loading image and preprocessing...\")\n",
" # get image\n",
" img_input: Image.Image = Image.open(image_path)\n",
" # ensure image is RGB\n",
" img_input = pil_ensure_rgb(img_input)\n",
" # pad to square with white background\n",
" img_input = pil_pad_square(img_input)\n",
" display(img_input)\n",
" # run the model's input transform to convert to tensor and rescale\n",
" inputs: Tensor = transform(img_input).unsqueeze(0)\n",
" # NCHW image RGB to BGR\n",
" inputs = inputs[:, [2, 1, 0]]\n",
"\n",
" print(\"Running inference...\")\n",
" with torch.inference_mode():\n",
" # move model to GPU, if available\n",
" if torch_device.type != \"cpu\":\n",
" model = model.to(torch_device)\n",
" inputs = inputs.to(torch_device)\n",
" # run the model\n",
" outputs = model.forward(inputs)\n",
" # apply the final activation function (timm doesn't support doing this internally)\n",
" outputs = F.sigmoid(outputs)\n",
" # move inputs, outputs, and model back to to cpu if we were on GPU\n",
" if torch_device.type != \"cpu\":\n",
" inputs = inputs.to(\"cpu\")\n",
" outputs = outputs.to(\"cpu\")\n",
" model = model.to(\"cpu\")\n",
"\n",
" print(\"Processing results...\")\n",
" caption, taglist, ratings, character, general = get_tags(\n",
" probs=outputs.squeeze(0),\n",
" labels=labels,\n",
" gen_threshold=opts.gen_threshold,\n",
" char_threshold=opts.char_threshold,\n",
" )\n",
"\n",
" print(\"--------\")\n",
" print(f\"Caption: {caption}\")\n",
" print(\"--------\")\n",
" print(f\"Tags: {taglist}\")\n",
"\n",
" print(\"--------\")\n",
" print(\"Ratings:\")\n",
" for k, v in ratings.items():\n",
" print(f\" {k}: {v:.3f}\")\n",
"\n",
" print(\"--------\")\n",
" print(f\"Character tags (threshold={opts.char_threshold}):\")\n",
" for k, v in character.items():\n",
" print(f\" {k}: {v:.3f}\")\n",
"\n",
" print(\"--------\")\n",
" print(f\"General tags (threshold={opts.gen_threshold}):\")\n",
" for k, v in general.items():\n",
" print(f\" {k}: {v:.3f}\")\n",
"\n",
" print(\"Done!\")\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" opts, _ = parse_known_args(ScriptOptions)\n",
" if opts.model not in MODEL_REPO_MAP:\n",
" print(f\"Available models: {list(MODEL_REPO_MAP.keys())}\")\n",
" raise ValueError(f\"Unknown model name '{opts.model}'\")\n",
" main(opts)"
],
"metadata": {
"id": "vBdzpmyj30h4"
},
"execution_count": null,
"outputs": []
}
]
}
# -*- coding: utf-8 -*-
"""smilingtimm.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/gist/unLomTrois/7cf20bf8b4bd0042397a300f1e4afa4f/smilingtimm.ipynb
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
import timm
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
from PIL import Image
from simple_parsing import field, parse_known_args
from timm.data import create_transform, resolve_data_config
from torch import Tensor, nn
from torch.nn import functional as F
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_REPO_MAP = {
"vit": "SmilingWolf/wd-vit-tagger-v3",
"swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
"convnext": "SmilingWolf/wd-convnext-tagger-v3",
}
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
# convert to RGB/RGBA if not already (deals with palette images etc.)
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
# convert RGBA to RGB with white background
if image.mode == "RGBA":
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
return image
def pil_pad_square(image: Image.Image) -> Image.Image:
w, h = image.size
# get the largest dimension so we can pad to a square
px = max(image.size)
# pad to square with white background
canvas = Image.new("RGB", (px, px), (255, 255, 255))
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
return canvas
@dataclass
class LabelData:
names: list[str]
rating: list[np.int64]
general: list[np.int64]
character: list[np.int64]
def load_labels_hf(
repo_id: str,
revision: Optional[str] = None,
token: Optional[str] = None,
) -> LabelData:
try:
csv_path = hf_hub_download(
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
)
csv_path = Path(csv_path).resolve()
except HfHubHTTPError as e:
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
tag_data = LabelData(
names=df["name"].tolist(),
rating=list(np.where(df["category"] == 9)[0]),
general=list(np.where(df["category"] == 0)[0]),
character=list(np.where(df["category"] == 4)[0]),
)
return tag_data
def get_tags(
probs: Tensor,
labels: LabelData,
gen_threshold: float,
char_threshold: float,
):
# Convert indices+probs to labels
probs = list(zip(labels.names, probs.numpy()))
# First 4 labels are actually ratings
rating_labels = dict([probs[i] for i in labels.rating])
# General labels, pick any where prediction confidence > threshold
gen_labels = [probs[i] for i in labels.general]
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
# Character labels, pick any where prediction confidence > threshold
char_labels = [probs[i] for i in labels.character]
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
# Combine general and character labels, sort by confidence
combined_names = [x for x in gen_labels]
combined_names.extend([x for x in char_labels])
# Convert to a string suitable for use as a training caption
caption = ", ".join(combined_names)
taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
return caption, taglist, rating_labels, char_labels, gen_labels
@dataclass
class ScriptOptions:
image_file: Path = field(positional=True)
model: str = field(default="swinv2")
gen_threshold: float = field(default=0.35)
char_threshold: float = field(default=0.75)
def main(opts: ScriptOptions):
repo_id = MODEL_REPO_MAP.get(opts.model)
image_path = Path(opts.image_file).resolve()
if not image_path.is_file():
raise FileNotFoundError(f"Image file not found: {image_path}")
print(f"Loading model '{opts.model}' from '{repo_id}'...")
model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval()
state_dict = timm.models.load_state_dict_from_hf(repo_id)
model.load_state_dict(state_dict)
print("Loading tag list...")
labels: LabelData = load_labels_hf(repo_id=repo_id)
print("Creating data transform...")
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
print("Loading image and preprocessing...")
# get image
img_input: Image.Image = Image.open(image_path)
# ensure image is RGB
img_input = pil_ensure_rgb(img_input)
# pad to square with white background
img_input = pil_pad_square(img_input)
display(img_input)
# run the model's input transform to convert to tensor and rescale
inputs: Tensor = transform(img_input).unsqueeze(0)
# NCHW image RGB to BGR
inputs = inputs[:, [2, 1, 0]]
print("Running inference...")
with torch.inference_mode():
# move model to GPU, if available
if torch_device.type != "cpu":
model = model.to(torch_device)
inputs = inputs.to(torch_device)
# run the model
outputs = model.forward(inputs)
# apply the final activation function (timm doesn't support doing this internally)
outputs = F.sigmoid(outputs)
# move inputs, outputs, and model back to to cpu if we were on GPU
if torch_device.type != "cpu":
inputs = inputs.to("cpu")
outputs = outputs.to("cpu")
model = model.to("cpu")
print("Processing results...")
caption, taglist, ratings, character, general = get_tags(
probs=outputs.squeeze(0),
labels=labels,
gen_threshold=opts.gen_threshold,
char_threshold=opts.char_threshold,
)
print("--------")
print(f"Caption: {caption}")
print("--------")
print(f"Tags: {taglist}")
print("--------")
print("Ratings:")
for k, v in ratings.items():
print(f" {k}: {v:.3f}")
print("--------")
print(f"Character tags (threshold={opts.char_threshold}):")
for k, v in character.items():
print(f" {k}: {v:.3f}")
print("--------")
print(f"General tags (threshold={opts.gen_threshold}):")
for k, v in general.items():
print(f" {k}: {v:.3f}")
print("Done!")
if __name__ == "__main__":
opts, _ = parse_known_args(ScriptOptions)
if opts.model not in MODEL_REPO_MAP:
print(f"Available models: {list(MODEL_REPO_MAP.keys())}")
raise ValueError(f"Unknown model name '{opts.model}'")
main(opts)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment