Update configs
This commit is contained in:
		| @@ -42,7 +42,7 @@ def get_config(): | ||||
|     ###### Sampling ###### | ||||
|     config.sample = sample = ml_collections.ConfigDict() | ||||
|     # number of sampler inference steps. | ||||
|     sample.num_steps = 10 | ||||
|     sample.num_steps = 50 | ||||
|     # eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0 | ||||
|     # being fully deterministic and 1.0 being equivalent to the DDPM sampler. | ||||
|     sample.eta = 1.0 | ||||
| @@ -61,7 +61,7 @@ def get_config(): | ||||
|     # whether to use the 8bit Adam optimizer from bitsandbytes. | ||||
|     train.use_8bit_adam = False | ||||
|     # learning rate. | ||||
|     train.learning_rate = 1e-4 | ||||
|     train.learning_rate = 3e-4 | ||||
|     # Adam beta1. | ||||
|     train.adam_beta1 = 0.9 | ||||
|     # Adam beta2. | ||||
| @@ -82,7 +82,7 @@ def get_config(): | ||||
|     # sampling will be used during training. | ||||
|     train.cfg = True | ||||
|     # clip advantages to the range [-adv_clip_max, adv_clip_max]. | ||||
|     train.adv_clip_max = 10 | ||||
|     train.adv_clip_max = 5 | ||||
|     # the PPO clip range. | ||||
|     train.clip_range = 1e-4 | ||||
|     # the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the | ||||
|   | ||||
| @@ -5,28 +5,91 @@ import os | ||||
| base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) | ||||
|  | ||||
|  | ||||
| def get_config(): | ||||
| def compressibility(): | ||||
|     config = base.get_config() | ||||
|  | ||||
|     config.pretrained.model = "runwayml/stable-diffusion-v1-5" | ||||
|     config.pretrained.model = "CompVis/stable-diffusion-v1-4" | ||||
|  | ||||
|     config.mixed_precision = "fp16" | ||||
|     config.allow_tf32 = True | ||||
|     config.use_lora = False | ||||
|     config.num_epochs = 100 | ||||
|     config.use_lora = True | ||||
|     config.save_freq = 1 | ||||
|     config.num_checkpoint_limit = 100000000 | ||||
|  | ||||
|     config.train.batch_size = 4 | ||||
|     config.train.gradient_accumulation_steps = 2 | ||||
|     config.train.learning_rate = 3e-5 | ||||
|     config.train.clip_range = 1e-4 | ||||
|  | ||||
|     # sampling | ||||
|     config.sample.num_steps = 50 | ||||
|     # the DGX machine I used had 8 GPUs, so this corresponds to 8 * 8 * 4 = 256 samples per epoch. | ||||
|     config.sample.batch_size = 8 | ||||
|     config.sample.num_batches_per_epoch = 4 | ||||
|  | ||||
|     # this corresponds to (8 * 4) / (4 * 2) = 4 gradient updates per epoch. | ||||
|     config.train.batch_size = 4 | ||||
|     config.train.gradient_accumulation_steps = 2 | ||||
|  | ||||
|     # prompting | ||||
|     config.prompt_fn = "imagenet_animals" | ||||
|     config.prompt_fn_kwargs = {} | ||||
|  | ||||
|     # rewards | ||||
|     config.reward_fn = "jpeg_compressibility" | ||||
|  | ||||
|     config.per_prompt_stat_tracking = { | ||||
|         "buffer_size": 16, | ||||
|         "min_count": 16, | ||||
|     } | ||||
|  | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def incompressibility(): | ||||
|     config = compressibility() | ||||
|     config.reward_fn = "jpeg_incompressibility" | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def aesthetic(): | ||||
|     config = compressibility() | ||||
|     config.num_epochs = 200 | ||||
|     config.reward_fn = "aesthetic_score" | ||||
|  | ||||
|     # this reward is a bit harder to optimize, so I used 2 gradient updates per epoch. | ||||
|     config.train.gradient_accumulation_steps = 4 | ||||
|  | ||||
|     config.prompt_fn = "simple_animals" | ||||
|     config.per_prompt_stat_tracking = { | ||||
|         "buffer_size": 32, | ||||
|         "min_count": 16, | ||||
|     } | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def prompt_image_alignment(): | ||||
|     config = compressibility() | ||||
|  | ||||
|     config.num_epochs = 200 | ||||
|     # for this experiment, I reserved 2 GPUs for LLaVA inference so only 6 could be used for DDPO. the total number of | ||||
|     # samples per epoch is 8 * 6 * 6 = 288. | ||||
|     config.sample.batch_size = 8 | ||||
|     config.sample.num_batches_per_epoch = 6 | ||||
|  | ||||
|     # again, this one is harder to optimize, so I used (8 * 6) / (4 * 6) = 2 gradient updates per epoch. | ||||
|     config.train.batch_size = 4 | ||||
|     config.train.gradient_accumulation_steps = 6 | ||||
|  | ||||
|     # prompting | ||||
|     config.prompt_fn = "nouns_activities" | ||||
|     config.prompt_fn_kwargs = { | ||||
|         "nouns_file": "simple_animals.txt", | ||||
|         "activities_file": "activities.txt", | ||||
|     } | ||||
|  | ||||
|     # rewards | ||||
|     config.reward_fn = "llava_bertscore" | ||||
|  | ||||
|     config.per_prompt_stat_tracking = { | ||||
|         "buffer_size": 32, | ||||
|         "min_count": 16, | ||||
|     } | ||||
|  | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def get_config(name): | ||||
|     return globals()[name]() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user