refactor __init (1)
This commit is contained in:
parent
738095bc8c
commit
6fbc496702
@ -7,6 +7,10 @@ import json
|
|||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
from utils.misc import get_idx_max
|
||||||
|
from utils.kitti import check_conditions, get_category, split_training
|
||||||
|
from visuals.results import print_results
|
||||||
|
|
||||||
|
|
||||||
class KittiEval:
|
class KittiEval:
|
||||||
"""
|
"""
|
||||||
@ -22,19 +26,12 @@ class KittiEval:
|
|||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
self.show = show
|
self.show = show
|
||||||
from utils.misc import get_idx_max
|
|
||||||
self.get_idx_max = get_idx_max
|
|
||||||
from utils.kitti import check_conditions, get_category
|
|
||||||
self.check_conditions = check_conditions
|
|
||||||
self.get_category = get_category
|
|
||||||
from visuals.results import print_results
|
|
||||||
self.print_results = print_results
|
|
||||||
|
|
||||||
self.dir_gt = os.path.join('data', 'kitti', 'gt')
|
self.dir_gt = os.path.join('data', 'kitti', 'gt')
|
||||||
self.dir_m3d = os.path.join('data', 'kitti', 'm3d')
|
self.dir_m3d = os.path.join('data', 'kitti', 'm3d')
|
||||||
self.dir_3dop = os.path.join('data', 'kitti', '3dop')
|
self.dir_3dop = os.path.join('data', 'kitti', '3dop')
|
||||||
self.dir_md = os.path.join('data', 'kitti', 'monodepth')
|
self.dir_md = os.path.join('data', 'kitti', 'monodepth')
|
||||||
self.dir_our = os.path.join('data', 'kitti', 'monoloco')
|
self.dir_our = os.path.join('data', 'kitti', 'monoloco')
|
||||||
|
path_train = os.path.join('splits', 'kitti_train.txt')
|
||||||
path_val = os.path.join('splits', 'kitti_val.txt')
|
path_val = os.path.join('splits', 'kitti_val.txt')
|
||||||
dir_logs = os.path.join('data', 'logs')
|
dir_logs = os.path.join('data', 'logs')
|
||||||
assert dir_logs, "No directory to save final statistics"
|
assert dir_logs, "No directory to save final statistics"
|
||||||
@ -54,27 +51,22 @@ class KittiEval:
|
|||||||
self.dic_cnt = defaultdict(int)
|
self.dic_cnt = defaultdict(int)
|
||||||
self.errors = defaultdict(lambda: defaultdict(list))
|
self.errors = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
# Only consider validation images
|
# Extract validation images for evaluation
|
||||||
set_gt = set(os.listdir(self.dir_gt))
|
names_gt = tuple(os.listdir(self.dir_gt))
|
||||||
set_val = set()
|
_, self.set_val = split_training(names_gt, path_train, path_val)
|
||||||
|
aa = 5
|
||||||
with open(path_val, "r") as f_val:
|
|
||||||
for line in f_val:
|
|
||||||
set_val.add(line[:-1] + '.txt')
|
|
||||||
self.list_gt = list(set_gt.intersection(set_val))
|
|
||||||
assert self.list_gt, "No images in the folder"
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
|
"""Evaluate Monoloco methods on ALP and ALE metrics"""
|
||||||
|
|
||||||
self.dic_stds = defaultdict(lambda: defaultdict(list))
|
self.dic_stds = defaultdict(lambda: defaultdict(list))
|
||||||
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
|
dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
|
||||||
|
|
||||||
cnt_gt = 0
|
cnt_gt = 0
|
||||||
|
|
||||||
# Iterate over each ground truth file in the training set
|
# Iterate over each ground truth file in the training set
|
||||||
for name in self.list_gt:
|
for name in self.set_val:
|
||||||
if name == '004647.txt':
|
|
||||||
aa = 5
|
|
||||||
path_gt = os.path.join(self.dir_gt, name)
|
path_gt = os.path.join(self.dir_gt, name)
|
||||||
path_m3d = os.path.join(self.dir_m3d, name)
|
path_m3d = os.path.join(self.dir_m3d, name)
|
||||||
path_our = os.path.join(self.dir_our, name)
|
path_our = os.path.join(self.dir_our, name)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class PreprocessKitti:
|
|||||||
self.dir_ann = dir_ann
|
self.dir_ann = dir_ann
|
||||||
self.iou_thresh = iou_thresh
|
self.iou_thresh = iou_thresh
|
||||||
self.dir_gt = os.path.join('data', 'kitti', 'gt')
|
self.dir_gt = os.path.join('data', 'kitti', 'gt')
|
||||||
self.names_gt = os.listdir(self.dir_gt)
|
self.names_gt = tuple(os.listdir(self.dir_gt))
|
||||||
self.dir_kk = os.path.join('data', 'kitti', 'calib')
|
self.dir_kk = os.path.join('data', 'kitti', 'calib')
|
||||||
self.list_gt = glob.glob(self.dir_gt + '/*.txt')
|
self.list_gt = glob.glob(self.dir_gt + '/*.txt')
|
||||||
assert os.path.exists(self.dir_gt), "Ground truth dir does not exist"
|
assert os.path.exists(self.dir_gt), "Ground truth dir does not exist"
|
||||||
|
|||||||
@ -138,8 +138,8 @@ def split_training(names_gt, path_train, path_val):
|
|||||||
for line in f_val:
|
for line in f_val:
|
||||||
set_val.add(line[:-1] + '.txt')
|
set_val.add(line[:-1] + '.txt')
|
||||||
|
|
||||||
set_train = set_gt.intersection(set_train)
|
set_train = tuple(set_gt.intersection(set_train))
|
||||||
set_val = set_gt.intersection(set_val)
|
set_val = tuple(set_gt.intersection(set_val))
|
||||||
assert set_train and set_val, "No validation or training annotations"
|
assert set_train and set_val, "No validation or training annotations"
|
||||||
return set_train, set_val
|
return set_train, set_val
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user