add function for statistics

This commit is contained in:
lorenzo 2019-05-21 11:41:47 +02:00
parent 3a42536cef
commit 632e138cd1

View File

@ -56,11 +56,10 @@ class KittiEval:
_, self.set_val = split_training(names_gt, path_train, path_val)
def run(self):
"""Evaluate Monoloco methods on ALP and ALE metrics"""
cnt_gt = 0
"""Evaluate Monoloco performances on ALP and ALE metrics"""
# Iterate over each ground truth file in the training set
cnt_gt = 0
for name in self.set_val:
path_gt = os.path.join(self.dir_gt, name)
path_m3d = os.path.join(self.dir_m3d, name)
@ -91,24 +90,11 @@ class KittiEval:
self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
# Save statistics
# Update statistics of mean and max and uncertainty
for key in self.errors:
for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters
self.dic_stats['test'][key]['mean'][clst] = \
sum(self.errors[key][clst]) / float(len(self.errors[key][clst]))
self.dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst])
self.dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst])
if key == 'our':
for clst in self.clusters[:-2]:
self.dic_stats['test'][key]['std_ale'][clst] = \
sum(self.dic_stds['ale'][clst]) / float(len(self.dic_stds['ale'][clst]))
self.dic_stats['test'][key]['std_epi'][clst] = \
sum(self.dic_stds['epi'][clst]) / float(len(self.dic_stds['epi'][clst]))
self.dic_stats['test'][key]['interval'][clst] = \
sum(self.dic_stds['interval'][clst]) / float(len(self.dic_stds['interval'][clst]))
self.dic_stats['test'][key]['at_risk'][clst] = \
sum(self.dic_stds['at_risk'][clst]) / float(len(self.dic_stds['at_risk'][clst]))
get_statistics(self.dic_stats['test'][key], self.errors[key][clst], clst, self.dic_stds, key)
# Print statistics
print(" Number of GT annotations: {} ".format(cnt_gt))
@ -398,6 +384,20 @@ class KittiEval:
# self.dic_stds['at_risk'][clst].append(0)
# self.dic_stds['at_risk'][cat].append(0)
def get_statistics(dic_stats, errors, clst, dic_stds, key):
"""Update statistics of a cluster"""
dic_stats['mean'][clst] = sum(errors) / float(len(errors))
dic_stats['max'][clst] = max(errors)
dic_stats['cnt'][clst] = len(errors)
if key == 'our':
dic_stats['std_ale'][clst] = sum(dic_stds['ale'][clst]) / float(len(dic_stds['ale'][clst]))
dic_stats['std_epi'][clst] = sum(dic_stds['epi'][clst]) / float(len(dic_stds['epi'][clst]))
dic_stats['interval'][clst] = sum(dic_stds['interval'][clst]) / float(len(dic_stds['interval'][clst]))
dic_stats['at_risk'][clst] = sum(dic_stds['at_risk'][clst]) / float(len(dic_stds['at_risk'][clst]))
def find_cluster(dd, clusters):
"""Find the correct cluster. The first and the last one are not numeric"""
@ -405,4 +405,5 @@ def find_cluster(dd, clusters):
if dd <= int(clst):
return clst
return clusters[-1]
return clusters[-1]