Update LFNA
This commit is contained in:
		| @@ -132,8 +132,8 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|         optimizer, |         optimizer, | ||||||
|         milestones=[1, 2], |         milestones=[1, 2, 3, 4, 5], | ||||||
|         gamma=0.1, |         gamma=0.2, | ||||||
|     ) |     ) | ||||||
|     logger.log("The base-model is\n{:}".format(base_model)) |     logger.log("The base-model is\n{:}".format(base_model)) | ||||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) |     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||||
| @@ -223,11 +223,12 @@ def main(args): | |||||||
|                 logger, |                 logger, | ||||||
|             ) |             ) | ||||||
|         if iepoch - last_success_epoch >= args.early_stop_thresh: |         if iepoch - last_success_epoch >= args.early_stop_thresh: | ||||||
|             if lr_scheduler.last_epoch > 2: |             if lr_scheduler.last_epoch > 4: | ||||||
|                 logger.log("Early stop at {:}".format(iepoch)) |                 logger.log("Early stop at {:}".format(iepoch)) | ||||||
|                 break |                 break | ||||||
|             else: |             else: | ||||||
|                 last_epoch.step() |                 last_success_epoch = iepoch | ||||||
|  |                 lr_scheduler.step() | ||||||
|                 logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch)) |                 logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch)) | ||||||
|  |  | ||||||
|         per_epoch_time.update(time.time() - start_time) |         per_epoch_time.update(time.time() - start_time) | ||||||
| @@ -375,7 +376,7 @@ if __name__ == "__main__": | |||||||
|         help="The #epochs for early stop.", |         help="The #epochs for early stop.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--seq_length", type=int, default=5, help="The sequence length." |         "--seq_length", type=int, default=10, help="The sequence length." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--workers", type=int, default=4, help="The number of workers in parallel." |         "--workers", type=int, default=4, help="The number of workers in parallel." | ||||||
| @@ -392,11 +393,12 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|     args.save_dir = "{:}-d{:}_{:}_{:}-lr{:}-wd{:}-e{:}-env{:}".format( |     args.save_dir = "{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format( | ||||||
|         args.save_dir, |         args.save_dir, | ||||||
|         args.hidden_dim, |         args.hidden_dim, | ||||||
|         args.layer_dim, |         args.layer_dim, | ||||||
|         args.time_dim, |         args.time_dim, | ||||||
|  |         args.seq_length, | ||||||
|         args.lr, |         args.lr, | ||||||
|         args.weight_decay, |         args.weight_decay, | ||||||
|         args.epochs, |         args.epochs, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user