Fix bugs
This commit is contained in:
parent
184f2326bb
commit
f7c2bb5e32
@ -74,8 +74,12 @@ def main(args):
|
|||||||
)
|
)
|
||||||
# train the same data
|
# train the same data
|
||||||
assert idx != 0
|
assert idx != 0
|
||||||
historical_x = env_info["{:}-x".format(idx)]
|
historical_x, historical_y = [], []
|
||||||
historical_y = env_info["{:}-y".format(idx)]
|
for past_i in range(idx):
|
||||||
|
historical_x.append(env_info["{:}-x".format(past_i)])
|
||||||
|
historical_y.append(env_info["{:}-y".format(past_i)])
|
||||||
|
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
|
||||||
|
historical_x, historical_y = subsample(historical_x, historical_y)
|
||||||
# build model
|
# build model
|
||||||
mean, std = historical_x.mean().item(), historical_x.std().item()
|
mean, std = historical_x.mean().item(), historical_x.std().item()
|
||||||
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
|
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
|
||||||
@ -153,7 +157,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_dir",
|
"--save_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="./outputs/lfna-synthetic/use-same-timestamp",
|
default="./outputs/lfna-synthetic/use-all-past-data",
|
||||||
help="The checkpoint directory.",
|
help="The checkpoint directory.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -74,12 +74,8 @@ def main(args):
|
|||||||
)
|
)
|
||||||
# train the same data
|
# train the same data
|
||||||
assert idx != 0
|
assert idx != 0
|
||||||
historical_x, historical_y = [], []
|
historical_x = env_info["{:}-x".format(idx)]
|
||||||
for past_i in range(idx):
|
historical_y = env_info["{:}-y".format(idx)]
|
||||||
historical_x.append(env_info["{:}-x".format(past_i)])
|
|
||||||
historical_y.append(env_info["{:}-y".format(past_i)])
|
|
||||||
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
|
|
||||||
historical_x, historical_y = subsample(historical_x, historical_y)
|
|
||||||
# build model
|
# build model
|
||||||
mean, std = historical_x.mean().item(), historical_x.std().item()
|
mean, std = historical_x.mean().item(), historical_x.std().item()
|
||||||
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
|
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
|
||||||
@ -153,11 +149,11 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser("Use data at the same timestamp.")
|
parser = argparse.ArgumentParser("Use the data in the past.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_dir",
|
"--save_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="./outputs/lfna-synthetic/use-all-past-data",
|
default="./outputs/lfna-synthetic/use-same-timestamp",
|
||||||
help="The checkpoint directory.",
|
help="The checkpoint directory.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
Loading…
Reference in New Issue
Block a user