Add the SuperMLP class
This commit is contained in:
		| @@ -96,7 +96,8 @@ Some methods use knowledge distillation (KD), which require pre-trained models. | |||||||
|  |  | ||||||
| Please use | Please use | ||||||
| ``` | ``` | ||||||
| git clone --recurse-submodules git@github.com:D-X-Y/AutoDL-Projects.git | git clone --recurse-submodules git@github.com:D-X-Y/AutoDL-Projects.git XAutoDL | ||||||
|  | git clone --recurse-submodules https://github.com/D-X-Y/AutoDL-Projects.git XAutoDL | ||||||
| ``` | ``` | ||||||
| to download this repo with submodules. | to download this repo with submodules. | ||||||
|  |  | ||||||
|   | |||||||
| @@ -22,19 +22,32 @@ class Space(metaclass=abc.ABCMeta): | |||||||
|     All search space must inherit from this basic class. |     All search space must inherit from this basic class. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         # used to avoid duplicate sample | ||||||
|  |         self._last_sample = None | ||||||
|  |         self._last_abstract = None | ||||||
|  |  | ||||||
|     @abc.abstractproperty |     @abc.abstractproperty | ||||||
|     def xrepr(self, prefix="") -> Text: |     def xrepr(self, depth=0) -> Text: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def __repr__(self) -> Text: |     def __repr__(self) -> Text: | ||||||
|         return self.xrepr() |         return self.xrepr() | ||||||
|  |  | ||||||
|     @abc.abstractproperty |     @abc.abstractproperty | ||||||
|     def abstract(self) -> "Space": |     def abstract(self, reuse_last=False) -> "Space": | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def random(self, recursion=True): |     def random(self, recursion=True, reuse_last=False): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def clean_last_sample(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def clean_last_abstract(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractproperty |     @abc.abstractproperty | ||||||
| @@ -63,6 +76,7 @@ class VirtualNode(Space): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, id=None, value=None): |     def __init__(self, id=None, value=None): | ||||||
|  |         super(VirtualNode, self).__init__() | ||||||
|         self._id = id |         self._id = id | ||||||
|         self._value = value |         self._value = value | ||||||
|         self._attributes = OrderedDict() |         self._attributes = OrderedDict() | ||||||
| @@ -82,26 +96,51 @@ class VirtualNode(Space): | |||||||
|         #    raise ValueError("Can not attach a determined value: {:}".format(value)) |         #    raise ValueError("Can not attach a determined value: {:}".format(value)) | ||||||
|         self._attributes[key] = value |         self._attributes[key] = value | ||||||
|  |  | ||||||
|     def xrepr(self, prefix="  ") -> Text: |     def xrepr(self, depth=0) -> Text: | ||||||
|         strs = [self.__class__.__name__ + "(value={:}".format(self._value)] |         strs = [self.__class__.__name__ + "(value={:}".format(self._value)] | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             strs.append(value.xrepr(prefix + "  " + key + " = ")) |             strs.append(key + " = " + value.xrepr(depth + 1)) | ||||||
|         strs.append(")") |         strs.append(")") | ||||||
|         return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs) |         if len(strs) == 2: | ||||||
|  |             return "".join(strs) | ||||||
|  |         else: | ||||||
|  |             space = "  " | ||||||
|  |             xstrs = ( | ||||||
|  |                 [strs[0]] | ||||||
|  |                 + [space * (depth + 1) + x for x in strs[1:-1]] | ||||||
|  |                 + [space * depth + strs[-1]] | ||||||
|  |             ) | ||||||
|  |             return ",\n".join(xstrs) | ||||||
|  |  | ||||||
|     def abstract(self) -> Space: |     def abstract(self, reuse_last=False) -> Space: | ||||||
|  |         if reuse_last and self._last_abstract is not None: | ||||||
|  |             return self._last_abstract | ||||||
|         node = VirtualNode(id(self)) |         node = VirtualNode(id(self)) | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             if not value.determined: |             if not value.determined: | ||||||
|                 node.append(value.abstract()) |                 node.append(value.abstract(reuse_last)) | ||||||
|         return node |         self._last_abstract = node | ||||||
|  |         return self._last_abstract | ||||||
|  |  | ||||||
|     def random(self, recursion=True): |     def random(self, recursion=True, reuse_last=False): | ||||||
|  |         if reuse_last and self._last_sample is not None: | ||||||
|  |             return self._last_sample | ||||||
|         node = VirtualNode(None, self._value) |         node = VirtualNode(None, self._value) | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             node.append(key, value.random(recursion)) |             node.append(key, value.random(recursion, reuse_last)) | ||||||
|  |         self._last_sample = node  # record the last sample | ||||||
|         return node |         return node | ||||||
|  |  | ||||||
|  |     def clean_last_sample(self): | ||||||
|  |         self._last_sample = None | ||||||
|  |         for key, value in self._attributes.items(): | ||||||
|  |             value.clean_last_sample() | ||||||
|  |  | ||||||
|  |     def clean_last_abstract(self): | ||||||
|  |         self._last_abstract = None | ||||||
|  |         for key, value in self._attributes.items(): | ||||||
|  |             value.clean_last_abstract() | ||||||
|  |  | ||||||
|     def has(self, x) -> bool: |     def has(self, x) -> bool: | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             if value.has(x): |             if value.has(x): | ||||||
| @@ -117,7 +156,7 @@ class VirtualNode(Space): | |||||||
|     @property |     @property | ||||||
|     def determined(self) -> bool: |     def determined(self) -> bool: | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             if not value.determined(x): |             if not value.determined: | ||||||
|                 return False |                 return False | ||||||
|         return True |         return True | ||||||
|  |  | ||||||
| @@ -138,6 +177,7 @@ class Categorical(Space): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, *data, default: Optional[int] = None): |     def __init__(self, *data, default: Optional[int] = None): | ||||||
|  |         super(Categorical, self).__init__() | ||||||
|         self._candidates = [*data] |         self._candidates = [*data] | ||||||
|         self._default = default |         self._default = default | ||||||
|         assert self._default is None or 0 <= self._default < len( |         assert self._default is None or 0 <= self._default < len( | ||||||
| @@ -169,32 +209,54 @@ class Categorical(Space): | |||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self._candidates) |         return len(self._candidates) | ||||||
|  |  | ||||||
|     def abstract(self) -> Space: |     def clean_last_sample(self): | ||||||
|         if self.determined: |         self._last_sample = None | ||||||
|             return VirtualNode(id(self), self) |         for candidate in self._candidates: | ||||||
|         # [TO-IMPROVE] |  | ||||||
|         data = [] |  | ||||||
|         for candidate in self.candidates: |  | ||||||
|             if isinstance(candidate, Space): |             if isinstance(candidate, Space): | ||||||
|                 data.append(candidate.abstract()) |                 candidate.clean_last_sample() | ||||||
|             else: |  | ||||||
|                 data.append(VirtualNode(id(candidate), candidate)) |  | ||||||
|         return Categorical(*data, default=self._default) |  | ||||||
|  |  | ||||||
|     def random(self, recursion=True): |     def clean_last_abstract(self): | ||||||
|  |         self._last_abstract = None | ||||||
|  |         for candidate in self._candidates: | ||||||
|  |             if isinstance(candidate, Space): | ||||||
|  |                 candidate.clean_last_abstract() | ||||||
|  |  | ||||||
|  |     def abstract(self, reuse_last=False) -> Space: | ||||||
|  |         if reuse_last and self._last_abstract is not None: | ||||||
|  |             return self._last_abstract | ||||||
|  |         if self.determined: | ||||||
|  |             result = VirtualNode(id(self), self) | ||||||
|  |         else: | ||||||
|  |             # [TO-IMPROVE] | ||||||
|  |             data = [] | ||||||
|  |             for candidate in self.candidates: | ||||||
|  |                 if isinstance(candidate, Space): | ||||||
|  |                     data.append(candidate.abstract()) | ||||||
|  |                 else: | ||||||
|  |                     data.append(VirtualNode(id(candidate), candidate)) | ||||||
|  |             result = Categorical(*data, default=self._default) | ||||||
|  |         self._last_abstract = result | ||||||
|  |         return self._last_abstract | ||||||
|  |  | ||||||
|  |     def random(self, recursion=True, reuse_last=False): | ||||||
|  |         if reuse_last and self._last_sample is not None: | ||||||
|  |             return self._last_sample | ||||||
|         sample = random.choice(self._candidates) |         sample = random.choice(self._candidates) | ||||||
|         if recursion and isinstance(sample, Space): |         if recursion and isinstance(sample, Space): | ||||||
|             sample = sample.random(recursion) |             sample = sample.random(recursion, reuse_last) | ||||||
|         if isinstance(sample, VirtualNode): |         if isinstance(sample, VirtualNode): | ||||||
|             return sample.copy() |             sample = sample.copy() | ||||||
|         else: |         else: | ||||||
|             return VirtualNode(None, sample) |             sample = VirtualNode(None, sample) | ||||||
|  |         self._last_sample = sample | ||||||
|  |         return self._last_sample | ||||||
|  |  | ||||||
|     def xrepr(self, prefix=""): |     def xrepr(self, depth=0): | ||||||
|  |         del depth | ||||||
|         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( |         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( | ||||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default |             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||||
|         ) |         ) | ||||||
|         return prefix + xrepr |         return xrepr | ||||||
|  |  | ||||||
|     def has(self, x): |     def has(self, x): | ||||||
|         super().has(x) |         super().has(x) | ||||||
| @@ -213,7 +275,7 @@ class Categorical(Space): | |||||||
|         if self.default != other.default: |         if self.default != other.default: | ||||||
|             return False |             return False | ||||||
|         for index in range(len(self)): |         for index in range(len(self)): | ||||||
|             if self.__getitem__[index] != other[index]: |             if self.__getitem__(index) != other[index]: | ||||||
|                 return False |                 return False | ||||||
|         return True |         return True | ||||||
|  |  | ||||||
| @@ -235,14 +297,15 @@ class Integer(Categorical): | |||||||
|             default = data.index(default) |             default = data.index(default) | ||||||
|         super(Integer, self).__init__(*data, default=default) |         super(Integer, self).__init__(*data, default=default) | ||||||
|  |  | ||||||
|     def xrepr(self, prefix=""): |     def xrepr(self, depth=0): | ||||||
|  |         del depth | ||||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( |         xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             lower=self._raw_lower, |             lower=self._raw_lower, | ||||||
|             upper=self._raw_upper, |             upper=self._raw_upper, | ||||||
|             default=self._raw_default, |             default=self._raw_default, | ||||||
|         ) |         ) | ||||||
|         return prefix + xrepr |         return xrepr | ||||||
|  |  | ||||||
|  |  | ||||||
| np_float_types = (np.float16, np.float32, np.float64) | np_float_types = (np.float16, np.float32, np.float64) | ||||||
| @@ -269,6 +332,7 @@ class Continuous(Space): | |||||||
|         log: bool = False, |         log: bool = False, | ||||||
|         eps: float = _EPS, |         eps: float = _EPS, | ||||||
|     ): |     ): | ||||||
|  |         super(Continuous, self).__init__() | ||||||
|         self._lower = lower |         self._lower = lower | ||||||
|         self._upper = upper |         self._upper = upper | ||||||
|         self._default = default |         self._default = default | ||||||
| @@ -295,19 +359,26 @@ class Continuous(Space): | |||||||
|     def eps(self): |     def eps(self): | ||||||
|         return self._eps |         return self._eps | ||||||
|  |  | ||||||
|     def abstract(self) -> Space: |     def abstract(self, reuse_last=False) -> Space: | ||||||
|         return self.copy() |         if reuse_last and self._last_abstract is not None: | ||||||
|  |             return self._last_abstract | ||||||
|  |         self._last_abstract = self.copy() | ||||||
|  |         return self._last_abstract | ||||||
|  |  | ||||||
|     def random(self, recursion=True): |     def random(self, recursion=True, reuse_last=False): | ||||||
|         del recursion |         del recursion | ||||||
|  |         if reuse_last and self._last_sample is not None: | ||||||
|  |             return self._last_sample | ||||||
|         if self._log_scale: |         if self._log_scale: | ||||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) |             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||||
|             sample = math.exp(sample) |             sample = math.exp(sample) | ||||||
|         else: |         else: | ||||||
|             sample = random.uniform(self._lower, self._upper) |             sample = random.uniform(self._lower, self._upper) | ||||||
|         return VirtualNode(None, sample) |         self._last_sample = VirtualNode(None, sample) | ||||||
|  |         return self._last_sample | ||||||
|  |  | ||||||
|     def xrepr(self, prefix=""): |     def xrepr(self, depth=0): | ||||||
|  |         del depth | ||||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( |         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             lower=self._lower, |             lower=self._lower, | ||||||
| @@ -315,7 +386,7 @@ class Continuous(Space): | |||||||
|             default=self._default, |             default=self._default, | ||||||
|             log=self._log_scale, |             log=self._log_scale, | ||||||
|         ) |         ) | ||||||
|         return prefix + xrepr |         return xrepr | ||||||
|  |  | ||||||
|     def convert(self, x): |     def convert(self, x): | ||||||
|         if isinstance(x, np_float_types) and x.size == 1: |         if isinstance(x, np_float_types) and x.size == 1: | ||||||
| @@ -338,6 +409,12 @@ class Continuous(Space): | |||||||
|     def determined(self): |     def determined(self): | ||||||
|         return abs(self.lower - self.upper) <= self._eps |         return abs(self.lower - self.upper) <= self._eps | ||||||
|  |  | ||||||
|  |     def clean_last_sample(self): | ||||||
|  |         self._last_sample = None | ||||||
|  |  | ||||||
|  |     def clean_last_abstract(self): | ||||||
|  |         self._last_abstract = None | ||||||
|  |  | ||||||
|     def __eq__(self, other): |     def __eq__(self, other): | ||||||
|         if not isinstance(other, Continuous): |         if not isinstance(other, Continuous): | ||||||
|             return False |             return False | ||||||
|   | |||||||
| @@ -3,4 +3,5 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| from .super_module import SuperRunMode | from .super_module import SuperRunMode | ||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| from .super_mlp import SuperLinear | from .super_linear import SuperLinear | ||||||
|  | from .super_linear import SuperMLP | ||||||
|   | |||||||
| @@ -6,7 +6,7 @@ import torch.nn as nn | |||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| 
 | 
 | ||||||
| import math | import math | ||||||
| from typing import Optional, Union | from typing import Optional, Union, Callable | ||||||
| 
 | 
 | ||||||
| import spaces | import spaces | ||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| @@ -57,11 +57,15 @@ class SuperLinear(SuperModule): | |||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|         root_node = spaces.VirtualNode(id(self)) |         root_node = spaces.VirtualNode(id(self)) | ||||||
|         if not spaces.is_determined(self._in_features): |         if not spaces.is_determined(self._in_features): | ||||||
|             root_node.append("_in_features", self._in_features.abstract()) |             root_node.append( | ||||||
|  |                 "_in_features", self._in_features.abstract(reuse_last=True) | ||||||
|  |             ) | ||||||
|         if not spaces.is_determined(self._out_features): |         if not spaces.is_determined(self._out_features): | ||||||
|             root_node.append("_out_features", self._out_features.abstract()) |             root_node.append( | ||||||
|  |                 "_out_features", self._out_features.abstract(reuse_last=True) | ||||||
|  |             ) | ||||||
|         if not spaces.is_determined(self._bias): |         if not spaces.is_determined(self._bias): | ||||||
|             root_node.append("_bias", self._bias.abstract()) |             root_node.append("_bias", self._bias.abstract(reuse_last=True)) | ||||||
|         return root_node |         return root_node | ||||||
| 
 | 
 | ||||||
|     def reset_parameters(self) -> None: |     def reset_parameters(self) -> None: | ||||||
| @@ -116,24 +120,51 @@ class SuperMLP(SuperModule): | |||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         in_features, |         in_features: IntSpaceType, | ||||||
|         hidden_features: Optional[int] = None, |         hidden_features: IntSpaceType, | ||||||
|         out_features: Optional[int] = None, |         out_features: IntSpaceType, | ||||||
|         act_layer=nn.GELU, |         act_layer: Callable[[], nn.Module] = nn.GELU, | ||||||
|         drop: Optional[float] = None, |         drop: Optional[float] = None, | ||||||
|     ): |     ): | ||||||
|         super(SuperMLP, self).__init__() |         super(SuperMLP, self).__init__() | ||||||
|         out_features = out_features or in_features |         self._in_features = in_features | ||||||
|         hidden_features = hidden_features or in_features |         self._hidden_features = hidden_features | ||||||
|         self.fc1 = nn.Linear(in_features, hidden_features) |         self._out_features = out_features | ||||||
|  |         self._drop_rate = drop | ||||||
|  |         self.fc1 = SuperLinear(in_features, hidden_features) | ||||||
|         self.act = act_layer() |         self.act = act_layer() | ||||||
|         self.fc2 = nn.Linear(hidden_features, out_features) |         self.fc2 = SuperLinear(hidden_features, out_features) | ||||||
|         self.drop = nn.Dropout(drop or 0.0) |         self.drop = nn.Dropout(drop or 0.0) | ||||||
| 
 | 
 | ||||||
