add class variables

This commit is contained in:
lorenzo 2019-05-21 11:44:50 +02:00
parent 632e138cd1
commit 45bb9df863

View File

@ -18,6 +18,7 @@ class KittiEval:
- 3DOP
- MonoDepth
"""
CLUSTERS = ('easy', 'moderate', 'hard', 'all', '6', '10', '15', '20', '25', '30', '40', '50', '>50')
dic_stds = defaultdict(lambda: defaultdict(list))
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
dic_cnt = defaultdict(int)
@ -46,8 +47,6 @@ class KittiEval:
assert os.path.exists(self.dir_m3d) and os.path.exists(self.dir_our) \
and os.path.exists(self.dir_3dop)
self.clusters = ['easy', 'moderate', 'hard', 'all', '6', '10', '15', '20', '25', '30', '40', '50', '>50']
self.dic_thresh_iou = {'m3d': thresh_iou_m3d, '3dop': thresh_iou_m3d, 'md': thresh_iou_our, 'our': thresh_iou_our}
self.dic_thresh_conf = {'m3d': thresh_conf_m3d, '3dop': thresh_conf_m3d, 'our': thresh_conf_our}
@ -93,7 +92,7 @@ class KittiEval:
# Update statistics of mean and max and uncertainty
for key in self.errors:
for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters
for clst in self.CLUSTERS[:-2]: # M3d and pifpaf does not have annotations above 40 meters
get_statistics(self.dic_stats['test'][key], self.errors[key][clst], clst, self.dic_stds, key)
# Print statistics
@ -111,7 +110,7 @@ class KittiEval:
self.errors[key]['<1m'].extend(zeros)
self.errors[key]['<2m'].extend(zeros)
for clst in self.clusters[:-9]:
for clst in self.CLUSTERS[:-9]:
print(" {} Average error in cluster {}: {:.2f} with a max error of {:.1f}, "
"for {} annotations"
.format(key, clst, self.dic_stats['test'][key]['mean'][clst], self.dic_stats['test'][key]['max'][clst],
@ -326,7 +325,7 @@ class KittiEval:
"""Compute and save errors between a single box and the gt box which match"""
diff = abs(dd - dd_gt)
clst = find_cluster(dd_gt, self.clusters)
clst = find_cluster(dd_gt, self.CLUSTERS)
errors['all'].append(diff)
errors[cat].append(diff)
errors[clst].append(diff)
@ -349,7 +348,7 @@ class KittiEval:
def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat):
clst = find_cluster(dd_gt, self.clusters)
clst = find_cluster(dd_gt, self.CLUSTERS)
self.dic_stds['ale']['all'].append(std_ale)
self.dic_stds['ale'][clst].append(std_ale)
self.dic_stds['ale'][cat].append(std_ale)