Skip to content

Multiclass Model with interactions leads to invalid model #24

@pschleiter

Description

@pschleiter
  • ebm2onnx version: 3.3.0
  • python version: 3.9
  • Operating System: Windows

Description

I like to convert an ExplainableBoostingClassifier for a multiclass problem (with interactions set to a value greater than 0) in to an onnx model. The created onnx_model can not be instancated in a onnxruntime, because some reshape nodes have a wrong shape tensor.

What I Did

Here is a minimal example demonstrating the issue:

# /// script
# requires-python = ">=3.9"
# dependencies = [
#     "ebm2onnx==3.3.0",
#     "interpret==0.6.14",
#     "onnx==1.18.0",
#     "onnxruntime==1.19.2",
#     "pandas==2.3.0",
#     "skl2onnx==1.19.1",
# ]
# ///

if __name__ == "__main__":
    import pandas as pd
    from interpret.glassbox import ExplainableBoostingClassifier
    import ebm2onnx

    # Small dataset
    X = pd.DataFrame(
        {
            "feature1": [0, 0, 1, 1] * 8,
            "feature2": [0] * 16 + [1] * 16,
        }
    )
    y = pd.Series([0] * 24 + [1] * 4 + [2] * 4)

    # Train model
    model = ExplainableBoostingClassifier(interactions=2)
    model.fit(X=X, y=y)

    # Convert mode to onnx
    onnx_model = ebm2onnx.to_onnx(
        model,
        ebm2onnx.get_dtype_from_pandas(X),
        predict_proba=True,
    )
    onnx_model.ir_version = 10  # Issue #22

    # Score with the onnx model
    import onnxruntime as rt

    session = rt.InferenceSession(onnx_model.SerializeToString())
    # Creating the session fails with the following error
    #   onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (Concat_1) Op (Concat) [ShapeInferenceError] Can't merge shape info. Both inferred and declared dimension have values but they differ. Inferred=1 Declared=3 Dimension=2

    print(
        list(
            filter(
                lambda i: i.name.startswith("score_reshape"),
                onnx_model.graph.initializer,
            )
        )
    )
    # Provides the following output:
    #   [
    #   "dims: 3\ndata_type: 7\nint64_data: -1\nint64_data: 1\nint64_data: 3\nname: \"score_reshape_0\"\n",
    #   "dims: 3\ndata_type: 7\nint64_data: -1\nint64_data: 1\nint64_data: 3\nname: \"score_reshape_1\"\n",
    #   "dims: 3\ndata_type: 7\nint64_data: -1\nint64_data: 1\nint64_data: 1\nname: \"score_reshape_2\"\n"
    #   ]
    # The tensor "score_reshape_2" is faulty if changing to [-1, 1, 3] manually the session can be create and scoring works

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions