Re-org baselines
This commit is contained in:
		 Submodule .latent-data/qlib updated: ba56e4071e...3886022669
									
								
							| @@ -15,9 +15,11 @@ | |||||||
| # python exps/trading/baselines.py --alg TabNet     # | # python exps/trading/baselines.py --alg TabNet     # | ||||||
| #                                                   # | #                                                   # | ||||||
| # python exps/trading/baselines.py --alg Transformer# | # python exps/trading/baselines.py --alg Transformer# | ||||||
| # python exps/trading/baselines.py --alg TSF-A      # | # python exps/trading/baselines.py --alg TSF        # | ||||||
|  | # python exps/trading/baselines.py --alg TSF-4x64-d0 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys | import sys | ||||||
|  | import copy | ||||||
| import argparse | import argparse | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -38,6 +40,33 @@ from qlib.workflow import R | |||||||
| from qlib.utils import flatten_dict | from qlib.utils import flatten_dict | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def to_pos_drop(config, value): | ||||||
|  |     config = copy.deepcopy(config) | ||||||
|  |     net = config["task"]["model"]["kwargs"]["net_config"] | ||||||
|  |     net["pos_drop"] = value | ||||||
|  |     return config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def to_layer(config, embed_dim, depth): | ||||||
|  |     config = copy.deepcopy(config) | ||||||
|  |     net = config["task"]["model"]["kwargs"]["net_config"] | ||||||
|  |     net["embed_dim"] = embed_dim | ||||||
|  |     net["num_heads"] = [4] * depth | ||||||
|  |     net["mlp_hidden_multipliers"] = [4] * depth | ||||||
|  |     return config | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def extend_transformer_settings(alg2configs, name): | ||||||
|  |     config = copy.deepcopy(alg2configs[name]) | ||||||
|  |     for i in range(6): | ||||||
|  |         for j in [24, 32, 48, 64]: | ||||||
|  |             for k in [0, 0.1]: | ||||||
|  |                 alg2configs[name + "-{:}x{:}-d{:}".format(i, j, k)] = to_layer( | ||||||
|  |                     to_pos_drop(config, k), j, i | ||||||
|  |                 ) | ||||||
|  |     return alg2configs | ||||||
|  |  | ||||||
|  |  | ||||||
| def retrieve_configs(): | def retrieve_configs(): | ||||||
|     # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ |     # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ | ||||||
|     config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() |     config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() | ||||||
| @@ -60,29 +89,28 @@ def retrieve_configs(): | |||||||
|     alg2names["NAIVE-V1"] = "workflow_config_naive_v1_Alpha360.yaml" |     alg2names["NAIVE-V1"] = "workflow_config_naive_v1_Alpha360.yaml" | ||||||
|     alg2names["NAIVE-V2"] = "workflow_config_naive_v2_Alpha360.yaml" |     alg2names["NAIVE-V2"] = "workflow_config_naive_v2_Alpha360.yaml" | ||||||
|     alg2names["Transformer"] = "workflow_config_transformer_Alpha360.yaml" |     alg2names["Transformer"] = "workflow_config_transformer_Alpha360.yaml" | ||||||
|     alg2names["TSF-A"] = "workflow_config_transformer_basic_Alpha360.yaml" |     alg2names["TSF"] = "workflow_config_transformer_basic_Alpha360.yaml" | ||||||
|  |  | ||||||
|     # find the yaml paths |     # find the yaml paths | ||||||
|     alg2paths = OrderedDict() |     alg2configs = OrderedDict() | ||||||
|     print("Start retrieving the algorithm configurations") |     print("Start retrieving the algorithm configurations") | ||||||
|     for idx, (alg, name) in enumerate(alg2names.items()): |     for idx, (alg, name) in enumerate(alg2names.items()): | ||||||
|         path = config_dir / name |         path = config_dir / name | ||||||
|         assert path.exists(), "{:} does not exist.".format(path) |         assert path.exists(), "{:} does not exist.".format(path) | ||||||
|         alg2paths[alg] = str(path) |         with open(path) as fp: | ||||||
|  |             alg2configs[alg] = yaml.safe_load(fp) | ||||||
|         print( |         print( | ||||||
|             "The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format( |             "The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format( | ||||||
|                 idx, len(alg2names), alg, path |                 idx, len(alg2configs), alg, path | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|     return alg2paths |     alg2configs = extend_transformer_settings(alg2configs, "TSF-A") | ||||||
|  |     return alg2configs | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs, exp_yaml): | def main(xargs, config): | ||||||
|     assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml) |  | ||||||
|  |  | ||||||
|     pprint("Run {:}".format(xargs.alg)) |     pprint("Run {:}".format(xargs.alg)) | ||||||
|     with open(exp_yaml) as fp: |  | ||||||
|         config = yaml.safe_load(fp) |  | ||||||
|     config = update_market(config, xargs.market) |     config = update_market(config, xargs.market) | ||||||
|     config = update_gpu(config, xargs.gpu) |     config = update_gpu(config, xargs.gpu) | ||||||
|  |  | ||||||
| @@ -105,7 +133,7 @@ def main(xargs, exp_yaml): | |||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  |  | ||||||
|     alg2paths = retrieve_configs() |     alg2configs = retrieve_configs() | ||||||
|  |  | ||||||
|     parser = argparse.ArgumentParser("Baselines") |     parser = argparse.ArgumentParser("Baselines") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -121,7 +149,7 @@ if __name__ == "__main__": | |||||||
|         choices=["csi100", "csi300", "all"], |         choices=["csi100", "csi300", "all"], | ||||||
|         help="The market indicator.", |         help="The market indicator.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") |     parser.add_argument("--times", type=int, default=5, help="The repeated run times.") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--gpu", type=int, default=0, help="The GPU ID used for train / test." |         "--gpu", type=int, default=0, help="The GPU ID used for train / test." | ||||||
|     ) |     ) | ||||||
| @@ -134,4 +162,4 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     main(args, alg2paths[args.alg]) |     main(args, alg2configs[args.alg]) | ||||||
|   | |||||||
| @@ -126,7 +126,7 @@ def query_info(save_dir, verbose): | |||||||
|     experiments = R.list_experiments() |     experiments = R.list_experiments() | ||||||
|  |  | ||||||
|     key_map = { |     key_map = { | ||||||
|         "RMSE": "RMSE", |         # "RMSE": "RMSE", | ||||||
|         "IC": "IC", |         "IC": "IC", | ||||||
|         "ICIR": "ICIR", |         "ICIR": "ICIR", | ||||||
|         "Rank IC": "Rank_IC", |         "Rank IC": "Rank_IC", | ||||||
|   | |||||||
| @@ -18,7 +18,7 @@ class SuperLayerNorm1D(SuperModule): | |||||||
|     """Super Layer Norm.""" |     """Super Layer Norm.""" | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, dim: IntSpaceType, eps: float = 1e-5, elementwise_affine: bool = True |         self, dim: IntSpaceType, eps: float = 1e-6, elementwise_affine: bool = True | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         super(SuperLayerNorm1D, self).__init__() |         super(SuperLayerNorm1D, self).__init__() | ||||||
|         self._in_dim = dim |         self._in_dim = dim | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user