Skip to content

Commit eec0866

Browse files
authored
fix: preserve structure of jagged subfields when accessing the arrays directly (#1476)
* Fix jagged subfields * Rename tuple subfields to match Awkward * Added test * Minor refactoring
1 parent 7739ce7 commit eec0866

File tree

5 files changed

+213
-31
lines changed

5 files changed

+213
-31
lines changed

src/uproot/_dask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,7 +1543,7 @@ def real_filter_branch(branch):
15431543
entry_stop = ttree.num_entries
15441544

15451545
if isinstance(ttree, HasFields):
1546-
akform = ttree.to_akform(filter_name=common_keys)
1546+
akform, _ = ttree.to_akform(filter_name=common_keys)
15471547
ttree_step = _RNTuple_regularize_step_size(
15481548
ttree, akform, step_size, entry_start, entry_stop
15491549
)
@@ -1593,7 +1593,7 @@ def real_filter_branch(branch):
15931593
partition_args.append((i, start, stop))
15941594

15951595
if isinstance(ttrees[0], HasFields):
1596-
base_form = ttrees[0].to_akform(filter_name=common_keys)
1596+
base_form, _ = ttrees[0].to_akform(filter_name=common_keys)
15971597
else:
15981598
base_form = _get_ttree_form(
15991599
awkward, ttrees[0], common_keys, interp_options.get("ak_add_doc")

src/uproot/behaviors/RNTuple.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,9 @@ def to_akform(
525525
for compatibility with software that was used for :doc:`uproot.behaviors.TBranch.TBranch`. This argument should not be used
526526
and will be removed in a future version.
527527
528-
Returns the an Awkward Form with the structure of the data in the ``RNTuple`` or ``RField``.
528+
Returns a 2-tuple where the first entry is the Awkward Form with the structure of the data in the ``RNTuple`` or ``RField``,
529+
and the second entry is the relative path of the requested RField. The second entry is needed in cases where the requested RField
530+
is a subfield of a collection, which requires constructing the form with information about the parent field.
529531
"""
530532
ak = uproot.extras.awkward()
531533

@@ -540,6 +542,7 @@ def to_akform(
540542
top_names = []
541543
record_list = []
542544
if self is rntuple:
545+
field_path = None
543546
for field in self.fields:
544547
# the field needs to be in the keys or be a parent of a field in the keys
545548
if any(
@@ -551,14 +554,39 @@ def to_akform(
551554
rntuple.field_form(field.field_id, keys, ak_add_doc=ak_add_doc)
552555
)
553556
else:
557+
# If it is a subfield of a collection, we need to include the collection in the keys
558+
path_keys = self.path.split(".")
559+
top_collection = None
560+
tmp_field = self.ntuple
561+
field_path = [self.name]
562+
for i, key in enumerate(path_keys):
563+
tmp_field = tmp_field[key]
564+
if (
565+
tmp_field.record.struct_role
566+
== uproot.const.RNTupleFieldRole.COLLECTION
567+
):
568+
top_collection = tmp_field
569+
field_path = path_keys[i:]
570+
break
554571
# Always use the full path for keys
555572
# Also include the field itself
556573
keys = [self.path] + [f"{self.path}.{k}" for k in keys]
557-
# The field needs to be in the keys or be a parent of a field in the keys
558-
if any(key.startswith(f"{self.path}.") or key == self.path for key in keys):
559-
top_names.append(self.name)
574+
if top_collection is None:
575+
# The field needs to be in the keys or be a parent of a field in the keys
576+
if any(
577+
key.startswith(f"{self.path}.") or key == self.path for key in keys
578+
):
579+
top_names.append(self.name)
580+
record_list.append(
581+
rntuple.field_form(self.field_id, keys, ak_add_doc=ak_add_doc)
582+
)
583+
else:
584+
keys += [top_collection.path]
585+
top_names.append(top_collection.name)
560586
record_list.append(
561-
rntuple.field_form(self.field_id, keys, ak_add_doc=ak_add_doc)
587+
rntuple.field_form(
588+
top_collection.field_id, keys, ak_add_doc=ak_add_doc
589+
)
562590
)
563591

564592
parameters = None
@@ -572,7 +600,7 @@ def to_akform(
572600
form = ak.forms.RecordForm(
573601
record_list, top_names, form_key="toplevel", parameters=parameters
574602
)
575-
return form
603+
return (form, field_path)
576604

577605
def arrays(
578606
self,
@@ -697,7 +725,7 @@ def arrays(
697725
[c.num_entries for c in clusters[start_cluster_idx:stop_cluster_idx]]
698726
)
699727

700-
form = self.to_akform(
728+
form, field_path = self.to_akform(
701729
filter_name=filter_name,
702730
filter_typename=filter_typename,
703731
filter_field=filter_field,
@@ -755,6 +783,20 @@ def arrays(
755783
# no longer needed; save memory
756784
del container_dict
757785

786+
# If we constructed some parent fields, we need to get back to the requested field
787+
if field_path is not None:
788+
for field in field_path[:-1]:
789+
if field in arrays.fields:
790+
arrays = arrays[field]
791+
# tuples are a trickier since indices no longer match
792+
else:
793+
if field.isdigit() and arrays.fields == ["0"]:
794+
arrays = arrays["0"]
795+
else:
796+
raise AssertionError(
797+
"The array was not constructed correctly. Please report this issue."
798+
)
799+
758800
# FIXME: This is not right, but it might temporarily work
759801
if library.name == "np":
760802
return arrays.to_numpy()
@@ -896,7 +938,7 @@ def iterate(
896938
)
897939
)
898940

899-
akform = self.to_akform(
941+
akform, _ = self.to_akform(
900942
filter_name=filter_name,
901943
filter_typename=filter_typename,
902944
filter_field=filter_field,
@@ -1408,7 +1450,7 @@ def num_entries_for(
14081450
)
14091451
)
14101452

1411-
akform = self.to_akform(
1453+
akform, _ = self.to_akform(
14121454
filter_name=filter_name,
14131455
filter_typename=filter_typename,
14141456
filter_field=filter_field,

src/uproot/models/RNTuple.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def field_form(self, this_id, keys, ak_add_doc=False):
562562
):
563563
recordlist.append(self.field_form(i, keys, ak_add_doc=ak_add_doc))
564564
namelist.append(field_records[i].field_name)
565-
if all(name == f"_{i}" for i, name in enumerate(namelist)):
565+
if all(re.fullmatch(r"_[0-9]+", name) is not None for name in namelist):
566566
namelist = None
567567
return ak.forms.RecordForm(
568568
recordlist, namelist, form_key="whatever", parameters=parameters
@@ -1485,7 +1485,15 @@ def name(self):
14851485
"""
14861486
Name of the ``RField``.
14871487
"""
1488-
return self._ntuple.field_records[self._fid].field_name
1488+
# We rename subfields of tuples to match Awkward
1489+
name = self._ntuple.field_records[self._fid].field_name
1490+
if (
1491+
not self.top_level
1492+
and self.parent.record.struct_role == uproot.const.RNTupleFieldRole.RECORD
1493+
and re.fullmatch(r"_[0-9]+", name) is not None
1494+
):
1495+
name = name[1:]
1496+
return name
14891497

14901498
@property
14911499
def description(self):
@@ -1637,14 +1645,25 @@ def array(
16371645
See also :ref:`uproot.behaviors.RNTuple.HasFields.arrays` to read
16381646
multiple ``RFields`` into a group of arrays or an array-group.
16391647
"""
1640-
return self.arrays(
1648+
arrays = self.arrays(
16411649
entry_start=entry_start,
16421650
entry_stop=entry_stop,
16431651
library=library,
16441652
interpreter=interpreter,
16451653
backend=backend,
16461654
ak_add_doc=ak_add_doc,
1647-
)[self.name]
1655+
)
1656+
if self.name in arrays.fields:
1657+
arrays = arrays[self.name]
1658+
# tuples are a trickier since indices no longer match
1659+
else:
1660+
if self.name.isdigit() and arrays.fields == ["0"]:
1661+
arrays = arrays["0"]
1662+
else:
1663+
raise AssertionError(
1664+
"The array was not constructed correctly. Please report this issue."
1665+
)
1666+
return arrays
16481667

16491668

16501669
# No cupy version of numpy.insert() provided

tests/test_1406_improved_rntuple_methods.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def test_getitem(tmp_path):
8080
assert obj["struct1"]["x"] is obj[r"struct1\x"]
8181

8282
# Make sure it accesses the grandchildren field instead of the "real" _0
83-
assert obj["struct5._0"].record.struct_role == uproot.const.RNTupleFieldRole.LEAF
84-
assert obj["struct5._1"].record.struct_role == uproot.const.RNTupleFieldRole.LEAF
85-
assert obj["struct5._2"].record.struct_role == uproot.const.RNTupleFieldRole.LEAF
86-
assert obj["struct6._0"].record.struct_role == uproot.const.RNTupleFieldRole.LEAF
83+
assert obj["struct5.0"].record.struct_role == uproot.const.RNTupleFieldRole.LEAF
84+
assert obj["struct5.1"].record.struct_role == uproot.const.RNTupleFieldRole.LEAF
85+
assert obj["struct5.2"].record.struct_role == uproot.const.RNTupleFieldRole.LEAF
86+
assert obj["struct6.0"].record.struct_role == uproot.const.RNTupleFieldRole.LEAF
8787

8888

8989
def test_to_akform(tmp_path):
@@ -94,21 +94,22 @@ def test_to_akform(tmp_path):
9494

9595
obj = uproot.open(filepath)["ntuple"]
9696

97-
akform = obj.to_akform()
97+
akform, field_path = obj.to_akform()
9898
assert akform == data.layout.form
99+
assert field_path is None
99100

100-
assert obj["struct1"].to_akform() == akform.select_columns("struct1")
101-
assert obj["struct2"].to_akform() == akform.select_columns("struct2")
102-
assert obj["struct3"].to_akform() == akform.select_columns("struct3")
103-
assert obj["struct4"].to_akform() == akform.select_columns("struct4")
104-
assert obj["struct5"].to_akform() == akform.select_columns("struct5")
101+
assert obj["struct1"].to_akform() == (akform.select_columns("struct1"), ["struct1"])
102+
assert obj["struct2"].to_akform() == (akform.select_columns("struct2"), ["struct2"])
103+
assert obj["struct3"].to_akform() == (akform.select_columns("struct3"), ["struct3"])
104+
assert obj["struct4"].to_akform() == (akform.select_columns("struct4"), ["struct4"])
105+
assert obj["struct5"].to_akform() == (akform.select_columns("struct5"), ["struct5"])
105106

106-
assert obj["struct1"].to_akform(filter_name="x") == akform.select_columns(
107+
assert obj["struct1"].to_akform(filter_name="x")[0] == akform.select_columns(
107108
["struct1.x"]
108109
)
109-
assert obj["struct3"].to_akform(filter_typename="double") == akform.select_columns(
110-
["struct3.t"]
111-
)
110+
assert obj["struct3"].to_akform(filter_typename="double")[
111+
0
112+
] == akform.select_columns(["struct3.t"])
112113

113114

114115
def test_iterate_and_concatenate(tmp_path):
@@ -144,5 +145,5 @@ def test_array(tmp_path):
144145

145146
obj = uproot.open(filepath)["ntuple"]
146147

147-
assert obj["struct5._0"].array().tolist() == [1, 4]
148-
# assert obj["struct6._0"].array().tolist() == [[1, 4], [7]] # TODO: Need to fix this
148+
assert obj["struct5.0"].array().tolist() == [1, 4]
149+
assert obj["struct6.0"].array().tolist() == [[1, 4], [7]]
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE
2+
3+
import os
4+
5+
import pytest
6+
7+
import uproot
8+
9+
ak = pytest.importorskip("awkward")
10+
11+
data = ak.Array(
12+
{
13+
"field1": [[(0, 1), (2, 3)], [], [(4, 5)]],
14+
"field2": [[{"x": 0, "y": 1}, {"x": 2, "y": 3}], [], [{"x": 4, "y": 5}]],
15+
"field3": [
16+
{"x": [1, 2, 3], "y": [(1, 2), (2, 3), (3, 4)]},
17+
{"x": [], "y": [(7, 8)]},
18+
{"x": [9, 10], "y": [(11, 12), (13, 14)]},
19+
],
20+
"field4": [(1, [(1, 2), (3, 4)]), (2, []), (3, [(5, 6)])],
21+
"field5": [
22+
{
23+
"x": [
24+
{"up": [(0, 2), (1, 2)], "down": [[1, 2, 3], []]},
25+
{"up": [], "down": [[4]]},
26+
],
27+
"y": [
28+
(
29+
{"left": [[0, "hi", 2.3]], "right": 6},
30+
{"left": [[], [""]], "right": 8.0},
31+
)
32+
],
33+
},
34+
{
35+
"x": [
36+
{"up": [(10, 2), (12, 2)], "down": []},
37+
{"up": [(1, 2)], "down": [[], [4, 2, 3]]},
38+
],
39+
"y": [
40+
(
41+
{"left": [[23, 4.1, "hello"]], "right": 14},
42+
{"left": [], "right": 16.0},
43+
)
44+
],
45+
},
46+
{
47+
"x": [
48+
{"up": [], "down": [[4, 5], [1, 2, 3], []]},
49+
{"up": [(0, 2), (1, 2)], "down": []},
50+
],
51+
"y": [
52+
(
53+
{"left": [[]], "right": 14},
54+
{"left": [[2, 3], ["bye", ""]], "right": 16.0},
55+
)
56+
],
57+
},
58+
],
59+
"field6": [(1, 2, 3, 4), (2, 5, 6, 7), ("hello", 2.3, 8, "")],
60+
}
61+
)
62+
63+
64+
def test_jagged_subfields(tmp_path):
65+
filepath = os.path.join(tmp_path, "test.root")
66+
67+
with uproot.recreate(filepath) as file:
68+
obj = file.mkrntuple("ntuple", data)
69+
70+
obj = uproot.open(filepath)["ntuple"]
71+
72+
assert ak.array_equal(obj["field1"].array(), data["field1"])
73+
assert ak.array_equal(obj["field1.0"].array(), data["field1"]["0"])
74+
assert ak.array_equal(obj["field1.1"].array(), data["field1"]["1"])
75+
76+
assert ak.array_equal(obj["field2"].array(), data["field2"])
77+
assert ak.array_equal(obj["field2.x"].array(), data["field2"]["x"])
78+
assert ak.array_equal(obj["field2.y"].array(), data["field2"]["y"])
79+
80+
assert ak.array_equal(obj["field3"].array(), data["field3"])
81+
assert ak.array_equal(obj["field3.x"].array(), data["field3"]["x"])
82+
assert ak.array_equal(obj["field3.y"].array(), data["field3"]["y"])
83+
assert ak.array_equal(obj["field3.y.0"].array(), data["field3"]["y"]["0"])
84+
assert ak.array_equal(obj["field3.y.1"].array(), data["field3"]["y"]["1"])
85+
86+
assert ak.array_equal(obj["field4"].array(), data["field4"])
87+
assert ak.array_equal(obj["field4.0"].array(), data["field4"]["0"])
88+
assert ak.array_equal(obj["field4.1"].array(), data["field4"]["1"])
89+
assert ak.array_equal(obj["field4.1.0"].array(), data["field4"]["1"]["0"])
90+
assert ak.array_equal(obj["field4.1.1"].array(), data["field4"]["1"]["1"])
91+
92+
assert obj["field5"].array().tolist() == data["field5"].tolist()
93+
assert obj["field5.x"].array().tolist() == data["field5"]["x"].tolist()
94+
assert obj["field5.x.up"].array().tolist() == data["field5"]["x"]["up"].tolist()
95+
assert obj["field5.x.down"].array().tolist() == data["field5"]["x"]["down"].tolist()
96+
assert obj["field5.y"].array().tolist() == data["field5"]["y"].tolist()
97+
assert obj["field5.y.0"].array().tolist() == data["field5"]["y"]["0"].tolist()
98+
assert (
99+
obj["field5.y.0.left"].array().tolist()
100+
== data["field5"]["y"]["0"]["left"].tolist()
101+
)
102+
assert (
103+
obj["field5.y.0.right"].array().tolist()
104+
== data["field5"]["y"]["0"]["right"].tolist()
105+
)
106+
assert obj["field5.y.1"].array().tolist() == data["field5"]["y"]["1"].tolist()
107+
assert (
108+
obj["field5.y.1.left"].array().tolist()
109+
== data["field5"]["y"]["1"]["left"].tolist()
110+
)
111+
assert (
112+
obj["field5.y.1.right"].array().tolist()
113+
== data["field5"]["y"]["1"]["right"].tolist()
114+
)
115+
116+
assert obj["field6"].array().tolist() == data["field6"].tolist()
117+
assert obj["field6.0"].array().tolist() == data["field6"]["0"].tolist()
118+
assert obj["field6.1"].array().tolist() == data["field6"]["1"].tolist()
119+
assert obj["field6.2"].array().tolist() == data["field6"]["2"].tolist()
120+
assert obj["field6.3"].array().tolist() == data["field6"]["3"].tolist()

0 commit comments

Comments
 (0)