change dic_stats format

This commit is contained in:
lorenzo 2019-05-21 12:12:57 +02:00
parent 45bb9df863
commit 762163877b
2 changed files with 57 additions and 62 deletions

View File

@ -1,9 +1,9 @@
"""Evaluate Monoloco code on KITTI dataset using ALE and ALP metrics"""
import os import os
import math import math
import logging import logging
from collections import defaultdict from collections import defaultdict
import json
import copy import copy
import datetime import datetime
from utils.misc import get_idx_max from utils.misc import get_idx_max
@ -13,11 +13,13 @@ from visuals.results import print_results
class KittiEval: class KittiEval:
""" """
Evaluate Monoloco code on KITTI dataset and compare it with: Evaluate Monoloco code and compare it with the following baselines:
- Mono3D - Mono3D
- 3DOP - 3DOP
- MonoDepth - MonoDepth
""" """
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
CLUSTERS = ('easy', 'moderate', 'hard', 'all', '6', '10', '15', '20', '25', '30', '40', '50', '>50') CLUSTERS = ('easy', 'moderate', 'hard', 'all', '6', '10', '15', '20', '25', '30', '40', '50', '>50')
dic_stds = defaultdict(lambda: defaultdict(list)) dic_stds = defaultdict(lambda: defaultdict(list))
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))) dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
@ -25,10 +27,6 @@ class KittiEval:
errors = defaultdict(lambda: defaultdict(list)) errors = defaultdict(lambda: defaultdict(list))
def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3): def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
self.show = show self.show = show
self.dir_gt = os.path.join('data', 'kitti', 'gt') self.dir_gt = os.path.join('data', 'kitti', 'gt')
self.dir_m3d = os.path.join('data', 'kitti', 'm3d') self.dir_m3d = os.path.join('data', 'kitti', 'm3d')
@ -83,50 +81,43 @@ class KittiEval:
self._estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop') self._estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
self._estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md') self._estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
self._estimate_error_mloco(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, self._estimate_error_mloco(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
boxes_gt, dds_gt, truncs_gt, occs_gt, name) boxes_gt, dds_gt, truncs_gt, occs_gt, name)
# Iterate over all the files together to find a pool of common annotations # Iterate over all the files together to find a pool of common annotations
self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, self._compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom) boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom)
# Update statistics of mean and max and uncertainty
# Update statistics of errors and uncertainty
for key in self.errors: for key in self.errors:
add_true_negatives(self.errors[key], cnt_gt)
for clst in self.CLUSTERS[:-2]: # M3d and pifpaf does not have annotations above 40 meters for clst in self.CLUSTERS[:-2]: # M3d and pifpaf does not have annotations above 40 meters
get_statistics(self.dic_stats['test'][key], self.errors[key][clst], clst, self.dic_stds, key) get_statistics(self.dic_stats['test'][key][clst], self.errors[key][clst], self.dic_stds[clst], key)
# Print statistics # Show statistics
print(" Number of GT annotations: {} ".format(cnt_gt)) print(" Number of GT annotations: {} ".format(cnt_gt))
for key in self.errors: for key in self.errors:
if key in ['our', 'm3d', '3dop']: if key in ['our', 'm3d', '3dop']:
print(" Number of {} annotations with confidence >= {} : {} " print(" Number of {} annotations with confidence >= {} : {} "
.format(key, self.dic_thresh_conf[key], self.dic_cnt[key])) .format(key, self.dic_thresh_conf[key], self.dic_cnt[key]))
# Include also missed annotations in the statistics
matched = len(self.errors[key]['all'])
missed = cnt_gt - matched
zeros = [0] * missed
self.errors[key]['<0.5m'].extend(zeros)
self.errors[key]['<1m'].extend(zeros)
self.errors[key]['<2m'].extend(zeros)
for clst in self.CLUSTERS[:-9]: for clst in self.CLUSTERS[:-9]:
print(" {} Average error in cluster {}: {:.2f} with a max error of {:.1f}, " print(" {} Average error in cluster {}: {:.2f} with a max error of {:.1f}, "
"for {} annotations" "for {} annotations"
.format(key, clst, self.dic_stats['test'][key]['mean'][clst], self.dic_stats['test'][key]['max'][clst], .format(key, clst, self.dic_stats['test'][key][clst]['mean'],
self.dic_stats['test'][key]['cnt'][clst])) self.dic_stats['test'][key][clst]['max'],
self.dic_stats['test'][key][clst]['cnt']))
if key == 'our': if key == 'our':
print("% of annotation inside the confidence interval: {:.1f} %, " print("% of annotation inside the confidence interval: {:.1f} %, "
"of which {:.1f} % at higher risk" "of which {:.1f} % at higher risk"
.format(100 * self.dic_stats['test'][key]['interval'][clst], .format(100 * self.dic_stats['test'][key][clst]['interval'],
100 * self.dic_stats['test'][key]['at_risk'][clst])) 100 * self.dic_stats['test'][key][clst]['at_risk']))
for perc in ['<0.5m', '<1m', '<2m']: for perc in ['<0.5m', '<1m', '<2m']:
print("{} Instances with error {}: {:.2f} %" print("{} Instances with error {}: {:.2f} %"
.format(key, perc, 100 * sum(self.errors[key][perc])/len(self.errors[key][perc]))) .format(key, perc, 100 * sum(self.errors[key][perc])/len(self.errors[key][perc])))
print("\n Number of matched annotations: {:.1f} %".format(100 * matched/cnt_gt)) print("\n Number of matched annotations: {:.1f} %".format(self.errors[key]['matched']))
print("-"*100) print("-"*100)
# Print images # Print images
@ -247,7 +238,7 @@ class KittiEval:
break break
def _estimate_error_mloco(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, def _estimate_error_mloco(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
boxes_gt, dds_gt, truncs_gt, occs_gt, name): boxes_gt, dds_gt, truncs_gt, occs_gt, name):
# Compute error (distance) and save it # Compute error (distance) and save it
boxes_gt = copy.deepcopy(boxes_gt) boxes_gt = copy.deepcopy(boxes_gt)
@ -261,8 +252,6 @@ class KittiEval:
dd_geom = dds_geom[idx] dd_geom = dds_geom[idx]
ale = stds_ale[idx] ale = stds_ale[idx]
epi = stds_epi[idx] epi = stds_epi[idx]
xyz = xyzs[idx]
xy_kp = xy_kps[idx]
idx_max, iou_max = get_idx_max(box, boxes_gt) idx_max, iou_max = get_idx_max(box, boxes_gt)
cat = get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) cat = get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max])
@ -300,7 +289,6 @@ class KittiEval:
iou_min = min(iou_max_3dop, iou_max_m3d, iou_max_md) iou_min = min(iou_max_3dop, iou_max_m3d, iou_max_md)
if iou_max >= self.dic_thresh_iou['our'] and iou_min >= self.dic_thresh_iou['m3d']: if iou_max >= self.dic_thresh_iou['our'] and iou_min >= self.dic_thresh_iou['m3d']:
dd_gt = dds_gt[idx_max] dd_gt = dds_gt[idx_max]
dd_3dop = dds_3dop[idx_max_3dop] dd_3dop = dds_3dop[idx_max_3dop]
dd_m3d = dds_m3d[idx_max_m3d] dd_m3d = dds_m3d[idx_max_m3d]
@ -349,52 +337,59 @@ class KittiEval:
def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat): def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat):
clst = find_cluster(dd_gt, self.CLUSTERS) clst = find_cluster(dd_gt, self.CLUSTERS)
self.dic_stds['ale']['all'].append(std_ale) self.dic_stds['all']['ale'].append(std_ale)
self.dic_stds['ale'][clst].append(std_ale) self.dic_stds[clst]['ale'].append(std_ale)
self.dic_stds['ale'][cat].append(std_ale) self.dic_stds[cat]['ale'].append(std_ale)
self.dic_stds['epi']['all'].append(std_epi) self.dic_stds['all']['epi'].append(std_epi)
self.dic_stds['epi'][clst].append(std_epi) self.dic_stds[clst]['epi'].append(std_epi)
self.dic_stds['epi'][cat].append(std_epi) self.dic_stds[cat]['epi'].append(std_epi)
# Number of annotations inside the confidence interval # Number of annotations inside the confidence interval
if dd_gt <= dd: # Particularly dangerous instances if dd_gt <= dd: # Particularly dangerous instances
self.dic_stds['at_risk']['all'].append(1) self.dic_stds['all']['at_risk'].append(1)
self.dic_stds['at_risk'][clst].append(1) self.dic_stds[clst]['at_risk'].append(1)
self.dic_stds['at_risk'][cat].append(1) self.dic_stds[cat]['at_risk'].append(1)
if abs(dd - dd_gt) <= std_epi: if abs(dd - dd_gt) <= std_epi:
self.dic_stds['interval']['all'].append(1) self.dic_stds['all']['interval'].append(1)
self.dic_stds['interval'][clst].append(1) self.dic_stds[clst]['interval'].append(1)
self.dic_stds['interval'][cat].append(1) self.dic_stds[cat]['interval'].append(1)
else: else:
self.dic_stds['interval']['all'].append(0) self.dic_stds['all']['interval'].append(0)
self.dic_stds['interval'][clst].append(0) self.dic_stds[clst]['interval'].append(0)
self.dic_stds['interval'][cat].append(0) self.dic_stds[cat]['interval'].append(0)
else: else:
self.dic_stds['at_risk']['all'].append(0) self.dic_stds['all']['at_risk'].append(0)
self.dic_stds['at_risk'][clst].append(0) self.dic_stds[clst]['at_risk'].append(0)
self.dic_stds['at_risk'][cat].append(0) self.dic_stds[cat]['at_risk'].append(0)
# self.dic_stds['at_risk']['all'].append(0) def get_statistics(dic_stats, errors, dic_stds, key):
# self.dic_stds['at_risk'][clst].append(0)
# self.dic_stds['at_risk'][cat].append(0)
def get_statistics(dic_stats, errors, clst, dic_stds, key):
"""Update statistics of a cluster""" """Update statistics of a cluster"""
dic_stats['mean'][clst] = sum(errors) / float(len(errors))
dic_stats['max'][clst] = max(errors) dic_stats['mean'] = sum(errors) / float(len(errors))
dic_stats['cnt'][clst] = len(errors) dic_stats['max'] = max(errors)
dic_stats['cnt'] = len(errors)
if key == 'our': if key == 'our':
dic_stats['std_ale'][clst] = sum(dic_stds['ale'][clst]) / float(len(dic_stds['ale'][clst])) dic_stats['std_ale'] = sum(dic_stds['ale']) / float(len(dic_stds['ale']))
dic_stats['std_epi'][clst] = sum(dic_stds['epi'][clst]) / float(len(dic_stds['epi'][clst])) dic_stats['std_epi'] = sum(dic_stds['epi']) / float(len(dic_stds['epi']))
dic_stats['interval'][clst] = sum(dic_stds['interval'][clst]) / float(len(dic_stds['interval'][clst])) dic_stats['interval'] = sum(dic_stds['interval']) / float(len(dic_stds['interval']))
dic_stats['at_risk'][clst] = sum(dic_stds['at_risk'][clst]) / float(len(dic_stds['at_risk'][clst])) dic_stats['at_risk'] = sum(dic_stds['at_risk']) / float(len(dic_stds['at_risk']))
def add_true_negatives(err, cnt_gt):
"""Update errors statistics of a specific method with missing detections"""
matched = len(err['all'])
missed = cnt_gt - matched
zeros = [0] * missed
err['<0.5m'].extend(zeros)
err['<1m'].extend(zeros)
err['<2m'].extend(zeros)
err['matched'] = 100 * matched / cnt_gt
def find_cluster(dd, clusters): def find_cluster(dd, clusters):

View File

@ -21,7 +21,7 @@ def print_results(dic_stats, show=False, save=False):
mm_std = 0.04 mm_std = 0.04
mm_gender = 0.0556 mm_gender = 0.0556
excl_clusters = ['all', '50', '>50', 'easy', 'moderate', 'hard'] excl_clusters = ['all', '50', '>50', 'easy', 'moderate', 'hard']
clusters = tuple([clst for clst in dic_stats[phase]['our']['mean'] if clst not in excl_clusters]) clusters = tuple([clst for clst in dic_stats[phase]['our'] if clst not in excl_clusters])
yy_gender = target_error(xx, mm_gender) yy_gender = target_error(xx, mm_gender)
yy_gps = np.linspace(5., 5., xx.shape[0]) yy_gps = np.linspace(5., 5., xx.shape[0])