diff --git a/exps/synthetic/baseline.py b/exps/synthetic/baseline.py index 09f3367..3f4f46c 100644 --- a/exps/synthetic/baseline.py +++ b/exps/synthetic/baseline.py @@ -73,7 +73,7 @@ def main(save_dir): additional_xaxis = np.arange(-6, 6, 0.2) models = dict() - + for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): xaxis_all = dataset[:, 0].numpy() # xaxis_all = np.concatenate((additional_xaxis, xaxis_all)) @@ -84,15 +84,19 @@ def main(save_dir): # split the dataset indexes = list(range(xaxis_all.shape[0])) random.shuffle(indexes) - train_indexes = indexes[:len(indexes)//2] - valid_indexes = indexes[len(indexes)//2:] + train_indexes = indexes[: len(indexes) // 2] + valid_indexes = indexes[len(indexes) // 2 :] train_xs, train_ys = xaxis_all[train_indexes], yaxis_all[train_indexes] valid_xs, valid_ys = xaxis_all[valid_indexes], yaxis_all[valid_indexes] - + model, loss_fn, train_loss = optimize_fn(train_xs, train_ys) # model, loss_fn, train_loss = optimize_fn(xaxis_all, yaxis_all) pred_valid_ys, valid_loss = evaluate_fn(model, valid_xs, valid_ys, loss_fn) - print("[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format(idx, timestamp, train_loss, valid_loss)) + print( + "[{:03d}] T-{:03d}, train-loss={:.5f}, valid-loss={:.5f}".format( + idx, timestamp, train_loss, valid_loss + ) + ) # the first plot scatter_list = [] @@ -114,10 +118,10 @@ def main(save_dir): "color": "r", "s": 10, "alpha": 0.5, - "label": "MLP at now" + "label": "MLP at now", } ) - + draw_fig(save_dir, timestamp, scatter_list) print("Save all figures into {:}".format(save_dir)) save_dir = save_dir.resolve() diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index df77fc3..9004f34 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -49,8 +49,8 @@ class SuperModule(abc.ABC, nn.Module): def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: if not isinstance(module, SuperModule): warnings.warn( - "Add {:} module, which is not SuperModule, into {:}".format( - name, self.__class__.__name__ + "Add {:}:{:} module, which is not SuperModule, into {:}".format( + name, module.__class__.__name__, self.__class__.__name__ ) + "\n" + "It may cause some functions invalid."