Working non-lora training; other changes
This commit is contained in:
parent
c680890d5c
commit
269615a35e
1
.gitignore
vendored
1
.gitignore
vendored
@ -303,3 +303,4 @@ tags
|
|||||||
|
|
||||||
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim
|
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim
|
||||||
|
|
||||||
|
wandb/
|
@ -10,6 +10,7 @@ def get_config():
|
|||||||
config.num_epochs = 100
|
config.num_epochs = 100
|
||||||
config.mixed_precision = "fp16"
|
config.mixed_precision = "fp16"
|
||||||
config.allow_tf32 = True
|
config.allow_tf32 = True
|
||||||
|
config.use_lora = True
|
||||||
|
|
||||||
# pretrained model initialization
|
# pretrained model initialization
|
||||||
config.pretrained = pretrained = ml_collections.ConfigDict()
|
config.pretrained = pretrained = ml_collections.ConfigDict()
|
||||||
@ -20,7 +21,6 @@ def get_config():
|
|||||||
config.train = train = ml_collections.ConfigDict()
|
config.train = train = ml_collections.ConfigDict()
|
||||||
train.batch_size = 1
|
train.batch_size = 1
|
||||||
train.use_8bit_adam = False
|
train.use_8bit_adam = False
|
||||||
train.scale_lr = False
|
|
||||||
train.learning_rate = 1e-4
|
train.learning_rate = 1e-4
|
||||||
train.adam_beta1 = 0.9
|
train.adam_beta1 = 0.9
|
||||||
train.adam_beta2 = 0.999
|
train.adam_beta2 = 0.999
|
||||||
@ -35,7 +35,7 @@ def get_config():
|
|||||||
|
|
||||||
# sampling
|
# sampling
|
||||||
config.sample = sample = ml_collections.ConfigDict()
|
config.sample = sample = ml_collections.ConfigDict()
|
||||||
sample.num_steps = 30
|
sample.num_steps = 5
|
||||||
sample.eta = 1.0
|
sample.eta = 1.0
|
||||||
sample.guidance_scale = 5.0
|
sample.guidance_scale = 5.0
|
||||||
sample.batch_size = 1
|
sample.batch_size = 1
|
||||||
|
@ -4,16 +4,19 @@ from ddpo_pytorch.config import base
|
|||||||
def get_config():
|
def get_config():
|
||||||
config = base.get_config()
|
config = base.get_config()
|
||||||
|
|
||||||
config.mixed_precision = "bf16"
|
config.mixed_precision = "no"
|
||||||
config.allow_tf32 = True
|
config.allow_tf32 = True
|
||||||
|
config.use_lora = False
|
||||||
|
|
||||||
config.train.batch_size = 8
|
config.train.batch_size = 4
|
||||||
config.train.gradient_accumulation_steps = 4
|
config.train.gradient_accumulation_steps = 8
|
||||||
|
config.train.learning_rate = 1e-5
|
||||||
|
config.train.clip_range = 1.0
|
||||||
|
|
||||||
# sampling
|
# sampling
|
||||||
config.sample.num_steps = 50
|
config.sample.num_steps = 50
|
||||||
config.sample.batch_size = 8
|
config.sample.batch_size = 16
|
||||||
config.sample.num_batches_per_epoch = 4
|
config.sample.num_batches_per_epoch = 2
|
||||||
|
|
||||||
config.per_prompt_stat_tracking = None
|
config.per_prompt_stat_tracking = None
|
||||||
|
|
||||||
|
110
scripts/train.py
110
scripts/train.py
@ -1,4 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
from absl import app, flags, logging
|
from absl import app, flags, logging
|
||||||
from ml_collections import config_flags
|
from ml_collections import config_flags
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
@ -17,6 +19,8 @@ import torch
|
|||||||
import wandb
|
import wandb
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import tempfile
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
tqdm = partial(tqdm.tqdm, dynamic_ncols=True)
|
tqdm = partial(tqdm.tqdm, dynamic_ncols=True)
|
||||||
|
|
||||||
@ -46,9 +50,9 @@ def main(_):
|
|||||||
# load scheduler, tokenizer and models.
|
# load scheduler, tokenizer and models.
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)
|
pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)
|
||||||
# freeze parameters of models to save more memory
|
# freeze parameters of models to save more memory
|
||||||
pipeline.unet.requires_grad_(False)
|
|
||||||
pipeline.vae.requires_grad_(False)
|
pipeline.vae.requires_grad_(False)
|
||||||
pipeline.text_encoder.requires_grad_(False)
|
pipeline.text_encoder.requires_grad_(False)
|
||||||
|
pipeline.unet.requires_grad_(not config.use_lora)
|
||||||
# disable safety checker
|
# disable safety checker
|
||||||
pipeline.safety_checker = None
|
pipeline.safety_checker = None
|
||||||
# make the progress bar nicer
|
# make the progress bar nicer
|
||||||
@ -56,27 +60,33 @@ def main(_):
|
|||||||
position=1,
|
position=1,
|
||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
leave=False,
|
leave=False,
|
||||||
|
desc="Timestep",
|
||||||
|
dynamic_ncols=True,
|
||||||
)
|
)
|
||||||
# switch to DDIM scheduler
|
# switch to DDIM scheduler
|
||||||
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||||
|
|
||||||
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
|
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||||
weight_dtype = torch.float32
|
inference_dtype = torch.float32
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
weight_dtype = torch.float16
|
inference_dtype = torch.float16
|
||||||
elif accelerator.mixed_precision == "bf16":
|
elif accelerator.mixed_precision == "bf16":
|
||||||
weight_dtype = torch.bfloat16
|
inference_dtype = torch.bfloat16
|
||||||
|
|
||||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
# Move unet, vae and text_encoder to device and cast to inference_dtype
|
||||||
pipeline.unet.to(accelerator.device, dtype=weight_dtype)
|
pipeline.vae.to(accelerator.device, dtype=inference_dtype)
|
||||||
pipeline.vae.to(accelerator.device, dtype=weight_dtype)
|
pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype)
|
||||||
pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype)
|
if config.use_lora:
|
||||||
|
pipeline.unet.to(accelerator.device, dtype=inference_dtype)
|
||||||
|
|
||||||
|
if config.use_lora:
|
||||||
# Set correct lora layers
|
# Set correct lora layers
|
||||||
lora_attn_procs = {}
|
lora_attn_procs = {}
|
||||||
for name in pipeline.unet.attn_processors.keys():
|
for name in pipeline.unet.attn_processors.keys():
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim
|
cross_attention_dim = (
|
||||||
|
None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim
|
||||||
|
)
|
||||||
if name.startswith("mid_block"):
|
if name.startswith("mid_block"):
|
||||||
hidden_size = pipeline.unet.config.block_out_channels[-1]
|
hidden_size = pipeline.unet.config.block_out_channels[-1]
|
||||||
elif name.startswith("up_blocks"):
|
elif name.startswith("up_blocks"):
|
||||||
@ -87,9 +97,10 @@ def main(_):
|
|||||||
hidden_size = pipeline.unet.config.block_out_channels[block_id]
|
hidden_size = pipeline.unet.config.block_out_channels[block_id]
|
||||||
|
|
||||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||||
|
|
||||||
pipeline.unet.set_attn_processor(lora_attn_procs)
|
pipeline.unet.set_attn_processor(lora_attn_procs)
|
||||||
lora_layers = AttnProcsLayers(pipeline.unet.attn_processors)
|
trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors)
|
||||||
|
else:
|
||||||
|
trainable_layers = pipeline.unet
|
||||||
|
|
||||||
# Enable TF32 for faster training on Ampere GPUs,
|
# Enable TF32 for faster training on Ampere GPUs,
|
||||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||||
@ -110,7 +121,7 @@ def main(_):
|
|||||||
optimizer_cls = torch.optim.AdamW
|
optimizer_cls = torch.optim.AdamW
|
||||||
|
|
||||||
optimizer = optimizer_cls(
|
optimizer = optimizer_cls(
|
||||||
lora_layers.parameters(),
|
trainable_layers.parameters(),
|
||||||
lr=config.train.learning_rate,
|
lr=config.train.learning_rate,
|
||||||
betas=(config.train.adam_beta1, config.train.adam_beta2),
|
betas=(config.train.adam_beta1, config.train.adam_beta2),
|
||||||
weight_decay=config.train.adam_weight_decay,
|
weight_decay=config.train.adam_weight_decay,
|
||||||
@ -121,8 +132,31 @@ def main(_):
|
|||||||
prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn)
|
prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn)
|
||||||
reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)()
|
reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)()
|
||||||
|
|
||||||
|
# generate negative prompt embeddings
|
||||||
|
neg_prompt_embed = pipeline.text_encoder(
|
||||||
|
pipeline.tokenizer(
|
||||||
|
[""],
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=pipeline.tokenizer.model_max_length,
|
||||||
|
).input_ids.to(accelerator.device)
|
||||||
|
)[0]
|
||||||
|
sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1)
|
||||||
|
train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1)
|
||||||
|
|
||||||
|
# initialize stat tracker
|
||||||
|
if config.per_prompt_stat_tracking:
|
||||||
|
stat_tracker = PerPromptStatTracker(
|
||||||
|
config.per_prompt_stat_tracking.buffer_size,
|
||||||
|
config.per_prompt_stat_tracking.min_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
# for some reason, autocast is necessary for non-lora training but for lora training it uses more memory
|
||||||
|
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
|
||||||
|
|
||||||
# Prepare everything with our `accelerator`.
|
# Prepare everything with our `accelerator`.
|
||||||
lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer)
|
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
|
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
|
||||||
@ -144,27 +178,10 @@ def main(_):
|
|||||||
assert config.sample.batch_size % config.train.batch_size == 0
|
assert config.sample.batch_size % config.train.batch_size == 0
|
||||||
assert samples_per_epoch % total_train_batch_size == 0
|
assert samples_per_epoch % total_train_batch_size == 0
|
||||||
|
|
||||||
neg_prompt_embed = pipeline.text_encoder(
|
|
||||||
pipeline.tokenizer(
|
|
||||||
[""],
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
max_length=pipeline.tokenizer.model_max_length,
|
|
||||||
).input_ids.to(accelerator.device)
|
|
||||||
)[0]
|
|
||||||
sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1)
|
|
||||||
train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1)
|
|
||||||
|
|
||||||
if config.per_prompt_stat_tracking:
|
|
||||||
stat_tracker = PerPromptStatTracker(
|
|
||||||
config.per_prompt_stat_tracking.buffer_size,
|
|
||||||
config.per_prompt_stat_tracking.min_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
for epoch in range(config.num_epochs):
|
for epoch in range(config.num_epochs):
|
||||||
#################### SAMPLING ####################
|
#################### SAMPLING ####################
|
||||||
|
pipeline.unet.eval()
|
||||||
samples = []
|
samples = []
|
||||||
prompts = []
|
prompts = []
|
||||||
for i in tqdm(
|
for i in tqdm(
|
||||||
@ -189,8 +206,7 @@ def main(_):
|
|||||||
prompt_embeds = pipeline.text_encoder(prompt_ids)[0]
|
prompt_embeds = pipeline.text_encoder(prompt_ids)[0]
|
||||||
|
|
||||||
# sample
|
# sample
|
||||||
pipeline.unet.eval()
|
with autocast():
|
||||||
pipeline.vae.eval()
|
|
||||||
images, _, latents, log_probs = pipeline_with_logprob(
|
images, _, latents, log_probs = pipeline_with_logprob(
|
||||||
pipeline,
|
pipeline,
|
||||||
prompt_embeds=prompt_embeds,
|
prompt_embeds=prompt_embeds,
|
||||||
@ -226,14 +242,26 @@ def main(_):
|
|||||||
# gather rewards across processes
|
# gather rewards across processes
|
||||||
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
||||||
|
|
||||||
# log sample-related stuff
|
# log rewards and images
|
||||||
accelerator.log({"reward": rewards, "epoch": epoch}, step=global_step)
|
|
||||||
accelerator.log(
|
accelerator.log(
|
||||||
{"images": [wandb.Image(image, caption=prompt) for image, prompt in zip(images, prompts)]},
|
{"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()},
|
||||||
|
step=global_step,
|
||||||
|
)
|
||||||
|
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
pil = Image.fromarray((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
|
||||||
|
pil = pil.resize((256, 256))
|
||||||
|
pil.save(os.path.join(tmpdir, f"{i}.jpg"))
|
||||||
|
accelerator.log(
|
||||||
|
{
|
||||||
|
"images": [
|
||||||
|
wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=prompt)
|
||||||
|
for i, prompt in enumerate(prompts)
|
||||||
|
],
|
||||||
|
},
|
||||||
step=global_step,
|
step=global_step,
|
||||||
)
|
)
|
||||||
# from PIL import Image
|
|
||||||
# Image.fromarray((images[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)).save(f"test.png")
|
|
||||||
|
|
||||||
# per-prompt mean/std tracking
|
# per-prompt mean/std tracking
|
||||||
if config.per_prompt_stat_tracking:
|
if config.per_prompt_stat_tracking:
|
||||||
@ -271,6 +299,7 @@ def main(_):
|
|||||||
samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())]
|
samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())]
|
||||||
|
|
||||||
# train
|
# train
|
||||||
|
pipeline.unet.train()
|
||||||
for i, sample in tqdm(
|
for i, sample in tqdm(
|
||||||
list(enumerate(samples_batched)),
|
list(enumerate(samples_batched)),
|
||||||
desc=f"Epoch {epoch}.{inner_epoch}: training",
|
desc=f"Epoch {epoch}.{inner_epoch}: training",
|
||||||
@ -286,12 +315,13 @@ def main(_):
|
|||||||
info = defaultdict(list)
|
info = defaultdict(list)
|
||||||
for j in tqdm(
|
for j in tqdm(
|
||||||
range(num_timesteps),
|
range(num_timesteps),
|
||||||
desc=f"Timestep",
|
desc="Timestep",
|
||||||
position=1,
|
position=1,
|
||||||
leave=False,
|
leave=False,
|
||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
):
|
):
|
||||||
with accelerator.accumulate(pipeline.unet):
|
with accelerator.accumulate(pipeline.unet):
|
||||||
|
with autocast():
|
||||||
if config.train.cfg:
|
if config.train.cfg:
|
||||||
noise_pred = pipeline.unet(
|
noise_pred = pipeline.unet(
|
||||||
torch.cat([sample["latents"][:, j]] * 2),
|
torch.cat([sample["latents"][:, j]] * 2),
|
||||||
@ -337,7 +367,7 @@ def main(_):
|
|||||||
# backward pass
|
# backward pass
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm)
|
accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user