refactor trainer
This commit is contained in:
parent
453e4b7b24
commit
75593fe3e0
@ -50,11 +50,11 @@ class Trainer:
|
||||
warnings.warn("Warning: default logs directory not found")
|
||||
assert os.path.exists(args.joints), "Input file not found"
|
||||
|
||||
self.mode = args.mode
|
||||
self.joints = args.joints
|
||||
self.num_epochs = args.epochs
|
||||
self.no_save = args.no_save
|
||||
self.print_loss = args.print_loss
|
||||
self.monocular = args.monocular
|
||||
self.lr = args.lr
|
||||
self.sched_step = args.sched_step
|
||||
self.sched_gamma = args.sched_gamma
|
||||
@ -73,7 +73,7 @@ class Trainer:
|
||||
torch.cuda.manual_seed(self.r_seed)
|
||||
|
||||
# Remove auxiliary task if monocular
|
||||
if self.monocular and self.tasks[-1] == 'aux':
|
||||
if self.mode == 'mono' and self.tasks[-1] == 'aux':
|
||||
self.tasks = self.tasks[:-1]
|
||||
self.lambdas = self.lambdas[:-1]
|
||||
|
||||
@ -85,28 +85,28 @@ class Trainer:
|
||||
self.mt_loss = MultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
|
||||
self.mt_loss.to(self.device)
|
||||
|
||||
if not self.monocular:
|
||||
if not self.mode == 'stereo':
|
||||
input_size = 68
|
||||
output_size = 10
|
||||
else:
|
||||
input_size = 34
|
||||
output_size = 9
|
||||
|
||||
name = 'monoloco_pp' if self.monocular else 'monstereo'
|
||||
name = 'monoloco_pp' if self.mode == 'mono' else 'monstereo'
|
||||
now = datetime.datetime.now()
|
||||
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
||||
name_out = name + '-' + now_time
|
||||
if not self.no_save:
|
||||
self.path_model = os.path.join(dir_out, name_out + '.pkl')
|
||||
self.logger = set_logger(os.path.join(dir_logs, name_out))
|
||||
self.logger.info("Training arguments: \nepochs: {} \nbatch_size: {} \ndropout: {}"
|
||||
"\nmonocular: {} \nlearning rate: {} \nscheduler step: {} \nscheduler gamma: {} "
|
||||
"\ninput_size: {} \noutput_size: {}\nhidden_size: {} \nn_stages: {} "
|
||||
"\nr_seed: {} \nlambdas: {} \ninput_file: {}"
|
||||
.format(args.epochs, args.bs, args.dropout, self.monocular,
|
||||
args.lr, args.sched_step, args.sched_gamma, input_size,
|
||||
output_size, args.hidden_size, args.n_stage, args.r_seed,
|
||||
self.lambdas, self.joints))
|
||||
self.logger.info(
|
||||
f'Training arguments: \ninput_file: {self.joints} \nmode: {self.mode} '
|
||||
f'\nlearning rate: {args.lr} \nbatch_size: {args.bs}'
|
||||
f'\nepochs: {args.epochs} \ndropout: {args.dropout} '
|
||||
f'\nscheduler step: {args.sched_step} \nscheduler gamma: {args.sched_gamma} '
|
||||
f'\ninput_size: {input_size} \noutput_size: {output_size} \nhidden_size: {args.hidden_size}'
|
||||
f' \nn_stages: {args.n_stage} \n r_seed: {args.r_seed} \nlambdas: {self.lambdas}'
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
@ -276,7 +276,7 @@ class Trainer:
|
||||
dic_err[clst]['std'] = errs.std()
|
||||
|
||||
# (Don't) Save auxiliary task results
|
||||
if self.monocular:
|
||||
if self.mode == 'mono':
|
||||
dic_err[clst]['aux'] = 0
|
||||
dic_err['sigmas'].append(0)
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user