Fix test bugs
This commit is contained in:
parent
4c14c7b85b
commit
f6a024a6ff
@ -82,7 +82,14 @@ def main(args):
|
|||||||
historical_x, historical_y = subsample(historical_x, historical_y)
|
historical_x, historical_y = subsample(historical_x, historical_y)
|
||||||
# build model
|
# build model
|
||||||
mean, std = historical_x.mean().item(), historical_x.std().item()
|
mean, std = historical_x.mean().item(), historical_x.std().item()
|
||||||
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
|
model_kwargs = dict(
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=1,
|
||||||
|
act_cls="leaky_relu",
|
||||||
|
norm_cls="simple_norm",
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
)
|
||||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||||
# build optimizer
|
# build optimizer
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||||
|
@ -78,7 +78,14 @@ def main(args):
|
|||||||
historical_y = env_info["{:}-y".format(idx)]
|
historical_y = env_info["{:}-y".format(idx)]
|
||||||
# build model
|
# build model
|
||||||
mean, std = historical_x.mean().item(), historical_x.std().item()
|
mean, std = historical_x.mean().item(), historical_x.std().item()
|
||||||
model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std)
|
model_kwargs = dict(
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=1,
|
||||||
|
act_cls="leaky_relu",
|
||||||
|
norm_cls="simple_norm",
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
)
|
||||||
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
|
||||||
# build optimizer
|
# build optimizer
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||||
|
@ -24,6 +24,8 @@ from models.xcore import get_model
|
|||||||
|
|
||||||
|
|
||||||
class Population:
|
class Population:
|
||||||
|
"""A population used to maintain models at different timestamps."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._time2model = dict()
|
self._time2model = dict()
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class TestSuperSimpleNorm(unittest.TestCase):
|
|||||||
model.apply_verbose(True)
|
model.apply_verbose(True)
|
||||||
|
|
||||||
print(model.super_run_type)
|
print(model.super_run_type)
|
||||||
self.assertTrue(model[1].bias)
|
self.assertTrue(model[2].bias)
|
||||||
|
|
||||||
inputs = torch.rand(20, 10)
|
inputs = torch.rand(20, 10)
|
||||||
print("Input shape: {:}".format(inputs.shape))
|
print("Input shape: {:}".format(inputs.shape))
|
||||||
@ -80,6 +80,6 @@ class TestSuperSimpleNorm(unittest.TestCase):
|
|||||||
model.set_super_run_type(super_core.SuperRunMode.Candidate)
|
model.set_super_run_type(super_core.SuperRunMode.Candidate)
|
||||||
model.apply_candidate(abstract_child)
|
model.apply_candidate(abstract_child)
|
||||||
|
|
||||||
output_shape = (20, abstract_child["1"]["_out_features"].value)
|
output_shape = (20, abstract_child["2"]["_out_features"].value)
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
self.assertEqual(tuple(outputs.shape), output_shape)
|
self.assertEqual(tuple(outputs.shape), output_shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user