add function for split_training

This commit is contained in:
lorenzo 2019-05-21 10:30:21 +02:00
parent 3c6c305606
commit 738095bc8c
2 changed files with 22 additions and 18 deletions

View File

@ -7,7 +7,7 @@ import logging
from collections import defaultdict
import json
import datetime
from utils.kitti import get_calibration, check_conditions
from utils.kitti import get_calibration, check_conditions, split_training
from utils.pifpaf import get_input_data, preprocess_pif
from utils.misc import get_idx_max, append_cluster
@ -44,23 +44,7 @@ class PreprocessKitti:
self.path_names = os.path.join(dir_out, 'names-kitti-' + now_time + '.json')
path_train = os.path.join('splits', 'kitti_train.txt')
path_val = os.path.join('splits', 'kitti_val.txt')
# Split training and validation images
set_gt = set(self.names_gt)
set_train = set()
set_val = set()
with open(path_train, "r") as f_train:
for line in f_train:
set_train.add(line[:-1] + '.txt')
with open(path_val, "r") as f_val:
for line in f_val:
set_val.add(line[:-1] + '.txt')
self.set_train = set_gt.intersection(set_train)
self.set_val = set_gt.intersection(set_val)
assert self.set_train and self.set_val, "No validation or training annotations"
self.set_train, self.set_val = split_training(self.names_gt, path_train, path_val)
def run(self):

View File

@ -123,3 +123,23 @@ def get_category(box, trunc, occ):
cat = 'hard'
return cat
def split_training(names_gt, path_train, path_val):
"""Split training and validation images"""
set_gt = set(names_gt)
set_train = set()
set_val = set()
with open(path_train, "r") as f_train:
for line in f_train:
set_train.add(line[:-1] + '.txt')
with open(path_val, "r") as f_val:
for line in f_val:
set_val.add(line[:-1] + '.txt')
set_train = set_gt.intersection(set_train)
set_val = set_gt.intersection(set_val)
assert set_train and set_val, "No validation or training annotations"
return set_train, set_val