2020-02-23 00:30:37 +01:00
|
|
|
#####################################################
|
|
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
|
|
|
#####################################################
|
2020-01-14 14:52:06 +01:00
|
|
|
# I write this package to make AutoDL-Projects to be compatible with the old GDAS projects.
|
|
|
|
# Ideally, this package will be merged into lib/models/cell_infers in future.
|
|
|
|
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
|
|
|
|
##################################################
|
|
|
|
|
2020-03-06 09:29:07 +01:00
|
|
|
import os, torch
|
2019-09-28 10:24:47 +02:00
|
|
|
|
2020-03-06 09:29:07 +01:00
|
|
|
def obtain_nas_infer_model(config, extra_model_path=None):
|
|
|
|
|
2019-09-28 10:24:47 +02:00
|
|
|
if config.arch == 'dxys':
|
|
|
|
from .DXYs import CifarNet, ImageNet, Networks
|
2020-03-06 09:29:07 +01:00
|
|
|
from .DXYs import build_genotype_from_dict
|
|
|
|
if config.genotype is None:
|
|
|
|
if extra_model_path is not None and not os.path.isfile(extra_model_path):
|
|
|
|
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
|
|
|
|
xdata = torch.load(extra_model_path)
|
|
|
|
current_epoch = xdata['epoch']
|
|
|
|
genotype_dict = xdata['genotypes'][current_epoch-1]
|
|
|
|
genotype = build_genotype_from_dict(genotype_dict)
|
|
|
|
else:
|
|
|
|
genotype = Networks[config.genotype]
|
2019-09-28 10:24:47 +02:00
|
|
|
if config.dataset == 'cifar':
|
|
|
|
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
|
|
|
|
elif config.dataset == 'imagenet':
|
|
|
|
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
|
|
|
|
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
|
|
|
|
else:
|
|
|
|
raise ValueError('invalid nas arch type : {:}'.format(config.arch))
|