Created
October 3, 2025 14:49
-
-
Save pieper/8225fe75c63c89c487fb9ed074365398 to your computer and use it in GitHub Desktop.
GrowCutWarp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import warp as wp | |
| import numpy as np | |
| """ | |
| This is a demo of OpenCL code ported automatically to Warp using Google Gemini. | |
| Given this input file, the code below was generated and worked the first time. | |
| https://github.com/pieper/SlicerCL/blob/master/GrowCutCL/GrowCutCL.cl.in | |
| To run in Slicer (probably any 5.8.1 or later would work, but tested on a recent preview build. | |
| Update the path to whever you save the GrowCutWarp.py code. | |
| pip_install("warp-lang[extras]") | |
| import runpy; runResult = runpy.run_path('GrowCutWarp.py', run_name='__main__') | |
| """ | |
| @wp.func | |
| def set_neighbors( | |
| volume: wp.array3d(dtype=wp.int16), | |
| slice: wp.int32, | |
| row: wp.int32, | |
| column: wp.int32, | |
| value: wp.int16, | |
| shape: wp.vec3i, | |
| ): | |
| """Sets the 3x3x3 neighborhood around a voxel to a given value.""" | |
| size = 1 | |
| for slice_off in range(-size, size + 1): | |
| for row_off in range(-size, size + 1): | |
| for col_off in range(-size, size + 1): | |
| s = slice + slice_off | |
| r = row + row_off | |
| c = column + col_off | |
| if s >= 0 and s < shape[0] and r >= 0 and r < shape[1] and c >= 0 and c < shape[2]: | |
| volume[s, r, c] = value | |
| @wp.kernel | |
| def initial_candidates_kernel( | |
| labels: wp.array3d(dtype=wp.int16), | |
| candidates: wp.array3d(dtype=wp.int16), | |
| shape: wp.vec3i, | |
| ): | |
| """Initializes the candidate voxels for the GrowCut algorithm.""" | |
| slice, row, column = wp.tid() | |
| if labels[slice, row, column] != 0: | |
| set_neighbors(candidates, slice, row, column, wp.int16(1), shape) | |
| @wp.kernel | |
| def grow_cut_kernel( | |
| volume: wp.array3d(dtype=wp.int16), | |
| label: wp.array3d(dtype=wp.int16), | |
| theta: wp.array3d(dtype=wp.float32), | |
| label_next: wp.array3d(dtype=wp.int16), | |
| theta_next: wp.array3d(dtype=wp.float32), | |
| candidates: wp.array3d(dtype=wp.int16), | |
| candidates_next: wp.array3d(dtype=wp.int16), | |
| volume_max: wp.float32, | |
| changed_flag: wp.array(dtype=wp.int32), | |
| shape: wp.vec3i, | |
| ): | |
| """Performs one iteration of the GrowCut algorithm.""" | |
| slice, row, column = wp.tid() | |
| # Copy current state to next state | |
| label_next[slice, row, column] = label[slice, row, column] | |
| theta_now = theta[slice, row, column] | |
| theta_next[slice, row, column] = theta_now | |
| if candidates[slice, row, column] == 0: | |
| return | |
| sample = wp.float(volume[slice, row, column]) | |
| size = 1 | |
| for slice_off in range(-size, size + 1): | |
| for row_off in range(-size, size + 1): | |
| for col_off in range(-size, size + 1): | |
| s = slice + slice_off | |
| r = row + row_off | |
| c = column + col_off | |
| if s >= 0 and s < shape[0] and r >= 0 and r < shape[1] and c >= 0 and c < shape[2]: | |
| other_label = label[s, r, c] | |
| if other_label != 0: | |
| other_sample = wp.float(volume[s, r, c]) | |
| other_theta = theta[s, r, c] | |
| sample_diff = wp.abs(sample - other_sample) | |
| attack_strength = other_theta * (1.0 - (sample_diff / volume_max)) | |
| if attack_strength > theta_now: | |
| label_next[slice, row, column] = other_label | |
| theta_next[slice, row, column] = attack_strength | |
| theta_now = attack_strength | |
| wp.atomic_add(changed_flag, 0, 1) | |
| set_neighbors(candidates_next, slice, row, column, wp.int16(1), shape) | |
| class GrowCutWarpLogic: | |
| def __init__(self, background_array, label_array, device="cpu"): | |
| wp.init() | |
| self.device = device | |
| self.shape = background_array.shape | |
| self.wp_shape = wp.vec3i(self.shape[0], self.shape[1], self.shape[2]) | |
| self.background_wp = wp.array(background_array, dtype=wp.int16, device=self.device) | |
| self.label_wp = wp.array(label_array, dtype=wp.int16, device=self.device) | |
| self.label_next_wp = wp.zeros_like(self.label_wp) | |
| self.volume_max = float(background_array.max()) | |
| binary_labels = (label_array != 0).astype(np.float32) | |
| theta = (2**15) * binary_labels | |
| self.theta_wp = wp.array(theta, dtype=wp.float32, device=self.device) | |
| self.theta_next_wp = wp.zeros_like(self.theta_wp) | |
| self.candidates_wp = wp.zeros_like(self.label_wp) | |
| self.candidates_next_wp = wp.zeros_like(self.label_wp) | |
| self.changed_flag_wp = wp.zeros(1, dtype=wp.int32, device=self.device) | |
| self.candidates_initialized = False | |
| def _initialize_candidates(self): | |
| if not self.candidates_initialized: | |
| wp.launch( | |
| kernel=initial_candidates_kernel, | |
| dim=self.shape, | |
| inputs=[self.label_wp, self.candidates_wp, self.wp_shape], | |
| device=self.device, | |
| ) | |
| self.candidates_initialized = True | |
| def step(self, iterations=1): | |
| self._initialize_candidates() | |
| for _ in range(iterations): | |
| self.changed_flag_wp.zero_() | |
| self.candidates_next_wp.zero_() | |
| wp.launch( | |
| kernel=grow_cut_kernel, | |
| dim=self.shape, | |
| inputs=[ | |
| self.background_wp, self.label_wp, self.theta_wp, | |
| self.label_next_wp, self.theta_next_wp, | |
| self.candidates_wp, self.candidates_next_wp, | |
| self.volume_max, self.changed_flag_wp, self.wp_shape | |
| ], | |
| device=self.device, | |
| ) | |
| wp.copy(self.label_wp, self.label_next_wp) | |
| wp.copy(self.theta_wp, self.theta_next_wp) | |
| wp.copy(self.candidates_wp, self.candidates_next_wp) | |
| # Return the number of pixels that changed | |
| return self.changed_flag_wp.numpy()[0] | |
| def get_result(self): | |
| return self.label_wp.numpy() | |
| if __name__ == "__main__": | |
| try: # pragma: no cover | |
| import slicer | |
| import SampleData | |
| import time | |
| except ImportError: | |
| print("This test script must be run within the 3D Slicer Python environment.") | |
| exit(1) | |
| print("Setting up GrowCutWarp test scene...") | |
| # 1. Download MRHead sample data | |
| print("Downloading sample data...") | |
| background_volume_node = SampleData.SampleDataLogic().downloadSample("MRHead") | |
| # 2. Create a label map volume | |
| print("Creating label map...") | |
| volumes_logic = slicer.modules.volumes.logic() | |
| label_volume_node = volumes_logic.CloneVolume( | |
| slicer.mrmlScene, background_volume_node, "GrowCutWarp-labels" | |
| ) | |
| label_volume_node.GetDisplayNode().SetAndObserveColorNodeID( | |
| slicer.util.getNode("GenericAnatomyColors").GetID() | |
| ) | |
| # Make the label map visible | |
| app_logic = slicer.app.applicationLogic() | |
| selection_node = app_logic.GetSelectionNode() | |
| selection_node.SetReferenceActiveLabelVolumeID(label_volume_node.GetID()) | |
| app_logic.PropagateVolumeSelection() | |
| # 3. Get numpy arrays | |
| background_array = slicer.util.arrayFromVolume(background_volume_node) | |
| label_array = slicer.util.arrayFromVolume(label_volume_node) | |
| label_array[:] = 0 # Clear the label map | |
| # 4. Add some seed points (blobs) | |
| print("Adding seed labels...") | |
| # Foreground seed (label 1) | |
| label_array[70:75, 130:135, 90:95] = 1 | |
| # Background seed (label 2) | |
| label_array[20:25, 20:25, 20:25] = 2 | |
| label_array[110:115, 230:235, 120:125] = 2 | |
| # Notify Slicer that the label array has been modified | |
| slicer.util.arrayFromVolumeModified(label_volume_node) | |
| # 5. Instantiate and run the GrowCutWarp logic | |
| wp.init() | |
| if wp.is_cuda_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| # device = "cpu" # Uncomment to force CPU execution | |
| print("Initializing GrowCutWarp logic...") | |
| logic = GrowCutWarpLogic(background_array, label_array, device=device) | |
| print(f"Running GrowCut iterations on device: {device}") | |
| max_run_time_seconds = 10.0 | |
| iteration_count = 0 | |
| total_start_time = time.time() | |
| while (time.time() - total_start_time) < max_run_time_seconds: | |
| iteration_start_time = time.time() | |
| iteration_count += 1 | |
| changed_count = logic.step(1) | |
| elapsed_time = time.time() - iteration_start_time | |
| converged = (changed_count == 0) | |
| if (iteration_count % 10 == 0) or converged: | |
| print(f" Iteration {iteration_count}: {changed_count} pixels changed in {elapsed_time:.4f} seconds (updating scene)") | |
| # Update the Slicer scene with the intermediate result | |
| result_array = logic.get_result() | |
| label_array[:] = result_array | |
| slicer.util.arrayFromVolumeModified(label_volume_node) | |
| slicer.app.processEvents() # Allow GUI to update | |
| if converged: | |
| print("Convergence reached.") | |
| break | |
| if (time.time() - total_start_time) >= max_run_time_seconds: | |
| print(f"Stopping after {max_run_time_seconds:.1f} seconds.") | |
| total_elapsed_time = time.time() - total_start_time | |
| print(f"GrowCutWarp test finished in {total_elapsed_time:.2f} seconds.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment