diff --git a/monoloco/predict.py b/monoloco/predict.py index 1066c0c..c267fd1 100644 --- a/monoloco/predict.py +++ b/monoloco/predict.py @@ -73,7 +73,7 @@ def download_checkpoints(args): assert not args.social_distance, "Social distance not supported in stereo modality" path = MONSTEREO_MODEL name = 'monstereo-201202-1212.pkl' - elif args.social_distance or (args.activities and 'social_distance' in args.activities) or args.webcam: + elif (args.activities and 'social_distance' in args.activities) or args.webcam: path = MONOLOCO_MODEL_NU name = 'monoloco_pp-201207-1350.pkl' else: @@ -220,7 +220,7 @@ def predict(args): dic_out = net.forward(keypoints, kk) dic_out = net.post_process( dic_out, boxes, keypoints, kk, dic_gt) - if args.social_distance or (args.activities and 'social_distance' in args.activities): + if args.activities and 'social_distance' in args.activities: dic_out = net.social_distance(dic_out, args) if args.activities and 'raise_hand' in args.activities: dic_out = net.raising_hand(dic_out, keypoints) diff --git a/monoloco/run.py b/monoloco/run.py index 74dc0ed..ae72242 100644 --- a/monoloco/run.py +++ b/monoloco/run.py @@ -20,7 +20,7 @@ def cli(): predict_parser.add_argument('--glob', help='glob expression for input images (for many images)') predict_parser.add_argument('--checkpoint', help='pifpaf model') predict_parser.add_argument('-o', '--output-directory', help='Output directory') - predict_parser.add_argument('--output_types', nargs='+', + predict_parser.add_argument('--output_types', nargs='+', default=['multi'], help='what to output: json keypoints skeleton for Pifpaf' 'json bird front or multi for MonStereo') predict_parser.add_argument('--no_save', help='to show images', action='store_true') @@ -65,7 +65,6 @@ def cli(): type=float, default=5.7) # Social distancing and social interactions - predict_parser.add_argument('--social_distance', help='social', action='store_true') predict_parser.add_argument('--threshold_prob', type=float, help='concordance for samples', default=0.25) predict_parser.add_argument('--threshold_dist', type=float, help='min distance of people', default=2.5) predict_parser.add_argument('--radii', type=tuple, help='o-space radii', default=(0.3, 0.5, 1)) @@ -137,8 +136,6 @@ def main(): from .visuals.webcam import webcam webcam(args) else: - if args.output_types is None: - args.output_types = ['json'] from .predict import predict predict(args) diff --git a/monoloco/visuals/pifpaf_show.py b/monoloco/visuals/pifpaf_show.py index 21bd1ec..0e2a807 100644 --- a/monoloco/visuals/pifpaf_show.py +++ b/monoloco/visuals/pifpaf_show.py @@ -93,9 +93,11 @@ class KeypointPainter: self.solid_threshold = solid_threshold self.dashed_threshold = 0.1 # Patch to still allow force complete pose (set to zero to resume original) - def _draw_skeleton(self, ax, x, y, v, *, size=None, color=None, raise_hand='none'): - if not np.any(v > 0): - return + + def _highlighted_arm(self, x, y, connection, color, lwidth, raise_hand, size=None): + + c = color + linewidth = lwidth width, height = (1,1) if size: @@ -105,18 +107,32 @@ class KeypointPainter: l_arm_width = np.sqrt(((x[9]-x[7])/width)**2 + ((y[9]-y[7])/height)**2)*100 r_arm_width = np.sqrt(((x[10]-x[8])/width)**2 + ((y[10]-y[8])/height)**2)*100 + if ((connection[0] == 5 and connection[1] == 7) + or (connection[0] == 7 and connection[1] == 9)) and raise_hand in ['left','both']: + c = 'yellow' + linewidth = l_arm_width + if ((connection[0] == 6 and connection[1] == 8) + or (connection[0] == 8 and connection[1] == 10)) and raise_hand in ['right', 'both']: + c = 'yellow' + linewidth = r_arm_width + + return c, linewidth + + + def _draw_skeleton(self, ax, x, y, v, i, *, size=None, color=None, activities=None, dic_out=None): + if not np.any(v > 0): + return + if self.skeleton is not None: for ci, connection in enumerate(np.array(self.skeleton) - 1): c = color - linewidth=self.linewidth - if ((connection[0] == 5 and connection[1] == 7) - or (connection[0] == 7 and connection[1] == 9)) and raise_hand in ['left','both']: - c = 'yellow' - linewidth = l_arm_width - if ((connection[0] == 6 and connection[1] == 8) - or (connection[0] == 8 and connection[1] == 10)) and raise_hand in ['right', 'both']: - c = 'yellow' - linewidth = r_arm_width + linewidth = self.linewidth + + if activities: + if 'raise_hand' in activities: + c, linewidth = self._highlighted_arm(x, y, connection, c, linewidth, + dic_out['raising_hand'][:][i], size=size) + if self.color_connections: c = matplotlib.cm.get_cmap('tab20')(ci / len(self.skeleton)) if np.all(v[connection] > self.dashed_threshold): @@ -193,7 +209,8 @@ class KeypointPainter: (x - scale, y - scale), 2 * scale, 2 * scale, fill=False, color=color)) def keypoints(self, ax, keypoint_sets, *, - size=None, scores=None, color=None, colors=None, texts=None, raise_hand='none'): + size=None, scores=None, color=None, + colors=None, texts=None, activities=None, dic_out=None): if keypoint_sets is None: return @@ -214,12 +231,8 @@ class KeypointPainter: if isinstance(color, (int, np.integer)): color = matplotlib.cm.get_cmap('tab20')((color % 20 + 0.05) / 20) - if raise_hand != 'none': - # if raise_hand[:][i] is 'both' or raise_hand[:][i] is 'left' or raise_hand[:][i] is 'right': - # color = 'green' - self._draw_skeleton(ax, x, y, v, size=size, color=color, raise_hand=raise_hand[:][i]) - else: - self._draw_skeleton(ax, x, y, v, color=color) + self._draw_skeleton(ax, x, y, v, i, size=size, color=color, activities=activities, dic_out=dic_out) + score = scores[i] if scores is not None else None if score is not None: z_str = str(score).split(sep='.') diff --git a/monoloco/visuals/printer.py b/monoloco/visuals/printer.py index 6161407..4b5793d 100644 --- a/monoloco/visuals/printer.py +++ b/monoloco/visuals/printer.py @@ -52,7 +52,6 @@ class Printer: boxes_gt, uv_camera, radius, auxs = nones(16) def __init__(self, image, output_path, kk, args): - self.im = image self.width = self.im.size[0] self.height = self.im.size[1] @@ -60,11 +59,12 @@ class Printer: self.kk = kk self.output_types = args.output_types self.z_max = args.z_max # set max distance to show instances - self.show_all = args.show_all or args.webcam - self.show = args.show_all or args.webcam - self.save = not args.no_save and not args.webcam - self.plt_close = not args.webcam - self.args = args + self.webcam = args.webcam + self.show_all = args.show_all or self.webcam + self.show = args.show_all or self.webcam + self.save = not args.no_save and not self.webcam + self.plt_close = not self.webcam + self.activities = args.activities # define image attributes self.attr = image_attributes(args.dpi, args.output_types) @@ -177,33 +177,36 @@ class Printer: return figures, axes - def social_distance_front(self, axis, colors, annotations, dic_out): + def _webcam_front(self, axis, colors, activities, annotations, dic_out): sizes = [abs(self.centers[idx][1] - uv_s[1]*self.y_scale) / 1.5 for idx, uv_s in enumerate(self.uv_shoulders)] keypoint_sets, _ = get_pifpaf_outputs(annotations) keypoint_painter = KeypointPainter(show_box=False, y_scale=self.y_scale) - r_h = 'none' - if 'raise_hand' in self.args.activities: - r_h = dic_out['raising_hand'] - keypoint_painter.keypoints( - axis, keypoint_sets, size=self.im.size,scores=self.dd_pred,colors=colors, raise_hand=r_h) - draw_orientation(axis, self.centers, - sizes, self.angles, colors, mode='front') + + if activities: + keypoint_painter.keypoints( + axis, keypoint_sets, size=self.im.size, + scores=self.dd_pred, colors=colors, activities=activities, dic_out=dic_out) + + if 'social_distance' in activities: + draw_orientation(axis, self.centers, + sizes, self.angles, colors, mode='front') + else: + keypoint_painter.keypoints( + axis, keypoint_sets, size=self.im.size, scores=self.dd_pred) - def social_distance_bird(self, axis, colors): - draw_orientation(axis, self.xz_centers, [], self.angles, colors, mode='bird') + def _activities_bird(self, axis, colors, activities): + if 'social_distance' in activities: + draw_orientation(axis, self.xz_centers, [], self.angles, colors, mode='bird') def _front_loop(self, iterator, axes, number, colors, annotations, dic_out): for idx in iterator: if any(xx in self.output_types for xx in ['front', 'multi']) and self.zz_pred[idx] > 0: - if self.args.activities: - if 'social_distance' in self.args.activities: - self.social_distance_front(axes[0], colors, annotations, dic_out) - elif 'raise_hand' in self.args.activities: - self.social_distance_front(axes[0], colors, annotations, dic_out) + if self.webcam: + self._webcam_front(axes[0], colors, self.activities, annotations, dic_out) else: self._draw_front(axes[0], self.dd_pred[idx], @@ -215,10 +218,8 @@ class Printer: def _bird_loop(self, iterator, axes, colors, number): for idx in iterator: if any(xx in self.output_types for xx in ['bird', 'multi']) and self.zz_pred[idx] > 0: - - if self.args.activities: - if 'social_distance' in self.args.activities: - self.social_distance_bird(axes[1], colors) + if self.activities: + self._activities_bird(axes[1], colors, self.activities) # Draw ground truth and uncertainty self._draw_uncertainty(axes, idx) @@ -231,9 +232,9 @@ class Printer: def draw(self, figures, axes, image, dic_out=None, annotations=None): colors = [] - if self.args.activities: + if self.activities: colors = ['deepskyblue' for _ in self.uv_heads] - if 'social_distance' in self.args.activities: + if 'social_distance' in self.activities: colors = social_distance_colors(colors, dic_out) # whether to include instances that don't match the ground-truth @@ -246,7 +247,7 @@ class Printer: if any(xx in self.output_types for xx in ['front', 'multi']): number['flag'] = True # add numbers # Remove image if social distance is activated - if not self.args.activities or 'social_distance' not in self.args.activities: + if not self.activities or 'social_distance' not in self.activities: self.mpl_im0.set_data(image) self._front_loop(iterator, axes, number, colors, annotations, dic_out) @@ -420,7 +421,7 @@ class Printer: ax.set_axis_off() ax.set_xlim(0, self.width) ax.set_ylim(self.height, 0) - if not self.args.activities or 'social_distance' not in self.args.activities: + if not self.activities or 'social_distance' not in self.activities: self.mpl_im0 = ax.imshow(self.im) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) @@ -428,8 +429,7 @@ class Printer: else: uv_max = [0., float(self.height)] xyz_max = pixel_to_camera(uv_max, self.kk, self.z_max) - x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk - x_max=6 + x_max = max(abs(xyz_max[0]), 6) # shortcut to avoid oval circles in case of different kk corr = round(float(x_max / 3)) ax.plot([0, x_max], [0, self.z_max], 'k--') ax.plot([0, -x_max], [0, self.z_max], 'k--') diff --git a/monoloco/visuals/webcam.py b/monoloco/visuals/webcam.py index 8edab16..94bcc0b 100644 --- a/monoloco/visuals/webcam.py +++ b/monoloco/visuals/webcam.py @@ -36,14 +36,7 @@ def factory_from_args(args): logger.configure(args, LOG) # logger first - if args.output_types is None: - args.output_types = ['multi'] - - assert 'bird' not in args.output_types - if 'json' not in args.output_types: - assert len(args.output_types) == 1 - else: - assert len(args.output_types) < 3 + assert len(args.output_types) == 1 and 'json' not in args.output_types # Devices args.device = torch.device('cpu') @@ -129,8 +122,7 @@ def webcam(args): print("Escape hit, closing...") break - intrinsic_size = [xx * 1.3 for xx in pil_image.size] - kk, dic_gt = factory_for_gt(intrinsic_size, focal_length=args.focal) # better intrinsics for mac camera + kk, dic_gt = factory_for_gt(pil_image.size, focal_length=args.focal) boxes, keypoints = preprocess_pifpaf( pifpaf_outs['left'], (width, height)) diff --git a/tests/test_train_mono.py b/tests/test_train_mono.py index 964ea3d..af08a10 100644 --- a/tests/test_train_mono.py +++ b/tests/test_train_mono.py @@ -33,7 +33,7 @@ PREDICT_COMMAND_SOCIAL_DISTANCE = [ 'python3', '-m', 'monoloco.run', 'predict', 'docs/frame0032.jpg', - '--social_distance', + '--activities', 'social_distance', '--output_types', 'front', 'bird', '--decoder-workers=0' # for windows' ]