Add the SuperMLP class
This commit is contained in:
		| @@ -3,4 +3,5 @@ | ||||
| ##################################################### | ||||
| from .super_module import SuperRunMode | ||||
| from .super_module import SuperModule | ||||
| from .super_mlp import SuperLinear | ||||
| from .super_linear import SuperLinear | ||||
| from .super_linear import SuperMLP | ||||
|   | ||||
| @@ -6,7 +6,7 @@ import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| 
 | ||||
| import math | ||||
| from typing import Optional, Union | ||||
| from typing import Optional, Union, Callable | ||||
| 
 | ||||
| import spaces | ||||
| from .super_module import SuperModule | ||||
| @@ -57,11 +57,15 @@ class SuperLinear(SuperModule): | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             root_node.append("_in_features", self._in_features.abstract()) | ||||
|             root_node.append( | ||||
|                 "_in_features", self._in_features.abstract(reuse_last=True) | ||||
|             ) | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             root_node.append("_out_features", self._out_features.abstract()) | ||||
|             root_node.append( | ||||
|                 "_out_features", self._out_features.abstract(reuse_last=True) | ||||
|             ) | ||||
|         if not spaces.is_determined(self._bias): | ||||
|             root_node.append("_bias", self._bias.abstract()) | ||||
|             root_node.append("_bias", self._bias.abstract(reuse_last=True)) | ||||
|         return root_node | ||||
| 
 | ||||
|     def reset_parameters(self) -> None: | ||||
| @@ -116,24 +120,51 @@ class SuperMLP(SuperModule): | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features, | ||||
|         hidden_features: Optional[int] = None, | ||||
|         out_features: Optional[int] = None, | ||||
|         act_layer=nn.GELU, | ||||
|         in_features: IntSpaceType, | ||||
|         hidden_features: IntSpaceType, | ||||
|         out_features: IntSpaceType, | ||||
|         act_layer: Callable[[], nn.Module] = nn.GELU, | ||||
|         drop: Optional[float] = None, | ||||
|     ): | ||||
|         super(SuperMLP, self).__init__() | ||||
|         out_features = out_features or in_features | ||||
|         hidden_features = hidden_features or in_features | ||||
|         self.fc1 = nn.Linear(in_features, hidden_features) | ||||
|         self._in_features = in_features | ||||
|         self._hidden_features = hidden_features | ||||
|         self._out_features = out_features | ||||
|         self._drop_rate = drop | ||||
|         self.fc1 = SuperLinear(in_features, hidden_features) | ||||
|         self.act = act_layer() | ||||
|         self.fc2 = nn.Linear(hidden_features, out_features) | ||||
|         self.fc2 = SuperLinear(hidden_features, out_features) | ||||
|         self.drop = nn.Dropout(drop or 0.0) | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         space_fc1 = self.fc1.abstract_search_space | ||||
|         space_fc2 = self.fc2.abstract_search_space | ||||
|         if not spaces.is_determined(space_fc1): | ||||
|             root_node.append("fc1", space_fc1) | ||||
|         if not spaces.is_determined(space_fc2): | ||||
|             root_node.append("fc2", space_fc2) | ||||
|         return root_node | ||||
| 
 | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self._unified_forward(x) | ||||
| 
 | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self._unified_forward(x) | ||||
| 
 | ||||
|     def _unified_forward(self, x): | ||||
|         x = self.fc1(x) | ||||
|         x = self.act(x) | ||||
|         x = self.drop(x) | ||||
|         x = self.fc2(x) | ||||
|         x = self.drop(x) | ||||
|         return x | ||||
| 
 | ||||
|     def extra_repr(self) -> str: | ||||
|         return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( | ||||
|             self._in_features, | ||||
|             self._hidden_features, | ||||
|             self._out_features, | ||||
|             self._drop_rate, | ||||
|         ) | ||||
		Reference in New Issue
	
	Block a user