added training code
This commit is contained in:
		| @@ -200,7 +200,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): | ||||
|     """ Create the data loader for the corresponding trainign set """ | ||||
|  | ||||
|     if args.stage == 'chairs': | ||||
|         aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True} | ||||
|         aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} | ||||
|         train_dataset = FlyingChairs(aug_params, split='training') | ||||
|      | ||||
|     elif args.stage == 'things': | ||||
| @@ -210,14 +210,14 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): | ||||
|         train_dataset = clean_dataset + final_dataset | ||||
|  | ||||
|     elif args.stage == 'sintel': | ||||
|         aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True} | ||||
|         aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} | ||||
|         things = FlyingThings3D(aug_params, dstype='frames_cleanpass') | ||||
|         sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') | ||||
|         sintel_final = MpiSintel(aug_params, split='training', dstype='final')         | ||||
|  | ||||
|         if TRAIN_DS == 'C+T+K+S+H': | ||||
|             kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}) | ||||
|             hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True}) | ||||
|             kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) | ||||
|             hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) | ||||
|             train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things | ||||
|  | ||||
|         elif TRAIN_DS == 'C+T+K/S': | ||||
| @@ -225,7 +225,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): | ||||
|  | ||||
|     elif args.stage == 'kitti': | ||||
|         aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} | ||||
|         train_dataset = KITTI(args, image_size=args.image_size, is_val=False) | ||||
|         train_dataset = KITTI(aug_params, split='training') | ||||
|  | ||||
|     train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,  | ||||
|         pin_memory=False, shuffle=True, num_workers=4, drop_last=True) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user