Skip to content

Instantly share code, notes, and snippets.

@kipavy
Created August 13, 2025 14:14
Show Gist options
  • Select an option

  • Save kipavy/edd457f6b1e8be24b90f169f2d967237 to your computer and use it in GitHub Desktop.

Select an option

Save kipavy/edd457f6b1e8be24b90f169f2d967237 to your computer and use it in GitHub Desktop.
Python wrapper for SAM, SAM2, SAM2.1, MobileSAM (Segment Anything Model from Meta)
"""
Python wrapper for Ultralytics SAM
Available models:
https://docs.ultralytics.com/models/sam/
SAM base: 'sam_b.pt'
SAM large: 'mobile_sam.pt'
https://docs.ultralytics.com/models/sam-2/
SAM 2 tiny: 'sam2_t.pt'
SAM 2 small: 'sam2_s.pt'
SAM 2 base: 'sam2_b.pt'
SAM 2 large: 'sam2_l.pt'
SAM 2.1 tiny: ' sam2.1_t.pt'
SAM 2.1 small: 'sam2.1_s.pt'
SAM 2.1 base: 'sam2.1_b.pt'
SAM 2.1 large: 'sam2.1_l.pt'
MobileSAM is approximately 5 times smaller and 7 times faster than FastSAM. https://docs.ultralytics.com/models/mobile-sam/
MobileSAM: 'mobile_sam.pt'
"""
from ultralytics import SAM
import numpy as np
import tifffile
from PIL import Image
from pathlib import Path
from skimage import exposure, color, filters, util
from skimage.measure import find_contours
from skimage.morphology import disk
from skimage.filters import rank
def to_grayscale(img):
if isinstance(img, Image.Image): # PIL Image
return np.array(img.convert("L"))
if img.ndim == 3:
if img.shape[2] == 4:
img = np.array(Image.fromarray(img).convert("RGB"))
return color.rgb2gray(img)
if img.ndim == 2:
return img
raise ValueError(f"Unsupported image shape for to_grayscale: {img.shape}")
def to_rgb(img):
if isinstance(img, Image.Image): # PIL Image
return np.array(img.convert("RGB"))
if img.ndim == 2:
return color.gray2rgb(img)
elif img.ndim == 3:
if img.shape[2] == 1:
return color.gray2rgb(img.squeeze(-1))
elif img.shape[2] == 4:
return np.array(Image.fromarray(img, "RGBA").convert("RGB"))
elif img.shape[2] == 3:
return img
raise ValueError(f"Unsupported image shape for to_rgb: {img.shape}")
def to_int8(img):
if isinstance(img, np.ndarray) and img.dtype != np.uint8:
img = util.img_as_ubyte(img / img.max())
return img
def load_img(path):
ext = Path(path).suffix.lower()
if ext in [".tif", ".tiff"]:
img = tifffile.imread(path)
else:
img = Image.open(path)
img = to_rgb(img)
img = to_int8(img)
return np.ascontiguousarray(img)
class SAM2Detector:
def __init__(self, model="sam2.1_b.pt", device="auto"):
super().__init__()
if SAM is None:
raise ImportError(
"Ultralytics SAM is not installed. Install with: pip install ultralytics"
)
self.model = SAM(model)
self.set_device(device)
##################################################################
# GPU/CPU DETECTION/SELECTION #
##################################################################
@staticmethod
def list_available_devices():
"""
Returns a dict mapping device string to info dict:
{
'cpu': {'name': 'CPU', 'vram_used_gb': None, 'vram_total_gb': None},
'cuda:0': {'name': 'GPU 0 (NVIDIA RTX 3090)', 'vram_used_gb': ..., 'vram_total_gb': ...},
...
}
Requires pynvml for CUDA VRAM info.
"""
import torch
import pynvml
devices = {}
devices["cpu"] = {
"name": "CPU",
"vram_used_gb": None,
"vram_total_gb": None,
"vram_free_gb": None,
}
# CUDA GPUs
if torch.cuda.is_available():
pynvml.nvmlInit()
for i in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(i)
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
vram_used_gb = round(meminfo.used / (1024**3), 2) # GB
vram_total_gb = round(meminfo.total / (1024**3), 2) # GB
vram_free_gb = round(
(meminfo.total - meminfo.used) / (1024**3), 2
) # GB
devices[f"cuda:{i}"] = {
"name": f"GPU {i} ({name})",
"vram_used_gb": vram_used_gb,
"vram_total_gb": vram_total_gb,
"vram_free_gb": vram_free_gb,
}
pynvml.nvmlShutdown()
# Apple MPS
if torch.backends.mps.is_available():
devices["mps"] = {
"name": "Apple MPS",
"vram_used_gb": None,
"vram_total_gb": None,
"vram_free_gb": None,
}
return devices
def set_device(self, device):
"""
Set the device for the model.
device: 'auto', None, 'cpu', 'cuda', 'cuda:0', 'cuda:1', ...
'auto' or None will select the CUDA device with the most free VRAM, else MPS, else CPU.
"""
import torch
if device == "auto" or device is None:
devices = self.list_available_devices()
cuda_devices = [
(dev, info)
for dev, info in devices.items()
if dev.startswith("cuda")
]
if cuda_devices:
best_cuda, _ = max(
cuda_devices, key=lambda x: x[1]["vram_free_gb"]
)
device = best_cuda
elif "mps" in devices:
device = "mps"
else:
device = "cpu"
self.model.to(device)
def get_device(self):
return self.model.device
##################################################################
# PROCESSING #
##################################################################
def preprocess(self, img):
img = to_grayscale(img)
eq = exposure.equalize_hist(img)
clahe_img = exposure.equalize_adapthist(eq, clip_limit=0.03)
blurred = filters.gaussian(clahe_img, sigma=1, preserve_range=True)
return to_rgb(blurred)
def process(self, img, **kwargs):
# points = kwargs.get("points")
# labels = kwargs.get("labels")
processed_img = self.model(img, **kwargs)[0]
# Extract masks and find contours using skimage
contours = []
if hasattr(processed_img, "masks") and processed_img.masks is not None:
masks = processed_img.masks.data.cpu().numpy() # (N, H, W)
for mask in masks:
mask_bin = (mask > 0.5).astype(np.uint8)
found = find_contours(mask_bin, 0.5)
for contour in found:
contour_xy = np.fliplr(contour)
contours.append(contour_xy)
return contours, processed_img
def postprocess(self, contours, processed_img):
return contours, processed_img
def run(
self,
img=None,
img_path=None,
preprocess=True,
postprocess=True,
**kwargs,
):
# Load image if not provided
self.img = img if img is not None else self.load(img_path)
# Preprocessing
self.preprocessed_img = self.img
if preprocess:
self.preprocessed_img = self.preprocess(self.img)
# Processing
self.contours, self.processed_img = self.process(
self.preprocessed_img, **kwargs
)
# Postprocessing
if postprocess:
return self.postprocess(self.contours, self.processed_img)
return self.contours, self.processed_img
if __name__ == "__main__":
import matplotlib.pyplot as plt
from skimage import data
def plot_sam2_showcase(img, sam2, points=None, labels=None, title=""):
fig, axs = plt.subplots(1, 4, figsize=(16, 8))
axs[0].imshow(img)
axs[0].set_title("Original")
# Plot points if provided
if points is not None and labels is not None:
points = np.array(points)
labels = np.array(labels)
pos_pts = points[labels == 1]
neg_pts = points[labels == 0]
if len(pos_pts) > 0:
axs[0].scatter(pos_pts[:, 0], pos_pts[:, 1], c="lime", s=80, marker="+", label="Positive")
if len(neg_pts) > 0:
axs[0].scatter(neg_pts[:, 0], neg_pts[:, 1], c="red", s=80, marker="x", label="Negative")
axs[0].legend()
axs[1].imshow(sam2.preprocessed_img)
axs[1].set_title("Preprocessed")
axs[2].imshow(sam2.processed_img.plot(labels=False, boxes=False))
axs[2].set_title("Processed")
axs[3].imshow(img)
for contour in sam2.contours:
axs[3].plot(contour[:, 0], contour[:, 1], linewidth=2, color="red")
if points is not None and labels is not None:
if len(pos_pts) > 0:
axs[3].scatter(pos_pts[:, 0], pos_pts[:, 1], c="lime", s=80, marker="+", label="Positive")
if len(neg_pts) > 0:
axs[3].scatter(neg_pts[:, 0], neg_pts[:, 1], c="red", s=80, marker="x", label="Negative")
axs[3].set_title("Contours" + (" + Points" if points is not None else ""))
for ax in axs:
ax.axis("off")
fig.suptitle(title)
plt.tight_layout()
plt.show()
img = data.astronaut()
sam2 = SAM2Detector(model="../data/mobile_sam.pt") # downloaded automatically or specify path
# sam2.set_device("cpu")
print("Using", sam2.get_device())
# Different ways of using SAM
# Segment Everything (no prompt)
sam2.run(img=img, preprocess=False)
plot_sam2_showcase(img, sam2, title="SAM2: Segment Everything (no prompt)")
# Single positive point
pos_points = [[100, 100]]
pos_labels = [1]
sam2.run(img=img, points=pos_points, labels=pos_labels, preprocess=False)
plot_sam2_showcase(img, sam2, points=pos_points, labels=pos_labels, title="SAM2: Single Positive Point")
# Multiple positive and negative points
points = [[235, 40], [220, 120], [200, 200]]
labels = [1, 1, 0]
sam2.run(img=img, points=points, labels=labels, preprocess=False)
plot_sam2_showcase(img, sam2, points=points, labels=labels, title="SAM2: Multiple Positive and Negative Points")
# Multiple positive and negative points per object
points = [[[235, 40], [220, 120], [200, 200]]]
labels = [[1, 1, 0]]
sam2.run(img=img, points=points, labels=labels, preprocess=False)
plot_sam2_showcase(img, sam2, points=points, labels=labels, title="SAM2: Multiple Positive and Negative Points per Object")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment