Skip to content

Commit 9ac4f24

Browse files
authored
Merge pull request #95 from martindurant/str_sugar
Add mul/add convenience to .str
2 parents b2d4952 + 9b9fa23 commit 9ac4f24

File tree

10 files changed

+68
-17
lines changed

10 files changed

+68
-17
lines changed

src/akimbo/apply_tree.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,12 @@ def dec(
7171
match: Callable[[ak.contents.Content], bool] = leaf,
7272
outtype: Callable[[ak.contents.Content], ak.contents.Content] | None = None,
7373
inmode: Literal["arrow", "numpy", "ak", "other"] = "ak",
74+
match_kwargs=None,
7475
):
7576
"""Make a nested/ragged version of an operation to apply throughout a tree"""
7677

7778
@functools.wraps(func)
78-
def f(arr, *args, where=None, match_kwargs=None, **kwargs):
79+
def f(arr, *args, where=None, **kwargs):
7980
others = []
8081
if args:
8182
sig = list(inspect.signature(func).parameters)[1:]

src/akimbo/dask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def to_dask_awkward(self):
5757
)
5858

5959
def __getattr__(self, item):
60-
if self.subaccessor:
60+
if self.subaccessor and isinstance(item, str):
6161
item = getattr(self.subaccessors[self.subaccessor], item)
6262
elif isinstance(item, str) and item in self.subaccessors:
6363
return DaskAwkwardAccessor(

src/akimbo/mixin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def unexplode(self, *cols: tuple[str, ...], outname="grouped") -> ak.Array:
176176
@classmethod
177177
def _create_op(cls, op):
178178
def run(self, *args, **kwargs):
179+
if self.subaccessor:
180+
# defer
181+
op2 = op(self.subaccessors[self.subaccessor](), None)
182+
return self.__getattr__(op2)(*args, **kwargs)
179183
args = [
180184
to_ak_layout(_) if isinstance(_, (str, int, float, np.number)) else _
181185
for _ in args

src/akimbo/polars.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __getattr__(self, item: str, **flags) -> callable:
5050
)
5151

5252
def select(*inargs, subaccessor=self.subaccessor, where=None, **kwargs):
53-
if subaccessor:
53+
if subaccessor and isinstance(item, str):
5454
func0 = getattr(self.subaccessors[subaccessor](), item)
5555
elif callable(item):
5656
func0 = item
@@ -100,18 +100,6 @@ def f(batch):
100100
raise NotImplementedError
101101
else:
102102
obj = self._obj
103-
# def map_batches(
104-
# self,
105-
# function: Callable[[DataFrame], DataFrame],
106-
# *,
107-
# predicate_pushdown: bool = True,
108-
# projection_pushdown: bool = True,
109-
# slice_pushdown: bool = True,
110-
# no_optimizations: bool = False,
111-
# schema: None | SchemaDict = None,
112-
# validate_output_schema: bool = True,
113-
# streamable: bool = False,
114-
# ) -> LazyFrame:
115103
arrow_type = polars_to_arrow_schema(obj.collect_schema())
116104
arr = pa.table([[]] * len(arrow_type), schema=arrow_type)
117105
out1 = f(pl.from_arrow(arr))

src/akimbo/ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __getattr__(self, item: str) -> callable:
3434
return RayAccessor(self._obj, subaccessor=item, behavior=self._behavior)
3535

3636
def select(*inargs, subaccessor=self.subaccessor, where=None, **kwargs):
37-
if subaccessor:
37+
if subaccessor and isinstance(item, str):
3838
func0 = getattr(self.subaccessors[subaccessor](), item)
3939
elif callable(item):
4040
func0 = item

src/akimbo/spark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __getattr__(self, item: str) -> sdf:
3838
return SparkAccessor(self._obj, subaccessor=item, behavior=self._behavior)
3939

4040
def select(*inargs, subaccessor=self.subaccessor, where=None, **kwargs):
41-
if subaccessor:
41+
if subaccessor and isinstance(item, str):
4242
func0 = getattr(self.subaccessors[subaccessor](), item)
4343
elif callable(item):
4444
func0 = item

src/akimbo/strings.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44

55
import awkward as ak
6+
import pyarrow as pa
67
import pyarrow.compute as pc
78

89
from akimbo.apply_tree import dec
@@ -59,6 +60,16 @@ def strptime(*args, format="%FT%T", unit="us", error_is_null=True, **kw):
5960
return out
6061

6162

63+
def repeat(arr, count):
64+
return pc.binary_repeat(arr, count)
65+
66+
67+
def concat(arr, arr2, sep=""):
68+
return pc.binary_join_element_wise(
69+
arr.cast(pa.string()), arr2.cast(pa.string()), sep
70+
)
71+
72+
6273
class StringAccessor:
6374
"""String operations on nested/var-length data"""
6475

@@ -92,6 +103,14 @@ def __getattr__(self, attr: str) -> callable:
92103
return getattr(ak.str, attr)
93104

94105
strptime = staticmethod(dec(strptime, match=match_string, inmode="arrow"))
106+
repeat = staticmethod(dec(repeat, match=match_string, inmode="arrow"))
107+
join_el = staticmethod(dec(concat, match=match_string, inmode="arrow"))
108+
109+
def __add__(self, *_):
110+
return dec(concat, match=match_string, inmode="arrow")
111+
112+
def __mul__(self, *_):
113+
return dec(repeat, match=match_string, inmode="arrow")
95114

96115
def __dir__(self) -> list[str]:
97116
return sorted(methods + ["strptime", "encode", "decode"])

tests/test_polars.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ def test_binary():
7575
assert s3.to_list() == [[0, 0, 0], [], [0, 0]]
7676

7777

78+
def test_str_sugar():
79+
s = pl.Series([["hay", "hi", "hola"], [], ["bye", "yo"]])
80+
out = s.ak.str.repeat(3)
81+
out2 = s.ak.str * 3
82+
expected = [["hayhayhay", "hihihi", "holaholahola"], [], ["byebyebye", "yoyoyo"]]
83+
assert out.ak.tolist() == out2.ak.tolist() == expected
84+
85+
out = s.ak.str.join_el(s)
86+
out2 = s.ak.str + s
87+
expected = [["hayhay", "hihi", "holahola"], [], ["byebye", "yoyo"]]
88+
assert out.ak.tolist() == out2.ak.tolist() == expected
89+
90+
7891
def test_unexplode():
7992
df = pl.DataFrame(
8093
{

tests/test_ray.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,16 @@ def f2(data):
150150
out = df.ak.apply(f2)
151151
result = out.ak.to_output()
152152
assert result.ak.tolist() == [6, None] * 100
153+
154+
155+
def test_str_sugar(df):
156+
s = df.ak["y"]
157+
out = s.ak.str.repeat(3)
158+
out2 = s.ak.str * 3
159+
expected = [["heyheyhey", None], ["hihihi", "hohoho"]] * 100
160+
assert out.ak.to_output().tolist() == out2.ak.to_output().tolist() == expected
161+
162+
out = s.ak.str.join_el(s)
163+
out2 = s.ak.str + s
164+
expected = [["heyhey", None], ["hihi", "hoho"]] * 100
165+
assert out.ak.to_output().tolist() == out2.ak.to_output().tolist() == expected

tests/test_str.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,16 @@ def test_split():
4040
s = pd.Series([b"hello world", b"oio", b""])
4141
s2 = s.ak.str.split_whitespace()
4242
assert s2.tolist() == [[b"hello", b"world"], [b"oio"], [b""]]
43+
44+
45+
def test_str_sugar():
46+
s = pd.Series([["hay", "hi", "hola"], [], ["bye", "yo"]])
47+
out = s.ak.str.repeat(3)
48+
out2 = s.ak.str * 3
49+
expected = [["hayhayhay", "hihihi", "holaholahola"], [], ["byebyebye", "yoyoyo"]]
50+
assert out.ak.tolist() == out2.ak.tolist() == expected
51+
52+
out = s.ak.str.join_el(s)
53+
out2 = s.ak.str + s
54+
expected = [["hayhay", "hihi", "holahola"], [], ["byebye", "yoyo"]]
55+
assert out.ak.tolist() == out2.ak.tolist() == expected

0 commit comments

Comments
 (0)