add model download and mode argument
This commit is contained in:
parent
3b97afb89e
commit
f0bbaa2a0e
@ -73,7 +73,7 @@ or check the file `monoloco/run.py`
|
|||||||
|
|
||||||
# Predictions
|
# Predictions
|
||||||
For a quick setup download a pifpaf and MonoLoco++ / MonStereo models from
|
For a quick setup download a pifpaf and MonoLoco++ / MonStereo models from
|
||||||
[here](https://drive.google.com/drive/folders/1jZToVMBEZQMdLB5BAIq2CdCLP5kzNo9t?usp=sharing) and save them into `data/models`.
|
[here](https://drive.google.com/drive/folders/1jZToVMBEZQMdLB5BAIq2CdCLP5kzNo9t?usp=sharing) and save them into `data/models`.
|
||||||
|
|
||||||
## A) 3D Localization
|
## A) 3D Localization
|
||||||
The predict script receives an image (or an entire folder using glob expressions),
|
The predict script receives an image (or an entire folder using glob expressions),
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.7 MiB |
@ -25,9 +25,22 @@ class Loco:
|
|||||||
LINEAR_SIZE_MONO = 256
|
LINEAR_SIZE_MONO = 256
|
||||||
N_SAMPLES = 100
|
N_SAMPLES = 100
|
||||||
|
|
||||||
def __init__(self, model, net='monstereo', device=None, n_dropout=0, p_dropout=0.2, linear_size=1024):
|
def __init__(self, model, mode, net=None, device=None, n_dropout=0, p_dropout=0.2, linear_size=1024):
|
||||||
self.net = net
|
|
||||||
assert self.net in ('monstereo', 'monoloco', 'monoloco_p', 'monoloco_pp')
|
# Select networks
|
||||||
|
assert mode in ('mono', 'stereo'), "mode not recognized"
|
||||||
|
self.mode = mode
|
||||||
|
if net is None:
|
||||||
|
if mode == 'mono':
|
||||||
|
self.net = 'monoloco_pp'
|
||||||
|
else:
|
||||||
|
self.net = 'monstereo'
|
||||||
|
else:
|
||||||
|
assert self.net in ('monstereo', 'monoloco', 'monoloco_p', 'monoloco_pp')
|
||||||
|
if self.net != 'monstereo':
|
||||||
|
assert mode == 'stereo', "Assert arguments mode and net are in conflict"
|
||||||
|
self.net = net
|
||||||
|
|
||||||
if self.net == 'monstereo':
|
if self.net == 'monstereo':
|
||||||
input_size = 68
|
input_size = 68
|
||||||
output_size = 10
|
output_size = 10
|
||||||
|
|||||||
@ -25,7 +25,52 @@ from .activity import show_social
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
OPENPIFPAF_PATH = 'data/models/shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl' # Default model
|
OPENPIFPAF_MODEL = 'https://drive.google.com/file/d/1b408ockhh29OLAED8Tysd2yGZOo0N_SQ/view?usp=sharing'
|
||||||
|
MONOLOCO_MODEL = 'https://drive.google.com/file/d/1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r/view?usp=sharing'
|
||||||
|
MONSTEREO_MODEL = 'https://drive.google.com/file/d/1xztN07dmp2e_nHI6Lcn103SAzt-Ntg49/view?usp=sharing'
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_checkpoints_dir():
|
||||||
|
if hasattr(torch, 'hub') and hasattr(torch.hub, 'get_dir'):
|
||||||
|
# new in pytorch 1.6.0
|
||||||
|
base_dir = torch.hub.get_dir()
|
||||||
|
elif os.getenv('TORCH_HOME'):
|
||||||
|
base_dir = os.getenv('TORCH_HOME')
|
||||||
|
elif os.getenv('XDG_CACHE_HOME'):
|
||||||
|
base_dir = os.path.join(os.getenv('XDG_CACHE_HOME'), 'torch')
|
||||||
|
else:
|
||||||
|
base_dir = os.path.expanduser(os.path.join('~', '.cache', 'torch'))
|
||||||
|
return os.path.join(base_dir, 'checkpoints')
|
||||||
|
|
||||||
|
|
||||||
|
def download_checkpoints(args):
|
||||||
|
torch_dir = get_torch_checkpoints_dir()
|
||||||
|
pifpaf_model = os.path.join(torch_dir, 'shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl')
|
||||||
|
dic_models = {'keypoints': pifpaf_model}
|
||||||
|
|
||||||
|
if not os.path.exists(pifpaf_model):
|
||||||
|
import gdown
|
||||||
|
gdown.download(OPENPIFPAF_MODEL, pifpaf_model, quiet=False)
|
||||||
|
|
||||||
|
if args.mode == 'keypoints':
|
||||||
|
return dic_models
|
||||||
|
elif args.model is not None:
|
||||||
|
dic_models[args.mode] = args.model
|
||||||
|
return dic_models
|
||||||
|
elif args.mode == 'mono':
|
||||||
|
model = os.path.join(torch_dir, 'monoloco_pp-201203-1424.pkl')
|
||||||
|
path = MONOLOCO_MODEL
|
||||||
|
dic_models[args.mode] = model
|
||||||
|
else:
|
||||||
|
model = os.path.join(torch_dir, 'monstereo-201202-1212.pkl')
|
||||||
|
path = MONSTEREO_MODEL
|
||||||
|
dic_models[args.mode] = model
|
||||||
|
|
||||||
|
if not os.path.exists(model):
|
||||||
|
import gdown
|
||||||
|
gdown.download(path, model, quiet=False)
|
||||||
|
|
||||||
|
return dic_models
|
||||||
|
|
||||||
|
|
||||||
def factory_from_args(args):
|
def factory_from_args(args):
|
||||||
@ -36,14 +81,9 @@ def factory_from_args(args):
|
|||||||
if not args.images:
|
if not args.images:
|
||||||
raise Exception("no image files given")
|
raise Exception("no image files given")
|
||||||
|
|
||||||
# Model
|
# Models
|
||||||
if not args.checkpoint:
|
dic_models = download_checkpoints(args)
|
||||||
if os.path.exists(OPENPIFPAF_PATH):
|
args.checkpoint = dic_models['keypoints']
|
||||||
args.checkpoint = OPENPIFPAF_PATH
|
|
||||||
else:
|
|
||||||
LOG.info("Checkpoint for OpenPifPaf not specified and default model not found in 'data/models'. "
|
|
||||||
"Using a ShuffleNet backbone")
|
|
||||||
args.checkpoint = 'shufflenetv2k30'
|
|
||||||
|
|
||||||
logger.configure(args, LOG) # logger first
|
logger.configure(args, LOG) # logger first
|
||||||
|
|
||||||
@ -59,7 +99,7 @@ def factory_from_args(args):
|
|||||||
args.figure_width = 10
|
args.figure_width = 10
|
||||||
args.dpi_factor = 1.0
|
args.dpi_factor = 1.0
|
||||||
|
|
||||||
if args.net == 'monstereo':
|
if args.mode == 'stereo':
|
||||||
args.batch_size = 2
|
args.batch_size = 2
|
||||||
args.images = sorted(args.images)
|
args.images = sorted(args.images)
|
||||||
else:
|
else:
|
||||||
@ -79,26 +119,31 @@ def factory_from_args(args):
|
|||||||
show.configure(args)
|
show.configure(args)
|
||||||
visualizer.configure(args)
|
visualizer.configure(args)
|
||||||
|
|
||||||
return args
|
return args, dic_models
|
||||||
|
|
||||||
|
|
||||||
def predict(args):
|
def predict(args):
|
||||||
|
|
||||||
cnt = 0
|
cnt = 0
|
||||||
args = factory_from_args(args)
|
assert args.mode in ('keypoints', 'mono', 'stereo')
|
||||||
|
args, dic_models = factory_from_args(args)
|
||||||
|
|
||||||
# Load Models
|
# Load Models
|
||||||
assert args.net in ('monoloco_pp', 'monstereo', 'pifpaf')
|
if args.mode in ('mono', 'stereo'):
|
||||||
if args.net in ('monoloco_pp', 'monstereo'):
|
net = Loco(
|
||||||
net = Loco(model=args.model, net=args.net, device=args.device, n_dropout=args.n_dropout, p_dropout=args.dropout)
|
model=dic_models[args.mode],
|
||||||
|
mode=args.mode,
|
||||||
|
device=args.device,
|
||||||
|
n_dropout=args.n_dropout,
|
||||||
|
p_dropout=args.dropout)
|
||||||
|
|
||||||
# data
|
# data
|
||||||
processor, model = processor_factory(args)
|
processor, pifpaf_model = processor_factory(args)
|
||||||
preprocess = preprocess_factory(args)
|
preprocess = preprocess_factory(args)
|
||||||
|
|
||||||
# data
|
# data
|
||||||
data = datasets.ImageList(args.images, preprocess=preprocess)
|
data = datasets.ImageList(args.images, preprocess=preprocess)
|
||||||
if args.net == 'monstereo':
|
if args.mode == 'stereo':
|
||||||
assert len(data.image_paths) % 2 == 0, "Odd number of images in a stereo setting"
|
assert len(data.image_paths) % 2 == 0, "Odd number of images in a stereo setting"
|
||||||
|
|
||||||
data_loader = torch.utils.data.DataLoader(
|
data_loader = torch.utils.data.DataLoader(
|
||||||
@ -106,7 +151,7 @@ def predict(args):
|
|||||||
pin_memory=False, collate_fn=datasets.collate_images_anns_meta)
|
pin_memory=False, collate_fn=datasets.collate_images_anns_meta)
|
||||||
|
|
||||||
for batch_i, (image_tensors_batch, _, meta_batch) in enumerate(data_loader):
|
for batch_i, (image_tensors_batch, _, meta_batch) in enumerate(data_loader):
|
||||||
pred_batch = processor.batch(model, image_tensors_batch, device=args.device)
|
pred_batch = processor.batch(pifpaf_model, image_tensors_batch, device=args.device)
|
||||||
|
|
||||||
# unbatch (only for MonStereo)
|
# unbatch (only for MonStereo)
|
||||||
for idx, (pred, meta) in enumerate(zip(pred_batch, meta_batch)):
|
for idx, (pred, meta) in enumerate(zip(pred_batch, meta_batch)):
|
||||||
@ -136,14 +181,14 @@ def predict(args):
|
|||||||
pifpaf_outs['right'] = [ann.json_data() for ann in pred]
|
pifpaf_outs['right'] = [ann.json_data() for ann in pred]
|
||||||
|
|
||||||
# 3D Predictions
|
# 3D Predictions
|
||||||
if args.net in ('monoloco_pp', 'monstereo'):
|
if args.mode != 'keypoints':
|
||||||
im_size = (cpu_image.size[0], cpu_image.size[1]) # Original
|
im_size = (cpu_image.size[0], cpu_image.size[1]) # Original
|
||||||
kk, dic_gt = factory_for_gt(im_size, focal_length=args.focal, name=file_name, path_gt=args.path_gt)
|
kk, dic_gt = factory_for_gt(im_size, focal_length=args.focal, name=file_name, path_gt=args.path_gt)
|
||||||
|
|
||||||
# Preprocess pifpaf outputs and run monoloco
|
# Preprocess pifpaf outputs and run monoloco
|
||||||
boxes, keypoints = preprocess_pifpaf(pifpaf_outs['left'], im_size, enlarge_boxes=False)
|
boxes, keypoints = preprocess_pifpaf(pifpaf_outs['left'], im_size, enlarge_boxes=False)
|
||||||
|
|
||||||
if args.net == 'monoloco_pp':
|
if args.mode == 'mono':
|
||||||
LOG.info("Prediction with MonoLoco++")
|
LOG.info("Prediction with MonoLoco++")
|
||||||
dic_out = net.forward(keypoints, kk)
|
dic_out = net.forward(keypoints, kk)
|
||||||
dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt)
|
dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt)
|
||||||
@ -171,11 +216,11 @@ def factory_outputs(args, pifpaf_outs, dic_out, output_path, kk=None):
|
|||||||
|
|
||||||
# Verify conflicting options
|
# Verify conflicting options
|
||||||
if any((xx in args.output_types for xx in ['front', 'bird', 'multi'])):
|
if any((xx in args.output_types for xx in ['front', 'bird', 'multi'])):
|
||||||
assert args.net != 'pifpaf', "please use pifpaf original arguments"
|
assert args.mode != 'keypoints', "for keypooints please use pifpaf original arguments"
|
||||||
if args.social_distance:
|
if args.social_distance:
|
||||||
assert args.net == 'monoloco_pp', "Social distancing only works with MonoLoco++ network"
|
assert args.mode == 'mono', "Social distancing only works with monocular network"
|
||||||
|
|
||||||
if args.net == 'pifpaf':
|
if args.mode == 'keypoints':
|
||||||
annotation_painter = openpifpaf.show.AnnotationPainter()
|
annotation_painter = openpifpaf.show.AnnotationPainter()
|
||||||
with openpifpaf.show.image_canvas(pifpaf_outs['image'], output_path) as ax:
|
with openpifpaf.show.image_canvas(pifpaf_outs['image'], output_path) as ax:
|
||||||
annotation_painter.annotations(ax, pifpaf_outs['pred'])
|
annotation_painter.annotations(ax, pifpaf_outs['pred'])
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def cli():
|
|||||||
|
|
||||||
# Predict (2D pose and/or 3D location from images)
|
# Predict (2D pose and/or 3D location from images)
|
||||||
# General
|
# General
|
||||||
predict_parser.add_argument('--mode', help='pifpaf, mono, stereo', default='stereo')
|
|
||||||
predict_parser.add_argument('images', nargs='*', help='input images')
|
predict_parser.add_argument('images', nargs='*', help='input images')
|
||||||
predict_parser.add_argument('--glob', help='glob expression for input images (for many images)')
|
predict_parser.add_argument('--glob', help='glob expression for input images (for many images)')
|
||||||
predict_parser.add_argument('-o', '--output-directory', help='Output directory')
|
predict_parser.add_argument('-o', '--output-directory', help='Output directory')
|
||||||
@ -58,12 +58,11 @@ def cli():
|
|||||||
visualizer.cli(parser)
|
visualizer.cli(parser)
|
||||||
|
|
||||||
# Monoloco
|
# Monoloco
|
||||||
predict_parser.add_argument('--net', help='Choose network: monoloco, monoloco_p, monoloco_pp, monstereo')
|
predict_parser.add_argument('--mode', help='keypoints, mono, stereo', default='mono')
|
||||||
predict_parser.add_argument('--model', help='path of MonoLoco model to load', required=True)
|
predict_parser.add_argument('--model', help='path of MonoLoco/MonStereo model to load')
|
||||||
predict_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=512)
|
predict_parser.add_argument('--net', help='only to select older MonoLoco model, otherwise use --mode')
|
||||||
predict_parser.add_argument('--path_gt', help='path of json file with gt 3d localization',
|
predict_parser.add_argument('--path_gt', help='path of json file with gt 3d localization',
|
||||||
default='data/arrays/names-kitti-200615-1022.json')
|
default='data/arrays/names-kitti-200615-1022.json')
|
||||||
predict_parser.add_argument('--transform', help='transformation for the pose', default='None')
|
|
||||||
predict_parser.add_argument('--z_max', type=int, help='maximum meters distance for predictions', default=100)
|
predict_parser.add_argument('--z_max', type=int, help='maximum meters distance for predictions', default=100)
|
||||||
predict_parser.add_argument('--n_dropout', type=int, help='Epistemic uncertainty evaluation', default=0)
|
predict_parser.add_argument('--n_dropout', type=int, help='Epistemic uncertainty evaluation', default=0)
|
||||||
predict_parser.add_argument('--dropout', type=float, help='dropout parameter', default=0.2)
|
predict_parser.add_argument('--dropout', type=float, help='dropout parameter', default=0.2)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user