Update configs
This commit is contained in:
parent
ec499edf84
commit
beb8c2f86d
@ -42,7 +42,7 @@ def get_config():
|
|||||||
###### Sampling ######
|
###### Sampling ######
|
||||||
config.sample = sample = ml_collections.ConfigDict()
|
config.sample = sample = ml_collections.ConfigDict()
|
||||||
# number of sampler inference steps.
|
# number of sampler inference steps.
|
||||||
sample.num_steps = 10
|
sample.num_steps = 50
|
||||||
# eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0
|
# eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0
|
||||||
# being fully deterministic and 1.0 being equivalent to the DDPM sampler.
|
# being fully deterministic and 1.0 being equivalent to the DDPM sampler.
|
||||||
sample.eta = 1.0
|
sample.eta = 1.0
|
||||||
@ -61,7 +61,7 @@ def get_config():
|
|||||||
# whether to use the 8bit Adam optimizer from bitsandbytes.
|
# whether to use the 8bit Adam optimizer from bitsandbytes.
|
||||||
train.use_8bit_adam = False
|
train.use_8bit_adam = False
|
||||||
# learning rate.
|
# learning rate.
|
||||||
train.learning_rate = 1e-4
|
train.learning_rate = 3e-4
|
||||||
# Adam beta1.
|
# Adam beta1.
|
||||||
train.adam_beta1 = 0.9
|
train.adam_beta1 = 0.9
|
||||||
# Adam beta2.
|
# Adam beta2.
|
||||||
@ -82,7 +82,7 @@ def get_config():
|
|||||||
# sampling will be used during training.
|
# sampling will be used during training.
|
||||||
train.cfg = True
|
train.cfg = True
|
||||||
# clip advantages to the range [-adv_clip_max, adv_clip_max].
|
# clip advantages to the range [-adv_clip_max, adv_clip_max].
|
||||||
train.adv_clip_max = 10
|
train.adv_clip_max = 5
|
||||||
# the PPO clip range.
|
# the PPO clip range.
|
||||||
train.clip_range = 1e-4
|
train.clip_range = 1e-4
|
||||||
# the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the
|
# the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the
|
||||||
|
@ -5,28 +5,91 @@ import os
|
|||||||
base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))
|
base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))
|
||||||
|
|
||||||
|
|
||||||
def get_config():
|
def compressibility():
|
||||||
config = base.get_config()
|
config = base.get_config()
|
||||||
|
|
||||||
config.pretrained.model = "runwayml/stable-diffusion-v1-5"
|
config.pretrained.model = "CompVis/stable-diffusion-v1-4"
|
||||||
|
|
||||||
config.mixed_precision = "fp16"
|
config.num_epochs = 100
|
||||||
config.allow_tf32 = True
|
config.use_lora = True
|
||||||
config.use_lora = False
|
config.save_freq = 1
|
||||||
|
config.num_checkpoint_limit = 100000000
|
||||||
|
|
||||||
config.train.batch_size = 4
|
# the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch.
|
||||||
config.train.gradient_accumulation_steps = 2
|
|
||||||
config.train.learning_rate = 3e-5
|
|
||||||
config.train.clip_range = 1e-4
|
|
||||||
|
|
||||||
# sampling
|
|
||||||
config.sample.num_steps = 50
|
|
||||||
config.sample.batch_size = 8
|
config.sample.batch_size = 8
|
||||||
config.sample.num_batches_per_epoch = 4
|
config.sample.num_batches_per_epoch = 4
|
||||||
|
|
||||||
|
# this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch.
|
||||||
|
config.train.batch_size = 4
|
||||||
|
config.train.gradient_accumulation_steps = 2
|
||||||
|
|
||||||
|
# prompting
|
||||||
|
config.prompt_fn = "imagenet_animals"
|
||||||
|
config.prompt_fn_kwargs = {}
|
||||||
|
|
||||||
|
# rewards
|
||||||
|
config.reward_fn = "jpeg_compressibility"
|
||||||
|
|
||||||
config.per_prompt_stat_tracking = {
|
config.per_prompt_stat_tracking = {
|
||||||
"buffer_size": 16,
|
"buffer_size": 16,
|
||||||
"min_count": 16,
|
"min_count": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def incompressibility():
|
||||||
|
config = compressibility()
|
||||||
|
config.reward_fn = "jpeg_incompressibility"
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def aesthetic():
|
||||||
|
config = compressibility()
|
||||||
|
config.num_epochs = 200
|
||||||
|
config.reward_fn = "aesthetic_score"
|
||||||
|
|
||||||
|
# this reward is a bit harder to optimize, so I used 2 gradient updates per epoch.
|
||||||
|
config.train.gradient_accumulation_steps = 4
|
||||||
|
|
||||||
|
config.prompt_fn = "simple_animals"
|
||||||
|
config.per_prompt_stat_tracking = {
|
||||||
|
"buffer_size": 32,
|
||||||
|
"min_count": 16,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_image_alignment():
|
||||||
|
config = compressibility()
|
||||||
|
|
||||||
|
config.num_epochs = 200
|
||||||
|
# for this experiment, I reserved 2 GPUs for LLaVA inference so only 6 could be used for DDPO. the total number of
|
||||||
|
# samples per epoch is 8 * 6 * 6 = 288.
|
||||||
|
config.sample.batch_size = 8
|
||||||
|
config.sample.num_batches_per_epoch = 6
|
||||||
|
|
||||||
|
# again, this one is harder to optimize, so I used (8 * 6) / (4 * 6) = 2 gradient updates per epoch.
|
||||||
|
config.train.batch_size = 4
|
||||||
|
config.train.gradient_accumulation_steps = 6
|
||||||
|
|
||||||
|
# prompting
|
||||||
|
config.prompt_fn = "nouns_activities"
|
||||||
|
config.prompt_fn_kwargs = {
|
||||||
|
"nouns_file": "simple_animals.txt",
|
||||||
|
"activities_file": "activities.txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
# rewards
|
||||||
|
config.reward_fn = "llava_bertscore"
|
||||||
|
|
||||||
|
config.per_prompt_stat_tracking = {
|
||||||
|
"buffer_size": 32,
|
||||||
|
"min_count": 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(name):
|
||||||
|
return globals()[name]()
|
||||||
|
Loading…
Reference in New Issue
Block a user