diff --git a/docs/out_002282.png.multi_all.jpg b/docs/out_002282.png.multi_all.jpg index 08722fd..d327ac1 100644 Binary files a/docs/out_002282.png.multi_all.jpg and b/docs/out_002282.png.multi_all.jpg differ diff --git a/docs/out_002282_pifpaf.jpg b/docs/out_002282_pifpaf.jpg new file mode 100644 index 0000000..7a3b3a1 Binary files /dev/null and b/docs/out_002282_pifpaf.jpg differ diff --git a/monstereo/predict.py b/monstereo/predict.py index 3c5f342..571af47 100644 --- a/monstereo/predict.py +++ b/monstereo/predict.py @@ -22,6 +22,7 @@ from .activity import show_social LOG = logging.getLogger(__name__) +OPENPIFPAF_PATH = 'data/models/shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl' # Default model def factory_from_args(args): @@ -33,7 +34,12 @@ def factory_from_args(args): # Model if not args.checkpoint: - args.checkpoint = 'data/models/shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl' # Default model + if os.path.exists(OPENPIFPAF_PATH): + args.checkpoint = OPENPIFPAF_PATH + else: + print("Checkpoint for OpenPifPaf not specified and default model not found in 'data/models'. " + "Using a ShuffleNet backbone") + args.checkpoint = 'shufflenetv2k30' # Devices args.device = torch.device('cpu')