fix no detection case

This commit is contained in:
lorenzo 2019-07-01 18:58:20 +02:00
parent 2357ce67fc
commit 93d878666d
5 changed files with 15 additions and 16 deletions

View File

@ -14,15 +14,15 @@ def factory_for_gt(im_size, name=None, path_gt=None):
try:
with open(path_gt, 'r') as f:
dic_names = json.load(f)
print('-' * 120 + "\nMonoloco: Ground-truth file opened")
print('-' * 120 + "\nGround-truth file opened")
except (FileNotFoundError, TypeError):
print('-' * 120 + "\nMonoloco: ground-truth file not found\n")
print('-' * 120 + "\nGround-truth file not found")
dic_names = {}
try:
kk = dic_names[name]['K']
dic_gt = dic_names[name]
print("Monoloco: matched ground-truth file!\n" + '-' * 120)
print("Matched ground-truth file!")
except KeyError:
dic_gt = None
x_factor = im_size[0] / 1600
@ -35,8 +35,7 @@ def factory_for_gt(im_size, name=None, path_gt=None):
[0, 1266.4 * pixel_factor, 491.5 * y_factor],
[0., 0., 1.]] # nuScenes calibration
print("Ground-truth annotations for the image not found\n"
"Using a standard calibration matrix...\n" + '-' * 120)
print("Using a standard calibration matrix...")
return kk, dic_gt
@ -74,18 +73,16 @@ def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, monoloco_
skeleton_painter.keypoints(ax, keypoint_sets, scores=scores)
if 'monoloco' in args.networks:
dic_out = monoloco_post_process(monoloco_outputs)
if any((xx in args.output_types for xx in ['front', 'bird', 'combined'])):
epistemic = False
if args.n_dropout > 0:
epistemic = True
printer = Printer(images_outputs[1], output_path, dic_out, kk, output_types=args.output_types,
show=args.show, z_max=args.z_max, epistemic=epistemic)
printer.print()
if dic_out['boxes']: # Only print in case of detections
printer = Printer(images_outputs[1], output_path, dic_out, kk, output_types=args.output_types,
show=args.show, z_max=args.z_max, epistemic=epistemic)
printer.print()
if 'json' in args.output_types:
with open(os.path.join(output_path + '.monoloco.json'), 'w') as ff:
@ -95,8 +92,11 @@ def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, monoloco_
def monoloco_post_process(monoloco_outputs, iou_min=0.25):
"""Post process monoloco to output final dictionary with all information for visualizations"""
dic_out = defaultdict(list)
outputs, varss, boxes, keypoints, kk, dic_gt = monoloco_outputs[:]
dic_out = defaultdict(list)
if outputs is None:
return dic_out
if dic_gt:
boxes_gt, dds_gt = dic_gt['boxes'], dic_gt['dds']
matches = get_iou_matches(boxes, boxes_gt, thresh=iou_min)

View File

@ -39,7 +39,7 @@ class MonoLoco:
def forward(self, keypoints, kk):
"""forward pass of monoloco network"""
if not keypoints:
return None
return None, None
with torch.no_grad():
inputs = get_network_inputs(torch.tensor(keypoints).to(self.device), torch.tensor(kk).to(self.device))

View File

@ -158,7 +158,7 @@ def predict(args):
kk = None
factory_outputs(args, images_outputs, output_path, pifpaf_outputs, monoloco_outputs=monoloco_outputs, kk=kk)
sys.stdout.write('\r' + 'Saving image {}'.format(cnt) + '\t')
print('Image {}\n'.format(cnt) + '-' * 120)
cnt += 1
return keypoints_whole

View File

@ -180,7 +180,7 @@ def xyz_from_distance(distances, xy_centers):
if type(distances) == float:
distances = torch.tensor(distances).unsqueeze(0)
if len(distances.size()) == 1:
distances = torch.tensor(distances).unsqueeze(1)
distances = distances.unsqueeze(1)
if len(xy_centers.size()) == 1:
xy_centers = xy_centers.unsqueeze(0)

View File

@ -39,7 +39,6 @@ class Printer:
self.z_max = z_max # To include ellipses in the image
self.fig_width = fig_width
# Define the output dir
self.path_out = output_path