pylint
This commit is contained in:
parent
943b07f58c
commit
a2dc7f160d
@ -9,9 +9,7 @@ Good-names=xx,dd,zz,hh,ww,pp,kk,lr,w1,w2,w3,mm,im,uv,ax,COV_MIN,CONF_MIN
|
|||||||
|
|
||||||
[TYPECHECK]
|
[TYPECHECK]
|
||||||
|
|
||||||
disable=import-error,invalid-name,unused-variable,
|
disable=import-error,invalid-name,unused-variable,E1102,missing-docstring,useless-object-inheritance,duplicate-code,too-many-arguments,too-many-instance-attributes,too-many-locals,too-few-public-methods,arguments-differ,logging-format-interpolation,import-outside-toplevel
|
||||||
E1102,missing-docstring,useless-object-inheritance,duplicate-code,too-many-arguments,
|
|
||||||
too-many-instance-attributes,too-many-locals,too-few-public-methods,arguments-differ,logging-format-interpolation
|
|
||||||
|
|
||||||
|
|
||||||
# List of members which are set dynamically and missed by pylint inference
|
# List of members which are set dynamically and missed by pylint inference
|
||||||
|
|||||||
@ -124,15 +124,8 @@ def show_social(args, image_t, output_path, annotations, dic_out):
|
|||||||
else 'deepskyblue'
|
else 'deepskyblue'
|
||||||
for idx, _ in enumerate(dic_out['xyz_pred'])]
|
for idx, _ in enumerate(dic_out['xyz_pred'])]
|
||||||
|
|
||||||
|
# Draw keypoints and orientation
|
||||||
if 'front' in args.output_types:
|
if 'front' in args.output_types:
|
||||||
|
|
||||||
# Resize back the tensor image to its original dimensions
|
|
||||||
# if not 0.99 < args.scale < 1.01:
|
|
||||||
# size = (round(image_t.shape[0] / args.scale), round(image_t.shape[1] / args.scale)) # height width
|
|
||||||
# image_t = image_t.permute(2, 0, 1).unsqueeze(0) # batch x channels x height x width
|
|
||||||
# image_t = F.interpolate(image_t, size=size).squeeze().permute(1, 2, 0)
|
|
||||||
|
|
||||||
# Draw keypoints and orientation
|
|
||||||
keypoint_sets, scores = get_pifpaf_outputs(annotations)
|
keypoint_sets, scores = get_pifpaf_outputs(annotations)
|
||||||
uv_centers = dic_out['uv_heads']
|
uv_centers = dic_out['uv_heads']
|
||||||
sizes = [abs(dic_out['uv_heads'][idx][1] - uv_s[1]) / 1.5 for idx, uv_s in
|
sizes = [abs(dic_out['uv_heads'][idx][1] - uv_s[1]) / 1.5 for idx, uv_s in
|
||||||
@ -155,14 +148,13 @@ def show_social(args, image_t, output_path, annotations, dic_out):
|
|||||||
|
|
||||||
|
|
||||||
def get_pifpaf_outputs(annotations):
|
def get_pifpaf_outputs(annotations):
|
||||||
|
# TODO extract direct from predictions with pifpaf 0.11+
|
||||||
"""Extract keypoints sets and scores from output dictionary"""
|
"""Extract keypoints sets and scores from output dictionary"""
|
||||||
if not annotations:
|
if not annotations:
|
||||||
return [], []
|
return [], []
|
||||||
keypoints_sets = np.array([dic['keypoints'] for dic in annotations]).reshape(-1, 17, 3)
|
keypoints_sets = np.array([dic['keypoints'] for dic in annotations]).reshape((-1, 17, 3))
|
||||||
score_weights = np.ones((keypoints_sets.shape[0], 17))
|
score_weights = np.ones((keypoints_sets.shape[0], 17))
|
||||||
score_weights[:, 3] = 3.0
|
score_weights[:, 3] = 3.0
|
||||||
# score_weights[:, 5:] = 0.1
|
|
||||||
# score_weights[:, -2:] = 0.0 # ears are not annotated
|
|
||||||
score_weights /= np.sum(score_weights[0, :])
|
score_weights /= np.sum(score_weights[0, :])
|
||||||
kps_scores = keypoints_sets[:, :, 2]
|
kps_scores = keypoints_sets[:, :, 2]
|
||||||
ordered_kps_scores = np.sort(kps_scores, axis=1)[:, ::-1]
|
ordered_kps_scores = np.sort(kps_scores, axis=1)[:, ::-1]
|
||||||
|
|||||||
@ -60,10 +60,12 @@ class EvalKitti:
|
|||||||
self.path_results = os.path.join(self.dir_logs, 'eval-' + now_time + '.json')
|
self.path_results = os.path.join(self.dir_logs, 'eval-' + now_time + '.json')
|
||||||
|
|
||||||
# Set thresholds for comparable recalls
|
# Set thresholds for comparable recalls
|
||||||
self.dic_thresh_iou = {method: (self.thresh_iou_monoloco if method in self.OUR_METHODS else self.thresh_iou_base)
|
self.dic_thresh_iou = {method: (self.thresh_iou_monoloco if method in self.OUR_METHODS
|
||||||
for method in self.methods}
|
else self.thresh_iou_base)
|
||||||
self.dic_thresh_conf = {method: (self.thresh_conf_monoloco if method in self.OUR_METHODS else self.thresh_conf_base)
|
for method in self.methods}
|
||||||
for method in self.methods}
|
self.dic_thresh_conf = {method: (self.thresh_conf_monoloco if method in self.OUR_METHODS
|
||||||
|
else self.thresh_conf_base)
|
||||||
|
for method in self.methods}
|
||||||
|
|
||||||
# Set thresholds to obtain comparable recall
|
# Set thresholds to obtain comparable recall
|
||||||
self.dic_thresh_conf['monopsr'] += 0.4
|
self.dic_thresh_conf['monopsr'] += 0.4
|
||||||
@ -108,7 +110,7 @@ class EvalKitti:
|
|||||||
methods_out = defaultdict(tuple) # Save all methods for comparison
|
methods_out = defaultdict(tuple) # Save all methods for comparison
|
||||||
|
|
||||||
# Count ground_truth:
|
# Count ground_truth:
|
||||||
boxes_gt, ys, truncs_gt, occs_gt = out_gt
|
boxes_gt, ys, truncs_gt, occs_gt = out_gt # pylint: disable=unbalanced-tuple-unpacking
|
||||||
for idx, box in enumerate(boxes_gt):
|
for idx, box in enumerate(boxes_gt):
|
||||||
mode = get_difficulty(box, truncs_gt[idx], occs_gt[idx])
|
mode = get_difficulty(box, truncs_gt[idx], occs_gt[idx])
|
||||||
self.cnt_gt[mode] += 1
|
self.cnt_gt[mode] += 1
|
||||||
@ -371,7 +373,7 @@ class EvalKitti:
|
|||||||
self.name = name
|
self.name = name
|
||||||
# Iterate over each line of the gt file and save box location and distances
|
# Iterate over each line of the gt file and save box location and distances
|
||||||
out_gt = parse_ground_truth(path_gt, 'pedestrian')
|
out_gt = parse_ground_truth(path_gt, 'pedestrian')
|
||||||
boxes_gt, ys, truncs_gt, occs_gt = out_gt
|
boxes_gt, ys, truncs_gt, occs_gt = out_gt # pylint: disable=unbalanced-tuple-unpacking
|
||||||
for label in ys:
|
for label in ys:
|
||||||
heights.append(label[4])
|
heights.append(label[4])
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
|
# pylint: disable=too-many-statements,too-many-branches,cyclic-import
|
||||||
# pylint: disable=too-many-statements,cyclic-import, too-many-branches
|
|
||||||
|
|
||||||
"""Joints Analysis: Supplementary material of MonStereo"""
|
"""Joints Analysis: Supplementary material of MonStereo"""
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
"""
|
"""
|
||||||
Run MonoLoco/MonStereo and converts annotations into KITTI format
|
Run MonoLoco/MonStereo and converts annotations into KITTI format
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
|
|||||||
@ -29,7 +29,7 @@ def get_reid_features(reid_net, boxes, boxes_r, path_image, path_image_r):
|
|||||||
|
|
||||||
class ReID(object):
|
class ReID(object):
|
||||||
def __init__(self, weights_path, device, num_classes=751, height=256, width=128):
|
def __init__(self, weights_path, device, num_classes=751, height=256, width=128):
|
||||||
super(ReID, self).__init__()
|
super().__init__()
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ class ReID(object):
|
|||||||
|
|
||||||
class ResNet50(nn.Module):
|
class ResNet50(nn.Module):
|
||||||
def __init__(self, num_classes, loss):
|
def __init__(self, num_classes, loss):
|
||||||
super(ResNet50, self).__init__()
|
super().__init__()
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
resnet50 = torchvision.models.resnet50(pretrained=True)
|
resnet50 = torchvision.models.resnet50(pretrained=True)
|
||||||
self.base = nn.Sequential(*list(resnet50.children())[:-2])
|
self.base = nn.Sequential(*list(resnet50.children())[:-2])
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from .net import Loco
|
from .net import Loco
|
||||||
from .pifpaf import PifPaf, ImageList
|
|
||||||
from .process import unnormalize_bi, extract_outputs, extract_labels, extract_labels_aux
|
from .process import unnormalize_bi, extract_outputs, extract_labels, extract_labels_aux
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||||||
class MonStereoModel(nn.Module):
|
class MonStereoModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, input_size, output_size=2, linear_size=512, p_dropout=0.2, num_stage=3, device='cuda'):
|
def __init__(self, input_size, output_size=2, linear_size=512, p_dropout=0.2, num_stage=3, device='cuda'):
|
||||||
super(MonStereoModel, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_stage = num_stage
|
self.num_stage = num_stage
|
||||||
self.stereo_size = input_size
|
self.stereo_size = input_size
|
||||||
@ -73,7 +73,7 @@ class MonStereoModel(nn.Module):
|
|||||||
|
|
||||||
class MyLinearSimple(nn.Module):
|
class MyLinearSimple(nn.Module):
|
||||||
def __init__(self, linear_size, p_dropout=0.5):
|
def __init__(self, linear_size, p_dropout=0.5):
|
||||||
super(MyLinearSimple, self).__init__()
|
super().__init__()
|
||||||
self.l_size = linear_size
|
self.l_size = linear_size
|
||||||
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
@ -109,7 +109,7 @@ class MonolocoModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_size, output_size=2, linear_size=256, p_dropout=0.2, num_stage=3):
|
def __init__(self, input_size, output_size=2, linear_size=256, p_dropout=0.2, num_stage=3):
|
||||||
super(MonolocoModel, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
@ -147,7 +147,7 @@ class MonolocoModel(nn.Module):
|
|||||||
|
|
||||||
class MyLinear(nn.Module):
|
class MyLinear(nn.Module):
|
||||||
def __init__(self, linear_size, p_dropout=0.5):
|
def __init__(self, linear_size, p_dropout=0.5):
|
||||||
super(MyLinear, self).__init__()
|
super().__init__()
|
||||||
self.l_size = linear_size
|
self.l_size = linear_size
|
||||||
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|||||||
@ -1,97 +0,0 @@
|
|||||||
|
|
||||||
import glob
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torchvision
|
|
||||||
import torch
|
|
||||||
from PIL import Image, ImageFile
|
|
||||||
from openpifpaf.network import nets
|
|
||||||
from openpifpaf import decoder
|
|
||||||
|
|
||||||
from .process import image_transform
|
|
||||||
|
|
||||||
|
|
||||||
class ImageList(torch.utils.data.Dataset):
|
|
||||||
"""It defines transformations to apply to images and outputs of the dataloader"""
|
|
||||||
def __init__(self, image_paths, scale):
|
|
||||||
self.image_paths = image_paths
|
|
||||||
self.image_paths.sort()
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
image_path = self.image_paths[index]
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
||||||
with open(image_path, 'rb') as f:
|
|
||||||
image = Image.open(f).convert('RGB')
|
|
||||||
|
|
||||||
# PIL images are not iterables
|
|
||||||
original_image = torchvision.transforms.functional.to_tensor(image) # 0-255 --> 0-1
|
|
||||||
image = image_transform(image)
|
|
||||||
|
|
||||||
return image_path, original_image, image
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.image_paths)
|
|
||||||
|
|
||||||
|
|
||||||
def factory_from_args(args):
|
|
||||||
|
|
||||||
# Merge the model_pifpaf argument
|
|
||||||
if not args.checkpoint:
|
|
||||||
args.checkpoint = 'resnet152' # Default model Resnet 152
|
|
||||||
# glob
|
|
||||||
if args.glob:
|
|
||||||
args.images += glob.glob(args.glob)
|
|
||||||
if not args.images:
|
|
||||||
raise Exception("no image files given")
|
|
||||||
|
|
||||||
# add args.device
|
|
||||||
args.device = torch.device('cpu')
|
|
||||||
args.pin_memory = False
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
args.device = torch.device('cuda')
|
|
||||||
args.pin_memory = True
|
|
||||||
|
|
||||||
# Add num_workers
|
|
||||||
args.loader_workers = 8
|
|
||||||
|
|
||||||
# Add visualization defaults
|
|
||||||
args.figure_width = 10
|
|
||||||
args.dpi_factor = 1.0
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
class PifPaf:
|
|
||||||
def __init__(self, args):
|
|
||||||
"""Instanciate the mdodel"""
|
|
||||||
factory_from_args(args)
|
|
||||||
model_pifpaf, _ = nets.factory_from_args(args)
|
|
||||||
model_pifpaf = model_pifpaf.to(args.device)
|
|
||||||
self.processor = decoder.factory_from_args(args, model_pifpaf)
|
|
||||||
self.keypoints_whole = []
|
|
||||||
|
|
||||||
# Scale the keypoints to the original image size for printing (if not webcam)
|
|
||||||
self.scale_np = np.array([args.scale, args.scale, 1] * 17).reshape(17, 3)
|
|
||||||
|
|
||||||
def fields(self, processed_images):
|
|
||||||
"""Encoder for pif and paf fields"""
|
|
||||||
fields_batch = self.processor.fields(processed_images)
|
|
||||||
return fields_batch
|
|
||||||
|
|
||||||
def forward(self, image, processed_image_cpu, fields):
|
|
||||||
"""Decoder, from pif and paf fields to keypoints"""
|
|
||||||
self.processor.set_cpu_image(image, processed_image_cpu)
|
|
||||||
keypoint_sets, scores = self.processor.keypoint_sets(fields)
|
|
||||||
|
|
||||||
if keypoint_sets.size > 0:
|
|
||||||
self.keypoints_whole.append(np.around((keypoint_sets / self.scale_np), 1)
|
|
||||||
.reshape(keypoint_sets.shape[0], -1).tolist())
|
|
||||||
|
|
||||||
pifpaf_out = [
|
|
||||||
{'keypoints': np.around(kps / self.scale_np, 1).reshape(-1).tolist(),
|
|
||||||
'bbox': [np.min(kps[:, 0]) / self.scale_np[0, 0], np.min(kps[:, 1]) / self.scale_np[0, 0],
|
|
||||||
np.max(kps[:, 0]) / self.scale_np[0, 0], np.max(kps[:, 1]) / self.scale_np[0, 0]]}
|
|
||||||
for kps in keypoint_sets
|
|
||||||
]
|
|
||||||
return keypoint_sets, scores, pifpaf_out
|
|
||||||
@ -10,7 +10,6 @@ from collections import defaultdict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import PIL
|
import PIL
|
||||||
from PIL import Image
|
|
||||||
import openpifpaf
|
import openpifpaf
|
||||||
import openpifpaf.datasets as datasets
|
import openpifpaf.datasets as datasets
|
||||||
from openpifpaf.predict import processor_factory, preprocess_factory
|
from openpifpaf.predict import processor_factory, preprocess_factory
|
||||||
@ -132,8 +131,7 @@ def predict(args):
|
|||||||
if args.net == 'monoloco_pp':
|
if args.net == 'monoloco_pp':
|
||||||
print("Prediction with MonoLoco++")
|
print("Prediction with MonoLoco++")
|
||||||
dic_out = net.forward(keypoints, kk)
|
dic_out = net.forward(keypoints, kk)
|
||||||
reorder = False if args.social_distance else True
|
dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt, reorder=not args.social_distance)
|
||||||
dic_out = net.post_process(dic_out, boxes, keypoints, kk, dic_gt, reorder=reorder)
|
|
||||||
|
|
||||||
if args.social_distance:
|
if args.social_distance:
|
||||||
show_social(args, cpu_image, output_path, pifpaf_out, dic_out)
|
show_social(args, cpu_image, output_path, pifpaf_out, dic_out)
|
||||||
@ -168,7 +166,7 @@ def factory_outputs(args, annotation_painter, cpu_image, output_path, pred, dic_
|
|||||||
if dic_out['boxes']: # Only print in case of detections
|
if dic_out['boxes']: # Only print in case of detections
|
||||||
printer = Printer(cpu_image, output_path, kk, args)
|
printer = Printer(cpu_image, output_path, kk, args)
|
||||||
figures, axes = printer.factory_axes(dic_out)
|
figures, axes = printer.factory_axes(dic_out)
|
||||||
printer.draw(figures, axes, dic_out, cpu_image)
|
printer.draw(figures, axes, cpu_image)
|
||||||
|
|
||||||
if 'json' in args.output_types:
|
if 'json' in args.output_types:
|
||||||
with open(os.path.join(output_path + '.monoloco.json'), 'w') as ff:
|
with open(os.path.join(output_path + '.monoloco.json'), 'w') as ff:
|
||||||
|
|||||||
@ -93,7 +93,9 @@ class PreprocessKitti:
|
|||||||
category = 'pedestrian'
|
category = 'pedestrian'
|
||||||
|
|
||||||
# Extract ground truth
|
# Extract ground truth
|
||||||
boxes_gt, ys, _, _ = parse_ground_truth(path_gt, category=category, spherical=True)
|
boxes_gt, ys, _, _ = parse_ground_truth(path_gt, # pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
category=category,
|
||||||
|
spherical=True)
|
||||||
cnt_gt[phase] += len(boxes_gt)
|
cnt_gt[phase] += len(boxes_gt)
|
||||||
cnt_files += 1
|
cnt_files += 1
|
||||||
cnt_files_ped += min(len(boxes_gt), 1) # if no boxes 0 else 1
|
cnt_files_ped += min(len(boxes_gt), 1) # if no boxes 0 else 1
|
||||||
|
|||||||
@ -136,7 +136,7 @@ def main():
|
|||||||
hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
|
hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
|
||||||
monocular=args.monocular, dropout=args.dropout,
|
monocular=args.monocular, dropout=args.dropout,
|
||||||
multiplier=args.multiplier, r_seed=args.r_seed)
|
multiplier=args.multiplier, r_seed=args.r_seed)
|
||||||
hyp_tuning.train()
|
hyp_tuning.train(args)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
from .train import Trainer
|
from .train import Trainer
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class HypTuning:
|
|||||||
# plt.hist(self.lr_list, bins=50)
|
# plt.hist(self.lr_list, bins=50)
|
||||||
# plt.show()
|
# plt.show()
|
||||||
|
|
||||||
def train(self):
|
def train(self, args):
|
||||||
"""Train multiple times using log-space random search"""
|
"""Train multiple times using log-space random search"""
|
||||||
|
|
||||||
best_acc_val = 20
|
best_acc_val = 20
|
||||||
@ -76,10 +76,7 @@ class HypTuning:
|
|||||||
hidden_size = self.hidden_list[idx]
|
hidden_size = self.hidden_list[idx]
|
||||||
n_stage = self.n_stage_list[idx]
|
n_stage = self.n_stage_list[idx]
|
||||||
|
|
||||||
training = Trainer(joints=self.joints, epochs=self.num_epochs,
|
training = Trainer(args)
|
||||||
bs=bs, monocular=self.monocular, dropout=self.dropout, lr=lr, sched_step=sched_step,
|
|
||||||
sched_gamma=sched_gamma, hidden_size=hidden_size, n_stage=n_stage,
|
|
||||||
save=False, print_loss=False, r_seed=self.r_seed)
|
|
||||||
|
|
||||||
best_epoch = training.train()
|
best_epoch = training.train()
|
||||||
dic_err, model = training.evaluate()
|
dic_err, model = training.evaluate()
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class AutoTuneMultiTaskLoss(torch.nn.Module):
|
|||||||
loss_values = [lam * l(o, g) / (2.0 * (log_sigma.exp() ** 2))
|
loss_values = [lam * l(o, g) / (2.0 * (log_sigma.exp() ** 2))
|
||||||
for lam, log_sigma, l, o, g in zip(self.lambdas, self.log_sigmas, self.losses, out, gt_out)]
|
for lam, log_sigma, l, o, g in zip(self.lambdas, self.log_sigmas, self.losses, out, gt_out)]
|
||||||
|
|
||||||
auto_reg = [log_sigma for log_sigma in self.log_sigmas]
|
auto_reg = [log_sigma for log_sigma in self.log_sigmas] # pylint: disable=unnecessary-comprehension
|
||||||
|
|
||||||
loss = sum(loss_values) + sum(auto_reg)
|
loss = sum(loss_values) + sum(auto_reg)
|
||||||
if phase == 'val':
|
if phase == 'val':
|
||||||
@ -70,7 +70,7 @@ class MultiTaskLoss(torch.nn.Module):
|
|||||||
class CompositeLoss(torch.nn.Module):
|
class CompositeLoss(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, tasks):
|
def __init__(self, tasks):
|
||||||
super(CompositeLoss, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.tasks = tasks
|
self.tasks = tasks
|
||||||
self.multi_loss_tr = {task: (LaplacianLoss() if task == 'd'
|
self.multi_loss_tr = {task: (LaplacianLoss() if task == 'd'
|
||||||
@ -98,7 +98,7 @@ class CompositeLoss(torch.nn.Module):
|
|||||||
class LaplacianLoss(torch.nn.Module):
|
class LaplacianLoss(torch.nn.Module):
|
||||||
"""1D Gaussian with std depending on the absolute distance"""
|
"""1D Gaussian with std depending on the absolute distance"""
|
||||||
def __init__(self, size_average=True, reduce=True, evaluate=False):
|
def __init__(self, size_average=True, reduce=True, evaluate=False):
|
||||||
super(LaplacianLoss, self).__init__()
|
super().__init__()
|
||||||
self.size_average = size_average
|
self.size_average = size_average
|
||||||
self.reduce = reduce
|
self.reduce = reduce
|
||||||
self.evaluate = evaluate
|
self.evaluate = evaluate
|
||||||
@ -140,7 +140,7 @@ class GaussianLoss(torch.nn.Module):
|
|||||||
"""1D Gaussian with std depending on the absolute distance
|
"""1D Gaussian with std depending on the absolute distance
|
||||||
"""
|
"""
|
||||||
def __init__(self, device, size_average=True, reduce=True, evaluate=False):
|
def __init__(self, device, size_average=True, reduce=True, evaluate=False):
|
||||||
super(GaussianLoss, self).__init__()
|
super().__init__()
|
||||||
self.size_average = size_average
|
self.size_average = size_average
|
||||||
self.reduce = reduce
|
self.reduce = reduce
|
||||||
self.evaluate = evaluate
|
self.evaluate = evaluate
|
||||||
|
|||||||
@ -165,7 +165,7 @@ class Printer:
|
|||||||
axes.append(ax1)
|
axes.append(ax1)
|
||||||
return figures, axes
|
return figures, axes
|
||||||
|
|
||||||
def draw(self, figures, axes, dic_out, image):
|
def draw(self, figures, axes, image):
|
||||||
|
|
||||||
# whether to include instances that don't match the ground-truth
|
# whether to include instances that don't match the ground-truth
|
||||||
iterator = range(len(self.zz_pred)) if self.show_all else range(len(self.zz_gt))
|
iterator = range(len(self.zz_pred)) if self.show_all else range(len(self.zz_gt))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user