remove double loading for evaluation of Trainer

This commit is contained in:
lorenzo 2019-07-26 10:27:42 +02:00
parent 0ea3ae811f
commit 87ec4c441e
2 changed files with 33 additions and 60 deletions

View File

@ -19,16 +19,16 @@ import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from train.datasets import KeypointsDataset
from .datasets import KeypointsDataset
from ..network import LaplacianLoss
from ..network.process import laplace_sampling, unnormalize_bi
from ..network.process import unnormalize_bi
from ..network.architectures import LinearModel
from ..utils import set_logger
class Trainer:
def __init__(self, joints, epochs=100, bs=256, dropout=0.2, lr=0.002,
sched_step=20, sched_gamma=1, hidden_size=256, n_stage=3, r_seed=1, n_dropout=0, n_samples=100,
sched_step=20, sched_gamma=1, hidden_size=256, n_stage=3, r_seed=1, n_samples=100,
baseline=False, save=False, print_loss=False):
"""
Initialize directories, load the data and parameters for the training
@ -57,7 +57,6 @@ class Trainer:
self.hidden_size = hidden_size
self.n_stage = n_stage
self.dir_out = dir_out
self.n_dropout = n_dropout
self.n_samples = n_samples
self.r_seed = r_seed
@ -215,71 +214,40 @@ class Trainer:
self.model.eval()
dic_err = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) # initialized to zero
phase = 'val'
dataloader_eval = DataLoader(KeypointsDataset(self.joints, phase=phase),
batch_size=5000, shuffle=True)
size_eval = len(KeypointsDataset(self.joints, phase=phase))
batch_size = 5000
dataset = KeypointsDataset(self.joints, phase=phase)
size_eval = len(dataset)
start = 0
with torch.no_grad():
for inputs, labels, _, _ in dataloader_eval:
for end in range(batch_size, size_eval+batch_size, batch_size):
end = end if end < size_eval else size_eval
inputs, labels, _, _ = dataset[start:end]
start = end
inputs = inputs.to(self.device)
labels = labels.to(self.device)
# Debug plot for input-output distributions
if debug:
inputs_shoulder = inputs.cpu().numpy()[:, 5]
inputs_hip = inputs.cpu().numpy()[:, 11]
labels = labels.cpu().numpy()
heights = inputs_hip - inputs_shoulder
plt.figure(1)
plt.hist(heights, bins='auto')
plt.show()
plt.figure(2)
plt.hist(labels, bins='auto')
plt.show()
debug_plots(inputs, labels)
exit()
# Manually reactivate dropout in eval
self.model.dropout.training = True
total_outputs = torch.empty((0, len(labels))).to(self.device)
if self.n_dropout > 0:
for _ in range(self.n_dropout):
outputs = self.model(inputs)
outputs = unnormalize_bi(outputs)
samples = laplace_sampling(outputs, self.n_samples)
total_outputs = torch.cat((total_outputs, samples), 0)
varss = total_outputs.std(0)
else:
varss = [0]
# Don't use dropout for the mean prediction
self.model.dropout.training = False
# Forward pass
outputs = self.model(inputs)
if not self.baseline:
outputs = unnormalize_bi(outputs)
dic_err[phase]['all'] = self.compute_stats(outputs, labels, varss, dic_err[phase]['all'], size_eval)
dic_err[phase]['all'] = self.compute_stats(outputs, labels, dic_err[phase]['all'], size_eval)
print('-'*120)
self.logger.info("Evaluation:\nAverage distance on the {} set: {:.2f}"
.format(phase, dic_err[phase]['all']['mean']))
self.logger.info("Aleatoric Uncertainty: {:.2f}, inside the interval: {:.1f}%"
self.logger.info("Aleatoric Uncertainty: {:.2f}, inside the interval: {:.1f}%\n"
.format(dic_err[phase]['all']['bi'], dic_err[phase]['all']['conf_bi']*100))
# TODO Add evaluation variance
# self.logger.info("Combined Uncertainty: {:.2f} with a max of {:.2f}, inside the interval: {:.1f}%\n"
# .format(stat_var[0], stat_var[1], stat_var[2]*100))
# Evaluate performances on different clusters and save statistics
nuscenes = KeypointsDataset(self.joints, phase=phase)
for clst in self.clusters:
inputs, labels, size_eval = nuscenes.get_cluster_annotations(clst)
inputs, labels, size_eval = dataset.get_cluster_annotations(clst)
inputs, labels = inputs.to(self.device), labels.to(self.device)
# Forward pass on each cluster
@ -287,7 +255,7 @@ class Trainer:
if not self.baseline:
outputs = unnormalize_bi(outputs)
dic_err[phase][clst] = self.compute_stats(outputs, labels, [0], dic_err[phase][clst], size_eval)
dic_err[phase][clst] = self.compute_stats(outputs, labels, dic_err[phase][clst], size_eval)
self.logger.info("{} error in cluster {} = {:.2f} for {} instances. "
"Aleatoric of {:.2f} with {:.1f}% inside the interval"
@ -295,6 +263,7 @@ class Trainer:
dic_err[phase][clst]['bi'], dic_err[phase][clst]['conf_bi'] * 100))
# Save the model and the results
self.save = False
if self.save and not load:
torch.save(self.model.state_dict(), self.path_model)
print('-'*120)
@ -304,8 +273,8 @@ class Trainer:
return dic_err, self.model
def compute_stats(self, outputs, labels_orig, varss, dic_err, size_eval):
"""Compute mean std (aleatoric) and max of torch tensors"""
def compute_stats(self, outputs, labels_orig, dic_err, size_eval):
"""Compute mean, bi and max of torch tensors"""
labels = labels_orig.view(-1, )
mean_mu = float(self.criterion_eval(outputs[:, 0], labels).item())
@ -321,18 +290,22 @@ class Trainer:
bools_bi = low_bound_bi & up_bound_bi
conf_bi = float(torch.sum(bools_bi)) / float(bools_bi.shape[0])
# if varss[0] >= 0:
# mean_var = torch.mean(varss).item()
# max_var = torch.max(varss).item()
#
# low_bound_var = labels >= (outputs[:, 0] - varss)
# up_bound_var = labels <= (outputs[:, 0] + varss)
# bools_var = low_bound_var & up_bound_var
# conf_var = float(torch.sum(bools_var)) / float(bools_var.shape[0])
dic_err['mean'] += mean_mu * (outputs.size(0) / size_eval)
dic_err['bi'] += mean_bi * (outputs.size(0) / size_eval)
dic_err['count'] += (outputs.size(0) / size_eval)
dic_err['conf_bi'] += conf_bi * (outputs.size(0) / size_eval)
return dic_err
def debug_plots(inputs, labels):
inputs_shoulder = inputs.cpu().numpy()[:, 5]
inputs_hip = inputs.cpu().numpy()[:, 11]
labels = labels.cpu().numpy()
heights = inputs_hip - inputs_shoulder
plt.figure(1)
plt.hist(heights, bins='auto')
plt.show()
plt.figure(2)
plt.hist(labels, bins='auto')
plt.show()

View File

@ -27,7 +27,7 @@ setup(
zip_safe=False,
install_requires=[
'openpifpaf',
'openpifpaf==0.8.0',
'tabulate', # For evaluation
],
extras_require={