Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions bodo/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from numba.extending import lower_builtin, models, register_model

import bodo
from bodo.pandas_compat import bodo_pandas_udf_execution_engine

# Add Bodo's options to Numba's allowed options/flags
numba.core.cpu.CPUTargetOptions.all_args_distributed_block = _mapping(
Expand Down Expand Up @@ -328,6 +329,11 @@ def jit(signature_or_function=None, pipeline_class=None, **options):
# precedence)
disable_jit = os.environ.get("NUMBA_DISABLE_JIT", "0") == "1"
dist_mode = options.get("distributed", True) is not False

py_func = None
if isinstance(signature_or_function, pytypes.FunctionType):
py_func = signature_or_function

if options.get("spawn", bodo.spawn_mode) and not disable_jit and dist_mode:
from bodo.spawn.spawner import SpawnDispatcher
from bodo.spawn.worker_state import is_worker
Expand All @@ -346,18 +352,26 @@ def return_wrapped_fn(py_func):
submit_jit_args["pipeline_class"] = pipeline_class
return SpawnDispatcher(py_func, submit_jit_args)

if isinstance(signature_or_function, pytypes.FunctionType):
py_func = signature_or_function
if py_func is not None:
return return_wrapped_fn(py_func)

return return_wrapped_fn

bodo_jit = return_wrapped_fn
elif "propagate_env" in options:
raise bodo.utils.typing.BodoError(
"spawn=False while propagate_env is set. No worker to propagate env vars."
)
else:
bodo_jit = _jit(signature_or_function, pipeline_class, **options)

# Return jit decorator that can be used in Pandas UDF function. See definition of
# bodo_pandas_udf_execution_engine for more details.
if py_func is None:
bodo_jit.__pandas_udf__ = bodo_pandas_udf_execution_engine

return bodo_jit


return _jit(signature_or_function, pipeline_class, **options)
jit.__pandas_udf__ = bodo_pandas_udf_execution_engine


def _jit(signature_or_function=None, pipeline_class=None, **options):
Expand Down
110 changes: 110 additions & 0 deletions bodo/pandas_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,113 @@
else:
raise ValueError(f"Unsupported resolution {in_reso}")
return factor * value


# Class responsible for executing UDFs using Bodo as the engine in
# newer version of Pandas. See:
# https://github.com/pandas-dev/pandas/pull/61032
bodo_pandas_udf_execution_engine = None

if pandas_version >= (3, 0):
from collections.abc import Callable
from typing import Any

Check warning on line 256 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L255-L256

Added lines #L255 - L256 were not covered by tests

from pandas._typing import AggFuncType, Axis
from pandas.core.apply import BaseExecutionEngine

Check warning on line 259 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L258-L259

Added lines #L258 - L259 were not covered by tests

def _prepare_function_arguments(

Check warning on line 261 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L261

Added line #L261 was not covered by tests
func: Callable, args: tuple, kwargs: dict, *, num_required_args: int
) -> tuple[tuple, dict]:
"""
Prepare arguments for jitted function. by trying to move keyword arguments inside
of args to eliminate kwargs.
This simplifies typing as well as catches keyword-only arguments,
which lead to unexpected behavior in Bodo. Copied from:
https://github.com/pandas-dev/pandas/blob/5fef9793dd23867e7b227a1df7aa60a283f6204e/pandas/core/util/numba_.py#L97
"""
_sentinel = object()

Check warning on line 272 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L272

Added line #L272 was not covered by tests

if not kwargs:
return args, kwargs

Check warning on line 275 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L275

Added line #L275 was not covered by tests

# the udf should have this pattern: def udf(arg1, arg2, ..., *args, **kwargs):...
signature = inspect.signature(func)
arguments = signature.bind(*[_sentinel] * num_required_args, *args, **kwargs)
arguments.apply_defaults()

Check warning on line 280 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L278-L280

Added lines #L278 - L280 were not covered by tests
# Ref: https://peps.python.org/pep-0362/
# Arguments which could be passed as part of either *args or **kwargs
# will be included only in the BoundArguments.args attribute.
args = arguments.args
kwargs = arguments.kwargs

Check warning on line 285 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L284-L285

Added lines #L284 - L285 were not covered by tests

if kwargs:
# Bodo change: error message
raise ValueError("Bodo does not support keyword only arguments.")

Check warning on line 289 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L289

Added line #L289 was not covered by tests

args = args[num_required_args:]
return args, kwargs

Check warning on line 292 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L291-L292

Added lines #L291 - L292 were not covered by tests

class BodoExecutionEngine(BaseExecutionEngine):
@staticmethod
def map(

Check warning on line 296 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L294-L296

Added lines #L294 - L296 were not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not implemented yet in Pandas so there is no way to test. Leaving as a followup

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll open a PR adding engine to Series.map shortly. I have the implementation already, but map and apply tests are structured very differently, and I need to see how to implement the common fixtures without making the mess bigger or refactoring all the tests.

data: pd.Series | pd.DataFrame | np.ndarray,
func: AggFuncType,
args: tuple,
kwargs: dict[str, Any],
decorator: Callable | None,
skip_na: bool,
):
raise NotImplementedError("BodoExecutionEngine: map not implemented yet.")

Check warning on line 304 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L304

Added line #L304 was not covered by tests

@staticmethod
def apply(

Check warning on line 307 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L306-L307

Added lines #L306 - L307 were not covered by tests
data: pd.Series | pd.DataFrame | np.ndarray,
func: AggFuncType,
args: tuple,
kwargs: dict[str, Any],
decorator: Callable,
axis: Axis,
):
from bodo import spawn_mode
from bodo.spawn import spawner
from bodo.utils.utils import bodo_exec

Check warning on line 317 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L315-L317

Added lines #L315 - L317 were not covered by tests

# raw = True converts data to ndarray first
if isinstance(data, np.ndarray):
raise ValueError(

Check warning on line 321 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L321

Added line #L321 was not covered by tests
"BodoExecutionEngine: does not support the raw=True for DataFrame.apply."
)

if isinstance(func, Callable):
args, _ = _prepare_function_arguments(

Check warning on line 326 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L326

Added line #L326 was not covered by tests
func, args, kwargs, num_required_args=1
)

# Embed args as a string e.g. (args[0], args[1], ...) in func text
# to avoid typing issues with Bodo.
args_str = ""

Check warning on line 332 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L332

Added line #L332 was not covered by tests
if len(args):
args_str = ", ".join(f"args[{i}]" for i in range(len(args)))
args_str += ","

Check warning on line 335 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L334-L335

Added lines #L334 - L335 were not covered by tests
else:
# Add dummy value for args for spawn mode compatibility.
# TODO: fix in spawn mode.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a PR to fix this here: #414

args = (0,)

Check warning on line 339 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L339

Added line #L339 was not covered by tests

apply_func_text = "def bodo_apply_func(data, axis, args):\n"
apply_func_text += f" return data.apply(udf, axis=axis, args=({args_str}))"

Check warning on line 342 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L341-L342

Added lines #L341 - L342 were not covered by tests

glbls = {"udf": func}

Check warning on line 344 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L344

Added line #L344 was not covered by tests
if spawn_mode:
# In the spawn mode case we need to bodo_exec on the workers as well
# so the code object is available to the caching infra.
def f(func_text, glbls, loc_vars, __name__):
bodo_exec(func_text, glbls, loc_vars, __name__)

Check warning on line 349 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L348-L349

Added lines #L348 - L349 were not covered by tests

spawner.submit_func_to_workers(f, [], apply_func_text, glbls, {}, __name__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong indentation looks like. Should be inside if spawn_mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I just realized this codepath is taken even when engine=bodo.jit(spawn=False, distributed=False). Would there be a way to also case on the decorator?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will exec the func text on the workers even when spawn=False. Discussed offline it would be better to just avoid func_text altogether.

apply_func = decorator(bodo_exec(apply_func_text, glbls, {}, __name__))

Check warning on line 352 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L351-L352

Added lines #L351 - L352 were not covered by tests

return apply_func(data, axis, args)

Check warning on line 354 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L354

Added line #L354 was not covered by tests

bodo_pandas_udf_execution_engine = BodoExecutionEngine

Check warning on line 356 in bodo/pandas_compat.py

View check run for this annotation

Codecov / codecov/patch

bodo/pandas_compat.py#L356

Added line #L356 was not covered by tests
119 changes: 119 additions & 0 deletions bodo/tests/test_pandas_udf_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
This file tests the Bodo implementation of the Pandas UDF interface.
See https://github.com/pandas-dev/pandas/pull/61032 for more details.
This feature is only availible on newer versions of Pandas (>=3.0)
"""

import numpy as np
import pandas as pd
import pytest

import bodo
from bodo.pandas_compat import pandas_version
from bodo.tests.utils import _test_equal, pytest_spawn_mode

pytestmark = [
pytest.mark.skipif(
pandas_version < (3, 0), reason="Third-party UDF engines requires Pandas >= 3.0"
)
] + pytest_spawn_mode


@pytest.fixture(
params=(
pytest.param(bodo.jit, id="jit_no_kwargs"),
pytest.param(bodo.jit(spawn=False, distributed=False), id="jit_no_spawn"),
pytest.param(bodo.jit(cache=True), id="jit_with_cache"),
),
scope="module",
)
def engine(request):
return request.param


def test_apply_basic(engine):
"""Simplest test to check Pandas UDF apply hook is set up properly"""

df = pd.DataFrame({"A": np.arange(30)})

bodo_result = df.apply(lambda x: x.A, axis=1, engine=engine)

pandas_result = df.apply(lambda x: x.A, axis=1)

_test_equal(bodo_result, pandas_result, check_pandas_types=False)


def test_apply_raw_error():
"""Test passing raw=True raises appropriate error message."""

df = pd.DataFrame({"A": np.arange(30)})

with pytest.raises(
ValueError,
match="BodoExecutionEngine: does not support the raw=True for DataFrame.apply.",
):
df.apply(lambda x: x.A, axis=1, engine=bodo.jit, raw=True)


def test_udf_args(engine):
df = pd.DataFrame({"A": np.arange(30)})

def udf(x, a):
return x.A + a

bodo_result = df.apply(udf, axis=1, engine=engine, args=(1,))

pandas_result = df.apply(udf, axis=1, args=(1,))

_test_equal(bodo_result, pandas_result, check_pandas_types=False)


def test_udf_kwargs(engine):
df = pd.DataFrame({"A": np.arange(30), "B": ["hi", "hello", "goodbye"] * 10})

def udf(x, a=1, b="goodbye", d=3):
if b == x.B:
return x.A + a
else:
return x.A + d

bodo_out = df.apply(udf, axis=1, args=(4,), d=16, b="hi", engine=engine)

pandas_out = df.apply(udf, axis=1, args=(4,), d=16, b="hi")

_test_equal(bodo_out, pandas_out, check_pandas_types=False)


def test_udf_cache():
"""Tests that we can call the same UDF multiple times with cache flag on
without any errors. TODO: check cache."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you verified caching manually at least?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

engine = bodo.jit(cache=True)

df = pd.DataFrame({"A": np.arange(30)})

def udf(x, a):
return x.A + a

bodo_result = df.apply(udf, axis=1, engine=engine, args=(1,))

pandas_result = df.apply(udf, axis=1, args=(1,))

_test_equal(bodo_result, pandas_result, check_pandas_types=False)

bodo_result = df.apply(udf, axis=1, engine=engine, args=(1,))

_test_equal(bodo_result, pandas_result, check_pandas_types=False)


def test_udf_str(engine):
"""Test passing string as func works properly."""
df = pd.DataFrame({"A": np.arange(30)})

str_func = "mean"

bodo_out = df.apply(str_func, axis=1, engine=engine)

pandas_out = df.apply(str_func, axis=1)

_test_equal(bodo_out, pandas_out, check_pandas_types=False)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ classifiers = [
dependencies = [
"numba==0.61.0",
"pyarrow==19.0.0",
"pandas>=2.2,<2.3",
"pandas>=2.2,<3.1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert?

"numpy>=1.24,<2.2",
# fsspec >= 2021.09 because it includes Arrow filesystem wrappers (useful for fs.glob() for example)
"fsspec>=2021.09",
Expand Down
Loading