diff --git a/README.md b/README.md
index 68a9e1a..0469f28 100644
--- a/README.md
+++ b/README.md
@@ -171,14 +171,14 @@ For more info, run:
**Examples**
An example from the Collective Activity Dataset is provided below.
-
+
To visualize social distancing run the below, command:
```
-python -m monoloco.run predict docs/frame0038.jpg \
+python -m monoloco.run predict docs/frame0032.jpg \
--social_distance --output_types front bird
```
-
+
## C) Orientation and Bounding Box dimensions
diff --git a/docs/frame0032.jpg b/docs/frame0032.jpg
new file mode 100644
index 0000000..34f3396
Binary files /dev/null and b/docs/frame0032.jpg differ
diff --git a/docs/frame0038.jpg b/docs/frame0038.jpg
deleted file mode 100644
index 7050b2c..0000000
Binary files a/docs/frame0038.jpg and /dev/null differ
diff --git a/docs/out_frame0032_front_bird.jpg b/docs/out_frame0032_front_bird.jpg
new file mode 100644
index 0000000..268d2db
Binary files /dev/null and b/docs/out_frame0032_front_bird.jpg differ
diff --git a/docs/out_frame0038.jpg.front_bird.jpg b/docs/out_frame0038.jpg.front_bird.jpg
deleted file mode 100644
index 3fbfaac..0000000
Binary files a/docs/out_frame0038.jpg.front_bird.jpg and /dev/null differ
diff --git a/monoloco/predict.py b/monoloco/predict.py
index b2de878..4dd5d26 100644
--- a/monoloco/predict.py
+++ b/monoloco/predict.py
@@ -26,7 +26,8 @@ from .activity import show_social
LOG = logging.getLogger(__name__)
OPENPIFPAF_MODEL = 'https://drive.google.com/uc?id=1b408ockhh29OLAED8Tysd2yGZOo0N_SQ'
-MONOLOCO_MODEL = 'https://drive.google.com/uc?id=1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r'
+MONOLOCO_MODEL_KI = 'https://drive.google.com/uc?id=1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r'
+MONOLOCO_MODEL_NU = 'https://drive.google.com/uc?id=1BKZWJ1rmkg5AF9rmBEfxF1r8s8APwcyC'
MONSTEREO_MODEL = 'https://drive.google.com/uc?id=1xztN07dmp2e_nHI6Lcn103SAzt-Ntg49'
@@ -54,18 +55,23 @@ def download_checkpoints(args):
if args.mode == 'keypoints':
return dic_models
- elif args.model is not None:
+ if args.model is not None:
+ assert os.path.exists(args.model), "Model path not found"
dic_models[args.mode] = args.model
return dic_models
- elif args.mode == 'mono':
- model = os.path.join(torch_dir, 'monoloco_pp-201203-1424.pkl')
- path = MONOLOCO_MODEL
- dic_models[args.mode] = model
- else:
- model = os.path.join(torch_dir, 'monstereo-201202-1212.pkl')
+ if args.mode == 'stereo':
+ assert not args.social_distance, "Social distance not supported in stereo modality"
path = MONSTEREO_MODEL
- dic_models[args.mode] = model
+ name = 'monstereo-201202-1212.pkl'
+ elif args.social_distance:
+ path = MONOLOCO_MODEL_NU
+ name = 'monoloco_pp-201207-1350.pkl'
+ else:
+ path = MONOLOCO_MODEL_KI
+ name = 'monoloco_pp-201203-1424.pkl'
+ model = os.path.join(torch_dir, name)
+ dic_models[args.mode] = model
if not os.path.exists(model):
import gdown
LOG.info(f'Downloading model (modality: {args.mode}) in {torch_dir}')
diff --git a/monoloco/run.py b/monoloco/run.py
index 566e71b..bb7dc03 100644
--- a/monoloco/run.py
+++ b/monoloco/run.py
@@ -46,7 +46,7 @@ def cli():
predict_parser.add_argument('--monocolor-connections', default=False, action='store_true',
help='use a single color per instance')
predict_parser.add_argument('--instance-threshold', type=float, default=None, help='threshold for entire instance')
- predict_parser.add_argument('--seed-threshold', type=float, default=None, help='threshold for single seed')
+ predict_parser.add_argument('--seed-threshold', type=float, default=0.5, help='threshold for single seed')
predict_parser.add_argument('--disable-cuda', action='store_true', help='disable CUDA')
predict_parser.add_argument('--focal',
help='focal length in mm for a sensor size of 7.2x5.4 mm. Default nuScenes sensor',