Skip to content

Fix DataArray transpose inconsistent with Dataset Ellipsis usage #4767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,4 @@ re-open it directly with Zarr:

zgroup = zarr.open("rasm.zarr")
print(zgroup.tree())
dict(zgroup["Tair"].attrs)
dict(zgroup["Tair"].attrs)
2 changes: 1 addition & 1 deletion doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -955,4 +955,4 @@ One can also make line plots with multidimensional coordinates. In this case, ``
f, ax = plt.subplots(2, 1)
da.plot.line(x="lon", hue="y", ax=ax[0])
@savefig plotting_example_2d_hue_xy.png
da.plot.line(x="lon", hue="x", ax=ax[1])
da.plot.line(x="lon", hue="x", ax=ax[1])
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Bug fixes
- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`).
By `Alessandro Amici <https://github.com/alexamici>`_
- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling <https://github.com/illviljan>`_.
- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo <https://github.com/mesejo>`_.

Documentation
~~~~~~~~~~~~~
Expand All @@ -76,6 +77,7 @@ Internal Changes
- Run the tests in parallel using pytest-xdist (:pull:`4694`).

By `Justus Magin <https://github.com/keewis>`_ and `Mathias Hauser <https://github.com/mathause>`_.

- Replace all usages of ``assert x.identical(y)`` with ``assert_identical(x, y)``
for clearer error messages.
(:pull:`4752`);
Expand Down
15 changes: 13 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2120,7 +2120,12 @@ def to_unstacked_dataset(self, dim, level=0):
# unstacked dataset
return Dataset(data_dict)

def transpose(self, *dims: Hashable, transpose_coords: bool = True) -> "DataArray":
def transpose(
self,
*dims: Hashable,
transpose_coords: bool = True,
missing_dims: str = "raise",
) -> "DataArray":
"""Return a new DataArray object with transposed dimensions.

Parameters
Expand All @@ -2130,6 +2135,12 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = True) -> "DataArra
dimensions to this order.
transpose_coords : bool, default: True
If True, also transpose the coordinates of this DataArray.
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
What to do if dimensions that should be selected from are not present in the
DataArray:
- "raise": raise an exception
- "warning": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions

Returns
-------
Expand All @@ -2148,7 +2159,7 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = True) -> "DataArra
Dataset.transpose
"""
if dims:
dims = tuple(utils.infix_dims(dims, self.dims))
dims = tuple(utils.infix_dims(dims, self.dims, missing_dims))
variable = self.variable.transpose(*dims)
if transpose_coords:
coords: Dict[Hashable, Variable] = {}
Expand Down
64 changes: 55 additions & 9 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,28 +744,32 @@ def __len__(self) -> int:
return len(self._data) - num_hidden


def infix_dims(dims_supplied: Collection, dims_all: Collection) -> Iterator:
def infix_dims(
dims_supplied: Collection, dims_all: Collection, missing_dims: str = "raise"
) -> Iterator:
"""
Resolves a supplied list containing an ellispsis representing other items, to
Resolves a supplied list containing an ellipsis representing other items, to
a generator with the 'realized' list of all items
"""
if ... in dims_supplied:
if len(set(dims_all)) != len(dims_all):
raise ValueError("Cannot use ellipsis with repeated dims")
if len([d for d in dims_supplied if d == ...]) > 1:
if list(dims_supplied).count(...) > 1:
raise ValueError("More than one ellipsis supplied")
other_dims = [d for d in dims_all if d not in dims_supplied]
for d in dims_supplied:
if d == ...:
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
for d in existing_dims:
if d is ...:
yield from other_dims
else:
yield d
else:
if set(dims_supplied) ^ set(dims_all):
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
if set(existing_dims) ^ set(dims_all):
raise ValueError(
f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
)
yield from dims_supplied
yield from existing_dims


