Skip to content

Commit 476353e

Browse files
committed
feat(polars): support is_in queries from uncorrelated tables
1 parent c83f842 commit 476353e

File tree

3 files changed

+51
-8
lines changed

3 files changed

+51
-8
lines changed

ibis/backends/polars/compiler.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,29 @@ def greatest(op, **kw):
477477

478478
@translate.register(ops.InSubquery)
479479
def in_column(op, **kw):
480-
value = translate(op.value, **kw)
480+
if op.value.dtype.is_struct():
481+
raise com.UnsupportedOperationError(
482+
"polars doesn't support contains with struct elements"
483+
)
484+
481485
needle = translate(op.needle, **kw)
482-
return needle.is_in(value)
486+
value = translate(op.value, **kw)
487+
(rel,) = op.value.relations
488+
# The `collect` triggers computation, but there appears to be no way to
489+
# spell this operation in a polars-native way that's
490+
#
491+
# 1. not deprecated
492+
# 2. operates only using pl.Expr objects and methods
493+
#
494+
# In other words, we need to either rearchitect the polars compiler to
495+
# operate only with DataFrames, or compute first. I chose computing first
496+
# since it is less effort and we don't know how impactful it is.
497+
value = translate(rel, **kw).select(value).collect().to_series()
498+
499+
return needle.map_batches(
500+
lambda needle, value=value: needle.is_in(value),
501+
return_dtype=pl.Boolean(),
502+
)
483503

484504

485505
@translate.register(ops.InValues)

ibis/backends/tests/test_generic.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,32 @@ def test_typeof(con):
11281128
assert result is not None
11291129

11301130

1131-
@pytest.mark.notimpl(["polars"], reason="incorrect answer")
1131+
@pytest.mark.notyet(["impala"], reason="can't find table in subquery")
1132+
@pytest.mark.notimpl(["datafusion", "druid"])
1133+
@pytest.mark.xfail_version(pyspark=["pyspark<3.5"])
1134+
@pytest.mark.notyet(["exasol"], raises=ExaQueryError, reason="not supported by exasol")
1135+
@pytest.mark.notyet(
1136+
["risingwave"],
1137+
raises=PsycoPg2InternalError,
1138+
reason="https://github.com/risingwavelabs/risingwave/issues/1343",
1139+
)
1140+
@pytest.mark.notyet(
1141+
["mssql"],
1142+
raises=PyODBCProgrammingError,
1143+
reason="naked IN queries are not supported",
1144+
)
1145+
def test_isin_uncorrelated_simple(con):
1146+
u1 = ibis.memtable({"id": [1, 2, 3]})
1147+
a = ibis.memtable({"id": [1, 2]})
1148+
1149+
u2 = u1.mutate(in_a=u1["id"].isin(a["id"]))
1150+
final = u2.order_by("id")
1151+
1152+
result = con.to_pyarrow(final)
1153+
expected = pa.table({"id": [1, 2, 3], "in_a": [True, True, False]})
1154+
assert result.equals(expected)
1155+
1156+
11321157
@pytest.mark.notyet(["impala"], reason="can't find table in subquery")
11331158
@pytest.mark.notimpl(["datafusion", "druid"])
11341159
@pytest.mark.xfail_version(pyspark=["pyspark<3.5"])
@@ -1161,7 +1186,6 @@ def test_isin_uncorrelated(
11611186
backend.assert_series_equal(result, expected)
11621187

11631188

1164-
@pytest.mark.notimpl(["polars"], reason="incorrect answer")
11651189
def test_isin_uncorrelated_filter(
11661190
backend, batting, awards_players, batting_df, awards_players_df
11671191
):

ibis/backends/tests/test_struct.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ibis.backends.tests.conftest import NO_STRUCT_SUPPORT_MARKS
1313
from ibis.backends.tests.errors import (
1414
DatabricksServerOperationError,
15-
PolarsColumnNotFoundError,
1615
PsycoPg2InternalError,
1716
PsycoPg2ProgrammingError,
1817
PsycoPgSyntaxError,
@@ -21,7 +20,7 @@
2120
PyAthenaOperationalError,
2221
PySparkAnalysisException,
2322
)
24-
from ibis.common.exceptions import IbisError
23+
from ibis.common.exceptions import IbisError, UnsupportedOperationError
2524

2625
np = pytest.importorskip("numpy")
2726
pd = pytest.importorskip("pandas")
@@ -240,8 +239,8 @@ def test_keyword_fields(con, nullable):
240239
)
241240
@pytest.mark.notyet(
242241
["polars"],
243-
raises=PolarsColumnNotFoundError,
244-
reason="doesn't seem to support IN-style subqueries on structs",
242+
raises=UnsupportedOperationError,
243+
reason="doesn't support IN-style subqueries on structs",
245244
)
246245
@pytest.mark.xfail_version(
247246
pyspark=["pyspark<3.5"],

0 commit comments

Comments
 (0)