Complete Super Linear
This commit is contained in:
		| @@ -1,5 +1,6 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| from .super_module import SuperRunMode | ||||
| from .super_module import SuperModule | ||||
| from .super_mlp import SuperLinear | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import math | ||||
| from typing import Optional, Union | ||||
| @@ -52,14 +53,15 @@ class SuperLinear(SuperModule): | ||||
|     def bias(self): | ||||
|         return spaces.has_categorical(self._bias, True) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             root_node.append("_in_features", self._in_features) | ||||
|             root_node.append("_in_features", self._in_features.abstract()) | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             root_node.append("_out_features", self._out_features) | ||||
|             root_node.append("_out_features", self._out_features.abstract()) | ||||
|         if not spaces.is_determined(self._bias): | ||||
|             root_node.append("_bias", self._bias) | ||||
|             root_node.append("_bias", self._bias.abstract()) | ||||
|         return root_node | ||||
|  | ||||
|     def reset_parameters(self) -> None: | ||||
| @@ -69,6 +71,37 @@ class SuperLinear(SuperModule): | ||||
|             bound = 1 / math.sqrt(fan_in) | ||||
|             nn.init.uniform_(self._super_bias, -bound, bound) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check inputs -> | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             expected_input_dim = self.abstract_child["_in_features"].value | ||||
|         else: | ||||
|             expected_input_dim = spaces.get_determined_value(self._in_features) | ||||
|         if input.size(-1) != expected_input_dim: | ||||
|             raise ValueError( | ||||
|                 "Expect the input dim of {:} instead of {:}".format( | ||||
|                     expected_input_dim, input.size(-1) | ||||
|                 ) | ||||
|             ) | ||||
|         # create the weight matrix | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             out_dim = self.abstract_child["_out_features"].value | ||||
|         else: | ||||
|             out_dim = spaces.get_determined_value(self._out_features) | ||||
|         candidate_weight = self._super_weight[:out_dim, :expected_input_dim] | ||||
|         # create the bias matrix | ||||
|         if not spaces.is_determined(self._bias): | ||||
|             if self.abstract_child["_bias"].value: | ||||
|                 candidate_bias = self._super_bias[:out_dim] | ||||
|             else: | ||||
|                 candidate_bias = None | ||||
|         else: | ||||
|             if spaces.get_determined_value(self._bias): | ||||
|                 candidate_bias = self._super_bias[:out_dim] | ||||
|             else: | ||||
|                 candidate_bias = None | ||||
|         return F.linear(input, candidate_weight, candidate_bias) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return F.linear(input, self._super_weight, self._super_bias) | ||||
|  | ||||
| @@ -78,8 +111,9 @@ class SuperLinear(SuperModule): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SuperMLP(nn.Module): | ||||
|     # MLP: FC -> Activation -> Drop -> FC -> Drop | ||||
| class SuperMLP(SuperModule): | ||||
|     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features, | ||||
| @@ -88,13 +122,13 @@ class SuperMLP(nn.Module): | ||||
|         act_layer=nn.GELU, | ||||
|         drop: Optional[float] = None, | ||||
|     ): | ||||
|         super(MLP, self).__init__() | ||||
|         super(SuperMLP, self).__init__() | ||||
|         out_features = out_features or in_features | ||||
|         hidden_features = hidden_features or in_features | ||||
|         self.fc1 = nn.Linear(in_features, hidden_features) | ||||
|         self.act = act_layer() | ||||
|         self.fc2 = nn.Linear(hidden_features, out_features) | ||||
|         self.drop = nn.Dropout(drop or 0) | ||||
|         self.drop = nn.Dropout(drop or 0.0) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.fc1(x) | ||||
|   | ||||
| @@ -6,11 +6,14 @@ import abc | ||||
| import torch.nn as nn | ||||
| from enum import Enum | ||||
|  | ||||
| import spaces | ||||
|  | ||||
|  | ||||
| class SuperRunMode(Enum): | ||||
|     """This class defines the enumerations for Super Model Running Mode.""" | ||||
|  | ||||
|     FullModel = "fullmodel" | ||||
|     Candidate = "candidate" | ||||
|     Default = "fullmodel" | ||||
|  | ||||
|  | ||||
| @@ -20,8 +23,23 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|     def __init__(self): | ||||
|         super(SuperModule, self).__init__() | ||||
|         self._super_run_type = SuperRunMode.Default | ||||
|         self._abstract_child = None | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def set_super_run_type(self, super_run_type): | ||||
|         def _reset_super_run(m): | ||||
|             if isinstance(m, SuperModule): | ||||
|                 m._super_run_type = super_run_type | ||||
|  | ||||
|         self.apply(_reset_super_run) | ||||
|  | ||||
|     def apply_candiate(self, abstract_child): | ||||
|         if not isinstance(abstract_child, spaces.VirtualNode): | ||||
|             raise ValueError( | ||||
|                 "Invalid abstract child program: {:}".format(abstract_child) | ||||
|             ) | ||||
|         self._abstract_child = abstract_child | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
| @@ -29,13 +47,24 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|     def super_run_type(self): | ||||
|         return self._super_run_type | ||||
|  | ||||
|     @property | ||||
|     def abstract_child(self): | ||||
|         return self._abstract_child | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def forward_raw(self, *inputs): | ||||
|         """Use the largest candidate for forward. Similar to the original PyTorch model.""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def forward_candidate(self, *inputs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward(self, *inputs): | ||||
|         if self.super_run_type == SuperRunMode.FullModel: | ||||
|             return self.forward_raw(*inputs) | ||||
|         elif self.super_run_type == SuperRunMode.Candidate: | ||||
|             return self.forward_candidate(*inputs) | ||||
|         else: | ||||
|             raise ModeError( | ||||
|                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user