From e3cc03a5459d43762a40c6ec2c74bdc45ef82001 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Tue, 2 Jul 2019 10:36:28 +0200 Subject: [PATCH] refactor printer --- src/predict/factory.py | 6 +-- src/visuals/printer.py | 99 ++++++++++++++++++++++-------------------- 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/src/predict/factory.py b/src/predict/factory.py index 8c94b07..ae636d4 100644 --- a/src/predict/factory.py +++ b/src/predict/factory.py @@ -76,10 +76,10 @@ def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, dic_out=N epistemic = True if dic_out['boxes']: # Only print in case of detections - printer = Printer(images_outputs[1], output_path, kk, output_types=args.output_types, - show=args.show, z_max=args.z_max, epistemic=epistemic) + printer = Printer(images_outputs[1], output_path, kk, output_types=args.output_types + , z_max=args.z_max, epistemic=epistemic) figures, axes = printer.factory_axes() - printer.draw(figures, axes, dic_out) + printer.draw(figures, axes, dic_out, images_outputs[1], save=True, show=args.show) if 'json' in args.output_types: with open(os.path.join(output_path + '.monoloco.json'), 'w') as ff: diff --git a/src/visuals/printer.py b/src/visuals/printer.py index 74a8b73..67fce25 100644 --- a/src/visuals/printer.py +++ b/src/visuals/printer.py @@ -2,7 +2,6 @@ import math from collections import OrderedDict import numpy as np -import cv2 import matplotlib import matplotlib.pyplot as plt @@ -24,27 +23,28 @@ class Printer: TEXTCOLOR = 'darkorange' COLOR_KPS = 'yellow' - def __init__(self, image, output_path, kk, output_types, show=False, - text=True, legend=True, epistemic=False, z_max=30, fig_width=10): + def __init__(self, image, output_path, kk, output_types, text=True, legend=True, epistemic=False, + z_max=30, fig_width=10): self.im = image self.kk = kk self.output_types = output_types - self.show = show self.text = text self.epistemic = epistemic self.legend = legend self.z_max = z_max # To include ellipses in the image self.y_scale = 1 - self.ww = self.im.size[0] - self.hh = self.im.size[1] + self.width = self.im.size[0] + self.height = self.im.size[1] self.fig_width = fig_width # Define the output dir self.path_out = output_path self.cmap = cm.get_cmap('jet') + self.extensions = [] + self.mpl_im0 = None - def _process_input(self, dic_ann): + def _process_results(self, dic_ann): # Include the vectors inside the interval given by z_max self.stds_ale = dic_ann['stds_ale'] self.stds_ale_epi = dic_ann['stds_epi'] @@ -60,57 +60,60 @@ class Printer: self.uv_kps = dic_ann['uv_kps'] self.uv_camera = (int(self.im.size[0] / 2), self.im.size[1]) - self.radius = 14 / 1600 * self.ww - self.ext = ".png" + self.radius = 14 / 1600 * self.width def factory_axes(self): + """Create axes for figures: front bird combined""" axes = [] figures = [] - self.mpl_im0 = None - # Resize image for aesthetic proportions in combined visualization + + # Initialize combined figure, resizing it for aesthetic proportions if 'combined' in self.output_types: - self.y_scale = self.ww / (self.hh * 1.8) # Defined proportion - self.im = self.im.resize((self.ww, round(self.hh * self.y_scale))) - self.ww = self.im.size[0] - self.hh = self.im.size[1] + assert 'bird' and 'front' not in self.output_types, \ + "combined figure cannot be print together with front or bird ones" + + self.y_scale = self.width / (self.height * 1.8) # Defined proportion + self.im = self.im.resize((self.width, round(self.height * self.y_scale))) + self.width = self.im.size[0] + self.height = self.im.size[1] fig_width = self.fig_width + 0.6 * self.fig_width - fig_height = self.fig_width * self.hh / self.ww - + fig_height = self.fig_width * self.height / self.width + # Distinguish between KITTI images and general images if self.y_scale > 1.7: fig_ar_1 = 1.7 else: fig_ar_1 = 1.3 width_ratio = 1.9 - self.ext = '.combined.png' + self.extensions.append('.combined.png') fig, (ax1, ax0) = plt.subplots(1, 2, sharey=False, gridspec_kw={'width_ratios': [1, width_ratio]}, figsize=(fig_width, fig_height)) ax1.set_aspect(fig_ar_1) fig.set_tight_layout(True) fig.subplots_adjust(left=0.02, right=0.98, bottom=0, top=1, hspace=0, wspace=0.02) - + figures.append(fig) assert 'front' not in self.output_types and 'bird' not in self.output_types, \ "--combined arguments is not supported with other visualizations" - # Initialize front + # Initialize front figure elif 'front' in self.output_types: width = self.fig_width - height = self.fig_width * self.hh / self.ww - self.ext = ".front.png" + height = self.fig_width * self.height / self.width + self.extensions.append(".front.png") plt.figure(0) fig0, ax0 = plt.subplots(1, 1, figsize=(width, height)) fig0.set_tight_layout(True) - + figures.append(fig0) - # Create front + # Create front figure axis if any(xx in self.output_types for xx in ['front', 'combined']): ax0.set_axis_off() - ax0.set_xlim(0, self.ww) - ax0.set_ylim(self.hh, 0) + ax0.set_xlim(0, self.width) + ax0.set_ylim(self.height, 0) self.mpl_im0 = ax0.imshow(self.im) z_min = 0 bar_ticks = self.z_max // 5 + 1 @@ -121,32 +124,25 @@ class Printer: cax = divider.append_axes('right', size='3%', pad=0.05) norm = matplotlib.colors.Normalize(vmin=z_min, vmax=self.z_max) - sm = plt.cm.ScalarMappable(cmap=self.cmap, norm=norm) - sm.set_array([]) - plt.colorbar(sm, ticks=np.linspace(z_min, self.z_max, bar_ticks), + scalar_mappable = plt.cm.ScalarMappable(cmap=self.cmap, norm=norm) + scalar_mappable.set_array([]) + plt.colorbar(scalar_mappable, ticks=np.linspace(z_min, self.z_max, bar_ticks), boundaries=np.arange(z_min - 0.05, self.z_max + 0.1, .1), cax=cax, label='Z [m]') - + axes.append(ax0) - if len(axes) == 0: + if not axes: axes.append(None) if 'bird' in self.output_types: - self.ext = ".bird.png" #TODO multiple savings external - plt.figure(1) + self.extensions.append(".bird.png") fig1, ax1 = plt.subplots(1, 1) fig1.set_tight_layout(True) figures.append(fig1) if any(xx in self.output_types for xx in ['bird', 'combined']): - uv_max = [0., float(self.hh)] + uv_max = [0., float(self.height)] xyz_max = pixel_to_camera(uv_max, self.kk, self.z_max) x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk - # To avoid repetitions in the legend - if self.legend: - handles, labels = ax1.get_legend_handles_labels() - by_label = OrderedDict(zip(labels, handles)) - ax1.legend(by_label.values(), by_label.keys()) - # Adding field of view ax1.plot([0, x_max], [0, self.z_max], 'k--') ax1.plot([0, -x_max], [0, self.z_max], 'k--') @@ -156,9 +152,9 @@ class Printer: axes.append(ax1) return figures, axes - def draw(self, figures, axes, dic_out, image, save=False): + def draw(self, figures, axes, dic_out, image, save=False, show=False): - self._process_input(dic_out) + self._process_results(dic_out) num = 0 if any(xx in self.output_types for xx in ['front', 'combined']): self.mpl_im0.set_data(image) @@ -203,17 +199,26 @@ class Printer: axes[1].plot(self.xx_pred[idx], self.zz_pred[idx], 'ro', label="Predicted", markersize=3) + # Setup the legend to avoid repetitions + if self.legend: + handles, labels = axes[1].get_legend_handles_labels() + by_label = OrderedDict(zip(labels, handles)) + axes[1].legend(by_label.values(), by_label.keys()) + # Plot the number - (_, x_pos), (_, z_pos) = get_confidence(self.xx_pred[idx], self.zz_pred[idx], self.stds_ale_epi[idx]) + (_, x_pos), (_, z_pos) = get_confidence(self.xx_pred[idx], self.zz_pred[idx], + self.stds_ale_epi[idx]) if self.text: axes[1].text(x_pos, z_pos, str(num), fontsize=self.FONTSIZE_BV, color='darkorange') num += 1 - for fig in figures: - fig.canvas.draw() - if save: - plt.savefig(self.path_out + self.ext, bbox_inches='tight') + for idx, fig in enumerate(figures): + fig.canvas.draw() + if save: + fig.savefig(self.path_out + self.extensions[idx], bbox_inches='tight') + if show: + fig.show() def get_confidence(xx, zz, std):