add internal methods

This commit is contained in:
lorenzo 2019-05-21 11:23:45 +02:00
parent 521a04ece5
commit 3a42536cef

View File

@ -20,6 +20,8 @@ class KittiEval:
"""
dic_stds = defaultdict(lambda: defaultdict(list))
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
dic_cnt = defaultdict(int)
errors = defaultdict(lambda: defaultdict(list))
def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
@ -49,9 +51,6 @@ class KittiEval:
self.dic_thresh_iou = {'m3d': thresh_iou_m3d, '3dop': thresh_iou_m3d, 'md': thresh_iou_our, 'our': thresh_iou_our}
self.dic_thresh_conf = {'m3d': thresh_conf_m3d, '3dop': thresh_conf_m3d, 'our': thresh_conf_our}
self.dic_cnt = defaultdict(int)
self.errors = defaultdict(lambda: defaultdict(list))
# Extract validation images for evaluation
names_gt = tuple(os.listdir(self.dir_gt))
_, self.set_val = split_training(names_gt, path_train, path_val)
@ -82,21 +81,21 @@ class KittiEval:
self._parse_txts(path_our, method='our')
# Compute the error with ground truth
self.estimate_error_base(boxes_m3d, dds_m3d, boxes_gt, dds_gt, truncs_gt, occs_gt, method='m3d')
self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
self._estimate_error_base(boxes_m3d, dds_m3d, boxes_gt, dds_gt, truncs_gt, occs_gt, method='m3d')
self._estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
self._estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
self._estimate_error_mloco(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
# Iterate over all the files together to find a pool of common annotations
self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
# Save statistics
for key in self.errors:
for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters
self.dic_stats['test'][key]['mean'][clst] = sum(self.errors[key][clst]) / float(len(self.errors[key][clst]))
self.dic_stats['test'][key]['mean'][clst] = \
sum(self.errors[key][clst]) / float(len(self.errors[key][clst]))
self.dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst])
self.dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst])
@ -237,7 +236,7 @@ class KittiEval:
except FileNotFoundError:
return [], [], [], [], [], [], [], []
def estimate_error_base(self, boxes, dds, boxes_gt, dds_gt, truncs_gt, occs_gt, method):
def _estimate_error_base(self, boxes, dds, boxes_gt, dds_gt, truncs_gt, occs_gt, method):
# Compute error (distance) and save it
boxes_gt = copy.deepcopy(boxes_gt)
@ -262,7 +261,7 @@ class KittiEval:
else:
break
def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
def _estimate_error_mloco(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
# Compute error (distance) and save it
@ -294,7 +293,7 @@ class KittiEval:
truncs_gt.pop(idx_max)
occs_gt.pop(idx_max)
def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
def _compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
boxes_gt = copy.deepcopy(boxes_gt)