refactor printer
This commit is contained in:
parent
3b7de1e592
commit
e3cc03a545
@ -76,10 +76,10 @@ def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, dic_out=N
|
||||
epistemic = True
|
||||
|
||||
if dic_out['boxes']: # Only print in case of detections
|
||||
printer = Printer(images_outputs[1], output_path, kk, output_types=args.output_types,
|
||||
show=args.show, z_max=args.z_max, epistemic=epistemic)
|
||||
printer = Printer(images_outputs[1], output_path, kk, output_types=args.output_types
|
||||
, z_max=args.z_max, epistemic=epistemic)
|
||||
figures, axes = printer.factory_axes()
|
||||
printer.draw(figures, axes, dic_out)
|
||||
printer.draw(figures, axes, dic_out, images_outputs[1], save=True, show=args.show)
|
||||
|
||||
if 'json' in args.output_types:
|
||||
with open(os.path.join(output_path + '.monoloco.json'), 'w') as ff:
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
@ -24,27 +23,28 @@ class Printer:
|
||||
TEXTCOLOR = 'darkorange'
|
||||
COLOR_KPS = 'yellow'
|
||||
|
||||
def __init__(self, image, output_path, kk, output_types, show=False,
|
||||
text=True, legend=True, epistemic=False, z_max=30, fig_width=10):
|
||||
def __init__(self, image, output_path, kk, output_types, text=True, legend=True, epistemic=False,
|
||||
z_max=30, fig_width=10):
|
||||
|
||||
self.im = image
|
||||
self.kk = kk
|
||||
self.output_types = output_types
|
||||
self.show = show
|
||||
self.text = text
|
||||
self.epistemic = epistemic
|
||||
self.legend = legend
|
||||
self.z_max = z_max # To include ellipses in the image
|
||||
self.y_scale = 1
|
||||
self.ww = self.im.size[0]
|
||||
self.hh = self.im.size[1]
|
||||
self.width = self.im.size[0]
|
||||
self.height = self.im.size[1]
|
||||
self.fig_width = fig_width
|
||||
|
||||
# Define the output dir
|
||||
self.path_out = output_path
|
||||
self.cmap = cm.get_cmap('jet')
|
||||
self.extensions = []
|
||||
self.mpl_im0 = None
|
||||
|
||||
def _process_input(self, dic_ann):
|
||||
def _process_results(self, dic_ann):
|
||||
# Include the vectors inside the interval given by z_max
|
||||
self.stds_ale = dic_ann['stds_ale']
|
||||
self.stds_ale_epi = dic_ann['stds_epi']
|
||||
@ -60,57 +60,60 @@ class Printer:
|
||||
self.uv_kps = dic_ann['uv_kps']
|
||||
|
||||
self.uv_camera = (int(self.im.size[0] / 2), self.im.size[1])
|
||||
self.radius = 14 / 1600 * self.ww
|
||||
self.ext = ".png"
|
||||
self.radius = 14 / 1600 * self.width
|
||||
|
||||
def factory_axes(self):
|
||||
"""Create axes for figures: front bird combined"""
|
||||
axes = []
|
||||
figures = []
|
||||
self.mpl_im0 = None
|
||||
# Resize image for aesthetic proportions in combined visualization
|
||||
|
||||
# Initialize combined figure, resizing it for aesthetic proportions
|
||||
if 'combined' in self.output_types:
|
||||
self.y_scale = self.ww / (self.hh * 1.8) # Defined proportion
|
||||
self.im = self.im.resize((self.ww, round(self.hh * self.y_scale)))
|
||||
self.ww = self.im.size[0]
|
||||
self.hh = self.im.size[1]
|
||||
assert 'bird' and 'front' not in self.output_types, \
|
||||
"combined figure cannot be print together with front or bird ones"
|
||||
|
||||
self.y_scale = self.width / (self.height * 1.8) # Defined proportion
|
||||
self.im = self.im.resize((self.width, round(self.height * self.y_scale)))
|
||||
self.width = self.im.size[0]
|
||||
self.height = self.im.size[1]
|
||||
fig_width = self.fig_width + 0.6 * self.fig_width
|
||||
fig_height = self.fig_width * self.hh / self.ww
|
||||
|
||||
fig_height = self.fig_width * self.height / self.width
|
||||
|
||||
# Distinguish between KITTI images and general images
|
||||
if self.y_scale > 1.7:
|
||||
fig_ar_1 = 1.7
|
||||
else:
|
||||
fig_ar_1 = 1.3
|
||||
width_ratio = 1.9
|
||||
self.ext = '.combined.png'
|
||||
self.extensions.append('.combined.png')
|
||||
|
||||
fig, (ax1, ax0) = plt.subplots(1, 2, sharey=False, gridspec_kw={'width_ratios': [1, width_ratio]},
|
||||
figsize=(fig_width, fig_height))
|
||||
ax1.set_aspect(fig_ar_1)
|
||||
fig.set_tight_layout(True)
|
||||
fig.subplots_adjust(left=0.02, right=0.98, bottom=0, top=1, hspace=0, wspace=0.02)
|
||||
|
||||
|
||||
figures.append(fig)
|
||||
assert 'front' not in self.output_types and 'bird' not in self.output_types, \
|
||||
"--combined arguments is not supported with other visualizations"
|
||||
|
||||
# Initialize front
|
||||
# Initialize front figure
|
||||
elif 'front' in self.output_types:
|
||||
width = self.fig_width
|
||||
height = self.fig_width * self.hh / self.ww
|
||||
self.ext = ".front.png"
|
||||
height = self.fig_width * self.height / self.width
|
||||
self.extensions.append(".front.png")
|
||||
plt.figure(0)
|
||||
fig0, ax0 = plt.subplots(1, 1, figsize=(width, height))
|
||||
fig0.set_tight_layout(True)
|
||||
|
||||
|
||||
figures.append(fig0)
|
||||
|
||||
# Create front
|
||||
# Create front figure axis
|
||||
if any(xx in self.output_types for xx in ['front', 'combined']):
|
||||
|
||||
ax0.set_axis_off()
|
||||
ax0.set_xlim(0, self.ww)
|
||||
ax0.set_ylim(self.hh, 0)
|
||||
ax0.set_xlim(0, self.width)
|
||||
ax0.set_ylim(self.height, 0)
|
||||
self.mpl_im0 = ax0.imshow(self.im)
|
||||
z_min = 0
|
||||
bar_ticks = self.z_max // 5 + 1
|
||||
@ -121,32 +124,25 @@ class Printer:
|
||||
cax = divider.append_axes('right', size='3%', pad=0.05)
|
||||
|
||||
norm = matplotlib.colors.Normalize(vmin=z_min, vmax=self.z_max)
|
||||
sm = plt.cm.ScalarMappable(cmap=self.cmap, norm=norm)
|
||||
sm.set_array([])
|
||||
plt.colorbar(sm, ticks=np.linspace(z_min, self.z_max, bar_ticks),
|
||||
scalar_mappable = plt.cm.ScalarMappable(cmap=self.cmap, norm=norm)
|
||||
scalar_mappable.set_array([])
|
||||
plt.colorbar(scalar_mappable, ticks=np.linspace(z_min, self.z_max, bar_ticks),
|
||||
boundaries=np.arange(z_min - 0.05, self.z_max + 0.1, .1), cax=cax, label='Z [m]')
|
||||
|
||||
|
||||
axes.append(ax0)
|
||||
if len(axes) == 0:
|
||||
if not axes:
|
||||
axes.append(None)
|
||||
|
||||
if 'bird' in self.output_types:
|
||||
self.ext = ".bird.png" #TODO multiple savings external
|
||||
plt.figure(1)
|
||||
self.extensions.append(".bird.png")
|
||||
fig1, ax1 = plt.subplots(1, 1)
|
||||
fig1.set_tight_layout(True)
|
||||
figures.append(fig1)
|
||||
if any(xx in self.output_types for xx in ['bird', 'combined']):
|
||||
uv_max = [0., float(self.hh)]
|
||||
uv_max = [0., float(self.height)]
|
||||
xyz_max = pixel_to_camera(uv_max, self.kk, self.z_max)
|
||||
x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk
|
||||
|
||||
# To avoid repetitions in the legend
|
||||
if self.legend:
|
||||
handles, labels = ax1.get_legend_handles_labels()
|
||||
by_label = OrderedDict(zip(labels, handles))
|
||||
ax1.legend(by_label.values(), by_label.keys())
|
||||
|
||||
# Adding field of view
|
||||
ax1.plot([0, x_max], [0, self.z_max], 'k--')
|
||||
ax1.plot([0, -x_max], [0, self.z_max], 'k--')
|
||||
@ -156,9 +152,9 @@ class Printer:
|
||||
axes.append(ax1)
|
||||
return figures, axes
|
||||
|
||||
def draw(self, figures, axes, dic_out, image, save=False):
|
||||
def draw(self, figures, axes, dic_out, image, save=False, show=False):
|
||||
|
||||
self._process_input(dic_out)
|
||||
self._process_results(dic_out)
|
||||
num = 0
|
||||
if any(xx in self.output_types for xx in ['front', 'combined']):
|
||||
self.mpl_im0.set_data(image)
|
||||
@ -203,17 +199,26 @@ class Printer:
|
||||
|
||||
axes[1].plot(self.xx_pred[idx], self.zz_pred[idx], 'ro', label="Predicted", markersize=3)
|
||||
|
||||
# Setup the legend to avoid repetitions
|
||||
if self.legend:
|
||||
handles, labels = axes[1].get_legend_handles_labels()
|
||||
by_label = OrderedDict(zip(labels, handles))
|
||||
axes[1].legend(by_label.values(), by_label.keys())
|
||||
|
||||
# Plot the number
|
||||
(_, x_pos), (_, z_pos) = get_confidence(self.xx_pred[idx], self.zz_pred[idx], self.stds_ale_epi[idx])
|
||||
(_, x_pos), (_, z_pos) = get_confidence(self.xx_pred[idx], self.zz_pred[idx],
|
||||
self.stds_ale_epi[idx])
|
||||
|
||||
if self.text:
|
||||
axes[1].text(x_pos, z_pos, str(num), fontsize=self.FONTSIZE_BV, color='darkorange')
|
||||
num += 1
|
||||
for fig in figures:
|
||||
fig.canvas.draw()
|
||||
|
||||
if save:
|
||||
plt.savefig(self.path_out + self.ext, bbox_inches='tight')
|
||||
for idx, fig in enumerate(figures):
|
||||
fig.canvas.draw()
|
||||
if save:
|
||||
fig.savefig(self.path_out + self.extensions[idx], bbox_inches='tight')
|
||||
if show:
|
||||
fig.show()
|
||||
|
||||
|
||||
def get_confidence(xx, zz, std):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user