From 738095bc8c67660c34409f54c32d62b3e8ed3577 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Tue, 21 May 2019 10:30:21 +0200 Subject: [PATCH] add function for split_training --- src/features/preprocess_ki.py | 20 ++------------------ src/utils/kitti.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/features/preprocess_ki.py b/src/features/preprocess_ki.py index 9e29682..ffd2675 100644 --- a/src/features/preprocess_ki.py +++ b/src/features/preprocess_ki.py @@ -7,7 +7,7 @@ import logging from collections import defaultdict import json import datetime -from utils.kitti import get_calibration, check_conditions +from utils.kitti import get_calibration, check_conditions, split_training from utils.pifpaf import get_input_data, preprocess_pif from utils.misc import get_idx_max, append_cluster @@ -44,23 +44,7 @@ class PreprocessKitti: self.path_names = os.path.join(dir_out, 'names-kitti-' + now_time + '.json') path_train = os.path.join('splits', 'kitti_train.txt') path_val = os.path.join('splits', 'kitti_val.txt') - - # Split training and validation images - set_gt = set(self.names_gt) - set_train = set() - set_val = set() - - with open(path_train, "r") as f_train: - for line in f_train: - set_train.add(line[:-1] + '.txt') - with open(path_val, "r") as f_val: - for line in f_val: - set_val.add(line[:-1] + '.txt') - - self.set_train = set_gt.intersection(set_train) - self.set_val = set_gt.intersection(set_val) - assert self.set_train and self.set_val, "No validation or training annotations" - + self.set_train, self.set_val = split_training(self.names_gt, path_train, path_val) def run(self): diff --git a/src/utils/kitti.py b/src/utils/kitti.py index 3d7703e..ba85332 100644 --- a/src/utils/kitti.py +++ b/src/utils/kitti.py @@ -123,3 +123,23 @@ def get_category(box, trunc, occ): cat = 'hard' return cat + + +def split_training(names_gt, path_train, path_val): + """Split training and validation images""" + set_gt = set(names_gt) + set_train = set() + set_val = set() + + with open(path_train, "r") as f_train: + for line in f_train: + set_train.add(line[:-1] + '.txt') + with open(path_val, "r") as f_val: + for line in f_val: + set_val.add(line[:-1] + '.txt') + + set_train = set_gt.intersection(set_train) + set_val = set_gt.intersection(set_val) + assert set_train and set_val, "No validation or training annotations" + return set_train, set_val +