""" Class for drawing frontal, bird-eye-view and combined figures """ # pylint: disable=attribute-defined-outside-init import math from collections import OrderedDict import numpy as np import matplotlib import matplotlib.pyplot as plt import matplotlib.cm as cm from matplotlib.patches import Ellipse, Circle, Rectangle from mpl_toolkits.axes_grid1 import make_axes_locatable from ..utils import pixel_to_camera, get_task_error class Printer: """ Print results on images: birds eye view and computed distance """ FONTSIZE_BV = 16 FONTSIZE = 18 TEXTCOLOR = 'darkorange' COLOR_KPS = 'yellow' def __init__(self, image, output_path, kk, output_types, epistemic=False, z_max=30, fig_width=10): self.im = image self.kk = kk self.output_types = output_types self.epistemic = epistemic self.z_max = z_max # To include ellipses in the image self.y_scale = 1 self.width = self.im.size[0] self.height = self.im.size[1] self.fig_width = fig_width # Define the output dir self.output_path = output_path self.cmap = cm.get_cmap('jet') self.extensions = [] # Define variables of the class to change for every image self.mpl_im0 = self.stds_ale = self.stds_epi = self.xx_gt = self.zz_gt = self.xx_pred = self.zz_pred =\ self.dds_real = self.uv_centers = self.uv_shoulders = self.uv_kps = self.boxes = self.boxes_gt = \ self.uv_camera = self.radius = self.auxs = None def _process_results(self, dic_ann): # Include the vectors inside the interval given by z_max self.stds_ale = dic_ann['stds_ale'] self.stds_epi = dic_ann['stds_epi'] self.gt = dic_ann['gt'] # regulate ground-truth matching self.xx_gt = [xx[0] for xx in dic_ann['xyz_real']] self.xx_pred = [xx[0] for xx in dic_ann['xyz_pred']] # Do not print instances outside z_max self.zz_gt = [xx[2] if xx[2] < self.z_max - self.stds_epi[idx] else 0 for idx, xx in enumerate(dic_ann['xyz_real'])] self.zz_pred = [xx[2] if xx[2] < self.z_max - self.stds_epi[idx] else 0 for idx, xx in enumerate(dic_ann['xyz_pred'])] self.dds_real = dic_ann['dds_real'] self.uv_shoulders = dic_ann['uv_shoulders'] self.boxes = dic_ann['boxes'] self.boxes_gt = dic_ann['boxes_gt'] self.uv_camera = (int(self.im.size[0] / 2), self.im.size[1]) self.radius = 11 / 1600 * self.width if dic_ann['aux']: self.auxs = dic_ann['aux'] if dic_ann['aux'] else None def factory_axes(self): """Create axes for figures: front bird combined""" axes = [] figures = [] # Initialize combined figure, resizing it for aesthetic proportions if 'combined' in self.output_types: assert 'bird' and 'front' not in self.output_types, \ "combined figure cannot be print together with front or bird ones" self.y_scale = self.width / (self.height * 2) # Defined proportion if self.y_scale < 0.95 or self.y_scale > 1.05: # allows more variation without resizing self.im = self.im.resize((self.width, round(self.height * self.y_scale))) self.width = self.im.size[0] self.height = self.im.size[1] fig_width = self.fig_width + 0.6 * self.fig_width fig_height = self.fig_width * self.height / self.width # Distinguish between KITTI images and general images fig_ar_1 = 0.8 width_ratio = 1.9 self.extensions.append('.combined.png') fig, (ax0, ax1) = plt.subplots(1, 2, sharey=False, gridspec_kw={'width_ratios': [width_ratio, 1]}, figsize=(fig_width, fig_height)) ax1.set_aspect(fig_ar_1) fig.set_tight_layout(True) fig.subplots_adjust(left=0.02, right=0.98, bottom=0, top=1, hspace=0, wspace=0.02) figures.append(fig) assert 'front' not in self.output_types and 'bird' not in self.output_types, \ "--combined arguments is not supported with other visualizations" # Initialize front figure elif 'front' in self.output_types: width = self.fig_width height = self.fig_width * self.height / self.width self.extensions.append(".front.png") plt.figure(0) fig0, ax0 = plt.subplots(1, 1, figsize=(width, height)) fig0.set_tight_layout(True) figures.append(fig0) # Create front figure axis if any(xx in self.output_types for xx in ['front', 'combined']): ax0 = self.set_axes(ax0, axis=0) divider = make_axes_locatable(ax0) cax = divider.append_axes('right', size='3%', pad=0.05) bar_ticks = self.z_max // 5 + 1 norm = matplotlib.colors.Normalize(vmin=0, vmax=self.z_max) scalar_mappable = plt.cm.ScalarMappable(cmap=self.cmap, norm=norm) scalar_mappable.set_array([]) plt.colorbar(scalar_mappable, ticks=np.linspace(0, self.z_max, bar_ticks), boundaries=np.arange(- 0.05, self.z_max + 0.1, .1), cax=cax, label='Z [m]') axes.append(ax0) if not axes: axes.append(None) # Initialize bird-eye-view figure if 'bird' in self.output_types: self.extensions.append(".bird.png") fig1, ax1 = plt.subplots(1, 1) fig1.set_tight_layout(True) figures.append(fig1) if any(xx in self.output_types for xx in ['bird', 'combined']): ax1 = self.set_axes(ax1, axis=1) # Adding field of view axes.append(ax1) return figures, axes def draw(self, figures, axes, dic_out, image, show_all=False, draw_text=True, legend=True, draw_box=False, save=False, show=False): # Process the annotation dictionary of monoloco self._process_results(dic_out) # whether to include instances that don't match the ground-truth iterator = range(len(self.zz_pred)) if show_all else range(len(self.zz_gt)) if not iterator: print("-"*110 + '\n' + "! No instances detected, be sure to include file with ground-truth values or " "use the command --show_all" + '\n' + "-"*110) # Draw the front figure num = 0 self.mpl_im0.set_data(image) for idx in iterator: if any(xx in self.output_types for xx in ['front', 'combined']) and self.zz_pred[idx] > 0: color = self.cmap((self.zz_pred[idx] % self.z_max) / self.z_max) self.draw_circle(axes, self.uv_shoulders[idx], color) if draw_box: self.draw_boxes(axes, idx, color) if draw_text: self.draw_text_front(axes, self.uv_shoulders[idx], num) num += 1 # Draw the bird figure num = 0 for idx in iterator: if any(xx in self.output_types for xx in ['bird', 'combined']) and self.zz_pred[idx] > 0: # Draw ground truth and uncertainty self.draw_uncertainty(axes, idx) # Draw bird eye view text if draw_text: self.draw_text_bird(axes, idx, num) num += 1 # Add the legend if legend: draw_legend(axes) # Draw, save or/and show the figures for idx, fig in enumerate(figures): fig.canvas.draw() if save: fig.savefig(self.output_path + self.extensions[idx], bbox_inches='tight') if show: fig.show() plt.close(fig) def draw_uncertainty(self, axes, idx): theta = math.atan2(self.zz_pred[idx], self.xx_pred[idx]) dic_std = {'ale': self.stds_ale[idx], 'epi': self.stds_epi[idx]} dic_x, dic_y = {}, {} # Aleatoric and epistemic for key, std in dic_std.items(): delta_x = std * math.cos(theta) delta_z = std * math.sin(theta) dic_x[key] = (self.xx_pred[idx] - delta_x, self.xx_pred[idx] + delta_x) dic_y[key] = (self.zz_pred[idx] - delta_z, self.zz_pred[idx] + delta_z) # MonoLoco if not self.auxs: axes[1].plot(dic_x['epi'], dic_y['epi'], color='coral', linewidth=2, label="Epistemic Uncertainty") axes[1].plot(dic_x['ale'], dic_y['ale'], color='deepskyblue', linewidth=4, label="Aleatoric Uncertainty") axes[1].plot(self.xx_pred[idx], self.zz_pred[idx], color='cornflowerblue', label="Prediction", markersize=6, marker='o') if self.gt[idx]: axes[1].plot(self.xx_gt[idx], self.zz_gt[idx], color='k', label="Ground-truth", markersize=8, marker='x') # MonStereo(stereo case) elif self.auxs[idx] > 0.5: axes[1].plot(dic_x['ale'], dic_y['ale'], color='r', linewidth=4, label="Prediction (mono)") axes[1].plot(dic_x['ale'], dic_y['ale'], color='deepskyblue', linewidth=4, label="Prediction (stereo+mono)") if self.gt[idx]: axes[1].plot(self.xx_gt[idx], self.zz_gt[idx], color='k', label="Ground-truth", markersize=8, marker='x') # MonStereo (monocular case) else: axes[1].plot(dic_x['ale'], dic_y['ale'], color='deepskyblue', linewidth=4, label="Prediction (stereo+mono)") axes[1].plot(dic_x['ale'], dic_y['ale'], color='r', linewidth=4, label="Prediction (mono)") if self.gt[idx]: axes[1].plot(self.xx_gt[idx], self.zz_gt[idx], color='k', label="Ground-truth", markersize=8, marker='x') def draw_ellipses(self, axes, idx): """draw uncertainty ellipses""" target = get_task_error(self.dds_real[idx]) angle_gt = get_angle(self.xx_gt[idx], self.zz_gt[idx]) ellipse_real = Ellipse((self.xx_gt[idx], self.zz_gt[idx]), width=target * 2, height=1, angle=angle_gt, color='lightgreen', fill=True, label="Task error") axes[1].add_patch(ellipse_real) if abs(self.zz_gt[idx] - self.zz_pred[idx]) > 0.001: axes[1].plot(self.xx_gt[idx], self.zz_gt[idx], 'kx', label="Ground truth", markersize=3) angle = get_angle(self.xx_pred[idx], self.zz_pred[idx]) ellipse_ale = Ellipse((self.xx_pred[idx], self.zz_pred[idx]), width=self.stds_ale[idx] * 2, height=1, angle=angle, color='b', fill=False, label="Aleatoric Uncertainty", linewidth=1.3) ellipse_var = Ellipse((self.xx_pred[idx], self.zz_pred[idx]), width=self.stds_epi[idx] * 2, height=1, angle=angle, color='r', fill=False, label="Uncertainty", linewidth=1, linestyle='--') axes[1].add_patch(ellipse_ale) if self.epistemic: axes[1].add_patch(ellipse_var) axes[1].plot(self.xx_pred[idx], self.zz_pred[idx], 'ro', label="Predicted", markersize=3) def draw_boxes(self, axes, idx, color): ww_box = self.boxes[idx][2] - self.boxes[idx][0] hh_box = (self.boxes[idx][3] - self.boxes[idx][1]) * self.y_scale ww_box_gt = self.boxes_gt[idx][2] - self.boxes_gt[idx][0] hh_box_gt = (self.boxes_gt[idx][3] - self.boxes_gt[idx][1]) * self.y_scale rectangle = Rectangle((self.boxes[idx][0], self.boxes[idx][1] * self.y_scale), width=ww_box, height=hh_box, fill=False, color=color, linewidth=3) rectangle_gt = Rectangle((self.boxes_gt[idx][0], self.boxes_gt[idx][1] * self.y_scale), width=ww_box_gt, height=hh_box_gt, fill=False, color='g', linewidth=2) axes[0].add_patch(rectangle_gt) axes[0].add_patch(rectangle) def draw_text_front(self, axes, uv, num): axes[0].text(uv[0] + self.radius, uv[1] * self.y_scale - self.radius, str(num), fontsize=self.FONTSIZE, color=self.TEXTCOLOR, weight='bold') def draw_text_bird(self, axes, idx, num): """Plot the number in the bird eye view map""" std = self.stds_epi[idx] if self.stds_epi[idx] > 0 else self.stds_ale[idx] theta = math.atan2(self.zz_pred[idx], self.xx_pred[idx]) delta_x = std * math.cos(theta) delta_z = std * math.sin(theta) axes[1].text(self.xx_pred[idx] + delta_x, self.zz_pred[idx] + delta_z, str(num), fontsize=self.FONTSIZE_BV, color='darkorange') def draw_circle(self, axes, uv, color): circle = Circle((uv[0], uv[1] * self.y_scale), radius=self.radius, color=color, fill=True) axes[0].add_patch(circle) def set_axes(self, ax, axis): assert axis in (0, 1) if axis == 0: ax.set_axis_off() ax.set_xlim(0, self.width) ax.set_ylim(self.height, 0) self.mpl_im0 = ax.imshow(self.im) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) else: uv_max = [0., float(self.height)] xyz_max = pixel_to_camera(uv_max, self.kk, self.z_max) x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk corr = round(float(x_max / 3)) ax.plot([0, x_max], [0, self.z_max], 'k--') ax.plot([0, -x_max], [0, self.z_max], 'k--') ax.set_xlim(-x_max+corr, x_max-corr) ax.set_ylim(0, self.z_max+1) ax.set_xlabel("X [m]") return ax def draw_legend(axes): handles, labels = axes[1].get_legend_handles_labels() by_label = OrderedDict(zip(labels, handles)) axes[1].legend(by_label.values(), by_label.keys(), loc='best') def get_angle(xx, zz): """Obtain the points to plot the confidence of each annotation""" theta = math.atan2(zz, xx) angle = theta * (180 / math.pi) return angle