Update VIS-CODES and SCRIPTS
This commit is contained in:
		| @@ -363,9 +363,9 @@ def main(xargs): | |||||||
|   params = count_parameters_in_MB(search_model) |   params = count_parameters_in_MB(search_model) | ||||||
|   logger.log('The parameters of the search model = {:.2f} MB'.format(params)) |   logger.log('The parameters of the search model = {:.2f} MB'.format(params)) | ||||||
|   logger.log('search-space : {:}'.format(search_space)) |   logger.log('search-space : {:}'.format(search_space)) | ||||||
|   try: |   if bool(xargs.use_api): | ||||||
|     api = API(verbose=False) |     api = API(verbose=False) | ||||||
|   except: |   else: | ||||||
|     api = None |     api = None | ||||||
|   logger.log('{:} create API = {:} done'.format(time_string(), api)) |   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||||
|  |  | ||||||
| @@ -486,6 +486,7 @@ if __name__ == '__main__': | |||||||
|   parser.add_argument('--dataset'     ,       type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') |   parser.add_argument('--dataset'     ,       type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||||
|   parser.add_argument('--search_space',       type=str,   default='tss', choices=['tss'], help='The search space name.') |   parser.add_argument('--search_space',       type=str,   default='tss', choices=['tss'], help='The search space name.') | ||||||
|   parser.add_argument('--algo'        ,       type=str,   choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.') |   parser.add_argument('--algo'        ,       type=str,   choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.') | ||||||
|  |   parser.add_argument('--use_api'     ,       type=int,   default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).') | ||||||
|   # FOR GDAS |   # FOR GDAS | ||||||
|   parser.add_argument('--tau_min',            type=float, default=0.1,  help='The minimum tau for Gumbel Softmax.') |   parser.add_argument('--tau_min',            type=float, default=0.1,  help='The minimum tau for Gumbel Softmax.') | ||||||
|   parser.add_argument('--tau_max',            type=float, default=10,   help='The maximum tau for Gumbel Softmax.') |   parser.add_argument('--tau_max',            type=float, default=10,   help='The maximum tau for Gumbel Softmax.') | ||||||
|   | |||||||
| @@ -30,14 +30,16 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): | |||||||
|   ss_dir = '{:}-{:}'.format(root_dir, search_space) |   ss_dir = '{:}-{:}'.format(root_dir, search_space) | ||||||
|   alg2name, alg2path = OrderedDict(), OrderedDict() |   alg2name, alg2path = OrderedDict(), OrderedDict() | ||||||
|   seeds = [777] |   seeds = [777] | ||||||
|  |   if search_space == 'tss': | ||||||
|     alg2name['GDAS'] = 'gdas-affine0_BN0-None' |     alg2name['GDAS'] = 'gdas-affine0_BN0-None' | ||||||
|     alg2name['RSPS'] = 'random-affine0_BN0-None' |     alg2name['RSPS'] = 'random-affine0_BN0-None' | ||||||
|     alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None' |     alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None' | ||||||
|  |     alg2name['DARTS (2nd)'] = 'darts-v2-affine0_BN0-None' | ||||||
|     alg2name['ENAS'] = 'enas-affine0_BN0-None' |     alg2name['ENAS'] = 'enas-affine0_BN0-None' | ||||||
|   """ |     alg2name['SETN'] = 'setn-affine0_BN0-None' | ||||||
|   alg2name['DARTS (2nd)'] = 'darts-v2-affine1_BN0-None' |   else: | ||||||
|   alg2name['SETN'] = 'setn-affine1_BN0-None' |     alg2name['TAS'] = 'tas-affine0_BN0' | ||||||
|   """ |     alg2name['FBNetV2'] = 'fbv2-affine0_BN0' | ||||||
|   for alg, name in alg2name.items(): |   for alg, name in alg2name.items(): | ||||||
|     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') |     alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth') | ||||||
|   alg2data = OrderedDict() |   alg2data = OrderedDict() | ||||||
| @@ -66,6 +68,10 @@ y_max_s = {('cifar10', 'tss'): 94.5, | |||||||
|            ('ImageNet16-120', 'tss'): 44, |            ('ImageNet16-120', 'tss'): 44, | ||||||
|            ('ImageNet16-120', 'sss'): 46} |            ('ImageNet16-120', 'sss'): 46} | ||||||
|  |  | ||||||
|  | name2label = {'cifar10': 'CIFAR-10', | ||||||
|  |               'cifar100': 'CIFAR-100', | ||||||
|  |               'ImageNet16-120': 'ImageNet-16-120'} | ||||||
|  |  | ||||||
| def visualize_curve(api, vis_save_dir, search_space): | def visualize_curve(api, vis_save_dir, search_space): | ||||||
|   vis_save_dir = vis_save_dir.resolve() |   vis_save_dir = vis_save_dir.resolve() | ||||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) |   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
| @@ -94,8 +100,8 @@ def visualize_curve(api, vis_save_dir, search_space): | |||||||
|       alg2accuracies[alg] = accuracies |       alg2accuracies[alg] = accuracies | ||||||
|       ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg)) |       ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg)) | ||||||
|       ax.set_xlabel('The searching epoch', fontsize=LabelSize) |       ax.set_xlabel('The searching epoch', fontsize=LabelSize) | ||||||
|       ax.set_ylabel('Test accuracy on {:}'.format(dataset), fontsize=LabelSize) |       ax.set_ylabel('Test accuracy on {:}'.format(name2label[dataset]), fontsize=LabelSize) | ||||||
|       ax.set_title('Searching results on {:}'.format(dataset), fontsize=LabelSize+4) |       ax.set_title('Searching results on {:}'.format(name2label[dataset]), fontsize=LabelSize+4) | ||||||
|     ax.legend(loc=4, fontsize=LegendFontsize) |     ax.legend(loc=4, fontsize=LegendFontsize) | ||||||
|  |  | ||||||
|   fig, axs = plt.subplots(1, 3, figsize=figsize) |   fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user