make utils in torch and remove redundant functions (#3)

* Add precision metrics

* add mode gt_all and change default threshold

* add cyclists

* add iou matrix'
'

* add cyclists only in training phase

* add dropout in model name

* small typos

* small typo

* fix error on uv_boxes

* change default mode from gt_ped to gt

* 2 decimals

* fix name bug

* refactor prepare_pif_kps

* corrected get_keypoints_batch

* add pixel to camera for 3d vectors

* preprocessing in torch

* return original outputs

* Skeleton for post_process

* baseline version for post processing

* add keypoints torch in post_processing

* cleaning misc

* add reorder_matches

* update preprocess with get_iou_matches

* fix indices

* remove aa

* temp

* skeleton kitti_generate

* skeleton kitti_generate (2)

* refactor file

* remove old get_input_data

* refactor geometric eval

* refactor geometric eval(2)

* temp

* refactor geometric

* change saving order for txts

* update pixel to camera

* update depth

* Fix pixel to camera

* add xyz_from_distance

* use new function

* fix std_ale calculation in eval

* remove debug points
This commit is contained in:
Lorenzo Bertoni 2019-06-28 18:33:58 +02:00 committed by GitHub
parent 2aea30cb7d
commit 019b6b0fad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 579 additions and 778 deletions

View File

@ -190,7 +190,7 @@ in txt file with format comparable to other baseline.
Then the model performs evaluation.
The following graph is obtained running:
`python3 src/main.py eval --dataset kitti --run_kitti --model data/models/monoloco-190513-1437.pkl
`python3 src/main.py eval --dataset kitti --generate --model data/models/monoloco-190513-1437.pkl
--dir_ann <folder containing pifpaf annotations of KITTI images>`
![kitti_evaluation](docs/results.png)

159
src/eval/generate_kitti.py Normal file
View File

@ -0,0 +1,159 @@
"""Run monoloco over all the pifpaf joints of KITTI images
and extract and save the annotations in txt files"""
import math
import os
import glob
import json
import shutil
import itertools
import numpy as np
import torch
from predict.monoloco import MonoLoco
from eval.geom_baseline import compute_distance
from utils.kitti import get_calibration
from utils.pifpaf import preprocess_pif
from utils.camera import xyz_from_distance, get_keypoints, pixel_to_camera
def generate_kitti(model, dir_ann, p_dropout=0.2, n_dropout=0):
cnt_ann = 0
cnt_file = 0
cnt_no_file = 0
dir_kk = os.path.join('data', 'kitti', 'calib')
dir_out = os.path.join('data', 'kitti', 'monoloco')
# Remove the output directory if alreaady exists (avoid residual txt files)
if os.path.exists(dir_out):
shutil.rmtree(dir_out)
os.makedirs(dir_out)
print("Created empty output directory for txt files")
# Load monoloco
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
monoloco = MonoLoco(model_path=model, device=device, n_dropout=n_dropout, p_dropout=p_dropout)
# Run monoloco over the list of images
list_basename = factory_basename(dir_ann)
for basename in list_basename:
path_calib = os.path.join(dir_kk, basename + '.txt')
annotations, kk, tt, _ = factory_file(path_calib, dir_ann, basename)
boxes, keypoints = preprocess_pif(annotations, im_size=(1242, 374))
if not keypoints:
cnt_no_file += 1
else:
# Run the network and the geometric baseline
outputs, varss = monoloco.forward(keypoints, kk)
dds_geom = eval_geometric(keypoints, kk, average_y=0.48)
# Save the file
all_outputs = [outputs.detach().cpu(), varss.detach().cpu(), dds_geom]
all_inputs = [boxes, keypoints]
all_params = [kk, tt]
path_txt = os.path.join(dir_out, basename + '.txt')
save_txts(path_txt, all_inputs, all_outputs, all_params)
# Update counting
cnt_ann += len(boxes)
cnt_file += 1
# Print statistics
print("Saved in {} txt {} annotations. Not found {} images"
.format(cnt_file, cnt_ann, cnt_no_file))
def save_txts(path_txt, all_inputs, all_outputs, all_params):
outputs, varss, dds_geom = all_outputs[:]
uv_boxes, keypoints = all_inputs[:]
kk, tt = all_params[:]
uv_centers = get_keypoints(keypoints, mode='center')
xy_centers = pixel_to_camera(uv_centers, kk, 1)
zzs = xyz_from_distance(outputs[:, 0:1], xy_centers)[:, 2].tolist()
with open(path_txt, "w+") as ff:
for idx in range(outputs.shape[0]):
xx = float(xy_centers[idx][0]) * zzs[idx] + tt[0]
yy = float(xy_centers[idx][1]) * zzs[idx] + tt[1]
zz = zzs[idx] + tt[2]
dd = math.sqrt(xx ** 2 + yy ** 2 + zz ** 2)
cam_0 = [xx, yy, zz, dd]
for el in uv_boxes[idx][:]:
ff.write("%s " % el)
for el in cam_0:
ff.write("%s " % el)
ff.write("%s " % float(outputs[idx][1]))
ff.write("%s " % float(varss[idx]))
ff.write("%s " % dds_geom[idx])
ff.write("\n")
# Save intrinsic matrix in the last row
for kk_el in itertools.chain(*kk): # Flatten a list of lists
ff.write("%f " % kk_el)
ff.write("\n")
def factory_basename(dir_ann):
""" Return all the basenames in the annotations folder"""
list_ann = glob.glob(os.path.join(dir_ann, '*.json'))
list_basename = [os.path.basename(x).split('.')[0] for x in list_ann]
assert list_basename, " Missing json annotations file to create txt files for KITTI datasets"
return list_basename
def factory_file(path_calib, dir_ann, basename, ite=0):
"""Choose the annotation and the calibration files. Stereo option with ite = 1"""
stereo_file = True
p_left, p_right = get_calibration(path_calib)
if ite == 0:
kk, tt = p_left[:]
path_ann = os.path.join(dir_ann, basename + '.png.pifpaf.json')
else:
kk, tt = p_right[:]
path_ann = os.path.join(dir_ann + '_right', basename + '.png.pifpaf.json')
try:
with open(path_ann, 'r') as f:
annotations = json.load(f)
except FileNotFoundError:
annotations = None
if ite == 1:
stereo_file = False
return annotations, kk, tt, stereo_file
def eval_geometric(keypoints, kk, average_y=0.48):
""" Evaluate geometric distance"""
dds_geom = []
uv_centers = get_keypoints(keypoints, mode='center')
uv_shoulders = get_keypoints(keypoints, mode='shoulder')
uv_hips = get_keypoints(keypoints, mode='hip')
xy_centers = pixel_to_camera(uv_centers, kk, 1)
xy_shoulders = pixel_to_camera(uv_shoulders, kk, 1)
xy_hips = pixel_to_camera(uv_hips, kk, 1)
for idx, xy_center in enumerate(xy_centers):
zz = compute_distance(xy_shoulders[idx], xy_hips[idx], average_y)
xyz_center = np.array([xy_center[0], xy_center[1], zz])
dd_geom = float(np.linalg.norm(xyz_center))
dds_geom.append(dd_geom)
return dds_geom

View File

@ -1,85 +1,70 @@
import glob
import json
import logging
import os
import numpy as np
import math
from collections import defaultdict
from utils.camera import pixel_to_camera
import numpy as np
from utils.camera import pixel_to_camera, get_keypoints
AVERAGE_Y = 0.48
CLUSTERS = ['10', '20', '30', 'all']
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GeomBaseline:
def geometric_baseline(joints):
"""
List of json files --> 2 lists with mean and std for each segment and the total count of instances
def __init__(self, joints):
For each annotation:
1. From gt boxes calculate the height (deltaY) for the segments head, shoulder, hip, ankle
2. From mask boxes calculate distance of people using average height of people and real pixel height
# Initialize directories
self.clusters = ['10', '20', '30', '>30', 'all']
self.average_y = 0.48
self.joints = joints
For left-right ambiguities we chose always the average of the joints
from utils.misc import calculate_iou
self.calculate_iou = calculate_iou
from utils.nuscenes import get_unique_tokens, split_scenes
self.get_unique_tokens = get_unique_tokens
self.split_scenes = split_scenes
The joints are mapped from 0 to 16 in the following order:
['nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear', 'left_shoulder', 'right_shoulder', 'left_elbow',
'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle',
'right_ankle']
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
"""
cnt_tot = 0
dic_dist = defaultdict(lambda: defaultdict(list))
def run(self):
"""
List of json files --> 2 lists with mean and std for each segment and the total count of instances
# Access the joints file
with open(joints, 'r') as ff:
dic_joints = json.load(ff)
For each annotation:
1. From gt boxes calculate the height (deltaY) for the segments head, shoulder, hip, ankle
2. From mask boxes calculate distance of people using average height of people and real pixel height
# Calculate distances for all the instances in the joints dictionary
for phase in ['train', 'val']:
cnt = update_distances(dic_joints[phase], dic_dist, phase, AVERAGE_Y)
cnt_tot += cnt
For left-right ambiguities we chose always the average of the joints
# Calculate mean and std of each segment
dic_h_means = calculate_heights(dic_dist['heights'], mode='mean')
dic_h_stds = calculate_heights(dic_dist['heights'], mode='std')
errors = calculate_error(dic_dist['error'])
The joints are mapped from 0 to 16 in the following order:
['nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear', 'left_shoulder', 'right_shoulder', 'left_elbow',
'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle',
'right_ankle']
"""
cnt_tot = 0
dic_dist = defaultdict(lambda: defaultdict(list))
# Access the joints file
with open(self.joints, 'r') as ff:
dic_joints = json.load(ff)
# Calculate distances for all the segments
for phase in ['train', 'val']:
cnt = update_distances(dic_joints[phase], dic_dist, phase, self.average_y)
cnt_tot += cnt
dic_h_means = calculate_heights(dic_dist['heights'], mode='mean')
dic_h_stds = calculate_heights(dic_dist['heights'], mode='std')
self.logger.info("Computed distance of {} annotations".format(cnt_tot))
for key in dic_h_means:
self.logger.info("Average height of segment {} is {:.2f} with a std of {:.2f}".
format(key, dic_h_means[key], dic_h_stds[key]))
errors = calculate_error(dic_dist['error'])
for clst in self.clusters:
self.logger.info("Average distance over the val set for clst {}: {:.2f}".format(clst, errors[clst]))
self.logger.info("Joints used: {}".format(self.joints))
# Show results
logger.info("Computed distance of {} annotations".format(cnt_tot))
for key in dic_h_means:
logger.info("Average height of segment {} is {:.2f} with a std of {:.2f}".
format(key, dic_h_means[key], dic_h_stds[key]))
for clst in CLUSTERS:
logger.info("Average error over the val set for clst {}: {:.2f}".format(clst, errors[clst]))
logger.info("Joints used: {}".format(joints))
def update_distances(dic_fin, dic_dist, phase, average_y):
# Loop over each annotation in the json file corresponding to the image
cnt = 0
for idx, kps in enumerate(dic_fin['kps']):
# Extract pixel coordinates of head, shoulder, hip, ankle and and save them
dic_uv = extract_pixel_coord(kps)
dic_uv = {mode: get_keypoints(kps, mode) for mode in ['head', 'shoulder', 'hip', 'ankle']}
# Convert segments from pixel coordinate to camera coordinate
kk = dic_fin['K'][idx]
@ -87,26 +72,21 @@ def update_distances(dic_fin, dic_dist, phase, average_y):
# Create a dict with all annotations in meters
dic_xyz = {key: pixel_to_camera(dic_uv[key], kk, z_met) for key in dic_uv}
dic_xyz_norm = {key: pixel_to_camera(dic_uv[key], kk, 1) for key in dic_uv}
# Compute real height
dy_met = abs(dic_xyz['hip'][1] - dic_xyz['shoulder'][1])
dy_met = abs(float((dic_xyz['hip'][0][1] - dic_xyz['shoulder'][0][1])))
# Estimate distance for a single annotation
z_met_real, _ = compute_distance_single(dic_uv['shoulder'], dic_uv['hip'], kk, average_y,
mode='real', dy_met=dy_met)
z_met_approx, _ = compute_distance_single(dic_uv['shoulder'], dic_uv['hip'], kk, average_y,
mode='average')
z_met_real = compute_distance(dic_xyz_norm['shoulder'][0], dic_xyz_norm['hip'][0], average_y,
mode='real', dy_met=dy_met)
z_met_approx = compute_distance(dic_xyz_norm['shoulder'][0], dic_xyz_norm['hip'][0], average_y, mode='average')
# Compute distance with respect to the center of the 3D bounding box
xyz_met = np.array(dic_fin['boxes_3d'][idx][0:3])
d_met = np.linalg.norm(xyz_met)
d_real = math.sqrt(z_met_real ** 2 + dic_fin['boxes_3d'][idx][0] ** 2 + dic_fin['boxes_3d'][idx][1] ** 2)
d_approx = math.sqrt(z_met_approx ** 2 +
dic_fin['boxes_3d'][idx][0] ** 2 + dic_fin['boxes_3d'][idx][1] ** 2)
# if abs(d_qmet - d_real) > 1e-1: # "Error in computing distance with real height in pixels"
# aa = 5
# Update the dictionary with distance and heights metrics
dic_dist = update_dic_dist(dic_dist, dic_xyz, d_real, d_approx, phase)
cnt += 1
@ -114,22 +94,18 @@ def update_distances(dic_fin, dic_dist, phase, average_y):
return cnt
def compute_distance_single(uv_1, uv_2, kk, average_y, mode='average', dy_met=0):
def compute_distance(xyz_norm_1, xyz_norm_2, average_y, mode='average', dy_met=0):
"""
Compute distance Z of a mask annotation (solving a linear system) for 2 possible cases:
1. knowing specific height of the annotation (head-ankle) dy_met
2. using mean height of people (average_y)
"""
assert mode == 'average' or mode == 'real'
# Trasform into normalized camera coordinates (plane at 1m)
xyz_met_norm_1 = pixel_to_camera(uv_1, kk, 1)
xyz_met_norm_2 = pixel_to_camera(uv_2, kk, 1)
x1 = xyz_met_norm_1[0]
y1 = xyz_met_norm_1[1]
x2 = xyz_met_norm_2[0]
y2 = xyz_met_norm_2[1]
x1 = float(xyz_norm_1[0])
y1 = float(xyz_norm_1[1])
x2 = float(xyz_norm_2[0])
y2 = float(xyz_norm_2[1])
xx = (x1 + x2) / 2
# Choose if solving for provided height or average one.
@ -138,9 +114,6 @@ def compute_distance_single(uv_1, uv_2, kk, average_y, mode='average', dy_met=0)
else:
cc = -dy_met
# if - 3 * average_y <= cc <= -2:
# aa = 5
# Solving the linear system Ax = b
Aa = np.array([[y1, 0, -xx],
[0, -y1, 1],
@ -151,26 +124,7 @@ def compute_distance_single(uv_1, uv_2, kk, average_y, mode='average', dy_met=0)
xx = np.linalg.lstsq(Aa, bb, rcond=None)
z_met = abs(np.float(xx[0][1])) # Abs take into account specularity behind the observer
# Compute the absolute x and y coordinates in meters
xyz_met_1 = xyz_met_norm_1 * z_met
xyz_met_2 = xyz_met_norm_2 * z_met
return z_met, (xyz_met_1, xyz_met_2)
def extract_pixel_coord(kps):
"""Extract uv coordinates from keypoints and save them in a dict """
# For each level of height (e.g. 5 points in the head), take the average of them
uv_head = np.array([np.average(kps[0][0:5]), np.average(kps[1][0:5]), 1])
uv_shoulder = np.array([np.average(kps[0][5:7]), np.average(kps[1][5:7]), 1])
uv_hip = np.array([np.average(kps[0][11:13]), np.average(kps[1][11:13]), 1])
uv_ankle = np.array([np.average(kps[0][15:17]), np.average(kps[1][15:17]), 1])
dic_uv = {'head': uv_head, 'shoulder': uv_shoulder, 'hip': uv_hip, 'ankle': uv_ankle}
return dic_uv
return z_met
def update_dic_dist(dic_dist, dic_xyz, d_real, d_approx, phase):
@ -178,10 +132,10 @@ def update_dic_dist(dic_dist, dic_xyz, d_real, d_approx, phase):
# Update the dict with heights metric
if phase == 'train':
dic_dist['heights']['head'].append(np.float(dic_xyz['head'][1]))
dic_dist['heights']['shoulder'].append(np.float(dic_xyz['shoulder'][1]))
dic_dist['heights']['hip'].append(np.float(dic_xyz['hip'][1]))
dic_dist['heights']['ankle'].append(np.float(dic_xyz['ankle'][1]))
dic_dist['heights']['head'].append(float(dic_xyz['head'][0][1]))
dic_dist['heights']['shoulder'].append(float(dic_xyz['shoulder'][0][1]))
dic_dist['heights']['hip'].append(float(dic_xyz['hip'][0][1]))
dic_dist['heights']['ankle'].append(float(dic_xyz['ankle'][0][1]))
# Update the dict with distance metrics for the test phase
if phase == 'val':
@ -235,11 +189,8 @@ def calculate_error(dic_errors):
"""
Compute statistics of distances based on the distance
"""
errors = {}
for clst in dic_errors:
errors[clst] = np.float(np.mean(np.array(dic_errors[clst])))
return errors

View File

@ -6,7 +6,7 @@ import logging
from collections import defaultdict
import datetime
from utils.misc import get_iou_matches
from utils.misc import get_iou_matches, get_task_error
from utils.kitti import check_conditions, get_category, split_training, parse_ground_truth
from visuals.results import print_results
@ -117,6 +117,11 @@ class KittiEval:
print("\n Number of matched annotations: {:.1f} %".format(self.errors[key]['matched']))
print("-"*100)
print("\n Annotations inside the confidence interval: {:.1f} %"
.format(100 * self.dic_stats['test']['our']['all']['interval']))
print("precision 1: {:.2f}".format(self.dic_stats['test']['our']['all']['prec_1']))
print("precision 2: {:.2f}".format(self.dic_stats['test']['our']['all']['prec_2']))
def printer(self, show):
print_results(self.dic_stats, show)
@ -171,15 +176,13 @@ class KittiEval:
file_lines = ff.readlines()
for line_our in file_lines[:-1]:
line_list = [float(x) for x in line_our.split()]
if check_conditions(line_list, thresh=self.dic_thresh_conf[method], mode=method):
boxes.append(line_list[:4])
# xyzs.append(line_list[4:7])
dds.append(line_list[7])
stds_ale.append(line_list[8])
stds_epi.append(line_list[9])
dds.append(line_list[8])
stds_ale.append(line_list[9])
stds_epi.append(line_list[10])
dds_geom.append(line_list[11])
# xy_kps.append(line_list[12:])
self.dic_cnt[method] += 1
# kk_list = [float(x) for x in file_lines[-1].split()]
@ -238,7 +241,6 @@ class KittiEval:
self.dic_cnt['merged'] += 1
def update_errors(self, dd, dd_gt, cat, errors):
"""Compute and save errors between a single box and the gt box which match"""
diff = abs(dd - dd_gt)
@ -274,26 +276,49 @@ class KittiEval:
self.dic_stds[cat]['epi'].append(std_epi)
# Number of annotations inside the confidence interval
if dd_gt <= dd: # Particularly dangerous instances
std = std_epi if std_epi > 0 else std_ale # consider aleatoric uncertainty if epistemic is not calculated
if abs(dd - dd_gt) <= std:
self.dic_stds['all']['interval'].append(1)
self.dic_stds[clst]['interval'].append(1)
self.dic_stds[cat]['interval'].append(1)
else:
self.dic_stds['all']['interval'].append(0)
self.dic_stds[clst]['interval'].append(0)
self.dic_stds[cat]['interval'].append(0)
# Annotations at risk inside the confidence interval
if dd_gt <= dd:
self.dic_stds['all']['at_risk'].append(1)
self.dic_stds[clst]['at_risk'].append(1)
self.dic_stds[cat]['at_risk'].append(1)
if abs(dd - dd_gt) <= std_epi:
self.dic_stds['all']['interval'].append(1)
self.dic_stds[clst]['interval'].append(1)
self.dic_stds[cat]['interval'].append(1)
self.dic_stds['all']['at_risk-interval'].append(1)
self.dic_stds[clst]['at_risk-interval'].append(1)
self.dic_stds[cat]['at_risk-interval'].append(1)
else:
self.dic_stds['all']['interval'].append(0)
self.dic_stds[clst]['interval'].append(0)
self.dic_stds[cat]['interval'].append(0)
self.dic_stds['all']['at_risk-interval'].append(0)
self.dic_stds[clst]['at_risk-interval'].append(0)
self.dic_stds[cat]['at_risk-interval'].append(0)
else:
self.dic_stds['all']['at_risk'].append(0)
self.dic_stds[clst]['at_risk'].append(0)
self.dic_stds[cat]['at_risk'].append(0)
# Precision of uncertainty
eps = 1e-4
task_error = get_task_error(dd)
prec_1 = abs(dd - dd_gt) / (std_epi + eps)
prec_2 = abs(std_epi - task_error)
self.dic_stds['all']['prec_1'].append(prec_1)
self.dic_stds[clst]['prec_1'].append(prec_1)
self.dic_stds[cat]['prec_1'].append(prec_1)
self.dic_stds['all']['prec_2'].append(prec_2)
self.dic_stds[clst]['prec_2'].append(prec_2)
self.dic_stds[cat]['prec_2'].append(prec_2)
def get_statistics(dic_stats, errors, dic_stds, key):
"""Update statistics of a cluster"""
@ -307,6 +332,8 @@ def get_statistics(dic_stats, errors, dic_stds, key):
dic_stats['std_epi'] = sum(dic_stds['epi']) / float(len(dic_stds['epi']))
dic_stats['interval'] = sum(dic_stds['interval']) / float(len(dic_stds['interval']))
dic_stats['at_risk'] = sum(dic_stds['at_risk']) / float(len(dic_stds['at_risk']))
dic_stats['prec_1'] = sum(dic_stds['prec_1']) / float(len(dic_stds['prec_1']))
dic_stats['prec_2'] = sum(dic_stds['prec_2']) / float(len(dic_stds['prec_2']))
def add_true_negatives(err, cnt_gt):

View File

@ -1,186 +0,0 @@
"""Run monoloco over all the pifpaf joints of KITTI images
and extract and save the annotations in txt files"""
import math
import os
import glob
import json
import logging
import numpy as np
import torch
from models.architectures import LinearModel
from utils.misc import laplace_sampling
from utils.kitti import eval_geometric, get_calibration
from utils.normalize import unnormalize_bi
from utils.pifpaf import get_input_data, preprocess_pif
from utils.camera import get_depth_from_distance
class RunKitti:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
cnt_ann = 0
cnt_file = 0
cnt_no_file = 0
average_y = 0.48
n_samples = 100
def __init__(self, model, dir_ann, dropout, hidden_size, n_stage, n_dropout):
self.dir_ann = dir_ann
self.n_dropout = n_dropout
self.dir_kk = os.path.join('data', 'kitti', 'calib')
self.dir_out = os.path.join('data', 'kitti', 'monoloco')
if not os.path.exists(self.dir_out):
os.makedirs(self.dir_out)
print("Created output directory for txt files")
self.list_basename = factory_basename(dir_ann)
# Load the model
input_size = 17 * 2
use_cuda = torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.model = LinearModel(input_size=input_size, output_size=2, linear_size=hidden_size,
p_dropout=dropout, num_stage=n_stage)
self.model.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
self.model.eval() # Default is train
self.model.to(self.device)
def run(self):
# Run inference
for basename in self.list_basename:
path_calib = os.path.join(self.dir_kk, basename + '.txt')
annotations, kk, tt, _ = factory_file(path_calib, self.dir_ann, basename)
boxes, keypoints = preprocess_pif(annotations)
(inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = get_input_data(boxes, keypoints, kk)
dds_geom, xy_centers = eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48)
# Update counting
self.cnt_ann += len(boxes)
if not inputs:
self.cnt_no_file += 1
else:
self.cnt_file += 1
# Run the model
inputs = torch.from_numpy(np.array(inputs)).float().to(self.device)
if self.n_dropout > 0:
total_outputs = torch.empty((0, len(uv_boxes))).to(self.device)
self.model.dropout.training = True
for _ in range(self.n_dropout):
outputs = self.model(inputs)
outputs = unnormalize_bi(outputs)
samples = laplace_sampling(outputs, self.n_samples)
total_outputs = torch.cat((total_outputs, samples), 0)
varss = total_outputs.std(0)
else:
varss = [0]*len(uv_boxes)
# Don't use dropout for the mean prediction and aleatoric uncertainty
self.model.dropout.training = False
outputs_net = self.model(inputs)
outputs = outputs_net.cpu().detach().numpy()
list_zzs = get_depth_from_distance(outputs, xy_centers)
all_outputs = [outputs, varss, dds_geom]
all_inputs = [uv_boxes, xy_centers, xy_kps]
all_params = [kk, tt]
# Save the file
all_outputs.append(list_zzs)
path_txt = os.path.join(self.dir_out, basename + '.txt')
save_txts(path_txt, all_inputs, all_outputs, all_params)
aa = 5
# Print statistics
print("Saved in {} txt {} annotations. Not found {} images"
.format(self.cnt_file, self.cnt_ann, self.cnt_no_file))
def save_txts(path_txt, all_inputs, all_outputs, all_params):
outputs, varss, dds_geom, zzs = all_outputs[:]
uv_boxes, xy_centers, xy_kps = all_inputs[:]
kk, tt = all_params[:]
with open(path_txt, "w+") as ff:
for idx in range(outputs.shape[0]):
xx_1 = float(xy_centers[idx][0])
yy_1 = float(xy_centers[idx][1])
xy_kp = xy_kps[idx]
dd = float(outputs[idx][0])
std_ale = math.exp(float(outputs[idx][1])) * dd
zz = zzs[idx]
xx_cam_0 = xx_1 * zz + tt[0]
yy_cam_0 = yy_1 * zz + tt[1]
zz_cam_0 = zz + tt[2]
dd_cam_0 = math.sqrt(xx_cam_0 ** 2 + yy_cam_0 ** 2 + zz_cam_0 ** 2)
uv_box = uv_boxes[idx]
twodecimals = ["%.3f" % vv for vv in [uv_box[0], uv_box[1], uv_box[2], uv_box[3],
xx_cam_0, yy_cam_0, zz_cam_0, dd_cam_0,
std_ale, varss[idx], uv_box[4], dds_geom[idx]]]
keypoints_str = ["%.5f" % vv for vv in xy_kp]
for item in twodecimals:
ff.write("%s " % item)
for item in keypoints_str:
ff.write("%s " % item)
ff.write("\n")
# Save intrinsic matrix in the last row
kk_list = kk.reshape(-1, ).tolist()
for kk_el in kk_list:
ff.write("%f " % kk_el)
ff.write("\n")
def factory_basename(dir_ann):
""" Return all the basenames in the annotations folder"""
list_ann = glob.glob(os.path.join(dir_ann, '*.json'))
list_basename = [os.path.basename(x).split('.')[0] for x in list_ann]
assert list_basename, " Missing json annotations file to create txt files for KITTI datasets"
return list_basename
def factory_file(path_calib, dir_ann, basename, ite=0):
"""Choose the annotation and the calibration files. Stereo option with ite = 1"""
stereo_file = True
p_left, p_right = get_calibration(path_calib)
if ite == 0:
kk, tt = p_left[:]
path_ann = os.path.join(dir_ann, basename + '.png.pifpaf.json')
else:
kk, tt = p_right[:]
path_ann = os.path.join(dir_ann + '_right', basename + '.png.pifpaf.json')
try:
with open(path_ann, 'r') as f:
annotations = json.load(f)
except FileNotFoundError:
annotations = None
if ite == 1:
stereo_file = False
return annotations, kk, tt, stereo_file

View File

@ -7,9 +7,11 @@ import logging
from collections import defaultdict
import json
import datetime
import torch
from utils.kitti import get_calibration, split_training, parse_ground_truth
from utils.pifpaf import get_input_data, preprocess_pif
from utils.misc import get_idx_max, append_cluster
from utils.pifpaf import get_network_inputs, preprocess_pif
from utils.misc import get_iou_matches, append_cluster
class PreprocessKitti:
@ -26,10 +28,10 @@ class PreprocessKitti:
clst=defaultdict(lambda: defaultdict(list)))}
dic_names = defaultdict(lambda: defaultdict(list))
def __init__(self, dir_ann, iou_thresh=0.3):
def __init__(self, dir_ann, iou_min=0.3):
self.dir_ann = dir_ann
self.iou_thresh = iou_thresh
self.iou_min = iou_min
self.dir_gt = os.path.join('data', 'kitti', 'gt')
self.names_gt = tuple(os.listdir(self.dir_gt))
self.dir_kk = os.path.join('data', 'kitti', 'calib')
@ -70,10 +72,14 @@ class PreprocessKitti:
kk = p_left[0]
# Iterate over each line of the gt file and save box location and distances
(boxes_gt, boxes_3d, dds_gt, _, _) = parse_ground_truth(path_gt)
if phase == 'train':
(boxes_gt, boxes_3d, dds_gt, _, _) = parse_ground_truth(path_gt, mode='gt_all') # Also cyclists
else:
(boxes_gt, boxes_3d, dds_gt, _, _) = parse_ground_truth(path_gt, mode='gt') # only pedestrians
self.dic_names[basename + '.png']['boxes'] = copy.deepcopy(boxes_gt)
self.dic_names[basename + '.png']['dds'] = copy.deepcopy(dds_gt)
self.dic_names[basename + '.png']['K'] = copy.deepcopy(kk.tolist())
self.dic_names[basename + '.png']['K'] = copy.deepcopy(kk)
cnt_gt += len(boxes_gt)
cnt_files += 1
cnt_files_ped += min(len(boxes_gt), 1) # if no boxes 0 else 1
@ -82,28 +88,23 @@ class PreprocessKitti:
try:
with open(os.path.join(self.dir_ann, basename + '.png.pifpaf.json'), 'r') as f:
annotations = json.load(f)
boxes, keypoints = preprocess_pif(annotations)
(inputs, _), (uv_kps, uv_boxes, _, _) = get_input_data(boxes, keypoints, kk)
boxes, keypoints = preprocess_pif(annotations, im_size=(1238, 374))
inputs = get_network_inputs(keypoints, kk).tolist()
except FileNotFoundError:
uv_boxes = []
boxes = []
# Match each set of keypoint with a ground truth
for ii, box in enumerate(uv_boxes):
idx_max, iou_max = get_idx_max(box, boxes_gt)
if iou_max >= self.iou_thresh:
self.dic_jo[phase]['kps'].append(uv_kps[ii])
self.dic_jo[phase]['X'].append(inputs[ii])
self.dic_jo[phase]['Y'].append([dds_gt[idx_max]]) # Trick to make it (nn,1)
self.dic_jo[phase]['boxes_3d'].append(boxes_3d[idx_max])
self.dic_jo[phase]['K'].append(kk.tolist())
self.dic_jo[phase]['names'].append(name) # One image name for each annotation
append_cluster(self.dic_jo, phase, inputs[ii], dds_gt[idx_max], uv_kps[ii])
dic_cnt[phase] += 1
boxes_gt.pop(idx_max)
dds_gt.pop(idx_max)
matches = get_iou_matches(boxes, boxes_gt, self.iou_min)
for (idx, idx_gt) in matches:
self.dic_jo[phase]['kps'].append(keypoints[idx])
self.dic_jo[phase]['X'].append(inputs[idx])
self.dic_jo[phase]['Y'].append([dds_gt[idx_gt]]) # Trick to make it (nn,1)
self.dic_jo[phase]['boxes_3d'].append(boxes_3d[idx_gt])
self.dic_jo[phase]['K'].append(kk)
self.dic_jo[phase]['names'].append(name) # One image name for each annotation
append_cluster(self.dic_jo, phase, inputs[idx], dds_gt[idx_gt], keypoints[idx])
dic_cnt[phase] += 1
with open(self.path_joints, 'w') as file:
json.dump(self.dic_jo, file)

View File

@ -12,10 +12,10 @@ import numpy as np
from nuscenes.nuscenes import NuScenes
from nuscenes.utils import splits
from utils.misc import get_idx_max, append_cluster
from utils.misc import get_iou_matches, append_cluster
from utils.nuscenes import select_categories
from utils.camera import project_3d
from utils.pifpaf import get_input_data, preprocess_pif
from utils.pifpaf import preprocess_pif, get_network_inputs
class PreprocessNuscenes:
@ -90,6 +90,7 @@ class PreprocessNuscenes:
sd_token = sample_dic['data'][cam]
cnt_sd += 1
path_im, boxes_obj, kk = self.nusc.get_sample_data(sd_token, box_vis_level=1) # At least one corner
kk = kk.tolist()
# Extract all the annotations of the person
boxes_gt = []
@ -110,7 +111,7 @@ class PreprocessNuscenes:
boxes_3d.append(box_3d)
self.dic_names[name]['boxes'].append(box)
self.dic_names[name]['dds'].append(dd)
self.dic_names[name]['K'] = kk.tolist()
self.dic_names[name]['K'] = kk
# Run IoU with pifpaf detections and save
path_pif = os.path.join(self.dir_ann, name + '.pifpaf.json')
@ -120,27 +121,22 @@ class PreprocessNuscenes:
with open(path_pif, 'r') as file:
annotations = json.load(file)
boxes, keypoints = preprocess_pif(annotations, im_size=None)
(inputs, _), (uv_kps, uv_boxes, _, _) = get_input_data(boxes, keypoints, kk)
boxes, keypoints = preprocess_pif(annotations, im_size=(1600, 900))
for ii, box in enumerate(uv_boxes):
idx_max, iou_max = get_idx_max(box, boxes_gt)
if keypoints:
inputs = get_network_inputs(keypoints, kk).tolist()
if iou_max > self.iou_min:
self.dic_jo[phase]['kps'].append(uv_kps[ii])
self.dic_jo[phase]['X'].append(inputs[ii])
self.dic_jo[phase]['Y'].append([dds[idx_max]]) # Trick to make it (nn,1)
matches = get_iou_matches(boxes, boxes_gt, self.iou_min)
for (idx, idx_gt) in matches:
self.dic_jo[phase]['kps'].append(keypoints[idx])
self.dic_jo[phase]['X'].append(inputs[idx])
self.dic_jo[phase]['Y'].append([dds[idx_gt]]) # Trick to make it (nn,1)
self.dic_jo[phase]['names'].append(name) # One image name for each annotation
self.dic_jo[phase]['boxes_3d'].append(boxes_3d[idx_max])
self.dic_jo[phase]['K'].append(kk.tolist())
append_cluster(self.dic_jo, phase, inputs[ii], dds[idx_max], uv_kps[ii])
boxes_gt.pop(idx_max)
dds.pop(idx_max)
boxes_3d.pop(idx_max)
self.dic_jo[phase]['boxes_3d'].append(boxes_3d[idx_gt])
self.dic_jo[phase]['K'].append(kk)
append_cluster(self.dic_jo, phase, inputs[idx], dds[idx_gt], keypoints[idx])
cnt_ann += 1
sys.stdout.write('\r' + 'Saved annotations {}'
.format(cnt_ann) + '\t')
sys.stdout.write('\r' + 'Saved annotations {}'.format(cnt_ann) + '\t')
current_token = sample_dic['next']

View File

@ -11,8 +11,8 @@ from features.preprocess_nu import PreprocessNuscenes
from features.preprocess_ki import PreprocessKitti
from predict.predict import predict
from models.trainer import Trainer
from eval.run_kitti import RunKitti
from eval.geom_baseline import GeomBaseline
from eval.generate_kitti import generate_kitti
from eval.geom_baseline import geometric_baseline
from models.hyp_tuning import HypTuning
from eval.kitti_eval import KittiEval
@ -66,6 +66,7 @@ def cli():
predict_parser.add_argument('--predict', help='whether to make prediction', action='store_true')
predict_parser.add_argument('--z_max', type=int, help='maximum meters distance for predictions', default=22)
predict_parser.add_argument('--n_dropout', type=int, help='Epistemic uncertainty evaluation', default=0)
predict_parser.add_argument('--dropout', type=float, help='dropout parameter', default=0.2)
predict_parser.add_argument('--combined', help='to print combined images', action='store_true')
# Training
@ -88,7 +89,7 @@ def cli():
# Evaluation
eval_parser.add_argument('--dataset', help='datasets to evaluate, kitti or nuscenes', default='kitti')
eval_parser.add_argument('--geometric', help='to evaluate geometric distance', action='store_true')
eval_parser.add_argument('--run_kitti', help='create txt files for KITTI evaluation', action='store_true')
eval_parser.add_argument('--generate', help='create txt files for KITTI evaluation', action='store_true')
eval_parser.add_argument('--dir_ann', help='directory of annotations of 2d joints (for KITTI evaluation')
eval_parser.add_argument('--model', help='path of MonoLoco model to load', required=True)
eval_parser.add_argument('--joints', help='Json file with input joints to evaluate (for nuScenes evaluation)')
@ -133,14 +134,10 @@ def main():
elif args.command == 'eval':
if args.geometric:
geometric_baseline = GeomBaseline(args.joints)
geometric_baseline.run()
geometric_baseline(args.joints)
if args.run_kitti:
run_kitti = RunKitti(model=args.model, dir_ann=args.dir_ann,
dropout=args.dropout, hidden_size=args.hidden_size, n_stage=args.n_stage,
n_dropout=args.n_dropout)
run_kitti.run()
if args.generate:
generate_kitti(args.model, args.dir_ann, p_dropout=args.dropout, n_dropout=args.n_dropout)
if args.dataset == 'kitti':
kitti_eval = KittiEval()

View File

@ -32,11 +32,7 @@ class HypTuning:
now = datetime.datetime.now()
now_time = now.strftime("%Y%m%d-%H%M")[2:]
if baseline:
name_out = 'hyp-baseline-' + now_time
else:
name_out = 'hyp-monoloco-' + now_time
name_out = 'hyp-baseline-' if baseline else 'hyp-monoloco-'
self.path_log = os.path.join(dir_logs, name_out + now_time)
self.path_model = os.path.join(dir_out, name_out + now_time + '.pkl')

View File

@ -95,7 +95,7 @@ class Trainer:
# Select the device and load the data
use_cuda = torch.cuda.is_available()
self.device = torch.device("cuda:0" if use_cuda else "cpu")
self.device = torch.device("cuda:1" if use_cuda else "cpu")
print('Device: ', self.device)
# Set the seed for random initialization
@ -331,24 +331,20 @@ class Trainer:
else:
mean_bi = torch.mean(outputs[:, 1]).item()
max_bi = torch.max(outputs[:, 1]).item()
low_bound_bi = labels >= (outputs[:, 0] - outputs[:, 1])
up_bound_bi = labels <= (outputs[:, 0] + outputs[:, 1])
bools_bi = low_bound_bi & up_bound_bi
conf_bi = float(torch.sum(bools_bi)) / float(bools_bi.shape[0])
if varss[0] == 0:
aa = 5
else:
mean_var = torch.mean(varss).item()
max_var = torch.max(varss).item()
low_bound_var = labels >= (outputs[:, 0] - varss)
up_bound_var = labels <= (outputs[:, 0] + varss)
bools_var = low_bound_var & up_bound_var
conf_var = float(torch.sum(bools_var)) / float(bools_var.shape[0])
# if varss[0] >= 0:
# mean_var = torch.mean(varss).item()
# max_var = torch.max(varss).item()
#
# low_bound_var = labels >= (outputs[:, 0] - varss)
# up_bound_var = labels <= (outputs[:, 0] + varss)
# bools_var = low_bound_var & up_bound_var
# conf_var = float(torch.sum(bools_var)) / float(bools_var.shape[0])
dic_err['mean'] += mean_mu * (outputs.size(0) / size_eval)
dic_err['bi'] += mean_bi * (outputs.size(0) / size_eval)

View File

@ -1,10 +1,11 @@
import json
import os
from visuals.printer import Printer
from collections import defaultdict
from openpifpaf import show
from PIL import Image
from visuals.printer import Printer
from utils.misc import get_iou_matches, reorder_matches
from utils.camera import get_keypoints, pixel_to_camera, xyz_from_distance
def factory_for_gt(im_size, name=None, path_gt=None):
@ -13,7 +14,7 @@ def factory_for_gt(im_size, name=None, path_gt=None):
try:
with open(path_gt, 'r') as f:
dic_names = json.load(f)
print('-' * 120 + "\nMonoloco: Ground-truth file opened\n")
print('-' * 120 + "\nMonoloco: Ground-truth file opened")
except FileNotFoundError:
print('-' * 120 + "\nMonoloco: ground-truth file not found\n")
dic_names = {}
@ -45,7 +46,6 @@ def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, monoloco_
# Save json file
if 'pifpaf' in args.networks:
keypoint_sets, scores, pifpaf_out = pifpaf_outputs[:]
# Visualizer
@ -74,13 +74,16 @@ def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, monoloco_
skeleton_painter.keypoints(ax, keypoint_sets, scores=scores)
if 'monoloco' in args.networks:
dic_out = monoloco_post_process(monoloco_outputs)
if any((xx in args.output_types for xx in ['front', 'bird', 'combined'])):
epistemic = False
if args.n_dropout > 0:
epistemic = True
printer = Printer(images_outputs[1], output_path, monoloco_outputs, kk, output_types=args.output_types,
printer = Printer(images_outputs[1], output_path, dic_out, kk, output_types=args.output_types,
show=args.show, z_max=args.z_max, epistemic=epistemic)
printer.print()
@ -89,3 +92,49 @@ def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, monoloco_
json.dump(monoloco_outputs, ff)
def monoloco_post_process(monoloco_outputs, iou_min=0.25):
"""Post process monoloco to output final dictionary with all information for visualizations"""
dic_out = defaultdict(list)
outputs, varss, boxes, keypoints, kk, dic_gt = monoloco_outputs[:]
if dic_gt:
boxes_gt, dds_gt = dic_gt['boxes'], dic_gt['dds']
matches = get_iou_matches(boxes, boxes_gt, thresh=iou_min)
else:
matches = [(idx, idx_gt) for idx, idx_gt in range(len(boxes))] # Replicate boxes
matches = reorder_matches(matches, boxes, mode='left_right')
uv_shoulders = get_keypoints(keypoints, mode='shoulder')
uv_centers = get_keypoints(keypoints, mode='center')
xy_centers = pixel_to_camera(uv_centers, kk, 1)
# Match with ground truth if available
for idx, idx_gt in matches:
dd_pred = float(outputs[idx][0])
ale = float(outputs[idx][1])
var_y = float(varss[idx])
dd_real = dds_gt[idx_gt] if dic_gt else dd_pred
kps = keypoints[idx]
box = boxes[idx]
uu_s, vv_s = uv_shoulders.tolist()[idx][0:2]
uu_c, vv_c = uv_centers.tolist()[idx][0:2]
uv_shoulder = [round(uu_s), round(vv_s)]
uv_center = [round(uu_c), round(vv_c)]
xyz_real = xyz_from_distance(dd_real, xy_centers[idx])
xyz_pred = xyz_from_distance(dd_pred, xy_centers[idx])
dic_out['boxes'].append(box)
dic_out['dds_real'].append(dd_real)
dic_out['dds_pred'].append(dd_pred)
dic_out['stds_ale'].append(ale)
dic_out['stds_epi'].append(var_y)
dic_out['xyz_real'].append(xyz_real.squeeze().tolist())
dic_out['xyz_pred'].append(xyz_pred.squeeze().tolist())
dic_out['uv_kps'].append(kps)
dic_out['uv_centers'].append(uv_center)
dic_out['uv_shoulders'].append(uv_shoulder)
return dic_out

View File

@ -3,18 +3,14 @@
Monoloco predictor. It receives pifpaf joints and outputs distances
"""
from collections import defaultdict
import logging
import time
import numpy as np
import torch
from models.architectures import LinearModel
from utils.camera import get_depth
from utils.misc import laplace_sampling, get_idx_max
from utils.misc import laplace_sampling
from utils.normalize import unnormalize_bi
from utils.pifpaf import get_input_data
from utils.pifpaf import get_network_inputs
class MonoLoco:
@ -24,97 +20,44 @@ class MonoLoco:
OUTPUT_SIZE = 2
INPUT_SIZE = 17 * 2
LINEAR_SIZE = 256
IOU_MIN = 0.25
N_SAMPLES = 100
def __init__(self, model, device, n_dropout=0):
def __init__(self, model_path, device, n_dropout=0, p_dropout=0.2):
self.device = device
self.n_dropout = n_dropout
if self.n_dropout > 0:
self.epistemic = True
else:
self.epistemic = False
self.epistemic = True if self.n_dropout > 0 else False
# load the model parameters
self.model = LinearModel(input_size=self.INPUT_SIZE, output_size=self.OUTPUT_SIZE, linear_size=self.LINEAR_SIZE)
self.model.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
self.model = LinearModel(p_dropout=p_dropout,
input_size=self.INPUT_SIZE, output_size=self.OUTPUT_SIZE, linear_size=self.LINEAR_SIZE,
)
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
self.model.eval() # Default is train
self.model.to(self.device)
def forward(self, boxes, keypoints, kk, dic_gt=None):
def forward(self, keypoints, kk):
"""forward pass of monoloco network"""
if not keypoints:
return None
(inputs_norm, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = \
get_input_data(boxes, keypoints, kk, left_to_right=True)
with torch.no_grad():
inputs = get_network_inputs(torch.tensor(keypoints).to(self.device), torch.tensor(kk).to(self.device))
if self.n_dropout > 0:
self.model.dropout.training = True # Manually reactivate dropout in eval
total_outputs = torch.empty((0, inputs.size()[0])).to(self.device)
# Conversion into torch tensor
if inputs_norm:
with torch.no_grad():
inputs = torch.from_numpy(np.array(inputs_norm)).float()
inputs = inputs.to(self.device)
# self.model.to("cpu")
start = time.time()
# Manually reactivate dropout in eval
self.model.dropout.training = True
total_outputs = torch.empty((0, len(xy_kps))).to(self.device)
if self.n_dropout > 0:
for _ in range(self.n_dropout):
outputs = self.model(inputs)
outputs = unnormalize_bi(outputs)
samples = laplace_sampling(outputs, self.N_SAMPLES)
total_outputs = torch.cat((total_outputs, samples), 0)
varss = total_outputs.std(0)
else:
varss = [0] * len(inputs_norm)
# # Don't use dropout for the mean prediction
start_single = time.time()
for _ in range(self.n_dropout):
outputs = self.model(inputs)
outputs = unnormalize_bi(outputs)
samples = laplace_sampling(outputs, self.N_SAMPLES)
total_outputs = torch.cat((total_outputs, samples), 0)
varss = total_outputs.std(0)
self.model.dropout.training = False
outputs = self.model(inputs)
outputs = unnormalize_bi(outputs)
end = time.time()
print("Total Forward pass time with {} forward passes = {:.2f} ms"
.format(self.n_dropout, (end-start) * 1000))
print("Single forward pass time = {:.2f} ms".format((end - start_single) * 1000))
# Create output files
dic_out = defaultdict(list)
if dic_gt:
boxes_gt, dds_gt = dic_gt['boxes'], dic_gt['dds']
for idx, box in enumerate(uv_boxes):
dd_pred = float(outputs[idx][0])
ale = float(outputs[idx][1])
var_y = float(varss[idx])
# Find the corresponding ground truth if available
if dic_gt:
idx_max, iou_max = get_idx_max(box, boxes_gt)
if iou_max > self.IOU_MIN:
dd_real = dds_gt[idx_max]
boxes_gt.pop(idx_max)
dds_gt.pop(idx_max)
# In case of no matching
else:
dd_real = 0
# In case of no ground truth
else:
dd_real = dd_pred
varss = torch.zeros(inputs.size()[0])
uv_center = uv_centers[idx]
xyz_real = get_depth(uv_center, kk, dd_real)
xyz_pred = get_depth(uv_center, kk, dd_pred)
dic_out['boxes'].append(box)
dic_out['dds_real'].append(dd_real)
dic_out['dds_pred'].append(dd_pred)
dic_out['stds_ale'].append(ale)
dic_out['stds_epi'].append(var_y)
dic_out['xyz_real'].append(xyz_real)
dic_out['xyz_pred'].append(xyz_pred)
dic_out['xy_kps'].append(xy_kps[idx])
dic_out['uv_kps'].append(uv_kps[idx])
dic_out['uv_centers'].append(uv_center)
dic_out['uv_shoulders'].append(uv_shoulders[idx])
return dic_out
# Don't use dropout for the mean prediction
outputs = self.model(inputs)
outputs = unnormalize_bi(outputs)
return outputs, varss

View File

@ -86,7 +86,7 @@ def predict(args):
processor = decoder.factory_from_args(args, model_pifpaf)
# load monoloco
monoloco = MonoLoco(model=args.model, device=args.device, n_dropout=args.n_dropout)
monoloco = MonoLoco(model_path=args.model, device=args.device, n_dropout=args.n_dropout, p_dropout=args.dropout)
# data
data = ImageList(args.images, scale=args.scale)
@ -146,11 +146,13 @@ def predict(args):
im_name = os.path.basename(image_path)
kk, _ = factory_for_gt(im_size, name=im_name, path_gt=args.path_gt)
kk, dic_gt = factory_for_gt(im_size, name=im_name, path_gt=args.path_gt)
# Preprocess pifpaf outputs and run monoloco
boxes, keypoints = preprocess_pif(pifpaf_out, im_size)
monoloco_outputs = monoloco.forward(boxes, keypoints, kk)
outputs, varss = monoloco.forward(keypoints, kk)
monoloco_outputs = [outputs, varss, boxes, keypoints, kk, dic_gt]
else:
monoloco_outputs = None
kk = None

View File

@ -1,17 +1,29 @@
import numpy as np
import math
import torch
import torch.nn.functional as F
def pixel_to_camera(uv1, kk, z_met):
def pixel_to_camera(uv_tensor, kk, z_met):
"""
(3,) array --> (3,) array
Convert a point in pixel coordinate to absolute camera coordinates
Convert a tensor in pixel coordinate to absolute camera coordinates
It accepts lists or tensors of (m, 2) or (m, x, 2) or (m, 2, x)
where x is the number of keypoints
"""
if type(uv_tensor) == list:
uv_tensor = torch.tensor(uv_tensor)
if type(kk) == list:
kk = torch.tensor(kk)
if uv_tensor.size()[-1] != 2:
uv_tensor = uv_tensor.permute(0, 2, 1) # permute to have 2 as last dim to be padded
assert uv_tensor.size()[-1] == 2, "Tensor size not recognized"
uv_padded = F.pad(uv_tensor, pad=(0, 1), mode="constant", value=1) # pad only last-dim below with value 1
kk_1 = np.linalg.inv(kk)
xyz_met_norm = np.dot(kk_1, uv1)
kk_1 = torch.inverse(kk)
xyz_met_norm = torch.matmul(uv_padded, kk_1.t()) # More general than torch.mm
xyz_met = xyz_met_norm * z_met
return xyz_met
@ -28,9 +40,7 @@ def project_3d(box_obj, kk):
"""
Project a 3D bounding box into the image plane using the central corners
"""
box_2d = []
# Obtain the 3d points of the box
xc, yc, zc = box_obj.center
ww, ll, hh, = box_obj.wlh
@ -55,59 +65,39 @@ def project_3d(box_obj, kk):
return box_2d
def preprocess_single(kps, kk):
""" Preprocess input of a single annotations
Input_kps = list of 4 elements with 0=x, 1=y, 2= confidence, 3 = ? in pixels
Output_kps = [x0, y0, x1,...x15, y15] in meters normalized (z=1) and zero-centered using the center of the box
def get_keypoints(keypoints, mode):
"""
Extract center, shoulder or hip points of a keypoint
Input --> list or torch.tensor [(m, 3, 17) or (3, 17)]
Output --> torch.tensor [(m, 2)]
"""
if type(keypoints) == list:
keypoints = torch.tensor(keypoints)
if len(keypoints.size()) == 2: # add batch dim
keypoints = keypoints.unsqueeze(0)
kps_uv = []
kps_0c = []
kps_orig = []
# Create center of the bounding box using min max of the keypoints
uu_c, vv_c = get_keypoints(kps[0], kps[1], mode='center')
uv_center = np.array([uu_c, vv_c, 1])
# Create a list of single arrays of (u, v, 1)
for idx, _ in enumerate(kps[0]):
uv_kp = np.array([kps[0][idx], kps[1][idx], 1])
kps_uv.append(uv_kp)
# Projection in normalized image coordinates and zero-center with the center of the bounding box
xy1_center = pixel_to_camera(uv_center, kk, 1) * 10
for idx, kp in enumerate(kps_uv):
kp_proj = pixel_to_camera(kp, kk, 1) * 10
kp_proj_0c = kp_proj - xy1_center
kps_0c.append(float(kp_proj_0c[0]))
kps_0c.append(float(kp_proj_0c[1]))
kp_orig = pixel_to_camera(kp, kk, 1)
kps_orig.append(float(kp_orig[0]))
kps_orig.append(float(kp_orig[1]))
return kps_0c, kps_orig
def get_keypoints(kps_0, kps_1, mode):
"""Get the center of 2 lists"""
assert mode == 'center' or mode == 'shoulder' or mode == 'hip'
assert len(keypoints.size()) == 3 and keypoints.size()[1] == 3, "tensor dimensions not recognized"
assert mode in ['center', 'head', 'shoulder', 'hip' , 'ankle']
kps_in = keypoints[:, 0:2, :] # (m, 2, 17)
if mode == 'center':
uu = (max(kps_0) - min(kps_0)) / 2 + min(kps_0)
vv = (max(kps_1) - min(kps_1)) / 2 + min(kps_1)
kps_max, _ = kps_in.max(2) # returns value, indices
kps_min, _ = kps_in.min(2)
kps_out = (kps_max - kps_min) / 2 + kps_min # (m, 2) as keepdims is False
elif mode == 'head':
kps_out = kps_in[:, :, 0:5].mean(2)
elif mode == 'shoulder':
uu = float(np.average(kps_0[5:7]))
vv = float(np.average(kps_1[5:7]))
kps_out = kps_in[:, :, 5:7].mean(2)
elif mode == 'hip':
uu = float(np.average(kps_0[11:13]))
vv = float(np.average(kps_1[11:13]))
kps_out = kps_in[:, :, 11:13].mean(2)
return uu, vv
elif mode == 'ankle':
kps_out = kps_in[:, :, 15:17].mean(2)
return kps_out # (m, 2)
def transform_kp(kps, tr_mode):
@ -118,7 +108,7 @@ def transform_kp(kps, tr_mode):
or tr_mode == 'shoulder' or tr_mode == 'knee' or tr_mode == 'upside' or tr_mode == 'falling' \
or tr_mode == 'random'
uu_c, vv_c = get_keypoints(kps[0], kps[1], mode='center')
uu_c, vv_c = get_keypoints(kps, mode='center')
if tr_mode == "None":
return kps
@ -180,25 +170,33 @@ def transform_kp(kps, tr_mode):
return [uus, vvs, kps[2], []]
def get_depth(uv_center, kk, dd):
def xyz_from_distance(distances, xy_centers):
"""
From distances and normalized image coordinates (z=1), extract the real world position xyz
distances --> tensor (m,1) or (m) or float
xy_centers --> tensor(m,3) or (3)
"""
if len(uv_center) == 2:
uv_center.extend([1])
uv_center_np = np.array(uv_center)
xyz_norm = pixel_to_camera(uv_center, kk, 1)
zz = dd / math.sqrt(1 + xyz_norm[0] ** 2 + xyz_norm[1] ** 2)
if type(distances) == float:
distances = torch.tensor(distances).unsqueeze(0)
if len(distances.size()) == 1:
distances = torch.tensor(distances).unsqueeze(1)
if len(xy_centers.size()) == 1:
xy_centers = xy_centers.unsqueeze(0)
xyz = pixel_to_camera(uv_center_np, kk, zz).tolist()
return xyz
assert xy_centers.size()[-1] == 3 and distances.size()[-1] == 1, "Size of tensor not recognized"
return xy_centers * distances / torch.sqrt(1 + xy_centers[:, 0:1].pow(2) + xy_centers[:, 1:2].pow(2))
def get_depth_from_distance(outputs, xy_centers):
list_zzs = []
for idx, _ in enumerate(outputs):
dd = float(outputs[idx][0])
xx_1 = float(xy_centers[idx][0])
yy_1 = float(xy_centers[idx][1])
zz = dd / math.sqrt(1 + xx_1 ** 2 + yy_1 ** 2)
list_zzs.append(zz)
return list_zzs
def pixel_to_camera_old(uv1, kk, z_met):
"""
(3,) array --> (3,) array
Convert a point in pixel coordinate to absolute camera coordinates
"""
if len(uv1) == 2:
uv1.append(1)
kk_1 = np.linalg.inv(kk)
xyz_met_norm = np.dot(kk_1, uv1)
xyz_met = xyz_met_norm * z_met
return xyz_met

View File

@ -1,37 +1,7 @@
import numpy as np
import copy
import math
from utils.camera import pixel_to_camera, get_keypoints
from eval.geom_baseline import compute_distance_single
def eval_geometric(uv_kps, uv_centers, uv_shoulders, kk, average_y=0.48):
"""
Evaluate geometric distance
"""
xy_centers = []
dds_geom = []
for idx, _ in enumerate(uv_centers):
uv_center = copy.deepcopy(uv_centers[idx])
uv_center.append(1)
uv_shoulder = copy.deepcopy(uv_shoulders[idx])
uv_shoulder.append(1)
uv_kp = uv_kps[idx]
xy_center = pixel_to_camera(uv_center, kk, 1)
xy_centers.append(xy_center.tolist())
uu_2, vv_2 = get_keypoints(uv_kp[0], uv_kp[1], mode='hip')
uv_hip = [uu_2, vv_2, 1]
zz, _ = compute_distance_single(uv_shoulder, uv_hip, kk, average_y)
xyz_center = np.array([xy_center[0], xy_center[1], zz])
dd_geom = float(np.linalg.norm(xyz_center))
dds_geom.append(dd_geom)
return dds_geom, xy_centers
def get_calibration(path_txt):
"""Read calibration parameters from txt file:
@ -71,17 +41,17 @@ def get_calibration(path_txt):
def get_translation(pp):
"""Separate intrinsic matrix from translation"""
"""Separate intrinsic matrix from translation and convert in lists"""
kk = pp[:, :-1]
f_x = kk[0, 0]
f_y = kk[1, 1]
x0, y0 = kk[2, 0:2]
aa, bb, t3 = pp[0:3, 3]
t1 = (aa - x0*t3) / f_x
t2 = (bb - y0*t3) / f_y
tt = np.array([t1, t2, t3]).reshape(3, 1)
return kk, tt
t1 = float((aa - x0*t3) / f_x)
t2 = float((bb - y0*t3) / f_y)
tt = [t1, t2, float(t3)]
return kk.tolist(), tt
def get_simplified_calibration(path_txt):
@ -99,12 +69,11 @@ def get_simplified_calibration(path_txt):
raise ValueError('Matrix K_02 not found in the file')
def check_conditions(line, mode, thresh=0.5):
def check_conditions(line, mode, thresh=0.3):
"""Check conditions of our or m3d txt file"""
check = False
assert mode == 'gt' or mode == 'm3d' or mode == '3dop' or mode == 'our', "Type not recognized"
assert mode in ['gt', 'gt_all', 'm3d', '3dop','our'], "Mode %r not recognized" % mode
if mode == 'm3d' or mode == '3dop':
conf = line.split()[15]
@ -116,8 +85,13 @@ def check_conditions(line, mode, thresh=0.5):
if line[:10] == 'Pedestrian':
check = True
# Consider also person sitting and cyclists categories
elif mode == 'gt_all':
if line[:10] == 'Pedestrian' or line[:10] == 'Person_sit' or line[:7] == 'Cyclist':
check = True
elif mode == 'our':
if line[10] >= thresh:
if line[4] >= thresh:
check = True
return check
@ -126,7 +100,6 @@ def check_conditions(line, mode, thresh=0.5):
def get_category(box, trunc, occ):
hh = box[3] - box[1]
if hh >= 40 and trunc <= 0.15 and occ <= 0:
cat = 'easy'
elif trunc <= 0.3 and occ <= 1 and hh >= 25:
@ -135,7 +108,6 @@ def get_category(box, trunc, occ):
cat = 'hard'
else:
cat = 'excluded'
return cat
@ -158,7 +130,7 @@ def split_training(names_gt, path_train, path_val):
return set_train, set_val
def parse_ground_truth(path_gt):
def parse_ground_truth(path_gt, mode='gt'):
"""Parse KITTI ground truth files"""
boxes_gt = []
dds_gt = []
@ -168,7 +140,7 @@ def parse_ground_truth(path_gt):
with open(path_gt, "r") as f_gt:
for line_gt in f_gt:
if check_conditions(line_gt, mode='gt'):
if check_conditions(line_gt, mode=mode):
truncs_gt.append(float(line_gt.split()[1]))
occs_gt.append(int(line_gt.split()[2]))
boxes_gt.append([float(x) for x in line_gt.split()[4:8]])
@ -177,4 +149,4 @@ def parse_ground_truth(path_gt):
boxes_3d.append(loc_gt + wlh)
dds_gt.append(math.sqrt(loc_gt[0] ** 2 + loc_gt[1] ** 2 + loc_gt[2] ** 2))
return (boxes_gt, boxes_3d, dds_gt, truncs_gt, occs_gt)
return boxes_gt, boxes_3d, dds_gt, truncs_gt, occs_gt

View File

@ -3,7 +3,7 @@ import numpy as np
import torch
import time
import logging
# from shapely.geometry import box as Sbox
def set_logger(log_path):
"""Set the logger to log info in terminal and file `log_path`.
@ -70,7 +70,6 @@ def get_iou_matrix(boxes, boxes_gt):
Dim: (boxes, boxes_gt)
"""
iou_matrix = np.zeros((len(boxes), len(boxes_gt)))
for idx, box in enumerate(boxes):
for idx_gt, box_gt in enumerate(boxes_gt):
iou_matrix[idx, idx_gt] = calculate_iou(box, box_gt)
@ -96,36 +95,21 @@ def get_iou_matches(boxes, boxes_gt, thresh):
return matches
def reparametrize_box3d(box):
"""Reparametrized 3D box in the XZ plane and add the height"""
def reorder_matches(matches, boxes, mode='left_rigth'):
"""
Reorder a list of (idx, idx_gt) matches based on position of the detections in the image
ordered_boxes = (5, 6, 7, 0, 1, 4, 2, 4)
matches = [(0, x), (2,x), (4,x), (3,x), (5,x)]
Output --> [(5, x), (0, x), (3, x), (2, x), (5, x)]
"""
hh, ww, ll = box[0:3]
x_c, y_c, z_c = box[3:6]
assert mode == 'left_right'
x1 = x_c - ll/2
z1 = z_c - ww/2
x2 = x_c + ll/2
z2 = z_c + ww / 2
# Order the boxes based on the left-right position in the image and
ordered_boxes = np.argsort([box[0] for box in boxes]) # indices of boxes ordered from left to right
matches_left = [idx for (idx, _) in matches]
return [x1, z1, x2, z2, hh]
# def calculate_iou3d(box3d_1, box3d_2):
# """3D intersection over union. Boxes are parametrized as x1, z1, x2, z2, hh
# We compute 2d iou in the birds plane and then add a factor for height differences (0-1)"""
#
# poly1 = Sbox(box3d_1[0], box3d_1[1], box3d_1[2], box3d_1[3])
# poly2 = Sbox(box3d_2[0], box3d_2[1], box3d_2[2], box3d_2[3])
#
# inter_2d = poly1.intersection(poly2).area
# union_2d = poly1.area + poly2.area - inter_2d
#
# # height_factor = 1 - abs(box3d_1[4] - box3d_2[4]) / max(box3d_1[4], box3d_2[4])
#
# #
# iou_3d = inter_2d / union_2d # * height_factor
#
# return iou_3d
return [matches[matches_left.index(idx_boxes)] for idx_boxes in ordered_boxes if idx_boxes in matches_left]
def laplace_sampling(outputs, n_samples):
@ -135,7 +119,6 @@ def laplace_sampling(outputs, n_samples):
mu = outputs[:, 0]
bi = torch.abs(outputs[:, 1])
# Analytical
# uu = np.random.uniform(low=-0.5, high=0.5, size=mu.shape[0])
# xx = mu - bi * np.sign(uu) * np.log(1 - 2 * np.abs(uu))
@ -148,30 +131,13 @@ def laplace_sampling(outputs, n_samples):
device = torch.device(type="cuda", index=get_device)
else:
device = torch.device("cpu")
t1 = time.time()
xxs = torch.empty((0, mu.shape[0])).to(device)
t2 = time.time()
laplace = torch.distributions.Laplace(mu, bi)
t3 = time.time()
for ii in range(1):
xx = laplace.sample((n_samples,))
t4a = time.time()
xxs = torch.cat((xxs, xx.view(n_samples, -1)), 0)
t4 = time.time()
# time_tot = t4 - t0
# time_1 = t1 - t0
# time_2 = t2 - t1
# time_3 = t3 - t2
# time_4a = t4a - t3
# time_4 = t4 - t3
# print("Time 1: {:.1f}%".format(time_1 / time_tot * 100))
# print("Time 2: {:.1f}%".format(time_2 / time_tot * 100))
# print("Time 3: {:.1f}%".format(time_3 / time_tot * 100))
# print("Time 4a: {:.1f}%".format(time_4a / time_tot * 100))
# print("Time 4: {:.1f}%".format(time_4 / time_tot * 100))
return xxs
@ -191,48 +157,30 @@ def append_cluster(dic_jo, phase, xx, dd, kps):
"""Append the annotation based on its distance"""
# if dd <= 6:
# dic_jo[phase]['clst']['6']['kps'].append(kps)
# dic_jo[phase]['clst']['6']['X'].append(xx)
# dic_jo[phase]['clst']['6']['Y'].append([dd]) # Trick to make it (nn,1) instead of (nn, )
if dd <= 10:
dic_jo[phase]['clst']['10']['kps'].append(kps)
dic_jo[phase]['clst']['10']['X'].append(xx)
dic_jo[phase]['clst']['10']['Y'].append([dd])
# elif dd <= 15:
# dic_jo[phase]['clst']['15']['kps'].append(kps)
# dic_jo[phase]['clst']['15']['X'].append(xx)
# dic_jo[phase]['clst']['15']['Y'].append([dd])
elif dd <= 20:
dic_jo[phase]['clst']['20']['kps'].append(kps)
dic_jo[phase]['clst']['20']['X'].append(xx)
dic_jo[phase]['clst']['20']['Y'].append([dd])
# elif dd <= 25:
# dic_jo[phase]['clst']['25']['kps'].append(kps)
# dic_jo[phase]['clst']['25']['X'].append(xx)
# dic_jo[phase]['clst']['25']['Y'].append([dd])
elif dd <= 30:
dic_jo[phase]['clst']['30']['kps'].append(kps)
dic_jo[phase]['clst']['30']['X'].append(xx)
dic_jo[phase]['clst']['30']['Y'].append([dd])
# elif dd <= 40:
# dic_jo[phase]['clst']['40']['kps'].append(kps)
# dic_jo[phase]['clst']['40']['X'].append(xx)
# dic_jo[phase]['clst']['40']['Y'].append([dd])
#
# elif dd <= 50:
# dic_jo[phase]['clst']['50']['kps'].append(kps)
# dic_jo[phase]['clst']['50']['X'].append(xx)
# dic_jo[phase]['clst']['50']['Y'].append([dd])
else:
dic_jo[phase]['clst']['>30']['kps'].append(kps)
dic_jo[phase]['clst']['>30']['X'].append(xx)
dic_jo[phase]['clst']['>30']['Y'].append([dd])
def get_task_error(dd):
"""Get target error not knowing the gender"""
mm_gender = 0.0556
return mm_gender * dd

View File

@ -40,68 +40,3 @@ def unnormalize_bi(outputs):
outputs[:, 1] = torch.exp(outputs[:, 1]) * outputs[:, 0]
return outputs
# def normalize_arrays_jo(dic_jo):
# """Normalize according to the mean and std of each keypoint in the training dataset
# PS normalization of training also for test and val"""
#
# # Normalize
# phase = 'train'
# kps_orig_tr = np.array(dic_jo[phase]['X'])
# # dd_orig_tr = np.array(dic_jo[phase]['Y']).reshape(-1, 1)
# kps_mean = np.mean(kps_orig_tr, axis=0)
# plt.hist(kps_orig_tr, bins=100)
# plt.show()
# kps_std = np.std(kps_orig_tr, axis=0)
# # dd_mean = np.mean(dd_orig_tr, axis=0)
# # dd_std = np.std(dd_orig_tr, axis=0)
#
# for phase in dic_jo:
#
# # Compute the normalized arrays
# kps_orig = np.array(dic_jo[phase]['X'])
# dd_orig = np.array(dic_jo[phase]['Y']).reshape(-1, 1)
# kps_norm = np.divide((kps_orig - kps_mean), kps_std)
#
# # dd_norm = np.divide((dd_orig - dd_mean), dd_std) # ! No normalization on the output
#
# # Substitute the new values in the dictionary and save the mean and std
# dic_jo[phase]['X'] = kps_norm.tolist()
# dic_jo[phase]['mean']['X'] = kps_mean.tolist()
# dic_jo[phase]['std']['X'] = kps_std.tolist()
#
# dic_jo[phase]['Y'] = dd_orig.tolist()
# # dic_jo[phase]['mean']['Y'] = float(dd_mean)
# # dic_jo[phase]['std']['Y'] = float(dd_std)
#
# # Normalize all the clusters
# for clst in dic_jo[phase]['clst']:
#
# # Extract
# kps_orig = np.array(dic_jo[phase]['clst'][clst]['X'])
# dd_orig = np.array(dic_jo[phase]['clst'][clst]['Y']).reshape(-1, 1)
# # Normalize
# kps_norm = np.divide((kps_orig - kps_mean), kps_std)
#
# # dd_norm = np.divide((dd_orig - dd_mean), dd_std) #! No normalization on the output
#
# # Put back
# dic_jo[phase]['clst'][clst]['X'] = kps_norm.tolist()
# dic_jo[phase]['clst'][clst]['Y'] = dd_orig.tolist()
#
# return dic_jo
#
#
# def check_cluster_dim(dic_jo):
# """ Check that the sum of the clusters corresponds to all annotations"""
#
# for phase in ['train', 'val', 'test']:
# cnt_clst = 0
# cnt_all = len(dic_jo[phase]['X'])
# for clst in dic_jo[phase]['clst']:
# cnt_clst += len(dic_jo[phase]['clst'][clst]['X'])
# assert cnt_all == cnt_clst

View File

@ -1,6 +1,7 @@
import numpy as np
from utils.camera import preprocess_single, get_keypoints, pixel_to_camera
import torch
from utils.camera import get_keypoints, pixel_to_camera
def preprocess_pif(annotations, im_size=None):
@ -45,56 +46,75 @@ def preprocess_pif(annotations, im_size=None):
return boxes, keypoints
def get_input_data(boxes, keypoints, kk, left_to_right=False):
inputs = []
xy_centers = []
uv_boxes = []
uv_centers = []
uv_shoulders = []
uv_kps = []
xy_kps = []
def get_network_inputs(keypoints, kk):
if left_to_right: # Order boxes from left to right
ordered = np.argsort([xx[0] for xx in boxes])
""" Preprocess batches of inputs
keypoints = torch tensors of (m, 3, 17) or list [3,17]
Outputs = torch tensors of (m, 34) in meters normalized (z=1) and zero-centered using the center of the box
"""
if type(keypoints) == list:
keypoints = torch.tensor(keypoints)
if type(kk) == list:
kk = torch.tensor(kk)
# Projection in normalized image coordinates and zero-center with the center of the bounding box
uv_center = get_keypoints(keypoints, mode='center')
xy1_center = pixel_to_camera(uv_center, kk, 1) * 10
xy1_all = pixel_to_camera(keypoints[:, 0:2, :], kk, 1) * 10
kps_norm = xy1_all - xy1_center.unsqueeze(1) # (m, 17, 3) - (m, 1, 3)
kps_out = kps_norm[:, :, 0:2].reshape(kps_norm.size()[0], -1) # no contiguous for view
return kps_out
else: # Order boxes from most to less confident
confs = []
for idx, box in enumerate(boxes):
confs.append(box[4])
ordered = np.argsort(confs).tolist()[::-1]
for idx in ordered:
kps = keypoints[idx]
uv_kps.append(kps)
uv_boxes.append(boxes[idx])
def preprocess_pif(annotations, im_size=None):
"""
Preprocess pif annotations:
1. enlarge the box of 10%
2. Constraint it inside the image (if image_size provided)
"""
uu_c, vv_c = get_keypoints(kps[0], kps[1], "center")
uv_centers.append([round(uu_c), round(vv_c)])
xy_center = pixel_to_camera(np.array([uu_c, vv_c, 1]), kk, 1)
xy_centers.append(xy_center)
boxes = []
keypoints = []
uu_1, vv_1 = get_keypoints(kps[0], kps[1], "shoulder")
uv_shoulders.append([round(uu_1), round(vv_1)])
for dic in annotations:
box = dic['bbox']
if box[3] < 0.5: # Check for no detections (boxes 0,0,0,0)
return [], []
# 2 steps of input normalization for each instance
kps_prep, kps_orig = preprocess_single(kps, kk)
inputs.append(kps_prep)
xy_kps.append(kps_orig)
else:
kps = prepare_pif_kps(dic['keypoints'])
conf = float(np.mean(np.array(kps[2])))
return (inputs, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders)
# Add 10% for y
delta_h = (box[3] - box[1]) / 10
delta_w = (box[2] - box[0]) / 10
assert delta_h > 0 and delta_w > 0, "Bounding box <=0"
box[0] -= delta_w
box[1] -= delta_h
box[2] += delta_w
box[3] += delta_h
# Put the box inside the image
if im_size is not None:
box[0] = max(0, box[0])
box[1] = max(0, box[1])
box[2] = min(box[2], im_size[0])
box[3] = min(box[3], im_size[1])
box.append(conf)
boxes.append(box)
keypoints.append(kps)
return boxes, keypoints
def prepare_pif_kps(kps_in):
"""Convert from a list of 51 to a list of 3, 17"""
keypoints = np.array(kps_in).reshape(-1, 3).tolist()
xxs = []
yys = []
ccs = []
for kp_triple in keypoints:
xxs.append(kp_triple[0])
yys.append(kp_triple[1])
ccs.append(kp_triple[2])
assert len(kps_in) % 3 == 0, "keypoints expected as a multiple of 3"
xxs = kps_in[0:][::3]
yys = kps_in[1:][::3] # from offset 1 every 3
ccs = kps_in[2:][::3]
return [xxs, yys, ccs]

View File

@ -1,15 +1,18 @@
import os
import math
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.patches import Ellipse, Circle
import cv2
from collections import OrderedDict
from utils.camera import pixel_to_camera
from utils.misc import get_task_error
class Printer:
@ -158,13 +161,13 @@ class Printer:
# Create bird or combine it with front)
if any(xx in self.output_types for xx in ['bird', 'combined']):
uv_max = np.array([0, self.hh, 1])
uv_max = [0., float(self.hh)]
xyz_max = pixel_to_camera(uv_max, self.kk, self.z_max)
x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk
for idx, _ in enumerate(self.xx_gt):
if self.zz_gt[idx] > 0:
target = get_target_error(self.dds_real[idx])
target = get_task_error(self.dds_real[idx])
angle = get_angle(self.xx_gt[idx], self.zz_gt[idx])
ellipse_real = Ellipse((self.xx_gt[idx], self.zz_gt[idx]), width=target * 2, height=1,
@ -270,9 +273,3 @@ def get_angle(xx, zz):
angle = theta * (180 / math.pi)
return angle
def get_target_error(dd):
"""Get target error not knowing the gender"""
mm_gender = 0.0556
return mm_gender * dd