can run but need to test whtich pth is
This commit is contained in:
		| @@ -282,7 +282,7 @@ def test(cfg: DictConfig): | |||||||
|      |      | ||||||
|     # Normal reward function |     # Normal reward function | ||||||
|     from nas_201_api import NASBench201API as API |     from nas_201_api import NASBench201API as API | ||||||
|     api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth') |     api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||||
|     def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): |     def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): | ||||||
|         rewards = [] |         rewards = [] | ||||||
|         if reward_model == 'swap': |         if reward_model == 'swap': | ||||||
| @@ -308,9 +308,9 @@ def test(cfg: DictConfig): | |||||||
|                     reward = swap_scores[api.query_index_by_arch(arch_str)] |                     reward = swap_scores[api.query_index_by_arch(arch_str)] | ||||||
|                     rewards.append(reward) |                     rewards.append(reward) | ||||||
|                  |                  | ||||||
|         for graph in graphs: |         # for graph in graphs: | ||||||
|             reward = 1.0 |         #     reward = 1.0 | ||||||
|             rewards.append(reward) |         #     rewards.append(reward) | ||||||
|         return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) |         return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) | ||||||
|     while samples_left_to_generate > 0: |     while samples_left_to_generate > 0: | ||||||
|         print(f'samples left to generate: {samples_left_to_generate}/' |         print(f'samples left to generate: {samples_left_to_generate}/' | ||||||
| @@ -326,6 +326,8 @@ def test(cfg: DictConfig): | |||||||
|                                         keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) |                                         keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||||
|         samples = samples + cur_sample |         samples = samples + cur_sample | ||||||
|         reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) |         reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) | ||||||
|  |         advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||||
|  |  | ||||||
|  |  | ||||||
|         samples_with_log_probs.append((cur_sample, log_probs, reward)) |         samples_with_log_probs.append((cur_sample, log_probs, reward)) | ||||||
|          |          | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user