add factory method

This commit is contained in:
lorenzo 2019-05-21 14:45:23 +02:00
parent 496e147c2a
commit 6f3379e394
3 changed files with 26 additions and 29 deletions

View File

@ -1,13 +1,13 @@
"""Preprocess annnotations with KITTI ground-truth""" """Preprocess annotations with KITTI ground-truth"""
import os import os
import glob import glob
import math import copy
import logging import logging
from collections import defaultdict from collections import defaultdict
import json import json
import datetime import datetime
from utils.kitti import get_calibration, check_conditions, split_training from utils.kitti import get_calibration, split_training, parse_ground_truth
from utils.pifpaf import get_input_data, preprocess_pif from utils.pifpaf import get_input_data, preprocess_pif
from utils.misc import get_idx_max, append_cluster from utils.misc import get_idx_max, append_cluster
@ -47,23 +47,18 @@ class PreprocessKitti:
self.set_train, self.set_val = split_training(self.names_gt, path_train, path_val) self.set_train, self.set_val = split_training(self.names_gt, path_train, path_val)
def run(self): def run(self):
"""Save json files"""
cnt_gt = 0 cnt_gt = 0
cnt_fnf = 0 cnt_fnf = 0
dic_cnt = {'train': 0, 'val': 0, 'test': 0} dic_cnt = {'train': 0, 'val': 0, 'test': 0}
for name in self.names_gt: for name in self.names_gt:
# Extract ground truth
path_gt = os.path.join(self.dir_gt, name) path_gt = os.path.join(self.dir_gt, name)
basename, _ = os.path.splitext(name) basename, _ = os.path.splitext(name)
boxes_gt = []
dds = []
if name in self.set_train: phase, flag = self._factory_phase(name)
phase = 'train' if flag:
elif name in self.set_val:
phase = 'val'
else:
cnt_fnf += 1 cnt_fnf += 1
continue continue
@ -72,18 +67,11 @@ class PreprocessKitti:
kk, tt = get_calibration(path_txt) kk, tt = get_calibration(path_txt)
# Iterate over each line of the gt file and save box location and distances # Iterate over each line of the gt file and save box location and distances
with open(path_gt, "r") as f_gt: boxes_gt, dds_gt, _, _ = parse_ground_truth(path_gt)
for line_gt in f_gt: self.dic_names[basename + '.png']['boxes'] = copy.deepcopy(boxes_gt)
if check_conditions(line_gt, mode='gt'): self.dic_names[basename + '.png']['dds'] = copy.deepcopy(dds_gt)
box = [float(x) for x in line_gt.split()[4:8]] self.dic_names[basename + '.png']['K'] = copy.deepcopy(kk.tolist())
boxes_gt.append(box) cnt_gt += len(boxes_gt)
loc_gt = [float(x) for x in line_gt.split()[11:14]]
dd = math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2)
dds.append(dd)
self.dic_names[basename + '.png']['boxes'].append(box)
self.dic_names[basename + '.png']['dds'].append(dd)
self.dic_names[basename + '.png']['K'] = kk.tolist()
cnt_gt += 1
# Find the annotations if exists # Find the annotations if exists
try: try:
@ -103,13 +91,13 @@ class PreprocessKitti:
self.dic_jo[phase]['kps'].append(uv_kps[ii]) self.dic_jo[phase]['kps'].append(uv_kps[ii])
self.dic_jo[phase]['X'].append(inputs[ii]) self.dic_jo[phase]['X'].append(inputs[ii])
self.dic_jo[phase]['Y'].append([dds[idx_max]]) # Trick to make it (nn,1) self.dic_jo[phase]['Y'].append([dds_gt[idx_max]]) # Trick to make it (nn,1)
self.dic_jo[phase]['K'] = kk.tolist() self.dic_jo[phase]['K'] = kk.tolist()
self.dic_jo[phase]['names'].append(name) # One image name for each annotation self.dic_jo[phase]['names'].append(name) # One image name for each annotation
append_cluster(self.dic_jo, phase, inputs[ii], dds[idx_max], uv_kps[ii]) append_cluster(self.dic_jo, phase, inputs[ii], dds_gt[idx_max], uv_kps[ii])
dic_cnt[phase] += 1 dic_cnt[phase] += 1
boxes_gt.pop(idx_max) boxes_gt.pop(idx_max)
dds.pop(idx_max) dds_gt.pop(idx_max)
with open(self.path_joints, 'w') as file: with open(self.path_joints, 'w') as file:
json.dump(self.dic_jo, file) json.dump(self.dic_jo, file)
@ -122,5 +110,16 @@ class PreprocessKitti:
.format(cnt_gt, cnt_fnf)) .format(cnt_gt, cnt_fnf))
print("\nOutput files:\n{}\n{}\n".format(self.path_names, self.path_joints)) print("\nOutput files:\n{}\n{}\n".format(self.path_names, self.path_joints))
def _factory_phase(self, name):
"""Choose the phase"""
phase = None
flag = False
if name in self.set_train:
phase = 'train'
elif name in self.set_val:
phase = 'val'
else:
flag = True
return phase, flag

View File

@ -161,5 +161,4 @@ def parse_ground_truth(path_gt):
loc_gt = [float(x) for x in line_gt.split()[11:14]] loc_gt = [float(x) for x in line_gt.split()[11:14]]
dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2)) dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2))
return boxes_gt, dds_gt, truncs_gt, occs_gt return boxes_gt, dds_gt, truncs_gt, occs_gt

View File

@ -1,6 +1,5 @@
import os import os
import time
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse from matplotlib.patches import Ellipse