xautodl/lib/spaces/basic_space.py
2021-03-19 23:57:23 +08:00

435 lines
13 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
#####################################################
import abc
import math
import copy
import random
import numpy as np
from collections import OrderedDict
from typing import Optional, Text
__all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"]
_EPS = 1e-9
class Space(metaclass=abc.ABCMeta):
"""Basic search space describing the set of possible candidate values for hyperparameter.
All search space must inherit from this basic class.
"""
def __init__(self):
# used to avoid duplicate sample
self._last_sample = None
self._last_abstract = None
@abc.abstractproperty
def xrepr(self, depth=0) -> Text:
raise NotImplementedError
def __repr__(self) -> Text:
return self.xrepr()
@abc.abstractproperty
def abstract(self, reuse_last=False) -> "Space":
raise NotImplementedError
@abc.abstractmethod
def random(self, recursion=True, reuse_last=False):
raise NotImplementedError
@abc.abstractmethod
def clean_last_sample(self):
raise NotImplementedError
@abc.abstractmethod
def clean_last_abstract(self):
raise NotImplementedError
def clean_last(self):
self.clean_last_sample()
self.clean_last_abstract()
@abc.abstractproperty
def determined(self) -> bool:
raise NotImplementedError
@abc.abstractmethod
def has(self, x) -> bool:
"""Check whether x is in this search space."""
assert not isinstance(
x, Space
), "The input value itself can not be a search space."
@abc.abstractmethod
def __eq__(self, other):
raise NotImplementedError
def copy(self) -> "Space":
return copy.deepcopy(self)
class VirtualNode(Space):
"""For a nested search space, we represent it as a tree structure.
For example,
"""
def __init__(self, id=None, value=None):
super(VirtualNode, self).__init__()
self._id = id
self._value = value
self._attributes = OrderedDict()
@property
def value(self):
return self._value
def append(self, key, value):
if not isinstance(key, str):
raise TypeError(
"Only accept string as a key instead of {:}".format(type(key))
)
if not isinstance(value, Space):
raise ValueError("Invalid type of value: {:}".format(type(value)))
# if value.determined:
# raise ValueError("Can not attach a determined value: {:}".format(value))
self._attributes[key] = value
def xrepr(self, depth=0) -> Text:
strs = [self.__class__.__name__ + "(value={:}".format(self._value)]
for key, value in self._attributes.items():
strs.append(key + " = " + value.xrepr(depth + 1))
strs.append(")")
if len(strs) == 2:
return "".join(strs)
else:
space = " "
xstrs = (
[strs[0]]
+ [space * (depth + 1) + x for x in strs[1:-1]]
+ [space * depth + strs[-1]]
)
return ",\n".join(xstrs)
def abstract(self, reuse_last=False) -> Space:
if reuse_last and self._last_abstract is not None:
return self._last_abstract
node = VirtualNode(id(self))
for key, value in self._attributes.items():
if not value.determined:
node.append(value.abstract(reuse_last))
self._last_abstract = node
return self._last_abstract
def random(self, recursion=True, reuse_last=False):
if reuse_last and self._last_sample is not None:
return self._last_sample
node = VirtualNode(None, self._value)
for key, value in self._attributes.items():
node.append(key, value.random(recursion, reuse_last))
self._last_sample = node # record the last sample
return node
def clean_last_sample(self):
self._last_sample = None
for key, value in self._attributes.items():
value.clean_last_sample()
def clean_last_abstract(self):
self._last_abstract = None
for key, value in self._attributes.items():
value.clean_last_abstract()
def has(self, x) -> bool:
for key, value in self._attributes.items():
if value.has(x):
return True
return False
def __contains__(self, key):
return key in self._attributes
def __getitem__(self, key):
return self._attributes[key]
@property
def determined(self) -> bool:
for key, value in self._attributes.items():
if not value.determined:
return False
return True
def __eq__(self, other):
if not isinstance(other, VirtualNode):
return False
for key, value in self._attributes.items():
if not key in other:
return False
if value != other[key]:
return False
return True
class Categorical(Space):
"""A space contains the categorical values.
It can be a nested space, which means that the candidate in this space can also be a search space.
"""
def __init__(self, *data, default: Optional[int] = None):
super(Categorical, self).__init__()
self._candidates = [*data]
self._default = default
assert self._default is None or 0 <= self._default < len(
self._candidates
), "default >= {:}".format(len(self._candidates))
assert len(self) > 0, "Please provide at least one candidate"
@property
def candidates(self):
return self._candidates
@property
def default(self):
return self._default
@property
def determined(self):
if len(self) == 1:
return (
not isinstance(self._candidates[0], Space)
or self._candidates[0].determined
)
else:
return False
def __getitem__(self, index):
return self._candidates[index]
def __len__(self):
return len(self._candidates)
def clean_last_sample(self):
self._last_sample = None
for candidate in self._candidates:
if isinstance(candidate, Space):
candidate.clean_last_sample()
def clean_last_abstract(self):
self._last_abstract = None
for candidate in self._candidates:
if isinstance(candidate, Space):
candidate.clean_last_abstract()
def abstract(self, reuse_last=False) -> Space:
if reuse_last and self._last_abstract is not None:
return self._last_abstract
if self.determined:
result = VirtualNode(id(self), self)
else:
# [TO-IMPROVE]
data = []
for candidate in self.candidates:
if isinstance(candidate, Space):
data.append(candidate.abstract())
else:
data.append(VirtualNode(id(candidate), candidate))
result = Categorical(*data, default=self._default)
self._last_abstract = result
return self._last_abstract
def random(self, recursion=True, reuse_last=False):
if reuse_last and self._last_sample is not None:
return self._last_sample
sample = random.choice(self._candidates)
if recursion and isinstance(sample, Space):
sample = sample.random(recursion, reuse_last)
if isinstance(sample, VirtualNode):
sample = sample.copy()
else:
sample = VirtualNode(None, sample)
self._last_sample = sample
return self._last_sample
def xrepr(self, depth=0):
del depth
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
name=self.__class__.__name__, cs=self._candidates, default=self._default
)
return xrepr
def has(self, x):
super().has(x)
for candidate in self._candidates:
if isinstance(candidate, Space) and candidate.has(x):
return True
elif candidate == x:
return True
return False
def __eq__(self, other):
if not isinstance(other, Categorical):
return False
if len(self) != len(other):
return False
if self.default != other.default:
return False
for index in range(len(self)):
if self.__getitem__(index) != other[index]:
return False
return True
class Integer(Categorical):
"""A space contains the integer values."""
def __init__(self, lower: int, upper: int, default: Optional[int] = None):
if not isinstance(lower, int) or not isinstance(upper, int):
raise ValueError(
"The lower [{:}] and uppwer [{:}] must be int.".format(lower, upper)
)
data = list(range(lower, upper + 1))
self._raw_lower = lower
self._raw_upper = upper
self._raw_default = default
if default is not None and (default < lower or default > upper):
raise ValueError("The default value [{:}] is out of range.".format(default))
default = data.index(default)
super(Integer, self).__init__(*data, default=default)
def xrepr(self, depth=0):
del depth
xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
name=self.__class__.__name__,
lower=self._raw_lower,
upper=self._raw_upper,
default=self._raw_default,
)
return xrepr
np_float_types = (np.float16, np.float32, np.float64)
np_int_types = (
np.uint8,
np.int8,
np.uint16,
np.int16,
np.uint32,
np.int32,
np.uint64,
np.int64,
)
class Continuous(Space):
"""A space contains the continuous values."""
def __init__(
self,
lower: float,
upper: float,
default: Optional[float] = None,
log: bool = False,
eps: float = _EPS,
):
super(Continuous, self).__init__()
self._lower = lower
self._upper = upper
self._default = default
self._log_scale = log
self._eps = eps
@property
def lower(self):
return self._lower
@property
def upper(self):
return self._upper
@property
def default(self):
return self._default
@property
def use_log(self):
return self._log_scale
@property
def eps(self):
return self._eps
def abstract(self, reuse_last=False) -> Space:
if reuse_last and self._last_abstract is not None:
return self._last_abstract
self._last_abstract = self.copy()
return self._last_abstract
def random(self, recursion=True, reuse_last=False):
del recursion
if reuse_last and self._last_sample is not None:
return self._last_sample
if self._log_scale:
sample = random.uniform(math.log(self._lower), math.log(self._upper))
sample = math.exp(sample)
else:
sample = random.uniform(self._lower, self._upper)
self._last_sample = VirtualNode(None, sample)
return self._last_sample
def xrepr(self, depth=0):
del depth
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
name=self.__class__.__name__,
lower=self._lower,
upper=self._upper,
default=self._default,
log=self._log_scale,
)
return xrepr
def convert(self, x):
if isinstance(x, np_float_types) and x.size == 1:
return float(x), True
elif isinstance(x, np_int_types) and x.size == 1:
return float(x), True
elif isinstance(x, int):
return float(x), True
elif isinstance(x, float):
return float(x), True
else:
return None, False
def has(self, x):
super().has(x)
converted_x, success = self.convert(x)
return success and self.lower <= converted_x <= self.upper
@property
def determined(self):
return abs(self.lower - self.upper) <= self._eps
def clean_last_sample(self):
self._last_sample = None
def clean_last_abstract(self):
self._last_abstract = None
def __eq__(self, other):
if not isinstance(other, Continuous):
return False
if self is other:
return True
else:
return (
self.lower == other.lower
and self.upper == other.upper
and self.default == other.default
and self.use_log == other.use_log
and self.eps == other.eps
)