-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
- ebm2onnx version: 3.3.0
- python version: 3.9
- Operating System: Windows
Description
I like to convert an ExplainableBoostingClassifier for a multiclass problem (with interactions set to a value greater than 0) in to an onnx model. The created onnx_model can not be instancated in a onnxruntime, because some reshape nodes have a wrong shape tensor.
What I Did
Here is a minimal example demonstrating the issue:
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "ebm2onnx==3.3.0",
# "interpret==0.6.14",
# "onnx==1.18.0",
# "onnxruntime==1.19.2",
# "pandas==2.3.0",
# "skl2onnx==1.19.1",
# ]
# ///
if __name__ == "__main__":
import pandas as pd
from interpret.glassbox import ExplainableBoostingClassifier
import ebm2onnx
# Small dataset
X = pd.DataFrame(
{
"feature1": [0, 0, 1, 1] * 8,
"feature2": [0] * 16 + [1] * 16,
}
)
y = pd.Series([0] * 24 + [1] * 4 + [2] * 4)
# Train model
model = ExplainableBoostingClassifier(interactions=2)
model.fit(X=X, y=y)
# Convert mode to onnx
onnx_model = ebm2onnx.to_onnx(
model,
ebm2onnx.get_dtype_from_pandas(X),
predict_proba=True,
)
onnx_model.ir_version = 10 # Issue #22
# Score with the onnx model
import onnxruntime as rt
session = rt.InferenceSession(onnx_model.SerializeToString())
# Creating the session fails with the following error
# onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (Concat_1) Op (Concat) [ShapeInferenceError] Can't merge shape info. Both inferred and declared dimension have values but they differ. Inferred=1 Declared=3 Dimension=2
print(
list(
filter(
lambda i: i.name.startswith("score_reshape"),
onnx_model.graph.initializer,
)
)
)
# Provides the following output:
# [
# "dims: 3\ndata_type: 7\nint64_data: -1\nint64_data: 1\nint64_data: 3\nname: \"score_reshape_0\"\n",
# "dims: 3\ndata_type: 7\nint64_data: -1\nint64_data: 1\nint64_data: 3\nname: \"score_reshape_1\"\n",
# "dims: 3\ndata_type: 7\nint64_data: -1\nint64_data: 1\nint64_data: 1\nname: \"score_reshape_2\"\n"
# ]
# The tensor "score_reshape_2" is faulty if changing to [-1, 1, 3] manually the session can be create and scoring works
Metadata
Metadata
Assignees
Labels
No labels