refactor class
This commit is contained in:
parent
ad252e76e7
commit
086a4d909a
@ -6,22 +6,29 @@ import os
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
from models.architectures import TriLinear, LinearModel
|
||||
from models.architectures import LinearModel
|
||||
from utils.misc import laplace_sampling
|
||||
from utils.kitti import eval_geometric, get_calibration
|
||||
from utils.normalize import unnormalize_bi
|
||||
from utils.pifpaf import get_input_data, preprocess_pif
|
||||
|
||||
|
||||
class RunKitti:
|
||||
|
||||
def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
cnt_ann = 0
|
||||
cnt_file = 0
|
||||
cnt_no_file = 0
|
||||
average_y = 0.48
|
||||
n_samples = 100
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout, stereo):
|
||||
|
||||
# Set directories
|
||||
assert dir_ann, "Annotations folder is required"
|
||||
self.dir_ann = dir_ann
|
||||
self.average_y = 0.48
|
||||
self.n_dropout = n_dropout
|
||||
self.n_samples = 100
|
||||
|
||||
list_ann = glob.glob(os.path.join(dir_ann, '*.json'))
|
||||
self.dir_kk = os.path.join('data', 'kitti', 'calib')
|
||||
@ -43,43 +50,21 @@ class RunKitti:
|
||||
self.model.eval() # Default is train
|
||||
self.model.to(self.device)
|
||||
|
||||
# Import
|
||||
from utils.misc import epistemic_variance, laplace_sampling
|
||||
self.epistemic_variance = epistemic_variance
|
||||
self.laplace_sampling = laplace_sampling
|
||||
from visuals.printer import Printer
|
||||
self.Printer = Printer
|
||||
from utils.kitti import eval_geometric, get_calibration
|
||||
self.eval_geometric = eval_geometric
|
||||
self.get_calibration = get_calibration
|
||||
|
||||
from utils.normalize import unnormalize_bi
|
||||
self.unnormalize_bi = unnormalize_bi
|
||||
|
||||
from utils.pifpaf import get_input_data, preprocess_pif
|
||||
self.get_input_data = get_input_data
|
||||
self.preprocess_pif = preprocess_pif
|
||||
|
||||
# Counters
|
||||
self.cnt_ann = 0
|
||||
self.cnt_file = 0
|
||||
self.cnt_no_file = 0
|
||||
|
||||
def run(self):
|
||||
|
||||
# Run inference
|
||||
for basename in self.list_basename:
|
||||
path_calib = os.path.join(self.dir_kk, basename + '.txt')
|
||||
kk, tt = self.get_calibration(path_calib)
|
||||
kk, tt = get_calibration(path_calib)
|
||||
|
||||
path_ann = os.path.join(self.dir_ann, basename + '.png.pifpaf.json')
|
||||
with open(path_ann, 'r') as f:
|
||||
annotations = json.load(f)
|
||||
|
||||
boxes, keypoints = self.preprocess_pif(annotations)
|
||||
(inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = self.get_input_data(boxes, keypoints, kk)
|
||||
boxes, keypoints = preprocess_pif(annotations)
|
||||
(inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = get_input_data(boxes, keypoints, kk)
|
||||
|
||||
dds_geom, xy_centers = self.eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48)
|
||||
dds_geom, xy_centers = eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48)
|
||||
|
||||
self.cnt_ann += len(boxes)
|
||||
|
||||
@ -96,8 +81,8 @@ class RunKitti:
|
||||
self.model.dropout.training = True
|
||||
for ii in range(self.n_dropout):
|
||||
outputs = self.model(inputs)
|
||||
outputs = self.unnormalize_bi(outputs)
|
||||
samples = self.laplace_sampling(outputs, self.n_samples)
|
||||
outputs = unnormalize_bi(outputs)
|
||||
samples = laplace_sampling(outputs, self.n_samples)
|
||||
total_outputs = torch.cat((total_outputs, samples), 0)
|
||||
varss = total_outputs.std(0)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user