##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # ##################################################### import abc import math import random from typing import Optional class Space(metaclass=abc.ABCMeta): @abc.abstractmethod def random(self, recursion=True): raise NotImplementedError @abc.abstractmethod def __repr__(self): raise NotImplementedError class Categorical(Space): def __init__(self, *data, default: Optional[int] = None): self._candidates = [*data] self._default = default assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format( len(self._candidates) ) def __getitem__(self, index): return self._candidates[index] def __len__(self): return len(self._candidates) def __repr__(self): return "{name:}(candidates={cs:}, default_index={default:})".format( name=self.__class__.__name__, cs=self._candidates, default=self._default ) def random(self, recursion=True): sample = random.choice(self._candidates) if recursion and isinstance(sample, Space): return sample.random(recursion) else: return sample class Continuous(Space): def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False): self._lower = lower self._upper = upper self._default = default self._log_scale = log @property def lower(self): return self._lower @property def upper(self): return self._upper @property def default(self): return self._default def __repr__(self): return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( name=self.__class__.__name__, lower=self._lower, upper=self._upper, default=self._default, log=self._log_scale, ) def random(self, recursion=True): del recursion if self._log_scale: sample = random.uniform(math.log(self._lower), math.log(self._upper)) return math.exp(sample) else: return random.uniform(self._lower, self._upper)