Fix black
This commit is contained in:
		| @@ -73,7 +73,7 @@ def main(save_dir): | ||||
|  | ||||
|     additional_xaxis = np.arange(-6, 6, 0.2) | ||||
|     models = dict() | ||||
|      | ||||
|  | ||||
|     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         xaxis_all = dataset[:, 0].numpy() | ||||
|         # xaxis_all = np.concatenate((additional_xaxis, xaxis_all)) | ||||
| @@ -84,15 +84,19 @@ def main(save_dir): | ||||
|         # split the dataset | ||||
|         indexes = list(range(xaxis_all.shape[0])) | ||||
|         random.shuffle(indexes) | ||||
|         train_indexes = indexes[:len(indexes)//2] | ||||
|         valid_indexes = indexes[len(indexes)//2:] | ||||
|         train_indexes = indexes[: len(indexes) // 2] | ||||
|         valid_indexes = indexes[len(indexes) // 2 :] | ||||
|         train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes] | ||||
|         valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes] | ||||
|          | ||||
|  | ||||
|         model, loss_fn, train_loss = optimize_fn(train_xs, train_ys) | ||||
|         # model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all) | ||||
|         pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn) | ||||
|         print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss)) | ||||
|         print( | ||||
|             "[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format( | ||||
|                 idx, timestamp, train_loss, valid_loss | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # the first plot | ||||
|         scatter_list = [] | ||||
| @@ -114,10 +118,10 @@ def main(save_dir): | ||||
|                 "color": "r", | ||||
|                 "s": 10, | ||||
|                 "alpha": 0.5, | ||||
|                 "label": "MLP at now" | ||||
|                 "label": "MLP at now", | ||||
|             } | ||||
|         ) | ||||
|          | ||||
|  | ||||
|         draw_fig(save_dir, timestamp, scatter_list) | ||||
|     print("Save all figures into {:}".format(save_dir)) | ||||
|     save_dir = save_dir.resolve() | ||||
|   | ||||
| @@ -49,8 +49,8 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|     def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: | ||||
|         if not isinstance(module, SuperModule): | ||||
|             warnings.warn( | ||||
|                 "Add {:} module, which is not SuperModule, into {:}".format( | ||||
|                     name, self.__class__.__name__ | ||||
|                 "Add {:}:{:} module, which is not SuperModule, into {:}".format( | ||||
|                     name, module.__class__.__name__, self.__class__.__name__ | ||||
|                 ) | ||||
|                 + "\n" | ||||
|                 + "It may cause some functions invalid." | ||||
|   | ||||
		Reference in New Issue
	
	Block a user