refactor article figures

This commit is contained in:
lorenzo 2019-07-26 08:32:18 +02:00
parent bbe040d4a2
commit 83a2fb8019
5 changed files with 39 additions and 37 deletions

View File

@ -15,7 +15,7 @@ from tabulate import tabulate
from ..utils import get_iou_matches, get_task_error, get_pixel_error, check_conditions, get_category, split_training, \
parse_ground_truth
from ..visuals import print_results
from ..visuals import show_results, show_spread
class EvalKitti:
@ -125,11 +125,10 @@ class EvalKitti:
print('\n' + category.upper() + ':')
self.show_statistics()
# Show/save results
self.printer(show=False)
def printer(self, show):
print_results(self.dic_stats, show)
def printer(self, show, save):
if save or show:
show_results(self.dic_stats, show, save)
show_spread(self.dic_stats, show, save)
def _parse_txts(self, path, category, method):
boxes = []

View File

@ -86,6 +86,7 @@ def cli():
eval_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=256)
eval_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3)
eval_parser.add_argument('--show', help='whether to show statistic graphs', action='store_true')
eval_parser.add_argument('--save', help='whether to save statistic graphs', action='store_true')
eval_parser.add_argument('--verbose', help='verbosity of statistics', action='store_true')
eval_parser.add_argument('--stereo', help='include stereo baseline results', action='store_true')
@ -146,7 +147,7 @@ def main():
from .eval import EvalKitti
kitti_eval = EvalKitti(verbose=args.verbose, stereo=args.stereo)
kitti_eval.run()
kitti_eval.printer(show=args.show)
kitti_eval.printer(show=args.show, save=args.save)
if 'nuscenes' in args.dataset:
from .train import Trainer

View File

@ -1,3 +1,3 @@
from .printer import Printer
from .results import print_results
from .figures import show_results, show_spread

View File

@ -12,7 +12,7 @@ from matplotlib.patches import Ellipse
from ..utils import get_task_error
def show_results(dic_stats, show=False):
def show_results(dic_stats, show=False, save=False):
"""
Visualize error as function of the distance and compare it with target errors based on human height analyses
@ -57,7 +57,7 @@ def show_results(dic_stats, show=False):
plt.close()
def show_spread(dic_stats, show=False):
def show_spread(dic_stats, show=False, save=False):
"""Predicted confidence intervals and task error as a function of ground-truth distance"""
phase = 'test'
@ -103,35 +103,12 @@ def show_spread(dic_stats, show=False):
plt.legend()
if show:
plt.show()
else:
if save:
plt.savefig(os.path.join(dir_out, 'spread_bi.png'))
plt.close()
def show_method():
""" method figure"""
std_1 = 0.75
fig = plt.figure(1)
ax = fig.add_subplot(1, 1, 1)
ell_3 = Ellipse((0, 2), width=std_1 * 2, height=0.3, angle=-90, color='b', fill=False, linewidth=2.5)
ell_4 = Ellipse((0, 2), width=std_1 * 3, height=0.3, angle=-90, color='r', fill=False,
linestyle='dashed', linewidth=2.5)
ax.add_patch(ell_4)
ax.add_patch(ell_3)
plt.plot(0, 2, marker='o', color='skyblue', markersize=9)
plt.plot([0, 3], [0, 4], 'k--')
plt.plot([0, -3], [0, 4], 'k--')
plt.xlim(-3, 3)
plt.ylim(0, 3.5)
plt.xticks([])
plt.yticks([])
plt.xlabel('X [m]')
plt.ylabel('Z [m]')
plt.savefig(os.path.join('docs', 'output_method.png'))
def show_task_error():
def show_task_error(show, save):
"""Task error figure"""
plt.figure(2)
xx = np.linspace(0, 40, 100)
@ -159,7 +136,33 @@ def show_task_error():
plt.xlabel("Ground-truth distance from the camera $d_{gt}$ [m]")
plt.ylabel("Localization error $\hat{e}$ due to human height variation [m]")
plt.legend(loc=(0.01, 0.55)) # Location from 0 to 1 from lower left
plt.savefig(os.path.join('docs', 'task_error.png'))
if show:
plt.show()
if save:
plt.savefig(os.path.join('docs', 'task_error.png'))
def show_method():
""" method figure"""
std_1 = 0.75
fig = plt.figure(1)
ax = fig.add_subplot(1, 1, 1)
ell_3 = Ellipse((0, 2), width=std_1 * 2, height=0.3, angle=-90, color='b', fill=False, linewidth=2.5)
ell_4 = Ellipse((0, 2), width=std_1 * 3, height=0.3, angle=-90, color='r', fill=False,
linestyle='dashed', linewidth=2.5)
ax.add_patch(ell_4)
ax.add_patch(ell_3)
plt.plot(0, 2, marker='o', color='skyblue', markersize=9)
plt.plot([0, 3], [0, 4], 'k--')
plt.plot([0, -3], [0, 4], 'k--')
plt.xlim(-3, 3)
plt.ylim(0, 3.5)
plt.xticks([])
plt.yticks([])
plt.xlabel('X [m]')
plt.ylabel('Z [m]')
plt.savefig(os.path.join('docs', 'output_method.png'))
def target_error(xx, mm):
@ -218,7 +221,6 @@ def get_confidence_points(confidences, distances, errors):
return distance_points, confidence_points
def height_distributions():
mu_men = 178