Fixed predict

This commit is contained in:
Charles Joseph Pierre Beauville 2021-06-28 00:58:47 +02:00
parent 1a2ec7a0ef
commit ce6d26d307

View File

@ -28,7 +28,7 @@ class Loco:
N_SAMPLES = 100 N_SAMPLES = 100
def __init__(self, model, mode, net=None, device=None, n_dropout=0, def __init__(self, model, mode, net=None, device=None, n_dropout=0,
p_dropout=0.2, linear_size=1024, casr='nonstd', casr_model=None): p_dropout=0.2, linear_size=1024, casr=None, casr_model=None):
# Select networks # Select networks
assert mode in ('mono', 'stereo'), "mode not recognized" assert mode in ('mono', 'stereo'), "mode not recognized"
@ -62,7 +62,7 @@ class Loco:
print("CASR with standard gestures") print("CASR with standard gestures")
turning_output_size = 3 turning_output_size = 3
turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl" turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl"
else: elif casr== 'nonstd':
turning_output_size = 4 turning_output_size = 4
if casr_model: if casr_model:
turning_model_path = casr_model turning_model_path = casr_model
@ -84,17 +84,20 @@ class Loco:
if net in ('monoloco', 'monoloco_p'): if net in ('monoloco', 'monoloco_p'):
self.model = MonolocoModel(p_dropout=p_dropout, input_size=input_size, linear_size=linear_size, self.model = MonolocoModel(p_dropout=p_dropout, input_size=input_size, linear_size=linear_size,
output_size=output_size) output_size=output_size)
self.turning_model = MonolocoModel(p_dropout=p_dropout, input_size=34, linear_size=linear_size, if casr:
output_size=turning_output_size) self.turning_model = MonolocoModel(p_dropout=p_dropout, input_size=34, linear_size=linear_size,
output_size=turning_output_size)
else: else:
self.model = LocoModel(p_dropout=p_dropout, input_size=input_size, output_size=output_size, self.model = LocoModel(p_dropout=p_dropout, input_size=input_size, output_size=output_size,
linear_size=linear_size, device=self.device) linear_size=linear_size, device=self.device)
self.turning_model = LocoModel(p_dropout=p_dropout, input_size=34, output_size=turning_output_size, if casr:
linear_size=linear_size, device=self.device) self.turning_model = LocoModel(p_dropout=p_dropout, input_size=34, output_size=turning_output_size,
linear_size=linear_size, device=self.device)
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
self.turning_model.load_state_dict(torch.load(turning_model_path, if casr:
map_location=lambda storage, loc: storage)) self.turning_model.load_state_dict(torch.load(turning_model_path,
map_location=lambda storage, loc: storage))
else: else:
self.model = model self.model = model
self.model.eval() # Default is train self.model.eval() # Default is train