correct query-point predictions (#32)
This commit is contained in:
		| @@ -152,6 +152,21 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|             visibilities = visibilities[:, :, : -self.support_grid_size ** 2] |             visibilities = visibilities[:, :, : -self.support_grid_size ** 2] | ||||||
|         thr = 0.9 |         thr = 0.9 | ||||||
|         visibilities = visibilities > thr |         visibilities = visibilities > thr | ||||||
|  |  | ||||||
|  |         # correct query-point predictions | ||||||
|  |         # see https://github.com/facebookresearch/co-tracker/issues/28 | ||||||
|  |  | ||||||
|  |         # TODO: batchify | ||||||
|  |         for i in range(len(queries)): | ||||||
|  |             queries_t = queries[i, :tracks.size(2), 0].to(torch.int64) | ||||||
|  |             arange = torch.arange(0, len(queries_t)) | ||||||
|  |  | ||||||
|  |             # overwrite the predictions with the query points | ||||||
|  |             tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:] | ||||||
|  |  | ||||||
|  |             # correct visibilities, the query points should be visible | ||||||
|  |             visibilities[i, queries_t, arange] = True | ||||||
|  |  | ||||||
|         tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) |         tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) | ||||||
|         tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) |         tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) | ||||||
|         return tracks, visibilities |         return tracks, visibilities | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user