refactor predict (#1)

* add stereo arg (#1)

* stereo skeleton

* skeleton stereo (#2)

* skeleton stereo (#3)

* refactor for stereo

* stereo running

* modify stereo parameters

* remove warnings as errors

* add p2_right

* add factory file

* add p2_right

* add scaling factor to intrinsic matrix

* Refactor predict

* refactor predict 2

* temp

* add uppercase constants

* add uppercase constants

* working predict

* add todo

* Add person_sitting as option to uncomment

* turn off stereo

* remove stereo

* pylint corrections

* add break point

* temp

* fix small bug

* Add compatibility for pifpaf and monoloco networks

* pylint fix
This commit is contained in:
Lorenzo Bertoni 2019-06-20 15:43:38 +02:00 committed by GitHub
parent d941a5bdc7
commit 86438189a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 391 additions and 242 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 694 KiB

View File

@ -1,16 +1,22 @@
"""Run monoloco over all the pifpaf joints of KITTI images
and extract and save the annotations in txt files"""
import torch
import math import math
import numpy as np
import os import os
import glob import glob
import json import json
import logging import logging
import numpy as np
import torch
from models.architectures import LinearModel from models.architectures import LinearModel
from utils.misc import laplace_sampling from utils.misc import laplace_sampling
from utils.kitti import eval_geometric, get_calibration from utils.kitti import eval_geometric, get_calibration
from utils.normalize import unnormalize_bi from utils.normalize import unnormalize_bi
from utils.pifpaf import get_input_data, preprocess_pif from utils.pifpaf import get_input_data, preprocess_pif
from utils.camera import get_depth_from_distance
class RunKitti: class RunKitti:
@ -23,22 +29,17 @@ class RunKitti:
average_y = 0.48 average_y = 0.48
n_samples = 100 n_samples = 100
def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout, stereo=False): def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout):
# Set directories
assert dir_ann, "Annotations folder is required"
self.dir_ann = dir_ann self.dir_ann = dir_ann
self.n_dropout = n_dropout self.n_dropout = n_dropout
list_ann = glob.glob(os.path.join(dir_ann, '*.json'))
self.dir_kk = os.path.join('data', 'kitti', 'calib') self.dir_kk = os.path.join('data', 'kitti', 'calib')
self.dir_out = os.path.join('data', 'kitti', 'monoloco') self.dir_out = os.path.join('data', 'kitti', 'monoloco')
if not os.path.exists(self.dir_out): if not os.path.exists(self.dir_out):
os.makedirs(self.dir_out) os.makedirs(self.dir_out)
print("Created output directory for txt files") print("Created output directory for txt files")
self.list_basename = [os.path.basename(x).split('.')[0] for x in list_ann] self.list_basename = factory_basename(dir_ann)
assert self.list_basename, " Missing json annotations file to create txt files for KITTI datasets"
# Load the model # Load the model
input_size = 17 * 2 input_size = 17 * 2
@ -54,86 +55,131 @@ class RunKitti:
# Run inference # Run inference
for basename in self.list_basename: for basename in self.list_basename:
path_calib = os.path.join(self.dir_kk, basename + '.txt')
kk, tt = get_calibration(path_calib)
path_ann = os.path.join(self.dir_ann, basename + '.png.pifpaf.json') path_calib = os.path.join(self.dir_kk, basename + '.txt')
with open(path_ann, 'r') as f:
annotations = json.load(f) annotations, kk, tt, _ = factory_file(path_calib, self.dir_ann, basename)
boxes, keypoints = preprocess_pif(annotations) boxes, keypoints = preprocess_pif(annotations)
(inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = get_input_data(boxes, keypoints, kk) (inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = get_input_data(boxes, keypoints, kk)
dds_geom, xy_centers = eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48) dds_geom, xy_centers = eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48)
# Update counting
self.cnt_ann += len(boxes) self.cnt_ann += len(boxes)
if not inputs:
inputs = torch.from_numpy(np.array(inputs)).float().to(self.device)
if len(inputs) == 0:
self.cnt_no_file += 1 self.cnt_no_file += 1
else: else:
self.cnt_file += 1 self.cnt_file += 1
if self.n_dropout > 0: # Run the model
total_outputs = torch.empty((0, len(uv_boxes))).to(self.device) inputs = torch.from_numpy(np.array(inputs)).float().to(self.device)
self.model.dropout.training = True if self.n_dropout > 0:
for ii in range(self.n_dropout): total_outputs = torch.empty((0, len(uv_boxes))).to(self.device)
outputs = self.model(inputs) self.model.dropout.training = True
outputs = unnormalize_bi(outputs) for _ in range(self.n_dropout):
samples = laplace_sampling(outputs, self.n_samples) outputs = self.model(inputs)
total_outputs = torch.cat((total_outputs, samples), 0) outputs = unnormalize_bi(outputs)
varss = total_outputs.std(0) samples = laplace_sampling(outputs, self.n_samples)
total_outputs = torch.cat((total_outputs, samples), 0)
varss = total_outputs.std(0)
else: else:
varss = [0]*len(uv_boxes) varss = [0]*len(uv_boxes)
# Don't use dropout for the mean prediction and aleatoric uncertainty # Don't use dropout for the mean prediction and aleatoric uncertainty
self.model.dropout.training = False self.model.dropout.training = False
outputs_net = self.model(inputs) outputs_net = self.model(inputs)
outputs = outputs_net.cpu().detach().numpy() outputs = outputs_net.cpu().detach().numpy()
path_txt = os.path.join(self.dir_out, basename + '.txt') list_zzs = get_depth_from_distance(outputs, xy_centers)
with open(path_txt, "w+") as ff: all_outputs = [outputs, varss, dds_geom]
for idx in range(outputs.shape[0]): all_inputs = [uv_boxes, xy_centers, xy_kps]
xx_1 = float(xy_centers[idx][0]) all_params = [kk, tt]
yy_1 = float(xy_centers[idx][1])
xy_kp = xy_kps[idx]
dd = float(outputs[idx][0])
std_ale = math.exp(float(outputs[idx][1])) * dd
zz = dd / math.sqrt(1 + xx_1**2 + yy_1**2) # Save the file
xx_cam_0 = xx_1*zz + tt[0] # Still to verify the sign but negligible all_outputs.append(list_zzs)
yy_cam_0 = yy_1*zz + tt[1] path_txt = os.path.join(self.dir_out, basename + '.txt')
zz_cam_0 = zz + tt[2] save_txts(path_txt, all_inputs, all_outputs, all_params)
dd_cam_0 = math.sqrt(xx_cam_0**2 + yy_cam_0**2 + zz_cam_0**2) aa = 5
uv_box = uv_boxes[idx]
twodecimals = ["%.3f" % vv for vv in [uv_box[0], uv_box[1], uv_box[2], uv_box[3],
xx_cam_0, yy_cam_0, zz_cam_0, dd_cam_0,
std_ale, varss[idx], uv_box[4], dds_geom[idx]]]
keypoints_str = ["%.5f" % vv for vv in xy_kp]
for item in twodecimals:
ff.write("%s " % item)
for item in keypoints_str:
ff.write("%s " % item)
ff.write("\n")
# Save intrinsic matrix in the last row
kk_list = kk.reshape(-1,).tolist()
for kk_el in kk_list:
ff.write("%f " % kk_el)
ff.write("\n")
# Print statistics # Print statistics
print("Saved in {} txt {} annotations. Not found {} images" print("Saved in {} txt {} annotations. Not found {} images"
.format(self.cnt_file, self.cnt_ann, self.cnt_no_file)) .format(self.cnt_file, self.cnt_ann, self.cnt_no_file))
def save_txts(path_txt, all_inputs, all_outputs, all_params):
outputs, varss, dds_geom, zzs = all_outputs[:]
uv_boxes, xy_centers, xy_kps = all_inputs[:]
kk, tt = all_params[:]
with open(path_txt, "w+") as ff:
for idx in range(outputs.shape[0]):
xx_1 = float(xy_centers[idx][0])
yy_1 = float(xy_centers[idx][1])
xy_kp = xy_kps[idx]
dd = float(outputs[idx][0])
std_ale = math.exp(float(outputs[idx][1])) * dd
zz = zzs[idx]
xx_cam_0 = xx_1 * zz + tt[0]
yy_cam_0 = yy_1 * zz + tt[1]
zz_cam_0 = zz + tt[2]
dd_cam_0 = math.sqrt(xx_cam_0 ** 2 + yy_cam_0 ** 2 + zz_cam_0 ** 2)
uv_box = uv_boxes[idx]
twodecimals = ["%.3f" % vv for vv in [uv_box[0], uv_box[1], uv_box[2], uv_box[3],
xx_cam_0, yy_cam_0, zz_cam_0, dd_cam_0,
std_ale, varss[idx], uv_box[4], dds_geom[idx]]]
keypoints_str = ["%.5f" % vv for vv in xy_kp]
for item in twodecimals:
ff.write("%s " % item)
for item in keypoints_str:
ff.write("%s " % item)
ff.write("\n")
# Save intrinsic matrix in the last row
kk_list = kk.reshape(-1, ).tolist()
for kk_el in kk_list:
ff.write("%f " % kk_el)
ff.write("\n")
def factory_basename(dir_ann):
""" Return all the basenames in the annotations folder"""
list_ann = glob.glob(os.path.join(dir_ann, '*.json'))
list_basename = [os.path.basename(x).split('.')[0] for x in list_ann]
assert list_basename, " Missing json annotations file to create txt files for KITTI datasets"
return list_basename
def factory_file(path_calib, dir_ann, basename, ite=0):
"""Choose the annotation and the calibration files. Stereo option with ite = 1"""
stereo_file = True
p_left, p_right = get_calibration(path_calib)
if ite == 0:
kk, tt = p_left[:]
path_ann = os.path.join(dir_ann, basename + '.png.pifpaf.json')
else:
kk, tt = p_right[:]
path_ann = os.path.join(dir_ann + '_right', basename + '.png.pifpaf.json')
try:
with open(path_ann, 'r') as f:
annotations = json.load(f)
except FileNotFoundError:
annotations = None
if ite == 1:
stereo_file = False
return annotations, kk, tt, stereo_file

View File

@ -64,7 +64,8 @@ class PreprocessKitti:
# Extract keypoints # Extract keypoints
path_txt = os.path.join(self.dir_kk, basename + '.txt') path_txt = os.path.join(self.dir_kk, basename + '.txt')
kk, tt = get_calibration(path_txt) p_left, _ = get_calibration(path_txt)
kk = p_left[0]
# Iterate over each line of the gt file and save box location and distances # Iterate over each line of the gt file and save box location and distances
boxes_gt, dds_gt, _, _ = parse_ground_truth(path_gt) boxes_gt, dds_gt, _, _ = parse_ground_truth(path_gt)

View File

@ -9,7 +9,7 @@ from openpifpaf.network import nets
from openpifpaf import decoder from openpifpaf import decoder
from features.preprocess_nu import PreprocessNuscenes from features.preprocess_nu import PreprocessNuscenes
from features.preprocess_ki import PreprocessKitti from features.preprocess_ki import PreprocessKitti
from predict.predict_2d_3d import predict from predict.predict import predict
from models.trainer import Trainer from models.trainer import Trainer
from eval.run_kitti import RunKitti from eval.run_kitti import RunKitti
from eval.geom_baseline import GeomBaseline from eval.geom_baseline import GeomBaseline

91
src/predict/factory.py Normal file
View File

@ -0,0 +1,91 @@
import json
import os
from visuals.printer import Printer
from openpifpaf import show
from PIL import Image
def factory_for_gt(image, name=None, path_gt=None):
"""Look for ground-truth annotations file and define calibration matrix based on image size """
try:
with open(path_gt, 'r') as f:
dic_names = json.load(f)
print('-' * 120 + "\nMonoloco: Ground-truth file opened\n")
except FileNotFoundError:
print('-' * 120 + "\nMonoloco: ground-truth file not found\n")
dic_names = {}
try:
kk = dic_names[name]['K']
dic_gt = dic_names[name]
print("Monoloco: matched ground-truth file!\n" + '-' * 120)
except KeyError:
dic_gt = None
x_factor = image.size[0] / 1600
y_factor = image.size[1] / 900
pixel_factor = (x_factor + y_factor) / 2
if image.size[0] / image.size[1] > 2.5:
kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]] # Kitti calibration
else:
kk = [[1266.4 * pixel_factor, 0., 816.27 * x_factor],
[0, 1266.4 * pixel_factor, 491.5 * y_factor],
[0., 0., 1.]] # nuScenes calibration
print("Ground-truth annotations for the image not found\n"
"Using a standard calibration matrix...\n" + '-' * 120)
return kk, dic_gt
def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, monoloco_outputs=None, kk=None):
"""Output json files or images according to the choice"""
# Save json file
if 'pifpaf' in args.networks:
keypoint_sets, pifpaf_out, scores = pifpaf_outputs[:]
# Visualizer
keypoint_painter = show.KeypointPainter(show_box=True)
skeleton_painter = show.KeypointPainter(show_box=False, color_connections=True,
markersize=1, linewidth=4)
if 'json' in args.output_types and keypoint_sets.size > 0:
with open(output_path + '.pifpaf.json', 'w') as f:
json.dump(pifpaf_out, f)
if 'keypoints' in args.output_types:
with show.image_canvas(images_outputs[0],
output_path + '.keypoints.png',
show=args.show,
fig_width=args.figure_width,
dpi_factor=args.dpi_factor) as ax:
keypoint_painter.keypoints(ax, keypoint_sets)
if 'skeleton' in args.output_types:
with show.image_canvas(images_outputs[0],
output_path + '.skeleton.png',
show=args.show,
fig_width=args.figure_width,
dpi_factor=args.dpi_factor) as ax:
skeleton_painter.keypoints(ax, keypoint_sets, scores=scores)
if 'monoloco' in args.networks:
if any((xx in args.output_types for xx in ['front', 'bird', 'combined'])):
epistemic = False
if args.n_dropout > 0:
epistemic = True
printer = Printer(images_outputs[1], output_path, monoloco_outputs, kk, output_types=args.output_types,
show=args.show, z_max=args.z_max, epistemic=epistemic)
printer.print()
if 'json' in args.output_types:
with open(os.path.join(args.output_path + '.monoloco.json'), 'w') as ff:
json.dump(monoloco_outputs, ff)

View File

@ -1,65 +1,51 @@
""" """
From a json file output images and json annotations Monoloco predictor. It receives pifpaf joints and outputs distances
""" """
import sys
from collections import defaultdict from collections import defaultdict
import os
import json
import logging import logging
import time import time
import numpy as np import numpy as np
import torch import torch
from PIL import Image
from models.architectures import LinearModel from models.architectures import LinearModel
from visuals.printer import Printer
from utils.camera import get_depth from utils.camera import get_depth
from utils.misc import laplace_sampling, get_idx_max from utils.misc import laplace_sampling, get_idx_max
from utils.normalize import unnormalize_bi from utils.normalize import unnormalize_bi
from utils.pifpaf import get_input_data from utils.pifpaf import get_input_data
class PredictMonoLoco: class MonoLoco:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
output_size = 2 OUTPUT_SIZE = 2
input_size = 17 * 2 INPUT_SIZE = 17 * 2
LINEAR_SIZE = 256
IOU_MIN = 0.25
N_SAMPLES = 100
def __init__(self, boxes, keypoints, image_path, output_path, args): def __init__(self, model, device, n_dropout=0):
self.boxes = boxes
self.keypoints = keypoints self.device = device
self.image_path = image_path self.n_dropout = n_dropout
self.output_path = output_path
self.device = args.device
self.draw_kps = args.draw_kps
self.z_max = args.z_max
self.output_types = args.output_types
self.path_gt = args.path_gt
self.show = args.show
self.n_samples = 100
self.n_dropout = args.n_dropout
if self.n_dropout > 0: if self.n_dropout > 0:
self.epistemic = True self.epistemic = True
else: else:
self.epistemic = False self.epistemic = False
self.iou_min = 0.25
# load the model parameters # load the model parameters
self.model = LinearModel(input_size=self.input_size, output_size=self.output_size, linear_size=args.hidden_size) self.model = LinearModel(input_size=self.INPUT_SIZE, output_size=self.OUTPUT_SIZE, linear_size=self.LINEAR_SIZE)
self.model.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage)) self.model.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
self.model.eval() # Default is train self.model.eval() # Default is train
self.model.to(self.device) self.model.to(self.device)
def run(self): def forward(self, boxes, keypoints, kk, dic_gt=None):
# Extract calibration matrix if ground-truth file is present or use a default one
cnt = 0
dic_names, kk = factory_for_gt(self.path_gt, self.image_path)
(inputs_norm, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = \ (inputs_norm, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = \
get_input_data(self.boxes, self.keypoints, kk, left_to_right=True) get_input_data(boxes, keypoints, kk, left_to_right=True)
# Conversion into torch tensor # Conversion into torch tensor
if inputs_norm: if inputs_norm:
@ -76,7 +62,7 @@ class PredictMonoLoco:
for _ in range(self.n_dropout): for _ in range(self.n_dropout):
outputs = self.model(inputs) outputs = self.model(inputs)
outputs = unnormalize_bi(outputs) outputs = unnormalize_bi(outputs)
samples = laplace_sampling(outputs, self.n_samples) samples = laplace_sampling(outputs, self.N_SAMPLES)
total_outputs = torch.cat((total_outputs, samples), 0) total_outputs = torch.cat((total_outputs, samples), 0)
varss = total_outputs.std(0) varss = total_outputs.std(0)
else: else:
@ -92,11 +78,10 @@ class PredictMonoLoco:
.format(self.n_dropout, (end-start) * 1000)) .format(self.n_dropout, (end-start) * 1000))
print("Single forward pass time = {:.2f} ms".format((end - start_single) * 1000)) print("Single forward pass time = {:.2f} ms".format((end - start_single) * 1000))
# Print image and save json # Create output files
dic_out = defaultdict(list) dic_out = defaultdict(list)
if dic_names: if dic_gt:
name = os.path.basename(self.image_path) boxes_gt, dds_gt = dic_gt['boxes'], dic_gt['dds']
boxes_gt, dds_gt = dic_names[name]['boxes'], dic_names[name]['dds']
for idx, box in enumerate(uv_boxes): for idx, box in enumerate(uv_boxes):
dd_pred = float(outputs[idx][0]) dd_pred = float(outputs[idx][0])
@ -104,9 +89,9 @@ class PredictMonoLoco:
var_y = float(varss[idx]) var_y = float(varss[idx])
# Find the corresponding ground truth if available # Find the corresponding ground truth if available
if dic_names: if dic_gt:
idx_max, iou_max = get_idx_max(box, boxes_gt) idx_max, iou_max = get_idx_max(box, boxes_gt)
if iou_max > self.iou_min: if iou_max > self.IOU_MIN:
dd_real = dds_gt[idx_max] dd_real = dds_gt[idx_max]
boxes_gt.pop(idx_max) boxes_gt.pop(idx_max)
dds_gt.pop(idx_max) dds_gt.pop(idx_max)
@ -132,42 +117,4 @@ class PredictMonoLoco:
dic_out['uv_centers'].append(uv_center) dic_out['uv_centers'].append(uv_center)
dic_out['uv_shoulders'].append(uv_shoulders[idx]) dic_out['uv_shoulders'].append(uv_shoulders[idx])
if any((xx in self.output_types for xx in ['front', 'bird', 'combined'])): return dic_out
printer = Printer(self.image_path, self.output_path, dic_out, kk, output_types=self.output_types,
show=self.show, z_max=self.z_max, epistemic=self.epistemic)
printer.print()
if 'json' in self.output_types:
with open(os.path.join(self.output_path + '.monoloco.json'), 'w') as ff:
json.dump(dic_out, ff)
sys.stdout.write('\r' + 'Saving image {}'.format(cnt) + '\t')
def factory_for_gt(path_gt, image_path):
"""Look for ground-truth annotations file and define calibration matrix based on image size """
try:
with open(path_gt, 'r') as f:
dic_names = json.load(f)
print('-' * 120 + "\nMonoloco: Ground-truth file opened\n")
except FileNotFoundError:
print('-' * 120 + "\nMonoloco: ground-truth file not found\n")
dic_names = {}
try:
name = os.path.basename(image_path)
kk = dic_names[name]['K']
print("Monoloco: matched ground-truth file!\n" + '-' * 120)
except KeyError:
dic_names = None
with open(image_path, 'rb') as f:
im = Image.open(f)
if im.size[0] / im.size[1] > 2.5:
kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]] # Kitti calibration
else:
kk = [[1266.4, 0., 816.27], [0, 1266.4, 491.5], [0., 0., 1.]] # Nuscenes calibration
print("Ground-truth annotations for the image not found\n"
"Using a standard calibration matrix...\n" + '-' * 120)
return dic_names, kk

