Fixed predict
This commit is contained in:
parent
1a2ec7a0ef
commit
ce6d26d307
@ -28,7 +28,7 @@ class Loco:
|
||||
N_SAMPLES = 100
|
||||
|
||||
def __init__(self, model, mode, net=None, device=None, n_dropout=0,
|
||||
p_dropout=0.2, linear_size=1024, casr='nonstd', casr_model=None):
|
||||
p_dropout=0.2, linear_size=1024, casr=None, casr_model=None):
|
||||
|
||||
# Select networks
|
||||
assert mode in ('mono', 'stereo'), "mode not recognized"
|
||||
@ -62,7 +62,7 @@ class Loco:
|
||||
print("CASR with standard gestures")
|
||||
turning_output_size = 3
|
||||
turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl"
|
||||
else:
|
||||
elif casr== 'nonstd':
|
||||
turning_output_size = 4
|
||||
if casr_model:
|
||||
turning_model_path = casr_model
|
||||
@ -84,15 +84,18 @@ class Loco:
|
||||
if net in ('monoloco', 'monoloco_p'):
|
||||
self.model = MonolocoModel(p_dropout=p_dropout, input_size=input_size, linear_size=linear_size,
|
||||
output_size=output_size)
|
||||
if casr:
|
||||
self.turning_model = MonolocoModel(p_dropout=p_dropout, input_size=34, linear_size=linear_size,
|
||||
output_size=turning_output_size)
|
||||
else:
|
||||
self.model = LocoModel(p_dropout=p_dropout, input_size=input_size, output_size=output_size,
|
||||
linear_size=linear_size, device=self.device)
|
||||
if casr:
|
||||
self.turning_model = LocoModel(p_dropout=p_dropout, input_size=34, output_size=turning_output_size,
|
||||
linear_size=linear_size, device=self.device)
|
||||
|
||||
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
|
||||
if casr:
|
||||
self.turning_model.load_state_dict(torch.load(turning_model_path,
|
||||
map_location=lambda storage, loc: storage))
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user