Add early stop
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 20000 --init_lr 0.01 | # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 | ||||||
| # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda | # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| @@ -76,6 +76,7 @@ def main(args): | |||||||
|     # LFNA meta-training |     # LFNA meta-training | ||||||
|     loss_meter = AverageMeter() |     loss_meter = AverageMeter() | ||||||
|     per_epoch_time, start_time = AverageMeter(), time.time() |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|  |     last_success = 0 | ||||||
|     for iepoch in range(args.epochs): |     for iepoch in range(args.epochs): | ||||||
|  |  | ||||||
|         need_time = "Time Left: {:}".format( |         need_time = "Time Left: {:}".format( | ||||||
| @@ -108,6 +109,13 @@ def main(args): | |||||||
|         lr_scheduler.step() |         lr_scheduler.step() | ||||||
|  |  | ||||||
|         loss_meter.update(final_loss.item()) |         loss_meter.update(final_loss.item()) | ||||||
|  |         success, best_score = hypernet.save_best(-loss_meter.val) | ||||||
|  |         if success: | ||||||
|  |             logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) | ||||||
|  |             last_success = iepoch | ||||||
|  |         if iepoch - last_success >= args.early_stop_thresh: | ||||||
|  |             logger.log("Early stop at {:}".format(iepoch)) | ||||||
|  |             break | ||||||
|         if iepoch % 20 == 0: |         if iepoch % 20 == 0: | ||||||
|             logger.log( |             logger.log( | ||||||
|                 head_str |                 head_str | ||||||
| @@ -119,11 +127,6 @@ def main(args): | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             success, best_score = hypernet.save_best(-loss_meter.avg) |  | ||||||
|             if success: |  | ||||||
|                 logger.log( |  | ||||||
|                     "Achieve the best with best_score = {:.3f}".format(best_score) |  | ||||||
|                 ) |  | ||||||
|             save_checkpoint( |             save_checkpoint( | ||||||
|                 { |                 { | ||||||
|                     "hypernet": hypernet.state_dict(), |                     "hypernet": hypernet.state_dict(), | ||||||
| @@ -192,6 +195,12 @@ if __name__ == "__main__": | |||||||
|         required=True, |         required=True, | ||||||
|         help="The hidden dimension.", |         help="The hidden dimension.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--early_stop_thresh", | ||||||
|  |         type=int, | ||||||
|  |         default=100, | ||||||
|  |         help="The maximum epochs for early stop.", | ||||||
|  |     ) | ||||||
|     ##### |     ##### | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user