61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def drop_path(x, drop_prob):
|
|
if drop_prob > 0.:
|
|
keep_prob = 1. - drop_prob
|
|
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
|
mask = mask.bernoulli_(keep_prob)
|
|
x = torch.div(x, keep_prob)
|
|
x.mul_(mask)
|
|
return x
|
|
|
|
|
|
def return_alphas_str(basemodel):
|
|
if hasattr(basemodel, 'alphas_normal'):
|
|
string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) )
|
|
else: string = ''
|
|
if hasattr(basemodel, 'alphas_reduce'):
|
|
string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
|
|
|
|
if hasattr(basemodel, 'get_adjacency'):
|
|
adjacency = basemodel.get_adjacency()
|
|
for i in range( len(adjacency) ):
|
|
weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 )
|
|
adj = torch.mm(weight, adjacency[i]).view(-1)
|
|
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
|
string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj))
|
|
for i in range( len(adjacency) ):
|
|
weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 )
|
|
adj = torch.mm(weight, adjacency[i]).view(-1)
|
|
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
|
string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj))
|
|
|
|
if hasattr(basemodel, 'alphas_connect'):
|
|
weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu()
|
|
ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()]
|
|
IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()]
|
|
string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN )
|
|
else:
|
|
string = string + '\nconnect = None'
|
|
|
|
if hasattr(basemodel, 'get_gcn_out'):
|
|
outputs = basemodel.get_gcn_out(True)
|
|
for i, output in enumerate(outputs):
|
|
string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) )
|
|
|
|
return string
|
|
|
|
|
|
def remove_duplicate_archs(all_archs):
|
|
archs = []
|
|
str_archs = ['{:}'.format(x) for x in all_archs]
|
|
for i, arch_x in enumerate(str_archs):
|
|
choose = True
|
|
for j in range(i):
|
|
if arch_x == str_archs[j]:
|
|
choose = False; break
|
|
if choose: archs.append(all_archs[i])
|
|
return archs
|