Update LFNA
This commit is contained in:
		| @@ -94,7 +94,10 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     optimizer = torch.optim.Adam( |     optimizer = torch.optim.Adam( | ||||||
|         meta_model.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True |         meta_model.parameters(), | ||||||
|  |         lr=args.init_lr, | ||||||
|  |         weight_decay=args.weight_decay, | ||||||
|  |         amsgrad=True, | ||||||
|     ) |     ) | ||||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|         optimizer, |         optimizer, | ||||||
| @@ -137,7 +140,7 @@ def main(args): | |||||||
|         ) |         ) | ||||||
|         success, best_score = meta_model.save_best(-loss_meter.avg) |         success, best_score = meta_model.save_best(-loss_meter.avg) | ||||||
|         if success: |         if success: | ||||||
|             logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) |             logger.log("Achieve the best with best-score = {:.5f}".format(best_score)) | ||||||
|             last_success_epoch = iepoch |             last_success_epoch = iepoch | ||||||
|             save_checkpoint( |             save_checkpoint( | ||||||
|                 { |                 { | ||||||
| @@ -262,6 +265,12 @@ if __name__ == "__main__": | |||||||
|         default=0.005, |         default=0.005, | ||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--weight_decay", | ||||||
|  |         type=float, | ||||||
|  |         default=0.00001, | ||||||
|  |         help="The weight decay for the optimizer (default is Adam)", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--meta_batch", |         "--meta_batch", | ||||||
|         type=int, |         type=int, | ||||||
| @@ -274,11 +283,11 @@ if __name__ == "__main__": | |||||||
|         default=5, |         default=5, | ||||||
|         help="Enlarge the #iterations for an epoch", |         help="Enlarge the #iterations for an epoch", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument("--epochs", type=int, default=1000, help="The total #epochs.") |     parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
|         type=int, |         type=int, | ||||||
|         default=25, |         default=100, | ||||||
|         help="The maximum epochs for early stop.", |         help="The maximum epochs for early stop.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -299,11 +308,13 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|     args.save_dir = "{:}-d{:}_{:}_{:}-e{:}-env{:}".format( |     args.save_dir = "{:}-d{:}_{:}_{:}-lr{:}-wd{:}-e{:}-env{:}".format( | ||||||
|         args.save_dir, |         args.save_dir, | ||||||
|         args.hidden_dim, |         args.hidden_dim, | ||||||
|         args.layer_dim, |         args.layer_dim, | ||||||
|         args.time_dim, |         args.time_dim, | ||||||
|  |         args.init_lr, | ||||||
|  |         args.weight_decay, | ||||||
|         args.epochs, |         args.epochs, | ||||||
|         args.env_version, |         args.env_version, | ||||||
|     ) |     ) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user