diff --git a/src/predict/predict_monoloco.py b/src/predict/predict_monoloco.py index 42657ab..e46d618 100644 --- a/src/predict/predict_monoloco.py +++ b/src/predict/predict_monoloco.py @@ -57,19 +57,25 @@ class PredictMonoLoco: try: with open(args.path_gt, 'r') as f: self.dic_names = json.load(f) + print('-' * 120 + "\nMonoloco: Ground-truth file opened\n") except FileNotFoundError: self.dic_names = None - print('-' * 120 + "\nMonoloco: ground truth file not found\n" + '-' * 120) + print('-' * 120 + "\nMonoloco: ground-truth file not found\n") def run(self): # Extract calibration matrix if ground-truth file is present or use a default one cnt = 0 name = os.path.basename(self.image_path) - if self.dic_names: + try: kk = self.dic_names[name]['K'] - else: - # kk = [[1266.4, 0., 816.27], [0, 1266.4, 491.5], [0., 0., 1.]] - kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]] + print("Monoloco: matched ground-truth file!\n" + '-' * 120) + except (KeyError, TypeError): + self.dic_names = None + # kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]] # Kitti standard + #kk = [[1266.4, 0., 816.27], [0, 1266.4, 491.5], [0., 0., 1.]] # Nuscenes standard + kk = [[1266.4, 0., 816.27], [0, 1266.4, 491.5], [0., 0., 1.]] + print("Ground-truth annotations for the image not found\n" + "Using a standard calibration matrix...\n" + '-' * 120) (inputs_norm, xy_kps), (uv_kps, uv_boxes, uv_centers, uv_shoulders) = \ get_input_data(self.boxes, self.keypoints, kk, left_to_right=True) @@ -101,8 +107,9 @@ class PredictMonoLoco: outputs = self.model(inputs) outputs = unnormalize_bi(outputs) end = time.time() - print("Total Forward pass time = {:.2f} ms".format((end-start) * 1000)) - print("Single pass time = {:.2f} ms".format((end - start_single) * 1000)) + print("Total Forward pass time with {} forward passes = {:.2f} ms" + .format(self.n_dropout, (end-start) * 1000)) + print("Single forward pass time = {:.2f} ms".format((end - start_single) * 1000)) # Print image and save json dic_out = defaultdict(list)