fix pylint

This commit is contained in:
lorenzo 2019-05-16 10:30:52 +02:00
parent bf8fbf5234
commit 3243da72cd
5 changed files with 24 additions and 23 deletions

View File

@ -79,8 +79,6 @@ class PreprocessKitti:
for name in self.names_gt:
# Extract ground truth
if name == '004223.txt':
aa = 5
path_gt = os.path.join(self.dir_gt, name)
basename, _ = os.path.splitext(name)
boxes_gt = []
@ -117,7 +115,7 @@ class PreprocessKitti:
with open(os.path.join(self.dir_ann, basename + '.png.pifpaf.json'), 'r') as f:
annotations = json.load(f)
boxes, keypoints = self.preprocess_pif(annotations)
(inputs, xy_kps), (uv_kps, uv_boxes, _, _) = self.get_input_data(boxes, keypoints, kk)
(inputs, _), (uv_kps, uv_boxes, _, _) = self.get_input_data(boxes, keypoints, kk)
except FileNotFoundError:
uv_boxes = []
@ -138,10 +136,10 @@ class PreprocessKitti:
boxes_gt.pop(idx_max)
dds.pop(idx_max)
with open(self.path_joints, 'w') as f:
json.dump(self.dic_jo, f)
with open(os.path.join(self.path_names), 'w') as f:
json.dump(self.dic_names, f)
with open(self.path_joints, 'w') as file:
json.dump(self.dic_jo, file)
with open(os.path.join(self.path_names), 'w') as file:
json.dump(self.dic_names, file)
for phase in ['train', 'val', 'test']:
print("Saved {} annotations for phase {}"
.format(self.dic_cnt[phase], phase))

View File

@ -1,4 +1,6 @@
import numpy as np
"""Extract joints annotations and match with nuScenes ground truths
"""
import os
import sys
import time
@ -7,6 +9,8 @@ import logging
from collections import defaultdict
import datetime
import numpy as np
class PreprocessNuscenes:
"""
@ -48,7 +52,7 @@ class PreprocessNuscenes:
# Initialize dicts to save joints for training
self.dic_jo = {'train': dict(X=[], Y=[], names=[], kps=[], boxes_3d=[], K=[],
clst=defaultdict(lambda: defaultdict(list))),
'val': dict(X=[], Y=[], names=[], kps=[], boxes_3d=[], K=[],
'val': dict(X=[], Y=[], names=[], kps=[], boxes_3d=[], K=[],
clst=defaultdict(lambda: defaultdict(list))),
'test': dict(X=[], Y=[], names=[], kps=[], boxes_3d=[], K=[],
clst=defaultdict(lambda: defaultdict(list)))
@ -73,12 +77,12 @@ class PreprocessNuscenes:
elif dataset == 'nuscenes_teaser':
self.nusc = NuScenes(version='v1.0-trainval', dataroot=dir_nuscenes, verbose=True)
with open("splits/nuscenes_teaser_scenes.txt", "r") as ff:
teaser_scenes = ff.read().splitlines()
with open("splits/nuscenes_teaser_scenes.txt", "r") as file:
teaser_scenes = file.read().splitlines()
self.scenes = self.nusc.scene
self.scenes = [scene for scene in self.scenes if scene['token'] in teaser_scenes]
with open("splits/split_nuscenes_teaser.json", "r") as ff:
dic_split = json.load(ff)
with open("splits/split_nuscenes_teaser.json", "r") as file:
dic_split = json.load(file)
self.split_train = [scene['name'] for scene in self.scenes if scene['token'] in dic_split['train']]
self.split_val = [scene['name'] for scene in self.scenes if scene['token'] in dic_split['val']]
@ -147,11 +151,11 @@ class PreprocessNuscenes:
exists = os.path.isfile(path_pif)
if exists:
with open(path_pif, 'r') as f:
annotations = json.load(f)
with open(path_pif, 'r') as file:
annotations = json.load(file)
boxes, keypoints = self.preprocess_pif(annotations, im_size=None)
(inputs, xy_kps), (uv_kps, uv_boxes, _, _) = self.get_input_data(boxes, keypoints, kk)
(inputs, _), (uv_kps, uv_boxes, _, _) = self.get_input_data(boxes, keypoints, kk)
for ii, box in enumerate(uv_boxes):
idx_max, iou_max = self.get_idx_max(box, boxes_gt)

View File

@ -70,7 +70,7 @@ def cli():
# Training
training_parser.add_argument('--joints', help='Json file with input joints',
default='data/arrays/joints-nuscenes-190507-0852.json')
default='data/arrays/joints-nuscenes_teaser-190513-1846.json')
training_parser.add_argument('--save', help='whether to not save model and log file', action='store_false')
training_parser.add_argument('-e', '--epochs', type=int, help='number of epochs to train for', default=150)
training_parser.add_argument('--bs', type=int, default=256, help='input batch size')

View File

@ -1,16 +1,16 @@
import torch
import torch.nn as nn
import copy
import numpy as np
import matplotlib.pyplot as plt
import os
import datetime
import logging
from collections import defaultdict
import json
import sys
import time
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler

View File

@ -1,6 +1,5 @@
import os
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse