From 521a04ece5c9eba204c0f5b6dcee29abb194efd0 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Tue, 21 May 2019 11:20:35 +0200 Subject: [PATCH] refactor methods --- src/eval/kitti_eval.py | 70 ++++++++++++------------------------------ src/utils/kitti.py | 20 ++++++++++++ 2 files changed, 39 insertions(+), 51 deletions(-) diff --git a/src/eval/kitti_eval.py b/src/eval/kitti_eval.py index 0d2ea66..486fe34 100644 --- a/src/eval/kitti_eval.py +++ b/src/eval/kitti_eval.py @@ -6,9 +6,8 @@ from collections import defaultdict import json import copy import datetime - from utils.misc import get_idx_max -from utils.kitti import check_conditions, get_category, split_training +from utils.kitti import check_conditions, get_category, split_training, parse_ground_truth from visuals.results import print_results @@ -69,30 +68,18 @@ class KittiEval: path_our = os.path.join(self.dir_our, name) path_3dop = os.path.join(self.dir_3dop, name) path_md = os.path.join(self.dir_md, name) - boxes_gt = [] - truncs_gt = [] # Float from 0 to 1 - occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown - dds_gt = [] - dic_fin = defaultdict(list) # Iterate over each line of the gt file and save box location and distances - with open(path_gt, "r") as f_gt: - for line_gt in f_gt: - if check_conditions(line_gt, mode='gt'): - truncs_gt.append(float(line_gt.split()[1])) - occs_gt.append(int(line_gt.split()[2])) - boxes_gt.append([float(x) for x in line_gt.split()[4:8]]) - loc_gt = [float(x) for x in line_gt.split()[11:14]] - dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2)) - cnt_gt += 1 + boxes_gt, dds_gt, truncs_gt, occs_gt = parse_ground_truth(path_gt) + cnt_gt += len(boxes_gt) # Extract annotations for the same file if len(boxes_gt) > 0: - boxes_m3d, dds_m3d = self.parse_txts(path_m3d, method='m3d') - boxes_3dop, dds_3dop = self.parse_txts(path_3dop, method='3dop') - boxes_md, dds_md = self.parse_txts(path_md, method='md') + boxes_m3d, dds_m3d = self._parse_txts(path_m3d, method='m3d') + boxes_3dop, dds_3dop = self._parse_txts(path_3dop, method='3dop') + boxes_md, dds_md = self._parse_txts(path_md, method='md') boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps = \ - self.parse_txts(path_our, method='our') + self._parse_txts(path_our, method='our') # Compute the error with ground truth @@ -100,7 +87,7 @@ class KittiEval: self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop') self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md') self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, - boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name) + boxes_gt, dds_gt, truncs_gt, occs_gt, name) # Iterate over all the files together to find a pool of common annotations self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, @@ -161,7 +148,7 @@ class KittiEval: # Print images print_results(self.dic_stats, self.show) - def parse_txts(self, path, method): + def _parse_txts(self, path, method): boxes = [] dds = [] stds_ale = [] @@ -276,7 +263,7 @@ class KittiEval: break def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps, - boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name): + boxes_gt, dds_gt, truncs_gt, occs_gt, name): # Compute error (distance) and save it boxes_gt = copy.deepcopy(boxes_gt) @@ -307,24 +294,6 @@ class KittiEval: truncs_gt.pop(idx_max) occs_gt.pop(idx_max) - # Extract K and save it everything in a json file - dic_fin['boxes'].append(box) - dic_fin['dds_gt'].append(dd_gt) - dic_fin['dds_pred'].append(dd) - dic_fin['stds_ale'].append(ale) - dic_fin['stds_epi'].append(epi) - dic_fin['dds_geom'].append(dd_geom) - dic_fin['xyz'].append(xyz) - dic_fin['xy_kps'].append(xy_kp) - else: - break - - # kk_fin = np.array(kk_list).reshape(3, 3).tolist() - # dic_fin['K'] = kk_fin - # path_json = os.path.join(self.dir_fin, name[:-4] + '.json') - # with open(path_json, 'w') as ff: - # json.dump(dic_fin, ff) - def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our, boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom): @@ -372,7 +341,7 @@ class KittiEval: """Compute and save errors between a single box and the gt box which match""" diff = abs(dd - dd_gt) - clst = self.find_cluster(dd_gt, self.clusters) + clst = find_cluster(dd_gt, self.clusters) errors['all'].append(diff) errors[cat].append(diff) errors[clst].append(diff) @@ -395,7 +364,7 @@ class KittiEval: def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat): - clst = self.find_cluster(dd_gt, self.clusters) + clst = find_cluster(dd_gt, self.clusters) self.dic_stds['ale']['all'].append(std_ale) self.dic_stds['ale'][clst].append(std_ale) self.dic_stds['ale'][cat].append(std_ale) @@ -409,7 +378,7 @@ class KittiEval: self.dic_stds['at_risk'][clst].append(1) self.dic_stds['at_risk'][cat].append(1) - if abs(dd - dd_gt) <= (std_epi): + if abs(dd - dd_gt) <= std_epi: self.dic_stds['interval']['all'].append(1) self.dic_stds['interval'][clst].append(1) self.dic_stds['interval'][cat].append(1) @@ -430,12 +399,11 @@ class KittiEval: # self.dic_stds['at_risk'][clst].append(0) # self.dic_stds['at_risk'][cat].append(0) - @staticmethod - def find_cluster(dd, clusters): - """Find the correct cluster. The first and the last one are not numeric""" +def find_cluster(dd, clusters): + """Find the correct cluster. The first and the last one are not numeric""" - for clst in clusters[4: -1]: - if dd <= int(clst): - return clst + for clst in clusters[4: -1]: + if dd <= int(clst): + return clst - return clusters[-1] + return clusters[-1] \ No newline at end of file diff --git a/src/utils/kitti.py b/src/utils/kitti.py index 910b647..c9ecb75 100644 --- a/src/utils/kitti.py +++ b/src/utils/kitti.py @@ -1,6 +1,7 @@ import numpy as np import copy +import math from utils.camera import pixel_to_camera, get_keypoints from eval.geom_baseline import compute_distance_single @@ -143,3 +144,22 @@ def split_training(names_gt, path_train, path_val): assert set_train and set_val, "No validation or training annotations" return set_train, set_val + +def parse_ground_truth(path_gt): + """Parse KITTI ground truth files""" + boxes_gt = [] + dds_gt = [] + truncs_gt = [] # Float from 0 to 1 + occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown + + with open(path_gt, "r") as f_gt: + for line_gt in f_gt: + if check_conditions(line_gt, mode='gt'): + truncs_gt.append(float(line_gt.split()[1])) + occs_gt.append(int(line_gt.split()[2])) + boxes_gt.append([float(x) for x in line_gt.split()[4:8]]) + loc_gt = [float(x) for x in line_gt.split()[11:14]] + dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2)) + + return boxes_gt, dds_gt, truncs_gt, occs_gt +