From 418be43566440d69d425daedcdcd572d6bcbf3bc Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 27 May 2021 13:03:59 +0000 Subject: [PATCH] Fix bugs --- exps/GeMOSA/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exps/GeMOSA/main.py b/exps/GeMOSA/main.py index 6824179..199cf8f 100644 --- a/exps/GeMOSA/main.py +++ b/exps/GeMOSA/main.py @@ -204,11 +204,13 @@ def main(args): train_env = get_synthetic_env(mode="train", version=args.env_version) valid_env = get_synthetic_env(mode="valid", version=args.env_version) trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) + test_env = get_synthetic_env(mode="test", version=args.env_version) all_env = get_synthetic_env(mode=None, version=args.env_version) logger.log("The training enviornment: {:}".format(train_env)) logger.log("The validation enviornment: {:}".format(valid_env)) logger.log("The trainval enviornment: {:}".format(trainval_env)) logger.log("The total enviornment: {:}".format(all_env)) + logger.log("The test enviornment: {:}".format(test_env)) model_kwargs = dict( config=dict(model_type="norm_mlp"), input_dim=all_env.meta_info["input_dim"], @@ -268,10 +270,10 @@ def main(args): logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) """ _, loss_adapt_v1, metric_adapt_v1 = online_evaluate( - valid_env, meta_model, base_model, criterion, metric, args, logger, False, False + test_env, meta_model, base_model, criterion, metric, args, logger, False, False ) _, loss_adapt_v2, metric_adapt_v2 = online_evaluate( - valid_env, meta_model, base_model, criterion, metric, args, logger, False, True + test_env, meta_model, base_model, criterion, metric, args, logger, False, True ) logger.log( "[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format(