Skip to content

ENH: Plotting for groupby_bins #2152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Oct 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
78c077c
ENH: Plotting for groupby_bins
May 17, 2018
7b400fa
changed pd._libs.interval.Interval to pd.Interval
May 23, 2018
4175bbf
Assign new variable with _interval_to_mid_points instead of mutating …
May 23, 2018
2d11c10
'_center' added to label only for 1d plot
May 23, 2018
e43f0b0
added tests
May 23, 2018
0a15f07
Merge branch 'master' into groupy_plot2
May 23, 2018
a63d68a
missing whitespace
May 23, 2018
347740b
Simplified test
May 29, 2018
73f790a
simplified tests once more
May 29, 2018
ecb0935
1d plots now defaults to step plot
May 29, 2018
b4d05e7
non-uniform bin spacing for pcolormesh
May 29, 2018
e77e996
Added step plot function
Jun 5, 2018
6d9416d
bugfix: linestyle == '' results in no line plotted
Jun 5, 2018
ce407cd
Merge branch 'master' into groupy_plot2
Jun 5, 2018
389f63b
Adapted to upstream changes
Jun 5, 2018
0dcbf50
Added _resolve_intervals_2dplot function, simplified code
Jun 8, 2018
3898394
Merge branch 'master' into groupy_plot2
Jun 8, 2018
0217b29
Added documentation
Jun 13, 2018
447aea3
typo in documentation
Jun 29, 2018
b87d0f6
Merge branch 'master' into groupy_plot2
Aug 9, 2018
98bc369
Fixed bug introduced by upstream change
Aug 9, 2018
87ef1cc
Merge branch 'master' into groupy_plot2
Aug 14, 2018
826df44
Merge branch 'master' into maahn-groupy_plot2
Oct 10, 2018
ea6f6df
Refactor out utility functions.
Oct 10, 2018
1c2d6d6
Fix test.
Oct 10, 2018
a255857
Add whats-new.
Oct 10, 2018
e60728e
Remove duplicate whats new entry. :/
Oct 10, 2018
448d6b7
Make things neater.
Oct 23, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,41 @@ It is also possible to make line plots such that the data are on the x-axis and
@savefig plotting_example_xy_kwarg.png
air.isel(time=10, lon=[10, 11]).plot(y='lat', hue='lon')

Step plots
~~~~~~~~~~

As an alternative, also a step plot similar to matplotlib's ``plt.step`` can be
made using 1D data.

.. ipython:: python

@savefig plotting_example_step.png width=4in
air1d[:20].plot.step(where='mid')

The argument ``where`` defines where the steps should be placed, options are
``'pre'`` (default), ``'post'``, and ``'mid'``. This is particularly handy
when plotting data grouped with :py:func:`xarray.Dataset.groupby_bins`.

.. ipython:: python

air_grp = air.mean(['time','lon']).groupby_bins('lat',[0,23.5,66.5,90])
air_mean = air_grp.mean()
air_std = air_grp.std()
air_mean.plot.step()
(air_mean + air_std).plot.step(ls=':')
(air_mean - air_std).plot.step(ls=':')
plt.ylim(-20,30)
@savefig plotting_example_step_groupby.png width=4in
plt.title('Zonal mean temperature')

In this case, the actual boundaries of the bins are used and the ``where`` argument
is ignored.


Other axes kwargs
-----------------


The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction.

.. ipython:: python
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ Enhancements

- Added support for Python 3.7. (:issue:`2271`).
By `Joe Hamman <https://github.com/jhamman>`_.
- Added support for plotting data with `pandas.Interval` coordinates, such as those
created by :py:meth:`~xarray.DataArray.groupby_bins`
By `Maximilian Maahn <https://github.com/maahn>`_.
- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a
CFTimeIndex by a specified frequency. (:issue:`2244`).
By `Spencer Clark <https://github.com/spencerkclark>`_.
Expand Down
3 changes: 2 additions & 1 deletion xarray/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .plot import (plot, line, contourf, contour,
from .plot import (plot, line, step, contourf, contour,
hist, imshow, pcolormesh)

from .facetgrid import FacetGrid

__all__ = [
'plot',
'line',
'step',
'contour',
'contourf',
'hist',
Expand Down
114 changes: 93 additions & 21 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from .facetgrid import FacetGrid
from .utils import (
ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels, get_axis,
ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels,
_interval_to_double_bound_points, _interval_to_mid_points,
_resolve_intervals_2dplot, _valid_other_type, get_axis,
import_matplotlib_pyplot, label_from_attrs)


Expand All @@ -35,27 +37,20 @@ def _valid_numpy_subdtype(x, numpy_types):
return any(np.issubdtype(x.dtype, t) for t in numpy_types)


def _valid_other_type(x, types):
"""
Do all elements of x have a type from types?
"""
return all(any(isinstance(el, t) for t in types) for el in np.ravel(x))


def _ensure_plottable(*args):
"""
Raise exception if there is anything in args that can't be plotted on an
axis.
axis by matplotlib.
"""
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64]
other_types = [datetime]

for x in args:
if not (_valid_numpy_subdtype(np.array(x), numpy_types) or
_valid_other_type(np.array(x), other_types)):
if not (_valid_numpy_subdtype(np.array(x), numpy_types)
or _valid_other_type(np.array(x), other_types)):
raise TypeError('Plotting requires coordinates to be numeric '
'or dates of type np.datetime64 or '
'datetime.datetime.')
'datetime.datetime or pd.Interval.')


def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None,
Expand Down Expand Up @@ -326,9 +321,30 @@ def line(darray, *args, **kwargs):
xplt, yplt, hueplt, xlabel, ylabel, huelabel = \
_infer_line_data(darray, x, y, hue)

_ensure_plottable(xplt)
# Remove pd.Intervals if contained in xplt.values.
if _valid_other_type(xplt.values, [pd.Interval]):
# Is it a step plot? (see matplotlib.Axes.step)
if kwargs.get('linestyle', '').startswith('steps-'):
xplt_val, yplt_val = _interval_to_double_bound_points(xplt.values,
yplt.values)
# Remove steps-* to be sure that matplotlib is not confused
kwargs['linestyle'] = (kwargs['linestyle']
.replace('steps-pre', '')
.replace('steps-post', '')
.replace('steps-mid', ''))
if kwargs['linestyle'] == '':
kwargs.pop('linestyle')
else:
xplt_val = _interval_to_mid_points(xplt.values)
yplt_val = yplt.values
xlabel += '_center'
else:
xplt_val = xplt.values
yplt_val = yplt.values

primitive = ax.plot(xplt, yplt, *args, **kwargs)
_ensure_plottable(xplt_val, yplt_val)

primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs)

if _labels:
if xlabel is not None:
Expand Down Expand Up @@ -359,6 +375,46 @@ def line(darray, *args, **kwargs):
return primitive


def step(darray, *args, **kwargs):
"""
Step plot of DataArray index against values

Similar to :func:`matplotlib:matplotlib.pyplot.step`

Parameters
----------
where : {'pre', 'post', 'mid'}, optional, default 'pre'
Define where the steps should be placed:
- 'pre': The y value is continued constantly to the left from
every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the
value ``y[i]``.
- 'post': The y value is continued constantly to the right from
every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the
value ``y[i]``.
- 'mid': Steps occur half-way between the *x* positions.
Note that this parameter is ignored if the x coordinate consists of
:py:func:`pandas.Interval` values, e.g. as a result of
:py:func:`xarray.Dataset.groupby_bins`. In this case, the actual
boundaries of the interval are used.

*args, **kwargs : optional
Additional arguments following :py:func:`xarray.plot.line`

"""
if ('ls' in kwargs.keys()) and ('linestyle' not in kwargs.keys()):
kwargs['linestyle'] = kwargs.pop('ls')

where = kwargs.pop('where', 'pre')

if where not in ('pre', 'post', 'mid'):
raise ValueError("'where' argument to step must be "
"'pre', 'post' or 'mid'")

kwargs['linestyle'] = 'steps-' + where + kwargs.get('linestyle', '')

return line(darray, *args, **kwargs)


