35 lines
		
	
	
		
			828 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			35 lines
		
	
	
		
			828 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| #####################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
 | |
| #####################################################
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| 
 | |
| 
 | |
| def additive_func(A, B):
 | |
|   assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
 | |
|   C = min(A.size(1), B.size(1))
 | |
|   if A.size(1) == B.size(1):
 | |
|     return A + B
 | |
|   elif A.size(1) < B.size(1):
 | |
|     out = B.clone()
 | |
|     out[:,:C] += A
 | |
|     return out
 | |
|   else:
 | |
|     out = A.clone()
 | |
|     out[:,:C] += B
 | |
|     return out
 | |
| 
 | |
| 
 | |
| def change_key(key, value):
 | |
|   def func(m):
 | |
|     if hasattr(m, key):
 | |
|       setattr(m, key, value)
 | |
|   return func
 | |
| 
 | |
| 
 | |
| def parse_channel_info(xstring):
 | |
|   blocks = xstring.split(' ')
 | |
|   blocks = [x.split('-') for x in blocks]
 | |
|   blocks = [[int(_) for _ in x] for x in blocks]
 | |
|   return blocks
 |