Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ Breaking changes
Enhancements
~~~~~~~~~~~~

- :py:meth:`DataArray.transpose` now accepts a keyword argument
``transpose_coords`` which enables transposition of coordinates in the
same way as :py:meth:`Dataset.transpose`.
By `Peter Hausamann <http://github.com/phausamann>`_.

Bug fixes
~~~~~~~~~

Expand Down
27 changes: 25 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,14 +1357,16 @@ def unstack(self, dim=None):
ds = self._to_temp_dataset().unstack(dim)
return self._from_temp_dataset(ds)

def transpose(self, *dims):
def transpose(self, *dims, **kwargs):
"""Return a new DataArray object with transposed dimensions.

Parameters
----------
*dims : str, optional
By default, reverse the dimensions. Otherwise, reorder the
dimensions to this order.
transpose_coords : boolean, optional
If True, also transpose the coordinates of this DataArray.

Returns
-------
Expand All @@ -1381,8 +1383,29 @@ def transpose(self, *dims):
numpy.transpose
Dataset.transpose
"""
if dims:
if set(dims) ^ set(self.dims):
raise ValueError('arguments to transpose (%s) must be '
'permuted array dimensions (%s)'
% (dims, tuple(self.dims)))

transpose_coords = kwargs.pop('transpose_coords', None)
variable = self.variable.transpose(*dims)
return self._replace(variable)
if transpose_coords:
coords = {}
for name, coord in iteritems(self.coords):
coord_dims = tuple(dim for dim in dims if dim in coord.dims)
coords[name] = coord.variable.transpose(*coord_dims)
return self._replace(variable, coords)
else:
if transpose_coords is None \
and any(self[c].ndim > 1 for c in self.coords):
warnings.warn('This DataArray contains multi-dimensional '
'coordinates. In the future, these coordinates '
'will be transposed as well unless you specify '
'transpose_coords=False.',
FutureWarning, stacklevel=2)
return self._replace(variable)

def drop(self, labels, dim=None):
"""Drop coordinates or index labels from this DataArray.
Expand Down
20 changes: 18 additions & 2 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,8 +1709,24 @@ def test_stack_nonunique_consistency(self):
assert_identical(expected, actual)

def test_transpose(self):
assert_equal(self.dv.variable.transpose(),
self.dv.transpose().variable)
da = DataArray(np.random.randn(3, 4, 5), dims=('x', 'y', 'z'),
coords={'x': range(3), 'y': range(4), 'z': range(5),
'xy': (('x', 'y'), np.random.randn(3, 4))})

actual = da.transpose(transpose_coords=False)
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
coords=da.coords)
assert_equal(expected, actual)

actual = da.transpose('z', 'y', 'x', transpose_coords=True)
expected = DataArray(da.values.T, dims=('z', 'y', 'x'),
coords={'x': da.x.values, 'y': da.y.values,
'z': da.z.values,
'xy': (('y', 'x'), da.xy.values.T)})
assert_equal(expected, actual)

with pytest.raises(ValueError):
da.transpose('x', 'y')

def test_squeeze(self):
assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable)
Expand Down
9 changes: 7 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3799,10 +3799,15 @@ def test_dataset_math_errors(self):

def test_dataset_transpose(self):
ds = Dataset({'a': (('x', 'y'), np.random.randn(3, 4)),
'b': (('y', 'x'), np.random.randn(4, 3))})
'b': (('y', 'x'), np.random.randn(4, 3))},
coords={'x': range(3), 'y': range(4),
'xy': (('x', 'y'), np.random.randn(3, 4))})

actual = ds.transpose()
expected = ds.apply(lambda x: x.transpose())
expected = Dataset({'a': (('y', 'x'), ds.a.values.T),
'b': (('x', 'y'), ds.b.values.T)},
coords={'x': ds.x.values, 'y': ds.y.values,
'xy': (('y', 'x'), ds.xy.values.T)})
assert_identical(expected, actual)

actual = ds.transpose('x', 'y')
Expand Down