correct query-point predictions (#32)
This commit is contained in:
		| @@ -152,6 +152,21 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|             visibilities = visibilities[:, :, : -self.support_grid_size ** 2] | ||||
|         thr = 0.9 | ||||
|         visibilities = visibilities > thr | ||||
|  | ||||
|         # correct query-point predictions | ||||
|         # see https://github.com/facebookresearch/co-tracker/issues/28 | ||||
|  | ||||
|         # TODO: batchify | ||||
|         for i in range(len(queries)): | ||||
|             queries_t = queries[i, :tracks.size(2), 0].to(torch.int64) | ||||
|             arange = torch.arange(0, len(queries_t)) | ||||
|  | ||||
|             # overwrite the predictions with the query points | ||||
|             tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:] | ||||
|  | ||||
|             # correct visibilities, the query points should be visible | ||||
|             visibilities[i, queries_t, arange] = True | ||||
|  | ||||
|         tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) | ||||
|         tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) | ||||
|         return tracks, visibilities | ||||
|   | ||||
		Reference in New Issue
	
	Block a user