Test
This commit is contained in:
parent
ce6d26d307
commit
f20e17709c
@ -62,7 +62,7 @@ class Loco:
|
|||||||
print("CASR with standard gestures")
|
print("CASR with standard gestures")
|
||||||
turning_output_size = 3
|
turning_output_size = 3
|
||||||
turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl"
|
turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl"
|
||||||
elif casr== 'nonstd':
|
elif casr == 'nonstd':
|
||||||
turning_output_size = 4
|
turning_output_size = 4
|
||||||
if casr_model:
|
if casr_model:
|
||||||
turning_model_path = casr_model
|
turning_model_path = casr_model
|
||||||
@ -96,6 +96,7 @@ class Loco:
|
|||||||
|
|
||||||
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
|
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
|
||||||
if casr:
|
if casr:
|
||||||
|
print("WTF")
|
||||||
self.turning_model.load_state_dict(torch.load(turning_model_path,
|
self.turning_model.load_state_dict(torch.load(turning_model_path,
|
||||||
map_location=lambda storage, loc: storage))
|
map_location=lambda storage, loc: storage))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -161,6 +161,7 @@ def predict(args):
|
|||||||
|
|
||||||
# Load Models
|
# Load Models
|
||||||
if args.mode in ('mono', 'stereo'):
|
if args.mode in ('mono', 'stereo'):
|
||||||
|
print(args.casr)
|
||||||
net = Loco(
|
net = Loco(
|
||||||
model=dic_models[args.mode],
|
model=dic_models[args.mode],
|
||||||
mode=args.mode,
|
mode=args.mode,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user