This commit is contained in:
Lorenzo 2021-01-07 11:38:34 +01:00
parent 339793d6b4
commit f5d350e7b0
4 changed files with 36 additions and 31 deletions

View File

@ -218,22 +218,22 @@ def show_social(args, image_t, output_path, annotations, dic_out):
stds = dic_out['stds_ale']
xz_centers = [[xx[0], xx[2]] for xx in dic_out['xyz_pred']]
# Prepare color for social distancing
colors = ['r' if social_interactions(idx, xz_centers, angles, dds,
stds=stds,
threshold_prob=args.threshold_prob,
threshold_dist=args.threshold_dist,
radii=args.radii)
else 'deepskyblue'
for idx, _ in enumerate(dic_out['xyz_pred'])]
if 'front' in args.output_types:
# Resize back the tensor image to its original dimensions
if not 0.99 < args.scale < 1.01:
size = (round(image_t.shape[0] / args.scale), round(image_t.shape[1] / args.scale)) # height width
image_t = image_t.permute(2, 0, 1).unsqueeze(0) # batch x channels x height x width
image_t = F.interpolate(image_t, size=size).squeeze().permute(1, 2, 0)
# Prepare color for social distancing
colors = ['r' if social_interactions(idx, xz_centers, angles, dds,
stds=stds,
threshold_prob=args.threshold_prob,
threshold_dist=args.threshold_dist,
radii=args.radii)
else 'deepskyblue'
for idx, _ in enumerate(dic_out['xyz_pred'])]
# if not 0.99 < args.scale < 1.01:
# size = (round(image_t.shape[0] / args.scale), round(image_t.shape[1] / args.scale)) # height width
# image_t = image_t.permute(2, 0, 1).unsqueeze(0) # batch x channels x height x width
# image_t = F.interpolate(image_t, size=size).squeeze().permute(1, 2, 0)
# Draw keypoints and orientation
keypoint_sets, scores = get_pifpaf_outputs(annotations)

View File

@ -135,7 +135,9 @@ def predict(args):
dic_out = net.forward(keypoints, kk)
reorder = False if args.social_distance else True
dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt, reorder=reorder)
if args.social_distance:
# image_t = torchvision.transforms.functional.to_tensor(image).permute(1, 2, 0)
show_social(args, cpu_image, output_path, pifpaf_out, dic_out)
else:
@ -148,9 +150,9 @@ def predict(args):
dic_out = defaultdict(list)
kk = None
# TODO Clean
factory_outputs(args, annotation_painter, cpu_image, output_path, pifpaf_outputs, pifpaf_out,
dic_out=dic_out, kk=kk)
if not args.social_distance:
factory_outputs(args, annotation_painter, cpu_image, output_path, pifpaf_outputs, pifpaf_out,
dic_out=dic_out, kk=kk)
print('Image {}\n'.format(cnt) + '-' * 120)
cnt += 1
@ -192,7 +194,7 @@ def factory_outputs(args, annotation_painter, cpu_image, output_path, pred, pifp
print(output_path)
if dic_out['boxes']: # Only print in case of detections
printer = Printer(cpu_image, output_path, kk, args)
figures, axes = printer.factory_axes()
figures, axes = printer.factory_axes(dic_out)
printer.draw(figures, axes, dic_out, cpu_image)
if 'json' in args.output_types:

View File

@ -39,21 +39,20 @@ def canvas(fig_file=None, show=True, **kwargs):
@contextmanager
def image_canvas(image, fig_file=None, show=True, dpi_factor=1.0, fig_width=10.0, **kwargs):
if 'figsize' not in kwargs:
kwargs['figsize'] = (fig_width, fig_width * image.shape[0] / image.shape[1])
kwargs['figsize'] = (fig_width, fig_width * image.size[1] / image.size[0])
fig = plt.figure(**kwargs)
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
ax.set_axis_off()
ax.set_xlim(0, image.shape[1])
ax.set_ylim(image.shape[0], 0)
ax.set_xlim(0, image.size[0])
ax.set_ylim(image.size[1], 0)
fig.add_axes(ax)
image_2 = ndimage.gaussian_filter(image, sigma=2.5)
ax.imshow(image_2, alpha=0.4)
yield ax
if fig_file:
fig.savefig(fig_file, dpi=image.shape[1] / kwargs['figsize'][0] * dpi_factor)
fig.savefig(fig_file, dpi=image.size[0] / kwargs['figsize'][0] * dpi_factor)
print('keypoints image saved')
if show:
plt.show()

View File

@ -58,7 +58,7 @@ class Printer:
self.output_path = output_path
self.kk = kk
self.output_types = args.output_types
self.z_max = args.z_max # To include ellipses in the image
self.z_max = args.z_max # set max distance to show instances
self.show_all = args.show_all
self.show = args.show_all
self.save = not args.no_save
@ -74,13 +74,17 @@ class Printer:
self.xx_gt = [xx[0] for xx in dic_ann['xyz_real']]
self.xx_pred = [xx[0] for xx in dic_ann['xyz_pred']]
# Set maximum distance
self.dd_pred = dic_ann['dds_pred']
self.dd_real = dic_ann['dds_real']
self.z_max = int(min(self.z_max + 4, max(max(self.dd_pred), max(self.dd_real, default=0))))
# Do not print instances outside z_max
self.zz_gt = [xx[2] if xx[2] < self.z_max - self.stds_epi[idx] else 0
for idx, xx in enumerate(dic_ann['xyz_real'])]
self.zz_pred = [xx[2] if xx[2] < self.z_max - self.stds_epi[idx] else 0
for idx, xx in enumerate(dic_ann['xyz_pred'])]
self.dd_pred = dic_ann['dds_pred']
self.dd_real = dic_ann['dds_real']
self.uv_heads = dic_ann['uv_heads']
self.uv_shoulders = dic_ann['uv_shoulders']
self.boxes = dic_ann['boxes']
@ -97,11 +101,14 @@ class Printer:
else:
self.modes.append('stereo')
def factory_axes(self):
def factory_axes(self, dic_out):
"""Create axes for figures: front bird multi"""
axes = []
figures = []
# Process the annotation dictionary of monoloco
self._process_results(dic_out)
# Initialize multi figure, resizing it for aesthetic proportion
if 'multi' in self.output_types:
assert 'bird' and 'front' not in self.output_types, \
@ -160,9 +167,6 @@ class Printer:
def draw(self, figures, axes, dic_out, image):
# Process the annotation dictionary of monoloco
self._process_results(dic_out)
# whether to include instances that don't match the ground-truth
iterator = range(len(self.zz_pred)) if self.show_all else range(len(self.zz_gt))
if not iterator:
@ -171,9 +175,9 @@ class Printer:
# Draw the front figure
number = dict(flag=False, num=97)
if 'multi' in self.output_types:
if any(xx in self.output_types for xx in ['front', 'multi']):
number['flag'] = True # add numbers
self.mpl_im0.set_data(image)
self.mpl_im0.set_data(image)
for idx in iterator:
if any(xx in self.output_types for xx in ['front', 'multi']) and self.zz_pred[idx] > 0:
self._draw_front(axes[0],