197 lines
6.6 KiB
Python
197 lines
6.6 KiB
Python
"""Inspired by Openpifpaf"""
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
from ..network import extract_labels, extract_labels_aux, extract_outputs
|
|
|
|
|
|
class AutoTuneMultiTaskLoss(torch.nn.Module):
|
|
def __init__(self, losses_tr, losses_val, lambdas, tasks):
|
|
super().__init__()
|
|
|
|
assert all(l in (0.0, 1.0) for l in lambdas)
|
|
self.losses = torch.nn.ModuleList(losses_tr)
|
|
self.losses_val = losses_val
|
|
self.lambdas = lambdas
|
|
self.tasks = tasks
|
|
self.log_sigmas = torch.nn.Parameter(torch.zeros((len(lambdas),), dtype=torch.float32), requires_grad=True)
|
|
|
|
def forward(self, outputs, labels, phase='train'):
|
|
|
|
assert phase in ('train', 'val')
|
|
out = extract_outputs(outputs, tasks=self.tasks)
|
|
gt_out = extract_labels(labels, tasks=self.tasks)
|
|
loss_values = [lam * l(o, g) / (2.0 * (log_sigma.exp() ** 2))
|
|
for lam, log_sigma, l, o, g in zip(self.lambdas, self.log_sigmas, self.losses, out, gt_out)]
|
|
|
|
auto_reg = [log_sigma for log_sigma in self.log_sigmas]
|
|
|
|
loss = sum(loss_values) + sum(auto_reg)
|
|
if phase == 'val':
|
|
loss_values_val = [l(o, g) for l, o, g in zip(self.losses_val, out, gt_out)]
|
|
loss_values_val.extend([s.exp() for s in self.log_sigmas])
|
|
return loss, loss_values_val
|
|
return loss, loss_values
|
|
|
|
|
|
class MultiTaskLoss(torch.nn.Module):
|
|
def __init__(self, losses_tr, losses_val, lambdas, tasks):
|
|
super().__init__()
|
|
|
|
self.losses = torch.nn.ModuleList(losses_tr)
|
|
self.losses_val = losses_val
|
|
self.lambdas = lambdas
|
|
self.tasks = tasks
|
|
if len(self.tasks) == 1 and self.tasks[0] == 'aux':
|
|
self.flag_aux = True
|
|
else:
|
|
self.flag_aux = False
|
|
|
|
def forward(self, outputs, labels, phase='train'):
|
|
|
|
assert phase in ('train', 'val')
|
|
out = extract_outputs(outputs, tasks=self.tasks)
|
|
if self.flag_aux:
|
|
gt_out = extract_labels_aux(labels, tasks=self.tasks)
|
|
else:
|
|
gt_out = extract_labels(labels, tasks=self.tasks)
|
|
loss_values = [lam * l(o, g) for lam, l, o, g in zip(self.lambdas, self.losses, out, gt_out)]
|
|
loss = sum(loss_values)
|
|
|
|
if phase == 'val':
|
|
loss_values_val = [l(o, g) for l, o, g in zip(self.losses_val, out, gt_out)]
|
|
return loss, loss_values_val
|
|
return loss, loss_values
|
|
|
|
|
|
class CompositeLoss(torch.nn.Module):
|
|
|
|
def __init__(self, tasks):
|
|
super(CompositeLoss, self).__init__()
|
|
|
|
self.tasks = tasks
|
|
self.multi_loss_tr = {task: (LaplacianLoss() if task == 'd'
|
|
else (nn.BCEWithLogitsLoss() if task in ('aux', )
|
|
else nn.L1Loss())) for task in tasks}
|
|
|
|
self.multi_loss_val = {}
|
|
for task in tasks:
|
|
if task == 'd':
|
|
loss = l1_loss_from_laplace
|
|
elif task == 'ori':
|
|
loss = angle_loss
|
|
elif task in ('aux', ):
|
|
loss = nn.BCEWithLogitsLoss()
|
|
else:
|
|
loss = nn.L1Loss()
|
|
self.multi_loss_val[task] = loss
|
|
|
|
def forward(self):
|
|
losses_tr = [self.multi_loss_tr[l] for l in self.tasks]
|
|
losses_val = [self.multi_loss_val[l] for l in self.tasks]
|
|
return losses_tr, losses_val
|
|
|
|
|
|
class LaplacianLoss(torch.nn.Module):
|
|
"""1D Gaussian with std depending on the absolute distance"""
|
|
def __init__(self, size_average=True, reduce=True, evaluate=False):
|
|
super(LaplacianLoss, self).__init__()
|
|
self.size_average = size_average
|
|
self.reduce = reduce
|
|
self.evaluate = evaluate
|
|
|
|
def laplacian_1d(self, mu_si, xx):
|
|
"""
|
|
1D Gaussian Loss. f(x | mu, sigma). The network outputs mu and sigma. X is the ground truth distance.
|
|
This supports backward().
|
|
Inspired by
|
|
https://github.com/naba89/RNN-Handwriting-Generation-Pytorch/blob/master/loss_functions.py
|
|
|
|
"""
|
|
eps = 0.01 # To avoid 0/0 when no uncertainty
|
|
mu, si = mu_si[:, 0:1], mu_si[:, 1:2]
|
|
norm = 1 - mu / xx # Relative
|
|
const = 2
|
|
|
|
term_a = torch.abs(norm) * torch.exp(-si) + eps
|
|
term_b = si
|
|
norm_bi = (np.mean(np.abs(norm.cpu().detach().numpy())), np.mean(torch.exp(si).cpu().detach().numpy()))
|
|
|
|
if self.evaluate:
|
|
return norm_bi
|
|
return term_a + term_b + const
|
|
|
|
def forward(self, outputs, targets):
|
|
|
|
values = self.laplacian_1d(outputs, targets)
|
|
|
|
if not self.reduce or self.evaluate:
|
|
return values
|
|
if self.size_average:
|
|
mean_values = torch.mean(values)
|
|
return mean_values
|
|
return torch.sum(values)
|
|
|
|
|
|
class GaussianLoss(torch.nn.Module):
|
|
"""1D Gaussian with std depending on the absolute distance
|
|
"""
|
|
def __init__(self, device, size_average=True, reduce=True, evaluate=False):
|
|
super(GaussianLoss, self).__init__()
|
|
self.size_average = size_average
|
|
self.reduce = reduce
|
|
self.evaluate = evaluate
|
|
self.device = device
|
|
|
|
def gaussian_1d(self, mu_si, xx):
|
|
"""
|
|
1D Gaussian Loss. f(x | mu, sigma). The network outputs mu and sigma. X is the ground truth distance.
|
|
This supports backward().
|
|
Inspired by
|
|
https://github.com/naba89/RNN-Handwriting-Generation-Pytorch/blob/master/loss_functions.py
|
|
"""
|
|
mu, si = mu_si[:, 0:1], mu_si[:, 1:2]
|
|
|
|
min_si = torch.ones(si.size()).cuda(self.device) * 0.1
|
|
si = torch.max(min_si, si)
|
|
norm = xx - mu
|
|
term_a = (norm / si)**2 / 2
|
|
term_b = torch.log(si * math.sqrt(2 * math.pi))
|
|
|
|
norm_si = (np.mean(np.abs(norm.cpu().detach().numpy())), np.mean(si.cpu().detach().numpy()))
|
|
|
|
if self.evaluate:
|
|
return norm_si
|
|
|
|
return term_a + term_b
|
|
|
|
def forward(self, outputs, targets):
|
|
|
|
values = self.gaussian_1d(outputs, targets)
|
|
|
|
if not self.reduce or self.evaluate:
|
|
return values
|
|
if self.size_average:
|
|
mean_values = torch.mean(values)
|
|
return mean_values
|
|
return torch.sum(values)
|
|
|
|
|
|
def angle_loss(orient, gt_orient):
|
|
"""Only for evaluation"""
|
|
angles = torch.atan2(orient[:, 0], orient[:, 1])
|
|
gt_angles = torch.atan2(gt_orient[:, 0], gt_orient[:, 1])
|
|
# assert all(angles < math.pi) & all(angles > - math.pi)
|
|
# assert all(gt_angles < math.pi) & all(gt_angles > - math.pi)
|
|
loss = torch.mean(torch.abs(angles - gt_angles)) * 180 / 3.14
|
|
return loss
|
|
|
|
|
|
def l1_loss_from_laplace(out, gt_out):
|
|
"""Only for evaluation"""
|
|
loss = torch.mean(torch.abs(out[:, 0:1] - gt_out))
|
|
return loss
|