diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4149432b117..7f4dd464118 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -33,6 +33,9 @@ Enhancements - Experimental support for parsing ENVI metadata to coordinates and attributes in :py:func:`xarray.open_rasterio`. By `Matti Eskelinen `_. +- :py:func:`~plot.line()` learned to draw multiple lines if provided with a + 2D variable. + By `Deepak Cherian `_. .. _Zarr: http://zarr.readthedocs.io/ diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 94556d70f6c..2952bd14c51 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -156,7 +156,7 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, # matplotlib format strings def line(darray, *args, **kwargs): """ - Line plot of 1 dimensional DataArray index against values + Line plot of DataArray index against values Wraps :func:`matplotlib:matplotlib.pyplot.plot` @@ -176,6 +176,12 @@ def line(darray, *args, **kwargs): ax : matplotlib axes object, optional Axis on which to plot this figure. By default, use the current axis. Mutually exclusive with ``size`` and ``figsize``. + hue : string, optional + Coordinate for which you want multiple lines plotted (2D inputs only). + x : string, optional + Coordinate for x axis. + add_legend : boolean, optional + Add legend with y axis coordinates (2D inputs only). *args, **kwargs : optional Additional arguments to matplotlib.pyplot.plot @@ -183,8 +189,8 @@ def line(darray, *args, **kwargs): plt = import_matplotlib_pyplot() ndims = len(darray.dims) - if ndims != 1: - raise ValueError('Line plots are for 1 dimensional DataArrays. ' + if ndims > 2: + raise ValueError('Line plots are for 1- or 2-dimensional DataArrays. ' 'Passed DataArray has {ndims} ' 'dimensions'.format(ndims=ndims)) @@ -193,11 +199,27 @@ def line(darray, *args, **kwargs): aspect = kwargs.pop('aspect', None) size = kwargs.pop('size', None) ax = kwargs.pop('ax', None) + hue = kwargs.pop('hue', None) + x = kwargs.pop('x', None) + add_legend = kwargs.pop('add_legend', True) ax = get_axis(figsize, size, aspect, ax) - xlabel, = darray.dims - x = darray.coords[xlabel] + if ndims == 1: + xlabel, = darray.dims + if x is not None and xlabel != x: + raise ValueError('Input does not have specified dimension' + + ' {!r}'.format(x)) + + x = darray.coords[xlabel] + + else: + if x is None and hue is None: + raise ValueError('For 2D inputs, please specify either hue or x.') + + xlabel, huelabel = _infer_xy_labels(darray=darray, x=x, y=hue) + x = darray.coords[xlabel] + darray = darray.transpose(xlabel, huelabel) _ensure_plottable(x) @@ -209,6 +231,11 @@ def line(darray, *args, **kwargs): if darray.name is not None: ax.set_ylabel(darray.name) + if darray.ndim == 2 and add_legend: + ax.legend(handles=primitive, + labels=list(darray.coords[huelabel].values), + title=huelabel) + # Rotate dates on xlabels if np.issubdtype(x.dtype, np.datetime64): plt.gcf().autofmt_xdate() diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 1d62f7856f9..dd589eb3765 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -94,6 +94,41 @@ def setUp(self): def test1d(self): self.darray[:, 0, 0].plot() + with raises_regex(ValueError, 'dimension'): + self.darray[:, 0, 0].plot(x='dim_1') + + def test_2d_line(self): + with raises_regex(ValueError, 'hue'): + self.darray[:, :, 0].plot.line() + + self.darray[:, :, 0].plot.line(hue='dim_1') + + def test_2d_line_accepts_legend_kw(self): + self.darray[:, :, 0].plot.line(x='dim_0', add_legend=False) + self.assertFalse(plt.gca().get_legend()) + plt.cla() + self.darray[:, :, 0].plot.line(x='dim_0', add_legend=True) + self.assertTrue(plt.gca().get_legend()) + # check whether legend title is set + self.assertTrue(plt.gca().get_legend().get_title().get_text() + == 'dim_1') + + def test_2d_line_accepts_x_kw(self): + self.darray[:, :, 0].plot.line(x='dim_0') + self.assertTrue(plt.gca().get_xlabel() == 'dim_0') + plt.cla() + self.darray[:, :, 0].plot.line(x='dim_1') + self.assertTrue(plt.gca().get_xlabel() == 'dim_1') + + def test_2d_line_accepts_hue_kw(self): + self.darray[:, :, 0].plot.line(hue='dim_0') + self.assertTrue(plt.gca().get_legend().get_title().get_text() + == 'dim_0') + plt.cla() + self.darray[:, :, 0].plot.line(hue='dim_1') + self.assertTrue(plt.gca().get_legend().get_title().get_text() + == 'dim_1') + def test_2d_before_squeeze(self): a = DataArray(easy_array((1, 5))) a.plot() @@ -243,11 +278,6 @@ def test_ylabel_is_data_name(self): self.darray.plot() self.assertEqual(self.darray.name, plt.gca().get_ylabel()) - def test_wrong_dims_raises_valueerror(self): - twodims = DataArray(easy_array((2, 5))) - with pytest.raises(ValueError): - twodims.plot.line() - def test_format_string(self): self.darray.plot.line('ro')