Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions dataloaders/kitti_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataloaders import transforms
from dataloaders.pose_estimator import get_pose_pnp

input_options = ['d', 'rgb', 'rgbd', 'g', 'gd']
input_options = ['d', 'rgb', 'rgbd', 'y', 'yd']

def load_calib():
"""
Expand All @@ -35,7 +35,7 @@ def load_calib():
root_d = os.path.join('..', 'data', 'kitti_depth')
root_rgb = os.path.join('..', 'data', 'kitti_rgb')
def get_paths_and_transform(split, args):
assert (args.use_d or args.use_rgb or args.use_g), 'no proper input selected'
assert (args.use_d or args.use_rgb or args.use_y), 'no proper input selected'

if split == "train":
transform = train_transform
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_rgb_paths(p):
raise(RuntimeError("Requested sparse depth but none was found"))
if len(paths_rgb) == 0 and args.use_rgb:
raise(RuntimeError("Requested rgb images but none was found"))
if len(paths_rgb) == 0 and args.use_g:
if len(paths_rgb) == 0 and args.use_y:
raise(RuntimeError("Requested gray images but no rgb was found"))
if len(paths_rgb) != len(paths_d) or len(paths_rgb) != len(paths_gt):
raise(RuntimeError("Produced different sizes for datasets"))
Expand Down Expand Up @@ -186,7 +186,7 @@ def no_transform(rgb, sparse, target, rgb_near, args):
def handle_gray(rgb, args):
if rgb is None:
return None, None
if not args.use_g:
if not args.use_y:
return rgb, None
else:
img = np.array(Image.fromarray(rgb).convert('L'))
Expand Down Expand Up @@ -238,7 +238,7 @@ def __init__(self, split, args):

def __getraw__(self, index):
rgb = rgb_read(self.paths['rgb'][index]) if \
(self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_g)) else None
(self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_y)) else None
sparse = depth_read(self.paths['d'][index]) if \
(self.paths['d'][index] is not None and self.args.use_d) else None
target = depth_read(self.paths['gt'][index]) if \
Expand All @@ -265,7 +265,7 @@ def __getitem__(self, index):

rgb, gray = handle_gray(rgb, self.args)
candidates = {"rgb":rgb, "d":sparse, "gt":target, \
"g":gray, "r_mat":r_mat, "t_vec":t_vec, "rgb_near":rgb_near}
"y":gray, "r_mat":r_mat, "t_vec":t_vec, "rgb_near":rgb_near}
items = {key:to_float_tensor(val) for key, val in candidates.items() if val is not None}

return items
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-i','--input', type=str, default='gd',
parser.add_argument('-i','--input', type=str, default='yd',
choices=input_options, help='input: | '.join(input_options))
parser.add_argument('-l','--layers', type=int, default=34,
help='use 16 for sparse_conv; use 18 or 34 for resnet')
Expand All @@ -60,7 +60,7 @@
args.result = os.path.join('..', 'results')
args.use_rgb = ('rgb' in args.input) or args.use_pose
args.use_d = 'd' in args.input
args.use_g = 'g' in args.input
args.use_y = 'y' in args.input
if args.use_pose:
args.w1, args.w2 = 0.1, 0.1
else:
Expand Down
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, args):
if 'rgb' in self.modality:
channels = 64 * 3 // len(self.modality)
self.conv1_img = conv_bn_relu(3, channels, kernel_size=3, stride=1, padding=1)
elif 'g' in self.modality:
elif 'y' in self.modality:
channels = 64 // len(self.modality)
self.conv1_img = conv_bn_relu(1, channels, kernel_size=3, stride=1, padding=1)

Expand Down Expand Up @@ -107,10 +107,10 @@ def forward(self, x):
conv1_d = self.conv1_d(x['d'])
if 'rgb' in self.modality:
conv1_img = self.conv1_img(x['rgb'])
elif 'g' in self.modality:
conv1_img = self.conv1_img(x['g'])
elif 'y' in self.modality:
conv1_img = self.conv1_img(x['y'])

if self.modality=='rgbd' or self.modality=='gd':
if self.modality=='rgbd' or self.modality=='yd':
conv1 = torch.cat((conv1_d, conv1_img),1)
else:
conv1 = conv1_d if (self.modality=='d') else conv1_img
Expand Down
8 changes: 4 additions & 4 deletions vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def preprocess_depth(x):
rgb = np.squeeze(ele['rgb'][0,...].data.cpu().numpy())
rgb = np.transpose(rgb, (1, 2, 0))
img_list.append(rgb)
elif 'g' in ele:
g = np.squeeze(ele['g'][0,...].data.cpu().numpy())
g = np.array(Image.fromarray(g).convert('RGB'))
img_list.append(g)
elif 'y' in ele:
y = np.squeeze(ele['y'][0,...].data.cpu().numpy())
y = np.array(Image.fromarray(y).convert('RGB'))
img_list.append(y)
if 'd' in ele:
img_list.append(preprocess_depth(ele['d'][0,...]))
img_list.append(preprocess_depth(pred[0,...]))
Expand Down