View File

@ -1,26 +1,26 @@
import glob import glob
import json
import os import os
import sys
import numpy as np import numpy as np
from openpifpaf.network import nets
from openpifpaf import decoder, show
from openpifpaf import transforms
from predict.predict_monoloco import PredictMonoLoco
from utils.pifpaf import preprocess_pif
import torchvision import torchvision
import torch import torch
from PIL import Image, ImageFile from PIL import Image, ImageFile
from openpifpaf.network import nets
from openpifpaf import decoder
from openpifpaf import transforms
from predict.monoloco import MonoLoco
from predict.factory import factory_for_gt, factory_outputs
from utils.pifpaf import preprocess_pif
class ImageList(torch.utils.data.Dataset): class ImageList(torch.utils.data.Dataset):
"""It defines transformations to apply to images and outputs of the dataloader"""
def __init__(self, image_paths, scale, image_transform=None): def __init__(self, image_paths, scale, image_transform=None):
self.image_paths = image_paths self.image_paths = image_paths
self.image_transform = image_transform or transforms.image_transform self.image_transform = image_transform or transforms.image_transform # to_tensor + normalize (from pifpaf)
self.scale = scale self.scale = scale
# data = datasets.ImageList(args.images, preprocess=transforms.RescaleRelative(2 # data = datasets.ImageList(args.images, preprocess=transforms.RescaleRelative(2
@ -37,7 +37,8 @@ class ImageList(torch.utils.data.Dataset):
(round(self.scale * image.size[1]), (round(self.scale * image.size[1]),
round(self.scale * image.size[0])), round(self.scale * image.size[0])),
interpolation=Image.BICUBIC) interpolation=Image.BICUBIC)
original_image = torchvision.transforms.functional.to_tensor(image) # PIL images are not iterables
original_image = torchvision.transforms.functional.to_tensor(image) # 0-255 --> 0-1
image = self.image_transform(image) image = self.image_transform(image)
return image_path, original_image, image return image_path, original_image, image
@ -76,12 +77,16 @@ def factory_from_args(args):
def predict(args): def predict(args):
cnt = 0
factory_from_args(args) factory_from_args(args)
# load model # load pifpaf model
model, _ = nets.factory_from_args(args) model_pifpaf, _ = nets.factory_from_args(args)
model = model.to(args.device) model_pifpaf = model_pifpaf.to(args.device)
processor = decoder.factory_from_args(args, model) processor = decoder.factory_from_args(args, model_pifpaf)
# load monoloco
monoloco = MonoLoco(model=args.model, device=args.device, n_dropout=args.n_dropout)
# data # data
data = ImageList(args.images, scale=args.scale) data = ImageList(args.images, scale=args.scale)
@ -89,11 +94,6 @@ def predict(args):
data, batch_size=1, shuffle=False, data, batch_size=1, shuffle=False,
pin_memory=args.pin_memory, num_workers=args.loader_workers) pin_memory=args.pin_memory, num_workers=args.loader_workers)
# Visualizer
keypoint_painter = show.KeypointPainter(show_box=True)
skeleton_painter = show.KeypointPainter(show_box=False, color_connections=True,
markersize=1, linewidth=4)
keypoints_whole = [] keypoints_whole = []
for idx, (image_paths, image_tensors, processed_images_cpu) in enumerate(data_loader): for idx, (image_paths, image_tensors, processed_images_cpu) in enumerate(data_loader):
images = image_tensors.permute(0, 2, 3, 1) images = image_tensors.permute(0, 2, 3, 1)
@ -121,45 +121,42 @@ def predict(args):
# Correct to not change the confidence # Correct to not change the confidence
scale_np = np.array([args.scale, args.scale, 1] * 17).reshape(17, 3) scale_np = np.array([args.scale, args.scale, 1] * 17).reshape(17, 3)
if keypoint_sets.size > 0:
keypoints_whole.append(np.around((keypoint_sets / scale_np), 1)
.reshape(keypoint_sets.shape[0], -1).tolist())
pifpaf_out = [ pifpaf_out = [
{'keypoints': np.around(kps / scale_np, 1).reshape(-1).tolist(), {'keypoints': np.around(kps / scale_np, 1).reshape(-1).tolist(),
'bbox': [np.min(kps[:, 0]) / args.scale, np.min(kps[:, 1]) / args.scale, 'bbox': [np.min(kps[:, 0]) / args.scale, np.min(kps[:, 1]) / args.scale,
np.max(kps[:, 0]) / args.scale, np.max(kps[:, 1]) / args.scale]} np.max(kps[:, 0]) / args.scale, np.max(kps[:, 1]) / args.scale]}
for kps in keypoint_sets for kps in keypoint_sets
] ]
pifpaf_outputs = [keypoint_sets, scores, pifpaf_out] # keypoints_sets and scores for pifpaf printing
# Save json file images_outputs = [image] # List of 1 or 2 elements with pifpaf tensor (resized) and monoloco original image
if 'pifpaf' in args.networks:
if 'json' in args.output_types and keypoint_sets.size > 0:
with open(output_path + '.pifpaf.json', 'w') as f:
json.dump(pifpaf_out, f)
if keypoint_sets.size > 0:
keypoints_whole.append(np.around((keypoint_sets / scale_np), 1)
.reshape(keypoint_sets.shape[0], -1).tolist())
if 'keypoints' in args.output_types:
with show.image_canvas(image,
output_path + '.keypoints.png',
show=args.show,
fig_width=args.figure_width,
dpi_factor=args.dpi_factor) as ax:
keypoint_painter.keypoints(ax, keypoint_sets)
if 'skeleton' in args.output_types:
with show.image_canvas(image,
output_path + '.skeleton.png',
show=args.show,
fig_width=args.figure_width,
dpi_factor=args.dpi_factor) as ax:
skeleton_painter.keypoints(ax, keypoint_sets, scores=scores)
if 'monoloco' in args.networks: if 'monoloco' in args.networks:
im_size = (float(image.size()[1] / args.scale), im_size = (float(image.size()[1] / args.scale),
float(image.size()[0] / args.scale)) # Width, Height (original) float(image.size()[0] / args.scale)) # Width, Height (original)
boxes, keypoints = preprocess_pif(pifpaf_out, im_size)
predict_monoloco = PredictMonoLoco(boxes, keypoints, image_path, output_path, args)
predict_monoloco.run()
# Extract calibration matrix and ground truth file if present
with open(image_path, 'rb') as f:
pil_image = Image.open(f).convert('RGB')
images_outputs.append(pil_image)
im_name = os.path.basename(image_path)
kk, _ = factory_for_gt(image, name=im_name, path_gt=args.path_gt)
# Preprocess pifpaf outputs and run monoloco
boxes, keypoints = preprocess_pif(pifpaf_out, im_size)
monoloco_outputs = monoloco.forward(boxes, keypoints, kk)
else:
monoloco_outputs = None
kk = None
factory_outputs(args, images_outputs, output_path, pifpaf_outputs, monoloco_outputs=monoloco_outputs, kk=kk)
sys.stdout.write('\r' + 'Saving image {}'.format(cnt) + '\t')
cnt += 1
return keypoints_whole return keypoints_whole

View File

@ -193,3 +193,15 @@ def get_depth(uv_center, kk, dd):
xyz = pixel_to_camera(uv_center_np, kk, zz).tolist() xyz = pixel_to_camera(uv_center_np, kk, zz).tolist()
return xyz return xyz
def get_depth_from_distance(outputs, xy_centers):
list_zzs = []
for idx, _ in enumerate(outputs):
dd = float(outputs[idx][0])
xx_1 = float(xy_centers[idx][0])
yy_1 = float(xy_centers[idx][1])
zz = dd / math.sqrt(1 + xx_1 ** 2 + yy_1 ** 2)
list_zzs.append(zz)
return list_zzs

View File

@ -60,14 +60,24 @@ def get_calibration(path_txt):
p2_list = [float(xx) for xx in p2_str] p2_list = [float(xx) for xx in p2_str]
p2 = np.array(p2_list).reshape(3, 4) p2 = np.array(p2_list).reshape(3, 4)
kk = p2[:, :-1] p3_str = file[3].split()[1:]
p3_list = [float(xx) for xx in p3_str]
p3 = np.array(p3_list).reshape(3, 4)
kk, tt = get_translation(p2)
kk_right, tt_right = get_translation(p3)
return [kk, tt], [kk_right, tt_right]
def get_translation(pp):
"""Separate intrinsic matrix from translation"""
kk = pp[:, :-1]
f_x = kk[0, 0] f_x = kk[0, 0]
f_y = kk[1, 1] f_y = kk[1, 1]
x0 = kk[2, 0] x0, y0 = kk[2, 0:2]
y0 = kk[2, 1] aa, bb, t3 = pp[0:3, 3]
aa = p2[0, 3]
bb = p2[1, 3]
t3 = p2[2, 3]
t1 = (aa - x0*t3) / f_x t1 = (aa - x0*t3) / f_x
t2 = (bb - y0*t3) / f_y t2 = (bb - y0*t3) / f_y
tt = np.array([t1, t2, t3]).reshape(3, 1) tt = np.array([t1, t2, t3]).reshape(3, 1)
@ -102,6 +112,7 @@ def check_conditions(line, mode, thresh=0.5):
check = True check = True
elif mode == 'gt': elif mode == 'gt':
# if line[:10] == 'Pedestrian' or line[:10] == 'Person_sit':
if line[:10] == 'Pedestrian': if line[:10] == 'Pedestrian':
check = True check = True

View File

@ -204,4 +204,3 @@ def append_cluster(dic_jo, phase, xx, dd, kps):
dic_jo[phase]['clst']['>30']['X'].append(xx) dic_jo[phase]['clst']['>30']['X'].append(xx)
dic_jo[phase]['clst']['>30']['Y'].append([dd]) dic_jo[phase]['clst']['>30']['Y'].append([dd])

49
src/utils/stereo.py Normal file
View File

@ -0,0 +1,49 @@
import copy
import numpy as np
def depth_from_disparity(zzs, zzs_right, kps, kps_right):
"""Associate instances in left and right images and compute disparity"""
zzs_stereo = []
cnt = 0
for idx, zz in enumerate(zzs):
# Find the closest human in terms of distance
zz_stereo, idx_min, delta_d_min = calculate_disparity(zz, zzs_right, kps[idx], kps_right)
if delta_d_min < 1:
zzs_stereo.append(zz_stereo)
zzs_right.pop(idx_min)
kps_right.pop(idx_min)
cnt += 1
else:
zzs_stereo.append(zz)
return zzs_stereo, cnt
def calculate_disparity(zz, zzs_right, kp, kps_right):
"""From 2 sets of keypoints calculate disparity as the median of the disparities"""
kp = np.array(copy.deepcopy(kp))
kps_right = np.array(copy.deepcopy(kps_right))
zz_stereo = 0
idx_min = 0
delta_z_min = 4
for idx, zz_right in enumerate(zzs_right):
delta_z = abs(zz - zz_right)
diffs = np.array(np.array(kp[0] - kps_right[idx][0]))
diff = np.mean(diffs)
# Check only for right instances (5 pxls = 80meters)
if delta_z < delta_z_min and diff > 5:
delta_z_min = delta_z
idx_min = idx
zzs = 0.54 * 721 / diffs
zz_stereo = np.median(zzs[kp[2] > 0])
return zz_stereo, idx_min, delta_z_min

View File

@ -9,17 +9,23 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.patches import Ellipse, Circle from matplotlib.patches import Ellipse, Circle
import cv2 import cv2
from collections import OrderedDict from collections import OrderedDict
from PIL import Image from utils.camera import pixel_to_camera
class Printer: class Printer:
""" """
Print results on images: birds eye view and computed distance Print results on images: birds eye view and computed distance
""" """
RADIUS_KPS = 6
FONTSIZE_BV = 16
FONTSIZE = 18
TEXTCOLOR = 'darkorange'
COLOR_KPS = 'yellow'
def __init__(self, image_path, output_path, dic_ann, kk, output_types, show=False, def __init__(self, image, output_path, dic_ann, kk, output_types, show=False,
draw_kps=False, text=True, legend=True, epistemic=False, z_max=30, fig_width=10): draw_kps=False, text=True, legend=True, epistemic=False, z_max=30, fig_width=10):
self.im = image
self.kk = kk self.kk = kk
self.output_types = output_types self.output_types = output_types
self.show = show self.show = show
@ -27,13 +33,9 @@ class Printer:
self.text = text self.text = text
self.epistemic = epistemic self.epistemic = epistemic
self.legend = legend self.legend = legend
self.z_max = z_max # To include ellipses in the image self.z_max = z_max # To include ellipses in the image
self.fig_width = fig_width self.fig_width = fig_width
from utils.camera import pixel_to_camera, get_depth
self.pixel_to_camera = pixel_to_camera
self.get_depth = get_depth
# Define the output dir # Define the output dir
self.path_out = output_path self.path_out = output_path
@ -52,12 +54,10 @@ class Printer:
self.uv_shoulders = dic_ann['uv_shoulders'] self.uv_shoulders = dic_ann['uv_shoulders']
self.uv_kps = dic_ann['uv_kps'] self.uv_kps = dic_ann['uv_kps']
# Load the image
with open(image_path, 'rb') as f:
self.im = Image.open(f).convert('RGB')
self.uv_camera = (int(self.im.size[0] / 2), self.im.size[1]) self.uv_camera = (int(self.im.size[0] / 2), self.im.size[1])
self.ww = self.im.size[0]
self.hh = self.im.size[1] self.hh = self.im.size[1]
self.radius = 14 / 1600 * self.ww
def print(self): def print(self):
""" """
@ -66,22 +66,15 @@ class Printer:
Either front and/or bird visualization or combined one Either front and/or bird visualization or combined one
""" """
# Parameters # Parameters
radius = 14
radius_kps = 6
fontsize_bv = 16
fontsize = 18
textcolor = 'darkorange'
color_kps = 'yellow'
# Resize image for aesthetic proportions in combined visualization # Resize image for aesthetic proportions in combined visualization
if 'combined' in self.output_types: if 'combined' in self.output_types:
ww = self.im.size[0] y_scale = self.ww / (self.hh * 1.8) # Defined proportion
hh = self.im.size[1] self.im = self.im.resize((self.ww, round(self.hh * y_scale)))
y_scale = ww / (hh * 1.8) # Defined proportion self.ww = self.im.size[0]
self.im = self.im.resize((ww, round(hh * y_scale))) self.hh = self.im.size[1]
print(y_scale) fig_width = self.fig_width + 0.6 * self.fig_width
width = self.fig_width + 0.6 * self.fig_width fig_height = self.fig_width * self.hh / self.ww
height = self.fig_width * self.im.size[1] / self.im.size[0]
# Distinguish between KITTI images and general images # Distinguish between KITTI images and general images
if y_scale > 1.7: if y_scale > 1.7:
@ -92,7 +85,7 @@ class Printer:
ext = '.combined.png' ext = '.combined.png'
fig, (ax1, ax0) = plt.subplots(1, 2, sharey=False, gridspec_kw={'width_ratios': [1, width_ratio]}, fig, (ax1, ax0) = plt.subplots(1, 2, sharey=False, gridspec_kw={'width_ratios': [1, width_ratio]},
figsize=(width, height)) figsize=(fig_width, fig_height))
ax1.set_aspect(fig_ar_1) ax1.set_aspect(fig_ar_1)
fig.set_tight_layout(True) fig.set_tight_layout(True)
fig.subplots_adjust(left=0.02, right=0.98, bottom=0, top=1, hspace=0, wspace=0.02) fig.subplots_adjust(left=0.02, right=0.98, bottom=0, top=1, hspace=0, wspace=0.02)
@ -104,7 +97,7 @@ class Printer:
elif 'front' in self.output_types: elif 'front' in self.output_types:
y_scale = 1 y_scale = 1
width = self.fig_width width = self.fig_width
height = self.fig_width * self.im.size[1] / self.im.size[0] height = self.fig_width * self.hh / self.ww
plt.figure(0) plt.figure(0)
fig0, ax0 = plt.subplots(1, 1, figsize=(width, height)) fig0, ax0 = plt.subplots(1, 1, figsize=(width, height))
@ -114,8 +107,8 @@ class Printer:
if any(xx in self.output_types for xx in ['front', 'combined']): if any(xx in self.output_types for xx in ['front', 'combined']):
ax0.set_axis_off() ax0.set_axis_off()
ax0.set_xlim(0, self.im.size[0]) ax0.set_xlim(0, self.ww)
ax0.set_ylim(self.im.size[1], 0) ax0.set_ylim(self.hh, 0)
ax0.imshow(self.im) ax0.imshow(self.im)
z_min = 0 z_min = 0
bar_ticks = self.z_max // 5 + 1 bar_ticks = self.z_max // 5 + 1
@ -125,16 +118,16 @@ class Printer:
for idx, uv in enumerate(self.uv_shoulders): for idx, uv in enumerate(self.uv_shoulders):
if self.draw_kps: if self.draw_kps:
ax0 = self.show_kps(ax0, self.uv_kps[idx], y_scale, radius_kps, color_kps) ax0 = self.show_kps(ax0, self.uv_kps[idx], y_scale, self.RADIUS_KPS, self.COLOR_KPS)
elif min(self.zz_pred[idx], self.zz_gt[idx]) > 0: elif min(self.zz_pred[idx], self.zz_gt[idx]) > 0:
color = cmap((self.zz_pred[idx] % self.z_max) / self.z_max) color = cmap((self.zz_pred[idx] % self.z_max) / self.z_max)
circle = Circle((uv[0], uv[1] * y_scale), radius=radius, color=color, fill=True) circle = Circle((uv[0], uv[1] * y_scale), radius=self.radius, color=color, fill=True)
ax0.add_patch(circle) ax0.add_patch(circle)
if self.text: if self.text:
ax0.text(uv[0]+radius, uv[1] * y_scale - radius, str(num), ax0.text(uv[0]+self.radius, uv[1] * y_scale - self.radius, str(num),
fontsize=fontsize, color=textcolor, weight='bold') fontsize=self.FONTSIZE, color=self.TEXTCOLOR, weight='bold')
num += 1 num += 1
ax0.get_xaxis().set_visible(False) ax0.get_xaxis().set_visible(False)
@ -166,7 +159,7 @@ class Printer:
# Create bird or combine it with front) # Create bird or combine it with front)
if any(xx in self.output_types for xx in ['bird', 'combined']): if any(xx in self.output_types for xx in ['bird', 'combined']):
uv_max = np.array([0, self.hh, 1]) uv_max = np.array([0, self.hh, 1])
xyz_max = self.pixel_to_camera(uv_max, self.kk, self.z_max) xyz_max = pixel_to_camera(uv_max, self.kk, self.z_max)
x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk
for idx, _ in enumerate(self.xx_gt): for idx, _ in enumerate(self.xx_gt):
@ -189,8 +182,8 @@ class Printer:
height=1, angle=angle, color='b', fill=False, label="Aleatoric Uncertainty", height=1, angle=angle, color='b', fill=False, label="Aleatoric Uncertainty",
linewidth=1.3) linewidth=1.3)
ellipse_var = Ellipse((self.xx_pred[idx], self.zz_pred[idx]), width=self.stds_ale_epi[idx] * 2, ellipse_var = Ellipse((self.xx_pred[idx], self.zz_pred[idx]), width=self.stds_ale_epi[idx] * 2,
height=1, angle=angle, color='r', fill=False, label="Uncertainty", linewidth=1, height=1, angle=angle, color='r', fill=False, label="Uncertainty",
linestyle='--') linewidth=1, linestyle='--')
ax1.add_patch(ellipse_ale) ax1.add_patch(ellipse_ale)
if self.epistemic: if self.epistemic:
@ -203,7 +196,7 @@ class Printer:
(_, x_pos), (_, z_pos) = get_confidence(self.xx_pred[idx], self.zz_pred[idx], self.stds_ale_epi[idx]) (_, x_pos), (_, z_pos) = get_confidence(self.xx_pred[idx], self.zz_pred[idx], self.stds_ale_epi[idx])
if self.text: if self.text:
ax1.text(x_pos, z_pos, str(num), fontsize=fontsize_bv, color='darkorange') ax1.text(x_pos, z_pos, str(num), fontsize=self.FONTSIZE_BV, color='darkorange')
num += 1 num += 1
# To avoid repetitions in the legend # To avoid repetitions in the legend
@ -219,6 +212,10 @@ class Printer:
ax1.set_xlabel("X [m]") ax1.set_xlabel("X [m]")
ax1.set_ylabel("Z [m]") ax1.set_ylabel("Z [m]")
# TO remove axis numbers
# plt.setp([ax1.get_yticklabels() for aa in fig.axes[:-1]], visible=False)
# plt.setp([ax1.get_xticklabels() for aa in fig.axes[:-1]], visible=False)
if self.show: if self.show:
plt.show() plt.show()
else: else:
@ -227,9 +224,7 @@ class Printer:
if self.draw_kps: if self.draw_kps:
im = cv2.imread(self.path_out + ext) im = cv2.imread(self.path_out + ext)
im = self.increase_brightness(im, value=30) im = self.increase_brightness(im, value=30)
hh = im.size[1] im_new = im[0 : self.hh, 0:round(self.ww / 1.7)]
ww = im.size[0]
im_new = im[0:hh, 0:round(ww/1.7)]
cv2.imwrite(self.path_out, im_new) cv2.imwrite(self.path_out, im_new)
plt.close('all') plt.close('all')
@ -243,7 +238,8 @@ class Printer:
return ax0 return ax0
def increase_brightness(self, img, value=30): @staticmethod
def increase_brightness(img, value=30):
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(hsv) h, s, v = cv2.split(hsv)