add model download and mode argument
This commit is contained in:
parent
3b97afb89e
commit
f0bbaa2a0e
@ -73,7 +73,7 @@ or check the file `monoloco/run.py`
|
||||
|
||||
# Predictions
|
||||
For a quick setup download a pifpaf and MonoLoco++ / MonStereo models from
|
||||
[here](https://drive.google.com/drive/folders/1jZToVMBEZQMdLB5BAIq2CdCLP5kzNo9t?usp=sharing) and save them into `data/models`.
|
||||
[here](https://drive.google.com/drive/folders/1jZToVMBEZQMdLB5BAIq2CdCLP5kzNo9t?usp=sharing) and save them into `data/models`.
|
||||
|
||||
## A) 3D Localization
|
||||
The predict script receives an image (or an entire folder using glob expressions),
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.7 MiB |
@ -25,9 +25,22 @@ class Loco:
|
||||
LINEAR_SIZE_MONO = 256
|
||||
N_SAMPLES = 100
|
||||
|
||||
def __init__(self, model, net='monstereo', device=None, n_dropout=0, p_dropout=0.2, linear_size=1024):
|
||||
self.net = net
|
||||
assert self.net in ('monstereo', 'monoloco', 'monoloco_p', 'monoloco_pp')
|
||||
def __init__(self, model, mode, net=None, device=None, n_dropout=0, p_dropout=0.2, linear_size=1024):
|
||||
|
||||
# Select networks
|
||||
assert mode in ('mono', 'stereo'), "mode not recognized"
|
||||
self.mode = mode
|
||||
if net is None:
|
||||
if mode == 'mono':
|
||||
self.net = 'monoloco_pp'
|
||||
else:
|
||||
self.net = 'monstereo'
|
||||
else:
|
||||
assert self.net in ('monstereo', 'monoloco', 'monoloco_p', 'monoloco_pp')
|
||||
if self.net != 'monstereo':
|
||||
assert mode == 'stereo', "Assert arguments mode and net are in conflict"
|
||||
self.net = net
|
||||
|
||||
if self.net == 'monstereo':
|
||||
input_size = 68
|
||||
output_size = 10
|
||||
|
||||
@ -25,7 +25,52 @@ from .activity import show_social
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
OPENPIFPAF_PATH = 'data/models/shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl' # Default model
|
||||
OPENPIFPAF_MODEL = 'https://drive.google.com/file/d/1b408ockhh29OLAED8Tysd2yGZOo0N_SQ/view?usp=sharing'
|
||||
MONOLOCO_MODEL = 'https://drive.google.com/file/d/1krkB8J9JhgQp4xppmDu-YBRUxZvOs96r/view?usp=sharing'
|
||||
MONSTEREO_MODEL = 'https://drive.google.com/file/d/1xztN07dmp2e_nHI6Lcn103SAzt-Ntg49/view?usp=sharing'
|
||||
|
||||
|
||||
def get_torch_checkpoints_dir():
|
||||
if hasattr(torch, 'hub') and hasattr(torch.hub, 'get_dir'):
|
||||
# new in pytorch 1.6.0
|
||||
base_dir = torch.hub.get_dir()
|
||||
elif os.getenv('TORCH_HOME'):
|
||||
base_dir = os.getenv('TORCH_HOME')
|
||||
elif os.getenv('XDG_CACHE_HOME'):
|
||||
base_dir = os.path.join(os.getenv('XDG_CACHE_HOME'), 'torch')
|
||||
else:
|
||||
base_dir = os.path.expanduser(os.path.join('~', '.cache', 'torch'))
|
||||
return os.path.join(base_dir, 'checkpoints')
|
||||
|
||||
|
||||
def download_checkpoints(args):
|
||||
torch_dir = get_torch_checkpoints_dir()
|
||||
pifpaf_model = os.path.join(torch_dir, 'shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl')
|
||||
dic_models = {'keypoints': pifpaf_model}
|
||||
|
||||
if not os.path.exists(pifpaf_model):
|
||||
import gdown
|
||||
gdown.download(OPENPIFPAF_MODEL, pifpaf_model, quiet=False)
|
||||
|
||||
if args.mode == 'keypoints':
|
||||
return dic_models
|
||||
elif args.model is not None:
|
||||
dic_models[args.mode] = args.model
|
||||
return dic_models
|
||||
elif args.mode == 'mono':
|
||||
model = os.path.join(torch_dir, 'monoloco_pp-201203-1424.pkl')
|
||||
path = MONOLOCO_MODEL
|
||||
dic_models[args.mode] = model
|
||||
else:
|
||||
model = os.path.join(torch_dir, 'monstereo-201202-1212.pkl')
|
||||
path = MONSTEREO_MODEL
|
||||
dic_models[args.mode] = model
|
||||
|
||||
if not os.path.exists(model):
|
||||
import gdown
|
||||
gdown.download(path, model, quiet=False)
|
||||
|
||||
return dic_models
|
||||
|
||||
|
||||
def factory_from_args(args):
|
||||
@ -36,14 +81,9 @@ def factory_from_args(args):
|
||||
if not args.images:
|
||||
raise Exception("no image files given")
|
||||
|
||||
# Model
|
||||
if not args.checkpoint:
|
||||
if os.path.exists(OPENPIFPAF_PATH):
|
||||
args.checkpoint = OPENPIFPAF_PATH
|
||||
else:
|
||||
LOG.info("Checkpoint for OpenPifPaf not specified and default model not found in 'data/models'. "
|
||||
"Using a ShuffleNet backbone")
|
||||
args.checkpoint = 'shufflenetv2k30'
|
||||
# Models
|
||||
dic_models = download_checkpoints(args)
|
||||
args.checkpoint = dic_models['keypoints']
|
||||
|
||||
logger.configure(args, LOG) # logger first
|
||||
|
||||
@ -59,7 +99,7 @@ def factory_from_args(args):
|
||||
args.figure_width = 10
|
||||
args.dpi_factor = 1.0
|
||||
|
||||
if args.net == 'monstereo':
|
||||
if args.mode == 'stereo':
|
||||
args.batch_size = 2
|
||||
args.images = sorted(args.images)
|
||||
else:
|
||||
@ -79,26 +119,31 @@ def factory_from_args(args):
|
||||
show.configure(args)
|
||||
visualizer.configure(args)
|
||||
|
||||
return args
|
||||
return args, dic_models
|
||||
|
||||
|
||||
def predict(args):
|
||||
|
||||
cnt = 0
|
||||
args = factory_from_args(args)
|
||||
assert args.mode in ('keypoints', 'mono', 'stereo')
|
||||
args, dic_models = factory_from_args(args)
|
||||
|
||||
# Load Models
|
||||
assert args.net in ('monoloco_pp', 'monstereo', 'pifpaf')
|
||||
if args.net in ('monoloco_pp', 'monstereo'):
|
||||
net = Loco(model=args.model, net=args.net, device=args.device, n_dropout=args.n_dropout, p_dropout=args.dropout)
|
||||
if args.mode in ('mono', 'stereo'):
|
||||
net = Loco(
|
||||
model=dic_models[args.mode],
|
||||
mode=args.mode,
|
||||
device=args.device,
|
||||
n_dropout=args.n_dropout,
|
||||
p_dropout=args.dropout)
|
||||
|
||||
# data
|
||||
processor, model = processor_factory(args)
|
||||
processor, pifpaf_model = processor_factory(args)
|
||||
preprocess = preprocess_factory(args)
|
||||
|
||||
# data
|
||||
data = datasets.ImageList(args.images, preprocess=preprocess)
|
||||
if args.net == 'monstereo':
|
||||
if args.mode == 'stereo':
|
||||
assert len(data.image_paths) % 2 == 0, "Odd number of images in a stereo setting"
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
@ -106,7 +151,7 @@ def predict(args):
|
||||
pin_memory=False, collate_fn=datasets.collate_images_anns_meta)
|
||||
|
||||
for batch_i, (image_tensors_batch, _, meta_batch) in enumerate(data_loader):
|
||||
pred_batch = processor.batch(model, image_tensors_batch, device=args.device)
|
||||
pred_batch = processor.batch(pifpaf_model, image_tensors_batch, device=args.device)
|
||||
|
||||
# unbatch (only for MonStereo)
|
||||
for idx, (pred, meta) in enumerate(zip(pred_batch, meta_batch)):
|
||||
@ -136,14 +181,14 @@ def predict(args):
|
||||
pifpaf_outs['right'] = [ann.json_data() for ann in pred]
|
||||
|
||||
# 3D Predictions
|
||||
if args.net in ('monoloco_pp', 'monstereo'):
|
||||
if args.mode != 'keypoints':
|
||||
im_size = (cpu_image.size[0], cpu_image.size[1]) # Original
|
||||
kk, dic_gt = factory_for_gt(im_size, focal_length=args.focal, name=file_name, path_gt=args.path_gt)
|
||||
|
||||
# Preprocess pifpaf outputs and run monoloco
|
||||
boxes, keypoints = preprocess_pifpaf(pifpaf_outs['left'], im_size, enlarge_boxes=False)
|
||||
|
||||
if args.net == 'monoloco_pp':
|
||||
if args.mode == 'mono':
|
||||
LOG.info("Prediction with MonoLoco++")
|
||||
dic_out = net.forward(keypoints, kk)
|
||||
dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt)
|
||||
@ -171,11 +216,11 @@ def factory_outputs(args, pifpaf_outs, dic_out, output_path, kk=None):
|
||||
|
||||
# Verify conflicting options
|
||||
if any((xx in args.output_types for xx in ['front', 'bird', 'multi'])):
|
||||
assert args.net != 'pifpaf', "please use pifpaf original arguments"
|
||||
assert args.mode != 'keypoints', "for keypooints please use pifpaf original arguments"
|
||||
if args.social_distance:
|
||||
assert args.net == 'monoloco_pp', "Social distancing only works with MonoLoco++ network"
|
||||
assert args.mode == 'mono', "Social distancing only works with monocular network"
|
||||
|
||||
if args.net == 'pifpaf':
|
||||
if args.mode == 'keypoints':
|
||||
annotation_painter = openpifpaf.show.AnnotationPainter()
|
||||
with openpifpaf.show.image_canvas(pifpaf_outs['image'], output_path) as ax:
|
||||
annotation_painter.annotations(ax, pifpaf_outs['pred'])
|
||||
|
||||
@ -28,7 +28,7 @@ def cli():
|
||||
|
||||
# Predict (2D pose and/or 3D location from images)
|
||||
# General
|
||||
predict_parser.add_argument('--mode', help='pifpaf, mono, stereo', default='stereo')
|
||||
|
||||
predict_parser.add_argument('images', nargs='*', help='input images')
|
||||
predict_parser.add_argument('--glob', help='glob expression for input images (for many images)')
|
||||
predict_parser.add_argument('-o', '--output-directory', help='Output directory')
|
||||
@ -58,12 +58,11 @@ def cli():
|
||||
visualizer.cli(parser)
|
||||
|
||||
# Monoloco
|
||||
predict_parser.add_argument('--net', help='Choose network: monoloco, monoloco_p, monoloco_pp, monstereo')
|
||||
predict_parser.add_argument('--model', help='path of MonoLoco model to load', required=True)
|
||||
predict_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=512)
|
||||
predict_parser.add_argument('--mode', help='keypoints, mono, stereo', default='mono')
|
||||
predict_parser.add_argument('--model', help='path of MonoLoco/MonStereo model to load')
|
||||
predict_parser.add_argument('--net', help='only to select older MonoLoco model, otherwise use --mode')
|
||||
predict_parser.add_argument('--path_gt', help='path of json file with gt 3d localization',
|
||||
default='data/arrays/names-kitti-200615-1022.json')
|
||||
predict_parser.add_argument('--transform', help='transformation for the pose', default='None')
|
||||
predict_parser.add_argument('--z_max', type=int, help='maximum meters distance for predictions', default=100)
|
||||
predict_parser.add_argument('--n_dropout', type=int, help='Epistemic uncertainty evaluation', default=0)
|
||||
predict_parser.add_argument('--dropout', type=float, help='dropout parameter', default=0.2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user