Working on DGX
This commit is contained in:
		
							
								
								
									
										55
									
								
								ddpo_pytorch/config/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								ddpo_pytorch/config/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| import ml_collections | ||||
|  | ||||
| def get_config(): | ||||
|  | ||||
|     config = ml_collections.ConfigDict() | ||||
|  | ||||
|     # misc | ||||
|     config.seed = 42 | ||||
|     config.logdir = "logs" | ||||
|     config.num_epochs = 100 | ||||
|     config.mixed_precision = "fp16" | ||||
|     config.allow_tf32 = True | ||||
|  | ||||
|     # pretrained model initialization | ||||
|     config.pretrained = pretrained = ml_collections.ConfigDict() | ||||
|     pretrained.model = "runwayml/stable-diffusion-v1-5" | ||||
|     pretrained.revision = "main" | ||||
|  | ||||
|     # training | ||||
|     config.train = train = ml_collections.ConfigDict() | ||||
|     train.batch_size = 1 | ||||
|     train.use_8bit_adam = False | ||||
|     train.scale_lr = False | ||||
|     train.learning_rate = 1e-4 | ||||
|     train.adam_beta1 = 0.9 | ||||
|     train.adam_beta2 = 0.999 | ||||
|     train.adam_weight_decay = 1e-4 | ||||
|     train.adam_epsilon = 1e-8 | ||||
|     train.gradient_accumulation_steps = 1 | ||||
|     train.max_grad_norm = 1.0 | ||||
|     train.num_inner_epochs = 1 | ||||
|     train.cfg = True | ||||
|     train.adv_clip_max = 10 | ||||
|     train.clip_range = 1e-4 | ||||
|  | ||||
|     # sampling | ||||
|     config.sample = sample = ml_collections.ConfigDict() | ||||
|     sample.num_steps = 30 | ||||
|     sample.eta = 1.0 | ||||
|     sample.guidance_scale = 5.0 | ||||
|     sample.batch_size = 1 | ||||
|     sample.num_batches_per_epoch = 1 | ||||
|  | ||||
|     # prompting | ||||
|     config.prompt_fn = "imagenet_animals" | ||||
|     config.prompt_fn_kwargs = {} | ||||
|  | ||||
|     # rewards | ||||
|     config.reward_fn = "jpeg_compressibility" | ||||
|  | ||||
|     config.per_prompt_stat_tracking = ml_collections.ConfigDict() | ||||
|     config.per_prompt_stat_tracking.buffer_size = 64 | ||||
|     config.per_prompt_stat_tracking.min_count = 16 | ||||
|  | ||||
|     return config | ||||
							
								
								
									
										20
									
								
								ddpo_pytorch/config/dgx.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								ddpo_pytorch/config/dgx.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| import ml_collections | ||||
| from ddpo_pytorch.config import base | ||||
|  | ||||
| def get_config(): | ||||
|     config = base.get_config() | ||||
|  | ||||
|     config.mixed_precision = "bf16" | ||||
|     config.allow_tf32 = True | ||||
|  | ||||
|     config.train.batch_size = 8 | ||||
|     config.train.gradient_accumulation_steps = 4 | ||||
|  | ||||
|     # sampling | ||||
|     config.sample.num_steps = 50 | ||||
|     config.sample.batch_size = 8 | ||||
|     config.sample.num_batches_per_epoch = 4 | ||||
|  | ||||
|     config.per_prompt_stat_tracking = None | ||||
|  | ||||
|     return config | ||||
| @@ -14,6 +14,11 @@ from diffusers.utils import randn_tensor | ||||
| from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler | ||||
|  | ||||
|  | ||||
| def _left_broadcast(t, shape): | ||||
|     assert t.ndim <= len(shape) | ||||
|     return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape) | ||||
|  | ||||
|  | ||||
| def _get_variance(self, timestep, prev_timestep): | ||||
|     alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) | ||||
|     alpha_prod_t_prev = torch.where( | ||||
| @@ -82,13 +87,16 @@ def ddim_step_with_logprob( | ||||
|  | ||||
|     # 1. get previous step value (=t-1) | ||||
|     prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | ||||
|     # to prevent OOB on gather | ||||
|     prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) | ||||
|  | ||||
|     # 2. compute alphas, betas | ||||
|     alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()).to(timestep.device) | ||||
|     alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()) | ||||
|     alpha_prod_t_prev = torch.where( | ||||
|         prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod | ||||
|     ).to(timestep.device) | ||||
|     ) | ||||
|     alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device) | ||||
|     alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device) | ||||
|  | ||||
|     beta_prod_t = 1 - alpha_prod_t | ||||
|  | ||||
| @@ -121,6 +129,7 @@ def ddim_step_with_logprob( | ||||
|     # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | ||||
|     variance = _get_variance(self, timestep, prev_timestep) | ||||
|     std_dev_t = eta * variance ** (0.5) | ||||
|     std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device) | ||||
|  | ||||
|     if use_clipped_model_output: | ||||
|         # the pred_epsilon is always re-derived from the clipped x_0 in Glide | ||||
| @@ -153,4 +162,4 @@ def ddim_step_with_logprob( | ||||
|     # mean along all but batch dimension | ||||
|     log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | ||||
|  | ||||
|     return prev_sample, log_prob | ||||
|     return prev_sample.type(sample.dtype), log_prob | ||||
|   | ||||
		Reference in New Issue
	
	Block a user