From 3c7d07d081f2d5a16930d1641ddaa42f0468a0ab Mon Sep 17 00:00:00 2001 From: lorenzo Date: Tue, 21 May 2019 15:03:46 +0200 Subject: [PATCH] add print method --- src/eval/kitti_eval.py | 14 +++++++------- src/main.py | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/eval/kitti_eval.py b/src/eval/kitti_eval.py index 9c4dc70..1475de1 100644 --- a/src/eval/kitti_eval.py +++ b/src/eval/kitti_eval.py @@ -26,8 +26,7 @@ class KittiEval: dic_cnt = defaultdict(int) errors = defaultdict(lambda: defaultdict(list)) - def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3): - self.show = show + def __init__(self, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3): self.dir_gt = os.path.join('data', 'kitti', 'gt') self.dir_m3d = os.path.join('data', 'kitti', 'm3d') self.dir_3dop = os.path.join('data', 'kitti', '3dop') @@ -70,7 +69,7 @@ class KittiEval: # Extract annotations for the same file if len(boxes_gt) > 0: - boxes_m3d, dds_m3d = self._parse_txts(path_m3d, method='m3d') + boxes_m3d, dds_m3d = self._parse_txts(path_m3d, method='m3d') boxes_3dop, dds_3dop = self._parse_txts(path_3dop, method='3dop') boxes_md, dds_md = self._parse_txts(path_md, method='md') boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps = \ @@ -120,8 +119,8 @@ class KittiEval: print("\n Number of matched annotations: {:.1f} %".format(self.errors[key]['matched'])) print("-"*100) - # Print images - print_results(self.dic_stats, self.show) + def print(self, show): + print_results(self.dic_stats, show) def _parse_txts(self, path, method): boxes = [] @@ -134,7 +133,7 @@ class KittiEval: xy_kps = [] # Iterate over each line of the txt file - if method == '3dop' or method == 'm3d': + if method in ['3dop', 'm3d']: try: with open(path, "r") as ff: for line in ff: @@ -268,7 +267,7 @@ class KittiEval: occs_gt.pop(idx_max) def _compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, - boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom): + boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom): boxes_gt = copy.deepcopy(boxes_gt) dds_gt = copy.deepcopy(dds_gt) @@ -401,3 +400,4 @@ def find_cluster(dd, clusters): return clusters[-1] + diff --git a/src/main.py b/src/main.py index 28127e2..caae2eb 100644 --- a/src/main.py +++ b/src/main.py @@ -143,8 +143,9 @@ def main(): run_kitti.run() if args.dataset == 'kitti': - kitti_eval = KittiEval(show=args.show) + kitti_eval = KittiEval() kitti_eval.run() + kitti_eval.print(show=args.show) if 'nuscenes' in args.dataset: training = Trainer(joints=args.joints)