Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/onnx/onnxmltools into xgbbug
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Oct 2, 2023
2 parents e9bf791 + 79c34e3 commit 28a515f
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 26 deletions.
8 changes: 4 additions & 4 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:

Python311-1140-RT1151-xgb175:
python.version: '3.11'
ONNX_PATH: 'onnx==1.14.0' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNX_PATH: 'onnx==1.14.1' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
Expand All @@ -27,7 +27,7 @@ jobs:

Python310-1140-RT1151-xgb175:
python.version: '3.10'
ONNX_PATH: 'onnx==1.14.0' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNX_PATH: 'onnx==1.14.1' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
lightgbm.version: '<4.0'
Expand All @@ -37,7 +37,7 @@ jobs:

Python310-1140-RT1140-xgb175:
python.version: '3.10'
ONNX_PATH: 'onnx==1.14.0' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNX_PATH: 'onnx==1.14.1' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNXRT_PATH: onnxruntime==1.14.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003'
COREML_PATH: NONE
lightgbm.version: '<4.0'
Expand All @@ -47,7 +47,7 @@ jobs:

Python39-1140-RT1151-xgb175-scipy180:
python.version: '3.9'
ONNX_PATH: 'onnx==1.14.0' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNX_PATH: 'onnx==1.14.1' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
Expand Down
6 changes: 3 additions & 3 deletions .azure-pipelines/win32-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ jobs:

Python311-1140-RT1151:
python.version: '3.11'
ONNX_PATH: 'onnx==1.14.0' # '-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNX_PATH: 'onnx==1.14.1' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
numpy.version: ''

Python310-1140-RT1151:
python.version: '3.10'
ONNX_PATH: 'onnx==1.14.0' # '-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNX_PATH: 'onnx==1.14.1' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
numpy.version: ''

Python310-1140-RT1140:
python.version: '3.10'
ONNX_PATH: 'onnx==1.14.0' # '-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNX_PATH: 'onnx==1.14.1' #'-i https://test.pypi.org/simple/ onnx==1.14.0rc3'
ONNXRT_PATH: onnxruntime==1.14.0 #'-i https://test.pypi.org/simple/ ort-nightly==1.11.0.dev20220311003'
COREML_PATH: NONE
numpy.version: ''
Expand Down
1 change: 1 addition & 0 deletions onnxmltools/convert/sparkml/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..common.onnx_ex import get_maximum_opset_supported
from ..common._topology import convert_topology
from ._parse import parse_sparkml
from . import operator_converters # noqa: F401


def convert(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time
import numpy
import re
from pyspark.sql import SparkSession


Expand Down Expand Up @@ -47,19 +48,65 @@ def sparkml_tree_dataset_to_sklearn(tree_df, is_classifier):


def save_read_sparkml_model_data(spark: SparkSession, model):
tdir = tempfile.tempdir
if tdir is None:
local_dir = spark._jvm.org.apache.spark.util.Utils.getLocalDir(
spark._jsc.sc().conf()
)
tdir = spark._jvm.org.apache.spark.util.Utils.createTempDir(
local_dir, "onnx"
).getAbsolutePath()
if tdir is None:
raise FileNotFoundError(
"Unable to create a temporary directory for model '{}'"
".".format(type(model).__name__)
)
# Get the value of spark.master
spark_mode = spark.conf.get("spark.master")

# Check the value of spark.master using regular expression
if "spark://" in spark_mode and (
"localhost" not in spark_mode or "127.0.0.1" not in spark_mode
):
dfs_key = "ONNX_DFS_PATH"
try:
dfs_path = spark.conf.get("ONNX_DFS_PATH")
except Exception:
raise ValueError(
"Configuration property '{}' does not exist for SparkSession. \
Please set this variable to a root distributed file system path to allow \
for saving and reading of spark models in cluster mode. \
You can set this in your SparkConfig \
by setting sparkBuilder.config(ONNX_DFS_PATH, dfs_path)".format(
dfs_key
)
)
if dfs_path is None:
# If dfs_path is not specified, throw an error message
# dfs_path arg is required for cluster mode
raise ValueError(
"Argument dfs_path is required for saving model '{}' in cluster mode. \
You can set this in your SparkConfig by \
setting sparkBuilder.config(ONNX_DFS_PATH, dfs_path)".format(
type(model).__name__
)
)
else:
# Check that the dfs_path is a valid distributed file system path
# This can be hdfs, wabs, s3, etc.
if re.match(r"^[a-zA-Z]+://", dfs_path) is None:
raise ValueError(
"Argument dfs_path '{}' is not a valid distributed path".format(
dfs_path
)
)
else:
# If dfs_path is specified, save the model to a tmp directory
# The dfs_path will be the root of the /tmp
tdir = os.path.join(dfs_path, "tmp/onnx")
else:
# If spark.master is not set or set to local, save the model to a local path.
tdir = tempfile.tempdir
if tdir is None:
local_dir = spark._jvm.org.apache.spark.util.Utils.getLocalDir(
spark._jsc.sc().conf()
)
tdir = spark._jvm.org.apache.spark.util.Utils.createTempDir(
local_dir, "onnx"
).getAbsolutePath()
if tdir is None:
raise FileNotFoundError(
"Unable to create a temporary directory for model '{}'"
".".format(type(model).__name__)
)

path = os.path.join(tdir, type(model).__name__ + "_" + str(time.time()))
model.write().overwrite().save(path)
df = spark.read.parquet(os.path.join(path, "data"))
Expand Down
11 changes: 5 additions & 6 deletions onnxmltools/convert/sparkml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def buildInitialTypesSimple(dataframe):


def getTensorTypeFromSpark(sparktype):
if sparktype == "StringType" or sparktype == "StringType()":
if sparktype in ("StringType", "StringType()"):
return StringTensorType([1, 1])
elif (
if (
sparktype == "DecimalType"
or sparktype == "DecimalType()"
or sparktype == "DoubleType"
Expand All @@ -34,17 +34,16 @@ def getTensorTypeFromSpark(sparktype):
or sparktype == "BooleanType"
or sparktype == "BooleanType()"
):
return FloatTensorType([1, 1])
else:
raise TypeError("Cannot map this type to Onnx types: " + sparktype)
return FloatTensorType([None, 1])
raise TypeError(f"Cannot map this type to Onnx types: {sparktype}.")


def buildInputDictSimple(dataframe):
import numpy

result = {}
for field in dataframe.schema.fields:
if str(field.dataType) == "StringType":
if str(field.dataType) in ("StringType", "StringType()"):
result[field.name] = dataframe.select(field.name).toPandas().values
else:
result[field.name] = (
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ scikit-learn>=1.2.0
scipy
wheel
xgboost==1.7.5
onnxruntime

0 comments on commit 28a515f

Please sign in to comment.