diff --git a/exps/LFNA/basic-his.py b/exps/LFNA/basic-his.py index fca229f..c8b369b 100644 --- a/exps/LFNA/basic-his.py +++ b/exps/LFNA/basic-his.py @@ -74,8 +74,12 @@ def main(args): ) # train the same data assert idx != 0 - historical_x = env_info["{:}-x".format(idx)] - historical_y = env_info["{:}-y".format(idx)] + historical_x, historical_y = [], [] + for past_i in range(idx): + historical_x.append(env_info["{:}-x".format(past_i)]) + historical_y.append(env_info["{:}-y".format(past_i)]) + historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) + historical_x, historical_y = subsample(historical_x, historical_y) # build model mean, std = historical_x.mean().item(), historical_x.std().item() model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) @@ -153,7 +157,7 @@ if __name__ == "__main__": parser.add_argument( "--save_dir", type=str, - default="./outputs/lfna-synthetic/use-same-timestamp", + default="./outputs/lfna-synthetic/use-all-past-data", help="The checkpoint directory.", ) parser.add_argument( diff --git a/exps/LFNA/basic-same.py b/exps/LFNA/basic-same.py index 578c6f7..0a889a9 100644 --- a/exps/LFNA/basic-same.py +++ b/exps/LFNA/basic-same.py @@ -74,12 +74,8 @@ def main(args): ) # train the same data assert idx != 0 - historical_x, historical_y = [], [] - for past_i in range(idx): - historical_x.append(env_info["{:}-x".format(past_i)]) - historical_y.append(env_info["{:}-y".format(past_i)]) - historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y) - historical_x, historical_y = subsample(historical_x, historical_y) + historical_x = env_info["{:}-x".format(idx)] + historical_y = env_info["{:}-y".format(idx)] # build model mean, std = historical_x.mean().item(), historical_x.std().item() model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) @@ -153,11 +149,11 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser("Use data at the same timestamp.") + parser = argparse.ArgumentParser("Use the data in the past.") parser.add_argument( "--save_dir", type=str, - default="./outputs/lfna-synthetic/use-all-past-data", + default="./outputs/lfna-synthetic/use-same-timestamp", help="The checkpoint directory.", ) parser.add_argument(