From 086a4d909aaba6ab573175877eec7b26a9511176 Mon Sep 17 00:00:00 2001 From: lorenzo Date: Tue, 28 May 2019 08:52:10 +0200 Subject: [PATCH] refactor class --- src/eval/run_kitti.py | 53 ++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/src/eval/run_kitti.py b/src/eval/run_kitti.py index d7b249d..1387a0f 100644 --- a/src/eval/run_kitti.py +++ b/src/eval/run_kitti.py @@ -6,22 +6,29 @@ import os import glob import json import logging -from models.architectures import TriLinear, LinearModel +from models.architectures import LinearModel +from utils.misc import laplace_sampling +from utils.kitti import eval_geometric, get_calibration +from utils.normalize import unnormalize_bi +from utils.pifpaf import get_input_data, preprocess_pif class RunKitti: - def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout): + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + cnt_ann = 0 + cnt_file = 0 + cnt_no_file = 0 + average_y = 0.48 + n_samples = 100 - logging.basicConfig(level=logging.INFO) - self.logger = logging.getLogger(__name__) + def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout, stereo): # Set directories assert dir_ann, "Annotations folder is required" self.dir_ann = dir_ann - self.average_y = 0.48 self.n_dropout = n_dropout - self.n_samples = 100 list_ann = glob.glob(os.path.join(dir_ann, '*.json')) self.dir_kk = os.path.join('data', 'kitti', 'calib') @@ -43,43 +50,21 @@ class RunKitti: self.model.eval() # Default is train self.model.to(self.device) - # Import - from utils.misc import epistemic_variance, laplace_sampling - self.epistemic_variance = epistemic_variance - self.laplace_sampling = laplace_sampling - from visuals.printer import Printer - self.Printer = Printer - from utils.kitti import eval_geometric, get_calibration - self.eval_geometric = eval_geometric - self.get_calibration = get_calibration - - from utils.normalize import unnormalize_bi - self.unnormalize_bi = unnormalize_bi - - from utils.pifpaf import get_input_data, preprocess_pif - self.get_input_data = get_input_data - self.preprocess_pif = preprocess_pif - - # Counters - self.cnt_ann = 0 - self.cnt_file = 0 - self.cnt_no_file = 0 - def run(self): # Run inference for basename in self.list_basename: path_calib = os.path.join(self.dir_kk, basename + '.txt') - kk, tt = self.get_calibration(path_calib) + kk, tt = get_calibration(path_calib) path_ann = os.path.join(self.dir_ann, basename + '.png.pifpaf.json') with open(path_ann, 'r') as f: annotations = json.load(f) - boxes, keypoints = self.preprocess_pif(annotations) - (inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = self.get_input_data(boxes, keypoints, kk) + boxes, keypoints = preprocess_pif(annotations) + (inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = get_input_data(boxes, keypoints, kk) - dds_geom, xy_centers = self.eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48) + dds_geom, xy_centers = eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48) self.cnt_ann += len(boxes) @@ -96,8 +81,8 @@ class RunKitti: self.model.dropout.training = True for ii in range(self.n_dropout): outputs = self.model(inputs) - outputs = self.unnormalize_bi(outputs) - samples = self.laplace_sampling(outputs, self.n_samples) + outputs = unnormalize_bi(outputs) + samples = laplace_sampling(outputs, self.n_samples) total_outputs = torch.cat((total_outputs, samples), 0) varss = total_outputs.std(0)