Created
June 18, 2024 15:11
-
-
Save Georgefwt/b90e165a6e96cca65340a20efcf733b2 to your computer and use it in GitHub Desktop.
stable diffusion 3 clip walk and latent space walk
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import inspect | |
| import fire | |
| from diffusers import StableDiffusion3Pipeline | |
| from time import time | |
| from PIL import Image | |
| from einops import rearrange | |
| import numpy as np | |
| import torch | |
| from torch import autocast | |
| from torchvision.utils import make_grid | |
| from tqdm import tqdm | |
| # ----------------------------------------------------------------------------- | |
| @torch.no_grad() | |
| def diffuse( | |
| pipe, | |
| cond_embeddings, | |
| cond_latents, | |
| num_inference_steps, | |
| guidance_scale, | |
| eta, | |
| ): | |
| prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = cond_embeddings | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | |
| pipe.scheduler.set_timesteps(num_inference_steps, device="cuda") | |
| with tqdm(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(pipe.scheduler.timesteps): | |
| latent_model_input = torch.cat([cond_latents] * 2) | |
| timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.device) | |
| # predict the noise residual | |
| noise_pred = pipe.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| pooled_projections=pooled_prompt_embeds, | |
| return_dict=False, | |
| )[0] | |
| # cfg | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, return_dict=False, s_noise=0.0)[0] | |
| progress_bar.update() | |
| # scale and decode the image latents with vae | |
| cond_latents = (cond_latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor | |
| image = pipe.vae.decode(cond_latents, return_dict=False)[0] | |
| # generate output numpy image as uint8 | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy() | |
| image = (image[0] * 255).astype(np.uint8) | |
| return image | |
| def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): | |
| """Helper function to spherically interpolate two arrays v1 v2 in PyTorch""" | |
| v2 = [] | |
| for i in range(len(v0)): | |
| v0_ = v0[i] | |
| v1_ = v1[i] | |
| dot = torch.sum((v0_ / torch.norm(v0_)) * (v1_ / torch.norm(v1_))) | |
| dot = torch.clamp(dot, -1.0, 1.0) # Clip to handle numerical issues | |
| if torch.abs(dot) > DOT_THRESHOLD: | |
| v2_ = (1 - t) * v0_ + t * v1_ | |
| else: | |
| theta_0 = torch.acos(dot) | |
| sin_theta_0 = torch.sin(theta_0) | |
| theta_t = theta_0 * t | |
| sin_theta_t = torch.sin(theta_t) | |
| s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 | |
| s1 = sin_theta_t / sin_theta_0 | |
| v2_ = s0 * v0_ + s1 * v1_ | |
| v2.append(v2_) | |
| return tuple(v2) | |
| def main( | |
| # -------------------------------------- | |
| # args you probably want to change | |
| prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompts to dream about | |
| seeds=[243, 523], | |
| gpu = 0, # id of the gpu to run on | |
| name = 'berry_good_spaghetti', # name of this project, for the output directory | |
| rootdir = './dreams', | |
| num_steps = 72, # number of steps between each pair of sampled points | |
| # -------------------------------------- | |
| # args you probably don't want to change | |
| num_inference_steps = 30, | |
| guidance_scale = 7.5, | |
| eta = 0.0, | |
| width = 1024, | |
| height = 1024, | |
| # -------------------------------------- | |
| ): | |
| assert len(prompts) == len(seeds) | |
| assert torch.cuda.is_available() | |
| assert height % 8 == 0 and width % 8 == 0 | |
| # init the output dir | |
| outdir = os.path.join(rootdir, name) | |
| os.makedirs(outdir, exist_ok=True) | |
| # # init all of the models and move them to a given GPU | |
| pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16) | |
| pipe = pipe.to("cuda") | |
| # get the conditional text embeddings based on the prompts | |
| prompt_embeddings = [] | |
| for prompt in prompts: | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt( | |
| prompt=prompt, | |
| prompt_2=prompt, | |
| prompt_3=prompt, | |
| negative_prompt="", | |
| negative_prompt_2="", | |
| negative_prompt_3="", | |
| do_classifier_free_guidance=7.0, | |
| device="cuda", | |
| ) | |
| # detach the embeddings | |
| prompt_embeds = prompt_embeds.detach() | |
| negative_prompt_embeds = negative_prompt_embeds.detach() | |
| pooled_prompt_embeds = pooled_prompt_embeds.detach() | |
| negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.detach() | |
| prompt_embeddings.append((prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)) | |
| # Take first embed and set it as starting point, leaving rest as list we'll loop over. | |
| prompt_embedding_a, *prompt_embeddings = prompt_embeddings | |
| # Take first seed and use it to generate init noise | |
| init_seed, *seeds = seeds | |
| init_a = torch.randn( | |
| (1, 16, height // 8, width // 8), | |
| device="cuda", | |
| generator=torch.Generator(device='cuda').manual_seed(init_seed) | |
| ) | |
| frame_index = 0 | |
| for p, prompt_embedding_b in enumerate(prompt_embeddings): | |
| init_b = torch.randn( | |
| (1, 16, height // 8, width // 8), | |
| generator=torch.Generator(device='cuda').manual_seed(seeds[p]), | |
| device="cuda" | |
| ) | |
| for i, t in enumerate(np.linspace(0, 1, num_steps)): | |
| print("dreaming... ", frame_index) | |
| cond_embedding = slerp(float(t), prompt_embedding_a, prompt_embedding_b) | |
| init = torch.lerp(init_a, init_b, t) | |
| with autocast("cuda"): | |
| image = diffuse(pipe, cond_embedding, init, num_inference_steps, guidance_scale, eta) | |
| im = Image.fromarray(image) | |
| outpath = os.path.join(outdir, 'frame%06d.jpg' % frame_index) | |
| im.save(outpath) | |
| frame_index += 1 | |
| prompt_embedding_a = prompt_embedding_b | |
| init_a = init_b | |
| if __name__ == '__main__': | |
| fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment