diff --git a/lib/spaces/__init__.py b/lib/spaces/__init__.py index 9cfe5b1..d777d0c 100644 --- a/lib/spaces/__init__.py +++ b/lib/spaces/__init__.py @@ -12,5 +12,6 @@ from .basic_space import VirtualNode from .basic_op import has_categorical from .basic_op import has_continuous from .basic_op import is_determined +from .basic_op import get_determined_value from .basic_op import get_min from .basic_op import get_max diff --git a/lib/spaces/basic_op.py b/lib/spaces/basic_op.py index fbb75b3..4f5f9b8 100644 --- a/lib/spaces/basic_op.py +++ b/lib/spaces/basic_op.py @@ -1,4 +1,5 @@ from spaces.basic_space import Space +from spaces.basic_space import VirtualNode from spaces.basic_space import Integer from spaces.basic_space import Continuous from spaces.basic_space import Categorical @@ -26,6 +27,20 @@ def is_determined(space_or_value): return True +def get_determined_value(space_or_value): + if not is_determined(space_or_value): + raise ValueError("This input is not determined: {:}".format(space_or_value)) + if isinstance(space_or_value, Space): + if isinstance(space_or_value, Continuous): + return space_or_value.lower + elif isinstance(space_or_value, Categorical): + return get_determined_value(space_or_value[0]) + else: # VirtualNode + return space_or_value.value + else: + return space_or_value + + def get_max(space_or_value): if isinstance(space_or_value, Integer): return max(space_or_value.candidates) diff --git a/lib/spaces/basic_space.py b/lib/spaces/basic_space.py index cd4ad63..a0b6465 100644 --- a/lib/spaces/basic_space.py +++ b/lib/spaces/basic_space.py @@ -23,7 +23,7 @@ class Space(metaclass=abc.ABCMeta): """ @abc.abstractproperty - def xrepr(self, indent=0) -> Text: + def xrepr(self, prefix="") -> Text: raise NotImplementedError def __repr__(self) -> Text: @@ -67,17 +67,27 @@ class VirtualNode(Space): self._value = value self._attributes = OrderedDict() + @property + def value(self): + return self._value + def append(self, key, value): + if not isinstance(key, str): + raise TypeError( + "Only accept string as a key instead of {:}".format(type(key)) + ) if not isinstance(value, Space): raise ValueError("Invalid type of value: {:}".format(type(value))) + # if value.determined: + # raise ValueError("Can not attach a determined value: {:}".format(value)) self._attributes[key] = value - def xrepr(self, indent=0) -> Text: - strs = [self.__class__.__name__ + "("] + def xrepr(self, prefix=" ") -> Text: + strs = [self.__class__.__name__ + "(value={:}".format(self._value)] for key, value in self._attributes.items(): - strs.append(value.xrepr(indent + 2) + ",") + strs.append(value.xrepr(prefix + " " + key + " = ")) strs.append(")") - return "\n".join(strs) + return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs) def abstract(self) -> Space: node = VirtualNode(id(self)) @@ -87,7 +97,10 @@ class VirtualNode(Space): return node def random(self, recursion=True): - raise NotImplementedError + node = VirtualNode(None, self._value) + for key, value in self._attributes.items(): + node.append(key, value.random(recursion)) + return node def has(self, x) -> bool: for key, value in self._attributes.items(): @@ -101,6 +114,7 @@ class VirtualNode(Space): def __getitem__(self, key): return self._attributes[key] + @property def determined(self) -> bool: for key, value in self._attributes.items(): if not value.determined(x): @@ -165,20 +179,22 @@ class Categorical(Space): data.append(candidate.abstract()) else: data.append(VirtualNode(id(candidate), candidate)) - return Categorical(*data, self._default) + return Categorical(*data, default=self._default) def random(self, recursion=True): sample = random.choice(self._candidates) if recursion and isinstance(sample, Space): - return sample.random(recursion) + sample = sample.random(recursion) + if isinstance(sample, VirtualNode): + return sample.copy() else: - return sample + return VirtualNode(None, sample) - def xrepr(self, indent=0): + def xrepr(self, prefix=""): xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( name=self.__class__.__name__, cs=self._candidates, default=self._default ) - return " " * indent + xrepr + return prefix + xrepr def has(self, x): super().has(x) @@ -219,14 +235,14 @@ class Integer(Categorical): default = data.index(default) super(Integer, self).__init__(*data, default=default) - def xrepr(self, indent=0): + def xrepr(self, prefix=""): xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( name=self.__class__.__name__, lower=self._raw_lower, upper=self._raw_upper, default=self._raw_default, ) - return " " * indent + xrepr + return prefix + xrepr np_float_types = (np.float16, np.float32, np.float64) @@ -286,11 +302,12 @@ class Continuous(Space): del recursion if self._log_scale: sample = random.uniform(math.log(self._lower), math.log(self._upper)) - return math.exp(sample) + sample = math.exp(sample) else: - return random.uniform(self._lower, self._upper) + sample = random.uniform(self._lower, self._upper) + return VirtualNode(None, sample) - def xrepr(self, indent=0): + def xrepr(self, prefix=""): xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( name=self.__class__.__name__, lower=self._lower, @@ -298,7 +315,7 @@ class Continuous(Space): default=self._default, log=self._log_scale, ) - return " " * indent + xrepr + return prefix + xrepr def convert(self, x): if isinstance(x, np_float_types) and x.size == 1: diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index eb41901..8c7b056 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -1,5 +1,6 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # ##################################################### +from .super_module import SuperRunMode from .super_module import SuperModule from .super_mlp import SuperLinear diff --git a/lib/xlayers/super_mlp.py b/lib/xlayers/super_mlp.py index f21a8d8..d80ceed 100644 --- a/lib/xlayers/super_mlp.py +++ b/lib/xlayers/super_mlp.py @@ -3,6 +3,7 @@ ##################################################### import torch import torch.nn as nn +import torch.nn.functional as F import math from typing import Optional, Union @@ -52,14 +53,15 @@ class SuperLinear(SuperModule): def bias(self): return spaces.has_categorical(self._bias, True) + @property def abstract_search_space(self): root_node = spaces.VirtualNode(id(self)) if not spaces.is_determined(self._in_features): - root_node.append("_in_features", self._in_features) + root_node.append("_in_features", self._in_features.abstract()) if not spaces.is_determined(self._out_features): - root_node.append("_out_features", self._out_features) + root_node.append("_out_features", self._out_features.abstract()) if not spaces.is_determined(self._bias): - root_node.append("_bias", self._bias) + root_node.append("_bias", self._bias.abstract()) return root_node def reset_parameters(self) -> None: @@ -69,6 +71,37 @@ class SuperLinear(SuperModule): bound = 1 / math.sqrt(fan_in) nn.init.uniform_(self._super_bias, -bound, bound) + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + # check inputs -> + if not spaces.is_determined(self._in_features): + expected_input_dim = self.abstract_child["_in_features"].value + else: + expected_input_dim = spaces.get_determined_value(self._in_features) + if input.size(-1) != expected_input_dim: + raise ValueError( + "Expect the input dim of {:} instead of {:}".format( + expected_input_dim, input.size(-1) + ) + ) + # create the weight matrix + if not spaces.is_determined(self._out_features): + out_dim = self.abstract_child["_out_features"].value + else: + out_dim = spaces.get_determined_value(self._out_features) + candidate_weight = self._super_weight[:out_dim, :expected_input_dim] + # create the bias matrix + if not spaces.is_determined(self._bias): + if self.abstract_child["_bias"].value: + candidate_bias = self._super_bias[:out_dim] + else: + candidate_bias = None + else: + if spaces.get_determined_value(self._bias): + candidate_bias = self._super_bias[:out_dim] + else: + candidate_bias = None + return F.linear(input, candidate_weight, candidate_bias) + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self._super_weight, self._super_bias) @@ -78,8 +111,9 @@ class SuperLinear(SuperModule): ) -class SuperMLP(nn.Module): - # MLP: FC -> Activation -> Drop -> FC -> Drop +class SuperMLP(SuperModule): + """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" + def __init__( self, in_features, @@ -88,13 +122,13 @@ class SuperMLP(nn.Module): act_layer=nn.GELU, drop: Optional[float] = None, ): - super(MLP, self).__init__() + super(SuperMLP, self).__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop or 0) + self.drop = nn.Dropout(drop or 0.0) def forward(self, x): x = self.fc1(x) diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index cdfab6c..1e0702c 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -6,11 +6,14 @@ import abc import torch.nn as nn from enum import Enum +import spaces + class SuperRunMode(Enum): """This class defines the enumerations for Super Model Running Mode.""" FullModel = "fullmodel" + Candidate = "candidate" Default = "fullmodel" @@ -20,8 +23,23 @@ class SuperModule(abc.ABC, nn.Module): def __init__(self): super(SuperModule, self).__init__() self._super_run_type = SuperRunMode.Default + self._abstract_child = None - @abc.abstractmethod + def set_super_run_type(self, super_run_type): + def _reset_super_run(m): + if isinstance(m, SuperModule): + m._super_run_type = super_run_type + + self.apply(_reset_super_run) + + def apply_candiate(self, abstract_child): + if not isinstance(abstract_child, spaces.VirtualNode): + raise ValueError( + "Invalid abstract child program: {:}".format(abstract_child) + ) + self._abstract_child = abstract_child + + @property def abstract_search_space(self): raise NotImplementedError @@ -29,13 +47,24 @@ class SuperModule(abc.ABC, nn.Module): def super_run_type(self): return self._super_run_type + @property + def abstract_child(self): + return self._abstract_child + @abc.abstractmethod def forward_raw(self, *inputs): + """Use the largest candidate for forward. Similar to the original PyTorch model.""" + raise NotImplementedError + + @abc.abstractmethod + def forward_candidate(self, *inputs): raise NotImplementedError def forward(self, *inputs): if self.super_run_type == SuperRunMode.FullModel: return self.forward_raw(*inputs) + elif self.super_run_type == SuperRunMode.Candidate: + return self.forward_candidate(*inputs) else: raise ModeError( "Unknown Super Model Run Mode: {:}".format(self.super_run_type) diff --git a/tests/test_basic_space.py b/tests/test_basic_space.py index d6f47c5..2de430e 100644 --- a/tests/test_basic_space.py +++ b/tests/test_basic_space.py @@ -41,14 +41,14 @@ class TestBasicSpace(unittest.TestCase): def test_continuous(self): random.seed(999) space = Continuous(0, 1) - self.assertGreaterEqual(space.random(), 0) - self.assertGreaterEqual(1, space.random()) + self.assertGreaterEqual(space.random().value, 0) + self.assertGreaterEqual(1, space.random().value) lower, upper = 1.5, 4.6 space = Continuous(lower, upper, log=False) values = [] for i in range(1000000): - x = space.random() + x = space.random().value self.assertGreaterEqual(x, lower) self.assertGreaterEqual(upper, x) values.append(x) @@ -89,7 +89,7 @@ class TestBasicSpace(unittest.TestCase): Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11), 12, ) - print(nested_space) + print("\nThe nested search space:\n{:}".format(nested_space)) for i in range(1, 13): self.assertTrue(nested_space.has(i)) @@ -102,6 +102,19 @@ class TestAbstractSpace(unittest.TestCase): """Test the abstract search spaces.""" def test_continous(self): + print("") space = Continuous(0, 1) self.assertEqual(space, space.abstract()) + print("The abstract search space for Continuous: {:}".format(space.abstract())) + + space = Categorical(1, 2, 3) + self.assertEqual(len(space.abstract()), 3) print(space.abstract()) + + nested_space = Categorical( + Categorical(1, 2, 3), + Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11), + 12, + ) + abstract_nested_space = nested_space.abstract() + print("The abstract nested search space:\n{:}".format(abstract_nested_space)) diff --git a/tests/test_super_model.py b/tests/test_super_model.py index c117363..7df1f4a 100644 --- a/tests/test_super_model.py +++ b/tests/test_super_model.py @@ -25,6 +25,26 @@ class TestSuperLinear(unittest.TestCase): out_features = spaces.Categorical(12, 24, 36) bias = spaces.Categorical(True, False) model = super_core.SuperLinear(10, out_features, bias=bias) - print(model) + print("The simple super linear module is:\n{:}".format(model)) + print(model.super_run_type) - print(model.abstract_search_space()) + self.assertTrue(model.bias) + + inputs = torch.rand(32, 10) + print("Input shape: {:}".format(inputs.shape)) + print("Weight shape: {:}".format(model._super_weight.shape)) + print("Bias shape: {:}".format(model._super_bias.shape)) + outputs = model(inputs) + self.assertEqual(tuple(outputs.shape), (32, 36)) + + abstract_space = model.abstract_search_space + abstract_child = abstract_space.random() + print("The abstract searc space:\n{:}".format(abstract_space)) + print("The abstract child program:\n{:}".format(abstract_child)) + + model.set_super_run_type(super_core.SuperRunMode.Candidate) + model.apply_candiate(abstract_child) + + output_shape = (32, abstract_child["_out_features"].value) + outputs = model(inputs) + self.assertEqual(tuple(outputs.shape), output_shape)