added upsampling module

This commit is contained in:
Zach Teed
2020-07-25 17:36:17 -06:00
parent dc1220825d
commit a2408eab78
32 changed files with 23559 additions and 619 deletions

View File

@@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
@@ -12,10 +13,10 @@ class CorrBlock:
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.view(batch*h1*w1, dim, h2, w2)
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels):
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
@@ -40,14 +41,16 @@ class CorrBlock:
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
return corr / torch.sqrt(torch.tensor(dim).float())

View File

@@ -6,53 +6,42 @@ import torch.utils.data as data
import torch.nn.functional as F
import os
import cv2
import math
import random
from glob import glob
import os.path as osp
from utils import frame_utils
from utils.augmentor import FlowAugmentor, FlowAugmentorKITTI
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
class CombinedDataset(data.Dataset):
def __init__(self, datasets):
self.datasets = datasets
def __len__(self):
length = 0
for i in range(len(self.datasets)):
length += len(self.datsaets[i])
return length
def __getitem__(self, index):
i = 0
for j in range(len(self.datasets)):
if i + len(self.datasets[j]) >= index:
yield self.datasets[j][index-i]
break
i += len(self.datasets[j])
def __add__(self, other):
self.datasets.append(other)
return self
class FlowDataset(data.Dataset):
def __init__(self, args, image_size=None, do_augument=False):
self.image_size = image_size
self.do_augument = do_augument
if self.do_augument:
self.augumentor = FlowAugmentor(self.image_size)
def __init__(self, aug_params=None, sparse=False):
self.augmentor = None
self.sparse = sparse
if aug_params is not None:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
self.is_test = False
self.init_seed = False
self.flow_list = []
self.image_list = []
self.init_seed = False
self.extra_info = []
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
@@ -62,133 +51,96 @@ class FlowDataset(data.Dataset):
self.init_seed = True
index = index % len(self.image_list)
flow = frame_utils.read_gen(self.flow_list[index])
valid = None
if self.sparse:
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
flow = np.array(flow).astype(np.float32)
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
if self.do_augument:
img1, img2, flow = self.augumentor(img1, img2, flow)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
valid = torch.ones_like(flow[0])
return img1, img2, flow, valid
if valid is not None:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
return img1, img2, flow, valid.float()
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
def __add(self, other):
return CombinedDataset([self, other])
class MpiSintelTest(FlowDataset):
def __init__(self, args, root='datasets/Sintel/test', dstype='clean'):
super(MpiSintelTest, self).__init__(args, image_size=None, do_augument=False)
self.root = root
self.dstype = dstype
image_dir = osp.join(self.root, dstype)
all_sequences = os.listdir(image_dir)
self.image_list = []
for sequence in all_sequences:
frames = sorted(glob(osp.join(image_dir, sequence, '*.png')))
for i in range(len(frames)-1):
self.image_list += [[frames[i], frames[i+1], sequence, i]]
def __getitem__(self, index):
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
sequence = self.image_list[index][2]
frame = self.image_list[index][3]
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, sequence, frame
class MpiSintel(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, root='datasets/Sintel/training', dstype='clean'):
super(MpiSintel, self).__init__(args, image_size, do_augument)
if do_augument:
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 0.7
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
super(MpiSintel, self).__init__(aug_params)
flow_root = osp.join(root, split, 'flow')
image_root = osp.join(root, split, dstype)
self.root = root
self.dstype = dstype
if split == 'test':
self.is_test = True
flow_root = osp.join(root, 'flow')
image_root = osp.join(root, dstype)
for scene in os.listdir(image_root):
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
for i in range(len(image_list)-1):
self.image_list += [ [image_list[i], image_list[i+1]] ]
self.extra_info += [ (scene, i) ] # scene and frame_id
file_list = sorted(glob(osp.join(flow_root, '*/*.flo')))
for flo in file_list:
fbase = flo[len(flow_root)+1:]
fprefix = fbase[:-8]
fnum = int(fbase[-8:-4])
img1 = osp.join(image_root, fprefix + "%04d"%(fnum+0) + '.png')
img2 = osp.join(image_root, fprefix + "%04d"%(fnum+1) + '.png')
if not osp.isfile(img1) or not osp.isfile(img2) or not osp.isfile(flo):
continue
self.image_list.append((img1, img2))
self.flow_list.append(flo)
if split != 'test':
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
class FlyingChairs(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(args, image_size, do_augument)
self.root = root
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 1.0
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(aug_params)
images = sorted(glob(osp.join(root, '*.ppm')))
self.flow_list = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(self.flow_list))
flows = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(flows))
self.image_list = []
for i in range(len(self.flow_list)):
im1 = images[2*i]
im2 = images[2*i + 1]
self.image_list.append([im1, im2])
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
for i in range(len(flows)):
xid = split_list[i]
if (split=='training' and xid==1) or (split=='validation' and xid==2):
self.flow_list += [ flows[i] ]
self.image_list += [ [images[2*i], images[2*i+1]] ]
class SceneFlow(FlowDataset):
def __init__(self, args, image_size, do_augument=True, root='datasets',
dstype='frames_cleanpass', use_flyingthings=True, use_monkaa=False, use_driving=False):
super(SceneFlow, self).__init__(args, image_size, do_augument)
self.root = root
self.dstype = dstype
self.augumentor.min_scale = -0.2
self.augumentor.max_scale = 0.8
if use_flyingthings:
self.add_flyingthings()
if use_monkaa:
self.add_monkaa()
if use_driving:
self.add_driving()
def add_flyingthings(self):
root = osp.join(self.root, 'FlyingThings3D')
class FlyingThings3D(FlowDataset):
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
super(FlyingThings3D, self).__init__(aug_params)
for cam in ['left']:
for direction in ['into_future', 'into_past']:
image_dirs = sorted(glob(osp.join(root, self.dstype, 'TRAIN/*/*')))
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
@@ -199,114 +151,85 @@ class SceneFlow(FlowDataset):
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
for i in range(len(flows)-1):
if direction == 'into_future':
self.image_list += [[images[i], images[i+1]]]
self.flow_list += [flows[i]]
self.image_list += [ [images[i], images[i+1]] ]
self.flow_list += [ flows[i] ]
elif direction == 'into_past':
self.image_list += [[images[i+1], images[i]]]
self.flow_list += [flows[i+1]]
self.image_list += [ [images[i+1], images[i]] ]
self.flow_list += [ flows[i+1] ]
def add_monkaa(self):
pass # we don't use monkaa
def add_driving(self):
pass # we don't use driving
class KITTI(FlowDataset):
def __init__(self, args, image_size=None, do_augument=True, is_test=False, is_val=False, do_pad=False, split=True, root='datasets/KITTI'):
super(KITTI, self).__init__(args, image_size, do_augument)
self.root = root
self.is_test = is_test
self.is_val = is_val
self.do_pad = do_pad
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
super(KITTI, self).__init__(aug_params, sparse=True)
if split == 'testing':
self.is_test = True
if self.do_augument:
self.augumentor = FlowAugmentorKITTI(self.image_size, min_scale=-0.2, max_scale=0.5)
root = osp.join(root, split)
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
if self.is_test:
images1 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_10.png')))
images2 = sorted(glob(os.path.join(root, 'testing', 'image_2/*_11.png')))
for i in range(len(images1)):
self.image_list += [[images1[i], images2[i]]]
for img1, img2 in zip(images1, images2):
frame_id = img1.split('/')[-1]
self.extra_info += [ [frame_id] ]
self.image_list += [ [img1, img2] ]
else:
flows = sorted(glob(os.path.join(root, 'training', 'flow_occ/*_10.png')))
images1 = sorted(glob(os.path.join(root, 'training', 'image_2/*_10.png')))
images2 = sorted(glob(os.path.join(root, 'training', 'image_2/*_11.png')))
if split == 'training':
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
for i in range(len(flows)):
class HD1K(FlowDataset):
def __init__(self, aug_params=None, root='datasets/HD1k'):
super(HD1K, self).__init__(aug_params, sparse=True)
seq_ix = 0
while 1:
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
if len(flows) == 0:
break
for i in range(len(flows)-1):
self.flow_list += [flows[i]]
self.image_list += [[images1[i], images2[i]]]
self.image_list += [ [images[i], images[i+1]] ]
seq_ix += 1
def __getitem__(self, index):
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """
if self.is_test:
frame_id = self.image_list[index][0]
frame_id = frame_id.split('/')[-1]
if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training')
elif args.stage == 'things':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, frame_id
elif TRAIN_DS == 'C+T+K/S':
train_dataset = 100*sintel_clean + 100*sintel_final + things
elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
else:
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
index = index % len(self.image_list)
frame_id = self.image_list[index][0]
frame_id = frame_id.split('/')[-1]
print('Training with %d image pairs' % len(train_dataset))
return train_loader
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
if self.do_augument:
img1, img2, flow, valid = self.augumentor(img1, img2, flow, valid)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
valid = torch.from_numpy(valid).float()
if self.do_pad:
ht, wd = img1.shape[1:]
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
pad_ht1 = [0, pad_ht]
pad_wd1 = [pad_wd//2, pad_wd - pad_wd//2]
pad = pad_wd1 + pad_ht1
img1 = img1.view(1, 3, ht, wd)
img2 = img2.view(1, 3, ht, wd)
flow = flow.view(1, 2, ht, wd)
valid = valid.view(1, 1, ht, wd)
img1 = torch.nn.functional.pad(img1, pad, mode='replicate')
img2 = torch.nn.functional.pad(img2, pad, mode='replicate')
flow = torch.nn.functional.pad(flow, pad, mode='constant', value=0)
valid = torch.nn.functional.pad(valid, pad, mode='replicate', value=0)
img1 = img1.view(3, ht+pad_ht, wd+pad_wd)
img2 = img2.view(3, ht+pad_ht, wd+pad_wd)
flow = flow.view(2, ht+pad_ht, wd+pad_wd)
valid = valid.view(ht+pad_ht, wd+pad_wd)
if self.is_test:
return img1, img2, flow, valid, frame_id
return img1, img2, flow, valid

View File

@@ -143,10 +143,9 @@ class BasicEncoder(nn.Module):
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
for m in self.modules():
if isinstance(m, nn.Conv2d):
@@ -184,7 +183,7 @@ class BasicEncoder(nn.Module):
x = self.conv2(x)
if self.dropout is not None:
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
@@ -218,10 +217,9 @@ class SmallEncoder(nn.Module):
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
@@ -260,8 +258,8 @@ class SmallEncoder(nn.Module):
x = self.layer3(x)
x = self.conv2(x)
# if self.dropout is not None:
# x = self.dropout(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)

View File

@@ -3,11 +3,23 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.update import BasicUpdateBlock, SmallUpdateBlock
from modules.extractor import BasicEncoder, SmallEncoder
from modules.corr import CorrBlock
from update import BasicUpdateBlock, SmallUpdateBlock
from extractor import BasicEncoder, SmallEncoder
from corr import CorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module):
def __init__(self, args):
@@ -26,7 +38,7 @@ class RAFT(nn.Module):
args.corr_levels = 4
args.corr_radius = 4
if not hasattr(args, 'dropout'):
if 'dropout' not in args._get_kwargs():
args.dropout = 0
# feature network, context network, and update block
@@ -40,6 +52,7 @@ class RAFT(nn.Module):
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
@@ -54,46 +67,73 @@ class RAFT(nn.Module):
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True):
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3,3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8*H, 8*W)
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
fmap1, fmap2 = self.fnet([image1, image2])
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net, inp = torch.tanh(net), torch.relu(inp)
with autocast(enabled=self.args.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
# if dropout is being used reset mask
self.update_block.reset_mask(net, inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
net, delta_flow = self.update_block(net, inp, corr, flow)
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
if upsample:
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
flow_predictions.append(flow_up)
else:
flow_predictions.append(coords1 - coords0)
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions

View File

@@ -2,34 +2,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
# VariationalHidDropout from https://github.com/locuslab/trellisnet/tree/master/TrellisNet
class VariationalHidDropout(nn.Module):
def __init__(self, dropout=0.0):
"""
Hidden-to-hidden (VD-based) dropout that applies the same mask at every time step and every layer of TrellisNet
:param dropout: The dropout rate (0 means no dropout is applied)
"""
super(VariationalHidDropout, self).__init__()
self.dropout = dropout
self.mask = None
def reset_mask(self, x):
dropout = self.dropout
# Dimension (N, C, L)
n, c, h, w = x.shape
m = x.data.new(n, c, 1, 1).bernoulli_(1 - dropout)
with torch.no_grad():
mask = m / (1 - dropout)
self.mask = mask
return mask
def forward(self, x):
if not self.training or self.dropout == 0:
return x
assert self.mask is not None, "You need to reset mask before using VariationalHidDropout"
return self.mask * x
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
@@ -41,7 +13,6 @@ class FlowHead(nn.Module):
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
@@ -59,7 +30,6 @@ class ConvGRU(nn.Module):
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
@@ -133,49 +103,37 @@ class SmallUpdateBlock(nn.Module):
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
self.drop_net = VariationalHidDropout(dropout=args.dropout)
def reset_mask(self, net, inp):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, delta_flow
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
self.drop_net = VariationalHidDropout(dropout=args.dropout)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64*9, 1, padding=0))
def reset_mask(self, net, inp):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)
def forward(self, net, inp, corr, flow):
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, delta_flow
# scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow

View File

@@ -1,46 +1,55 @@
import numpy as np
import random
import math
import cv2
from PIL import Image
import cv2
import torch
import torchvision
from torchvision.transforms import ColorJitter
import torch.nn.functional as F
class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
# spatial augmentation params
self.crop_size = crop_size
self.augcolor = torchvision.transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.spatial_aug_prob = 0.8
self.eraser_aug_prob = 0.5
self.min_scale = min_scale
self.max_scale = max_scale
self.max_stretch = 0.2
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.margin = 20
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
""" Photometric augmentation """
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
img1 = np.array(self.augcolor(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.augcolor(Image.fromarray(img2)), dtype=np.uint8)
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
# symmetric
else:
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]):
""" Occlusion augmentation """
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
@@ -55,22 +64,18 @@ class FlowAugmentor:
def spatial_transform(self, img1, img2, flow):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht),
(self.crop_size[1] + 1) / float(wd))
(self.crop_size[0] + 8) / float(ht),
(self.crop_size[1] + 8) / float(wd))
max_scale = self.max_scale
min_scale = max(min_scale, self.min_scale)
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
@@ -81,22 +86,20 @@ class FlowAugmentor:
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = flow * [scale_x, scale_y]
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if self.do_flip:
if np.random.rand() < self.h_flip_prob: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < 0.1: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
if np.random.rand() < self.v_flip_prob: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(-self.margin, img1.shape[0] - self.crop_size[0] + self.margin)
x0 = np.random.randint(-self.margin, img1.shape[1] - self.crop_size[1] + self.margin)
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
@@ -114,22 +117,29 @@ class FlowAugmentor:
return img1, img2, flow
class FlowAugmentorKITTI:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5):
class SparseFlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
# spatial augmentation params
self.crop_size = crop_size
self.augcolor = torchvision.transforms.ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.max_scale = max_scale
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.augcolor(Image.fromarray(image_stack)), dtype=np.uint8)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
@@ -198,11 +208,12 @@ class FlowAugmentorKITTI:
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
if self.do_flip:
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
margin_y = 20
margin_x = 50

View File

@@ -103,6 +103,13 @@ def readFlowKITTI(filename):
flow = (flow - 2**15) / 64.0
return flow, valid
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
flow = np.stack([-disp, np.zeros_like(disp)], -1)
return flow, valid
def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
@@ -120,5 +127,8 @@ def read_gen(file_name, pil=False):
return readFlow(file_name).astype(np.float32)
elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32)
return flow[:, :, :-1]
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []

View File

@@ -4,21 +4,21 @@ import numpy as np
from scipy import interpolate
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
class InputPadder:
""" Pads images such that dimensions are divisible by 8 """
def __init__(self, dims):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def unpad(self,x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
@@ -42,15 +42,33 @@ def forward_interpolate(flow):
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest')
(x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest')
(x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()