Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
remove signature
  • Loading branch information
Aman123lug authored Aug 28, 2023
1 parent 8134bdc commit baf7ab1
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,28 @@ def training(self, path:Path):
logger.info(f" MLflow log_metric saved !")

# Model registry does not work with file store
if tracking_url_type_store != "file":
# Register the model
# There are other ways to use the Model Registry, which depends on the use case,
# please refer to the doc for more information:
# https://mlflow.org/docs/latest/model-registry.html#api-workflow
mlflow.sklearn.log_model(
dtree, "model/model.pkl", registered_model_name="DecisionTree", signature=signature
)
else:
mlflow.sklearn.log_model(dtree, "model/model.pkl", signature=signature)
# if tracking_url_type_store != "file":
# # Register the model
# # There are other ways to use the Model Registry, which depends on the use case,
# # please refer to the doc for more information:
# # https://mlflow.org/docs/latest/model-registry.html#api-workflow
# mlflow.sklearn.log_model(
# dtree, r"model/model.pkl", registered_model_name="DecisionTree", signature=signature
# )
# else:
# mlflow.sklearn.log_model(dtree, r"model/model.pkl", signature=signature)

with open("metrics.txt", "w+") as f:
f.write(str(test_accracy))
logger.info(f" metrics saved !")
with open("metrics.txt", "w+") as f:
f.write(str(test_accracy))
logger.info(f" metrics saved !")

# save model
pickle.dump(dtree, open(r"model/model.pkl", "wb"))
logger.info(f" mode saved !")
IC = type('IdentityClassifier', (), {"predict": lambda i : i, "_estimator_type": "classifier"})
cm=ConfusionMatrixDisplay.from_estimator(IC, y_pred, y_test, normalize='true', values_format='.2%')

cm.figure_.savefig('images/confusion_matrix.png')
pickle.dump(dtree, open(r"model/model.pkl", "wb"))
logger.info(f" model saved !")
# IC = type('IdentityClassifier', (), {"predict": lambda i : i, "_estimator_type": "classifier"})
# cm=ConfusionMatrixDisplay.from_estimator(IC, y_pred, y_test, normalize='true', values_format='.2%')
cm.figure_.savefig('images/confusion_matrix.png')



Expand Down

0 comments on commit baf7ab1

Please sign in to comment.