Commenting pass

This commit is contained in:
Kevin Black
2023-06-29 00:51:38 -07:00
parent 8779f62a1c
commit c0bc708549
2 changed files with 84 additions and 28 deletions

View File

@@ -36,7 +36,7 @@ def main(_):
# basic Accelerate and logging setup
config = FLAGS.config
unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
if not config.run_name:
config.run_name = unique_id
else:
@@ -67,8 +67,9 @@ def main(_):
log_with="wandb",
mixed_precision=config.mixed_precision,
project_config=accelerator_config,
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
# _samples_ to accumulate across
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
# the total number of optimizer steps to accumulate across.
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
)
if accelerator.is_main_process:
@@ -243,6 +244,7 @@ def main(_):
logger.info(f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}")
logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}")
assert config.sample.batch_size >= config.train.batch_size
assert config.sample.batch_size % config.train.batch_size == 0
assert samples_per_epoch % total_train_batch_size == 0
@@ -418,6 +420,7 @@ def main(_):
noise_pred = pipeline.unet(
sample["latents"][:, j], sample["timesteps"][:, j], embeds
).sample
# compute the log prob of next_latents given latents under the current model
_, log_prob = ddim_step_with_logprob(
pipeline.scheduler,
noise_pred,