try to get the original perf
This commit is contained in:
		| @@ -195,15 +195,18 @@ class Graph_DiT(pl.LightningModule): | ||||
|         # print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) | ||||
|  | ||||
|     def on_train_epoch_start(self) -> None: | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) | ||||
|         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             # print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) | ||||
|             print("Starting train epoch {}/{}...".format(self.current_epoch, self.cfg.train.n_epochs)) | ||||
|         self.start_epoch_time = time.time() | ||||
|         self.train_loss.reset() | ||||
|         self.train_metrics.reset() | ||||
|  | ||||
|     def on_train_epoch_end(self) -> None: | ||||
|  | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             log = True | ||||
|         else: | ||||
|             log = False | ||||
| @@ -601,8 +604,8 @@ class Graph_DiT(pl.LightningModule): | ||||
|  | ||||
|         assert (E == torch.transpose(E, 1, 2)).all() | ||||
|  | ||||
|         # total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) | ||||
|         total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) | ||||
|         # total_log_probs = torch.zeros([self.cfg.general.samples_to_generate,10], device=self.device) | ||||
|  | ||||
|         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | ||||
|         for s_int in reversed(range(0, self.T)): | ||||
|   | ||||
| @@ -161,7 +161,8 @@ def test(cfg: DictConfig): | ||||
|     accelerator = Accelerator( | ||||
|         mixed_precision='no', | ||||
|         project_config=accelerator_config, | ||||
|         gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, | ||||
|         # gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, | ||||
|         gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,  | ||||
|     ) | ||||
|  | ||||
|     # Debug: 确认可用设备 | ||||
| @@ -219,29 +220,34 @@ def test(cfg: DictConfig): | ||||
|     for epoch in range(cfg.train.n_epochs): | ||||
|         graph_dit_model.train()  # 设置模型为训练模式 | ||||
|         print(f"Epoch {epoch}", end="\n") | ||||
|         graph_dit_model.on_train_epoch_start() | ||||
|         for data in train_dataloader:  # 从数据加载器中获取一个批次的数据 | ||||
|             data.to(accelerator.device) | ||||
|             data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||
|             data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|             dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||
|             dense_data = dense_data.mask(node_mask) | ||||
|             X, E = dense_data.X, dense_data.E | ||||
|             noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) | ||||
|             pred = graph_dit_model.forward(noisy_data) | ||||
|             loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, | ||||
|                                 true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, | ||||
|                                 log=epoch % graph_dit_model.log_every_steps == 0) | ||||
|             # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') | ||||
|             graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, | ||||
|                             log=epoch % graph_dit_model.log_every_steps == 0) | ||||
|             graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) | ||||
|             print(f"training loss: {loss}") | ||||
|             with open("training-loss.csv", "a") as f: | ||||
|                 f.write(f"{loss}, {epoch}\n") | ||||
|             # data.to(accelerator.device) | ||||
|             # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||
|             # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|             # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||
|             # dense_data = dense_data.mask(node_mask) | ||||
|             # X, E = dense_data.X, dense_data.E | ||||
|             # noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) | ||||
|             # pred = graph_dit_model.forward(noisy_data) | ||||
|             # loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, | ||||
|             #                     true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, | ||||
|             #                     log=epoch % graph_dit_model.log_every_steps == 0) | ||||
|             # # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') | ||||
|             # graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, | ||||
|             #                 log=epoch % graph_dit_model.log_every_steps == 0) | ||||
|             # graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) | ||||
|             # print(f"training loss: {loss}") | ||||
|             # with open("training-loss.csv", "a") as f: | ||||
|             #     f.write(f"{loss}, {epoch}\n") | ||||
|             loss = graph_dit_model.training_step(data, epoch) | ||||
|             loss = loss['loss'] | ||||
|  | ||||
|             accelerator.backward(loss) | ||||
|             optimizer.step() | ||||
|             optimizer.zero_grad() | ||||
|             # return {'loss': loss} | ||||
|         graph_dit_model.on_train_epoch_end() | ||||
|         if epoch % cfg.train.check_val_every_n_epoch == 0: | ||||
|             print(f'print validation loss') | ||||
|             graph_dit_model.eval() | ||||
| @@ -253,126 +259,69 @@ def test(cfg: DictConfig): | ||||
|     print("start testing") | ||||
|     graph_dit_model.eval() | ||||
|     test_dataloader = accelerator.prepare(datamodule.test_dataloader()) | ||||
|     graph_dit_model.on_test_epoch_start() | ||||
|     for data in test_dataloader: | ||||
|         data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|         nll = graph_dit_model.test_step(data, epoch) | ||||
|         # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||
|         # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|  | ||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||
|         dense_data = dense_data.mask(node_mask) | ||||
|         noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||
|         pred = graph_dit_model.forward(noisy_data) | ||||
|         nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) | ||||
|         graph_dit_model.test_y_collection.append(data.y) | ||||
|         # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||
|         # dense_data = dense_data.mask(node_mask) | ||||
|         # noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||
|         # pred = graph_dit_model.forward(noisy_data) | ||||
|         # nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) | ||||
|         # graph_dit_model.test_y_collection.append(data.y) | ||||
|         print(f'test loss: {nll}') | ||||
|      | ||||
|     graph_dit_model.on_test_epoch_end() | ||||
|  | ||||
|     # start sampling | ||||
|  | ||||
|     samples_left_to_generate = cfg.general.final_model_samples_to_generate | ||||
|     samples_left_to_save = cfg.general.final_model_samples_to_save | ||||
|     chains_left_to_save = cfg.general.final_model_chains_to_save | ||||
|     # samples_left_to_generate = cfg.general.final_model_samples_to_generate | ||||
|     # samples_left_to_save = cfg.general.final_model_samples_to_save | ||||
|     # chains_left_to_save = cfg.general.final_model_chains_to_save | ||||
|  | ||||
|     samples, all_ys, batch_id = [], [], 0 | ||||
|     samples_with_log_probs = [] | ||||
|     test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) | ||||
|     num_examples = test_y_collection.size(0) | ||||
|     if cfg.general.final_model_samples_to_generate > num_examples: | ||||
|         ratio = cfg.general.final_model_samples_to_generate // num_examples | ||||
|         test_y_collection = test_y_collection.repeat(ratio+1, 1) | ||||
|         num_examples = test_y_collection.size(0) | ||||
|     # samples, all_ys, batch_id = [], [], 0 | ||||
|     # samples_with_log_probs = [] | ||||
|     # test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) | ||||
|     # num_examples = test_y_collection.size(0) | ||||
|     # if cfg.general.final_model_samples_to_generate > num_examples: | ||||
|     #     ratio = cfg.general.final_model_samples_to_generate // num_examples | ||||
|     #     test_y_collection = test_y_collection.repeat(ratio+1, 1) | ||||
|     #     num_examples = test_y_collection.size(0) | ||||
|      | ||||
|     # Normal reward function | ||||
|     from nas_201_api import NASBench201API as API | ||||
|     api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||
|     def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): | ||||
|         rewards = [] | ||||
|         if reward_model == 'swap': | ||||
|             import csv | ||||
|             with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||
|                 reader = csv.reader(f) | ||||
|                 header = next(reader) | ||||
|                 data = [row for row in reader] | ||||
|                 swap_scores = [float(row[0]) for row in data] | ||||
|                 for graph in graphs: | ||||
|                     node_tensor = graph[0] | ||||
|                     node = node_tensor.cpu().numpy().tolist() | ||||
|     # from nas_201_api import NASBench201API as API | ||||
|     # api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||
|     # def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): | ||||
|     #     rewards = [] | ||||
|     #     if reward_model == 'swap': | ||||
|     #         import csv | ||||
|     #         with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||
|     #             reader = csv.reader(f) | ||||
|     #             header = next(reader) | ||||
|     #             data = [row for row in reader] | ||||
|     #             swap_scores = [float(row[0]) for row in data] | ||||
|     #             for graph in graphs: | ||||
|     #                 node_tensor = graph[0] | ||||
|     #                 node = node_tensor.cpu().numpy().tolist() | ||||
|  | ||||
|                     def nodes_to_arch_str(nodes): | ||||
|                         num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
|                         nodes_str = [num_to_op[node] for node in nodes] | ||||
|                         arch_str = '|' + nodes_str[1] + '~0|+' + \ | ||||
|                                 '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ | ||||
|                                 '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'  | ||||
|                         return arch_str | ||||
|     #                 def nodes_to_arch_str(nodes): | ||||
|     #                     num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
|     #                     nodes_str = [num_to_op[node] for node in nodes] | ||||
|     #                     arch_str = '|' + nodes_str[1] + '~0|+' + \ | ||||
|     #                             '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ | ||||
|     #                             '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'  | ||||
|     #                     return arch_str | ||||
|                      | ||||
|                     arch_str = nodes_to_arch_str(node) | ||||
|                     reward = swap_scores[api.query_index_by_arch(arch_str)] | ||||
|                     rewards.append(reward) | ||||
|     #                 arch_str = nodes_to_arch_str(node) | ||||
|     #                 reward = swap_scores[api.query_index_by_arch(arch_str)] | ||||
|     #                 rewards.append(reward) | ||||
|                  | ||||
|         # for graph in graphs: | ||||
|         #     reward = 1.0 | ||||
|         #     rewards.append(reward) | ||||
|         return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) | ||||
|     old_log_probs = None | ||||
|     while samples_left_to_generate > 0: | ||||
|         print(f'samples left to generate: {samples_left_to_generate}/' | ||||
|             f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) | ||||
|         bs = 1 * cfg.train.batch_size | ||||
|         to_generate = min(samples_left_to_generate, bs) | ||||
|         to_save = min(samples_left_to_save, bs) | ||||
|         chains_save = min(chains_left_to_save, bs) | ||||
|         # batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||
|         batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||
|  | ||||
|         cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||
|                                         keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||
|         samples = samples + cur_sample | ||||
|         reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) | ||||
|         advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||
|         if old_log_probs is None: | ||||
|             old_log_probs = log_probs.clone() | ||||
|         ratio = torch.exp(log_probs - old_log_probs) | ||||
|         unclipped_loss = -advantages * ratio | ||||
|         clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param) | ||||
|         loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||
|         accelerator.backward(loss) | ||||
|         optimizer.step() | ||||
|         optimizer.zero_grad() | ||||
|  | ||||
|  | ||||
|         samples_with_log_probs.append((cur_sample, log_probs, reward)) | ||||
|          | ||||
|         all_ys.append(batch_y) | ||||
|         batch_id += to_generate | ||||
|  | ||||
|         samples_left_to_save -= to_save | ||||
|         samples_left_to_generate -= to_generate | ||||
|         chains_left_to_save -= chains_save | ||||
|          | ||||
|     print(f"final Computing sampling metrics...") | ||||
|     graph_dit_model.sampling_metrics.reset() | ||||
|     graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) | ||||
|     graph_dit_model.sampling_metrics.reset() | ||||
|     print(f"Done.") | ||||
|  | ||||
|     # save samples | ||||
|     print("Samples:") | ||||
|     print(samples) | ||||
|  | ||||
|     # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) | ||||
|     # samples, log_probs, rewards = samples_with_log_probs[perm] | ||||
|     # samples = list(samples) | ||||
|     # log_probs = list(log_probs) | ||||
|     # for i in range(len(log_probs)): | ||||
|     #     log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) | ||||
|     # print(f'log_probs: {log_probs[:5]}') | ||||
|     # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) | ||||
|     # rewards = list(rewards) | ||||
|     # log_probs = torch.cat(log_probs, dim=0) | ||||
|     # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) | ||||
|     # old_log_probs = log_probs.clone() | ||||
|  | ||||
|     # === | ||||
|  | ||||
|     #     # for graph in graphs: | ||||
|     #     #     reward = 1.0 | ||||
|     #     #     rewards.append(reward) | ||||
|     #     return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) | ||||
|     # old_log_probs = None | ||||
|     # while samples_left_to_generate > 0: | ||||
|     #     print(f'samples left to generate: {samples_left_to_generate}/' | ||||
| @@ -381,27 +330,28 @@ def test(cfg: DictConfig): | ||||
|     #     to_generate = min(samples_left_to_generate, bs) | ||||
|     #     to_save = min(samples_left_to_save, bs) | ||||
|     #     chains_save = min(chains_left_to_save, bs) | ||||
|     #     # batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||
|     #     batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||
|  | ||||
|     #     with accelerator.accumulate(graph_dit_model): | ||||
|     #         batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||
|     #         new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||
|     #         samples = samples + new_samples | ||||
|     #         reward = graph_reward_fn(new_samples, device=graph_dit_model.device) | ||||
|     #         advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||
|     #         if old_log_probs is None: | ||||
|     #             old_log_probs = log_probs.clone() | ||||
|     #         ratio = torch.exp(log_probs - old_log_probs) | ||||
|     #         unclipped_loss = -advantages * ratio | ||||
|     #         clipped_loss = -advantages * torch.clamp(ratio, | ||||
|     #                         1.0 - cfg.ppo.clip_param, | ||||
|     #                         1.0 + cfg.ppo.clip_param) | ||||
|     #         loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||
|     #         accelerator.backward(loss) | ||||
|     #         optimizer.step() | ||||
|     #         optimizer.zero_grad() | ||||
|     #     cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||
|     #                                     keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||
|     #     log_probs = torch.sum(log_probs, dim=-1).unsqueeze(1) | ||||
|     #     samples = samples + cur_sample | ||||
|     #     reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) | ||||
|     #     advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||
|     #     print(f'reward: {reward.shape}, advantages: {advantages.shape}, log_probs: {log_probs.shape}, cur_sample: {len(cur_sample)}') | ||||
|     #     if old_log_probs is None: | ||||
|     #         old_log_probs = log_probs.clone() | ||||
|     #     ratio = torch.exp(log_probs - old_log_probs) | ||||
|     #     unclipped_loss = -advantages * ratio | ||||
|     #     clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param) | ||||
|     #     loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||
|     #     accelerator.backward(loss) | ||||
|     #     optimizer.step() | ||||
|     #     optimizer.zero_grad() | ||||
|  | ||||
|  | ||||
|     #     samples_with_log_probs.append((new_samples, log_probs, reward)) | ||||
|     #     samples_with_log_probs.append((cur_sample, log_probs, reward)) | ||||
|          | ||||
|     #     all_ys.append(batch_y) | ||||
|     #     batch_id += to_generate | ||||
| @@ -409,7 +359,6 @@ def test(cfg: DictConfig): | ||||
|     #     samples_left_to_save -= to_save | ||||
|     #     samples_left_to_generate -= to_generate | ||||
|     #     chains_left_to_save -= chains_save | ||||
|     #     # break | ||||
|          | ||||
|     # print(f"final Computing sampling metrics...") | ||||
|     # graph_dit_model.sampling_metrics.reset() | ||||
| @@ -421,46 +370,10 @@ def test(cfg: DictConfig): | ||||
|     # print("Samples:") | ||||
|     # print(samples) | ||||
|  | ||||
|     # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) | ||||
|     # samples, log_probs, rewards = samples_with_log_probs[perm] | ||||
|     # samples = list(samples) | ||||
|     # log_probs = list(log_probs) | ||||
|     # for i in range(len(log_probs)): | ||||
|     #     log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) | ||||
|     # print(f'log_probs: {log_probs[:5]}') | ||||
|     # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) | ||||
|     # rewards = list(rewards) | ||||
|     # log_probs = torch.cat(log_probs, dim=0) | ||||
|     # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) | ||||
|     # old_log_probs = log_probs.clone() | ||||
|     # # multi metrics range | ||||
|     # # reward hacking hiking | ||||
|     # for inner_epoch in range(cfg.train.n_epochs): | ||||
|     #     # print(f'rewards: {rewards.shape}') # torch.Size([1000]) | ||||
|     #     print(f'rewards: {rewards[:5]}') | ||||
|     #     print(f'len rewards: {len(rewards)}') | ||||
|     #     print(f'type rewards: {type(rewards)}') | ||||
|     #     if len(rewards) > 1 and isinstance(rewards, list): | ||||
|     #         rewards = torch.cat(rewards, dim=0) | ||||
|     #     elif len(rewards) == 1 and isinstance(rewards, list): | ||||
|     #         rewards = rewards[0] | ||||
|     #     # print(f'rewards: {rewards.shape}') | ||||
|     #     advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) | ||||
|     #     print(f'advantages: {advantages.shape}') | ||||
|     #     with accelerator.accumulate(graph_dit_model): | ||||
|     #         ratio = torch.exp(log_probs - old_log_probs) | ||||
|     #         unclipped_loss = -advantages * ratio | ||||
|     #         # z-score normalization | ||||
|     #         clipped_loss = -advantages * torch.clamp(ratio, | ||||
|     #                         1.0 - cfg.ppo.clip_param, | ||||
|     #                         1.0 + cfg.ppo.clip_param) | ||||
|     #         loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||
|     #         accelerator.backward(loss) | ||||
|     #         optimizer.step() | ||||
|     #         optimizer.zero_grad() | ||||
|     # ======================== | ||||
|      | ||||
|  | ||||
|     #     accelerator.log({"loss": loss.item(), "epoch": inner_epoch}) | ||||
|     #     print(f"loss: {loss.item()}, epoch: {inner_epoch}") | ||||
|      | ||||
|  | ||||
|  | ||||
|     # trainer = Trainer( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user