Update LFNA

This commit is contained in:
D-X-Y 2021-05-17 12:01:58 +00:00
parent 5c851ac25a
commit 85f7f1a400

View File

@ -132,8 +132,8 @@ def main(args):
) )
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, optimizer,
milestones=[1, 2], milestones=[1, 2, 3, 4, 5],
gamma=0.1, gamma=0.2,
) )
logger.log("The base-model is\n{:}".format(base_model)) logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model)) logger.log("The meta-model is\n{:}".format(meta_model))
@ -223,11 +223,12 @@ def main(args):
logger, logger,
) )
if iepoch - last_success_epoch >= args.early_stop_thresh: if iepoch - last_success_epoch >= args.early_stop_thresh:
if lr_scheduler.last_epoch > 2: if lr_scheduler.last_epoch > 4:
logger.log("Early stop at {:}".format(iepoch)) logger.log("Early stop at {:}".format(iepoch))
break break
else: else:
last_epoch.step() last_success_epoch = iepoch
lr_scheduler.step()
logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch)) logger.log("Decay the lr [{:}]".format(lr_scheduler.last_epoch))
per_epoch_time.update(time.time() - start_time) per_epoch_time.update(time.time() - start_time)
@ -375,7 +376,7 @@ if __name__ == "__main__":
help="The #epochs for early stop.", help="The #epochs for early stop.",
) )
parser.add_argument( parser.add_argument(
"--seq_length", type=int, default=5, help="The sequence length." "--seq_length", type=int, default=10, help="The sequence length."
) )
parser.add_argument( parser.add_argument(
"--workers", type=int, default=4, help="The number of workers in parallel." "--workers", type=int, default=4, help="The number of workers in parallel."
@ -392,11 +393,12 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0: if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000) args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None" assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-d{:}_{:}_{:}-lr{:}-wd{:}-e{:}-env{:}".format( args.save_dir = "{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
args.save_dir, args.save_dir,
args.hidden_dim, args.hidden_dim,
args.layer_dim, args.layer_dim,
args.time_dim, args.time_dim,
args.seq_length,
args.lr, args.lr,
args.weight_decay, args.weight_decay,
args.epochs, args.epochs,