refactor class

This commit is contained in:
lorenzo 2019-05-28 08:52:10 +02:00
parent ad252e76e7
commit 086a4d909a

View File

@ -6,22 +6,29 @@ import os
import glob
import json
import logging
from models.architectures import TriLinear, LinearModel
from models.architectures import LinearModel
from utils.misc import laplace_sampling
from utils.kitti import eval_geometric, get_calibration
from utils.normalize import unnormalize_bi
from utils.pifpaf import get_input_data, preprocess_pif
class RunKitti:
def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
cnt_ann = 0
cnt_file = 0
cnt_no_file = 0
average_y = 0.48
n_samples = 100
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout, stereo):
# Set directories
assert dir_ann, "Annotations folder is required"
self.dir_ann = dir_ann
self.average_y = 0.48
self.n_dropout = n_dropout
self.n_samples = 100
list_ann = glob.glob(os.path.join(dir_ann, '*.json'))
self.dir_kk = os.path.join('data', 'kitti', 'calib')
@ -43,43 +50,21 @@ class RunKitti:
self.model.eval() # Default is train
self.model.to(self.device)
# Import
from utils.misc import epistemic_variance, laplace_sampling
self.epistemic_variance = epistemic_variance
self.laplace_sampling = laplace_sampling
from visuals.printer import Printer
self.Printer = Printer
from utils.kitti import eval_geometric, get_calibration
self.eval_geometric = eval_geometric
self.get_calibration = get_calibration
from utils.normalize import unnormalize_bi
self.unnormalize_bi = unnormalize_bi
from utils.pifpaf import get_input_data, preprocess_pif
self.get_input_data = get_input_data
self.preprocess_pif = preprocess_pif
# Counters
self.cnt_ann = 0
self.cnt_file = 0
self.cnt_no_file = 0
def run(self):
# Run inference
for basename in self.list_basename:
path_calib = os.path.join(self.dir_kk, basename + '.txt')
kk, tt = self.get_calibration(path_calib)
kk, tt = get_calibration(path_calib)
path_ann = os.path.join(self.dir_ann, basename + '.png.pifpaf.json')
with open(path_ann, 'r') as f:
annotations = json.load(f)
boxes, keypoints = self.preprocess_pif(annotations)
(inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = self.get_input_data(boxes, keypoints, kk)
boxes, keypoints = preprocess_pif(annotations)
(inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = get_input_data(boxes, keypoints, kk)
dds_geom, xy_centers = self.eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48)
dds_geom, xy_centers = eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48)
self.cnt_ann += len(boxes)
@ -96,8 +81,8 @@ class RunKitti:
self.model.dropout.training = True
for ii in range(self.n_dropout):
outputs = self.model(inputs)
outputs = self.unnormalize_bi(outputs)
samples = self.laplace_sampling(outputs, self.n_samples)
outputs = unnormalize_bi(outputs)
samples = laplace_sampling(outputs, self.n_samples)
total_outputs = torch.cat((total_outputs, samples), 0)
varss = total_outputs.std(0)