Minor changes; add train_timestep_fraction

This commit is contained in:
Kevin Black
2023-06-27 22:17:32 -07:00
parent bae3f43f5f
commit 28d2d8c40e
5 changed files with 50 additions and 26 deletions

View File

@@ -34,17 +34,23 @@ logger = get_logger(__name__)
def main(_):
# basic Accelerate and logging setup
config = FLAGS.config
# number of timesteps within each trajectory to train on
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)
accelerator = Accelerator(
log_with="wandb",
mixed_precision=config.mixed_precision,
project_dir=config.logdir,
gradient_accumulation_steps=config.train.gradient_accumulation_steps * config.sample.num_steps,
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
# _samples_ to accumulate across
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
)
if accelerator.is_main_process:
accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict())
logger.info(f"\n{config}")
# set seed
# set seed (device_specific is very important to get different prompts on different devices)
set_seed(config.seed, device_specific=True)
# load scheduler, tokenizer and models.
@@ -152,7 +158,8 @@ def main(_):
config.per_prompt_stat_tracking.min_count,
)
# for some reason, autocast is necessary for non-lora training but for lora training it uses more memory
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
# more memory
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
# Prepare everything with our `accelerator`.
@@ -289,8 +296,15 @@ def main(_):
#################### TRAINING ####################
for inner_epoch in range(config.train.num_inner_epochs):
# shuffle samples along batch dimension
indices = torch.randperm(total_batch_size, device=accelerator.device)
samples = {k: v[indices] for k, v in samples.items()}
perm = torch.randperm(total_batch_size, device=accelerator.device)
samples = {k: v[perm] for k, v in samples.items()}
# shuffle along time dimension independently for each sample
perms = torch.stack(
[torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)]
)
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms]
# rebatch for training
samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()}
@@ -300,6 +314,7 @@ def main(_):
# train
pipeline.unet.train()
info = defaultdict(list)
for i, sample in tqdm(
list(enumerate(samples_batched)),
desc=f"Epoch {epoch}.{inner_epoch}: training",
@@ -312,9 +327,8 @@ def main(_):
else:
embeds = sample["prompt_embeds"]
info = defaultdict(list)
for j in tqdm(
range(num_timesteps),
range(num_train_timesteps),
desc="Timestep",
position=1,
leave=False,
@@ -371,14 +385,20 @@ def main(_):
optimizer.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0
# log training-related stuff
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
info = accelerator.reduce(info, reduction="mean")
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
accelerator.log(info, step=global_step)
global_step += 1
info = defaultdict(list)
# make sure we did an optimization step at the end of the inner epoch
assert accelerator.sync_gradients
if __name__ == "__main__":
app.run(main)