Linting
This commit is contained in:
parent
f302fd5b86
commit
9fe42480c1
@ -1,15 +1,13 @@
|
||||
import pickle
|
||||
import re
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .. import __version__
|
||||
from .transforms import flip_inputs, flip_labels, height_augmentation
|
||||
from ..network.process import preprocess_monoloco
|
||||
|
||||
gt_path = '/scratch/izar/beauvill/casr/data/annotations/casr_annotation.pickle'
|
||||
@ -26,11 +24,9 @@ def bb_intersection_over_union(boxA, boxB):
|
||||
iou = interArea / float(boxAArea + boxBArea - interArea)
|
||||
return iou
|
||||
|
||||
def match_bboxes(bbox_gt, bbox_pred, IOU_THRESH=1):
|
||||
def match_bboxes(bbox_gt, bbox_pred):
|
||||
n_true = bbox_gt.shape[0]
|
||||
n_pred = bbox_pred.shape[0]
|
||||
MAX_DIST = 1.0
|
||||
MIN_IOU = 0.0
|
||||
|
||||
iou_matrix = np.zeros((n_true, n_pred))
|
||||
for i in range(n_true):
|
||||
@ -47,18 +43,20 @@ def load_gt(path=gt_path):
|
||||
|
||||
def load_res(path=res_path):
|
||||
mono = []
|
||||
for dir in sorted(glob.glob(path), key=lambda x:float(re.findall("(\d+)",x)[0])):
|
||||
for folder in sorted(glob.glob(path), key=lambda x:float(re.findall(r"(\d+)",x)[0])):
|
||||
data_list = []
|
||||
for file in sorted(os.listdir(dir), key=lambda x:float(re.findall("(\d+)",x)[0])):
|
||||
for file in sorted(os.listdir(folder), key=lambda x:float(re.findall(r"(\d+)",x)[0])):
|
||||
if 'json' in file:
|
||||
json_path = os.path.join(dir, file)
|
||||
json_path = os.path.join(folder, file)
|
||||
json_data = json.load(open(json_path))
|
||||
json_data['filename'] = json_path
|
||||
data_list.append(json_data)
|
||||
mono.append(data_list)
|
||||
return mono
|
||||
|
||||
def create_dic(gt=load_gt(), res=load_res()):
|
||||
def create_dic():
|
||||
gt=load_gt()
|
||||
res=load_res()
|
||||
dic_jo = {
|
||||
'train': dict(X=[], Y=[], names=[], kps=[]),
|
||||
'val': dict(X=[], Y=[], names=[], kps=[]),
|
||||
@ -66,14 +64,13 @@ def create_dic(gt=load_gt(), res=load_res()):
|
||||
}
|
||||
split = ['3', '4']
|
||||
for i in range(len(res[:])):
|
||||
for j in range(len(res[i][:])):
|
||||
for j in [x for x in range(len(res[i][:])) if 'boxes' in res[i][x]]:
|
||||
folder = gt[i][j]['video_folder']
|
||||
|
||||
phase = 'val'
|
||||
if folder[7] in split:
|
||||
phase = 'train'
|
||||
|
||||
if('boxes' in res[i][j]):
|
||||
gt_box = gt[i][j]['bbox_gt']
|
||||
|
||||
good_idx = match_bboxes(np.array([standard_bbox(gt_box)]), np.array(res[i][j]['boxes'])[:,:4])
|
||||
|
||||
@ -66,7 +66,6 @@ class HypTuningCasr:
|
||||
|
||||
best_acc_val = 20
|
||||
dic_best = {}
|
||||
dic_err_best = {}
|
||||
start = time.time()
|
||||
cnt = 0
|
||||
for idx, lr in enumerate(self.lr_list):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user