added training code
This commit is contained in:
		| @@ -8,11 +8,11 @@ Zachary Teed and Jia Deng<br/> | |||||||
| <img src="RAFT.png"> | <img src="RAFT.png"> | ||||||
|  |  | ||||||
| ## Requirements | ## Requirements | ||||||
| The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build. | The code has been tested with PyTorch 1.6 and Cuda 10.1. | ||||||
| ```Shell | ```Shell | ||||||
| conda create --name raft | conda create --name raft | ||||||
| conda activate raft | conda activate raft | ||||||
| conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch | ||||||
| conda install matplotlib | conda install matplotlib | ||||||
| conda install tensorboard | conda install tensorboard | ||||||
| conda install scipy | conda install scipy | ||||||
| @@ -67,8 +67,7 @@ python evaluate.py --model=models/raft-things.pth --dataset=sintel | |||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ## Training | ## Training | ||||||
| Training code will be made available in the next few days | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard | ||||||
| <!-- We used the following training schedule in our paper (note: we use 2 GPUs for training). Training logs will be written to the `runs` which can be visualized using tensorboard |  | ||||||
| ```Shell | ```Shell | ||||||
| ./train_standard.sh | ./train_standard.sh | ||||||
| ``` | ``` | ||||||
| @@ -76,4 +75,4 @@ Training code will be made available in the next few days | |||||||
| If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) | ||||||
| ```Shell | ```Shell | ||||||
| ./train_mixed.sh | ./train_mixed.sh | ||||||
| ``` --> | ``` | ||||||
|   | |||||||
| @@ -200,7 +200,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): | |||||||
|     """ Create the data loader for the corresponding trainign set """ |     """ Create the data loader for the corresponding trainign set """ | ||||||
|  |  | ||||||
|     if args.stage == 'chairs': |     if args.stage == 'chairs': | ||||||
|         aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True} |         aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} | ||||||
|         train_dataset = FlyingChairs(aug_params, split='training') |         train_dataset = FlyingChairs(aug_params, split='training') | ||||||
|      |      | ||||||
|     elif args.stage == 'things': |     elif args.stage == 'things': | ||||||
| @@ -210,14 +210,14 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): | |||||||
|         train_dataset = clean_dataset + final_dataset |         train_dataset = clean_dataset + final_dataset | ||||||
|  |  | ||||||
|     elif args.stage == 'sintel': |     elif args.stage == 'sintel': | ||||||
|         aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True} |         aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} | ||||||
|         things = FlyingThings3D(aug_params, dstype='frames_cleanpass') |         things = FlyingThings3D(aug_params, dstype='frames_cleanpass') | ||||||
|         sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') |         sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') | ||||||
|         sintel_final = MpiSintel(aug_params, split='training', dstype='final')         |         sintel_final = MpiSintel(aug_params, split='training', dstype='final')         | ||||||
|  |  | ||||||
|         if TRAIN_DS == 'C+T+K+S+H': |         if TRAIN_DS == 'C+T+K+S+H': | ||||||
|             kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}) |             kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) | ||||||
|             hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True}) |             hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) | ||||||
|             train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things |             train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things | ||||||
|  |  | ||||||
|         elif TRAIN_DS == 'C+T+K/S': |         elif TRAIN_DS == 'C+T+K/S': | ||||||
| @@ -225,7 +225,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): | |||||||
|  |  | ||||||
|     elif args.stage == 'kitti': |     elif args.stage == 'kitti': | ||||||
|         aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} |         aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} | ||||||
|         train_dataset = KITTI(args, image_size=args.image_size, is_val=False) |         train_dataset = KITTI(aug_params, split='training') | ||||||
|  |  | ||||||
|     train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,  |     train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,  | ||||||
|         pin_memory=False, shuffle=True, num_workers=4, drop_last=True) |         pin_memory=False, shuffle=True, num_workers=4, drop_last=True) | ||||||
|   | |||||||
							
								
								
									
										7
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								train.py
									
									
									
									
									
								
							| @@ -39,7 +39,7 @@ except: | |||||||
|  |  | ||||||
|  |  | ||||||
| # exclude extremly large displacements | # exclude extremly large displacements | ||||||
| MAX_FLOW = 500 | MAX_FLOW = 400 | ||||||
| SUM_FREQ = 100 | SUM_FREQ = 100 | ||||||
| VAL_FREQ = 5000 | VAL_FREQ = 5000 | ||||||
|  |  | ||||||
| @@ -181,13 +181,14 @@ def train(args): | |||||||
|  |  | ||||||
|             loss, metrics = sequence_loss(flow_predictions, flow, valid) |             loss, metrics = sequence_loss(flow_predictions, flow, valid) | ||||||
|             scaler.scale(loss).backward() |             scaler.scale(loss).backward() | ||||||
|  |             scaler.unscale_(optimizer)                 | ||||||
|             scaler.unscale_(optimizer) |  | ||||||
|             torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) |             torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) | ||||||
|              |              | ||||||
|             scaler.step(optimizer) |             scaler.step(optimizer) | ||||||
|             scheduler.step() |             scheduler.step() | ||||||
|             scaler.update() |             scaler.update() | ||||||
|  |  | ||||||
|  |  | ||||||
|             logger.push(metrics) |             logger.push(metrics) | ||||||
|  |  | ||||||
|             if total_steps % VAL_FREQ == VAL_FREQ - 1: |             if total_steps % VAL_FREQ == VAL_FREQ - 1: | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								train_mixed.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										6
									
								
								train_mixed.sh
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | mkdir -p checkpoints | ||||||
|  | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision  | ||||||
|  | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision | ||||||
|  | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision | ||||||
|  | python -u train.py --name raft-kitti  --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision | ||||||
							
								
								
									
										6
									
								
								train_standard.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										6
									
								
								train_standard.sh
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | mkdir -p checkpoints | ||||||
|  | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 | ||||||
|  | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 | ||||||
|  | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 | ||||||
|  | python -u train.py --name raft-kitti  --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 | ||||||
		Reference in New Issue
	
	Block a user