##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # ##################################################### import os from pathlib import Path import abc import tempfile import warnings from typing import Optional, Union, Callable import torch import torch.nn as nn from enum import Enum from xautodl import spaces from .super_utils import IntSpaceType, BoolSpaceType from .super_utils import LayerOrder, SuperRunMode from .super_utils import TensorContainer from .super_utils import ShapeContainer BEST_DIR_KEY = "best_model_dir" BEST_NAME_KEY = "best_model_name" BEST_SCORE_KEY = "best_model_score" class SuperModule(abc.ABC, nn.Module): """This class equips the nn.Module class with the ability to apply AutoDL.""" def __init__(self): super(SuperModule, self).__init__() self._super_run_type = SuperRunMode.Default self._abstract_child = None self._verbose = False self._meta_info = {} def set_super_run_type(self, super_run_type): def _reset_super_run(m): if isinstance(m, SuperModule): m._super_run_type = super_run_type self.apply(_reset_super_run) def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: if not isinstance(module, SuperModule): warnings.warn( "Add {:}:{:} module, which is not SuperModule, into {:}".format( name, module.__class__.__name__, self.__class__.__name__ ) + "\n" + "It may cause some functions invalid." ) super(SuperModule, self).add_module(name, module) def apply_verbose(self, verbose): def _reset_verbose(m): if isinstance(m, SuperModule): m._verbose = verbose self.apply(_reset_verbose) def apply_candidate(self, abstract_child): if not isinstance(abstract_child, spaces.VirtualNode): raise ValueError( "Invalid abstract child program: {:}".format(abstract_child) ) self._abstract_child = abstract_child def get_w_container(self): container = TensorContainer() for name, param in self.named_parameters(): container.append(name, param, True) for name, buf in self.named_buffers(): container.append(name, buf, False) return container def analyze_weights(self): with torch.no_grad(): for name, param in self.named_parameters(): shapestr = "[{:10s}] shape={:}".format(name, list(param.shape)) finalstr = shapestr + "{:.2f} +- {:.2f}".format( param.mean(), param.std() ) print(finalstr) def numel(self, buffer=True): total = 0 for name, param in self.named_parameters(): total += param.numel() if buffer: for name, buf in self.named_buffers(): total += buf.numel() return total def set_best_dir(self, xdir): self._meta_info[BEST_DIR_KEY] = str(xdir) Path(xdir).mkdir(parents=True, exist_ok=True) def set_best_name(self, xname): self._meta_info[BEST_NAME_KEY] = str(xname) def save_best(self, score): if BEST_DIR_KEY not in self._meta_info: tempdir = tempfile.mkdtemp("-xlayers") self._meta_info[BEST_DIR_KEY] = tempdir if BEST_SCORE_KEY not in self._meta_info: self._meta_info[BEST_SCORE_KEY] = None best_score = self._meta_info[BEST_SCORE_KEY] if best_score is None or best_score <= score: best_save_name = self._meta_info.get( BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__) ) best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name) self._meta_info[BEST_SCORE_KEY] = score torch.save(self.state_dict(), best_save_path) return True, self._meta_info[BEST_SCORE_KEY] else: return False, self._meta_info[BEST_SCORE_KEY] def load_best(self, best_save_path=None): if best_save_path is None: if ( BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info ): raise ValueError("Please call save_best at first") best_save_name = self._meta_info.get( BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__) ) best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name) state_dict = torch.load(best_save_path) self.load_state_dict(state_dict) def has_best(self, best_name=None): if BEST_DIR_KEY not in self._meta_info: raise ValueError("Please set BEST_DIR_KEY at first") if best_name is None: best_save_name = self._meta_info.get( BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__) ) else: best_save_name = best_name best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name) return os.path.exists(best_save_path) @property def abstract_search_space(self): raise NotImplementedError @property def super_run_type(self): return self._super_run_type @property def abstract_child(self): return self._abstract_child @property def verbose(self): return self._verbose @abc.abstractmethod def forward_raw(self, *inputs): """Use the largest candidate for forward. Similar to the original PyTorch model.""" raise NotImplementedError @abc.abstractmethod def forward_candidate(self, *inputs): raise NotImplementedError @property def name_with_id(self): return "name={:}, id={:}".format(self.__class__.__name__, id(self)) def get_shape_str(self, tensors): if isinstance(tensors, (list, tuple)): shapes = [self.get_shape_str(tensor) for tensor in tensors] if len(shapes) == 1: return shapes[0] else: return ", ".join(shapes) elif isinstance(tensors, (torch.Tensor, nn.Parameter)): return str(tuple(tensors.shape)) else: raise TypeError("Invalid input type: {:}.".format(type(tensors))) def forward(self, *inputs): if self.verbose: print( "[{:}] inputs shape: {:}".format( self.name_with_id, self.get_shape_str(inputs) ) ) if self.super_run_type == SuperRunMode.FullModel: outputs = self.forward_raw(*inputs) elif self.super_run_type == SuperRunMode.Candidate: outputs = self.forward_candidate(*inputs) else: raise ModeError( "Unknown Super Model Run Mode: {:}".format(self.super_run_type) ) if self.verbose: print( "[{:}] outputs shape: {:}".format( self.name_with_id, self.get_shape_str(outputs) ) ) return outputs def forward_with_container(self, inputs, container, prefix=[]): raise NotImplementedError