need to update the model
This commit is contained in:
		| @@ -239,7 +239,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|          |          | ||||||
|                    self.val_X_logp.compute(), self.val_E_logp.compute()] |                    self.val_X_logp.compute(), self.val_E_logp.compute()] | ||||||
|          |          | ||||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: |         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||||
|         print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", |         print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||||
|                 f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) |                 f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) | ||||||
|         with open("validation-metrics.csv", "a") as f: |         with open("validation-metrics.csv", "a") as f: | ||||||
|   | |||||||
| @@ -242,6 +242,12 @@ def test(cfg: DictConfig): | |||||||
|             optimizer.step() |             optimizer.step() | ||||||
|             optimizer.zero_grad() |             optimizer.zero_grad() | ||||||
|             # return {'loss': loss} |             # return {'loss': loss} | ||||||
|  |         if epoch % cfg.train.check_val_every_n_epoch == 0: | ||||||
|  |             print(f'print validation loss') | ||||||
|  |             graph_dit_model.eval() | ||||||
|  |             graph_dit_model.on_validation_epoch_start() | ||||||
|  |             graph_dit_model.validation_step(data, epoch) | ||||||
|  |             graph_dit_model.on_validation_epoch_end() | ||||||
|      |      | ||||||
|     # start testing |     # start testing | ||||||
|     print("start testing") |     print("start testing") | ||||||
| @@ -281,6 +287,53 @@ def test(cfg: DictConfig): | |||||||
|             reward = 1.0 |             reward = 1.0 | ||||||
|             rewards.append(reward) |             rewards.append(reward) | ||||||
|         return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) |         return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) | ||||||
|  |     # while samples_left_to_generate > 0: | ||||||
|  |     #     print(f'samples left to generate: {samples_left_to_generate}/' | ||||||
|  |     #         f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) | ||||||
|  |     #     bs = 1 * cfg.train.batch_size | ||||||
|  |     #     to_generate = min(samples_left_to_generate, bs) | ||||||
|  |     #     to_save = min(samples_left_to_save, bs) | ||||||
|  |     #     chains_save = min(chains_left_to_save, bs) | ||||||
|  |     #     # batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||||
|  |     #     batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||||
|  |  | ||||||
|  |     #     cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||||
|  |     #                                     keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||||
|  |     #     samples = samples + cur_sample | ||||||
|  |     #     reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) | ||||||
|  |  | ||||||
|  |     #     samples_with_log_probs.append((cur_sample, log_probs, reward)) | ||||||
|  |          | ||||||
|  |     #     all_ys.append(batch_y) | ||||||
|  |     #     batch_id += to_generate | ||||||
|  |  | ||||||
|  |     #     samples_left_to_save -= to_save | ||||||
|  |     #     samples_left_to_generate -= to_generate | ||||||
|  |     #     chains_left_to_save -= chains_save | ||||||
|  |          | ||||||
|  |     # print(f"final Computing sampling metrics...") | ||||||
|  |     # graph_dit_model.sampling_metrics.reset() | ||||||
|  |     # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) | ||||||
|  |     # graph_dit_model.sampling_metrics.reset() | ||||||
|  |     # print(f"Done.") | ||||||
|  |  | ||||||
|  |     # # save samples | ||||||
|  |     # print("Samples:") | ||||||
|  |     # print(samples) | ||||||
|  |  | ||||||
|  |     # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) | ||||||
|  |     # samples, log_probs, rewards = samples_with_log_probs[perm] | ||||||
|  |     # samples = list(samples) | ||||||
|  |     # log_probs = list(log_probs) | ||||||
|  |     # for i in range(len(log_probs)): | ||||||
|  |     #     log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) | ||||||
|  |     # print(f'log_probs: {log_probs[:5]}') | ||||||
|  |     # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) | ||||||
|  |     # rewards = list(rewards) | ||||||
|  |     # log_probs = torch.cat(log_probs, dim=0) | ||||||
|  |     # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) | ||||||
|  |     # old_log_probs = log_probs.clone() | ||||||
|  |     old_log_probs = None | ||||||
|     while samples_left_to_generate > 0: |     while samples_left_to_generate > 0: | ||||||
|         print(f'samples left to generate: {samples_left_to_generate}/' |         print(f'samples left to generate: {samples_left_to_generate}/' | ||||||
|             f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) |             f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) | ||||||
| @@ -289,14 +342,34 @@ def test(cfg: DictConfig): | |||||||
|         to_save = min(samples_left_to_save, bs) |         to_save = min(samples_left_to_save, bs) | ||||||
|         chains_save = min(chains_left_to_save, bs) |         chains_save = min(chains_left_to_save, bs) | ||||||
|         # batch_y = test_y_collection[batch_id : batch_id + to_generate] |         # batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||||
|  |  | ||||||
|  |         # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||||
|  |  | ||||||
|  |         # cur_sample, old_log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||||
|  |                                         # keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||||
|  |         # samples = samples + cur_sample | ||||||
|  |         # reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) | ||||||
|  |         # advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||||
|  |         with accelerator.accumulate(graph_dit_model): | ||||||
|             batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) |             batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||||
|  |             new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||||
|  |             samples = samples + new_samples | ||||||
|  |             reward = graph_reward_fn(new_samples, device=graph_dit_model.device) | ||||||
|  |             advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||||
|  |             if old_log_probs is None: | ||||||
|  |                 old_log_probs = log_probs.clone() | ||||||
|  |             ratio = torch.exp(log_probs - old_log_probs) | ||||||
|  |             unclipped_loss = -advantages * ratio | ||||||
|  |             clipped_loss = -advantages * torch.clamp(ratio, | ||||||
|  |                             1.0 - cfg.ppo.clip_param, | ||||||
|  |                             1.0 + cfg.ppo.clip_param) | ||||||
|  |             loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||||
|  |             accelerator.backward(loss) | ||||||
|  |             optimizer.step() | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |  | ||||||
|         cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, |  | ||||||
|                                         keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) |  | ||||||
|         samples = samples + cur_sample |  | ||||||
|         reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) |  | ||||||
|  |  | ||||||
|         samples_with_log_probs.append((cur_sample, log_probs, reward)) |         samples_with_log_probs.append((new_samples, log_probs, reward)) | ||||||
|          |          | ||||||
|         all_ys.append(batch_y) |         all_ys.append(batch_y) | ||||||
|         batch_id += to_generate |         batch_id += to_generate | ||||||
| @@ -304,6 +377,7 @@ def test(cfg: DictConfig): | |||||||
|         samples_left_to_save -= to_save |         samples_left_to_save -= to_save | ||||||
|         samples_left_to_generate -= to_generate |         samples_left_to_generate -= to_generate | ||||||
|         chains_left_to_save -= chains_save |         chains_left_to_save -= chains_save | ||||||
|  |         # break | ||||||
|          |          | ||||||
|     print(f"final Computing sampling metrics...") |     print(f"final Computing sampling metrics...") | ||||||
|     graph_dit_model.sampling_metrics.reset() |     graph_dit_model.sampling_metrics.reset() | ||||||
| @@ -315,47 +389,46 @@ def test(cfg: DictConfig): | |||||||
|     print("Samples:") |     print("Samples:") | ||||||
|     print(samples) |     print(samples) | ||||||
|  |  | ||||||
|     perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) |     # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) | ||||||
|     samples, log_probs, rewards = samples_with_log_probs[perm] |     # samples, log_probs, rewards = samples_with_log_probs[perm] | ||||||
|     samples = list(samples) |     # samples = list(samples) | ||||||
|     log_probs = list(log_probs) |     # log_probs = list(log_probs) | ||||||
|     for i in range(len(log_probs)): |     # for i in range(len(log_probs)): | ||||||
|         log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) |     #     log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) | ||||||
|     print(f'log_probs: {log_probs[:5]}') |     # print(f'log_probs: {log_probs[:5]}') | ||||||
|     print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) |     # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) | ||||||
|     rewards = list(rewards) |     # rewards = list(rewards) | ||||||
|     log_probs = torch.cat(log_probs, dim=0) |     # log_probs = torch.cat(log_probs, dim=0) | ||||||
|     print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) |     # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) | ||||||
|     old_log_probs = log_probs.clone() |     # old_log_probs = log_probs.clone() | ||||||
|  |     # # multi metrics range | ||||||
|  |     # # reward hacking hiking | ||||||
|  |     # for inner_epoch in range(cfg.train.n_epochs): | ||||||
|  |     #     # print(f'rewards: {rewards.shape}') # torch.Size([1000]) | ||||||
|  |     #     print(f'rewards: {rewards[:5]}') | ||||||
|  |     #     print(f'len rewards: {len(rewards)}') | ||||||
|  |     #     print(f'type rewards: {type(rewards)}') | ||||||
|  |     #     if len(rewards) > 1 and isinstance(rewards, list): | ||||||
|  |     #         rewards = torch.cat(rewards, dim=0) | ||||||
|  |     #     elif len(rewards) == 1 and isinstance(rewards, list): | ||||||
|  |     #         rewards = rewards[0] | ||||||
|  |     #     # print(f'rewards: {rewards.shape}') | ||||||
|  |     #     advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) | ||||||
|  |     #     print(f'advantages: {advantages.shape}') | ||||||
|  |     #     with accelerator.accumulate(graph_dit_model): | ||||||
|  |     #         ratio = torch.exp(log_probs - old_log_probs) | ||||||
|  |     #         unclipped_loss = -advantages * ratio | ||||||
|  |     #         # z-score normalization | ||||||
|  |     #         clipped_loss = -advantages * torch.clamp(ratio, | ||||||
|  |     #                         1.0 - cfg.ppo.clip_param, | ||||||
|  |     #                         1.0 + cfg.ppo.clip_param) | ||||||
|  |     #         loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||||
|  |     #         accelerator.backward(loss) | ||||||
|  |     #         optimizer.step() | ||||||
|  |     #         optimizer.zero_grad() | ||||||
|  |  | ||||||
|     # multi metrics range |     #     accelerator.log({"loss": loss.item(), "epoch": inner_epoch}) | ||||||
|     # reward hacking hiking |     #     print(f"loss: {loss.item()}, epoch: {inner_epoch}") | ||||||
|     for inner_epoch in range(cfg.train.n_epochs): |  | ||||||
|         # print(f'rewards: {rewards.shape}') # torch.Size([1000]) |  | ||||||
|         print(f'rewards: {rewards[:5]}') |  | ||||||
|         print(f'len rewards: {len(rewards)}') |  | ||||||
|         print(f'type rewards: {type(rewards)}') |  | ||||||
|         if len(rewards) > 1 and isinstance(rewards, list): |  | ||||||
|             rewards = torch.cat(rewards, dim=0) |  | ||||||
|         elif len(rewards) == 1 and isinstance(rewards, list): |  | ||||||
|             rewards = rewards[0] |  | ||||||
|         # print(f'rewards: {rewards.shape}') |  | ||||||
|         advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) |  | ||||||
|         print(f'advantages: {advantages.shape}') |  | ||||||
|         with accelerator.accumulate(graph_dit_model): |  | ||||||
|             ratio = torch.exp(log_probs - old_log_probs) |  | ||||||
|             unclipped_loss = -advantages * ratio |  | ||||||
|             # z-score normalization |  | ||||||
|             clipped_loss = -advantages * torch.clamp(ratio, |  | ||||||
|                             1.0 - cfg.ppo.clip_param, |  | ||||||
|                             1.0 + cfg.ppo.clip_param) |  | ||||||
|             loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) |  | ||||||
|             accelerator.backward(loss) |  | ||||||
|             optimizer.step() |  | ||||||
|             optimizer.zero_grad() |  | ||||||
|  |  | ||||||
|         accelerator.log({"loss": loss.item(), "epoch": inner_epoch}) |  | ||||||
|         print(f"loss: {loss.item()}, epoch: {inner_epoch}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     # trainer = Trainer( |     # trainer = Trainer( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user