Skip to content

Commit a1f2bab

Browse files
committed
feat(databricks): add support for json via variant
1 parent cb18749 commit a1f2bab

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

ibis/backends/sql/compilers/databricks.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import ibis.expr.datatypes as dt
77
import ibis.expr.operations as ops
8+
from ibis.backends.sql.compilers.base import NULL
89
from ibis.backends.sql.compilers.pyspark import PySparkCompiler
910
from ibis.backends.sql.dialects import Databricks
1011

@@ -21,7 +22,14 @@ class DatabricksCompiler(PySparkCompiler):
2122
ops.BitXor: "bit_xor",
2223
ops.TypeOf: "typeof",
2324
ops.RandomUUID: "uuid",
25+
ops.JSONGetItem: "json_extract",
2426
}
27+
del (
28+
SIMPLE_OPS[ops.UnwrapJSONString],
29+
SIMPLE_OPS[ops.UnwrapJSONInt64],
30+
SIMPLE_OPS[ops.UnwrapJSONFloat64],
31+
SIMPLE_OPS[ops.UnwrapJSONBoolean],
32+
)
2533

2634
UNSUPPORTED_OPS = (
2735
ops.ElementWiseVectorizedUDF,
@@ -31,6 +39,42 @@ class DatabricksCompiler(PySparkCompiler):
3139
ops.TimestampBucket,
3240
)
3341

42+
def visit_ToJSONArray(self, op, *, arg):
43+
return self.f.try_variant_get(arg, "$", "ARRAY<VARIANT>")
44+
45+
def visit_ToJSONMap(self, op, *, arg):
46+
return self.f.try_variant_get(arg, "$", "MAP<STRING, VARIANT>")
47+
48+
def visit_UnwrapJSONString(self, op, *, arg):
49+
return self.if_(
50+
self.f.schema_of_variant(arg).eq(sge.convert("STRING")),
51+
self.f.try_variant_get(arg, "$", "STRING"),
52+
NULL,
53+
)
54+
55+
def visit_UnwrapJSONInt64(self, op, *, arg):
56+
return self.if_(
57+
self.f.schema_of_variant(arg).eq(sge.convert("BIGINT")),
58+
self.f.try_variant_get(arg, "$", "BIGINT"),
59+
NULL,
60+
)
61+
62+
def visit_UnwrapJSONFloat64(self, op, *, arg):
63+
return self.if_(
64+
self.f.schema_of_variant(arg).isin(
65+
sge.convert("STRING"), sge.convert("BOOLEAN")
66+
),
67+
NULL,
68+
self.f.try_variant_get(arg, "$", "DOUBLE"),
69+
)
70+
71+
def visit_UnwrapJSONBoolean(self, op, *, arg):
72+
return self.if_(
73+
self.f.schema_of_variant(arg).eq(sge.convert("BOOLEAN")),
74+
self.f.try_variant_get(arg, "$", "BOOLEAN"),
75+
NULL,
76+
)
77+
3478
def visit_NonNullLiteral(self, op, *, value, dtype):
3579
if dtype.is_binary():
3680
return self.f.unhex(value.hex())

ibis/backends/sql/datatypes.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,16 @@ def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
13251325
class DatabricksType(SqlglotType):
13261326
dialect = "databricks"
13271327

1328+
@classmethod
1329+
def _from_ibis_JSON(cls, dtype: dt.JSON) -> sge.DataType:
1330+
return sge.DataType(this=typecode.VARIANT)
1331+
1332+
@classmethod
1333+
def _from_sqlglot_VARIANT(cls, nullable: bool | None = None) -> sge.DataType:
1334+
return dt.JSON(nullable=nullable)
1335+
1336+
_from_sqlglot_JSON = _from_sqlglot_VARIANT
1337+
13281338

13291339
class AthenaType(SqlglotType):
13301340
dialect = "athena"

ibis/backends/tests/test_json.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
pytestmark = [
1818
pytest.mark.never(["impala"], reason="doesn't support JSON and never will"),
1919
pytest.mark.notyet(["clickhouse"], reason="upstream is broken"),
20-
pytest.mark.notimpl(
21-
["datafusion", "exasol", "mssql", "druid", "oracle", "databricks"]
22-
),
20+
pytest.mark.notimpl(["datafusion", "exasol", "mssql", "druid", "oracle"]),
2321
]
2422

2523

0 commit comments

Comments
 (0)