Skip to content

Instantly share code, notes, and snippets.

@pieper
Created October 3, 2025 14:49
Show Gist options
  • Select an option

  • Save pieper/8225fe75c63c89c487fb9ed074365398 to your computer and use it in GitHub Desktop.

Select an option

Save pieper/8225fe75c63c89c487fb9ed074365398 to your computer and use it in GitHub Desktop.
GrowCutWarp
import warp as wp
import numpy as np
"""
This is a demo of OpenCL code ported automatically to Warp using Google Gemini.
Given this input file, the code below was generated and worked the first time.
https://github.com/pieper/SlicerCL/blob/master/GrowCutCL/GrowCutCL.cl.in
To run in Slicer (probably any 5.8.1 or later would work, but tested on a recent preview build.
Update the path to whever you save the GrowCutWarp.py code.
pip_install("warp-lang[extras]")
import runpy; runResult = runpy.run_path('GrowCutWarp.py', run_name='__main__')
"""
@wp.func
def set_neighbors(
volume: wp.array3d(dtype=wp.int16),
slice: wp.int32,
row: wp.int32,
column: wp.int32,
value: wp.int16,
shape: wp.vec3i,
):
"""Sets the 3x3x3 neighborhood around a voxel to a given value."""
size = 1
for slice_off in range(-size, size + 1):
for row_off in range(-size, size + 1):
for col_off in range(-size, size + 1):
s = slice + slice_off
r = row + row_off
c = column + col_off
if s >= 0 and s < shape[0] and r >= 0 and r < shape[1] and c >= 0 and c < shape[2]:
volume[s, r, c] = value
@wp.kernel
def initial_candidates_kernel(
labels: wp.array3d(dtype=wp.int16),
candidates: wp.array3d(dtype=wp.int16),
shape: wp.vec3i,
):
"""Initializes the candidate voxels for the GrowCut algorithm."""
slice, row, column = wp.tid()
if labels[slice, row, column] != 0:
set_neighbors(candidates, slice, row, column, wp.int16(1), shape)
@wp.kernel
def grow_cut_kernel(
volume: wp.array3d(dtype=wp.int16),
label: wp.array3d(dtype=wp.int16),
theta: wp.array3d(dtype=wp.float32),
label_next: wp.array3d(dtype=wp.int16),
theta_next: wp.array3d(dtype=wp.float32),
candidates: wp.array3d(dtype=wp.int16),
candidates_next: wp.array3d(dtype=wp.int16),
volume_max: wp.float32,
changed_flag: wp.array(dtype=wp.int32),
shape: wp.vec3i,
):
"""Performs one iteration of the GrowCut algorithm."""
slice, row, column = wp.tid()
# Copy current state to next state
label_next[slice, row, column] = label[slice, row, column]
theta_now = theta[slice, row, column]
theta_next[slice, row, column] = theta_now
if candidates[slice, row, column] == 0:
return
sample = wp.float(volume[slice, row, column])
size = 1
for slice_off in range(-size, size + 1):
for row_off in range(-size, size + 1):
for col_off in range(-size, size + 1):
s = slice + slice_off
r = row + row_off
c = column + col_off
if s >= 0 and s < shape[0] and r >= 0 and r < shape[1] and c >= 0 and c < shape[2]:
other_label = label[s, r, c]
if other_label != 0:
other_sample = wp.float(volume[s, r, c])
other_theta = theta[s, r, c]
sample_diff = wp.abs(sample - other_sample)
attack_strength = other_theta * (1.0 - (sample_diff / volume_max))
if attack_strength > theta_now:
label_next[slice, row, column] = other_label
theta_next[slice, row, column] = attack_strength
theta_now = attack_strength
wp.atomic_add(changed_flag, 0, 1)
set_neighbors(candidates_next, slice, row, column, wp.int16(1), shape)
class GrowCutWarpLogic:
def __init__(self, background_array, label_array, device="cpu"):
wp.init()
self.device = device
self.shape = background_array.shape
self.wp_shape = wp.vec3i(self.shape[0], self.shape[1], self.shape[2])
self.background_wp = wp.array(background_array, dtype=wp.int16, device=self.device)
self.label_wp = wp.array(label_array, dtype=wp.int16, device=self.device)
self.label_next_wp = wp.zeros_like(self.label_wp)
self.volume_max = float(background_array.max())
binary_labels = (label_array != 0).astype(np.float32)
theta = (2**15) * binary_labels
self.theta_wp = wp.array(theta, dtype=wp.float32, device=self.device)
self.theta_next_wp = wp.zeros_like(self.theta_wp)
self.candidates_wp = wp.zeros_like(self.label_wp)
self.candidates_next_wp = wp.zeros_like(self.label_wp)
self.changed_flag_wp = wp.zeros(1, dtype=wp.int32, device=self.device)
self.candidates_initialized = False
def _initialize_candidates(self):
if not self.candidates_initialized:
wp.launch(
kernel=initial_candidates_kernel,
dim=self.shape,
inputs=[self.label_wp, self.candidates_wp, self.wp_shape],
device=self.device,
)
self.candidates_initialized = True
def step(self, iterations=1):
self._initialize_candidates()
for _ in range(iterations):
self.changed_flag_wp.zero_()
self.candidates_next_wp.zero_()
wp.launch(
kernel=grow_cut_kernel,
dim=self.shape,
inputs=[
self.background_wp, self.label_wp, self.theta_wp,
self.label_next_wp, self.theta_next_wp,
self.candidates_wp, self.candidates_next_wp,
self.volume_max, self.changed_flag_wp, self.wp_shape
],
device=self.device,
)
wp.copy(self.label_wp, self.label_next_wp)
wp.copy(self.theta_wp, self.theta_next_wp)
wp.copy(self.candidates_wp, self.candidates_next_wp)
# Return the number of pixels that changed
return self.changed_flag_wp.numpy()[0]
def get_result(self):
return self.label_wp.numpy()
if __name__ == "__main__":
try: # pragma: no cover
import slicer
import SampleData
import time
except ImportError:
print("This test script must be run within the 3D Slicer Python environment.")
exit(1)
print("Setting up GrowCutWarp test scene...")
# 1. Download MRHead sample data
print("Downloading sample data...")
background_volume_node = SampleData.SampleDataLogic().downloadSample("MRHead")
# 2. Create a label map volume
print("Creating label map...")
volumes_logic = slicer.modules.volumes.logic()
label_volume_node = volumes_logic.CloneVolume(
slicer.mrmlScene, background_volume_node, "GrowCutWarp-labels"
)
label_volume_node.GetDisplayNode().SetAndObserveColorNodeID(
slicer.util.getNode("GenericAnatomyColors").GetID()
)
# Make the label map visible
app_logic = slicer.app.applicationLogic()
selection_node = app_logic.GetSelectionNode()
selection_node.SetReferenceActiveLabelVolumeID(label_volume_node.GetID())
app_logic.PropagateVolumeSelection()
# 3. Get numpy arrays
background_array = slicer.util.arrayFromVolume(background_volume_node)
label_array = slicer.util.arrayFromVolume(label_volume_node)
label_array[:] = 0 # Clear the label map
# 4. Add some seed points (blobs)
print("Adding seed labels...")
# Foreground seed (label 1)
label_array[70:75, 130:135, 90:95] = 1
# Background seed (label 2)
label_array[20:25, 20:25, 20:25] = 2
label_array[110:115, 230:235, 120:125] = 2
# Notify Slicer that the label array has been modified
slicer.util.arrayFromVolumeModified(label_volume_node)
# 5. Instantiate and run the GrowCutWarp logic
wp.init()
if wp.is_cuda_available():
device = "cuda"
else:
device = "cpu"
# device = "cpu" # Uncomment to force CPU execution
print("Initializing GrowCutWarp logic...")
logic = GrowCutWarpLogic(background_array, label_array, device=device)
print(f"Running GrowCut iterations on device: {device}")
max_run_time_seconds = 10.0
iteration_count = 0
total_start_time = time.time()
while (time.time() - total_start_time) < max_run_time_seconds:
iteration_start_time = time.time()
iteration_count += 1
changed_count = logic.step(1)
elapsed_time = time.time() - iteration_start_time
converged = (changed_count == 0)
if (iteration_count % 10 == 0) or converged:
print(f" Iteration {iteration_count}: {changed_count} pixels changed in {elapsed_time:.4f} seconds (updating scene)")
# Update the Slicer scene with the intermediate result
result_array = logic.get_result()
label_array[:] = result_array
slicer.util.arrayFromVolumeModified(label_volume_node)
slicer.app.processEvents() # Allow GUI to update
if converged:
print("Convergence reached.")
break
if (time.time() - total_start_time) >= max_run_time_seconds:
print(f"Stopping after {max_run_time_seconds:.1f} seconds.")
total_elapsed_time = time.time() - total_start_time
print(f"GrowCutWarp test finished in {total_elapsed_time:.2f} seconds.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment