Sync qlib
This commit is contained in:
		| @@ -68,7 +68,7 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|     model_fit_kwargs = dict(dataset=dataset) | ||||
|  | ||||
|     # Let's start the experiment. | ||||
|     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri): | ||||
|     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri, resume=True): | ||||
|         # Setup log | ||||
|         recorder_root_dir = R.get_recorder().get_local_dir() | ||||
|         log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name)) | ||||
| @@ -81,7 +81,9 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|         # Train model | ||||
|         R.log_params(**flatten_dict(task_config)) | ||||
|         if "save_path" in inspect.getfullargspec(model.fit).args: | ||||
|             model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model.ckps") | ||||
|             model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model.ckp") | ||||
|         elif "save_dir" in inspect.getfullargspec(model.fit).args: | ||||
|             model_fit_kwargs["save_dir"] = os.path.join(recorder_root_dir, "model-ckps") | ||||
|         model.fit(**model_fit_kwargs) | ||||
|         # Get the recorder | ||||
|         recorder = R.get_recorder() | ||||
|   | ||||
| @@ -138,7 +138,7 @@ class QuantTransformer(Model): | ||||
|     def fit( | ||||
|         self, | ||||
|         dataset: DatasetH, | ||||
|         save_path: Optional[Text] = None, | ||||
|         save_dir: Optional[Text] = None, | ||||
|     ): | ||||
|         def _prepare_dataset(df_data): | ||||
|             return th_data.TensorDataset( | ||||
| @@ -172,8 +172,8 @@ class QuantTransformer(Model): | ||||
|             _prepare_loader(test_dataset, False), | ||||
|         ) | ||||
|  | ||||
|         save_path = get_or_create_path(save_path, return_dir=True) | ||||
|         self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_path)) | ||||
|         save_dir = get_or_create_path(save_dir, return_dir=True) | ||||
|         self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_dir)) | ||||
|  | ||||
|         def _internal_test(ckp_epoch=None, results_dict=None): | ||||
|             with torch.no_grad(): | ||||
| @@ -196,15 +196,18 @@ class QuantTransformer(Model): | ||||
|                 return dict(train=train_score, valid=valid_score, test=test_score), xstr | ||||
|  | ||||
|         # Pre-fetch the potential checkpoints | ||||
|         ckp_path = os.path.join(save_path, "{:}.pth".format(self.__class__.__name__)) | ||||
|         ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__)) | ||||
|         if os.path.exists(ckp_path): | ||||
|             ckp_data = torch.load(ckp_path) | ||||
|             import pdb | ||||
|  | ||||
|             pdb.set_trace() | ||||
|             stop_steps, best_score, best_epoch = ckp_data['stop_steps'], ckp_data['best_score'], ckp_data['best_epoch'] | ||||
|             start_epoch, best_param = ckp_data['start_epoch'], ckp_data['best_param'] | ||||
|             results_dict = ckp_data['results_dict'] | ||||
|             self.model.load_state_dict(ckp_data['net_state_dict']) | ||||
|             self.train_optimizer.load_state_dict(ckp_data['opt_state_dict']) | ||||
|             self.logger.info("Resume from existing checkpoint: {:}".format(ckp_path)) | ||||
|         else: | ||||
|             stop_steps, best_score, best_epoch = 0, -np.inf, -1 | ||||
|             start_epoch = 0 | ||||
|             start_epoch, best_param = 0, None | ||||
|             results_dict = dict(train=OrderedDict(), valid=OrderedDict(), test=OrderedDict()) | ||||
|             _, eval_str = _internal_test(-1, results_dict) | ||||
|             self.logger.info("Training from scratch, metrics@start: {:}".format(eval_str)) | ||||
| @@ -215,7 +218,6 @@ class QuantTransformer(Model): | ||||
|                     iepoch, self.opt_config["epochs"], best_epoch, best_score | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|             train_loss, train_score = self.train_or_test_epoch( | ||||
|                 train_loader, self.model, self.loss_fn, self.metric_fn, True, self.train_optimizer | ||||
|             ) | ||||
| @@ -241,11 +243,14 @@ class QuantTransformer(Model): | ||||
|                 stop_steps=stop_steps, | ||||
|                 best_score=best_score, | ||||
|                 best_epoch=best_epoch, | ||||
|                 results_dict=results_dict, | ||||
|                 start_epoch=iepoch + 1, | ||||
|             ) | ||||
|             torch.save(save_info, ckp_path) | ||||
|         self.logger.info("The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch)) | ||||
|         self.model.load_state_dict(best_param) | ||||
|         _, eval_str = _internal_test('final', results_dict) | ||||
|         self.logger.info("Reload the best parameter :: {:}".format(eval_str)) | ||||
|  | ||||
|         if self.use_gpu: | ||||
|             torch.cuda.empty_cache() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user