Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created September 5, 2025 08:29
Show Gist options
  • Select an option

  • Save sayakpaul/04b51ed45efd08c19d726f0778151871 to your computer and use it in GitHub Desktop.

Select an option

Save sayakpaul/04b51ed45efd08c19d726f0778151871 to your computer and use it in GitHub Desktop.
import torch
from diffusers import DiffusionPipeline
import spaces
from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
from time import perf_counter
import argparse
CKPT_ID = "black-forest-labs/Flux.1-Dev"
def get_pipe_kwargs():
pipe_kwargs = {
"prompt": "A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 50,
"max_sequence_length": 512,
"generator": torch.manual_seed(0)
}
return pipe_kwargs
def load_pipeline():
pipe = DiffusionPipeline.from_pretrained(CKPT_ID, torch_dtype=torch.bfloat16, device_map="cuda")
pipe.set_progress_bar_config(disable=True)
return pipe
@torch.no_grad()
def aot_compile_load(pipe, regional=False):
prompt = "arbitrary example prompt"
torch.compiler.reset()
with torch._inductor.utils.fresh_inductor_cache():
if regional:
with spaces.aoti_capture(pipe.transformer.transformer_blocks[0]) as call_double_blocks:
pipe(prompt=prompt)
with spaces.aoti_capture(
pipe.transformer.single_transformer_blocks[0]
) as call_single_blocks:
pipe(prompt=prompt)
exported_double = torch.export.export(
mod=pipe.transformer.transformer_blocks[0],
args=call_double_blocks.args,
kwargs=call_double_blocks.kwargs,
)
exported_single = torch.export.export(
mod=pipe.transformer.single_transformer_blocks[0],
args=call_single_blocks.args,
kwargs=call_single_blocks.kwargs,
)
compiled_double = spaces.aoti_compile(exported_double)
compiled_single = spaces.aoti_compile(exported_single)
for block in pipe.transformer.transformer_blocks:
weights = ZeroGPUWeights(block.state_dict())
compiled_block = ZeroGPUCompiledModel(compiled_double.archive_file, weights)
block.forward = compiled_block
for block in pipe.transformer.single_transformer_blocks:
weights = ZeroGPUWeights(block.state_dict())
compiled_block = ZeroGPUCompiledModel(compiled_single.archive_file, weights)
block.forward = compiled_block
else:
with spaces.aoti_capture(pipe.transformer) as call:
pipe(prompt=prompt)
exported = torch.export.export(
pipe.transformer,
args=call.args,
kwargs=call.kwargs
)
compiled = spaces.aoti_compile(exported)
spaces.aoti_apply(compiled, pipe.transformer)
return pipe
def measure_compile_time(pipe, regional=False):
start = perf_counter()
pipe = aot_compile_load(pipe=pipe, regional=regional)
torch.cuda.synchronize()
end = perf_counter()
# make sure the model works.
image = pipe(**get_pipe_kwargs()).images[0]
return end - start, image
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--regional", action="store_true")
args = parser.parse_args()
pipe = load_pipeline()
latency, image = measure_compile_time(pipe, regional=args.regional)
print(f"{args.regional=}, {latency=} secs.")
image.save(f"regional@{args.regional}.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment