2019-11-15 07:15:07 +01:00
|
|
|
##################################################
|
|
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
|
|
##################################################
|
2021-03-07 04:09:47 +01:00
|
|
|
from .starts import prepare_seed
|
|
|
|
from .starts import prepare_logger
|
|
|
|
from .starts import get_machine_info
|
|
|
|
from .starts import save_checkpoint
|
|
|
|
from .starts import copy_checkpoint
|
2019-09-28 10:24:47 +02:00
|
|
|
from .optimizers import get_optim_scheduler
|
2020-03-09 09:38:00 +01:00
|
|
|
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
|
|
|
|
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
|
2020-03-10 09:08:56 +01:00
|
|
|
from .funcs_nasbench import get_nas_bench_loaders
|
2019-09-28 10:24:47 +02:00
|
|
|
|
2021-03-07 04:09:47 +01:00
|
|
|
|
2019-09-28 10:24:47 +02:00
|
|
|
def get_procedures(procedure):
|
2021-03-07 04:09:47 +01:00
|
|
|
from .basic_main import basic_train, basic_valid
|
|
|
|
from .search_main import search_train, search_valid
|
|
|
|
from .search_main_v2 import search_train_v2
|
|
|
|
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
|
|
|
|
|
|
|
train_funcs = {
|
|
|
|
"basic": basic_train,
|
|
|
|
"search": search_train,
|
|
|
|
"Simple-KD": simple_KD_train,
|
|
|
|
"search-v2": search_train_v2,
|
|
|
|
}
|
|
|
|
valid_funcs = {
|
|
|
|
"basic": basic_valid,
|
|
|
|
"search": search_valid,
|
|
|
|
"Simple-KD": simple_KD_valid,
|
|
|
|
"search-v2": search_valid,
|
|
|
|
}
|
2019-09-28 10:24:47 +02:00
|
|
|
|
2021-03-07 04:09:47 +01:00
|
|
|
train_func = train_funcs[procedure]
|
|
|
|
valid_func = valid_funcs[procedure]
|
|
|
|
return train_func, valid_func
|