refactor methods
This commit is contained in:
parent
67bc780677
commit
521a04ece5
@ -6,9 +6,8 @@ from collections import defaultdict
|
||||
import json
|
||||
import copy
|
||||
import datetime
|
||||
|
||||
from utils.misc import get_idx_max
|
||||
from utils.kitti import check_conditions, get_category, split_training
|
||||
from utils.kitti import check_conditions, get_category, split_training, parse_ground_truth
|
||||
from visuals.results import print_results
|
||||
|
||||
|
||||
@ -69,30 +68,18 @@ class KittiEval:
|
||||
path_our = os.path.join(self.dir_our, name)
|
||||
path_3dop = os.path.join(self.dir_3dop, name)
|
||||
path_md = os.path.join(self.dir_md, name)
|
||||
boxes_gt = []
|
||||
truncs_gt = [] # Float from 0 to 1
|
||||
occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown
|
||||
dds_gt = []
|
||||
dic_fin = defaultdict(list)
|
||||
|
||||
# Iterate over each line of the gt file and save box location and distances
|
||||
with open(path_gt, "r") as f_gt:
|
||||
for line_gt in f_gt:
|
||||
if check_conditions(line_gt, mode='gt'):
|
||||
truncs_gt.append(float(line_gt.split()[1]))
|
||||
occs_gt.append(int(line_gt.split()[2]))
|
||||
boxes_gt.append([float(x) for x in line_gt.split()[4:8]])
|
||||
loc_gt = [float(x) for x in line_gt.split()[11:14]]
|
||||
dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2))
|
||||
cnt_gt += 1
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt = parse_ground_truth(path_gt)
|
||||
cnt_gt += len(boxes_gt)
|
||||
|
||||
# Extract annotations for the same file
|
||||
if len(boxes_gt) > 0:
|
||||
boxes_m3d, dds_m3d = self.parse_txts(path_m3d, method='m3d')
|
||||
boxes_3dop, dds_3dop = self.parse_txts(path_3dop, method='3dop')
|
||||
boxes_md, dds_md = self.parse_txts(path_md, method='md')
|
||||
boxes_m3d, dds_m3d = self._parse_txts(path_m3d, method='m3d')
|
||||
boxes_3dop, dds_3dop = self._parse_txts(path_3dop, method='3dop')
|
||||
boxes_md, dds_md = self._parse_txts(path_md, method='md')
|
||||
boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps = \
|
||||
self.parse_txts(path_our, method='our')
|
||||
self._parse_txts(path_our, method='our')
|
||||
|
||||
# Compute the error with ground truth
|
||||
|
||||
@ -100,7 +87,7 @@ class KittiEval:
|
||||
self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
|
||||
self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
|
||||
self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name)
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
|
||||
|
||||
# Iterate over all the files together to find a pool of common annotations
|
||||
self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||
@ -161,7 +148,7 @@ class KittiEval:
|
||||
# Print images
|
||||
print_results(self.dic_stats, self.show)
|
||||
|
||||
def parse_txts(self, path, method):
|
||||
def _parse_txts(self, path, method):
|
||||
boxes = []
|
||||
dds = []
|
||||
stds_ale = []
|
||||
@ -276,7 +263,7 @@ class KittiEval:
|
||||
break
|
||||
|
||||
def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name):
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
|
||||
|
||||
# Compute error (distance) and save it
|
||||
boxes_gt = copy.deepcopy(boxes_gt)
|
||||
@ -307,24 +294,6 @@ class KittiEval:
|
||||
truncs_gt.pop(idx_max)
|
||||
occs_gt.pop(idx_max)
|
||||
|
||||
# Extract K and save it everything in a json file
|
||||
dic_fin['boxes'].append(box)
|
||||
dic_fin['dds_gt'].append(dd_gt)
|
||||
dic_fin['dds_pred'].append(dd)
|
||||
dic_fin['stds_ale'].append(ale)
|
||||
dic_fin['stds_epi'].append(epi)
|
||||
dic_fin['dds_geom'].append(dd_geom)
|
||||
dic_fin['xyz'].append(xyz)
|
||||
dic_fin['xy_kps'].append(xy_kp)
|
||||
else:
|
||||
break
|
||||
|
||||
# kk_fin = np.array(kk_list).reshape(3, 3).tolist()
|
||||
# dic_fin['K'] = kk_fin
|
||||
# path_json = os.path.join(self.dir_fin, name[:-4] + '.json')
|
||||
# with open(path_json, 'w') as ff:
|
||||
# json.dump(dic_fin, ff)
|
||||
|
||||
def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
|
||||
|
||||
@ -372,7 +341,7 @@ class KittiEval:
|
||||
"""Compute and save errors between a single box and the gt box which match"""
|
||||
|
||||
diff = abs(dd - dd_gt)
|
||||
clst = self.find_cluster(dd_gt, self.clusters)
|
||||
clst = find_cluster(dd_gt, self.clusters)
|
||||
errors['all'].append(diff)
|
||||
errors[cat].append(diff)
|
||||
errors[clst].append(diff)
|
||||
@ -395,7 +364,7 @@ class KittiEval:
|
||||
|
||||
def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat):
|
||||
|
||||
clst = self.find_cluster(dd_gt, self.clusters)
|
||||
clst = find_cluster(dd_gt, self.clusters)
|
||||
self.dic_stds['ale']['all'].append(std_ale)
|
||||
self.dic_stds['ale'][clst].append(std_ale)
|
||||
self.dic_stds['ale'][cat].append(std_ale)
|
||||
@ -409,7 +378,7 @@ class KittiEval:
|
||||
self.dic_stds['at_risk'][clst].append(1)
|
||||
self.dic_stds['at_risk'][cat].append(1)
|
||||
|
||||
if abs(dd - dd_gt) <= (std_epi):
|
||||
if abs(dd - dd_gt) <= std_epi:
|
||||
self.dic_stds['interval']['all'].append(1)
|
||||
self.dic_stds['interval'][clst].append(1)
|
||||
self.dic_stds['interval'][cat].append(1)
|
||||
@ -430,12 +399,11 @@ class KittiEval:
|
||||
# self.dic_stds['at_risk'][clst].append(0)
|
||||
# self.dic_stds['at_risk'][cat].append(0)
|
||||
|
||||
@staticmethod
|
||||
def find_cluster(dd, clusters):
|
||||
"""Find the correct cluster. The first and the last one are not numeric"""
|
||||
def find_cluster(dd, clusters):
|
||||
"""Find the correct cluster. The first and the last one are not numeric"""
|
||||
|
||||
for clst in clusters[4: -1]:
|
||||
if dd <= int(clst):
|
||||
return clst
|
||||
for clst in clusters[4: -1]:
|
||||
if dd <= int(clst):
|
||||
return clst
|
||||
|
||||
return clusters[-1]
|
||||
return clusters[-1]
|
||||
@ -1,6 +1,7 @@
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
import math
|
||||
|
||||
from utils.camera import pixel_to_camera, get_keypoints
|
||||
from eval.geom_baseline import compute_distance_single
|
||||
@ -143,3 +144,22 @@ def split_training(names_gt, path_train, path_val):
|
||||
assert set_train and set_val, "No validation or training annotations"
|
||||
return set_train, set_val
|
||||
|
||||
|
||||
def parse_ground_truth(path_gt):
|
||||
"""Parse KITTI ground truth files"""
|
||||
boxes_gt = []
|
||||
dds_gt = []
|
||||
truncs_gt = [] # Float from 0 to 1
|
||||
occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown
|
||||
|
||||
with open(path_gt, "r") as f_gt:
|
||||
for line_gt in f_gt:
|
||||
if check_conditions(line_gt, mode='gt'):
|
||||
truncs_gt.append(float(line_gt.split()[1]))
|
||||
occs_gt.append(int(line_gt.split()[2]))
|
||||
boxes_gt.append([float(x) for x in line_gt.split()[4:8]])
|
||||
loc_gt = [float(x) for x in line_gt.split()[11:14]]
|
||||
dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2))
|
||||
|
||||
return boxes_gt, dds_gt, truncs_gt, occs_gt
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user