def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs):
"""
Histogram of DataArray
Expand Down Expand Up @@ -476,6 +532,10 @@ def hist(self, ax=None, **kwargs):
def line(self, *args, **kwargs):
return line(self._da, *args, **kwargs)

@functools.wraps(step)
def step(self, *args, **kwargs):
return step(self._da, *args, **kwargs)


def _rescale_imshow_rgb(darray, vmin, vmax, robust):
assert robust or vmin is not None or vmax is not None
Expand Down Expand Up @@ -716,7 +776,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)

_ensure_plottable(xval, yval)
# Replace pd.Intervals if contained in xval or yval.
xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__)
yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__)

_ensure_plottable(xplt, yplt)

if 'contour' in plotfunc.__name__ and levels is None:
levels = 7 # this is the matplotlib default
Expand Down Expand Up @@ -756,16 +820,16 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
"in xarray")

ax = get_axis(figsize, size, aspect, ax)
primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'],
primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'],
vmin=cmap_params['vmin'],
vmax=cmap_params['vmax'],
norm=cmap_params['norm'],
**kwargs)

# Label the plot with metadata
if add_labels:
ax.set_xlabel(label_from_attrs(darray[xlab]))
ax.set_ylabel(label_from_attrs(darray[ylab]))
ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra))
ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra))
ax.set_title(darray._title_for_slice())

if add_colorbar:
Expand Down Expand Up @@ -794,7 +858,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
# Do this without calling autofmt_xdate so that x-axes ticks
# on other subplots (if any) are not deleted.
# https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
if np.issubdtype(xval.dtype, np.datetime64):
if np.issubdtype(xplt.dtype, np.datetime64):
for xlabels in ax.get_xticklabels():
xlabels.set_rotation(30)
xlabels.set_ha('right')
Expand Down Expand Up @@ -995,14 +1059,22 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
else:
infer_intervals = True

if infer_intervals:
if (infer_intervals and
((np.shape(x)[0] == np.shape(z)[1]) or
((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])))):
if len(x.shape) == 1:
x = _infer_interval_breaks(x, check_monotonic=True)
y = _infer_interval_breaks(y, check_monotonic=True)
else:
# we have to infer the intervals on both axes
x = _infer_interval_breaks(x, axis=1)
x = _infer_interval_breaks(x, axis=0)

if (infer_intervals and
(np.shape(y)[0] == np.shape(z)[0])):
if len(y.shape) == 1:
y = _infer_interval_breaks(y, check_monotonic=True)
else:
# we have to infer the intervals on both axes
y = _infer_interval_breaks(y, axis=1)
y = _infer_interval_breaks(y, axis=0)

Expand Down
68 changes: 66 additions & 2 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import absolute_import, division, print_function

import itertools
import textwrap
import warnings

import numpy as np
import pandas as pd

from ..core.options import OPTIONS
from ..core.pycompat import basestring
Expand Down Expand Up @@ -367,7 +369,7 @@ def get_axis(figsize, size, aspect, ax):
return ax


def label_from_attrs(da):
def label_from_attrs(da, extra=''):
''' Makes informative labels if variable metadata (attrs) follows
CF conventions. '''

Expand All @@ -385,4 +387,66 @@ def label_from_attrs(da):
else:
units = ''

return '\n'.join(textwrap.wrap(name + units, 30))
return '\n'.join(textwrap.wrap(name + extra + units, 30))


def _interval_to_mid_points(array):
"""
Helper function which returns an array
with the Intervals' mid points.
"""

return np.array([x.mid for x in array])


def _interval_to_bound_points(array):
"""
Helper function which returns an array
with the Intervals' boundaries.
"""

array_boundaries = np.array([x.left for x in array])
array_boundaries = np.concatenate(
(array_boundaries, np.array([array[-1].right])))

return array_boundaries


def _interval_to_double_bound_points(xarray, yarray):
"""
Helper function to deal with a xarray consisting of pd.Intervals. Each
interval is replaced with both boundaries. I.e. the length of xarray
doubles. yarray is modified so it matches the new shape of xarray.
"""

xarray1 = np.array([x.left for x in xarray])
xarray2 = np.array([x.right for x in xarray])

xarray = list(itertools.chain.from_iterable(zip(xarray1, xarray2)))
yarray = list(itertools.chain.from_iterable(zip(yarray, yarray)))

return xarray, yarray


def _resolve_intervals_2dplot(val, func_name):
"""
Helper function to replace the values of a coordinate array containing
pd.Interval with their mid-points or - for pcolormesh - boundaries which
increases length by 1.
"""
label_extra = ''
if _valid_other_type(val, [pd.Interval]):
if func_name == 'pcolormesh':
val = _interval_to_bound_points(val)
else:
val = _interval_to_mid_points(val)
label_extra = '_center'

return val, label_extra


def _valid_other_type(x, types):
"""
Do all elements of x have a type from types?
"""
return all(any(isinstance(el, t) for t in types) for el in np.ravel(x))
Loading