Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions bodo/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from numba.extending import lower_builtin, models, register_model

import bodo
from bodo.pandas_compat import bodo_pandas_udf_execution_engine

# Add Bodo's options to Numba's allowed options/flags
numba.core.cpu.CPUTargetOptions.all_args_distributed_block = _mapping(
Expand Down Expand Up @@ -328,6 +329,11 @@ def jit(signature_or_function=None, pipeline_class=None, **options):
# precedence)
disable_jit = os.environ.get("NUMBA_DISABLE_JIT", "0") == "1"
dist_mode = options.get("distributed", True) is not False

py_func = None
if isinstance(signature_or_function, pytypes.FunctionType):
py_func = signature_or_function

if options.get("spawn", bodo.spawn_mode) and not disable_jit and dist_mode:
from bodo.spawn.spawner import SpawnDispatcher
from bodo.spawn.worker_state import is_worker
Expand All @@ -346,18 +352,26 @@ def return_wrapped_fn(py_func):
submit_jit_args["pipeline_class"] = pipeline_class
return SpawnDispatcher(py_func, submit_jit_args)

if isinstance(signature_or_function, pytypes.FunctionType):
py_func = signature_or_function
if py_func is not None:
return return_wrapped_fn(py_func)

return return_wrapped_fn

bodo_jit = return_wrapped_fn
elif "propagate_env" in options:
raise bodo.utils.typing.BodoError(
"spawn=False while propagate_env is set. No worker to propagate env vars."
)
else:
bodo_jit = _jit(signature_or_function, pipeline_class, **options)

# Return jit decorator that can be used in Pandas UDF function. See definition of
# bodo_pandas_udf_execution_engine for more details.
if py_func is None:
bodo_jit.__pandas_udf__ = bodo_pandas_udf_execution_engine

return bodo_jit


return _jit(signature_or_function, pipeline_class, **options)
jit.__pandas_udf__ = bodo_pandas_udf_execution_engine


def _jit(signature_or_function=None, pipeline_class=None, **options):
Expand Down
112 changes: 112 additions & 0 deletions bodo/pandas_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,115 @@ def get_conversion_factor_to_ns(in_reso: str) -> int:
else:
raise ValueError(f"Unsupported resolution {in_reso}")
return factor * value


# Class responsible for executing UDFs using Bodo as the engine in
# newer version of Pandas. See:
# https://github.com/pandas-dev/pandas/pull/61032
bodo_pandas_udf_execution_engine = None

if pandas_version >= (3, 0):
from collections.abc import Callable
from typing import Any

from pandas._typing import AggFuncType, Axis
from pandas.core.apply import BaseExecutionEngine

def _prepare_function_arguments(
func: Callable, args: tuple, kwargs: dict, *, num_required_args: int
) -> tuple[tuple, dict]:
"""
Prepare arguments for jitted function by trying to move keyword arguments inside
of args to eliminate kwargs.

This simplifies typing as well as catches keyword-only arguments,
which lead to unexpected behavior in Bodo. Copied from:
https://github.com/pandas-dev/pandas/blob/5fef9793dd23867e7b227a1df7aa60a283f6204e/pandas/core/util/numba_.py#L97
"""
_sentinel = object()

if not kwargs:
return args, kwargs

# the udf should have this pattern: def udf(arg1, arg2, ..., *args, **kwargs):...
signature = inspect.signature(func)
arguments = signature.bind(*[_sentinel] * num_required_args, *args, **kwargs)
arguments.apply_defaults()
# Ref: https://peps.python.org/pep-0362/
# Arguments which could be passed as part of either *args or **kwargs
# will be included only in the BoundArguments.args attribute.
args = arguments.args
kwargs = arguments.kwargs

if kwargs:
# Bodo change: error message
raise ValueError("Bodo does not support keyword only arguments.")

args = args[num_required_args:]
return args, kwargs

class BodoExecutionEngine(BaseExecutionEngine):
@staticmethod
def map(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not implemented yet in Pandas so there is no way to test. Leaving as a followup

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll open a PR adding engine to Series.map shortly. I have the implementation already, but map and apply tests are structured very differently, and I need to see how to implement the common fixtures without making the mess bigger or refactoring all the tests.

data: pd.Series | pd.DataFrame | np.ndarray,
func: AggFuncType,
args: tuple,
kwargs: dict[str, Any],
decorator: Callable | None,
skip_na: bool,
):
raise NotImplementedError("BodoExecutionEngine: map not implemented yet.")

@staticmethod
def apply(
data: pd.Series | pd.DataFrame | np.ndarray,
func: AggFuncType,
args: tuple,
kwargs: dict[str, Any],
decorator: Callable,
axis: Axis,
):
from bodo import spawn_mode
from bodo.spawn import spawner
from bodo.utils.utils import bodo_exec

# raw = True converts data to ndarray first
if isinstance(data, np.ndarray):
raise ValueError(
"BodoExecutionEngine: does not support the raw=True for DataFrame.apply."
)

if isinstance(func, Callable):
args, _ = _prepare_function_arguments(
func, args, kwargs, num_required_args=1
)

# Embed args as a string e.g. (args[0], args[1], ...) in func text
# to avoid typing issues with Bodo.
args_str = ""
if len(args):
args_str = ", ".join(f"args[{i}]" for i in range(len(args)))
args_str += ","
else:
# Add dummy value for args for spawn mode compatibility.
# TODO: fix in spawn mode.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a PR to fix this here: #414

args = (0,)

apply_func_text = "def bodo_apply_func(data, axis, args):\n"
apply_func_text += f" return data.apply(udf, axis=axis, args=({args_str}))"

glbls = {"udf": func}
if spawn_mode:
# In the spawn mode case we need to bodo_exec on the workers as well
# so the code object is available to the caching infra.
def f(func_text, glbls, loc_vars, __name__):
bodo_exec(func_text, glbls, loc_vars, __name__)

spawner.submit_func_to_workers(
f, [], apply_func_text, glbls, {}, __name__
)
apply_func = decorator(bodo_exec(apply_func_text, glbls, {}, __name__))

return apply_func(data, axis, args)

bodo_pandas_udf_execution_engine = BodoExecutionEngine
119 changes: 119 additions & 0 deletions bodo/tests/test_pandas_udf_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
This file tests the Bodo implementation of the Pandas UDF interface.
See https://github.com/pandas-dev/pandas/pull/61032 for more details.
This feature is only availible on newer versions of Pandas (>=3.0)
"""

import numpy as np
import pandas as pd
import pytest

import bodo
from bodo.pandas_compat import pandas_version
from bodo.tests.utils import _test_equal, pytest_spawn_mode

pytestmark = [
pytest.mark.skipif(
pandas_version < (3, 0), reason="Third-party UDF engines requires Pandas >= 3.0"
)
] + pytest_spawn_mode


@pytest.fixture(
params=(
pytest.param(bodo.jit, id="jit_no_kwargs"),
pytest.param(bodo.jit(spawn=False, distributed=False), id="jit_no_spawn"),
pytest.param(bodo.jit(cache=True), id="jit_with_cache"),
),
scope="module",
)
def engine(request):
return request.param


def test_apply_basic(engine):
"""Simplest test to check Pandas UDF apply hook is set up properly"""

df = pd.DataFrame({"A": np.arange(30)})

bodo_result = df.apply(lambda x: x.A, axis=1, engine=engine)

pandas_result = df.apply(lambda x: x.A, axis=1)

_test_equal(bodo_result, pandas_result, check_pandas_types=False)


def test_apply_raw_error():
"""Test passing raw=True raises appropriate error message."""

df = pd.DataFrame({"A": np.arange(30)})

with pytest.raises(
ValueError,
match="BodoExecutionEngine: does not support the raw=True for DataFrame.apply.",
):
df.apply(lambda x: x.A, axis=1, engine=bodo.jit, raw=True)


def test_udf_args(engine):
df = pd.DataFrame({"A": np.arange(30)})

def udf(x, a):
return x.A + a

bodo_result = df.apply(udf, axis=1, engine=engine, args=(1,))

pandas_result = df.apply(udf, axis=1, args=(1,))

_test_equal(bodo_result, pandas_result, check_pandas_types=False)


def test_udf_kwargs(engine):
df = pd.DataFrame({"A": np.arange(30), "B": ["hi", "hello", "goodbye"] * 10})

def udf(x, a=1, b="goodbye", d=3):
if b == x.B:
return x.A + a
else:
return x.A + d

bodo_out = df.apply(udf, axis=1, args=(4,), d=16, b="hi", engine=engine)

pandas_out = df.apply(udf, axis=1, args=(4,), d=16, b="hi")

_test_equal(bodo_out, pandas_out, check_pandas_types=False)


def test_udf_cache():
"""Tests that we can call the same UDF multiple times with cache flag on
without any errors. TODO: check cache."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you verified caching manually at least?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

engine = bodo.jit(cache=True)

df = pd.DataFrame({"A": np.arange(30)})

def udf(x, a):
return x.A + a

bodo_result = df.apply(udf, axis=1, engine=engine, args=(1,))

pandas_result = df.apply(udf, axis=1, args=(1,))

_test_equal(bodo_result, pandas_result, check_pandas_types=False)

bodo_result = df.apply(udf, axis=1, engine=engine, args=(1,))

_test_equal(bodo_result, pandas_result, check_pandas_types=False)


def test_udf_str(engine):
"""Test passing string as func works properly."""
df = pd.DataFrame({"A": np.arange(30)})

str_func = "mean"

bodo_out = df.apply(str_func, axis=1, engine=engine)

pandas_out = df.apply(str_func, axis=1)

_test_equal(bodo_out, pandas_out, check_pandas_types=False)