Fix black
This commit is contained in:
		| @@ -73,7 +73,7 @@ def main(save_dir): | |||||||
|  |  | ||||||
|     additional_xaxis = np.arange(-6, 6, 0.2) |     additional_xaxis = np.arange(-6, 6, 0.2) | ||||||
|     models = dict() |     models = dict() | ||||||
|      |  | ||||||
|     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|         xaxis_all = dataset[:, 0].numpy() |         xaxis_all = dataset[:, 0].numpy() | ||||||
|         # xaxis_all = np.concatenate((additional_xaxis, xaxis_all)) |         # xaxis_all = np.concatenate((additional_xaxis, xaxis_all)) | ||||||
| @@ -84,15 +84,19 @@ def main(save_dir): | |||||||
|         # split the dataset |         # split the dataset | ||||||
|         indexes = list(range(xaxis_all.shape[0])) |         indexes = list(range(xaxis_all.shape[0])) | ||||||
|         random.shuffle(indexes) |         random.shuffle(indexes) | ||||||
|         train_indexes = indexes[:len(indexes)//2] |         train_indexes = indexes[: len(indexes) // 2] | ||||||
|         valid_indexes = indexes[len(indexes)//2:] |         valid_indexes = indexes[len(indexes) // 2 :] | ||||||
|         train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes] |         train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes] | ||||||
|         valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes] |         valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes] | ||||||
|          |  | ||||||
|         model, loss_fn, train_loss = optimize_fn(train_xs, train_ys) |         model, loss_fn, train_loss = optimize_fn(train_xs, train_ys) | ||||||
|         # model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all) |         # model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all) | ||||||
|         pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn) |         pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn) | ||||||
|         print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss)) |         print( | ||||||
|  |             "[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format( | ||||||
|  |                 idx, timestamp, train_loss, valid_loss | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # the first plot |         # the first plot | ||||||
|         scatter_list = [] |         scatter_list = [] | ||||||
| @@ -114,10 +118,10 @@ def main(save_dir): | |||||||
|                 "color": "r", |                 "color": "r", | ||||||
|                 "s": 10, |                 "s": 10, | ||||||
|                 "alpha": 0.5, |                 "alpha": 0.5, | ||||||
|                 "label": "MLP at now" |                 "label": "MLP at now", | ||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|          |  | ||||||
|         draw_fig(save_dir, timestamp, scatter_list) |         draw_fig(save_dir, timestamp, scatter_list) | ||||||
|     print("Save all figures into {:}".format(save_dir)) |     print("Save all figures into {:}".format(save_dir)) | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|   | |||||||
| @@ -49,8 +49,8 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|     def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: |     def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: | ||||||
|         if not isinstance(module, SuperModule): |         if not isinstance(module, SuperModule): | ||||||
|             warnings.warn( |             warnings.warn( | ||||||
|                 "Add {:} module, which is not SuperModule, into {:}".format( |                 "Add {:}:{:} module, which is not SuperModule, into {:}".format( | ||||||
|                     name, self.__class__.__name__ |                     name, module.__class__.__name__, self.__class__.__name__ | ||||||
|                 ) |                 ) | ||||||
|                 + "\n" |                 + "\n" | ||||||
|                 + "It may cause some functions invalid." |                 + "It may cause some functions invalid." | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user