122 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			122 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2021 Samsung Electronics Co., Ltd.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| 
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| # =============================================================================
 | |
| 
 | |
| import os
 | |
| import argparse
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| from .nasbench2_ops import *
 | |
| 
 | |
| 
 | |
| def gen_searchcell_mask_from_arch_str(arch_str):
 | |
|     nodes = arch_str.split('+') 
 | |
|     nodes = [node[1:-1].split('|') for node in nodes]
 | |
|     nodes = [[op_and_input.split('~')  for op_and_input in node] for node in nodes]
 | |
| 
 | |
|     keep_mask = []
 | |
|     for curr_node_idx in range(len(nodes)):
 | |
|             for prev_node_idx in range(curr_node_idx+1): 
 | |
|                 _op = [edge[0] for edge in nodes[curr_node_idx] if int(edge[1]) == prev_node_idx]
 | |
|                 assert len(_op) == 1, 'The arch string does not follow the assumption of 1 connection between two nodes.'
 | |
|                 for _op_name in OPS.keys():
 | |
|                     keep_mask.append(_op[0] == _op_name)
 | |
|     return keep_mask
 | |
| 
 | |
| 
 | |
| def get_model_from_arch_str(arch_str, num_classes, use_bn=True, init_channels=16):
 | |
|     keep_mask = gen_searchcell_mask_from_arch_str(arch_str)
 | |
|     net = NAS201Model(arch_str=arch_str, num_classes=num_classes, use_bn=use_bn, keep_mask=keep_mask, stem_ch=init_channels)
 | |
|     return net
 | |
| 
 | |
| 
 | |
| def get_super_model(num_classes, use_bn=True):
 | |
|     net = NAS201Model(arch_str=arch_str, num_classes=num_classes, use_bn=use_bn)
 | |
|     return net
 | |
| 
 | |
| 
 | |
| class NAS201Model(nn.Module):
 | |
| 
 | |
|     def __init__(self, arch_str, num_classes, use_bn=True, keep_mask=None, stem_ch=16):
 | |
|         super(NAS201Model, self).__init__()
 | |
|         self.arch_str=arch_str
 | |
|         self.num_classes=num_classes
 | |
|         self.use_bn= use_bn
 | |
| 
 | |
|         self.stem = stem(out_channels=stem_ch, use_bn=use_bn)
 | |
|         self.stack_cell1 = nn.Sequential(*[SearchCell(in_channels=stem_ch, out_channels=stem_ch, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)])
 | |
|         self.reduction1 = reduction(in_channels=stem_ch, out_channels=stem_ch*2)
 | |
|         self.stack_cell2 = nn.Sequential(*[SearchCell(in_channels=stem_ch*2, out_channels=stem_ch*2, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)])
 | |
|         self.reduction2 = reduction(in_channels=stem_ch*2, out_channels=stem_ch*4)
 | |
|         self.stack_cell3 = nn.Sequential(*[SearchCell(in_channels=stem_ch*4, out_channels=stem_ch*4, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)])
 | |
|         self.top = top(in_dims=stem_ch*4, num_classes=num_classes, use_bn=use_bn)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x = self.stem(x)        
 | |
| 
 | |
|         x = self.stack_cell1(x)
 | |
|         x = self.reduction1(x)
 | |
| 
 | |
|         x = self.stack_cell2(x)
 | |
|         x = self.reduction2(x)
 | |
| 
 | |
|         x = self.stack_cell3(x)
 | |
| 
 | |
|         x = self.top(x)
 | |
|         return x
 | |
|     
 | |
|     def get_prunable_copy(self, bn=False):
 | |
|         model_new = get_model_from_arch_str(self.arch_str, self.num_classes, use_bn=bn)
 | |
| 
 | |
|         #TODO this is quite brittle and doesn't work with nn.Sequential when bn is different
 | |
|         # it is only required to maintain initialization -- maybe init after get_punable_copy?
 | |
|         model_new.load_state_dict(self.state_dict(), strict=False)
 | |
|         model_new.train()
 | |
| 
 | |
|         return model_new
 | |
|     
 | |
| 
 | |
| def get_arch_str_from_model(net):
 | |
|     search_cell = net.stack_cell1[0].options
 | |
|     keep_mask = net.stack_cell1[0].keep_mask
 | |
|     num_nodes = net.stack_cell1[0].num_nodes
 | |
| 
 | |
|     nodes = []
 | |
|     idx = 0
 | |
|     for curr_node in range(num_nodes -1):
 | |
|         edges = []
 | |
|         for prev_node in range(curr_node+1): # n-1 prev nodes
 | |
|             for _op_name in OPS.keys():
 | |
|                 if keep_mask[idx]:
 | |
|                     edges.append(f'{_op_name}~{prev_node}')
 | |
|                 idx += 1
 | |
|         node_str = '|'.join(edges)
 | |
|         node_str = f'|{node_str}|'
 | |
|         nodes.append(node_str) 
 | |
|     arch_str = '+'.join(nodes)
 | |
|     return arch_str
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     arch_str = '|nor_conv_3x3~0|+|none~0|none~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'
 | |
|     
 | |
|     n = get_model_from_arch_str(arch_str=arch_str, num_classes=10)
 | |
|     print(n.stack_cell1[0])
 | |
|     
 | |
|     arch_str2 = get_arch_str_from_model(n)
 | |
|     print(arch_str)
 | |
|     print(arch_str2)
 | |
|     print(f'Are the two arch strings same? {arch_str == arch_str2}')
 |