Skip to content

Commit

Permalink
* Bump version -> 4.2.1 (#881)
Browse files Browse the repository at this point in the history
* Revert to passing full path to model in training call which got accidentally broken in 4.2 master.
  • Loading branch information
DocGarbanzo authored Jun 10, 2021
1 parent 5c34d23 commit 81c2ea3
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
2 changes: 1 addition & 1 deletion donkeycar/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from pyfiglet import Figlet

__version__ = '4.2.0'
__version__ = '4.2.1'
f = Figlet(font='speed')

print(f.renderText('Donkey Car'))
Expand Down
4 changes: 2 additions & 2 deletions donkeycar/pipeline/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def generate_model_name(self) -> Tuple[str, int]:
else:
this_num = 0
date = time.strftime('%y-%m-%d')
name = 'pilot_' + date + '_' + str(this_num)
return name, this_num
name = f'pilot_{date}_{this_num}.h5'
return os.path.join(self.cfg.MODELS_PATH, name), this_num

def to_df(self) -> pd.DataFrame:
if self.entries:
Expand Down
29 changes: 19 additions & 10 deletions donkeycar/pipeline/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def create_tf_data(self) -> tf.data.Dataset:
def get_model_train_details(cfg: Config, database: PilotDatabase,
model: str = None, model_type: str = None) \
-> Tuple[str, int, str, bool]:
"""
Returns automatic model name if none is given
:param cfg: donkey config
:param database: model database with existing training data
:param model: model path
:param model_type: type of model, like 'linear', 'tflite_linear', etc
:return: tuple of model path, number, training type, and if
tflite is requested
"""
if not model_type:
model_type = cfg.DEFAULT_MODEL_TYPE
train_type = model_type
Expand All @@ -90,12 +99,13 @@ def get_model_train_details(cfg: Config, database: PilotDatabase,
is_tflite = True
model_num = 0
if not model:
model_name, model_num = database.generate_model_name()
model_path, model_num = database.generate_model_name()
else:
model_name, model_ext = os.path.splitext(model)
_, model_ext = os.path.splitext(model)
model_path = model
is_tflite = model_ext == '.tflite'

return model_name, model_num, train_type, is_tflite
return model_path, model_num, train_type, is_tflite


def train(cfg: Config, tub_paths: str, model: str = None,
Expand All @@ -105,10 +115,9 @@ def train(cfg: Config, tub_paths: str, model: str = None,
Train the model
"""
database = PilotDatabase(cfg)
model_name, model_num, train_type, is_tflite = \
model_path, model_num, train_type, is_tflite = \
get_model_train_details(cfg, database, model, model_type)

output_path = os.path.join(cfg.MODELS_PATH, model_name + '.h5')
kl = get_model_by_type(train_type, cfg)
if transfer:
kl.load(transfer)
Expand All @@ -135,7 +144,7 @@ def train(cfg: Config, tub_paths: str, model: str = None,
assert val_size > 0, "Not enough validation data, decrease the batch " \
"size or add more data."

history = kl.train(model_path=output_path,
history = kl.train(model_path=model_path,
train_data=dataset_train,
train_steps=train_size,
batch_size=cfg.BATCH_SIZE,
Expand All @@ -146,14 +155,14 @@ def train(cfg: Config, tub_paths: str, model: str = None,
min_delta=cfg.MIN_DELTA,
patience=cfg.EARLY_STOP_PATIENCE,
show_plot=cfg.SHOW_PLOT)

base_path = os.path.splitext(model_path)[0]
if is_tflite:
tf_lite_model_path = f'{os.path.splitext(output_path)[0]}.tflite'
keras_model_to_tflite(output_path, tf_lite_model_path)
tf_lite_model_path = f'{base_path}.tflite'
keras_model_to_tflite(model_path, tf_lite_model_path)

database_entry = {
'Number': model_num,
'Name': model_name,
'Name': os.path.basename(base_path),
'Type': str(kl),
'Tubs': tub_paths,
'Time': time(),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def package_files(directory, strip_leading):
long_description = fh.read()

setup(name='donkeycar',
version='4.2.0',
version='4.2.1',
long_description=long_description,
description='Self driving library for python.',
url='https://github.com/autorope/donkeycar',
Expand Down

0 comments on commit 81c2ea3

Please sign in to comment.