refactor trainer

This commit is contained in:
Lorenzo 2021-03-23 08:31:00 +01:00
parent 453e4b7b24
commit 75593fe3e0

View File

@ -50,11 +50,11 @@ class Trainer:
warnings.warn("Warning: default logs directory not found")
assert os.path.exists(args.joints), "Input file not found"
self.mode = args.mode
self.joints = args.joints
self.num_epochs = args.epochs
self.no_save = args.no_save
self.print_loss = args.print_loss
self.monocular = args.monocular
self.lr = args.lr
self.sched_step = args.sched_step
self.sched_gamma = args.sched_gamma
@ -73,7 +73,7 @@ class Trainer:
torch.cuda.manual_seed(self.r_seed)
# Remove auxiliary task if monocular
if self.monocular and self.tasks[-1] == 'aux':
if self.mode == 'mono' and self.tasks[-1] == 'aux':
self.tasks = self.tasks[:-1]
self.lambdas = self.lambdas[:-1]
@ -85,28 +85,28 @@ class Trainer:
self.mt_loss = MultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
self.mt_loss.to(self.device)
if not self.monocular:
if not self.mode == 'stereo':
input_size = 68
output_size = 10
else:
input_size = 34
output_size = 9
name = 'monoloco_pp' if self.monocular else 'monstereo'
name = 'monoloco_pp' if self.mode == 'mono' else 'monstereo'
now = datetime.datetime.now()
now_time = now.strftime("%Y%m%d-%H%M")[2:]
name_out = name + '-' + now_time
if not self.no_save:
self.path_model = os.path.join(dir_out, name_out + '.pkl')
self.logger = set_logger(os.path.join(dir_logs, name_out))
self.logger.info("Training arguments: \nepochs: {} \nbatch_size: {} \ndropout: {}"
"\nmonocular: {} \nlearning rate: {} \nscheduler step: {} \nscheduler gamma: {} "
"\ninput_size: {} \noutput_size: {}\nhidden_size: {} \nn_stages: {} "
"\nr_seed: {} \nlambdas: {} \ninput_file: {}"
.format(args.epochs, args.bs, args.dropout, self.monocular,
args.lr, args.sched_step, args.sched_gamma, input_size,
output_size, args.hidden_size, args.n_stage, args.r_seed,
self.lambdas, self.joints))
self.logger.info(
f'Training arguments: \ninput_file: {self.joints} \nmode: {self.mode} '
f'\nlearning rate: {args.lr} \nbatch_size: {args.bs}'
f'\nepochs: {args.epochs} \ndropout: {args.dropout} '
f'\nscheduler step: {args.sched_step} \nscheduler gamma: {args.sched_gamma} '
f'\ninput_size: {input_size} \noutput_size: {output_size} \nhidden_size: {args.hidden_size}'
f' \nn_stages: {args.n_stage} \n r_seed: {args.r_seed} \nlambdas: {self.lambdas}'
)
else:
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
@ -276,7 +276,7 @@ class Trainer:
dic_err[clst]['std'] = errs.std()
# (Don't) Save auxiliary task results
if self.monocular:
if self.mode == 'mono':
dic_err[clst]['aux'] = 0
dic_err['sigmas'].append(0)
else: