pylint refactor (1)

This commit is contained in:
lorenzo 2019-05-21 18:34:57 +02:00
parent 02ac6626d6
commit eeea8945fb
4 changed files with 39 additions and 75 deletions

View File

@ -58,6 +58,7 @@ def cli():
# 2) Monoloco argument
predict_parser.add_argument('--model', help='path of MonoLoco model to load',
default="data/models/best_model__seed_2_.pickle")
predict_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=256)
predict_parser.add_argument('--path_gt', help='path of json file with gt 3d localization',
default='data/arrays/names-kitti-190513-1754.json')
predict_parser.add_argument('--transform', help='transformation for the pose', default='None')

View File

@ -4,7 +4,6 @@ import json
import os
import numpy as np
import torch
from openpifpaf.network import nets
from openpifpaf import decoder, show
@ -47,7 +46,7 @@ class ImageList(torch.utils.data.Dataset):
return len(self.image_paths)
def elaborate_cli(args):
def factory_from_args(args):
# Merge the model_pifpaf argument
if not args.checkpoint:
@ -77,7 +76,7 @@ def elaborate_cli(args):
def predict(args):
elaborate_cli(args)
factory_from_args(args)
# load model
model, _ = nets.factory_from_args(args)

View File

@ -3,36 +3,33 @@
From a json file output images and json annotations
"""
import torch
import sys
import numpy as np
from collections import defaultdict
import os
import json
import logging
import time
from models.architectures import LinearModel
from utils.camera import preprocess_single, get_keypoints, get_depth
from utils.misc import epistemic_variance, laplace_sampling, get_idx_max
from visuals.printer import Printer
from utils.normalize import unnormalize_bi
from utils.kitti import get_simplified_calibration, get_calibration
from utils.pifpaf import get_input_data
import numpy as np
import torch
class PredictMonoLoco:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
output_size = 2
input_size = 17 * 2
def __init__(self, boxes, keypoints, image_path, output_path, args):
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
basename, _ = os.path.splitext(os.path.basename(image_path))
# Check for ground-truth file
try:
with open(args.path_gt, 'r') as f:
self.dic_names = json.load(f)
except FileNotFoundError:
self.dic_names = None
print('-' * 120)
print("Monoloco: ground truth file not found")
print('-' * 120)
self.boxes = boxes
self.keypoints = keypoints
self.image_path = image_path
@ -43,7 +40,6 @@ class PredictMonoLoco:
self.z_max = args.z_max
self.output_types = args.output_types
self.show = args.show
output_size = 2
self.n_samples = 100
self.n_dropout = args.n_dropout
if self.n_dropout > 0:
@ -52,43 +48,23 @@ class PredictMonoLoco:
self.epistemic = False
self.iou_min = 0.25
# load the model
input_size = 17 * 2
# self.model = TriLinear(input_size=input_size, output_size=output_size, p_dropout=dropout)
self.model = LinearModel(input_size=input_size, output_size=output_size)
# load the model parameters
self.model = LinearModel(input_size=self.input_size, output_size=self.output_size, linear_size=args.hidden_size)
self.model.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage))
self.model.eval() # Default is train
self.model.to(self.device)
# Import
from utils.camera import preprocess_single, get_keypoints, get_depth
self.preprocess_single = preprocess_single
self.get_keypoints = get_keypoints
self.get_depth = get_depth
from utils.misc import epistemic_variance, laplace_sampling, get_idx_max
self.epistemic_variance = epistemic_variance
self.laplace_sampling = laplace_sampling
self.get_idx_max = get_idx_max
from visuals.printer import Printer
self.Printer = Printer
from utils.normalize import unnormalize_bi
self.unnormalize_bi = unnormalize_bi
from utils.kitti import get_simplified_calibration, get_calibration
self.get_simplified_calibration = get_simplified_calibration
self.get_calibration = get_calibration
from utils.pifpaf import get_input_data
self.get_input_data = get_input_data
# Check for ground-truth file
try:
with open(args.path_gt, 'r') as f:
self.dic_names = json.load(f)
except FileNotFoundError:
self.dic_names = None
print('-' * 120 + "\nMonoloco: ground truth file not found\n" + '-' * 120)
def run(self):
cnt = 0
# Extract calibration matrix if ground-truth file is present or use a default one
cnt = 0
name = os.path.basename(self.image_path)
if self.dic_names:
kk = self.dic_names[name]['K']
@ -97,7 +73,7 @@ class PredictMonoLoco:
kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]]
(inputs_norm, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = \
self.get_input_data(self.boxes, self.keypoints, kk, left_to_right=True)
get_input_data(self.boxes, self.keypoints, kk, left_to_right=True)
# Conversion into torch tensor
if len(inputs_norm) > 0:
@ -113,8 +89,8 @@ class PredictMonoLoco:
if self.n_dropout > 0:
for ii in range(self.n_dropout):
outputs = self.model(inputs)
outputs = self.unnormalize_bi(outputs)
samples = self.laplace_sampling(outputs, self.n_samples)
outputs = unnormalize_bi(outputs)
samples = laplace_sampling(outputs, self.n_samples)
total_outputs = torch.cat((total_outputs, samples), 0)
varss = total_outputs.std(0)
else:
@ -124,7 +100,7 @@ class PredictMonoLoco:
start_single = time.time()
self.model.dropout.training = False
outputs = self.model(inputs)
outputs = self.unnormalize_bi(outputs)
outputs = unnormalize_bi(outputs)
end = time.time()
print("Total Forward pass time = {:.2f} ms".format((end-start) * 1000))
print("Single pass time = {:.2f} ms".format((end - start_single) * 1000))
@ -141,13 +117,11 @@ class PredictMonoLoco:
# Find the corresponding ground truth if available
if self.dic_names:
idx_max, iou_max = self.get_idx_max(box, boxes_gt)
idx_max, iou_max = get_idx_max(box, boxes_gt)
if iou_max > self.iou_min:
dd_real = dds_gt[idx_max]
boxes_gt.pop(idx_max)
dds_gt.pop(idx_max)
# In case of no matching
else:
dd_real = 0
@ -156,9 +130,8 @@ class PredictMonoLoco:
dd_real = dd_pred
uv_center = uv_centers[idx]
xyz_real = self.get_depth(uv_center, kk, dd_real)
xyz_pred = self.get_depth(uv_center, kk, dd_pred)
xyz_real = get_depth(uv_center, kk, dd_real)
xyz_pred = get_depth(uv_center, kk, dd_pred)
dic_out['boxes'].append(box)
dic_out['dds_real'].append(dd_real)
dic_out['dds_pred'].append(dd_pred)
@ -172,14 +145,12 @@ class PredictMonoLoco:
dic_out['uv_shoulders'].append(uv_shoulders[idx])
if any((xx in self.output_types for xx in ['front', 'bird', 'combined'])):
printer = self.Printer(self.image_path, self.output_path, dic_out, kk,
y_scale=self.y_scale, output_types=self.output_types,
show=self.show, z_max=self.z_max, epistemic=self.epistemic)
printer = Printer(self.image_path, self.output_path, dic_out, kk,
y_scale=self.y_scale, output_types=self.output_types,
show=self.show, z_max=self.z_max, epistemic=self.epistemic)
printer.print()
if 'json' in self.output_types:
with open(os.path.join(self.output_path + '.monoloco.json'), 'w') as ff:
json.dump(dic_out, ff)

View File

@ -90,12 +90,6 @@ def preprocess_single(kps, kk):
kps_orig.append(float(kp_orig[0]))
kps_orig.append(float(kp_orig[1]))
# Append the y of the ground foot to the keypoints
# kp_gr = np.array([0, vv_gr, 1])
# xy1_gr = pixel_to_camera(kp_gr, kk, 1)
# kps_0c.append(float(xy1_gr[1]))
return kps_0c, kps_orig
@ -198,5 +192,4 @@ def get_depth(uv_center, kk, dd):
zz = dd / math.sqrt(1 + xyz_norm[0] ** 2 + xyz_norm[1] ** 2)
xyz = pixel_to_camera(uv_center_np, kk, zz).tolist()
return xyz
return xyz