Simplifying
This commit is contained in:
parent
0333295edb
commit
5ffc7dd0f6
@ -87,7 +87,7 @@ def cli():
|
||||
|
||||
# Training
|
||||
training_parser.add_argument('--joints', help='Json file with input joints', required=True)
|
||||
training_parser.add_argument('--mode', help='mono, stereo', default='mono')
|
||||
training_parser.add_argument('--mode', help='mono, stereo, casr, casr_std', default='mono')
|
||||
training_parser.add_argument('--out', help='output_path, e.g., data/outputs/test.pkl')
|
||||
training_parser.add_argument('-e', '--epochs', type=int, help='number of epochs to train for', default=500)
|
||||
training_parser.add_argument('--bs', type=int, default=512, help='input batch size')
|
||||
@ -99,9 +99,6 @@ def cli():
|
||||
training_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=1024)
|
||||
training_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3)
|
||||
training_parser.add_argument('--hyp', help='run hyperparameters tuning', action='store_true')
|
||||
training_parser.add_argument('--casr', help='run casr training', action='store_true')
|
||||
training_parser.add_argument('--std', help='run casr training with only standard gestures',
|
||||
action='store_true')
|
||||
training_parser.add_argument('--multiplier', type=int, help='Size of the grid of hyp search', default=1)
|
||||
training_parser.add_argument('--r_seed', type=int, help='specify the seed for training and hyp tuning', default=1)
|
||||
training_parser.add_argument('--print_loss', help='print training and validation losses', action='store_true')
|
||||
@ -169,22 +166,11 @@ def main():
|
||||
elif args.command == 'train':
|
||||
from .train import HypTuning
|
||||
if args.hyp:
|
||||
if args.casr:
|
||||
from .train import HypTuningCasr
|
||||
hyp_tuning_casr = HypTuningCasr(joints=args.joints, epochs=args.epochs,
|
||||
monocular=args.monocular, dropout=args.dropout,
|
||||
multiplier=args.multiplier, r_seed=args.r_seed)
|
||||
hyp_tuning_casr.train(args)
|
||||
else:
|
||||
hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
|
||||
monocular=args.monocular, dropout=args.dropout,
|
||||
multiplier=args.multiplier, r_seed=args.r_seed)
|
||||
hyp_tuning.train(args)
|
||||
elif args.casr:
|
||||
from .train import CASRTrainer
|
||||
training = CASRTrainer(args)
|
||||
_ = training.train()
|
||||
_ = training.evaluate()
|
||||
hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
|
||||
monocular=args.monocular, dropout=args.dropout,
|
||||
multiplier=args.multiplier, r_seed=args.r_seed,
|
||||
mode=args.mode)
|
||||
hyp_tuning.train(args)
|
||||
else:
|
||||
from .train import Trainer
|
||||
training = Trainer(args)
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
from .hyp_tuning import HypTuning
|
||||
from .hyp_tuning_casr import HypTuningCasr
|
||||
from .trainer import Trainer
|
||||
from .trainer_casr import CASRTrainer
|
||||
|
||||
@ -15,7 +15,8 @@ from .trainer import Trainer
|
||||
|
||||
class HypTuning:
|
||||
|
||||
def __init__(self, joints, epochs, monocular, dropout, multiplier=1, r_seed=1):
|
||||
def __init__(self, joints, epochs, monocular,
|
||||
dropout, multiplier=1, r_seed=1, mode=None):
|
||||
"""
|
||||
Initialize directories, load the data and parameters for the training
|
||||
"""
|
||||
@ -33,6 +34,9 @@ class HypTuning:
|
||||
os.makedirs(dir_logs)
|
||||
|
||||
name_out = 'hyp-monoloco-' if monocular else 'hyp-ms-'
|
||||
if mode:
|
||||
name_out = ('hyp-casr-' if mode == 'casr' else
|
||||
'hyp-casr_std-' if mode == 'casr_std' else name_out)
|
||||
|
||||
self.path_log = os.path.join(dir_logs, name_out)
|
||||
self.path_model = os.path.join(dir_out, name_out)
|
||||
@ -120,7 +124,8 @@ class HypTuning:
|
||||
print()
|
||||
self.logger.info("Accuracy in each cluster:")
|
||||
|
||||
for key in ('10', '20', '30', '>30', 'all'):
|
||||
self.logger.info("Val: error in cluster {} = {} ".format(key, dic_err_best['val'][key]['d']))
|
||||
if args.mode in ['mono', 'stereo']:
|
||||
for key in ('10', '20', '30', '>30', 'all'):
|
||||
self.logger.info("Val: error in cluster {} = {} ".format(key, dic_err_best['val'][key]['d']))
|
||||
self.logger.info("Final accuracy Val: {:.2f}".format(dic_best['acc_val']))
|
||||
self.logger.info("\nSaved the model: {}".format(self.path_model))
|
||||
|
||||
@ -41,8 +41,8 @@ class Trainer:
|
||||
val_task = 'd'
|
||||
lambdas = (1, 1, 1, 1, 1, 1, 1, 1)
|
||||
clusters = ['10', '20', '30', '40']
|
||||
input_size = dict(mono=34, stereo=68)
|
||||
output_size = dict(mono=9, stereo=10)
|
||||
input_size = dict(mono=34, stereo=68, casr=34, casr_std=34)
|
||||
output_size = dict(mono=9, stereo=10, casr=4, casr_std=3)
|
||||
dir_figures = os.path.join('figures', 'losses')
|
||||
|
||||
def __init__(self, args):
|
||||
@ -63,14 +63,21 @@ class Trainer:
|
||||
self.n_stage = args.n_stage
|
||||
self.r_seed = args.r_seed
|
||||
self.auto_tune_mtl = args.auto_tune_mtl
|
||||
self.is_casr = self.mode in ['casr', 'casr_std']
|
||||
|
||||
if self.is_casr:
|
||||
self.tasks = ('cyclist',)
|
||||
self.val_task = 'cyclist'
|
||||
self.lambdas = (1,)
|
||||
# Select path out
|
||||
if args.out:
|
||||
self.path_out = args.out # full path without extension
|
||||
dir_out, _ = os.path.split(self.path_out)
|
||||
else:
|
||||
dir_out = os.path.join('data', 'outputs')
|
||||
name = 'monoloco_pp' if self.mode == 'mono' else 'monstereo'
|
||||
name = ('monoloco_pp' if self.mode == 'mono' else
|
||||
'monstereo' if self.mode == 'stereo' else
|
||||
'casr' if self.mode == 'casr' else 'casr_std')
|
||||
now = datetime.datetime.now()
|
||||
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
||||
name_out = name + '-' + now_time + '.pkl'
|
||||
@ -224,18 +231,20 @@ class Trainer:
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(inputs)
|
||||
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
|
||||
if not self.is_casr:
|
||||
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
|
||||
|
||||
self.cout_stats(dic_err['val'], size_eval, clst='all')
|
||||
# Evaluate performances on different clusters and save statistics
|
||||
for clst in self.clusters:
|
||||
inputs, labels, size_eval = dataset.get_cluster_annotations(clst)
|
||||
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
||||
if not self.is_casr:
|
||||
self.cout_stats(dic_err['val'], size_eval, clst='all')
|
||||
# Evaluate performances on different clusters and save statistics
|
||||
for clst in self.clusters:
|
||||
inputs, labels, size_eval = dataset.get_cluster_annotations(clst)
|
||||
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
||||
|
||||
# Forward pass on each cluster
|
||||
outputs = self.model(inputs)
|
||||
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst=clst)
|
||||
self.cout_stats(dic_err['val'], size_eval, clst=clst)
|
||||
# Forward pass on each cluster
|
||||
outputs = self.model(inputs)
|
||||
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst=clst)
|
||||
self.cout_stats(dic_err['val'], size_eval, clst=clst)
|
||||
|
||||
# Save the model and the results
|
||||
if not (self.no_save or load):
|
||||
@ -274,7 +283,7 @@ class Trainer:
|
||||
if self.mode == 'mono':
|
||||
dic_err[clst]['aux'] = 0
|
||||
dic_err['sigmas'].append(0)
|
||||
else:
|
||||
elif not self.is_casr:
|
||||
acc_aux = get_accuracy(extract_outputs(outputs)['aux'], extract_labels(labels)['aux'])
|
||||
dic_err[clst]['aux'] += acc_aux * rel_frac
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user