update README
This commit is contained in:
		@@ -22,6 +22,8 @@ from utils import get_model_infos
 | 
			
		||||
flop, param  = get_model_infos(net, (1,3,32,32))
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/NAS-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717)
 | 
			
		||||
In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network.
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								exps/vis/random-nn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								exps/vis/random-nn.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from graphviz import Digraph
 | 
			
		||||
							
								
								
									
										67
									
								
								exps/vis/show-results.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								exps/vis/show-results.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,67 @@
 | 
			
		||||
# python ./vis-exps/show-results.py
 | 
			
		||||
import os, sys
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 | 
			
		||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 | 
			
		||||
 | 
			
		||||
from aa_nas_api   import AANASBenchAPI
 | 
			
		||||
 | 
			
		||||
api = AANASBenchAPI('./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth')
 | 
			
		||||
 | 
			
		||||
def plot_results_nas(dataset, xset, file_name, y_lims):
 | 
			
		||||
  import matplotlib
 | 
			
		||||
  matplotlib.use('agg')
 | 
			
		||||
  import matplotlib.pyplot as plt
 | 
			
		||||
  root = Path('./output/cell-search-tiny-vis').resolve()
 | 
			
		||||
  print ('root path : {:}'.format( root ))
 | 
			
		||||
  root.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
  checkpoints = ['./output/cell-search-tiny/R-EA-cifar10/results.pth',
 | 
			
		||||
                 './output/cell-search-tiny/REINFORCE-cifar10/results.pth',
 | 
			
		||||
                 './output/cell-search-tiny/RAND-cifar10/results.pth',
 | 
			
		||||
                 './output/cell-search-tiny/BOHB-cifar10/results.pth'
 | 
			
		||||
                ]
 | 
			
		||||
  legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
 | 
			
		||||
  All_Accs = OrderedDict()
 | 
			
		||||
  for legend, checkpoint in zip(legends, checkpoints):
 | 
			
		||||
    all_indexes = torch.load(checkpoint, map_location='cpu')
 | 
			
		||||
    accuracies  = []
 | 
			
		||||
    for x in all_indexes:
 | 
			
		||||
      info = api.arch2infos[ x ]
 | 
			
		||||
      _, accy = info.get_metrics(dataset, xset, None, False)
 | 
			
		||||
      accuracies.append( accy )
 | 
			
		||||
    if indexes is None: indexes = list(range(len(all_indexes)))
 | 
			
		||||
    All_Accs[legend] = sorted(accuracies)
 | 
			
		||||
  
 | 
			
		||||
  color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
 | 
			
		||||
  dpi, width, height = 300, 3400, 2600
 | 
			
		||||
  LabelSize, LegendFontsize = 26, 26
 | 
			
		||||
  figsize = width / float(dpi), height / float(dpi)
 | 
			
		||||
  fig = plt.figure(figsize=figsize)
 | 
			
		||||
  x_axis = np.arange(0, 600)
 | 
			
		||||
  plt.xlim(0, max(indexes))
 | 
			
		||||
  plt.ylim(y_lims[0], y_lims[1])
 | 
			
		||||
  interval_x, interval_y = 100, y_lims[2]
 | 
			
		||||
  plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
 | 
			
		||||
  plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
 | 
			
		||||
  plt.grid()
 | 
			
		||||
  plt.xlabel('The index of runs', fontsize=LabelSize)
 | 
			
		||||
  plt.ylabel('The accuracy (%)', fontsize=LabelSize)
 | 
			
		||||
 | 
			
		||||
  for idx, legend in enumerate(legends):
 | 
			
		||||
    plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2)
 | 
			
		||||
    print ('{:} : mean = {:}, std = {:}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
 | 
			
		||||
  plt.legend(loc=4, fontsize=LegendFontsize)
 | 
			
		||||
  save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name)
 | 
			
		||||
  print('save figure into {:}\n'.format(save_path))
 | 
			
		||||
  fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  plot_results_nas('cifar10', 'ori-test', 'nas-com.pdf', (85,95, 1))
 | 
			
		||||
  plot_results_nas('cifar100', 'x-valid', 'nas-com.pdf', (55,75, 3))
 | 
			
		||||
  plot_results_nas('cifar100', 'x-test' , 'nas-com.pdf', (55,75, 3))
 | 
			
		||||
  plot_results_nas('ImageNet16-120', 'x-valid', 'nas-com.pdf', (35,50, 3))
 | 
			
		||||
  plot_results_nas('ImageNet16-120', 'x-test' , 'nas-com.pdf', (35,50, 3))
 | 
			
		||||
@@ -74,20 +74,15 @@ class Structure:
 | 
			
		||||
      nodes[i+1] = sum(sums) > 0
 | 
			
		||||
    return nodes[len(self.nodes)]
 | 
			
		||||
 | 
			
		||||
  def to_unique_str(self, consider_zero=False):
 | 
			
		||||
  def to_unique_str(self):
 | 
			
		||||
    # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
 | 
			
		||||
    # two operations are special, i.e., none and skip_connect
 | 
			
		||||
    nodes = {0: '0'}
 | 
			
		||||
    for i_node, node_info in enumerate(self.nodes):
 | 
			
		||||
      cur_node = []
 | 
			
		||||
      for op, xin in node_info:
 | 
			
		||||
        if consider_zero:
 | 
			
		||||
          if op == 'none' or nodes[xin] == '#': x = '#' # zero
 | 
			
		||||
          elif op == 'skip_connect': x = nodes[xin]
 | 
			
		||||
          else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
 | 
			
		||||
        else:
 | 
			
		||||
          if op == 'skip_connect': x = nodes[xin]
 | 
			
		||||
          else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
 | 
			
		||||
        if op == 'skip_connect': x = nodes[xin]
 | 
			
		||||
        else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
 | 
			
		||||
        cur_node.append(x)
 | 
			
		||||
      nodes[i_node+1] = '+'.join( sorted(cur_node) )
 | 
			
		||||
    return nodes[ len(self.nodes) ]
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user