Simplifying

This commit is contained in:
Charles Joseph Pierre Beauville 2021-06-28 00:30:42 +02:00
parent 0333295edb
commit 5ffc7dd0f6
4 changed files with 37 additions and 39 deletions

View File

@ -87,7 +87,7 @@ def cli():
# Training # Training
training_parser.add_argument('--joints', help='Json file with input joints', required=True) training_parser.add_argument('--joints', help='Json file with input joints', required=True)
training_parser.add_argument('--mode', help='mono, stereo', default='mono') training_parser.add_argument('--mode', help='mono, stereo, casr, casr_std', default='mono')
training_parser.add_argument('--out', help='output_path, e.g., data/outputs/test.pkl') training_parser.add_argument('--out', help='output_path, e.g., data/outputs/test.pkl')
training_parser.add_argument('-e', '--epochs', type=int, help='number of epochs to train for', default=500) training_parser.add_argument('-e', '--epochs', type=int, help='number of epochs to train for', default=500)
training_parser.add_argument('--bs', type=int, default=512, help='input batch size') training_parser.add_argument('--bs', type=int, default=512, help='input batch size')
@ -99,9 +99,6 @@ def cli():
training_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=1024) training_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=1024)
training_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3) training_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3)
training_parser.add_argument('--hyp', help='run hyperparameters tuning', action='store_true') training_parser.add_argument('--hyp', help='run hyperparameters tuning', action='store_true')
training_parser.add_argument('--casr', help='run casr training', action='store_true')
training_parser.add_argument('--std', help='run casr training with only standard gestures',
action='store_true')
training_parser.add_argument('--multiplier', type=int, help='Size of the grid of hyp search', default=1) training_parser.add_argument('--multiplier', type=int, help='Size of the grid of hyp search', default=1)
training_parser.add_argument('--r_seed', type=int, help='specify the seed for training and hyp tuning', default=1) training_parser.add_argument('--r_seed', type=int, help='specify the seed for training and hyp tuning', default=1)
training_parser.add_argument('--print_loss', help='print training and validation losses', action='store_true') training_parser.add_argument('--print_loss', help='print training and validation losses', action='store_true')
@ -169,22 +166,11 @@ def main():
elif args.command == 'train': elif args.command == 'train':
from .train import HypTuning from .train import HypTuning
if args.hyp: if args.hyp:
if args.casr: hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
from .train import HypTuningCasr monocular=args.monocular, dropout=args.dropout,
hyp_tuning_casr = HypTuningCasr(joints=args.joints, epochs=args.epochs, multiplier=args.multiplier, r_seed=args.r_seed,
monocular=args.monocular, dropout=args.dropout, mode=args.mode)
multiplier=args.multiplier, r_seed=args.r_seed) hyp_tuning.train(args)
hyp_tuning_casr.train(args)
else:
hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
monocular=args.monocular, dropout=args.dropout,
multiplier=args.multiplier, r_seed=args.r_seed)
hyp_tuning.train(args)
elif args.casr:
from .train import CASRTrainer
training = CASRTrainer(args)
_ = training.train()
_ = training.evaluate()
else: else:
from .train import Trainer from .train import Trainer
training = Trainer(args) training = Trainer(args)

View File

@ -1,5 +1,3 @@
from .hyp_tuning import HypTuning from .hyp_tuning import HypTuning
from .hyp_tuning_casr import HypTuningCasr
from .trainer import Trainer from .trainer import Trainer
from .trainer_casr import CASRTrainer

View File

@ -15,7 +15,8 @@ from .trainer import Trainer
class HypTuning: class HypTuning:
def __init__(self, joints, epochs, monocular, dropout, multiplier=1, r_seed=1): def __init__(self, joints, epochs, monocular,
dropout, multiplier=1, r_seed=1, mode=None):
""" """
Initialize directories, load the data and parameters for the training Initialize directories, load the data and parameters for the training
""" """
@ -33,6 +34,9 @@ class HypTuning:
os.makedirs(dir_logs) os.makedirs(dir_logs)
name_out = 'hyp-monoloco-' if monocular else 'hyp-ms-' name_out = 'hyp-monoloco-' if monocular else 'hyp-ms-'
if mode:
name_out = ('hyp-casr-' if mode == 'casr' else
'hyp-casr_std-' if mode == 'casr_std' else name_out)
self.path_log = os.path.join(dir_logs, name_out) self.path_log = os.path.join(dir_logs, name_out)
self.path_model = os.path.join(dir_out, name_out) self.path_model = os.path.join(dir_out, name_out)
@ -120,7 +124,8 @@ class HypTuning:
print() print()
self.logger.info("Accuracy in each cluster:") self.logger.info("Accuracy in each cluster:")
for key in ('10', '20', '30', '>30', 'all'): if args.mode in ['mono', 'stereo']:
self.logger.info("Val: error in cluster {} = {} ".format(key, dic_err_best['val'][key]['d'])) for key in ('10', '20', '30', '>30', 'all'):
self.logger.info("Val: error in cluster {} = {} ".format(key, dic_err_best['val'][key]['d']))
self.logger.info("Final accuracy Val: {:.2f}".format(dic_best['acc_val'])) self.logger.info("Final accuracy Val: {:.2f}".format(dic_best['acc_val']))
self.logger.info("\nSaved the model: {}".format(self.path_model)) self.logger.info("\nSaved the model: {}".format(self.path_model))

