monoloco/monstereo/network/pifpaf.py
2020-08-20 11:33:19 +02:00

103 lines
3.5 KiB
Python

import glob
import numpy as np
import torchvision
import torch
from PIL import Image, ImageFile
from openpifpaf.network import nets
from openpifpaf import decoder
from .process import image_transform
class ImageList(torch.utils.data.Dataset):
"""It defines transformations to apply to images and outputs of the dataloader"""
def __init__(self, image_paths, scale):
self.image_paths = image_paths
self.image_paths.sort()
self.scale = scale
def __getitem__(self, index):
image_path = self.image_paths[index]
ImageFile.LOAD_TRUNCATED_IMAGES = True
with open(image_path, 'rb') as f:
image = Image.open(f).convert('RGB')
if self.scale > 1.01 or self.scale < 0.99:
image = torchvision.transforms.functional.resize(image,
(round(self.scale * image.size[1]),
round(self.scale * image.size[0])),
interpolation=Image.BICUBIC)
# PIL images are not iterables
original_image = torchvision.transforms.functional.to_tensor(image) # 0-255 --> 0-1
image = image_transform(image)
return image_path, original_image, image
def __len__(self):
return len(self.image_paths)
def factory_from_args(args):
# Merge the model_pifpaf argument
if not args.checkpoint:
args.checkpoint = 'resnet152' # Default model Resnet 152
# glob
if args.glob:
args.images += glob.glob(args.glob)
if not args.images:
raise Exception("no image files given")
# add args.device
args.device = torch.device('cpu')
args.pin_memory = False
if torch.cuda.is_available():
args.device = torch.device('cuda')
args.pin_memory = True
# Add num_workers
args.loader_workers = 8
# Add visualization defaults
args.figure_width = 10
args.dpi_factor = 1.0
return args
class PifPaf:
def __init__(self, args):
"""Instanciate the mdodel"""
factory_from_args(args)
model_pifpaf, _ = nets.factory_from_args(args)
model_pifpaf = model_pifpaf.to(args.device)
self.processor = decoder.factory_from_args(args, model_pifpaf)
self.keypoints_whole = []
# Scale the keypoints to the original image size for printing (if not webcam)
self.scale_np = np.array([args.scale, args.scale, 1] * 17).reshape(17, 3)
def fields(self, processed_images):
"""Encoder for pif and paf fields"""
fields_batch = self.processor.fields(processed_images)
return fields_batch
def forward(self, image, processed_image_cpu, fields):
"""Decoder, from pif and paf fields to keypoints"""
self.processor.set_cpu_image(image, processed_image_cpu)
keypoint_sets, scores = self.processor.keypoint_sets(fields)
if keypoint_sets.size > 0:
self.keypoints_whole.append(np.around((keypoint_sets / self.scale_np), 1)
.reshape(keypoint_sets.shape[0], -1).tolist())
pifpaf_out = [
{'keypoints': np.around(kps / self.scale_np, 1).reshape(-1).tolist(),
'bbox': [np.min(kps[:, 0]) / self.scale_np[0, 0], np.min(kps[:, 1]) / self.scale_np[0, 0],
np.max(kps[:, 0]) / self.scale_np[0, 0], np.max(kps[:, 1]) / self.scale_np[0, 0]]}
for kps in keypoint_sets
]
return keypoint_sets, scores, pifpaf_out