Re-org baselines
This commit is contained in:
		| @@ -15,9 +15,11 @@ | ||||
| # python exps/trading/baselines.py --alg TabNet     # | ||||
| #                                                   # | ||||
| # python exps/trading/baselines.py --alg Transformer# | ||||
| # python exps/trading/baselines.py --alg TSF-A      # | ||||
| # python exps/trading/baselines.py --alg TSF        # | ||||
| # python exps/trading/baselines.py --alg TSF-4x64-d0 | ||||
| ##################################################### | ||||
| import sys | ||||
| import copy | ||||
| import argparse | ||||
| from collections import OrderedDict | ||||
| from pathlib import Path | ||||
| @@ -38,6 +40,33 @@ from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
|  | ||||
|  | ||||
| def to_pos_drop(config, value): | ||||
|     config = copy.deepcopy(config) | ||||
|     net = config["task"]["model"]["kwargs"]["net_config"] | ||||
|     net["pos_drop"] = value | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def to_layer(config, embed_dim, depth): | ||||
|     config = copy.deepcopy(config) | ||||
|     net = config["task"]["model"]["kwargs"]["net_config"] | ||||
|     net["embed_dim"] = embed_dim | ||||
|     net["num_heads"] = [4] * depth | ||||
|     net["mlp_hidden_multipliers"] = [4] * depth | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def extend_transformer_settings(alg2configs, name): | ||||
|     config = copy.deepcopy(alg2configs[name]) | ||||
|     for i in range(6): | ||||
|         for j in [24, 32, 48, 64]: | ||||
|             for k in [0, 0.1]: | ||||
|                 alg2configs[name + "-{:}x{:}-d{:}".format(i, j, k)] = to_layer( | ||||
|                     to_pos_drop(config, k), j, i | ||||
|                 ) | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def retrieve_configs(): | ||||
|     # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ | ||||
|     config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() | ||||
| @@ -60,29 +89,28 @@ def retrieve_configs(): | ||||
|     alg2names["NAIVE-V1"] = "workflow_config_naive_v1_Alpha360.yaml" | ||||
|     alg2names["NAIVE-V2"] = "workflow_config_naive_v2_Alpha360.yaml" | ||||
|     alg2names["Transformer"] = "workflow_config_transformer_Alpha360.yaml" | ||||
|     alg2names["TSF-A"] = "workflow_config_transformer_basic_Alpha360.yaml" | ||||
|     alg2names["TSF"] = "workflow_config_transformer_basic_Alpha360.yaml" | ||||
|  | ||||
|     # find the yaml paths | ||||
|     alg2paths = OrderedDict() | ||||
|     alg2configs = OrderedDict() | ||||
|     print("Start retrieving the algorithm configurations") | ||||
|     for idx, (alg, name) in enumerate(alg2names.items()): | ||||
|         path = config_dir / name | ||||
|         assert path.exists(), "{:} does not exist.".format(path) | ||||
|         alg2paths[alg] = str(path) | ||||
|         with open(path) as fp: | ||||
|             alg2configs[alg] = yaml.safe_load(fp) | ||||
|         print( | ||||
|             "The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format( | ||||
|                 idx, len(alg2names), alg, path | ||||
|                 idx, len(alg2configs), alg, path | ||||
|             ) | ||||
|         ) | ||||
|     return alg2paths | ||||
|     alg2configs = extend_transformer_settings(alg2configs, "TSF-A") | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def main(xargs, exp_yaml): | ||||
|     assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml) | ||||
| def main(xargs, config): | ||||
|  | ||||
|     pprint("Run {:}".format(xargs.alg)) | ||||
|     with open(exp_yaml) as fp: | ||||
|         config = yaml.safe_load(fp) | ||||
|     config = update_market(config, xargs.market) | ||||
|     config = update_gpu(config, xargs.gpu) | ||||
|  | ||||
| @@ -105,7 +133,7 @@ def main(xargs, exp_yaml): | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     alg2paths = retrieve_configs() | ||||
|     alg2configs = retrieve_configs() | ||||
|  | ||||
|     parser = argparse.ArgumentParser("Baselines") | ||||
|     parser.add_argument( | ||||
| @@ -121,7 +149,7 @@ if __name__ == "__main__": | ||||
|         choices=["csi100", "csi300", "all"], | ||||
|         help="The market indicator.", | ||||
|     ) | ||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||
|     parser.add_argument("--times", type=int, default=5, help="The repeated run times.") | ||||
|     parser.add_argument( | ||||
|         "--gpu", type=int, default=0, help="The GPU ID used for train / test." | ||||
|     ) | ||||
| @@ -134,4 +162,4 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     main(args, alg2paths[args.alg]) | ||||
|     main(args, alg2configs[args.alg]) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user