Created
August 13, 2025 14:14
-
-
Save kipavy/edd457f6b1e8be24b90f169f2d967237 to your computer and use it in GitHub Desktop.
Python wrapper for SAM, SAM2, SAM2.1, MobileSAM (Segment Anything Model from Meta)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Python wrapper for Ultralytics SAM | |
| Available models: | |
| https://docs.ultralytics.com/models/sam/ | |
| SAM base: 'sam_b.pt' | |
| SAM large: 'mobile_sam.pt' | |
| https://docs.ultralytics.com/models/sam-2/ | |
| SAM 2 tiny: 'sam2_t.pt' | |
| SAM 2 small: 'sam2_s.pt' | |
| SAM 2 base: 'sam2_b.pt' | |
| SAM 2 large: 'sam2_l.pt' | |
| SAM 2.1 tiny: ' sam2.1_t.pt' | |
| SAM 2.1 small: 'sam2.1_s.pt' | |
| SAM 2.1 base: 'sam2.1_b.pt' | |
| SAM 2.1 large: 'sam2.1_l.pt' | |
| MobileSAM is approximately 5 times smaller and 7 times faster than FastSAM. https://docs.ultralytics.com/models/mobile-sam/ | |
| MobileSAM: 'mobile_sam.pt' | |
| """ | |
| from ultralytics import SAM | |
| import numpy as np | |
| import tifffile | |
| from PIL import Image | |
| from pathlib import Path | |
| from skimage import exposure, color, filters, util | |
| from skimage.measure import find_contours | |
| from skimage.morphology import disk | |
| from skimage.filters import rank | |
| def to_grayscale(img): | |
| if isinstance(img, Image.Image): # PIL Image | |
| return np.array(img.convert("L")) | |
| if img.ndim == 3: | |
| if img.shape[2] == 4: | |
| img = np.array(Image.fromarray(img).convert("RGB")) | |
| return color.rgb2gray(img) | |
| if img.ndim == 2: | |
| return img | |
| raise ValueError(f"Unsupported image shape for to_grayscale: {img.shape}") | |
| def to_rgb(img): | |
| if isinstance(img, Image.Image): # PIL Image | |
| return np.array(img.convert("RGB")) | |
| if img.ndim == 2: | |
| return color.gray2rgb(img) | |
| elif img.ndim == 3: | |
| if img.shape[2] == 1: | |
| return color.gray2rgb(img.squeeze(-1)) | |
| elif img.shape[2] == 4: | |
| return np.array(Image.fromarray(img, "RGBA").convert("RGB")) | |
| elif img.shape[2] == 3: | |
| return img | |
| raise ValueError(f"Unsupported image shape for to_rgb: {img.shape}") | |
| def to_int8(img): | |
| if isinstance(img, np.ndarray) and img.dtype != np.uint8: | |
| img = util.img_as_ubyte(img / img.max()) | |
| return img | |
| def load_img(path): | |
| ext = Path(path).suffix.lower() | |
| if ext in [".tif", ".tiff"]: | |
| img = tifffile.imread(path) | |
| else: | |
| img = Image.open(path) | |
| img = to_rgb(img) | |
| img = to_int8(img) | |
| return np.ascontiguousarray(img) | |
| class SAM2Detector: | |
| def __init__(self, model="sam2.1_b.pt", device="auto"): | |
| super().__init__() | |
| if SAM is None: | |
| raise ImportError( | |
| "Ultralytics SAM is not installed. Install with: pip install ultralytics" | |
| ) | |
| self.model = SAM(model) | |
| self.set_device(device) | |
| ################################################################## | |
| # GPU/CPU DETECTION/SELECTION # | |
| ################################################################## | |
| @staticmethod | |
| def list_available_devices(): | |
| """ | |
| Returns a dict mapping device string to info dict: | |
| { | |
| 'cpu': {'name': 'CPU', 'vram_used_gb': None, 'vram_total_gb': None}, | |
| 'cuda:0': {'name': 'GPU 0 (NVIDIA RTX 3090)', 'vram_used_gb': ..., 'vram_total_gb': ...}, | |
| ... | |
| } | |
| Requires pynvml for CUDA VRAM info. | |
| """ | |
| import torch | |
| import pynvml | |
| devices = {} | |
| devices["cpu"] = { | |
| "name": "CPU", | |
| "vram_used_gb": None, | |
| "vram_total_gb": None, | |
| "vram_free_gb": None, | |
| } | |
| # CUDA GPUs | |
| if torch.cuda.is_available(): | |
| pynvml.nvmlInit() | |
| for i in range(torch.cuda.device_count()): | |
| name = torch.cuda.get_device_name(i) | |
| handle = pynvml.nvmlDeviceGetHandleByIndex(i) | |
| meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) | |
| vram_used_gb = round(meminfo.used / (1024**3), 2) # GB | |
| vram_total_gb = round(meminfo.total / (1024**3), 2) # GB | |
| vram_free_gb = round( | |
| (meminfo.total - meminfo.used) / (1024**3), 2 | |
| ) # GB | |
| devices[f"cuda:{i}"] = { | |
| "name": f"GPU {i} ({name})", | |
| "vram_used_gb": vram_used_gb, | |
| "vram_total_gb": vram_total_gb, | |
| "vram_free_gb": vram_free_gb, | |
| } | |
| pynvml.nvmlShutdown() | |
| # Apple MPS | |
| if torch.backends.mps.is_available(): | |
| devices["mps"] = { | |
| "name": "Apple MPS", | |
| "vram_used_gb": None, | |
| "vram_total_gb": None, | |
| "vram_free_gb": None, | |
| } | |
| return devices | |
| def set_device(self, device): | |
| """ | |
| Set the device for the model. | |
| device: 'auto', None, 'cpu', 'cuda', 'cuda:0', 'cuda:1', ... | |
| 'auto' or None will select the CUDA device with the most free VRAM, else MPS, else CPU. | |
| """ | |
| import torch | |
| if device == "auto" or device is None: | |
| devices = self.list_available_devices() | |
| cuda_devices = [ | |
| (dev, info) | |
| for dev, info in devices.items() | |
| if dev.startswith("cuda") | |
| ] | |
| if cuda_devices: | |
| best_cuda, _ = max( | |
| cuda_devices, key=lambda x: x[1]["vram_free_gb"] | |
| ) | |
| device = best_cuda | |
| elif "mps" in devices: | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| self.model.to(device) | |
| def get_device(self): | |
| return self.model.device | |
| ################################################################## | |
| # PROCESSING # | |
| ################################################################## | |
| def preprocess(self, img): | |
| img = to_grayscale(img) | |
| eq = exposure.equalize_hist(img) | |
| clahe_img = exposure.equalize_adapthist(eq, clip_limit=0.03) | |
| blurred = filters.gaussian(clahe_img, sigma=1, preserve_range=True) | |
| return to_rgb(blurred) | |
| def process(self, img, **kwargs): | |
| # points = kwargs.get("points") | |
| # labels = kwargs.get("labels") | |
| processed_img = self.model(img, **kwargs)[0] | |
| # Extract masks and find contours using skimage | |
| contours = [] | |
| if hasattr(processed_img, "masks") and processed_img.masks is not None: | |
| masks = processed_img.masks.data.cpu().numpy() # (N, H, W) | |
| for mask in masks: | |
| mask_bin = (mask > 0.5).astype(np.uint8) | |
| found = find_contours(mask_bin, 0.5) | |
| for contour in found: | |
| contour_xy = np.fliplr(contour) | |
| contours.append(contour_xy) | |
| return contours, processed_img | |
| def postprocess(self, contours, processed_img): | |
| return contours, processed_img | |
| def run( | |
| self, | |
| img=None, | |
| img_path=None, | |
| preprocess=True, | |
| postprocess=True, | |
| **kwargs, | |
| ): | |
| # Load image if not provided | |
| self.img = img if img is not None else self.load(img_path) | |
| # Preprocessing | |
| self.preprocessed_img = self.img | |
| if preprocess: | |
| self.preprocessed_img = self.preprocess(self.img) | |
| # Processing | |
| self.contours, self.processed_img = self.process( | |
| self.preprocessed_img, **kwargs | |
| ) | |
| # Postprocessing | |
| if postprocess: | |
| return self.postprocess(self.contours, self.processed_img) | |
| return self.contours, self.processed_img | |
| if __name__ == "__main__": | |
| import matplotlib.pyplot as plt | |
| from skimage import data | |
| def plot_sam2_showcase(img, sam2, points=None, labels=None, title=""): | |
| fig, axs = plt.subplots(1, 4, figsize=(16, 8)) | |
| axs[0].imshow(img) | |
| axs[0].set_title("Original") | |
| # Plot points if provided | |
| if points is not None and labels is not None: | |
| points = np.array(points) | |
| labels = np.array(labels) | |
| pos_pts = points[labels == 1] | |
| neg_pts = points[labels == 0] | |
| if len(pos_pts) > 0: | |
| axs[0].scatter(pos_pts[:, 0], pos_pts[:, 1], c="lime", s=80, marker="+", label="Positive") | |
| if len(neg_pts) > 0: | |
| axs[0].scatter(neg_pts[:, 0], neg_pts[:, 1], c="red", s=80, marker="x", label="Negative") | |
| axs[0].legend() | |
| axs[1].imshow(sam2.preprocessed_img) | |
| axs[1].set_title("Preprocessed") | |
| axs[2].imshow(sam2.processed_img.plot(labels=False, boxes=False)) | |
| axs[2].set_title("Processed") | |
| axs[3].imshow(img) | |
| for contour in sam2.contours: | |
| axs[3].plot(contour[:, 0], contour[:, 1], linewidth=2, color="red") | |
| if points is not None and labels is not None: | |
| if len(pos_pts) > 0: | |
| axs[3].scatter(pos_pts[:, 0], pos_pts[:, 1], c="lime", s=80, marker="+", label="Positive") | |
| if len(neg_pts) > 0: | |
| axs[3].scatter(neg_pts[:, 0], neg_pts[:, 1], c="red", s=80, marker="x", label="Negative") | |
| axs[3].set_title("Contours" + (" + Points" if points is not None else "")) | |
| for ax in axs: | |
| ax.axis("off") | |
| fig.suptitle(title) | |
| plt.tight_layout() | |
| plt.show() | |
| img = data.astronaut() | |
| sam2 = SAM2Detector(model="../data/mobile_sam.pt") # downloaded automatically or specify path | |
| # sam2.set_device("cpu") | |
| print("Using", sam2.get_device()) | |
| # Different ways of using SAM | |
| # Segment Everything (no prompt) | |
| sam2.run(img=img, preprocess=False) | |
| plot_sam2_showcase(img, sam2, title="SAM2: Segment Everything (no prompt)") | |
| # Single positive point | |
| pos_points = [[100, 100]] | |
| pos_labels = [1] | |
| sam2.run(img=img, points=pos_points, labels=pos_labels, preprocess=False) | |
| plot_sam2_showcase(img, sam2, points=pos_points, labels=pos_labels, title="SAM2: Single Positive Point") | |
| # Multiple positive and negative points | |
| points = [[235, 40], [220, 120], [200, 200]] | |
| labels = [1, 1, 0] | |
| sam2.run(img=img, points=points, labels=labels, preprocess=False) | |
| plot_sam2_showcase(img, sam2, points=points, labels=labels, title="SAM2: Multiple Positive and Negative Points") | |
| # Multiple positive and negative points per object | |
| points = [[[235, 40], [220, 120], [200, 200]]] | |
| labels = [[1, 1, 0]] | |
| sam2.run(img=img, points=points, labels=labels, preprocess=False) | |
| plot_sam2_showcase(img, sam2, points=points, labels=labels, title="SAM2: Multiple Positive and Negative Points per Object") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment