diff --git a/doc/plotting.rst b/doc/plotting.rst index 43faa83b9da..271d63c37ab 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -222,9 +222,41 @@ It is also possible to make line plots such that the data are on the x-axis and @savefig plotting_example_xy_kwarg.png air.isel(time=10, lon=[10, 11]).plot(y='lat', hue='lon') +Step plots +~~~~~~~~~~ + +As an alternative, also a step plot similar to matplotlib's ``plt.step`` can be +made using 1D data. + +.. ipython:: python + + @savefig plotting_example_step.png width=4in + air1d[:20].plot.step(where='mid') + +The argument ``where`` defines where the steps should be placed, options are +``'pre'`` (default), ``'post'``, and ``'mid'``. This is particularly handy +when plotting data grouped with :py:func:`xarray.Dataset.groupby_bins`. + +.. ipython:: python + + air_grp = air.mean(['time','lon']).groupby_bins('lat',[0,23.5,66.5,90]) + air_mean = air_grp.mean() + air_std = air_grp.std() + air_mean.plot.step() + (air_mean + air_std).plot.step(ls=':') + (air_mean - air_std).plot.step(ls=':') + plt.ylim(-20,30) + @savefig plotting_example_step_groupby.png width=4in + plt.title('Zonal mean temperature') + +In this case, the actual boundaries of the bins are used and the ``where`` argument +is ignored. + + Other axes kwargs ----------------- + The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction. .. ipython:: python diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d0fec7b0778..92356b1f1ff 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -59,6 +59,9 @@ Enhancements - Added support for Python 3.7. (:issue:`2271`). By `Joe Hamman `_. +- Added support for plotting data with `pandas.Interval` coordinates, such as those + created by :py:meth:`~xarray.DataArray.groupby_bins` + By `Maximilian Maahn `_. - Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a CFTimeIndex by a specified frequency. (:issue:`2244`). By `Spencer Clark `_. diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index fe2c604a89e..4b53b22243c 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from .plot import (plot, line, contourf, contour, +from .plot import (plot, line, step, contourf, contour, hist, imshow, pcolormesh) from .facetgrid import FacetGrid @@ -9,6 +9,7 @@ __all__ = [ 'plot', 'line', + 'step', 'contour', 'contourf', 'hist', diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index b44ae7b3856..8cfa0bc1fd4 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -19,7 +19,9 @@ from .facetgrid import FacetGrid from .utils import ( - ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels, get_axis, + ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels, + _interval_to_double_bound_points, _interval_to_mid_points, + _resolve_intervals_2dplot, _valid_other_type, get_axis, import_matplotlib_pyplot, label_from_attrs) @@ -35,27 +37,20 @@ def _valid_numpy_subdtype(x, numpy_types): return any(np.issubdtype(x.dtype, t) for t in numpy_types) -def _valid_other_type(x, types): - """ - Do all elements of x have a type from types? - """ - return all(any(isinstance(el, t) for t in types) for el in np.ravel(x)) - - def _ensure_plottable(*args): """ Raise exception if there is anything in args that can't be plotted on an - axis. + axis by matplotlib. """ numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64] other_types = [datetime] for x in args: - if not (_valid_numpy_subdtype(np.array(x), numpy_types) or - _valid_other_type(np.array(x), other_types)): + if not (_valid_numpy_subdtype(np.array(x), numpy_types) + or _valid_other_type(np.array(x), other_types)): raise TypeError('Plotting requires coordinates to be numeric ' 'or dates of type np.datetime64 or ' - 'datetime.datetime.') + 'datetime.datetime or pd.Interval.') def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, @@ -326,9 +321,30 @@ def line(darray, *args, **kwargs): xplt, yplt, hueplt, xlabel, ylabel, huelabel = \ _infer_line_data(darray, x, y, hue) - _ensure_plottable(xplt) + # Remove pd.Intervals if contained in xplt.values. + if _valid_other_type(xplt.values, [pd.Interval]): + # Is it a step plot? (see matplotlib.Axes.step) + if kwargs.get('linestyle', '').startswith('steps-'): + xplt_val, yplt_val = _interval_to_double_bound_points(xplt.values, + yplt.values) + # Remove steps-* to be sure that matplotlib is not confused + kwargs['linestyle'] = (kwargs['linestyle'] + .replace('steps-pre', '') + .replace('steps-post', '') + .replace('steps-mid', '')) + if kwargs['linestyle'] == '': + kwargs.pop('linestyle') + else: + xplt_val = _interval_to_mid_points(xplt.values) + yplt_val = yplt.values + xlabel += '_center' + else: + xplt_val = xplt.values + yplt_val = yplt.values - primitive = ax.plot(xplt, yplt, *args, **kwargs) + _ensure_plottable(xplt_val, yplt_val) + + primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) if _labels: if xlabel is not None: @@ -359,6 +375,46 @@ def line(darray, *args, **kwargs): return primitive +def step(darray, *args, **kwargs): + """ + Step plot of DataArray index against values + + Similar to :func:`matplotlib:matplotlib.pyplot.step` + + Parameters + ---------- + where : {'pre', 'post', 'mid'}, optional, default 'pre' + Define where the steps should be placed: + - 'pre': The y value is continued constantly to the left from + every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the + value ``y[i]``. + - 'post': The y value is continued constantly to the right from + every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the + value ``y[i]``. + - 'mid': Steps occur half-way between the *x* positions. + Note that this parameter is ignored if the x coordinate consists of + :py:func:`pandas.Interval` values, e.g. as a result of + :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual + boundaries of the interval are used. + + *args, **kwargs : optional + Additional arguments following :py:func:`xarray.plot.line` + + """ + if ('ls' in kwargs.keys()) and ('linestyle' not in kwargs.keys()): + kwargs['linestyle'] = kwargs.pop('ls') + + where = kwargs.pop('where', 'pre') + + if where not in ('pre', 'post', 'mid'): + raise ValueError("'where' argument to step must be " + "'pre', 'post' or 'mid'") + + kwargs['linestyle'] = 'steps-' + where + kwargs.get('linestyle', '') + + return line(darray, *args, **kwargs) + + def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): """ Histogram of DataArray @@ -476,6 +532,10 @@ def hist(self, ax=None, **kwargs): def line(self, *args, **kwargs): return line(self._da, *args, **kwargs) + @functools.wraps(step) + def step(self, *args, **kwargs): + return step(self._da, *args, **kwargs) + def _rescale_imshow_rgb(darray, vmin, vmax, robust): assert robust or vmin is not None or vmax is not None @@ -716,7 +776,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # Pass the data as a masked ndarray too zval = darray.to_masked_array(copy=False) - _ensure_plottable(xval, yval) + # Replace pd.Intervals if contained in xval or yval. + xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) + yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__) + + _ensure_plottable(xplt, yplt) if 'contour' in plotfunc.__name__ and levels is None: levels = 7 # this is the matplotlib default @@ -756,7 +820,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, "in xarray") ax = get_axis(figsize, size, aspect, ax) - primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'], + primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'], vmin=cmap_params['vmin'], vmax=cmap_params['vmax'], norm=cmap_params['norm'], @@ -764,8 +828,8 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # Label the plot with metadata if add_labels: - ax.set_xlabel(label_from_attrs(darray[xlab])) - ax.set_ylabel(label_from_attrs(darray[ylab])) + ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) + ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) if add_colorbar: @@ -794,7 +858,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # Do this without calling autofmt_xdate so that x-axes ticks # on other subplots (if any) are not deleted. # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - if np.issubdtype(xval.dtype, np.datetime64): + if np.issubdtype(xplt.dtype, np.datetime64): for xlabels in ax.get_xticklabels(): xlabels.set_rotation(30) xlabels.set_ha('right') @@ -995,14 +1059,22 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): else: infer_intervals = True - if infer_intervals: + if (infer_intervals and + ((np.shape(x)[0] == np.shape(z)[1]) or + ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])))): if len(x.shape) == 1: x = _infer_interval_breaks(x, check_monotonic=True) - y = _infer_interval_breaks(y, check_monotonic=True) else: # we have to infer the intervals on both axes x = _infer_interval_breaks(x, axis=1) x = _infer_interval_breaks(x, axis=0) + + if (infer_intervals and + (np.shape(y)[0] == np.shape(z)[0])): + if len(y.shape) == 1: + y = _infer_interval_breaks(y, check_monotonic=True) + else: + # we have to infer the intervals on both axes y = _infer_interval_breaks(y, axis=1) y = _infer_interval_breaks(y, axis=0) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index be38a6d7a4c..2ed115c85c4 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -1,9 +1,11 @@ from __future__ import absolute_import, division, print_function +import itertools import textwrap import warnings import numpy as np +import pandas as pd from ..core.options import OPTIONS from ..core.pycompat import basestring @@ -367,7 +369,7 @@ def get_axis(figsize, size, aspect, ax): return ax -def label_from_attrs(da): +def label_from_attrs(da, extra=''): ''' Makes informative labels if variable metadata (attrs) follows CF conventions. ''' @@ -385,4 +387,66 @@ def label_from_attrs(da): else: units = '' - return '\n'.join(textwrap.wrap(name + units, 30)) + return '\n'.join(textwrap.wrap(name + extra + units, 30)) + + +def _interval_to_mid_points(array): + """ + Helper function which returns an array + with the Intervals' mid points. + """ + + return np.array([x.mid for x in array]) + + +def _interval_to_bound_points(array): + """ + Helper function which returns an array + with the Intervals' boundaries. + """ + + array_boundaries = np.array([x.left for x in array]) + array_boundaries = np.concatenate( + (array_boundaries, np.array([array[-1].right]))) + + return array_boundaries + + +def _interval_to_double_bound_points(xarray, yarray): + """ + Helper function to deal with a xarray consisting of pd.Intervals. Each + interval is replaced with both boundaries. I.e. the length of xarray + doubles. yarray is modified so it matches the new shape of xarray. + """ + + xarray1 = np.array([x.left for x in xarray]) + xarray2 = np.array([x.right for x in xarray]) + + xarray = list(itertools.chain.from_iterable(zip(xarray1, xarray2))) + yarray = list(itertools.chain.from_iterable(zip(yarray, yarray))) + + return xarray, yarray + + +def _resolve_intervals_2dplot(val, func_name): + """ + Helper function to replace the values of a coordinate array containing + pd.Interval with their mid-points or - for pcolormesh - boundaries which + increases length by 1. + """ + label_extra = '' + if _valid_other_type(val, [pd.Interval]): + if func_name == 'pcolormesh': + val = _interval_to_bound_points(val) + else: + val = _interval_to_mid_points(val) + label_extra = '_center' + + return val, label_extra + + +def _valid_other_type(x, types): + """ + Do all elements of x have a type from types? + """ + return all(any(isinstance(el, t) for t in types) for el in np.ravel(x)) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 53f6077ee4f..0f03e7e233e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -345,6 +345,10 @@ def test_convenient_facetgrid_4d(self): with raises_regex(ValueError, '[Ff]acet'): d.plot(x='x', y='y', col='columns', ax=plt.gca()) + def test_coord_with_interval(self): + bins = [-1, 0, 1, 2] + self.darray.groupby_bins('dim_0', bins).mean(xr.ALL_DIMS).plot() + class TestPlot1D(PlotTestCase): @pytest.fixture(autouse=True) @@ -419,6 +423,20 @@ def test_slice_in_title(self): assert 'd = 10' == title +class TestPlotStep(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + self.darray = DataArray(easy_array((2, 3, 4))) + + def test_step(self): + self.darray[0, 0].plot.step() + + def test_coord_with_interval_step(self): + bins = [-1, 0, 1, 2] + self.darray.groupby_bins('dim_0', bins).mean(xr.ALL_DIMS).plot.step() + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + + class TestPlotHistogram(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): @@ -454,6 +472,10 @@ def test_plot_nans(self): self.darray[0, 0, 0] = np.nan self.darray.plot.hist() + def test_hist_coord_with_interval(self): + (self.darray.groupby_bins('dim_0', [-1, 0, 1, 2]).mean(xr.ALL_DIMS) + .plot.hist(range=(-1, 2))) + @requires_matplotlib class TestDetermineCmapParams(object): @@ -1110,6 +1132,12 @@ def test_cmap_and_color_both(self): with pytest.raises(ValueError): self.plotmethod(colors='k', cmap='RdBu') + def test_2d_coord_with_interval(self): + for dim in self.darray.dims: + gp = self.darray.groupby_bins(dim, range(15)).mean(dim) + for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: + getattr(gp.plot, kind)() + def test_colormap_error_norm_and_vmin_vmax(self): norm = mpl.colors.LogNorm(0.1, 1e1)