added training code

This commit is contained in:
Zach Teed
2020-07-30 21:25:36 -06:00
parent dc370f877b
commit a1d8344039
5 changed files with 25 additions and 13 deletions

View File

@@ -200,7 +200,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """
if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training')
elif args.stage == 'things':
@@ -210,14 +210,14 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
train_dataset = clean_dataset + final_dataset
elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
elif TRAIN_DS == 'C+T+K/S':
@@ -225,7 +225,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
train_dataset = KITTI(aug_params, split='training')
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)