Fix black issues
This commit is contained in:
		| @@ -161,7 +161,10 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", type=int, required=True, help="The hidden dimension.", |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         required=True, | ||||||
|  |         help="The hidden dimension.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
| @@ -170,10 +173,16 @@ if __name__ == "__main__": | |||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--batch_size", type=int, default=512, help="The batch size", |         "--batch_size", | ||||||
|  |         type=int, | ||||||
|  |         default=512, | ||||||
|  |         help="The batch size", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", type=int, default=1000, help="The total number of epochs.", |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=1000, | ||||||
|  |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--srange", type=str, required=True, help="The range of models to be evaluated" |         "--srange", type=str, required=True, help="The range of models to be evaluated" | ||||||
|   | |||||||
| @@ -41,7 +41,10 @@ class MAML: | |||||||
|         ) |         ) | ||||||
|         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |         self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|             self.meta_optimizer, |             self.meta_optimizer, | ||||||
|             milestones=[int(epochs * 0.8), int(epochs * 0.9),], |             milestones=[ | ||||||
|  |                 int(epochs * 0.8), | ||||||
|  |                 int(epochs * 0.9), | ||||||
|  |             ], | ||||||
|             gamma=0.1, |             gamma=0.1, | ||||||
|         ) |         ) | ||||||
|         self.inner_lr = inner_lr |         self.inner_lr = inner_lr | ||||||
| @@ -194,7 +197,10 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", type=int, default=16, help="The hidden dimension.", |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         default=16, | ||||||
|  |         help="The hidden dimension.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_lr", |         "--meta_lr", | ||||||
| @@ -224,10 +230,16 @@ if __name__ == "__main__": | |||||||
|         help="The gap between prev_time and current_timestamp", |         help="The gap between prev_time and current_timestamp", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_batch", type=int, default=64, help="The batch size for the meta-model", |         "--meta_batch", | ||||||
|  |         type=int, | ||||||
|  |         default=64, | ||||||
|  |         help="The batch size for the meta-model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", type=int, default=2000, help="The total number of epochs.", |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=2000, | ||||||
|  |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
|   | |||||||
| @@ -149,7 +149,10 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", type=int, required=True, help="The hidden dimension.", |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         required=True, | ||||||
|  |         help="The hidden dimension.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
| @@ -164,10 +167,16 @@ if __name__ == "__main__": | |||||||
|         help="The gap between prev_time and current_timestamp", |         help="The gap between prev_time and current_timestamp", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--batch_size", type=int, default=512, help="The batch size", |         "--batch_size", | ||||||
|  |         type=int, | ||||||
|  |         default=512, | ||||||
|  |         help="The batch size", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", type=int, default=300, help="The total number of epochs.", |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=300, | ||||||
|  |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", |         "--workers", | ||||||
|   | |||||||
| @@ -149,7 +149,10 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", type=int, required=True, help="The hidden dimension.", |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         required=True, | ||||||
|  |         help="The hidden dimension.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
| @@ -158,10 +161,16 @@ if __name__ == "__main__": | |||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--batch_size", type=int, default=512, help="The batch size", |         "--batch_size", | ||||||
|  |         type=int, | ||||||
|  |         default=512, | ||||||
|  |         help="The batch size", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", type=int, default=300, help="The total number of epochs.", |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=300, | ||||||
|  |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", |         "--workers", | ||||||
|   | |||||||
| @@ -62,7 +62,10 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|         optimizer, |         optimizer, | ||||||
|         milestones=[int(args.epochs * 0.8), int(args.epochs * 0.9),], |         milestones=[ | ||||||
|  |             int(args.epochs * 0.8), | ||||||
|  |             int(args.epochs * 0.9), | ||||||
|  |         ], | ||||||
|         gamma=0.1, |         gamma=0.1, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
| @@ -170,7 +173,10 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", type=int, required=True, help="The hidden dimension.", |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         required=True, | ||||||
|  |         help="The hidden dimension.", | ||||||
|     ) |     ) | ||||||
|     ##### |     ##### | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -180,7 +186,10 @@ if __name__ == "__main__": | |||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_batch", type=int, default=64, help="The batch size for the meta-model", |         "--meta_batch", | ||||||
|  |         type=int, | ||||||
|  |         default=64, | ||||||
|  |         help="The batch size for the meta-model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
| @@ -189,13 +198,22 @@ if __name__ == "__main__": | |||||||
|         help="The maximum epochs for early stop.", |         help="The maximum epochs for early stop.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--epochs", type=int, default=2000, help="The total number of epochs.", |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=2000, | ||||||
|  |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--per_epoch_step", type=int, default=20, help="The total number of epochs.", |         "--per_epoch_step", | ||||||
|  |         type=int, | ||||||
|  |         default=20, | ||||||
|  |         help="The total number of epochs.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--device", type=str, default="cpu", help="", |         "--device", | ||||||
|  |         type=str, | ||||||
|  |         default="cpu", | ||||||
|  |         help="", | ||||||
|     ) |     ) | ||||||
|     # Random Seed |     # Random Seed | ||||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|   | |||||||
| @@ -101,7 +101,10 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|         optimizer, |         optimizer, | ||||||
|         milestones=[int(args.epochs * 0.8), int(args.epochs * 0.9),], |         milestones=[ | ||||||
|  |             int(args.epochs * 0.8), | ||||||
|  |             int(args.epochs * 0.9), | ||||||
|  |         ], | ||||||
|         gamma=0.1, |         gamma=0.1, | ||||||
|     ) |     ) | ||||||
|     logger.log("The base-model is\n{:}".format(base_model)) |     logger.log("The base-model is\n{:}".format(base_model)) | ||||||
| @@ -240,13 +243,22 @@ if __name__ == "__main__": | |||||||
|         help="The synthetic enviornment version.", |         help="The synthetic enviornment version.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--hidden_dim", type=int, default=16, help="The hidden dimension.", |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         default=16, | ||||||
|  |         help="The hidden dimension.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--layer_dim", type=int, default=16, help="The layer chunk dimension.", |         "--layer_dim", | ||||||
|  |         type=int, | ||||||
|  |         default=16, | ||||||
|  |         help="The layer chunk dimension.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--time_dim", type=int, default=16, help="The timestamp dimension.", |         "--time_dim", | ||||||
|  |         type=int, | ||||||
|  |         default=16, | ||||||
|  |         help="The timestamp dimension.", | ||||||
|     ) |     ) | ||||||
|     ##### |     ##### | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -262,7 +274,10 @@ if __name__ == "__main__": | |||||||
|         help="The weight decay for the optimizer (default is Adam)", |         help="The weight decay for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_batch", type=int, default=64, help="The batch size for the meta-model", |         "--meta_batch", | ||||||
|  |         type=int, | ||||||
|  |         default=64, | ||||||
|  |         help="The batch size for the meta-model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--sampler_enlarge", |         "--sampler_enlarge", | ||||||
| @@ -284,7 +299,10 @@ if __name__ == "__main__": | |||||||
|         "--workers", type=int, default=4, help="The number of workers in parallel." |         "--workers", type=int, default=4, help="The number of workers in parallel." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--device", type=str, default="cpu", help="", |         "--device", | ||||||
|  |         type=str, | ||||||
|  |         default="cpu", | ||||||
|  |         help="", | ||||||
|     ) |     ) | ||||||
|     # Random Seed |     # Random Seed | ||||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|   | |||||||
| @@ -75,7 +75,8 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|  |  | ||||||
|         # unknown token |         # unknown token | ||||||
|         self.register_parameter( |         self.register_parameter( | ||||||
|             "_unknown_token", torch.nn.Parameter(torch.Tensor(1, time_embedding)), |             "_unknown_token", | ||||||
|  |             torch.nn.Parameter(torch.Tensor(1, time_embedding)), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # initialization |         # initialization | ||||||
|   | |||||||
| @@ -164,9 +164,11 @@ def compare_cl(save_dir): | |||||||
|         ) |         ) | ||||||
|     print("Save all figures into {:}".format(save_dir)) |     print("Save all figures into {:}".format(save_dir)) | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|     base_cmd = "ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format( |     base_cmd = ( | ||||||
|  |         "ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format( | ||||||
|             xdir=save_dir |             xdir=save_dir | ||||||
|         ) |         ) | ||||||
|  |     ) | ||||||
|     video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format( |     video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format( | ||||||
|         base_cmd, xdir=save_dir |         base_cmd, xdir=save_dir | ||||||
|     ) |     ) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user