refactor class
This commit is contained in:
parent
ad252e76e7
commit
086a4d909a
@ -6,22 +6,29 @@ import os
|
|||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from models.architectures import TriLinear, LinearModel
|
from models.architectures import LinearModel
|
||||||
|
from utils.misc import laplace_sampling
|
||||||
|
from utils.kitti import eval_geometric, get_calibration
|
||||||
|
from utils.normalize import unnormalize_bi
|
||||||
|
from utils.pifpaf import get_input_data, preprocess_pif
|
||||||
|
|
||||||
|
|
||||||
class RunKitti:
|
class RunKitti:
|
||||||
|
|
||||||
def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout):
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
cnt_ann = 0
|
||||||
|
cnt_file = 0
|
||||||
|
cnt_no_file = 0
|
||||||
|
average_y = 0.48
|
||||||
|
n_samples = 100
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout, stereo):
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Set directories
|
# Set directories
|
||||||
assert dir_ann, "Annotations folder is required"
|
assert dir_ann, "Annotations folder is required"
|
||||||
self.dir_ann = dir_ann
|
self.dir_ann = dir_ann
|
||||||
self.average_y = 0.48
|
|
||||||
self.n_dropout = n_dropout
|
self.n_dropout = n_dropout
|
||||||
self.n_samples = 100
|
|
||||||
|
|
||||||
list_ann = glob.glob(os.path.join(dir_ann, '*.json'))
|
list_ann = glob.glob(os.path.join(dir_ann, '*.json'))
|
||||||
self.dir_kk = os.path.join('data', 'kitti', 'calib')
|
self.dir_kk = os.path.join('data', 'kitti', 'calib')
|
||||||
@ -43,43 +50,21 @@ class RunKitti:
|
|||||||
self.model.eval() # Default is train
|
self.model.eval() # Default is train
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
# Import
|
|
||||||
from utils.misc import epistemic_variance, laplace_sampling
|
|
||||||
self.epistemic_variance = epistemic_variance
|
|
||||||
self.laplace_sampling = laplace_sampling
|
|
||||||
from visuals.printer import Printer
|
|
||||||
self.Printer = Printer
|
|
||||||
from utils.kitti import eval_geometric, get_calibration
|
|
||||||
self.eval_geometric = eval_geometric
|
|
||||||
self.get_calibration = get_calibration
|
|
||||||
|
|
||||||
from utils.normalize import unnormalize_bi
|
|
||||||
self.unnormalize_bi = unnormalize_bi
|
|
||||||
|
|
||||||
from utils.pifpaf import get_input_data, preprocess_pif
|
|
||||||
self.get_input_data = get_input_data
|
|
||||||
self.preprocess_pif = preprocess_pif
|
|
||||||
|
|
||||||
# Counters
|
|
||||||
self.cnt_ann = 0
|
|
||||||
self.cnt_file = 0
|
|
||||||
self.cnt_no_file = 0
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
for basename in self.list_basename:
|
for basename in self.list_basename:
|
||||||
path_calib = os.path.join(self.dir_kk, basename + '.txt')
|
path_calib = os.path.join(self.dir_kk, basename + '.txt')
|
||||||
kk, tt = self.get_calibration(path_calib)
|
kk, tt = get_calibration(path_calib)
|
||||||
|
|
||||||
path_ann = os.path.join(self.dir_ann, basename + '.png.pifpaf.json')
|
path_ann = os.path.join(self.dir_ann, basename + '.png.pifpaf.json')
|
||||||
with open(path_ann, 'r') as f:
|
with open(path_ann, 'r') as f:
|
||||||
annotations = json.load(f)
|
annotations = json.load(f)
|
||||||
|
|
||||||
boxes, keypoints = self.preprocess_pif(annotations)
|
boxes, keypoints = preprocess_pif(annotations)
|
||||||
(inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = self.get_input_data(boxes, keypoints, kk)
|
(inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = get_input_data(boxes, keypoints, kk)
|
||||||
|
|
||||||
dds_geom, xy_centers = self.eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48)
|
dds_geom, xy_centers = eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48)
|
||||||
|
|
||||||
self.cnt_ann += len(boxes)
|
self.cnt_ann += len(boxes)
|
||||||
|
|
||||||
@ -96,8 +81,8 @@ class RunKitti:
|
|||||||
self.model.dropout.training = True
|
self.model.dropout.training = True
|
||||||
for ii in range(self.n_dropout):
|
for ii in range(self.n_dropout):
|
||||||
outputs = self.model(inputs)
|
outputs = self.model(inputs)
|
||||||
outputs = self.unnormalize_bi(outputs)
|
outputs = unnormalize_bi(outputs)
|
||||||
samples = self.laplace_sampling(outputs, self.n_samples)
|
samples = laplace_sampling(outputs, self.n_samples)
|
||||||
total_outputs = torch.cat((total_outputs, samples), 0)
|
total_outputs = torch.cat((total_outputs, samples), 0)
|
||||||
varss = total_outputs.std(0)
|
varss = total_outputs.std(0)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user