diff --git a/monoloco/prep/casr_preprocess.py b/monoloco/prep/casr_preprocess.py index 4cc9cfc..eab5f91 100644 --- a/monoloco/prep/casr_preprocess.py +++ b/monoloco/prep/casr_preprocess.py @@ -38,12 +38,12 @@ def match_bboxes(bbox_gt, bbox_pred): def standard_bbox(bbox): return [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]] -def load_gt(path=gt_path): - return pickle.load(open(path, 'rb'), encoding='latin1') +def load_gt(): + return pickle.load(open(gt_path, 'rb'), encoding='latin1') -def load_res(path=res_path): +def load_res(): mono = [] - for folder in sorted(glob.glob(path), key=lambda x:float(re.findall(r"(\d+)",x)[0])): + for folder in sorted(glob.glob(res_path), key=lambda x:float(re.findall(r"(\d+)",x)[0])): data_list = [] for file in sorted(os.listdir(folder), key=lambda x:float(re.findall(r"(\d+)",x)[0])): if 'json' in file: @@ -54,7 +54,7 @@ def load_res(path=res_path): mono.append(data_list) return mono -def create_dic(): +def create_dic(std=False): gt=load_gt() res=load_res() dic_jo = { @@ -63,7 +63,14 @@ def create_dic(): 'version': __version__, } split = ['3', '4'] - for i in range(len(res[:])): + if std: + wrong = [6, 8, 9, 10, 11, 12, 14, 21, 40, 43, 55, 70, 76, 92, 109, + 110, 112, 113, 121, 123, 124, 127, 128, 134, 136, 139, 165, 173] + mode = 'std' + else: + wrong = [] + mode = '' + for i in [x for x in range(len(res[:])) if x not in wrong]: for j in [x for x in range(len(res[i][:])) if 'boxes' in res[i][x]]: folder = gt[i][j]['video_folder'] @@ -77,13 +84,16 @@ def create_dic(): keypoints = [res[i][j]['uv_kps'][good_idx]] + gt_turn = gt[i][j]['left_or_right'] + if std and gt_turn == 3: + gt_turn = 2 inp = preprocess_monoloco(keypoints, torch.eye(3)).view(-1).tolist() dic_jo[phase]['kps'].append(keypoints) dic_jo[phase]['X'].append(inp) - dic_jo[phase]['Y'].append(gt[i][j]['left_or_right']) + dic_jo[phase]['Y'].append(gt_turn) dic_jo[phase]['names'].append(folder+"_frame{}".format(j)) now_time = datetime.datetime.now().strftime("%Y%m%d-%H%M")[2:] - with open("/home/beauvill/joints-casr-right-" + split[0] + split[1] + "-" + now_time + ".json", 'w') as file: + with open("/home/beauvill/joints-casr-" + mode + "-right-" + split[0] + split[1] + "-" + now_time + ".json", 'w') as file: json.dump(dic_jo, file) return dic_jo diff --git a/monoloco/prep/casr_preprocess_standard.py b/monoloco/prep/casr_preprocess_standard.py deleted file mode 100644 index 9716617..0000000 --- a/monoloco/prep/casr_preprocess_standard.py +++ /dev/null @@ -1,95 +0,0 @@ -import pickle -import re -import json -import os -import glob -import datetime -import numpy as np -import torch - -from .. import __version__ -from ..network.process import preprocess_monoloco - -gt_path = '/scratch/izar/beauvill/casr/data/annotations/casr_annotation.pickle' -res_path = '/scratch/izar/beauvill/casr/res_extended/casr*' - -def bb_intersection_over_union(boxA, boxB): - xA = max(boxA[0], boxB[0]) - yA = max(boxA[1], boxB[1]) - xB = min(boxA[2], boxB[2]) - yB = min(boxA[3], boxB[3]) - interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) - boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) - boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) - iou = interArea / float(boxAArea + boxBArea - interArea) - return iou - -def match_bboxes(bbox_gt, bbox_pred): - n_true = bbox_gt.shape[0] - n_pred = bbox_pred.shape[0] - - iou_matrix = np.zeros((n_true, n_pred)) - for i in range(n_true): - for j in range(n_pred): - iou_matrix[i, j] = bb_intersection_over_union(bbox_gt[i,:], bbox_pred[j,:]) - - return np.argmax(iou_matrix) - -def standard_bbox(bbox): - return [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]] - -def load_gt(): - return pickle.load(open(gt_path, 'rb'), encoding='latin1') - -def load_res(): - mono = [] - for folder in sorted(glob.glob(res_path), key=lambda x:float(re.findall(r"(\d+)",x)[0])): - data_list = [] - for file in sorted(os.listdir(folder), key=lambda x:float(re.findall(r"(\d+)",x)[0])): - if 'json' in file: - json_path = os.path.join(folder, file) - json_data = json.load(open(json_path)) - json_data['filename'] = json_path - data_list.append(json_data) - mono.append(data_list) - return mono - -def create_dic_std(): - gt=load_gt() - res=load_res() - dic_jo = { - 'train': dict(X=[], Y=[], names=[], kps=[]), - 'val': dict(X=[], Y=[], names=[], kps=[]), - 'version': __version__, - } - wrong = [6, 8, 9, 10, 11, 12, 14, 21, 40, 43, 55, 70, 76, 92, 109, - 110, 112, 113, 121, 123, 124, 127, 128, 134, 136, 139, 165, 173] - for i in [x for x in range(len(res[:])) if x not in wrong]: - for j in range(len(res[i][:])): - phase = 'val' - if (j % 10) > 1: - phase = 'train' - - folder = gt[i][j]['video_folder'] - - if('boxes' in res[i][j] and gt[i][j]['left_or_right'] != 2): - gt_box = gt[i][j]['bbox_gt'] - - good_idx = match_bboxes(np.array([standard_bbox(gt_box)]), np.array(res[i][j]['boxes'])[:,:4]) - - keypoints = [res[i][j]['uv_kps'][good_idx]] - - gt_turn = gt[i][j]['left_or_right'] - if gt_turn == 3: - gt_turn = 2 - - inp = preprocess_monoloco(keypoints, torch.eye(3)).view(-1).tolist() - dic_jo[phase]['kps'].append(keypoints) - dic_jo[phase]['X'].append(inp) - dic_jo[phase]['Y'].append(gt_turn) - dic_jo[phase]['names'].append(folder+"_frame{}".format(j)) - - now_time = datetime.datetime.now().strftime("%Y%m%d-%H%M")[2:] - with open("/home/beauvill/joints-casr-std-" + now_time + ".json", 'w') as file: - json.dump(dic_jo, file) - return dic_jo \ No newline at end of file diff --git a/monoloco/run.py b/monoloco/run.py index ab39a7a..d4079ce 100644 --- a/monoloco/run.py +++ b/monoloco/run.py @@ -151,10 +151,10 @@ def main(): prep.run() elif 'casr' in args.dataset: from .prep.casr_preprocess import create_dic - create_dic() + create_dic(std=False) elif 'casr_std' in args.dataset: - from .prep.casr_preprocess_standard import create_dic_std - create_dic_std() + from .prep.casr_preprocess import create_dic + create_dic(std=True) else: from .prep.preprocess_kitti import PreprocessKitti prep = PreprocessKitti(args.dir_ann, mode=args.mode, iou_min=args.iou_min) diff --git a/monoloco/train/hyp_tuning_casr.py b/monoloco/train/hyp_tuning_casr.py deleted file mode 100644 index 9c97650..0000000 --- a/monoloco/train/hyp_tuning_casr.py +++ /dev/null @@ -1,122 +0,0 @@ - -import math -import os -import json -import time -import logging -import random -import datetime - -import torch -import numpy as np - -from .trainer_casr import CASRTrainer - - -class HypTuningCasr: - - def __init__(self, joints, epochs, monocular, dropout, multiplier=1, r_seed=1): - """ - Initialize directories, load the data and parameters for the training - """ - - # Initialize Directories - self.joints = joints - self.monocular = monocular - self.dropout = dropout - self.num_epochs = epochs - self.r_seed = r_seed - dir_out = os.path.join('data', 'models') - dir_logs = os.path.join('data', 'logs') - assert os.path.exists(dir_out), "Output directory not found" - if not os.path.exists(dir_logs): - os.makedirs(dir_logs) - - name_out = 'hyp-casr-' - - self.path_log = os.path.join(dir_logs, name_out) - self.path_model = os.path.join(dir_out, name_out) - - logging.basicConfig(level=logging.INFO) - self.logger = logging.getLogger(__name__) - - # Initialize grid of parameters - random.seed(r_seed) - np.random.seed(r_seed) - self.sched_gamma_list = [0.8, 0.9, 1, 0.8, 0.9, 1] * multiplier - random.shuffle(self.sched_gamma_list) - self.sched_step = [10, 20, 40, 60, 80, 100] * multiplier - random.shuffle(self.sched_step) - self.bs_list = [64, 128, 256, 512, 512, 1024] * multiplier - random.shuffle(self.bs_list) - self.hidden_list = [512, 1024, 2048, 512, 1024, 2048] * multiplier - random.shuffle(self.hidden_list) - self.n_stage_list = [3, 3, 3, 3, 3, 3] * multiplier - random.shuffle(self.n_stage_list) - # Learning rate - aa = math.log(0.0005, 10) - bb = math.log(0.01, 10) - log_lr_list = np.random.uniform(aa, bb, int(6 * multiplier)).tolist() - self.lr_list = [10 ** xx for xx in log_lr_list] - # plt.hist(self.lr_list, bins=50) - # plt.show() - - def train(self, args): - """Train multiple times using log-space random search""" - - best_acc_val = 20 - dic_best = {} - start = time.time() - cnt = 0 - for idx, lr in enumerate(self.lr_list): - bs = self.bs_list[idx] - sched_gamma = self.sched_gamma_list[idx] - sched_step = self.sched_step[idx] - hidden_size = self.hidden_list[idx] - n_stage = self.n_stage_list[idx] - - training = CASRTrainer(args) - - best_epoch = training.train() - dic_err, model = training.evaluate() - acc_val = dic_err['val']['all']['mean'] - cnt += 1 - print("Combination number: {}".format(cnt)) - - if acc_val < best_acc_val: - dic_best['lr'] = lr - dic_best['joints'] = self.joints - dic_best['bs'] = bs - dic_best['monocular'] = self.monocular - dic_best['sched_gamma'] = sched_gamma - dic_best['sched_step'] = sched_step - dic_best['hidden_size'] = hidden_size - dic_best['n_stage'] = n_stage - dic_best['acc_val'] = dic_err['val']['all']['d'] - dic_best['best_epoch'] = best_epoch - dic_best['random_seed'] = self.r_seed - # dic_best['acc_test'] = dic_err['test']['all']['mean'] - - best_acc_val = acc_val - model_best = model - - # Save model and log - now = datetime.datetime.now() - now_time = now.strftime("%Y%m%d-%H%M")[2:] - self.path_model = self.path_model + now_time + '.pkl' - torch.save(model_best.state_dict(), self.path_model) - with open(self.path_log + now_time, 'w') as f: - json.dump(dic_best, f) - end = time.time() - print('\n\n\n') - self.logger.info(" Tried {} combinations".format(cnt)) - self.logger.info(" Total time for hyperparameters search: {:.2f} minutes".format((end - start) / 60)) - self.logger.info(" Best hyperparameters are:") - for key, value in dic_best.items(): - self.logger.info(" {}: {}".format(key, value)) - - print() - self.logger.info("Accuracy in each cluster:") - - self.logger.info("Final accuracy Val: {:.2f}".format(dic_best['acc_val'])) - self.logger.info("\nSaved the model: {}".format(self.path_model)) diff --git a/monoloco/train/trainer_casr.py b/monoloco/train/trainer_casr.py deleted file mode 100644 index 6012635..0000000 --- a/monoloco/train/trainer_casr.py +++ /dev/null @@ -1,367 +0,0 @@ -# pylint: disable=too-many-statements - -""" -Training and evaluation of a neural network that, given 2D joints, estimates: -- 3D localization and confidence intervals -- Orientation -- Bounding box dimensions -""" - -import copy -import os -import datetime -import logging -from collections import defaultdict -import sys -import time -from itertools import chain - -import matplotlib.pyplot as plt -import torch -from torch.utils.data import DataLoader -from torch.optim import lr_scheduler - -from .. import __version__ -from .datasets import KeypointsDataset -from .losses import CompositeLoss, MultiTaskLoss, AutoTuneMultiTaskLoss -from ..network import extract_outputs, extract_labels -from ..network.architectures import LocoModel -from ..utils import set_logger - - -class CASRTrainer: - # Constants - VAL_BS = 10000 - - tasks = ('cyclist',) - val_task = 'cyclist' - lambdas = (1,) - #clusters = ['10', '20', '30', '40'] - input_size = 34 - output_size = 4 - dir_figures = os.path.join('figures', 'losses') - - def __init__(self, args): - """ - Initialize directories, load the data and parameters for the training - """ - - assert os.path.exists(args.joints), "Input file not found" - self.mode = args.mode - self.joints = args.joints - self.num_epochs = args.epochs - self.no_save = args.no_save - self.print_loss = args.print_loss - self.lr = args.lr - self.sched_step = args.sched_step - self.sched_gamma = args.sched_gamma - self.hidden_size = args.hidden_size - self.n_stage = args.n_stage - self.r_seed = args.r_seed - self.auto_tune_mtl = args.auto_tune_mtl - - if args.std: - self.output_size = 3 - name = 'casr_standard' - else: - name = 'casr' - # Select path out - if args.out: - self.path_out = args.out # full path without extension - dir_out, _ = os.path.split(self.path_out) - else: - dir_out = os.path.join('data', 'outputs') - now = datetime.datetime.now() - now_time = now.strftime("%Y%m%d-%H%M")[2:] - name_out = name + '-' + now_time + '.pkl' - self.path_out = os.path.join(dir_out, name_out) - assert os.path.exists(dir_out), "Directory to save the model not found" - print(self.path_out) - # Select the device - use_cuda = torch.cuda.is_available() - self.device = torch.device("cuda" if use_cuda else "cpu") - print('Device: ', self.device) - torch.manual_seed(self.r_seed) - if use_cuda: - torch.cuda.manual_seed(self.r_seed) - - losses_tr, losses_val = CompositeLoss(self.tasks)() - - if self.auto_tune_mtl: - self.mt_loss = AutoTuneMultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks) - else: - self.mt_loss = MultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks) - self.mt_loss.to(self.device) - - # Dataloader - self.dataloaders = {phase: DataLoader(KeypointsDataset(self.joints, phase=phase), - batch_size=args.bs, shuffle=True) for phase in ['train', 'val']} - - self.dataset_sizes = {phase: len(KeypointsDataset(self.joints, phase=phase)) - for phase in ['train', 'val']} - self.dataset_version = KeypointsDataset(self.joints, phase='train').get_version() - - self._set_logger(args) - - # Define the model - self.logger.info('Sizes of the dataset: {}'.format(self.dataset_sizes)) - print(">>> creating model") - - self.model = LocoModel( - input_size=self.input_size, - output_size=self.output_size, - linear_size=args.hidden_size, - p_dropout=args.dropout, - num_stage=self.n_stage, - device=self.device, - ) - self.model.to(self.device) - print(">>> model params: {:.3f}M".format(sum(p.numel() for p in self.model.parameters()) / 1000000.0)) - print(">>> loss params: {}".format(sum(p.numel() for p in self.mt_loss.parameters()))) - - # Optimizer and scheduler - all_params = chain(self.model.parameters(), self.mt_loss.parameters()) - self.optimizer = torch.optim.Adam(params=all_params, lr=args.lr) - self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min') - self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=self.sched_step, gamma=self.sched_gamma) - - def train(self): - since = time.time() - best_model_wts = copy.deepcopy(self.model.state_dict()) - best_acc = 1e6 - best_training_acc = 1e6 - best_epoch = 0 - epoch_losses = defaultdict(lambda: defaultdict(list)) - for epoch in range(self.num_epochs): - running_loss = defaultdict(lambda: defaultdict(int)) - - # Each epoch has a training and validation phase - for phase in ['train', 'val']: - if phase == 'train': - self.model.train() # Set model to training mode - else: - self.model.eval() # Set model to evaluate mode - - for inputs, labels, _, _ in self.dataloaders[phase]: - inputs = inputs.to(self.device) - labels = labels.to(self.device) - with torch.set_grad_enabled(phase == 'train'): - if phase == 'train': - self.optimizer.zero_grad() - outputs = self.model(inputs) - loss, _ = self.mt_loss(outputs, labels, phase=phase) - loss.backward() - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3) - self.optimizer.step() - self.scheduler.step() - - else: - outputs = self.model(inputs) - with torch.no_grad(): - loss_eval, loss_values_eval = self.mt_loss(outputs, labels, phase='val') - self.epoch_logs(phase, loss_eval, loss_values_eval, inputs, running_loss) - - self.cout_values(epoch, epoch_losses, running_loss) - - # deep copy the model - - if epoch_losses['val'][self.val_task][-1] < best_acc: - best_acc = epoch_losses['val'][self.val_task][-1] - best_training_acc = epoch_losses['train']['all'][-1] - best_epoch = epoch - best_model_wts = copy.deepcopy(self.model.state_dict()) - - time_elapsed = time.time() - since - print('\n\n' + '-' * 120) - self.logger.info('Training:\nTraining complete in {:.0f}m {:.0f}s' - .format(time_elapsed // 60, time_elapsed % 60)) - self.logger.info('Best training Accuracy: {:.3f}'.format(best_training_acc)) - self.logger.info('Best validation Accuracy for {}: {:.3f}'.format(self.val_task, best_acc)) - self.logger.info('Saved weights of the model at epoch: {}'.format(best_epoch)) - - self._print_losses(epoch_losses) - - # load best model weights - self.model.load_state_dict(best_model_wts) - return best_epoch - - def epoch_logs(self, phase, loss, loss_values, inputs, running_loss): - - running_loss[phase]['all'] += loss.item() * inputs.size(0) - for i, task in enumerate(self.tasks): - running_loss[phase][task] += loss_values[i].item() * inputs.size(0) - - def evaluate(self, load=False, model=None, debug=False): - - # To load a model instead of using the trained one - if load: - self.model.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage)) - - # Average distance on training and test set after unnormalizing - self.model.eval() - dic_err = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) # initialized to zero - dic_err['val']['sigmas'] = [0.] * len(self.tasks) - dataset = KeypointsDataset(self.joints, phase='val') - size_eval = len(dataset) - start = 0 - with torch.no_grad(): - for end in range(self.VAL_BS, size_eval + self.VAL_BS, self.VAL_BS): - end = end if end < size_eval else size_eval - inputs, labels, _, _ = dataset[start:end] - start = end - inputs = inputs.to(self.device) - labels = labels.to(self.device) - - # Debug plot for input-output distributions - if debug: - debug_plots(inputs, labels) - sys.exit() - - # Forward pass - # outputs = self.model(inputs) - #self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all') - - # self.cout_stats(dic_err['val'], size_eval, clst='all') - # Evaluate performances on different clusters and save statistics - - # Save the model and the results - if not (self.no_save or load): - torch.save(self.model.state_dict(), self.path_model) - print('-' * 120) - self.logger.info("\nmodel saved: {} \n".format(self.path_model)) - else: - self.logger.info("\nmodel not saved\n") - - return dic_err, self.model - - def compute_stats(self, outputs, labels, dic_err, size_eval, clst): - """Compute mean, bi and max of torch tensors""" - - _, loss_values = self.mt_loss(outputs, labels, phase='val') - rel_frac = outputs.size(0) / size_eval - - tasks = self.tasks # Exclude auxiliary - - for idx, task in enumerate(tasks): - dic_err[clst][task] += float(loss_values[idx].item()) * (outputs.size(0) / size_eval) - - # Distance - errs = torch.abs(extract_outputs(outputs)['d'] - extract_labels(labels)['d']) - assert rel_frac > 0.99, "Variance of errors not supported with partial evaluation" - - # Uncertainty - bis = extract_outputs(outputs)['bi'].cpu() - bi = float(torch.mean(bis).item()) - bi_perc = float(torch.sum(errs <= bis)) / errs.shape[0] - dic_err[clst]['bi'] += bi * rel_frac - dic_err[clst]['bi%'] += bi_perc * rel_frac - dic_err[clst]['std'] = errs.std() - - # (Don't) Save auxiliary task results - dic_err['sigmas'].append(0) - - if self.auto_tune_mtl: - assert len(loss_values) == 2 * len(self.tasks) - for i, _ in enumerate(self.tasks): - dic_err['sigmas'][i] += float(loss_values[len(tasks) + i + 1].item()) * rel_frac - - def cout_stats(self, dic_err, size_eval, clst): - if clst == 'all': - print('-' * 120) - self.logger.info("Evaluation, val set: \nAv. dist D: {:.2f} m with bi {:.2f} ({:.1f}%), \n" - "X: {:.1f} cm, Y: {:.1f} cm \nOri: {:.1f} " - "\n H: {:.1f} cm, W: {:.1f} cm, L: {:.1f} cm" - "\nAuxiliary Task: {:.1f} %, " - .format(dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100, - dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100, - dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100, - dic_err[clst]['l'] * 100, dic_err[clst]['aux'] * 100)) - if self.auto_tune_mtl: - self.logger.info("Sigmas: Z: {:.2f}, X: {:.2f}, Y:{:.2f}, H: {:.2f}, W: {:.2f}, L: {:.2f}, ORI: {:.2f}" - " AUX:{:.2f}\n" - .format(*dic_err['sigmas'])) - else: - self.logger.info("Val err clust {} --> D:{:.2f}m, bi:{:.2f} ({:.1f}%), STD:{:.1f}m X:{:.1f} Y:{:.1f} " - "Ori:{:.1f}d, H: {:.0f} W: {:.0f} L:{:.0f} for {} pp. " - .format(clst, dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100, - dic_err[clst]['std'], dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100, - dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100, - dic_err[clst]['l'] * 100, size_eval)) - - def cout_values(self, epoch, epoch_losses, running_loss): - - string = '\r' + '{:.0f} ' - format_list = [epoch] - for phase in running_loss: - string = string + phase[0:1].upper() + ':' - for el in running_loss['train']: - loss = running_loss[phase][el] / self.dataset_sizes[phase] - epoch_losses[phase][el].append(loss) - if el == 'all': - string = string + ':{:.1f} ' - format_list.append(loss) - elif el in ('ori', 'aux'): - string = string + el + ':{:.1f} ' - format_list.append(loss) - else: - string = string + el + ':{:.0f} ' - format_list.append(loss * 100) - - if epoch % 10 == 0: - print(string.format(*format_list)) - - def _print_losses(self, epoch_losses): - if not self.print_loss: - return - os.makedirs(self.dir_figures, exist_ok=True) - for idx, phase in enumerate(epoch_losses): - for idx_2, el in enumerate(epoch_losses['train']): - plt.figure(idx + idx_2) - plt.title(phase + '_' + el) - plt.xlabel('epochs') - plt.plot(epoch_losses[phase][el][10:], label='{} Loss: {}'.format(phase, el)) - plt.savefig(os.path.join(self.dir_figures, '{}_loss_{}.png'.format(phase, el))) - plt.close() - - def _set_logger(self, args): - if self.no_save: - logging.basicConfig(level=logging.INFO) - self.logger = logging.getLogger(__name__) - else: - self.path_model = self.path_out - print(self.path_model) - self.logger = set_logger(os.path.splitext(self.path_out)[0]) # remove .pkl - self.logger.info( # pylint: disable=logging-fstring-interpolation - f'\nVERSION: {__version__}\n' - f'\nINPUT_FILE: {args.joints}' - f'\nInput file version: {self.dataset_version}' - f'\nTorch version: {torch.__version__}\n' - f'\nTraining arguments:' - f'\nmode: {self.mode} \nlearning rate: {args.lr} \nbatch_size: {args.bs}' - f'\nepochs: {args.epochs} \ndropout: {args.dropout} ' - f'\nscheduler step: {args.sched_step} \nscheduler gamma: {args.sched_gamma} ' - f'\ninput_size: {self.input_size} \noutput_size: {self.output_size} ' - f'\nhidden_size: {args.hidden_size}' - f' \nn_stages: {args.n_stage} \n r_seed: {args.r_seed} \nlambdas: {self.lambdas}' - ) - - -def debug_plots(inputs, labels): - inputs_shoulder = inputs.cpu().numpy()[:, 5] - inputs_hip = inputs.cpu().numpy()[:, 11] - labels = labels.cpu().numpy() - heights = inputs_hip - inputs_shoulder - plt.figure(1) - plt.hist(heights, bins='auto') - plt.show() - plt.figure(2) - plt.hist(labels, bins='auto') - plt.show() - - -def get_accuracy(outputs, labels): - """From Binary cross entropy outputs to accuracy""" - - mask = outputs >= 0.5 - accuracy = 1. - torch.mean(torch.abs(mask.float() - labels)).item() - return accuracy