add print method
This commit is contained in:
parent
6f3379e394
commit
3c7d07d081
@ -26,8 +26,7 @@ class KittiEval:
|
|||||||
dic_cnt = defaultdict(int)
|
dic_cnt = defaultdict(int)
|
||||||
errors = defaultdict(lambda: defaultdict(list))
|
errors = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
|
def __init__(self, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
|
||||||
self.show = show
|
|
||||||
self.dir_gt = os.path.join('data', 'kitti', 'gt')
|
self.dir_gt = os.path.join('data', 'kitti', 'gt')
|
||||||
self.dir_m3d = os.path.join('data', 'kitti', 'm3d')
|
self.dir_m3d = os.path.join('data', 'kitti', 'm3d')
|
||||||
self.dir_3dop = os.path.join('data', 'kitti', '3dop')
|
self.dir_3dop = os.path.join('data', 'kitti', '3dop')
|
||||||
@ -70,7 +69,7 @@ class KittiEval:
|
|||||||
|
|
||||||
# Extract annotations for the same file
|
# Extract annotations for the same file
|
||||||
if len(boxes_gt) > 0:
|
if len(boxes_gt) > 0:
|
||||||
boxes_m3d, dds_m3d = self._parse_txts(path_m3d, method='m3d')
|
boxes_m3d, dds_m3d = self._parse_txts(path_m3d, method='m3d')
|
||||||
boxes_3dop, dds_3dop = self._parse_txts(path_3dop, method='3dop')
|
boxes_3dop, dds_3dop = self._parse_txts(path_3dop, method='3dop')
|
||||||
boxes_md, dds_md = self._parse_txts(path_md, method='md')
|
boxes_md, dds_md = self._parse_txts(path_md, method='md')
|
||||||
boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps = \
|
boxes_our, dds_our, stds_ale, stds_epi, kk_list, dds_geom, xyzs, xy_kps = \
|
||||||
@ -120,8 +119,8 @@ class KittiEval:
|
|||||||
print("\n Number of matched annotations: {:.1f} %".format(self.errors[key]['matched']))
|
print("\n Number of matched annotations: {:.1f} %".format(self.errors[key]['matched']))
|
||||||
print("-"*100)
|
print("-"*100)
|
||||||
|
|
||||||
# Print images
|
def print(self, show):
|
||||||
print_results(self.dic_stats, self.show)
|
print_results(self.dic_stats, show)
|
||||||
|
|
||||||
def _parse_txts(self, path, method):
|
def _parse_txts(self, path, method):
|
||||||
boxes = []
|
boxes = []
|
||||||
@ -134,7 +133,7 @@ class KittiEval:
|
|||||||
xy_kps = []
|
xy_kps = []
|
||||||
|
|
||||||
# Iterate over each line of the txt file
|
# Iterate over each line of the txt file
|
||||||
if method == '3dop' or method == 'm3d':
|
if method in ['3dop', 'm3d']:
|
||||||
try:
|
try:
|
||||||
with open(path, "r") as ff:
|
with open(path, "r") as ff:
|
||||||
for line in ff:
|
for line in ff:
|
||||||
@ -268,7 +267,7 @@ class KittiEval:
|
|||||||
occs_gt.pop(idx_max)
|
occs_gt.pop(idx_max)
|
||||||
|
|
||||||
def _compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
def _compare_error(self, boxes_m3d, dds_m3d, boxes_3dop, dds_3dop, boxes_md, dds_md, boxes_our, dds_our,
|
||||||
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
|
boxes_gt, dds_gt, truncs_gt, occs_gt, dds_geom):
|
||||||
|
|
||||||
boxes_gt = copy.deepcopy(boxes_gt)
|
boxes_gt = copy.deepcopy(boxes_gt)
|
||||||
dds_gt = copy.deepcopy(dds_gt)
|
dds_gt = copy.deepcopy(dds_gt)
|
||||||
@ -401,3 +400,4 @@ def find_cluster(dd, clusters):
|
|||||||
|
|
||||||
return clusters[-1]
|
return clusters[-1]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -143,8 +143,9 @@ def main():
|
|||||||
run_kitti.run()
|
run_kitti.run()
|
||||||
|
|
||||||
if args.dataset == 'kitti':
|
if args.dataset == 'kitti':
|
||||||
kitti_eval = KittiEval(show=args.show)
|
kitti_eval = KittiEval()
|
||||||
kitti_eval.run()
|
kitti_eval.run()
|
||||||
|
kitti_eval.print(show=args.show)
|
||||||
|
|
||||||
if 'nuscenes' in args.dataset:
|
if 'nuscenes' in args.dataset:
|
||||||
training = Trainer(joints=args.joints)
|
training = Trainer(joints=args.joints)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user