refactor methods
This commit is contained in:
parent
67bc780677
commit
521a04ece5
@ -6,9 +6,8 @@ from collections import defaultdict
|
|||||||
import json
|
import json
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from utils.misc import get_idx_max
|
from utils.misc import get_idx_max
|
||||||
from utils.kitti import check_conditions, get_category, split_training
|
from utils.kitti import check_conditions, get_category, split_training, parse_ground_truth
|
||||||
from visuals.results import print_results
|
from visuals.results import print_results
|
||||||
|
|
||||||
|
|
||||||
@ -69,30 +68,18 @@ class KittiEval:
|
|||||||
path_our = os.path.join(self.dir_our, name)
|
path_our = os.path.join(self.dir_our, name)
|
||||||
path_3dop = os.path.join(self.dir_3dop, name)
|
path_3dop = os.path.join(self.dir_3dop, name)
|
||||||
path_md = os.path.join(self.dir_md, name)
|
path_md = os.path.join(self.dir_md, name)
|
||||||
boxes_gt = []
|
|
||||||
truncs_gt = [] # Float from 0 to 1
|
|
||||||
occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown
|
|
||||||
dds_gt = []
|
|
||||||
dic_fin = defaultdict(list)
|
|
||||||
|
|
||||||
# Iterate over each line of the gt file and save box location and distances
|
# Iterate over each line of the gt file and save box location and distances
|
||||||
with open(path_gt, "r") as f_gt:
|
boxes_gt, dds_gt, truncs_gt, occs_gt = parse_ground_truth(path_gt)
|
||||||
for line_gt in f_gt:
|
cnt_gt += len(boxes_gt)
|
||||||
if check_conditions(line_gt, mode='gt'):
|
|
||||||
truncs_gt.append(float(line_gt.split()[1]))
|
|
||||||
occs_gt.append(int(line_gt.split()[2]))
|
|
||||||
boxes_gt.append([float(x) for x in line_gt.split()[4:8]])
|
|
||||||
loc_gt = [float(x) for x in line_gt.split()[11:14]]
|
|
||||||
dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2))
|
|
||||||
cnt_gt += 1
|
|
||||||
|
|
||||||
# Extract annotations for the same file
|
# Extract annotations for the same file
|
||||||
if len(boxes_gt) > 0:
|
if len(boxes_gt) > 0:
|
||||||
boxes_m3d, dds_m3d = self.parse_txts(path_m3d, method='m3d')
|
boxes_m3d, dds_m3d = self._parse_txts(path_m3d, method='m3d')
|
||||||
boxes_3dop, dds_3dop = self.parse_txts(path_3dop, method='3dop')
|
boxes_3dop, dds_3dop = self._parse_txts(path_3dop, method='3dop')
|
||||||
boxes_md, dds_md = self.parse_txts(path_md, method='md')
|
boxes_md, dds_md = self._parse_txts(path_md, method='md')
|
||||||
boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps = \
|
boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps = \
|
||||||
self.parse_txts(path_our, method='our')
|
self._parse_txts(path_our, method='our')
|
||||||
|
|
||||||
# Compute the error with ground truth
|
# Compute the error with ground truth
|
||||||
|
|
||||||
@ -100,7 +87,7 @@ class KittiEval:
|
|||||||
self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
|
self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
|
||||||
self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
|
self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
|
||||||
self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name)
|
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
|
||||||
|
|
||||||
# Iterate over all the files together to find a pool of common annotations
|
# Iterate over all the files together to find a pool of common annotations
|
||||||
self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||||
@ -161,7 +148,7 @@ class KittiEval:
|
|||||||
# Print images
|
# Print images
|
||||||
print_results(self.dic_stats, self.show)
|
print_results(self.dic_stats, self.show)
|
||||||
|
|
||||||
def parse_txts(self, path, method):
|
def _parse_txts(self, path, method):
|
||||||
boxes = []
|
boxes = []
|
||||||
dds = []
|
dds = []
|
||||||
stds_ale = []
|
stds_ale = []
|
||||||
@ -276,7 +263,7 @@ class KittiEval:
|
|||||||
break
|
break
|
||||||
|
|
||||||
def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
|
||||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name):
|
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
|
||||||
|
|
||||||
# Compute error (distance) and save it
|
# Compute error (distance) and save it
|
||||||
boxes_gt = copy.deepcopy(boxes_gt)
|
boxes_gt = copy.deepcopy(boxes_gt)
|
||||||
@ -307,24 +294,6 @@ class KittiEval:
|
|||||||
truncs_gt.pop(idx_max)
|
truncs_gt.pop(idx_max)
|
||||||
occs_gt.pop(idx_max)
|
occs_gt.pop(idx_max)
|
||||||
|
|
||||||
# Extract K and save it everything in a json file
|
|
||||||
dic_fin['boxes'].append(box)
|
|
||||||
dic_fin['dds_gt'].append(dd_gt)
|
|
||||||
dic_fin['dds_pred'].append(dd)
|
|
||||||
dic_fin['stds_ale'].append(ale)
|
|
||||||
dic_fin['stds_epi'].append(epi)
|
|
||||||
dic_fin['dds_geom'].append(dd_geom)
|
|
||||||
dic_fin['xyz'].append(xyz)
|
|
||||||
dic_fin['xy_kps'].append(xy_kp)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
# kk_fin = np.array(kk_list).reshape(3, 3).tolist()
|
|
||||||
# dic_fin['K'] = kk_fin
|
|
||||||
# path_json = os.path.join(self.dir_fin, name[:-4] + '.json')
|
|
||||||
# with open(path_json, 'w') as ff:
|
|
||||||
# json.dump(dic_fin, ff)
|
|
||||||
|
|
||||||
def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
|
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
|
||||||
|
|
||||||
@ -372,7 +341,7 @@ class KittiEval:
|
|||||||
"""Compute and save errors between a single box and the gt box which match"""
|
"""Compute and save errors between a single box and the gt box which match"""
|
||||||
|
|
||||||
diff = abs(dd - dd_gt)
|
diff = abs(dd - dd_gt)
|
||||||
clst = self.find_cluster(dd_gt, self.clusters)
|
clst = find_cluster(dd_gt, self.clusters)
|
||||||
errors['all'].append(diff)
|
errors['all'].append(diff)
|
||||||
errors[cat].append(diff)
|
errors[cat].append(diff)
|
||||||
errors[clst].append(diff)
|
errors[clst].append(diff)
|
||||||
@ -395,7 +364,7 @@ class KittiEval:
|
|||||||
|
|
||||||
def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat):
|
def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat):
|
||||||
|
|
||||||
clst = self.find_cluster(dd_gt, self.clusters)
|
clst = find_cluster(dd_gt, self.clusters)
|
||||||
self.dic_stds['ale']['all'].append(std_ale)
|
self.dic_stds['ale']['all'].append(std_ale)
|
||||||
self.dic_stds['ale'][clst].append(std_ale)
|
self.dic_stds['ale'][clst].append(std_ale)
|
||||||
self.dic_stds['ale'][cat].append(std_ale)
|
self.dic_stds['ale'][cat].append(std_ale)
|
||||||
@ -409,7 +378,7 @@ class KittiEval:
|
|||||||
self.dic_stds['at_risk'][clst].append(1)
|
self.dic_stds['at_risk'][clst].append(1)
|
||||||
self.dic_stds['at_risk'][cat].append(1)
|
self.dic_stds['at_risk'][cat].append(1)
|
||||||
|
|
||||||
if abs(dd - dd_gt) <= (std_epi):
|
if abs(dd - dd_gt) <= std_epi:
|
||||||
self.dic_stds['interval']['all'].append(1)
|
self.dic_stds['interval']['all'].append(1)
|
||||||
self.dic_stds['interval'][clst].append(1)
|
self.dic_stds['interval'][clst].append(1)
|
||||||
self.dic_stds['interval'][cat].append(1)
|
self.dic_stds['interval'][cat].append(1)
|
||||||
@ -430,12 +399,11 @@ class KittiEval:
|
|||||||
# self.dic_stds['at_risk'][clst].append(0)
|
# self.dic_stds['at_risk'][clst].append(0)
|
||||||
# self.dic_stds['at_risk'][cat].append(0)
|
# self.dic_stds['at_risk'][cat].append(0)
|
||||||
|
|
||||||
@staticmethod
|
def find_cluster(dd, clusters):
|
||||||
def find_cluster(dd, clusters):
|
"""Find the correct cluster. The first and the last one are not numeric"""
|
||||||
"""Find the correct cluster. The first and the last one are not numeric"""
|
|
||||||
|
|
||||||
for clst in clusters[4: -1]:
|
for clst in clusters[4: -1]:
|
||||||
if dd <= int(clst):
|
if dd <= int(clst):
|
||||||
return clst
|
return clst
|
||||||
|
|
||||||
return clusters[-1]
|
return clusters[-1]
|
||||||
@ -1,6 +1,7 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import copy
|
import copy
|
||||||
|
import math
|
||||||
|
|
||||||
from utils.camera import pixel_to_camera, get_keypoints
|
from utils.camera import pixel_to_camera, get_keypoints
|
||||||
from eval.geom_baseline import compute_distance_single
|
from eval.geom_baseline import compute_distance_single
|
||||||
@ -143,3 +144,22 @@ def split_training(names_gt, path_train, path_val):
|
|||||||
assert set_train and set_val, "No validation or training annotations"
|
assert set_train and set_val, "No validation or training annotations"
|
||||||
return set_train, set_val
|
return set_train, set_val
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ground_truth(path_gt):
|
||||||
|
"""Parse KITTI ground truth files"""
|
||||||
|
boxes_gt = []
|
||||||
|
dds_gt = []
|
||||||
|
truncs_gt = [] # Float from 0 to 1
|
||||||
|
occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown
|
||||||
|
|
||||||
|
with open(path_gt, "r") as f_gt:
|
||||||
|
for line_gt in f_gt:
|
||||||
|
if check_conditions(line_gt, mode='gt'):
|
||||||
|
truncs_gt.append(float(line_gt.split()[1]))
|
||||||
|
occs_gt.append(int(line_gt.split()[2]))
|
||||||
|
boxes_gt.append([float(x) for x in line_gt.split()[4:8]])
|
||||||
|
loc_gt = [float(x) for x in line_gt.split()[11:14]]
|
||||||
|
dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2))
|
||||||
|
|
||||||
|
return boxes_gt, dds_gt, truncs_gt, occs_gt
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user