monoloco/tests/test_train_stereo.py
Lorenzo Bertoni 934622bc81
Lint (#50)
- Add continuous integration
- Add Versioneer
- Refactor of preprocessing
- Add tables of evaluation
2021-04-22 15:43:51 +02:00

60 lines
1.7 KiB
Python

"""
Adapted from https://github.com/openpifpaf/openpifpaf/blob/main/tests/test_train.py,
which is: 'Copyright 2019-2021 by Sven Kreiss and contributors. All rights reserved.'
and licensed under GNU AGPLv3
"""
import os
import subprocess
import gdown
OPENPIFPAF_MODEL = 'https://drive.google.com/uc?id=1b408ockhh29OLAED8Tysd2yGZOo0N_SQ'
TRAIN_COMMAND = [
'python3', '-m', 'monoloco.run',
'train',
'--mode=stereo',
'--joints', 'tests/sample_joints-kitti-stereo.json',
'--lr=0.001',
'-e=20',
]
PREDICT_COMMAND = [
'python3', '-m', 'monoloco.run',
'predict',
'--mode=stereo',
'--glob', 'docs/000840*.png',
'--output_types', 'multi', 'json',
'--decoder-workers=0', # for windows'
]
def test_train_stereo(tmp_path):
# train a model
train_cmd = TRAIN_COMMAND + ['--out={}'.format(os.path.join(tmp_path, 'train_test.pkl'))]
print(' '.join(train_cmd))
subprocess.run(train_cmd, check=True, capture_output=True)
print(os.listdir(tmp_path))
# find the trained model checkpoint
final_model = next(iter(f for f in os.listdir(tmp_path) if f.endswith('.pkl')))
pifpaf_model = os.path.join(tmp_path, 'pifpaf_model.pkl')
print('Downloading OpenPifPaf model in temporary folder')
gdown.download(OPENPIFPAF_MODEL, pifpaf_model)
# run predictions with that model
model = os.path.join(tmp_path, final_model)
predict_cmd = PREDICT_COMMAND + [
'--model={}'.format(model),
'--checkpoint={}'.format(pifpaf_model),
'-o={}'.format(tmp_path),
]
print(' '.join(predict_cmd))
subprocess.run(predict_cmd, check=True, capture_output=True)
print(os.listdir(tmp_path))
assert 'out_000840.png.multi.png' in os.listdir(tmp_path)