added training code
This commit is contained in:
@@ -200,7 +200,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
||||
""" Create the data loader for the corresponding trainign set """
|
||||
|
||||
if args.stage == 'chairs':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
|
||||
train_dataset = FlyingChairs(aug_params, split='training')
|
||||
|
||||
elif args.stage == 'things':
|
||||
@@ -210,14 +210,14 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
||||
train_dataset = clean_dataset + final_dataset
|
||||
|
||||
elif args.stage == 'sintel':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
|
||||
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
||||
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
||||
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
||||
|
||||
if TRAIN_DS == 'C+T+K+S+H':
|
||||
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
|
||||
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
|
||||
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
|
||||
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
|
||||
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
||||
|
||||
elif TRAIN_DS == 'C+T+K/S':
|
||||
@@ -225,7 +225,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
||||
|
||||
elif args.stage == 'kitti':
|
||||
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
||||
train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
|
||||
train_dataset = KITTI(aug_params, split='training')
|
||||
|
||||
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
||||
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
||||
|
Reference in New Issue
Block a user