monoloco/monstereo/train/losses.py
2020-08-20 11:33:19 +02:00

197 lines
6.6 KiB
Python

"""Inspired by Openpifpaf"""
import math
import torch
import torch.nn as nn
import numpy as np
from ..network import extract_labels, extract_labels_aux, extract_outputs
class AutoTuneMultiTaskLoss(torch.nn.Module):
def __init__(self, losses_tr, losses_val, lambdas, tasks):
super().__init__()
assert all(l in (0.0, 1.0) for l in lambdas)
self.losses = torch.nn.ModuleList(losses_tr)
self.losses_val = losses_val
self.lambdas = lambdas
self.tasks = tasks
self.log_sigmas = torch.nn.Parameter(torch.zeros((len(lambdas),), dtype=torch.float32), requires_grad=True)
def forward(self, outputs, labels, phase='train'):
assert phase in ('train', 'val')
out = extract_outputs(outputs, tasks=self.tasks)
gt_out = extract_labels(labels, tasks=self.tasks)
loss_values = [lam * l(o, g) / (2.0 * (log_sigma.exp() ** 2))
for lam, log_sigma, l, o, g in zip(self.lambdas, self.log_sigmas, self.losses, out, gt_out)]
auto_reg = [log_sigma for log_sigma in self.log_sigmas]
loss = sum(loss_values) + sum(auto_reg)
if phase == 'val':
loss_values_val = [l(o, g) for l, o, g in zip(self.losses_val, out, gt_out)]
loss_values_val.extend([s.exp() for s in self.log_sigmas])
return loss, loss_values_val
return loss, loss_values
class MultiTaskLoss(torch.nn.Module):
def __init__(self, losses_tr, losses_val, lambdas, tasks):
super().__init__()
self.losses = torch.nn.ModuleList(losses_tr)
self.losses_val = losses_val
self.lambdas = lambdas
self.tasks = tasks
if len(self.tasks) == 1 and self.tasks[0] == 'aux':
self.flag_aux = True
else:
self.flag_aux = False
def forward(self, outputs, labels, phase='train'):
assert phase in ('train', 'val')
out = extract_outputs(outputs, tasks=self.tasks)
if self.flag_aux:
gt_out = extract_labels_aux(labels, tasks=self.tasks)
else:
gt_out = extract_labels(labels, tasks=self.tasks)
loss_values = [lam * l(o, g) for lam, l, o, g in zip(self.lambdas, self.losses, out, gt_out)]
loss = sum(loss_values)
if phase == 'val':
loss_values_val = [l(o, g) for l, o, g in zip(self.losses_val, out, gt_out)]
return loss, loss_values_val
return loss, loss_values
class CompositeLoss(torch.nn.Module):
def __init__(self, tasks):
super(CompositeLoss, self).__init__()
self.tasks = tasks
self.multi_loss_tr = {task: (LaplacianLoss() if task == 'd'
else (nn.BCEWithLogitsLoss() if task in ('aux', )
else nn.L1Loss())) for task in tasks}
self.multi_loss_val = {}
for task in tasks:
if task == 'd':
loss = l1_loss_from_laplace
elif task == 'ori':
loss = angle_loss
elif task in ('aux', ):
loss = nn.BCEWithLogitsLoss()
else:
loss = nn.L1Loss()
self.multi_loss_val[task] = loss
def forward(self):
losses_tr = [self.multi_loss_tr[l] for l in self.tasks]
losses_val = [self.multi_loss_val[l] for l in self.tasks]
return losses_tr, losses_val
class LaplacianLoss(torch.nn.Module):
"""1D Gaussian with std depending on the absolute distance"""
def __init__(self, size_average=True, reduce=True, evaluate=False):
super(LaplacianLoss, self).__init__()
self.size_average = size_average
self.reduce = reduce
self.evaluate = evaluate
def laplacian_1d(self, mu_si, xx):
"""
1D Gaussian Loss. f(x | mu, sigma). The network outputs mu and sigma. X is the ground truth distance.
This supports backward().
Inspired by
https://github.com/naba89/RNN-Handwriting-Generation-Pytorch/blob/master/loss_functions.py
"""
eps = 0.01 # To avoid 0/0 when no uncertainty
mu, si = mu_si[:, 0:1], mu_si[:, 1:2]
norm = 1 - mu / xx # Relative
const = 2
term_a = torch.abs(norm) * torch.exp(-si) + eps
term_b = si
norm_bi = (np.mean(np.abs(norm.cpu().detach().numpy())), np.mean(torch.exp(si).cpu().detach().numpy()))
if self.evaluate:
return norm_bi
return term_a + term_b + const
def forward(self, outputs, targets):
values = self.laplacian_1d(outputs, targets)
if not self.reduce or self.evaluate:
return values
if self.size_average:
mean_values = torch.mean(values)
return mean_values
return torch.sum(values)
class GaussianLoss(torch.nn.Module):
"""1D Gaussian with std depending on the absolute distance
"""
def __init__(self, device, size_average=True, reduce=True, evaluate=False):
super(GaussianLoss, self).__init__()
self.size_average = size_average
self.reduce = reduce
self.evaluate = evaluate
self.device = device
def gaussian_1d(self, mu_si, xx):
"""
1D Gaussian Loss. f(x | mu, sigma). The network outputs mu and sigma. X is the ground truth distance.
This supports backward().
Inspired by
https://github.com/naba89/RNN-Handwriting-Generation-Pytorch/blob/master/loss_functions.py
"""
mu, si = mu_si[:, 0:1], mu_si[:, 1:2]
min_si = torch.ones(si.size()).cuda(self.device) * 0.1
si = torch.max(min_si, si)
norm = xx - mu
term_a = (norm / si)**2 / 2
term_b = torch.log(si * math.sqrt(2 * math.pi))
norm_si = (np.mean(np.abs(norm.cpu().detach().numpy())), np.mean(si.cpu().detach().numpy()))
if self.evaluate:
return norm_si
return term_a + term_b
def forward(self, outputs, targets):
values = self.gaussian_1d(outputs, targets)
if not self.reduce or self.evaluate:
return values
if self.size_average:
mean_values = torch.mean(values)
return mean_values
return torch.sum(values)
def angle_loss(orient, gt_orient):
"""Only for evaluation"""
angles = torch.atan2(orient[:, 0], orient[:, 1])
gt_angles = torch.atan2(gt_orient[:, 0], gt_orient[:, 1])
# assert all(angles < math.pi) & all(angles > - math.pi)
# assert all(gt_angles < math.pi) & all(gt_angles > - math.pi)
loss = torch.mean(torch.abs(angles - gt_angles)) * 180 / 3.14
return loss
def l1_loss_from_laplace(out, gt_out):
"""Only for evaluation"""
loss = torch.mean(torch.abs(out[:, 0:1] - gt_out))
return loss