| 
									
										
										
										
											2019-11-15 17:15:07 +11:00
										 |  |  | ################################################## | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | 
					
						
							|  |  |  | ################################################## | 
					
						
							| 
									
										
										
										
											2019-09-28 18:24:47 +10:00
										 |  |  | import torch | 
					
						
							|  |  |  | import torch.nn as nn | 
					
						
							|  |  |  | from SoftSelect import ChannelWiseInter | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   tensors = torch.rand((16, 128, 7, 7)) | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   for oc in range(200, 210): | 
					
						
							|  |  |  |     out_v1  = ChannelWiseInter(tensors, oc, 'v1') | 
					
						
							|  |  |  |     out_v2  = ChannelWiseInter(tensors, oc, 'v2') | 
					
						
							|  |  |  |     assert (out_v1 == out_v2).any().item() == 1 | 
					
						
							|  |  |  |   for oc in range(48, 160): | 
					
						
							|  |  |  |     out_v1  = ChannelWiseInter(tensors, oc, 'v1') | 
					
						
							|  |  |  |     out_v2  = ChannelWiseInter(tensors, oc, 'v2') | 
					
						
							|  |  |  |     assert (out_v1 == out_v2).any().item() == 1 |