autodl-projects/xautodl/nas_infer_model/DXYs/construct_utils.py
2021-05-18 14:08:00 +00:00

61 lines
2.2 KiB
Python

import torch
import torch.nn.functional as F
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = torch.div(x, keep_prob)
x.mul_(mask)
return x
def return_alphas_str(basemodel):
if hasattr(basemodel, 'alphas_normal'):
string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) )
else: string = ''
if hasattr(basemodel, 'alphas_reduce'):
string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
if hasattr(basemodel, 'get_adjacency'):
adjacency = basemodel.get_adjacency()
for i in range( len(adjacency) ):
weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 )
adj = torch.mm(weight, adjacency[i]).view(-1)
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj))
for i in range( len(adjacency) ):
weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 )
adj = torch.mm(weight, adjacency[i]).view(-1)
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj))
if hasattr(basemodel, 'alphas_connect'):
weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu()
ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()]
IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()]
string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN )
else:
string = string + '\nconnect = None'
if hasattr(basemodel, 'get_gcn_out'):
outputs = basemodel.get_gcn_out(True)
for i, output in enumerate(outputs):
string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) )
return string
def remove_duplicate_archs(all_archs):
archs = []
str_archs = ['{:}'.format(x) for x in all_archs]
for i, arch_x in enumerate(str_archs):
choose = True
for j in range(i):
if arch_x == str_archs[j]:
choose = False; break
if choose: archs.append(all_archs[i])
return archs