update code styles
This commit is contained in:
		| @@ -8,7 +8,6 @@ from tqdm import tqdm | ||||
| from collections import OrderedDict | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| from collections import defaultdict | ||||
| import matplotlib | ||||
| @@ -498,6 +497,8 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_ | ||||
|  | ||||
|   def get_accs(xdata): | ||||
|     epochs, xresults = xdata['epoch'], [] | ||||
|     metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False) | ||||
|     xresults.append( metrics['accuracy'] ) | ||||
|     for iepoch in range(epochs): | ||||
|       genotype = xdata['genotypes'][iepoch] | ||||
|       index = api.query_index_by_arch(genotype) | ||||
| @@ -547,7 +548,6 @@ if __name__ == '__main__': | ||||
|   #visualize_relative_ranking(vis_save_dir) | ||||
|  | ||||
|   api = API(args.api_path) | ||||
|   """ | ||||
|   for x_maxs in [50, 250]: | ||||
|     show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
| @@ -555,11 +555,12 @@ if __name__ == '__main__': | ||||
|     show_nas_sharing_w(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|     show_nas_sharing_w(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs) | ||||
|   just_show(api) | ||||
|   """ | ||||
|   just_show(api) | ||||
|   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||
|   plot_results_nas(api, 'cifar10'       , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) | ||||
|   plot_results_nas(api, 'cifar100'      , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||
|   plot_results_nas(api, 'cifar100'      , 'x-test'  , vis_save_dir, 'nas-com.pdf', (55,75, 3)) | ||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||
|   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , vis_save_dir, 'nas-com.pdf', (35,50, 3)) | ||||
|   """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user