From 632e138cd177890f3669c095ced451cd26704055 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Tue, 21 May 2019 11:41:47 +0200 Subject: [PATCH] add function for statistics --- src/eval/kitti_eval.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/eval/kitti_eval.py b/src/eval/kitti_eval.py index 1069da5..0a58d54 100644 --- a/src/eval/kitti_eval.py +++ b/src/eval/kitti_eval.py @@ -56,11 +56,10 @@ class KittiEval: _, self.set_val = split_training(names_gt, path_train, path_val) def run(self): - - """Evaluate Monoloco methods on ALP and ALE metrics""" - cnt_gt = 0 + """Evaluate Monoloco performances on ALP and ALE metrics""" # Iterate over each ground truth file in the training set + cnt_gt = 0 for name in self.set_val: path_gt = os.path.join(self.dir_gt, name) path_m3d = os.path.join(self.dir_m3d, name) @@ -91,24 +90,11 @@ class KittiEval: self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom) - # Save statistics + # Update statistics of mean and max and uncertainty + for key in self.errors: for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters - self.dic_stats['test'][key]['mean'][clst] = \ - sum(self.errors[key][clst]) / float(len(self.errors[key][clst])) - self.dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst]) - self.dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst]) - - if key == 'our': - for clst in self.clusters[:-2]: - self.dic_stats['test'][key]['std_ale'][clst] = \ - sum(self.dic_stds['ale'][clst]) / float(len(self.dic_stds['ale'][clst])) - self.dic_stats['test'][key]['std_epi'][clst] = \ - sum(self.dic_stds['epi'][clst]) / float(len(self.dic_stds['epi'][clst])) - self.dic_stats['test'][key]['interval'][clst] = \ - sum(self.dic_stds['interval'][clst]) / float(len(self.dic_stds['interval'][clst])) - self.dic_stats['test'][key]['at_risk'][clst] = \ - sum(self.dic_stds['at_risk'][clst]) / float(len(self.dic_stds['at_risk'][clst])) + get_statistics(self.dic_stats['test'][key], self.errors[key][clst], clst, self.dic_stds, key) # Print statistics print(" Number of GT annotations: {} ".format(cnt_gt)) @@ -398,6 +384,20 @@ class KittiEval: # self.dic_stds['at_risk'][clst].append(0) # self.dic_stds['at_risk'][cat].append(0) +def get_statistics(dic_stats, errors, clst, dic_stds, key): + + """Update statistics of a cluster""" + dic_stats['mean'][clst] = sum(errors) / float(len(errors)) + dic_stats['max'][clst] = max(errors) + dic_stats['cnt'][clst] = len(errors) + + if key == 'our': + dic_stats['std_ale'][clst] = sum(dic_stds['ale'][clst]) / float(len(dic_stds['ale'][clst])) + dic_stats['std_epi'][clst] = sum(dic_stds['epi'][clst]) / float(len(dic_stds['epi'][clst])) + dic_stats['interval'][clst] = sum(dic_stds['interval'][clst]) / float(len(dic_stds['interval'][clst])) + dic_stats['at_risk'][clst] = sum(dic_stds['at_risk'][clst]) / float(len(dic_stds['at_risk'][clst])) + + def find_cluster(dd, clusters): """Find the correct cluster. The first and the last one are not numeric""" @@ -405,4 +405,5 @@ def find_cluster(dd, clusters): if dd <= int(clst): return clst - return clusters[-1] \ No newline at end of file + return clusters[-1] +