Preprocessing for std and non-std CASR in one file
This commit is contained in:
parent
5ffc7dd0f6
commit
148d3f2843
@ -38,12 +38,12 @@ def match_bboxes(bbox_gt, bbox_pred):
|
|||||||
def standard_bbox(bbox):
|
def standard_bbox(bbox):
|
||||||
return [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]]
|
return [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]]
|
||||||
|
|
||||||
def load_gt(path=gt_path):
|
def load_gt():
|
||||||
return pickle.load(open(path, 'rb'), encoding='latin1')
|
return pickle.load(open(gt_path, 'rb'), encoding='latin1')
|
||||||
|
|
||||||
def load_res(path=res_path):
|
def load_res():
|
||||||
mono = []
|
mono = []
|
||||||
for folder in sorted(glob.glob(path), key=lambda x:float(re.findall(r"(\d+)",x)[0])):
|
for folder in sorted(glob.glob(res_path), key=lambda x:float(re.findall(r"(\d+)",x)[0])):
|
||||||
data_list = []
|
data_list = []
|
||||||
for file in sorted(os.listdir(folder), key=lambda x:float(re.findall(r"(\d+)",x)[0])):
|
for file in sorted(os.listdir(folder), key=lambda x:float(re.findall(r"(\d+)",x)[0])):
|
||||||
if 'json' in file:
|
if 'json' in file:
|
||||||
@ -54,7 +54,7 @@ def load_res(path=res_path):
|
|||||||
mono.append(data_list)
|
mono.append(data_list)
|
||||||
return mono
|
return mono
|
||||||
|
|
||||||
def create_dic():
|
def create_dic(std=False):
|
||||||
gt=load_gt()
|
gt=load_gt()
|
||||||
res=load_res()
|
res=load_res()
|
||||||
dic_jo = {
|
dic_jo = {
|
||||||
@ -63,7 +63,14 @@ def create_dic():
|
|||||||
'version': __version__,
|
'version': __version__,
|
||||||
}
|
}
|
||||||
split = ['3', '4']
|
split = ['3', '4']
|
||||||
for i in range(len(res[:])):
|
if std:
|
||||||
|
wrong = [6, 8, 9, 10, 11, 12, 14, 21, 40, 43, 55, 70, 76, 92, 109,
|
||||||
|
110, 112, 113, 121, 123, 124, 127, 128, 134, 136, 139, 165, 173]
|
||||||
|
mode = 'std'
|
||||||
|
else:
|
||||||
|
wrong = []
|
||||||
|
mode = ''
|
||||||
|
for i in [x for x in range(len(res[:])) if x not in wrong]:
|
||||||
for j in [x for x in range(len(res[i][:])) if 'boxes' in res[i][x]]:
|
for j in [x for x in range(len(res[i][:])) if 'boxes' in res[i][x]]:
|
||||||
folder = gt[i][j]['video_folder']
|
folder = gt[i][j]['video_folder']
|
||||||
|
|
||||||
@ -77,13 +84,16 @@ def create_dic():
|
|||||||
|
|
||||||
keypoints = [res[i][j]['uv_kps'][good_idx]]
|
keypoints = [res[i][j]['uv_kps'][good_idx]]
|
||||||
|
|
||||||
|
gt_turn = gt[i][j]['left_or_right']
|
||||||
|
if std and gt_turn == 3:
|
||||||
|
gt_turn = 2
|
||||||
inp = preprocess_monoloco(keypoints, torch.eye(3)).view(-1).tolist()
|
inp = preprocess_monoloco(keypoints, torch.eye(3)).view(-1).tolist()
|
||||||
dic_jo[phase]['kps'].append(keypoints)
|
dic_jo[phase]['kps'].append(keypoints)
|
||||||
dic_jo[phase]['X'].append(inp)
|
dic_jo[phase]['X'].append(inp)
|
||||||
dic_jo[phase]['Y'].append(gt[i][j]['left_or_right'])
|
dic_jo[phase]['Y'].append(gt_turn)
|
||||||
dic_jo[phase]['names'].append(folder+"_frame{}".format(j))
|
dic_jo[phase]['names'].append(folder+"_frame{}".format(j))
|
||||||
|
|
||||||
now_time = datetime.datetime.now().strftime("%Y%m%d-%H%M")[2:]
|
now_time = datetime.datetime.now().strftime("%Y%m%d-%H%M")[2:]
|
||||||
with open("/home/beauvill/joints-casr-right-" + split[0] + split[1] + "-" + now_time + ".json", 'w') as file:
|
with open("/home/beauvill/joints-casr-" + mode + "-right-" + split[0] + split[1] + "-" + now_time + ".json", 'w') as file:
|
||||||
json.dump(dic_jo, file)
|
json.dump(dic_jo, file)
|
||||||
return dic_jo
|
return dic_jo
|
||||||
|
|||||||
@ -1,95 +0,0 @@
|
|||||||
import pickle
|
|
||||||
import re
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import datetime
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .. import __version__
|
|
||||||
from ..network.process import preprocess_monoloco
|
|
||||||
|
|
||||||
gt_path = '/scratch/izar/beauvill/casr/data/annotations/casr_annotation.pickle'
|
|
||||||
res_path = '/scratch/izar/beauvill/casr/res_extended/casr*'
|
|
||||||
|
|
||||||
def bb_intersection_over_union(boxA, boxB):
|
|
||||||
xA = max(boxA[0], boxB[0])
|
|
||||||
yA = max(boxA[1], boxB[1])
|
|
||||||
xB = min(boxA[2], boxB[2])
|
|
||||||
yB = min(boxA[3], boxB[3])
|
|
||||||
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
|
||||||
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
|
||||||
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
|
||||||
iou = interArea / float(boxAArea + boxBArea - interArea)
|
|
||||||
return iou
|
|
||||||
|
|
||||||
def match_bboxes(bbox_gt, bbox_pred):
|
|
||||||
n_true = bbox_gt.shape[0]
|
|
||||||
n_pred = bbox_pred.shape[0]
|
|
||||||
|
|
||||||
iou_matrix = np.zeros((n_true, n_pred))
|
|
||||||
for i in range(n_true):
|
|
||||||
for j in range(n_pred):
|
|
||||||
iou_matrix[i, j] = bb_intersection_over_union(bbox_gt[i,:], bbox_pred[j,:])
|
|
||||||
|
|
||||||
return np.argmax(iou_matrix)
|
|
||||||
|
|
||||||
def standard_bbox(bbox):
|
|
||||||
return [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]]
|
|
||||||
|
|
||||||
def load_gt():
|
|
||||||
return pickle.load(open(gt_path, 'rb'), encoding='latin1')
|
|
||||||
|
|
||||||
def load_res():
|
|
||||||
mono = []
|
|
||||||
for folder in sorted(glob.glob(res_path), key=lambda x:float(re.findall(r"(\d+)",x)[0])):
|
|
||||||
data_list = []
|
|
||||||
for file in sorted(os.listdir(folder), key=lambda x:float(re.findall(r"(\d+)",x)[0])):
|
|
||||||
if 'json' in file:
|
|
||||||
json_path = os.path.join(folder, file)
|
|
||||||
json_data = json.load(open(json_path))
|
|
||||||
json_data['filename'] = json_path
|
|
||||||
data_list.append(json_data)
|
|
||||||
mono.append(data_list)
|
|
||||||
return mono
|
|
||||||
|
|
||||||
def create_dic_std():
|
|
||||||
gt=load_gt()
|
|
||||||
res=load_res()
|
|
||||||
dic_jo = {
|
|
||||||
'train': dict(X=[], Y=[], names=[], kps=[]),
|
|
||||||
'val': dict(X=[], Y=[], names=[], kps=[]),
|
|
||||||
'version': __version__,
|
|
||||||
}
|
|
||||||
wrong = [6, 8, 9, 10, 11, 12, 14, 21, 40, 43, 55, 70, 76, 92, 109,
|
|
||||||
110, 112, 113, 121, 123, 124, 127, 128, 134, 136, 139, 165, 173]
|
|
||||||
for i in [x for x in range(len(res[:])) if x not in wrong]:
|
|
||||||
for j in range(len(res[i][:])):
|
|
||||||
phase = 'val'
|
|
||||||
if (j % 10) > 1:
|
|
||||||
phase = 'train'
|
|
||||||
|
|
||||||
folder = gt[i][j]['video_folder']
|
|
||||||
|
|
||||||
if('boxes' in res[i][j] and gt[i][j]['left_or_right'] != 2):
|
|
||||||
gt_box = gt[i][j]['bbox_gt']
|
|
||||||
|
|
||||||
good_idx = match_bboxes(np.array([standard_bbox(gt_box)]), np.array(res[i][j]['boxes'])[:,:4])
|
|
||||||
|
|
||||||
keypoints = [res[i][j]['uv_kps'][good_idx]]
|
|
||||||
|
|
||||||
gt_turn = gt[i][j]['left_or_right']
|
|
||||||
if gt_turn == 3:
|
|
||||||
gt_turn = 2
|
|
||||||
|
|
||||||
inp = preprocess_monoloco(keypoints, torch.eye(3)).view(-1).tolist()
|
|
||||||
dic_jo[phase]['kps'].append(keypoints)
|
|
||||||
dic_jo[phase]['X'].append(inp)
|
|
||||||
dic_jo[phase]['Y'].append(gt_turn)
|
|
||||||
dic_jo[phase]['names'].append(folder+"_frame{}".format(j))
|
|
||||||
|
|
||||||
now_time = datetime.datetime.now().strftime("%Y%m%d-%H%M")[2:]
|
|
||||||
with open("/home/beauvill/joints-casr-std-" + now_time + ".json", 'w') as file:
|
|
||||||
json.dump(dic_jo, file)
|
|
||||||
return dic_jo
|
|
||||||
@ -151,10 +151,10 @@ def main():
|
|||||||
prep.run()
|
prep.run()
|
||||||
elif 'casr' in args.dataset:
|
elif 'casr' in args.dataset:
|
||||||
from .prep.casr_preprocess import create_dic
|
from .prep.casr_preprocess import create_dic
|
||||||
create_dic()
|
create_dic(std=False)
|
||||||
elif 'casr_std' in args.dataset:
|
elif 'casr_std' in args.dataset:
|
||||||
from .prep.casr_preprocess_standard import create_dic_std
|
from .prep.casr_preprocess import create_dic
|
||||||
create_dic_std()
|
create_dic(std=True)
|
||||||
else:
|
else:
|
||||||
from .prep.preprocess_kitti import PreprocessKitti
|
from .prep.preprocess_kitti import PreprocessKitti
|
||||||
prep = PreprocessKitti(args.dir_ann, mode=args.mode, iou_min=args.iou_min)
|
prep = PreprocessKitti(args.dir_ann, mode=args.mode, iou_min=args.iou_min)
|
||||||
|
|||||||
@ -1,122 +0,0 @@
|
|||||||
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from .trainer_casr import CASRTrainer
|
|
||||||
|
|
||||||
|
|
||||||
class HypTuningCasr:
|
|
||||||
|
|
||||||
def __init__(self, joints, epochs, monocular, dropout, multiplier=1, r_seed=1):
|
|
||||||
"""
|
|
||||||
Initialize directories, load the data and parameters for the training
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Initialize Directories
|
|
||||||
self.joints = joints
|
|
||||||
self.monocular = monocular
|
|
||||||
self.dropout = dropout
|
|
||||||
self.num_epochs = epochs
|
|
||||||
self.r_seed = r_seed
|
|
||||||
dir_out = os.path.join('data', 'models')
|
|
||||||
dir_logs = os.path.join('data', 'logs')
|
|
||||||
assert os.path.exists(dir_out), "Output directory not found"
|
|
||||||
if not os.path.exists(dir_logs):
|
|
||||||
os.makedirs(dir_logs)
|
|
||||||
|
|
||||||
name_out = 'hyp-casr-'
|
|
||||||
|
|
||||||
self.path_log = os.path.join(dir_logs, name_out)
|
|
||||||
self.path_model = os.path.join(dir_out, name_out)
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Initialize grid of parameters
|
|
||||||
random.seed(r_seed)
|
|
||||||
np.random.seed(r_seed)
|
|
||||||
self.sched_gamma_list = [0.8, 0.9, 1, 0.8, 0.9, 1] * multiplier
|
|
||||||
random.shuffle(self.sched_gamma_list)
|
|
||||||
self.sched_step = [10, 20, 40, 60, 80, 100] * multiplier
|
|
||||||
random.shuffle(self.sched_step)
|
|
||||||
self.bs_list = [64, 128, 256, 512, 512, 1024] * multiplier
|
|
||||||
random.shuffle(self.bs_list)
|
|
||||||
self.hidden_list = [512, 1024, 2048, 512, 1024, 2048] * multiplier
|
|
||||||
random.shuffle(self.hidden_list)
|
|
||||||
self.n_stage_list = [3, 3, 3, 3, 3, 3] * multiplier
|
|
||||||
random.shuffle(self.n_stage_list)
|
|
||||||
# Learning rate
|
|
||||||
aa = math.log(0.0005, 10)
|
|
||||||
bb = math.log(0.01, 10)
|
|
||||||
log_lr_list = np.random.uniform(aa, bb, int(6 * multiplier)).tolist()
|
|
||||||
self.lr_list = [10 ** xx for xx in log_lr_list]
|
|
||||||
# plt.hist(self.lr_list, bins=50)
|
|
||||||
# plt.show()
|
|
||||||
|
|
||||||
def train(self, args):
|
|
||||||
"""Train multiple times using log-space random search"""
|
|
||||||
|
|
||||||
best_acc_val = 20
|
|
||||||
dic_best = {}
|
|
||||||
start = time.time()
|
|
||||||
cnt = 0
|
|
||||||
for idx, lr in enumerate(self.lr_list):
|
|
||||||
bs = self.bs_list[idx]
|
|
||||||
sched_gamma = self.sched_gamma_list[idx]
|
|
||||||
sched_step = self.sched_step[idx]
|
|
||||||
hidden_size = self.hidden_list[idx]
|
|
||||||
n_stage = self.n_stage_list[idx]
|
|
||||||
|
|
||||||
training = CASRTrainer(args)
|
|
||||||
|
|
||||||
best_epoch = training.train()
|
|
||||||
dic_err, model = training.evaluate()
|
|
||||||
acc_val = dic_err['val']['all']['mean']
|
|
||||||
cnt += 1
|
|
||||||
print("Combination number: {}".format(cnt))
|
|
||||||
|
|
||||||
if acc_val < best_acc_val:
|
|
||||||
dic_best['lr'] = lr
|
|
||||||
dic_best['joints'] = self.joints
|
|
||||||
dic_best['bs'] = bs
|
|
||||||
dic_best['monocular'] = self.monocular
|
|
||||||
dic_best['sched_gamma'] = sched_gamma
|
|
||||||
dic_best['sched_step'] = sched_step
|
|
||||||
dic_best['hidden_size'] = hidden_size
|
|
||||||
dic_best['n_stage'] = n_stage
|
|
||||||
dic_best['acc_val'] = dic_err['val']['all']['d']
|
|
||||||
dic_best['best_epoch'] = best_epoch
|
|
||||||
dic_best['random_seed'] = self.r_seed
|
|
||||||
# dic_best['acc_test'] = dic_err['test']['all']['mean']
|
|
||||||
|
|
||||||
best_acc_val = acc_val
|
|
||||||
model_best = model
|
|
||||||
|
|
||||||
# Save model and log
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
|
||||||
self.path_model = self.path_model + now_time + '.pkl'
|
|
||||||
torch.save(model_best.state_dict(), self.path_model)
|
|
||||||
with open(self.path_log + now_time, 'w') as f:
|
|
||||||
json.dump(dic_best, f)
|
|
||||||
end = time.time()
|
|
||||||
print('\n\n\n')
|
|
||||||
self.logger.info(" Tried {} combinations".format(cnt))
|
|
||||||
self.logger.info(" Total time for hyperparameters search: {:.2f} minutes".format((end - start) / 60))
|
|
||||||
self.logger.info(" Best hyperparameters are:")
|
|
||||||
for key, value in dic_best.items():
|
|
||||||
self.logger.info(" {}: {}".format(key, value))
|
|
||||||
|
|
||||||
print()
|
|
||||||
self.logger.info("Accuracy in each cluster:")
|
|
||||||
|
|
||||||
self.logger.info("Final accuracy Val: {:.2f}".format(dic_best['acc_val']))
|
|
||||||
self.logger.info("\nSaved the model: {}".format(self.path_model))
|
|
||||||
@ -1,367 +0,0 @@
|
|||||||
# pylint: disable=too-many-statements
|
|
||||||
|
|
||||||
"""
|
|
||||||
Training and evaluation of a neural network that, given 2D joints, estimates:
|
|
||||||
- 3D localization and confidence intervals
|
|
||||||
- Orientation
|
|
||||||
- Bounding box dimensions
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import os
|
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from itertools import chain
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.optim import lr_scheduler
|
|
||||||
|
|
||||||
from .. import __version__
|
|
||||||
from .datasets import KeypointsDataset
|
|
||||||
from .losses import CompositeLoss, MultiTaskLoss, AutoTuneMultiTaskLoss
|
|
||||||
from ..network import extract_outputs, extract_labels
|
|
||||||
from ..network.architectures import LocoModel
|
|
||||||
from ..utils import set_logger
|
|
||||||
|
|
||||||
|
|
||||||
class CASRTrainer:
|
|
||||||
# Constants
|
|
||||||
VAL_BS = 10000
|
|
||||||
|
|
||||||
tasks = ('cyclist',)
|
|
||||||
val_task = 'cyclist'
|
|
||||||
lambdas = (1,)
|
|
||||||
#clusters = ['10', '20', '30', '40']
|
|
||||||
input_size = 34
|
|
||||||
output_size = 4
|
|
||||||
dir_figures = os.path.join('figures', 'losses')
|
|
||||||
|
|
||||||
def __init__(self, args):
|
|
||||||
"""
|
|
||||||
Initialize directories, load the data and parameters for the training
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert os.path.exists(args.joints), "Input file not found"
|
|
||||||
self.mode = args.mode
|
|
||||||
self.joints = args.joints
|
|
||||||
self.num_epochs = args.epochs
|
|
||||||
self.no_save = args.no_save
|
|
||||||
self.print_loss = args.print_loss
|
|
||||||
self.lr = args.lr
|
|
||||||
self.sched_step = args.sched_step
|
|
||||||
self.sched_gamma = args.sched_gamma
|
|
||||||
self.hidden_size = args.hidden_size
|
|
||||||
self.n_stage = args.n_stage
|
|
||||||
self.r_seed = args.r_seed
|
|
||||||
self.auto_tune_mtl = args.auto_tune_mtl
|
|
||||||
|
|
||||||
if args.std:
|
|
||||||
self.output_size = 3
|
|
||||||
name = 'casr_standard'
|
|
||||||
else:
|
|
||||||
name = 'casr'
|
|
||||||
# Select path out
|
|
||||||
if args.out:
|
|
||||||
self.path_out = args.out # full path without extension
|
|
||||||
dir_out, _ = os.path.split(self.path_out)
|
|
||||||
else:
|
|
||||||
dir_out = os.path.join('data', 'outputs')
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
|
||||||
name_out = name + '-' + now_time + '.pkl'
|
|
||||||
self.path_out = os.path.join(dir_out, name_out)
|
|
||||||
assert os.path.exists(dir_out), "Directory to save the model not found"
|
|
||||||
print(self.path_out)
|
|
||||||
# Select the device
|
|
||||||
use_cuda = torch.cuda.is_available()
|
|
||||||
self.device = torch.device("cuda" if use_cuda else "cpu")
|
|
||||||
print('Device: ', self.device)
|
|
||||||
torch.manual_seed(self.r_seed)
|
|
||||||
if use_cuda:
|
|
||||||
torch.cuda.manual_seed(self.r_seed)
|
|
||||||
|
|
||||||
losses_tr, losses_val = CompositeLoss(self.tasks)()
|
|
||||||
|
|
||||||
if self.auto_tune_mtl:
|
|
||||||
self.mt_loss = AutoTuneMultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
|
|
||||||
else:
|
|
||||||
self.mt_loss = MultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
|
|
||||||
self.mt_loss.to(self.device)
|
|
||||||
|
|
||||||
# Dataloader
|
|
||||||
self.dataloaders = {phase: DataLoader(KeypointsDataset(self.joints, phase=phase),
|
|
||||||
batch_size=args.bs, shuffle=True) for phase in ['train', 'val']}
|
|
||||||
|
|
||||||
self.dataset_sizes = {phase: len(KeypointsDataset(self.joints, phase=phase))
|
|
||||||
for phase in ['train', 'val']}
|
|
||||||
self.dataset_version = KeypointsDataset(self.joints, phase='train').get_version()
|
|
||||||
|
|
||||||
self._set_logger(args)
|
|
||||||
|
|
||||||
# Define the model
|
|
||||||
self.logger.info('Sizes of the dataset: {}'.format(self.dataset_sizes))
|
|
||||||
print(">>> creating model")
|
|
||||||
|
|
||||||
self.model = LocoModel(
|
|
||||||
input_size=self.input_size,
|
|
||||||
output_size=self.output_size,
|
|
||||||
linear_size=args.hidden_size,
|
|
||||||
p_dropout=args.dropout,
|
|
||||||
num_stage=self.n_stage,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.model.to(self.device)
|
|
||||||
print(">>> model params: {:.3f}M".format(sum(p.numel() for p in self.model.parameters()) / 1000000.0))
|
|
||||||
print(">>> loss params: {}".format(sum(p.numel() for p in self.mt_loss.parameters())))
|
|
||||||
|
|
||||||
# Optimizer and scheduler
|
|
||||||
all_params = chain(self.model.parameters(), self.mt_loss.parameters())
|
|
||||||
self.optimizer = torch.optim.Adam(params=all_params, lr=args.lr)
|
|
||||||
self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')
|
|
||||||
self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=self.sched_step, gamma=self.sched_gamma)
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
since = time.time()
|
|
||||||
best_model_wts = copy.deepcopy(self.model.state_dict())
|
|
||||||
best_acc = 1e6
|
|
||||||
best_training_acc = 1e6
|
|
||||||
best_epoch = 0
|
|
||||||
epoch_losses = defaultdict(lambda: defaultdict(list))
|
|
||||||
for epoch in range(self.num_epochs):
|
|
||||||
running_loss = defaultdict(lambda: defaultdict(int))
|
|
||||||
|
|
||||||
# Each epoch has a training and validation phase
|
|
||||||
for phase in ['train', 'val']:
|
|
||||||
if phase == 'train':
|
|
||||||
self.model.train() # Set model to training mode
|
|
||||||
else:
|
|
||||||
self.model.eval() # Set model to evaluate mode
|
|
||||||
|
|
||||||
for inputs, labels, _, _ in self.dataloaders[phase]:
|
|
||||||
inputs = inputs.to(self.device)
|
|
||||||
labels = labels.to(self.device)
|
|
||||||
with torch.set_grad_enabled(phase == 'train'):
|
|
||||||
if phase == 'train':
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
outputs = self.model(inputs)
|
|
||||||
loss, _ = self.mt_loss(outputs, labels, phase=phase)
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3)
|
|
||||||
self.optimizer.step()
|
|
||||||
self.scheduler.step()
|
|
||||||
|
|
||||||
else:
|
|
||||||
outputs = self.model(inputs)
|
|
||||||
with torch.no_grad():
|
|
||||||
loss_eval, loss_values_eval = self.mt_loss(outputs, labels, phase='val')
|
|
||||||
self.epoch_logs(phase, loss_eval, loss_values_eval, inputs, running_loss)
|
|
||||||
|
|
||||||
self.cout_values(epoch, epoch_losses, running_loss)
|
|
||||||
|
|
||||||
# deep copy the model
|
|
||||||
|
|
||||||
if epoch_losses['val'][self.val_task][-1] < best_acc:
|
|
||||||
best_acc = epoch_losses['val'][self.val_task][-1]
|
|
||||||
best_training_acc = epoch_losses['train']['all'][-1]
|
|
||||||
best_epoch = epoch
|
|
||||||
best_model_wts = copy.deepcopy(self.model.state_dict())
|
|
||||||
|
|
||||||
time_elapsed = time.time() - since
|
|
||||||
print('\n\n' + '-' * 120)
|
|
||||||
self.logger.info('Training:\nTraining complete in {:.0f}m {:.0f}s'
|
|
||||||
.format(time_elapsed // 60, time_elapsed % 60))
|
|
||||||
self.logger.info('Best training Accuracy: {:.3f}'.format(best_training_acc))
|
|
||||||
self.logger.info('Best validation Accuracy for {}: {:.3f}'.format(self.val_task, best_acc))
|
|
||||||
self.logger.info('Saved weights of the model at epoch: {}'.format(best_epoch))
|
|
||||||
|
|
||||||
self._print_losses(epoch_losses)
|
|
||||||
|
|
||||||
# load best model weights
|
|
||||||
self.model.load_state_dict(best_model_wts)
|
|
||||||
return best_epoch
|
|
||||||
|
|
||||||
def epoch_logs(self, phase, loss, loss_values, inputs, running_loss):
|
|
||||||
|
|
||||||
running_loss[phase]['all'] += loss.item() * inputs.size(0)
|
|
||||||
for i, task in enumerate(self.tasks):
|
|
||||||
running_loss[phase][task] += loss_values[i].item() * inputs.size(0)
|
|
||||||
|
|
||||||
def evaluate(self, load=False, model=None, debug=False):
|
|
||||||
|
|
||||||
# To load a model instead of using the trained one
|
|
||||||
if load:
|
|
||||||
self.model.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
|
|
||||||
|
|
||||||
# Average distance on training and test set after unnormalizing
|
|
||||||
self.model.eval()
|
|
||||||
dic_err = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) # initialized to zero
|
|
||||||
dic_err['val']['sigmas'] = [0.] * len(self.tasks)
|
|
||||||
dataset = KeypointsDataset(self.joints, phase='val')
|
|
||||||
size_eval = len(dataset)
|
|
||||||
start = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for end in range(self.VAL_BS, size_eval + self.VAL_BS, self.VAL_BS):
|
|
||||||
end = end if end < size_eval else size_eval
|
|
||||||
inputs, labels, _, _ = dataset[start:end]
|
|
||||||
start = end
|
|
||||||
inputs = inputs.to(self.device)
|
|
||||||
labels = labels.to(self.device)
|
|
||||||
|
|
||||||
# Debug plot for input-output distributions
|
|
||||||
if debug:
|
|
||||||
debug_plots(inputs, labels)
|
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
# outputs = self.model(inputs)
|
|
||||||
#self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
|
|
||||||
|
|
||||||
# self.cout_stats(dic_err['val'], size_eval, clst='all')
|
|
||||||
# Evaluate performances on different clusters and save statistics
|
|
||||||
|
|
||||||
# Save the model and the results
|
|
||||||
if not (self.no_save or load):
|
|
||||||
torch.save(self.model.state_dict(), self.path_model)
|
|
||||||
print('-' * 120)
|
|
||||||
self.logger.info("\nmodel saved: {} \n".format(self.path_model))
|
|
||||||
else:
|
|
||||||
self.logger.info("\nmodel not saved\n")
|
|
||||||
|
|
||||||
return dic_err, self.model
|
|
||||||
|
|
||||||
def compute_stats(self, outputs, labels, dic_err, size_eval, clst):
|
|
||||||
"""Compute mean, bi and max of torch tensors"""
|
|
||||||
|
|
||||||
_, loss_values = self.mt_loss(outputs, labels, phase='val')
|
|
||||||
rel_frac = outputs.size(0) / size_eval
|
|
||||||
|
|
||||||
tasks = self.tasks # Exclude auxiliary
|
|
||||||
|
|
||||||
for idx, task in enumerate(tasks):
|
|
||||||
dic_err[clst][task] += float(loss_values[idx].item()) * (outputs.size(0) / size_eval)
|
|
||||||
|
|
||||||
# Distance
|
|
||||||
errs = torch.abs(extract_outputs(outputs)['d'] - extract_labels(labels)['d'])
|
|
||||||
assert rel_frac > 0.99, "Variance of errors not supported with partial evaluation"
|
|
||||||
|
|
||||||
# Uncertainty
|
|
||||||
bis = extract_outputs(outputs)['bi'].cpu()
|
|
||||||
bi = float(torch.mean(bis).item())
|
|
||||||
bi_perc = float(torch.sum(errs <= bis)) / errs.shape[0]
|
|
||||||
dic_err[clst]['bi'] += bi * rel_frac
|
|
||||||
dic_err[clst]['bi%'] += bi_perc * rel_frac
|
|
||||||
dic_err[clst]['std'] = errs.std()
|
|
||||||
|
|
||||||
# (Don't) Save auxiliary task results
|
|
||||||
dic_err['sigmas'].append(0)
|
|
||||||
|
|
||||||
if self.auto_tune_mtl:
|
|
||||||
assert len(loss_values) == 2 * len(self.tasks)
|
|
||||||
for i, _ in enumerate(self.tasks):
|
|
||||||
dic_err['sigmas'][i] += float(loss_values[len(tasks) + i + 1].item()) * rel_frac
|
|
||||||
|
|
||||||
def cout_stats(self, dic_err, size_eval, clst):
|
|
||||||
if clst == 'all':
|
|
||||||
print('-' * 120)
|
|
||||||
self.logger.info("Evaluation, val set: \nAv. dist D: {:.2f} m with bi {:.2f} ({:.1f}%), \n"
|
|
||||||
"X: {:.1f} cm, Y: {:.1f} cm \nOri: {:.1f} "
|
|
||||||
"\n H: {:.1f} cm, W: {:.1f} cm, L: {:.1f} cm"
|
|
||||||
"\nAuxiliary Task: {:.1f} %, "
|
|
||||||
.format(dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
|
|
||||||
dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
|
|
||||||
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
|
|
||||||
dic_err[clst]['l'] * 100, dic_err[clst]['aux'] * 100))
|
|
||||||
if self.auto_tune_mtl:
|
|
||||||
self.logger.info("Sigmas: Z: {:.2f}, X: {:.2f}, Y:{:.2f}, H: {:.2f}, W: {:.2f}, L: {:.2f}, ORI: {:.2f}"
|
|
||||||
" AUX:{:.2f}\n"
|
|
||||||
.format(*dic_err['sigmas']))
|
|
||||||
else:
|
|
||||||
self.logger.info("Val err clust {} --> D:{:.2f}m, bi:{:.2f} ({:.1f}%), STD:{:.1f}m X:{:.1f} Y:{:.1f} "
|
|
||||||
"Ori:{:.1f}d, H: {:.0f} W: {:.0f} L:{:.0f} for {} pp. "
|
|
||||||
.format(clst, dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
|
|
||||||
dic_err[clst]['std'], dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
|
|
||||||
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
|
|
||||||
dic_err[clst]['l'] * 100, size_eval))
|
|
||||||
|
|
||||||
def cout_values(self, epoch, epoch_losses, running_loss):
|
|
||||||
|
|
||||||
string = '\r' + '{:.0f} '
|
|
||||||
format_list = [epoch]
|
|
||||||
for phase in running_loss:
|
|
||||||
string = string + phase[0:1].upper() + ':'
|
|
||||||
for el in running_loss['train']:
|
|
||||||
loss = running_loss[phase][el] / self.dataset_sizes[phase]
|
|
||||||
epoch_losses[phase][el].append(loss)
|
|
||||||
if el == 'all':
|
|
||||||
string = string + ':{:.1f} '
|
|
||||||
format_list.append(loss)
|
|
||||||
elif el in ('ori', 'aux'):
|
|
||||||
string = string + el + ':{:.1f} '
|
|
||||||
format_list.append(loss)
|
|
||||||
else:
|
|
||||||
string = string + el + ':{:.0f} '
|
|
||||||
format_list.append(loss * 100)
|
|
||||||
|
|
||||||
if epoch % 10 == 0:
|
|
||||||
print(string.format(*format_list))
|
|
||||||
|
|
||||||
def _print_losses(self, epoch_losses):
|
|
||||||
if not self.print_loss:
|
|
||||||
return
|
|
||||||
os.makedirs(self.dir_figures, exist_ok=True)
|
|
||||||
for idx, phase in enumerate(epoch_losses):
|
|
||||||
for idx_2, el in enumerate(epoch_losses['train']):
|
|
||||||
plt.figure(idx + idx_2)
|
|
||||||
plt.title(phase + '_' + el)
|
|
||||||
plt.xlabel('epochs')
|
|
||||||
plt.plot(epoch_losses[phase][el][10:], label='{} Loss: {}'.format(phase, el))
|
|
||||||
plt.savefig(os.path.join(self.dir_figures, '{}_loss_{}.png'.format(phase, el)))
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
def _set_logger(self, args):
|
|
||||||
if self.no_save:
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
else:
|
|
||||||
self.path_model = self.path_out
|
|
||||||
print(self.path_model)
|
|
||||||
self.logger = set_logger(os.path.splitext(self.path_out)[0]) # remove .pkl
|
|
||||||
self.logger.info( # pylint: disable=logging-fstring-interpolation
|
|
||||||
f'\nVERSION: {__version__}\n'
|
|
||||||
f'\nINPUT_FILE: {args.joints}'
|
|
||||||
f'\nInput file version: {self.dataset_version}'
|
|
||||||
f'\nTorch version: {torch.__version__}\n'
|
|
||||||
f'\nTraining arguments:'
|
|
||||||
f'\nmode: {self.mode} \nlearning rate: {args.lr} \nbatch_size: {args.bs}'
|
|
||||||
f'\nepochs: {args.epochs} \ndropout: {args.dropout} '
|
|
||||||
f'\nscheduler step: {args.sched_step} \nscheduler gamma: {args.sched_gamma} '
|
|
||||||
f'\ninput_size: {self.input_size} \noutput_size: {self.output_size} '
|
|
||||||
f'\nhidden_size: {args.hidden_size}'
|
|
||||||
f' \nn_stages: {args.n_stage} \n r_seed: {args.r_seed} \nlambdas: {self.lambdas}'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def debug_plots(inputs, labels):
|
|
||||||
inputs_shoulder = inputs.cpu().numpy()[:, 5]
|
|
||||||
inputs_hip = inputs.cpu().numpy()[:, 11]
|
|
||||||
labels = labels.cpu().numpy()
|
|
||||||
heights = inputs_hip - inputs_shoulder
|
|
||||||
plt.figure(1)
|
|
||||||
plt.hist(heights, bins='auto')
|
|
||||||
plt.show()
|
|
||||||
plt.figure(2)
|
|
||||||
plt.hist(labels, bins='auto')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def get_accuracy(outputs, labels):
|
|
||||||
"""From Binary cross entropy outputs to accuracy"""
|
|
||||||
|
|
||||||
mask = outputs >= 0.5
|
|
||||||
accuracy = 1. - torch.mean(torch.abs(mask.float() - labels)).item()
|
|
||||||
return accuracy
|
|
||||||
Loading…
Reference in New Issue
Block a user