pylint
This commit is contained in:
parent
3c6ebe22c9
commit
224ee0c3cd
@ -57,8 +57,13 @@ class ActivityEvaluator:
|
||||
device = torch.device('cpu')
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
self.monoloco = Loco(model=args.model, net=args.net,
|
||||
device=device, n_dropout=args.n_dropout, p_dropout=args.dropout)
|
||||
self.monoloco = Loco(
|
||||
model=args.model,
|
||||
mode=args.mode,
|
||||
device=device,
|
||||
n_dropout=args.n_dropout,
|
||||
p_dropout=args.dropout)
|
||||
|
||||
self.all_pred = defaultdict(list)
|
||||
self.all_gt = defaultdict(list)
|
||||
assert args.dataset in ('collective', 'kitti')
|
||||
|
||||
@ -13,7 +13,7 @@ from collections import defaultdict
|
||||
from tabulate import tabulate
|
||||
|
||||
from ..utils import get_iou_matches, get_task_error, get_pixel_error, check_conditions, \
|
||||
get_difficulty, split_training, parse_ground_truth, get_iou_matches_matrix
|
||||
get_difficulty, split_training, parse_ground_truth, get_iou_matches_matrix, average, find_cluster
|
||||
from ..visuals import show_results, show_spread, show_task_error, show_box_plot
|
||||
|
||||
|
||||
@ -417,15 +417,6 @@ def add_true_negatives(err, cnt_gt):
|
||||
err['matched'] = 100 * matched / cnt_gt
|
||||
|
||||
|
||||
def find_cluster(dd, clusters):
|
||||
"""Find the correct cluster. Above the last cluster goes into "excluded (together with the ones from kitti cat"""
|
||||
|
||||
for idx, clst in enumerate(clusters[:-1]):
|
||||
if int(clst) < dd <= int(clusters[idx+1]):
|
||||
return clst
|
||||
return 'excluded'
|
||||
|
||||
|
||||
def extract_indices(idx_to_check, *args):
|
||||
"""
|
||||
Look if a given index j_gt is present in all the other series of indices (_, j)
|
||||
@ -448,11 +439,6 @@ def extract_indices(idx_to_check, *args):
|
||||
return all(checks), indices
|
||||
|
||||
|
||||
def average(my_list):
|
||||
"""calculate mean of a list"""
|
||||
return sum(my_list) / len(my_list)
|
||||
|
||||
|
||||
def filter_directories(main_dir, methods):
|
||||
for method in methods:
|
||||
dir_method = os.path.join(main_dir, method)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# pylint: disable=too-many-statements,too-many-branches,cyclic-import
|
||||
# pylint: disable=too-many-statements,too-many-branches
|
||||
|
||||
"""Joints Analysis: Supplementary material of MonStereo"""
|
||||
|
||||
@ -9,7 +9,7 @@ from collections import defaultdict
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from .eval_kitti import find_cluster, average
|
||||
from ..utils import find_cluster, average
|
||||
from ..visuals.figures import get_distances
|
||||
|
||||
COCO_KEYPOINTS = [
|
||||
|
||||
@ -63,6 +63,7 @@ class GenerateKitti:
|
||||
self.baselines['mono'] = ['monoloco', 'geometric']
|
||||
self.monoloco = Loco(
|
||||
model=self.monoloco_checkpoint,
|
||||
mode='mono',
|
||||
net='monoloco',
|
||||
device=device,
|
||||
n_dropout=args.n_dropout,
|
||||
|
||||
@ -12,7 +12,7 @@ class CustomL1Loss(torch.nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, dic_norm, device, beta=1):
|
||||
super(CustomL1Loss, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.dic_norm = dic_norm
|
||||
self.device = device
|
||||
@ -60,7 +60,7 @@ class LaplacianLoss(torch.nn.Module):
|
||||
"""1D Gaussian with std depending on the absolute distance
|
||||
"""
|
||||
def __init__(self, size_average=True, reduce=True, evaluate=False):
|
||||
super(LaplacianLoss, self).__init__()
|
||||
super().__init__()
|
||||
self.size_average = size_average
|
||||
self.reduce = reduce
|
||||
self.evaluate = evaluate
|
||||
@ -101,7 +101,7 @@ class GaussianLoss(torch.nn.Module):
|
||||
"""1D Gaussian with std depending on the absolute distance
|
||||
"""
|
||||
def __init__(self, device, size_average=True, reduce=True, evaluate=False):
|
||||
super(GaussianLoss, self).__init__()
|
||||
super().__init__()
|
||||
self.size_average = size_average
|
||||
self.reduce = reduce
|
||||
self.evaluate = evaluate
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# pylint: disable=too-many-statements
|
||||
# pylint: disable=too-many-statements, too-many-branches
|
||||
|
||||
"""
|
||||
Loco super class for MonStereo, MonoLoco, MonoLoco++ nets.
|
||||
|
||||
@ -1,105 +0,0 @@
|
||||
|
||||
import glob
|
||||
|
||||
import numpy as np
|
||||
import torchvision
|
||||
import torch
|
||||
from PIL import Image, ImageFile
|
||||
from openpifpaf.network import nets
|
||||
from openpifpaf import decoder
|
||||
|
||||
from .process import image_transform
|
||||
|
||||
|
||||
class ImageList(torch.utils.data.Dataset):
|
||||
"""It defines transformations to apply to images and outputs of the dataloader"""
|
||||
def __init__(self, image_paths, scale):
|
||||
self.image_paths = image_paths
|
||||
self.scale = scale
|
||||
|
||||
def __getitem__(self, index):
|
||||
image_path = self.image_paths[index]
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
with open(image_path, 'rb') as f:
|
||||
image = Image.open(f).convert('RGB')
|
||||
|
||||
if self.scale > 1.01 or self.scale < 0.99:
|
||||
image = torchvision.transforms.functional.resize(image,
|
||||
(round(self.scale * image.size[1]),
|
||||
round(self.scale * image.size[0])),
|
||||
interpolation=Image.BICUBIC)
|
||||
# PIL images are not iterables
|
||||
original_image = torchvision.transforms.functional.to_tensor(image) # 0-255 --> 0-1
|
||||
image = image_transform(image)
|
||||
|
||||
return image_path, original_image, image
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
|
||||
def factory_from_args(args):
|
||||
|
||||
# Merge the model_pifpaf argument
|
||||
if not args.checkpoint:
|
||||
args.checkpoint = 'resnet152' # Default model Resnet 152
|
||||
# glob
|
||||
if not args.webcam:
|
||||
if args.glob:
|
||||
args.images += glob.glob(args.glob)
|
||||
if not args.images:
|
||||
raise Exception("no image files given")
|
||||
|
||||
# add args.device
|
||||
args.device = torch.device('cpu')
|
||||
args.pin_memory = False
|
||||
if torch.cuda.is_available():
|
||||
args.device = torch.device('cuda')
|
||||
args.pin_memory = True
|
||||
|
||||
# Add num_workers
|
||||
args.loader_workers = 8
|
||||
|
||||
# Add visualization defaults
|
||||
args.figure_width = 10
|
||||
args.dpi_factor = 1.0
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class PifPaf:
|
||||
def __init__(self, args):
|
||||
"""Instanciate the mdodel"""
|
||||
factory_from_args(args)
|
||||
model_pifpaf, _ = nets.factory_from_args(args)
|
||||
model_pifpaf = model_pifpaf.to(args.device)
|
||||
self.processor = decoder.factory_from_args(args, model_pifpaf)
|
||||
self.keypoints_whole = []
|
||||
|
||||
# Scale the keypoints to the original image size for printing (if not webcam)
|
||||
if not args.webcam:
|
||||
self.scale_np = np.array([args.scale, args.scale, 1] * 17).reshape(17, 3)
|
||||
else:
|
||||
self.scale_np = np.array([1, 1, 1] * 17).reshape(17, 3)
|
||||
|
||||
def fields(self, processed_images):
|
||||
"""Encoder for pif and paf fields"""
|
||||
fields_batch = self.processor.fields(processed_images)
|
||||
return fields_batch
|
||||
|
||||
def forward(self, image, processed_image_cpu, fields):
|
||||
"""Decoder, from pif and paf fields to keypoints"""
|
||||
self.processor.set_cpu_image(image, processed_image_cpu)
|
||||
keypoint_sets, scores = self.processor.keypoint_sets(fields)
|
||||
|
||||
if keypoint_sets.size > 0:
|
||||
self.keypoints_whole.append(np.around((keypoint_sets / self.scale_np), 1)
|
||||
.reshape(keypoint_sets.shape[0], -1).tolist())
|
||||
|
||||
pifpaf_out = [
|
||||
{'keypoints': np.around(kps / self.scale_np, 1).reshape(-1).tolist(),
|
||||
'bbox': [np.min(kps[:, 0]) / self.scale_np[0, 0], np.min(kps[:, 1]) / self.scale_np[0, 0],
|
||||
np.max(kps[:, 0]) / self.scale_np[0, 0], np.max(kps[:, 1]) / self.scale_np[0, 0]]}
|
||||
for kps in keypoint_sets
|
||||
]
|
||||
return keypoint_sets, scores, pifpaf_out
|
||||
@ -78,9 +78,9 @@ def factory_for_gt(im_size, focal_length=5.7, name=None, path_gt=None):
|
||||
|
||||
# Without ground-truth-file
|
||||
elif im_size[0] / im_size[1] > 2.5: # KITTI default
|
||||
kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]] # Kitti calibration
|
||||
dic_gt = None
|
||||
logger.info("Using KITTI calibration matrix...")
|
||||
kk = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]] # Kitti calibration
|
||||
dic_gt = None
|
||||
logger.info("Using KITTI calibration matrix...")
|
||||
else: # nuScenes camera parameters
|
||||
kk = [
|
||||
[im_size[0]*focal_length/Sx, 0., im_size[0]/2],
|
||||
|
||||
@ -50,7 +50,7 @@ def download_checkpoints(args):
|
||||
dic_models = {'keypoints': pifpaf_model}
|
||||
if not os.path.exists(pifpaf_model):
|
||||
import gdown
|
||||
LOG.info(f'Downloading OpenPifPaf model in {torch_dir}')
|
||||
LOG.info('Downloading OpenPifPaf model in %s', torch_dir)
|
||||
gdown.download(OPENPIFPAF_MODEL, pifpaf_model, quiet=False)
|
||||
|
||||
if args.mode == 'keypoints':
|
||||
@ -74,7 +74,7 @@ def download_checkpoints(args):
|
||||
dic_models[args.mode] = model
|
||||
if not os.path.exists(model):
|
||||
import gdown
|
||||
LOG.info(f'Downloading model (modality: {args.mode}) in {torch_dir}')
|
||||
LOG.info('Downloading model in %s', torch_dir)
|
||||
gdown.download(path, model, quiet=False)
|
||||
return dic_models
|
||||
|
||||
@ -117,7 +117,7 @@ def factory_from_args(args):
|
||||
# Patch for stereo images with batch_size = 2
|
||||
if args.batch_size == 2 and not args.long_edge:
|
||||
args.long_edge = 1238
|
||||
LOG.info("Long-edge set to %i".format(args.long_edge))
|
||||
LOG.info("Long-edge set to %i", args.long_edge)
|
||||
|
||||
# Make default pifpaf argument
|
||||
args.force_complete_pose = True
|
||||
@ -165,7 +165,7 @@ def predict(args):
|
||||
|
||||
# unbatch (only for MonStereo)
|
||||
for idx, (pred, meta) in enumerate(zip(pred_batch, meta_batch)):
|
||||
LOG.info('batch %d: %s'.format(batch_i, meta['file_name']))
|
||||
LOG.info('batch %d: %s', batch_i, meta['file_name'])
|
||||
pred = [ann.inverse_transform(meta) for ann in pred]
|
||||
|
||||
# Load image and collect pifpaf results
|
||||
@ -219,7 +219,7 @@ def predict(args):
|
||||
|
||||
# Outputs
|
||||
factory_outputs(args, pifpaf_outs, dic_out, output_path, kk=kk)
|
||||
LOG.info('Image {}\n'.format(cnt) + '-' * 120)
|
||||
print(f'Image {cnt}\n' + '-' * 120)
|
||||
cnt += 1
|
||||
|
||||
|
||||
|
||||
@ -99,7 +99,7 @@ class Trainer:
|
||||
if not self.no_save:
|
||||
self.path_model = os.path.join(dir_out, name_out + '.pkl')
|
||||
self.logger = set_logger(os.path.join(dir_logs, name_out))
|
||||
self.logger.info(
|
||||
self.logger.info( # pylint: disable=logging-fstring-interpolation
|
||||
f'Training arguments: \ninput_file: {self.joints} \nmode: {self.mode} '
|
||||
f'\nlearning rate: {args.lr} \nbatch_size: {args.bs}'
|
||||
f'\nepochs: {args.epochs} \ndropout: {args.dropout} '
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
|
||||
from .iou import get_iou_matches, reorder_matches, get_iou_matrix, get_iou_matches_matrix
|
||||
from .misc import get_task_error, get_pixel_error, append_cluster, open_annotations, make_new_directory, normalize_hwl
|
||||
from .misc import get_task_error, get_pixel_error, append_cluster, open_annotations, make_new_directory,\
|
||||
normalize_hwl, average
|
||||
from .kitti import check_conditions, get_difficulty, split_training, parse_ground_truth, get_calibration, \
|
||||
factory_basename, factory_file, get_category, read_and_rewrite
|
||||
factory_basename, factory_file, get_category, read_and_rewrite, find_cluster
|
||||
from .camera import xyz_from_distance, get_keypoints, pixel_to_camera, project_3d, open_image, correct_angle,\
|
||||
to_spherical, to_cartesian, back_correct_angles, project_to_pixels
|
||||
from .logs import set_logger
|
||||
|
||||
@ -266,3 +266,12 @@ def read_and_rewrite(path_orig, path_new):
|
||||
except FileNotFoundError:
|
||||
ff = open(path_new, "a+")
|
||||
ff.close()
|
||||
|
||||
|
||||
def find_cluster(dd, clusters):
|
||||
"""Find the correct cluster. Above the last cluster goes into "excluded (together with the ones from kitti cat"""
|
||||
|
||||
for idx, clst in enumerate(clusters[:-1]):
|
||||
if int(clst) < dd <= int(clusters[idx+1]):
|
||||
return clst
|
||||
return 'excluded'
|
||||
|
||||
@ -72,3 +72,8 @@ def normalize_hwl(lab):
|
||||
hwl_new = list((np.array(hwl) - np.array([AV_H, AV_W, AV_L])) / HLW_STD)
|
||||
lab_new = lab[0:4] + hwl_new + lab[7:]
|
||||
return lab_new
|
||||
|
||||
|
||||
def average(my_list):
|
||||
"""calculate mean of a list"""
|
||||
return sum(my_list) / len(my_list)
|
||||
|
||||
@ -1,122 +0,0 @@
|
||||
# pylint: disable=W0212
|
||||
"""
|
||||
Webcam demo application
|
||||
|
||||
Implementation adapted from https://github.com/vita-epfl/openpifpaf/blob/master/openpifpaf/webcam.py
|
||||
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
from ..visuals import Printer
|
||||
from ..network import PifPaf, MonoLoco
|
||||
from ..network.process import preprocess_pifpaf, factory_for_gt, image_transform
|
||||
|
||||
|
||||
def webcam(args):
|
||||
|
||||
# add args.device
|
||||
args.device = torch.device('cpu')
|
||||
if torch.cuda.is_available():
|
||||
args.device = torch.device('cuda')
|
||||
|
||||
# load models
|
||||
args.camera = True
|
||||
pifpaf = PifPaf(args)
|
||||
monoloco = MonoLoco(model=args.model, device=args.device)
|
||||
|
||||
# Start recording
|
||||
cam = cv2.VideoCapture(0)
|
||||
visualizer_monoloco = None
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
ret, frame = cam.read()
|
||||
image = cv2.resize(frame, None, fx=args.scale, fy=args.scale)
|
||||
height, width, _ = image.shape
|
||||
print('resized image size: {}'.format(image.shape))
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_image_cpu = image_transform(image.copy())
|
||||
processed_image = processed_image_cpu.contiguous().to(args.device, non_blocking=True)
|
||||
fields = pifpaf.fields(torch.unsqueeze(processed_image, 0))[0]
|
||||
_, _, pifpaf_out = pifpaf.forward(image, processed_image_cpu, fields)
|
||||
|
||||
if not ret:
|
||||
break
|
||||
key = cv2.waitKey(1)
|
||||
|
||||
if key % 256 == 27:
|
||||
# ESC pressed
|
||||
print("Escape hit, closing...")
|
||||
break
|
||||
pil_image = Image.fromarray(image)
|
||||
intrinsic_size = [xx * 1.3 for xx in pil_image.size]
|
||||
kk, dict_gt = factory_for_gt(intrinsic_size) # better intrinsics for mac camera
|
||||
if visualizer_monoloco is None: # it is, at the beginning
|
||||
visualizer_monoloco = VisualizerMonoloco(kk, args)(pil_image) # create it with the first image
|
||||
visualizer_monoloco.send(None)
|
||||
|
||||
boxes, keypoints = preprocess_pifpaf(pifpaf_out, (width, height))
|
||||
outputs, varss = monoloco.forward(keypoints, kk)
|
||||
dic_out = monoloco.post_process(outputs, varss, boxes, keypoints, kk, dict_gt)
|
||||
print(dic_out)
|
||||
visualizer_monoloco.send((pil_image, dic_out))
|
||||
|
||||
end = time.time()
|
||||
print("run-time: {:.2f} ms".format((end-start)*1000))
|
||||
|
||||
cam.release()
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
class VisualizerMonoloco:
|
||||
def __init__(self, kk, args, epistemic=False):
|
||||
self.kk = kk
|
||||
self.args = args
|
||||
self.z_max = args.z_max
|
||||
self.epistemic = epistemic
|
||||
self.output_types = args.output_types
|
||||
|
||||
def __call__(self, first_image, fig_width=4.0, **kwargs):
|
||||
if 'figsize' not in kwargs:
|
||||
kwargs['figsize'] = (fig_width, fig_width * first_image.size[0] / first_image.size[1])
|
||||
|
||||
printer = Printer(first_image, output_path="", kk=self.kk, output_types=self.output_types,
|
||||
z_max=self.z_max, epistemic=self.epistemic)
|
||||
figures, axes = printer.factory_axes()
|
||||
|
||||
for fig in figures:
|
||||
fig.show()
|
||||
|
||||
while True:
|
||||
image, dict_ann = yield
|
||||
while axes and (axes[-1] and axes[-1].patches): # for front -1==0, for bird/combined -1 == 1
|
||||
if axes[0]:
|
||||
del axes[0].patches[0]
|
||||
del axes[0].texts[0]
|
||||
if len(axes) == 2:
|
||||
del axes[1].patches[0]
|
||||
del axes[1].patches[0] # the one became the 0
|
||||
if len(axes[1].lines) > 2:
|
||||
del axes[1].lines[2]
|
||||
if axes[1].texts: # in case of no text
|
||||
del axes[1].texts[0]
|
||||
printer.draw(figures, axes, dict_ann, image)
|
||||
mypause(0.01)
|
||||
|
||||
|
||||
def mypause(interval):
|
||||
manager = plt._pylab_helpers.Gcf.get_active()
|
||||
if manager is not None:
|
||||
canvas = manager.canvas
|
||||
if canvas.figure.stale:
|
||||
canvas.draw_idle()
|
||||
canvas.start_event_loop(interval)
|
||||
else:
|
||||
time.sleep(interval)
|
||||
Loading…
Reference in New Issue
Block a user