Update xmisc with yaml
This commit is contained in:
parent
aef5c7579b
commit
1a7440d2af
2
.github/workflows/basic_test.yml
vendored
2
.github/workflows/basic_test.yml
vendored
@ -32,7 +32,7 @@ jobs:
|
|||||||
echo $PWD ; ls
|
echo $PWD ; ls
|
||||||
python -m black ./exps -l 88 --check --diff --verbose
|
python -m black ./exps -l 88 --check --diff --verbose
|
||||||
python -m black ./tests -l 88 --check --diff --verbose
|
python -m black ./tests -l 88 --check --diff --verbose
|
||||||
python -m black ./xautodl/xlayers -l 88 --check --diff --verbose
|
python -m black ./xautodl/x* -l 88 --check --diff --verbose
|
||||||
python -m black ./xautodl/spaces -l 88 --check --diff --verbose
|
python -m black ./xautodl/spaces -l 88 --check --diff --verbose
|
||||||
python -m black ./xautodl/trade_models -l 88 --check --diff --verbose
|
python -m black ./xautodl/trade_models -l 88 --check --diff --verbose
|
||||||
python -m black ./xautodl/procedures -l 88 --check --diff --verbose
|
python -m black ./xautodl/procedures -l 88 --check --diff --verbose
|
||||||
|
7
configs/data.yaml/cifar10.test
Normal file
7
configs/data.yaml/cifar10.test
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
class_or_func: CIFAR10
|
||||||
|
module_path: torchvision.datasets
|
||||||
|
args: []
|
||||||
|
kwargs:
|
||||||
|
train: False
|
||||||
|
download: True
|
||||||
|
transform: null
|
7
configs/data.yaml/cifar10.train
Normal file
7
configs/data.yaml/cifar10.train
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
class_or_func: CIFAR10
|
||||||
|
module_path: torchvision.datasets
|
||||||
|
args: []
|
||||||
|
kwargs:
|
||||||
|
train: True
|
||||||
|
download: True
|
||||||
|
transform: null
|
@ -1,35 +1,28 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||||
|
#####################################################
|
||||||
|
# python exps/basic/xmain.py --save_dir outputs/x #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, torch, random, argparse
|
import sys, time, torch, random, argparse
|
||||||
from PIL import ImageFile
|
|
||||||
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from xautodl.datasets import get_datasets
|
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||||
from xautodl.config_utils import load_config, obtain_basic_args as obtain_args
|
print("LIB-DIR: {:}".format(lib_dir))
|
||||||
from xautodl.procedures import (
|
if str(lib_dir) not in sys.path:
|
||||||
prepare_seed,
|
sys.path.insert(0, str(lib_dir))
|
||||||
prepare_logger,
|
|
||||||
save_checkpoint,
|
from xautodl.xmisc import nested_call_by_yaml
|
||||||
copy_checkpoint,
|
|
||||||
)
|
|
||||||
from xautodl.procedures import get_optim_scheduler, get_procedures
|
|
||||||
from xautodl.models import obtain_model
|
|
||||||
from xautodl.xmodels import obtain_model as obtain_xmodel
|
|
||||||
from xautodl.nas_infer_model import obtain_nas_infer_model
|
|
||||||
from xautodl.utils import get_model_infos
|
|
||||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
assert torch.cuda.is_available(), "CUDA is not available."
|
|
||||||
torch.backends.cudnn.enabled = True
|
train_data = nested_call_by_yaml(args.train_data_config, args.data_path)
|
||||||
torch.backends.cudnn.benchmark = True
|
valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path)
|
||||||
# torch.backends.cudnn.deterministic = True
|
|
||||||
# torch.set_num_threads(args.workers)
|
import pdb
|
||||||
|
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
prepare_seed(args.rand_seed)
|
prepare_seed(args.rand_seed)
|
||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
@ -290,5 +283,44 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = obtain_args()
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Train a model with a loss function.",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir", type=str, help="Folder to save checkpoints and log."
|
||||||
|
)
|
||||||
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
|
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||||
|
parser.add_argument("--model_config", type=str, help="The path to the model config")
|
||||||
|
parser.add_argument(
|
||||||
|
"--optim_config", type=str, help="The path to the optimizer config"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_data_config", type=str, help="The dataset config path."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--valid_data_config", type=str, help="The dataset config path."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_path", type=str, help="The path to the dataset."
|
||||||
|
)
|
||||||
|
parser.add_argument("--algorithm", type=str, help="The algorithm.")
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument("--batch_size", type=int, default=2, help="The batch size.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="number of data loading workers (default: 8)",
|
||||||
|
)
|
||||||
|
# Random Seed
|
||||||
|
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
|
args.rand_seed = random.randint(1, 100000)
|
||||||
|
if args.save_dir is None:
|
||||||
|
raise ValueError("The save-path argument can not be None")
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
27
scripts/experimental/train-vit.sh
Normal file
27
scripts/experimental/train-vit.sh
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./scripts/experimental/train-vit.sh cifar10 -1
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 2 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 2 parameters for dataset and random-seed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
|
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
dataset=$1
|
||||||
|
rseed=$2
|
||||||
|
|
||||||
|
save_dir=./outputs/${dataset}/vit-experimental
|
||||||
|
|
||||||
|
python --version
|
||||||
|
|
||||||
|
python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \
|
||||||
|
--train_data_config ./configs/data.yaml/${dataset}.train \
|
||||||
|
--valid_data_config ./configs/data.yaml/${dataset}.test \
|
||||||
|
--data_path $TORCH_HOME/cifar.python
|
@ -5,45 +5,51 @@ import torch.nn as nn
|
|||||||
class ImageNetHEAD(nn.Sequential):
|
class ImageNetHEAD(nn.Sequential):
|
||||||
def __init__(self, C, stride=2):
|
def __init__(self, C, stride=2):
|
||||||
super(ImageNetHEAD, self).__init__()
|
super(ImageNetHEAD, self).__init__()
|
||||||
self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False))
|
self.add_module(
|
||||||
self.add_module('bn1' , nn.BatchNorm2d(C // 2))
|
"conv1",
|
||||||
self.add_module('relu1', nn.ReLU(inplace=True))
|
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||||
self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False))
|
)
|
||||||
self.add_module('bn2' , nn.BatchNorm2d(C))
|
self.add_module("bn1", nn.BatchNorm2d(C // 2))
|
||||||
|
self.add_module("relu1", nn.ReLU(inplace=True))
|
||||||
|
self.add_module(
|
||||||
|
"conv2",
|
||||||
|
nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||||
|
)
|
||||||
|
self.add_module("bn2", nn.BatchNorm2d(C))
|
||||||
|
|
||||||
|
|
||||||
class CifarHEAD(nn.Sequential):
|
class CifarHEAD(nn.Sequential):
|
||||||
def __init__(self, C):
|
def __init__(self, C):
|
||||||
super(CifarHEAD, self).__init__()
|
super(CifarHEAD, self).__init__()
|
||||||
self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
|
self.add_module("conv", nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
|
||||||
self.add_module('bn', nn.BatchNorm2d(C))
|
self.add_module("bn", nn.BatchNorm2d(C))
|
||||||
|
|
||||||
|
|
||||||
class AuxiliaryHeadCIFAR(nn.Module):
|
class AuxiliaryHeadCIFAR(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C, num_classes):
|
def __init__(self, C, num_classes):
|
||||||
"""assuming input size 8x8"""
|
"""assuming input size 8x8"""
|
||||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||||
self.features = nn.Sequential(
|
self.features = nn.Sequential(
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
nn.AvgPool2d(
|
||||||
|
5, stride=3, padding=0, count_include_pad=False
|
||||||
|
), # image size = 2 x 2
|
||||||
nn.Conv2d(C, 128, 1, bias=False),
|
nn.Conv2d(C, 128, 1, bias=False),
|
||||||
nn.BatchNorm2d(128),
|
nn.BatchNorm2d(128),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(128, 768, 2, bias=False),
|
nn.Conv2d(128, 768, 2, bias=False),
|
||||||
nn.BatchNorm2d(768),
|
nn.BatchNorm2d(768),
|
||||||
nn.ReLU(inplace=True)
|
nn.ReLU(inplace=True),
|
||||||
)
|
)
|
||||||
self.classifier = nn.Linear(768, num_classes)
|
self.classifier = nn.Linear(768, num_classes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.features(x)
|
x = self.features(x)
|
||||||
x = self.classifier(x.view(x.size(0),-1))
|
x = self.classifier(x.view(x.size(0), -1))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AuxiliaryHeadImageNet(nn.Module):
|
class AuxiliaryHeadImageNet(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C, num_classes):
|
def __init__(self, C, num_classes):
|
||||||
"""assuming input size 14x14"""
|
"""assuming input size 14x14"""
|
||||||
super(AuxiliaryHeadImageNet, self).__init__()
|
super(AuxiliaryHeadImageNet, self).__init__()
|
||||||
@ -55,11 +61,11 @@ class AuxiliaryHeadImageNet(nn.Module):
|
|||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(128, 768, 2, bias=False),
|
nn.Conv2d(128, 768, 2, bias=False),
|
||||||
nn.BatchNorm2d(768),
|
nn.BatchNorm2d(768),
|
||||||
nn.ReLU(inplace=True)
|
nn.ReLU(inplace=True),
|
||||||
)
|
)
|
||||||
self.classifier = nn.Linear(768, num_classes)
|
self.classifier = nn.Linear(768, num_classes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.features(x)
|
x = self.features(x)
|
||||||
x = self.classifier(x.view(x.size(0),-1))
|
x = self.classifier(x.view(x.size(0), -1))
|
||||||
return x
|
return x
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
######################################################################
|
||||||
|
# This folder is deprecated, which is re-organized in "xalgorithms". #
|
||||||
|
######################################################################
|
||||||
from .starts import prepare_seed
|
from .starts import prepare_seed
|
||||||
from .starts import prepare_logger
|
from .starts import prepare_logger
|
||||||
from .starts import get_machine_info
|
from .starts import get_machine_info
|
||||||
|
@ -47,7 +47,7 @@ class SuperSelfAttention(SuperModule):
|
|||||||
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||||
|
|
||||||
self.attn_drop = SuperDrop(attn_drop or 0.0, [-1, -1, -1, -1], recover=True)
|
self.attn_drop = SuperDrop(attn_drop or 0.0, [-1, -1, -1, -1], recover=True)
|
||||||
if proj_dim is None:
|
if proj_dim is not None:
|
||||||
self.proj = SuperLinear(input_dim, proj_dim)
|
self.proj = SuperLinear(input_dim, proj_dim)
|
||||||
self.proj_drop = SuperDropout(proj_drop or 0.0)
|
self.proj_drop = SuperDropout(proj_drop or 0.0)
|
||||||
else:
|
else:
|
||||||
|
8
xautodl/xmisc/__init__.py
Normal file
8
xautodl/xmisc/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||||
|
#####################################################
|
||||||
|
from .module_utils import call_by_dict
|
||||||
|
from .module_utils import call_by_yaml
|
||||||
|
from .module_utils import nested_call_by_dict
|
||||||
|
from .module_utils import nested_call_by_yaml
|
||||||
|
from .yaml_utils import load_yaml
|
81
xautodl/xmisc/module_utils.py
Normal file
81
xautodl/xmisc/module_utils.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
||||||
|
#####################################################
|
||||||
|
from typing import Union, Dict, Text, Any
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from .yaml_utils import load_yaml
|
||||||
|
|
||||||
|
CLS_FUNC_KEY = "class_or_func"
|
||||||
|
KEYS = (CLS_FUNC_KEY, "module_path", "args", "kwargs")
|
||||||
|
|
||||||
|
|
||||||
|
def has_key_words(xdict):
|
||||||
|
if not isinstance(xdict, dict):
|
||||||
|
return False
|
||||||
|
key_set = set(KEYS)
|
||||||
|
cur_set = set(xdict.keys())
|
||||||
|
return key_set.intersection(cur_set) == key_set
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_by_module_path(module_path):
|
||||||
|
"""Load the module from the path."""
|
||||||
|
|
||||||
|
if module_path.endswith(".py"):
|
||||||
|
module_spec = importlib.util.spec_from_file_location("", module_path)
|
||||||
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
|
module_spec.loader.exec_module(module)
|
||||||
|
else:
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def call_by_dict(config: Dict[Text, Any], *args, **kwargs) -> object:
|
||||||
|
"""
|
||||||
|
get initialized instance with config
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : a dictionary, such as:
|
||||||
|
{
|
||||||
|
'cls_or_func': 'ClassName',
|
||||||
|
'args': list,
|
||||||
|
'kwargs': dict,
|
||||||
|
'model_path': a string indicating the path,
|
||||||
|
}
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
object:
|
||||||
|
An initialized object based on the config info
|
||||||
|
"""
|
||||||
|
module = get_module_by_module_path(config["module_path"])
|
||||||
|
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
||||||
|
args = tuple(list(config["args"]) + list(args))
|
||||||
|
kwargs = {**config["kwargs"], **kwargs}
|
||||||
|
return cls_or_func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def call_by_yaml(path, *args, **kwargs) -> object:
|
||||||
|
config = load_yaml(path)
|
||||||
|
return call_by_config(config, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object:
|
||||||
|
"""Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called."""
|
||||||
|
if not has_key_words(config):
|
||||||
|
return config
|
||||||
|
module = get_module_by_module_path(config["module_path"])
|
||||||
|
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
||||||
|
args = tuple(list(config["args"]) + list(args))
|
||||||
|
kwargs = {**config["kwargs"], **kwargs}
|
||||||
|
# check whether there are nested special dict
|
||||||
|
new_args = [nested_call_by_dict(x) for x in args]
|
||||||
|
new_kwargs = {}
|
||||||
|
for key, x in kwargs.items():
|
||||||
|
new_kwargs[key] = nested_call_by_dict(x)
|
||||||
|
return cls_or_func(*new_args, **new_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def nested_call_by_yaml(path, *args, **kwargs) -> object:
|
||||||
|
config = load_yaml(path)
|
||||||
|
return nested_call_by_dict(config, *args, **kwargs)
|
13
xautodl/xmisc/yaml_utils.py
Normal file
13
xautodl/xmisc/yaml_utils.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||||
|
#####################################################
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml(path):
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
raise ValueError("{:} is not a file.".format(path))
|
||||||
|
with open(path, "r") as stream:
|
||||||
|
data = yaml.safe_load(stream)
|
||||||
|
return data
|
Loading…
Reference in New Issue
Block a user