diff --git a/monoloco/predict.py b/monoloco/predict.py index c86f46f..928d2c4 100644 --- a/monoloco/predict.py +++ b/monoloco/predict.py @@ -25,9 +25,9 @@ from .activity import show_social LOG = logging.getLogger(__name__) -OPENPIFPAF_MODEL = 'https://drive.google.com/file/d/1b408ockhh29OLAED8Tysd2yGZOo0N_SQ/view?usp=sharing' -MONOLOCO_MODEL = 'https://drive.google.com/file/d/1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r/view?usp=sharing' -MONSTEREO_MODEL = 'https://drive.google.com/file/d/1xztN07dmp2e_nHI6Lcn103SAzt-Ntg49/view?usp=sharing' +OPENPIFPAF_MODEL = 'https://drive.google.com/uc?id=1b408ockhh29OLAED8Tysd2yGZOo0N_SQ' +MONOLOCO_MODEL = 'https://drive.google.com/uc?id=1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r' +MONSTEREO_MODEL = 'https://drive.google.com/uc?id=1xztN07dmp2e_nHI6Lcn103SAzt-Ntg49' def get_torch_checkpoints_dir(): @@ -47,9 +47,10 @@ def download_checkpoints(args): torch_dir = get_torch_checkpoints_dir() pifpaf_model = os.path.join(torch_dir, 'shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl') dic_models = {'keypoints': pifpaf_model} - + print(torch_dir) if not os.path.exists(pifpaf_model): import gdown + LOG.info("Downloading OpenPifPaf model in %s".format(torch_dir)) gdown.download(OPENPIFPAF_MODEL, pifpaf_model, quiet=False) if args.mode == 'keypoints': @@ -68,8 +69,8 @@ def download_checkpoints(args): if not os.path.exists(model): import gdown + LOG.info("Downloading model in %s".format(torch_dir)) gdown.download(path, model, quiet=False) - return dic_models