* add box visualization * add box visualization and change thresholds for pif preprocessing * refactor printer * change default values * change confidence definition * remove redundant function * add debug plot in preprocessing * add task error in evaluation * add horizontal flipping * add evaluation table * add evaluation table with verbosity * add tabulate requirement and command line option verbose * refactor evaluate * add task error with mean absolute deviation * add stereo baseline * integrate stereo baseline * refactor factory preprocessing * add stereo command for evaluation * fix category bug * add interquartile range for stereo * use left tt for translation * refactor stereo functions * remvove redundant functions * change names of constants * add pixel error as function of depth * fix bug on output directory * add now time at the moment of saving * add person sitting category * remove box in pifpaf predictions * fix printing name * add printing of number of matches * add cyclist category * fix assertion error * add travis file * working eval * working eval * change source file * renaming * add pylint file * fix pylint * fix import * add pyc files in gitignore * pylint fix * pylint fix * add pytest cache * update readme * fix pylint * fix pylint * add travis file * add pylint in pip install * fix pylint
57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
|
|
import json
|
|
import torch
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class KeypointsDataset(Dataset):
|
|
"""
|
|
Dataloader fro nuscenes or kitti datasets
|
|
"""
|
|
|
|
def __init__(self, joints, phase):
|
|
"""
|
|
Load inputs and outputs from the pickles files from gt joints, mask joints or both
|
|
"""
|
|
assert(phase in ['train', 'val', 'test'])
|
|
|
|
with open(joints, 'r') as f:
|
|
dic_jo = json.load(f)
|
|
|
|
# Define input and output for normal training and inference
|
|
self.inputs_all = torch.tensor(dic_jo[phase]['X'])
|
|
self.outputs_all = torch.tensor(dic_jo[phase]['Y']).view(-1, 1)
|
|
self.names_all = dic_jo[phase]['names']
|
|
self.kps_all = torch.tensor(dic_jo[phase]['kps'])
|
|
|
|
# Extract annotations divided in clusters
|
|
self.dic_clst = dic_jo[phase]['clst']
|
|
|
|
def __len__(self):
|
|
"""
|
|
:return: number of samples (m)
|
|
"""
|
|
return self.inputs_all.shape[0]
|
|
|
|
def __getitem__(self, idx):
|
|
"""
|
|
Reading the tensors when required. E.g. Retrieving one element or one batch at a time
|
|
:param idx: corresponding to m
|
|
"""
|
|
inputs = self.inputs_all[idx, :]
|
|
outputs = self.outputs_all[idx]
|
|
names = self.names_all[idx]
|
|
kps = self.kps_all[idx, :]
|
|
|
|
return inputs, outputs, names, kps
|
|
|
|
def get_cluster_annotations(self, clst):
|
|
"""Return normalized annotations corresponding to a certain cluster
|
|
"""
|
|
inputs = torch.tensor(self.dic_clst[clst]['X'])
|
|
outputs = torch.tensor(self.dic_clst[clst]['Y']).float()
|
|
count = len(self.dic_clst[clst]['Y'])
|
|
|
|
return inputs, outputs, count
|