refactor article figures

This commit is contained in:
lorenzo 2019-07-26 08:32:18 +02:00
parent bbe040d4a2
commit 83a2fb8019
5 changed files with 39 additions and 37 deletions

View File

@ -15,7 +15,7 @@ from tabulate import tabulate
from ..utils import get_iou_matches, get_task_error, get_pixel_error, check_conditions, get_category, split_training, \ from ..utils import get_iou_matches, get_task_error, get_pixel_error, check_conditions, get_category, split_training, \
parse_ground_truth parse_ground_truth
from ..visuals import print_results from ..visuals import show_results, show_spread
class EvalKitti: class EvalKitti:
@ -125,11 +125,10 @@ class EvalKitti:
print('\n' + category.upper() + ':') print('\n' + category.upper() + ':')
self.show_statistics() self.show_statistics()
# Show/save results def printer(self, show, save):
self.printer(show=False) if save or show:
show_results(self.dic_stats, show, save)
def printer(self, show): show_spread(self.dic_stats, show, save)
print_results(self.dic_stats, show)
def _parse_txts(self, path, category, method): def _parse_txts(self, path, category, method):
boxes = [] boxes = []

View File

@ -86,6 +86,7 @@ def cli():
eval_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=256) eval_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=256)
eval_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3) eval_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3)
eval_parser.add_argument('--show', help='whether to show statistic graphs', action='store_true') eval_parser.add_argument('--show', help='whether to show statistic graphs', action='store_true')
eval_parser.add_argument('--save', help='whether to save statistic graphs', action='store_true')
eval_parser.add_argument('--verbose', help='verbosity of statistics', action='store_true') eval_parser.add_argument('--verbose', help='verbosity of statistics', action='store_true')
eval_parser.add_argument('--stereo', help='include stereo baseline results', action='store_true') eval_parser.add_argument('--stereo', help='include stereo baseline results', action='store_true')
@ -146,7 +147,7 @@ def main():
from .eval import EvalKitti from .eval import EvalKitti
kitti_eval = EvalKitti(verbose=args.verbose, stereo=args.stereo) kitti_eval = EvalKitti(verbose=args.verbose, stereo=args.stereo)
kitti_eval.run() kitti_eval.run()
kitti_eval.printer(show=args.show) kitti_eval.printer(show=args.show, save=args.save)
if 'nuscenes' in args.dataset: if 'nuscenes' in args.dataset:
from .train import Trainer from .train import Trainer

View File

@ -1,3 +1,3 @@
from .printer import Printer from .printer import Printer
from .results import print_results from .figures import show_results, show_spread

View File

@ -12,7 +12,7 @@ from matplotlib.patches import Ellipse
from ..utils import get_task_error from ..utils import get_task_error
def show_results(dic_stats, show=False): def show_results(dic_stats, show=False, save=False):
""" """
Visualize error as function of the distance and compare it with target errors based on human height analyses Visualize error as function of the distance and compare it with target errors based on human height analyses
@ -57,7 +57,7 @@ def show_results(dic_stats, show=False):
plt.close() plt.close()
def show_spread(dic_stats, show=False): def show_spread(dic_stats, show=False, save=False):
"""Predicted confidence intervals and task error as a function of ground-truth distance""" """Predicted confidence intervals and task error as a function of ground-truth distance"""
phase = 'test' phase = 'test'
@ -103,35 +103,12 @@ def show_spread(dic_stats, show=False):
plt.legend() plt.legend()
if show: if show:
plt.show() plt.show()
else: if save:
plt.savefig(os.path.join(dir_out, 'spread_bi.png')) plt.savefig(os.path.join(dir_out, 'spread_bi.png'))
plt.close() plt.close()
def show_method(): def show_task_error(show, save):
""" method figure"""
std_1 = 0.75
fig = plt.figure(1)
ax = fig.add_subplot(1, 1, 1)
ell_3 = Ellipse((0, 2), width=std_1 * 2, height=0.3, angle=-90, color='b', fill=False, linewidth=2.5)
ell_4 = Ellipse((0, 2), width=std_1 * 3, height=0.3, angle=-90, color='r', fill=False,
linestyle='dashed', linewidth=2.5)
ax.add_patch(ell_4)
ax.add_patch(ell_3)
plt.plot(0, 2, marker='o', color='skyblue', markersize=9)
plt.plot([0, 3], [0, 4], 'k--')
plt.plot([0, -3], [0, 4], 'k--')
plt.xlim(-3, 3)
plt.ylim(0, 3.5)
plt.xticks([])
plt.yticks([])
plt.xlabel('X [m]')
plt.ylabel('Z [m]')
plt.savefig(os.path.join('docs', 'output_method.png'))
def show_task_error():
"""Task error figure""" """Task error figure"""
plt.figure(2) plt.figure(2)
xx = np.linspace(0, 40, 100) xx = np.linspace(0, 40, 100)
@ -159,7 +136,33 @@ def show_task_error():
plt.xlabel("Ground-truth distance from the camera $d_{gt}$ [m]") plt.xlabel("Ground-truth distance from the camera $d_{gt}$ [m]")
plt.ylabel("Localization error $\hat{e}$ due to human height variation [m]") plt.ylabel("Localization error $\hat{e}$ due to human height variation [m]")
plt.legend(loc=(0.01, 0.55)) # Location from 0 to 1 from lower left plt.legend(loc=(0.01, 0.55)) # Location from 0 to 1 from lower left
plt.savefig(os.path.join('docs', 'task_error.png')) if show:
plt.show()
if save:
plt.savefig(os.path.join('docs', 'task_error.png'))
def show_method():
""" method figure"""
std_1 = 0.75
fig = plt.figure(1)
ax = fig.add_subplot(1, 1, 1)
ell_3 = Ellipse((0, 2), width=std_1 * 2, height=0.3, angle=-90, color='b', fill=False, linewidth=2.5)
ell_4 = Ellipse((0, 2), width=std_1 * 3, height=0.3, angle=-90, color='r', fill=False,
linestyle='dashed', linewidth=2.5)
ax.add_patch(ell_4)
ax.add_patch(ell_3)
plt.plot(0, 2, marker='o', color='skyblue', markersize=9)
plt.plot([0, 3], [0, 4], 'k--')
plt.plot([0, -3], [0, 4], 'k--')
plt.xlim(-3, 3)
plt.ylim(0, 3.5)
plt.xticks([])
plt.yticks([])
plt.xlabel('X [m]')
plt.ylabel('Z [m]')
plt.savefig(os.path.join('docs', 'output_method.png'))
def target_error(xx, mm): def target_error(xx, mm):
@ -218,7 +221,6 @@ def get_confidence_points(confidences, distances, errors):
return distance_points, confidence_points return distance_points, confidence_points
def height_distributions(): def height_distributions():
mu_men = 178 mu_men = 178