From 60da6548ba098c49886dedd3dff0ed2d945e3d6a Mon Sep 17 00:00:00 2001 From: DocGarbanzo Date: Thu, 31 Mar 2022 08:38:02 +0100 Subject: [PATCH] Fix torch imports for rpi (#1004) * Moving torch imports into functions such that donkeycar can continue to work w/o pytorch installation. This is the default setup on RPi. * Bumped version (cherry picked from commit 14db8a87b967a41e0a962f407a7cca06e0acdb43) --- donkeycar/parts/interpreter.py | 6 ++++-- donkeycar/utils.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/donkeycar/parts/interpreter.py b/donkeycar/parts/interpreter.py index a031e903c..d2912e190 100755 --- a/donkeycar/parts/interpreter.py +++ b/donkeycar/parts/interpreter.py @@ -5,7 +5,6 @@ from typing import Union, Sequence, List import tensorflow as tf -import torch from tensorflow import keras from tensorflow.python.framework.convert_to_constants import \ @@ -171,6 +170,7 @@ def load_weights(self, model_path: str, by_name: bool = True) -> \ def summary(self) -> str: return self.model.summary() + class FastAIInterpreter(Interpreter): def __init__(self): @@ -206,7 +206,7 @@ def invoke(self, inputs): def predict(self, img_arr: np.ndarray, other_arr: np.ndarray) \ -> Sequence[Union[float, np.ndarray]]: - + import torch inputs = torch.unsqueeze(img_arr, 0) if other_arr is not None: #other_arr = np.expand_dims(other_arr, axis=0) @@ -214,6 +214,7 @@ def predict(self, img_arr: np.ndarray, other_arr: np.ndarray) \ return self.invoke(inputs) def load(self, model_path: str) -> None: + import torch logger.info(f'Loading model {model_path}') if torch.cuda.is_available(): logger.info("using cuda for torch inference") @@ -228,6 +229,7 @@ def load(self, model_path: str) -> None: def summary(self) -> str: return self.model + class TfLite(Interpreter): """ This class wraps around the TensorFlow Lite interpreter. diff --git a/donkeycar/utils.py b/donkeycar/utils.py index 615d1f236..7d1b67604 100644 --- a/donkeycar/utils.py +++ b/donkeycar/utils.py @@ -439,8 +439,6 @@ def get_model_by_type(model_type: str, cfg: 'Config') -> Union['KerasPilot', 'Fa from donkeycar.parts.interpreter import KerasInterpreter, TfLite, TensorRT, \ FastAIInterpreter - from donkeycar.parts.fastai import FastAILinear - if model_type is None: model_type = cfg.DEFAULT_MODEL_TYPE logger.info(f'get_model_by_type: model type is: {model_type}') @@ -455,6 +453,7 @@ def get_model_by_type(model_type: str, cfg: 'Config') -> Union['KerasPilot', 'Fa interpreter = FastAIInterpreter() used_model_type = model_type.replace('fastai_', '') if used_model_type == "linear": + from donkeycar.parts.fastai import FastAILinear return FastAILinear(interpreter=interpreter, input_shape=input_shape) else: interpreter = KerasInterpreter()