Simplifying
This commit is contained in:
parent
0333295edb
commit
5ffc7dd0f6
@ -87,7 +87,7 @@ def cli():
|
|||||||
|
|
||||||
# Training
|
# Training
|
||||||
training_parser.add_argument('--joints', help='Json file with input joints', required=True)
|
training_parser.add_argument('--joints', help='Json file with input joints', required=True)
|
||||||
training_parser.add_argument('--mode', help='mono, stereo', default='mono')
|
training_parser.add_argument('--mode', help='mono, stereo, casr, casr_std', default='mono')
|
||||||
training_parser.add_argument('--out', help='output_path, e.g., data/outputs/test.pkl')
|
training_parser.add_argument('--out', help='output_path, e.g., data/outputs/test.pkl')
|
||||||
training_parser.add_argument('-e', '--epochs', type=int, help='number of epochs to train for', default=500)
|
training_parser.add_argument('-e', '--epochs', type=int, help='number of epochs to train for', default=500)
|
||||||
training_parser.add_argument('--bs', type=int, default=512, help='input batch size')
|
training_parser.add_argument('--bs', type=int, default=512, help='input batch size')
|
||||||
@ -99,9 +99,6 @@ def cli():
|
|||||||
training_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=1024)
|
training_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=1024)
|
||||||
training_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3)
|
training_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3)
|
||||||
training_parser.add_argument('--hyp', help='run hyperparameters tuning', action='store_true')
|
training_parser.add_argument('--hyp', help='run hyperparameters tuning', action='store_true')
|
||||||
training_parser.add_argument('--casr', help='run casr training', action='store_true')
|
|
||||||
training_parser.add_argument('--std', help='run casr training with only standard gestures',
|
|
||||||
action='store_true')
|
|
||||||
training_parser.add_argument('--multiplier', type=int, help='Size of the grid of hyp search', default=1)
|
training_parser.add_argument('--multiplier', type=int, help='Size of the grid of hyp search', default=1)
|
||||||
training_parser.add_argument('--r_seed', type=int, help='specify the seed for training and hyp tuning', default=1)
|
training_parser.add_argument('--r_seed', type=int, help='specify the seed for training and hyp tuning', default=1)
|
||||||
training_parser.add_argument('--print_loss', help='print training and validation losses', action='store_true')
|
training_parser.add_argument('--print_loss', help='print training and validation losses', action='store_true')
|
||||||
@ -169,22 +166,11 @@ def main():
|
|||||||
elif args.command == 'train':
|
elif args.command == 'train':
|
||||||
from .train import HypTuning
|
from .train import HypTuning
|
||||||
if args.hyp:
|
if args.hyp:
|
||||||
if args.casr:
|
hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
|
||||||
from .train import HypTuningCasr
|
monocular=args.monocular, dropout=args.dropout,
|
||||||
hyp_tuning_casr = HypTuningCasr(joints=args.joints, epochs=args.epochs,
|
multiplier=args.multiplier, r_seed=args.r_seed,
|
||||||
monocular=args.monocular, dropout=args.dropout,
|
mode=args.mode)
|
||||||
multiplier=args.multiplier, r_seed=args.r_seed)
|
hyp_tuning.train(args)
|
||||||
hyp_tuning_casr.train(args)
|
|
||||||
else:
|
|
||||||
hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
|
|
||||||
monocular=args.monocular, dropout=args.dropout,
|
|
||||||
multiplier=args.multiplier, r_seed=args.r_seed)
|
|
||||||
hyp_tuning.train(args)
|
|
||||||
elif args.casr:
|
|
||||||
from .train import CASRTrainer
|
|
||||||
training = CASRTrainer(args)
|
|
||||||
_ = training.train()
|
|
||||||
_ = training.evaluate()
|
|
||||||
else:
|
else:
|
||||||
from .train import Trainer
|
from .train import Trainer
|
||||||
training = Trainer(args)
|
training = Trainer(args)
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
from .hyp_tuning import HypTuning
|
from .hyp_tuning import HypTuning
|
||||||
from .hyp_tuning_casr import HypTuningCasr
|
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
from .trainer_casr import CASRTrainer
|
|
||||||
|
|||||||
@ -15,7 +15,8 @@ from .trainer import Trainer
|
|||||||
|
|
||||||
class HypTuning:
|
class HypTuning:
|
||||||
|
|
||||||
def __init__(self, joints, epochs, monocular, dropout, multiplier=1, r_seed=1):
|
def __init__(self, joints, epochs, monocular,
|
||||||
|
dropout, multiplier=1, r_seed=1, mode=None):
|
||||||
"""
|
"""
|
||||||
Initialize directories, load the data and parameters for the training
|
Initialize directories, load the data and parameters for the training
|
||||||
"""
|
"""
|
||||||
@ -33,6 +34,9 @@ class HypTuning:
|
|||||||
os.makedirs(dir_logs)
|
os.makedirs(dir_logs)
|
||||||
|
|
||||||
name_out = 'hyp-monoloco-' if monocular else 'hyp-ms-'
|
name_out = 'hyp-monoloco-' if monocular else 'hyp-ms-'
|
||||||
|
if mode:
|
||||||
|
name_out = ('hyp-casr-' if mode == 'casr' else
|
||||||
|
'hyp-casr_std-' if mode == 'casr_std' else name_out)
|
||||||
|
|
||||||
self.path_log = os.path.join(dir_logs, name_out)
|
self.path_log = os.path.join(dir_logs, name_out)
|
||||||
self.path_model = os.path.join(dir_out, name_out)
|
self.path_model = os.path.join(dir_out, name_out)
|
||||||
@ -120,7 +124,8 @@ class HypTuning:
|
|||||||
print()
|
print()
|
||||||
self.logger.info("Accuracy in each cluster:")
|
self.logger.info("Accuracy in each cluster:")
|
||||||
|
|
||||||
for key in ('10', '20', '30', '>30', 'all'):
|
if args.mode in ['mono', 'stereo']:
|
||||||
self.logger.info("Val: error in cluster {} = {} ".format(key, dic_err_best['val'][key]['d']))
|
for key in ('10', '20', '30', '>30', 'all'):
|
||||||
|
self.logger.info("Val: error in cluster {} = {} ".format(key, dic_err_best['val'][key]['d']))
|
||||||
self.logger.info("Final accuracy Val: {:.2f}".format(dic_best['acc_val']))
|
self.logger.info("Final accuracy Val: {:.2f}".format(dic_best['acc_val']))
|
||||||
self.logger.info("\nSaved the model: {}".format(self.path_model))
|
self.logger.info("\nSaved the model: {}".format(self.path_model))
|
||||||
|
|||||||
@ -41,8 +41,8 @@ class Trainer:
|
|||||||
val_task = 'd'
|
val_task = 'd'
|
||||||
lambdas = (1, 1, 1, 1, 1, 1, 1, 1)
|
lambdas = (1, 1, 1, 1, 1, 1, 1, 1)
|
||||||
clusters = ['10', '20', '30', '40']
|
clusters = ['10', '20', '30', '40']
|
||||||
input_size = dict(mono=34, stereo=68)
|
input_size = dict(mono=34, stereo=68, casr=34, casr_std=34)
|
||||||
output_size = dict(mono=9, stereo=10)
|
output_size = dict(mono=9, stereo=10, casr=4, casr_std=3)
|
||||||
dir_figures = os.path.join('figures', 'losses')
|
dir_figures = os.path.join('figures', 'losses')
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
@ -63,14 +63,21 @@ class Trainer:
|
|||||||
self.n_stage = args.n_stage
|
self.n_stage = args.n_stage
|
||||||
self.r_seed = args.r_seed
|
self.r_seed = args.r_seed
|
||||||
self.auto_tune_mtl = args.auto_tune_mtl
|
self.auto_tune_mtl = args.auto_tune_mtl
|
||||||
|
self.is_casr = self.mode in ['casr', 'casr_std']
|
||||||
|
|
||||||
|
if self.is_casr:
|
||||||
|
self.tasks = ('cyclist',)
|
||||||
|
self.val_task = 'cyclist'
|
||||||
|
self.lambdas = (1,)
|
||||||
# Select path out
|
# Select path out
|
||||||
if args.out:
|
if args.out:
|
||||||
self.path_out = args.out # full path without extension
|
self.path_out = args.out # full path without extension
|
||||||
dir_out, _ = os.path.split(self.path_out)
|
dir_out, _ = os.path.split(self.path_out)
|
||||||
else:
|
else:
|
||||||
dir_out = os.path.join('data', 'outputs')
|
dir_out = os.path.join('data', 'outputs')
|
||||||
name = 'monoloco_pp' if self.mode == 'mono' else 'monstereo'
|
name = ('monoloco_pp' if self.mode == 'mono' else
|
||||||
|
'monstereo' if self.mode == 'stereo' else
|
||||||
|
'casr' if self.mode == 'casr' else 'casr_std')
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
||||||
name_out = name + '-' + now_time + '.pkl'
|
name_out = name + '-' + now_time + '.pkl'
|
||||||
@ -224,18 +231,20 @@ class Trainer:
|
|||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
outputs = self.model(inputs)
|
outputs = self.model(inputs)
|
||||||
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
|
if not self.is_casr:
|
||||||
|
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
|
||||||
|
|
||||||
self.cout_stats(dic_err['val'], size_eval, clst='all')
|
if not self.is_casr:
|
||||||
# Evaluate performances on different clusters and save statistics
|
self.cout_stats(dic_err['val'], size_eval, clst='all')
|
||||||
for clst in self.clusters:
|
# Evaluate performances on different clusters and save statistics
|
||||||
inputs, labels, size_eval = dataset.get_cluster_annotations(clst)
|
for clst in self.clusters:
|
||||||
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
inputs, labels, size_eval = dataset.get_cluster_annotations(clst)
|
||||||
|
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
||||||
|
|
||||||
# Forward pass on each cluster
|
# Forward pass on each cluster
|
||||||
outputs = self.model(inputs)
|
outputs = self.model(inputs)
|
||||||
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst=clst)
|
self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst=clst)
|
||||||
self.cout_stats(dic_err['val'], size_eval, clst=clst)
|
self.cout_stats(dic_err['val'], size_eval, clst=clst)
|
||||||
|
|
||||||
# Save the model and the results
|
# Save the model and the results
|
||||||
if not (self.no_save or load):
|
if not (self.no_save or load):
|
||||||
@ -274,7 +283,7 @@ class Trainer:
|
|||||||
if self.mode == 'mono':
|
if self.mode == 'mono':
|
||||||
dic_err[clst]['aux'] = 0
|
dic_err[clst]['aux'] = 0
|
||||||
dic_err['sigmas'].append(0)
|
dic_err['sigmas'].append(0)
|
||||||
else:
|
elif not self.is_casr:
|
||||||
acc_aux = get_accuracy(extract_outputs(outputs)['aux'], extract_labels(labels)['aux'])
|
acc_aux = get_accuracy(extract_outputs(outputs)['aux'], extract_labels(labels)['aux'])
|
||||||
dic_err[clst]['aux'] += acc_aux * rel_frac
|
dic_err[clst]['aux'] += acc_aux * rel_frac
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user