add model download and mode argument

This commit is contained in:
Lorenzo 2021-03-22 14:17:43 +01:00
parent 3b97afb89e
commit f0bbaa2a0e
6 changed files with 90 additions and 32 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 MiB

View File

@ -25,9 +25,22 @@ class Loco:
LINEAR_SIZE_MONO = 256 LINEAR_SIZE_MONO = 256
N_SAMPLES = 100 N_SAMPLES = 100
def __init__(self, model, net='monstereo', device=None, n_dropout=0, p_dropout=0.2, linear_size=1024): def __init__(self, model, mode, net=None, device=None, n_dropout=0, p_dropout=0.2, linear_size=1024):
self.net = net
# Select networks
assert mode in ('mono', 'stereo'), "mode not recognized"
self.mode = mode
if net is None:
if mode == 'mono':
self.net = 'monoloco_pp'
else:
self.net = 'monstereo'
else:
assert self.net in ('monstereo', 'monoloco', 'monoloco_p', 'monoloco_pp') assert self.net in ('monstereo', 'monoloco', 'monoloco_p', 'monoloco_pp')
if self.net != 'monstereo':
assert mode == 'stereo', "Assert arguments mode and net are in conflict"
self.net = net
if self.net == 'monstereo': if self.net == 'monstereo':
input_size = 68 input_size = 68
output_size = 10 output_size = 10

View File

