add function for split_training
This commit is contained in:
parent
3c6c305606
commit
738095bc8c
@ -7,7 +7,7 @@ import logging
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import datetime
|
||||
from utils.kitti import get_calibration, check_conditions
|
||||
from utils.kitti import get_calibration, check_conditions, split_training
|
||||
from utils.pifpaf import get_input_data, preprocess_pif
|
||||
from utils.misc import get_idx_max, append_cluster
|
||||
|
||||
@ -44,23 +44,7 @@ class PreprocessKitti:
|
||||
self.path_names = os.path.join(dir_out, 'names-kitti-' + now_time + '.json')
|
||||
path_train = os.path.join('splits', 'kitti_train.txt')
|
||||
path_val = os.path.join('splits', 'kitti_val.txt')
|
||||
|
||||
# Split training and validation images
|
||||
set_gt = set(self.names_gt)
|
||||
set_train = set()
|
||||
set_val = set()
|
||||
|
||||
with open(path_train, "r") as f_train:
|
||||
for line in f_train:
|
||||
set_train.add(line[:-1] + '.txt')
|
||||
with open(path_val, "r") as f_val:
|
||||
for line in f_val:
|
||||
set_val.add(line[:-1] + '.txt')
|
||||
|
||||
self.set_train = set_gt.intersection(set_train)
|
||||
self.set_val = set_gt.intersection(set_val)
|
||||
assert self.set_train and self.set_val, "No validation or training annotations"
|
||||
|
||||
self.set_train, self.set_val = split_training(self.names_gt, path_train, path_val)
|
||||
|
||||
def run(self):
|
||||
|
||||
|
||||
@ -123,3 +123,23 @@ def get_category(box, trunc, occ):
|
||||
cat = 'hard'
|
||||
|
||||
return cat
|
||||
|
||||
|
||||
def split_training(names_gt, path_train, path_val):
|
||||
"""Split training and validation images"""
|
||||
set_gt = set(names_gt)
|
||||
set_train = set()
|
||||
set_val = set()
|
||||
|
||||
with open(path_train, "r") as f_train:
|
||||
for line in f_train:
|
||||
set_train.add(line[:-1] + '.txt')
|
||||
with open(path_val, "r") as f_val:
|
||||
for line in f_val:
|
||||
set_val.add(line[:-1] + '.txt')
|
||||
|
||||
set_train = set_gt.intersection(set_train)
|
||||
set_val = set_gt.intersection(set_val)
|
||||
assert set_train and set_val, "No validation or training annotations"
|
||||
return set_train, set_val
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user