Skip to content

Commit ec9a925

Browse files
committed
Update expand_args_to_dims
* Generalize type of dim (str -> Hashable) * Rename to_array -> to_list * Make error message clearer (only print unexpected inputs) * Allow lists of length one as inputs * Add tests
1 parent ef0ac63 commit ec9a925

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

xarray/core/utils.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -921,10 +921,10 @@ def iterate_nested(nested_list):
921921

922922

923923
def expand_args_to_dims(
924-
dim: ItemOrSequence[str],
924+
dim: ItemOrSequence[Hashable],
925925
arg_names: Sequence[str],
926926
args: Sequence[ItemOrSequence[Any]],
927-
) -> Tuple[Sequence[str], Sequence[Sequence[Any]]]:
927+
) -> Tuple[Sequence[Hashable], Sequence[Sequence[Any]]]:
928928
"""Expand dims and all elements in args to be arrays of the length of the number of dimensions
929929
930930
Parameters
@@ -946,31 +946,33 @@ def expand_args_to_dims(
946946
list of dims, list of lists of arguments
947947
"""
948948
if is_scalar(dim):
949-
for name, arg in zip(arg_names, args):
950-
if not is_scalar(arg):
951-
raise ValueError(f"Expected {name}={arg!r} to be a scalar like 'dim'.")
952-
assert isinstance(dim, str) or isinstance(dim, bytes)
953-
dim = [dim]
949+
dim_list: Sequence[Hashable] = [dim]
950+
else:
951+
assert isinstance(dim, list)
952+
dim_list = dim
954953

955954
# dim is now a list
956-
nroll = len(dim)
955+
nroll = len(dim_list)
957956

958-
def to_array(arg):
957+
def to_list(arg):
959958
if is_scalar(arg):
960959
return [arg] * nroll
960+
elif isinstance(arg, list) and len(arg) == 1:
961+
return [arg[0]] * nroll
961962
return arg
962963

963-
arr_args = [to_array(arg) for arg in args]
964+
arr_args = [to_list(arg) for arg in args]
964965

965-
if any(len(dim) != len(arg) for arg in arr_args):
966-
names_vals = ", ".join(
967-
f"{name}={val!r}" for name, val in zip(arg_names, arr_args)
966+
if any(len(arg) != len(dim_list) for arg in arr_args):
967+
names_args_len = ", ".join(
968+
f"{name}={args!r} (len={len(args)})" for name, args in zip(arg_names, arr_args)
969+
if len(args) != len(dim_list)
968970
)
969971
raise ValueError(
970-
"Arguments must all be the same length. " f"Received {names_vals}."
972+
f"Expected all arguments to have len={len(dim_list)}. Received: {names_args_len}"
971973
)
972974

973-
return dim, arr_args
975+
return dim_list, arr_args
974976

975977

976978
def get_pads(

xarray/tests/test_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from xarray.coding.cftimeindex import CFTimeIndex
99
from xarray.core import duck_array_ops, utils
1010
from xarray.core.indexes import PandasIndex
11-
from xarray.core.utils import either_dict_or_kwargs, iterate_nested
11+
from xarray.core.utils import either_dict_or_kwargs, expand_args_to_dims, iterate_nested
1212

1313
from . import assert_array_equal, requires_cftime, requires_dask
1414
from .test_coding_times import _all_cftime_date_types
@@ -333,3 +333,24 @@ def test_infix_dims_errors(supplied, all_):
333333
)
334334
def test_iterate_nested(nested_list, expected):
335335
assert list(iterate_nested(nested_list)) == expected
336+
337+
338+
def test_expand_args_to_dims():
339+
dims, (arg1, arg2, arg3, arg4) = expanded_args = expand_args_to_dims(
340+
["a", "b"],
341+
["arg1", "arg2", "arg3", "arg4"],
342+
[1, ["val2.1", "val2.2"], False, [True, False]],
343+
)
344+
345+
assert dims == ["a", "b"]
346+
assert arg1 == [1, 1]
347+
assert arg2 == ["val2.1", "val2.2"]
348+
assert arg3 == [False, False]
349+
assert arg4 == [True, False]
350+
351+
with pytest.raises(ValueError, match="Expected all arguments"):
352+
expand_args_to_dims(
353+
["a", "b"],
354+
["arg1", "arg2", "arg3", "arg4"],
355+
["asdf", ["arg2.1", "arg2.2"], ["arg3.1"], ["arg4.1", "arg4.2", "arg4.3"]],
356+
)

0 commit comments

Comments
 (0)