diff --git a/monstereo/run.py b/monstereo/run.py index e6ba121..7d0ef06 100644 --- a/monstereo/run.py +++ b/monstereo/run.py @@ -41,6 +41,7 @@ def cli(): predict_parser.add_argument('--no_save', help='to show images', action='store_true') predict_parser.add_argument('--show', help='to show images', action='store_true') predict_parser.add_argument('--dpi', help='image resolution', type=int, default=100) + predict_parser.add_argument('--force-complete-pose', help='', action ='store_true') # Pifpaf openpifpaf_cli(predict_parser) diff --git a/monstereo/visuals/printer.py b/monstereo/visuals/printer.py index cd6bd8f..70d5eb4 100644 --- a/monstereo/visuals/printer.py +++ b/monstereo/visuals/printer.py @@ -87,6 +87,15 @@ class Printer: self.boxes_gt = dic_ann['boxes_gt'] self.uv_camera = (int(self.im.size[0] / 2), self.im.size[1]) self.auxs = dic_ann['aux'] + if len(self.auxs) == 0: + self.modes = ['mono'] * len(self.dd_pred) + else: + self.modes = [] + for aux in self.auxs: + if aux <= 0.3: + self.modes.append('mono') + else: + self.modes.append('stereo') def factory_axes(self): """Create axes for figures: front bird multi""" @@ -198,13 +207,6 @@ class Printer: def _draw_front(self, ax, z, idx, number): - if len(self.auxs) == 0: - mode = 'mono' - elif self.auxs[idx] <= 0.3: - mode = 'mono' - else: - mode = 'stereo' - # Bbox w = min(self.width-2, self.boxes[idx][2] - self.boxes[idx][0]) h = min(self.height-2, (self.boxes[idx][3] - self.boxes[idx][1]) * self.y_scale) @@ -215,12 +217,12 @@ class Printer: width=w, height=h, fill=False, - color=self.attr[mode]['color'], - linewidth=self.attr[mode]['linewidth']) + color=self.attr[self.modes[idx]]['color'], + linewidth=self.attr[self.modes[idx]]['linewidth']) ax.add_patch(rectangle) z_str = str(z).split(sep='.') text = z_str[0] + '.' + z_str[1][0] - bbox_config = {'facecolor': self.attr[mode]['color'], 'alpha': 0.4, 'linewidth': 0} + bbox_config = {'facecolor': self.attr[self.modes[idx]]['color'], 'alpha': 0.4, 'linewidth': 0} x_t = x0 - 1.5 y_t = y1 + self.attr['y_box_margin'] @@ -240,12 +242,12 @@ class Printer: y1 + 14, chr(number['num']), fontsize=self.attr['fontsize_num'], - color=self.attr[mode]['numcolor'], + color=self.attr[self.modes[idx]]['numcolor'], weight='bold') def _draw_text_bird(self, axes, idx, num): """Plot the number in the bird eye view map""" - mode = 'stereo' if self.auxs[idx] > 0.3 else 'mono' + std = self.stds_epi[idx] if self.stds_epi[idx] > 0 else self.stds_ale[idx] theta = math.atan2(self.zz_pred[idx], self.xx_pred[idx]) @@ -254,7 +256,7 @@ class Printer: axes[1].text(self.xx_pred[idx] + delta_x + 0.2, self.zz_pred[idx] + delta_z + 0/2, chr(num), fontsize=self.attr['fontsize_bv'], - color=self.attr[mode]['numcolor']) + color=self.attr[self.modes[idx]]['numcolor']) def _draw_uncertainty(self, axes, idx):