printer.py cleanup

This commit is contained in:
charlesbvll 2021-04-25 10:06:26 +02:00
parent e64ab138b3
commit 3458cc58e9
5 changed files with 63 additions and 60 deletions

View File

@ -10,7 +10,9 @@ import torch
import matplotlib.pyplot as plt
from .network.process import laplace_sampling
from .visuals.pifpaf_show import KeypointPainter, image_canvas, get_pifpaf_outputs, draw_orientation, social_distance_colors
from .visuals.pifpaf_show import (
KeypointPainter, image_canvas, get_pifpaf_outputs, draw_orientation, social_distance_colors
)
def social_interactions(idx, centers, angles, dds, stds=None, social_distance=False,

View File

@ -48,7 +48,8 @@ def cli():
visualizer.cli(parser)
# Monoloco
predict_parser.add_argument('--activities', nargs='+', help='Choose activities to show: social_distance, raise_hand')
predict_parser.add_argument('--activities', nargs='+',
help='Choose activities to show: social_distance, raise_hand')
predict_parser.add_argument('--mode', help='keypoints, mono, stereo', default='mono')
predict_parser.add_argument('--model', help='path of MonoLoco/MonStereo model to load')
predict_parser.add_argument('--net', help='only to select older MonoLoco model, otherwise use --mode')

View File

@ -109,10 +109,12 @@ class KeypointPainter:
for ci, connection in enumerate(np.array(self.skeleton) - 1):
c = color
linewidth=self.linewidth
if ((connection[0] == 5 and connection[1] == 7) or (connection[0] == 7 and connection[1] == 9)) and raise_hand in ['left','both']:
if ((connection[0] == 5 and connection[1] == 7)
or (connection[0] == 7 and connection[1] == 9)) and raise_hand in ['left','both']:
c = 'yellow'
linewidth = l_arm_width
if ((connection[0] == 6 and connection[1] == 8) or (connection[0] == 8 and connection[1] == 10)) and raise_hand in ['right', 'both']:
if ((connection[0] == 6 and connection[1] == 8)
or (connection[0] == 8 and connection[1] == 10)) and raise_hand in ['right', 'both']:
c = 'yellow'
linewidth = r_arm_width
if self.color_connections:
@ -190,7 +192,8 @@ class KeypointPainter:
matplotlib.patches.Rectangle(
(x - scale, y - scale), 2 * scale, 2 * scale, fill=False, color=color))
def keypoints(self, ax, keypoint_sets, *, size=None, scores=None, color=None, colors=None, texts=None, raise_hand='none'):
def keypoints(self, ax, keypoint_sets, *,
size=None, scores=None, color=None, colors=None, texts=None, raise_hand='none'):
if keypoint_sets is None:
return
@ -211,7 +214,7 @@ class KeypointPainter:
if isinstance(color, (int, np.integer)):
color = matplotlib.cm.get_cmap('tab20')((color % 20 + 0.05) / 20)
if raise_hand is not 'none':
if raise_hand != 'none':
# if raise_hand[:][i] is 'both' or raise_hand[:][i] is 'left' or raise_hand[:][i] is 'right':
# color = 'green'
self._draw_skeleton(ax, x, y, v, size=size, color=color, raise_hand=raise_hand[:][i])
@ -229,19 +232,6 @@ class KeypointPainter:
if texts is not None:
self._draw_text(ax, x, y, v, texts[i], color)
# nose = 0
# l_ear = 3
# l_shoulder = 5
# r_ear = 4
# r_shoulder = 6
# head_width = kps[l_ear][0]- kps[r_ear][0]
# head_top = (kps[nose][1] - head_width)
# ax.plot([kps[l_shoulder][0],kps[l_shoulder][0]], [kps[l_shoulder][1],head_top], linewidth=10, color='red')
# ax.plot([kps[r_shoulder][0],kps[r_shoulder][0]], [kps[r_shoulder][1],head_top], linewidth=10, color='red')
# ax.plot([kps[l_shoulder][0],kps[r_shoulder][0]], [head_top,head_top], linewidth=10, color='red')
def annotations(self, ax, annotations, *,
color=None, colors=None, texts=None):

