Add __eq__
This commit is contained in:
parent
ae7136645f
commit
b3eed4ca5a
@ -9,7 +9,7 @@ import random
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Text
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"]
|
__all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"]
|
||||||
@ -22,29 +22,37 @@ class Space(metaclass=abc.ABCMeta):
|
|||||||
All search space must inherit from this basic class.
|
All search space must inherit from this basic class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abc.abstractproperty
|
||||||
|
def xrepr(self, indent=0) -> Text:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __repr__(self) -> Text:
|
||||||
|
return self.xrepr()
|
||||||
|
|
||||||
|
@abc.abstractproperty
|
||||||
|
def abstract(self) -> "Space":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def random(self, recursion=True):
|
def random(self, recursion=True):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractproperty
|
@abc.abstractproperty
|
||||||
def determined(self):
|
def determined(self) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def xrepr(self, indent=0):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self.xrepr()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def has(self, x):
|
def has(self, x) -> bool:
|
||||||
"""Check whether x is in this search space."""
|
"""Check whether x is in this search space."""
|
||||||
assert not isinstance(
|
assert not isinstance(
|
||||||
x, Space
|
x, Space
|
||||||
), "The input value itself can not be a search space."
|
), "The input value itself can not be a search space."
|
||||||
|
|
||||||
def copy(self):
|
@abc.abstractmethod
|
||||||
|
def __eq__(self, other):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def copy(self) -> "Space":
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
|
|
||||||
@ -59,33 +67,56 @@ class VirtualNode(Space):
|
|||||||
self._value = value
|
self._value = value
|
||||||
self._attributes = OrderedDict()
|
self._attributes = OrderedDict()
|
||||||
|
|
||||||
def has(self, x):
|
|
||||||
for key, value in self._attributes.items():
|
|
||||||
if value.has(x):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def append(self, key, value):
|
def append(self, key, value):
|
||||||
if not isinstance(value, Space):
|
if not isinstance(value, Space):
|
||||||
raise ValueError("Invalid type of value: {:}".format(type(value)))
|
raise ValueError("Invalid type of value: {:}".format(type(value)))
|
||||||
self._attributes[key] = value
|
self._attributes[key] = value
|
||||||
|
|
||||||
def determined(self):
|
def xrepr(self, indent=0) -> Text:
|
||||||
for key, value in self._attributes.items():
|
|
||||||
if not value.determined(x):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def random(self, recursion=True):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def xrepr(self, indent=0):
|
|
||||||
strs = [self.__class__.__name__ + "("]
|
strs = [self.__class__.__name__ + "("]
|
||||||
for key, value in self._attributes.items():
|
for key, value in self._attributes.items():
|
||||||
strs.append(value.xrepr(indent + 2) + ",")
|
strs.append(value.xrepr(indent + 2) + ",")
|
||||||
strs.append(")")
|
strs.append(")")
|
||||||
return "\n".join(strs)
|
return "\n".join(strs)
|
||||||
|
|
||||||
|
def abstract(self) -> Space:
|
||||||
|
node = VirtualNode(id(self))
|
||||||
|
for key, value in self._attributes.items():
|
||||||
|
if not value.determined:
|
||||||
|
node.append(value.abstract())
|
||||||
|
return node
|
||||||
|
|
||||||
|
def random(self, recursion=True):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def has(self, x) -> bool:
|
||||||
|
for key, value in self._attributes.items():
|
||||||
|
if value.has(x):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self._attributes
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._attributes[key]
|
||||||
|
|
||||||
|
def determined(self) -> bool:
|
||||||
|
for key, value in self._attributes.items():
|
||||||
|
if not value.determined(x):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, VirtualNode):
|
||||||
|
return False
|
||||||
|
for key, value in self._attributes.items():
|
||||||
|
if not key in other:
|
||||||
|
return False
|
||||||
|
if value != other[key]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Categorical(Space):
|
class Categorical(Space):
|
||||||
"""A space contains the categorical values.
|
"""A space contains the categorical values.
|
||||||
@ -104,6 +135,10 @@ class Categorical(Space):
|
|||||||
def candidates(self):
|
def candidates(self):
|
||||||
return self._candidates
|
return self._candidates
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default(self):
|
||||||
|
return self._default
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def determined(self):
|
def determined(self):
|
||||||
if len(self) == 1:
|
if len(self) == 1:
|
||||||
@ -120,6 +155,25 @@ class Categorical(Space):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._candidates)
|
return len(self._candidates)
|
||||||
|
|
||||||
|
def abstract(self) -> Space:
|
||||||
|
if self.determined:
|
||||||
|
return VirtualNode(id(self), self)
|
||||||
|
# [TO-IMPROVE]
|
||||||
|
data = []
|
||||||
|
for candidate in self.candidates:
|
||||||
|
if isinstance(candidate, Space):
|
||||||
|
data.append(candidate.abstract())
|
||||||
|
else:
|
||||||
|
data.append(VirtualNode(id(candidate), candidate))
|
||||||
|
return Categorical(*data, self._default)
|
||||||
|
|
||||||
|
def random(self, recursion=True):
|
||||||
|
sample = random.choice(self._candidates)
|
||||||
|
if recursion and isinstance(sample, Space):
|
||||||
|
return sample.random(recursion)
|
||||||
|
else:
|
||||||
|
return sample
|
||||||
|
|
||||||
def xrepr(self, indent=0):
|
def xrepr(self, indent=0):
|
||||||
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
|
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
|
||||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||||
@ -135,12 +189,17 @@ class Categorical(Space):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def random(self, recursion=True):
|
def __eq__(self, other):
|
||||||
sample = random.choice(self._candidates)
|
if not isinstance(other, Categorical):
|
||||||
if recursion and isinstance(sample, Space):
|
return False
|
||||||
return sample.random(recursion)
|
if len(self) != len(other):
|
||||||
else:
|
return False
|
||||||
return sample
|
if self.default != other.default:
|
||||||
|
return False
|
||||||
|
for index in range(len(self)):
|
||||||
|
if self.__getitem__[index] != other[index]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Integer(Categorical):
|
class Integer(Categorical):
|
||||||
@ -213,8 +272,23 @@ class Continuous(Space):
|
|||||||
return self._default
|
return self._default
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def determined(self):
|
def use_log(self):
|
||||||
return abs(self.lower - self.upper) <= self._eps
|
return self._log_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eps(self):
|
||||||
|
return self._eps
|
||||||
|
|
||||||
|
def abstract(self) -> Space:
|
||||||
|
return self.copy()
|
||||||
|
|
||||||
|
def random(self, recursion=True):
|
||||||
|
del recursion
|
||||||
|
if self._log_scale:
|
||||||
|
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
||||||
|
return math.exp(sample)
|
||||||
|
else:
|
||||||
|
return random.uniform(self._lower, self._upper)
|
||||||
|
|
||||||
def xrepr(self, indent=0):
|
def xrepr(self, indent=0):
|
||||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||||
@ -243,10 +317,20 @@ class Continuous(Space):
|
|||||||
converted_x, success = self.convert(x)
|
converted_x, success = self.convert(x)
|
||||||
return success and self.lower <= converted_x <= self.upper
|
return success and self.lower <= converted_x <= self.upper
|
||||||
|
|
||||||
def random(self, recursion=True):
|
@property
|
||||||
del recursion
|
def determined(self):
|
||||||
if self._log_scale:
|
return abs(self.lower - self.upper) <= self._eps
|
||||||
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
|
||||||
return math.exp(sample)
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Continuous):
|
||||||
|
return False
|
||||||
|
if self is other:
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
return random.uniform(self._lower, self._upper)
|
return (
|
||||||
|
self.lower == other.lower
|
||||||
|
and self.upper == other.upper
|
||||||
|
and self.default == other.default
|
||||||
|
and self.use_log == other.use_log
|
||||||
|
and self.eps == other.eps
|
||||||
|
)
|
||||||
|
@ -96,3 +96,14 @@ class TestBasicSpace(unittest.TestCase):
|
|||||||
# Test Simple Op
|
# Test Simple Op
|
||||||
self.assertTrue(is_determined(1))
|
self.assertTrue(is_determined(1))
|
||||||
self.assertFalse(is_determined(nested_space))
|
self.assertFalse(is_determined(nested_space))
|
||||||
|
|
||||||
|
|
||||||
|
class TestAbstractSpace(unittest.TestCase):
|
||||||
|
"""Test the abstract search spaces."""
|
||||||
|
|
||||||
|
def test_continous(self):
|
||||||
|
space = Continuous(0, 1)
|
||||||
|
self.assertEqual(space, space.abstract())
|
||||||
|
print(space.abstract())
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user