@ -25,7 +25,52 @@ from .activity import show_social
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
OPENPIFPAF_PATH = 'data/models/shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl' # Default model OPENPIFPAF_MODEL = 'https://drive.google.com/file/d/1b408ockhh29OLAED8Tysd2yGZOo0N_SQ/view?usp=sharing'
MONOLOCO_MODEL = 'https://drive.google.com/file/d/1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r/view?usp=sharing'
MONSTEREO_MODEL = 'https://drive.google.com/file/d/1xztN07dmp2e_nHI6Lcn103SAzt-Ntg49/view?usp=sharing'
def get_torch_checkpoints_dir():
if hasattr(torch, 'hub') and hasattr(torch.hub, 'get_dir'):
# new in pytorch 1.6.0
base_dir = torch.hub.get_dir()
elif os.getenv('TORCH_HOME'):
base_dir = os.getenv('TORCH_HOME')
elif os.getenv('XDG_CACHE_HOME'):
base_dir = os.path.join(os.getenv('XDG_CACHE_HOME'), 'torch')
else:
base_dir = os.path.expanduser(os.path.join('~', '.cache', 'torch'))
return os.path.join(base_dir, 'checkpoints')
def download_checkpoints(args):
torch_dir = get_torch_checkpoints_dir()
pifpaf_model = os.path.join(torch_dir, 'shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl')
dic_models = {'keypoints': pifpaf_model}
if not os.path.exists(pifpaf_model):
import gdown
gdown.download(OPENPIFPAF_MODEL, pifpaf_model, quiet=False)
if args.mode == 'keypoints':
return dic_models
elif args.model is not None:
dic_models[args.mode] = args.model
return dic_models
elif args.mode == 'mono':
model = os.path.join(torch_dir, 'monoloco_pp-201203-1424.pkl')
path = MONOLOCO_MODEL
dic_models[args.mode] = model
else:
model = os.path.join(torch_dir, 'monstereo-201202-1212.pkl')
path = MONSTEREO_MODEL
dic_models[args.mode] = model
if not os.path.exists(model):
import gdown
gdown.download(path, model, quiet=False)
return dic_models
def factory_from_args(args): def factory_from_args(args):
@ -36,14 +81,9 @@ def factory_from_args(args):
if not args.images: if not args.images:
raise Exception("no image files given") raise Exception("no image files given")
# Model # Models
if not args.checkpoint: dic_models = download_checkpoints(args)
if os.path.exists(OPENPIFPAF_PATH): args.checkpoint = dic_models['keypoints']
args.checkpoint = OPENPIFPAF_PATH
else:
LOG.info("Checkpoint for OpenPifPaf not specified and default model not found in 'data/models'. "
"Using a ShuffleNet backbone")
args.checkpoint = 'shufflenetv2k30'
logger.configure(args, LOG) # logger first logger.configure(args, LOG) # logger first
@ -59,7 +99,7 @@ def factory_from_args(args):
args.figure_width = 10 args.figure_width = 10
args.dpi_factor = 1.0 args.dpi_factor = 1.0
if args.net == 'monstereo': if args.mode == 'stereo':
args.batch_size = 2 args.batch_size = 2
args.images = sorted(args.images) args.images = sorted(args.images)
else: else:
@ -79,26 +119,31 @@ def factory_from_args(args):
show.configure(args) show.configure(args)
visualizer.configure(args) visualizer.configure(args)
return args return args, dic_models
def predict(args): def predict(args):
cnt = 0 cnt = 0
args = factory_from_args(args) assert args.mode in ('keypoints', 'mono', 'stereo')
args, dic_models = factory_from_args(args)
# Load Models # Load Models
assert args.net in ('monoloco_pp', 'monstereo', 'pifpaf') if args.mode in ('mono', 'stereo'):
if args.net in ('monoloco_pp', 'monstereo'): net = Loco(
net = Loco(model=args.model, net=args.net, device=args.device, n_dropout=args.n_dropout, p_dropout=args.dropout) model=dic_models[args.mode],
mode=args.mode,
device=args.device,
n_dropout=args.n_dropout,
p_dropout=args.dropout)
# data # data
processor, model = processor_factory(args) processor, pifpaf_model = processor_factory(args)
preprocess = preprocess_factory(args) preprocess = preprocess_factory(args)
# data # data
data = datasets.ImageList(args.images, preprocess=preprocess) data = datasets.ImageList(args.images, preprocess=preprocess)
if args.net == 'monstereo': if args.mode == 'stereo':
assert len(data.image_paths) % 2 == 0, "Odd number of images in a stereo setting" assert len(data.image_paths) % 2 == 0, "Odd number of images in a stereo setting"
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
@ -106,7 +151,7 @@ def predict(args):
pin_memory=False, collate_fn=datasets.collate_images_anns_meta) pin_memory=False, collate_fn=datasets.collate_images_anns_meta)
for batch_i, (image_tensors_batch, _, meta_batch) in enumerate(data_loader): for batch_i, (image_tensors_batch, _, meta_batch) in enumerate(data_loader):
pred_batch = processor.batch(model, image_tensors_batch, device=args.device) pred_batch = processor.batch(pifpaf_model, image_tensors_batch, device=args.device)
# unbatch (only for MonStereo) # unbatch (only for MonStereo)
for idx, (pred, meta) in enumerate(zip(pred_batch, meta_batch)): for idx, (pred, meta) in enumerate(zip(pred_batch, meta_batch)):
@ -136,14 +181,14 @@ def predict(args):
pifpaf_outs['right'] = [ann.json_data() for ann in pred] pifpaf_outs['right'] = [ann.json_data() for ann in pred]
# 3D Predictions # 3D Predictions
if args.net in ('monoloco_pp', 'monstereo'): if args.mode != 'keypoints':
im_size = (cpu_image.size[0], cpu_image.size[1]) # Original im_size = (cpu_image.size[0], cpu_image.size[1]) # Original
kk, dic_gt = factory_for_gt(im_size, focal_length=args.focal, name=file_name, path_gt=args.path_gt) kk, dic_gt = factory_for_gt(im_size, focal_length=args.focal, name=file_name, path_gt=args.path_gt)
# Preprocess pifpaf outputs and run monoloco # Preprocess pifpaf outputs and run monoloco
boxes, keypoints = preprocess_pifpaf(pifpaf_outs['left'], im_size, enlarge_boxes=False) boxes, keypoints = preprocess_pifpaf(pifpaf_outs['left'], im_size, enlarge_boxes=False)
if args.net == 'monoloco_pp': if args.mode == 'mono':
LOG.info("Prediction with MonoLoco++") LOG.info("Prediction with MonoLoco++")
dic_out = net.forward(keypoints, kk) dic_out = net.forward(keypoints, kk)
dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt) dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt)
@ -171,11 +216,11 @@ def factory_outputs(args, pifpaf_outs, dic_out, output_path, kk=None):
# Verify conflicting options # Verify conflicting options
if any((xx in args.output_types for xx in ['front', 'bird', 'multi'])): if any((xx in args.output_types for xx in ['front', 'bird', 'multi'])):
assert args.net != 'pifpaf', "please use pifpaf original arguments" assert args.mode != 'keypoints', "for keypooints please use pifpaf original arguments"
if args.social_distance: if args.social_distance:
assert args.net == 'monoloco_pp', "Social distancing only works with MonoLoco++ network" assert args.mode == 'mono', "Social distancing only works with monocular network"
if args.net == 'pifpaf': if args.mode == 'keypoints':
annotation_painter = openpifpaf.show.AnnotationPainter() annotation_painter = openpifpaf.show.AnnotationPainter()
with openpifpaf.show.image_canvas(pifpaf_outs['image'], output_path) as ax: with openpifpaf.show.image_canvas(pifpaf_outs['image'], output_path) as ax:
annotation_painter.annotations(ax, pifpaf_outs['pred']) annotation_painter.annotations(ax, pifpaf_outs['pred'])

