autodl-projects/xautodl/datasets/synthetic_core.py

39 lines
1.4 KiB
Python
Raw Normal View History

2021-05-09 12:53:18 +02:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
#####################################################
2021-05-24 07:06:10 +02:00
import math
2021-05-09 12:53:18 +02:00
from .synthetic_utils import TimeStamp
2021-04-28 17:56:25 +02:00
from .synthetic_env import SyntheticDEnv
2021-05-09 13:11:56 +02:00
from .math_core import LinearFunc
from .math_core import DynamicLinearFunc
from .math_core import DynamicQuadraticFunc
from .math_core import ConstantFunc, ComposedSinFunc
2021-05-24 07:06:10 +02:00
from .math_core import GaussianDGenerator
2021-04-28 17:56:25 +02:00
2021-05-09 12:53:18 +02:00
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"):
if version == "v1":
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
2021-05-24 07:06:10 +02:00
data_generator = GaussianDGenerator(
[mean_generator], [[std_generator]], (-2, 2)
2021-05-09 12:53:18 +02:00
)
2021-05-24 07:06:10 +02:00
time_generator = TimeStamp(
2021-05-24 07:38:02 +02:00
min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode
2021-05-09 12:53:18 +02:00
)
2021-05-24 07:06:10 +02:00
oracle_map = DynamicLinearFunc(
params={
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
2021-05-24 07:38:02 +02:00
1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),
2021-05-24 07:06:10 +02:00
}
)
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
2021-05-09 13:05:07 +02:00
)
2021-05-09 12:53:18 +02:00
else:
raise ValueError("Unknown version: {:}".format(version))
2021-04-28 17:56:25 +02:00
return dynamic_env