Commenting pass
This commit is contained in:
parent
8779f62a1c
commit
c0bc708549
101
config/base.py
101
config/base.py
@ -1,60 +1,113 @@
|
|||||||
import ml_collections
|
import ml_collections
|
||||||
|
|
||||||
def get_config():
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
config = ml_collections.ConfigDict()
|
config = ml_collections.ConfigDict()
|
||||||
|
|
||||||
# misc
|
###### General ######
|
||||||
|
# run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime.
|
||||||
config.run_name = ""
|
config.run_name = ""
|
||||||
|
# random seed for reproducibility.
|
||||||
config.seed = 42
|
config.seed = 42
|
||||||
|
# top-level logging directory for checkpoint saving.
|
||||||
config.logdir = "logs"
|
config.logdir = "logs"
|
||||||
|
# number of epochs to train for. each epoch is one round of sampling from the model followed by training on those
|
||||||
|
# samples.
|
||||||
config.num_epochs = 100
|
config.num_epochs = 100
|
||||||
|
# number of epochs between saving model checkpoints.
|
||||||
config.save_freq = 20
|
config.save_freq = 20
|
||||||
|
# number of checkpoints to keep before overwriting old ones.
|
||||||
config.num_checkpoint_limit = 5
|
config.num_checkpoint_limit = 5
|
||||||
|
# mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly.
|
||||||
config.mixed_precision = "fp16"
|
config.mixed_precision = "fp16"
|
||||||
|
# allow tf32 on Ampere GPUs, which can speed up training.
|
||||||
config.allow_tf32 = True
|
config.allow_tf32 = True
|
||||||
config.use_lora = True
|
# resume training from a checkpoint. either an exact checkpoint directory (e.g. checkpoint_50), or a directory
|
||||||
|
# containing checkpoints, in which case the latest one will be used. `config.use_lora` must be set to the same value
|
||||||
|
# as the run that generated the saved checkpoint.
|
||||||
config.resume_from = ""
|
config.resume_from = ""
|
||||||
|
# whether or not to use LoRA. LoRA reduces memory usage significantly by injecting small weight matrices into the
|
||||||
|
# attention layers of the UNet. with LoRA, fp16, and a batch size of 1, finetuning Stable Diffusion should take
|
||||||
|
# about 10GB of GPU memory. beware that if LoRA is disabled, training will take a lot of memory and saved checkpoint
|
||||||
|
# files will also be large.
|
||||||
|
config.use_lora = True
|
||||||
|
|
||||||
# pretrained model initialization
|
###### Pretrained Model ######
|
||||||
config.pretrained = pretrained = ml_collections.ConfigDict()
|
config.pretrained = pretrained = ml_collections.ConfigDict()
|
||||||
|
# base model to load. either a path to a local directory, or a model name from the HuggingFace model hub.
|
||||||
pretrained.model = "runwayml/stable-diffusion-v1-5"
|
pretrained.model = "runwayml/stable-diffusion-v1-5"
|
||||||
|
# revision of the model to load.
|
||||||
pretrained.revision = "main"
|
pretrained.revision = "main"
|
||||||
|
|
||||||
# training
|
###### Sampling ######
|
||||||
config.train = train = ml_collections.ConfigDict()
|
|
||||||
train.batch_size = 1
|
|
||||||
train.use_8bit_adam = False
|
|
||||||
train.learning_rate = 1e-4
|
|
||||||
train.adam_beta1 = 0.9
|
|
||||||
train.adam_beta2 = 0.999
|
|
||||||
train.adam_weight_decay = 1e-4
|
|
||||||
train.adam_epsilon = 1e-8
|
|
||||||
train.gradient_accumulation_steps = 1
|
|
||||||
train.max_grad_norm = 1.0
|
|
||||||
train.num_inner_epochs = 1
|
|
||||||
train.cfg = True
|
|
||||||
train.adv_clip_max = 10
|
|
||||||
train.clip_range = 1e-4
|
|
||||||
train.timestep_fraction = 1.0
|
|
||||||
|
|
||||||
# sampling
|
|
||||||
config.sample = sample = ml_collections.ConfigDict()
|
config.sample = sample = ml_collections.ConfigDict()
|
||||||
|
# number of sampler inference steps.
|
||||||
sample.num_steps = 10
|
sample.num_steps = 10
|
||||||
|
# eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0
|
||||||
|
# being fully deterministic and 1.0 being equivalent to the DDPM sampler.
|
||||||
sample.eta = 1.0
|
sample.eta = 1.0
|
||||||
|
# classifier-free guidance weight. 1.0 is no guidance.
|
||||||
sample.guidance_scale = 5.0
|
sample.guidance_scale = 5.0
|
||||||
|
# batch size (per GPU!) to use for sampling.
|
||||||
sample.batch_size = 1
|
sample.batch_size = 1
|
||||||
|
# number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch *
|
||||||
|
# batch_size * num_gpus`.
|
||||||
sample.num_batches_per_epoch = 2
|
sample.num_batches_per_epoch = 2
|
||||||
|
|
||||||
# prompting
|
###### Training ######
|
||||||
|
config.train = train = ml_collections.ConfigDict()
|
||||||
|
# batch size (per GPU!) to use for training.
|
||||||
|
train.batch_size = 1
|
||||||
|
# whether to use the 8bit Adam optimizer from bitsandbytes.
|
||||||
|
train.use_8bit_adam = False
|
||||||
|
# learning rate.
|
||||||
|
train.learning_rate = 1e-4
|
||||||
|
# Adam beta1.
|
||||||
|
train.adam_beta1 = 0.9
|
||||||
|
# Adam beta2.
|
||||||
|
train.adam_beta2 = 0.999
|
||||||
|
# Adam weight decay.
|
||||||
|
train.adam_weight_decay = 1e-4
|
||||||
|
# Adam epsilon.
|
||||||
|
train.adam_epsilon = 1e-8
|
||||||
|
# number of gradient accumulation steps. the effective batch size is `batch_size * num_gpus *
|
||||||
|
# gradient_accumulation_steps`.
|
||||||
|
train.gradient_accumulation_steps = 1
|
||||||
|
# maximum gradient norm for gradient clipping.
|
||||||
|
train.max_grad_norm = 1.0
|
||||||
|
# number of inner epochs per outer epoch. each inner epoch is one iteration through the data collected during one
|
||||||
|
# outer epoch's round of sampling.
|
||||||
|
train.num_inner_epochs = 1
|
||||||
|
# whether or not to use classifier-free guidance during training. if enabled, the same guidance scale used during
|
||||||
|
# sampling will be used during training.
|
||||||
|
train.cfg = True
|
||||||
|
# clip advantages to the range [-adv_clip_max, adv_clip_max].
|
||||||
|
train.adv_clip_max = 10
|
||||||
|
# the PPO clip range.
|
||||||
|
train.clip_range = 1e-4
|
||||||
|
# the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the
|
||||||
|
# timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates.
|
||||||
|
train.timestep_fraction = 1.0
|
||||||
|
|
||||||
|
###### Prompt Function ######
|
||||||
|
# prompt function to use. see `prompts.py` for available prompt functions.
|
||||||
config.prompt_fn = "imagenet_animals"
|
config.prompt_fn = "imagenet_animals"
|
||||||
|
# kwargs to pass to the prompt function.
|
||||||
config.prompt_fn_kwargs = {}
|
config.prompt_fn_kwargs = {}
|
||||||
|
|
||||||
# rewards
|
###### Reward Function ######
|
||||||
|
# reward function to use. see `rewards.py` for available reward functions.
|
||||||
config.reward_fn = "jpeg_compressibility"
|
config.reward_fn = "jpeg_compressibility"
|
||||||
|
|
||||||
|
###### Per-Prompt Stat Tracking ######
|
||||||
|
# when enabled, the model will track the mean and std of reward on a per-prompt basis and use that to compute
|
||||||
|
# advantages. set `config.per_prompt_stat_tracking` to None to disable per-prompt stat tracking, in which case
|
||||||
|
# advantages will be calculated using the mean and std of the entire batch.
|
||||||
config.per_prompt_stat_tracking = ml_collections.ConfigDict()
|
config.per_prompt_stat_tracking = ml_collections.ConfigDict()
|
||||||
|
# number of reward values to store in the buffer for each prompt. the buffer persists across epochs.
|
||||||
config.per_prompt_stat_tracking.buffer_size = 16
|
config.per_prompt_stat_tracking.buffer_size = 16
|
||||||
|
# the minimum number of reward values to store in the buffer before using the per-prompt mean and std. if the buffer
|
||||||
|
# contains fewer than `min_count` values, the mean and std of the entire batch will be used instead.
|
||||||
config.per_prompt_stat_tracking.min_count = 16
|
config.per_prompt_stat_tracking.min_count = 16
|
||||||
|
|
||||||
return config
|
return config
|
@ -36,7 +36,7 @@ def main(_):
|
|||||||
# basic Accelerate and logging setup
|
# basic Accelerate and logging setup
|
||||||
config = FLAGS.config
|
config = FLAGS.config
|
||||||
|
|
||||||
unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
|
||||||
if not config.run_name:
|
if not config.run_name:
|
||||||
config.run_name = unique_id
|
config.run_name = unique_id
|
||||||
else:
|
else:
|
||||||
@ -67,8 +67,9 @@ def main(_):
|
|||||||
log_with="wandb",
|
log_with="wandb",
|
||||||
mixed_precision=config.mixed_precision,
|
mixed_precision=config.mixed_precision,
|
||||||
project_config=accelerator_config,
|
project_config=accelerator_config,
|
||||||
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
|
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
||||||
# _samples_ to accumulate across
|
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
||||||
|
# the total number of optimizer steps to accumulate across.
|
||||||
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
|
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
|
||||||
)
|
)
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
@ -243,6 +244,7 @@ def main(_):
|
|||||||
logger.info(f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}")
|
logger.info(f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}")
|
||||||
logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}")
|
logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}")
|
||||||
|
|
||||||
|
assert config.sample.batch_size >= config.train.batch_size
|
||||||
assert config.sample.batch_size % config.train.batch_size == 0
|
assert config.sample.batch_size % config.train.batch_size == 0
|
||||||
assert samples_per_epoch % total_train_batch_size == 0
|
assert samples_per_epoch % total_train_batch_size == 0
|
||||||
|
|
||||||
@ -418,6 +420,7 @@ def main(_):
|
|||||||
noise_pred = pipeline.unet(
|
noise_pred = pipeline.unet(
|
||||||
sample["latents"][:, j], sample["timesteps"][:, j], embeds
|
sample["latents"][:, j], sample["timesteps"][:, j], embeds
|
||||||
).sample
|
).sample
|
||||||
|
# compute the log prob of next_latents given latents under the current model
|
||||||
_, log_prob = ddim_step_with_logprob(
|
_, log_prob = ddim_step_with_logprob(
|
||||||
pipeline.scheduler,
|
pipeline.scheduler,
|
||||||
noise_pred,
|
noise_pred,
|
||||||
|
Loading…
Reference in New Issue
Block a user