add pytest at package level (#10)

* training pytests

* add json sample file

* add  pytest compatibility

* add compatibility with package test

* add reference for pytorch version
This commit is contained in:
Lorenzo Bertoni 2019-08-07 17:58:34 +02:00 committed by GitHub
parent f23a2e34f5
commit aef978d231
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 157 additions and 100 deletions

1
.gitignore vendored
View File

@ -7,4 +7,5 @@ Monoloco/*.pyc
dist/
build/
*.egg-info
tests/*.png

View File

@ -8,4 +8,4 @@ install:
- pip install ".[test]"
script:
- pylint monoloco --disable=unused-variable,fixme
- pytest -vv
- pytest -v

View File

@ -26,7 +26,7 @@ class GenerateKitti:
# Load monoloco
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
self.monoloco = MonoLoco(model_path=model, device=device, n_dropout=n_dropout, p_dropout=p_dropout)
self.monoloco = MonoLoco(model=model, device=device, n_dropout=n_dropout, p_dropout=p_dropout)
self.dir_out = os.path.join('data', 'kitti', 'monoloco')
self.dir_ann = dir_ann

View File

@ -2,84 +2,13 @@
import torch.nn as nn
class TriLinear(nn.Module):
"""
As Bilinear but without skip connection
"""
def __init__(self, input_size, output_size, p_dropout, linear_size=1024):
super(TriLinear, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.l_size = linear_size
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p_dropout)
self.w1 = nn.Linear(self.input_size, self.l_size)
self.batch_norm1 = nn.BatchNorm1d(self.l_size)
self.w2 = nn.Linear(self.l_size, self.l_size)
self.batch_norm2 = nn.BatchNorm1d(self.l_size)
self.w3 = nn.Linear(self.l_size, self.output_size)
def forward(self, x):
y = self.w1(x)
y = self.batch_norm1(y)
y = self.relu(y)
y = self.dropout(y)
y = self.w2(y)
y = self.batch_norm2(y)
y = self.relu(y)
y = self.dropout(y)
y = self.w3(y)
return y
def weight_init(batch):
"""TO initialize weights using kaiming initialization"""
if isinstance(batch, nn.Linear):
nn.init.kaiming_normal_(batch.weight)
class Linear(nn.Module):
def __init__(self, linear_size, p_dropout=0.5):
super(Linear, self).__init__()
self.l_size = linear_size
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p_dropout)
self.w1 = nn.Linear(self.l_size, self.l_size)
self.batch_norm1 = nn.BatchNorm1d(self.l_size)
self.w2 = nn.Linear(self.l_size, self.l_size)
self.batch_norm2 = nn.BatchNorm1d(self.l_size)
def forward(self, x):
y = self.w1(x)
y = self.batch_norm1(y)
y = self.relu(y)
y = self.dropout(y)
y = self.w2(y)
y = self.batch_norm2(y)
y = self.relu(y)
y = self.dropout(y)
out = x + y
return out
class LinearModel(nn.Module):
"""
Architecture inspired by https://github.com/una-dinosauria/3d-pose-baseline
Pytorch implementation from: https://github.com/weigq/3d_pose_baseline_pytorch
"""
"""Class from Simple yet effective baseline"""
def __init__(self, input_size, output_size, linear_size=256, p_dropout=0.2, num_stage=3):
def __init__(self, input_size, output_size=2, linear_size=256, p_dropout=0.2, num_stage=3):
super(LinearModel, self).__init__()
self.input_size = input_size
@ -114,3 +43,33 @@ class LinearModel(nn.Module):
y = self.linear_stages[i](y)
y = self.w2(y)
return y
class Linear(nn.Module):
def __init__(self, linear_size, p_dropout=0.5):
super(Linear, self).__init__()
self.l_size = linear_size
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p_dropout)
self.w1 = nn.Linear(self.l_size, self.l_size)
self.batch_norm1 = nn.BatchNorm1d(self.l_size)
self.w2 = nn.Linear(self.l_size, self.l_size)
self.batch_norm2 = nn.BatchNorm1d(self.l_size)
def forward(self, x):
y = self.w1(x)
y = self.batch_norm1(y)
y = self.relu(y)
y = self.dropout(y)
y = self.w2(y)
y = self.batch_norm2(y)
y = self.relu(y)
y = self.dropout(y)
out = x + y
return out

View File

@ -17,22 +17,28 @@ class MonoLoco:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
OUTPUT_SIZE = 2
INPUT_SIZE = 17 * 2
LINEAR_SIZE = 256
N_SAMPLES = 100
def __init__(self, model_path, device, n_dropout=0, p_dropout=0.2):
def __init__(self, model, device=None, n_dropout=0, p_dropout=0.2):
self.device = device
if not device:
self.device = torch.device('cpu')
else:
self.device = device
self.n_dropout = n_dropout
self.epistemic = bool(self.n_dropout > 0)
# load the model parameters
self.model = LinearModel(p_dropout=p_dropout,
input_size=self.INPUT_SIZE, output_size=self.OUTPUT_SIZE, linear_size=self.LINEAR_SIZE,
)
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
# if the path is provided load the model parameters
if isinstance(model, str):
model_path = model
self.model = LinearModel(p_dropout=p_dropout, input_size=self.INPUT_SIZE, linear_size=self.LINEAR_SIZE)
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
# if the model is directly provided
else:
self.model = model
self.model.eval() # Default is train
self.model.to(self.device)
@ -63,7 +69,7 @@ class MonoLoco:
return outputs, varss
@staticmethod
def post_process(outputs, varss, boxes, keypoints, kk, dic_gt, iou_min=0.3):
def post_process(outputs, varss, boxes, keypoints, kk, dic_gt=None, iou_min=0.3):
"""Post process monoloco to output final dictionary with all information for visualizations"""
dic_out = defaultdict(list)

View File

@ -17,7 +17,7 @@ def predict(args):
# load pifpaf and monoloco models
pifpaf = PifPaf(args)
monoloco = MonoLoco(model_path=args.model, device=args.device, n_dropout=args.n_dropout, p_dropout=args.dropout)
monoloco = MonoLoco(model=args.model, device=args.device, n_dropout=args.n_dropout, p_dropout=args.dropout)
# data
data = ImageList(args.images, scale=args.scale)

View File

@ -6,9 +6,7 @@ from openpifpaf import decoder
def cli():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Subparser definition
subparsers = parser.add_subparsers(help='Different parsers for main actions', dest='command')
@ -22,8 +20,7 @@ def cli():
prep_parser.add_argument('--dataset',
help='datasets to preprocess: nuscenes, nuscenes_teaser, nuscenes_mini, kitti',
default='nuscenes')
prep_parser.add_argument('--dir_nuscenes', help='directory of nuscenes devkit',
default='data/nuscenes/')
prep_parser.add_argument('--dir_nuscenes', help='directory of nuscenes devkit', default='data/nuscenes/')
prep_parser.add_argument('--iou_min', help='minimum iou to match ground truth', type=float, default=0.3)
# Predict (2D pose and/or 3D location from images)

View File

@ -12,6 +12,7 @@ import logging
from collections import defaultdict
import sys
import time
import warnings
import matplotlib.pyplot as plt
import torch
@ -36,10 +37,11 @@ class Trainer:
# Initialize directories and parameters
dir_out = os.path.join('data', 'models')
assert os.path.exists(dir_out), "Output directory not found"
if not os.path.exists(dir_out):
warnings.warn("Warning: output directory not found, the model will not be saved")
dir_logs = os.path.join('data', 'logs')
if not os.path.exists(dir_logs):
os.makedirs(dir_logs)
warnings.warn("Warning: default logs directory not found")
assert os.path.exists(joints), "Input file not found"
self.joints = joints

View File

@ -34,7 +34,7 @@ class Printer:
self.fig_width = fig_width
# Define the output dir
self.path_out = output_path
self.output_path = output_path
self.cmap = cm.get_cmap('jet')
self.extensions = []
@ -54,9 +54,7 @@ class Printer:
self.zz_pred = [xx[2] if xx[2] < self.z_max - self.stds_epi[idx] else 0
for idx, xx in enumerate(dic_ann['xyz_pred'])]
self.dds_real = dic_ann['dds_real']
self.uv_centers = dic_ann['uv_centers']
self.uv_shoulders = dic_ann['uv_shoulders']
self.uv_kps = dic_ann['uv_kps']
self.boxes = dic_ann['boxes']
self.boxes_gt = dic_ann['boxes_gt']
@ -176,7 +174,7 @@ class Printer:
for idx, fig in enumerate(figures):
fig.canvas.draw()
if save:
fig.savefig(self.path_out + self.extensions[idx], bbox_inches='tight')
fig.savefig(self.output_path + self.extensions[idx], bbox_inches='tight')
if show:
fig.show()

View File

@ -28,7 +28,7 @@ def webcam(args):
# load models
args.camera = True
pifpaf = PifPaf(args)
monoloco = MonoLoco(model_path=args.model, device=args.device)
monoloco = MonoLoco(model=args.model, device=args.device)
# Start recording
cam = cv2.VideoCapture(0)

File diff suppressed because one or more lines are too long

1
tests/joints_sample.json Normal file

File diff suppressed because one or more lines are too long

69
tests/test_package.py Normal file
View File

@ -0,0 +1,69 @@
"""Test if the main modules of the package run correctly"""
import os
import sys
import json
# Python does not consider the current directory to be a package
sys.path.insert(0, os.path.join('..', 'monoloco'))
from PIL import Image
from monoloco.train import Trainer
from monoloco.network import MonoLoco
from monoloco.network.process import preprocess_pifpaf, factory_for_gt
from monoloco.visuals.printer import Printer
JOINTS = 'tests/joints_sample.json'
PIFPAF_KEYPOINTS = 'tests/002282.png.pifpaf.json'
IMAGE = 'docs/002282.png'
def tst_trainer(joints):
trainer = Trainer(joints=joints, epochs=150, lr=0.01)
_ = trainer.train()
dic_err, model = trainer.evaluate()
return dic_err['val']['all']['mean'], model
def tst_prediction(model, path_keypoints):
with open(path_keypoints, 'r') as f:
pifpaf_out = json.load(f)
kk, _ = factory_for_gt(im_size=[1240, 340])
# Preprocess pifpaf outputs and run monoloco
boxes, keypoints = preprocess_pifpaf(pifpaf_out)
monoloco = MonoLoco(model)
outputs, varss = monoloco.forward(keypoints, kk)
dic_out = monoloco.post_process(outputs, varss, boxes, keypoints, kk)
return dic_out, kk
def tst_printer(dic_out, kk, image_path):
"""Draw a fake figure"""
with open(image_path, 'rb') as f:
pil_image = Image.open(f).convert('RGB')
printer = Printer(image=pil_image, output_path='tests/test_image', kk=kk, output_types=['combined'], z_max=15)
figures, axes = printer.factory_axes()
printer.draw(figures, axes, dic_out, pil_image, save=True)
def test_package():
# Training test
val_acc, model = tst_trainer(JOINTS)
assert val_acc < 2
# Prediction test
dic_out, kk = tst_prediction(model, PIFPAF_KEYPOINTS)
assert dic_out['boxes'] and kk
# Visualization test
tst_printer(dic_out, kk, IMAGE)

23
tests/test_visuals.py Normal file
View File

@ -0,0 +1,23 @@
import os
import sys
from collections import defaultdict
from PIL import Image
# Python does not consider the current directory to be a package
sys.path.insert(0, os.path.join('..', 'monoloco'))
def test_printer():
"""Draw a fake figure"""
from monoloco.visuals.printer import Printer
test_list = [[718.3351, 0., 600.3891], [0., 718.3351, 181.5122], [0., 0., 1.]]
boxes = [xx + [0] for xx in test_list]
kk = test_list
dict_ann = defaultdict(lambda: [1., 2., 3.], xyz_real=test_list, xyz_pred=test_list, uv_shoulders=test_list,
boxes=boxes, boxes_gt=boxes)
with open('docs/002282.png', 'rb') as f:
pil_image = Image.open(f).convert('RGB')
printer = Printer(image=pil_image, output_path=None, kk=kk, output_types=['combined'])
figures, axes = printer.factory_axes()
printer.draw(figures, axes, dict_ann, pil_image)