remove class imports
This commit is contained in:
parent
755688818f
commit
3c6c305606
@ -1,4 +1,5 @@
|
|||||||
"""Preprocess annnotations with KITTI ground-truth"""
|
"""Preprocess annnotations with KITTI ground-truth"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
@ -6,11 +7,25 @@ import logging
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import json
|
import json
|
||||||
import datetime
|
import datetime
|
||||||
|
from utils.kitti import get_calibration, check_conditions
|
||||||
|
from utils.pifpaf import get_input_data, preprocess_pif
|
||||||
|
from utils.misc import get_idx_max, append_cluster
|
||||||
|
|
||||||
|
|
||||||
class PreprocessKitti:
|
class PreprocessKitti:
|
||||||
"""Prepare arrays with same format as nuScenes preprocessing but using ground truth txt files"""
|
"""Prepare arrays with same format as nuScenes preprocessing but using ground truth txt files"""
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
dic_jo = {'train': dict(X=[], Y=[], names=[], kps=[], K=[],
|
||||||
|
clst=defaultdict(lambda: defaultdict(list))),
|
||||||
|
'val': dict(X=[], Y=[], names=[], kps=[], K=[],
|
||||||
|
clst=defaultdict(lambda: defaultdict(list))),
|
||||||
|
'test': dict(X=[], Y=[], names=[], kps=[], K=[],
|
||||||
|
clst=defaultdict(lambda: defaultdict(list)))}
|
||||||
|
dic_names = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
def __init__(self, dir_ann, iou_thresh=0.3):
|
def __init__(self, dir_ann, iou_thresh=0.3):
|
||||||
|
|
||||||
self.dir_ann = dir_ann
|
self.dir_ann = dir_ann
|
||||||
@ -30,26 +45,6 @@ class PreprocessKitti:
|
|||||||
path_train = os.path.join('splits', 'kitti_train.txt')
|
path_train = os.path.join('splits', 'kitti_train.txt')
|
||||||
path_val = os.path.join('splits', 'kitti_val.txt')
|
path_val = os.path.join('splits', 'kitti_val.txt')
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
from utils.kitti import get_calibration, check_conditions
|
|
||||||
self.get_calibration = get_calibration
|
|
||||||
self.check_conditions = check_conditions
|
|
||||||
|
|
||||||
from utils.pifpaf import get_input_data, preprocess_pif
|
|
||||||
self.get_input_data = get_input_data
|
|
||||||
self.preprocess_pif = preprocess_pif
|
|
||||||
|
|
||||||
from utils.misc import get_idx_max, append_cluster
|
|
||||||
self.get_idx_max = get_idx_max
|
|
||||||
self.append_cluster = append_cluster
|
|
||||||
|
|
||||||
# self.clusters = ['all', '6', '10', '15', '20', '25', '30', '40', '50', '>50'
|
|
||||||
self.cnt_gt = 0
|
|
||||||
self.cnt_fnf = 0
|
|
||||||
self.dic_cnt = {'train': 0, 'val': 0, 'test': 0}
|
|
||||||
|
|
||||||
# Split training and validation images
|
# Split training and validation images
|
||||||
set_gt = set(self.names_gt)
|
set_gt = set(self.names_gt)
|
||||||
set_train = set()
|
set_train = set()
|
||||||
@ -66,17 +61,13 @@ class PreprocessKitti:
|
|||||||
self.set_val = set_gt.intersection(set_val)
|
self.set_val = set_gt.intersection(set_val)
|
||||||
assert self.set_train and self.set_val, "No validation or training annotations"
|
assert self.set_train and self.set_val, "No validation or training annotations"
|
||||||
|
|
||||||
self.dic_jo = {'train': dict(X=[], Y=[], names=[], kps=[], K=[],
|
|
||||||
clst=defaultdict(lambda: defaultdict(list))),
|
|
||||||
'val': dict(X=[], Y=[], names=[], kps=[], K=[],
|
|
||||||
clst=defaultdict(lambda: defaultdict(list))),
|
|
||||||
'test': dict(X=[], Y=[], names=[], kps=[], K=[],
|
|
||||||
clst=defaultdict(lambda: defaultdict(list)))}
|
|
||||||
|
|
||||||
self.dic_names = defaultdict(lambda: defaultdict(list))
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
|
cnt_gt = 0
|
||||||
|
cnt_fnf = 0
|
||||||
|
dic_cnt = {'train': 0, 'val': 0, 'test': 0}
|
||||||
|
|
||||||
for name in self.names_gt:
|
for name in self.names_gt:
|
||||||
# Extract ground truth
|
# Extract ground truth
|
||||||
path_gt = os.path.join(self.dir_gt, name)
|
path_gt = os.path.join(self.dir_gt, name)
|
||||||
@ -89,17 +80,17 @@ class PreprocessKitti:
|
|||||||
elif name in self.set_val:
|
elif name in self.set_val:
|
||||||
phase = 'val'
|
phase = 'val'
|
||||||
else:
|
else:
|
||||||
self.cnt_fnf += 1
|
cnt_fnf += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract keypoints
|
# Extract keypoints
|
||||||
path_txt = os.path.join(self.dir_kk, basename + '.txt')
|
path_txt = os.path.join(self.dir_kk, basename + '.txt')
|
||||||
kk, tt = self.get_calibration(path_txt)
|
kk, tt = get_calibration(path_txt)
|
||||||
|
|
||||||
# Iterate over each line of the gt file and save box location and distances
|
# Iterate over each line of the gt file and save box location and distances
|
||||||
with open(path_gt, "r") as f_gt:
|
with open(path_gt, "r") as f_gt:
|
||||||
for line_gt in f_gt:
|
for line_gt in f_gt:
|
||||||
if self.check_conditions(line_gt, mode='gt'):
|
if check_conditions(line_gt, mode='gt'):
|
||||||
box = [float(x) for x in line_gt.split()[4:8]]
|
box = [float(x) for x in line_gt.split()[4:8]]
|
||||||
boxes_gt.append(box)
|
boxes_gt.append(box)
|
||||||
loc_gt = [float(x) for x in line_gt.split()[11:14]]
|
loc_gt = [float(x) for x in line_gt.split()[11:14]]
|
||||||
@ -108,21 +99,21 @@ class PreprocessKitti:
|
|||||||
self.dic_names[basename + '.png']['boxes'].append(box)
|
self.dic_names[basename + '.png']['boxes'].append(box)
|
||||||
self.dic_names[basename + '.png']['dds'].append(dd)
|
self.dic_names[basename + '.png']['dds'].append(dd)
|
||||||
self.dic_names[basename + '.png']['K'] = kk.tolist()
|
self.dic_names[basename + '.png']['K'] = kk.tolist()
|
||||||
self.cnt_gt += 1
|
cnt_gt += 1
|
||||||
|
|
||||||
# Find the annotations if exists
|
# Find the annotations if exists
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(self.dir_ann, basename + '.png.pifpaf.json'), 'r') as f:
|
with open(os.path.join(self.dir_ann, basename + '.png.pifpaf.json'), 'r') as f:
|
||||||
annotations = json.load(f)
|
annotations = json.load(f)
|
||||||
boxes, keypoints = self.preprocess_pif(annotations)
|
boxes, keypoints = preprocess_pif(annotations)
|
||||||
(inputs, _), (uv_kps, uv_boxes, _, _) = self.get_input_data(boxes, keypoints, kk)
|
(inputs, _), (uv_kps, uv_boxes, _, _) = get_input_data(boxes, keypoints, kk)
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
uv_boxes = []
|
uv_boxes = []
|
||||||
|
|
||||||
# Match each set of keypoint with a ground truth
|
# Match each set of keypoint with a ground truth
|
||||||
for ii, box in enumerate(uv_boxes):
|
for ii, box in enumerate(uv_boxes):
|
||||||
idx_max, iou_max = self.get_idx_max(box, boxes_gt)
|
idx_max, iou_max = get_idx_max(box, boxes_gt)
|
||||||
|
|
||||||
if iou_max >= self.iou_thresh:
|
if iou_max >= self.iou_thresh:
|
||||||
|
|
||||||
@ -131,8 +122,8 @@ class PreprocessKitti:
|
|||||||
self.dic_jo[phase]['Y'].append([dds[idx_max]]) # Trick to make it (nn,1)
|
self.dic_jo[phase]['Y'].append([dds[idx_max]]) # Trick to make it (nn,1)
|
||||||
self.dic_jo[phase]['K'] = kk.tolist()
|
self.dic_jo[phase]['K'] = kk.tolist()
|
||||||
self.dic_jo[phase]['names'].append(name) # One image name for each annotation
|
self.dic_jo[phase]['names'].append(name) # One image name for each annotation
|
||||||
self.append_cluster(self.dic_jo, phase, inputs[ii], dds[idx_max], uv_kps[ii])
|
append_cluster(self.dic_jo, phase, inputs[ii], dds[idx_max], uv_kps[ii])
|
||||||
self.dic_cnt[phase] += 1
|
dic_cnt[phase] += 1
|
||||||
boxes_gt.pop(idx_max)
|
boxes_gt.pop(idx_max)
|
||||||
dds.pop(idx_max)
|
dds.pop(idx_max)
|
||||||
|
|
||||||
@ -142,9 +133,9 @@ class PreprocessKitti:
|
|||||||
json.dump(self.dic_names, file)
|
json.dump(self.dic_names, file)
|
||||||
for phase in ['train', 'val', 'test']:
|
for phase in ['train', 'val', 'test']:
|
||||||
print("Saved {} annotations for phase {}"
|
print("Saved {} annotations for phase {}"
|
||||||
.format(self.dic_cnt[phase], phase))
|
.format(dic_cnt[phase], phase))
|
||||||
print("Number of GT files: {}. Files not found: {}"
|
print("Number of GT files: {}. Files not found: {}"
|
||||||
.format(self.cnt_gt, self.cnt_fnf))
|
.format(cnt_gt, cnt_fnf))
|
||||||
print("\nOutput files:\n{}\n{}\n".format(self.path_names, self.path_joints))
|
print("\nOutput files:\n{}\n{}\n".format(self.path_names, self.path_joints))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from utils.nuscenes import select_categories
|
|||||||
from utils.camera import project_3d
|
from utils.camera import project_3d
|
||||||
from utils.pifpaf import get_input_data, preprocess_pif
|
from utils.pifpaf import get_input_data, preprocess_pif
|
||||||
|
|
||||||
|
|
||||||
class PreprocessNuscenes:
|
class PreprocessNuscenes:
|
||||||
"""
|
"""
|
||||||
Preprocess Nuscenes dataset
|
Preprocess Nuscenes dataset
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user