Skip to content

Instantly share code, notes, and snippets.

@Georgefwt
Created June 18, 2024 15:11
Show Gist options
  • Select an option

  • Save Georgefwt/b90e165a6e96cca65340a20efcf733b2 to your computer and use it in GitHub Desktop.

Select an option

Save Georgefwt/b90e165a6e96cca65340a20efcf733b2 to your computer and use it in GitHub Desktop.
stable diffusion 3 clip walk and latent space walk
import os
import inspect
import fire
from diffusers import StableDiffusion3Pipeline
from time import time
from PIL import Image
from einops import rearrange
import numpy as np
import torch
from torch import autocast
from torchvision.utils import make_grid
from tqdm import tqdm
# -----------------------------------------------------------------------------
@torch.no_grad()
def diffuse(
pipe,
cond_embeddings,
cond_latents,
num_inference_steps,
guidance_scale,
eta,
):
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = cond_embeddings
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
pipe.scheduler.set_timesteps(num_inference_steps, device="cuda")
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(pipe.scheduler.timesteps):
latent_model_input = torch.cat([cond_latents] * 2)
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.device)
# predict the noise residual
noise_pred = pipe.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
return_dict=False,
)[0]
# cfg
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, return_dict=False, s_noise=0.0)[0]
progress_bar.update()
# scale and decode the image latents with vae
cond_latents = (cond_latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
image = pipe.vae.decode(cond_latents, return_dict=False)[0]
# generate output numpy image as uint8
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image[0] * 255).astype(np.uint8)
return image
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
"""Helper function to spherically interpolate two arrays v1 v2 in PyTorch"""
v2 = []
for i in range(len(v0)):
v0_ = v0[i]
v1_ = v1[i]
dot = torch.sum((v0_ / torch.norm(v0_)) * (v1_ / torch.norm(v1_)))
dot = torch.clamp(dot, -1.0, 1.0) # Clip to handle numerical issues
if torch.abs(dot) > DOT_THRESHOLD:
v2_ = (1 - t) * v0_ + t * v1_
else:
theta_0 = torch.acos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = torch.sin(theta_t)
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2_ = s0 * v0_ + s1 * v1_
v2.append(v2_)
return tuple(v2)
def main(
# --------------------------------------
# args you probably want to change
prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompts to dream about
seeds=[243, 523],
gpu = 0, # id of the gpu to run on
name = 'berry_good_spaghetti', # name of this project, for the output directory
rootdir = './dreams',
num_steps = 72, # number of steps between each pair of sampled points
# --------------------------------------
# args you probably don't want to change
num_inference_steps = 30,
guidance_scale = 7.5,
eta = 0.0,
width = 1024,
height = 1024,
# --------------------------------------
):
assert len(prompts) == len(seeds)
assert torch.cuda.is_available()
assert height % 8 == 0 and width % 8 == 0
# init the output dir
outdir = os.path.join(rootdir, name)
os.makedirs(outdir, exist_ok=True)
# # init all of the models and move them to a given GPU
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
# get the conditional text embeddings based on the prompts
prompt_embeddings = []
for prompt in prompts:
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
prompt=prompt,
prompt_2=prompt,
prompt_3=prompt,
negative_prompt="",
negative_prompt_2="",
negative_prompt_3="",
do_classifier_free_guidance=7.0,
device="cuda",
)
# detach the embeddings
prompt_embeds = prompt_embeds.detach()
negative_prompt_embeds = negative_prompt_embeds.detach()
pooled_prompt_embeds = pooled_prompt_embeds.detach()
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.detach()
prompt_embeddings.append((prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds))
# Take first embed and set it as starting point, leaving rest as list we'll loop over.
prompt_embedding_a, *prompt_embeddings = prompt_embeddings
# Take first seed and use it to generate init noise
init_seed, *seeds = seeds
init_a = torch.randn(
(1, 16, height // 8, width // 8),
device="cuda",
generator=torch.Generator(device='cuda').manual_seed(init_seed)
)
frame_index = 0
for p, prompt_embedding_b in enumerate(prompt_embeddings):
init_b = torch.randn(
(1, 16, height // 8, width // 8),
generator=torch.Generator(device='cuda').manual_seed(seeds[p]),
device="cuda"
)
for i, t in enumerate(np.linspace(0, 1, num_steps)):
print("dreaming... ", frame_index)
cond_embedding = slerp(float(t), prompt_embedding_a, prompt_embedding_b)
init = torch.lerp(init_a, init_b, t)
with autocast("cuda"):
image = diffuse(pipe, cond_embedding, init, num_inference_steps, guidance_scale, eta)
im = Image.fromarray(image)
outpath = os.path.join(outdir, 'frame%06d.jpg' % frame_index)
im.save(outpath)
frame_index += 1
prompt_embedding_a = prompt_embedding_b
init_a = init_b
if __name__ == '__main__':
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment