refactor methods

This commit is contained in:
lorenzo 2019-05-21 11:20:35 +02:00
parent 67bc780677
commit 521a04ece5
2 changed files with 39 additions and 51 deletions

View File

@ -6,9 +6,8 @@ from collections import defaultdict
import json
import copy
import datetime
from utils.misc import get_idx_max
from utils.kitti import check_conditions, get_category, split_training
from utils.kitti import check_conditions, get_category, split_training, parse_ground_truth
from visuals.results import print_results
@ -69,30 +68,18 @@ class KittiEval:
path_our = os.path.join(self.dir_our, name)
path_3dop = os.path.join(self.dir_3dop, name)
path_md = os.path.join(self.dir_md, name)
boxes_gt = []
truncs_gt = [] # Float from 0 to 1
occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown
dds_gt = []
dic_fin = defaultdict(list)
# Iterate over each line of the gt file and save box location and distances
with open(path_gt, "r") as f_gt:
for line_gt in f_gt:
if check_conditions(line_gt, mode='gt'):
truncs_gt.append(float(line_gt.split()[1]))
occs_gt.append(int(line_gt.split()[2]))
boxes_gt.append([float(x) for x in line_gt.split()[4:8]])
loc_gt = [float(x) for x in line_gt.split()[11:14]]
dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2))
cnt_gt += 1
boxes_gt, dds_gt, truncs_gt, occs_gt = parse_ground_truth(path_gt)
cnt_gt += len(boxes_gt)
# Extract annotations for the same file
if len(boxes_gt) > 0:
boxes_m3d, dds_m3d = self.parse_txts(path_m3d, method='m3d')
boxes_3dop, dds_3dop = self.parse_txts(path_3dop, method='3dop')
boxes_md, dds_md = self.parse_txts(path_md, method='md')
boxes_m3d, dds_m3d = self._parse_txts(path_m3d, method='m3d')
boxes_3dop, dds_3dop = self._parse_txts(path_3dop, method='3dop')
boxes_md, dds_md = self._parse_txts(path_md, method='md')
boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps = \
self.parse_txts(path_our, method='our')
self._parse_txts(path_our, method='our')
# Compute the error with ground truth
@ -100,7 +87,7 @@ class KittiEval:
self.estimate_error_base(boxes_3dop, dds_3dop, boxes_gt, dds_gt, truncs_gt, occs_gt, method='3dop')
self.estimate_error_base(boxes_md, dds_md, boxes_gt, dds_gt, truncs_gt, occs_gt, method='md')
self.estimate_error_our(boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name)
boxes_gt, dds_gt, truncs_gt, occs_gt, name)
# Iterate over all the files together to find a pool of common annotations
self.compare_error(boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
@ -161,7 +148,7 @@ class KittiEval:
# Print images
print_results(self.dic_stats, self.show)
def parse_txts(self, path, method):
def _parse_txts(self, path, method):
boxes = []
dds = []
stds_ale = []
@ -276,7 +263,7 @@ class KittiEval:
break
def estimate_error_our(self, boxes, dds, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps,
boxes_gt, dds_gt, truncs_gt, occs_gt, dic_fin, name):
boxes_gt, dds_gt, truncs_gt, occs_gt, name):
# Compute error (distance) and save it
boxes_gt = copy.deepcopy(boxes_gt)
@ -307,24 +294,6 @@ class KittiEval:
truncs_gt.pop(idx_max)
occs_gt.pop(idx_max)
# Extract K and save it everything in a json file
dic_fin['boxes'].append(box)
dic_fin['dds_gt'].append(dd_gt)
dic_fin['dds_pred'].append(dd)
dic_fin['stds_ale'].append(ale)
dic_fin['stds_epi'].append(epi)
dic_fin['dds_geom'].append(dd_geom)
dic_fin['xyz'].append(xyz)
dic_fin['xy_kps'].append(xy_kp)
else:
break
# kk_fin = np.array(kk_list).reshape(3, 3).tolist()
# dic_fin['K'] = kk_fin
# path_json = os.path.join(self.dir_fin, name[:-4] + '.json')
# with open(path_json, 'w') as ff:
# json.dump(dic_fin, ff)
def compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
@ -372,7 +341,7 @@ class KittiEval:
"""Compute and save errors between a single box and the gt box which match"""
diff = abs(dd - dd_gt)
clst = self.find_cluster(dd_gt, self.clusters)
clst = find_cluster(dd_gt, self.clusters)
errors['all'].append(diff)
errors[cat].append(diff)
errors[clst].append(diff)
@ -395,7 +364,7 @@ class KittiEval:
def update_uncertainty(self, std_ale, std_epi, dd, dd_gt, cat):
clst = self.find_cluster(dd_gt, self.clusters)
clst = find_cluster(dd_gt, self.clusters)
self.dic_stds['ale']['all'].append(std_ale)
self.dic_stds['ale'][clst].append(std_ale)
self.dic_stds['ale'][cat].append(std_ale)
@ -409,7 +378,7 @@ class KittiEval:
self.dic_stds['at_risk'][clst].append(1)
self.dic_stds['at_risk'][cat].append(1)
if abs(dd - dd_gt) <= (std_epi):
if abs(dd - dd_gt) <= std_epi:
self.dic_stds['interval']['all'].append(1)
self.dic_stds['interval'][clst].append(1)
self.dic_stds['interval'][cat].append(1)
@ -430,12 +399,11 @@ class KittiEval:
# self.dic_stds['at_risk'][clst].append(0)
# self.dic_stds['at_risk'][cat].append(0)
@staticmethod
def find_cluster(dd, clusters):
"""Find the correct cluster. The first and the last one are not numeric"""
def find_cluster(dd, clusters):
"""Find the correct cluster. The first and the last one are not numeric"""
for clst in clusters[4: -1]:
if dd <= int(clst):
return clst
for clst in clusters[4: -1]:
if dd <= int(clst):
return clst
return clusters[-1]
return clusters[-1]

View File

@ -1,6 +1,7 @@
import numpy as np
import copy
import math
from utils.camera import pixel_to_camera, get_keypoints
from eval.geom_baseline import compute_distance_single
@ -143,3 +144,22 @@ def split_training(names_gt, path_train, path_val):
assert set_train and set_val, "No validation or training annotations"
return set_train, set_val
def parse_ground_truth(path_gt):
"""Parse KITTI ground truth files"""
boxes_gt = []
dds_gt = []
truncs_gt = [] # Float from 0 to 1
occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown
with open(path_gt, "r") as f_gt:
for line_gt in f_gt:
if check_conditions(line_gt, mode='gt'):
truncs_gt.append(float(line_gt.split()[1]))
occs_gt.append(int(line_gt.split()[2]))
boxes_gt.append([float(x) for x in line_gt.split()[4:8]])
loc_gt = [float(x) for x in line_gt.split()[11:14]]
dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2))
return boxes_gt, dds_gt, truncs_gt, occs_gt