View File

@ -195,6 +195,39 @@ class Printer:
def social_distance_bird(self, axis, colors):
draw_orientation(axis, self.xz_centers, [], self.angles, colors, mode='bird')
def _front_loop(self, iterator, axes, number, colors, annotations, dic_out):
for idx in iterator:
if any(xx in self.output_types for xx in ['front', 'multi']) and self.zz_pred[idx] > 0:
if self.args.activities:
if 'social_distance' in self.args.activities:
self.social_distance_front(axes[0], colors, annotations, dic_out)
elif 'raise_hand' in self.args.activities:
self.social_distance_front(axes[0], colors, annotations, dic_out)
else:
self._draw_front(axes[0],
self.dd_pred[idx],
idx,
number)
number['num'] += 1
def _bird_loop(self, iterator, axes, colors, number):
for idx in iterator:
if any(xx in self.output_types for xx in ['bird', 'multi']) and self.zz_pred[idx] > 0:
if self.args.activities:
if 'social_distance' in self.args.activities:
self.social_distance_bird(axes[1], colors)
# Draw ground truth and uncertainty
self._draw_uncertainty(axes, idx)
# Draw bird eye view text
if number['flag']:
self._draw_text_bird(axes, idx, number['num'])
number['num'] += 1
def draw(self, figures, axes, image, dic_out=None, annotations=None):
if self.args.activities:
@ -211,37 +244,16 @@ class Printer:
number = dict(flag=False, num=97)
if any(xx in self.output_types for xx in ['front', 'multi']):
number['flag'] = True # add numbers
# Remove image if social distance is activated
if not self.args.activities or 'social_distance' not in self.args.activities:
self.mpl_im0.set_data(image)
for idx in iterator:
if any(xx in self.output_types for xx in ['front', 'multi']) and self.zz_pred[idx] > 0:
if self.args.activities:
if 'social_distance' in self.args.activities:
self.social_distance_front(axes[0], colors, annotations, dic_out)
elif 'raise_hand' in self.args.activities:
self.social_distance_front(axes[0], colors, annotations, dic_out)
else:
self._draw_front(axes[0],
self.dd_pred[idx],
idx,
number)
number['num'] += 1
self._front_loop(iterator, axes, number, colors, annotations, dic_out)
# Draw the bird figure
number['num'] = 97
for idx in iterator:
if any(xx in self.output_types for xx in ['bird', 'multi']) and self.zz_pred[idx] > 0:
self._bird_loop(iterator, axes, colors, number)
if self.args.activities:
if 'social_distance' in self.args.activities:
self.social_distance_bird(axes[1], colors)
# Draw ground truth and uncertainty
self._draw_uncertainty(axes, idx)
# Draw bird eye view text
if number['flag']:
self._draw_text_bird(axes, idx, number['num'])
number['num'] += 1
self._draw_legend(axes)
# Draw, save or/and show the figures
@ -255,7 +267,6 @@ class Printer:
plt.close(fig)
def _draw_front(self, ax, z, idx, number):
# Bbox

View File

@ -7,7 +7,6 @@ Implementation adapted from https://github.com/vita-epfl/openpifpaf/blob/master/
"""
import time
import os
import logging
import torch
@ -42,7 +41,7 @@ def factory_from_args(args):
assert 'bird' not in args.output_types
if 'json' not in args.output_types:
assert len(args.output_types) is 1
assert len(args.output_types) == 1
else:
assert len(args.output_types) < 3
@ -78,7 +77,7 @@ def factory_from_args(args):
def webcam(args):
assert args.mode in ('mono')
assert args.mode in 'mono'
args, dic_models = factory_from_args(args)