pylint refactor (1)
This commit is contained in:
parent
02ac6626d6
commit
eeea8945fb
@ -58,6 +58,7 @@ def cli():
|
||||
# 2) Monoloco argument
|
||||
predict_parser.add_argument('--model', help='path of MonoLoco model to load',
|
||||
default="data/models/best_model__seed_2_.pickle")
|
||||
predict_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=256)
|
||||
predict_parser.add_argument('--path_gt', help='path of json file with gt 3d localization',
|
||||
default='data/arrays/names-kitti-190513-1754.json')
|
||||
predict_parser.add_argument('--transform', help='transformation for the pose', default='None')
|
||||
|
||||
@ -4,7 +4,6 @@ import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from openpifpaf.network import nets
|
||||
from openpifpaf import decoder, show
|
||||
@ -47,7 +46,7 @@ class ImageList(torch.utils.data.Dataset):
|
||||
return len(self.image_paths)
|
||||
|
||||
|
||||
def elaborate_cli(args):
|
||||
def factory_from_args(args):
|
||||
|
||||
# Merge the model_pifpaf argument
|
||||
if not args.checkpoint:
|
||||
@ -77,7 +76,7 @@ def elaborate_cli(args):
|
||||
|
||||
def predict(args):
|
||||
|
||||
elaborate_cli(args)
|
||||
factory_from_args(args)
|
||||
|
||||
# load model
|
||||
model, _ = nets.factory_from_args(args)
|
||||
|
||||
@ -3,36 +3,33 @@
|
||||
From a json file output images and json annotations
|
||||
"""
|
||||
|
||||
import torch
|
||||
import sys
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from models.architectures import LinearModel
|
||||
from utils.camera import preprocess_single, get_keypoints, get_depth
|
||||
from utils.misc import epistemic_variance, laplace_sampling, get_idx_max
|
||||
from visuals.printer import Printer
|
||||
from utils.normalize import unnormalize_bi
|
||||
from utils.kitti import get_simplified_calibration, get_calibration
|
||||
from utils.pifpaf import get_input_data
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class PredictMonoLoco:
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
output_size = 2
|
||||
input_size = 17 * 2
|
||||
|
||||
def __init__(self, boxes, keypoints, image_path, output_path, args):
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
basename, _ = os.path.splitext(os.path.basename(image_path))
|
||||
|
||||
# Check for ground-truth file
|
||||
try:
|
||||
with open(args.path_gt, 'r') as f:
|
||||
self.dic_names = json.load(f)
|
||||
except FileNotFoundError:
|
||||
self.dic_names = None
|
||||
|
||||
print('-' * 120)
|
||||
print("Monoloco: ground truth file not found")
|
||||
print('-' * 120)
|
||||
|
||||
self.boxes = boxes
|
||||
self.keypoints = keypoints
|
||||
self.image_path = image_path
|
||||
@ -43,7 +40,6 @@ class PredictMonoLoco:
|
||||
self.z_max = args.z_max
|
||||
self.output_types = args.output_types
|
||||
self.show = args.show
|
||||
output_size = 2
|
||||
self.n_samples = 100
|
||||
self.n_dropout = args.n_dropout
|
||||
if self.n_dropout > 0:
|
||||
@ -52,43 +48,23 @@ class PredictMonoLoco:
|
||||
self.epistemic = False
|
||||
self.iou_min = 0.25
|
||||
|
||||
# load the model
|
||||
input_size = 17 * 2
|
||||
|
||||
# self.model = TriLinear(input_size=input_size, output_size=output_size, p_dropout=dropout)
|
||||
self.model = LinearModel(input_size=input_size, output_size=output_size)
|
||||
# load the model parameters
|
||||
self.model = LinearModel(input_size=self.input_size, output_size=self.output_size, linear_size=args.hidden_size)
|
||||
self.model.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage))
|
||||
self.model.eval() # Default is train
|
||||
self.model.to(self.device)
|
||||
|
||||
# Import
|
||||
from utils.camera import preprocess_single, get_keypoints, get_depth
|
||||
self.preprocess_single = preprocess_single
|
||||
self.get_keypoints = get_keypoints
|
||||
self.get_depth = get_depth
|
||||
|
||||
from utils.misc import epistemic_variance, laplace_sampling, get_idx_max
|
||||
self.epistemic_variance = epistemic_variance
|
||||
self.laplace_sampling = laplace_sampling
|
||||
self.get_idx_max = get_idx_max
|
||||
from visuals.printer import Printer
|
||||
self.Printer = Printer
|
||||
|
||||
from utils.normalize import unnormalize_bi
|
||||
self.unnormalize_bi = unnormalize_bi
|
||||
|
||||
from utils.kitti import get_simplified_calibration, get_calibration
|
||||
self.get_simplified_calibration = get_simplified_calibration
|
||||
self.get_calibration = get_calibration
|
||||
|
||||
from utils.pifpaf import get_input_data
|
||||
self.get_input_data = get_input_data
|
||||
# Check for ground-truth file
|
||||
try:
|
||||
with open(args.path_gt, 'r') as f:
|
||||
self.dic_names = json.load(f)
|
||||
except FileNotFoundError:
|
||||
self.dic_names = None
|
||||
print('-' * 120 + "\nMonoloco: ground truth file not found\n" + '-' * 120)
|
||||
|
||||
def run(self):
|
||||
|
||||
cnt = 0
|
||||
|
||||
# Extract calibration matrix if ground-truth file is present or use a default one
|
||||
cnt = 0
|
||||
name = os.path.basename(self.image_path)
|
||||
if self.dic_names:
|
||||
kk = self.dic_names[name]['K']
|
||||
@ -97,7 +73,7 @@ class PredictMonoLoco:
|
||||
kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]]
|
||||
|
||||
(inputs_norm, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = \
|
||||
self.get_input_data(self.boxes, self.keypoints, kk, left_to_right=True)
|
||||
get_input_data(self.boxes, self.keypoints, kk, left_to_right=True)
|
||||
|
||||
# Conversion into torch tensor
|
||||
if len(inputs_norm) > 0:
|
||||
@ -113,8 +89,8 @@ class PredictMonoLoco:
|
||||
if self.n_dropout > 0:
|
||||
for ii in range(self.n_dropout):
|
||||
outputs = self.model(inputs)
|
||||
outputs = self.unnormalize_bi(outputs)
|
||||
samples = self.laplace_sampling(outputs, self.n_samples)
|
||||
outputs = unnormalize_bi(outputs)
|
||||
samples = laplace_sampling(outputs, self.n_samples)
|
||||
total_outputs = torch.cat((total_outputs, samples), 0)
|
||||
varss = total_outputs.std(0)
|
||||
else:
|
||||
@ -124,7 +100,7 @@ class PredictMonoLoco:
|
||||
start_single = time.time()
|
||||
self.model.dropout.training = False
|
||||
outputs = self.model(inputs)
|
||||
outputs = self.unnormalize_bi(outputs)
|
||||
outputs = unnormalize_bi(outputs)
|
||||
end = time.time()
|
||||
print("Total Forward pass time = {:.2f} ms".format((end-start) * 1000))
|
||||
print("Single pass time = {:.2f} ms".format((end - start_single) * 1000))
|
||||
@ -141,13 +117,11 @@ class PredictMonoLoco:
|
||||
|
||||
# Find the corresponding ground truth if available
|
||||
if self.dic_names:
|
||||
|
||||
idx_max, iou_max = self.get_idx_max(box, boxes_gt)
|
||||
idx_max, iou_max = get_idx_max(box, boxes_gt)
|
||||
if iou_max > self.iou_min:
|
||||
dd_real = dds_gt[idx_max]
|
||||
boxes_gt.pop(idx_max)
|
||||
dds_gt.pop(idx_max)
|
||||
|
||||
# In case of no matching
|
||||
else:
|
||||
dd_real = 0
|
||||
@ -156,9 +130,8 @@ class PredictMonoLoco:
|
||||
dd_real = dd_pred
|
||||
|
||||
uv_center = uv_centers[idx]
|
||||
xyz_real = self.get_depth(uv_center, kk, dd_real)
|
||||
xyz_pred = self.get_depth(uv_center, kk, dd_pred)
|
||||
|
||||
xyz_real = get_depth(uv_center, kk, dd_real)
|
||||
xyz_pred = get_depth(uv_center, kk, dd_pred)
|
||||
dic_out['boxes'].append(box)
|
||||
dic_out['dds_real'].append(dd_real)
|
||||
dic_out['dds_pred'].append(dd_pred)
|
||||
@ -172,14 +145,12 @@ class PredictMonoLoco:
|
||||
dic_out['uv_shoulders'].append(uv_shoulders[idx])
|
||||
|
||||
if any((xx in self.output_types for xx in ['front', 'bird', 'combined'])):
|
||||
|
||||
printer = self.Printer(self.image_path, self.output_path, dic_out, kk,
|
||||
y_scale=self.y_scale, output_types=self.output_types,
|
||||
show=self.show, z_max=self.z_max, epistemic=self.epistemic)
|
||||
printer = Printer(self.image_path, self.output_path, dic_out, kk,
|
||||
y_scale=self.y_scale, output_types=self.output_types,
|
||||
show=self.show, z_max=self.z_max, epistemic=self.epistemic)
|
||||
printer.print()
|
||||
|
||||
if 'json' in self.output_types:
|
||||
|
||||
with open(os.path.join(self.output_path + '.monoloco.json'), 'w') as ff:
|
||||
json.dump(dic_out, ff)
|
||||
|
||||
|
||||
@ -90,12 +90,6 @@ def preprocess_single(kps, kk):
|
||||
kps_orig.append(float(kp_orig[0]))
|
||||
kps_orig.append(float(kp_orig[1]))
|
||||
|
||||
|
||||
# Append the y of the ground foot to the keypoints
|
||||
# kp_gr = np.array([0, vv_gr, 1])
|
||||
# xy1_gr = pixel_to_camera(kp_gr, kk, 1)
|
||||
# kps_0c.append(float(xy1_gr[1]))
|
||||
|
||||
return kps_0c, kps_orig
|
||||
|
||||
|
||||
@ -198,5 +192,4 @@ def get_depth(uv_center, kk, dd):
|
||||
zz = dd / math.sqrt(1 + xyz_norm[0] ** 2 + xyz_norm[1] ** 2)
|
||||
|
||||
xyz = pixel_to_camera(uv_center_np, kk, zz).tolist()
|
||||
|
||||
return xyz
|
||||
return xyz
|
||||
|
||||
Loading…
Reference in New Issue
Block a user