Add __eq__
This commit is contained in:
		| @@ -9,7 +9,7 @@ import random | |||||||
| import numpy as np | import numpy as np | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
|  |  | ||||||
| from typing import Optional | from typing import Optional, Text | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"] | __all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"] | ||||||
| @@ -22,29 +22,37 @@ class Space(metaclass=abc.ABCMeta): | |||||||
|     All search space must inherit from this basic class. |     All search space must inherit from this basic class. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |     @abc.abstractproperty | ||||||
|  |     def xrepr(self, indent=0) -> Text: | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __repr__(self) -> Text: | ||||||
|  |         return self.xrepr() | ||||||
|  |  | ||||||
|  |     @abc.abstractproperty | ||||||
|  |     def abstract(self) -> "Space": | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def random(self, recursion=True): |     def random(self, recursion=True): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractproperty |     @abc.abstractproperty | ||||||
|     def determined(self): |     def determined(self) -> bool: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractproperty |  | ||||||
|     def xrepr(self, indent=0): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return self.xrepr() |  | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def has(self, x): |     def has(self, x) -> bool: | ||||||
|         """Check whether x is in this search space.""" |         """Check whether x is in this search space.""" | ||||||
|         assert not isinstance( |         assert not isinstance( | ||||||
|             x, Space |             x, Space | ||||||
|         ), "The input value itself can not be a search space." |         ), "The input value itself can not be a search space." | ||||||
|  |  | ||||||
|     def copy(self): |     @abc.abstractmethod | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def copy(self) -> "Space": | ||||||
|         return copy.deepcopy(self) |         return copy.deepcopy(self) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -59,33 +67,56 @@ class VirtualNode(Space): | |||||||
|         self._value = value |         self._value = value | ||||||
|         self._attributes = OrderedDict() |         self._attributes = OrderedDict() | ||||||
|  |  | ||||||
|     def has(self, x): |  | ||||||
|         for key, value in self._attributes.items(): |  | ||||||
|             if value.has(x): |  | ||||||
|                 return True |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     def append(self, key, value): |     def append(self, key, value): | ||||||
|         if not isinstance(value, Space): |         if not isinstance(value, Space): | ||||||
|             raise ValueError("Invalid type of value: {:}".format(type(value))) |             raise ValueError("Invalid type of value: {:}".format(type(value))) | ||||||
|         self._attributes[key] = value |         self._attributes[key] = value | ||||||
|  |  | ||||||
|     def determined(self): |     def xrepr(self, indent=0) -> Text: | ||||||
|         for key, value in self._attributes.items(): |  | ||||||
|             if not value.determined(x): |  | ||||||
|                 return False |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     def random(self, recursion=True): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def xrepr(self, indent=0): |  | ||||||
|         strs = [self.__class__.__name__ + "("] |         strs = [self.__class__.__name__ + "("] | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             strs.append(value.xrepr(indent + 2) + ",") |             strs.append(value.xrepr(indent + 2) + ",") | ||||||
|         strs.append(")") |         strs.append(")") | ||||||
|         return "\n".join(strs) |         return "\n".join(strs) | ||||||
|  |  | ||||||
|  |     def abstract(self) -> Space: | ||||||
|  |         node = VirtualNode(id(self)) | ||||||
|  |         for key, value in self._attributes.items(): | ||||||
|  |             if not value.determined: | ||||||
|  |                 node.append(value.abstract()) | ||||||
|  |         return node | ||||||
|  |  | ||||||
|  |     def random(self, recursion=True): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def has(self, x) -> bool: | ||||||
|  |         for key, value in self._attributes.items(): | ||||||
|  |             if value.has(x): | ||||||
|  |                 return True | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def __contains__(self, key): | ||||||
|  |         return key in self._attributes | ||||||
|  |  | ||||||
|  |     def __getitem__(self, key): | ||||||
|  |         return self._attributes[key] | ||||||
|  |  | ||||||
|  |     def determined(self) -> bool: | ||||||
|  |         for key, value in self._attributes.items(): | ||||||
|  |             if not value.determined(x): | ||||||
|  |                 return False | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         if not isinstance(other, VirtualNode): | ||||||
|  |             return False | ||||||
|  |         for key, value in self._attributes.items(): | ||||||
|  |             if not key in other: | ||||||
|  |                 return False | ||||||
|  |             if value != other[key]: | ||||||
|  |                 return False | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |  | ||||||
| class Categorical(Space): | class Categorical(Space): | ||||||
|     """A space contains the categorical values. |     """A space contains the categorical values. | ||||||
| @@ -104,6 +135,10 @@ class Categorical(Space): | |||||||
|     def candidates(self): |     def candidates(self): | ||||||
|         return self._candidates |         return self._candidates | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def default(self): | ||||||
|  |         return self._default | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def determined(self): |     def determined(self): | ||||||
|         if len(self) == 1: |         if len(self) == 1: | ||||||
| @@ -120,6 +155,25 @@ class Categorical(Space): | |||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self._candidates) |         return len(self._candidates) | ||||||
|  |  | ||||||
|  |     def abstract(self) -> Space: | ||||||
|  |         if self.determined: | ||||||
|  |             return VirtualNode(id(self), self) | ||||||
|  |         # [TO-IMPROVE] | ||||||
|  |         data = [] | ||||||
|  |         for candidate in self.candidates: | ||||||
|  |             if isinstance(candidate, Space): | ||||||
|  |                 data.append(candidate.abstract()) | ||||||
|  |             else: | ||||||
|  |                 data.append(VirtualNode(id(candidate), candidate)) | ||||||
|  |         return Categorical(*data, self._default) | ||||||
|  |  | ||||||
|  |     def random(self, recursion=True): | ||||||
|  |         sample = random.choice(self._candidates) | ||||||
|  |         if recursion and isinstance(sample, Space): | ||||||
|  |             return sample.random(recursion) | ||||||
|  |         else: | ||||||
|  |             return sample | ||||||
|  |  | ||||||
|     def xrepr(self, indent=0): |     def xrepr(self, indent=0): | ||||||
|         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( |         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( | ||||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default |             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||||
| @@ -135,12 +189,17 @@ class Categorical(Space): | |||||||
|                 return True |                 return True | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     def random(self, recursion=True): |     def __eq__(self, other): | ||||||
|         sample = random.choice(self._candidates) |         if not isinstance(other, Categorical): | ||||||
|         if recursion and isinstance(sample, Space): |             return False | ||||||
|             return sample.random(recursion) |         if len(self) != len(other): | ||||||
|         else: |             return False | ||||||
|             return sample |         if self.default != other.default: | ||||||
|  |             return False | ||||||
|  |         for index in range(len(self)): | ||||||
|  |             if self.__getitem__[index] != other[index]: | ||||||
|  |                 return False | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |  | ||||||
| class Integer(Categorical): | class Integer(Categorical): | ||||||
| @@ -213,8 +272,23 @@ class Continuous(Space): | |||||||
|         return self._default |         return self._default | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def determined(self): |     def use_log(self): | ||||||
|         return abs(self.lower - self.upper) <= self._eps |         return self._log_scale | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def eps(self): | ||||||
|  |         return self._eps | ||||||
|  |  | ||||||
|  |     def abstract(self) -> Space: | ||||||
|  |         return self.copy() | ||||||
|  |  | ||||||
|  |     def random(self, recursion=True): | ||||||
|  |         del recursion | ||||||
|  |         if self._log_scale: | ||||||
|  |             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||||
|  |             return math.exp(sample) | ||||||
|  |         else: | ||||||
|  |             return random.uniform(self._lower, self._upper) | ||||||
|  |  | ||||||
|     def xrepr(self, indent=0): |     def xrepr(self, indent=0): | ||||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( |         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||||
| @@ -243,10 +317,20 @@ class Continuous(Space): | |||||||
|         converted_x, success = self.convert(x) |         converted_x, success = self.convert(x) | ||||||
|         return success and self.lower <= converted_x <= self.upper |         return success and self.lower <= converted_x <= self.upper | ||||||
|  |  | ||||||
|     def random(self, recursion=True): |     @property | ||||||
|         del recursion |     def determined(self): | ||||||
|         if self._log_scale: |         return abs(self.lower - self.upper) <= self._eps | ||||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) |  | ||||||
|             return math.exp(sample) |     def __eq__(self, other): | ||||||
|  |         if not isinstance(other, Continuous): | ||||||
|  |             return False | ||||||
|  |         if self is other: | ||||||
|  |             return True | ||||||
|         else: |         else: | ||||||
|             return random.uniform(self._lower, self._upper) |             return ( | ||||||
|  |                 self.lower == other.lower | ||||||
|  |                 and self.upper == other.upper | ||||||
|  |                 and self.default == other.default | ||||||
|  |                 and self.use_log == other.use_log | ||||||
|  |                 and self.eps == other.eps | ||||||
|  |             ) | ||||||
|   | |||||||
| @@ -96,3 +96,14 @@ class TestBasicSpace(unittest.TestCase): | |||||||
|         # Test Simple Op |         # Test Simple Op | ||||||
|         self.assertTrue(is_determined(1)) |         self.assertTrue(is_determined(1)) | ||||||
|         self.assertFalse(is_determined(nested_space)) |         self.assertFalse(is_determined(nested_space)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestAbstractSpace(unittest.TestCase): | ||||||
|  |     """Test the abstract search spaces.""" | ||||||
|  |  | ||||||
|  |     def test_continous(self): | ||||||
|  |       space = Continuous(0, 1) | ||||||
|  |       self.assertEqual(space, space.abstract()) | ||||||
|  |       print(space.abstract()) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user