Fix test bugs
This commit is contained in:
		| @@ -82,7 +82,14 @@ def main(args): | |||||||
|         historical_x, historical_y = subsample(historical_x, historical_y) |         historical_x, historical_y = subsample(historical_x, historical_y) | ||||||
|         # build model |         # build model | ||||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() |         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) |         model_kwargs = dict( | ||||||
|  |             input_dim=1, | ||||||
|  |             output_dim=1, | ||||||
|  |             act_cls="leaky_relu", | ||||||
|  |             norm_cls="simple_norm", | ||||||
|  |             mean=mean, | ||||||
|  |             std=std, | ||||||
|  |         ) | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|   | |||||||
| @@ -78,7 +78,14 @@ def main(args): | |||||||
|         historical_y = env_info["{:}-y".format(idx)] |         historical_y = env_info["{:}-y".format(idx)] | ||||||
|         # build model |         # build model | ||||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() |         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) |         model_kwargs = dict( | ||||||
|  |             input_dim=1, | ||||||
|  |             output_dim=1, | ||||||
|  |             act_cls="leaky_relu", | ||||||
|  |             norm_cls="simple_norm", | ||||||
|  |             mean=mean, | ||||||
|  |             std=std, | ||||||
|  |         ) | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|   | |||||||
| @@ -24,6 +24,8 @@ from models.xcore import get_model | |||||||
|  |  | ||||||
|  |  | ||||||
| class Population: | class Population: | ||||||
|  |     """A population used to maintain models at different timestamps.""" | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self._time2model = dict() |         self._time2model = dict() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -64,7 +64,7 @@ class TestSuperSimpleNorm(unittest.TestCase): | |||||||
|         model.apply_verbose(True) |         model.apply_verbose(True) | ||||||
|  |  | ||||||
|         print(model.super_run_type) |         print(model.super_run_type) | ||||||
|         self.assertTrue(model[1].bias) |         self.assertTrue(model[2].bias) | ||||||
|  |  | ||||||
|         inputs = torch.rand(20, 10) |         inputs = torch.rand(20, 10) | ||||||
|         print("Input shape: {:}".format(inputs.shape)) |         print("Input shape: {:}".format(inputs.shape)) | ||||||
| @@ -80,6 +80,6 @@ class TestSuperSimpleNorm(unittest.TestCase): | |||||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|         model.apply_candidate(abstract_child) |         model.apply_candidate(abstract_child) | ||||||
|  |  | ||||||
|         output_shape = (20, abstract_child["1"]["_out_features"].value) |         output_shape = (20, abstract_child["2"]["_out_features"].value) | ||||||
|         outputs = model(inputs) |         outputs = model(inputs) | ||||||
|         self.assertEqual(tuple(outputs.shape), output_shape) |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user