fixed problems with variational dropout
This commit is contained in:
		| @@ -133,8 +133,20 @@ class SmallUpdateBlock(nn.Module): | ||||
|         self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) | ||||
|         self.flow_head = FlowHead(hidden_dim, hidden_dim=128) | ||||
|  | ||||
|         self.drop_inp = VariationalHidDropout(dropout=args.dropout) | ||||
|         self.drop_net = VariationalHidDropout(dropout=args.dropout) | ||||
|  | ||||
|     def reset_mask(self, net, inp): | ||||
|         self.drop_inp.reset_mask(inp) | ||||
|         self.drop_net.reset_mask(net) | ||||
|  | ||||
|     def forward(self, net, inp, corr, flow): | ||||
|         motion_features = self.encoder(flow, corr) | ||||
|  | ||||
|         if self.training: | ||||
|             net = self.drop_net(net) | ||||
|             inp = self.drop_inp(inp) | ||||
|  | ||||
|         inp = torch.cat([inp, motion_features], dim=1) | ||||
|         net = self.gru(net, inp) | ||||
|         delta_flow = self.flow_head(net) | ||||
| @@ -157,12 +169,12 @@ class BasicUpdateBlock(nn.Module): | ||||
|  | ||||
|     def forward(self, net, inp, corr, flow): | ||||
|         motion_features = self.encoder(flow, corr) | ||||
|         inp = torch.cat([inp, motion_features], dim=1) | ||||
|  | ||||
|         if self.training: | ||||
|             net = self.drop_net(net) | ||||
|             inp = self.drop_inp(inp) | ||||
|  | ||||
|          | ||||
|         inp = torch.cat([inp, motion_features], dim=1) | ||||
|         net = self.gru(net, inp) | ||||
|         delta_flow = self.flow_head(net) | ||||
|  | ||||
|   | ||||
| @@ -26,7 +26,7 @@ class RAFT(nn.Module): | ||||
|             args.corr_levels = 4 | ||||
|             args.corr_radius = 4 | ||||
|  | ||||
|         if 'dropout' not in args._get_kwargs(): | ||||
|         if not hasattr(args, 'dropout'): | ||||
|             args.dropout = 0 | ||||
|  | ||||
|         # feature network, context network, and update block | ||||
|   | ||||
		Reference in New Issue
	
	Block a user