View File

@ -28,7 +28,7 @@ def cli():
# Predict (2D pose and/or 3D location from images) # Predict (2D pose and/or 3D location from images)
# General # General
predict_parser.add_argument('--mode', help='pifpaf, mono, stereo', default='stereo')
predict_parser.add_argument('images', nargs='*', help='input images') predict_parser.add_argument('images', nargs='*', help='input images')
predict_parser.add_argument('--glob', help='glob expression for input images (for many images)') predict_parser.add_argument('--glob', help='glob expression for input images (for many images)')
predict_parser.add_argument('-o', '--output-directory', help='Output directory') predict_parser.add_argument('-o', '--output-directory', help='Output directory')
@ -58,12 +58,11 @@ def cli():
visualizer.cli(parser) visualizer.cli(parser)
# Monoloco # Monoloco
predict_parser.add_argument('--net', help='Choose network: monoloco, monoloco_p, monoloco_pp, monstereo') predict_parser.add_argument('--mode', help='keypoints, mono, stereo', default='mono')
predict_parser.add_argument('--model', help='path of MonoLoco model to load', required=True) predict_parser.add_argument('--model', help='path of MonoLoco/MonStereo model to load')
predict_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=512) predict_parser.add_argument('--net', help='only to select older MonoLoco model, otherwise use --mode')
predict_parser.add_argument('--path_gt', help='path of json file with gt 3d localization', predict_parser.add_argument('--path_gt', help='path of json file with gt 3d localization',
default='data/arrays/names-kitti-200615-1022.json') default='data/arrays/names-kitti-200615-1022.json')
predict_parser.add_argument('--transform', help='transformation for the pose', default='None')
predict_parser.add_argument('--z_max', type=int, help='maximum meters distance for predictions', default=100) predict_parser.add_argument('--z_max', type=int, help='maximum meters distance for predictions', default=100)
predict_parser.add_argument('--n_dropout', type=int, help='Epistemic uncertainty evaluation', default=0) predict_parser.add_argument('--n_dropout', type=int, help='Epistemic uncertainty evaluation', default=0)
predict_parser.add_argument('--dropout', type=float, help='dropout parameter', default=0.2) predict_parser.add_argument('--dropout', type=float, help='dropout parameter', default=0.2)

View File

@ -37,6 +37,7 @@ setup(
'pandas', 'pandas',
'pylint', 'pylint',
'pytest', 'pytest',
'gdown',
], ],
'prep': [ 'prep': [
'nuscenes-devkit==1.0.2', 'nuscenes-devkit==1.0.2',