Fix bugs
This commit is contained in:
		| @@ -74,8 +74,12 @@ def main(args): | |||||||
|         ) |         ) | ||||||
|         # train the same data |         # train the same data | ||||||
|         assert idx != 0 |         assert idx != 0 | ||||||
|         historical_x = env_info["{:}-x".format(idx)] |         historical_x, historical_y = [], [] | ||||||
|         historical_y = env_info["{:}-y".format(idx)] |         for past_i in range(idx): | ||||||
|  |             historical_x.append(env_info["{:}-x".format(past_i)]) | ||||||
|  |             historical_y.append(env_info["{:}-y".format(past_i)]) | ||||||
|  |         historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) | ||||||
|  |         historical_x, historical_y = subsample(historical_x, historical_y) | ||||||
|         # build model |         # build model | ||||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() |         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) |         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||||
| @@ -153,7 +157,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--save_dir", |         "--save_dir", | ||||||
|         type=str, |         type=str, | ||||||
|         default="./outputs/lfna-synthetic/use-same-timestamp", |         default="./outputs/lfna-synthetic/use-all-past-data", | ||||||
|         help="The checkpoint directory.", |         help="The checkpoint directory.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|   | |||||||
| @@ -74,12 +74,8 @@ def main(args): | |||||||
|         ) |         ) | ||||||
|         # train the same data |         # train the same data | ||||||
|         assert idx != 0 |         assert idx != 0 | ||||||
|         historical_x, historical_y = [], [] |         historical_x = env_info["{:}-x".format(idx)] | ||||||
|         for past_i in range(idx): |         historical_y = env_info["{:}-y".format(idx)] | ||||||
|             historical_x.append(env_info["{:}-x".format(past_i)]) |  | ||||||
|             historical_y.append(env_info["{:}-y".format(past_i)]) |  | ||||||
|         historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) |  | ||||||
|         historical_x, historical_y = subsample(historical_x, historical_y) |  | ||||||
|         # build model |         # build model | ||||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() |         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) |         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||||
| @@ -153,11 +149,11 @@ def main(args): | |||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     parser = argparse.ArgumentParser("Use data at the same timestamp.") |     parser = argparse.ArgumentParser("Use the data in the past.") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--save_dir", |         "--save_dir", | ||||||
|         type=str, |         type=str, | ||||||
|         default="./outputs/lfna-synthetic/use-all-past-data", |         default="./outputs/lfna-synthetic/use-same-timestamp", | ||||||
|         help="The checkpoint directory.", |         help="The checkpoint directory.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user