clean
This commit is contained in:
parent
339793d6b4
commit
f5d350e7b0
@ -218,22 +218,22 @@ def show_social(args, image_t, output_path, annotations, dic_out):
|
||||
stds = dic_out['stds_ale']
|
||||
xz_centers = [[xx[0], xx[2]] for xx in dic_out['xyz_pred']]
|
||||
|
||||
# Prepare color for social distancing
|
||||
colors = ['r' if social_interactions(idx, xz_centers, angles, dds,
|
||||
stds=stds,
|
||||
threshold_prob=args.threshold_prob,
|
||||
threshold_dist=args.threshold_dist,
|
||||
radii=args.radii)
|
||||
else 'deepskyblue'
|
||||
for idx, _ in enumerate(dic_out['xyz_pred'])]
|
||||
|
||||
if 'front' in args.output_types:
|
||||
|
||||
# Resize back the tensor image to its original dimensions
|
||||
if not 0.99 < args.scale < 1.01:
|
||||
size = (round(image_t.shape[0] / args.scale), round(image_t.shape[1] / args.scale)) # height width
|
||||
image_t = image_t.permute(2, 0, 1).unsqueeze(0) # batch x channels x height x width
|
||||
image_t = F.interpolate(image_t, size=size).squeeze().permute(1, 2, 0)
|
||||
|
||||
# Prepare color for social distancing
|
||||
colors = ['r' if social_interactions(idx, xz_centers, angles, dds,
|
||||
stds=stds,
|
||||
threshold_prob=args.threshold_prob,
|
||||
threshold_dist=args.threshold_dist,
|
||||
radii=args.radii)
|
||||
else 'deepskyblue'
|
||||
for idx, _ in enumerate(dic_out['xyz_pred'])]
|
||||
# if not 0.99 < args.scale < 1.01:
|
||||
# size = (round(image_t.shape[0] / args.scale), round(image_t.shape[1] / args.scale)) # height width
|
||||
# image_t = image_t.permute(2, 0, 1).unsqueeze(0) # batch x channels x height x width
|
||||
# image_t = F.interpolate(image_t, size=size).squeeze().permute(1, 2, 0)
|
||||
|
||||
# Draw keypoints and orientation
|
||||
keypoint_sets, scores = get_pifpaf_outputs(annotations)
|
||||
|
||||
@ -135,7 +135,9 @@ def predict(args):
|
||||
dic_out = net.forward(keypoints, kk)
|
||||
reorder = False if args.social_distance else True
|
||||
dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt, reorder=reorder)
|
||||
|
||||
if args.social_distance:
|
||||
# image_t = torchvision.transforms.functional.to_tensor(image).permute(1, 2, 0)
|
||||
show_social(args, cpu_image, output_path, pifpaf_out, dic_out)
|
||||
|
||||
else:
|
||||
@ -148,9 +150,9 @@ def predict(args):
|
||||
dic_out = defaultdict(list)
|
||||
kk = None
|
||||
|
||||
# TODO Clean
|
||||
factory_outputs(args, annotation_painter, cpu_image, output_path, pifpaf_outputs, pifpaf_out,
|
||||
dic_out=dic_out, kk=kk)
|
||||
if not args.social_distance:
|
||||
factory_outputs(args, annotation_painter, cpu_image, output_path, pifpaf_outputs, pifpaf_out,
|
||||
dic_out=dic_out, kk=kk)
|
||||
print('Image {}\n'.format(cnt) + '-' * 120)
|
||||
cnt += 1
|
||||
|
||||
@ -192,7 +194,7 @@ def factory_outputs(args, annotation_painter, cpu_image, output_path, pred, pifp
|
||||
print(output_path)
|
||||
if dic_out['boxes']: # Only print in case of detections
|
||||
printer = Printer(cpu_image, output_path, kk, args)
|
||||
figures, axes = printer.factory_axes()
|
||||
figures, axes = printer.factory_axes(dic_out)
|
||||
printer.draw(figures, axes, dic_out, cpu_image)
|
||||
|
||||
if 'json' in args.output_types:
|
||||
|
||||
@ -39,21 +39,20 @@ def canvas(fig_file=None, show=True, **kwargs):
|
||||
@contextmanager
|
||||
def image_canvas(image, fig_file=None, show=True, dpi_factor=1.0, fig_width=10.0, **kwargs):
|
||||
if 'figsize' not in kwargs:
|
||||
kwargs['figsize'] = (fig_width, fig_width * image.shape[0] / image.shape[1])
|
||||
kwargs['figsize'] = (fig_width, fig_width * image.size[1] / image.size[0])
|
||||
|
||||
fig = plt.figure(**kwargs)
|
||||
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
|
||||
ax.set_axis_off()
|
||||
ax.set_xlim(0, image.shape[1])
|
||||
ax.set_ylim(image.shape[0], 0)
|
||||
ax.set_xlim(0, image.size[0])
|
||||
ax.set_ylim(image.size[1], 0)
|
||||
fig.add_axes(ax)
|
||||
image_2 = ndimage.gaussian_filter(image, sigma=2.5)
|
||||
ax.imshow(image_2, alpha=0.4)
|
||||
|
||||
yield ax
|
||||
|
||||
if fig_file:
|
||||
fig.savefig(fig_file, dpi=image.shape[1] / kwargs['figsize'][0] * dpi_factor)
|
||||
fig.savefig(fig_file, dpi=image.size[0] / kwargs['figsize'][0] * dpi_factor)
|
||||
print('keypoints image saved')
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
@ -58,7 +58,7 @@ class Printer:
|
||||
self.output_path = output_path
|
||||
self.kk = kk
|
||||
self.output_types = args.output_types
|
||||
self.z_max = args.z_max # To include ellipses in the image
|
||||
self.z_max = args.z_max # set max distance to show instances
|
||||
self.show_all = args.show_all
|
||||
self.show = args.show_all
|
||||
self.save = not args.no_save
|
||||
@ -74,13 +74,17 @@ class Printer:
|
||||
self.xx_gt = [xx[0] for xx in dic_ann['xyz_real']]
|
||||
self.xx_pred = [xx[0] for xx in dic_ann['xyz_pred']]
|
||||
|
||||
# Set maximum distance
|
||||
self.dd_pred = dic_ann['dds_pred']
|
||||
self.dd_real = dic_ann['dds_real']
|
||||
self.z_max = int(min(self.z_max + 4, max(max(self.dd_pred), max(self.dd_real, default=0))))
|
||||
|
||||
# Do not print instances outside z_max
|
||||
self.zz_gt = [xx[2] if xx[2] < self.z_max - self.stds_epi[idx] else 0
|
||||
for idx, xx in enumerate(dic_ann['xyz_real'])]
|
||||
self.zz_pred = [xx[2] if xx[2] < self.z_max - self.stds_epi[idx] else 0
|
||||
for idx, xx in enumerate(dic_ann['xyz_pred'])]
|
||||
self.dd_pred = dic_ann['dds_pred']
|
||||
self.dd_real = dic_ann['dds_real']
|
||||
|
||||
self.uv_heads = dic_ann['uv_heads']
|
||||
self.uv_shoulders = dic_ann['uv_shoulders']
|
||||
self.boxes = dic_ann['boxes']
|
||||
@ -97,11 +101,14 @@ class Printer:
|
||||
else:
|
||||
self.modes.append('stereo')
|
||||
|
||||
def factory_axes(self):
|
||||
def factory_axes(self, dic_out):
|
||||
"""Create axes for figures: front bird multi"""
|
||||
axes = []
|
||||
figures = []
|
||||
|
||||
# Process the annotation dictionary of monoloco
|
||||
self._process_results(dic_out)
|
||||
|
||||
# Initialize multi figure, resizing it for aesthetic proportion
|
||||
if 'multi' in self.output_types:
|
||||
assert 'bird' and 'front' not in self.output_types, \
|
||||
@ -160,9 +167,6 @@ class Printer:
|
||||
|
||||
def draw(self, figures, axes, dic_out, image):
|
||||
|
||||
# Process the annotation dictionary of monoloco
|
||||
self._process_results(dic_out)
|
||||
|
||||
# whether to include instances that don't match the ground-truth
|
||||
iterator = range(len(self.zz_pred)) if self.show_all else range(len(self.zz_gt))
|
||||
if not iterator:
|
||||
@ -171,9 +175,9 @@ class Printer:
|
||||
|
||||
# Draw the front figure
|
||||
number = dict(flag=False, num=97)
|
||||
if 'multi' in self.output_types:
|
||||
if any(xx in self.output_types for xx in ['front', 'multi']):
|
||||
number['flag'] = True # add numbers
|
||||
self.mpl_im0.set_data(image)
|
||||
self.mpl_im0.set_data(image)
|
||||
for idx in iterator:
|
||||
if any(xx in self.output_types for xx in ['front', 'multi']) and self.zz_pred[idx] > 0:
|
||||
self._draw_front(axes[0],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user