Created
September 5, 2025 08:29
-
-
Save sayakpaul/04b51ed45efd08c19d726f0778151871 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from diffusers import DiffusionPipeline | |
| import spaces | |
| from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights | |
| from time import perf_counter | |
| import argparse | |
| CKPT_ID = "black-forest-labs/Flux.1-Dev" | |
| def get_pipe_kwargs(): | |
| pipe_kwargs = { | |
| "prompt": "A cat holding a sign that says hello world", | |
| "height": 1024, | |
| "width": 1024, | |
| "guidance_scale": 3.5, | |
| "num_inference_steps": 50, | |
| "max_sequence_length": 512, | |
| "generator": torch.manual_seed(0) | |
| } | |
| return pipe_kwargs | |
| def load_pipeline(): | |
| pipe = DiffusionPipeline.from_pretrained(CKPT_ID, torch_dtype=torch.bfloat16, device_map="cuda") | |
| pipe.set_progress_bar_config(disable=True) | |
| return pipe | |
| @torch.no_grad() | |
| def aot_compile_load(pipe, regional=False): | |
| prompt = "arbitrary example prompt" | |
| torch.compiler.reset() | |
| with torch._inductor.utils.fresh_inductor_cache(): | |
| if regional: | |
| with spaces.aoti_capture(pipe.transformer.transformer_blocks[0]) as call_double_blocks: | |
| pipe(prompt=prompt) | |
| with spaces.aoti_capture( | |
| pipe.transformer.single_transformer_blocks[0] | |
| ) as call_single_blocks: | |
| pipe(prompt=prompt) | |
| exported_double = torch.export.export( | |
| mod=pipe.transformer.transformer_blocks[0], | |
| args=call_double_blocks.args, | |
| kwargs=call_double_blocks.kwargs, | |
| ) | |
| exported_single = torch.export.export( | |
| mod=pipe.transformer.single_transformer_blocks[0], | |
| args=call_single_blocks.args, | |
| kwargs=call_single_blocks.kwargs, | |
| ) | |
| compiled_double = spaces.aoti_compile(exported_double) | |
| compiled_single = spaces.aoti_compile(exported_single) | |
| for block in pipe.transformer.transformer_blocks: | |
| weights = ZeroGPUWeights(block.state_dict()) | |
| compiled_block = ZeroGPUCompiledModel(compiled_double.archive_file, weights) | |
| block.forward = compiled_block | |
| for block in pipe.transformer.single_transformer_blocks: | |
| weights = ZeroGPUWeights(block.state_dict()) | |
| compiled_block = ZeroGPUCompiledModel(compiled_single.archive_file, weights) | |
| block.forward = compiled_block | |
| else: | |
| with spaces.aoti_capture(pipe.transformer) as call: | |
| pipe(prompt=prompt) | |
| exported = torch.export.export( | |
| pipe.transformer, | |
| args=call.args, | |
| kwargs=call.kwargs | |
| ) | |
| compiled = spaces.aoti_compile(exported) | |
| spaces.aoti_apply(compiled, pipe.transformer) | |
| return pipe | |
| def measure_compile_time(pipe, regional=False): | |
| start = perf_counter() | |
| pipe = aot_compile_load(pipe=pipe, regional=regional) | |
| torch.cuda.synchronize() | |
| end = perf_counter() | |
| # make sure the model works. | |
| image = pipe(**get_pipe_kwargs()).images[0] | |
| return end - start, image | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--regional", action="store_true") | |
| args = parser.parse_args() | |
| pipe = load_pipeline() | |
| latency, image = measure_compile_time(pipe, regional=args.regional) | |
| print(f"{args.regional=}, {latency=} secs.") | |
| image.save(f"regional@{args.regional}.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment