Update scripts
This commit is contained in:
		| @@ -210,5 +210,23 @@ if __name__ == "__main__": | ||||
|             args.save_dir, | ||||
|             args.gpu, | ||||
|         ) | ||||
|     else: | ||||
|         print("-") | ||||
|     elif len(args.alg) > 1: | ||||
|         assert args.shared_dataset, "Must allow share dataset" | ||||
|         configs = [ | ||||
|             update_gpu(update_market(alg2configs[name], args.market), args.gpu) | ||||
|             for name in args.alg | ||||
|         ] | ||||
|         qlib.init(**configs[0].get("qlib_init")) | ||||
|         dataset_config = configs[0].get("task").get("dataset") | ||||
|         dataset = init_instance_by_config(dataset_config) | ||||
|         pprint(dataset_config) | ||||
|         pprint(dataset) | ||||
|         for alg_name, config in zip(args.alg, configs): | ||||
|             for irun in range(args.times): | ||||
|                 run_exp( | ||||
|                     config.get("task"), | ||||
|                     dataset, | ||||
|                     alg_name, | ||||
|                     "recorder-{:02d}-{:02d}".format(irun, args.times), | ||||
|                     "{:}-{:}".format(args.save_dir, args.market), | ||||
|                 ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user