Fixed predict
This commit is contained in:
parent
1a2ec7a0ef
commit
ce6d26d307
@ -28,7 +28,7 @@ class Loco:
|
|||||||
N_SAMPLES = 100
|
N_SAMPLES = 100
|
||||||
|
|
||||||
def __init__(self, model, mode, net=None, device=None, n_dropout=0,
|
def __init__(self, model, mode, net=None, device=None, n_dropout=0,
|
||||||
p_dropout=0.2, linear_size=1024, casr='nonstd', casr_model=None):
|
p_dropout=0.2, linear_size=1024, casr=None, casr_model=None):
|
||||||
|
|
||||||
# Select networks
|
# Select networks
|
||||||
assert mode in ('mono', 'stereo'), "mode not recognized"
|
assert mode in ('mono', 'stereo'), "mode not recognized"
|
||||||
@ -62,7 +62,7 @@ class Loco:
|
|||||||
print("CASR with standard gestures")
|
print("CASR with standard gestures")
|
||||||
turning_output_size = 3
|
turning_output_size = 3
|
||||||
turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl"
|
turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl"
|
||||||
else:
|
elif casr== 'nonstd':
|
||||||
turning_output_size = 4
|
turning_output_size = 4
|
||||||
if casr_model:
|
if casr_model:
|
||||||
turning_model_path = casr_model
|
turning_model_path = casr_model
|
||||||
@ -84,17 +84,20 @@ class Loco:
|
|||||||
if net in ('monoloco', 'monoloco_p'):
|
if net in ('monoloco', 'monoloco_p'):
|
||||||
self.model = MonolocoModel(p_dropout=p_dropout, input_size=input_size, linear_size=linear_size,
|
self.model = MonolocoModel(p_dropout=p_dropout, input_size=input_size, linear_size=linear_size,
|
||||||
output_size=output_size)
|
output_size=output_size)
|
||||||
self.turning_model = MonolocoModel(p_dropout=p_dropout, input_size=34, linear_size=linear_size,
|
if casr:
|
||||||
output_size=turning_output_size)
|
self.turning_model = MonolocoModel(p_dropout=p_dropout, input_size=34, linear_size=linear_size,
|
||||||
|
output_size=turning_output_size)
|
||||||
else:
|
else:
|
||||||
self.model = LocoModel(p_dropout=p_dropout, input_size=input_size, output_size=output_size,
|
self.model = LocoModel(p_dropout=p_dropout, input_size=input_size, output_size=output_size,
|
||||||
linear_size=linear_size, device=self.device)
|
linear_size=linear_size, device=self.device)
|
||||||
self.turning_model = LocoModel(p_dropout=p_dropout, input_size=34, output_size=turning_output_size,
|
if casr:
|
||||||
linear_size=linear_size, device=self.device)
|
self.turning_model = LocoModel(p_dropout=p_dropout, input_size=34, output_size=turning_output_size,
|
||||||
|
linear_size=linear_size, device=self.device)
|
||||||
|
|
||||||
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
|
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
|
||||||
self.turning_model.load_state_dict(torch.load(turning_model_path,
|
if casr:
|
||||||
map_location=lambda storage, loc: storage))
|
self.turning_model.load_state_dict(torch.load(turning_model_path,
|
||||||
|
map_location=lambda storage, loc: storage))
|
||||||
else:
|
else:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model.eval() # Default is train
|
self.model.eval() # Default is train
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user