!pip install -Uqq accelerate diffusers["training"] transformers ftfy "ipywidgets>=7,<8" fastcore bitsandbytes
import argparse, itertools, math, os, random, PIL
import numpy as np, torch, torch.nn.functional as F, torch.utils.checkpoint
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from PIL.Image import Resampling
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import fastcore.all as fc
from huggingface_hub import notebook_login
from pathlib import Path
import torchvision.transforms.functional as tf
import accelerate
torch.manual_seed(1)
if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()
model_nm = "CompVis/stable-diffusion-v1-4"
urls = [
"https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg",
"https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg",
"https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg",
"https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg",
## You can add additional images here
]
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
import requests
from io import BytesIO
save_path = Path.home()/"my_concept"
save_path.mkdir(exist_ok=True)
images = []
for i,url in enumerate(urls):
p = save_path/f"{i}.jpeg"
if p.exists(): image = Image.open(p)
else:
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
image.save(p)
images.append(image)
image_grid(images, 1, len(images))
# path = Path.home()/'Downloads/photos/'
# paths = list(path.iterdir())
# images = [Image.open(p).resize((512, 512), resample=Resampling.BICUBIC).convert("RGB") for p in paths]
what_to_teach = "object"
placeholder_token = "<tiny>"
initializer_token = "teddy"
templates = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
class TextualInversionDataset:
def __init__(self, tokenizer, images, learnable_property="object", size=512,
repeats=100, interpolation=Resampling.BICUBIC, flip_p=0.5, set="train", placeholder_token="*"):
fc.store_attr()
self.num_images = len(images)
if set == "train": self._length = self.num_images * repeats
self.templates = style_templates if learnable_property == "style" else templates
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
def __len__(self): return self.num_images
def __getitem__(self, i):
image = tf.to_tensor(self.images[i%self.num_images])*2-1
text = random.choice(self.templates).format(self.placeholder_token)
ids = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, return_tensors="pt")
return dict(input_ids=ids.input_ids[0], pixel_values=image)
tokenizer = CLIPTokenizer.from_pretrained(model_nm, subfolder="tokenizer", torch_type=torch.float16, revision="fp16")
# TODO: torch_type not supported?
text_encoder = CLIPTextModel.from_pretrained(model_nm, subfolder="text_encoder", revision="fp16")
vae = AutoencoderKL.from_pretrained(model_nm, subfolder="vae", torch_type=torch.float16, revision="fp16")
unet = UNet2DConditionModel.from_pretrained(model_nm, subfolder="unet", torch_type=torch.float16, revision="fp16")
num_added_tokens = tokenizer.add_tokens(placeholder_token)
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
# Freeze all parameters except for the token embeddings in text encoder
tm = text_encoder.text_model
for o in (vae, unet, tm.encoder, tm.final_layer_norm, tm.embeddings.position_embedding):
for p in o.parameters(): p.requires_grad = False
train_dataset = TextualInversionDataset(
images=images, tokenizer=tokenizer, size=512, placeholder_token=placeholder_token,
repeats=100, learnable_property=what_to_teach, set="train")
def create_dataloader(bs=1): return DataLoader(train_dataset, batch_size=bs, shuffle=True)
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
import bitsandbytes as bnb
def training_function(text_encoder, vae, unet, train_batch_size, gradient_accumulation_steps,
lr, max_train_steps, scale_lr, output_dir):
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision='fp16')
train_dataloader = create_dataloader(train_batch_size)
if scale_lr: lr = (lr * gradient_accumulation_steps * train_batch_size * accelerator.num_processes)
#optimizer = torch.optim.AdamW(text_encoder.get_input_embeddings().parameters(), lr=lr)
optimizer = bnb.optim.AdamW8bit(text_encoder.get_input_embeddings().parameters(), lr=lr)
text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader)
vae.to(accelerator.device).eval()
unet.to(accelerator.device).eval()
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(num_train_epochs):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() * 0.18215
noise = torch.randn(latents.shape).to(latents.device)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# We only want to optimize the concept embeddings
grads = text_encoder.get_input_embeddings().weight.grad
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
optimizer.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
progress_bar.set_postfix(loss=loss.detach().item())
if global_step >= max_train_steps: break
pipeline = StableDiffusionPipeline(
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae, unet=unet, tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"))
pipeline.save_pretrained(output_dir)
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(output_dir, "learned_embeds.bin"))
from functools import partial
f = partial(training_function, train_batch_size=1, gradient_accumulation_steps=2, lr=5e-04,
max_train_steps=3000, scale_lr=True, output_dir="sd-concept-output")
import accelerate
torch.manual_seed(42)
accelerate.notebook_launcher(f, args=(text_encoder, vae, unet), num_processes=1)
Launching training on one GPU.
0%| | 0/3000 [00:00<?, ?it/s]
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Input In [60], in <cell line: 3>() 1 import accelerate 2 torch.manual_seed(42) ----> 3 accelerate.notebook_launcher(f, args=(text_encoder, vae, unet), num_processes=1) File /usr/local/lib/python3.9/dist-packages/accelerate/launchers.py:134, in notebook_launcher(function, args, num_processes, use_fp16, mixed_precision, use_port) 132 print("Launching training on CPU.") 133 with patch_environment(use_mps_device=use_mps_device): --> 134 function(*args) Input In [52], in training_function(text_encoder, vae, unet, train_batch_size, gradient_accumulation_steps, lr, max_train_steps, scale_lr, output_dir) 28 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 29 encoder_hidden_states = text_encoder(batch["input_ids"])[0] ---> 30 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 31 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 32 accelerator.backward(loss) File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/unet_2d_condition.py:296, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, return_dict) 294 for downsample_block in self.down_blocks: 295 if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: --> 296 sample, res_samples = downsample_block( 297 hidden_states=sample, 298 temb=emb, 299 encoder_hidden_states=encoder_hidden_states, 300 ) 301 else: 302 sample, res_samples = downsample_block(hidden_states=sample, temb=emb) File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/unet_blocks.py:563, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states) 561 else: 562 hidden_states = resnet(hidden_states, temb) --> 563 hidden_states = attn(hidden_states, context=encoder_hidden_states) 565 output_states += (hidden_states,) 567 if self.downsamplers is not None: File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:162, in SpatialTransformer.forward(self, hidden_states, context) 160 hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 161 for block in self.transformer_blocks: --> 162 hidden_states = block(hidden_states, context=context) 163 hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) 164 hidden_states = self.proj_out(hidden_states) File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:211, in BasicTransformerBlock.forward(self, hidden_states, context) 209 def forward(self, hidden_states, context=None): 210 hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states --> 211 hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states 212 hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states 213 hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:283, in CrossAttention.forward(self, hidden_states, context, mask) 278 # TODO(PVP) - mask is currently never used. Remember to re-implement when used 279 280 # attention, what we cannot get enough of 282 if self._slice_size is None or query.shape[0] // self._slice_size == 1: --> 283 hidden_states = self._attention(query, key, value) 284 else: 285 hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) File /usr/local/lib/python3.9/dist-packages/diffusers/models/attention.py:291, in CrossAttention._attention(self, query, key, value) 289 def _attention(self, query, key, value): 290 # TODO: use baddbmm for better performance --> 291 attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale 292 attention_probs = attention_scores.softmax(dim=-1) 293 # compute attention output RuntimeError: CUDA out of memory. Tried to allocate 4.12 GiB (GPU 0; 15.90 GiB total capacity; 9.37 GiB already allocated; 2.29 GiB free; 12.76 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF