xautodl/lib/spaces/basic_space.py

174 lines
4.5 KiB
Python
Raw Normal View History

2021-03-17 13:48:43 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
#####################################################
import abc
2021-03-18 07:05:29 +01:00
import math
2021-03-18 08:04:14 +01:00
import copy
2021-03-17 13:48:43 +01:00
import random
2021-03-18 08:04:14 +01:00
import numpy as np
2021-03-17 13:48:43 +01:00
2021-03-18 07:05:29 +01:00
from typing import Optional
2021-03-18 08:04:14 +01:00
_EPS = 1e-9
2021-03-17 13:48:43 +01:00
class Space(metaclass=abc.ABCMeta):
2021-03-18 08:04:14 +01:00
"""Basic search space describing the set of possible candidate values for hyperparameter.
All search space must inherit from this basic class.
"""
2021-03-17 13:48:43 +01:00
@abc.abstractmethod
2021-03-18 07:05:29 +01:00
def random(self, recursion=True):
2021-03-17 13:48:43 +01:00
raise NotImplementedError
2021-03-18 08:04:14 +01:00
@abc.abstractproperty
def determined(self):
raise NotImplementedError
2021-03-17 13:48:43 +01:00
@abc.abstractmethod
def __repr__(self):
raise NotImplementedError
2021-03-18 08:04:14 +01:00
@abc.abstractmethod
def has(self, x):
"""Check whether x is in this search space."""
assert not isinstance(
x, Space
), "The input value itself can not be a search space."
def copy(self):
return copy.deepcopy(self)
2021-03-17 13:48:43 +01:00
class Categorical(Space):
2021-03-18 08:04:14 +01:00
"""A space contains the categorical values.
It can be a nested space, which means that the candidate in this space can also be a search space.
"""
2021-03-18 07:05:29 +01:00
def __init__(self, *data, default: Optional[int] = None):
2021-03-17 13:48:43 +01:00
self._candidates = [*data]
2021-03-18 07:05:29 +01:00
self._default = default
2021-03-18 08:04:14 +01:00
assert self._default is None or 0 <= self._default < len(
self._candidates
), "default >= {:}".format(len(self._candidates))
assert len(self) > 0, "Please provide at least one candidate"
@property
def determined(self):
if len(self) == 1:
return (
not isinstance(self._candidates[0], Space)
or self._candidates[0].determined
)
else:
return False
2021-03-17 13:48:43 +01:00
def __getitem__(self, index):
return self._candidates[index]
def __len__(self):
return len(self._candidates)
def __repr__(self):
2021-03-18 07:05:29 +01:00
return "{name:}(candidates={cs:}, default_index={default:})".format(
name=self.__class__.__name__, cs=self._candidates, default=self._default
)
2021-03-18 08:04:14 +01:00
def has(self, x):
super().has(x)
for candidate in self._candidates:
if isinstance(candidate, Space) and candidate.has(x):
return True
elif candidate == x:
return True
return False
2021-03-18 07:05:29 +01:00
def random(self, recursion=True):
sample = random.choice(self._candidates)
if recursion and isinstance(sample, Space):
return sample.random(recursion)
else:
return sample
2021-03-18 08:04:14 +01:00
np_float_types = (np.float16, np.float32, np.float64)
np_int_types = (
np.uint8,
np.int8,
np.uint16,
np.int16,
np.uint32,
np.int32,
np.uint64,
np.int64,
)
2021-03-18 07:05:29 +01:00
class Continuous(Space):
2021-03-18 08:04:14 +01:00
"""A space contains the continuous values."""
def __init__(
self,
lower: float,
upper: float,
default: Optional[float] = None,
log: bool = False,
eps: float = _EPS,
):
2021-03-18 07:05:29 +01:00
self._lower = lower
self._upper = upper
self._default = default
self._log_scale = log
2021-03-18 08:04:14 +01:00
self._eps = eps
2021-03-18 07:05:29 +01:00
@property
def lower(self):
return self._lower
@property
def upper(self):
return self._upper
@property
def default(self):
return self._default
2021-03-18 08:04:14 +01:00
@property
def determined(self):
return abs(self.lower - self.upper) <= self._eps
2021-03-18 07:05:29 +01:00
def __repr__(self):
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
name=self.__class__.__name__,
lower=self._lower,
upper=self._upper,
default=self._default,
log=self._log_scale,
)
2021-03-17 13:48:43 +01:00
2021-03-18 08:04:14 +01:00
def convert(self, x):
if isinstance(x, np_float_types) and x.size == 1:
return float(x), True
elif isinstance(x, np_int_types) and x.size == 1:
return float(x), True
elif isinstance(x, int):
return float(x), True
elif isinstance(x, float):
return float(x), True
else:
return None, False
def has(self, x):
super().has(x)
converted_x, success = self.convert(x)
return success and self.lower <= converted_x <= self.upper
2021-03-18 07:05:29 +01:00
def random(self, recursion=True):
del recursion
if self._log_scale:
sample = random.uniform(math.log(self._lower), math.log(self._upper))
return math.exp(sample)
else:
return random.uniform(self._lower, self._upper)