monoloco/monoloco/train/datasets.py
Lorenzo Bertoni 8968f3c8a2
Packaging (#6)
* add box visualization

* add box visualization and change thresholds for pif preprocessing

* refactor printer

* change default values

* change confidence definition

* remove redundant function

* add debug plot in preprocessing

* add task error in evaluation

* add horizontal flipping

* add evaluation table

* add evaluation table with verbosity

* add tabulate requirement and command line option verbose

* refactor evaluate

* add task error with mean absolute deviation

* add stereo baseline

* integrate stereo baseline

* refactor factory preprocessing

* add stereo command for evaluation

* fix category bug

* add interquartile range for stereo

* use left tt for translation

* refactor stereo functions

* remvove redundant functions

* change names of constants

* add pixel error as function of depth

* fix bug on output directory

* add now time at the moment of saving

* add person sitting category

* remove box in pifpaf predictions

* fix printing name

* add printing of number of matches

* add cyclist category

* fix assertion error

* add travis file

* working eval

* working eval

* change source file

* renaming

* add pylint file

* fix pylint

* fix import

* add pyc files in gitignore

* pylint fix

* pylint fix

* add pytest cache

* update readme

* fix pylint

* fix pylint

* add travis file

* add pylint in pip install

* fix pylint
2019-07-19 15:39:03 +02:00

57 lines
1.7 KiB
Python

import json
import torch
from torch.utils.data import Dataset
class KeypointsDataset(Dataset):
"""
Dataloader fro nuscenes or kitti datasets
"""
def __init__(self, joints, phase):
"""
Load inputs and outputs from the pickles files from gt joints, mask joints or both
"""
assert(phase in ['train', 'val', 'test'])
with open(joints, 'r') as f:
dic_jo = json.load(f)
# Define input and output for normal training and inference
self.inputs_all = torch.tensor(dic_jo[phase]['X'])
self.outputs_all = torch.tensor(dic_jo[phase]['Y']).view(-1, 1)
self.names_all = dic_jo[phase]['names']
self.kps_all = torch.tensor(dic_jo[phase]['kps'])
# Extract annotations divided in clusters
self.dic_clst = dic_jo[phase]['clst']
def __len__(self):
"""
:return: number of samples (m)
"""
return self.inputs_all.shape[0]
def __getitem__(self, idx):
"""
Reading the tensors when required. E.g. Retrieving one element or one batch at a time
:param idx: corresponding to m
"""
inputs = self.inputs_all[idx, :]
outputs = self.outputs_all[idx]
names = self.names_all[idx]
kps = self.kps_all[idx, :]
return inputs, outputs, names, kps
def get_cluster_annotations(self, clst):
"""Return normalized annotations corresponding to a certain cluster
"""
inputs = torch.tensor(self.dic_clst[clst]['X'])
outputs = torch.tensor(self.dic_clst[clst]['Y']).float()
count = len(self.dic_clst[clst]['Y'])
return inputs, outputs, count