fixed bug with alternate_corr flag
This commit is contained in:
		
							
								
								
									
										30
									
								
								core/corr.py
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								core/corr.py
									
									
									
									
									
								
							| @@ -60,26 +60,6 @@ class CorrBlock: | ||||
|         return corr  / torch.sqrt(torch.tensor(dim).float()) | ||||
|  | ||||
|  | ||||
| class CorrLayer(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     def forward(ctx, fmap1, fmap2, coords, r): | ||||
|         fmap1 = fmap1.contiguous() | ||||
|         fmap2 = fmap2.contiguous() | ||||
|         coords = coords.contiguous() | ||||
|         ctx.save_for_backward(fmap1, fmap2, coords) | ||||
|         ctx.r = r | ||||
|         corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) | ||||
|         return corr | ||||
|  | ||||
|     @staticmethod | ||||
|     def backward(ctx, grad_corr): | ||||
|         fmap1, fmap2, coords = ctx.saved_tensors | ||||
|         grad_corr = grad_corr.contiguous() | ||||
|         fmap1_grad, fmap2_grad, coords_grad = \ | ||||
|             correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) | ||||
|         return fmap1_grad, fmap2_grad, coords_grad, None | ||||
|  | ||||
|  | ||||
| class AlternateCorrBlock: | ||||
|     def __init__(self, fmap1, fmap2, num_levels=4, radius=4): | ||||
|         self.num_levels = num_levels | ||||
| @@ -92,20 +72,20 @@ class AlternateCorrBlock: | ||||
|             self.pyramid.append((fmap1, fmap2)) | ||||
|  | ||||
|     def __call__(self, coords): | ||||
|  | ||||
|         coords = coords.permute(0, 2, 3, 1) | ||||
|         B, H, W, _ = coords.shape | ||||
|         dim = self.pyramid[0][0].shape[1] | ||||
|  | ||||
|         corr_list = [] | ||||
|         for i in range(self.num_levels): | ||||
|             r = self.radius | ||||
|             fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) | ||||
|             fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) | ||||
|             fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() | ||||
|             fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() | ||||
|  | ||||
|             coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() | ||||
|             corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) | ||||
|             corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) | ||||
|             corr_list.append(corr.squeeze(1)) | ||||
|  | ||||
|         corr = torch.stack(corr_list, dim=1) | ||||
|         corr = corr.reshape(B, -1, H, W) | ||||
|         return corr / 16.0 | ||||
|         return corr / torch.sqrt(torch.tensor(dim).float()) | ||||
|   | ||||
							
								
								
									
										11
									
								
								core/raft.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								core/raft.py
									
									
									
									
									
								
							| @@ -38,11 +38,11 @@ class RAFT(nn.Module): | ||||
|             args.corr_levels = 4 | ||||
|             args.corr_radius = 4 | ||||
|  | ||||
|         if 'dropout' not in args._get_kwargs(): | ||||
|             args.dropout = 0 | ||||
|         if 'dropout' not in self.args: | ||||
|             self.args.dropout = 0 | ||||
|  | ||||
|         if 'alternate_corr' not in args._get_kwargs(): | ||||
|             args.alternate_corr = False | ||||
|         if 'alternate_corr' not in self.args: | ||||
|             self.args.alternate_corr = False | ||||
|  | ||||
|         # feature network, context network, and update block | ||||
|         if args.small: | ||||
| @@ -55,7 +55,6 @@ class RAFT(nn.Module): | ||||
|             self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) | ||||
|             self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) | ||||
|  | ||||
|  | ||||
|     def freeze_bn(self): | ||||
|         for m in self.modules(): | ||||
|             if isinstance(m, nn.BatchNorm2d): | ||||
| @@ -103,7 +102,7 @@ class RAFT(nn.Module): | ||||
|         fmap1 = fmap1.float() | ||||
|         fmap2 = fmap2.float() | ||||
|         if self.args.alternate_corr: | ||||
|             corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius) | ||||
|             corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) | ||||
|         else: | ||||
|             corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user