monoloco/monstereo/train/datasets.py
2020-08-20 11:33:19 +02:00

93 lines
2.7 KiB
Python

import json
import torch
from torch.utils.data import Dataset
class ActivityDataset(Dataset):
"""
Dataloader for activity dataset
"""
def __init__(self, joints, phase):
"""
Load inputs and outputs from the pickles files from gt joints, mask joints or both
"""
assert(phase in ['train', 'val', 'test'])
with open(joints, 'r') as f:
dic_jo = json.load(f)
# Define input and output for normal training and inference
self.inputs_all = torch.tensor(dic_jo[phase]['X'])
self.outputs_all = torch.tensor(dic_jo[phase]['Y']).view(-1, 1)
# self.kps_all = torch.tensor(dic_jo[phase]['kps'])
def __len__(self):
"""
:return: number of samples (m)
"""
return self.inputs_all.shape[0]
def __getitem__(self, idx):
"""
Reading the tensors when required. E.g. Retrieving one element or one batch at a time
:param idx: corresponding to m
"""
inputs = self.inputs_all[idx, :]
outputs = self.outputs_all[idx]
# kps = self.kps_all[idx, :]
return inputs, outputs
class KeypointsDataset(Dataset):
"""
Dataloader fro nuscenes or kitti datasets
"""
def __init__(self, joints, phase):
"""
Load inputs and outputs from the pickles files from gt joints, mask joints or both
"""
assert(phase in ['train', 'val', 'test'])
with open(joints, 'r') as f:
dic_jo = json.load(f)
# Define input and output for normal training and inference
self.inputs_all = torch.tensor(dic_jo[phase]['X'])
self.outputs_all = torch.tensor(dic_jo[phase]['Y'])
self.names_all = dic_jo[phase]['names']
self.kps_all = torch.tensor(dic_jo[phase]['kps'])
# Extract annotations divided in clusters
self.dic_clst = dic_jo[phase]['clst']
def __len__(self):
"""
:return: number of samples (m)
"""
return self.inputs_all.shape[0]
def __getitem__(self, idx):
"""
Reading the tensors when required. E.g. Retrieving one element or one batch at a time
:param idx: corresponding to m
"""
inputs = self.inputs_all[idx, :]
outputs = self.outputs_all[idx]
names = self.names_all[idx]
kps = self.kps_all[idx, :]
return inputs, outputs, names, kps
def get_cluster_annotations(self, clst):
"""Return normalized annotations corresponding to a certain cluster
"""
inputs = torch.tensor(self.dic_clst[clst]['X'])
outputs = torch.tensor(self.dic_clst[clst]['Y']).float()
count = len(self.dic_clst[clst]['Y'])
return inputs, outputs, count