add print method

This commit is contained in:
lorenzo 2019-05-21 15:03:46 +02:00
parent 6f3379e394
commit 3c7d07d081
2 changed files with 9 additions and 8 deletions

View File

@ -26,8 +26,7 @@ class KittiEval:
dic_cnt = defaultdict(int)
errors = defaultdict(lambda: defaultdict(list))
def __init__(self, show=False, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
self.show = show
def __init__(self, thresh_iou_our=0.3, thresh_iou_m3d=0.5, thresh_conf_m3d=0.5, thresh_conf_our=0.3):
self.dir_gt = os.path.join('data', 'kitti', 'gt')
self.dir_m3d = os.path.join('data', 'kitti', 'm3d')
self.dir_3dop = os.path.join('data', 'kitti', '3dop')
@ -120,8 +119,8 @@ class KittiEval:
print("\n Number of matched annotations: {:.1f} %".format(self.errors[key]['matched']))
print("-"*100)
# Print images
print_results(self.dic_stats, self.show)
def print(self, show):
print_results(self.dic_stats, show)
def _parse_txts(self, path, method):
boxes = []
@ -134,7 +133,7 @@ class KittiEval:
xy_kps = []
# Iterate over each line of the txt file
if method == '3dop' or method == 'm3d':
if method in ['3dop', 'm3d']:
try:
with open(path, "r") as ff:
for line in ff:
@ -401,3 +400,4 @@ def find_cluster(dd, clusters):
return clusters[-1]

View File

@ -143,8 +143,9 @@ def main():
run_kitti.run()
if args.dataset == 'kitti':
kitti_eval = KittiEval(show=args.show)
kitti_eval = KittiEval()
kitti_eval.run()
kitti_eval.print(show=args.show)
if 'nuscenes' in args.dataset:
training = Trainer(joints=args.joints)