Cyclist intention recognition
This commit is contained in:
parent
a8da927658
commit
f2271229f6
@ -68,6 +68,139 @@ def social_interactions(idx, centers, angles, dds, stds=None, social_distance=Fa
|
||||
return False
|
||||
|
||||
|
||||
def is_turning(kp):
|
||||
"""
|
||||
Returns flag if a cyclist is turning
|
||||
"""
|
||||
x=0
|
||||
y=1
|
||||
|
||||
nose = 0
|
||||
l_ear = 3
|
||||
r_ear = 4
|
||||
l_shoulder = 5
|
||||
l_elbow = 7
|
||||
l_hand = 9
|
||||
r_shoulder = 6
|
||||
r_elbow = 8
|
||||
r_hand = 10
|
||||
|
||||
head_width = kp[x][l_ear]- kp[x][r_ear]
|
||||
head_top = (kp[y][nose] - head_width)
|
||||
|
||||
l_forearm = [kp[x][l_hand] - kp[x][l_elbow], kp[y][l_hand] - kp[y][l_elbow]]
|
||||
l_arm = [kp[x][l_shoulder] - kp[x][l_elbow], kp[y][l_shoulder] - kp[y][l_elbow]]
|
||||
|
||||
r_forearm = [kp[x][r_hand] - kp[x][r_elbow], kp[y][r_hand] - kp[y][r_elbow]]
|
||||
r_arm = [kp[x][r_shoulder] - kp[x][r_elbow], kp[y][r_shoulder] - kp[y][r_elbow]]
|
||||
|
||||
l_angle = (90/np.pi) * np.arccos(np.dot(l_forearm/np.linalg.norm(l_forearm), l_arm/np.linalg.norm(l_arm)))
|
||||
r_angle = (90/np.pi) * np.arccos(np.dot(r_forearm/np.linalg.norm(r_forearm), r_arm/np.linalg.norm(r_arm)))
|
||||
|
||||
if kp[x][l_shoulder] > kp[x][r_shoulder]:
|
||||
is_left = kp[x][l_hand] > kp[x][l_shoulder] + np.linalg.norm(l_arm)
|
||||
is_right = kp[x][r_hand] < kp[x][r_shoulder] - np.linalg.norm(r_arm)
|
||||
l_too_close = kp[x][l_hand] > kp[x][l_shoulder] and kp[y][l_hand]>=head_top
|
||||
r_too_close = kp[x][r_hand] < kp[x][r_shoulder] and kp[y][r_hand]>=head_top
|
||||
else:
|
||||
is_left = kp[x][l_hand] < kp[x][l_shoulder] - np.linalg.norm(l_arm)
|
||||
is_right = kp[x][r_hand] > kp[x][r_shoulder] + np.linalg.norm(r_arm)
|
||||
l_too_close = kp[x][l_hand] <= kp[x][l_shoulder] and kp[y][l_hand]>=head_top
|
||||
r_too_close = kp[x][r_hand] >= kp[x][r_shoulder] and kp[y][r_hand]>=head_top
|
||||
|
||||
|
||||
is_l_up = kp[y][l_hand] < kp[y][l_shoulder]
|
||||
is_r_up = kp[y][r_hand] < kp[y][r_shoulder]
|
||||
|
||||
is_l_down = kp[y][l_hand] > kp[y][l_elbow]
|
||||
is_r_down = kp[y][r_hand] > kp[y][r_elbow]
|
||||
|
||||
is_left_risen = is_l_up and l_angle >= 30 and not l_too_close
|
||||
is_right_risen = is_r_up and r_angle >= 30 and not r_too_close
|
||||
|
||||
is_left_down = is_l_up and l_angle >= 30 and not l_too_close
|
||||
is_right_down = is_r_up and r_angle >= 30 and not r_too_close
|
||||
|
||||
if is_left and l_angle >= 40 and not(is_left_risen or is_right_risen):
|
||||
return 'left'
|
||||
|
||||
if is_right and r_angle >= 40 or (is_left_risen or is_right_risen):
|
||||
return 'right'
|
||||
|
||||
if is_left_down or is_right_down:
|
||||
return 'stop'
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_phoning(kp):
|
||||
"""
|
||||
Returns flag of alert if someone is using their phone
|
||||
"""
|
||||
x=0
|
||||
y=1
|
||||
|
||||
nose = 0
|
||||
l_ear = 3
|
||||
l_shoulder = 5
|
||||
l_elbow = 7
|
||||
l_hand = 9
|
||||
r_ear = 4
|
||||
r_shoulder = 6
|
||||
r_elbow = 8
|
||||
r_hand = 10
|
||||
|
||||
head_width = kp[x][l_ear]- kp[x][r_ear]
|
||||
head_top = (kp[y][nose] - head_width)
|
||||
|
||||
l_forearm = [kp[x][l_hand] - kp[x][l_elbow], kp[y][l_hand] - kp[y][l_elbow]]
|
||||
l_arm = [kp[x][l_shoulder] - kp[x][l_elbow], kp[y][l_shoulder] - kp[y][l_elbow]]
|
||||
|
||||
r_forearm = [kp[x][r_hand] - kp[x][r_elbow], kp[y][r_hand] - kp[y][r_elbow]]
|
||||
r_arm = [kp[x][r_shoulder] - kp[x][r_elbow], kp[y][r_shoulder] - kp[y][r_elbow]]
|
||||
|
||||
l_angle = (90/np.pi) * np.arccos(np.dot(l_forearm/np.linalg.norm(l_forearm), l_arm/np.linalg.norm(l_arm)))
|
||||
r_angle = (90/np.pi) * np.arccos(np.dot(r_forearm/np.linalg.norm(r_forearm), r_arm/np.linalg.norm(r_arm)))
|
||||
|
||||
is_l_up = kp[y][l_hand] < kp[y][l_shoulder]
|
||||
is_r_up = kp[y][r_hand] < kp[y][r_shoulder]
|
||||
|
||||
l_too_close = kp[x][l_hand] <= kp[x][l_shoulder] and kp[y][l_hand]>=head_top
|
||||
r_too_close = kp[x][r_hand] >= kp[x][r_shoulder] and kp[y][r_hand]>=head_top
|
||||
|
||||
is_left_phone = is_l_up and l_angle <= 30 and l_too_close
|
||||
is_right_phone = is_r_up and r_angle <= 30 and r_too_close
|
||||
|
||||
print("Top of head y is :", head_top)
|
||||
print("Nose height :", kp[y][nose])
|
||||
print("Right elbow x: {} and y: {}".format(kp[x][r_elbow], kp[y][r_elbow]))
|
||||
print("Left elbow x: {} and y: {}".format(kp[x][l_elbow], kp[y][l_elbow]))
|
||||
|
||||
print("Right shoulder height :", kp[y][r_shoulder])
|
||||
print("Left shoulder height :", kp[y][l_shoulder])
|
||||
|
||||
print("Left hand x = ", kp[x][l_hand])
|
||||
print("Left hand y = ", kp[y][l_hand])
|
||||
|
||||
print("Is left hand up : ", is_l_up)
|
||||
|
||||
print("Right hand x = ", kp[x][r_hand])
|
||||
print("Right hand y = ", kp[y][r_hand])
|
||||
|
||||
print("Is right hand up : ", is_r_up)
|
||||
|
||||
print("Left arm angle :", l_angle)
|
||||
print("Right arm angle :", r_angle)
|
||||
|
||||
print("Is left hand close to head :", l_too_close)
|
||||
print("Is right hand close to head:", r_too_close)
|
||||
|
||||
if is_left_phone or is_right_phone:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_raising_hand(kp):
|
||||
"""
|
||||
Returns flag of alert if someone raises their hand
|
||||
@ -175,6 +308,8 @@ def show_activities(args, image_t, output_path, annotations, dic_out):
|
||||
if 'social_distance' in args.activities:
|
||||
colors = social_distance_colors(colors, dic_out)
|
||||
|
||||
print("Size of the image :", image_t.size)
|
||||
|
||||
angles = dic_out['angles']
|
||||
stds = dic_out['stds_ale']
|
||||
xz_centers = [[xx[0], xx[2]] for xx in dic_out['xyz_pred']]
|
||||
@ -190,7 +325,6 @@ def show_activities(args, image_t, output_path, annotations, dic_out):
|
||||
r_h = 'none'
|
||||
if 'raise_hand' in args.activities:
|
||||
r_h = dic_out['raising_hand']
|
||||
print("RAISE_HAND :", r_h)
|
||||
|
||||
with image_canvas(image_t,
|
||||
output_path + '.front.png',
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
|
||||
from .net import Loco
|
||||
from .process import unnormalize_bi, extract_outputs, extract_labels, extract_labels_aux
|
||||
from .process import unnormalize_bi, extract_outputs, extract_labels, extract_labels_aux, extract_labels_cyclist
|
||||
|
||||
@ -16,7 +16,7 @@ from ..utils import get_iou_matches, reorder_matches, get_keypoints, pixel_to_ca
|
||||
mask_joint_disparity
|
||||
from .process import preprocess_monstereo, preprocess_monoloco, extract_outputs, extract_outputs_mono,\
|
||||
filter_outputs, cluster_outputs, unnormalize_bi, laplace_sampling
|
||||
from ..activity import social_interactions, is_raising_hand
|
||||
from ..activity import social_interactions, is_raising_hand, is_phoning, is_turning
|
||||
from .architectures import MonolocoModel, LocoModel
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ class Loco:
|
||||
LINEAR_SIZE_MONO = 256
|
||||
N_SAMPLES = 100
|
||||
|
||||
def __init__(self, model, mode, net=None, device=None, n_dropout=0, p_dropout=0.2, linear_size=1024):
|
||||
def __init__(self, model, mode, net=None, device=None, n_dropout=0, p_dropout=0.2, linear_size=1024, casr='nonstd', casr_model=None):
|
||||
|
||||
# Select networks
|
||||
assert mode in ('mono', 'stereo'), "mode not recognized"
|
||||
@ -57,6 +57,19 @@ class Loco:
|
||||
input_size = 34
|
||||
output_size = 2
|
||||
|
||||
if casr == 'std':
|
||||
print("CASR with standard gestures")
|
||||
turning_output_size = 3
|
||||
turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl"
|
||||
else:
|
||||
turning_output_size = 4
|
||||
if casr_model:
|
||||
turning_model_path = casr_model
|
||||
else:
|
||||
turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr-210615-1128.pkl"
|
||||
|
||||
print('-'*10 + 'Output size :' + str(turning_output_size) + '-'*10)
|
||||
|
||||
if not device:
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
@ -70,15 +83,22 @@ class Loco:
|
||||
if net in ('monoloco', 'monoloco_p'):
|
||||
self.model = MonolocoModel(p_dropout=p_dropout, input_size=input_size, linear_size=linear_size,
|
||||
output_size=output_size)
|
||||
self.turning_model = MonolocoModel(p_dropout=p_dropout, input_size=34, linear_size=linear_size,
|
||||
output_size=turning_output_size)
|
||||
else:
|
||||
self.model = LocoModel(p_dropout=p_dropout, input_size=input_size, output_size=output_size,
|
||||
linear_size=linear_size, device=self.device)
|
||||
self.turning_model = LocoModel(p_dropout=p_dropout, input_size=34, output_size=turning_output_size,
|
||||
linear_size=linear_size, device=self.device)
|
||||
|
||||
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
|
||||
self.turning_model.load_state_dict(torch.load(turning_model_path, map_location=lambda storage, loc: storage))
|
||||
else:
|
||||
self.model = model
|
||||
self.model.eval() # Default is train
|
||||
self.model.to(self.device)
|
||||
self.turning_model.eval() # Default is train
|
||||
self.turning_model.to(self.device)
|
||||
|
||||
def forward(self, keypoints, kk, keypoints_r=None):
|
||||
"""
|
||||
@ -271,6 +291,36 @@ class Loco:
|
||||
dic_out['raising_hand'] = [is_raising_hand(keypoint) for keypoint in keypoints]
|
||||
return dic_out
|
||||
|
||||
@staticmethod
|
||||
def using_phone(dic_out, keypoints):
|
||||
dic_out['using_phone'] = [is_phoning(keypoint) for keypoint in keypoints]
|
||||
return dic_out
|
||||
|
||||
@staticmethod
|
||||
def turning(dic_out, keypoints):
|
||||
dic_out['turning'] = [is_turning(keypoint) for keypoint in keypoints]
|
||||
return dic_out
|
||||
|
||||
def turning_forward(self, dic_out, keypoints):
|
||||
"""
|
||||
Forward pass of MonSter or monoloco network
|
||||
It includes preprocessing and postprocessing of data
|
||||
"""
|
||||
if not keypoints:
|
||||
return None
|
||||
|
||||
with torch.no_grad():
|
||||
keypoints = torch.tensor(keypoints).to(self.device)
|
||||
kk = torch.eye(3).to(self.device)
|
||||
|
||||
inputs = preprocess_monoloco(keypoints, kk, zero_center=False)
|
||||
outputs = self.turning_model(inputs)
|
||||
# bi = unnormalize_bi(outputs)
|
||||
dic = {'turning': [o for o in torch.argmax(outputs, axis=len(outputs.shape)-1).tolist()]}
|
||||
# dic = {key: el.detach().cpu() for key, el in dic.items()}
|
||||
dic_out['turning'] = dic['turning']
|
||||
|
||||
return dic_out
|
||||
|
||||
def median_disparity(dic_out, keypoints, keypoints_r, mask):
|
||||
"""
|
||||
|
||||
@ -237,7 +237,8 @@ def extract_outputs(outputs, tasks=()):
|
||||
'h': outputs[:, 4:5],
|
||||
'w': outputs[:, 5:6],
|
||||
'l': outputs[:, 6:7],
|
||||
'ori': outputs[:, 7:9]}
|
||||
'ori': outputs[:, 7:9],
|
||||
'cyclist': outputs}
|
||||
|
||||
if outputs.shape[1] == 10:
|
||||
dic_out['aux'] = outputs[:, 9:10]
|
||||
@ -283,6 +284,16 @@ def extract_labels_aux(labels, tasks=None):
|
||||
dic_gt_out = {key: el.detach().cpu() for key, el in dic_gt_out.items()}
|
||||
return dic_gt_out
|
||||
|
||||
def extract_labels_cyclist(labels, tasks=None):
|
||||
|
||||
dic_gt_out = {'cyclist': labels}
|
||||
|
||||
if tasks is not None:
|
||||
assert isinstance(tasks, tuple), "tasks need to be a tuple"
|
||||
return [dic_gt_out[task] for task in tasks]
|
||||
|
||||
dic_gt_out = {key: el.detach().cpu() for key, el in dic_gt_out.items()}
|
||||
return dic_gt_out
|
||||
|
||||
def extract_labels(labels, tasks=None):
|
||||
|
||||
|
||||
@ -55,6 +55,7 @@ def download_checkpoints(args):
|
||||
torch_dir = get_torch_checkpoints_dir()
|
||||
if args.checkpoint is None:
|
||||
pifpaf_model = os.path.join(torch_dir, 'shufflenetv2k30-201104-224654-cocokp-d75ed641.pkl')
|
||||
print(pifpaf_model)
|
||||
else:
|
||||
pifpaf_model = args.checkpoint
|
||||
dic_models = {'keypoints': pifpaf_model}
|
||||
@ -81,6 +82,7 @@ def download_checkpoints(args):
|
||||
name = 'monoloco_pp-201203-1424.pkl'
|
||||
|
||||
model = os.path.join(torch_dir, name)
|
||||
print(name)
|
||||
dic_models[args.mode] = model
|
||||
if not os.path.exists(model):
|
||||
assert DOWNLOAD is not None, "pip install gdown to download monoloco model, or pass it as --model"
|
||||
@ -124,6 +126,11 @@ def factory_from_args(args):
|
||||
else:
|
||||
args.batch_size = 1
|
||||
|
||||
if args.casr_std:
|
||||
args.casr = 'std'
|
||||
else:
|
||||
args.casr = 'nonstd'
|
||||
|
||||
# Patch for stereo images with batch_size = 2
|
||||
if args.batch_size == 2 and not args.long_edge:
|
||||
args.long_edge = 1238
|
||||
@ -155,7 +162,9 @@ def predict(args):
|
||||
mode=args.mode,
|
||||
device=args.device,
|
||||
n_dropout=args.n_dropout,
|
||||
p_dropout=args.dropout)
|
||||
p_dropout=args.dropout,
|
||||
casr=args.casr,
|
||||
casr_model=args.casr_model)
|
||||
|
||||
# data
|
||||
processor, pifpaf_model = processor_factory(args)
|
||||
@ -220,11 +229,15 @@ def predict(args):
|
||||
dic_out = net.forward(keypoints, kk)
|
||||
dic_out = net.post_process(
|
||||
dic_out, boxes, keypoints, kk, dic_gt)
|
||||
if args.activities and 'social_distance' in args.activities:
|
||||
dic_out = net.social_distance(dic_out, args)
|
||||
if args.activities and 'raise_hand' in args.activities:
|
||||
dic_out = net.raising_hand(dic_out, keypoints)
|
||||
|
||||
if args.activities:
|
||||
if 'social_distance' in args.activities:
|
||||
dic_out = net.social_distance(dic_out, args)
|
||||
if 'raise_hand' in args.activities:
|
||||
dic_out = net.raising_hand(dic_out, keypoints)
|
||||
if 'using_phone' in args.activities:
|
||||
dic_out = net.using_phone(dic_out, keypoints)
|
||||
if 'is_turning' in args.activities:
|
||||
dic_out = net.turning_forward(dic_out, keypoints)
|
||||
else:
|
||||
LOG.info("Prediction with MonStereo")
|
||||
_, keypoints_r = preprocess_pifpaf(pifpaf_outs['right'], im_size)
|
||||
|
||||
@ -1,2 +1,4 @@
|
||||
|
||||
from .preprocess_kitti import parse_ground_truth, factory_file
|
||||
from .casr_preprocess import create_dic
|
||||
from .casr_preprocess_standard import create_dic_std
|
||||
92
monoloco/prep/casr_preprocess.py
Normal file
92
monoloco/prep/casr_preprocess.py
Normal file
@ -0,0 +1,92 @@
|
||||
import pickle
|
||||
import re
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
|
||||
from .. import __version__
|
||||
from .transforms import flip_inputs, flip_labels, height_augmentation
|
||||
from ..network.process import preprocess_monoloco
|
||||
|
||||
gt_path = '/scratch/izar/beauvill/casr/data/annotations/casr_annotation.pickle'
|
||||
res_path = '/scratch/izar/beauvill/casr/res_extended/casr*'
|
||||
|
||||
def bb_intersection_over_union(boxA, boxB):
|
||||
xA = max(boxA[0], boxB[0])
|
||||
yA = max(boxA[1], boxB[1])
|
||||
xB = min(boxA[2], boxB[2])
|
||||
yB = min(boxA[3], boxB[3])
|
||||
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
||||
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
||||
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
||||
iou = interArea / float(boxAArea + boxBArea - interArea)
|
||||
return iou
|
||||
|
||||
def match_bboxes(bbox_gt, bbox_pred, IOU_THRESH=1):
|
||||
n_true = bbox_gt.shape[0]
|
||||
n_pred = bbox_pred.shape[0]
|
||||
MAX_DIST = 1.0
|
||||
MIN_IOU = 0.0
|
||||
|
||||
iou_matrix = np.zeros((n_true, n_pred))
|
||||
for i in range(n_true):
|
||||
for j in range(n_pred):
|
||||
iou_matrix[i, j] = bb_intersection_over_union(bbox_gt[i,:], bbox_pred[j,:])
|
||||
|
||||
return np.argmax(iou_matrix)
|
||||
|
||||
def standard_bbox(bbox):
|
||||
return [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]]
|
||||
|
||||
def load_gt(path=gt_path):
|
||||
return pickle.load(open(path, 'rb'), encoding='latin1')
|
||||
|
||||
def load_res(path=res_path):
|
||||
mono = []
|
||||
for dir in sorted(glob.glob(path), key=lambda x:float(re.findall("(\d+)",x)[0])):
|
||||
data_list = []
|
||||
for file in sorted(os.listdir(dir), key=lambda x:float(re.findall("(\d+)",x)[0])):
|
||||
if 'json' in file:
|
||||
json_path = os.path.join(dir, file)
|
||||
json_data = json.load(open(json_path))
|
||||
json_data['filename'] = json_path
|
||||
data_list.append(json_data)
|
||||
mono.append(data_list)
|
||||
return mono
|
||||
|
||||
def create_dic(gt=load_gt(), res=load_res()):
|
||||
dic_jo = {
|
||||
'train': dict(X=[], Y=[], names=[], kps=[]),
|
||||
'val': dict(X=[], Y=[], names=[], kps=[]),
|
||||
'version': __version__,
|
||||
}
|
||||
split = ['3', '4']
|
||||
for i in range(len(res[:])):
|
||||
for j in range(len(res[i][:])):
|
||||
folder = gt[i][j]['video_folder']
|
||||
|
||||
phase = 'val'
|
||||
if folder[7] in split:
|
||||
phase = 'train'
|
||||
|
||||
if('boxes' in res[i][j]):
|
||||
gt_box = gt[i][j]['bbox_gt']
|
||||
|
||||
good_idx = match_bboxes(np.array([standard_bbox(gt_box)]), np.array(res[i][j]['boxes'])[:,:4])
|
||||
|
||||
keypoints = [res[i][j]['uv_kps'][good_idx]]
|
||||
|
||||
inp = preprocess_monoloco(keypoints, torch.eye(3)).view(-1).tolist()
|
||||
dic_jo[phase]['kps'].append(keypoints)
|
||||
dic_jo[phase]['X'].append(inp)
|
||||
dic_jo[phase]['Y'].append(gt[i][j]['left_or_right'])
|
||||
dic_jo[phase]['names'].append(folder+"_frame{}".format(j))
|
||||
|
||||
now_time = datetime.datetime.now().strftime("%Y%m%d-%H%M")[2:]
|
||||
with open("/home/beauvill/joints-casr-right-" + split[0] + split[1] + "-" + now_time + ".json", 'w') as file:
|
||||
json.dump(dic_jo, file)
|
||||
return dic_jo
|
||||
99
monoloco/prep/casr_preprocess_standard.py
Normal file
99
monoloco/prep/casr_preprocess_standard.py
Normal file
@ -0,0 +1,99 @@
|
||||
import pickle
|
||||
import re
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
|
||||
from .. import __version__
|
||||
from .transforms import flip_inputs, flip_labels, height_augmentation
|
||||
from ..network.process import preprocess_monoloco
|
||||
|
||||
gt_path = '/scratch/izar/beauvill/casr/data/annotations/casr_annotation.pickle'
|
||||
res_path = '/scratch/izar/beauvill/casr/res_extended/casr*'
|
||||
|
||||
def bb_intersection_over_union(boxA, boxB):
|
||||
xA = max(boxA[0], boxB[0])
|
||||
yA = max(boxA[1], boxB[1])
|
||||
xB = min(boxA[2], boxB[2])
|
||||
yB = min(boxA[3], boxB[3])
|
||||
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
||||
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
||||
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
||||
iou = interArea / float(boxAArea + boxBArea - interArea)
|
||||
return iou
|
||||
|
||||
def match_bboxes(bbox_gt, bbox_pred, IOU_THRESH=1):
|
||||
n_true = bbox_gt.shape[0]
|
||||
n_pred = bbox_pred.shape[0]
|
||||
MAX_DIST = 1.0
|
||||
MIN_IOU = 0.0
|
||||
|
||||
iou_matrix = np.zeros((n_true, n_pred))
|
||||
for i in range(n_true):
|
||||
for j in range(n_pred):
|
||||
iou_matrix[i, j] = bb_intersection_over_union(bbox_gt[i,:], bbox_pred[j,:])
|
||||
|
||||
return np.argmax(iou_matrix)
|
||||
|
||||
def standard_bbox(bbox):
|
||||
return [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]]
|
||||
|
||||
def load_gt():
|
||||
return pickle.load(open(gt_path, 'rb'), encoding='latin1')
|
||||
|
||||
def load_res():
|
||||
mono = []
|
||||
for dir in sorted(glob.glob(res_path), key=lambda x:float(re.findall("(\d+)",x)[0])):
|
||||
data_list = []
|
||||
for file in sorted(os.listdir(dir), key=lambda x:float(re.findall("(\d+)",x)[0])):
|
||||
if 'json' in file:
|
||||
json_path = os.path.join(dir, file)
|
||||
json_data = json.load(open(json_path))
|
||||
json_data['filename'] = json_path
|
||||
data_list.append(json_data)
|
||||
mono.append(data_list)
|
||||
return mono
|
||||
|
||||
def create_dic_std(gt=load_gt(), res=load_res()):
|
||||
dic_jo = {
|
||||
'train': dict(X=[], Y=[], names=[], kps=[]),
|
||||
'val': dict(X=[], Y=[], names=[], kps=[]),
|
||||
'version': __version__,
|
||||
}
|
||||
wrong = [6, 8, 9, 10, 11, 12, 14, 21, 40, 43, 55, 70, 76, 92, 109, 110, 112, 113, 121, 123, 124, 127, 128, 134, 136, 139, 165, 173]
|
||||
for i in range(len(res[:])):
|
||||
if(not(i in wrong)):
|
||||
for j in range(len(res[i][:])):
|
||||
phase = 'val'
|
||||
if (j % 10) > 1:
|
||||
phase = 'train'
|
||||
|
||||
folder = gt[i][j]['video_folder']
|
||||
|
||||
if('boxes' in res[i][j] and not(gt[i][j]['left_or_right'] == 2)):
|
||||
gt_box = gt[i][j]['bbox_gt']
|
||||
|
||||
good_idx = match_bboxes(np.array([standard_bbox(gt_box)]), np.array(res[i][j]['boxes'])[:,:4])
|
||||
|
||||
keypoints = [res[i][j]['uv_kps'][good_idx]]
|
||||
|
||||
gt_turn = gt[i][j]['left_or_right']
|
||||
if gt_turn == 3:
|
||||
gt_turn = 2
|
||||
|
||||
inp = preprocess_monoloco(keypoints, torch.eye(3)).view(-1).tolist()
|
||||
dic_jo[phase]['kps'].append(keypoints)
|
||||
dic_jo[phase]['X'].append(inp)
|
||||
dic_jo[phase]['Y'].append(gt_turn)
|
||||
dic_jo[phase]['names'].append(folder+"_frame{}".format(j))
|
||||
|
||||
now_time = datetime.datetime.now().strftime("%Y%m%d-%H%M")[2:]
|
||||
with open("/home/beauvill/joints-casr-std-" + now_time + ".json", 'w') as file:
|
||||
json.dump(dic_jo, file)
|
||||
return dic_jo
|
||||
|
||||
create_dic_std()
|
||||
349
monoloco/prep/preprocess_casr.py
Normal file
349
monoloco/prep/preprocess_casr.py
Normal file
@ -0,0 +1,349 @@
|
||||
# pylint: disable=too-many-statements, too-many-branches, too-many-nested-blocks
|
||||
|
||||
"""Preprocess annotations with KITTI ground-truth"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import copy
|
||||
import math
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import warnings
|
||||
import datetime
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import split_training, get_iou_matches, append_cluster, get_calibration, open_annotations, \
|
||||
extract_stereo_matches, make_new_directory, \
|
||||
check_conditions, to_spherical, correct_angle
|
||||
from ..network.process import preprocess_pifpaf, preprocess_monoloco
|
||||
from .transforms import flip_inputs, flip_labels, height_augmentation
|
||||
|
||||
|
||||
class PreprocessCasr:
|
||||
"""Prepare arrays with same format as nuScenes preprocessing but using ground truth txt files"""
|
||||
|
||||
# KITTI Dataset files
|
||||
dir_gt = "/scratch/izar/beauvill/casr/annotations"
|
||||
dir_images = "/scratch/izar/beauvill/casr/images"
|
||||
# dir_kk = os.path.join('data', 'kitti', 'calib')
|
||||
|
||||
# SOCIAL DISTANCING PARAMETERS
|
||||
# THRESHOLD_DIST = 2 # Threshold to check distance of people
|
||||
# RADII = (0.3, 0.5, 1) # expected radii of the o-space
|
||||
# SOCIAL_DISTANCE = True
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
dic_jo = {
|
||||
'train': dict(X=[], Y=[], names=[], kps=[], K=[], clst=defaultdict(lambda: defaultdict(list))),
|
||||
'val': dict(X=[], Y=[], names=[], kps=[], K=[], clst=defaultdict(lambda: defaultdict(list))),
|
||||
'test': dict(X=[], Y=[], names=[], kps=[], K=[], clst=defaultdict(lambda: defaultdict(list))),
|
||||
'version': __version__,
|
||||
}
|
||||
dic_names = defaultdict(lambda: defaultdict(list))
|
||||
dic_std = defaultdict(lambda: defaultdict(list))
|
||||
# categories_gt = dict(train=['Pedestrian', 'Person_sitting'], val=['Pedestrian'])
|
||||
|
||||
def __init__(self, dir_ann, mode='mono', iou_min=0.3, sample=False):
|
||||
|
||||
self.dir_ann = dir_ann
|
||||
self.mode = mode
|
||||
self.iou_min = iou_min
|
||||
self.sample = sample
|
||||
|
||||
assert os.path.isdir(self.dir_ann), "Annotation directory not found"
|
||||
assert any(os.scandir(self.dir_ann)), "Annotation directory empty"
|
||||
assert os.path.isdir(self.dir_gt), "Ground truth directory not found"
|
||||
assert any(os.scandir(self.dir_gt)), "Ground-truth directory empty"
|
||||
# if self.mode == 'stereo':
|
||||
# assert os.path.isdir(self.dir_ann + '_right'), "Annotation directory for right images not found"
|
||||
# assert any(os.scandir(self.dir_ann + '_right')), "Annotation directory for right images empty"
|
||||
# elif not os.path.isdir(self.dir_ann + '_right') or not any(os.scandir(self.dir_ann + '_right')):
|
||||
# warnings.warn('Horizontal flipping not applied as annotation directory for right images not found/empty')
|
||||
assert self.mode in ('mono', 'stereo'), "modality not recognized"
|
||||
|
||||
self.names_gt = tuple(os.listdir(self.dir_gt))
|
||||
self.list_gt = glob.glob(self.dir_gt + '/*.txt')
|
||||
now = datetime.datetime.now()
|
||||
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
||||
dir_out = os.path.join('data', 'arrays')
|
||||
self.path_joints = os.path.join(dir_out, 'joints-kitti-' + self.mode + '-' + now_time + '.json')
|
||||
self.path_names = os.path.join(dir_out, 'names-kitti-' + self.mode + '-' + now_time + '.json')
|
||||
path_train = os.path.join('splits', 'kitti_train.txt')
|
||||
path_val = os.path.join('splits', 'kitti_val.txt')
|
||||
self.set_train, self.set_val = split_training(self.names_gt, path_train, path_val)
|
||||
self.phase, self.name = None, None
|
||||
self.stats = defaultdict(int)
|
||||
self.stats_stereo = defaultdict(int)
|
||||
|
||||
def run(self):
|
||||
# self.names_gt = ('002282.txt',)
|
||||
for self.name in self.names_gt:
|
||||
# Extract ground truth
|
||||
path_gt = os.path.join(self.dir_gt, self.name)
|
||||
basename, _ = os.path.splitext(self.name)
|
||||
self.phase, file_not_found = self._factory_phase(self.name)
|
||||
category = 'all' if self.phase == 'train' else 'pedestrian'
|
||||
if file_not_found:
|
||||
self.stats['fnf'] += 1
|
||||
continue
|
||||
|
||||
boxes_gt, labels, _, _, _ = parse_ground_truth(path_gt)
|
||||
self.stats['gt_' + self.phase] += len(boxes_gt)
|
||||
self.stats['gt_files'] += 1
|
||||
self.stats['gt_files_ped'] += min(len(boxes_gt), 1) # if no boxes 0 else 1
|
||||
self.dic_names[basename + '.png']['boxes'] = copy.deepcopy(boxes_gt)
|
||||
self.dic_names[basename + '.png']['ys'] = copy.deepcopy(labels)
|
||||
|
||||
# Extract annotations
|
||||
dic_boxes, dic_kps, dic_gt = self.parse_annotations(boxes_gt, labels, basename)
|
||||
if dic_boxes is None: # No annotations
|
||||
continue
|
||||
self.dic_names[basename + '.png']['K'] = copy.deepcopy(dic_gt['K'])
|
||||
self.dic_jo[self.phase]['K'].append(dic_gt['K'])
|
||||
|
||||
# Match each set of keypoint with a ground truth
|
||||
for ii, boxes_gt in enumerate(dic_boxes['gt']):
|
||||
kps, kps_r = torch.tensor(dic_kps['left'][ii]), torch.tensor(dic_kps['right'][ii])
|
||||
matches = get_iou_matches(dic_boxes['left'][ii], boxes_gt, self.iou_min)
|
||||
self.stats['flipping_match'] += len(matches) if ii == 1 else 0
|
||||
for (idx, idx_gt) in matches:
|
||||
cat_gt = dic_gt['labels'][ii][idx_gt][-1]
|
||||
if cat_gt not in self.categories_gt[self.phase]: # only for training as cyclists are also extracted
|
||||
continue
|
||||
kp = kps[idx:idx + 1]
|
||||
kk = dic_gt['K']
|
||||
label = dic_gt['labels'][ii][idx_gt][:-1]
|
||||
self.stats['match'] += 1
|
||||
assert len(label) == 10, 'dimensions of monocular label is wrong'
|
||||
|
||||
if self.mode == 'mono':
|
||||
self._process_annotation_mono(kp, kk, label)
|
||||
else:
|
||||
self._process_annotation_stereo(kp, kk, label, kps_r)
|
||||
|
||||
with open(self.path_joints, 'w') as file:
|
||||
json.dump(self.dic_jo, file)
|
||||
with open(os.path.join(self.path_names), 'w') as file:
|
||||
json.dump(self.dic_names, file)
|
||||
self._cout()
|
||||
|
||||
def parse_annotations(self, boxes_gt, labels, basename):
|
||||
|
||||
path_im = os.path.join(self.dir_images, basename + '.png')
|
||||
path_calib = os.path.join(self.dir_kk, basename + '.txt')
|
||||
min_conf = 0 if self.phase == 'train' else 0.1
|
||||
|
||||
# Check image size
|
||||
with Image.open(path_im) as im:
|
||||
width, height = im.size
|
||||
|
||||
# Extract left keypoints
|
||||
annotations, kk, _ = factory_file(path_calib, self.dir_ann, basename)
|
||||
boxes, keypoints = preprocess_pifpaf(annotations, im_size=(width, height), min_conf=min_conf)
|
||||
if not keypoints:
|
||||
return None, None, None
|
||||
|
||||
# Stereo-based horizontal flipping for training (obtaining ground truth for right images)
|
||||
self.stats['instances'] += len(keypoints)
|
||||
annotations_r, _, _ = factory_file(path_calib, self.dir_ann, basename, ann_type='right')
|
||||
boxes_r, keypoints_r = preprocess_pifpaf(annotations_r, im_size=(width, height), min_conf=min_conf)
|
||||
|
||||
if not keypoints_r: # Duplicate the left one(s)
|
||||
all_boxes_gt, all_labels = [boxes_gt], [labels]
|
||||
boxes_r, keypoints_r = boxes[0:1].copy(), keypoints[0:1].copy()
|
||||
all_boxes, all_keypoints = [boxes], [keypoints]
|
||||
all_keypoints_r = [keypoints_r]
|
||||
|
||||
elif self.phase == 'train':
|
||||
# GT)
|
||||
boxes_gt_flip, ys_flip = flip_labels(boxes_gt, labels, im_w=width)
|
||||
# New left
|
||||
boxes_flip = flip_inputs(boxes_r, im_w=width, mode='box')
|
||||
keypoints_flip = flip_inputs(keypoints_r, im_w=width)
|
||||
|
||||
# New right
|
||||
keypoints_r_flip = flip_inputs(keypoints, im_w=width)
|
||||
|
||||
# combine the 2 modes
|
||||
all_boxes_gt = [boxes_gt, boxes_gt_flip]
|
||||
all_labels = [labels, ys_flip]
|
||||
all_boxes = [boxes, boxes_flip]
|
||||
all_keypoints = [keypoints, keypoints_flip]
|
||||
all_keypoints_r = [keypoints_r, keypoints_r_flip]
|
||||
|
||||
else:
|
||||
all_boxes_gt, all_labels = [boxes_gt], [labels]
|
||||
all_boxes, all_keypoints = [boxes], [keypoints]
|
||||
all_keypoints_r = [keypoints_r]
|
||||
|
||||
dic_boxes = dict(left=all_boxes, gt=all_boxes_gt)
|
||||
dic_kps = dict(left=all_keypoints, right=all_keypoints_r)
|
||||
dic_gt = dict(K=kk, labels=all_labels)
|
||||
return dic_boxes, dic_kps, dic_gt
|
||||
|
||||
def _process_annotation_mono(self, kp, kk, label):
|
||||
"""For a single annotation, process all the labels and save them"""
|
||||
kp = kp.tolist()
|
||||
inp = preprocess_monoloco(kp, kk).view(-1).tolist()
|
||||
|
||||
# Save
|
||||
self.dic_jo[self.phase]['kps'].append(kp)
|
||||
self.dic_jo[self.phase]['X'].append(inp)
|
||||
self.dic_jo[self.phase]['Y'].append(label)
|
||||
self.dic_jo[self.phase]['names'].append(self.name) # One image name for each annotation
|
||||
append_cluster(self.dic_jo, self.phase, inp, label, kp)
|
||||
self.stats['total_' + self.phase] += 1
|
||||
|
||||
def _process_annotation_stereo(self, kp, kk, label, kps_r):
|
||||
"""For a reference annotation, combine it with some (right) annotations and save it"""
|
||||
|
||||
zz = label[2]
|
||||
stereo_matches, cnt_amb = extract_stereo_matches(kp, kps_r, zz,
|
||||
phase=self.phase,
|
||||
seed=self.stats_stereo['pair'])
|
||||
self.stats_stereo['ambiguous'] += cnt_amb
|
||||
|
||||
for idx_r, s_match in stereo_matches:
|
||||
label_s = label + [s_match] # add flag to distinguish "true pairs and false pairs"
|
||||
self.stats_stereo['true_pair'] += 1 if s_match > 0.9 else 0
|
||||
self.stats_stereo['pair'] += 1 # before augmentation
|
||||
|
||||
# ---> Remove noise of very far instances for validation
|
||||
# if (self.phase == 'val') and (label[3] >= 50):
|
||||
# continue
|
||||
|
||||
# ---> Save only positives unless there is no positive (keep positive flip and augm)
|
||||
# if num > 0 and s_match < 0.9:
|
||||
# continue
|
||||
|
||||
# Height augmentation
|
||||
flag_aug = False
|
||||
if self.phase == 'train' and 3 < label[2] < 30 and (s_match > 0.9 or self.stats_stereo['pair'] % 2 == 0):
|
||||
flag_aug = True
|
||||
|
||||
# Remove height augmentation
|
||||
# flag_aug = False
|
||||
|
||||
if flag_aug:
|
||||
kps_aug, labels_aug = height_augmentation(kp, kps_r[idx_r:idx_r + 1], label_s,
|
||||
seed=self.stats_stereo['pair'])
|
||||
else:
|
||||
kps_aug = [(kp, kps_r[idx_r:idx_r + 1])]
|
||||
labels_aug = [label_s]
|
||||
|
||||
for i, lab in enumerate(labels_aug):
|
||||
assert len(lab) == 11, 'dimensions of stereo label is wrong'
|
||||
self.stats_stereo['pair_aug'] += 1
|
||||
(kp_aug, kp_aug_r) = kps_aug[i]
|
||||
input_l = preprocess_monoloco(kp_aug, kk).view(-1)
|
||||
input_r = preprocess_monoloco(kp_aug_r, kk).view(-1)
|
||||
keypoint = torch.cat((kp_aug, kp_aug_r), dim=2).tolist()
|
||||
inp = torch.cat((input_l, input_l - input_r)).tolist()
|
||||
self.dic_jo[self.phase]['kps'].append(keypoint)
|
||||
self.dic_jo[self.phase]['X'].append(inp)
|
||||
self.dic_jo[self.phase]['Y'].append(lab)
|
||||
self.dic_jo[self.phase]['names'].append(self.name) # One image name for each annotation
|
||||
append_cluster(self.dic_jo, self.phase, inp, lab, keypoint)
|
||||
self.stats_stereo['total_' + self.phase] += 1 # including height augmentation
|
||||
|
||||
def _cout(self):
|
||||
print('-' * 100)
|
||||
print(f"Number of GT files: {self.stats['gt_files']} ")
|
||||
print(f"Files with at least one pedestrian/cyclist: {self.stats['gt_files_ped']}")
|
||||
print(f"Files not found: {self.stats['fnf']}")
|
||||
print('-' * 100)
|
||||
our = self.stats['match'] - self.stats['flipping_match']
|
||||
gt = self.stats['gt_train'] + self.stats['gt_val']
|
||||
print(f"Ground truth matches: {100 * our / gt:.1f} for left images (train and val)")
|
||||
print(f"Parsed instances: {self.stats['instances']}")
|
||||
print(f"Ground truth instances: {gt}")
|
||||
print(f"Matched instances: {our}")
|
||||
print(f"Including horizontal flipping: {self.stats['match']}")
|
||||
|
||||
if self.mode == 'stereo':
|
||||
print('-' * 100)
|
||||
print(f"Ambiguous instances removed: {self.stats_stereo['ambiguous']}")
|
||||
print(f"True pairs ratio: {100 * self.stats_stereo['true_pair'] / self.stats_stereo['pair']:.1f}% ")
|
||||
print(f"Height augmentation pairs: {self.stats_stereo['pair_aug'] - self.stats_stereo['pair']} ")
|
||||
print('-' * 100)
|
||||
total_train = self.stats_stereo['total_train'] if self.mode == 'stereo' else self.stats['total_train']
|
||||
total_val = self.stats_stereo['total_val'] if self.mode == 'stereo' else self.stats['total_val']
|
||||
print(f"Total annotations for TRAINING: {total_train}")
|
||||
print(f"Total annotations for VALIDATION: {total_val}")
|
||||
print('-' * 100)
|
||||
print(f"\nOutput files:\n{self.path_names}\n{self.path_joints}")
|
||||
print('-' * 100)
|
||||
|
||||
|
||||
def _factory_phase(self, name):
|
||||
"""Choose the phase"""
|
||||
phase = None
|
||||
flag = False
|
||||
if name in self.set_train:
|
||||
phase = 'train'
|
||||
elif name in self.set_val:
|
||||
phase = 'val'
|
||||
else:
|
||||
flag = True
|
||||
return phase, flag
|
||||
|
||||
|
||||
def parse_ground_truth(path_gt, spherical=False):
|
||||
"""Parse KITTI ground truth files"""
|
||||
|
||||
boxes_gt = []
|
||||
labels = []
|
||||
truncs_gt = [] # Float from 0 to 1
|
||||
occs_gt = [] # Either 0,1,2,3 fully visible, partly occluded, largely occluded, unknown
|
||||
lines = []
|
||||
|
||||
with open(path_gt, "r") as f_gt:
|
||||
for line_gt in f_gt:
|
||||
line = line_gt.split()
|
||||
truncs_gt.append(float(line[1]))
|
||||
occs_gt.append(int(line[2]))
|
||||
boxes_gt.append([float(x) for x in line[4:8]])
|
||||
xyz = [float(x) for x in line[11:14]]
|
||||
hwl = [float(x) for x in line[8:11]]
|
||||
dd = float(math.sqrt(xyz[0] ** 2 + xyz[1] ** 2 + xyz[2] ** 2))
|
||||
yaw = float(line[14])
|
||||
assert - math.pi <= yaw <= math.pi
|
||||
alpha = float(line[3])
|
||||
sin, cos, yaw_corr = correct_angle(yaw, xyz)
|
||||
assert min(abs(-yaw_corr - alpha), (abs(yaw_corr - alpha))) < 0.15, "more than 10 degrees of error"
|
||||
if spherical:
|
||||
rtp = to_spherical(xyz)
|
||||
loc = rtp[1:3] + xyz[2:3] + rtp[0:1] # [theta, psi, z, r]
|
||||
else:
|
||||
loc = xyz + [dd]
|
||||
cat = line[0] # 'Pedestrian', or 'Person_sitting' for people
|
||||
output = loc + hwl + [sin, cos, yaw, cat]
|
||||
labels.append(output)
|
||||
lines.append(line_gt)
|
||||
return boxes_gt, labels, truncs_gt, occs_gt, lines
|
||||
|
||||
|
||||
def factory_file(path_calib, dir_ann, basename, ann_type='left'):
|
||||
"""Choose the annotation and the calibration files"""
|
||||
|
||||
assert ann_type in ('left', 'right')
|
||||
p_left, p_right = get_calibration(path_calib)
|
||||
|
||||
if ann_type == 'left':
|
||||
kk, tt = p_left[:]
|
||||
path_ann = os.path.join(dir_ann, basename + '.png.predictions.json')
|
||||
|
||||
# The right folder is called <NameOfLeftFolder>_right
|
||||
else:
|
||||
kk, tt = p_right[:]
|
||||
path_ann = os.path.join(dir_ann + '_right', basename + '.png.predictions.json')
|
||||
|
||||
annotations = open_annotations(path_ann)
|
||||
|
||||
return annotations, kk, tt
|
||||
@ -50,10 +50,11 @@ def cli():
|
||||
visualizer.cli(parser)
|
||||
|
||||
# Monoloco
|
||||
predict_parser.add_argument('--activities', nargs='+', choices=['raise_hand', 'social_distance'],
|
||||
predict_parser.add_argument('--activities', nargs='+', choices=['raise_hand', 'social_distance', 'using_phone', 'is_turning'],
|
||||
help='Choose activities to show: social_distance, raise_hand')
|
||||
predict_parser.add_argument('--mode', help='keypoints, mono, stereo', default='mono')
|
||||
predict_parser.add_argument('--model', help='path of MonoLoco/MonStereo model to load')
|
||||
predict_parser.add_argument('--casr_model', help='path of casr model to load')
|
||||
predict_parser.add_argument('--net', help='only to select older MonoLoco model, otherwise use --mode')
|
||||
predict_parser.add_argument('--path_gt', help='path of json file with gt 3d localization')
|
||||
#default='data/arrays/names-kitti-200615-1022.json')
|
||||
@ -61,6 +62,7 @@ def cli():
|
||||
predict_parser.add_argument('--n_dropout', type=int, help='Epistemic uncertainty evaluation', default=0)
|
||||
predict_parser.add_argument('--dropout', type=float, help='dropout parameter', default=0.2)
|
||||
predict_parser.add_argument('--show_all', help='only predict ground-truth matches or all', action='store_true')
|
||||
predict_parser.add_argument('--casr_std', help='run casr training', action='store_true')
|
||||
predict_parser.add_argument('--webcam', help='monstereo streaming', action='store_true')
|
||||
predict_parser.add_argument('--camera', help='device to use for webcam streaming', type=int, default=0)
|
||||
predict_parser.add_argument('--focal', help='focal length in mm for a sensor size of 7.2x5.4 mm. (nuScenes)',
|
||||
@ -96,6 +98,8 @@ def cli():
|
||||
training_parser.add_argument('--hidden_size', type=int, help='Number of hidden units in the model', default=1024)
|
||||
training_parser.add_argument('--n_stage', type=int, help='Number of stages in the model', default=3)
|
||||
training_parser.add_argument('--hyp', help='run hyperparameters tuning', action='store_true')
|
||||
training_parser.add_argument('--casr', help='run casr training', action='store_true')
|
||||
training_parser.add_argument('--casr_std', help='run casr training', action='store_true')
|
||||
training_parser.add_argument('--multiplier', type=int, help='Size of the grid of hyp search', default=1)
|
||||
training_parser.add_argument('--r_seed', type=int, help='specify the seed for training and hyp tuning', default=1)
|
||||
training_parser.add_argument('--print_loss', help='print training and validation losses', action='store_true')
|
||||
@ -146,6 +150,12 @@ def main():
|
||||
from .prep.preprocess_nu import PreprocessNuscenes
|
||||
prep = PreprocessNuscenes(args.dir_ann, args.dir_nuscenes, args.dataset, args.iou_min)
|
||||
prep.run()
|
||||
elif 'casr' in args.dataset:
|
||||
from .prep.casr_preprocess import create_dic
|
||||
create_dic()
|
||||
elif 'casr_std' in args.dataset:
|
||||
from .prep.casr_preprocess_standard import create_dic_std
|
||||
create_dic_std()
|
||||
else:
|
||||
from .prep.preprocess_kitti import PreprocessKitti
|
||||
prep = PreprocessKitti(args.dir_ann, mode=args.mode, iou_min=args.iou_min)
|
||||
@ -157,10 +167,27 @@ def main():
|
||||
elif args.command == 'train':
|
||||
from .train import HypTuning
|
||||
if args.hyp:
|
||||
hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
|
||||
monocular=args.monocular, dropout=args.dropout,
|
||||
multiplier=args.multiplier, r_seed=args.r_seed)
|
||||
hyp_tuning.train(args)
|
||||
if args.casr:
|
||||
from .train import HypTuningCasr
|
||||
hyp_tuning_casr = HypTuningCasr(joints=args.joints, epochs=args.epochs,
|
||||
monocular=args.monocular, dropout=args.dropout,
|
||||
multiplier=args.multiplier, r_seed=args.r_seed)
|
||||
hyp_tuning_casr.train(args)
|
||||
else:
|
||||
hyp_tuning = HypTuning(joints=args.joints, epochs=args.epochs,
|
||||
monocular=args.monocular, dropout=args.dropout,
|
||||
multiplier=args.multiplier, r_seed=args.r_seed)
|
||||
hyp_tuning.train(args)
|
||||
elif args.casr:
|
||||
from .train import CASRTrainer
|
||||
training = CASRTrainer(args)
|
||||
_ = training.train()
|
||||
_ = training.evaluate()
|
||||
elif args.casr_std:
|
||||
from .train import CASRTrainerStandard
|
||||
training = CASRTrainerStandard(args)
|
||||
_ = training.train()
|
||||
_ = training.evaluate()
|
||||
else:
|
||||
from .train import Trainer
|
||||
training = Trainer(args)
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
|
||||
from .hyp_tuning import HypTuning
|
||||
from .hyp_tuning_casr import HypTuningCasr
|
||||
from .trainer import Trainer
|
||||
from .trainer_casr import CASRTrainer
|
||||
from .trainer_casr_standard import CASRTrainerStandard
|
||||
|
||||
@ -63,7 +63,7 @@ class KeypointsDataset(Dataset):
|
||||
self.version = dic_jo['version']
|
||||
|
||||
# Extract annotations divided in clusters
|
||||
self.dic_clst = dic_jo[phase]['clst']
|
||||
# self.dic_clst = dic_jo[phase]['clst']
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
|
||||
124
monoloco/train/hyp_tuning_casr.py
Normal file
124
monoloco/train/hyp_tuning_casr.py
Normal file
@ -0,0 +1,124 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import random
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .trainer_casr import CASRTrainer
|
||||
|
||||
|
||||
class HypTuningCasr:
|
||||
|
||||
def __init__(self, joints, epochs, monocular, dropout, multiplier=1, r_seed=1):
|
||||
"""
|
||||
Initialize directories, load the data and parameters for the training
|
||||
"""
|
||||
|
||||
# Initialize Directories
|
||||
self.joints = joints
|
||||
self.monocular = monocular
|
||||
self.dropout = dropout
|
||||
self.num_epochs = epochs
|
||||
self.r_seed = r_seed
|
||||
dir_out = os.path.join('data', 'models')
|
||||
dir_logs = os.path.join('data', 'logs')
|
||||
assert os.path.exists(dir_out), "Output directory not found"
|
||||
if not os.path.exists(dir_logs):
|
||||
os.makedirs(dir_logs)
|
||||
|
||||
name_out = 'hyp-casr-'
|
||||
|
||||
self.path_log = os.path.join(dir_logs, name_out)
|
||||
self.path_model = os.path.join(dir_out, name_out)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize grid of parameters
|
||||
random.seed(r_seed)
|
||||
np.random.seed(r_seed)
|
||||
self.sched_gamma_list = [0.8, 0.9, 1, 0.8, 0.9, 1] * multiplier
|
||||
random.shuffle(self.sched_gamma_list)
|
||||
self.sched_step = [10, 20, 40, 60, 80, 100] * multiplier
|
||||
random.shuffle(self.sched_step)
|
||||
self.bs_list = [64, 128, 256, 512, 512, 1024] * multiplier
|
||||
random.shuffle(self.bs_list)
|
||||
self.hidden_list = [512, 1024, 2048, 512, 1024, 2048] * multiplier
|
||||
random.shuffle(self.hidden_list)
|
||||
self.n_stage_list = [3, 3, 3, 3, 3, 3] * multiplier
|
||||
random.shuffle(self.n_stage_list)
|
||||
# Learning rate
|
||||
aa = math.log(0.0005, 10)
|
||||
bb = math.log(0.01, 10)
|
||||
log_lr_list = np.random.uniform(aa, bb, int(6 * multiplier)).tolist()
|
||||
self.lr_list = [10 ** xx for xx in log_lr_list]
|
||||
# plt.hist(self.lr_list, bins=50)
|
||||
# plt.show()
|
||||
|
||||
def train(self, args):
|
||||
"""Train multiple times using log-space random search"""
|
||||
|
||||
best_acc_val = 20
|
||||
dic_best = {}
|
||||
dic_err_best = {}
|
||||
start = time.time()
|
||||
cnt = 0
|
||||
for idx, lr in enumerate(self.lr_list):
|
||||
bs = self.bs_list[idx]
|
||||
sched_gamma = self.sched_gamma_list[idx]
|
||||
sched_step = self.sched_step[idx]
|
||||
hidden_size = self.hidden_list[idx]
|
||||
n_stage = self.n_stage_list[idx]
|
||||
|
||||
training = CASRTrainer(args)
|
||||
|
||||
best_epoch = training.train()
|
||||
dic_err, model = training.evaluate()
|
||||
acc_val = dic_err['val']['all']['mean']
|
||||
cnt += 1
|
||||
print("Combination number: {}".format(cnt))
|
||||
|
||||
if acc_val < best_acc_val:
|
||||
dic_best['lr'] = lr
|
||||
dic_best['joints'] = self.joints
|
||||
dic_best['bs'] = bs
|
||||
dic_best['monocular'] = self.monocular
|
||||
dic_best['sched_gamma'] = sched_gamma
|
||||
dic_best['sched_step'] = sched_step
|
||||
dic_best['hidden_size'] = hidden_size
|
||||
dic_best['n_stage'] = n_stage
|
||||
dic_best['acc_val'] = dic_err['val']['all']['d']
|
||||
dic_best['best_epoch'] = best_epoch
|
||||
dic_best['random_seed'] = self.r_seed
|
||||
# dic_best['acc_test'] = dic_err['test']['all']['mean']
|
||||
|
||||
dic_err_best = dic_err
|
||||
best_acc_val = acc_val
|
||||
model_best = model
|
||||
|
||||
# Save model and log
|
||||
now = datetime.datetime.now()
|
||||
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
||||
self.path_model = self.path_model + now_time + '.pkl'
|
||||
torch.save(model_best.state_dict(), self.path_model)
|
||||
with open(self.path_log + now_time, 'w') as f:
|
||||
json.dump(dic_best, f)
|
||||
end = time.time()
|
||||
print('\n\n\n')
|
||||
self.logger.info(" Tried {} combinations".format(cnt))
|
||||
self.logger.info(" Total time for hyperparameters search: {:.2f} minutes".format((end - start) / 60))
|
||||
self.logger.info(" Best hyperparameters are:")
|
||||
for key, value in dic_best.items():
|
||||
self.logger.info(" {}: {}".format(key, value))
|
||||
|
||||
print()
|
||||
self.logger.info("Accuracy in each cluster:")
|
||||
|
||||
self.logger.info("Final accuracy Val: {:.2f}".format(dic_best['acc_val']))
|
||||
self.logger.info("\nSaved the model: {}".format(self.path_model))
|
||||
@ -11,7 +11,7 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..network import extract_labels, extract_labels_aux, extract_outputs
|
||||
from ..network import extract_labels, extract_labels_aux, extract_labels_cyclist, extract_outputs
|
||||
|
||||
|
||||
class AutoTuneMultiTaskLoss(torch.nn.Module):
|
||||
@ -53,8 +53,13 @@ class MultiTaskLoss(torch.nn.Module):
|
||||
self.tasks = tasks
|
||||
if len(self.tasks) == 1 and self.tasks[0] == 'aux':
|
||||
self.flag_aux = True
|
||||
self.flag_cyclist = False
|
||||
elif len(self.tasks) == 1 and self.tasks[0] == 'cyclist':
|
||||
self.flag_cyclist = True
|
||||
self.flag_aux = False
|
||||
else:
|
||||
self.flag_aux = False
|
||||
self.flag_cyclist = False
|
||||
|
||||
def forward(self, outputs, labels, phase='train'):
|
||||
|
||||
@ -62,6 +67,8 @@ class MultiTaskLoss(torch.nn.Module):
|
||||
out = extract_outputs(outputs, tasks=self.tasks)
|
||||
if self.flag_aux:
|
||||
gt_out = extract_labels_aux(labels, tasks=self.tasks)
|
||||
elif self.flag_cyclist:
|
||||
gt_out = extract_labels_cyclist(labels, tasks=self.tasks)
|
||||
else:
|
||||
gt_out = extract_labels(labels, tasks=self.tasks)
|
||||
loss_values = [lam * l(o, g) for lam, l, o, g in zip(self.lambdas, self.losses, out, gt_out)]
|
||||
@ -81,7 +88,8 @@ class CompositeLoss(torch.nn.Module):
|
||||
self.tasks = tasks
|
||||
self.multi_loss_tr = {task: (LaplacianLoss() if task == 'd'
|
||||
else (nn.BCEWithLogitsLoss() if task in ('aux', )
|
||||
else nn.L1Loss())) for task in tasks}
|
||||
else (nn.CrossEntropyLoss() if task == 'cyclist'
|
||||
else nn.L1Loss()))) for task in tasks}
|
||||
|
||||
self.multi_loss_val = {}
|
||||
for task in tasks:
|
||||
@ -91,6 +99,8 @@ class CompositeLoss(torch.nn.Module):
|
||||
loss = angle_loss
|
||||
elif task in ('aux', ):
|
||||
loss = nn.BCEWithLogitsLoss()
|
||||
elif task == 'cyclist':
|
||||
loss = nn.CrossEntropyLoss()
|
||||
else:
|
||||
loss = nn.L1Loss()
|
||||
self.multi_loss_val[task] = loss
|
||||
|
||||
364
monoloco/train/trainer_casr.py
Normal file
364
monoloco/train/trainer_casr.py
Normal file
@ -0,0 +1,364 @@
|
||||
# pylint: disable=too-many-statements
|
||||
|
||||
"""
|
||||
Training and evaluation of a neural network that, given 2D joints, estimates:
|
||||
- 3D localization and confidence intervals
|
||||
- Orientation
|
||||
- Bounding box dimensions
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import datetime
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
import time
|
||||
from itertools import chain
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from .. import __version__
|
||||
from .datasets import KeypointsDataset
|
||||
from .losses import CompositeLoss, MultiTaskLoss, AutoTuneMultiTaskLoss
|
||||
from ..network import extract_outputs, extract_labels
|
||||
from ..network.architectures import LocoModel
|
||||
from ..utils import set_logger
|
||||
|
||||
|
||||
class CASRTrainer:
|
||||
# Constants
|
||||
VAL_BS = 10000
|
||||
|
||||
tasks = ('cyclist',)
|
||||
val_task = 'cyclist'
|
||||
lambdas = (1,)
|
||||
#clusters = ['10', '20', '30', '40']
|
||||
input_size = 34
|
||||
output_size = 4
|
||||
dir_figures = os.path.join('figures', 'losses')
|
||||
|
||||
def __init__(self, args):
|
||||
"""
|
||||
Initialize directories, load the data and parameters for the training
|
||||
"""
|
||||
|
||||
assert os.path.exists(args.joints), "Input file not found"
|
||||
self.mode = args.mode
|
||||
self.joints = args.joints
|
||||
self.num_epochs = args.epochs
|
||||
self.no_save = args.no_save
|
||||
self.print_loss = args.print_loss
|
||||
self.lr = args.lr
|
||||
self.sched_step = args.sched_step
|
||||
self.sched_gamma = args.sched_gamma
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_stage = args.n_stage
|
||||
self.r_seed = args.r_seed
|
||||
self.auto_tune_mtl = args.auto_tune_mtl
|
||||
|
||||
# Select path out
|
||||
if args.out:
|
||||
self.path_out = args.out # full path without extension
|
||||
dir_out, _ = os.path.split(self.path_out)
|
||||
else:
|
||||
dir_out = os.path.join('data', 'outputs')
|
||||
name = 'casr'
|
||||
now = datetime.datetime.now()
|
||||
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
||||
name_out = name + '-' + now_time + '.pkl'
|
||||
self.path_out = os.path.join(dir_out, name_out)
|
||||
assert os.path.exists(dir_out), "Directory to save the model not found"
|
||||
print(self.path_out)
|
||||
# Select the device
|
||||
use_cuda = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda" if use_cuda else "cpu")
|
||||
print('Device: ', self.device)
|
||||
torch.manual_seed(self.r_seed)
|
||||
if use_cuda:
|
||||
torch.cuda.manual_seed(self.r_seed)
|
||||
|
||||
losses_tr, losses_val = CompositeLoss(self.tasks)()
|
||||
|
||||
if self.auto_tune_mtl:
|
||||
self.mt_loss = AutoTuneMultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
|
||||
else:
|
||||
self.mt_loss = MultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
|
||||
self.mt_loss.to(self.device)
|
||||
|
||||
# Dataloader
|
||||
self.dataloaders = {phase: DataLoader(KeypointsDataset(self.joints, phase=phase),
|
||||
batch_size=args.bs, shuffle=True) for phase in ['train', 'val']}
|
||||
|
||||
self.dataset_sizes = {phase: len(KeypointsDataset(self.joints, phase=phase))
|
||||
for phase in ['train', 'val']}
|
||||
self.dataset_version = KeypointsDataset(self.joints, phase='train').get_version()
|
||||
|
||||
self._set_logger(args)
|
||||
|
||||
# Define the model
|
||||
self.logger.info('Sizes of the dataset: {}'.format(self.dataset_sizes))
|
||||
print(">>> creating model")
|
||||
|
||||
self.model = LocoModel(
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
linear_size=args.hidden_size,
|
||||
p_dropout=args.dropout,
|
||||
num_stage=self.n_stage,
|
||||
device=self.device,
|
||||
)
|
||||
self.model.to(self.device)
|
||||
print(">>> model params: {:.3f}M".format(sum(p.numel() for p in self.model.parameters()) / 1000000.0))
|
||||
print(">>> loss params: {}".format(sum(p.numel() for p in self.mt_loss.parameters())))
|
||||
|
||||
# Optimizer and scheduler
|
||||
all_params = chain(self.model.parameters(), self.mt_loss.parameters())
|
||||
self.optimizer = torch.optim.Adam(params=all_params, lr=args.lr)
|
||||
self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')
|
||||
self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=self.sched_step, gamma=self.sched_gamma)
|
||||
|
||||
def train(self):
|
||||
since = time.time()
|
||||
best_model_wts = copy.deepcopy(self.model.state_dict())
|
||||
best_acc = 1e6
|
||||
best_training_acc = 1e6
|
||||
best_epoch = 0
|
||||
epoch_losses = defaultdict(lambda: defaultdict(list))
|
||||
for epoch in range(self.num_epochs):
|
||||
running_loss = defaultdict(lambda: defaultdict(int))
|
||||
|
||||
# Each epoch has a training and validation phase
|
||||
for phase in ['train', 'val']:
|
||||
if phase == 'train':
|
||||
self.model.train() # Set model to training mode
|
||||
else:
|
||||
self.model.eval() # Set model to evaluate mode
|
||||
|
||||
for inputs, labels, _, _ in self.dataloaders[phase]:
|
||||
inputs = inputs.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
with torch.set_grad_enabled(phase == 'train'):
|
||||
if phase == 'train':
|
||||
self.optimizer.zero_grad()
|
||||
outputs = self.model(inputs)
|
||||
loss, _ = self.mt_loss(outputs, labels, phase=phase)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3)
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
|
||||
else:
|
||||
outputs = self.model(inputs)
|
||||
with torch.no_grad():
|
||||
loss_eval, loss_values_eval = self.mt_loss(outputs, labels, phase='val')
|
||||
self.epoch_logs(phase, loss_eval, loss_values_eval, inputs, running_loss)
|
||||
|
||||
self.cout_values(epoch, epoch_losses, running_loss)
|
||||
|
||||
# deep copy the model
|
||||
|
||||
if epoch_losses['val'][self.val_task][-1] < best_acc:
|
||||
best_acc = epoch_losses['val'][self.val_task][-1]
|
||||
best_training_acc = epoch_losses['train']['all'][-1]
|
||||
best_epoch = epoch
|
||||
best_model_wts = copy.deepcopy(self.model.state_dict())
|
||||
|
||||
time_elapsed = time.time() - since
|
||||
print('\n\n' + '-' * 120)
|
||||
self.logger.info('Training:\nTraining complete in {:.0f}m {:.0f}s'
|
||||
.format(time_elapsed // 60, time_elapsed % 60))
|
||||
self.logger.info('Best training Accuracy: {:.3f}'.format(best_training_acc))
|
||||
self.logger.info('Best validation Accuracy for {}: {:.3f}'.format(self.val_task, best_acc))
|
||||
self.logger.info('Saved weights of the model at epoch: {}'.format(best_epoch))
|
||||
|
||||
self._print_losses(epoch_losses)
|
||||
|
||||
# load best model weights
|
||||
self.model.load_state_dict(best_model_wts)
|
||||
return best_epoch
|
||||
|
||||
def epoch_logs(self, phase, loss, loss_values, inputs, running_loss):
|
||||
|
||||
running_loss[phase]['all'] += loss.item() * inputs.size(0)
|
||||
for i, task in enumerate(self.tasks):
|
||||
running_loss[phase][task] += loss_values[i].item() * inputs.size(0)
|
||||
|
||||
def evaluate(self, load=False, model=None, debug=False):
|
||||
|
||||
# To load a model instead of using the trained one
|
||||
if load:
|
||||
self.model.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
|
||||
|
||||
# Average distance on training and test set after unnormalizing
|
||||
self.model.eval()
|
||||
dic_err = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) # initialized to zero
|
||||
dic_err['val']['sigmas'] = [0.] * len(self.tasks)
|
||||
dataset = KeypointsDataset(self.joints, phase='val')
|
||||
size_eval = len(dataset)
|
||||
start = 0
|
||||
with torch.no_grad():
|
||||
for end in range(self.VAL_BS, size_eval + self.VAL_BS, self.VAL_BS):
|
||||
end = end if end < size_eval else size_eval
|
||||
inputs, labels, _, _ = dataset[start:end]
|
||||
start = end
|
||||
inputs = inputs.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
# Debug plot for input-output distributions
|
||||
if debug:
|
||||
debug_plots(inputs, labels)
|
||||
sys.exit()
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(inputs)
|
||||
#self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
|
||||
|
||||
# self.cout_stats(dic_err['val'], size_eval, clst='all')
|
||||
# Evaluate performances on different clusters and save statistics
|
||||
|
||||
# Save the model and the results
|
||||
if not (self.no_save or load):
|
||||
torch.save(self.model.state_dict(), self.path_model)
|
||||
print('-' * 120)
|
||||
self.logger.info("\nmodel saved: {} \n".format(self.path_model))
|
||||
else:
|
||||
self.logger.info("\nmodel not saved\n")
|
||||
|
||||
return dic_err, self.model
|
||||
|
||||
def compute_stats(self, outputs, labels, dic_err, size_eval, clst):
|
||||
"""Compute mean, bi and max of torch tensors"""
|
||||
|
||||
_, loss_values = self.mt_loss(outputs, labels, phase='val')
|
||||
rel_frac = outputs.size(0) / size_eval
|
||||
|
||||
tasks = self.tasks # Exclude auxiliary
|
||||
|
||||
for idx, task in enumerate(tasks):
|
||||
dic_err[clst][task] += float(loss_values[idx].item()) * (outputs.size(0) / size_eval)
|
||||
|
||||
# Distance
|
||||
errs = torch.abs(extract_outputs(outputs)['d'] - extract_labels(labels)['d'])
|
||||
assert rel_frac > 0.99, "Variance of errors not supported with partial evaluation"
|
||||
|
||||
# Uncertainty
|
||||
bis = extract_outputs(outputs)['bi'].cpu()
|
||||
bi = float(torch.mean(bis).item())
|
||||
bi_perc = float(torch.sum(errs <= bis)) / errs.shape[0]
|
||||
dic_err[clst]['bi'] += bi * rel_frac
|
||||
dic_err[clst]['bi%'] += bi_perc * rel_frac
|
||||
dic_err[clst]['std'] = errs.std()
|
||||
|
||||
# (Don't) Save auxiliary task results
|
||||
dic_err['sigmas'].append(0)
|
||||
|
||||
if self.auto_tune_mtl:
|
||||
assert len(loss_values) == 2 * len(self.tasks)
|
||||
for i, _ in enumerate(self.tasks):
|
||||
dic_err['sigmas'][i] += float(loss_values[len(tasks) + i + 1].item()) * rel_frac
|
||||
|
||||
def cout_stats(self, dic_err, size_eval, clst):
|
||||
if clst == 'all':
|
||||
print('-' * 120)
|
||||
self.logger.info("Evaluation, val set: \nAv. dist D: {:.2f} m with bi {:.2f} ({:.1f}%), \n"
|
||||
"X: {:.1f} cm, Y: {:.1f} cm \nOri: {:.1f} "
|
||||
"\n H: {:.1f} cm, W: {:.1f} cm, L: {:.1f} cm"
|
||||
"\nAuxiliary Task: {:.1f} %, "
|
||||
.format(dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
|
||||
dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
|
||||
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
|
||||
dic_err[clst]['l'] * 100, dic_err[clst]['aux'] * 100))
|
||||
if self.auto_tune_mtl:
|
||||
self.logger.info("Sigmas: Z: {:.2f}, X: {:.2f}, Y:{:.2f}, H: {:.2f}, W: {:.2f}, L: {:.2f}, ORI: {:.2f}"
|
||||
" AUX:{:.2f}\n"
|
||||
.format(*dic_err['sigmas']))
|
||||
else:
|
||||
self.logger.info("Val err clust {} --> D:{:.2f}m, bi:{:.2f} ({:.1f}%), STD:{:.1f}m X:{:.1f} Y:{:.1f} "
|
||||
"Ori:{:.1f}d, H: {:.0f} W: {:.0f} L:{:.0f} for {} pp. "
|
||||
.format(clst, dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
|
||||
dic_err[clst]['std'], dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
|
||||
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
|
||||
dic_err[clst]['l'] * 100, size_eval))
|
||||
|
||||
def cout_values(self, epoch, epoch_losses, running_loss):
|
||||
|
||||
string = '\r' + '{:.0f} '
|
||||
format_list = [epoch]
|
||||
for phase in running_loss:
|
||||
string = string + phase[0:1].upper() + ':'
|
||||
for el in running_loss['train']:
|
||||
loss = running_loss[phase][el] / self.dataset_sizes[phase]
|
||||
epoch_losses[phase][el].append(loss)
|
||||
if el == 'all':
|
||||
string = string + ':{:.1f} '
|
||||
format_list.append(loss)
|
||||
elif el in ('ori', 'aux'):
|
||||
string = string + el + ':{:.1f} '
|
||||
format_list.append(loss)
|
||||
else:
|
||||
string = string + el + ':{:.0f} '
|
||||
format_list.append(loss * 100)
|
||||
|
||||
if epoch % 10 == 0:
|
||||
print(string.format(*format_list))
|
||||
|
||||
def _print_losses(self, epoch_losses):
|
||||
if not self.print_loss:
|
||||
return
|
||||
os.makedirs(self.dir_figures, exist_ok=True)
|
||||
for idx, phase in enumerate(epoch_losses):
|
||||
for idx_2, el in enumerate(epoch_losses['train']):
|
||||
plt.figure(idx + idx_2)
|
||||
plt.title(phase + '_' + el)
|
||||
plt.xlabel('epochs')
|
||||
plt.plot(epoch_losses[phase][el][10:], label='{} Loss: {}'.format(phase, el))
|
||||
plt.savefig(os.path.join(self.dir_figures, '{}_loss_{}.png'.format(phase, el)))
|
||||
plt.close()
|
||||
|
||||
def _set_logger(self, args):
|
||||
if self.no_save:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
else:
|
||||
self.path_model = self.path_out
|
||||
print(self.path_model)
|
||||
self.logger = set_logger(os.path.splitext(self.path_out)[0]) # remove .pkl
|
||||
self.logger.info( # pylint: disable=logging-fstring-interpolation
|
||||
f'\nVERSION: {__version__}\n'
|
||||
f'\nINPUT_FILE: {args.joints}'
|
||||
f'\nInput file version: {self.dataset_version}'
|
||||
f'\nTorch version: {torch.__version__}\n'
|
||||
f'\nTraining arguments:'
|
||||
f'\nmode: {self.mode} \nlearning rate: {args.lr} \nbatch_size: {args.bs}'
|
||||
f'\nepochs: {args.epochs} \ndropout: {args.dropout} '
|
||||
f'\nscheduler step: {args.sched_step} \nscheduler gamma: {args.sched_gamma} '
|
||||
f'\ninput_size: {self.input_size} \noutput_size: {self.output_size} '
|
||||
f'\nhidden_size: {args.hidden_size}'
|
||||
f' \nn_stages: {args.n_stage} \n r_seed: {args.r_seed} \nlambdas: {self.lambdas}'
|
||||
)
|
||||
|
||||
|
||||
def debug_plots(inputs, labels):
|
||||
inputs_shoulder = inputs.cpu().numpy()[:, 5]
|
||||
inputs_hip = inputs.cpu().numpy()[:, 11]
|
||||
labels = labels.cpu().numpy()
|
||||
heights = inputs_hip - inputs_shoulder
|
||||
plt.figure(1)
|
||||
plt.hist(heights, bins='auto')
|
||||
plt.show()
|
||||
plt.figure(2)
|
||||
plt.hist(labels, bins='auto')
|
||||
plt.show()
|
||||
|
||||
|
||||
def get_accuracy(outputs, labels):
|
||||
"""From Binary cross entropy outputs to accuracy"""
|
||||
|
||||
mask = outputs >= 0.5
|
||||
accuracy = 1. - torch.mean(torch.abs(mask.float() - labels)).item()
|
||||
return accuracy
|
||||
365
monoloco/train/trainer_casr_standard.py
Normal file
365
monoloco/train/trainer_casr_standard.py
Normal file
@ -0,0 +1,365 @@
|
||||
# pylint: disable=too-many-statements
|
||||
|
||||
"""
|
||||
Training and evaluation of a neural network that, given 2D joints, estimates:
|
||||
- 3D localization and confidence intervals
|
||||
- Orientation
|
||||
- Bounding box dimensions
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import datetime
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
import time
|
||||
from itertools import chain
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from .. import __version__
|
||||
from .datasets import KeypointsDataset
|
||||
from .losses import CompositeLoss, MultiTaskLoss, AutoTuneMultiTaskLoss
|
||||
from ..network import extract_outputs, extract_labels
|
||||
from ..network.architectures import LocoModel
|
||||
from ..utils import set_logger
|
||||
|
||||
|
||||
class CASRTrainerStandard:
|
||||
# Constants
|
||||
VAL_BS = 10000
|
||||
|
||||
tasks = ('cyclist',)
|
||||
val_task = 'cyclist'
|
||||
lambdas = (1,)
|
||||
#clusters = ['10', '20', '30', '40']
|
||||
input_size = 34
|
||||
output_size = 3
|
||||
dir_figures = os.path.join('figures', 'losses')
|
||||
|
||||
def __init__(self, args):
|
||||
"""
|
||||
Initialize directories, load the data and parameters for the training
|
||||
"""
|
||||
|
||||
assert os.path.exists(args.joints), "Input file not found"
|
||||
self.mode = args.mode
|
||||
self.joints = args.joints
|
||||
self.num_epochs = args.epochs
|
||||
self.no_save = args.no_save
|
||||
self.print_loss = args.print_loss
|
||||
self.lr = args.lr
|
||||
self.sched_step = args.sched_step
|
||||
self.sched_gamma = args.sched_gamma
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_stage = args.n_stage
|
||||
self.r_seed = args.r_seed
|
||||
self.auto_tune_mtl = args.auto_tune_mtl
|
||||
|
||||
# Select path out
|
||||
if args.out:
|
||||
self.path_out = args.out # full path without extension
|
||||
dir_out, _ = os.path.split(self.path_out)
|
||||
else:
|
||||
dir_out = os.path.join('data', 'outputs')
|
||||
name = 'casr_standard'
|
||||
now = datetime.datetime.now()
|
||||
now_time = now.strftime("%Y%m%d-%H%M")[2:]
|
||||
name_out = name + '-' + now_time + '.pkl'
|
||||
self.path_out = os.path.join(dir_out, name_out)
|
||||
assert os.path.exists(dir_out), "Directory to save the model not found"
|
||||
print(self.path_out)
|
||||
# Select the device
|
||||
use_cuda = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda" if use_cuda else "cpu")
|
||||
print('Device: ', self.device)
|
||||
torch.manual_seed(self.r_seed)
|
||||
if use_cuda:
|
||||
torch.cuda.manual_seed(self.r_seed)
|
||||
|
||||
losses_tr, losses_val = CompositeLoss(self.tasks)()
|
||||
|
||||
if self.auto_tune_mtl:
|
||||
self.mt_loss = AutoTuneMultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
|
||||
else:
|
||||
self.mt_loss = MultiTaskLoss(losses_tr, losses_val, self.lambdas, self.tasks)
|
||||
self.mt_loss.to(self.device)
|
||||
|
||||
# Dataloader
|
||||
self.dataloaders = {phase: DataLoader(KeypointsDataset(self.joints, phase=phase),
|
||||
batch_size=args.bs, shuffle=True) for phase in ['train', 'val']}
|
||||
|
||||
self.dataset_sizes = {phase: len(KeypointsDataset(self.joints, phase=phase))
|
||||
for phase in ['train', 'val']}
|
||||
self.dataset_version = KeypointsDataset(self.joints, phase='train').get_version()
|
||||
|
||||
self._set_logger(args)
|
||||
|
||||
# Define the model
|
||||
self.logger.info('Sizes of the dataset: {}'.format(self.dataset_sizes))
|
||||
print(">>> creating model")
|
||||
|
||||
self.model = LocoModel(
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
linear_size=args.hidden_size,
|
||||
p_dropout=args.dropout,
|
||||
num_stage=self.n_stage,
|
||||
device=self.device,
|
||||
)
|
||||
self.model.to(self.device)
|
||||
print(">>> model params: {:.3f}M".format(sum(p.numel() for p in self.model.parameters()) / 1000000.0))
|
||||
print(">>> loss params: {}".format(sum(p.numel() for p in self.mt_loss.parameters())))
|
||||
|
||||
# Optimizer and scheduler
|
||||
all_params = chain(self.model.parameters(), self.mt_loss.parameters())
|
||||
self.optimizer = torch.optim.Adam(params=all_params, lr=args.lr)
|
||||
self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')
|
||||
self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=self.sched_step, gamma=self.sched_gamma)
|
||||
|
||||
def train(self):
|
||||
since = time.time()
|
||||
best_model_wts = copy.deepcopy(self.model.state_dict())
|
||||
best_acc = 1e6
|
||||
best_training_acc = 1e6
|
||||
best_epoch = 0
|
||||
epoch_losses = defaultdict(lambda: defaultdict(list))
|
||||
for epoch in range(self.num_epochs):
|
||||
running_loss = defaultdict(lambda: defaultdict(int))
|
||||
|
||||
# Each epoch has a training and validation phase
|
||||
for phase in ['train', 'val']:
|
||||
if phase == 'train':
|
||||
self.model.train() # Set model to training mode
|
||||
else:
|
||||
self.model.eval() # Set model to evaluate mode
|
||||
|
||||
for inputs, labels, _, _ in self.dataloaders[phase]:
|
||||
inputs = inputs.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
with torch.set_grad_enabled(phase == 'train'):
|
||||
if phase == 'train':
|
||||
self.optimizer.zero_grad()
|
||||
outputs = self.model(inputs)
|
||||
loss, _ = self.mt_loss(outputs, labels, phase=phase)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3)
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
|
||||
else:
|
||||
outputs = self.model(inputs)
|
||||
with torch.no_grad():
|
||||
loss_eval, loss_values_eval = self.mt_loss(outputs, labels, phase='val')
|
||||
self.epoch_logs(phase, loss_eval, loss_values_eval, inputs, running_loss)
|
||||
|
||||
self.cout_values(epoch, epoch_losses, running_loss)
|
||||
|
||||
# deep copy the model
|
||||
|
||||
if epoch_losses['val'][self.val_task][-1] < best_acc:
|
||||
best_acc = epoch_losses['val'][self.val_task][-1]
|
||||
best_training_acc = epoch_losses['train']['all'][-1]
|
||||
best_epoch = epoch
|
||||
best_model_wts = copy.deepcopy(self.model.state_dict())
|
||||
|
||||
time_elapsed = time.time() - since
|
||||
print('\n\n' + '-' * 120)
|
||||
self.logger.info('Training:\nTraining complete in {:.0f}m {:.0f}s'
|
||||
.format(time_elapsed // 60, time_elapsed % 60))
|
||||
self.logger.info('Best training Accuracy: {:.3f}'.format(best_training_acc))
|
||||
self.logger.info('Best validation Accuracy for {}: {:.3f}'.format(self.val_task, best_acc))
|
||||
self.logger.info('Saved weights of the model at epoch: {}'.format(best_epoch))
|
||||
|
||||
self._print_losses(epoch_losses)
|
||||
|
||||
# load best model weights
|
||||
self.model.load_state_dict(best_model_wts)
|
||||
return best_epoch
|
||||
|
||||
def epoch_logs(self, phase, loss, loss_values, inputs, running_loss):
|
||||
|
||||
running_loss[phase]['all'] += loss.item() * inputs.size(0)
|
||||
for i, task in enumerate(self.tasks):
|
||||
running_loss[phase][task] += loss_values[i].item() * inputs.size(0)
|
||||
|
||||
def evaluate(self, load=False, model=None, debug=False):
|
||||
|
||||
# To load a model instead of using the trained one
|
||||
if load:
|
||||
self.model.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
|
||||
|
||||
# Average distance on training and test set after unnormalizing
|
||||
self.model.eval()
|
||||
dic_err = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) # initialized to zero
|
||||
dic_err['val']['sigmas'] = [0.] * len(self.tasks)
|
||||
dataset = KeypointsDataset(self.joints, phase='val')
|
||||
size_eval = len(dataset)
|
||||
start = 0
|
||||
with torch.no_grad():
|
||||
for end in range(self.VAL_BS, size_eval + self.VAL_BS, self.VAL_BS):
|
||||
end = end if end < size_eval else size_eval
|
||||
inputs, labels, _, _ = dataset[start:end]
|
||||
start = end
|
||||
inputs = inputs.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
# Debug plot for input-output distributions
|
||||
if debug:
|
||||
debug_plots(inputs, labels)
|
||||
sys.exit()
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(inputs)
|
||||
#self.compute_stats(outputs, labels, dic_err['val'], size_eval, clst='all')
|
||||
|
||||
# self.cout_stats(dic_err['val'], size_eval, clst='all')
|
||||
# Evaluate performances on different clusters and save statistics
|
||||
|
||||
# Save the model and the results
|
||||
if not (self.no_save or load):
|
||||
torch.save(self.model.state_dict(), self.path_model)
|
||||
print('-' * 120)
|
||||
self.logger.info("\nmodel saved: {} \n".format(self.path_model))
|
||||
else:
|
||||
self.logger.info("\nmodel not saved\n")
|
||||
|
||||
return dic_err, self.model
|
||||
|
||||
def compute_stats(self, outputs, labels, dic_err, size_eval, clst):
|
||||
"""Compute mean, bi and max of torch tensors"""
|
||||
|
||||
_, loss_values = self.mt_loss(outputs, labels, phase='val')
|
||||
rel_frac = outputs.size(0) / size_eval
|
||||
|
||||
tasks = self.tasks # Exclude auxiliary
|
||||
|
||||
for idx, task in enumerate(tasks):
|
||||
dic_err[clst][task] += float(loss_values[idx].item()) * (outputs.size(0) / size_eval)
|
||||
|
||||
# Distance
|
||||
errs = torch.abs(extract_outputs(outputs)['d'] - extract_labels(labels)['d'])
|
||||
assert rel_frac > 0.99, "Variance of errors not supported with partial evaluation"
|
||||
|
||||
# Uncertainty
|
||||
bis = extract_outputs(outputs)['bi'].cpu()
|
||||
bi = float(torch.mean(bis).item())
|
||||
bi_perc = float(torch.sum(errs <= bis)) / errs.shape[0]
|
||||
dic_err[clst]['bi'] += bi * rel_frac
|
||||
dic_err[clst]['bi%'] += bi_perc * rel_frac
|
||||
dic_err[clst]['std'] = errs.std()
|
||||
|
||||
# (Don't) Save auxiliary task results
|
||||
dic_err['sigmas'].append(0)
|
||||
|
||||
if self.auto_tune_mtl:
|
||||
assert len(loss_values) == 2 * len(self.tasks)
|
||||
for i, _ in enumerate(self.tasks):
|
||||
dic_err['sigmas'][i] += float(loss_values[len(tasks) + i + 1].item()) * rel_frac
|
||||
|
||||
def cout_stats(self, dic_err, size_eval, clst):
|
||||
if clst == 'all':
|
||||
print('-' * 120)
|
||||
self.logger.info("Evaluation, val set: \nAv. dist D: {:.2f} m with bi {:.2f} ({:.1f}%), \n"
|
||||
"X: {:.1f} cm, Y: {:.1f} cm \nOri: {:.1f} "
|
||||
"\n H: {:.1f} cm, W: {:.1f} cm, L: {:.1f} cm"
|
||||
"\nAuxiliary Task: {:.1f} %, "
|
||||
.format(dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
|
||||
dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
|
||||
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
|
||||
dic_err[clst]['l'] * 100, dic_err[clst]['aux'] * 100))
|
||||
if self.auto_tune_mtl:
|
||||
self.logger.info("Sigmas: Z: {:.2f}, X: {:.2f}, Y:{:.2f}, H: {:.2f}, W: {:.2f}, L: {:.2f}, ORI: {:.2f}"
|
||||
" AUX:{:.2f}\n"
|
||||
.format(*dic_err['sigmas']))
|
||||
else:
|
||||
self.logger.info("Val err clust {} --> D:{:.2f}m, bi:{:.2f} ({:.1f}%), STD:{:.1f}m X:{:.1f} Y:{:.1f} "
|
||||
"Ori:{:.1f}d, H: {:.0f} W: {:.0f} L:{:.0f} for {} pp. "
|
||||
.format(clst, dic_err[clst]['d'], dic_err[clst]['bi'], dic_err[clst]['bi%'] * 100,
|
||||
dic_err[clst]['std'], dic_err[clst]['x'] * 100, dic_err[clst]['y'] * 100,
|
||||
dic_err[clst]['ori'], dic_err[clst]['h'] * 100, dic_err[clst]['w'] * 100,
|
||||
dic_err[clst]['l'] * 100, size_eval))
|
||||
|
||||
def cout_values(self, epoch, epoch_losses, running_loss):
|
||||
|
||||
string = '\r' + '{:.0f} '
|
||||
format_list = [epoch]
|
||||
for phase in running_loss:
|
||||
string = string + phase[0:1].upper() + ':'
|
||||
for el in running_loss['train']:
|
||||
loss = running_loss[phase][el] / self.dataset_sizes[phase]
|
||||
print("Loss = ", loss)
|
||||
epoch_losses[phase][el].append(loss)
|
||||
if el == 'all':
|
||||
string = string + ':{:.1f} '
|
||||
format_list.append(loss)
|
||||
elif el in ('ori', 'aux'):
|
||||
string = string + el + ':{:.1f} '
|
||||
format_list.append(loss)
|
||||
else:
|
||||
string = string + el + ':{:.0f} '
|
||||
format_list.append(loss * 100)
|
||||
|
||||
if epoch % 10 == 0:
|
||||
print(string.format(*format_list))
|
||||
|
||||
def _print_losses(self, epoch_losses):
|
||||
if not self.print_loss:
|
||||
return
|
||||
os.makedirs(self.dir_figures, exist_ok=True)
|
||||
for idx, phase in enumerate(epoch_losses):
|
||||
for idx_2, el in enumerate(epoch_losses['train']):
|
||||
plt.figure(idx + idx_2)
|
||||
plt.title(phase + '_' + el)
|
||||
plt.xlabel('epochs')
|
||||
plt.plot(epoch_losses[phase][el][10:], label='{} Loss: {}'.format(phase, el))
|
||||
plt.savefig(os.path.join(self.dir_figures, '{}_loss_{}.png'.format(phase, el)))
|
||||
plt.close()
|
||||
|
||||
def _set_logger(self, args):
|
||||
if self.no_save:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
else:
|
||||
self.path_model = self.path_out
|
||||
print(self.path_model)
|
||||
self.logger = set_logger(os.path.splitext(self.path_out)[0]) # remove .pkl
|
||||
self.logger.info( # pylint: disable=logging-fstring-interpolation
|
||||
f'\nVERSION: {__version__}\n'
|
||||
f'\nINPUT_FILE: {args.joints}'
|
||||
f'\nInput file version: {self.dataset_version}'
|
||||
f'\nTorch version: {torch.__version__}\n'
|
||||
f'\nTraining arguments:'
|
||||
f'\nmode: {self.mode} \nlearning rate: {args.lr} \nbatch_size: {args.bs}'
|
||||
f'\nepochs: {args.epochs} \ndropout: {args.dropout} '
|
||||
f'\nscheduler step: {args.sched_step} \nscheduler gamma: {args.sched_gamma} '
|
||||
f'\ninput_size: {self.input_size} \noutput_size: {self.output_size} '
|
||||
f'\nhidden_size: {args.hidden_size}'
|
||||
f' \nn_stages: {args.n_stage} \n r_seed: {args.r_seed} \nlambdas: {self.lambdas}'
|
||||
)
|
||||
|
||||
|
||||
def debug_plots(inputs, labels):
|
||||
inputs_shoulder = inputs.cpu().numpy()[:, 5]
|
||||
inputs_hip = inputs.cpu().numpy()[:, 11]
|
||||
labels = labels.cpu().numpy()
|
||||
heights = inputs_hip - inputs_shoulder
|
||||
plt.figure(1)
|
||||
plt.hist(heights, bins='auto')
|
||||
plt.show()
|
||||
plt.figure(2)
|
||||
plt.hist(labels, bins='auto')
|
||||
plt.show()
|
||||
|
||||
|
||||
def get_accuracy(outputs, labels):
|
||||
"""From Binary cross entropy outputs to accuracy"""
|
||||
|
||||
mask = outputs >= 0.5
|
||||
accuracy = 1. - torch.mean(torch.abs(mask.float() - labels)).item()
|
||||
return accuracy
|
||||
@ -132,6 +132,9 @@ class KeypointPainter:
|
||||
if 'raise_hand' in activities:
|
||||
c, linewidth = highlighted_arm(x, y, connection, c, linewidth,
|
||||
dic_out['raising_hand'][:][i], size=size)
|
||||
if 'is_turning' in activities:
|
||||
c, linewidth = highlighted_arm(x, y, connection, c, linewidth,
|
||||
dic_out['turning'][:][i], size=size)
|
||||
|
||||
if self.color_connections:
|
||||
c = matplotlib.cm.get_cmap('tab20')(ci / len(self.skeleton))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user