Working non-lora training; other changes
This commit is contained in:
		| @@ -10,6 +10,7 @@ def get_config(): | ||||
|     config.num_epochs = 100 | ||||
|     config.mixed_precision = "fp16" | ||||
|     config.allow_tf32 = True | ||||
|     config.use_lora = True | ||||
|  | ||||
|     # pretrained model initialization | ||||
|     config.pretrained = pretrained = ml_collections.ConfigDict() | ||||
| @@ -20,7 +21,6 @@ def get_config(): | ||||
|     config.train = train = ml_collections.ConfigDict() | ||||
|     train.batch_size = 1 | ||||
|     train.use_8bit_adam = False | ||||
|     train.scale_lr = False | ||||
|     train.learning_rate = 1e-4 | ||||
|     train.adam_beta1 = 0.9 | ||||
|     train.adam_beta2 = 0.999 | ||||
| @@ -35,7 +35,7 @@ def get_config(): | ||||
|  | ||||
|     # sampling | ||||
|     config.sample = sample = ml_collections.ConfigDict() | ||||
|     sample.num_steps = 30 | ||||
|     sample.num_steps = 5 | ||||
|     sample.eta = 1.0 | ||||
|     sample.guidance_scale = 5.0 | ||||
|     sample.batch_size = 1 | ||||
|   | ||||
| @@ -4,16 +4,19 @@ from ddpo_pytorch.config import base | ||||
| def get_config(): | ||||
|     config = base.get_config() | ||||
|  | ||||
|     config.mixed_precision = "bf16" | ||||
|     config.mixed_precision = "no" | ||||
|     config.allow_tf32 = True | ||||
|     config.use_lora = False | ||||
|  | ||||
|     config.train.batch_size = 8 | ||||
|     config.train.gradient_accumulation_steps = 4 | ||||
|     config.train.batch_size = 4 | ||||
|     config.train.gradient_accumulation_steps = 8 | ||||
|     config.train.learning_rate = 1e-5 | ||||
|     config.train.clip_range = 1.0 | ||||
|  | ||||
|     # sampling | ||||
|     config.sample.num_steps = 50 | ||||
|     config.sample.batch_size = 8 | ||||
|     config.sample.num_batches_per_epoch = 4 | ||||
|     config.sample.batch_size = 16 | ||||
|     config.sample.num_batches_per_epoch = 2 | ||||
|  | ||||
|     config.per_prompt_stat_tracking = None | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user