Minor changes; add train_timestep_fraction
This commit is contained in:
		| @@ -32,14 +32,15 @@ def get_config(): | |||||||
|     train.cfg = True |     train.cfg = True | ||||||
|     train.adv_clip_max = 10 |     train.adv_clip_max = 10 | ||||||
|     train.clip_range = 1e-4 |     train.clip_range = 1e-4 | ||||||
|  |     train.timestep_fraction = 1.0 | ||||||
|  |  | ||||||
|     # sampling |     # sampling | ||||||
|     config.sample = sample = ml_collections.ConfigDict() |     config.sample = sample = ml_collections.ConfigDict() | ||||||
|     sample.num_steps = 5 |     sample.num_steps = 10 | ||||||
|     sample.eta = 1.0 |     sample.eta = 1.0 | ||||||
|     sample.guidance_scale = 5.0 |     sample.guidance_scale = 5.0 | ||||||
|     sample.batch_size = 1 |     sample.batch_size = 1 | ||||||
|     sample.num_batches_per_epoch = 1 |     sample.num_batches_per_epoch = 2 | ||||||
|  |  | ||||||
|     # prompting |     # prompting | ||||||
|     config.prompt_fn = "imagenet_animals" |     config.prompt_fn = "imagenet_animals" | ||||||
| @@ -49,7 +50,7 @@ def get_config(): | |||||||
|     config.reward_fn = "jpeg_compressibility" |     config.reward_fn = "jpeg_compressibility" | ||||||
|  |  | ||||||
|     config.per_prompt_stat_tracking = ml_collections.ConfigDict() |     config.per_prompt_stat_tracking = ml_collections.ConfigDict() | ||||||
|     config.per_prompt_stat_tracking.buffer_size = 64 |     config.per_prompt_stat_tracking.buffer_size = 16 | ||||||
|     config.per_prompt_stat_tracking.min_count = 16 |     config.per_prompt_stat_tracking.min_count = 16 | ||||||
|  |  | ||||||
|     return config |     return config | ||||||
| @@ -8,20 +8,25 @@ base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py" | |||||||
| def get_config(): | def get_config(): | ||||||
|     config = base.get_config() |     config = base.get_config() | ||||||
|  |  | ||||||
|     config.mixed_precision = "no" |     config.pretrained.model = "runwayml/stable-diffusion-v1-5" | ||||||
|  |  | ||||||
|  |     config.mixed_precision = "fp16" | ||||||
|     config.allow_tf32 = True |     config.allow_tf32 = True | ||||||
|     config.use_lora = False |     config.use_lora = False | ||||||
|  |  | ||||||
|     config.train.batch_size = 4 |     config.train.batch_size = 4 | ||||||
|     config.train.gradient_accumulation_steps = 8 |     config.train.gradient_accumulation_steps = 2 | ||||||
|     config.train.learning_rate = 1e-5 |     config.train.learning_rate = 3e-5 | ||||||
|     config.train.clip_range = 1.0 |     config.train.clip_range = 1e-4 | ||||||
|  |  | ||||||
|     # sampling |     # sampling | ||||||
|     config.sample.num_steps = 50 |     config.sample.num_steps = 50 | ||||||
|     config.sample.batch_size = 16 |     config.sample.batch_size = 8 | ||||||
|     config.sample.num_batches_per_epoch = 2 |     config.sample.num_batches_per_epoch = 4 | ||||||
|  |  | ||||||
|     config.per_prompt_stat_tracking = None |     config.per_prompt_stat_tracking = { | ||||||
|  |         "buffer_size": 16, | ||||||
|  |         "min_count": 16, | ||||||
|  |     } | ||||||
|  |  | ||||||
|     return config |     return config | ||||||
|   | |||||||
| @@ -14,18 +14,15 @@ class MLP(nn.Module): | |||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.layers = nn.Sequential( |         self.layers = nn.Sequential( | ||||||
|             nn.Linear(768, 1024), |             nn.Linear(768, 1024), | ||||||
|             nn.Identity(), |             nn.Dropout(0.2), | ||||||
|             nn.Linear(1024, 128), |             nn.Linear(1024, 128), | ||||||
|             nn.Identity(), |             nn.Dropout(0.2), | ||||||
|             nn.Linear(128, 64), |             nn.Linear(128, 64), | ||||||
|             nn.Identity(), |             nn.Dropout(0.1), | ||||||
|             nn.Linear(64, 16), |             nn.Linear(64, 16), | ||||||
|             nn.Linear(16, 1), |             nn.Linear(16, 1), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) |  | ||||||
|         self.load_state_dict(state_dict) |  | ||||||
|  |  | ||||||
|     @torch.no_grad() |     @torch.no_grad() | ||||||
|     def forward(self, embed): |     def forward(self, embed): | ||||||
|         return self.layers(embed) |         return self.layers(embed) | ||||||
| @@ -37,6 +34,9 @@ class AestheticScorer(torch.nn.Module): | |||||||
|         self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |         self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||||||
|         self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |         self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | ||||||
|         self.mlp = MLP() |         self.mlp = MLP() | ||||||
|  |         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) | ||||||
|  |         self.mlp.load_state_dict(state_dict) | ||||||
|  |         self.eval() | ||||||
|  |  | ||||||
|     @torch.no_grad() |     @torch.no_grad() | ||||||
|     def __call__(self, images): |     def __call__(self, images): | ||||||
| @@ -44,5 +44,5 @@ class AestheticScorer(torch.nn.Module): | |||||||
|         inputs = {k: v.cuda() for k, v in inputs.items()} |         inputs = {k: v.cuda() for k, v in inputs.items()} | ||||||
|         embed = self.clip.get_image_features(**inputs) |         embed = self.clip.get_image_features(**inputs) | ||||||
|         # normalize embedding |         # normalize embedding | ||||||
|         embed = embed / embed.norm(dim=-1, keepdim=True) |         embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) | ||||||
|         return self.mlp(embed) |         return self.mlp(embed) | ||||||
|   | |||||||
| @@ -35,8 +35,6 @@ def aesthetic_score(): | |||||||
|     scorer = AestheticScorer().cuda() |     scorer = AestheticScorer().cuda() | ||||||
|  |  | ||||||
|     def _fn(images, prompts, metadata): |     def _fn(images, prompts, metadata): | ||||||
|         if not isinstance(images, torch.Tensor): |  | ||||||
|             images = torch.as_tensor(images) |  | ||||||
|         scores = scorer(images) |         scores = scorer(images) | ||||||
|         return scores, {} |         return scores, {} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -34,17 +34,23 @@ logger = get_logger(__name__) | |||||||
| def main(_): | def main(_): | ||||||
|     # basic Accelerate and logging setup |     # basic Accelerate and logging setup | ||||||
|     config = FLAGS.config |     config = FLAGS.config | ||||||
|  |  | ||||||
|  |     # number of timesteps within each trajectory to train on | ||||||
|  |     num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) | ||||||
|  |  | ||||||
|     accelerator = Accelerator( |     accelerator = Accelerator( | ||||||
|         log_with="wandb", |         log_with="wandb", | ||||||
|         mixed_precision=config.mixed_precision, |         mixed_precision=config.mixed_precision, | ||||||
|         project_dir=config.logdir, |         project_dir=config.logdir, | ||||||
|         gradient_accumulation_steps=config.train.gradient_accumulation_steps * config.sample.num_steps, |         # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of | ||||||
|  |         # _samples_ to accumulate across | ||||||
|  |         gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, | ||||||
|     ) |     ) | ||||||
|     if accelerator.is_main_process: |     if accelerator.is_main_process: | ||||||
|         accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) |         accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) | ||||||
|     logger.info(f"\n{config}") |     logger.info(f"\n{config}") | ||||||
|  |  | ||||||
|     # set seed |     # set seed (device_specific is very important to get different prompts on different devices) | ||||||
|     set_seed(config.seed, device_specific=True) |     set_seed(config.seed, device_specific=True) | ||||||
|  |  | ||||||
|     # load scheduler, tokenizer and models. |     # load scheduler, tokenizer and models. | ||||||
| @@ -152,7 +158,8 @@ def main(_): | |||||||
|             config.per_prompt_stat_tracking.min_count, |             config.per_prompt_stat_tracking.min_count, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     # for some reason, autocast is necessary for non-lora training but for lora training it uses more memory |     # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses | ||||||
|  |     # more memory | ||||||
|     autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast |     autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast | ||||||
|  |  | ||||||
|     # Prepare everything with our `accelerator`. |     # Prepare everything with our `accelerator`. | ||||||
| @@ -289,8 +296,15 @@ def main(_): | |||||||
|         #################### TRAINING #################### |         #################### TRAINING #################### | ||||||
|         for inner_epoch in range(config.train.num_inner_epochs): |         for inner_epoch in range(config.train.num_inner_epochs): | ||||||
|             # shuffle samples along batch dimension |             # shuffle samples along batch dimension | ||||||
|             indices = torch.randperm(total_batch_size, device=accelerator.device) |             perm = torch.randperm(total_batch_size, device=accelerator.device) | ||||||
|             samples = {k: v[indices] for k, v in samples.items()} |             samples = {k: v[perm] for k, v in samples.items()} | ||||||
|  |  | ||||||
|  |             # shuffle along time dimension independently for each sample | ||||||
|  |             perms = torch.stack( | ||||||
|  |                 [torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)] | ||||||
|  |             ) | ||||||
|  |             for key in ["timesteps", "latents", "next_latents", "log_probs"]: | ||||||
|  |                 samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms] | ||||||
|  |  | ||||||
|             # rebatch for training |             # rebatch for training | ||||||
|             samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} |             samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} | ||||||
| @@ -300,6 +314,7 @@ def main(_): | |||||||
|  |  | ||||||
|             # train |             # train | ||||||
|             pipeline.unet.train() |             pipeline.unet.train() | ||||||
|  |             info = defaultdict(list) | ||||||
|             for i, sample in tqdm( |             for i, sample in tqdm( | ||||||
|                 list(enumerate(samples_batched)), |                 list(enumerate(samples_batched)), | ||||||
|                 desc=f"Epoch {epoch}.{inner_epoch}: training", |                 desc=f"Epoch {epoch}.{inner_epoch}: training", | ||||||
| @@ -312,9 +327,8 @@ def main(_): | |||||||
|                 else: |                 else: | ||||||
|                     embeds = sample["prompt_embeds"] |                     embeds = sample["prompt_embeds"] | ||||||
|  |  | ||||||
|                 info = defaultdict(list) |  | ||||||
|                 for j in tqdm( |                 for j in tqdm( | ||||||
|                     range(num_timesteps), |                     range(num_train_timesteps), | ||||||
|                     desc="Timestep", |                     desc="Timestep", | ||||||
|                     position=1, |                     position=1, | ||||||
|                     leave=False, |                     leave=False, | ||||||
| @@ -371,14 +385,20 @@ def main(_): | |||||||
|                         optimizer.step() |                         optimizer.step() | ||||||
|                         optimizer.zero_grad() |                         optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |                     # Checks if the accelerator has performed an optimization step behind the scenes | ||||||
|                     if accelerator.sync_gradients: |                     if accelerator.sync_gradients: | ||||||
|  |                         assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0 | ||||||
|                         # log training-related stuff |                         # log training-related stuff | ||||||
|                         info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} |                         info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} | ||||||
|  |                         info = accelerator.reduce(info, reduction="mean") | ||||||
|                         info.update({"epoch": epoch, "inner_epoch": inner_epoch}) |                         info.update({"epoch": epoch, "inner_epoch": inner_epoch}) | ||||||
|                         accelerator.log(info, step=global_step) |                         accelerator.log(info, step=global_step) | ||||||
|                         global_step += 1 |                         global_step += 1 | ||||||
|                         info = defaultdict(list) |                         info = defaultdict(list) | ||||||
|  |  | ||||||
|  |             # make sure we did an optimization step at the end of the inner epoch | ||||||
|  |             assert accelerator.sync_gradients | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     app.run(main) |     app.run(main) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user