diff --git a/monstereo/activity.py b/monstereo/activity.py index 2a343fb..a7a937a 100644 --- a/monstereo/activity.py +++ b/monstereo/activity.py @@ -218,22 +218,22 @@ def show_social(args, image_t, output_path, annotations, dic_out): stds = dic_out['stds_ale'] xz_centers = [[xx[0], xx[2]] for xx in dic_out['xyz_pred']] + # Prepare color for social distancing + colors = ['r' if social_interactions(idx, xz_centers, angles, dds, + stds=stds, + threshold_prob=args.threshold_prob, + threshold_dist=args.threshold_dist, + radii=args.radii) + else 'deepskyblue' + for idx, _ in enumerate(dic_out['xyz_pred'])] + if 'front' in args.output_types: # Resize back the tensor image to its original dimensions - if not 0.99 < args.scale < 1.01: - size = (round(image_t.shape[0] / args.scale), round(image_t.shape[1] / args.scale)) # height width - image_t = image_t.permute(2, 0, 1).unsqueeze(0) # batch x channels x height x width - image_t = F.interpolate(image_t, size=size).squeeze().permute(1, 2, 0) - - # Prepare color for social distancing - colors = ['r' if social_interactions(idx, xz_centers, angles, dds, - stds=stds, - threshold_prob=args.threshold_prob, - threshold_dist=args.threshold_dist, - radii=args.radii) - else 'deepskyblue' - for idx, _ in enumerate(dic_out['xyz_pred'])] + # if not 0.99 < args.scale < 1.01: + # size = (round(image_t.shape[0] / args.scale), round(image_t.shape[1] / args.scale)) # height width + # image_t = image_t.permute(2, 0, 1).unsqueeze(0) # batch x channels x height x width + # image_t = F.interpolate(image_t, size=size).squeeze().permute(1, 2, 0) # Draw keypoints and orientation keypoint_sets, scores = get_pifpaf_outputs(annotations) diff --git a/monstereo/predict.py b/monstereo/predict.py index 11976e3..c46c0df 100644 --- a/monstereo/predict.py +++ b/monstereo/predict.py @@ -135,7 +135,9 @@ def predict(args): dic_out = net.forward(keypoints, kk) reorder = False if args.social_distance else True dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt, reorder=reorder) + if args.social_distance: + # image_t = torchvision.transforms.functional.to_tensor(image).permute(1, 2, 0) show_social(args, cpu_image, output_path, pifpaf_out, dic_out) else: @@ -148,9 +150,9 @@ def predict(args): dic_out = defaultdict(list) kk = None - # TODO Clean - factory_outputs(args, annotation_painter, cpu_image, output_path, pifpaf_outputs, pifpaf_out, - dic_out=dic_out, kk=kk) + if not args.social_distance: + factory_outputs(args, annotation_painter, cpu_image, output_path, pifpaf_outputs, pifpaf_out, + dic_out=dic_out, kk=kk) print('Image {}\n'.format(cnt) + '-' * 120) cnt += 1 @@ -192,7 +194,7 @@ def factory_outputs(args, annotation_painter, cpu_image, output_path, pred, pifp print(output_path) if dic_out['boxes']: # Only print in case of detections printer = Printer(cpu_image, output_path, kk, args) - figures, axes = printer.factory_axes() + figures, axes = printer.factory_axes(dic_out) printer.draw(figures, axes, dic_out, cpu_image) if 'json' in args.output_types: diff --git a/monstereo/visuals/pifpaf_show.py b/monstereo/visuals/pifpaf_show.py index 7a00736..95139f0 100644 --- a/monstereo/visuals/pifpaf_show.py +++ b/monstereo/visuals/pifpaf_show.py @@ -39,21 +39,20 @@ def canvas(fig_file=None, show=True, **kwargs): @contextmanager def image_canvas(image, fig_file=None, show=True, dpi_factor=1.0, fig_width=10.0, **kwargs): if 'figsize' not in kwargs: - kwargs['figsize'] = (fig_width, fig_width * image.shape[0] / image.shape[1]) + kwargs['figsize'] = (fig_width, fig_width * image.size[1] / image.size[0]) fig = plt.figure(**kwargs) ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() - ax.set_xlim(0, image.shape[1]) - ax.set_ylim(image.shape[0], 0) + ax.set_xlim(0, image.size[0]) + ax.set_ylim(image.size[1], 0) fig.add_axes(ax) image_2 = ndimage.gaussian_filter(image, sigma=2.5) ax.imshow(image_2, alpha=0.4) - yield ax if fig_file: - fig.savefig(fig_file, dpi=image.shape[1] / kwargs['figsize'][0] * dpi_factor) + fig.savefig(fig_file, dpi=image.size[0] / kwargs['figsize'][0] * dpi_factor) print('keypoints image saved') if show: plt.show() diff --git a/monstereo/visuals/printer.py b/monstereo/visuals/printer.py index 70d5eb4..fa933de 100644 --- a/monstereo/visuals/printer.py +++ b/monstereo/visuals/printer.py @@ -58,7 +58,7 @@ class Printer: self.output_path = output_path self.kk = kk self.output_types = args.output_types - self.z_max = args.z_max # To include ellipses in the image + self.z_max = args.z_max # set max distance to show instances self.show_all = args.show_all self.show = args.show_all self.save = not args.no_save @@ -74,13 +74,17 @@ class Printer: self.xx_gt = [xx[0] for xx in dic_ann['xyz_real']] self.xx_pred = [xx[0] for xx in dic_ann['xyz_pred']] + # Set maximum distance + self.dd_pred = dic_ann['dds_pred'] + self.dd_real = dic_ann['dds_real'] + self.z_max = int(min(self.z_max + 4, max(max(self.dd_pred), max(self.dd_real, default=0)))) + # Do not print instances outside z_max self.zz_gt = [xx[2] if xx[2] < self.z_max - self.stds_epi[idx] else 0 for idx, xx in enumerate(dic_ann['xyz_real'])] self.zz_pred = [xx[2] if xx[2] < self.z_max - self.stds_epi[idx] else 0 for idx, xx in enumerate(dic_ann['xyz_pred'])] - self.dd_pred = dic_ann['dds_pred'] - self.dd_real = dic_ann['dds_real'] + self.uv_heads = dic_ann['uv_heads'] self.uv_shoulders = dic_ann['uv_shoulders'] self.boxes = dic_ann['boxes'] @@ -97,11 +101,14 @@ class Printer: else: self.modes.append('stereo') - def factory_axes(self): + def factory_axes(self, dic_out): """Create axes for figures: front bird multi""" axes = [] figures = [] + # Process the annotation dictionary of monoloco + self._process_results(dic_out) + # Initialize multi figure, resizing it for aesthetic proportion if 'multi' in self.output_types: assert 'bird' and 'front' not in self.output_types, \ @@ -160,9 +167,6 @@ class Printer: def draw(self, figures, axes, dic_out, image): - # Process the annotation dictionary of monoloco - self._process_results(dic_out) - # whether to include instances that don't match the ground-truth iterator = range(len(self.zz_pred)) if self.show_all else range(len(self.zz_gt)) if not iterator: @@ -171,9 +175,9 @@ class Printer: # Draw the front figure number = dict(flag=False, num=97) - if 'multi' in self.output_types: + if any(xx in self.output_types for xx in ['front', 'multi']): number['flag'] = True # add numbers - self.mpl_im0.set_data(image) + self.mpl_im0.set_data(image) for idx in iterator: if any(xx in self.output_types for xx in ['front', 'multi']) and self.zz_pred[idx] > 0: self._draw_front(axes[0],