add internal methods
This commit is contained in:
parent
521a04ece5
commit
3a42536cef
@ -20,6 +20,8 @@ class KittiEval:
|
|||||||
"""
|
"""
|
||||||
dic_stds = defaultdict(lambda: defaultdict(list))
|
dic_stds = defaultdict(lambda: defaultdict(list))
|
||||||
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
|
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
|
||||||
|
dic_cnt = defaultdict(int)
|
||||||
|
errors = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
|
def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
|
||||||
|
|
||||||
@ -49,9 +51,6 @@ class KittiEval:
|
|||||||
self.dic_thresh_iou = {'m3d': thresh_iou_m3d, '3dop': thresh_iou_m3d, 'md': thresh_iou_our, 'our': thresh_iou_our}
|
self.dic_thresh_iou = {'m3d': thresh_iou_m3d, '3dop': thresh_iou_m3d, 'md': thresh_iou_our, 'our': thresh_iou_our}
|
||||||
self.dic_thresh_conf = {'m3d': thresh_conf_m3d, '3dop': thresh_conf_m3d, 'our': thresh_conf_our}
|
self.dic_thresh_conf = {'m3d': thresh_conf_m3d, '3dop': thresh_conf_m3d, 'our': thresh_conf_our}
|
||||||
|
|
||||||
self.dic_cnt = defaultdict(int)
|
|
||||||
self.errors = defaultdict(lambda: defaultdict(list))
|
|
||||||
|
|
||||||
# Extract validation images for evaluation
|
# Extract validation images for evaluation
|
||||||
names_gt = tuple(os.listdir(self.dir_gt))
|
names_gt = tuple(os.listdir(self.dir_gt))
|
||||||
_, self.set_val = split_training(names_gt, path_train, path_val)
|
_, self.set_val = split_training(names_gt, path_train, path_val)
|
||||||
@ -82,21 +81,21 @@ class KittiEval:
|
|||||||
self._parse_txts(path_our, method='our')
|
self._parse_txts(path_our, method='our')
|
||||||
|
|
||||||
# Compute the error with ground truth
|
# Compute the error with ground truth
|
||||||
|
self._estimate_error_base(boxes_m3d, dds_m3d, boxes_gt, dds_gt, truncs_gt, occs_gt, method='m3d')
|
||||||
self.estimate_error_base(boxes_m3d, dds_m3d, boxes_gt, dds_gt, truncs_gt, occs_gt, method='m3d')
|
self._estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
|
||||||
self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
|
self._estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
|
||||||
self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
|
self._estimate_error_mloco(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||||
self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
|
||||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
|
|
||||||
|
|
||||||
# Iterate over all the files together to find a pool of common annotations
|
# Iterate over all the files together to find a pool of common annotations
|
||||||
self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
|
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
|
||||||
|
|
||||||
# Save statistics
|
# Save statistics
|
||||||
for key in self.errors:
|
for key in self.errors:
|
||||||
for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters
|
for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters
|
||||||
self.dic_stats['test'][key]['mean'][clst] = sum(self.errors[key][clst]) / float(len(self.errors[key][clst]))
|
self.dic_stats['test'][key]['mean'][clst] = \
|
||||||
|
sum(self.errors[key][clst]) / float(len(self.errors[key][clst]))
|
||||||
self.dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst])
|
self.dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst])
|
||||||
self.dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst])
|
self.dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst])
|
||||||
|
|
||||||
@ -237,7 +236,7 @@ class KittiEval:
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return [], [], [], [], [], [], [], []
|
return [], [], [], [], [], [], [], []
|
||||||
|
|
||||||
def estimate_error_base(self, boxes, dds, boxes_gt, dds_gt, truncs_gt, occs_gt, method):
|
def _estimate_error_base(self, boxes, dds, boxes_gt, dds_gt, truncs_gt, occs_gt, method):
|
||||||
|
|
||||||
# Compute error (distance) and save it
|
# Compute error (distance) and save it
|
||||||
boxes_gt = copy.deepcopy(boxes_gt)
|
boxes_gt = copy.deepcopy(boxes_gt)
|
||||||
@ -262,7 +261,7 @@ class KittiEval:
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
def _estimate_error_mloco(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
|
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
|
||||||
|
|
||||||
# Compute error (distance) and save it
|
# Compute error (distance) and save it
|
||||||
@ -294,7 +293,7 @@ class KittiEval:
|
|||||||
truncs_gt.pop(idx_max)
|
truncs_gt.pop(idx_max)
|
||||||
occs_gt.pop(idx_max)
|
occs_gt.pop(idx_max)
|
||||||
|
|
||||||
def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
def _compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
|
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
|
||||||
|
|
||||||
boxes_gt = copy.deepcopy(boxes_gt)
|
boxes_gt = copy.deepcopy(boxes_gt)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user