monoloco/tests/test_train_mono.py
Charles Beauville 8c0ac3c0c5
Better GitHub workflow (#59)
* Update tests.yml

* Renamed test images

* Corrected test

* Fixed README

* Better images names
2021-05-18 10:33:15 +02:00

80 lines
2.5 KiB
Python

"""
Adapted from https://github.com/openpifpaf/openpifpaf/blob/main/tests/test_train.py,
which is: 'Copyright 2019-2021 by Sven Kreiss and contributors. All rights reserved.'
and licensed under GNU AGPLv3
"""
import os
import subprocess
import gdown
OPENPIFPAF_MODEL = 'https://drive.google.com/uc?id=1b408ockhh29OLAED8Tysd2yGZOo0N_SQ'
TRAIN_COMMAND = [
'python3', '-m', 'monoloco.run',
'train',
'--joints', 'tests/sample_joints-kitti-mono.json',
'--lr=0.001',
'-e=10',
]
PREDICT_COMMAND = [
'python3', '-m', 'monoloco.run',
'predict',
'docs/test_002282.png',
'--output_types', 'multi', 'json',
'--decoder-workers=0' # for windows
]
PREDICT_COMMAND_SOCIAL_DISTANCE = [
'python3', '-m', 'monoloco.run',
'predict',
'docs/test_frame0032.jpg',
'--activities', 'social_distance',
'--output_types', 'front', 'bird',
'--decoder-workers=0' # for windows'
]
def test_train_mono(tmp_path):
# train a model
train_cmd = TRAIN_COMMAND + ['--out={}'.format(os.path.join(tmp_path, 'train_test.pkl'))]
print(' '.join(train_cmd))
subprocess.run(train_cmd, check=True, capture_output=True)
print(os.listdir(tmp_path))
# find the trained model checkpoint and download pifpaf one
final_model = next(iter(f for f in os.listdir(tmp_path) if f.endswith('.pkl')))
pifpaf_model = os.path.join(tmp_path, 'pifpaf_model.pkl')
print('Downloading OpenPifPaf model in temporary folder')
gdown.download(OPENPIFPAF_MODEL, pifpaf_model)
# run predictions with that model
model = os.path.join(tmp_path, final_model)
print(model)
predict_cmd = PREDICT_COMMAND + [
'--model={}'.format(model),
'--checkpoint={}'.format(pifpaf_model),
'-o={}'.format(tmp_path),
]
print(' '.join(predict_cmd))
subprocess.run(predict_cmd, check=True, capture_output=True)
print(os.listdir(tmp_path))
assert 'out_test_002282.png.multi.png' in os.listdir(tmp_path)
assert 'out_test_002282.png.monoloco.json' in os.listdir(tmp_path)
predict_cmd_sd = PREDICT_COMMAND_SOCIAL_DISTANCE + [
'--model={}'.format(model),
'--checkpoint={}'.format(pifpaf_model),
'-o={}'.format(tmp_path),
]
print(' '.join(predict_cmd_sd))
subprocess.run(predict_cmd_sd, check=True, capture_output=True)
print(os.listdir(tmp_path))
assert 'out_test_frame0032.jpg.front.png' in os.listdir(tmp_path)
assert 'out_test_frame0032.jpg.bird.png' in os.listdir(tmp_path)