diff --git a/src/eval/kitti_eval.py b/src/eval/kitti_eval.py index 486fe34..1069da5 100644 --- a/src/eval/kitti_eval.py +++ b/src/eval/kitti_eval.py @@ -20,6 +20,8 @@ class KittiEval: """ dic_stds = defaultdict(lambda: defaultdict(list)) dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))) + dic_cnt = defaultdict(int) + errors = defaultdict(lambda: defaultdict(list)) def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3): @@ -49,9 +51,6 @@ class KittiEval: self.dic_thresh_iou = {'m3d': thresh_iou_m3d, '3dop': thresh_iou_m3d, 'md': thresh_iou_our, 'our': thresh_iou_our} self.dic_thresh_conf = {'m3d': thresh_conf_m3d, '3dop': thresh_conf_m3d, 'our': thresh_conf_our} - self.dic_cnt = defaultdict(int) - self.errors = defaultdict(lambda: defaultdict(list)) - # Extract validation images for evaluation names_gt = tuple(os.listdir(self.dir_gt)) _, self.set_val = split_training(names_gt, path_train, path_val) @@ -82,21 +81,21 @@ class KittiEval: self._parse_txts(path_our, method='our') # Compute the error with ground truth - - self.estimate_error_base(boxes_m3d, dds_m3d, boxes_gt, dds_gt, truncs_gt, occs_gt, method='m3d') - self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop') - self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md') - self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, - boxes_gt, dds_gt, truncs_gt, occs_gt, name) + self._estimate_error_base(boxes_m3d, dds_m3d, boxes_gt, dds_gt, truncs_gt, occs_gt, method='m3d') + self._estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop') + self._estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md') + self._estimate_error_mloco(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, + boxes_gt, dds_gt, truncs_gt, occs_gt, name) # Iterate over all the files together to find a pool of common annotations - self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, + self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom) # Save statistics for key in self.errors: for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters - self.dic_stats['test'][key]['mean'][clst] = sum(self.errors[key][clst]) / float(len(self.errors[key][clst])) + self.dic_stats['test'][key]['mean'][clst] = \ + sum(self.errors[key][clst]) / float(len(self.errors[key][clst])) self.dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst]) self.dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst]) @@ -237,7 +236,7 @@ class KittiEval: except FileNotFoundError: return [], [], [], [], [], [], [], [] - def estimate_error_base(self, boxes, dds, boxes_gt, dds_gt, truncs_gt, occs_gt, method): + def _estimate_error_base(self, boxes, dds, boxes_gt, dds_gt, truncs_gt, occs_gt, method): # Compute error (distance) and save it boxes_gt = copy.deepcopy(boxes_gt) @@ -262,7 +261,7 @@ class KittiEval: else: break - def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, + def _estimate_error_mloco(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, boxes_gt, dds_gt, truncs_gt, occs_gt, name): # Compute error (distance) and save it @@ -294,7 +293,7 @@ class KittiEval: truncs_gt.pop(idx_max) occs_gt.pop(idx_max) - def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, + def _compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom): boxes_gt = copy.deepcopy(boxes_gt)