import os import math import logging from collections import defaultdict import json import copy import datetime from utils.misc import get_idx_max from utils.kitti import check_conditions, get_category, split_training from visuals.results import print_results class KittiEval: """ Evaluate Monoloco code on KITTI dataset and compare it with: - Mono3D - 3DOP - MonoDepth """ def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3): logging.basicConfig(level=logging.INFO) self.logger = logging.getLogger(__name__) self.show = show self.dir_gt = os.path.join('data', 'kitti', 'gt') self.dir_m3d = os.path.join('data', 'kitti', 'm3d') self.dir_3dop = os.path.join('data', 'kitti', '3dop') self.dir_md = os.path.join('data', 'kitti', 'monodepth') self.dir_our = os.path.join('data', 'kitti', 'monoloco') path_train = os.path.join('splits', 'kitti_train.txt') path_val = os.path.join('splits', 'kitti_val.txt') dir_logs = os.path.join('data', 'logs') assert dir_logs, "No directory to save final statistics" now = datetime.datetime.now() now_time = now.strftime("%Y%m%d-%H%M")[2:] self.path_results = os.path.join(dir_logs, 'eval-' + now_time + '.json') assert os.path.exists(self.dir_m3d) and os.path.exists(self.dir_our) \ and os.path.exists(self.dir_3dop) self.clusters = ['easy', 'moderate', 'hard', 'all', '6', '10', '15', '20', '25', '30', '40', '50', '>50'] self.dic_thresh_iou = {'m3d': thresh_iou_m3d, '3dop': thresh_iou_m3d, 'md': thresh_iou_our, 'our': thresh_iou_our} self.dic_thresh_conf = {'m3d': thresh_conf_m3d, '3dop': thresh_conf_m3d, 'our': thresh_conf_our} self.dic_cnt = defaultdict(int) self.errors = defaultdict(lambda: defaultdict(list)) # Extract validation images for evaluation names_gt = tuple(os.listdir(self.dir_gt)) _, self.set_val = split_training(names_gt, path_train, path_val) aa = 5 def run(self): """Evaluate Monoloco methods on ALP and ALE metrics""" self.dic_stds = defaultdict(lambda: defaultdict(list)) dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))) cnt_gt = 0 # Iterate over each ground truth file in the training set for name in self.set_val: path_gt = os.path.join(self.dir_gt, name) path_m3d = os.path.join(self.dir_m3d, name) path_our = os.path.join(self.dir_our, name) path_3dop = os.path.join(self.dir_3dop, name) path_md = os.path.join(self.dir_md, name) boxes_gt = [] truncs_gt = [] # Float from 0 to 1 occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown dds_gt = [] dic_fin = defaultdict(list) # Iterate over each line of the gt file and save box location and distances with open(path_gt, "r") as f_gt: for line_gt in f_gt: if self.check_conditions(line_gt, mode='gt'): truncs_gt.append(float(line_gt.split()[1])) occs_gt.append(int(line_gt.split()[2])) boxes_gt.append([float(x) for x in line_gt.split()[4:8]]) loc_gt = [float(x) for x in line_gt.split()[11:14]] dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2)) cnt_gt += 1 # Extract annotations for the same file if len(boxes_gt) > 0: boxes_m3d, dds_m3d = self.parse_txts(path_m3d, method='m3d') boxes_3dop, dds_3dop = self.parse_txts(path_3dop, method='3dop') boxes_md, dds_md = self.parse_txts(path_md, method='md') boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps = \ self.parse_txts(path_our, method='our') # Compute the error with ground truth self.estimate_error_base(boxes_m3d, dds_m3d, boxes_gt, dds_gt, truncs_gt, occs_gt, method='m3d') self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop') self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md') self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name) # Iterate over all the files together to find a pool of common annotations self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom) # Save statistics for key in self.errors: for clst in self.clusters[:-2]: # M3d and pifpaf does not have annotations above 40 meters dic_stats['test'][key]['mean'][clst] = sum(self.errors[key][clst]) / float(len(self.errors[key][clst])) dic_stats['test'][key]['max'][clst] = max(self.errors[key][clst]) dic_stats['test'][key]['cnt'][clst] = len(self.errors[key][clst]) if key == 'our': for clst in self.clusters[:-2]: dic_stats['test'][key]['std_ale'][clst] = \ sum(self.dic_stds['ale'][clst]) / float(len(self.dic_stds['ale'][clst])) dic_stats['test'][key]['std_epi'][clst] = \ sum(self.dic_stds['epi'][clst]) / float(len(self.dic_stds['epi'][clst])) dic_stats['test'][key]['interval'][clst] = \ sum(self.dic_stds['interval'][clst]) / float(len(self.dic_stds['interval'][clst])) dic_stats['test'][key]['at_risk'][clst] = \ sum(self.dic_stds['at_risk'][clst]) / float(len(self.dic_stds['at_risk'][clst])) # Print statistics print(" Number of GT annotations: {} ".format(cnt_gt)) for key in self.errors: if key in ['our', 'm3d', '3dop']: print(" Number of {} annotations with confidence >= {} : {} " .format(key, self.dic_thresh_conf[key], self.dic_cnt[key])) # Include also missed annotations in the statistics matched = len(self.errors[key]['all']) missed = cnt_gt - matched zeros = [0] * missed self.errors[key]['<0.5m'].extend(zeros) self.errors[key]['<1m'].extend(zeros) self.errors[key]['<2m'].extend(zeros) for clst in self.clusters[:-9]: print(" {} Average error in cluster {}: {:.2f} with a max error of {:.1f}, " "for {} annotations" .format(key, clst, dic_stats['test'][key]['mean'][clst], dic_stats['test'][key]['max'][clst], dic_stats['test'][key]['cnt'][clst])) if key == 'our': print("% of annotation inside the confidence interval: {:.1f} %, " "of which {:.1f} % at higher risk" .format(100 * dic_stats['test'][key]['interval'][clst], 100 * dic_stats['test'][key]['at_risk'][clst])) for perc in ['<0.5m', '<1m', '<2m']: print("{} Instances with error {}: {:.2f} %" .format(key, perc, 100 * sum(self.errors[key][perc])/len(self.errors[key][perc]))) print("\n Number of matched annotations: {:.1f} %".format(100 * matched/cnt_gt)) print("-"*100) # Print images self.print_results(dic_stats, self.show) def parse_txts(self, path, method): boxes = [] dds = [] stds_ale = [] stds_epi = [] confs = [] dds_geom = [] xyzs = [] xy_kps = [] # Iterate over each line of the txt file if method == '3dop' or method == 'm3d': try: with open(path, "r") as ff: for line in ff: if self.check_conditions(line, thresh=self.dic_thresh_conf[method], mode=method): boxes.append([float(x) for x in line.split()[4:8]]) loc = ([float(x) for x in line.split()[11:14]]) dds.append(math.sqrt(loc[0] ** 2 + loc[1] ** 2 + loc[2] ** 2)) self.dic_cnt[method] += 1 return boxes, dds except FileNotFoundError: return [], [] elif method == 'md': try: with open(path, "r") as ff: for line in ff: box = [float(x[:-1]) for x in line.split()[0:4]] delta_h = (box[3] - box[1]) / 10 delta_w = (box[2] - box[0]) / 10 assert delta_h > 0 and delta_w > 0, "Bounding box <=0" box[0] -= delta_w box[1] -= delta_h box[2] += delta_w box[3] += delta_h boxes.append(box) dds.append(float(line.split()[5][:-1])) self.dic_cnt[method] += 1 return boxes, dds except FileNotFoundError: return [], [] elif method == 'psm': try: with open(path, "r") as ff: for line in ff: box = [float(x[:-1]) for x in line[1:-1].split(',')[0:4]] delta_h = (box[3] - box[1]) / 10 delta_w = (box[2] - box[0]) / 10 assert delta_h > 0 and delta_w > 0, "Bounding box <=0" box[0] -= delta_w box[1] -= delta_h box[2] += delta_w box[3] += delta_h boxes.append(box) dds.append(float(line.split()[5][:-1])) self.dic_cnt[method] += 1 return boxes, dds except FileNotFoundError: return [], [] elif method == 'our': try: with open(path, "r") as ff: file_lines = ff.readlines() for line_our in file_lines[:-1]: line_list = [float(x) for x in line_our.split()] if self.check_conditions(line_list, thresh=self.dic_thresh_conf[method], mode=method): boxes.append(line_list[:4]) xyzs.append(line_list[4:7]) dds.append(line_list[7]) stds_ale.append(line_list[8]) stds_epi.append(line_list[9]) dds_geom.append(line_list[11]) xy_kps.append(line_list[12:]) self.dic_cnt[method] += 1 kk_list = [float(x) for x in file_lines[-1].split()] return boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps except FileNotFoundError: return [], [], [], [], [], [], [], [] def estimate_error_base(self, boxes, dds, boxes_gt, dds_gt, truncs_gt, occs_gt, method): # Compute error (distance) and save it boxes_gt = copy.deepcopy(boxes_gt) dds_gt = copy.deepcopy(dds_gt) truncs_gt = copy.deepcopy(truncs_gt) occs_gt = copy.deepcopy(occs_gt) for idx, box in enumerate(boxes): if len(boxes_gt) >= 1: dd = dds[idx] idx_max, iou_max = self.get_idx_max(box, boxes_gt) cat = self.get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) # Update error if match is found if iou_max > self.dic_thresh_iou[method]: dd_gt = dds_gt[idx_max] self.update_errors(dd, dd_gt, cat, self.errors[method]) boxes_gt.pop(idx_max) dds_gt.pop(idx_max) truncs_gt.pop(idx_max) occs_gt.pop(idx_max) else: break def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name): # Compute error (distance) and save it boxes_gt = copy.deepcopy(boxes_gt) dds_gt = copy.deepcopy(dds_gt) truncs_gt = copy.deepcopy(truncs_gt) occs_gt = copy.deepcopy(occs_gt) for idx, box in enumerate(boxes): if len(boxes_gt) >= 1: dd = dds[idx] dd_geom = dds_geom[idx] ale = stds_ale[idx] epi = stds_epi[idx] xyz = xyzs[idx] xy_kp = xy_kps[idx] idx_max, iou_max = self.get_idx_max(box, boxes_gt) cat = self.get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) # Update error if match is found if iou_max > self.dic_thresh_iou['our']: dd_gt = dds_gt[idx_max] self.update_errors(dd, dd_gt, cat, self.errors['our']) self.update_errors(dd_geom, dd_gt, cat, self.errors['geom']) self.update_uncertainty(ale, epi, dd, dd_gt, cat) boxes_gt.pop(idx_max) dds_gt.pop(idx_max) truncs_gt.pop(idx_max) occs_gt.pop(idx_max) # Extract K and save it everything in a json file dic_fin['boxes'].append(box) dic_fin['dds_gt'].append(dd_gt) dic_fin['dds_pred'].append(dd) dic_fin['stds_ale'].append(ale) dic_fin['stds_epi'].append(epi) dic_fin['dds_geom'].append(dd_geom) dic_fin['xyz'].append(xyz) dic_fin['xy_kps'].append(xy_kp) else: break # kk_fin = np.array(kk_list).reshape(3, 3).tolist() # dic_fin['K'] = kk_fin # path_json = os.path.join(self.dir_fin, name[:-4] + '.json') # with open(path_json, 'w') as ff: # json.dump(dic_fin, ff) def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom): boxes_gt = copy.deepcopy(boxes_gt) dds_gt = copy.deepcopy(dds_gt) truncs_gt = copy.deepcopy(truncs_gt) occs_gt = copy.deepcopy(occs_gt) for idx, box in enumerate(boxes_our): if len(boxes_gt) >= 1: dd_our = dds_our[idx] dd_geom = dds_geom[idx] idx_max, iou_max = self.get_idx_max(box, boxes_gt) cat = self.get_category(boxes_gt[idx_max], truncs_gt[idx_max], occs_gt[idx_max]) idx_max_3dop, iou_max_3dop = self.get_idx_max(box, boxes_3dop) idx_max_m3d, iou_max_m3d = self.get_idx_max(box, boxes_m3d) idx_max_md, iou_max_md = self.get_idx_max(box, boxes_md) iou_min = min(iou_max_3dop, iou_max_m3d, iou_max_md) if iou_max >= self.dic_thresh_iou['our'] and iou_min >= self.dic_thresh_iou['m3d']: dd_gt = dds_gt[idx_max] dd_3dop = dds_3dop[idx_max_3dop] dd_m3d = dds_m3d[idx_max_m3d] dd_md = dds_md[idx_max_md] self.update_errors(dd_3dop, dd_gt, cat, self.errors['3dop_merged']) self.update_errors(dd_our, dd_gt, cat, self.errors['our_merged']) self.update_errors(dd_m3d, dd_gt, cat, self.errors['m3d_merged']) self.update_errors(dd_geom, dd_gt, cat, self.errors['geom_merged']) self.update_errors(dd_md, dd_gt, cat, self.errors['md_merged']) self.dic_cnt['merged'] += 1 boxes_gt.pop(idx_max) dds_gt.pop(idx_max) truncs_gt.pop(idx_max) occs_gt.pop(idx_max) else: break def update_errors(self, dd, dd_gt, cat, errors): """Compute and save errors between a single box and the gt box which match""" diff = abs(dd - dd_gt) clst = self.find_cluster(dd_gt, self.clusters) errors['all'].append(diff) errors[cat].append(diff) errors[clst].append(diff) # Check if the distance is less than one or 2 meters if diff <= 0.5: errors['<0.5m'].append(1) else: errors['<0.5m'].append(0) if diff <= 1: errors['<1m'].append(1) else: errors['<1m'].append(0) if diff <= 2: errors['<2m'].append(1) else: errors['<2m'].append(0) def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat): clst = self.find_cluster(dd_gt, self.clusters) self.dic_stds['ale']['all'].append(std_ale) self.dic_stds['ale'][clst].append(std_ale) self.dic_stds['ale'][cat].append(std_ale) self.dic_stds['epi']['all'].append(std_epi) self.dic_stds['epi'][clst].append(std_epi) self.dic_stds['epi'][cat].append(std_epi) # Number of annotations inside the confidence interval if dd_gt <= dd: # Particularly dangerous instances self.dic_stds['at_risk']['all'].append(1) self.dic_stds['at_risk'][clst].append(1) self.dic_stds['at_risk'][cat].append(1) if abs(dd - dd_gt) <= (std_epi): self.dic_stds['interval']['all'].append(1) self.dic_stds['interval'][clst].append(1) self.dic_stds['interval'][cat].append(1) else: self.dic_stds['interval']['all'].append(0) self.dic_stds['interval'][clst].append(0) self.dic_stds['interval'][cat].append(0) else: self.dic_stds['at_risk']['all'].append(0) self.dic_stds['at_risk'][clst].append(0) self.dic_stds['at_risk'][cat].append(0) # self.dic_stds['at_risk']['all'].append(0) # self.dic_stds['at_risk'][clst].append(0) # self.dic_stds['at_risk'][cat].append(0) @staticmethod def find_cluster(dd, clusters): """Find the correct cluster. The first and the last one are not numeric""" for clst in clusters[4: -1]: if dd <= int(clst): return clst return clusters[-1]