xautodl/lib/spaces/basic_space.py
2021-03-18 14:05:29 +08:00

84 lines
2.3 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
#####################################################
import abc
import math
import random
from typing import Optional
class Space(metaclass=abc.ABCMeta):
@abc.abstractmethod
def random(self, recursion=True):
raise NotImplementedError
@abc.abstractmethod
def __repr__(self):
raise NotImplementedError
class Categorical(Space):
def __init__(self, *data, default: Optional[int] = None):
self._candidates = [*data]
self._default = default
assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format(
len(self._candidates)
)
def __getitem__(self, index):
return self._candidates[index]
def __len__(self):
return len(self._candidates)
def __repr__(self):
return "{name:}(candidates={cs:}, default_index={default:})".format(
name=self.__class__.__name__, cs=self._candidates, default=self._default
)
def random(self, recursion=True):
sample = random.choice(self._candidates)
if recursion and isinstance(sample, Space):
return sample.random(recursion)
else:
return sample
class Continuous(Space):
def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False):
self._lower = lower
self._upper = upper
self._default = default
self._log_scale = log
@property
def lower(self):
return self._lower
@property
def upper(self):
return self._upper
@property
def default(self):
return self._default
def __repr__(self):
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
name=self.__class__.__name__,
lower=self._lower,
upper=self._upper,
default=self._default,
log=self._log_scale,
)
def random(self, recursion=True):
del recursion
if self._log_scale:
sample = random.uniform(math.log(self._lower), math.log(self._upper))
return math.exp(sample)
else:
return random.uniform(self._lower, self._upper)