def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
Expand Down Expand Up @@ -805,7 +809,7 @@ def drop_dims_from_indexers(
invalid = indexers.keys() - set(dims)
if invalid:
raise ValueError(
f"dimensions {invalid} do not exist. Expected one or more of {dims}"
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
)

return indexers
Expand All @@ -818,7 +822,7 @@ def drop_dims_from_indexers(
invalid = indexers.keys() - set(dims)
if invalid:
warnings.warn(
f"dimensions {invalid} do not exist. Expected one or more of {dims}"
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
)
for key in invalid:
indexers.pop(key)
Expand All @@ -834,6 +838,48 @@ def drop_dims_from_indexers(
)


def drop_missing_dims(
supplied_dims: Collection, dims: Collection, missing_dims: str
) -> Collection:
"""Depending on the setting of missing_dims, drop any dimensions from supplied_dims that
are not present in dims.

Parameters
----------
supplied_dims : dict
dims : sequence
missing_dims : {"raise", "warn", "ignore"}
"""

if missing_dims == "raise":
supplied_dims_set = set(val for val in supplied_dims if val is not ...)
invalid = supplied_dims_set - set(dims)
if invalid:
raise ValueError(
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
)

return supplied_dims

elif missing_dims == "warn":

invalid = set(supplied_dims) - set(dims)
if invalid:
warnings.warn(
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
)

return [val for val in supplied_dims if val in dims or val is ...]

elif missing_dims == "ignore":
return [val for val in supplied_dims if val in dims or val is ...]

else:
raise ValueError(
f"Unrecognised option {missing_dims} for missing_dims argument"
)


class UncachedAccessor:
"""Acts like a property, but on both classes and class instances

Expand Down
23 changes: 14 additions & 9 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,13 +797,13 @@ def test_isel(self):
assert_identical(self.dv[:3, :5], self.dv.isel(x=slice(3), y=slice(5)))
with raises_regex(
ValueError,
r"dimensions {'not_a_dim'} do not exist. Expected "
r"Dimensions {'not_a_dim'} do not exist. Expected "
r"one or more of \('x', 'y'\)",
):
self.dv.isel(not_a_dim=0)
with pytest.warns(
UserWarning,
match=r"dimensions {'not_a_dim'} do not exist. "
match=r"Dimensions {'not_a_dim'} do not exist. "
r"Expected one or more of \('x', 'y'\)",
):
self.dv.isel(not_a_dim=0, missing_dims="warn")
Expand Down Expand Up @@ -2231,9 +2231,21 @@ def test_transpose(self):
actual = da.transpose("z", ..., "x", transpose_coords=True)
assert_equal(expected, actual)

# same as previous but with a missing dimension
actual = da.transpose(
"z", "y", "x", "not_a_dim", transpose_coords=True, missing_dims="ignore"
)
assert_equal(expected, actual)

with pytest.raises(ValueError):
da.transpose("x", "y")

with pytest.raises(ValueError):
da.transpose("not_a_dim", "z", "x", ...)

with pytest.warns(UserWarning):
da.transpose("not_a_dim", "y", "x", ..., missing_dims="warn")

def test_squeeze(self):
assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable)

Expand Down Expand Up @@ -6227,7 +6239,6 @@ def da_dask(seed=123):

@pytest.mark.parametrize("da", ("repeating_ints",), indirect=True)
def test_isin(da):

expected = DataArray(
np.asarray([[0, 0, 0], [1, 0, 0]]),
dims=list("yx"),
Expand Down Expand Up @@ -6277,7 +6288,6 @@ def test_coarsen_keep_attrs():

@pytest.mark.parametrize("da", (1, 2), indirect=True)
def test_rolling_iter(da):

rolling_obj = da.rolling(time=7)
rolling_obj_mean = rolling_obj.mean()

Expand Down Expand Up @@ -6452,7 +6462,6 @@ def test_rolling_construct(center, window):
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("name", ("sum", "mean", "std", "max"))
def test_rolling_reduce(da, center, min_periods, window, name):

if min_periods is not None and window < min_periods:
min_periods = window

Expand Down Expand Up @@ -6491,7 +6500,6 @@ def test_rolling_reduce_nonnumeric(center, min_periods, window, name):


def test_rolling_count_correct():

da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")

kwargs = [
Expand Down Expand Up @@ -6579,7 +6587,6 @@ def test_ndrolling_construct(center, fill_value):
],
)
def test_rolling_keep_attrs(funcname, argument):

attrs_da = {"da_attr": "test"}

data = np.linspace(10, 15, 100)
Expand Down Expand Up @@ -6623,7 +6630,6 @@ def test_rolling_keep_attrs(funcname, argument):


def test_rolling_keep_attrs_deprecated():

attrs_da = {"da_attr": "test"}

data = np.linspace(10, 15, 100)
Expand Down Expand Up @@ -6957,7 +6963,6 @@ def test_rolling_exp(da, dim, window_type, window):

@requires_numbagg
def test_rolling_exp_keep_attrs(da):

attrs = {"attrs": "da"}
da.attrs = attrs

Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,14 +1024,14 @@ def test_isel(self):
data.isel(not_a_dim=slice(0, 2))
with raises_regex(
ValueError,
r"dimensions {'not_a_dim'} do not exist. Expected "
r"Dimensions {'not_a_dim'} do not exist. Expected "
r"one or more of "
r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*",
):
data.isel(not_a_dim=slice(0, 2))
with pytest.warns(
UserWarning,
match=r"dimensions {'not_a_dim'} do not exist. "
match=r"Dimensions {'not_a_dim'} do not exist. "
r"Expected one or more of "
r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*",
):
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,13 +1270,13 @@ def test_isel(self):
assert_identical(v.isel(time=[]), v[[]])
with raises_regex(
ValueError,
r"dimensions {'not_a_dim'} do not exist. Expected one or more of "
r"Dimensions {'not_a_dim'} do not exist. Expected one or more of "
r"\('time', 'x'\)",
):
v.isel(not_a_dim=0)
with pytest.warns(
UserWarning,
match=r"dimensions {'not_a_dim'} do not exist. Expected one or more of "
match=r"Dimensions {'not_a_dim'} do not exist. Expected one or more of "
r"\('time', 'x'\)",
):
v.isel(not_a_dim=0, missing_dims="warn")
Expand Down