predict with new pifpaf
This commit is contained in:
parent
c877a16c4b
commit
966b692e4d
@ -56,7 +56,7 @@ class Loco:
|
||||
output_size=output_size)
|
||||
else:
|
||||
self.model = MonStereoModel(p_dropout=p_dropout, input_size=input_size, output_size=output_size,
|
||||
linear_size=linear_size, device=self.device)
|
||||
linear_size=linear_size, device=self.device)
|
||||
|
||||
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
|
||||
else:
|
||||
|
||||
@ -82,7 +82,7 @@ def factory_for_gt(im_size, name=None, path_gt=None, verbose=True):
|
||||
dic_gt = None
|
||||
x_factor = im_size[0] / 1600
|
||||
y_factor = im_size[1] / 900
|
||||
pixel_factor = (x_factor + y_factor) / 1.75 # 1.7 for MOT
|
||||
pixel_factor = (x_factor + y_factor) / 1.75 # 1.75 for MOT
|
||||
# pixel_factor = 1
|
||||
if im_size[0] / im_size[1] > 2.5:
|
||||
kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]] # Kitti calibration
|
||||
@ -274,7 +274,6 @@ def extract_outputs(outputs, tasks=()):
|
||||
|
||||
if outputs.shape[1] == 10:
|
||||
dic_out['aux'] = torch.sigmoid(dic_out['aux'])
|
||||
|
||||
return dic_out
|
||||
|
||||
|
||||
|
||||
@ -105,19 +105,18 @@ def predict(args):
|
||||
pifpaf_out = [ann.json_data() for ann in pred]
|
||||
|
||||
if batch_i == 0:
|
||||
pifpaf_outputs = [keypoint_sets, scores, pifpaf_out] # keypoints_sets and scores for pifpaf printing
|
||||
images_outputs = [cpu_image] # List of 1 or 2 elements with pifpaf tensor and monoloco original image
|
||||
pifpaf_outputs = pred # to only print left image for stereo
|
||||
pifpaf_outs = {'left': pifpaf_out}
|
||||
with open(meta_batch[0]['file_name'], 'rb') as f:
|
||||
cpu_image = PIL.Image.open(f).convert('RGB')
|
||||
else:
|
||||
pifpaf_outs['right'] = pifpaf_out
|
||||
|
||||
# Load the original image
|
||||
if args.net in ('monoloco_pp', 'monstereo'):
|
||||
with open(meta['file_name'], 'rb') as f:
|
||||
cpu_image = PIL.Image.open(f).convert('RGB')
|
||||
|
||||
im_name = os.path.basename(meta['file_name'])
|
||||
im_size = (cpu_image.size()[1], cpu_image.size()[0]) # Original
|
||||
im_size = (cpu_image.size[0], cpu_image.size[1]) # Original
|
||||
kk, dic_gt = factory_for_gt(im_size, name=im_name, path_gt=args.path_gt)
|
||||
|
||||
# Preprocess pifpaf outputs and run monoloco
|
||||
@ -138,53 +137,52 @@ def predict(args):
|
||||
dic_out = defaultdict(list)
|
||||
kk = None
|
||||
|
||||
factory_outputs(args, images_outputs, output_path, pifpaf_outputs, dic_out=dic_out, kk=kk)
|
||||
# TODO Clean
|
||||
factory_outputs(args, annotation_painter, cpu_image, output_path, pifpaf_outputs, pifpaf_out,
|
||||
dic_out=dic_out, kk=kk)
|
||||
print('Image {}\n'.format(cnt) + '-' * 120)
|
||||
cnt += 1
|
||||
|
||||
|
||||
def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, dic_out=None, kk=None):
|
||||
def factory_outputs(args, annotation_painter, cpu_image, output_path, pred, pifpaf_out, dic_out=None, kk=None):
|
||||
"""Output json files or images according to the choice"""
|
||||
|
||||
# Save json file
|
||||
if args.mode == 'pifpaf':
|
||||
with show.image_canvas(cpu_image, image_out_name) as ax:
|
||||
if args.net == 'pifpaf':
|
||||
with openpifpaf.show.image_canvas(cpu_image, output_path) as ax:
|
||||
annotation_painter.annotations(ax, pred)
|
||||
|
||||
|
||||
keypoint_sets, scores, pifpaf_out = pifpaf_outputs[:]
|
||||
|
||||
# Visualizer
|
||||
keypoint_painter = KeypointPainter(show_box=False)
|
||||
skeleton_painter = KeypointPainter(show_box=False, color_connections=True, markersize=1, linewidth=4)
|
||||
|
||||
if 'json' in args.output_types and keypoint_sets.size > 0:
|
||||
if 'json' in args.output_types and len(pred) > 0:
|
||||
with open(output_path + '.pifpaf.json', 'w') as f:
|
||||
json.dump(pifpaf_out, f)
|
||||
|
||||
if 'keypoints' in args.output_types:
|
||||
with image_canvas(images_outputs[0],
|
||||
output_path + '.keypoints.png',
|
||||
show=args.show,
|
||||
fig_width=args.figure_width,
|
||||
dpi_factor=args.dpi_factor) as ax:
|
||||
keypoint_painter.keypoints(ax, keypoint_sets)
|
||||
|
||||
if 'skeleton' in args.output_types:
|
||||
with image_canvas(images_outputs[0],
|
||||
output_path + '.skeleton.png',
|
||||
show=args.show,
|
||||
fig_width=args.figure_width,
|
||||
dpi_factor=args.dpi_factor) as ax:
|
||||
skeleton_painter.keypoints(ax, keypoint_sets, scores=scores)
|
||||
# if 'keypoints' in args.output_types:
|
||||
# with image_canvas(images_outputs[0],
|
||||
# output_path + '.keypoints.png',
|
||||
# show=args.show,
|
||||
# fig_width=args.figure_width,
|
||||
# dpi_factor=args.dpi_factor) as ax:
|
||||
# keypoint_painter.keypoints(ax, keypoint_sets)
|
||||
#
|
||||
# if 'skeleton' in args.output_types:
|
||||
# with image_canvas(images_outputs[0],
|
||||
# output_path + '.skeleton.png',
|
||||
# show=args.show,
|
||||
# fig_width=args.figure_width,
|
||||
# dpi_factor=args.dpi_factor) as ax:
|
||||
# skeleton_painter.keypoints(ax, keypoint_sets, scores=scores)
|
||||
|
||||
else:
|
||||
if any((xx in args.output_types for xx in ['front', 'bird', 'multi'])):
|
||||
print(output_path)
|
||||
if dic_out['boxes']: # Only print in case of detections
|
||||
printer = Printer(images_outputs[1], output_path, kk, args)
|
||||
printer = Printer(cpu_image, output_path, kk, args)
|
||||
figures, axes = printer.factory_axes()
|
||||
printer.draw(figures, axes, dic_out, images_outputs[1])
|
||||
printer.draw(figures, axes, dic_out, cpu_image)
|
||||
|
||||
if 'json' in args.output_types:
|
||||
with open(os.path.join(output_path + '.monoloco.json'), 'w') as ff:
|
||||
|
||||
@ -86,8 +86,7 @@ class Printer:
|
||||
self.boxes = dic_ann['boxes']
|
||||
self.boxes_gt = dic_ann['boxes_gt']
|
||||
self.uv_camera = (int(self.im.size[0] / 2), self.im.size[1])
|
||||
if dic_ann['aux']:
|
||||
self.auxs = dic_ann['aux'] if dic_ann['aux'] else None
|
||||
self.auxs = dic_ann['aux']
|
||||
|
||||
def factory_axes(self):
|
||||
"""Create axes for figures: front bird multi"""
|
||||
@ -199,7 +198,12 @@ class Printer:
|
||||
|
||||
def _draw_front(self, ax, z, idx, number):
|
||||
|
||||
mode = 'stereo' if self.auxs[idx] > 0.3 else 'mono'
|
||||
if len(self.auxs) == 0:
|
||||
mode = 'mono'
|
||||
elif self.auxs[idx] <= 0.3:
|
||||
mode = 'mono'
|
||||
else:
|
||||
mode = 'stereo'
|
||||
|
||||
# Bbox
|
||||
w = min(self.width-2, self.boxes[idx][2] - self.boxes[idx][0])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user