add model download and mode argument

This commit is contained in:
Lorenzo 2021-03-22 14:17:43 +01:00
parent 3b97afb89e
commit f0bbaa2a0e
6 changed files with 90 additions and 32 deletions

View File

@ -73,7 +73,7 @@ or check the file `monoloco/run.py`
# Predictions
For a quick setup download a pifpaf and MonoLoco++ / MonStereo models from
[here](https://drive.google.com/drive/folders/1jZToVMBEZQMdLB5BAIq2CdCLP5kzNo9t?usp=sharing) and save them into `data/models`.
[here](https://drive.google.com/drive/folders/1jZToVMBEZQMdLB5BAIq2CdCLP5kzNo9t?usp=sharing) and save them into `data/models`.
## A) 3D Localization
The predict script receives an image (or an entire folder using glob expressions),

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 MiB

View File

@ -25,9 +25,22 @@ class Loco:
LINEAR_SIZE_MONO = 256
N_SAMPLES = 100
def __init__(self, model, net='monstereo', device=None, n_dropout=0, p_dropout=0.2, linear_size=1024):
self.net = net
assert self.net in ('monstereo', 'monoloco', 'monoloco_p', 'monoloco_pp')
def __init__(self, model, mode, net=None, device=None, n_dropout=0, p_dropout=0.2, linear_size=1024):
# Select networks
assert mode in ('mono', 'stereo'), "mode not recognized"
self.mode = mode
if net is None:
if mode == 'mono':
self.net = 'monoloco_pp'
else:
self.net = 'monstereo'
else:
assert self.net in ('monstereo', 'monoloco', 'monoloco_p', 'monoloco_pp')
if self.net != 'monstereo':
assert mode == 'stereo', "Assert arguments mode and net are in conflict"
self.net = net
if self.net == 'monstereo':
input_size = 68
output_size = 10

View File

@ -25,7 +25,52 @@ from .activity import show_social
LOG = logging.getLogger(__name__)
OPENPIFPAF_PATH = 'data/models/shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl' # Default model
OPENPIFPAF_MODEL = 'https://drive.google.com/file/d/1b408ockhh29OLAED8Tysd2yGZOo0N_SQ/view?usp=sharing'
MONOLOCO_MODEL = 'https://drive.google.com/file/d/1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r/view?usp=sharing'
MONSTEREO_MODEL = 'https://drive.google.com/file/d/1xztN07dmp2e_nHI6Lcn103SAzt-Ntg49/view?usp=sharing'
def get_torch_checkpoints_dir():
if hasattr(torch, 'hub') and hasattr(torch.hub, 'get_dir'):
# new in pytorch 1.6.0
base_dir = torch.hub.get_dir()
elif os.getenv('TORCH_HOME'):
base_dir = os.getenv('TORCH_HOME')
elif os.getenv('XDG_CACHE_HOME'):
base_dir = os.path.join(os.getenv('XDG_CACHE_HOME'), 'torch')
else:
base_dir = os.path.expanduser(os.path.join('~', '.cache', 'torch'))
return os.path.join(base_dir, 'checkpoints')
def download_checkpoints(args):
torch_dir = get_torch_checkpoints_dir()
pifpaf_model = os.path.join(torch_dir, 'shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl')
dic_models = {'keypoints': pifpaf_model}
if not os.path.exists(pifpaf_model):
import gdown
gdown.download(OPENPIFPAF_MODEL, pifpaf_model, quiet=False)
if args.mode == 'keypoints':
return dic_models
elif args.model is not None:
dic_models[args.mode] = args.model
return dic_models
elif args.mode == 'mono':
model = os.path.join(torch_dir, 'monoloco_pp-201203-1424.pkl')
path = MONOLOCO_MODEL
dic_models[args.mode] = model
else:
model = os.path.join(torch_dir, 'monstereo-201202-1212.pkl')
path = MONSTEREO_MODEL
dic_models[args.mode] = model
if not os.path.exists(model):
import gdown
gdown.download(path, model, quiet=False)
return dic_models
def factory_from_args(args):
@ -36,14 +81,9 @@ def factory_from_args(args):
if not args.images:
raise Exception("no image files given")
# Model
if not args.checkpoint:
if os.path.exists(OPENPIFPAF_PATH):
args.checkpoint = OPENPIFPAF_PATH
else:
LOG.info("Checkpoint for OpenPifPaf not specified and default model not found in 'data/models'. "
"Using a ShuffleNet backbone")
args.checkpoint = 'shufflenetv2k30'
# Models
dic_models = download_checkpoints(args)
args.checkpoint = dic_models['keypoints']
logger.configure(args, LOG) # logger first
@ -59,7 +99,7 @@ def factory_from_args(args):
args.figure_width = 10
args.dpi_factor = 1.0
if args.net == 'monstereo':
if args.mode == 'stereo':
args.batch_size = 2
args.images = sorted(args.images)
else:
@ -79,26 +119,31 @@ def factory_from_args(args):
show.configure(args)
visualizer.configure(args)
return args
return args, dic_models
def predict(args):
cnt = 0
args = factory_from_args(args)
assert args.mode in ('keypoints', 'mono', 'stereo')
args, dic_models = factory_from_args(args)
# Load Models
assert args.net in ('monoloco_pp', 'monstereo', 'pifpaf')
if args.net in ('monoloco_pp', 'monstereo'):
net = Loco(model=args.model, net=args.net, device=args.device, n_dropout=args.n_dropout, p_dropout=args.dropout)
if args.mode in ('mono', 'stereo'):
net = Loco(
model=dic_models[args.mode],
mode=args.mode,
device=args.device,
n_dropout=args.n_dropout,
p_dropout=args.dropout)
# data
processor, model = processor_factory(args)
processor, pifpaf_model = processor_factory(args)
preprocess = preprocess_factory(args)
# data
data = datasets.ImageList(args.images, preprocess=preprocess)
if args.net == 'monstereo':
if args.mode == 'stereo':
assert len(data.image_paths) % 2 == 0, "Odd number of images in a stereo setting"
data_loader = torch.utils.data.DataLoader(
@ -106,7 +151,7 @@ def predict(args):
pin_memory=False, collate_fn=datasets.collate_images_anns_meta)
for batch_i, (image_tensors_batch, _, meta_batch) in enumerate(data_loader):
pred_batch = processor.batch(model, image_tensors_batch, device=args.device)
pred_batch = processor.batch(pifpaf_model, image_tensors_batch, device=args.device)
# unbatch (only for MonStereo)
for idx, (pred, meta) in enumerate(zip(pred_batch, meta_batch)):
@ -136,14 +181,14 @@ def predict(args):
pifpaf_outs['right'] = [ann.json_data() for ann in pred]
# 3D Predictions
if args.net in ('monoloco_pp', 'monstereo'):
if args.mode != 'keypoints':
im_size = (cpu_image.size[0], cpu_image.size[1]) # Original
kk, dic_gt = factory_for_gt(im_size, focal_length=args.focal, name=file_name, path_gt=args.path_gt)
# Preprocess pifpaf outputs and run monoloco
boxes, keypoints = preprocess_pifpaf(pifpaf_outs['left'], im_size, enlarge_boxes=False)
if args.net == 'monoloco_pp':
if args.mode == 'mono':
LOG.info("Prediction with MonoLoco++")
dic_out = net.forward(keypoints, kk)
dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt)
@ -171,11 +216,11 @@ def factory_outputs(args, pifpaf_outs, dic_out, output_path, kk=None):
# Verify conflicting options
if any((xx in args.output_types for xx in ['front', 'bird', 'multi'])):
assert args.net != 'pifpaf', "please use pifpaf original arguments"
assert args.mode != 'keypoints', "for keypooints please use pifpaf original arguments"
if args.social_distance:
assert args.net == 'monoloco_pp', "Social distancing only works with MonoLoco++ network"
assert args.mode == 'mono', "Social distancing only works with monocular network"
if args.net == 'pifpaf':
if args.mode == 'keypoints':
annotation_painter = openpifpaf.show.AnnotationPainter()
with openpifpaf.show.image_canvas(pifpaf_outs['image'], output_path) as ax:
annotation_painter.annotations(ax, pifpaf_outs['pred'])

View File

@ -28,7 +28,7 @@ def cli():
# Predict (2D pose and/or 3D location from images)
# General
predict_parser.add_argument('--mode', help='pifpaf, mono, stereo', default='stereo')
predict_parser.add_argument('images', nargs='*', help='input images')
predict_parser.add_argument('--glob', help='glob expression for input images (for many images)')
predict_parser.add_argument('-o', '--output-directory', help='Output directory')
@ -58,12 +58,11 @@ def cli():
visualizer.cli(parser)
# Monoloco
predict_parser.add_argument('--net', help='Choose network: monoloco, monoloco_p, monoloco_pp, monstereo')
predict_parser.add_argument('--model', help='path of MonoLoco model to load', required=True)
predict_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=512)
predict_parser.add_argument('--mode', help='keypoints, mono, stereo', default='mono')
predict_parser.add_argument('--model', help='path of MonoLoco/MonStereo model to load')
predict_parser.add_argument('--net', help='only to select older MonoLoco model, otherwise use --mode')
predict_parser.add_argument('--path_gt', help='path of json file with gt 3d localization',
default='data/arrays/names-kitti-200615-1022.json')
predict_parser.add_argument('--transform', help='transformation for the pose', default='None')
predict_parser.add_argument('--z_max', type=int, help='maximum meters distance for predictions', default=100)
predict_parser.add_argument('--n_dropout', type=int, help='Epistemic uncertainty evaluation', default=0)
predict_parser.add_argument('--dropout', type=float, help='dropout parameter', default=0.2)

View File

@ -37,6 +37,7 @@ setup(
'pandas',
'pylint',
'pytest',
'gdown',
],
'prep': [
'nuscenes-devkit==1.0.2',