Minor changes; add train_timestep_fraction
This commit is contained in:
@@ -34,17 +34,23 @@ logger = get_logger(__name__)
|
||||
def main(_):
|
||||
# basic Accelerate and logging setup
|
||||
config = FLAGS.config
|
||||
|
||||
# number of timesteps within each trajectory to train on
|
||||
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)
|
||||
|
||||
accelerator = Accelerator(
|
||||
log_with="wandb",
|
||||
mixed_precision=config.mixed_precision,
|
||||
project_dir=config.logdir,
|
||||
gradient_accumulation_steps=config.train.gradient_accumulation_steps * config.sample.num_steps,
|
||||
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
|
||||
# _samples_ to accumulate across
|
||||
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict())
|
||||
logger.info(f"\n{config}")
|
||||
|
||||
# set seed
|
||||
# set seed (device_specific is very important to get different prompts on different devices)
|
||||
set_seed(config.seed, device_specific=True)
|
||||
|
||||
# load scheduler, tokenizer and models.
|
||||
@@ -152,7 +158,8 @@ def main(_):
|
||||
config.per_prompt_stat_tracking.min_count,
|
||||
)
|
||||
|
||||
# for some reason, autocast is necessary for non-lora training but for lora training it uses more memory
|
||||
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||
# more memory
|
||||
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
@@ -289,8 +296,15 @@ def main(_):
|
||||
#################### TRAINING ####################
|
||||
for inner_epoch in range(config.train.num_inner_epochs):
|
||||
# shuffle samples along batch dimension
|
||||
indices = torch.randperm(total_batch_size, device=accelerator.device)
|
||||
samples = {k: v[indices] for k, v in samples.items()}
|
||||
perm = torch.randperm(total_batch_size, device=accelerator.device)
|
||||
samples = {k: v[perm] for k, v in samples.items()}
|
||||
|
||||
# shuffle along time dimension independently for each sample
|
||||
perms = torch.stack(
|
||||
[torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)]
|
||||
)
|
||||
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
||||
samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms]
|
||||
|
||||
# rebatch for training
|
||||
samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()}
|
||||
@@ -300,6 +314,7 @@ def main(_):
|
||||
|
||||
# train
|
||||
pipeline.unet.train()
|
||||
info = defaultdict(list)
|
||||
for i, sample in tqdm(
|
||||
list(enumerate(samples_batched)),
|
||||
desc=f"Epoch {epoch}.{inner_epoch}: training",
|
||||
@@ -312,9 +327,8 @@ def main(_):
|
||||
else:
|
||||
embeds = sample["prompt_embeds"]
|
||||
|
||||
info = defaultdict(list)
|
||||
for j in tqdm(
|
||||
range(num_timesteps),
|
||||
range(num_train_timesteps),
|
||||
desc="Timestep",
|
||||
position=1,
|
||||
leave=False,
|
||||
@@ -371,14 +385,20 @@ def main(_):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0
|
||||
# log training-related stuff
|
||||
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
||||
info = accelerator.reduce(info, reduction="mean")
|
||||
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
||||
accelerator.log(info, step=global_step)
|
||||
global_step += 1
|
||||
info = defaultdict(list)
|
||||
|
||||
# make sure we did an optimization step at the end of the inner epoch
|
||||
assert accelerator.sync_gradients
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
Reference in New Issue
Block a user