Changes for pull request
This commit is contained in:
parent
8f271111a8
commit
d13b480f06
@ -73,7 +73,7 @@ def download_checkpoints(args):
|
||||
assert not args.social_distance, "Social distance not supported in stereo modality"
|
||||
path = MONSTEREO_MODEL
|
||||
name = 'monstereo-201202-1212.pkl'
|
||||
elif args.social_distance or (args.activities and 'social_distance' in args.activities) or args.webcam:
|
||||
elif (args.activities and 'social_distance' in args.activities) or args.webcam:
|
||||
path = MONOLOCO_MODEL_NU
|
||||
name = 'monoloco_pp-201207-1350.pkl'
|
||||
else:
|
||||
@ -220,7 +220,7 @@ def predict(args):
|
||||
dic_out = net.forward(keypoints, kk)
|
||||
dic_out = net.post_process(
|
||||
dic_out, boxes, keypoints, kk, dic_gt)
|
||||
if args.social_distance or (args.activities and 'social_distance' in args.activities):
|
||||
if args.activities and 'social_distance' in args.activities:
|
||||
dic_out = net.social_distance(dic_out, args)
|
||||
if args.activities and 'raise_hand' in args.activities:
|
||||
dic_out = net.raising_hand(dic_out, keypoints)
|
||||
|
||||
@ -20,7 +20,7 @@ def cli():
|
||||
predict_parser.add_argument('--glob', help='glob expression for input images (for many images)')
|
||||
predict_parser.add_argument('--checkpoint', help='pifpaf model')
|
||||
predict_parser.add_argument('-o', '--output-directory', help='Output directory')
|
||||
predict_parser.add_argument('--output_types', nargs='+',
|
||||
predict_parser.add_argument('--output_types', nargs='+', default=['multi'],
|
||||
help='what to output: json keypoints skeleton for Pifpaf'
|
||||
'json bird front or multi for MonStereo')
|
||||
predict_parser.add_argument('--no_save', help='to show images', action='store_true')
|
||||
@ -65,7 +65,6 @@ def cli():
|
||||
type=float, default=5.7)
|
||||
|
||||
# Social distancing and social interactions
|
||||
predict_parser.add_argument('--social_distance', help='social', action='store_true')
|
||||
predict_parser.add_argument('--threshold_prob', type=float, help='concordance for samples', default=0.25)
|
||||
predict_parser.add_argument('--threshold_dist', type=float, help='min distance of people', default=2.5)
|
||||
predict_parser.add_argument('--radii', type=tuple, help='o-space radii', default=(0.3, 0.5, 1))
|
||||
@ -137,8 +136,6 @@ def main():
|
||||
from .visuals.webcam import webcam
|
||||
webcam(args)
|
||||
else:
|
||||
if args.output_types is None:
|
||||
args.output_types = ['json']
|
||||
from .predict import predict
|
||||
predict(args)
|
||||
|
||||
|
||||
@ -93,9 +93,11 @@ class KeypointPainter:
|
||||
self.solid_threshold = solid_threshold
|
||||
self.dashed_threshold = 0.1 # Patch to still allow force complete pose (set to zero to resume original)
|
||||
|
||||
def _draw_skeleton(self, ax, x, y, v, *, size=None, color=None, raise_hand='none'):
|
||||
if not np.any(v > 0):
|
||||
return
|
||||
|
||||
def _highlighted_arm(self, x, y, connection, color, lwidth, raise_hand, size=None):
|
||||
|
||||
c = color
|
||||
linewidth = lwidth
|
||||
|
||||
width, height = (1,1)
|
||||
if size:
|
||||
@ -105,10 +107,6 @@ class KeypointPainter:
|
||||
l_arm_width = np.sqrt(((x[9]-x[7])/width)**2 + ((y[9]-y[7])/height)**2)*100
|
||||
r_arm_width = np.sqrt(((x[10]-x[8])/width)**2 + ((y[10]-y[8])/height)**2)*100
|
||||
|
||||
if self.skeleton is not None:
|
||||
for ci, connection in enumerate(np.array(self.skeleton) - 1):
|
||||
c = color
|
||||
linewidth=self.linewidth
|
||||
if ((connection[0] == 5 and connection[1] == 7)
|
||||
or (connection[0] == 7 and connection[1] == 9)) and raise_hand in ['left','both']:
|
||||
c = 'yellow'
|
||||
@ -117,6 +115,24 @@ class KeypointPainter:
|
||||
or (connection[0] == 8 and connection[1] == 10)) and raise_hand in ['right', 'both']:
|
||||
c = 'yellow'
|
||||
linewidth = r_arm_width
|
||||
|
||||
return c, linewidth
|
||||
|
||||
|
||||
def _draw_skeleton(self, ax, x, y, v, i, *, size=None, color=None, activities=None, dic_out=None):
|
||||
if not np.any(v > 0):
|
||||
return
|
||||
|
||||
if self.skeleton is not None:
|
||||
for ci, connection in enumerate(np.array(self.skeleton) - 1):
|
||||
c = color
|
||||
linewidth = self.linewidth
|
||||
|
||||
if activities:
|
||||
if 'raise_hand' in activities:
|
||||
c, linewidth = self._highlighted_arm(x, y, connection, c, linewidth,
|
||||
dic_out['raising_hand'][:][i], size=size)
|
||||
|
||||
if self.color_connections:
|
||||
c = matplotlib.cm.get_cmap('tab20')(ci / len(self.skeleton))
|
||||
if np.all(v[connection] > self.dashed_threshold):
|
||||
@ -193,7 +209,8 @@ class KeypointPainter:
|
||||
(x - scale, y - scale), 2 * scale, 2 * scale, fill=False, color=color))
|
||||
|
||||
def keypoints(self, ax, keypoint_sets, *,
|
||||
size=None, scores=None, color=None, colors=None, texts=None, raise_hand='none'):
|
||||
size=None, scores=None, color=None,
|
||||
colors=None, texts=None, activities=None, dic_out=None):
|
||||
if keypoint_sets is None:
|
||||
return
|
||||
|
||||
@ -214,12 +231,8 @@ class KeypointPainter:
|
||||
if isinstance(color, (int, np.integer)):
|
||||
color = matplotlib.cm.get_cmap('tab20')((color % 20 + 0.05) / 20)
|
||||
|
||||
if raise_hand != 'none':
|
||||
# if raise_hand[:][i] is 'both' or raise_hand[:][i] is 'left' or raise_hand[:][i] is 'right':
|
||||
# color = 'green'
|
||||
self._draw_skeleton(ax, x, y, v, size=size, color=color, raise_hand=raise_hand[:][i])
|
||||
else:
|
||||
self._draw_skeleton(ax, x, y, v, color=color)
|
||||
self._draw_skeleton(ax, x, y, v, i, size=size, color=color, activities=activities, dic_out=dic_out)
|
||||
|
||||
score = scores[i] if scores is not None else None
|
||||
if score is not None:
|
||||
z_str = str(score).split(sep='.')
|
||||
|
||||
@ -52,7 +52,6 @@ class Printer:
|
||||
boxes_gt, uv_camera, radius, auxs = nones(16)
|
||||
|
||||
def __init__(self, image, output_path, kk, args):
|
||||
|
||||
self.im = image
|
||||
self.width = self.im.size[0]
|
||||
self.height = self.im.size[1]
|
||||
@ -60,11 +59,12 @@ class Printer:
|
||||
self.kk = kk
|
||||
self.output_types = args.output_types
|
||||
self.z_max = args.z_max # set max distance to show instances
|
||||
self.show_all = args.show_all or args.webcam
|
||||
self.show = args.show_all or args.webcam
|
||||
self.save = not args.no_save and not args.webcam
|
||||
self.plt_close = not args.webcam
|
||||
self.args = args
|
||||
self.webcam = args.webcam
|
||||
self.show_all = args.show_all or self.webcam
|
||||
self.show = args.show_all or self.webcam
|
||||
self.save = not args.no_save and not self.webcam
|
||||
self.plt_close = not self.webcam
|
||||
self.activities = args.activities
|
||||
|
||||
# define image attributes
|
||||
self.attr = image_attributes(args.dpi, args.output_types)
|
||||
@ -177,33 +177,36 @@ class Printer:
|
||||
return figures, axes
|
||||
|
||||
|
||||
def social_distance_front(self, axis, colors, annotations, dic_out):
|
||||
def _webcam_front(self, axis, colors, activities, annotations, dic_out):
|
||||
sizes = [abs(self.centers[idx][1] - uv_s[1]*self.y_scale) / 1.5 for idx, uv_s in
|
||||
enumerate(self.uv_shoulders)]
|
||||
|
||||
keypoint_sets, _ = get_pifpaf_outputs(annotations)
|
||||
keypoint_painter = KeypointPainter(show_box=False, y_scale=self.y_scale)
|
||||
r_h = 'none'
|
||||
if 'raise_hand' in self.args.activities:
|
||||
r_h = dic_out['raising_hand']
|
||||
|
||||
if activities:
|
||||
keypoint_painter.keypoints(
|
||||
axis, keypoint_sets, size=self.im.size,scores=self.dd_pred,colors=colors, raise_hand=r_h)
|
||||
axis, keypoint_sets, size=self.im.size,
|
||||
scores=self.dd_pred, colors=colors, activities=activities, dic_out=dic_out)
|
||||
|
||||
if 'social_distance' in activities:
|
||||
draw_orientation(axis, self.centers,
|
||||
sizes, self.angles, colors, mode='front')
|
||||
else:
|
||||
keypoint_painter.keypoints(
|
||||
axis, keypoint_sets, size=self.im.size, scores=self.dd_pred)
|
||||
|
||||
|
||||
def social_distance_bird(self, axis, colors):
|
||||
def _activities_bird(self, axis, colors, activities):
|
||||
if 'social_distance' in activities:
|
||||
draw_orientation(axis, self.xz_centers, [], self.angles, colors, mode='bird')
|
||||
|
||||
|
||||
def _front_loop(self, iterator, axes, number, colors, annotations, dic_out):
|
||||
for idx in iterator:
|
||||
if any(xx in self.output_types for xx in ['front', 'multi']) and self.zz_pred[idx] > 0:
|
||||
if self.args.activities:
|
||||
if 'social_distance' in self.args.activities:
|
||||
self.social_distance_front(axes[0], colors, annotations, dic_out)
|
||||
elif 'raise_hand' in self.args.activities:
|
||||
self.social_distance_front(axes[0], colors, annotations, dic_out)
|
||||
if self.webcam:
|
||||
self._webcam_front(axes[0], colors, self.activities, annotations, dic_out)
|
||||
else:
|
||||
self._draw_front(axes[0],
|
||||
self.dd_pred[idx],
|
||||
@ -215,10 +218,8 @@ class Printer:
|
||||
def _bird_loop(self, iterator, axes, colors, number):
|
||||
for idx in iterator:
|
||||
if any(xx in self.output_types for xx in ['bird', 'multi']) and self.zz_pred[idx] > 0:
|
||||
|
||||
if self.args.activities:
|
||||
if 'social_distance' in self.args.activities:
|
||||
self.social_distance_bird(axes[1], colors)
|
||||
if self.activities:
|
||||
self._activities_bird(axes[1], colors, self.activities)
|
||||
# Draw ground truth and uncertainty
|
||||
self._draw_uncertainty(axes, idx)
|
||||
|
||||
@ -231,9 +232,9 @@ class Printer:
|
||||
def draw(self, figures, axes, image, dic_out=None, annotations=None):
|
||||
|
||||
colors = []
|
||||
if self.args.activities:
|
||||
if self.activities:
|
||||
colors = ['deepskyblue' for _ in self.uv_heads]
|
||||
if 'social_distance' in self.args.activities:
|
||||
if 'social_distance' in self.activities:
|
||||
colors = social_distance_colors(colors, dic_out)
|
||||
|
||||
# whether to include instances that don't match the ground-truth
|
||||
@ -246,7 +247,7 @@ class Printer:
|
||||
if any(xx in self.output_types for xx in ['front', 'multi']):
|
||||
number['flag'] = True # add numbers
|
||||
# Remove image if social distance is activated
|
||||
if not self.args.activities or 'social_distance' not in self.args.activities:
|
||||
if not self.activities or 'social_distance' not in self.activities:
|
||||
self.mpl_im0.set_data(image)
|
||||
|
||||
self._front_loop(iterator, axes, number, colors, annotations, dic_out)
|
||||
@ -420,7 +421,7 @@ class Printer:
|
||||
ax.set_axis_off()
|
||||
ax.set_xlim(0, self.width)
|
||||
ax.set_ylim(self.height, 0)
|
||||
if not self.args.activities or 'social_distance' not in self.args.activities:
|
||||
if not self.activities or 'social_distance' not in self.activities:
|
||||
self.mpl_im0 = ax.imshow(self.im)
|
||||
ax.get_xaxis().set_visible(False)
|
||||
ax.get_yaxis().set_visible(False)
|
||||
@ -428,8 +429,7 @@ class Printer:
|
||||
else:
|
||||
uv_max = [0., float(self.height)]
|
||||
xyz_max = pixel_to_camera(uv_max, self.kk, self.z_max)
|
||||
x_max = abs(xyz_max[0]) # shortcut to avoid oval circles in case of different kk
|
||||
x_max=6
|
||||
x_max = max(abs(xyz_max[0]), 6) # shortcut to avoid oval circles in case of different kk
|
||||
corr = round(float(x_max / 3))
|
||||
ax.plot([0, x_max], [0, self.z_max], 'k--')
|
||||
ax.plot([0, -x_max], [0, self.z_max], 'k--')
|
||||
|
||||
@ -36,14 +36,7 @@ def factory_from_args(args):
|
||||
|
||||
logger.configure(args, LOG) # logger first
|
||||
|
||||
if args.output_types is None:
|
||||
args.output_types = ['multi']
|
||||
|
||||
assert 'bird' not in args.output_types
|
||||
if 'json' not in args.output_types:
|
||||
assert len(args.output_types) == 1
|
||||
else:
|
||||
assert len(args.output_types) < 3
|
||||
assert len(args.output_types) == 1 and 'json' not in args.output_types
|
||||
|
||||
# Devices
|
||||
args.device = torch.device('cpu')
|
||||
@ -129,8 +122,7 @@ def webcam(args):
|
||||
print("Escape hit, closing...")
|
||||
break
|
||||
|
||||
intrinsic_size = [xx * 1.3 for xx in pil_image.size]
|
||||
kk, dic_gt = factory_for_gt(intrinsic_size, focal_length=args.focal) # better intrinsics for mac camera
|
||||
kk, dic_gt = factory_for_gt(pil_image.size, focal_length=args.focal)
|
||||
boxes, keypoints = preprocess_pifpaf(
|
||||
pifpaf_outs['left'], (width, height))
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ PREDICT_COMMAND_SOCIAL_DISTANCE = [
|
||||
'python3', '-m', 'monoloco.run',
|
||||
'predict',
|
||||
'docs/frame0032.jpg',
|
||||
'--social_distance',
|
||||
'--activities', 'social_distance',
|
||||
'--output_types', 'front', 'bird',
|
||||
'--decoder-workers=0' # for windows'
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user