-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
ENH: Plotting for groupby_bins #2152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
78c077c
7b400fa
4175bbf
2d11c10
e43f0b0
0a15f07
a63d68a
347740b
73f790a
ecb0935
b4d05e7
e77e996
6d9416d
ce407cd
389f63b
0dcbf50
3898394
0217b29
447aea3
b87d0f6
98bc369
87ef1cc
826df44
ea6f6df
1c2d6d6
a255857
e60728e
448d6b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,14 +48,23 @@ def _ensure_plottable(*args): | |
axis. | ||
""" | ||
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64] | ||
other_types = [datetime] | ||
other_types = [datetime, pd.Interval] | ||
|
||
for x in args: | ||
if not (_valid_numpy_subdtype(np.array(x), numpy_types) or | ||
_valid_other_type(np.array(x), other_types)): | ||
raise TypeError('Plotting requires coordinates to be numeric ' | ||
'or dates of type np.datetime64 or ' | ||
'datetime.datetime.') | ||
'datetime.datetime or pd.Interval.') | ||
|
||
|
||
def _interval_to_mid_points(array): | ||
""" | ||
Helper function which returns an array | ||
with the Intervals' mid points. | ||
""" | ||
|
||
return np.asarray(list(map(lambda x: x.mid, array))) | ||
|
||
|
||
def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, | ||
|
@@ -267,7 +276,14 @@ def line(darray, *args, **kwargs): | |
|
||
_ensure_plottable(xplt) | ||
|
||
primitive = ax.plot(xplt, yplt, *args, **kwargs) | ||
# Remove pd.Intervals if contained in xplt.values. | ||
if _valid_other_type(xplt.values, [pd.Interval]): | ||
xplt_val = _interval_to_mid_points(xplt.values) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might make sense to plot labels like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure, I guess in many case there is not enough space for all tick labels. And labeling only some intervals might be confusing? Maybe something like a step plot would be an alternative? https://matplotlib.org/gallery/lines_bars_and_markers/step_demo.html There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this is probably better default behavior. Potentially there could be a flag to choose. |
||
xlabel += '_center' | ||
else: | ||
xplt_val = xplt.values | ||
|
||
primitive = ax.plot(xplt_val, yplt, *args, **kwargs) | ||
|
||
if xlabel is not None: | ||
ax.set_xlabel(xlabel) | ||
|
@@ -610,6 +626,16 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, | |
|
||
_ensure_plottable(xval, yval) | ||
|
||
# Replace pd.Intervals if contained in xval or yval. | ||
if _valid_other_type(xval, [pd.Interval]): | ||
xplt = _interval_to_mid_points(xval) | ||
else: | ||
xplt = xval | ||
if _valid_other_type(yval, [pd.Interval]): | ||
yplt = _interval_to_mid_points(yval) | ||
else: | ||
yplt = yval | ||
|
||
if 'contour' in plotfunc.__name__ and levels is None: | ||
levels = 7 # this is the matplotlib default | ||
|
||
|
@@ -645,7 +671,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, | |
"in xarray") | ||
|
||
ax = get_axis(figsize, size, aspect, ax) | ||
primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'], | ||
primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'], | ||
vmin=cmap_params['vmin'], | ||
vmax=cmap_params['vmax'], | ||
**kwargs) | ||
|
@@ -674,7 +700,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, | |
_update_axes_limits(ax, xincrease, yincrease) | ||
|
||
# Rotate dates on xlabels | ||
if np.issubdtype(xval.dtype, np.datetime64): | ||
if np.issubdtype(xplt.dtype, np.datetime64): | ||
ax.get_figure().autofmt_xdate() | ||
|
||
return primitive | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -297,6 +297,19 @@ def test_convenient_facetgrid_4d(self): | |
with raises_regex(ValueError, '[Ff]acet'): | ||
d.plot(x='x', y='y', col='columns', ax=plt.gca()) | ||
|
||
def test_coord_with_interval(self): | ||
for dim in self.darray.dims: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure what the point of testing multiple dimensions is -- do you expect different behavior for different dimensions? If not, I would probably just pick one dimension. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point, that was basically a copy paste error from the 2d version. Will change that. |
||
for method in ['argmax', 'argmin', 'max', 'min', | ||
'mean', 'prod', 'sum', | ||
'std', 'var', 'median']: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to test all these different methods here. They all use the same logic internally, so just one groupby method should be enough. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, will use mean only. |
||
gp = self.darray.groupby_bins(dim, [-1, 0, 1, 2]) | ||
getattr(gp, method)().plot() | ||
|
||
def test_coord_with_interval_label_contains_center(self): | ||
for dim in self.darray.dims: | ||
self.darray.groupby_bins(dim, [-1, 0, 1, 2]).mean().plot() | ||
assert plt.gca().get_xlabel().endswith('_center') | ||
|
||
|
||
class TestPlot1D(PlotTestCase): | ||
def setUp(self): | ||
|
@@ -404,6 +417,14 @@ def test_plot_nans(self): | |
self.darray[0, 0, 0] = np.nan | ||
self.darray.plot.hist() | ||
|
||
def test_hist_coord_with_interval(self): | ||
for dim in self.darray.dims: | ||
for method in ['argmax', 'argmin', 'max', 'min', | ||
'mean', 'prod', 'sum', | ||
'std', 'var', 'median']: | ||
gp = self.darray.groupby_bins(dim, [-1, 0, 1, 2]) | ||
getattr(gp, method)().plot.hist(range=(-1, 2)) | ||
|
||
|
||
@requires_matplotlib | ||
class TestDetermineCmapParams(TestCase): | ||
|
@@ -959,6 +980,16 @@ def test_cmap_and_color_both(self): | |
with pytest.raises(ValueError): | ||
self.plotmethod(colors='k', cmap='RdBu') | ||
|
||
def test_2d_coord_with_interval(self): | ||
for dim in self.darray.dims: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left the loop here, because for the 2d plots, x and y axis are treated separately. |
||
gp = self.darray.groupby_bins(dim, range(15)) | ||
for method in ['argmax', 'argmin', 'max', 'min', | ||
'mean', 'prod', 'sum', | ||
'std', 'var', 'median']: | ||
gp_method = getattr(gp, method)(dim) | ||
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: | ||
getattr(gp_method.plot, kind)() | ||
|
||
|
||
@pytest.mark.slow | ||
class TestContourf(Common2dMixin, PlotTestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider writing these with list comprehensions, e.g.,
np.array([x.mid for x in array])