From 67bc78067785fe7e683aa973b836caed071b2be5 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Tue, 21 May 2019 11:11:07 +0200 Subject: [PATCH] refactor __init__(2) --- src/eval/kitti_eval.py | 57 ++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/src/eval/kitti_eval.py b/src/eval/kitti_eval.py index 3237078..0d2ea66 100644 --- a/src/eval/kitti_eval.py +++ b/src/eval/kitti_eval.py @@ -19,6 +19,8 @@ class KittiEval: - 3DOP - MonoDepth """ + dic_stds = defaultdict(lambda: defaultdict(list)) + dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))) def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3): @@ -54,15 +56,10 @@ class KittiEval: # Extract validation images for evaluation names_gt = tuple(os.listdir(self.dir_gt)) _, self.set_val = split_training(names_gt, path_train, path_val) - aa = 5 def run(self): """Evaluate Monoloco methods on ALP and ALE metrics""" - - self.dic_stds = defaultdict(lambda: defaultdict(list)) - dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))) - cnt_gt = 0 # Iterate over each ground truth file in the training set @@ -73,7 +70,7 @@ class KittiEval: path_3dop = os.path.join(self.dir_3dop, name) path_md = os.path.join(self.dir_md, name) boxes_gt = [] - truncs_gt = [] # Float from 0 to 1 + truncs_gt = [] # Float from 0 to 1 occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown dds_gt = [] dic_fin = defaultdict(list) @@ -81,7 +78,7 @@ class KittiEval: # Iterate over each line of the gt file and save box location and distances with open(path_gt, "r") as f_gt: for line_gt in f_gt: - if self.check_conditions(line_gt, mode='gt'): + if check_conditions(line_gt, mode='gt'): truncs_gt.append(float(line_gt.split()[1])) occs_gt.append(int(line_gt.split()[2])) boxes_gt.append([float(x) for x in line_gt.split()[4:8]]) @@ -112,19 +109,19 @@ class KittiEval: # Save statistics for key in self.errors: for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters - dic_stats['test'][key]['mean'][clst] = sum(self.errors[key][clst]) / float(len(self.errors[key][clst])) - dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst]) - dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst]) + self.dic_stats['test'][key]['mean'][clst] = sum(self.errors[key][clst]) / float(len(self.errors[key][clst])) + self.dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst]) + self.dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst]) if key == 'our': for clst in self.clusters[:-2]: - dic_stats['test'][key]['std_ale'][clst] = \ + self.dic_stats['test'][key]['std_ale'][clst] = \ sum(self.dic_stds['ale'][clst]) / float(len(self.dic_stds['ale'][clst])) - dic_stats['test'][key]['std_epi'][clst] = \ + self.dic_stats['test'][key]['std_epi'][clst] = \ sum(self.dic_stds['epi'][clst]) / float(len(self.dic_stds['epi'][clst])) - dic_stats['test'][key]['interval'][clst] = \ + self.dic_stats['test'][key]['interval'][clst] = \ sum(self.dic_stds['interval'][clst]) / float(len(self.dic_stds['interval'][clst])) - dic_stats['test'][key]['at_risk'][clst] = \ + self.dic_stats['test'][key]['at_risk'][clst] = \ sum(self.dic_stds['at_risk'][clst]) / float(len(self.dic_stds['at_risk'][clst])) # Print statistics @@ -145,14 +142,14 @@ class KittiEval: for clst in self.clusters[:-9]: print(" {} Average error in cluster {}: {:.2f} with a max error of {:.1f}, " "for {} annotations" - .format(key, clst, dic_stats['test'][key]['mean'][clst], dic_stats['test'][key]['max'][clst], - dic_stats['test'][key]['cnt'][clst])) + .format(key, clst, self.dic_stats['test'][key]['mean'][clst], self.dic_stats['test'][key]['max'][clst], + self.dic_stats['test'][key]['cnt'][clst])) if key == 'our': print("% of annotation inside the confidence interval: {:.1f} %, " "of which {:.1f} % at higher risk" - .format(100 * dic_stats['test'][key]['interval'][clst], - 100 * dic_stats['test'][key]['at_risk'][clst])) + .format(100 * self.dic_stats['test'][key]['interval'][clst], + 100 * self.dic_stats['test'][key]['at_risk'][clst])) for perc in ['<0.5m', '<1m', '<2m']: print("{} Instances with error {}: {:.2f} %" @@ -162,7 +159,7 @@ class KittiEval: print("-"*100) # Print images - self.print_results(dic_stats, self.show) + print_results(self.dic_stats, self.show) def parse_txts(self, path, method): boxes = [] @@ -179,7 +176,7 @@ class KittiEval: try: with open(path, "r") as ff: for line in ff: - if self.check_conditions(line, thresh=self.dic_thresh_conf[method], mode=method): + if check_conditions(line, thresh=self.dic_thresh_conf[method], mode=method): boxes.append([float(x) for x in line.split()[4:8]]) loc = ([float(x) for x in line.split()[11:14]]) dds.append(math.sqrt(loc[0] ** 2 + loc[1] ** 2 + loc[2] ** 2)) @@ -235,7 +232,7 @@ class KittiEval: file_lines = ff.readlines() for line_our in file_lines[:-1]: line_list = [float(x) for x in line_our.split()] - if self.check_conditions(line_list, thresh=self.dic_thresh_conf[method], mode=method): + if check_conditions(line_list, thresh=self.dic_thresh_conf[method], mode=method): boxes.append(line_list[:4]) xyzs.append(line_list[4:7]) dds.append(line_list[7]) @@ -264,8 +261,8 @@ class KittiEval: for idx, box in enumerate(boxes): if len(boxes_gt) >= 1: dd = dds[idx] - idx_max, iou_max = self.get_idx_max(box, boxes_gt) - cat = self.get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) + idx_max, iou_max = get_idx_max(box, boxes_gt) + cat = get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) # Update error if match is found if iou_max > self.dic_thresh_iou[method]: dd_gt = dds_gt[idx_max] @@ -295,8 +292,8 @@ class KittiEval: epi = stds_epi[idx] xyz = xyzs[idx] xy_kp = xy_kps[idx] - idx_max, iou_max = self.get_idx_max(box, boxes_gt) - cat = self.get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) + idx_max, iou_max = get_idx_max(box, boxes_gt) + cat = get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) # Update error if match is found if iou_max > self.dic_thresh_iou['our']: @@ -340,12 +337,12 @@ class KittiEval: if len(boxes_gt) >= 1: dd_our = dds_our[idx] dd_geom = dds_geom[idx] - idx_max, iou_max = self.get_idx_max(box, boxes_gt) - cat = self.get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) + idx_max, iou_max = get_idx_max(box, boxes_gt) + cat = get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) - idx_max_3dop, iou_max_3dop = self.get_idx_max(box, boxes_3dop) - idx_max_m3d, iou_max_m3d = self.get_idx_max(box, boxes_m3d) - idx_max_md, iou_max_md = self.get_idx_max(box, boxes_md) + idx_max_3dop, iou_max_3dop = get_idx_max(box, boxes_3dop) + idx_max_m3d, iou_max_m3d = get_idx_max(box, boxes_m3d) + idx_max_md, iou_max_md = get_idx_max(box, boxes_md) iou_min = min(iou_max_3dop, iou_max_m3d, iou_max_md)