change dic_stats format
This commit is contained in:
parent
45bb9df863
commit
762163877b
@ -1,9 +1,9 @@
|
||||
"""Evaluate Monoloco code on KITTI dataset using ALE and ALP metrics"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import copy
|
||||
import datetime
|
||||
from utils.misc import get_idx_max
|
||||
@ -13,11 +13,13 @@ from visuals.results import print_results
|
||||
|
||||
class KittiEval:
|
||||
"""
|
||||
Evaluate Monoloco code on KITTI dataset and compare it with:
|
||||
Evaluate Monoloco code and compare it with the following baselines:
|
||||
- Mono3D
|
||||
- 3DOP
|
||||
- MonoDepth
|
||||
"""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
CLUSTERS = ('easy', 'moderate', 'hard', 'all', '6', '10', '15', '20', '25', '30', '40', '50', '>50')
|
||||
dic_stds = defaultdict(lambda: defaultdict(list))
|
||||
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
|
||||
@ -25,10 +27,6 @@ class KittiEval:
|
||||
errors = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
self.show = show
|
||||
self.dir_gt = os.path.join('data', 'kitti', 'gt')
|
||||
self.dir_m3d = os.path.join('data', 'kitti', 'm3d')
|
||||
@ -83,50 +81,43 @@ class KittiEval:
|
||||
self._estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
|
||||
self._estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
|
||||
self._estimate_error_mloco(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
|
||||
|
||||
# Iterate over all the files together to find a pool of common annotations
|
||||
self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
|
||||
|
||||
# Update statistics of mean and max and uncertainty
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
|
||||
|
||||
# Update statistics of errors and uncertainty
|
||||
for key in self.errors:
|
||||
add_true_negatives(self.errors[key], cnt_gt)
|
||||
for clst in self.CLUSTERS[:-2]: # M3d and pifpaf does not have annotations above 40 meters
|
||||
get_statistics(self.dic_stats['test'][key], self.errors[key][clst], clst, self.dic_stds, key)
|
||||
get_statistics(self.dic_stats['test'][key][clst], self.errors[key][clst], self.dic_stds[clst], key)
|
||||
|
||||
# Print statistics
|
||||
# Show statistics
|
||||
print(" Number of GT annotations: {} ".format(cnt_gt))
|
||||
for key in self.errors:
|
||||
if key in ['our', 'm3d', '3dop']:
|
||||
print(" Number of {} annotations with confidence >= {} : {} "
|
||||
.format(key, self.dic_thresh_conf[key], self.dic_cnt[key]))
|
||||
|
||||
# Include also missed annotations in the statistics
|
||||
matched = len(self.errors[key]['all'])
|
||||
missed = cnt_gt - matched
|
||||
zeros = [0] * missed
|
||||
self.errors[key]['<0.5m'].extend(zeros)
|
||||
self.errors[key]['<1m'].extend(zeros)
|
||||
self.errors[key]['<2m'].extend(zeros)
|
||||
|
||||
for clst in self.CLUSTERS[:-9]:
|
||||
print(" {} Average error in cluster {}: {:.2f} with a max error of {:.1f}, "
|
||||
"for {} annotations"
|
||||
.format(key, clst, self.dic_stats['test'][key]['mean'][clst], self.dic_stats['test'][key]['max'][clst],
|
||||
self.dic_stats['test'][key]['cnt'][clst]))
|
||||
.format(key, clst, self.dic_stats['test'][key][clst]['mean'],
|
||||
self.dic_stats['test'][key][clst]['max'],
|
||||
self.dic_stats['test'][key][clst]['cnt']))
|
||||
|
||||
if key == 'our':
|
||||
print("% of annotation inside the confidence interval: {:.1f} %, "
|
||||
"of which {:.1f} % at higher risk"
|
||||
.format(100 * self.dic_stats['test'][key]['interval'][clst],
|
||||
100 * self.dic_stats['test'][key]['at_risk'][clst]))
|
||||
.format(100 * self.dic_stats['test'][key][clst]['interval'],
|
||||
100 * self.dic_stats['test'][key][clst]['at_risk']))
|
||||
|
||||
for perc in ['<0.5m', '<1m', '<2m']:
|
||||
print("{} Instances with error {}: {:.2f} %"
|
||||
.format(key, perc, 100 * sum(self.errors[key][perc])/len(self.errors[key][perc])))
|
||||
|
||||
print("\n Number of matched annotations: {:.1f} %".format(100 * matched/cnt_gt))
|
||||
print("\n Number of matched annotations: {:.1f} %".format(self.errors[key]['matched']))
|
||||
print("-"*100)
|
||||
|
||||
# Print images
|
||||
@ -247,7 +238,7 @@ class KittiEval:
|
||||
break
|
||||
|
||||
def _estimate_error_mloco(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
|
||||
|
||||
# Compute error (distance) and save it
|
||||
boxes_gt = copy.deepcopy(boxes_gt)
|
||||
@ -261,8 +252,6 @@ class KittiEval:
|
||||
dd_geom = dds_geom[idx]
|
||||
ale = stds_ale[idx]
|
||||
epi = stds_epi[idx]
|
||||
xyz = xyzs[idx]
|
||||
xy_kp = xy_kps[idx]
|
||||
idx_max, iou_max = get_idx_max(box, boxes_gt)
|
||||
cat = get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max])
|
||||
|
||||
@ -300,7 +289,6 @@ class KittiEval:
|
||||
iou_min = min(iou_max_3dop, iou_max_m3d, iou_max_md)
|
||||
|
||||
if iou_max >= self.dic_thresh_iou['our'] and iou_min >= self.dic_thresh_iou['m3d']:
|
||||
|
||||
dd_gt = dds_gt[idx_max]
|
||||
dd_3dop = dds_3dop[idx_max_3dop]
|
||||
dd_m3d = dds_m3d[idx_max_m3d]
|
||||
@ -349,52 +337,59 @@ class KittiEval:
|
||||
def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat):
|
||||
|
||||
clst = find_cluster(dd_gt, self.CLUSTERS)
|
||||
self.dic_stds['ale']['all'].append(std_ale)
|
||||
self.dic_stds['ale'][clst].append(std_ale)
|
||||
self.dic_stds['ale'][cat].append(std_ale)
|
||||
self.dic_stds['epi']['all'].append(std_epi)
|
||||
self.dic_stds['epi'][clst].append(std_epi)
|
||||
self.dic_stds['epi'][cat].append(std_epi)
|
||||
self.dic_stds['all']['ale'].append(std_ale)
|
||||
self.dic_stds[clst]['ale'].append(std_ale)
|
||||
self.dic_stds[cat]['ale'].append(std_ale)
|
||||
self.dic_stds['all']['epi'].append(std_epi)
|
||||
self.dic_stds[clst]['epi'].append(std_epi)
|
||||
self.dic_stds[cat]['epi'].append(std_epi)
|
||||
|
||||
# Number of annotations inside the confidence interval
|
||||
if dd_gt <= dd: # Particularly dangerous instances
|
||||
self.dic_stds['at_risk']['all'].append(1)
|
||||
self.dic_stds['at_risk'][clst].append(1)
|
||||
self.dic_stds['at_risk'][cat].append(1)
|
||||
self.dic_stds['all']['at_risk'].append(1)
|
||||
self.dic_stds[clst]['at_risk'].append(1)
|
||||
self.dic_stds[cat]['at_risk'].append(1)
|
||||
|
||||
if abs(dd - dd_gt) <= std_epi:
|
||||
self.dic_stds['interval']['all'].append(1)
|
||||
self.dic_stds['interval'][clst].append(1)
|
||||
self.dic_stds['interval'][cat].append(1)
|
||||
self.dic_stds['all']['interval'].append(1)
|
||||
self.dic_stds[clst]['interval'].append(1)
|
||||
self.dic_stds[cat]['interval'].append(1)
|
||||
|
||||
else:
|
||||
self.dic_stds['interval']['all'].append(0)
|
||||
self.dic_stds['interval'][clst].append(0)
|
||||
self.dic_stds['interval'][cat].append(0)
|
||||
|
||||
self.dic_stds['all']['interval'].append(0)
|
||||
self.dic_stds[clst]['interval'].append(0)
|
||||
self.dic_stds[cat]['interval'].append(0)
|
||||
|
||||
else:
|
||||
self.dic_stds['at_risk']['all'].append(0)
|
||||
self.dic_stds['at_risk'][clst].append(0)
|
||||
self.dic_stds['at_risk'][cat].append(0)
|
||||
self.dic_stds['all']['at_risk'].append(0)
|
||||
self.dic_stds[clst]['at_risk'].append(0)
|
||||
self.dic_stds[cat]['at_risk'].append(0)
|
||||
|
||||
|
||||
# self.dic_stds['at_risk']['all'].append(0)
|
||||
# self.dic_stds['at_risk'][clst].append(0)
|
||||
# self.dic_stds['at_risk'][cat].append(0)
|
||||
|
||||
def get_statistics(dic_stats, errors, clst, dic_stds, key):
|
||||
|
||||
def get_statistics(dic_stats, errors, dic_stds, key):
|
||||
"""Update statistics of a cluster"""
|
||||
dic_stats['mean'][clst] = sum(errors) / float(len(errors))
|
||||
dic_stats['max'][clst] = max(errors)
|
||||
dic_stats['cnt'][clst] = len(errors)
|
||||
|
||||
dic_stats['mean'] = sum(errors) / float(len(errors))
|
||||
dic_stats['max'] = max(errors)
|
||||
dic_stats['cnt'] = len(errors)
|
||||
|
||||
if key == 'our':
|
||||
dic_stats['std_ale'][clst] = sum(dic_stds['ale'][clst]) / float(len(dic_stds['ale'][clst]))
|
||||
dic_stats['std_epi'][clst] = sum(dic_stds['epi'][clst]) / float(len(dic_stds['epi'][clst]))
|
||||
dic_stats['interval'][clst] = sum(dic_stds['interval'][clst]) / float(len(dic_stds['interval'][clst]))
|
||||
dic_stats['at_risk'][clst] = sum(dic_stds['at_risk'][clst]) / float(len(dic_stds['at_risk'][clst]))
|
||||
dic_stats['std_ale'] = sum(dic_stds['ale']) / float(len(dic_stds['ale']))
|
||||
dic_stats['std_epi'] = sum(dic_stds['epi']) / float(len(dic_stds['epi']))
|
||||
dic_stats['interval'] = sum(dic_stds['interval']) / float(len(dic_stds['interval']))
|
||||
dic_stats['at_risk'] = sum(dic_stds['at_risk']) / float(len(dic_stds['at_risk']))
|
||||
|
||||
|
||||
def add_true_negatives(err, cnt_gt):
|
||||
"""Update errors statistics of a specific method with missing detections"""
|
||||
|
||||
matched = len(err['all'])
|
||||
missed = cnt_gt - matched
|
||||
zeros = [0] * missed
|
||||
err['<0.5m'].extend(zeros)
|
||||
err['<1m'].extend(zeros)
|
||||
err['<2m'].extend(zeros)
|
||||
err['matched'] = 100 * matched / cnt_gt
|
||||
|
||||
|
||||
def find_cluster(dd, clusters):
|
||||
|
||||
@ -21,7 +21,7 @@ def print_results(dic_stats, show=False, save=False):
|
||||
mm_std = 0.04
|
||||
mm_gender = 0.0556
|
||||
excl_clusters = ['all', '50', '>50', 'easy', 'moderate', 'hard']
|
||||
clusters = tuple([clst for clst in dic_stats[phase]['our']['mean'] if clst not in excl_clusters])
|
||||
clusters = tuple([clst for clst in dic_stats[phase]['our'] if clst not in excl_clusters])
|
||||
yy_gender = target_error(xx, mm_gender)
|
||||
yy_gps = np.linspace(5., 5., xx.shape[0])
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user