From beb8c2f86d8b6f8efd4bde5cb76047a51eb7b4cd Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Tue, 4 Jul 2023 00:25:37 -0700 Subject: [PATCH] Update configs --- config/base.py | 6 ++-- config/dgx.py | 87 +++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/config/base.py b/config/base.py index c9601cd..8ac8556 100644 --- a/config/base.py +++ b/config/base.py @@ -42,7 +42,7 @@ def get_config(): ###### Sampling ###### config.sample = sample = ml_collections.ConfigDict() # number of sampler inference steps. - sample.num_steps = 10 + sample.num_steps = 50 # eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0 # being fully deterministic and 1.0 being equivalent to the DDPM sampler. sample.eta = 1.0 @@ -61,7 +61,7 @@ def get_config(): # whether to use the 8bit Adam optimizer from bitsandbytes. train.use_8bit_adam = False # learning rate. - train.learning_rate = 1e-4 + train.learning_rate = 3e-4 # Adam beta1. train.adam_beta1 = 0.9 # Adam beta2. @@ -82,7 +82,7 @@ def get_config(): # sampling will be used during training. train.cfg = True # clip advantages to the range [-adv_clip_max, adv_clip_max]. - train.adv_clip_max = 10 + train.adv_clip_max = 5 # the PPO clip range. train.clip_range = 1e-4 # the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the diff --git a/config/dgx.py b/config/dgx.py index c5a6230..d0c60bf 100644 --- a/config/dgx.py +++ b/config/dgx.py @@ -5,28 +5,91 @@ import os base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) -def get_config(): +def compressibility(): config = base.get_config() - config.pretrained.model = "runwayml/stable-diffusion-v1-5" + config.pretrained.model = "CompVis/stable-diffusion-v1-4" - config.mixed_precision = "fp16" - config.allow_tf32 = True - config.use_lora = False + config.num_epochs = 100 + config.use_lora = True + config.save_freq = 1 + config.num_checkpoint_limit = 100000000 - config.train.batch_size = 4 - config.train.gradient_accumulation_steps = 2 - config.train.learning_rate = 3e-5 - config.train.clip_range = 1e-4 - - # sampling - config.sample.num_steps = 50 + # the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch. config.sample.batch_size = 8 config.sample.num_batches_per_epoch = 4 + # this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch. + config.train.batch_size = 4 + config.train.gradient_accumulation_steps = 2 + + # prompting + config.prompt_fn = "imagenet_animals" + config.prompt_fn_kwargs = {} + + # rewards + config.reward_fn = "jpeg_compressibility" + config.per_prompt_stat_tracking = { "buffer_size": 16, "min_count": 16, } return config + + +def incompressibility(): + config = compressibility() + config.reward_fn = "jpeg_incompressibility" + return config + + +def aesthetic(): + config = compressibility() + config.num_epochs = 200 + config.reward_fn = "aesthetic_score" + + # this reward is a bit harder to optimize, so I used 2 gradient updates per epoch. + config.train.gradient_accumulation_steps = 4 + + config.prompt_fn = "simple_animals" + config.per_prompt_stat_tracking = { + "buffer_size": 32, + "min_count": 16, + } + return config + + +def prompt_image_alignment(): + config = compressibility() + + config.num_epochs = 200 + # for this experiment, I reserved 2 GPUs for LLaVA inference so only 6 could be used for DDPO. the total number of + # samples per epoch is 8 * 6 * 6 = 288. + config.sample.batch_size = 8 + config.sample.num_batches_per_epoch = 6 + + # again, this one is harder to optimize, so I used (8 * 6) / (4 * 6) = 2 gradient updates per epoch. + config.train.batch_size = 4 + config.train.gradient_accumulation_steps = 6 + + # prompting + config.prompt_fn = "nouns_activities" + config.prompt_fn_kwargs = { + "nouns_file": "simple_animals.txt", + "activities_file": "activities.txt", + } + + # rewards + config.reward_fn = "llava_bertscore" + + config.per_prompt_stat_tracking = { + "buffer_size": 32, + "min_count": 16, + } + + return config + + +def get_config(name): + return globals()[name]()