|     def forward(self, x): |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         space_fc1 = self.fc1.abstract_search_space | ||||||
|  |         space_fc2 = self.fc2.abstract_search_space | ||||||
|  |         if not spaces.is_determined(space_fc1): | ||||||
|  |             root_node.append("fc1", space_fc1) | ||||||
|  |         if not spaces.is_determined(space_fc2): | ||||||
|  |             root_node.append("fc2", space_fc2) | ||||||
|  |         return root_node | ||||||
|  | 
 | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return self._unified_forward(x) | ||||||
|  | 
 | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return self._unified_forward(x) | ||||||
|  | 
 | ||||||
|  |     def _unified_forward(self, x): | ||||||
|         x = self.fc1(x) |         x = self.fc1(x) | ||||||
|         x = self.act(x) |         x = self.act(x) | ||||||
|         x = self.drop(x) |         x = self.drop(x) | ||||||
|         x = self.fc2(x) |         x = self.fc2(x) | ||||||
|         x = self.drop(x) |         x = self.drop(x) | ||||||
|         return x |         return x | ||||||
|  | 
 | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( | ||||||
|  |             self._in_features, | ||||||
|  |             self._hidden_features, | ||||||
|  |             self._out_features, | ||||||
|  |             self._drop_rate, | ||||||
|  |         ) | ||||||
| @@ -48,7 +48,7 @@ class TestBasicSpace(unittest.TestCase): | |||||||
|         space = Continuous(lower, upper, log=False) |         space = Continuous(lower, upper, log=False) | ||||||
|         values = [] |         values = [] | ||||||
|         for i in range(1000000): |         for i in range(1000000): | ||||||
|             x = space.random().value |             x = space.random(reuse_last=False).value | ||||||
|             self.assertGreaterEqual(x, lower) |             self.assertGreaterEqual(x, lower) | ||||||
|             self.assertGreaterEqual(upper, x) |             self.assertGreaterEqual(upper, x) | ||||||
|             values.append(x) |             values.append(x) | ||||||
| @@ -97,6 +97,12 @@ class TestBasicSpace(unittest.TestCase): | |||||||
|         self.assertTrue(is_determined(1)) |         self.assertTrue(is_determined(1)) | ||||||
|         self.assertFalse(is_determined(nested_space)) |         self.assertFalse(is_determined(nested_space)) | ||||||
|  |  | ||||||
|  |     def test_duplicate(self): | ||||||
|  |         space = Categorical(1, 2, 3, 4) | ||||||
|  |         x = space.random() | ||||||
|  |         for _ in range(100): | ||||||
|  |             self.assertEqual(x, space.random(reuse_last=True)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestAbstractSpace(unittest.TestCase): | class TestAbstractSpace(unittest.TestCase): | ||||||
|     """Test the abstract search spaces.""" |     """Test the abstract search spaces.""" | ||||||
|   | |||||||
| @@ -48,3 +48,29 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         output_shape = (32, abstract_child["_out_features"].value) |         output_shape = (32, abstract_child["_out_features"].value) | ||||||
|         outputs = model(inputs) |         outputs = model(inputs) | ||||||
|         self.assertEqual(tuple(outputs.shape), output_shape) |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
|  |  | ||||||
|  |     def test_super_mlp(self): | ||||||
|  |         hidden_features = spaces.Categorical(12, 24, 36) | ||||||
|  |         out_features = spaces.Categorical(12, 24, 36) | ||||||
|  |         mlp = super_core.SuperMLP(10, hidden_features, out_features) | ||||||
|  |         print(mlp) | ||||||
|  |         self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) | ||||||
|  |  | ||||||
|  |         abstract_space = mlp.abstract_search_space | ||||||
|  |         print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space)) | ||||||
|  |         self.assertEqual( | ||||||
|  |             abstract_space["fc1"]["_out_features"], | ||||||
|  |             abstract_space["fc2"]["_in_features"], | ||||||
|  |         ) | ||||||
|  |         self.assertTrue( | ||||||
|  |             abstract_space["fc1"]["_out_features"] | ||||||
|  |             is abstract_space["fc2"]["_in_features"] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         abstract_space.clean_last_sample() | ||||||
|  |         abstract_child = abstract_space.random(reuse_last=True) | ||||||
|  |         print("The abstract child program is:\n{:}".format(abstract_child)) | ||||||
|  |         self.assertEqual( | ||||||
|  |             abstract_child["fc1"]["_out_features"].value, | ||||||
|  |             abstract_child["fc2"]["_in_features"].value, | ||||||
|  |         ) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user