diff --git a/src/eval/kitti_eval.py b/src/eval/kitti_eval.py index dab6e41..3237078 100644 --- a/src/eval/kitti_eval.py +++ b/src/eval/kitti_eval.py @@ -7,6 +7,10 @@ import json import copy import datetime +from utils.misc import get_idx_max +from utils.kitti import check_conditions, get_category, split_training +from visuals.results import print_results + class KittiEval: """ @@ -22,19 +26,12 @@ class KittiEval: self.logger = logging.getLogger(__name__) self.show = show - from utils.misc import get_idx_max - self.get_idx_max = get_idx_max - from utils.kitti import check_conditions, get_category - self.check_conditions = check_conditions - self.get_category = get_category - from visuals.results import print_results - self.print_results = print_results - self.dir_gt = os.path.join('data', 'kitti', 'gt') self.dir_m3d = os.path.join('data', 'kitti', 'm3d') self.dir_3dop = os.path.join('data', 'kitti', '3dop') self.dir_md = os.path.join('data', 'kitti', 'monodepth') self.dir_our = os.path.join('data', 'kitti', 'monoloco') + path_train = os.path.join('splits', 'kitti_train.txt') path_val = os.path.join('splits', 'kitti_val.txt') dir_logs = os.path.join('data', 'logs') assert dir_logs, "No directory to save final statistics" @@ -54,27 +51,22 @@ class KittiEval: self.dic_cnt = defaultdict(int) self.errors = defaultdict(lambda: defaultdict(list)) - # Only consider validation images - set_gt = set(os.listdir(self.dir_gt)) - set_val = set() - - with open(path_val, "r") as f_val: - for line in f_val: - set_val.add(line[:-1] + '.txt') - self.list_gt = list(set_gt.intersection(set_val)) - assert self.list_gt, "No images in the folder" + # Extract validation images for evaluation + names_gt = tuple(os.listdir(self.dir_gt)) + _, self.set_val = split_training(names_gt, path_train, path_val) + aa = 5 def run(self): + """Evaluate Monoloco methods on ALP and ALE metrics""" + self.dic_stds = defaultdict(lambda: defaultdict(list)) dic_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float)))) cnt_gt = 0 # Iterate over each ground truth file in the training set - for name in self.list_gt: - if name == '004647.txt': - aa = 5 + for name in self.set_val: path_gt = os.path.join(self.dir_gt, name) path_m3d = os.path.join(self.dir_m3d, name) path_our = os.path.join(self.dir_our, name) diff --git a/src/features/preprocess_ki.py b/src/features/preprocess_ki.py index ffd2675..2ae77cc 100644 --- a/src/features/preprocess_ki.py +++ b/src/features/preprocess_ki.py @@ -31,7 +31,7 @@ class PreprocessKitti: self.dir_ann = dir_ann self.iou_thresh = iou_thresh self.dir_gt = os.path.join('data', 'kitti', 'gt') - self.names_gt = os.listdir(self.dir_gt) + self.names_gt = tuple(os.listdir(self.dir_gt)) self.dir_kk = os.path.join('data', 'kitti', 'calib') self.list_gt = glob.glob(self.dir_gt + '/*.txt') assert os.path.exists(self.dir_gt), "Ground truth dir does not exist" diff --git a/src/utils/kitti.py b/src/utils/kitti.py index ba85332..910b647 100644 --- a/src/utils/kitti.py +++ b/src/utils/kitti.py @@ -138,8 +138,8 @@ def split_training(names_gt, path_train, path_val): for line in f_val: set_val.add(line[:-1] + '.txt') - set_train = set_gt.intersection(set_train) - set_val = set_gt.intersection(set_val) + set_train = tuple(set_gt.intersection(set_train)) + set_val = tuple(set_gt.intersection(set_val)) assert set_train and set_val, "No validation or training annotations" return set_train, set_val