diff --git a/monoloco/predict.py b/monoloco/predict.py index e157027..66eabcf 100644 --- a/monoloco/predict.py +++ b/monoloco/predict.py @@ -53,13 +53,15 @@ def get_torch_checkpoints_dir(): def download_checkpoints(args): torch_dir = get_torch_checkpoints_dir() + os.makedirs(torch_dir, exist_ok=True) if args.checkpoint is None: + os.makedirs(torch_dir, exist_ok=True) pifpaf_model = os.path.join(torch_dir, 'shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl') else: pifpaf_model = args.checkpoint dic_models = {'keypoints': pifpaf_model} if not os.path.exists(pifpaf_model): - assert DOWNLOAD is not None, "pip install gdown to download pifpaf model, or pass it as --checkpoint" + assert DOWNLOAD is not None, "install gdown to download pifpaf model, or pass it as --checkpoint" LOG.info('Downloading OpenPifPaf model in %s', torch_dir) DOWNLOAD(OPENPIFPAF_MODEL, pifpaf_model, quiet=False) @@ -83,6 +85,7 @@ def download_checkpoints(args): model = os.path.join(torch_dir, name) dic_models[args.mode] = model if not os.path.exists(model): + os.makedirs(torch_dir, exist_ok=True) assert DOWNLOAD is not None, "pip install gdown to download monoloco model, or pass it as --model" LOG.info('Downloading model in %s', torch_dir) DOWNLOAD(path, model, quiet=False)