improve training comments
This commit is contained in:
		
							
								
								
									
										109
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										109
									
								
								train.py
									
									
									
									
									
								
							| @@ -36,21 +36,6 @@ from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_ | |||||||
| from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss | from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss | ||||||
|  |  | ||||||
|  |  | ||||||
| # define the handler function |  | ||||||
| # for training on a slurm cluster |  | ||||||
| def sig_handler(signum, frame): |  | ||||||
|     print("caught signal", signum) |  | ||||||
|     print(socket.gethostname(), "USR1 signal caught.") |  | ||||||
|     # do other stuff to cleanup here |  | ||||||
|     print("requeuing job " + os.environ["SLURM_JOB_ID"]) |  | ||||||
|     os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) |  | ||||||
|     sys.exit(-1) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def term_handler(signum, frame): |  | ||||||
|     print("bypassing sigterm", flush=True) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def fetch_optimizer(args, model): | def fetch_optimizer(args, model): | ||||||
|     """Create the optimizer and learning rate scheduler""" |     """Create the optimizer and learning rate scheduler""" | ||||||
|     optimizer = optim.AdamW( |     optimizer = optim.AdamW( | ||||||
| @@ -302,9 +287,7 @@ class Lite(LightningLite): | |||||||
|             eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture)) |             eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture)) | ||||||
|  |  | ||||||
|         if "tapvid_davis_first" in args.eval_datasets: |         if "tapvid_davis_first" in args.eval_datasets: | ||||||
|             data_root = os.path.join( |             data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl") | ||||||
|                 args.dataset_root, "/tapvid_davis/tapvid_davis.pkl" |  | ||||||
|             ) |  | ||||||
|             eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root) |             eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root) | ||||||
|             eval_dataloader_tapvid_davis = torch.utils.data.DataLoader( |             eval_dataloader_tapvid_davis = torch.utils.data.DataLoader( | ||||||
|                 eval_dataset, |                 eval_dataset, | ||||||
| @@ -551,17 +534,15 @@ class Lite(LightningLite): | |||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     signal.signal(signal.SIGUSR1, sig_handler) |  | ||||||
|     signal.signal(signal.SIGTERM, term_handler) |  | ||||||
|     parser = argparse.ArgumentParser() |     parser = argparse.ArgumentParser() | ||||||
|     parser.add_argument("--model_name", default="cotracker", help="model name") |     parser.add_argument("--model_name", default="cotracker", help="model name") | ||||||
|     parser.add_argument("--restore_ckpt", help="restore checkpoint") |     parser.add_argument("--restore_ckpt", help="path to restore a checkpoint") | ||||||
|     parser.add_argument("--ckpt_path", help="restore checkpoint") |     parser.add_argument("--ckpt_path", help="path to save checkpoints") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--batch_size", type=int, default=4, help="batch size used during training." |         "--batch_size", type=int, default=4, help="batch size used during training." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--num_workers", type=int, default=6, help="left right consistency loss" |         "--num_workers", type=int, default=6, help="number of dataloader workers" | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -578,20 +559,34 @@ if __name__ == "__main__": | |||||||
|         "--evaluate_every_n_epoch", |         "--evaluate_every_n_epoch", | ||||||
|         type=int, |         type=int, | ||||||
|         default=1, |         default=1, | ||||||
|         help="number of flow-field updates during validation forward pass", |         help="evaluate during training after every n epochs, after every epoch by default", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--save_every_n_epoch", |         "--save_every_n_epoch", | ||||||
|         type=int, |         type=int, | ||||||
|         default=1, |         default=1, | ||||||
|         help="number of flow-field updates during validation forward pass", |         help="save checkpoints during training after every n epochs, after every epoch by default", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--validate_at_start", action="store_true", help="use mixed precision" |         "--validate_at_start", | ||||||
|  |         action="store_true", | ||||||
|  |         help="whether to run evaluation before training starts", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_freq", | ||||||
|  |         type=int, | ||||||
|  |         default=100, | ||||||
|  |         help="frequency of trajectory visualization during training", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--traj_per_sample", | ||||||
|  |         type=int, | ||||||
|  |         default=768, | ||||||
|  |         help="the number of trajectories to sample for training", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--dataset_root", type=str, help="path lo all the datasets (train and eval)" | ||||||
|     ) |     ) | ||||||
|     parser.add_argument("--save_freq", type=int, default=100, help="save_freq") |  | ||||||
|     parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq") |  | ||||||
|     parser.add_argument("--dataset_root", type=str, help="path lo all the datasets") |  | ||||||
|  |  | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--train_iters", |         "--train_iters", | ||||||
| @@ -605,49 +600,75 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--eval_datasets", |         "--eval_datasets", | ||||||
|         nargs="+", |         nargs="+", | ||||||
|         default=["things", "badja", "fastcapture"], |         default=["things", "badja"], | ||||||
|         help="eval datasets.", |         help="what datasets to use for evaluation", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--remove_space_attn", action="store_true", help="use mixed precision" |         "--remove_space_attn", | ||||||
|  |         action="store_true", | ||||||
|  |         help="remove space attention from CoTracker", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--dont_use_augs", action="store_true", help="use mixed precision" |         "--dont_use_augs", | ||||||
|  |         action="store_true", | ||||||
|  |         help="don't apply augmentations during training", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--sample_vis_1st_frame", action="store_true", help="use mixed precision" |         "--sample_vis_1st_frame", | ||||||
|  |         action="store_true", | ||||||
|  |         help="only sample trajectories with points visible on the first frame", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--sliding_window_len", type=int, default=8, help="use mixed precision" |         "--sliding_window_len", | ||||||
|  |         type=int, | ||||||
|  |         default=8, | ||||||
|  |         help="length of the CoTracker sliding window", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--updateformer_hidden_size", type=int, default=384, help="use mixed precision" |         "--updateformer_hidden_size", | ||||||
|  |         type=int, | ||||||
|  |         default=384, | ||||||
|  |         help="hidden dimension of the CoTracker transformer model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--updateformer_num_heads", type=int, default=8, help="use mixed precision" |         "--updateformer_num_heads", | ||||||
|  |         type=int, | ||||||
|  |         default=8, | ||||||
|  |         help="number of heads of the CoTracker transformer model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--updateformer_space_depth", type=int, default=12, help="use mixed precision" |         "--updateformer_space_depth", | ||||||
|  |         type=int, | ||||||
|  |         default=12, | ||||||
|  |         help="number of group attention layers in the CoTracker transformer model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--updateformer_time_depth", type=int, default=12, help="use mixed precision" |         "--updateformer_time_depth", | ||||||
|  |         type=int, | ||||||
|  |         default=12, | ||||||
|  |         help="number of time attention layers in the CoTracker transformer model", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--model_stride", type=int, default=8, help="use mixed precision" |         "--model_stride", | ||||||
|  |         type=int, | ||||||
|  |         default=8, | ||||||
|  |         help="stride of the CoTracker feature network", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--crop_size", |         "--crop_size", | ||||||
|         type=int, |         type=int, | ||||||
|         nargs="+", |         nargs="+", | ||||||
|         default=[384, 512], |         default=[384, 512], | ||||||
|         help="use mixed precision", |         help="crop videos to this resolution during training", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--eval_max_seq_len", type=int, default=1000, help="use mixed precision" |         "--eval_max_seq_len", | ||||||
|  |         type=int, | ||||||
|  |         default=1000, | ||||||
|  |         help="maximum length of evaluation videos", | ||||||
|     ) |     ) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     logging.basicConfig( |     logging.basicConfig( | ||||||
|         level=logging.INFO, |         level=logging.INFO, | ||||||
|         format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", |         format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", | ||||||
| @@ -661,5 +682,5 @@ if __name__ == "__main__": | |||||||
|         devices="auto", |         devices="auto", | ||||||
|         accelerator="gpu", |         accelerator="gpu", | ||||||
|         precision=32, |         precision=32, | ||||||
|         num_nodes=4, |         # num_nodes=4, | ||||||
|     ).run(args) |     ).run(args) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user