refactor article figures
This commit is contained in:
parent
bbe040d4a2
commit
83a2fb8019
@ -15,7 +15,7 @@ from tabulate import tabulate
|
||||
|
||||
from ..utils import get_iou_matches, get_task_error, get_pixel_error, check_conditions, get_category, split_training, \
|
||||
parse_ground_truth
|
||||
from ..visuals import print_results
|
||||
from ..visuals import show_results, show_spread
|
||||
|
||||
|
||||
class EvalKitti:
|
||||
@ -125,11 +125,10 @@ class EvalKitti:
|
||||
print('\n' + category.upper() + ':')
|
||||
self.show_statistics()
|
||||
|
||||
# Show/save results
|
||||
self.printer(show=False)
|
||||
|
||||
def printer(self, show):
|
||||
print_results(self.dic_stats, show)
|
||||
def printer(self, show, save):
|
||||
if save or show:
|
||||
show_results(self.dic_stats, show, save)
|
||||
show_spread(self.dic_stats, show, save)
|
||||
|
||||
def _parse_txts(self, path, category, method):
|
||||
boxes = []
|
||||
|
||||
@ -86,6 +86,7 @@ def cli():
|
||||
eval_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=256)
|
||||
eval_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3)
|
||||
eval_parser.add_argument('--show', help='whether to show statistic graphs', action='store_true')
|
||||
eval_parser.add_argument('--save', help='whether to save statistic graphs', action='store_true')
|
||||
eval_parser.add_argument('--verbose', help='verbosity of statistics', action='store_true')
|
||||
eval_parser.add_argument('--stereo', help='include stereo baseline results', action='store_true')
|
||||
|
||||
@ -146,7 +147,7 @@ def main():
|
||||
from .eval import EvalKitti
|
||||
kitti_eval = EvalKitti(verbose=args.verbose, stereo=args.stereo)
|
||||
kitti_eval.run()
|
||||
kitti_eval.printer(show=args.show)
|
||||
kitti_eval.printer(show=args.show, save=args.save)
|
||||
|
||||
if 'nuscenes' in args.dataset:
|
||||
from .train import Trainer
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
|
||||
from .printer import Printer
|
||||
from .results import print_results
|
||||
from .figures import show_results, show_spread
|
||||
|
||||
@ -12,7 +12,7 @@ from matplotlib.patches import Ellipse
|
||||
from ..utils import get_task_error
|
||||
|
||||
|
||||
def show_results(dic_stats, show=False):
|
||||
def show_results(dic_stats, show=False, save=False):
|
||||
|
||||
"""
|
||||
Visualize error as function of the distance and compare it with target errors based on human height analyses
|
||||
@ -57,7 +57,7 @@ def show_results(dic_stats, show=False):
|
||||
plt.close()
|
||||
|
||||
|
||||
def show_spread(dic_stats, show=False):
|
||||
def show_spread(dic_stats, show=False, save=False):
|
||||
"""Predicted confidence intervals and task error as a function of ground-truth distance"""
|
||||
|
||||
phase = 'test'
|
||||
@ -103,35 +103,12 @@ def show_spread(dic_stats, show=False):
|
||||
plt.legend()
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
if save:
|
||||
plt.savefig(os.path.join(dir_out, 'spread_bi.png'))
|
||||
plt.close()
|
||||
|
||||
|
||||
def show_method():
|
||||
""" method figure"""
|
||||
std_1 = 0.75
|
||||
fig = plt.figure(1)
|
||||
ax = fig.add_subplot(1, 1, 1)
|
||||
|
||||
ell_3 = Ellipse((0, 2), width=std_1 * 2, height=0.3, angle=-90, color='b', fill=False, linewidth=2.5)
|
||||
ell_4 = Ellipse((0, 2), width=std_1 * 3, height=0.3, angle=-90, color='r', fill=False,
|
||||
linestyle='dashed', linewidth=2.5)
|
||||
ax.add_patch(ell_4)
|
||||
ax.add_patch(ell_3)
|
||||
plt.plot(0, 2, marker='o', color='skyblue', markersize=9)
|
||||
plt.plot([0, 3], [0, 4], 'k--')
|
||||
plt.plot([0, -3], [0, 4], 'k--')
|
||||
plt.xlim(-3, 3)
|
||||
plt.ylim(0, 3.5)
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
plt.xlabel('X [m]')
|
||||
plt.ylabel('Z [m]')
|
||||
plt.savefig(os.path.join('docs', 'output_method.png'))
|
||||
|
||||
|
||||
def show_task_error():
|
||||
def show_task_error(show, save):
|
||||
"""Task error figure"""
|
||||
plt.figure(2)
|
||||
xx = np.linspace(0, 40, 100)
|
||||
@ -159,7 +136,33 @@ def show_task_error():
|
||||
plt.xlabel("Ground-truth distance from the camera $d_{gt}$ [m]")
|
||||
plt.ylabel("Localization error $\hat{e}$ due to human height variation [m]")
|
||||
plt.legend(loc=(0.01, 0.55)) # Location from 0 to 1 from lower left
|
||||
plt.savefig(os.path.join('docs', 'task_error.png'))
|
||||
if show:
|
||||
plt.show()
|
||||
if save:
|
||||
plt.savefig(os.path.join('docs', 'task_error.png'))
|
||||
|
||||
|
||||
def show_method():
|
||||
""" method figure"""
|
||||
std_1 = 0.75
|
||||
fig = plt.figure(1)
|
||||
ax = fig.add_subplot(1, 1, 1)
|
||||
|
||||
ell_3 = Ellipse((0, 2), width=std_1 * 2, height=0.3, angle=-90, color='b', fill=False, linewidth=2.5)
|
||||
ell_4 = Ellipse((0, 2), width=std_1 * 3, height=0.3, angle=-90, color='r', fill=False,
|
||||
linestyle='dashed', linewidth=2.5)
|
||||
ax.add_patch(ell_4)
|
||||
ax.add_patch(ell_3)
|
||||
plt.plot(0, 2, marker='o', color='skyblue', markersize=9)
|
||||
plt.plot([0, 3], [0, 4], 'k--')
|
||||
plt.plot([0, -3], [0, 4], 'k--')
|
||||
plt.xlim(-3, 3)
|
||||
plt.ylim(0, 3.5)
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
plt.xlabel('X [m]')
|
||||
plt.ylabel('Z [m]')
|
||||
plt.savefig(os.path.join('docs', 'output_method.png'))
|
||||
|
||||
|
||||
def target_error(xx, mm):
|
||||
@ -218,7 +221,6 @@ def get_confidence_points(confidences, distances, errors):
|
||||
return distance_points, confidence_points
|
||||
|
||||
|
||||
|
||||
def height_distributions():
|
||||
|
||||
mu_men = 178
|
||||
|
||||
Loading…
Reference in New Issue
Block a user