# pip install -U --user diffusers transformers huggingface_hub
Requirement already satisfied: diffusers in /usr/local/lib/python3.9/dist-packages (0.6.0) Requirement already satisfied: transformers in /usr/local/lib/python3.9/dist-packages (4.23.1) Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.9/dist-packages (0.10.1) Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from diffusers) (3.7.1) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from diffusers) (2022.7.9) Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from diffusers) (2.28.1) Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from diffusers) (1.23.1) Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.9/dist-packages (from diffusers) (4.12.0) Requirement already satisfied: Pillow<10.0 in /usr/local/lib/python3.9/dist-packages (from diffusers) (9.2.0) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (21.3) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (5.4.1) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.9/dist-packages (from transformers) (4.64.0) Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.12.1) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface_hub) (4.3.0) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.9/dist-packages (from packaging>=20.0->transformers) (3.0.9) Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata->diffusers) (3.8.1) Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests->diffusers) (2019.11.28) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->diffusers) (1.26.10) Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests->diffusers) (2.8) Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.9/dist-packages (from requests->diffusers) (2.1.0) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv Note: you may need to restart the kernel to use updated packages.
%load_ext autoreload
%autoreload 2
from PIL import Image
from fastcore.all import concat
import torch, logging
from pathlib import Path
from huggingface_hub import notebook_login
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from miniai.stability import *
logging.disable(logging.WARNING)
torch.manual_seed(1)
if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()
guidance_scale = 7.5
num_inference_steps = 50
width = height = 512
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to("cuda")
# Here we use a different VAE to the original release, which has been fine-tuned for more steps
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16).to("cuda")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cuda")
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
from urllib.request import urlretrieve
img_url = 'https://huggingface.co/blog/assets/98_stable_diffusion/stable_diffusion_12_1.png'
img_path = Path('horse.png')
if not img_path.exists(): urlretrieve(img_url, img_path)
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
img = preprocess(Image.open(img_path).convert('RGB'))
show_image(img[0]);
def encode(x):
latents = vae.encode(x.to("cuda", dtype=torch.float16)).latent_dist.sample()
return latents * 0.18215
latents = encode(img)
show_images(latents[0].detach()); # TODO: detach in `show_images`?
# TODO: try different `r`s
r = len(scheduler.timesteps)//2
timesteps = scheduler.timesteps[[r]]
noise = torch.randn_like(latents, device='cuda')
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
show_images(noisey_latents[0].detach());
inp = scheduler.scale_model_input(torch.cat([noisy_latents] * 2), ts).cuda()
show_images(decode(inp).detach().to(dtype=torch.float32))
def embed(prompts):
tokens = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
return text_encoder(tokens.input_ids.to("cuda"))[0].half()
prompts = ['horse', 'zebra']
t = embed(prompts)
u = embed([""] * len(prompts))
emb = torch.cat([u, t])
emb.shape
torch.Size([4, 77, 768])
torch.manual_seed(100)
g = guidance_scale
ts = timesteps
scheduler.set_timesteps(num_inference_steps)
with torch.no_grad(): u,t = unet(inp, ts.cuda(), encoder_hidden_states=emb.cuda()).sample.chunk(2)
pred = u + g*(t-u)
latents = scheduler.step(pred, ts, noisy_latents).prev_sample
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Input In [137], in <cell line: 2>() 1 scheduler.set_timesteps(num_inference_steps) ----> 2 with torch.no_grad(): u,t = unet(inp, ts.cuda(), encoder_hidden_states=emb.cuda()).sample.chunk(2) 3 pred = u + g*(t-u) 4 latents = scheduler.step(pred, ts, noisy_latents).prev_sample File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/unet_2d_condition.py:296, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, return_dict) 294 for downsample_block in self.down_blocks: 295 if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: --> 296 sample, res_samples = downsample_block( 297 hidden_states=sample, 298 temb=emb, 299 encoder_hidden_states=encoder_hidden_states, 300 ) 301 else: 302 sample, res_samples = downsample_block(hidden_states=sample, temb=emb) File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/unet_blocks.py:563, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states) 561 else: 562 hidden_states = resnet(hidden_states, temb) --> 563 hidden_states = attn(hidden_states, context=encoder_hidden_states) 565 output_states += (hidden_states,) 567 if self.downsamplers is not None: File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:162, in SpatialTransformer.forward(self, hidden_states, context) 160 hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 161 for block in self.transformer_blocks: --> 162 hidden_states = block(hidden_states, context=context) 163 hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) 164 hidden_states = self.proj_out(hidden_states) File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:212, in BasicTransformerBlock.forward(self, hidden_states, context) 210 hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states 211 hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states --> 212 hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states 213 hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 214 return hidden_states File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:283, in CrossAttention.forward(self, hidden_states, context, mask) 278 # TODO(PVP) - mask is currently never used. Remember to re-implement when used 279 280 # attention, what we cannot get enough of 282 if self._slice_size is None or query.shape[0] // self._slice_size == 1: --> 283 hidden_states = self._attention(query, key, value) 284 else: 285 hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:291, in CrossAttention._attention(self, query, key, value) 289 def _attention(self, query, key, value): 290 # TODO: use baddbmm for better performance --> 291 attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale 292 attention_probs = attention_scores.softmax(dim=-1) 293 # compute attention output RuntimeError: The size of tensor a (16) must match the size of tensor b (32) at non-singleton dimension 0
show_images(decode(torch.concat([u,t,pred,latents])).detach().to(dtype=torch.float32))
for i,ts in enumerate(tqdm(scheduler.timesteps)):
inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
pred = u + g*(t-u)
latents = scheduler.step(pred, ts, latents).prev_sample
def decode(x):
with torch.no_grad(): res = vae.decode(1 / 0.18215 * x).sample
return (res / 2 + 0.5).clamp(0, 1)
res = decode(latents)
show_images(res.detach().to(dtype=torch.float32));
def text_enc(prompts, maxlen=None):
if maxlen is None: maxlen = tokenizer.model_max_length
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
return text_encoder(inp.input_ids.to("cuda"))[0].half()
def mk_img(t):
image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
return Image.fromarray((image*255).round().astype("uint8"))
prompts = ['a photograph of an astronaut riding a horse']
seed = 100
steps = 50
bs = len(prompts)
text = text_enc(prompts)
uncond = text_enc([""] * bs, text.shape[1])
emb = torch.cat([uncond, text])
if seed: torch.manual_seed(seed)
latents = torch.randn((bs, unet.in_channels, height//8, width//8))
scheduler.set_timesteps(steps)
latents = latents.to("cuda").half() * scheduler.init_noise_sigma
for i,ts in enumerate(tqdm(scheduler.timesteps)):
inp = scheduler.scale_model_input(latents, ts)
with torch.no_grad(): x = unet(inp, ts, encoder_hidden_states=emb).sample
pred = u + g*(t-u)
latents = scheduler.step(pred, ts, latents).prev_sample
with torch.no_grad(): res = vae.decode(1 / 0.18215 * latents).sample
100%|██████████| 50/50 [00:27<00:00, 1.85it/s]
ts = scheduler.timesteps[0]
inp = scheduler.scale_model_input(latents, ts)
with torch.no_grad(): x = unet(inp, ts, encoder_hidden_states=emb).sample
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Input In [175], in <cell line: 1>() ----> 1 with torch.no_grad(): x = unet(inp, ts, encoder_hidden_states=emb).sample File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/unet_2d_condition.py:296, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, return_dict) 294 for downsample_block in self.down_blocks: 295 if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: --> 296 sample, res_samples = downsample_block( 297 hidden_states=sample, 298 temb=emb, 299 encoder_hidden_states=encoder_hidden_states, 300 ) 301 else: 302 sample, res_samples = downsample_block(hidden_states=sample, temb=emb) File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/unet_blocks.py:563, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states) 561 else: 562 hidden_states = resnet(hidden_states, temb) --> 563 hidden_states = attn(hidden_states, context=encoder_hidden_states) 565 output_states += (hidden_states,) 567 if self.downsamplers is not None: File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:162, in SpatialTransformer.forward(self, hidden_states, context) 160 hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 161 for block in self.transformer_blocks: --> 162 hidden_states = block(hidden_states, context=context) 163 hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) 164 hidden_states = self.proj_out(hidden_states) File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:212, in BasicTransformerBlock.forward(self, hidden_states, context) 210 hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states 211 hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states --> 212 hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states 213 hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 214 return hidden_states File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:283, in CrossAttention.forward(self, hidden_states, context, mask) 278 # TODO(PVP) - mask is currently never used. Remember to re-implement when used 279 280 # attention, what we cannot get enough of 282 if self._slice_size is None or query.shape[0] // self._slice_size == 1: --> 283 hidden_states = self._attention(query, key, value) 284 else: 285 hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:291, in CrossAttention._attention(self, query, key, value) 289 def _attention(self, query, key, value): 290 # TODO: use baddbmm for better performance --> 291 attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale 292 attention_probs = attention_scores.softmax(dim=-1) 293 # compute attention output RuntimeError: The size of tensor a (8) must match the size of tensor b (16) at non-singleton dimension 0
show_images(res.to(dtype=torch.float32), figsize=(10,10));