autodl-projects/lib/nas_infer_model/__init__.py

17 lines
770 B
Python
Raw Normal View History

2019-11-15 07:15:07 +01:00
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
2019-09-28 10:24:47 +02:00
import torch
def obtain_nas_infer_model(config):
if config.arch == 'dxys':
from .DXYs import CifarNet, ImageNet, Networks
genotype = Networks[config.genotype]
if config.dataset == 'cifar':
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
elif config.dataset == 'imagenet':
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
else:
raise ValueError('invalid nas arch type : {:}'.format(config.arch))