refactor printer

This commit is contained in:
lorenzo 2019-07-02 10:36:28 +02:00
parent 3b7de1e592
commit e3cc03a545
2 changed files with 55 additions and 50 deletions

View File

@ -76,10 +76,10 @@ def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, dic_out=N
epistemic = True epistemic = True
if dic_out['boxes']: # Only print in case of detections if dic_out['boxes']: # Only print in case of detections
printer = Printer(images_outputs[1], output_path, kk, output_types=args.output_types, printer = Printer(images_outputs[1], output_path, kk, output_types=args.output_types
show=args.show, z_max=args.z_max, epistemic=epistemic) , z_max=args.z_max, epistemic=epistemic)
figures, axes = printer.factory_axes() figures, axes = printer.factory_axes()
printer.draw(figures, axes, dic_out) printer.draw(figures, axes, dic_out, images_outputs[1], save=True, show=args.show)
if 'json' in args.output_types: if 'json' in args.output_types:
with open(os.path.join(output_path + '.monoloco.json'), 'w') as ff: with open(os.path.join(output_path + '.monoloco.json'), 'w') as ff:

View File

@ -2,7 +2,6 @@
import math import math
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
import cv2
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -24,27 +23,28 @@ class Printer:
TEXTCOLOR = 'darkorange' TEXTCOLOR = 'darkorange'
COLOR_KPS = 'yellow' COLOR_KPS = 'yellow'
def __init__(self, image, output_path, kk, output_types, show=False, def __init__(self, image, output_path, kk, output_types, text=True, legend=True, epistemic=False,
text=True, legend=True, epistemic=False, z_max=30, fig_width=10): z_max=30, fig_width=10):
self.im = image self.im = image
self.kk = kk self.kk = kk
self.output_types = output_types self.output_types = output_types
self.show = show
self.text = text self.text = text
self.epistemic = epistemic self.epistemic = epistemic
self.legend = legend self.legend = legend
self.z_max = z_max # To include ellipses in the image self.z_max = z_max # To include ellipses in the image
self.y_scale = 1 self.y_scale = 1
self.ww = self.im.size[0] self.width = self.im.size[0]
self.hh = self.im.size[1] self.height = self.im.size[1]
self.fig_width = fig_width self.fig_width = fig_width
# Define the output dir # Define the output dir
self.path_out = output_path self.path_out = output_path
self.cmap = cm.get_cmap('jet') self.cmap = cm.get_cmap('jet')
self.extensions = []
self.mpl_im0 = None
def _process_input(self, dic_ann): def _process_results(self, dic_ann):
# Include the vectors inside the interval given by z_max # Include the vectors inside the interval given by z_max
self.stds_ale = dic_ann['stds_ale'] self.stds_ale = dic_ann['stds_ale']
self.stds_ale_epi = dic_ann['stds_epi'] self.stds_ale_epi = dic_ann['stds_epi']
@ -60,21 +60,24 @@ class Printer:
self.uv_kps = dic_ann['uv_kps'] self.uv_kps = dic_ann['uv_kps']
self.uv_camera = (int(self.im.size[0] / 2), self.im.size[1]) self.uv_camera = (int(self.im.size[0] / 2), self.im.size[1])
self.radius = 14 / 1600 * self.ww self.radius = 14 / 1600 * self.width
self.ext = ".png"
def factory_axes(self): def factory_axes(self):
"""Create axes for figures: front bird combined"""
axes = [] axes = []
figures = [] figures = []
self.mpl_im0 = None
# Resize image for aesthetic proportions in combined visualization # Initialize combined figure, resizing it for aesthetic proportions
if 'combined' in self.output_types: if 'combined' in self.output_types:
self.y_scale = self.ww / (self.hh * 1.8) # Defined proportion assert 'bird' and 'front' not in self.output_types, \
self.im = self.im.resize((self.ww, round(self.hh * self.y_scale))) "combined figure cannot be print together with front or bird ones"
self.ww = self.im.size[0]
self.hh = self.im.size[1] self.y_scale = self.width / (self.height * 1.8) # Defined proportion
self.im = self.im.resize((self.width, round(self.height * self.y_scale)))
self.width = self.im.size[0]
self.height = self.im.size[1]
fig_width = self.fig_width + 0.6 * self.fig_width fig_width = self.fig_width + 0.6 * self.fig_width
fig_height = self.fig_width * self.hh / self.ww fig_height = self.fig_width * self.height / self.width
# Distinguish between KITTI images and general images # Distinguish between KITTI images and general images
if self.y_scale > 1.7: if self.y_scale > 1.7:
@ -82,7 +85,7 @@ class Printer:
else: else:
fig_ar_1 = 1.3 fig_ar_1 = 1.3
width_ratio = 1.9 width_ratio = 1.9
self.ext = '.combined.png' self.extensions.append('.combined.png')
fig, (ax1, ax0) = plt.subplots(1, 2, sharey=False, gridspec_kw={'width_ratios': [1, width_ratio]}, fig, (ax1, ax0) = plt.subplots(1, 2, sharey=False, gridspec_kw={'width_ratios': [1, width_ratio]},
figsize=(fig_width, fig_height)) figsize=(fig_width, fig_height))
@ -94,23 +97,23 @@ class Printer:
assert 'front' not in self.output_types and 'bird' not in self.output_types, \ assert 'front' not in self.output_types and 'bird' not in self.output_types, \
"--combined arguments is not supported with other visualizations" "--combined arguments is not supported with other visualizations"
# Initialize front # Initialize front figure
elif 'front' in self.output_types: elif 'front' in self.output_types:
width = self.fig_width width = self.fig_width
height = self.fig_width * self.hh / self.ww height = self.fig_width * self.height / self.width
self.ext = ".front.png" self.extensions.append(".front.png")
plt.figure(0) plt.figure(0)
fig0, ax0 = plt.subplots(1, 1, figsize=(width, height)) fig0, ax0 = plt.subplots(1, 1, figsize=(width, height))
fig0.set_tight_layout(True) fig0.set_tight_layout(True)
figures.append(fig0) figures.append(fig0)
# Create front # Create front figure axis
if any(xx in self.output_types for xx in ['front', 'combined']): if any(xx in self.output_types for xx in ['front', 'combined']):
ax0.set_axis_off() ax0.set_axis_off()
ax0.set_xlim(0, self.ww) ax0.set_xlim(0, self.width)
ax0.set_ylim(self.hh, 0) ax0.set_ylim(self.height, 0)
self.mpl_im0 = ax0.imshow(self.im) self.mpl_im0 = ax0.imshow(self.im)
z_min = 0 z_min = 0
bar_ticks = self.z_max // 5 + 1 bar_ticks = self.z_max // 5 + 1
@ -121,32 +124,25 @@ class Printer:
cax = divider.append_axes('right', size='3%', pad=0.05) cax = divider.append_axes('right', size='3%', pad=0.05)
norm = matplotlib.colors.Normalize(vmin=z_min, vmax=self.z_max) norm = matplotlib.colors.Normalize(vmin=z_min, vmax=self.z_max)
sm = plt.cm.ScalarMappable(cmap=self.cmap, norm=norm) scalar_mappable = plt.cm.ScalarMappable(cmap=self.cmap, norm=norm)
sm.set_array([]) scalar_mappable.set_array([])
plt.colorbar(sm, ticks=np.linspace(z_min, self.z_max, bar_ticks), plt.colorbar(scalar_mappable, ticks=np.linspace(z_min, self.z_max, bar_ticks),
boundaries=np.arange(z_min - 0.05, self.z_max + 0.1, .1), cax=cax, label='Z [m]') boundaries=np.arange(z_min - 0.05, self.z_max + 0.1, .1), cax=cax, label='Z [m]')
axes.append(ax0) axes.append(ax0)
if len(axes) == 0: if not axes:
axes.append(None) axes.append(None)
if 'bird' in self.output_types: if 'bird' in self.output_types:
self.ext = ".bird.png" #TODO multiple savings external self.extensions.append(".bird.png")
plt.figure(1)
fig1, ax1 = plt.subplots(1, 1) fig1, ax1 = plt.subplots(1, 1)
fig1.set_tight_layout(True) fig1.set_tight_layout(True)
figures.append(fig1) figures.append(fig1)
if any(xx in self.output_types for xx in ['bird', 'combined']): if any(xx in self.output_types for xx in ['bird', 'combined']):
uv_max = [0., float(self.hh)] uv_max = [0., float(self.height)]
xyz_max = pixel_to_camera(uv_max, self.kk, self.z_max) xyz_max = pixel_to_camera(uv_max, self.kk, self.z_max)
x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk
# To avoid repetitions in the legend
if self.legend:
handles, labels = ax1.get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
ax1.legend(by_label.values(), by_label.keys())
# Adding field of view # Adding field of view
ax1.plot([0, x_max], [0, self.z_max], 'k--') ax1.plot([0, x_max], [0, self.z_max], 'k--')
ax1.plot([0, -x_max], [0, self.z_max], 'k--') ax1.plot([0, -x_max], [0, self.z_max], 'k--')
@ -156,9 +152,9 @@ class Printer:
axes.append(ax1) axes.append(ax1)
return figures, axes return figures, axes
def draw(self, figures, axes, dic_out, image, save=False): def draw(self, figures, axes, dic_out, image, save=False, show=False):
self._process_input(dic_out) self._process_results(dic_out)
num = 0 num = 0
if any(xx in self.output_types for xx in ['front', 'combined']): if any(xx in self.output_types for xx in ['front', 'combined']):
self.mpl_im0.set_data(image) self.mpl_im0.set_data(image)
@ -203,17 +199,26 @@ class Printer:
axes[1].plot(self.xx_pred[idx], self.zz_pred[idx], 'ro', label="Predicted", markersize=3) axes[1].plot(self.xx_pred[idx], self.zz_pred[idx], 'ro', label="Predicted", markersize=3)
# Setup the legend to avoid repetitions
if self.legend:
handles, labels = axes[1].get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
axes[1].legend(by_label.values(), by_label.keys())
# Plot the number # Plot the number
(_, x_pos), (_, z_pos) = get_confidence(self.xx_pred[idx], self.zz_pred[idx], self.stds_ale_epi[idx]) (_, x_pos), (_, z_pos) = get_confidence(self.xx_pred[idx], self.zz_pred[idx],
self.stds_ale_epi[idx])
if self.text: if self.text:
axes[1].text(x_pos, z_pos, str(num), fontsize=self.FONTSIZE_BV, color='darkorange') axes[1].text(x_pos, z_pos, str(num), fontsize=self.FONTSIZE_BV, color='darkorange')
num += 1 num += 1
for fig in figures:
fig.canvas.draw()
if save: for idx, fig in enumerate(figures):
plt.savefig(self.path_out + self.ext, bbox_inches='tight') fig.canvas.draw()
if save:
fig.savefig(self.path_out + self.extensions[idx], bbox_inches='tight')
if show:
fig.show()
def get_confidence(xx, zz, std): def get_confidence(xx, zz, std):