add internal methods
This commit is contained in:
parent
521a04ece5
commit
3a42536cef
@ -20,6 +20,8 @@ class KittiEval:
|
||||
"""
|
||||
dic_stds = defaultdict(lambda: defaultdict(list))
|
||||
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
|
||||
dic_cnt = defaultdict(int)
|
||||
errors = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
|
||||
|
||||
@ -49,9 +51,6 @@ class KittiEval:
|
||||
self.dic_thresh_iou = {'m3d': thresh_iou_m3d, '3dop': thresh_iou_m3d, 'md': thresh_iou_our, 'our': thresh_iou_our}
|
||||
self.dic_thresh_conf = {'m3d': thresh_conf_m3d, '3dop': thresh_conf_m3d, 'our': thresh_conf_our}
|
||||
|
||||
self.dic_cnt = defaultdict(int)
|
||||
self.errors = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# Extract validation images for evaluation
|
||||
names_gt = tuple(os.listdir(self.dir_gt))
|
||||
_, self.set_val = split_training(names_gt, path_train, path_val)
|
||||
@ -82,21 +81,21 @@ class KittiEval:
|
||||
self._parse_txts(path_our, method='our')
|
||||
|
||||
# Compute the error with ground truth
|
||||
|
||||
self.estimate_error_base(boxes_m3d, dds_m3d, boxes_gt, dds_gt, truncs_gt, occs_gt, method='m3d')
|
||||
self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
|
||||
self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
|
||||
self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
|
||||
self._estimate_error_base(boxes_m3d, dds_m3d, boxes_gt, dds_gt, truncs_gt, occs_gt, method='m3d')
|
||||
self._estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
|
||||
self._estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
|
||||
self._estimate_error_mloco(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
|
||||
|
||||
# Iterate over all the files together to find a pool of common annotations
|
||||
self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||
self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
|
||||
|
||||
# Save statistics
|
||||
for key in self.errors:
|
||||
for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters
|
||||
self.dic_stats['test'][key]['mean'][clst] = sum(self.errors[key][clst]) / float(len(self.errors[key][clst]))
|
||||
self.dic_stats['test'][key]['mean'][clst] = \
|
||||
sum(self.errors[key][clst]) / float(len(self.errors[key][clst]))
|
||||
self.dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst])
|
||||
self.dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst])
|
||||
|
||||
@ -237,7 +236,7 @@ class KittiEval:
|
||||
except FileNotFoundError:
|
||||
return [], [], [], [], [], [], [], []
|
||||
|
||||
def estimate_error_base(self, boxes, dds, boxes_gt, dds_gt, truncs_gt, occs_gt, method):
|
||||
def _estimate_error_base(self, boxes, dds, boxes_gt, dds_gt, truncs_gt, occs_gt, method):
|
||||
|
||||
# Compute error (distance) and save it
|
||||
boxes_gt = copy.deepcopy(boxes_gt)
|
||||
@ -262,7 +261,7 @@ class KittiEval:
|
||||
else:
|
||||
break
|
||||
|
||||
def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||
def _estimate_error_mloco(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
|
||||
|
||||
# Compute error (distance) and save it
|
||||
@ -294,7 +293,7 @@ class KittiEval:
|
||||
truncs_gt.pop(idx_max)
|
||||
occs_gt.pop(idx_max)
|
||||
|
||||
def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||
def _compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
|
||||
|
||||
boxes_gt = copy.deepcopy(boxes_gt)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user