Prototype generic nas model (cont.).

This commit is contained in:
D-X-Y 2020-07-19 09:07:05 +00:00
parent 31a896346a
commit b9a5d2880f
2 changed files with 17 additions and 11 deletions

View File

@ -1,21 +1,25 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
###################################################################################### ######################################################################################
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 1 # python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 777
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3 # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1 # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
#### ####
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 1 # python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 777
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2 # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2 # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2
#### ####
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 1 # python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 777
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas
#### ####
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 1 # python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 777
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
####
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
###################################################################################### ######################################################################################
import os, sys, time, random, argparse import os, sys, time, random, argparse
import numpy as np import numpy as np
@ -268,7 +272,7 @@ def main(xargs):
logger.log('The parameters of the search model = {:.2f} MB'.format(params)) logger.log('The parameters of the search model = {:.2f} MB'.format(params))
logger.log('search-space : {:}'.format(search_space)) logger.log('search-space : {:}'.format(search_space))
try: try:
api = API() api = API(verbose=False)
except: except:
api = None api = None
logger.log('{:} create API = {:} done'.format(time_string(), api)) logger.log('{:} create API = {:} done'.format(time_string(), api))
@ -385,10 +389,10 @@ if __name__ == '__main__':
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.') parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.') parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
# #
parser.add_argument('--eval_candidate_num', type=int, help='The number of selected architectures to evaluate.') parser.add_argument('--eval_candidate_num', type=int, default=100, help='The number of selected architectures to evaluate.')
# #
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.') parser.add_argument('--affine' , type=int, default=1, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.') parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
# architecture leraning rate # architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
@ -401,6 +405,8 @@ if __name__ == '__main__':
parser.add_argument('--rand_seed', type=int, help='manual seed') parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args() args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, '{:}-{:}'.format(args.algo, args.drop_path_rate)) args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
args.dataset,
'{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate))
main(args) main(args)

View File

@ -36,7 +36,7 @@ class ReLUConvBN(nn.Module):
super(ReLUConvBN, self).__init__() super(ReLUConvBN, self).__init__()
self.op = nn.Sequential( self.op = nn.Sequential(
nn.ReLU(inplace=False), nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
) )
@ -51,7 +51,7 @@ class SepConv(nn.Module):
self.op = nn.Sequential( self.op = nn.Sequential(
nn.ReLU(inplace=False), nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
) )
@ -171,7 +171,7 @@ class FactorizedReduce(nn.Module):
C_outs = [C_out // 2, C_out - C_out // 2] C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
for i in range(2): for i in range(2):
self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) ) self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1: elif stride == 1:
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False) self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)