From ce6d26d307b868647ab04dfcef6b730fd71dbb54 Mon Sep 17 00:00:00 2001 From: Charles Joseph Pierre Beauville Date: Mon, 28 Jun 2021 00:58:47 +0200 Subject: [PATCH] Fixed predict --- monoloco/network/net.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/monoloco/network/net.py b/monoloco/network/net.py index 2e166de..338f531 100644 --- a/monoloco/network/net.py +++ b/monoloco/network/net.py @@ -28,7 +28,7 @@ class Loco: N_SAMPLES = 100 def __init__(self, model, mode, net=None, device=None, n_dropout=0, - p_dropout=0.2, linear_size=1024, casr='nonstd', casr_model=None): + p_dropout=0.2, linear_size=1024, casr=None, casr_model=None): # Select networks assert mode in ('mono', 'stereo'), "mode not recognized" @@ -62,7 +62,7 @@ class Loco: print("CASR with standard gestures") turning_output_size = 3 turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl" - else: + elif casr== 'nonstd': turning_output_size = 4 if casr_model: turning_model_path = casr_model @@ -84,17 +84,20 @@ class Loco: if net in ('monoloco', 'monoloco_p'): self.model = MonolocoModel(p_dropout=p_dropout, input_size=input_size, linear_size=linear_size, output_size=output_size) - self.turning_model = MonolocoModel(p_dropout=p_dropout, input_size=34, linear_size=linear_size, - output_size=turning_output_size) + if casr: + self.turning_model = MonolocoModel(p_dropout=p_dropout, input_size=34, linear_size=linear_size, + output_size=turning_output_size) else: self.model = LocoModel(p_dropout=p_dropout, input_size=input_size, output_size=output_size, linear_size=linear_size, device=self.device) - self.turning_model = LocoModel(p_dropout=p_dropout, input_size=34, output_size=turning_output_size, - linear_size=linear_size, device=self.device) + if casr: + self.turning_model = LocoModel(p_dropout=p_dropout, input_size=34, output_size=turning_output_size, + linear_size=linear_size, device=self.device) self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) - self.turning_model.load_state_dict(torch.load(turning_model_path, - map_location=lambda storage, loc: storage)) + if casr: + self.turning_model.load_state_dict(torch.load(turning_model_path, + map_location=lambda storage, loc: storage)) else: self.model = model self.model.eval() # Default is train