refactor parser
This commit is contained in:
parent
75593fe3e0
commit
3c6ebe22c9
@ -81,7 +81,7 @@ class ActivityEvaluator:
|
||||
extension = '.predictions.json'
|
||||
path_pif = os.path.join(self.dir_ann, basename + extension)
|
||||
annotations = open_annotations(path_pif)
|
||||
kk, _ = factory_for_gt(im_size, verbose=False)
|
||||
kk, _ = factory_for_gt(im_size)
|
||||
|
||||
# Collect corresponding gt files (ys_gt: 1 or 0)
|
||||
boxes_gt, ys_gt = parse_gt_collective(self.dir_data, seq, path_pif)
|
||||
|
||||
@ -32,14 +32,13 @@ class GenerateKitti:
|
||||
def __init__(self, args):
|
||||
|
||||
# Load Network
|
||||
self.net = args.net
|
||||
assert args.net in ('monstereo', 'monoloco_pp'), "net not recognized"
|
||||
|
||||
assert args.mode in ('mono', 'stereo'), "mode not recognized"
|
||||
self.net = 'monstereo' if args.mode == 'mono' else 'monoloco_pp'
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
self.model = Loco(
|
||||
model=args.model,
|
||||
net=args.net,
|
||||
mode=args.mode,
|
||||
device=device,
|
||||
n_dropout=args.n_dropout,
|
||||
p_dropout=args.dropout,
|
||||
@ -60,7 +59,6 @@ class GenerateKitti:
|
||||
|
||||
# Add monocular and stereo baselines (they require monoloco as backbone)
|
||||
if args.baselines:
|
||||
|
||||
# Load MonoLoco
|
||||
self.baselines['mono'] = ['monoloco', 'geometric']
|
||||
self.monoloco = Loco(
|
||||
@ -72,7 +70,7 @@ class GenerateKitti:
|
||||
linear_size=256
|
||||
)
|
||||
# Stereo baselines
|
||||
if args.net == 'monstereo':
|
||||
if args.mode == 'stereo':
|
||||
self.baselines['stereo'] = ['pose', 'reid']
|
||||
self.cnt_disparity = defaultdict(int)
|
||||
self.cnt_no_stereo = 0
|
||||
|
||||
@ -45,11 +45,12 @@ class PreprocessKitti:
|
||||
dic_names = defaultdict(lambda: defaultdict(list))
|
||||
dic_std = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
def __init__(self, dir_ann, iou_min, monocular=False):
|
||||
def __init__(self, dir_ann, mode='mono', iou_min=0.3):
|
||||
|
||||
self.dir_ann = dir_ann
|
||||
self.iou_min = iou_min
|
||||
self.monocular = monocular
|
||||
self.mode = mode
|
||||
assert self.mode in ('mono', 'stereo'), "modality not recognized"
|
||||
self.names_gt = tuple(os.listdir(self.dir_gt))
|
||||
self.dir_kk = os.path.join('data', 'kitti', 'calib')
|
||||
self.list_gt = glob.glob(self.dir_gt + '/*.txt')
|
||||
@ -160,7 +161,7 @@ class PreprocessKitti:
|
||||
lab = ys[idx_gt][:-1]
|
||||
|
||||
# Preprocess MonoLoco++
|
||||
if self.monocular:
|
||||
if self.mode == 'mono':
|
||||
inp = preprocess_monoloco(keypoint, kk).view(-1).tolist()
|
||||
lab = normalize_hwl(lab)
|
||||
if ys[idx_gt][10] < 0.5:
|
||||
@ -270,7 +271,7 @@ class PreprocessKitti:
|
||||
print("Ambiguous instances removed: {}".format(cnt_ambiguous))
|
||||
print("Extra pairs created with horizontal flipping: {}\n".format(cnt_extra_pair))
|
||||
|
||||
if not self.monocular:
|
||||
if self.mode == 'stereo':
|
||||
print('Instances with stereo correspondence: {:.1f}% '.format(100 * cnt_pair / cnt_pair_tot))
|
||||
for phase in ['train', 'val']:
|
||||
cnt = cnt_mono[phase] + cnt_stereo[phase]
|
||||
|
||||
@ -62,6 +62,7 @@ def cli():
|
||||
|
||||
# Preprocess input data
|
||||
prep_parser.add_argument('--dir_ann', help='directory of annotations of 2d joints', required=True)
|
||||
prep_parser.add_argument('--mode', help='mono, stereo', default='mono')
|
||||
prep_parser.add_argument('--dataset',
|
||||
help='datasets to preprocess: nuscenes, nuscenes_teaser, nuscenes_mini, kitti',
|
||||
default='kitti')
|
||||
@ -69,7 +70,6 @@ def cli():
|
||||
prep_parser.add_argument('--iou_min', help='minimum iou to match ground truth', type=float, default=0.3)
|
||||
prep_parser.add_argument('--variance', help='new', action='store_true')
|
||||
prep_parser.add_argument('--activity', help='new', action='store_true')
|
||||
prep_parser.add_argument('--monocular', help='new', action='store_true')
|
||||
|
||||
# Training
|
||||
training_parser.add_argument('--joints', help='Json file with input joints', required=True)
|
||||
@ -132,7 +132,7 @@ def main():
|
||||
prep.run()
|
||||
else:
|
||||
from .prep.prep_kitti import PreprocessKitti
|
||||
prep = PreprocessKitti(args.dir_ann, args.iou_min, args.monocular)
|
||||
prep = PreprocessKitti(args.dir_ann, mode=args.mode, iou_min=args.iou_min)
|
||||
if args.activity:
|
||||
prep.prep_activity()
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user