93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
|
|
import json
|
|
import torch
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class ActivityDataset(Dataset):
|
|
"""
|
|
Dataloader for activity dataset
|
|
"""
|
|
|
|
def __init__(self, joints, phase):
|
|
"""
|
|
Load inputs and outputs from the pickles files from gt joints, mask joints or both
|
|
"""
|
|
assert(phase in ['train', 'val', 'test'])
|
|
|
|
with open(joints, 'r') as f:
|
|
dic_jo = json.load(f)
|
|
|
|
# Define input and output for normal training and inference
|
|
self.inputs_all = torch.tensor(dic_jo[phase]['X'])
|
|
self.outputs_all = torch.tensor(dic_jo[phase]['Y']).view(-1, 1)
|
|
# self.kps_all = torch.tensor(dic_jo[phase]['kps'])
|
|
|
|
def __len__(self):
|
|
"""
|
|
:return: number of samples (m)
|
|
"""
|
|
return self.inputs_all.shape[0]
|
|
|
|
def __getitem__(self, idx):
|
|
"""
|
|
Reading the tensors when required. E.g. Retrieving one element or one batch at a time
|
|
:param idx: corresponding to m
|
|
"""
|
|
inputs = self.inputs_all[idx, :]
|
|
outputs = self.outputs_all[idx]
|
|
# kps = self.kps_all[idx, :]
|
|
return inputs, outputs
|
|
|
|
|
|
class KeypointsDataset(Dataset):
|
|
"""
|
|
Dataloader fro nuscenes or kitti datasets
|
|
"""
|
|
|
|
def __init__(self, joints, phase):
|
|
"""
|
|
Load inputs and outputs from the pickles files from gt joints, mask joints or both
|
|
"""
|
|
assert(phase in ['train', 'val', 'test'])
|
|
|
|
with open(joints, 'r') as f:
|
|
dic_jo = json.load(f)
|
|
|
|
# Define input and output for normal training and inference
|
|
self.inputs_all = torch.tensor(dic_jo[phase]['X'])
|
|
self.outputs_all = torch.tensor(dic_jo[phase]['Y'])
|
|
self.names_all = dic_jo[phase]['names']
|
|
self.kps_all = torch.tensor(dic_jo[phase]['kps'])
|
|
|
|
# Extract annotations divided in clusters
|
|
self.dic_clst = dic_jo[phase]['clst']
|
|
|
|
def __len__(self):
|
|
"""
|
|
:return: number of samples (m)
|
|
"""
|
|
return self.inputs_all.shape[0]
|
|
|
|
def __getitem__(self, idx):
|
|
"""
|
|
Reading the tensors when required. E.g. Retrieving one element or one batch at a time
|
|
:param idx: corresponding to m
|
|
"""
|
|
inputs = self.inputs_all[idx, :]
|
|
outputs = self.outputs_all[idx]
|
|
names = self.names_all[idx]
|
|
kps = self.kps_all[idx, :]
|
|
|
|
return inputs, outputs, names, kps
|
|
|
|
def get_cluster_annotations(self, clst):
|
|
"""Return normalized annotations corresponding to a certain cluster
|
|
"""
|
|
inputs = torch.tensor(self.dic_clst[clst]['X'])
|
|
outputs = torch.tensor(self.dic_clst[clst]['Y']).float()
|
|
count = len(self.dic_clst[clst]['Y'])
|
|
|
|
return inputs, outputs, count
|