diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 9b1630c3624..ae44297058b 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -10,6 +10,7 @@ Iterable, Literal, MutableMapping, + cast, overload, ) @@ -925,7 +926,7 @@ def newplotfunc( _is_facetgrid = kwargs.pop("_is_facetgrid", False) - if markersize is not None: + if plotfunc.__name__ == "scatter": size_ = markersize size_r = _MARKERSIZE_RANGE else: @@ -960,7 +961,7 @@ def newplotfunc( cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( plotfunc, - hueplt_norm.values.data, + cast("DataArray", hueplt_norm.values).data, **locals(), ) @@ -1013,13 +1014,7 @@ def newplotfunc( ) if add_legend_: - if plotfunc.__name__ == "hist": - ax.legend( - handles=primitive[-1], - labels=list(hueplt_norm.values.to_numpy()), - title=label_from_attrs(hueplt_norm.data), - ) - elif plotfunc.__name__ in ["scatter", "line"]: + if plotfunc.__name__ in ["scatter", "line"]: _add_legend( hueplt_norm if add_legend or not add_colorbar_ @@ -1030,11 +1025,26 @@ def newplotfunc( plotfunc=plotfunc.__name__, ) else: - ax.legend( - handles=primitive, - labels=list(hueplt_norm.values.to_numpy()), - title=label_from_attrs(hueplt_norm.data), - ) + hueplt_norm_values: list[np.ndarray | None] + if hueplt_norm.data is not None: + hueplt_norm_values = list( + cast("DataArray", hueplt_norm.data).to_numpy() + ) + else: + hueplt_norm_values = [hueplt_norm.data] + + if plotfunc.__name__ == "hist": + ax.legend( + handles=primitive[-1], + labels=hueplt_norm_values, + title=label_from_attrs(hueplt_norm.data), + ) + else: + ax.legend( + handles=primitive, + labels=hueplt_norm_values, + title=label_from_attrs(hueplt_norm.data), + ) _update_axes( ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 8b37bd0bb04..c88fb8b9318 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -12,6 +12,7 @@ Iterable, Literal, TypeVar, + cast, ) import numpy as np @@ -41,6 +42,9 @@ from matplotlib.quiver import QuiverKey from matplotlib.text import Annotation + from ..core.dataarray import DataArray + + # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams _FONTSIZE = "small" @@ -402,18 +406,24 @@ def map_plot1d( hueplt_norm = _Normalize(hueplt) self._hue_var = hueplt cbar_kwargs = kwargs.pop("cbar_kwargs", {}) - if not hueplt_norm.data_is_numeric: - # TODO: Ticks seems a little too hardcoded, since it will always - # show all the values. But maybe it's ok, since plotting hundreds - # of categorical data isn't that meaningful anyway. - cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) - kwargs.update(levels=hueplt_norm.levels) - if "label" not in cbar_kwargs: - cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) - cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - func, hueplt_norm.values.to_numpy(), cbar_kwargs=cbar_kwargs, **kwargs - ) - self._cmap_extend = cmap_params.get("extend") + + if hueplt_norm.data is not None: + if not hueplt_norm.data_is_numeric: + # TODO: Ticks seems a little too hardcoded, since it will always + # show all the values. But maybe it's ok, since plotting hundreds + # of categorical data isn't that meaningful anyway. + cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) + kwargs.update(levels=hueplt_norm.levels) + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + func, + cast("DataArray", hueplt_norm.values).data, + cbar_kwargs=cbar_kwargs, + **kwargs, + ) + self._cmap_extend = cmap_params.get("extend") + else: + cmap_params = {} # Handle sizes: _size_r = _MARKERSIZE_RANGE if func.__name__ == "scatter" else _LINEWIDTH_RANGE @@ -513,6 +523,9 @@ def map_plot1d( if add_colorbar: # Colorbar is after legend so it correctly fits the plot: + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) + self.add_colorbar(**cbar_kwargs) return self diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a19f3c54cdd..e27695c4347 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -13,7 +13,6 @@ Iterable, Mapping, Sequence, - TypeVar, overload, ) @@ -1369,9 +1368,6 @@ def _parse_size( return pd.Series(sizes) -T = TypeVar("T", np.ndarray, "DataArray") - - class _Normalize(Sequence): """ Normalize numerical or categorical values to numerical values. @@ -1389,19 +1385,19 @@ class _Normalize(Sequence): """ _data: DataArray | None + _data_unique: np.ndarray + _data_unique_index: np.ndarray + _data_unique_inverse: np.ndarray _data_is_numeric: bool _width: tuple[float, float] | None - _unique: np.ndarray - _unique_index: np.ndarray - _unique_inverse: np.ndarray | DataArray __slots__ = ( "_data", + "_data_unique", + "_data_unique_index", + "_data_unique_inverse", "_data_is_numeric", "_width", - "_unique", - "_unique_index", - "_unique_inverse", ) def __init__( @@ -1416,36 +1412,27 @@ def __init__( pint_array_type = DuckArrayModule("pint").type to_unique = ( data.to_numpy() # type: ignore[union-attr] - if isinstance(self._type, pint_array_type) + if isinstance(data if data is None else data.data, pint_array_type) else data ) - unique, unique_inverse = np.unique(to_unique, return_inverse=True) # type: ignore[call-overload] - self._unique = unique - self._unique_index = np.arange(0, unique.size) - if data is not None: - self._unique_inverse = data.copy(data=unique_inverse.reshape(data.shape)) - self._data_is_numeric = _is_numeric(data) - else: - self._unique_inverse = unique_inverse - self._data_is_numeric = False + data_unique, data_unique_inverse = np.unique(to_unique, return_inverse=True) # type: ignore[call-overload] + self._data_unique = data_unique + self._data_unique_index = np.arange(0, data_unique.size) + self._data_unique_inverse = data_unique_inverse + self._data_is_numeric = False if data is None else _is_numeric(data) def __repr__(self) -> str: with np.printoptions(precision=4, suppress=True, threshold=5): return ( f"<_Normalize(data, width={self._width})>\n" - f"{self._unique} -> {self.values_unique}" + f"{self._data_unique} -> {self._values_unique}" ) def __len__(self) -> int: - return len(self._unique) + return len(self._data_unique) def __getitem__(self, key): - return self._unique[key] - - @property - def _type(self) -> Any | None: # same as DataArray.data? - da = self.data - return da.data if da is not None else da + return self._data_unique[key] @property def data(self) -> DataArray | None: @@ -1461,11 +1448,23 @@ def data_is_numeric(self) -> bool: >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) >>> _Normalize(a).data_is_numeric False + + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> _Normalize(a).data_is_numeric + True """ return self._data_is_numeric - def _calc_widths(self, y: T | None) -> T | None: - if self._width is None or y is None: + @overload + def _calc_widths(self, y: np.ndarray) -> np.ndarray: + ... + + @overload + def _calc_widths(self, y: DataArray) -> DataArray: + ... + + def _calc_widths(self, y: np.ndarray | DataArray) -> np.ndarray | DataArray: + if self._width is None: return y x0, x1 = self._width @@ -1475,17 +1474,23 @@ def _calc_widths(self, y: T | None) -> T | None: return widths - def _indexes_centered(self, x: T) -> T | None: + @overload + def _indexes_centered(self, x: np.ndarray) -> np.ndarray: + ... + + @overload + def _indexes_centered(self, x: DataArray) -> DataArray: + ... + + def _indexes_centered(self, x: np.ndarray | DataArray) -> np.ndarray | DataArray: """ Offset indexes to make sure being in the center of self.levels. ["a", "b", "c"] -> [1, 3, 5] """ - if self.data is None: - return None return x * 2 + 1 @property - def values(self): + def values(self) -> DataArray | None: """ Return a normalized number array for the unique levels. @@ -1513,43 +1518,52 @@ def values(self): array([27., 18., 18., 27., 54., 72.]) Dimensions without coordinates: dim_0 """ - return self._calc_widths( - self.data - if self.data_is_numeric - else self._indexes_centered(self._unique_inverse) - ) + if self.data is None: + return None - def _integers(self): - """ - Return integers. - ["a", "b", "c"] -> [1, 3, 5] - """ - return self._indexes_centered(self._unique_index) + val: DataArray + if self.data_is_numeric: + val = self.data + else: + arr = self._indexes_centered(self._data_unique_inverse) + val = self.data.copy(data=arr.reshape(self.data.shape)) + + return self._calc_widths(val) @property - def values_unique(self) -> np.ndarray: + def _values_unique(self) -> np.ndarray | None: """ Return unique values. Examples -------- >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) - >>> _Normalize(a).values_unique + >>> _Normalize(a)._values_unique array([1, 3, 5]) - >>> a = xr.DataArray([2, 1, 1, 2, 3]) - >>> _Normalize(a).values_unique - array([1, 2, 3]) - >>> _Normalize(a, width=[18, 72]).values_unique + + >>> _Normalize(a, width=[18, 72])._values_unique array([18., 45., 72.]) + + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> _Normalize(a)._values_unique + array([0. , 0.5, 2. , 3. ]) + + >>> _Normalize(a, width=[18, 72])._values_unique + array([18., 27., 54., 72.]) """ - return ( - self._integers() - if not self.data_is_numeric - else self._calc_widths(self._unique) - ) + if self.data is None: + return None + + val: np.ndarray + if self.data_is_numeric: + val = self._data_unique + else: + val = self._indexes_centered(self._data_unique_index) + + return self._calc_widths(val) @property - def ticks(self) -> None | np.ndarray: + def ticks(self) -> np.ndarray | None: """ Return ticks for plt.colorbar if the data is not numeric. @@ -1559,7 +1573,13 @@ def ticks(self) -> None | np.ndarray: >>> _Normalize(a).ticks array([1, 3, 5]) """ - return self._integers() if not self.data_is_numeric else None + val: None | np.ndarray + if self.data_is_numeric: + val = None + else: + val = self._indexes_centered(self._data_unique_index) + + return val @property def levels(self) -> np.ndarray: @@ -1573,11 +1593,16 @@ def levels(self) -> np.ndarray: >>> _Normalize(a).levels array([0, 2, 4, 6]) """ - return np.append(self._unique_index, np.max(self._unique_index) + 1) * 2 + return ( + np.append(self._data_unique_index, np.max(self._data_unique_index) + 1) * 2 + ) @property def _lookup(self) -> pd.Series: - return pd.Series(dict(zip(self.values_unique, self._unique))) + if self._values_unique is None: + raise ValueError("self.data can't be None.") + + return pd.Series(dict(zip(self._values_unique, self._data_unique))) def _lookup_arr(self, x) -> np.ndarray: # Use reindex to be less sensitive to float errors. reindex only @@ -1656,7 +1681,7 @@ def _determine_guide( else: add_colorbar = False - if (add_legend) and hueplt_norm.data is None and sizeplt_norm.data is None: + if add_legend and hueplt_norm.data is None and sizeplt_norm.data is None: raise KeyError("Cannot create a legend when hue and markersize is None.") if add_legend is None: if (