View File

@ -41,8 +41,8 @@ class Trainer:
val_task = 'd' val_task = 'd'
lambdas = (1, 1, 1, 1, 1, 1, 1, 1) lambdas = (1, 1, 1, 1, 1, 1, 1, 1)
clusters = ['10', '20', '30', '40'] clusters = ['10', '20', '30', '40']
input_size = dict(mono=34, stereo=68) input_size = dict(mono=34, stereo=68, casr=34, casr_std=34)
output_size = dict(mono=9, stereo=10) output_size = dict(mono=9, stereo=10, casr=4, casr_std=3)
dir_figures = os.path.join('figures', 'losses') dir_figures = os.path.join('figures', 'losses')
def __init__(self, args): def __init__(self, args):
@ -63,14 +63,21 @@ class Trainer:
self.n_stage = args.n_stage self.n_stage = args.n_stage
self.r_seed = args.r_seed self.r_seed = args.r_seed
self.auto_tune_mtl = args.auto_tune_mtl self.auto_tune_mtl = args.auto_tune_mtl
self.is_casr = self.mode in ['casr', 'casr_std']
if self.is_casr:
self.tasks = ('cyclist',)
self.val_task = 'cyclist'
self.lambdas = (1,)
# Select path out # Select path out
if args.out: if args.out:
self.path_out = args.out # full path without extension self.path_out = args.out # full path without extension
dir_out, _ = os.path.split(self.path_out) dir_out, _ = os.path.split(self.path_out)
else: else:
dir_out = os.path.join('data', 'outputs') dir_out = os.path.join('data', 'outputs')
name = 'monoloco_pp' if self.mode == 'mono' else 'monstereo' name = ('monoloco_pp' if self.mode == 'mono' else
'monstereo' if self.mode == 'stereo' else
'casr' if self.mode == 'casr' else 'casr_std')
now = datetime.datetime.now() now = datetime.datetime.now()
now_time = now.strftime("%Y%m%d-%H%M")[2:] now_time = now.strftime("%Y%m%d-%H%M")[2:]
name_out = name + '-' + now_time + '.pkl' name_out = name + '-' + now_time + '.pkl'
@ -224,18 +231,20 @@ class Trainer:
# Forward pass # Forward pass
outputs = self.model(inputs) outputs = self.model(inputs)
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all') if not self.is_casr:
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
self.cout_stats(dic_err['val'], size_eval, clst='all') if not self.is_casr:
# Evaluate performances on different clusters and save statistics self.cout_stats(dic_err['val'], size_eval, clst='all')
for clst in self.clusters: # Evaluate performances on different clusters and save statistics
inputs, labels, size_eval = dataset.get_cluster_annotations(clst) for clst in self.clusters:
inputs, labels = inputs.to(self.device), labels.to(self.device) inputs, labels, size_eval = dataset.get_cluster_annotations(clst)
inputs, labels = inputs.to(self.device), labels.to(self.device)
# Forward pass on each cluster # Forward pass on each cluster
outputs = self.model(inputs) outputs = self.model(inputs)
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst=clst) self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst=clst)
self.cout_stats(dic_err['val'], size_eval, clst=clst) self.cout_stats(dic_err['val'], size_eval, clst=clst)
# Save the model and the results # Save the model and the results
if not (self.no_save or load): if not (self.no_save or load):
@ -274,7 +283,7 @@ class Trainer:
if self.mode == 'mono': if self.mode == 'mono':
dic_err[clst]['aux'] = 0 dic_err[clst]['aux'] = 0
dic_err['sigmas'].append(0) dic_err['sigmas'].append(0)
else: elif not self.is_casr:
acc_aux = get_accuracy(extract_outputs(outputs)['aux'], extract_labels(labels)['aux']) acc_aux = get_accuracy(extract_outputs(outputs)['aux'], extract_labels(labels)['aux'])
dic_err[clst]['aux'] += acc_aux * rel_frac dic_err[clst]['aux'] += acc_aux * rel_frac