add function for split_training
This commit is contained in:
parent
3c6c305606
commit
738095bc8c
@ -7,7 +7,7 @@ import logging
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import json
|
import json
|
||||||
import datetime
|
import datetime
|
||||||
from utils.kitti import get_calibration, check_conditions
|
from utils.kitti import get_calibration, check_conditions, split_training
|
||||||
from utils.pifpaf import get_input_data, preprocess_pif
|
from utils.pifpaf import get_input_data, preprocess_pif
|
||||||
from utils.misc import get_idx_max, append_cluster
|
from utils.misc import get_idx_max, append_cluster
|
||||||
|
|
||||||
@ -44,23 +44,7 @@ class PreprocessKitti:
|
|||||||
self.path_names = os.path.join(dir_out, 'names-kitti-' + now_time + '.json')
|
self.path_names = os.path.join(dir_out, 'names-kitti-' + now_time + '.json')
|
||||||
path_train = os.path.join('splits', 'kitti_train.txt')
|
path_train = os.path.join('splits', 'kitti_train.txt')
|
||||||
path_val = os.path.join('splits', 'kitti_val.txt')
|
path_val = os.path.join('splits', 'kitti_val.txt')
|
||||||
|
self.set_train, self.set_val = split_training(self.names_gt, path_train, path_val)
|
||||||
# Split training and validation images
|
|
||||||
set_gt = set(self.names_gt)
|
|
||||||
set_train = set()
|
|
||||||
set_val = set()
|
|
||||||
|
|
||||||
with open(path_train, "r") as f_train:
|
|
||||||
for line in f_train:
|
|
||||||
set_train.add(line[:-1] + '.txt')
|
|
||||||
with open(path_val, "r") as f_val:
|
|
||||||
for line in f_val:
|
|
||||||
set_val.add(line[:-1] + '.txt')
|
|
||||||
|
|
||||||
self.set_train = set_gt.intersection(set_train)
|
|
||||||
self.set_val = set_gt.intersection(set_val)
|
|
||||||
assert self.set_train and self.set_val, "No validation or training annotations"
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
|
|||||||
@ -123,3 +123,23 @@ def get_category(box, trunc, occ):
|
|||||||
cat = 'hard'
|
cat = 'hard'
|
||||||
|
|
||||||
return cat
|
return cat
|
||||||
|
|
||||||
|
|
||||||
|
def split_training(names_gt, path_train, path_val):
|
||||||
|
"""Split training and validation images"""
|
||||||
|
set_gt = set(names_gt)
|
||||||
|
set_train = set()
|
||||||
|
set_val = set()
|
||||||
|
|
||||||
|
with open(path_train, "r") as f_train:
|
||||||
|
for line in f_train:
|
||||||
|
set_train.add(line[:-1] + '.txt')
|
||||||
|
with open(path_val, "r") as f_val:
|
||||||
|
for line in f_val:
|
||||||
|
set_val.add(line[:-1] + '.txt')
|
||||||
|
|
||||||
|
set_train = set_gt.intersection(set_train)
|
||||||
|
set_val = set_gt.intersection(set_val)
|
||||||
|
assert set_train and set_val, "No validation or training annotations"
|
||||||
|
return set_train, set_val
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user