diff --git a/README.md b/README.md index 68a9e1a..0469f28 100644 --- a/README.md +++ b/README.md @@ -171,14 +171,14 @@ For more info, run: **Examples**
An example from the Collective Activity Dataset is provided below. - + To visualize social distancing run the below, command: ``` -python -m monoloco.run predict docs/frame0038.jpg \ +python -m monoloco.run predict docs/frame0032.jpg \ --social_distance --output_types front bird ``` - + ## C) Orientation and Bounding Box dimensions diff --git a/docs/frame0032.jpg b/docs/frame0032.jpg new file mode 100644 index 0000000..34f3396 Binary files /dev/null and b/docs/frame0032.jpg differ diff --git a/docs/frame0038.jpg b/docs/frame0038.jpg deleted file mode 100644 index 7050b2c..0000000 Binary files a/docs/frame0038.jpg and /dev/null differ diff --git a/docs/out_frame0032_front_bird.jpg b/docs/out_frame0032_front_bird.jpg new file mode 100644 index 0000000..268d2db Binary files /dev/null and b/docs/out_frame0032_front_bird.jpg differ diff --git a/docs/out_frame0038.jpg.front_bird.jpg b/docs/out_frame0038.jpg.front_bird.jpg deleted file mode 100644 index 3fbfaac..0000000 Binary files a/docs/out_frame0038.jpg.front_bird.jpg and /dev/null differ diff --git a/monoloco/predict.py b/monoloco/predict.py index b2de878..4dd5d26 100644 --- a/monoloco/predict.py +++ b/monoloco/predict.py @@ -26,7 +26,8 @@ from .activity import show_social LOG = logging.getLogger(__name__) OPENPIFPAF_MODEL = 'https://drive.google.com/uc?id=1b408ockhh29OLAED8Tysd2yGZOo0N_SQ' -MONOLOCO_MODEL = 'https://drive.google.com/uc?id=1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r' +MONOLOCO_MODEL_KI = 'https://drive.google.com/uc?id=1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r' +MONOLOCO_MODEL_NU = 'https://drive.google.com/uc?id=1BKZWJ1rmkg5AF9rmBEfxF1r8s8APwcyC' MONSTEREO_MODEL = 'https://drive.google.com/uc?id=1xztN07dmp2e_nHI6Lcn103SAzt-Ntg49' @@ -54,18 +55,23 @@ def download_checkpoints(args): if args.mode == 'keypoints': return dic_models - elif args.model is not None: + if args.model is not None: + assert os.path.exists(args.model), "Model path not found" dic_models[args.mode] = args.model return dic_models - elif args.mode == 'mono': - model = os.path.join(torch_dir, 'monoloco_pp-201203-1424.pkl') - path = MONOLOCO_MODEL - dic_models[args.mode] = model - else: - model = os.path.join(torch_dir, 'monstereo-201202-1212.pkl') + if args.mode == 'stereo': + assert not args.social_distance, "Social distance not supported in stereo modality" path = MONSTEREO_MODEL - dic_models[args.mode] = model + name = 'monstereo-201202-1212.pkl' + elif args.social_distance: + path = MONOLOCO_MODEL_NU + name = 'monoloco_pp-201207-1350.pkl' + else: + path = MONOLOCO_MODEL_KI + name = 'monoloco_pp-201203-1424.pkl' + model = os.path.join(torch_dir, name) + dic_models[args.mode] = model if not os.path.exists(model): import gdown LOG.info(f'Downloading model (modality: {args.mode}) in {torch_dir}') diff --git a/monoloco/run.py b/monoloco/run.py index 566e71b..bb7dc03 100644 --- a/monoloco/run.py +++ b/monoloco/run.py @@ -46,7 +46,7 @@ def cli(): predict_parser.add_argument('--monocolor-connections', default=False, action='store_true', help='use a single color per instance') predict_parser.add_argument('--instance-threshold', type=float, default=None, help='threshold for entire instance') - predict_parser.add_argument('--seed-threshold', type=float, default=None, help='threshold for single seed') + predict_parser.add_argument('--seed-threshold', type=float, default=0.5, help='threshold for single seed') predict_parser.add_argument('--disable-cuda', action='store_true', help='disable CUDA') predict_parser.add_argument('--focal', help='focal length in mm for a sensor size of 7.2x5.4 mm. Default nuScenes sensor',