5
5
6
6
import ibis .expr .datatypes as dt
7
7
import ibis .expr .operations as ops
8
+ from ibis .backends .sql .compilers .base import NULL
8
9
from ibis .backends .sql .compilers .pyspark import PySparkCompiler
9
10
from ibis .backends .sql .dialects import Databricks
10
11
@@ -21,7 +22,14 @@ class DatabricksCompiler(PySparkCompiler):
21
22
ops .BitXor : "bit_xor" ,
22
23
ops .TypeOf : "typeof" ,
23
24
ops .RandomUUID : "uuid" ,
25
+ ops .JSONGetItem : "json_extract" ,
24
26
}
27
+ del (
28
+ SIMPLE_OPS [ops .UnwrapJSONString ],
29
+ SIMPLE_OPS [ops .UnwrapJSONInt64 ],
30
+ SIMPLE_OPS [ops .UnwrapJSONFloat64 ],
31
+ SIMPLE_OPS [ops .UnwrapJSONBoolean ],
32
+ )
25
33
26
34
UNSUPPORTED_OPS = (
27
35
ops .ElementWiseVectorizedUDF ,
@@ -31,6 +39,42 @@ class DatabricksCompiler(PySparkCompiler):
31
39
ops .TimestampBucket ,
32
40
)
33
41
42
+ def visit_ToJSONArray (self , op , * , arg ):
43
+ return self .f .try_variant_get (arg , "$" , "ARRAY<VARIANT>" )
44
+
45
+ def visit_ToJSONMap (self , op , * , arg ):
46
+ return self .f .try_variant_get (arg , "$" , "MAP<STRING, VARIANT>" )
47
+
48
+ def visit_UnwrapJSONString (self , op , * , arg ):
49
+ return self .if_ (
50
+ self .f .schema_of_variant (arg ).eq (sge .convert ("STRING" )),
51
+ self .f .try_variant_get (arg , "$" , "STRING" ),
52
+ NULL ,
53
+ )
54
+
55
+ def visit_UnwrapJSONInt64 (self , op , * , arg ):
56
+ return self .if_ (
57
+ self .f .schema_of_variant (arg ).eq (sge .convert ("BIGINT" )),
58
+ self .f .try_variant_get (arg , "$" , "BIGINT" ),
59
+ NULL ,
60
+ )
61
+
62
+ def visit_UnwrapJSONFloat64 (self , op , * , arg ):
63
+ return self .if_ (
64
+ self .f .schema_of_variant (arg ).isin (
65
+ sge .convert ("STRING" ), sge .convert ("BOOLEAN" )
66
+ ),
67
+ NULL ,
68
+ self .f .try_variant_get (arg , "$" , "DOUBLE" ),
69
+ )
70
+
71
+ def visit_UnwrapJSONBoolean (self , op , * , arg ):
72
+ return self .if_ (
73
+ self .f .schema_of_variant (arg ).eq (sge .convert ("BOOLEAN" )),
74
+ self .f .try_variant_get (arg , "$" , "BOOLEAN" ),
75
+ NULL ,
76
+ )
77
+
34
78
def visit_NonNullLiteral (self , op , * , value , dtype ):
35
79
if dtype .is_binary ():
36
80
return self .f .unhex (value .hex ())
0 commit comments