add cpu-only mode
This commit is contained in:
		| @@ -185,7 +185,11 @@ class Evaluator: | |||||||
|                 if not all(gotit): |                 if not all(gotit): | ||||||
|                     print("batch is None") |                     print("batch is None") | ||||||
|                     continue |                     continue | ||||||
|  |             if torch.cuda.is_available(): | ||||||
|                 dataclass_to_cuda_(sample) |                 dataclass_to_cuda_(sample) | ||||||
|  |                 device = torch.device("cuda") | ||||||
|  |             else: | ||||||
|  |                 device = torch.device("cpu") | ||||||
|  |  | ||||||
|             if ( |             if ( | ||||||
|                 not train_mode |                 not train_mode | ||||||
| @@ -205,7 +209,7 @@ class Evaluator: | |||||||
|                         queries[:, :, 1], |                         queries[:, :, 1], | ||||||
|                     ], |                     ], | ||||||
|                     dim=2, |                     dim=2, | ||||||
|                 ) |                 ).to(device) | ||||||
|             else: |             else: | ||||||
|                 queries = torch.cat( |                 queries = torch.cat( | ||||||
|                     [ |                     [ | ||||||
| @@ -213,7 +217,7 @@ class Evaluator: | |||||||
|                         sample.trajectory[:, 0], |                         sample.trajectory[:, 0], | ||||||
|                     ], |                     ], | ||||||
|                     dim=2, |                     dim=2, | ||||||
|                 ) |                 ).to(device) | ||||||
|  |  | ||||||
|             pred_tracks = model(sample.video, queries) |             pred_tracks = model(sample.video, queries) | ||||||
|             if "strided" in dataset_name: |             if "strided" in dataset_name: | ||||||
|   | |||||||
| @@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig): | |||||||
|         single_point=cfg.single_point, |         single_point=cfg.single_point, | ||||||
|         n_iters=cfg.n_iters, |         n_iters=cfg.n_iters, | ||||||
|     ) |     ) | ||||||
|  |     if torch.cuda.is_available(): | ||||||
|  |         predictor.model = predictor.model.cuda() | ||||||
|  |  | ||||||
|     # Setting the random seeds |     # Setting the random seeds | ||||||
|     torch.manual_seed(cfg.seed) |     torch.manual_seed(cfg.seed) | ||||||
|   | |||||||
| @@ -25,14 +25,14 @@ from cotracker.models.core.embeddings import ( | |||||||
| torch.manual_seed(0) | torch.manual_seed(0) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): | def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"): | ||||||
|     if grid_size == 1: |     if grid_size == 1: | ||||||
|         return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ |         return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[ | ||||||
|             None, None |             None, None | ||||||
|         ].cuda() |         ] | ||||||
|  |  | ||||||
|     grid_y, grid_x = meshgrid2d( |     grid_y, grid_x = meshgrid2d( | ||||||
|         1, grid_size, grid_size, stack=False, norm=False, device="cuda" |         1, grid_size, grid_size, stack=False, norm=False, device=device | ||||||
|     ) |     ) | ||||||
|     step = interp_shape[1] // 64 |     step = interp_shape[1] // 64 | ||||||
|     if grid_center[0] != 0 or grid_center[1] != 0: |     if grid_center[0] != 0 or grid_center[1] != 0: | ||||||
| @@ -47,7 +47,7 @@ def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): | |||||||
|  |  | ||||||
|     grid_y = grid_y + grid_center[0] |     grid_y = grid_y + grid_center[0] | ||||||
|     grid_x = grid_x + grid_center[1] |     grid_x = grid_x + grid_center[1] | ||||||
|     xy = torch.stack([grid_x, grid_y], dim=-1).cuda() |     xy = torch.stack([grid_x, grid_y], dim=-1).to(device) | ||||||
|     return xy |     return xy | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module): | |||||||
|         self.n_iters = n_iters |         self.n_iters = n_iters | ||||||
|  |  | ||||||
|         self.model = cotracker_model |         self.model = cotracker_model | ||||||
|         self.model.to("cuda") |  | ||||||
|         self.model.eval() |         self.model.eval() | ||||||
|  |  | ||||||
|     def forward(self, video, queries): |     def forward(self, video, queries): | ||||||
|         queries = queries.clone().cuda() |         queries = queries.clone() | ||||||
|         B, T, C, H, W = video.shape |         B, T, C, H, W = video.shape | ||||||
|         B, N, D = queries.shape |         B, N, D = queries.shape | ||||||
|  |  | ||||||
| @@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module): | |||||||
|  |  | ||||||
|         rgbs = video.reshape(B * T, C, H, W) |         rgbs = video.reshape(B * T, C, H, W) | ||||||
|         rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear") |         rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear") | ||||||
|         rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda() |         rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) | ||||||
|  |  | ||||||
|  |         device = rgbs.device | ||||||
|  |  | ||||||
|         queries[:, :, 1] *= self.interp_shape[1] / W |         queries[:, :, 1] *= self.interp_shape[1] / W | ||||||
|         queries[:, :, 2] *= self.interp_shape[0] / H |         queries[:, :, 2] *= self.interp_shape[0] / H | ||||||
|  |  | ||||||
|         if self.single_point: |         if self.single_point: | ||||||
|             traj_e = torch.zeros((B, T, N, 2)).cuda() |             traj_e = torch.zeros((B, T, N, 2), device=device) | ||||||
|             vis_e = torch.zeros((B, T, N)).cuda() |             vis_e = torch.zeros((B, T, N), device=device) | ||||||
|             for pind in range((N)): |             for pind in range((N)): | ||||||
|                 query = queries[:, pind : pind + 1] |                 query = queries[:, pind : pind + 1] | ||||||
|  |  | ||||||
| @@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module): | |||||||
|                 vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] |                 vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] | ||||||
|         else: |         else: | ||||||
|             if self.grid_size > 0: |             if self.grid_size > 0: | ||||||
|                 xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) |                 xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device) | ||||||
|                 xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # |                 xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to( | ||||||
|  |                     device | ||||||
|  |                 )  # | ||||||
|                 queries = torch.cat([queries, xy], dim=1)  # |                 queries = torch.cat([queries, xy], dim=1)  # | ||||||
|  |  | ||||||
|             traj_e, __, vis_e, __ = self.model( |             traj_e, __, vis_e, __ = self.model( | ||||||
| @@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module): | |||||||
|             query = torch.cat([query, xy_target], dim=1).to(device)  # |             query = torch.cat([query, xy_target], dim=1).to(device)  # | ||||||
|  |  | ||||||
|         if self.grid_size > 0: |         if self.grid_size > 0: | ||||||
|             xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) |             xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device) | ||||||
|             xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # |             xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device)  # | ||||||
|             query = torch.cat([query, xy], dim=1).to(device)  # |             query = torch.cat([query, xy], dim=1).to(device)  # | ||||||
|         # crop the video to start from the queried frame |         # crop the video to start from the queried frame | ||||||
|         query[0, 0, 0] = 0 |         query[0, 0, 0] = 0 | ||||||
|   | |||||||
| @@ -116,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|             queries[:, :, 1] *= self.interp_shape[1] / W |             queries[:, :, 1] *= self.interp_shape[1] / W | ||||||
|             queries[:, :, 2] *= self.interp_shape[0] / H |             queries[:, :, 2] *= self.interp_shape[0] / H | ||||||
|         elif grid_size > 0: |         elif grid_size > 0: | ||||||
|             grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) |             grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device) | ||||||
|             if segm_mask is not None: |             if segm_mask is not None: | ||||||
|                 segm_mask = F.interpolate( |                 segm_mask = F.interpolate( | ||||||
|                     segm_mask, tuple(self.interp_shape), mode="nearest" |                     segm_mask, tuple(self.interp_shape), mode="nearest" | ||||||
|   | |||||||
| @@ -65,26 +65,10 @@ | |||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 1, |    "execution_count": 2, | ||||||
|    "id": "1745a859-71d4-4ec3-8ef3-027cabe786d4", |    "id": "1745a859-71d4-4ec3-8ef3-027cabe786d4", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [], | ||||||
|     { |  | ||||||
|      "name": "stdout", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "/private/home/nikitakaraev/dev/neurips_2023/co-tracker\n" |  | ||||||
|      ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|      "name": "stderr", |  | ||||||
|      "output_type": "stream", |  | ||||||
|      "text": [ |  | ||||||
|       "/private/home/nikitakaraev/.conda/envs/stereoformer/lib/python3.8/site-packages/requests/__init__.py:109: RequestsDependencyWarning: urllib3 (1.26.14) or chardet (None)/charset_normalizer (3.2.0) doesn't match a supported version!\n", |  | ||||||
|       "  warnings.warn(\n" |  | ||||||
|      ] |  | ||||||
|     } |  | ||||||
|    ], |  | ||||||
|    "source": [ |    "source": [ | ||||||
|     "%cd ..\n", |     "%cd ..\n", | ||||||
|     "import os\n", |     "import os\n", | ||||||
| @@ -105,7 +89,7 @@ | |||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 2, |    "execution_count": 3, | ||||||
|    "id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd", |    "id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
| @@ -116,7 +100,7 @@ | |||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 3, |    "execution_count": 4, | ||||||
|    "id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87", |    "id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [ | ||||||
| @@ -129,7 +113,7 @@ | |||||||
|        "<IPython.core.display.HTML object>" |        "<IPython.core.display.HTML object>" | ||||||
|       ] |       ] | ||||||
|      }, |      }, | ||||||
|      "execution_count": 3, |      "execution_count": 4, | ||||||
|      "metadata": {}, |      "metadata": {}, | ||||||
|      "output_type": "execute_result" |      "output_type": "execute_result" | ||||||
|     } |     } | ||||||
| @@ -175,8 +159,8 @@ | |||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "if torch.cuda.is_available():\n", |     "if torch.cuda.is_available():\n", | ||||||
|     "    model=model.cuda()\n", |     "    model = model.cuda()\n", | ||||||
|     "    video=video.cuda()" |     "    video = video.cuda()" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @@ -282,7 +266,9 @@ | |||||||
|     "    [10., 600., 500.], # frame number 10\n", |     "    [10., 600., 500.], # frame number 10\n", | ||||||
|     "    [20., 750., 600.], # ...\n", |     "    [20., 750., 600.], # ...\n", | ||||||
|     "    [30., 900., 200.]\n", |     "    [30., 900., 200.]\n", | ||||||
|     "]).cuda()" |     "])\n", | ||||||
|  |     "if torch.cuda.is_available():\n", | ||||||
|  |     "    queries = queries.cuda()" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								train.py
									
									
									
									
									
								
							| @@ -138,6 +138,8 @@ def run_test_eval(evaluator, model, dataloaders, writer, step): | |||||||
|             single_point=False, |             single_point=False, | ||||||
|             n_iters=6, |             n_iters=6, | ||||||
|         ) |         ) | ||||||
|  |         if torch.cuda.is_available(): | ||||||
|  |             predictor.model = predictor.model.cuda() | ||||||
|  |  | ||||||
|         metrics = evaluator.evaluate_sequence( |         metrics = evaluator.evaluate_sequence( | ||||||
|             model=predictor, |             model=predictor, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user