84 lines
2.3 KiB
Python
84 lines
2.3 KiB
Python
#####################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
|
#####################################################
|
|
|
|
import abc
|
|
import math
|
|
import random
|
|
|
|
from typing import Optional
|
|
|
|
|
|
class Space(metaclass=abc.ABCMeta):
|
|
@abc.abstractmethod
|
|
def random(self, recursion=True):
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def __repr__(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class Categorical(Space):
|
|
def __init__(self, *data, default: Optional[int] = None):
|
|
self._candidates = [*data]
|
|
self._default = default
|
|
assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format(
|
|
len(self._candidates)
|
|
)
|
|
|
|
def __getitem__(self, index):
|
|
return self._candidates[index]
|
|
|
|
def __len__(self):
|
|
return len(self._candidates)
|
|
|
|
def __repr__(self):
|
|
return "{name:}(candidates={cs:}, default_index={default:})".format(
|
|
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
|
)
|
|
|
|
def random(self, recursion=True):
|
|
sample = random.choice(self._candidates)
|
|
if recursion and isinstance(sample, Space):
|
|
return sample.random(recursion)
|
|
else:
|
|
return sample
|
|
|
|
|
|
class Continuous(Space):
|
|
def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False):
|
|
self._lower = lower
|
|
self._upper = upper
|
|
self._default = default
|
|
self._log_scale = log
|
|
|
|
@property
|
|
def lower(self):
|
|
return self._lower
|
|
|
|
@property
|
|
def upper(self):
|
|
return self._upper
|
|
|
|
@property
|
|
def default(self):
|
|
return self._default
|
|
|
|
def __repr__(self):
|
|
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
|
name=self.__class__.__name__,
|
|
lower=self._lower,
|
|
upper=self._upper,
|
|
default=self._default,
|
|
log=self._log_scale,
|
|
)
|
|
|
|
def random(self, recursion=True):
|
|
del recursion
|
|
if self._log_scale:
|
|
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
|
return math.exp(sample)
|
|
else:
|
|
return random.uniform(self._lower, self._upper)
|