Update models
This commit is contained in:
		| @@ -99,7 +99,12 @@ def run_exp( | ||||
|  | ||||
|         # Train model | ||||
|         try: | ||||
|             model = R.load_object(model_obj_name) | ||||
|             if hasattr(model, "to"):  # Recoverable model | ||||
|                 device = model.device | ||||
|                 model = R.load_object(model_obj_name) | ||||
|                 model.to(device) | ||||
|             else: | ||||
|                 model = R.load_object(model_obj_name) | ||||
|             logger.info("[Find existing object from {:}]".format(model_obj_name)) | ||||
|         except OSError: | ||||
|             R.log_params(**flatten_dict(task_config)) | ||||
| @@ -112,16 +117,29 @@ def run_exp( | ||||
|                     recorder_root_dir, "model-ckps" | ||||
|                 ) | ||||
|             model.fit(**model_fit_kwargs) | ||||
|             R.save_objects(**{model_obj_name: model}) | ||||
|         except: | ||||
|             raise ValueError("Something wrong.") | ||||
|             # remove model to CPU for saving | ||||
|             if hasattr(model, "to"): | ||||
|                 model.to("cpu") | ||||
|                 R.save_objects(**{model_obj_name: model}) | ||||
|                 model.to() | ||||
|             else: | ||||
|                 R.save_objects(**{model_obj_name: model}) | ||||
|         except Exception as e: | ||||
|             import pdb | ||||
|  | ||||
|             pdb.set_trace() | ||||
|             raise ValueError("Something wrong: {:}".format(e)) | ||||
|         # Get the recorder | ||||
|         recorder = R.get_recorder() | ||||
|  | ||||
|         # Generate records: prediction, backtest, and analysis | ||||
|         for record in task_config["record"]: | ||||
|             record = deepcopy(record) | ||||
|             if record["class"] == "SignalRecord": | ||||
|             if record["class"] == "MultiSegRecord": | ||||
|                 record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder) | ||||
|                 sr = init_instance_by_config(record) | ||||
|                 sr.generate(**record["generate_kwargs"]) | ||||
|             elif record["class"] == "SignalRecord": | ||||
|                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} | ||||
|                 record["kwargs"].update(srconf) | ||||
|                 sr = init_instance_by_config(record) | ||||
|   | ||||
| @@ -112,6 +112,12 @@ class QuantTransformer(Model): | ||||
|     def use_gpu(self): | ||||
|         return self.device != torch.device("cpu") | ||||
|  | ||||
|     def to(self, device): | ||||
|         if device is None: | ||||
|             self.model.to(self.device) | ||||
|         else: | ||||
|             self.model.to("cpu") | ||||
|  | ||||
|     def loss_fn(self, pred, label): | ||||
|         mask = ~torch.isnan(label) | ||||
|         if self.opt_config["loss"] == "mse": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user