Skip to content

Commit

Permalink
Fix torch imports for rpi (#1004)
Browse files Browse the repository at this point in the history
* Moving torch imports into functions such that donkeycar can continue to work w/o pytorch installation. This is the default setup on RPi.

* Bumped version

(cherry picked from commit 14db8a8)
  • Loading branch information
DocGarbanzo committed Mar 31, 2022
1 parent 2a6e12d commit 60da654
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 4 additions & 2 deletions donkeycar/parts/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Union, Sequence, List

import tensorflow as tf
import torch
from tensorflow import keras

from tensorflow.python.framework.convert_to_constants import \
Expand Down Expand Up @@ -171,6 +170,7 @@ def load_weights(self, model_path: str, by_name: bool = True) -> \
def summary(self) -> str:
return self.model.summary()


class FastAIInterpreter(Interpreter):

def __init__(self):
Expand Down Expand Up @@ -206,14 +206,15 @@ def invoke(self, inputs):

def predict(self, img_arr: np.ndarray, other_arr: np.ndarray) \
-> Sequence[Union[float, np.ndarray]]:

import torch
inputs = torch.unsqueeze(img_arr, 0)
if other_arr is not None:
#other_arr = np.expand_dims(other_arr, axis=0)
inputs = [img_arr, other_arr]
return self.invoke(inputs)

def load(self, model_path: str) -> None:
import torch
logger.info(f'Loading model {model_path}')
if torch.cuda.is_available():
logger.info("using cuda for torch inference")
Expand All @@ -228,6 +229,7 @@ def load(self, model_path: str) -> None:
def summary(self) -> str:
return self.model


class TfLite(Interpreter):
"""
This class wraps around the TensorFlow Lite interpreter.
Expand Down
3 changes: 1 addition & 2 deletions donkeycar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,6 @@ def get_model_by_type(model_type: str, cfg: 'Config') -> Union['KerasPilot', 'Fa
from donkeycar.parts.interpreter import KerasInterpreter, TfLite, TensorRT, \
FastAIInterpreter

from donkeycar.parts.fastai import FastAILinear

if model_type is None:
model_type = cfg.DEFAULT_MODEL_TYPE
logger.info(f'get_model_by_type: model type is: {model_type}')
Expand All @@ -455,6 +453,7 @@ def get_model_by_type(model_type: str, cfg: 'Config') -> Union['KerasPilot', 'Fa
interpreter = FastAIInterpreter()
used_model_type = model_type.replace('fastai_', '')
if used_model_type == "linear":
from donkeycar.parts.fastai import FastAILinear
return FastAILinear(interpreter=interpreter, input_shape=input_shape)
else:
interpreter = KerasInterpreter()
Expand Down

0 comments on commit 60da654

Please sign in to comment.