refactor figures

This commit is contained in:
lorenzo 2019-05-21 14:25:35 +02:00
parent 762163877b
commit 496e147c2a

View File

@ -1,5 +1,6 @@
import os
import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
@ -13,19 +14,20 @@ def print_results(dic_stats, show=False, save=False):
Position error in meters due to a height variation of 7 cm (Standard deviation already knowing the sex)
Position error not knowing the gender (13cm as average difference --> 7.5cm of error to add)
"""
# ALE figure
dir_out = 'docs'
phase = 'test'
x_min = 0
x_max = 38
xx = np.linspace(0, 60, 100)
mm_std = 0.04
mm_gender = 0.0556
excl_clusters = ['all', '50', '>50', 'easy', 'moderate', 'hard']
clusters = tuple([clst for clst in dic_stats[phase]['our'] if clst not in excl_clusters])
yy_gender = target_error(xx, mm_gender)
yy_gps = np.linspace(5., 5., xx.shape[0])
# Precision on same instances
plt.figure(0)
fig_name = 'results.png'
plt.xlabel("Distance [meters]")
plt.ylabel("Average localization error [m]")
@ -39,8 +41,7 @@ def print_results(dic_stats, show=False, save=False):
plt.plot(xx, yy_gps, '-', label="GPS Error", color='y')
for idx, method in enumerate(['m3d_merged', 'geom_merged', 'md_merged', 'our_merged', '3dop_merged']):
dic_errs = dic_stats[phase][method]['mean']
errs = get_values(dic_errs, clusters)
errs = [dic_stats[phase][method][clst]['mean'] for clst in clusters]
xxs = get_distances(clusters)
plt.plot(xxs, errs, marker=mks[idx], markersize=mksizes[idx], linewidth=lws[idx], label=labels[idx],
@ -53,29 +54,19 @@ def print_results(dic_stats, show=False, save=False):
plt.savefig(os.path.join(dir_out, fig_name))
plt.close()
# FIGURE SPREAD
fig_name = 'spread.png'
# fig = plt.figure(3)
# ax = fig.add_subplot(1, 1, 1)
# SPREAD b Figure
plt.figure(1)
fig, ax = plt.subplots(2, sharex=True)
plt.xlabel("Distance [m]")
plt.ylabel("Aleatoric uncertainty [m]")
ar = 0.5 # Change aspect ratio of ellipses
scale = 1.5 # Factor to scale ellipses
rec_c = 0 # Center of the rectangle
# rec_h = 2.8 # Height of the rectangle
plots_line = True
# ax[0].set_ylim([-3, 3])
# ax[1].set_ylim([0, 3])
# ax[1].set_ylabel("Aleatoric uncertainty [m]")
# ax[0].set_ylabel("Confidence intervals")
dic_ale = dic_stats[phase]['our']['std_ale']
bbs = np.array(get_values(dic_ale, clusters))
bbs = np.array([dic_stats[phase]['our'][key]['std_ale'] for key in clusters])
xxs = get_distances(clusters)
yys = target_error(np.array(xxs), mm_gender)
# ale_het = tuple(bbs - yys)
# plt.plot(xxs, ale_het, marker='s', label=method)
ax[1].plot(xxs, bbs, marker='s', color='b', label="Spread b")
ax[1].plot(xxs, yys, '--', color='lightgreen', label="Task error", linewidth=2.5)
yys_up = [rec_c + ar/2 * scale * yy for yy in yys]
@ -94,7 +85,6 @@ def print_results(dic_stats, show=False, save=False):
bi = Ellipse((xx, rec_c), width=bbs[idx]*ar*scale, height=scale, angle=90, color='b',linewidth=1.8,
fill=False)
# ax[0].add_patch(rectangle)
ax[0].add_patch(te)
ax[0].add_patch(bi)
@ -107,10 +97,12 @@ def print_results(dic_stats, show=False, save=False):
def target_error(xx, mm):
"""Multiplication"""
return mm * xx
def get_distances(clusters):
"""Extract distances as intermediate values between 2 clusters"""
clusters_ext = list(clusters)
clusters_ext.insert(0, str(0))
@ -122,14 +114,6 @@ def get_distances(clusters):
return tuple(distances)
def get_values(dic_err, clusters):
errs = []
for key in clusters:
errs.append(dic_err[key])
return errs
def get_confidence_points(confidences, distances, errors):
confidence_points = []
@ -142,4 +126,4 @@ def get_confidence_points(confidences, distances, errors):
distance_points.append(dd)
distance_points.append(dd)
return distance_points, confidence_points
return distance_points, confidence_points