Skip to content

Instantly share code, notes, and snippets.

@Georgefwt
Created June 10, 2024 14:55
Show Gist options
  • Select an option

  • Save Georgefwt/160be561823154df4b4ae079fa31ddb4 to your computer and use it in GitHub Desktop.

Select an option

Save Georgefwt/160be561823154df4b4ae079fa31ddb4 to your computer and use it in GitHub Desktop.
break down sd-turbo one step sampling for simple headed person like me
# break down one step sampling for simple headed person like me
import torch
from PIL import Image
from diffusers import (
UNet2DConditionModel,
AutoencoderKL,
)
from diffusers.utils.torch_utils import randn_tensor
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
sdturbo_path="stabilityai/sd-turbo"
batch_size = 1
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
device="cuda"
height=512
width=512
generator = torch.Generator(device=device).manual_seed(42)
# ------------ prepare models
tokenizer = CLIPTokenizer.from_pretrained(
sdturbo_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
sdturbo_path, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(
sdturbo_path, subfolder="vae",).to(device)
unet = UNet2DConditionModel.from_pretrained(
sdturbo_path, subfolder="unet").to(device)
# ------------ encode prompt
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)
prompt_embeds = prompt_embeds[0]
# ------------ prepare latents
vae_scale_factor = 8
shape = (batch_size, 4, height // vae_scale_factor, width // vae_scale_factor)
latents = randn_tensor(shape, generator=generator, device=torch.device(device))
# ------------ predict noise
t = torch.tensor([999], device=device)
noise_pred = unet(
latents,
t,
encoder_hidden_states=prompt_embeds,
return_dict=False,
)[0]
# ------------ prepare x0
beta_start = 0.00085
beta_end = 0.012
num_train_timesteps=1000
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps) ** 2 # beta_schedule == "scaled_linear"
betas = betas.to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alpha_prod_t = alphas_cumprod[t]
beta_prod_t = 1 - alpha_prod_t
predicted_original_sample = (latents - beta_prod_t.sqrt() * noise_pred) / alpha_prod_t.sqrt()
# ------------ decode image
image = vae.decode(predicted_original_sample / vae.config.scaling_factor, return_dict=False, generator=generator)[0]
image = image.permute(0, 2, 3, 1).detach().cpu().numpy()
image = ((image + 1) / 2).clip(0, 1)
image = (image * 255).astype("uint8")
image = Image.fromarray(image[0])
image.save("result.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment