Add name filters for exp-org
This commit is contained in:
		| @@ -100,9 +100,9 @@ def run_exp( | ||||
|         # Train model | ||||
|         try: | ||||
|             if hasattr(model, "to"):  # Recoverable model | ||||
|                 device = model.device | ||||
|                 ori_device = model.device | ||||
|                 model = R.load_object(model_obj_name) | ||||
|                 model.to(device) | ||||
|                 model.to(ori_device) | ||||
|             else: | ||||
|                 model = R.load_object(model_obj_name) | ||||
|             logger.info("[Find existing object from {:}]".format(model_obj_name)) | ||||
| @@ -119,9 +119,10 @@ def run_exp( | ||||
|             model.fit(**model_fit_kwargs) | ||||
|             # remove model to CPU for saving | ||||
|             if hasattr(model, "to"): | ||||
|                 old_device = model.device | ||||
|                 model.to("cpu") | ||||
|                 R.save_objects(**{model_obj_name: model}) | ||||
|                 model.to() | ||||
|                 model.to(old_device) | ||||
|             else: | ||||
|                 R.save_objects(**{model_obj_name: model}) | ||||
|         except Exception as e: | ||||
|   | ||||
| @@ -114,9 +114,9 @@ class QuantTransformer(Model): | ||||
|  | ||||
|     def to(self, device): | ||||
|         if device is None: | ||||
|             self.model.to(self.device) | ||||
|         else: | ||||
|             self.model.to("cpu") | ||||
|             device = "cpu" | ||||
|         self.device = device | ||||
|         self.model.to(self.device) | ||||
|  | ||||
|     def loss_fn(self, pred, label): | ||||
|         mask = ~torch.isnan(label) | ||||
| @@ -227,7 +227,7 @@ class QuantTransformer(Model): | ||||
|         # Pre-fetch the potential checkpoints | ||||
|         ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__)) | ||||
|         if os.path.exists(ckp_path): | ||||
|             ckp_data = torch.load(ckp_path) | ||||
|             ckp_data = torch.load(ckp_path, map_location=self.device) | ||||
|             stop_steps, best_score, best_epoch = ( | ||||
|                 ckp_data["stop_steps"], | ||||
|                 ckp_data["best_score"], | ||||
| @@ -298,7 +298,7 @@ class QuantTransformer(Model): | ||||
|                 results_dict=results_dict, | ||||
|                 start_epoch=iepoch + 1, | ||||
|             ) | ||||
|             torch.save(save_info, ckp_path) | ||||
|             torch.save(save_info, ckp_path, map_location="cpu") | ||||
|         self.logger.info( | ||||
|             "The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch) | ||||
|         ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user