Fix small bugs
This commit is contained in:
parent
58733c18be
commit
d04edcd211
2
.gitmodules
vendored
2
.gitmodules
vendored
@ -1,6 +1,6 @@
|
|||||||
[submodule ".latent-data/qlib"]
|
[submodule ".latent-data/qlib"]
|
||||||
path = .latent-data/qlib
|
path = .latent-data/qlib
|
||||||
url = git@github.com:D-X-Y/qlib.git
|
url = git@github.com:microsoft/qlib.git
|
||||||
[submodule ".latent-data/NATS-Bench"]
|
[submodule ".latent-data/NATS-Bench"]
|
||||||
path = .latent-data/NATS-Bench
|
path = .latent-data/NATS-Bench
|
||||||
url = git@github.com:D-X-Y/NATS-Bench.git
|
url = git@github.com:D-X-Y/NATS-Bench.git
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 2d4f0e80f98211ba2e1f25a329ad2421fb8087cd
|
Subproject commit 6608a40965cb824269fea637c8a6be177994b523
|
10
configs/archs/NAS-CIFAR-DARTS.config
Normal file
10
configs/archs/NAS-CIFAR-DARTS.config
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"arch" : ["str", "dxys"],
|
||||||
|
"genotype" : ["str", "DARTS"],
|
||||||
|
"dataset" : ["str", "cifar"],
|
||||||
|
"ichannel" : ["int", 36],
|
||||||
|
"layers" : ["int", 6],
|
||||||
|
"stem_multi": ["int", 3],
|
||||||
|
"auxiliary" : ["bool", 1],
|
||||||
|
"drop_path_prob": ["float", 0.2]
|
||||||
|
}
|
@ -47,7 +47,7 @@ task:
|
|||||||
net_config:
|
net_config:
|
||||||
name: basic
|
name: basic
|
||||||
d_feat: 6
|
d_feat: 6
|
||||||
embed_dim: 48
|
embed_dim: 32
|
||||||
num_heads: [4, 4, 4, 4, 4]
|
num_heads: [4, 4, 4, 4, 4]
|
||||||
mlp_hidden_multipliers: [4, 4, 4, 4, 4]
|
mlp_hidden_multipliers: [4, 4, 4, 4, 4]
|
||||||
qkv_bias: True
|
qkv_bias: True
|
||||||
|
@ -32,6 +32,9 @@ Train some NAS models:
|
|||||||
```
|
```
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 SETN 96 -1
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 SETN 96 -1
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 SETN 96 -1
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 SETN 96 -1
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 DARTS 96 -1
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN1 256 -1
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN1 256 -1
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k DARTS 256 -1
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k DARTS 256 -1
|
||||||
|
@ -22,7 +22,7 @@ from utils import get_model_infos
|
|||||||
flop, param = get_model_infos(net, (1,3,32,32))
|
flop, param = get_model_infos(net, (1,3,32,32))
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/nas_infer_model/DXYs/genotypes.py).
|
2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/AutoDL-Projects/blob/main/xautodl/nas_infer_model/DXYs/genotypes.py).
|
||||||
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@ -34,7 +34,7 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 GDAS_V1 96 -1
|
|||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 GDAS_V1 96 -1
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 GDAS_V1 96 -1
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1
|
||||||
```
|
```
|
||||||
If you are interested in the configs of each NAS-searched architecture, they are defined at [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/nas_infer_model/DXYs/genotypes.py).
|
If you are interested in the configs of each NAS-searched architecture, they are defined at [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/main/xautodl/nas_infer_model/DXYs/genotypes.py).
|
||||||
|
|
||||||
### Searching on the NASNet search space
|
### Searching on the NASNet search space
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
|
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
|
||||||
# python exps/trading/workflow_tt.py --gpu 1 --market csi300
|
# python exps/trading/workflow_tt.py --gpu 1 --market csi300
|
||||||
#####################################################
|
#####################################################
|
||||||
|
import yaml
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from xautodl.procedures.q_exps import update_gpu
|
from xautodl.procedures.q_exps import update_gpu
|
||||||
@ -57,7 +58,7 @@ def main(xargs):
|
|||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"class": "QuantTransformer",
|
"class": "QuantTransformer",
|
||||||
"module_path": "trade_models",
|
"module_path": "xautodl.trade_models.quant_transformer",
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"net_config": None,
|
"net_config": None,
|
||||||
"opt_config": None,
|
"opt_config": None,
|
||||||
@ -108,6 +109,62 @@ def main(xargs):
|
|||||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||||
|
|
||||||
|
from qlib.utils import init_instance_by_config
|
||||||
|
|
||||||
|
xconfig = """
|
||||||
|
model:
|
||||||
|
class: SFM
|
||||||
|
module_path: qlib.contrib.model.pytorch_sfm
|
||||||
|
kwargs:
|
||||||
|
d_feat: 6
|
||||||
|
hidden_size: 64
|
||||||
|
output_dim: 32
|
||||||
|
freq_dim: 25
|
||||||
|
dropout_W: 0.5
|
||||||
|
dropout_U: 0.5
|
||||||
|
n_epochs: 20
|
||||||
|
lr: 1e-3
|
||||||
|
batch_size: 1600
|
||||||
|
early_stop: 20
|
||||||
|
eval_steps: 5
|
||||||
|
loss: mse
|
||||||
|
optimizer: adam
|
||||||
|
GPU: 0
|
||||||
|
"""
|
||||||
|
xconfig = """
|
||||||
|
model:
|
||||||
|
class: TabnetModel
|
||||||
|
module_path: qlib.contrib.model.pytorch_tabnet
|
||||||
|
kwargs:
|
||||||
|
d_feat: 360
|
||||||
|
pretrain: True
|
||||||
|
"""
|
||||||
|
xconfig = """
|
||||||
|
model:
|
||||||
|
class: GRU
|
||||||
|
module_path: qlib.contrib.model.pytorch_gru
|
||||||
|
kwargs:
|
||||||
|
d_feat: 6
|
||||||
|
hidden_size: 64
|
||||||
|
num_layers: 4
|
||||||
|
dropout: 0.0
|
||||||
|
n_epochs: 200
|
||||||
|
lr: 0.001
|
||||||
|
early_stop: 20
|
||||||
|
batch_size: 800
|
||||||
|
metric: loss
|
||||||
|
loss: mse
|
||||||
|
GPU: 0
|
||||||
|
"""
|
||||||
|
xconfig = yaml.safe_load(xconfig)
|
||||||
|
model = init_instance_by_config(xconfig["model"])
|
||||||
|
from xautodl.utils.flop_benchmark import count_parameters_in_MB
|
||||||
|
|
||||||
|
# print(count_parameters_in_MB(model.tabnet_model))
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market)
|
save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market)
|
||||||
dataset = init_instance_by_config(dataset_config)
|
dataset = init_instance_by_config(dataset_config)
|
||||||
for irun in range(xargs.times):
|
for irun in range(xargs.times):
|
||||||
|
@ -12,7 +12,6 @@ black ./tests/
|
|||||||
black ./xautodl/procedures
|
black ./xautodl/procedures
|
||||||
black ./xautodl/datasets
|
black ./xautodl/datasets
|
||||||
black ./xautodl/xlayers
|
black ./xautodl/xlayers
|
||||||
black ./exps/LFNA
|
|
||||||
black ./exps/trading
|
black ./exps/trading
|
||||||
rm -rf ./xautodl.egg-info
|
rm -rf ./xautodl.egg-info
|
||||||
rm -rf ./build
|
rm -rf ./build
|
||||||
|
6
setup.py
6
setup.py
@ -17,6 +17,8 @@
|
|||||||
# TODO(xuanyidong): upload it to conda
|
# TODO(xuanyidong): upload it to conda
|
||||||
#
|
#
|
||||||
# [2021.06.01] v0.9.9
|
# [2021.06.01] v0.9.9
|
||||||
|
# [2021.08.14] v1.0.0
|
||||||
|
#
|
||||||
import os
|
import os
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
@ -24,7 +26,7 @@ NAME = "xautodl"
|
|||||||
REQUIRES_PYTHON = ">=3.6"
|
REQUIRES_PYTHON = ">=3.6"
|
||||||
DESCRIPTION = "Automated Deep Learning Package"
|
DESCRIPTION = "Automated Deep Learning Package"
|
||||||
|
|
||||||
VERSION = "0.9.9"
|
VERSION = "1.0.0"
|
||||||
|
|
||||||
|
|
||||||
def read(fname="README.md"):
|
def read(fname="README.md"):
|
||||||
@ -35,7 +37,7 @@ def read(fname="README.md"):
|
|||||||
|
|
||||||
|
|
||||||
# What packages are required for this module to be executed?
|
# What packages are required for this module to be executed?
|
||||||
REQUIRED = ["numpy>=1.16.5,<=1.19.5", "pyyaml>=5.0.0"]
|
REQUIRED = ["numpy>=1.16.5,<=1.19.5", "pyyaml>=5.0.0", "fvcore"]
|
||||||
|
|
||||||
packages = find_packages(
|
packages = find_packages(
|
||||||
exclude=("tests", "scripts", "scripts-search", "lib*", "exps*")
|
exclude=("tests", "scripts", "scripts-search", "lib*", "exps*")
|
||||||
|
@ -7,5 +7,6 @@
|
|||||||
|
|
||||||
|
|
||||||
def version():
|
def version():
|
||||||
versions = ["0.9.9"] # 2021.05.19
|
versions = ["0.9.9"] # 2021.06.01
|
||||||
|
versions = ["1.0.0"] # 2021.08.14
|
||||||
return versions[-1]
|
return versions[-1]
|
||||||
|
@ -42,12 +42,13 @@ def _assert_types(x, expected_types):
|
|||||||
|
|
||||||
|
|
||||||
DEFAULT_NET_CONFIG = None
|
DEFAULT_NET_CONFIG = None
|
||||||
_default_max_depth = 5
|
_default_max_depth = 6
|
||||||
DefaultSearchSpace = dict(
|
DefaultSearchSpace = dict(
|
||||||
d_feat=6,
|
d_feat=6,
|
||||||
embed_dim=spaces.Categorical(*_get_list_mul(8, 16)),
|
embed_dim=32,
|
||||||
num_heads=_get_mul_specs((1, 2, 4, 8), _default_max_depth),
|
# embed_dim=spaces.Categorical(*_get_list_mul(8, 16)),
|
||||||
mlp_hidden_multipliers=_get_mul_specs((0.5, 1, 2, 4, 8), _default_max_depth),
|
num_heads=[4] * _default_max_depth,
|
||||||
|
mlp_hidden_multipliers=[4] * _default_max_depth,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
pos_drop=0.0,
|
pos_drop=0.0,
|
||||||
other_drop=0.0,
|
other_drop=0.0,
|
||||||
|
@ -14,20 +14,24 @@ def count_parameters(model_or_parameters, unit="mb", deprecated=False):
|
|||||||
if isinstance(model_or_parameters, nn.Module):
|
if isinstance(model_or_parameters, nn.Module):
|
||||||
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
|
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
|
||||||
elif isinstance(model_or_parameters, nn.Parameter):
|
elif isinstance(model_or_parameters, nn.Parameter):
|
||||||
counts = models_or_parameters.numel()
|
counts = model_or_parameters.numel()
|
||||||
elif isinstance(model_or_parameters, (list, tuple)):
|
elif isinstance(model_or_parameters, (list, tuple)):
|
||||||
counts = sum(
|
counts = sum(
|
||||||
count_parameters(x, None, deprecated) for x in models_or_parameters
|
count_parameters(x, None, deprecated) for x in model_or_parameters
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
counts = sum(np.prod(v.size()) for v in model_or_parameters)
|
counts = sum(np.prod(v.size()) for v in model_or_parameters)
|
||||||
if unit.lower() == "kb" or unit.lower() == "k":
|
if not isinstance(unit, str) and unit is not None:
|
||||||
|
raise ValueError("Unknow type of unit: {:}".format(unit))
|
||||||
|
elif unit is None:
|
||||||
|
counts = counts
|
||||||
|
elif unit.lower() == "kb" or unit.lower() == "k":
|
||||||
counts /= 1e3 if deprecated else 2 ** 10 # changed from 1e3 to 2^10
|
counts /= 1e3 if deprecated else 2 ** 10 # changed from 1e3 to 2^10
|
||||||
elif unit.lower() == "mb" or unit.lower() == "m":
|
elif unit.lower() == "mb" or unit.lower() == "m":
|
||||||
counts /= 1e6 if deprecated else 2 ** 20 # changed from 1e6 to 2^20
|
counts /= 1e6 if deprecated else 2 ** 20 # changed from 1e6 to 2^20
|
||||||
elif unit.lower() == "gb" or unit.lower() == "g":
|
elif unit.lower() == "gb" or unit.lower() == "g":
|
||||||
counts /= 1e9 if deprecated else 2 ** 30 # changed from 1e9 to 2^30
|
counts /= 1e9 if deprecated else 2 ** 30 # changed from 1e9 to 2^30
|
||||||
elif unit is not None:
|
else:
|
||||||
raise ValueError("Unknow unit: {:}".format(unit))
|
raise ValueError("Unknow unit: {:}".format(unit))
|
||||||
return counts
|
return counts
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user