61 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			61 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import torch | ||
|  | import torch.nn.functional as F | ||
|  | 
 | ||
|  | 
 | ||
|  | def drop_path(x, drop_prob): | ||
|  |   if drop_prob > 0.: | ||
|  |     keep_prob = 1. - drop_prob | ||
|  |     mask = x.new_zeros(x.size(0), 1, 1, 1) | ||
|  |     mask = mask.bernoulli_(keep_prob) | ||
|  |     x = torch.div(x, keep_prob) | ||
|  |     x.mul_(mask) | ||
|  |   return x | ||
|  | 
 | ||
|  | 
 | ||
|  | def return_alphas_str(basemodel): | ||
|  |   if hasattr(basemodel, 'alphas_normal'): | ||
|  |     string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) ) | ||
|  |   else: string = '' | ||
|  |   if hasattr(basemodel, 'alphas_reduce'): | ||
|  |     string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) ) | ||
|  | 
 | ||
|  |   if hasattr(basemodel, 'get_adjacency'): | ||
|  |     adjacency = basemodel.get_adjacency() | ||
|  |     for i in range( len(adjacency) ): | ||
|  |       weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 ) | ||
|  |       adj = torch.mm(weight, adjacency[i]).view(-1) | ||
|  |       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()] | ||
|  |       string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj)) | ||
|  |     for i in range( len(adjacency) ): | ||
|  |       weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 ) | ||
|  |       adj = torch.mm(weight, adjacency[i]).view(-1) | ||
|  |       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()] | ||
|  |       string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj)) | ||
|  | 
 | ||
|  |   if hasattr(basemodel, 'alphas_connect'): | ||
|  |     weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu() | ||
|  |     ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()] | ||
|  |     IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()] | ||
|  |     string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN ) | ||
|  |   else: | ||
|  |     string = string + '\nconnect = None' | ||
|  |    | ||
|  |   if hasattr(basemodel, 'get_gcn_out'): | ||
|  |     outputs = basemodel.get_gcn_out(True) | ||
|  |     for i, output in enumerate(outputs): | ||
|  |       string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) ) | ||
|  | 
 | ||
|  |   return string | ||
|  | 
 | ||
|  | 
 | ||
|  | def remove_duplicate_archs(all_archs): | ||
|  |   archs = [] | ||
|  |   str_archs = ['{:}'.format(x) for x in all_archs] | ||
|  |   for i, arch_x in enumerate(str_archs): | ||
|  |     choose = True | ||
|  |     for j in range(i): | ||
|  |       if arch_x == str_archs[j]: | ||
|  |         choose = False; break | ||
|  |     if choose: archs.append(all_archs[i]) | ||
|  |   return archs |