monoloco/src/eval/kitti_eval.py
2019-07-02 09:40:55 +02:00

382 lines
16 KiB
Python

"""Evaluate Monoloco code on KITTI dataset using ALE and ALP metrics"""
import os
import math
import logging
from collections import defaultdict
import datetime
from utils.iou import get_iou_matches
from utils.misc import get_task_error
from utils.kitti import check_conditions, get_category, split_training, parse_ground_truth
from visuals.results import print_results
class KittiEval:
"""
Evaluate Monoloco code and compare it with the following baselines:
- Mono3D
- 3DOP
- MonoDepth
"""
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
CLUSTERS = ('easy', 'moderate', 'hard', 'all', '6', '10', '15', '20', '25', '30', '40', '50', '>50')
dic_stds = defaultdict(lambda: defaultdict(list))
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
dic_cnt = defaultdict(int)
errors = defaultdict(lambda: defaultdict(list))
def __init__(self, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
self.dir_gt = os.path.join('data', 'kitti', 'gt')
self.dir_m3d = os.path.join('data', 'kitti', 'm3d')
self.dir_3dop = os.path.join('data', 'kitti', '3dop')
self.dir_md = os.path.join('data', 'kitti', 'monodepth')
self.dir_our = os.path.join('data', 'kitti', 'monoloco')
path_train = os.path.join('splits', 'kitti_train.txt')
path_val = os.path.join('splits', 'kitti_val.txt')
dir_logs = os.path.join('data', 'logs')
assert dir_logs, "No directory to save final statistics"
now = datetime.datetime.now()
now_time = now.strftime("%Y%m%d-%H%M")[2:]
self.path_results = os.path.join(dir_logs, 'eval-' + now_time + '.json')
assert os.path.exists(self.dir_m3d) and os.path.exists(self.dir_our) \
and os.path.exists(self.dir_3dop)
self.dic_thresh_iou = {'m3d': thresh_iou_m3d, '3dop': thresh_iou_m3d,
'md': thresh_iou_our, 'our': thresh_iou_our}
self.dic_thresh_conf = {'m3d': thresh_conf_m3d, '3dop': thresh_conf_m3d, 'our': thresh_conf_our}
# Extract validation images for evaluation
names_gt = tuple(os.listdir(self.dir_gt))
_, self.set_val = split_training(names_gt, path_train, path_val)
def run(self):
"""Evaluate Monoloco performances on ALP and ALE metrics"""
# Iterate over each ground truth file in the training set
cnt_gt = 0
for name in self.set_val:
path_gt = os.path.join(self.dir_gt, name)
path_m3d = os.path.join(self.dir_m3d, name)
path_our = os.path.join(self.dir_our, name)
path_3dop = os.path.join(self.dir_3dop, name)
path_md = os.path.join(self.dir_md, name)
# Iterate over each line of the gt file and save box location and distances
out_gt = parse_ground_truth(path_gt)
cnt_gt += len(out_gt[0])
# Extract annotations for the same file
if out_gt[0]:
out_m3d = self._parse_txts(path_m3d, method='m3d')
out_3dop = self._parse_txts(path_3dop, method='3dop')
out_md = self._parse_txts(path_md, method='md')
out_our = self._parse_txts(path_our, method='our')
# Compute the error with ground truth
self._estimate_error(out_gt, out_m3d, method='m3d')
self._estimate_error(out_gt, out_3dop, method='3dop')
self._estimate_error(out_gt, out_md, method='md')
self._estimate_error(out_gt, out_our, method='our')
# Iterate over all the files together to find a pool of common annotations
self._compare_error(out_gt, out_m3d, out_3dop, out_md, out_our)
# Update statistics of errors and uncertainty
for key in self.errors:
add_true_negatives(self.errors[key], cnt_gt)
for clst in self.CLUSTERS[:-2]: # M3d and pifpaf does not have annotations above 40 meters
get_statistics(self.dic_stats['test'][key][clst], self.errors[key][clst], self.dic_stds[clst], key)
# Show statistics
print(" Number of GT annotations: {} ".format(cnt_gt))
for key in self.errors:
if key in ['our', 'm3d', '3dop']:
print(" Number of {} annotations with confidence >= {} : {} "
.format(key, self.dic_thresh_conf[key], self.dic_cnt[key]))
for clst in self.CLUSTERS[:-9]:
print(" {} Average error in cluster {}: {:.2f} with a max error of {:.1f}, "
"for {} annotations"
.format(key, clst, self.dic_stats['test'][key][clst]['mean'],
self.dic_stats['test'][key][clst]['max'],
self.dic_stats['test'][key][clst]['cnt']))
if key == 'our':
print("% of annotation inside the confidence interval: {:.1f} %, "
"of which {:.1f} % at higher risk"
.format(100 * self.dic_stats['test'][key][clst]['interval'],
100 * self.dic_stats['test'][key][clst]['at_risk']))
for perc in ['<0.5m', '<1m', '<2m']:
print("{} Instances with error {}: {:.2f} %"
.format(key, perc, 100 * sum(self.errors[key][perc])/len(self.errors[key][perc])))
print("\n Number of matched annotations: {:.1f} %".format(self.errors[key]['matched']))
print("-"*100)
print("\n Annotations inside the confidence interval: {:.1f} %"
.format(100 * self.dic_stats['test']['our']['all']['interval']))
print("precision 1: {:.2f}".format(self.dic_stats['test']['our']['all']['prec_1']))
print("precision 2: {:.2f}".format(self.dic_stats['test']['our']['all']['prec_2']))
def printer(self, show):
print_results(self.dic_stats, show)
def _parse_txts(self, path, method):
boxes = []
dds = []
stds_ale = []
stds_epi = []
dds_geom = []
# xyzs = []
# xy_kps = []
# Iterate over each line of the txt file
if method in ['3dop', 'm3d']:
try:
with open(path, "r") as ff:
for line in ff:
if check_conditions(line, thresh=self.dic_thresh_conf[method], mode=method):
boxes.append([float(x) for x in line.split()[4:8]])
loc = ([float(x) for x in line.split()[11:14]])
dds.append(math.sqrt(loc[0] ** 2 + loc[1] ** 2 + loc[2] ** 2))
self.dic_cnt[method] += 1
return boxes, dds
except FileNotFoundError:
return [], []
elif method == 'md':
try:
with open(path, "r") as ff:
for line in ff:
box = [float(x[:-1]) for x in line.split()[0:4]]
delta_h = (box[3] - box[1]) / 10
delta_w = (box[2] - box[0]) / 10
assert delta_h > 0 and delta_w > 0, "Bounding box <=0"
box[0] -= delta_w
box[1] -= delta_h
box[2] += delta_w
box[3] += delta_h
boxes.append(box)
dds.append(float(line.split()[5][:-1]))
self.dic_cnt[method] += 1
return boxes, dds
except FileNotFoundError:
return [], []
else:
assert method == 'our', "method not recognized"
try:
with open(path, "r") as ff:
file_lines = ff.readlines()
for line_our in file_lines[:-1]:
line_list = [float(x) for x in line_our.split()]
if check_conditions(line_list, thresh=self.dic_thresh_conf[method], mode=method):
boxes.append(line_list[:4])
dds.append(line_list[8])
stds_ale.append(line_list[9])
stds_epi.append(line_list[10])
dds_geom.append(line_list[11])
self.dic_cnt[method] += 1
# kk_list = [float(x) for x in file_lines[-1].split()]
return boxes, dds, stds_ale, stds_epi, dds_geom
except FileNotFoundError:
return [], [], [], [], []
def _estimate_error(self, out_gt, out, method):
"""Estimate localization error"""
boxes_gt, _, dds_gt, truncs_gt, occs_gt = out_gt
if method == 'our':
boxes, dds, stds_ale, stds_epi, dds_geom = out
else:
boxes, dds = out
matches = get_iou_matches(boxes, boxes_gt, self.dic_thresh_iou[method])
for (idx, idx_gt) in matches:
# Update error if match is found
cat = get_category(boxes_gt[idx_gt], truncs_gt[idx_gt], occs_gt[idx_gt])
self.update_errors(dds[idx], dds_gt[idx_gt], cat, self.errors[method])
if method == 'our':
self.update_errors(dds_geom[idx], dds_gt[idx_gt], cat, self.errors['geom'])
self.update_uncertainty(stds_ale[idx], stds_epi[idx], dds[idx], dds_gt[idx_gt], cat)
def _compare_error(self, out_gt, out_m3d, out_3dop, out_md, out_our):
"""Compare the error for a pool of instances commonly matched by all methods"""
# Extract outputs of each method
boxes_gt, _, dds_gt, truncs_gt, occs_gt = out_gt
boxes_m3d, dds_m3d = out_m3d
boxes_3dop, dds_3dop = out_3dop
boxes_md, dds_md = out_md
boxes_our, dds_our, _, _, dds_geom = out_our
# Find IoU matches
matches_our = get_iou_matches(boxes_our, boxes_gt, self.dic_thresh_iou['our'])
matches_m3d = get_iou_matches(boxes_m3d, boxes_gt, self.dic_thresh_iou['m3d'])
matches_3dop = get_iou_matches(boxes_3dop, boxes_gt, self.dic_thresh_iou['3dop'])
matches_md = get_iou_matches(boxes_md, boxes_gt, self.dic_thresh_iou['md'])
# Update error of commonly matched instances
for (idx, idx_gt) in matches_our:
check, indices = extract_indices(idx_gt, matches_m3d, matches_3dop, matches_md)
if check:
cat = get_category(boxes_gt[idx_gt], truncs_gt[idx_gt], occs_gt[idx_gt])
dd_gt = dds_gt[idx_gt]
self.update_errors(dds_our[idx], dd_gt, cat, self.errors['our_merged'])
self.update_errors(dds_geom[idx], dd_gt, cat, self.errors['geom_merged'])
self.update_errors(dds_m3d[indices[0]], dd_gt, cat, self.errors['m3d_merged'])
self.update_errors(dds_3dop[indices[1]], dd_gt, cat, self.errors['3dop_merged'])
self.update_errors(dds_md[indices[2]], dd_gt, cat, self.errors['md_merged'])
self.dic_cnt['merged'] += 1
def update_errors(self, dd, dd_gt, cat, errors):
"""Compute and save errors between a single box and the gt box which match"""
diff = abs(dd - dd_gt)
clst = find_cluster(dd_gt, self.CLUSTERS)
errors['all'].append(diff)
errors[cat].append(diff)
errors[clst].append(diff)
# Check if the distance is less than one or 2 meters
if diff <= 0.5:
errors['<0.5m'].append(1)
else:
errors['<0.5m'].append(0)
if diff <= 1:
errors['<1m'].append(1)
else:
errors['<1m'].append(0)
if diff <= 2:
errors['<2m'].append(1)
else:
errors['<2m'].append(0)
def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat):
clst = find_cluster(dd_gt, self.CLUSTERS)
self.dic_stds['all']['ale'].append(std_ale)
self.dic_stds[clst]['ale'].append(std_ale)
self.dic_stds[cat]['ale'].append(std_ale)
self.dic_stds['all']['epi'].append(std_epi)
self.dic_stds[clst]['epi'].append(std_epi)
self.dic_stds[cat]['epi'].append(std_epi)
# Number of annotations inside the confidence interval
std = std_epi if std_epi > 0 else std_ale # consider aleatoric uncertainty if epistemic is not calculated
if abs(dd - dd_gt) <= std:
self.dic_stds['all']['interval'].append(1)
self.dic_stds[clst]['interval'].append(1)
self.dic_stds[cat]['interval'].append(1)
else:
self.dic_stds['all']['interval'].append(0)
self.dic_stds[clst]['interval'].append(0)
self.dic_stds[cat]['interval'].append(0)
# Annotations at risk inside the confidence interval
if dd_gt <= dd:
self.dic_stds['all']['at_risk'].append(1)
self.dic_stds[clst]['at_risk'].append(1)
self.dic_stds[cat]['at_risk'].append(1)
if abs(dd - dd_gt) <= std_epi:
self.dic_stds['all']['at_risk-interval'].append(1)
self.dic_stds[clst]['at_risk-interval'].append(1)
self.dic_stds[cat]['at_risk-interval'].append(1)
else:
self.dic_stds['all']['at_risk-interval'].append(0)
self.dic_stds[clst]['at_risk-interval'].append(0)
self.dic_stds[cat]['at_risk-interval'].append(0)
else:
self.dic_stds['all']['at_risk'].append(0)
self.dic_stds[clst]['at_risk'].append(0)
self.dic_stds[cat]['at_risk'].append(0)
# Precision of uncertainty
eps = 1e-4
task_error = get_task_error(dd)
prec_1 = abs(dd - dd_gt) / (std_epi + eps)
prec_2 = abs(std_epi - task_error)
self.dic_stds['all']['prec_1'].append(prec_1)
self.dic_stds[clst]['prec_1'].append(prec_1)
self.dic_stds[cat]['prec_1'].append(prec_1)
self.dic_stds['all']['prec_2'].append(prec_2)
self.dic_stds[clst]['prec_2'].append(prec_2)
self.dic_stds[cat]['prec_2'].append(prec_2)
def get_statistics(dic_stats, errors, dic_stds, key):
"""Update statistics of a cluster"""
dic_stats['mean'] = sum(errors) / float(len(errors))
dic_stats['max'] = max(errors)
dic_stats['cnt'] = len(errors)
if key == 'our':
dic_stats['std_ale'] = sum(dic_stds['ale']) / float(len(dic_stds['ale']))
dic_stats['std_epi'] = sum(dic_stds['epi']) / float(len(dic_stds['epi']))
dic_stats['interval'] = sum(dic_stds['interval']) / float(len(dic_stds['interval']))
dic_stats['at_risk'] = sum(dic_stds['at_risk']) / float(len(dic_stds['at_risk']))
dic_stats['prec_1'] = sum(dic_stds['prec_1']) / float(len(dic_stds['prec_1']))
dic_stats['prec_2'] = sum(dic_stds['prec_2']) / float(len(dic_stds['prec_2']))
def add_true_negatives(err, cnt_gt):
"""Update errors statistics of a specific method with missing detections"""
matched = len(err['all'])
missed = cnt_gt - matched
zeros = [0] * missed
err['<0.5m'].extend(zeros)
err['<1m'].extend(zeros)
err['<2m'].extend(zeros)
err['matched'] = 100 * matched / cnt_gt
def find_cluster(dd, clusters):
"""Find the correct cluster. The first and the last one are not numeric"""
for clst in clusters[4: -1]:
if dd <= int(clst):
return clst
return clusters[-1]
def extract_indices(idx_to_check, *args):
"""
Look if a given index j_gt is present in all the other series of indices (_, j)
and return the corresponding one for argument
idx_check --> gt index to check for correspondences in other method
idx_method --> index corresponding to the method
idx_gt --> index gt of the method
idx_pred --> index of the predicted box of the method
indices --> list of predicted indices for each method corresponding to the ground truth index to check
"""
checks = [False]*len(args)
indices = []
for idx_method, method in enumerate(args):
for (idx_pred, idx_gt) in method:
if idx_gt == idx_to_check:
checks[idx_method] = True
indices.append(idx_pred)
return all(checks), indices