Continue implementation
This commit is contained in:
		| @@ -25,9 +25,9 @@ def get_config(): | |||||||
|     train.learning_rate = 1e-4 |     train.learning_rate = 1e-4 | ||||||
|     train.adam_beta1 = 0.9 |     train.adam_beta1 = 0.9 | ||||||
|     train.adam_beta2 = 0.999 |     train.adam_beta2 = 0.999 | ||||||
|     train.adam_weight_decay = 1e-2 |     train.adam_weight_decay = 1e-4 | ||||||
|     train.adam_epsilon = 1e-8 |     train.adam_epsilon = 1e-8 | ||||||
|     train.gradient_accumulation_steps = 1 |     train.gradient_accumulation_steps = 32 | ||||||
|     train.max_grad_norm = 1.0 |     train.max_grad_norm = 1.0 | ||||||
|     train.num_inner_epochs = 1 |     train.num_inner_epochs = 1 | ||||||
|     train.cfg = True |     train.cfg = True | ||||||
| @@ -36,11 +36,11 @@ def get_config(): | |||||||
|  |  | ||||||
|     # sampling |     # sampling | ||||||
|     config.sample = sample = ml_collections.ConfigDict() |     config.sample = sample = ml_collections.ConfigDict() | ||||||
|     sample.num_steps = 5 |     sample.num_steps = 30 | ||||||
|     sample.eta = 1.0 |     sample.eta = 1.0 | ||||||
|     sample.guidance_scale = 5.0 |     sample.guidance_scale = 5.0 | ||||||
|     sample.batch_size = 1 |     sample.batch_size = 4 | ||||||
|     sample.num_batches_per_epoch = 4 |     sample.num_batches_per_epoch = 8 | ||||||
|  |  | ||||||
|     # prompting |     # prompting | ||||||
|     config.prompt_fn = "imagenet_animals" |     config.prompt_fn = "imagenet_animals" | ||||||
| @@ -50,7 +50,7 @@ def get_config(): | |||||||
|     config.reward_fn = "jpeg_compressibility" |     config.reward_fn = "jpeg_compressibility" | ||||||
|  |  | ||||||
|     config.per_prompt_stat_tracking = ml_collections.ConfigDict() |     config.per_prompt_stat_tracking = ml_collections.ConfigDict() | ||||||
|     config.per_prompt_stat_tracking.buffer_size = 128 |     config.per_prompt_stat_tracking.buffer_size = 64 | ||||||
|     config.per_prompt_stat_tracking.min_count = 16 |     config.per_prompt_stat_tracking.min_count = 16 | ||||||
|  |  | ||||||
|     return config |     return config | ||||||
| @@ -1,6 +1,9 @@ | |||||||
| # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py | # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py | ||||||
| # with the following modifications: | # with the following modifications: | ||||||
| # - | # - It computes and returns the log prob of `prev_sample` given the UNet prediction. | ||||||
|  | # - Instead of `variance_noise`, it takes `prev_sample` as an optional argument. If `prev_sample` is provided, | ||||||
|  | #   it uses it to compute the log prob. | ||||||
|  | # - Timesteps can be a batched torch.Tensor. | ||||||
|  |  | ||||||
| from typing import Optional, Tuple, Union | from typing import Optional, Tuple, Union | ||||||
|  |  | ||||||
| @@ -11,6 +14,19 @@ from diffusers.utils import randn_tensor | |||||||
| from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler | from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_variance(self, timestep, prev_timestep): | ||||||
|  |     alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) | ||||||
|  |     alpha_prod_t_prev = torch.where( | ||||||
|  |         prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod | ||||||
|  |     ).to(timestep.device) | ||||||
|  |     beta_prod_t = 1 - alpha_prod_t | ||||||
|  |     beta_prod_t_prev = 1 - alpha_prod_t_prev | ||||||
|  |  | ||||||
|  |     variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) | ||||||
|  |  | ||||||
|  |     return variance | ||||||
|  |  | ||||||
|  |  | ||||||
| def ddim_step_with_logprob( | def ddim_step_with_logprob( | ||||||
|     self: DDIMScheduler, |     self: DDIMScheduler, | ||||||
|     model_output: torch.FloatTensor, |     model_output: torch.FloatTensor, | ||||||
| @@ -66,16 +82,13 @@ def ddim_step_with_logprob( | |||||||
|  |  | ||||||
|     # 1. get previous step value (=t-1) |     # 1. get previous step value (=t-1) | ||||||
|     prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps |     prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | ||||||
|  |     prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) | ||||||
|  |  | ||||||
|     # 2. compute alphas, betas |     # 2. compute alphas, betas | ||||||
|     self.alphas_cumprod = self.alphas_cumprod.to(timestep.device) |     alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()).to(timestep.device) | ||||||
|     self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device) |     alpha_prod_t_prev = torch.where( | ||||||
|     alpha_prod_t = self.alphas_cumprod.gather(0, timestep) |         prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod | ||||||
|     alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod) |     ).to(timestep.device) | ||||||
|     print(timestep) |  | ||||||
|     print(alpha_prod_t) |  | ||||||
|     print(alpha_prod_t_prev) |  | ||||||
|     print(prev_timestep) |  | ||||||
|  |  | ||||||
|     beta_prod_t = 1 - alpha_prod_t |     beta_prod_t = 1 - alpha_prod_t | ||||||
|  |  | ||||||
| @@ -106,7 +119,7 @@ def ddim_step_with_logprob( | |||||||
|  |  | ||||||
|     # 5. compute variance: "sigma_t(η)" -> see formula (16) |     # 5. compute variance: "sigma_t(η)" -> see formula (16) | ||||||
|     # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) |     # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | ||||||
|     variance = self._get_variance(timestep, prev_timestep) |     variance = _get_variance(self, timestep, prev_timestep) | ||||||
|     std_dev_t = eta * variance ** (0.5) |     std_dev_t = eta * variance ** (0.5) | ||||||
|  |  | ||||||
|     if use_clipped_model_output: |     if use_clipped_model_output: | ||||||
|   | |||||||
| @@ -12,8 +12,12 @@ from ddpo_pytorch.stat_tracking import PerPromptStatTracker | |||||||
| from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob | from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob | ||||||
| from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob | from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob | ||||||
| import torch | import torch | ||||||
|  | import wandb | ||||||
|  | from functools import partial | ||||||
| import tqdm | import tqdm | ||||||
|  |  | ||||||
|  | tqdm = partial(tqdm.tqdm, dynamic_ncols=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| FLAGS = flags.FLAGS | FLAGS = flags.FLAGS | ||||||
| config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") | config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") | ||||||
| @@ -25,7 +29,7 @@ def main(_): | |||||||
|     # basic Accelerate and logging setup |     # basic Accelerate and logging setup | ||||||
|     config = FLAGS.config |     config = FLAGS.config | ||||||
|     accelerator = Accelerator( |     accelerator = Accelerator( | ||||||
|         log_with="all", |         log_with="wandb", | ||||||
|         mixed_precision=config.mixed_precision, |         mixed_precision=config.mixed_precision, | ||||||
|         project_dir=config.logdir, |         project_dir=config.logdir, | ||||||
|     ) |     ) | ||||||
| @@ -163,11 +167,12 @@ def main(_): | |||||||
|             config.per_prompt_stat_tracking.min_count, |             config.per_prompt_stat_tracking.min_count, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     global_step = 0 | ||||||
|     for epoch in range(config.num_epochs): |     for epoch in range(config.num_epochs): | ||||||
|         #################### SAMPLING #################### |         #################### SAMPLING #################### | ||||||
|         samples = [] |         samples = [] | ||||||
|         prompts = [] |         prompts = [] | ||||||
|         for i in tqdm.tqdm( |         for i in tqdm( | ||||||
|             range(config.sample.num_batches_per_epoch), |             range(config.sample.num_batches_per_epoch), | ||||||
|             desc=f"Epoch {epoch}: sampling", |             desc=f"Epoch {epoch}: sampling", | ||||||
|             disable=not accelerator.is_local_main_process, |             disable=not accelerator.is_local_main_process, | ||||||
| @@ -216,7 +221,7 @@ def main(_): | |||||||
|                     "latents": latents[:, :-1],  # each entry is the latent before timestep t |                     "latents": latents[:, :-1],  # each entry is the latent before timestep t | ||||||
|                     "next_latents": latents[:, 1:],  # each entry is the latent after timestep t |                     "next_latents": latents[:, 1:],  # each entry is the latent after timestep t | ||||||
|                     "log_probs": log_probs, |                     "log_probs": log_probs, | ||||||
|                     "rewards": torch.as_tensor(rewards), |                     "rewards": torch.as_tensor(rewards, device=accelerator.device), | ||||||
|                 } |                 } | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
| @@ -226,6 +231,13 @@ def main(_): | |||||||
|         # gather rewards across processes |         # gather rewards across processes | ||||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() |         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||||
|  |  | ||||||
|  |         # log sample-related stuff | ||||||
|  |         accelerator.log({"reward": rewards, "epoch": epoch}, step=global_step) | ||||||
|  |         accelerator.log( | ||||||
|  |             {"images": [wandb.Image(image, caption=prompt) for image, prompt in zip(images, prompts)]}, | ||||||
|  |             step=global_step, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # per-prompt mean/std tracking |         # per-prompt mean/std tracking | ||||||
|         if config.per_prompt_stat_tracking: |         if config.per_prompt_stat_tracking: | ||||||
|             # gather the prompts across processes |             # gather the prompts across processes | ||||||
| @@ -268,10 +280,11 @@ def main(_): | |||||||
|             samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] |             samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] | ||||||
|  |  | ||||||
|             # train |             # train | ||||||
|             for i, sample in tqdm.tqdm( |             for i, sample in tqdm( | ||||||
|                 list(enumerate(samples_batched)), |                 list(enumerate(samples_batched)), | ||||||
|                 desc=f"Outer epoch {epoch}, inner epoch {inner_epoch}: training", |                 desc=f"Epoch {epoch}.{inner_epoch}: training", | ||||||
|                 position=0, |                 position=0, | ||||||
|  |                 disable=not accelerator.is_local_main_process, | ||||||
|             ): |             ): | ||||||
|                 if config.train.cfg: |                 if config.train.cfg: | ||||||
|                     # concat negative prompts to sample prompts to avoid two forward passes |                     # concat negative prompts to sample prompts to avoid two forward passes | ||||||
| @@ -279,11 +292,12 @@ def main(_): | |||||||
|                 else: |                 else: | ||||||
|                     embeds = sample["prompt_embeds"] |                     embeds = sample["prompt_embeds"] | ||||||
|  |  | ||||||
|                 for j in tqdm.trange( |                 for j in tqdm( | ||||||
|                     num_timesteps, |                     range(num_timesteps), | ||||||
|                     desc=f"Timestep", |                     desc=f"Timestep", | ||||||
|                     position=1, |                     position=1, | ||||||
|                     leave=False, |                     leave=False, | ||||||
|  |                     disable=not accelerator.is_local_main_process, | ||||||
|                 ): |                 ): | ||||||
|                     with accelerator.accumulate(pipeline.unet): |                     with accelerator.accumulate(pipeline.unet): | ||||||
|                         if config.train.cfg: |                         if config.train.cfg: | ||||||
| @@ -311,7 +325,7 @@ def main(_): | |||||||
|  |  | ||||||
|                         # ppo logic |                         # ppo logic | ||||||
|                         advantages = torch.clamp( |                         advantages = torch.clamp( | ||||||
|                             sample["advantages"][:, j], -config.train.adv_clip_max, config.train.adv_clip_max |                             sample["advantages"], -config.train.adv_clip_max, config.train.adv_clip_max | ||||||
|                         ) |                         ) | ||||||
|                         ratio = torch.exp(log_prob - sample["log_probs"][:, j]) |                         ratio = torch.exp(log_prob - sample["log_probs"][:, j]) | ||||||
|                         unclipped_loss = -advantages * ratio |                         unclipped_loss = -advantages * ratio | ||||||
| @@ -326,9 +340,14 @@ def main(_): | |||||||
|                         # estimator, but most existing code uses this so... |                         # estimator, but most existing code uses this so... | ||||||
|                         # http://joschu.net/blog/kl-approx.html |                         # http://joschu.net/blog/kl-approx.html | ||||||
|                         info["approx_kl"] = 0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) |                         info["approx_kl"] = 0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) | ||||||
|                         info["clipfrac"] = torch.mean(torch.abs(ratio - 1.0) > config.train.clip_range) |                         info["clipfrac"] = torch.mean((torch.abs(ratio - 1.0) > config.train.clip_range).float()) | ||||||
|                         info["loss"] = loss |                         info["loss"] = loss | ||||||
|  |  | ||||||
|  |                         # log training-related stuff | ||||||
|  |                         info.update({"epoch": epoch, "inner_epoch": inner_epoch, "timestep": j}) | ||||||
|  |                         accelerator.log(info, step=global_step) | ||||||
|  |                         global_step += 1 | ||||||
|  |  | ||||||
|                         # backward pass |                         # backward pass | ||||||
|                         accelerator.backward(loss) |                         accelerator.backward(loss) | ||||||
|                         if accelerator.sync_gradients: |                         if accelerator.sync_gradients: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user