diff --git a/lib/spaces/basic_space.py b/lib/spaces/basic_space.py index 9bf707b..cd4ad63 100644 --- a/lib/spaces/basic_space.py +++ b/lib/spaces/basic_space.py @@ -9,7 +9,7 @@ import random import numpy as np from collections import OrderedDict -from typing import Optional +from typing import Optional, Text __all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"] @@ -22,29 +22,37 @@ class Space(metaclass=abc.ABCMeta): All search space must inherit from this basic class. """ + @abc.abstractproperty + def xrepr(self, indent=0) -> Text: + raise NotImplementedError + + def __repr__(self) -> Text: + return self.xrepr() + + @abc.abstractproperty + def abstract(self) -> "Space": + raise NotImplementedError + @abc.abstractmethod def random(self, recursion=True): raise NotImplementedError @abc.abstractproperty - def determined(self): + def determined(self) -> bool: raise NotImplementedError - @abc.abstractproperty - def xrepr(self, indent=0): - raise NotImplementedError - - def __repr__(self): - return self.xrepr() - @abc.abstractmethod - def has(self, x): + def has(self, x) -> bool: """Check whether x is in this search space.""" assert not isinstance( x, Space ), "The input value itself can not be a search space." - def copy(self): + @abc.abstractmethod + def __eq__(self, other): + raise NotImplementedError + + def copy(self) -> "Space": return copy.deepcopy(self) @@ -59,33 +67,56 @@ class VirtualNode(Space): self._value = value self._attributes = OrderedDict() - def has(self, x): - for key, value in self._attributes.items(): - if value.has(x): - return True - return False - def append(self, key, value): if not isinstance(value, Space): raise ValueError("Invalid type of value: {:}".format(type(value))) self._attributes[key] = value - def determined(self): - for key, value in self._attributes.items(): - if not value.determined(x): - return False - return True - - def random(self, recursion=True): - raise NotImplementedError - - def xrepr(self, indent=0): + def xrepr(self, indent=0) -> Text: strs = [self.__class__.__name__ + "("] for key, value in self._attributes.items(): strs.append(value.xrepr(indent + 2) + ",") strs.append(")") return "\n".join(strs) + def abstract(self) -> Space: + node = VirtualNode(id(self)) + for key, value in self._attributes.items(): + if not value.determined: + node.append(value.abstract()) + return node + + def random(self, recursion=True): + raise NotImplementedError + + def has(self, x) -> bool: + for key, value in self._attributes.items(): + if value.has(x): + return True + return False + + def __contains__(self, key): + return key in self._attributes + + def __getitem__(self, key): + return self._attributes[key] + + def determined(self) -> bool: + for key, value in self._attributes.items(): + if not value.determined(x): + return False + return True + + def __eq__(self, other): + if not isinstance(other, VirtualNode): + return False + for key, value in self._attributes.items(): + if not key in other: + return False + if value != other[key]: + return False + return True + class Categorical(Space): """A space contains the categorical values. @@ -104,6 +135,10 @@ class Categorical(Space): def candidates(self): return self._candidates + @property + def default(self): + return self._default + @property def determined(self): if len(self) == 1: @@ -120,6 +155,25 @@ class Categorical(Space): def __len__(self): return len(self._candidates) + def abstract(self) -> Space: + if self.determined: + return VirtualNode(id(self), self) + # [TO-IMPROVE] + data = [] + for candidate in self.candidates: + if isinstance(candidate, Space): + data.append(candidate.abstract()) + else: + data.append(VirtualNode(id(candidate), candidate)) + return Categorical(*data, self._default) + + def random(self, recursion=True): + sample = random.choice(self._candidates) + if recursion and isinstance(sample, Space): + return sample.random(recursion) + else: + return sample + def xrepr(self, indent=0): xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( name=self.__class__.__name__, cs=self._candidates, default=self._default @@ -135,12 +189,17 @@ class Categorical(Space): return True return False - def random(self, recursion=True): - sample = random.choice(self._candidates) - if recursion and isinstance(sample, Space): - return sample.random(recursion) - else: - return sample + def __eq__(self, other): + if not isinstance(other, Categorical): + return False + if len(self) != len(other): + return False + if self.default != other.default: + return False + for index in range(len(self)): + if self.__getitem__[index] != other[index]: + return False + return True class Integer(Categorical): @@ -213,8 +272,23 @@ class Continuous(Space): return self._default @property - def determined(self): - return abs(self.lower - self.upper) <= self._eps + def use_log(self): + return self._log_scale + + @property + def eps(self): + return self._eps + + def abstract(self) -> Space: + return self.copy() + + def random(self, recursion=True): + del recursion + if self._log_scale: + sample = random.uniform(math.log(self._lower), math.log(self._upper)) + return math.exp(sample) + else: + return random.uniform(self._lower, self._upper) def xrepr(self, indent=0): xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( @@ -243,10 +317,20 @@ class Continuous(Space): converted_x, success = self.convert(x) return success and self.lower <= converted_x <= self.upper - def random(self, recursion=True): - del recursion - if self._log_scale: - sample = random.uniform(math.log(self._lower), math.log(self._upper)) - return math.exp(sample) + @property + def determined(self): + return abs(self.lower - self.upper) <= self._eps + + def __eq__(self, other): + if not isinstance(other, Continuous): + return False + if self is other: + return True else: - return random.uniform(self._lower, self._upper) + return ( + self.lower == other.lower + and self.upper == other.upper + and self.default == other.default + and self.use_log == other.use_log + and self.eps == other.eps + ) diff --git a/tests/test_basic_space.py b/tests/test_basic_space.py index 983514c..cf1f7c1 100644 --- a/tests/test_basic_space.py +++ b/tests/test_basic_space.py @@ -96,3 +96,14 @@ class TestBasicSpace(unittest.TestCase): # Test Simple Op self.assertTrue(is_determined(1)) self.assertFalse(is_determined(nested_space)) + + +class TestAbstractSpace(unittest.TestCase): + """Test the abstract search spaces.""" + + def test_continous(self): + space = Continuous(0, 1) + self.assertEqual(space, space.abstract()) + print(space.abstract()) + +