add function for statistics
This commit is contained in:
parent
3a42536cef
commit
632e138cd1
@ -56,11 +56,10 @@ class KittiEval:
|
||||
_, self.set_val = split_training(names_gt, path_train, path_val)
|
||||
|
||||
def run(self):
|
||||
|
||||
"""Evaluate Monoloco methods on ALP and ALE metrics"""
|
||||
cnt_gt = 0
|
||||
"""Evaluate Monoloco performances on ALP and ALE metrics"""
|
||||
|
||||
# Iterate over each ground truth file in the training set
|
||||
cnt_gt = 0
|
||||
for name in self.set_val:
|
||||
path_gt = os.path.join(self.dir_gt, name)
|
||||
path_m3d = os.path.join(self.dir_m3d, name)
|
||||
@ -91,24 +90,11 @@ class KittiEval:
|
||||
self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
|
||||
|
||||
# Save statistics
|
||||
# Update statistics of mean and max and uncertainty
|
||||
|
||||
for key in self.errors:
|
||||
for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters
|
||||
self.dic_stats['test'][key]['mean'][clst] = \
|
||||
sum(self.errors[key][clst]) / float(len(self.errors[key][clst]))
|
||||
self.dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst])
|
||||
self.dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst])
|
||||
|
||||
if key == 'our':
|
||||
for clst in self.clusters[:-2]:
|
||||
self.dic_stats['test'][key]['std_ale'][clst] = \
|
||||
sum(self.dic_stds['ale'][clst]) / float(len(self.dic_stds['ale'][clst]))
|
||||
self.dic_stats['test'][key]['std_epi'][clst] = \
|
||||
sum(self.dic_stds['epi'][clst]) / float(len(self.dic_stds['epi'][clst]))
|
||||
self.dic_stats['test'][key]['interval'][clst] = \
|
||||
sum(self.dic_stds['interval'][clst]) / float(len(self.dic_stds['interval'][clst]))
|
||||
self.dic_stats['test'][key]['at_risk'][clst] = \
|
||||
sum(self.dic_stds['at_risk'][clst]) / float(len(self.dic_stds['at_risk'][clst]))
|
||||
get_statistics(self.dic_stats['test'][key], self.errors[key][clst], clst, self.dic_stds, key)
|
||||
|
||||
# Print statistics
|
||||
print(" Number of GT annotations: {} ".format(cnt_gt))
|
||||
@ -398,6 +384,20 @@ class KittiEval:
|
||||
# self.dic_stds['at_risk'][clst].append(0)
|
||||
# self.dic_stds['at_risk'][cat].append(0)
|
||||
|
||||
def get_statistics(dic_stats, errors, clst, dic_stds, key):
|
||||
|
||||
"""Update statistics of a cluster"""
|
||||
dic_stats['mean'][clst] = sum(errors) / float(len(errors))
|
||||
dic_stats['max'][clst] = max(errors)
|
||||
dic_stats['cnt'][clst] = len(errors)
|
||||
|
||||
if key == 'our':
|
||||
dic_stats['std_ale'][clst] = sum(dic_stds['ale'][clst]) / float(len(dic_stds['ale'][clst]))
|
||||
dic_stats['std_epi'][clst] = sum(dic_stds['epi'][clst]) / float(len(dic_stds['epi'][clst]))
|
||||
dic_stats['interval'][clst] = sum(dic_stds['interval'][clst]) / float(len(dic_stds['interval'][clst]))
|
||||
dic_stats['at_risk'][clst] = sum(dic_stds['at_risk'][clst]) / float(len(dic_stds['at_risk'][clst]))
|
||||
|
||||
|
||||
def find_cluster(dd, clusters):
|
||||
"""Find the correct cluster. The first and the last one are not numeric"""
|
||||
|
||||
@ -405,4 +405,5 @@ def find_cluster(dd, clusters):
|
||||
if dd <= int(clst):
|
||||
return clst
|
||||
|
||||
return clusters[-1]
|
||||
return clusters[-1]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user