Update VIS-CODES and SCRIPTS
This commit is contained in:
parent
8d27050f6f
commit
a2a1abcb7d
@ -1,22 +1,46 @@
|
||||
#!/bin/bash
|
||||
# bash ./exps/algos-v2/run-all.sh
|
||||
# bash ./exps/algos-v2/run-all.sh mul
|
||||
# bash ./exps/algos-v2/run-all.sh ws
|
||||
set -e
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 1 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 1 parameters for type of algorithms."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
datasets="cifar10 cifar100 ImageNet16-120"
|
||||
search_spaces="tss sss"
|
||||
alg_type=$1
|
||||
|
||||
for dataset in ${datasets}
|
||||
do
|
||||
for search_space in ${search_spaces}
|
||||
if [ "$alg_type" == "mul" ]; then
|
||||
search_spaces="tss sss"
|
||||
|
||||
for dataset in ${datasets}
|
||||
do
|
||||
python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01
|
||||
python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
|
||||
python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
|
||||
python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
||||
for search_space in ${search_spaces}
|
||||
do
|
||||
python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01
|
||||
python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
|
||||
python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
|
||||
python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
python exps/experimental/vis-bench-algos.py --search_space tss
|
||||
python exps/experimental/vis-bench-algos.py --search_space sss
|
||||
python exps/experimental/vis-bench-algos.py --search_space tss
|
||||
python exps/experimental/vis-bench-algos.py --search_space sss
|
||||
else
|
||||
seeds="777 888 999"
|
||||
epoch=200
|
||||
for seed in ${seeds}
|
||||
do
|
||||
for alg in "darts-v1 darts-v2 gdas setn random enas"
|
||||
do
|
||||
python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||
python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||
python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
|
@ -22,8 +22,8 @@
|
||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
|
||||
####
|
||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas
|
||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas
|
||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
######################################################################################
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
@ -333,7 +333,11 @@ def main(xargs):
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
if xargs.overwite_epochs is None:
|
||||
extra_info = {'class_num': class_num, 'xshape': xshape}
|
||||
else:
|
||||
extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs}
|
||||
config = load_config(xargs.config_path, extra_info, logger)
|
||||
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
|
||||
(config.batch_size, config.test_batch_size), xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
@ -496,6 +500,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
|
||||
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
|
||||
parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
@ -508,8 +513,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
|
||||
args.dataset,
|
||||
'{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate))
|
||||
if args.overwite_epochs is None:
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
|
||||
args.dataset,
|
||||
'{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate))
|
||||
else:
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
|
||||
args.dataset,
|
||||
'{:}-affine{:}_BN{:}-E{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.overwite_epochs, args.drop_path_rate))
|
||||
|
||||
main(args)
|
||||
|
@ -30,12 +30,12 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
|
||||
ss_dir = '{:}-{:}'.format(root_dir, search_space)
|
||||
alg2name, alg2path = OrderedDict(), OrderedDict()
|
||||
seeds = [777]
|
||||
alg2name['GDAS'] = 'gdas-affine1_BN0-None'
|
||||
alg2name['GDAS'] = 'gdas-affine0_BN0-None'
|
||||
alg2name['RSPS'] = 'random-affine0_BN0-None'
|
||||
"""
|
||||
alg2name['DARTS (1st)'] = 'darts-v1-affine1_BN0-None'
|
||||
alg2name['DARTS (2nd)'] = 'darts-v2-affine1_BN0-None'
|
||||
alg2name['SETN'] = 'setn-affine1_BN0-None'
|
||||
alg2name['RSPS'] = 'random-affine1_BN0-None'
|
||||
"""
|
||||
for alg, name in alg2name.items():
|
||||
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
|
||||
@ -76,7 +76,7 @@ def visualize_curve(api, vis_save_dir, search_space):
|
||||
def sub_plot_fn(ax, dataset):
|
||||
alg2data = fetch_data(search_space=search_space, dataset=dataset)
|
||||
alg2accuracies = OrderedDict()
|
||||
epochs = 20
|
||||
epochs = 100
|
||||
colors = ['b', 'g', 'c', 'm', 'y']
|
||||
ax.set_xlim(0, epochs)
|
||||
# ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
|
||||
|
Loading…
Reference in New Issue
Block a user