diff --git a/.gitignore b/.gitignore index 30b5f32..aecb0ea 100644 --- a/.gitignore +++ b/.gitignore @@ -303,4 +303,6 @@ tags # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim -wandb/ \ No newline at end of file +wandb/ +logs/ +notebooks/ \ No newline at end of file diff --git a/config/base.py b/config/base.py index e5e6139..6bc59ca 100644 --- a/config/base.py +++ b/config/base.py @@ -5,12 +5,16 @@ def get_config(): config = ml_collections.ConfigDict() # misc + config.run_name = "" config.seed = 42 config.logdir = "logs" config.num_epochs = 100 + config.save_freq = 20 + config.num_checkpoint_limit = 5 config.mixed_precision = "fp16" config.allow_tf32 = True config.use_lora = True + config.resume_from = "" # pretrained model initialization config.pretrained = pretrained = ml_collections.ConfigDict() diff --git a/scripts/train.py b/scripts/train.py index 6e7214c..fc458c4 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,12 +1,13 @@ from collections import defaultdict import contextlib import os +import datetime from absl import app, flags from ml_collections import config_flags from accelerate import Accelerator -from accelerate.utils import set_seed +from accelerate.utils import set_seed, ProjectConfiguration from accelerate.logging import get_logger -from diffusers import StableDiffusionPipeline, DDIMScheduler +from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers from diffusers.models.attention_processor import LoRAAttnProcessor import numpy as np @@ -35,19 +36,45 @@ def main(_): # basic Accelerate and logging setup config = FLAGS.config + unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + if not config.run_name: + config.run_name = unique_id + else: + config.run_name += "_" + unique_id + + if config.resume_from: + config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from)) + if "checkpoint_" not in os.path.basename(config.resume_from): + # get the most recent checkpoint in this directory + checkpoints = list(filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from))) + if len(checkpoints) == 0: + raise ValueError(f"No checkpoints found in {config.resume_from}") + config.resume_from = os.path.join( + config.resume_from, + sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1], + ) + # number of timesteps within each trajectory to train on num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) + accelerator_config = ProjectConfiguration( + project_dir=os.path.join(config.logdir, config.run_name), + automatic_checkpoint_naming=True, + total_limit=config.num_checkpoint_limit, + ) + accelerator = Accelerator( log_with="wandb", mixed_precision=config.mixed_precision, - project_dir=config.logdir, + project_config=accelerator_config, # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of # _samples_ to accumulate across gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, ) if accelerator.is_main_process: - accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) + accelerator.init_trackers( + project_name="ddpo-pytorch", config=config.to_dict(), init_kwargs={"wandb": {"name": config.run_name}} + ) logger.info(f"\n{config}") # set seed (device_specific is very important to get different prompts on different devices) @@ -108,6 +135,40 @@ def main(_): else: trainable_layers = pipeline.unet + # set up diffusers-friendly checkpoint saving with Accelerate + + def save_model_hook(models, weights, output_dir): + assert len(models) == 1 + if config.use_lora and isinstance(models[0], AttnProcsLayers): + pipeline.unet.save_attn_procs(output_dir) + elif not config.use_lora and isinstance(models[0], UNet2DConditionModel): + models[0].save_pretrained(os.path.join(output_dir, "unet")) + else: + raise ValueError(f"Unknown model type {type(models[0])}") + weights.pop() # ensures that accelerate doesn't try to handle saving of the model + + def load_model_hook(models, input_dir): + assert len(models) == 1 + if config.use_lora and isinstance(models[0], AttnProcsLayers): + # pipeline.unet.load_attn_procs(input_dir) + tmp_unet = UNet2DConditionModel.from_pretrained( + config.pretrained.model, revision=config.pretrained.revision, subfolder="unet" + ) + tmp_unet.load_attn_procs(input_dir) + models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict()) + del tmp_unet + elif not config.use_lora and isinstance(models[0], UNet2DConditionModel): + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + models[0].register_to_config(**load_model.config) + models[0].load_state_dict(load_model.state_dict()) + del load_model + else: + raise ValueError(f"Unknown model type {type(models[0])}") + models.pop() # ensures that accelerate doesn't try to handle loading of the model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if config.allow_tf32: @@ -185,8 +246,15 @@ def main(_): assert config.sample.batch_size % config.train.batch_size == 0 assert samples_per_epoch % total_train_batch_size == 0 + if config.resume_from: + logger.info(f"Resuming from {config.resume_from}") + accelerator.load_state(config.resume_from) + first_epoch = int(config.resume_from.split("_")[-1]) + 1 + else: + first_epoch = 0 + global_step = 0 - for epoch in range(config.num_epochs): + for epoch in range(first_epoch, config.num_epochs): #################### SAMPLING #################### pipeline.unet.eval() samples = [] @@ -387,7 +455,9 @@ def main(_): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0 + assert (j == num_train_timesteps - 1) and ( + i + 1 + ) % config.train.gradient_accumulation_steps == 0 # log training-related stuff info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} info = accelerator.reduce(info, reduction="mean") @@ -399,6 +469,9 @@ def main(_): # make sure we did an optimization step at the end of the inner epoch assert accelerator.sync_gradients + if epoch % config.save_freq == 0 and accelerator.is_main_process: + accelerator.save_state() + if __name__ == "__main__": app.run(main)