diff --git a/monoloco/network/net.py b/monoloco/network/net.py index 217384c..9a39dd0 100644 --- a/monoloco/network/net.py +++ b/monoloco/network/net.py @@ -61,7 +61,10 @@ class Loco: if casr == 'std': print("CASR with standard gestures") turning_output_size = 3 - turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl" + if casr_model: + turning_model_path = casr_model + else: + turning_model_path = "/home/beauvill/Repos/monoloco/data/outputs/casr_standard-210613-0005.pkl" elif casr == 'nonstd': turning_output_size = 4 if casr_model: