Update LFNA
This commit is contained in:
		| @@ -100,9 +100,15 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         weight_decay=args.weight_decay, |         weight_decay=args.weight_decay, | ||||||
|         amsgrad=True, |         amsgrad=True, | ||||||
|     ) |     ) | ||||||
|  |     logger.log("Pre-train the meta-model") | ||||||
|  |     logger.log("Using the optimizer: {:}".format(optimizer)) | ||||||
|  |  | ||||||
|     meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") |     meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") | ||||||
|  |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|     for iepoch in range(args.epochs): |     for iepoch in range(args.epochs): | ||||||
|  |         left_time = "Time Left: {:}".format( | ||||||
|  |             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||||
|  |         ) | ||||||
|         total_meta_losses, total_match_losses = [], [] |         total_meta_losses, total_match_losses = [], [] | ||||||
|         for ibatch in range(args.meta_batch): |         for ibatch in range(args.meta_batch): | ||||||
|             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) |             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) | ||||||
| @@ -151,7 +157,11 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|                 final_match_loss.item(), |                 final_match_loss.item(), | ||||||
|             ) |             ) | ||||||
|             + ", batch={:}".format(len(total_meta_losses)) |             + ", batch={:}".format(len(total_meta_losses)) | ||||||
|  |             + ", success={:}, best_score={:.4f}".format(success, -best_score) | ||||||
|  |             + " {:}".format(left_time) | ||||||
|         ) |         ) | ||||||
|  |         per_epoch_time.update(time.time() - start_time) | ||||||
|  |         start_time = time.time() | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user