From 6f3379e394a6b5be476538414deef95e3d20551b Mon Sep 17 00:00:00 2001 From: lorenzo Date: Tue, 21 May 2019 14:45:23 +0200 Subject: [PATCH] add factory method --- src/features/preprocess_ki.py | 51 +++++++++++++++++------------------ src/utils/kitti.py | 3 +-- src/visuals/results.py | 1 - 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/features/preprocess_ki.py b/src/features/preprocess_ki.py index 2ae77cc..ce0d99c 100644 --- a/src/features/preprocess_ki.py +++ b/src/features/preprocess_ki.py @@ -1,13 +1,13 @@ -"""Preprocess annnotations with KITTI ground-truth""" +"""Preprocess annotations with KITTI ground-truth""" import os import glob -import math +import copy import logging from collections import defaultdict import json import datetime -from utils.kitti import get_calibration, check_conditions, split_training +from utils.kitti import get_calibration, split_training, parse_ground_truth from utils.pifpaf import get_input_data, preprocess_pif from utils.misc import get_idx_max, append_cluster @@ -47,23 +47,18 @@ class PreprocessKitti: self.set_train, self.set_val = split_training(self.names_gt, path_train, path_val) def run(self): + """Save json files""" cnt_gt = 0 cnt_fnf = 0 dic_cnt = {'train': 0, 'val': 0, 'test': 0} for name in self.names_gt: - # Extract ground truth path_gt = os.path.join(self.dir_gt, name) basename, _ = os.path.splitext(name) - boxes_gt = [] - dds = [] - if name in self.set_train: - phase = 'train' - elif name in self.set_val: - phase = 'val' - else: + phase, flag = self._factory_phase(name) + if flag: cnt_fnf += 1 continue @@ -72,18 +67,11 @@ class PreprocessKitti: kk, tt = get_calibration(path_txt) # Iterate over each line of the gt file and save box location and distances - with open(path_gt, "r") as f_gt: - for line_gt in f_gt: - if check_conditions(line_gt, mode='gt'): - box = [float(x) for x in line_gt.split()[4:8]] - boxes_gt.append(box) - loc_gt = [float(x) for x in line_gt.split()[11:14]] - dd = math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2) - dds.append(dd) - self.dic_names[basename + '.png']['boxes'].append(box) - self.dic_names[basename + '.png']['dds'].append(dd) - self.dic_names[basename + '.png']['K'] = kk.tolist() - cnt_gt += 1 + boxes_gt, dds_gt, _, _ = parse_ground_truth(path_gt) + self.dic_names[basename + '.png']['boxes'] = copy.deepcopy(boxes_gt) + self.dic_names[basename + '.png']['dds'] = copy.deepcopy(dds_gt) + self.dic_names[basename + '.png']['K'] = copy.deepcopy(kk.tolist()) + cnt_gt += len(boxes_gt) # Find the annotations if exists try: @@ -103,13 +91,13 @@ class PreprocessKitti: self.dic_jo[phase]['kps'].append(uv_kps[ii]) self.dic_jo[phase]['X'].append(inputs[ii]) - self.dic_jo[phase]['Y'].append([dds[idx_max]]) # Trick to make it (nn,1) + self.dic_jo[phase]['Y'].append([dds_gt[idx_max]]) # Trick to make it (nn,1) self.dic_jo[phase]['K'] = kk.tolist() self.dic_jo[phase]['names'].append(name) # One image name for each annotation - append_cluster(self.dic_jo, phase, inputs[ii], dds[idx_max], uv_kps[ii]) + append_cluster(self.dic_jo, phase, inputs[ii], dds_gt[idx_max], uv_kps[ii]) dic_cnt[phase] += 1 boxes_gt.pop(idx_max) - dds.pop(idx_max) + dds_gt.pop(idx_max) with open(self.path_joints, 'w') as file: json.dump(self.dic_jo, file) @@ -122,5 +110,16 @@ class PreprocessKitti: .format(cnt_gt, cnt_fnf)) print("\nOutput files:\n{}\n{}\n".format(self.path_names, self.path_joints)) + def _factory_phase(self, name): + """Choose the phase""" + phase = None + flag = False + if name in self.set_train: + phase = 'train' + elif name in self.set_val: + phase = 'val' + else: + flag = True + return phase, flag diff --git a/src/utils/kitti.py b/src/utils/kitti.py index c9ecb75..bd48992 100644 --- a/src/utils/kitti.py +++ b/src/utils/kitti.py @@ -161,5 +161,4 @@ def parse_ground_truth(path_gt): loc_gt = [float(x) for x in line_gt.split()[11:14]] dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2)) - return boxes_gt, dds_gt, truncs_gt, occs_gt - + return boxes_gt, dds_gt, truncs_gt, occs_gt \ No newline at end of file diff --git a/src/visuals/results.py b/src/visuals/results.py index d21c2c7..d29b9a5 100644 --- a/src/visuals/results.py +++ b/src/visuals/results.py @@ -1,6 +1,5 @@ import os -import time import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Ellipse