diff --git a/src/main.py b/src/main.py index fa9d5aa..92da185 100644 --- a/src/main.py +++ b/src/main.py @@ -52,8 +52,7 @@ def cli(): # 1)Pifpaf arguments nets.cli(predict_parser) decoder.cli(predict_parser, force_complete_pose=True, instance_threshold=0.1) - predict_parser.add_argument('--model_pifpaf', help='pifpaf model to load', - default="data/models/resnet152-190412.pkl") + predict_parser.add_argument('--checkpoint', help='pifpaf model to load') predict_parser.add_argument('--scale', default=1.0, type=float, help='change the scale of the image to preprocess') # 2) Monoloco argument diff --git a/src/predict/pifpaf.py b/src/predict/pifpaf.py index ea1f9b4..980577a 100644 --- a/src/predict/pifpaf.py +++ b/src/predict/pifpaf.py @@ -46,7 +46,13 @@ def factory_from_args(args): # Merge the model_pifpaf argument if not args.checkpoint: - args.checkpoint = args.model_pifpaf + args.checkpoint = 'resnet152' # Default model Resnet 152 + elif args.checkpoint == 'resnet50': + args.checkpoint = 'data/models/resnet50block5-pif-paf-edge401-190424-122009-f26a1f53.pkl' + elif args.checkpoint == 'resnet101': + args.checkpoint = 'data/models/resnet101block5-pif-paf-edge401-190412-151013-513a2d2d.pkl' + elif args.checkpoint == 'resnet152': + args.checkpoint = 'data/models/resnet152block5-pif-paf-edge401-190412-121848-8d771fcc.pkl' # glob if not args.webcam: if args.glob: diff --git a/src/visuals/webcam.py b/src/visuals/webcam.py index ebbc6ed..1f65dd3 100644 --- a/src/visuals/webcam.py +++ b/src/visuals/webcam.py @@ -38,6 +38,7 @@ def webcam(args): visualizer_monoloco = None while True: + start = time.time() ret, frame = cam.read() image = cv2.resize(frame, None, fx=args.scale, fy=args.scale) height, width, _ = image.shape @@ -68,6 +69,8 @@ def webcam(args): outputs, varss = monoloco.forward(keypoints, kk) dic_out = monoloco.post_process(outputs, varss, boxes, keypoints, kk, dict_gt) visualizer_monoloco.send((pil_image, dic_out)) + end = time.time() + print("run-time: {:.2f} ms".format((end-start)*1000)) cam.release() @@ -95,8 +98,7 @@ class VisualizerMonoloco: while True: image, dict_ann = yield - draw_start = time.time() - while axes and ((axes[0] and axes[0].patches) or (axes[-1] and axes[-1].patches)): + while axes and (axes[-1] and axes[-1].patches): # for front -1==0, for bird/combined -1 == 1 if axes[0]: del axes[0].patches[0] del axes[0].texts[0] @@ -105,9 +107,9 @@ class VisualizerMonoloco: del axes[1].patches[0] # the one became the 0 if len(axes[1].lines) > 2: del axes[1].lines[2] - del axes[1].texts[0] + if len(axes[1].texts) > 0: # in case of no text + del axes[1].texts[0] printer.draw(figures, axes, dict_ann, image) - print('draw', time.time() - draw_start) mypause(0.01)