From 4c5fb0e42cd25b5c6c4d6d11dfaca1e702dd8306 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Tue, 8 Dec 2020 17:18:33 +0100 Subject: [PATCH] re add missing files --- monstereo/network/pifpaf.py | 102 +++++++++++++++++++++++++ monstereo/predict.py | 146 ++++++++++++++++++++++++++++++++++++ 2 files changed, 248 insertions(+) create mode 100644 monstereo/network/pifpaf.py create mode 100644 monstereo/predict.py diff --git a/monstereo/network/pifpaf.py b/monstereo/network/pifpaf.py new file mode 100644 index 0000000..6209c04 --- /dev/null +++ b/monstereo/network/pifpaf.py @@ -0,0 +1,102 @@ + +import glob + +import numpy as np +import torchvision +import torch +from PIL import Image, ImageFile +from openpifpaf.network import nets +from openpifpaf import decoder + +from .process import image_transform + + +class ImageList(torch.utils.data.Dataset): + """It defines transformations to apply to images and outputs of the dataloader""" + def __init__(self, image_paths, scale): + self.image_paths = image_paths + self.image_paths.sort() + self.scale = scale + + def __getitem__(self, index): + image_path = self.image_paths[index] + ImageFile.LOAD_TRUNCATED_IMAGES = True + with open(image_path, 'rb') as f: + image = Image.open(f).convert('RGB') + + if self.scale > 1.01 or self.scale < 0.99: + image = torchvision.transforms.functional.resize(image, + (round(self.scale * image.size[1]), + round(self.scale * image.size[0])), + interpolation=Image.BICUBIC) + # PIL images are not iterables + original_image = torchvision.transforms.functional.to_tensor(image) # 0-255 --> 0-1 + image = image_transform(image) + + return image_path, original_image, image + + def __len__(self): + return len(self.image_paths) + + +def factory_from_args(args): + + # Merge the model_pifpaf argument + if not args.checkpoint: + args.checkpoint = 'resnet152' # Default model Resnet 152 + # glob + if args.glob: + args.images += glob.glob(args.glob) + if not args.images: + raise Exception("no image files given") + + # add args.device + args.device = torch.device('cpu') + args.pin_memory = False + if torch.cuda.is_available(): + args.device = torch.device('cuda') + args.pin_memory = True + + # Add num_workers + args.loader_workers = 8 + + # Add visualization defaults + args.figure_width = 10 + args.dpi_factor = 1.0 + + return args + + +class PifPaf: + def __init__(self, args): + """Instanciate the mdodel""" + factory_from_args(args) + model_pifpaf, _ = nets.factory_from_args(args) + model_pifpaf = model_pifpaf.to(args.device) + self.processor = decoder.factory_from_args(args, model_pifpaf) + self.keypoints_whole = [] + + # Scale the keypoints to the original image size for printing (if not webcam) + self.scale_np = np.array([args.scale, args.scale, 1] * 17).reshape(17, 3) + + def fields(self, processed_images): + """Encoder for pif and paf fields""" + fields_batch = self.processor.fields(processed_images) + return fields_batch + + def forward(self, image, processed_image_cpu, fields): + """Decoder, from pif and paf fields to keypoints""" + self.processor.set_cpu_image(image, processed_image_cpu) + keypoint_sets, scores = self.processor.keypoint_sets(fields) + + if keypoint_sets.size > 0: + self.keypoints_whole.append(np.around((keypoint_sets / self.scale_np), 1) + .reshape(keypoint_sets.shape[0], -1).tolist()) + + pifpaf_out = [ + {'keypoints': np.around(kps / self.scale_np, 1).reshape(-1).tolist(), + 'bbox': [np.min(kps[:, 0]) / self.scale_np[0, 0], np.min(kps[:, 1]) / self.scale_np[0, 0], + np.max(kps[:, 0]) / self.scale_np[0, 0], np.max(kps[:, 1]) / self.scale_np[0, 0]]} + for kps in keypoint_sets + ] + return keypoint_sets, scores, pifpaf_out diff --git a/monstereo/predict.py b/monstereo/predict.py new file mode 100644 index 0000000..d869386 --- /dev/null +++ b/monstereo/predict.py @@ -0,0 +1,146 @@ + +# pylint: disable=too-many-statements, too-many-branches, undefined-loop-variable + +import os +import json +from collections import defaultdict + + +import torch +from PIL import Image + +from .visuals.printer import Printer +from .visuals.pifpaf_show import KeypointPainter, image_canvas +from .network import PifPaf, ImageList, Loco +from .network.process import factory_for_gt, preprocess_pifpaf + + +def predict(args): + + cnt = 0 + + # Load Models + pifpaf = PifPaf(args) + assert args.mode in ('mono', 'stereo', 'pifpaf') + + if 'mono' in args.mode: + monoloco = Loco(model=args.model, net='monoloco_pp', + device=args.device, n_dropout=args.n_dropout, p_dropout=args.dropout) + + if 'stereo' in args.mode: + monstereo = Loco(model=args.model, net='monstereo', + device=args.device, n_dropout=args.n_dropout, p_dropout=args.dropout) + + # data + data = ImageList(args.images, scale=args.scale) + if args.mode == 'stereo': + assert len(data.image_paths) % 2 == 0, "Odd number of images in a stereo setting" + bs = 2 + else: + bs = 1 + data_loader = torch.utils.data.DataLoader( + data, batch_size=bs, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.loader_workers) + + for idx, (image_paths, image_tensors, processed_images_cpu) in enumerate(data_loader): + images = image_tensors.permute(0, 2, 3, 1) + + processed_images = processed_images_cpu.to(args.device, non_blocking=True) + fields_batch = pifpaf.fields(processed_images) + + # unbatch stereo pair + for ii, (image_path, image, processed_image_cpu, fields) in enumerate(zip( + image_paths, images, processed_images_cpu, fields_batch)): + + if args.output_directory is None: + splits = os.path.split(image_paths[0]) + output_path = os.path.join(splits[0], 'out_' + splits[1]) + else: + file_name = os.path.basename(image_paths[0]) + output_path = os.path.join(args.output_directory, 'out_' + file_name) + print('image', idx, image_path, output_path) + keypoint_sets, scores, pifpaf_out = pifpaf.forward(image, processed_image_cpu, fields) + + if ii == 0: + pifpaf_outputs = [keypoint_sets, scores, pifpaf_out] # keypoints_sets and scores for pifpaf printing + images_outputs = [image] # List of 1 or 2 elements with pifpaf tensor and monoloco original image + pifpaf_outs = {'left': pifpaf_out} + image_path_l = image_path + else: + pifpaf_outs['right'] = pifpaf_out + + if args.mode in ('stereo', 'mono'): + # Extract calibration matrix and ground truth file if present + with open(image_path_l, 'rb') as f: + pil_image = Image.open(f).convert('RGB') + images_outputs.append(pil_image) + + im_name = os.path.basename(image_path_l) + im_size = (float(image.size()[1] / args.scale), float(image.size()[0] / args.scale)) # Original + kk, dic_gt = factory_for_gt(im_size, name=im_name, path_gt=args.path_gt) + + # Preprocess pifpaf outputs and run monoloco + boxes, keypoints = preprocess_pifpaf(pifpaf_outs['left'], im_size, enlarge_boxes=False) + + if args.mode == 'mono': + print("Prediction with MonoLoco++") + dic_out = monoloco.forward(keypoints, kk) + dic_out = monoloco.post_process(dic_out, boxes, keypoints, kk, dic_gt) + + else: + print("Prediction with MonStereo") + boxes_r, keypoints_r = preprocess_pifpaf(pifpaf_outs['right'], im_size) + dic_out = monstereo.forward(keypoints, kk, keypoints_r=keypoints_r) + dic_out = monstereo.post_process(dic_out, boxes, keypoints, kk, dic_gt) + + else: + dic_out = defaultdict(list) + kk = None + + factory_outputs(args, images_outputs, output_path, pifpaf_outputs, dic_out=dic_out, kk=kk) + print('Image {}\n'.format(cnt) + '-' * 120) + cnt += 1 + + +def factory_outputs(args, images_outputs, output_path, pifpaf_outputs, dic_out=None, kk=None): + """Output json files or images according to the choice""" + + # Save json file + if args.mode == 'pifpaf': + keypoint_sets, scores, pifpaf_out = pifpaf_outputs[:] + + # Visualizer + keypoint_painter = KeypointPainter(show_box=False) + skeleton_painter = KeypointPainter(show_box=False, color_connections=True, markersize=1, linewidth=4) + + if 'json' in args.output_types and keypoint_sets.size > 0: + with open(output_path + '.pifpaf.json', 'w') as f: + json.dump(pifpaf_out, f) + + if 'keypoints' in args.output_types: + with image_canvas(images_outputs[0], + output_path + '.keypoints.png', + show=args.show, + fig_width=args.figure_width, + dpi_factor=args.dpi_factor) as ax: + keypoint_painter.keypoints(ax, keypoint_sets) + + if 'skeleton' in args.output_types: + with image_canvas(images_outputs[0], + output_path + '.skeleton.png', + show=args.show, + fig_width=args.figure_width, + dpi_factor=args.dpi_factor) as ax: + skeleton_painter.keypoints(ax, keypoint_sets, scores=scores) + + else: + if any((xx in args.output_types for xx in ['front', 'bird', 'multi'])): + print(output_path) + if dic_out['boxes']: # Only print in case of detections + printer = Printer(images_outputs[1], output_path, kk, args) + figures, axes = printer.factory_axes() + printer.draw(figures, axes, dic_out, images_outputs[1]) + + if 'json' in args.output_types: + with open(os.path.join(output_path + '.monoloco.json'), 'w') as ff: + json.dump(dic_out, ff)