diff --git a/config/base.py b/config/base.py index ca0878f..e5e6139 100644 --- a/config/base.py +++ b/config/base.py @@ -32,14 +32,15 @@ def get_config(): train.cfg = True train.adv_clip_max = 10 train.clip_range = 1e-4 + train.timestep_fraction = 1.0 # sampling config.sample = sample = ml_collections.ConfigDict() - sample.num_steps = 5 + sample.num_steps = 10 sample.eta = 1.0 sample.guidance_scale = 5.0 sample.batch_size = 1 - sample.num_batches_per_epoch = 1 + sample.num_batches_per_epoch = 2 # prompting config.prompt_fn = "imagenet_animals" @@ -49,7 +50,7 @@ def get_config(): config.reward_fn = "jpeg_compressibility" config.per_prompt_stat_tracking = ml_collections.ConfigDict() - config.per_prompt_stat_tracking.buffer_size = 64 + config.per_prompt_stat_tracking.buffer_size = 16 config.per_prompt_stat_tracking.min_count = 16 return config \ No newline at end of file diff --git a/config/dgx.py b/config/dgx.py index 16b902b..c5a6230 100644 --- a/config/dgx.py +++ b/config/dgx.py @@ -8,20 +8,25 @@ base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py" def get_config(): config = base.get_config() - config.mixed_precision = "no" + config.pretrained.model = "runwayml/stable-diffusion-v1-5" + + config.mixed_precision = "fp16" config.allow_tf32 = True config.use_lora = False config.train.batch_size = 4 - config.train.gradient_accumulation_steps = 8 - config.train.learning_rate = 1e-5 - config.train.clip_range = 1.0 + config.train.gradient_accumulation_steps = 2 + config.train.learning_rate = 3e-5 + config.train.clip_range = 1e-4 # sampling config.sample.num_steps = 50 - config.sample.batch_size = 16 - config.sample.num_batches_per_epoch = 2 + config.sample.batch_size = 8 + config.sample.num_batches_per_epoch = 4 - config.per_prompt_stat_tracking = None + config.per_prompt_stat_tracking = { + "buffer_size": 16, + "min_count": 16, + } return config diff --git a/ddpo_pytorch/aesthetic_scorer.py b/ddpo_pytorch/aesthetic_scorer.py index 9f5263b..bb9af25 100644 --- a/ddpo_pytorch/aesthetic_scorer.py +++ b/ddpo_pytorch/aesthetic_scorer.py @@ -14,18 +14,15 @@ class MLP(nn.Module): super().__init__() self.layers = nn.Sequential( nn.Linear(768, 1024), - nn.Identity(), + nn.Dropout(0.2), nn.Linear(1024, 128), - nn.Identity(), + nn.Dropout(0.2), nn.Linear(128, 64), - nn.Identity(), + nn.Dropout(0.1), nn.Linear(64, 16), nn.Linear(16, 1), ) - state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) - self.load_state_dict(state_dict) - @torch.no_grad() def forward(self, embed): return self.layers(embed) @@ -37,6 +34,9 @@ class AestheticScorer(torch.nn.Module): self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") self.mlp = MLP() + state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) + self.mlp.load_state_dict(state_dict) + self.eval() @torch.no_grad() def __call__(self, images): @@ -44,5 +44,5 @@ class AestheticScorer(torch.nn.Module): inputs = {k: v.cuda() for k, v in inputs.items()} embed = self.clip.get_image_features(**inputs) # normalize embedding - embed = embed / embed.norm(dim=-1, keepdim=True) + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) return self.mlp(embed) diff --git a/ddpo_pytorch/rewards.py b/ddpo_pytorch/rewards.py index 716a8aa..1328ac6 100644 --- a/ddpo_pytorch/rewards.py +++ b/ddpo_pytorch/rewards.py @@ -35,8 +35,6 @@ def aesthetic_score(): scorer = AestheticScorer().cuda() def _fn(images, prompts, metadata): - if not isinstance(images, torch.Tensor): - images = torch.as_tensor(images) scores = scorer(images) return scores, {} diff --git a/scripts/train.py b/scripts/train.py index ba474b9..58395d9 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -34,17 +34,23 @@ logger = get_logger(__name__) def main(_): # basic Accelerate and logging setup config = FLAGS.config + + # number of timesteps within each trajectory to train on + num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) + accelerator = Accelerator( log_with="wandb", mixed_precision=config.mixed_precision, project_dir=config.logdir, - gradient_accumulation_steps=config.train.gradient_accumulation_steps * config.sample.num_steps, + # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of + # _samples_ to accumulate across + gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, ) if accelerator.is_main_process: accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) logger.info(f"\n{config}") - # set seed + # set seed (device_specific is very important to get different prompts on different devices) set_seed(config.seed, device_specific=True) # load scheduler, tokenizer and models. @@ -152,7 +158,8 @@ def main(_): config.per_prompt_stat_tracking.min_count, ) - # for some reason, autocast is necessary for non-lora training but for lora training it uses more memory + # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses + # more memory autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast # Prepare everything with our `accelerator`. @@ -289,8 +296,15 @@ def main(_): #################### TRAINING #################### for inner_epoch in range(config.train.num_inner_epochs): # shuffle samples along batch dimension - indices = torch.randperm(total_batch_size, device=accelerator.device) - samples = {k: v[indices] for k, v in samples.items()} + perm = torch.randperm(total_batch_size, device=accelerator.device) + samples = {k: v[perm] for k, v in samples.items()} + + # shuffle along time dimension independently for each sample + perms = torch.stack( + [torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)] + ) + for key in ["timesteps", "latents", "next_latents", "log_probs"]: + samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms] # rebatch for training samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} @@ -300,6 +314,7 @@ def main(_): # train pipeline.unet.train() + info = defaultdict(list) for i, sample in tqdm( list(enumerate(samples_batched)), desc=f"Epoch {epoch}.{inner_epoch}: training", @@ -312,9 +327,8 @@ def main(_): else: embeds = sample["prompt_embeds"] - info = defaultdict(list) for j in tqdm( - range(num_timesteps), + range(num_train_timesteps), desc="Timestep", position=1, leave=False, @@ -371,14 +385,20 @@ def main(_): optimizer.step() optimizer.zero_grad() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0 # log training-related stuff info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} + info = accelerator.reduce(info, reduction="mean") info.update({"epoch": epoch, "inner_epoch": inner_epoch}) accelerator.log(info, step=global_step) global_step += 1 info = defaultdict(list) + # make sure we did an optimization step at the end of the inner epoch + assert accelerator.sync_gradients + if __name__ == "__main__": app.run(main)