From 75593fe3e0f76ca4cd973309fd40a6866b7fd1b0 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Tue, 23 Mar 2021 08:31:00 +0100 Subject: [PATCH] refactor trainer --- monoloco/train/trainer.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/monoloco/train/trainer.py b/monoloco/train/trainer.py index cdbb632..088bd8a 100644 --- a/monoloco/train/trainer.py +++ b/monoloco/train/trainer.py @@ -50,11 +50,11 @@ class Trainer: warnings.warn("Warning: default logs directory not found") assert os.path.exists(args.joints), "Input file not found" + self.mode = args.mode self.joints = args.joints self.num_epochs = args.epochs self.no_save = args.no_save self.print_loss = args.print_loss - self.monocular = args.monocular self.lr = args.lr self.sched_step = args.sched_step self.sched_gamma = args.sched_gamma @@ -73,7 +73,7 @@ class Trainer: torch.cuda.manual_seed(self.r_seed) # Remove auxiliary task if monocular - if self.monocular and self.tasks[-1] == 'aux': + if self.mode == 'mono' and self.tasks[-1] == 'aux': self.tasks = self.tasks[:-1] self.lambdas = self.lambdas[:-1] @@ -85,28 +85,28 @@ class Trainer: self.mt_loss = MultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks) self.mt_loss.to(self.device) - if not self.monocular: + if not self.mode == 'stereo': input_size = 68 output_size = 10 else: input_size = 34 output_size = 9 - name = 'monoloco_pp' if self.monocular else 'monstereo' + name = 'monoloco_pp' if self.mode == 'mono' else 'monstereo' now = datetime.datetime.now() now_time = now.strftime("%Y%m%d-%H%M")[2:] name_out = name + '-' + now_time if not self.no_save: self.path_model = os.path.join(dir_out, name_out + '.pkl') self.logger = set_logger(os.path.join(dir_logs, name_out)) - self.logger.info("Training arguments: \nepochs: {} \nbatch_size: {} \ndropout: {}" - "\nmonocular: {} \nlearning rate: {} \nscheduler step: {} \nscheduler gamma: {} " - "\ninput_size: {} \noutput_size: {}\nhidden_size: {} \nn_stages: {} " - "\nr_seed: {} \nlambdas: {} \ninput_file: {}" - .format(args.epochs, args.bs, args.dropout, self.monocular, - args.lr, args.sched_step, args.sched_gamma, input_size, - output_size, args.hidden_size, args.n_stage, args.r_seed, - self.lambdas, self.joints)) + self.logger.info( + f'Training arguments: \ninput_file: {self.joints} \nmode: {self.mode} ' + f'\nlearning rate: {args.lr} \nbatch_size: {args.bs}' + f'\nepochs: {args.epochs} \ndropout: {args.dropout} ' + f'\nscheduler step: {args.sched_step} \nscheduler gamma: {args.sched_gamma} ' + f'\ninput_size: {input_size} \noutput_size: {output_size} \nhidden_size: {args.hidden_size}' + f' \nn_stages: {args.n_stage} \n r_seed: {args.r_seed} \nlambdas: {self.lambdas}' + ) else: logging.basicConfig(level=logging.INFO) self.logger = logging.getLogger(__name__) @@ -276,7 +276,7 @@ class Trainer: dic_err[clst]['std'] = errs.std() # (Don't) Save auxiliary task results - if self.monocular: + if self.mode == 'mono': dic_err[clst]['aux'] = 0 dic_err['sigmas'].append(0) else: