365 lines
15 KiB
Python
365 lines
15 KiB
Python
# pylint: disable=too-many-statements
|
|
|
|
"""
|
|
Training and evaluation of a neural network which predicts 3D localization and confidence intervals
|
|
given 2d joints
|
|
"""
|
|
|
|
import copy
|
|
import os
|
|
import datetime
|
|
import logging
|
|
from collections import defaultdict
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from itertools import chain
|
|
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import lr_scheduler
|
|
|
|
from .datasets import KeypointsDataset
|
|
from .losses import CompositeLoss, MultiTaskLoss, AutoTuneMultiTaskLoss
|
|
from ..network import extract_outputs, extract_labels
|
|
from ..network.architectures import SimpleModel
|
|
from ..utils import set_logger
|
|
|
|
|
|
class Trainer:
|
|
# Constants
|
|
VAL_BS = 10000
|
|
|
|
tasks = ('d', 'x', 'y', 'h', 'w', 'l', 'ori', 'aux')
|
|
val_task = 'd'
|
|
lambdas = (1, 1, 1, 1, 1, 1, 1, 1)
|
|
|
|
def __init__(self, joints, epochs=100, bs=256, dropout=0.2, lr=0.002,
|
|
sched_step=20, sched_gamma=1, hidden_size=256, n_stage=3, r_seed=1, n_samples=100,
|
|
monocular=False, save=False, print_loss=True):
|
|
"""
|
|
Initialize directories, load the data and parameters for the training
|
|
"""
|
|
|
|
# Initialize directories and parameters
|
|
dir_out = os.path.join('data', 'models')
|
|
if not os.path.exists(dir_out):
|
|
warnings.warn("Warning: output directory not found, the model will not be saved")
|
|
dir_logs = os.path.join('data', 'logs')
|
|
if not os.path.exists(dir_logs):
|
|
warnings.warn("Warning: default logs directory not found")
|
|
assert os.path.exists(joints), "Input file not found"
|
|
|
|
self.joints = joints
|
|
self.num_epochs = epochs
|
|
self.save = save
|
|
self.print_loss = print_loss
|
|
self.monocular = monocular
|
|
self.lr = lr
|
|
self.sched_step = sched_step
|
|
self.sched_gamma = sched_gamma
|
|
self.clusters = ['10', '20', '30', '50', '>50']
|
|
self.hidden_size = hidden_size
|
|
self.n_stage = n_stage
|
|
self.dir_out = dir_out
|
|
self.n_samples = n_samples
|
|
self.r_seed = r_seed
|
|
self.auto_tune_mtl = False
|
|
|
|
# Select the device
|
|
use_cuda = torch.cuda.is_available()
|
|
self.device = torch.device("cuda" if use_cuda else "cpu")
|
|
print('Device: ', self.device)
|
|
torch.manual_seed(r_seed)
|
|
if use_cuda:
|
|
torch.cuda.manual_seed(r_seed)
|
|
|
|
# Remove auxiliary task if monocular
|
|
if self.monocular and self.tasks[-1] == 'aux':
|
|
self.tasks = self.tasks[:-1]
|
|
self.lambdas = self.lambdas[:-1]
|
|
|
|
losses_tr, losses_val = CompositeLoss(self.tasks)()
|
|
|
|
if self.auto_tune_mtl:
|
|
self.mt_loss = AutoTuneMultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
|
|
else:
|
|
self.mt_loss = MultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
|
|
self.mt_loss.to(self.device)
|
|
|
|
if not self.monocular:
|
|
input_size = 68
|
|
output_size = 10
|
|
else:
|
|
input_size = 34
|
|
output_size = 9
|
|
|
|
now = datetime.datetime.now()
|
|
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
|
name_out = 'ms-' + now_time
|
|
if self.save:
|
|
self.path_model = os.path.join(dir_out, name_out + '.pkl')
|
|
self.logger = set_logger(os.path.join(dir_logs, name_out))
|
|
self.logger.info("Training arguments: \nepochs: {} \nbatch_size: {} \ndropout: {}"
|
|
"\nmonocular: {} \nlearning rate: {} \nscheduler step: {} \nscheduler gamma: {} "
|
|
"\ninput_size: {} \noutput_size: {}\nhidden_size: {} \nn_stages: {} "
|
|
"\nr_seed: {} \nlambdas: {} \ninput_file: {}"
|
|
.format(epochs, bs, dropout, self.monocular, lr, sched_step, sched_gamma, input_size,
|
|
output_size, hidden_size, n_stage, r_seed, self.lambdas, self.joints))
|
|
else:
|
|
logging.basicConfig(level=logging.INFO)
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
# Dataloader
|
|
self.dataloaders = {phase: DataLoader(KeypointsDataset(self.joints, phase=phase),
|
|
batch_size=bs, shuffle=True) for phase in ['train', 'val']}
|
|
|
|
self.dataset_sizes = {phase: len(KeypointsDataset(self.joints, phase=phase))
|
|
for phase in ['train', 'val']}
|
|
|
|
# Define the model
|
|
self.logger.info('Sizes of the dataset: {}'.format(self.dataset_sizes))
|
|
print(">>> creating model")
|
|
|
|
self.model = SimpleModel(input_size=input_size, output_size=output_size, linear_size=hidden_size,
|
|
p_dropout=dropout, num_stage=self.n_stage, device=self.device)
|
|
self.model.to(self.device)
|
|
print(">>> model params: {:.3f}M".format(sum(p.numel() for p in self.model.parameters()) / 1000000.0))
|
|
print(">>> loss params: {}".format(sum(p.numel() for p in self.mt_loss.parameters())))
|
|
|
|
# Optimizer and scheduler
|
|
all_params = chain(self.model.parameters(), self.mt_loss.parameters())
|
|
self.optimizer = torch.optim.Adam(params=all_params, lr=lr)
|
|
self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=self.sched_step, gamma=self.sched_gamma)
|
|
|
|
def train(self):
|
|
since = time.time()
|
|
best_model_wts = copy.deepcopy(self.model.state_dict())
|
|
best_acc = 1e6
|
|
best_training_acc = 1e6
|
|
best_epoch = 0
|
|
epoch_losses = defaultdict(lambda: defaultdict(list))
|
|
for epoch in range(self.num_epochs):
|
|
running_loss = defaultdict(lambda: defaultdict(int))
|
|
|
|
# Each epoch has a training and validation phase
|
|
for phase in ['train', 'val']:
|
|
if phase == 'train':
|
|
self.model.train() # Set model to training mode
|
|
else:
|
|
self.model.eval() # Set model to evaluate mode
|
|
|
|
for inputs, labels, _, _ in self.dataloaders[phase]:
|
|
inputs = inputs.to(self.device)
|
|
labels = labels.to(self.device)
|
|
with torch.set_grad_enabled(phase == 'train'):
|
|
if phase == 'train':
|
|
outputs = self.model(inputs)
|
|
loss, loss_values = self.mt_loss(outputs, labels, phase=phase)
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 2)
|
|
self.optimizer.step()
|
|
self.scheduler.step()
|
|
|
|
else:
|
|
outputs = self.model(inputs)
|
|
with torch.no_grad():
|
|
loss_eval, loss_values_eval = self.mt_loss(outputs, labels, phase='val')
|
|
self.epoch_logs(phase, loss_eval, loss_values_eval, inputs, running_loss)
|
|
|
|
self.cout_values(epoch, epoch_losses, running_loss)
|
|
|
|
# deep copy the model
|
|
|
|
if epoch_losses['val'][self.val_task][-1] < best_acc:
|
|
best_acc = epoch_losses['val'][self.val_task][-1]
|
|
best_training_acc = epoch_losses['train']['all'][-1]
|
|
best_epoch = epoch
|
|
best_model_wts = copy.deepcopy(self.model.state_dict())
|
|
|
|
time_elapsed = time.time() - since
|
|
print('\n\n' + '-' * 120)
|
|
self.logger.info('Training:\nTraining complete in {:.0f}m {:.0f}s'
|
|
.format(time_elapsed // 60, time_elapsed % 60))
|
|
self.logger.info('Best training Accuracy: {:.3f}'.format(best_training_acc))
|
|
self.logger.info('Best validation Accuracy for {}: {:.3f}'.format(self.val_task, best_acc))
|
|
self.logger.info('Saved weights of the model at epoch: {}'.format(best_epoch))
|
|
|
|
if self.print_loss:
|
|
print_losses(epoch_losses)
|
|
|
|
# load best model weights
|
|
self.model.load_state_dict(best_model_wts)
|
|
return best_epoch
|
|
|
|
def epoch_logs(self, phase, loss, loss_values, inputs, running_loss):
|
|
|
|
running_loss[phase]['all'] += loss.item() * inputs.size(0)
|
|
for i, task in enumerate(self.tasks):
|
|
running_loss[phase][task] += loss_values[i].item() * inputs.size(0)
|
|
|
|
def evaluate(self, load=False, model=None, debug=False):
|
|
|
|
# To load a model instead of using the trained one
|
|
if load:
|
|
self.model.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
|
|
|
|
# Average distance on training and test set after unnormalizing
|
|
self.model.eval()
|
|
dic_err = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) # initialized to zero
|
|
dic_err['val']['sigmas'] = [0.] * len(self.tasks)
|
|
dataset = KeypointsDataset(self.joints, phase='val')
|
|
size_eval = len(dataset)
|
|
start = 0
|
|
with torch.no_grad():
|
|
for end in range(self.VAL_BS, size_eval + self.VAL_BS, self.VAL_BS):
|
|
end = end if end < size_eval else size_eval
|
|
inputs, labels, _, _ = dataset[start:end]
|
|
start = end
|
|
inputs = inputs.to(self.device)
|
|
labels = labels.to(self.device)
|
|
|
|
# Debug plot for input-output distributions
|
|
if debug:
|
|
debug_plots(inputs, labels)
|
|
sys.exit()
|
|
|
|
# Forward pass
|
|
outputs = self.model(inputs)
|
|
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
|
|
|
|
self.cout_stats(dic_err['val'], size_eval, clst='all')
|
|
# Evaluate performances on different clusters and save statistics
|
|
for clst in self.clusters:
|
|
inputs, labels, size_eval = dataset.get_cluster_annotations(clst)
|
|
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
|
|
|
# Forward pass on each cluster
|
|
outputs = self.model(inputs)
|
|
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst=clst)
|
|
self.cout_stats(dic_err['val'], size_eval, clst=clst)
|
|
|
|
# Save the model and the results
|
|
if self.save and not load:
|
|
torch.save(self.model.state_dict(), self.path_model)
|
|
print('-' * 120)
|
|
self.logger.info("\nmodel saved: {} \n".format(self.path_model))
|
|
else:
|
|
self.logger.info("\nmodel not saved\n")
|
|
|
|
return dic_err, self.model
|
|
|
|
def compute_stats(self, outputs, labels, dic_err, size_eval, clst):
|
|
"""Compute mean, bi and max of torch tensors"""
|
|
|
|
loss, loss_values = self.mt_loss(outputs, labels, phase='val')
|
|
rel_frac = outputs.size(0) / size_eval
|
|
|
|
tasks = self.tasks[:-1] if self.tasks[-1] == 'aux' else self.tasks # Exclude auxiliary
|
|
|
|
for idx, task in enumerate(tasks):
|
|
dic_err[clst][task] += float(loss_values[idx].item()) * (outputs.size(0) / size_eval)
|
|
|
|
# Distance
|
|
errs = torch.abs(extract_outputs(outputs)['d'] - extract_labels(labels)['d'])
|
|
|
|
assert rel_frac > 0.99, "Variance of errors not supported with partial evaluation"
|
|
|
|
# Uncertainty
|
|
bis = extract_outputs(outputs)['bi'].cpu()
|
|
bi = float(torch.mean(bis).item())
|
|
bi_perc = float(torch.sum(errs <= bis)) / errs.shape[0]
|
|
dic_err[clst]['bi'] += bi * rel_frac
|
|
dic_err[clst]['bi%'] += bi_perc * rel_frac
|
|
dic_err[clst]['std'] = errs.std()
|
|
|
|
# (Don't) Save auxiliary task results
|
|
if self.monocular:
|
|
dic_err[clst]['aux'] = 0
|
|
dic_err['sigmas'].append(0)
|
|
else:
|
|
acc_aux = get_accuracy(extract_outputs(outputs)['aux'], extract_labels(labels)['aux'])
|
|
dic_err[clst]['aux'] += acc_aux * rel_frac
|
|
|
|
if self.auto_tune_mtl:
|
|
assert len(loss_values) == 2 * len(self.tasks)
|
|
for i, _ in enumerate(self.tasks):
|
|
dic_err['sigmas'][i] += float(loss_values[len(tasks) + i + 1].item()) * rel_frac
|
|
|
|
def cout_stats(self, dic_err, size_eval, clst):
|
|
if clst == 'all':
|
|
print('-' * 120)
|
|
self.logger.info("Evaluation, val set: \nAv. dist D: {:.2f} m with bi {:.2f} ({:.1f}%), \n"
|
|
"X: {:.1f} cm, Y: {:.1f} cm \nOri: {:.1f} "
|
|
"\n H: {:.1f} cm, W: {:.1f} cm, L: {:.1f} cm"
|
|
"\nAuxiliary Task: {:.1f} %, "
|
|
.format(dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
|
|
dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
|
|
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
|
|
dic_err[clst]['l'] * 100, dic_err[clst]['aux'] * 100))
|
|
if self.auto_tune_mtl:
|
|
self.logger.info("Sigmas: Z: {:.2f}, X: {:.2f}, Y:{:.2f}, H: {:.2f}, W: {:.2f}, L: {:.2f}, ORI: {:.2f}"
|
|
" AUX:{:.2f}\n"
|
|
.format(*dic_err['sigmas']))
|
|
else:
|
|
self.logger.info("Val err clust {} --> D:{:.2f}m, bi:{:.2f} ({:.1f}%), STD:{:.1f}m X:{:.1f} Y:{:.1f} "
|
|
"Ori:{:.1f}d, H: {:.0f} W: {:.0f} L:{:.0f} for {} pp. "
|
|
.format(clst, dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
|
|
dic_err[clst]['std'], dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
|
|
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
|
|
dic_err[clst]['l'] * 100, size_eval))
|
|
|
|
def cout_values(self, epoch, epoch_losses, running_loss):
|
|
|
|
string = '\r' + '{:.0f} '
|
|
format_list = [epoch]
|
|
for phase in running_loss:
|
|
string = string + phase[0:1].upper() + ':'
|
|
for el in running_loss['train']:
|
|
loss = running_loss[phase][el] / self.dataset_sizes[phase]
|
|
epoch_losses[phase][el].append(loss)
|
|
if el == 'all':
|
|
string = string + ':{:.1f} '
|
|
format_list.append(loss)
|
|
elif el in ('ori', 'aux'):
|
|
string = string + el + ':{:.1f} '
|
|
format_list.append(loss)
|
|
else:
|
|
string = string + el + ':{:.0f} '
|
|
format_list.append(loss * 100)
|
|
|
|
if epoch % 10 == 0:
|
|
print(string.format(*format_list))
|
|
|
|
|
|
def debug_plots(inputs, labels):
|
|
inputs_shoulder = inputs.cpu().numpy()[:, 5]
|
|
inputs_hip = inputs.cpu().numpy()[:, 11]
|
|
labels = labels.cpu().numpy()
|
|
heights = inputs_hip - inputs_shoulder
|
|
plt.figure(1)
|
|
plt.hist(heights, bins='auto')
|
|
plt.show()
|
|
plt.figure(2)
|
|
plt.hist(labels, bins='auto')
|
|
plt.show()
|
|
|
|
|
|
def print_losses(epoch_losses):
|
|
for idx, phase in enumerate(epoch_losses):
|
|
for idx_2, el in enumerate(epoch_losses['train']):
|
|
plt.figure(idx + idx_2)
|
|
plt.plot(epoch_losses[phase][el][10:], label='{} Loss: {}'.format(phase, el))
|
|
plt.savefig('figures/{}_loss_{}.png'.format(phase, el))
|
|
plt.close()
|
|
|
|
|
|
def get_accuracy(outputs, labels):
|
|
"""From Binary cross entropy outputs to accuracy"""
|
|
|
|
mask = outputs >= 0.5
|
|
accuracy = 1. - torch.mean(torch.abs(mask.float() - labels)).item()
|
|
return accuracy
|