monoloco/monoloco/train/trainer_casr.py
Charles Joseph Pierre Beauville f2271229f6 Cyclist intention recognition
2021-06-26 15:50:40 +02:00

365 lines
15 KiB
Python

# pylint: disable=too-many-statements
"""
Training and evaluation of a neural network that, given 2D joints, estimates:
- 3D localization and confidence intervals
- Orientation
- Bounding box dimensions
"""
import copy
import os
import datetime
import logging
from collections import defaultdict
import sys
import time
from itertools import chain
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from .. import __version__
from .datasets import KeypointsDataset
from .losses import CompositeLoss, MultiTaskLoss, AutoTuneMultiTaskLoss
from ..network import extract_outputs, extract_labels
from ..network.architectures import LocoModel
from ..utils import set_logger
class CASRTrainer:
# Constants
VAL_BS = 10000
tasks = ('cyclist',)
val_task = 'cyclist'
lambdas = (1,)
#clusters = ['10', '20', '30', '40']
input_size = 34
output_size = 4
dir_figures = os.path.join('figures', 'losses')
def __init__(self, args):
"""
Initialize directories, load the data and parameters for the training
"""
assert os.path.exists(args.joints), "Input file not found"
self.mode = args.mode
self.joints = args.joints
self.num_epochs = args.epochs
self.no_save = args.no_save
self.print_loss = args.print_loss
self.lr = args.lr
self.sched_step = args.sched_step
self.sched_gamma = args.sched_gamma
self.hidden_size = args.hidden_size
self.n_stage = args.n_stage
self.r_seed = args.r_seed
self.auto_tune_mtl = args.auto_tune_mtl
# Select path out
if args.out:
self.path_out = args.out # full path without extension
dir_out, _ = os.path.split(self.path_out)
else:
dir_out = os.path.join('data', 'outputs')
name = 'casr'
now = datetime.datetime.now()
now_time = now.strftime("%Y%m%d-%H%M")[2:]
name_out = name + '-' + now_time + '.pkl'
self.path_out = os.path.join(dir_out, name_out)
assert os.path.exists(dir_out), "Directory to save the model not found"
print(self.path_out)
# Select the device
use_cuda = torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
print('Device: ', self.device)
torch.manual_seed(self.r_seed)
if use_cuda:
torch.cuda.manual_seed(self.r_seed)
losses_tr, losses_val = CompositeLoss(self.tasks)()
if self.auto_tune_mtl:
self.mt_loss = AutoTuneMultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
else:
self.mt_loss = MultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
self.mt_loss.to(self.device)
# Dataloader
self.dataloaders = {phase: DataLoader(KeypointsDataset(self.joints, phase=phase),
batch_size=args.bs, shuffle=True) for phase in ['train', 'val']}
self.dataset_sizes = {phase: len(KeypointsDataset(self.joints, phase=phase))
for phase in ['train', 'val']}
self.dataset_version = KeypointsDataset(self.joints, phase='train').get_version()
self._set_logger(args)
# Define the model
self.logger.info('Sizes of the dataset: {}'.format(self.dataset_sizes))
print(">>> creating model")
self.model = LocoModel(
input_size=self.input_size,
output_size=self.output_size,
linear_size=args.hidden_size,
p_dropout=args.dropout,
num_stage=self.n_stage,
device=self.device,
)
self.model.to(self.device)
print(">>> model params: {:.3f}M".format(sum(p.numel() for p in self.model.parameters()) / 1000000.0))
print(">>> loss params: {}".format(sum(p.numel() for p in self.mt_loss.parameters())))
# Optimizer and scheduler
all_params = chain(self.model.parameters(), self.mt_loss.parameters())
self.optimizer = torch.optim.Adam(params=all_params, lr=args.lr)
self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')
self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=self.sched_step, gamma=self.sched_gamma)
def train(self):
since = time.time()
best_model_wts = copy.deepcopy(self.model.state_dict())
best_acc = 1e6
best_training_acc = 1e6
best_epoch = 0
epoch_losses = defaultdict(lambda: defaultdict(list))
for epoch in range(self.num_epochs):
running_loss = defaultdict(lambda: defaultdict(int))
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
self.model.train() # Set model to training mode
else:
self.model.eval() # Set model to evaluate mode
for inputs, labels, _, _ in self.dataloaders[phase]:
inputs = inputs.to(self.device)
labels = labels.to(self.device)
with torch.set_grad_enabled(phase == 'train'):
if phase == 'train':
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss, _ = self.mt_loss(outputs, labels, phase=phase)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3)
self.optimizer.step()
self.scheduler.step()
else:
outputs = self.model(inputs)
with torch.no_grad():
loss_eval, loss_values_eval = self.mt_loss(outputs, labels, phase='val')
self.epoch_logs(phase, loss_eval, loss_values_eval, inputs, running_loss)
self.cout_values(epoch, epoch_losses, running_loss)
# deep copy the model
if epoch_losses['val'][self.val_task][-1] < best_acc:
best_acc = epoch_losses['val'][self.val_task][-1]
best_training_acc = epoch_losses['train']['all'][-1]
best_epoch = epoch
best_model_wts = copy.deepcopy(self.model.state_dict())
time_elapsed = time.time() - since
print('\n\n' + '-' * 120)
self.logger.info('Training:\nTraining complete in {:.0f}m {:.0f}s'
.format(time_elapsed // 60, time_elapsed % 60))
self.logger.info('Best training Accuracy: {:.3f}'.format(best_training_acc))
self.logger.info('Best validation Accuracy for {}: {:.3f}'.format(self.val_task, best_acc))
self.logger.info('Saved weights of the model at epoch: {}'.format(best_epoch))
self._print_losses(epoch_losses)
# load best model weights
self.model.load_state_dict(best_model_wts)
return best_epoch
def epoch_logs(self, phase, loss, loss_values, inputs, running_loss):
running_loss[phase]['all'] += loss.item() * inputs.size(0)
for i, task in enumerate(self.tasks):
running_loss[phase][task] += loss_values[i].item() * inputs.size(0)
def evaluate(self, load=False, model=None, debug=False):
# To load a model instead of using the trained one
if load:
self.model.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
# Average distance on training and test set after unnormalizing
self.model.eval()
dic_err = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) # initialized to zero
dic_err['val']['sigmas'] = [0.] * len(self.tasks)
dataset = KeypointsDataset(self.joints, phase='val')
size_eval = len(dataset)
start = 0
with torch.no_grad():
for end in range(self.VAL_BS, size_eval + self.VAL_BS, self.VAL_BS):
end = end if end < size_eval else size_eval
inputs, labels, _, _ = dataset[start:end]
start = end
inputs = inputs.to(self.device)
labels = labels.to(self.device)
# Debug plot for input-output distributions
if debug:
debug_plots(inputs, labels)
sys.exit()
# Forward pass
outputs = self.model(inputs)
#self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
# self.cout_stats(dic_err['val'], size_eval, clst='all')
# Evaluate performances on different clusters and save statistics
# Save the model and the results
if not (self.no_save or load):
torch.save(self.model.state_dict(), self.path_model)
print('-' * 120)
self.logger.info("\nmodel saved: {} \n".format(self.path_model))
else:
self.logger.info("\nmodel not saved\n")
return dic_err, self.model
def compute_stats(self, outputs, labels, dic_err, size_eval, clst):
"""Compute mean, bi and max of torch tensors"""
_, loss_values = self.mt_loss(outputs, labels, phase='val')
rel_frac = outputs.size(0) / size_eval
tasks = self.tasks # Exclude auxiliary
for idx, task in enumerate(tasks):
dic_err[clst][task] += float(loss_values[idx].item()) * (outputs.size(0) / size_eval)
# Distance
errs = torch.abs(extract_outputs(outputs)['d'] - extract_labels(labels)['d'])
assert rel_frac > 0.99, "Variance of errors not supported with partial evaluation"
# Uncertainty
bis = extract_outputs(outputs)['bi'].cpu()
bi = float(torch.mean(bis).item())
bi_perc = float(torch.sum(errs <= bis)) / errs.shape[0]
dic_err[clst]['bi'] += bi * rel_frac
dic_err[clst]['bi%'] += bi_perc * rel_frac
dic_err[clst]['std'] = errs.std()
# (Don't) Save auxiliary task results
dic_err['sigmas'].append(0)
if self.auto_tune_mtl:
assert len(loss_values) == 2 * len(self.tasks)
for i, _ in enumerate(self.tasks):
dic_err['sigmas'][i] += float(loss_values[len(tasks) + i + 1].item()) * rel_frac
def cout_stats(self, dic_err, size_eval, clst):
if clst == 'all':
print('-' * 120)
self.logger.info("Evaluation, val set: \nAv. dist D: {:.2f} m with bi {:.2f} ({:.1f}%), \n"
"X: {:.1f} cm, Y: {:.1f} cm \nOri: {:.1f} "
"\n H: {:.1f} cm, W: {:.1f} cm, L: {:.1f} cm"
"\nAuxiliary Task: {:.1f} %, "
.format(dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
dic_err[clst]['l'] * 100, dic_err[clst]['aux'] * 100))
if self.auto_tune_mtl:
self.logger.info("Sigmas: Z: {:.2f}, X: {:.2f}, Y:{:.2f}, H: {:.2f}, W: {:.2f}, L: {:.2f}, ORI: {:.2f}"
" AUX:{:.2f}\n"
.format(*dic_err['sigmas']))
else:
self.logger.info("Val err clust {} --> D:{:.2f}m, bi:{:.2f} ({:.1f}%), STD:{:.1f}m X:{:.1f} Y:{:.1f} "
"Ori:{:.1f}d, H: {:.0f} W: {:.0f} L:{:.0f} for {} pp. "
.format(clst, dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
dic_err[clst]['std'], dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
dic_err[clst]['l'] * 100, size_eval))
def cout_values(self, epoch, epoch_losses, running_loss):
string = '\r' + '{:.0f} '
format_list = [epoch]
for phase in running_loss:
string = string + phase[0:1].upper() + ':'
for el in running_loss['train']:
loss = running_loss[phase][el] / self.dataset_sizes[phase]
epoch_losses[phase][el].append(loss)
if el == 'all':
string = string + ':{:.1f} '
format_list.append(loss)
elif el in ('ori', 'aux'):
string = string + el + ':{:.1f} '
format_list.append(loss)
else:
string = string + el + ':{:.0f} '
format_list.append(loss * 100)
if epoch % 10 == 0:
print(string.format(*format_list))
def _print_losses(self, epoch_losses):
if not self.print_loss:
return
os.makedirs(self.dir_figures, exist_ok=True)
for idx, phase in enumerate(epoch_losses):
for idx_2, el in enumerate(epoch_losses['train']):
plt.figure(idx + idx_2)
plt.title(phase + '_' + el)
plt.xlabel('epochs')
plt.plot(epoch_losses[phase][el][10:], label='{} Loss: {}'.format(phase, el))
plt.savefig(os.path.join(self.dir_figures, '{}_loss_{}.png'.format(phase, el)))
plt.close()
def _set_logger(self, args):
if self.no_save:
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
else:
self.path_model = self.path_out
print(self.path_model)
self.logger = set_logger(os.path.splitext(self.path_out)[0]) # remove .pkl
self.logger.info( # pylint: disable=logging-fstring-interpolation
f'\nVERSION: {__version__}\n'
f'\nINPUT_FILE: {args.joints}'
f'\nInput file version: {self.dataset_version}'
f'\nTorch version: {torch.__version__}\n'
f'\nTraining arguments:'
f'\nmode: {self.mode} \nlearning rate: {args.lr} \nbatch_size: {args.bs}'
f'\nepochs: {args.epochs} \ndropout: {args.dropout} '
f'\nscheduler step: {args.sched_step} \nscheduler gamma: {args.sched_gamma} '
f'\ninput_size: {self.input_size} \noutput_size: {self.output_size} '
f'\nhidden_size: {args.hidden_size}'
f' \nn_stages: {args.n_stage} \n r_seed: {args.r_seed} \nlambdas: {self.lambdas}'
)
def debug_plots(inputs, labels):
inputs_shoulder = inputs.cpu().numpy()[:, 5]
inputs_hip = inputs.cpu().numpy()[:, 11]
labels = labels.cpu().numpy()
heights = inputs_hip - inputs_shoulder
plt.figure(1)
plt.hist(heights, bins='auto')
plt.show()
plt.figure(2)
plt.hist(labels, bins='auto')
plt.show()
def get_accuracy(outputs, labels):
"""From Binary cross entropy outputs to accuracy"""
mask = outputs >= 0.5
accuracy = 1. - torch.mean(torch.abs(mask.float() - labels)).item()
return accuracy