change dic_stats format

This commit is contained in:
lorenzo 2019-05-21 12:12:57 +02:00
parent 45bb9df863
commit 762163877b
2 changed files with 57 additions and 62 deletions

View File

@ -1,9 +1,9 @@
"""Evaluate Monoloco code on KITTI dataset using ALE and ALP metrics"""
import os
import math
import logging
from collections import defaultdict
import json
import copy
import datetime
from utils.misc import get_idx_max
@ -13,11 +13,13 @@ from visuals.results import print_results
class KittiEval:
"""
Evaluate Monoloco code on KITTI dataset and compare it with:
Evaluate Monoloco code and compare it with the following baselines:
- Mono3D
- 3DOP
- MonoDepth
"""
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
CLUSTERS = ('easy', 'moderate', 'hard', 'all', '6', '10', '15', '20', '25', '30', '40', '50', '>50')
dic_stds = defaultdict(lambda: defaultdict(list))
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
@ -25,10 +27,6 @@ class KittiEval:
errors = defaultdict(lambda: defaultdict(list))
def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
self.show = show
self.dir_gt = os.path.join('data', 'kitti', 'gt')
self.dir_m3d = os.path.join('data', 'kitti', 'm3d')
@ -83,50 +81,43 @@ class KittiEval:
self._estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
self._estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
self._estimate_error_mloco(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
# Iterate over all the files together to find a pool of common annotations
self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
# Update statistics of mean and max and uncertainty
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
# Update statistics of errors and uncertainty
for key in self.errors:
add_true_negatives(self.errors[key], cnt_gt)
for clst in self.CLUSTERS[:-2]: # M3d and pifpaf does not have annotations above 40 meters
get_statistics(self.dic_stats['test'][key], self.errors[key][clst], clst, self.dic_stds, key)
get_statistics(self.dic_stats['test'][key][clst], self.errors[key][clst], self.dic_stds[clst], key)
# Print statistics
# Show statistics
print(" Number of GT annotations: {} ".format(cnt_gt))
for key in self.errors:
if key in ['our', 'm3d', '3dop']:
print(" Number of {} annotations with confidence >= {} : {} "
.format(key, self.dic_thresh_conf[key], self.dic_cnt[key]))
# Include also missed annotations in the statistics
matched = len(self.errors[key]['all'])
missed = cnt_gt - matched
zeros = [0] * missed
self.errors[key]['<0.5m'].extend(zeros)
self.errors[key]['<1m'].extend(zeros)
self.errors[key]['<2m'].extend(zeros)
for clst in self.CLUSTERS[:-9]:
print(" {} Average error in cluster {}: {:.2f} with a max error of {:.1f}, "
"for {} annotations"
.format(key, clst, self.dic_stats['test'][key]['mean'][clst], self.dic_stats['test'][key]['max'][clst],
self.dic_stats['test'][key]['cnt'][clst]))
.format(key, clst, self.dic_stats['test'][key][clst]['mean'],
self.dic_stats['test'][key][clst]['max'],
self.dic_stats['test'][key][clst]['cnt']))
if key == 'our':
print("% of annotation inside the confidence interval: {:.1f} %, "
"of which {:.1f} % at higher risk"
.format(100 * self.dic_stats['test'][key]['interval'][clst],
100 * self.dic_stats['test'][key]['at_risk'][clst]))
.format(100 * self.dic_stats['test'][key][clst]['interval'],
100 * self.dic_stats['test'][key][clst]['at_risk']))
for perc in ['<0.5m', '<1m', '<2m']:
print("{} Instances with error {}: {:.2f} %"
.format(key, perc, 100 * sum(self.errors[key][perc])/len(self.errors[key][perc])))
print("\n Number of matched annotations: {:.1f} %".format(100 * matched/cnt_gt))
print("\n Number of matched annotations: {:.1f} %".format(self.errors[key]['matched']))
print("-"*100)
# Print images
@ -247,7 +238,7 @@ class KittiEval:
break
def _estimate_error_mloco(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
# Compute error (distance) and save it
boxes_gt = copy.deepcopy(boxes_gt)
@ -261,8 +252,6 @@ class KittiEval:
dd_geom = dds_geom[idx]
ale = stds_ale[idx]
epi = stds_epi[idx]
xyz = xyzs[idx]
xy_kp = xy_kps[idx]
idx_max, iou_max = get_idx_max(box, boxes_gt)
cat = get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max])
@ -300,7 +289,6 @@ class KittiEval:
iou_min = min(iou_max_3dop, iou_max_m3d, iou_max_md)
if iou_max >= self.dic_thresh_iou['our'] and iou_min >= self.dic_thresh_iou['m3d']:
dd_gt = dds_gt[idx_max]
dd_3dop = dds_3dop[idx_max_3dop]
dd_m3d = dds_m3d[idx_max_m3d]
@ -349,52 +337,59 @@ class KittiEval:
def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat):
clst = find_cluster(dd_gt, self.CLUSTERS)
self.dic_stds['ale']['all'].append(std_ale)
self.dic_stds['ale'][clst].append(std_ale)
self.dic_stds['ale'][cat].append(std_ale)
self.dic_stds['epi']['all'].append(std_epi)
self.dic_stds['epi'][clst].append(std_epi)
self.dic_stds['epi'][cat].append(std_epi)
self.dic_stds['all']['ale'].append(std_ale)
self.dic_stds[clst]['ale'].append(std_ale)
self.dic_stds[cat]['ale'].append(std_ale)
self.dic_stds['all']['epi'].append(std_epi)
self.dic_stds[clst]['epi'].append(std_epi)
self.dic_stds[cat]['epi'].append(std_epi)
# Number of annotations inside the confidence interval
if dd_gt <= dd: # Particularly dangerous instances
self.dic_stds['at_risk']['all'].append(1)
self.dic_stds['at_risk'][clst].append(1)
self.dic_stds['at_risk'][cat].append(1)
self.dic_stds['all']['at_risk'].append(1)
self.dic_stds[clst]['at_risk'].append(1)
self.dic_stds[cat]['at_risk'].append(1)
if abs(dd - dd_gt) <= std_epi:
self.dic_stds['interval']['all'].append(1)
self.dic_stds['interval'][clst].append(1)
self.dic_stds['interval'][cat].append(1)
self.dic_stds['all']['interval'].append(1)
self.dic_stds[clst]['interval'].append(1)
self.dic_stds[cat]['interval'].append(1)
else:
self.dic_stds['interval']['all'].append(0)
self.dic_stds['interval'][clst].append(0)
self.dic_stds['interval'][cat].append(0)
self.dic_stds['all']['interval'].append(0)
self.dic_stds[clst]['interval'].append(0)
self.dic_stds[cat]['interval'].append(0)
else:
self.dic_stds['at_risk']['all'].append(0)
self.dic_stds['at_risk'][clst].append(0)
self.dic_stds['at_risk'][cat].append(0)
self.dic_stds['all']['at_risk'].append(0)
self.dic_stds[clst]['at_risk'].append(0)
self.dic_stds[cat]['at_risk'].append(0)
# self.dic_stds['at_risk']['all'].append(0)
# self.dic_stds['at_risk'][clst].append(0)
# self.dic_stds['at_risk'][cat].append(0)
def get_statistics(dic_stats, errors, clst, dic_stds, key):
def get_statistics(dic_stats, errors, dic_stds, key):
"""Update statistics of a cluster"""
dic_stats['mean'][clst] = sum(errors) / float(len(errors))
dic_stats['max'][clst] = max(errors)
dic_stats['cnt'][clst] = len(errors)
dic_stats['mean'] = sum(errors) / float(len(errors))
dic_stats['max'] = max(errors)
dic_stats['cnt'] = len(errors)
if key == 'our':
dic_stats['std_ale'][clst] = sum(dic_stds['ale'][clst]) / float(len(dic_stds['ale'][clst]))
dic_stats['std_epi'][clst] = sum(dic_stds['epi'][clst]) / float(len(dic_stds['epi'][clst]))
dic_stats['interval'][clst] = sum(dic_stds['interval'][clst]) / float(len(dic_stds['interval'][clst]))
dic_stats['at_risk'][clst] = sum(dic_stds['at_risk'][clst]) / float(len(dic_stds['at_risk'][clst]))
dic_stats['std_ale'] = sum(dic_stds['ale']) / float(len(dic_stds['ale']))
dic_stats['std_epi'] = sum(dic_stds['epi']) / float(len(dic_stds['epi']))
dic_stats['interval'] = sum(dic_stds['interval']) / float(len(dic_stds['interval']))
dic_stats['at_risk'] = sum(dic_stds['at_risk']) / float(len(dic_stds['at_risk']))
def add_true_negatives(err, cnt_gt):
"""Update errors statistics of a specific method with missing detections"""
matched = len(err['all'])
missed = cnt_gt - matched
zeros = [0] * missed
err['<0.5m'].extend(zeros)
err['<1m'].extend(zeros)
err['<2m'].extend(zeros)
err['matched'] = 100 * matched / cnt_gt
def find_cluster(dd, clusters):

View File

@ -21,7 +21,7 @@ def print_results(dic_stats, show=False, save=False):
mm_std = 0.04
mm_gender = 0.0556
excl_clusters = ['all', '50', '>50', 'easy', 'moderate', 'hard']
clusters = tuple([clst for clst in dic_stats[phase]['our']['mean'] if clst not in excl_clusters])
clusters = tuple([clst for clst in dic_stats[phase]['our'] if clst not in excl_clusters])
yy_gender = target_error(xx, mm_gender)
yy_gps = np.linspace(5., 5., xx.shape[0])