add combined for different dataset
This commit is contained in:
parent
afc439862d
commit
9c547f418b
@ -57,19 +57,25 @@ class PredictMonoLoco:
|
||||
try:
|
||||
with open(args.path_gt, 'r') as f:
|
||||
self.dic_names = json.load(f)
|
||||
print('-' * 120 + "\nMonoloco: Ground-truth file opened\n")
|
||||
except FileNotFoundError:
|
||||
self.dic_names = None
|
||||
print('-' * 120 + "\nMonoloco: ground truth file not found\n" + '-' * 120)
|
||||
print('-' * 120 + "\nMonoloco: ground-truth file not found\n")
|
||||
|
||||
def run(self):
|
||||
# Extract calibration matrix if ground-truth file is present or use a default one
|
||||
cnt = 0
|
||||
name = os.path.basename(self.image_path)
|
||||
if self.dic_names:
|
||||
try:
|
||||
kk = self.dic_names[name]['K']
|
||||
else:
|
||||
# kk = [[1266.4, 0., 816.27], [0, 1266.4, 491.5], [0., 0., 1.]]
|
||||
kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]]
|
||||
print("Monoloco: matched ground-truth file!\n" + '-' * 120)
|
||||
except (KeyError, TypeError):
|
||||
self.dic_names = None
|
||||
# kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]] # Kitti standard
|
||||
#kk = [[1266.4, 0., 816.27], [0, 1266.4, 491.5], [0., 0., 1.]] # Nuscenes standard
|
||||
kk = [[1266.4, 0., 816.27], [0, 1266.4, 491.5], [0., 0., 1.]]
|
||||
print("Ground-truth annotations for the image not found\n"
|
||||
"Using a standard calibration matrix...\n" + '-' * 120)
|
||||
|
||||
(inputs_norm, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = \
|
||||
get_input_data(self.boxes, self.keypoints, kk, left_to_right=True)
|
||||
@ -101,8 +107,9 @@ class PredictMonoLoco:
|
||||
outputs = self.model(inputs)
|
||||
outputs = unnormalize_bi(outputs)
|
||||
end = time.time()
|
||||
print("Total Forward pass time = {:.2f} ms".format((end-start) * 1000))
|
||||
print("Single pass time = {:.2f} ms".format((end - start_single) * 1000))
|
||||
print("Total Forward pass time with {} forward passes = {:.2f} ms"
|
||||
.format(self.n_dropout, (end-start) * 1000))
|
||||
print("Single forward pass time = {:.2f} ms".format((end - start_single) * 1000))
|
||||
|
||||
# Print image and save json
|
||||
dic_out = defaultdict(list)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user