This commit is contained in:
Charles Joseph Pierre Beauville 2021-06-28 01:04:35 +02:00
parent ce6d26d307
commit f20e17709c
2 changed files with 3 additions and 1 deletions

View File

@ -62,7 +62,7 @@ class Loco:
print("CASR with standard gestures") print("CASR with standard gestures")
turning_output_size = 3 turning_output_size = 3
turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl" turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl"
elif casr== 'nonstd': elif casr == 'nonstd':
turning_output_size = 4 turning_output_size = 4
if casr_model: if casr_model:
turning_model_path = casr_model turning_model_path = casr_model
@ -96,6 +96,7 @@ class Loco:
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
if casr: if casr:
print("WTF")
self.turning_model.load_state_dict(torch.load(turning_model_path, self.turning_model.load_state_dict(torch.load(turning_model_path,
map_location=lambda storage, loc: storage)) map_location=lambda storage, loc: storage))
else: else:

View File

@ -161,6 +161,7 @@ def predict(args):
# Load Models # Load Models
if args.mode in ('mono', 'stereo'): if args.mode in ('mono', 'stereo'):
print(args.casr)
net = Loco( net = Loco(
model=dic_models[args.mode], model=dic_models[args.mode],
mode=args.mode, mode=args.mode,