add factory method
This commit is contained in:
parent
496e147c2a
commit
6f3379e394
@ -1,13 +1,13 @@
|
||||
"""Preprocess annnotations with KITTI ground-truth"""
|
||||
"""Preprocess annotations with KITTI ground-truth"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import math
|
||||
import copy
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import datetime
|
||||
from utils.kitti import get_calibration, check_conditions, split_training
|
||||
from utils.kitti import get_calibration, split_training, parse_ground_truth
|
||||
from utils.pifpaf import get_input_data, preprocess_pif
|
||||
from utils.misc import get_idx_max, append_cluster
|
||||
|
||||
@ -47,23 +47,18 @@ class PreprocessKitti:
|
||||
self.set_train, self.set_val = split_training(self.names_gt, path_train, path_val)
|
||||
|
||||
def run(self):
|
||||
"""Save json files"""
|
||||
|
||||
cnt_gt = 0
|
||||
cnt_fnf = 0
|
||||
dic_cnt = {'train': 0, 'val': 0, 'test': 0}
|
||||
|
||||
for name in self.names_gt:
|
||||
# Extract ground truth
|
||||
path_gt = os.path.join(self.dir_gt, name)
|
||||
basename, _ = os.path.splitext(name)
|
||||
boxes_gt = []
|
||||
dds = []
|
||||
|
||||
if name in self.set_train:
|
||||
phase = 'train'
|
||||
elif name in self.set_val:
|
||||
phase = 'val'
|
||||
else:
|
||||
phase, flag = self._factory_phase(name)
|
||||
if flag:
|
||||
cnt_fnf += 1
|
||||
continue
|
||||
|
||||
@ -72,18 +67,11 @@ class PreprocessKitti:
|
||||
kk, tt = get_calibration(path_txt)
|
||||
|
||||
# Iterate over each line of the gt file and save box location and distances
|
||||
with open(path_gt, "r") as f_gt:
|
||||
for line_gt in f_gt:
|
||||
if check_conditions(line_gt, mode='gt'):
|
||||
box = [float(x) for x in line_gt.split()[4:8]]
|
||||
boxes_gt.append(box)
|
||||
loc_gt = [float(x) for x in line_gt.split()[11:14]]
|
||||
dd = math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2)
|
||||
dds.append(dd)
|
||||
self.dic_names[basename + '.png']['boxes'].append(box)
|
||||
self.dic_names[basename + '.png']['dds'].append(dd)
|
||||
self.dic_names[basename + '.png']['K'] = kk.tolist()
|
||||
cnt_gt += 1
|
||||
boxes_gt, dds_gt, _, _ = parse_ground_truth(path_gt)
|
||||
self.dic_names[basename + '.png']['boxes'] = copy.deepcopy(boxes_gt)
|
||||
self.dic_names[basename + '.png']['dds'] = copy.deepcopy(dds_gt)
|
||||
self.dic_names[basename + '.png']['K'] = copy.deepcopy(kk.tolist())
|
||||
cnt_gt += len(boxes_gt)
|
||||
|
||||
# Find the annotations if exists
|
||||
try:
|
||||
@ -103,13 +91,13 @@ class PreprocessKitti:
|
||||
|
||||
self.dic_jo[phase]['kps'].append(uv_kps[ii])
|
||||
self.dic_jo[phase]['X'].append(inputs[ii])
|
||||
self.dic_jo[phase]['Y'].append([dds[idx_max]]) # Trick to make it (nn,1)
|
||||
self.dic_jo[phase]['Y'].append([dds_gt[idx_max]]) # Trick to make it (nn,1)
|
||||
self.dic_jo[phase]['K'] = kk.tolist()
|
||||
self.dic_jo[phase]['names'].append(name) # One image name for each annotation
|
||||
append_cluster(self.dic_jo, phase, inputs[ii], dds[idx_max], uv_kps[ii])
|
||||
append_cluster(self.dic_jo, phase, inputs[ii], dds_gt[idx_max], uv_kps[ii])
|
||||
dic_cnt[phase] += 1
|
||||
boxes_gt.pop(idx_max)
|
||||
dds.pop(idx_max)
|
||||
dds_gt.pop(idx_max)
|
||||
|
||||
with open(self.path_joints, 'w') as file:
|
||||
json.dump(self.dic_jo, file)
|
||||
@ -122,5 +110,16 @@ class PreprocessKitti:
|
||||
.format(cnt_gt, cnt_fnf))
|
||||
print("\nOutput files:\n{}\n{}\n".format(self.path_names, self.path_joints))
|
||||
|
||||
def _factory_phase(self, name):
|
||||
"""Choose the phase"""
|
||||
|
||||
phase = None
|
||||
flag = False
|
||||
if name in self.set_train:
|
||||
phase = 'train'
|
||||
elif name in self.set_val:
|
||||
phase = 'val'
|
||||
else:
|
||||
flag = True
|
||||
return phase, flag
|
||||
|
||||
|
||||
@ -161,5 +161,4 @@ def parse_ground_truth(path_gt):
|
||||
loc_gt = [float(x) for x in line_gt.split()[11:14]]
|
||||
dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2))
|
||||
|
||||
return boxes_gt, dds_gt, truncs_gt, occs_gt
|
||||
|
||||
return boxes_gt, dds_gt, truncs_gt, occs_gt
|
||||
@ -1,6 +1,5 @@
|
||||
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Ellipse
|
||||
|
||||
Loading…
Reference in New Issue
Block a user