Minor changes; add train_timestep_fraction
This commit is contained in:
		| @@ -34,17 +34,23 @@ logger = get_logger(__name__) | ||||
| def main(_): | ||||
|     # basic Accelerate and logging setup | ||||
|     config = FLAGS.config | ||||
|  | ||||
|     # number of timesteps within each trajectory to train on | ||||
|     num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) | ||||
|  | ||||
|     accelerator = Accelerator( | ||||
|         log_with="wandb", | ||||
|         mixed_precision=config.mixed_precision, | ||||
|         project_dir=config.logdir, | ||||
|         gradient_accumulation_steps=config.train.gradient_accumulation_steps * config.sample.num_steps, | ||||
|         # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of | ||||
|         # _samples_ to accumulate across | ||||
|         gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, | ||||
|     ) | ||||
|     if accelerator.is_main_process: | ||||
|         accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) | ||||
|     logger.info(f"\n{config}") | ||||
|  | ||||
|     # set seed | ||||
|     # set seed (device_specific is very important to get different prompts on different devices) | ||||
|     set_seed(config.seed, device_specific=True) | ||||
|  | ||||
|     # load scheduler, tokenizer and models. | ||||
| @@ -152,7 +158,8 @@ def main(_): | ||||
|             config.per_prompt_stat_tracking.min_count, | ||||
|         ) | ||||
|  | ||||
|     # for some reason, autocast is necessary for non-lora training but for lora training it uses more memory | ||||
|     # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses | ||||
|     # more memory | ||||
|     autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast | ||||
|  | ||||
|     # Prepare everything with our `accelerator`. | ||||
| @@ -289,8 +296,15 @@ def main(_): | ||||
|         #################### TRAINING #################### | ||||
|         for inner_epoch in range(config.train.num_inner_epochs): | ||||
|             # shuffle samples along batch dimension | ||||
|             indices = torch.randperm(total_batch_size, device=accelerator.device) | ||||
|             samples = {k: v[indices] for k, v in samples.items()} | ||||
|             perm = torch.randperm(total_batch_size, device=accelerator.device) | ||||
|             samples = {k: v[perm] for k, v in samples.items()} | ||||
|  | ||||
|             # shuffle along time dimension independently for each sample | ||||
|             perms = torch.stack( | ||||
|                 [torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)] | ||||
|             ) | ||||
|             for key in ["timesteps", "latents", "next_latents", "log_probs"]: | ||||
|                 samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms] | ||||
|  | ||||
|             # rebatch for training | ||||
|             samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} | ||||
| @@ -300,6 +314,7 @@ def main(_): | ||||
|  | ||||
|             # train | ||||
|             pipeline.unet.train() | ||||
|             info = defaultdict(list) | ||||
|             for i, sample in tqdm( | ||||
|                 list(enumerate(samples_batched)), | ||||
|                 desc=f"Epoch {epoch}.{inner_epoch}: training", | ||||
| @@ -312,9 +327,8 @@ def main(_): | ||||
|                 else: | ||||
|                     embeds = sample["prompt_embeds"] | ||||
|  | ||||
|                 info = defaultdict(list) | ||||
|                 for j in tqdm( | ||||
|                     range(num_timesteps), | ||||
|                     range(num_train_timesteps), | ||||
|                     desc="Timestep", | ||||
|                     position=1, | ||||
|                     leave=False, | ||||
| @@ -371,14 +385,20 @@ def main(_): | ||||
|                         optimizer.step() | ||||
|                         optimizer.zero_grad() | ||||
|  | ||||
|                     # Checks if the accelerator has performed an optimization step behind the scenes | ||||
|                     if accelerator.sync_gradients: | ||||
|                         assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0 | ||||
|                         # log training-related stuff | ||||
|                         info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} | ||||
|                         info = accelerator.reduce(info, reduction="mean") | ||||
|                         info.update({"epoch": epoch, "inner_epoch": inner_epoch}) | ||||
|                         accelerator.log(info, step=global_step) | ||||
|                         global_step += 1 | ||||
|                         info = defaultdict(list) | ||||
|  | ||||
|             # make sure we did an optimization step at the end of the inner epoch | ||||
|             assert accelerator.sync_gradients | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     app.run(main) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user