diff --git a/docs/api.rst b/docs/api.rst index 4a904190..4bcb96aa 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -13,6 +13,7 @@ Dataset Dataset.pint.quantify Dataset.pint.dequantify + Dataset.pint.to Dataset.pint.to_base_units Dataset.pint.to_system diff --git a/docs/conf.py b/docs/conf.py index 7b9e884f..b0235b56 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,7 @@ import sphinx_autosummary_accessors # need to import so accessors get registered -import pintxarray # noqa: F401 +import pint_xarray # noqa: F401 # -- Project information ----------------------------------------------------- diff --git a/licenses/XARRAY_LICENSE b/licenses/XARRAY_LICENSE new file mode 100644 index 00000000..978b509e --- /dev/null +++ b/licenses/XARRAY_LICENSE @@ -0,0 +1,194 @@ +Copyright 2014-2020, xarray developers + + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, "control" means (i) the power, direct or +indirect, to cause the direction or management of such entity, whether by +contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising +permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + +"Object" form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object code, +generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made +available under the License, as indicated by a copyright notice that is included +in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative Works +shall not include works that remain separable from, or merely link (or bind by +name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version +of the Work and any modifications or additions to that Work or Derivative Works +thereof, that is intentionally submitted to Licensor for inclusion in the Work +by the copyright owner or by an individual or Legal Entity authorized to submit +on behalf of the copyright owner. For the purposes of this definition, +"submitted" means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, and +issue tracking systems that are managed by, or on behalf of, the Licensor for +the purpose of discussing and improving the Work, but excluding communication +that is conspicuously marked or otherwise designated in writing by the copyright +owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +2. Grant of Copyright License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the Work and such +Derivative Works in Source or Object form. + +3. Grant of Patent License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable (except as stated in this section) patent license to make, have +made, use, offer to sell, sell, import, and otherwise transfer the Work, where +such license applies only to those patent claims licensable by such Contributor +that are necessarily infringed by their Contribution(s) alone or by combination +of their Contribution(s) with the Work to which such Contribution(s) was +submitted. If You institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work or a +Contribution incorporated within the Work constitutes direct or contributory +patent infringement, then any patent licenses granted to You under this License +for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. + +You may reproduce and distribute copies of the Work or Derivative Works thereof +in any medium, with or without modifications, and in Source or Object form, +provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of +this License; and +You must cause any modified files to carry prominent notices stating that You +changed the files; and +You must retain, in the Source form of any Derivative Works that You distribute, +all copyright, patent, trademark, and attribution notices from the Source form +of the Work, excluding those notices that do not pertain to any part of the +Derivative Works; and +If the Work includes a "NOTICE" text file as part of its distribution, then any +Derivative Works that You distribute must include a readable copy of the +attribution notices contained within such NOTICE file, excluding those notices +that do not pertain to any part of the Derivative Works, in at least one of the +following places: within a NOTICE text file distributed as part of the +Derivative Works; within the Source form or documentation, if provided along +with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents of +the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works that +You distribute, alongside or as an addendum to the NOTICE text from the Work, +provided that such additional attribution notices cannot be construed as +modifying the License. +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a whole, +provided Your use, reproduction, and distribution of the Work otherwise complies +with the conditions stated in this License. + +5. Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted +for inclusion in the Work by You to the Licensor shall be under the terms and +conditions of this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify the terms of +any separate license agreement you may have executed with Licensor regarding +such Contributions. + +6. Trademarks. + +This License does not grant permission to use the trade names, trademarks, +service marks, or product names of the Licensor, except as required for +reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. + +Unless required by applicable law or agreed to in writing, Licensor provides the +Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, +including, without limitation, any warranties or conditions of TITLE, +NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are +solely responsible for determining the appropriateness of using or +redistributing the Work and assume any risks associated with Your exercise of +permissions under this License. + +8. Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), +contract, or otherwise, unless required by applicable law (such as deliberate +and grossly negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, incidental, +or consequential damages of any character arising as a result of this License or +out of the use or inability to use the Work (including but not limited to +damages for loss of goodwill, work stoppage, computer failure or malfunction, or +any and all other commercial damages or losses), even if such Contributor has +been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. + +While redistributing the Work or Derivative Works thereof, You may choose to +offer, and charge a fee for, acceptance of support, warranty, indemnity, or +other liability obligations and/or rights consistent with this License. However, +in accepting such obligations, You may act only on Your own behalf and on Your +sole responsibility, not on behalf of any other Contributor, and only if You +agree to indemnify, defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason of your +accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work + +To apply the Apache License to your work, attach the following boilerplate +notice, with the fields enclosed by brackets "[]" replaced with your own +identifying information. (Don't include the brackets!) The text should be +enclosed in the appropriate comment syntax for the file format. We also +recommend that a file or class name and description of purpose be included on +the same "printed page" as the copyright notice for easier identification within +third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/pint_xarray/accessors.py b/pint_xarray/accessors.py index a70aba1c..ed22e58a 100644 --- a/pint_xarray/accessors.py +++ b/pint_xarray/accessors.py @@ -12,6 +12,8 @@ ) from xarray.core.npcompat import IS_NEP18_ACTIVE +from . import conversion + if not hasattr(Quantity, "__array_function__"): raise ImportError( "Imported version of pint does not implement " "__array_function__" @@ -29,6 +31,28 @@ # TODO type hints +def is_dict_like(obj): + return hasattr(obj, "keys") and hasattr(obj, "__getitem__") + + +# based on xarray.core.utils.either_dict_or_kwargs +# https://github.com/pydata/xarray/blob/v0.15.1/xarray/core/utils.py#L249-L268 +def either_dict_or_kwargs(positional, keywords, method_name): + if positional is not None: + if not is_dict_like(positional): + raise ValueError( + f"the first argument to .{method_name} must be a dictionary" + ) + if keywords: + raise ValueError( + "cannot specify both keyword and positional " + f"arguments to .{method_name}" + ) + return positional + else: + return keywords + + def _array_attach_units(data, unit, convert_from=None): """ Internal utility function for attaching units to a numpy-like array, @@ -230,15 +254,109 @@ def registry(self): def registry(self, _): raise AttributeError("Don't try to change the registry once created") - def to(self, units): - quantity = self.da.data.to(units) - return DataArray( - dim=self.da.dims, - data=quantity, - coords=self.da.coords, - attrs=self.da.attrs, - encoding=self.da.encoding, - ) + def to(self, units=None, **unit_kwargs): + """ convert the quantities in a DataArray + + Parameters + ---------- + units : str or pint.Unit or mapping of hashable to str or pint.Unit, optional + The units to convert to. If a unit name or + :py:class`pint.Unit` object, convert the DataArray's + data. If a dict-like, it has to map a variable name to a + unit name or :py:class:`pint.Unit` object. + **unit_kwargs + The kwargs form of ``units``. Can only be used for + variable names that are strings and valid python identifiers. + + Returns + ------- + object : DataArray + A new object with converted units. + + Examples + -------- + >>> da = xr.DataArray( + ... data=np.linspace(0, 1, 5) * ureg.m, + ... coords={"u": ("x", np.arange(5) * ureg.s)}, + ... dims="x", + ... name="arr", + ... ) + >>> da + + + Coordinates: + u (x) int64 + Dimensions without coordinates: x + + Convert the data + + >>> da.pint.to("mm") + + + Coordinates: + u (x) int64 + Dimensions without coordinates: x + >>> da.pint.to(ureg.mm) + + + Coordinates: + u (x) int64 + Dimensions without coordinates: x + >>> da.pint.to({da.name: "mm"}) + + + Coordinates: + u (x) int64 + Dimensions without coordinates: x + + Convert coordinates + + >>> da.pint.to({"u": ureg.ms}) + + + Coordinates: + u (x) float64 >> da.pint.to(u="ms") + + + Coordinates: + u (x) float64 >> da.pint.to("mm", u="ms") + + + Coordinates: + u (x) float64 >> da.pint.to({"arr": ureg.mm, "u": ureg.ms}) + + + Coordinates: + u (x) float64 >> da.pint.to(arr="mm", u="ms") + + + Coordinates: + u (x) float64 >> ds = xr.Dataset( + ... data_vars={ + ... "a": ("x", np.linspace(0, 1, 5) * ureg.m), + ... "b": ("x", np.linspace(-1, 0, 5) * ureg.kg), + ... }, + ... coords={"u": ("x", np.arange(5) * ureg.s)}, + ... ) + >>> ds + + Dimensions: (x: 5) + Coordinates: + u (x) int64 + Dimensions without coordinates: x + Data variables: + a (x) float64 + b (x) float64 + + Convert the data + + >>> ds.pint.to({"a": "mm", "b": ureg.g}) + + Dimensions: (x: 5) + Coordinates: + u (x) int64 + Dimensions without coordinates: x + Data variables: + a (x) float64 >> ds.pint.to(a=ureg.mm, b="g") + + Dimensions: (x: 5) + Coordinates: + u (x) int64 + Dimensions without coordinates: x + Data variables: + a (x) float64 >> ds.pint.to({"u": ureg.ms}) + + Dimensions: (x: 5) + Coordinates: + u (x) float64 + b (x) float64 + >>> ds.pint.to(u="ms") + + Dimensions: (x: 5) + Coordinates: + u (x) float64 + b (x) float64 + + Convert both simultaneously + + >>> ds.pint.to(a=ureg.mm, b=ureg.g, u="ms") + + Dimensions: (x: 5) + Coordinates: + u (x) float64 >> ds.pint.to({"a": "mm", "b": "g", "u": ureg.ms}) + + Dimensions: (x: 5) + Coordinates: + u (x) float64 " + ds = obj.rename(new_name).to_dataset() + units = units.copy() + units[new_name] = units.get(old_name) + + new_ds = attach_units(ds, units, registry=registry) + new_obj = new_ds.get(new_name).rename(old_name) + elif isinstance(obj, Dataset): + data_vars = { + name: attach_units( + array.variable, {None: units.get(name)}, registry=registry + ) + for name, array in obj.data_vars.items() + } + coords = { + name: attach_units( + array.variable, {None: units.get(name)}, registry=registry + ) + for name, array in obj.coords.items() + } + + new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) + elif isinstance(obj, Variable): + new_data = array_attach_units(obj.data, units.get(None), registry=registry) + new_obj = obj.copy(data=new_data) + else: + raise ValueError(f"cannot attach units to {obj!r}: unknown type") + + return new_obj + + +def convert_units(obj, units): + if not isinstance(units, dict): + units = {None: units} + + if isinstance(obj, Variable): + new_data = array_convert_units(obj.data, units.get(None)) + new_obj = obj.copy(data=new_data) + elif isinstance(obj, DataArray): + original_name = obj.name + name = obj.name if obj.name is not None else "" + + units_ = units.copy() + if obj.name in units_: + units_[name] = units_[obj.name] + + ds = obj.rename(name).to_dataset() + converted = convert_units(ds, units_) + + new_obj = converted[name].rename(original_name) + elif isinstance(obj, Dataset): + coords = { + name: convert_units(data.variable, units.get(name)) + if name not in obj.dims + else data + for name, data in obj.coords.items() + } + data_vars = { + name: convert_units(data.variable, units.get(name)) + for name, data in obj.items() + } + + new_obj = Dataset(coords=coords, data_vars=data_vars, attrs=obj.attrs) + else: + raise ValueError("cannot convert non-xarray objects") + + return new_obj + + +def extract_units(obj): + if isinstance(obj, Dataset): + vars_units = { + name: array_extract_units(value.data) + for name, value in obj.data_vars.items() + } + coords_units = { + name: array_extract_units(value.data) for name, value in obj.coords.items() + } + + units = {**vars_units, **coords_units} + elif isinstance(obj, DataArray): + vars_units = {obj.name: array_extract_units(obj.data)} + coords_units = { + name: array_extract_units(value.data) for name, value in obj.coords.items() + } + + units = {**vars_units, **coords_units} + elif isinstance(obj, Variable): + vars_units = {None: array_extract_units(obj.data)} + + units = {**vars_units} + else: + raise ValueError(f"unknown type: {type(obj)}") + + return units + + +def strip_units(obj): + if isinstance(obj, Variable): + data = array_strip_units(obj.data) + new_obj = obj.copy(data=data) + elif isinstance(obj, DataArray): + original_name = obj.name + name = obj.name if obj.name is not None else "" + ds = obj.rename(name).to_dataset() + stripped = strip_units(ds) + + new_obj = stripped[name].rename(original_name) + elif isinstance(obj, Dataset): + data_vars = { + name: strip_units(array.variable) for name, array in obj.data_vars.items() + } + coords = { + name: strip_units(array.variable) for name, array in obj.coords.items() + } + + new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) + else: + raise ValueError("cannot strip units from {obj!r}: unknown type") + + return new_obj diff --git a/pint_xarray/tests/test_accessors.py b/pint_xarray/tests/test_accessors.py index 76e3fc7e..e8e5b227 100644 --- a/pint_xarray/tests/test_accessors.py +++ b/pint_xarray/tests/test_accessors.py @@ -3,13 +3,10 @@ import xarray as xr from numpy.testing import assert_array_equal from pint import UnitRegistry - -# from pint.unit import Unit -from pint.errors import UndefinedUnitError # , DimensionalityError +from pint.errors import UndefinedUnitError from xarray.testing import assert_equal -# from pintxarray.accessors import PintDataArrayAccessor, PintDatasetAccessor -from .utils import raises_regex # extract_units +from .utils import raises_regex # make sure scalars are converted to 0d arrays so quantities can # always be treated like ndarrays diff --git a/pint_xarray/tests/test_conversion.py b/pint_xarray/tests/test_conversion.py new file mode 100644 index 00000000..e77366d2 --- /dev/null +++ b/pint_xarray/tests/test_conversion.py @@ -0,0 +1,346 @@ +import numpy as np +import pint +import pytest +from xarray import DataArray, Dataset, Variable + +from pint_xarray import conversion + +from .utils import assert_array_equal, assert_array_units_equal, assert_equal + +unit_registry = pint.UnitRegistry() + +pytestmark = pytest.mark.filterwarnings("error::pint.UnitStrippedWarning") + + +class TestArrayFunctions: + @pytest.mark.parametrize( + "registry", + ( + pytest.param(None, id="without registry"), + pytest.param(unit_registry, id="with registry"), + ), + ) + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="not a unit"), + pytest.param(None, id="no unit"), + pytest.param("m", id="string"), + pytest.param(unit_registry.m, id="unit object"), + ), + ) + @pytest.mark.parametrize( + "data", + ( + pytest.param(np.array([0, 1]), id="array_like"), + pytest.param(np.array([1, 2]) * unit_registry.m, id="quantity"), + ), + ) + def test_array_attach_units(self, data, unit, registry): + if unit == 1: + match = "cannot use .+ as a unit" + elif isinstance(data, pint.Quantity) and unit is not None: + match = "already has units" + elif isinstance(unit, str) and registry is None: + match = "a string as unit" + else: + match = None + + if match is not None: + with pytest.raises(ValueError, match=match): + conversion.array_attach_units(data, unit, registry=registry) + + return + + expected = unit_registry.Quantity(data, "m") if unit is not None else data + actual = conversion.array_attach_units(data, unit, registry=registry) + + assert_array_units_equal(expected, actual) + assert_array_equal(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="not a unit"), + pytest.param(None, id="no unit"), + pytest.param("mm", id="string"), + pytest.param(unit_registry.mm, id="unit object"), + ), + ) + @pytest.mark.parametrize( + "data", + ( + pytest.param(np.array([0, 1]), id="array_like"), + pytest.param(np.array([1, 2]) * unit_registry.m, id="quantity"), + ), + ) + def test_array_convert_units(self, data, unit): + if unit == 1: + error = ValueError + match = "cannot use .+ as a unit" + elif not isinstance(data, pint.Quantity) and isinstance(unit, str): + error = ValueError + match = "cannot convert a non-quantity using .+ as unit" + elif not isinstance(data, pint.Quantity) and unit is not None: + error = pint.DimensionalityError + match = None + else: + error = None + match = None + + if error is not None: + with pytest.raises(error, match=match): + conversion.array_convert_units(data, unit) + + return + + expected = ( + unit_registry.Quantity(np.array([1000, 2000]), "mm") + if unit is not None + else data + ) + actual = conversion.array_convert_units(data, unit) + + assert_array_equal(expected, actual) + + @pytest.mark.parametrize( + "data", + ( + pytest.param(np.array([0, 1]), id="array_like"), + pytest.param(np.array([1, 2]) * unit_registry.m, id="quantity"), + ), + ) + def test_array_extract_units(self, data): + expected = unit_registry.m if isinstance(data, pint.Quantity) else None + actual = conversion.array_extract_units(data) + + assert expected == actual + + @pytest.mark.parametrize( + "data", + ( + pytest.param(np.array([1, 2]), id="array_like"), + pytest.param(np.array([1, 2]) * unit_registry.m, id="quantity"), + ), + ) + def test_array_strip_units(self, data): + expected = np.array([1, 2]) + actual = conversion.array_strip_units(data) + + assert_array_equal(expected, actual) + + +class TestXarrayFunctions: + @pytest.mark.parametrize( + "obj", + ( + pytest.param(Variable("x", np.linspace(0, 1, 5)), id="Variable"), + pytest.param( + DataArray( + data=np.linspace(0, 1, 5), + dims="x", + coords={"u": ("x", np.arange(5))}, + ), + id="DataArray", + ), + pytest.param( + Dataset( + { + "a": ("x", np.linspace(-1, 1, 5)), + "b": ("x", np.linspace(0, 1, 5)), + }, + coords={"u": ("x", np.arange(5))}, + ), + id="Dataset", + ), + ), + ) + @pytest.mark.parametrize( + "units", + ( + pytest.param({None: None, "u": None}, id="no units"), + pytest.param({None: unit_registry.m, "u": None}, id="data units"), + pytest.param({None: None, "u": unit_registry.s}, id="coord units"), + ), + ) + def test_attach_units(self, obj, units): + if isinstance(obj, Variable) and "u" in units: + pytest.skip(msg="variables don't have coordinates") + + if isinstance(obj, Dataset): + units = units.copy() + data_units = units.pop(None) + units.update({"a": data_units, "b": data_units}) + + actual = conversion.attach_units(obj, units) + + assert conversion.extract_units(actual) == units + + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + @pytest.mark.parametrize("typename", ("Variable", "DataArray", "Dataset")) + def test_convert_units(self, typename, variant): + if typename == "Variable": + if variant != "data": + pytest.skip("Variable doesn't store coordinates") + + data = np.linspace(0, 1, 3) * unit_registry.m + obj = Variable(dims="x", data=data) + units = {None: unit_registry.mm} + expected_units = units + elif typename == "DataArray": + unit_variants = { + "data": (unit_registry.Pa, 1, 1), + "dims": (1, unit_registry.s, 1), + "coords": (1, 1, unit_registry.m), + } + data_unit, dim_unit, coord_unit = unit_variants.get(variant) + + coords = { + "data": {}, + "dims": {"x": [0, 1, 2] * dim_unit}, + "coords": {"u": ("x", [10, 3, 4] * coord_unit)}, + } + + obj = DataArray( + dims="x", + data=np.linspace(0, 1, 3) * data_unit, + coords=coords.get(variant), + ) + template = { + **{obj.name: None}, + **{name: None for name in obj.coords}, + } + units = { + "data": {None: unit_registry.hPa}, + "dims": {"x": unit_registry.ms}, + "coords": {"u": unit_registry.mm}, + }.get(variant) + + expected_units = {**template, **units} + elif typename == "Dataset": + unit_variants = { + "data": ((unit_registry.s, unit_registry.kg), 1, 1), + "dims": ((1, 1), unit_registry.s, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (data_unit1, data_unit2), dim_unit, coord_unit = unit_variants.get(variant) + + coords = { + "data": {}, + "dims": {"x": [0, 1, 2] * dim_unit}, + "coords": {"u": ("x", [10, 3, 4] * coord_unit)}, + } + + obj = Dataset( + data_vars={ + "a": ("x", np.linspace(-1, 1, 3) * data_unit1), + "b": ("x", np.linspace(1, 2, 3) * data_unit2), + }, + coords=coords.get(variant), + ) + + template = { + **{name: None for name in obj.data_vars.keys()}, + **{name: None for name in obj.coords.keys()}, + } + units = { + "data": {"a": unit_registry.ms, "b": unit_registry.g}, + "dims": {"x": unit_registry.ms}, + "coords": {"u": unit_registry.mm}, + }.get(variant) + expected_units = {**template, **units} + + actual = conversion.convert_units(obj, units) + + assert conversion.extract_units(actual) == expected_units + assert_equal(obj, actual) + + @pytest.mark.parametrize( + "units", + ( + pytest.param({None: None, "u": None}, id="no units"), + pytest.param({None: unit_registry.m, "u": None}, id="data units"), + pytest.param({None: None, "u": unit_registry.s}, id="coord units"), + pytest.param( + {None: unit_registry.m, "u": unit_registry.s}, id="data and coord units" + ), + ), + ) + @pytest.mark.parametrize("typename", ("Variable", "DataArray", "Dataset")) + def test_extract_units(self, typename, units): + if typename == "Variable": + data_units = units.get(None) or 1 + data = np.linspace(0, 1, 2) * data_units + + units = units.copy() + units.pop("u") + + obj = Variable("x", data) + elif typename == "DataArray": + data_units = units.get(None) or 1 + data = np.linspace(0, 1, 2) * data_units + + coord_units = units.get("u") or 1 + coords = {"u": ("x", np.arange(2) * coord_units)} + + obj = DataArray(data, dims="x", coords=coords) + elif typename == "Dataset": + data_units = units.get(None) + data1 = np.linspace(-1, 1, 2) * (data_units or 1) + data2 = np.linspace(0, 1, 2) * (data_units or 1) + + coord_units = units.get("u") or 1 + coords = {"u": ("x", np.arange(2) * coord_units)} + + units = units.copy() + units.pop(None) + units.update({"a": data_units, "b": data_units}) + + obj = Dataset({"a": ("x", data1), "b": ("x", data2)}, coords=coords) + + assert conversion.extract_units(obj) == units + + @pytest.mark.parametrize( + "obj", + ( + pytest.param(Variable("x", [0, 4, 3] * unit_registry.m), id="Variable"), + pytest.param( + DataArray( + dims="x", + data=[0, 4, 3] * unit_registry.m, + coords={"u": ("x", [2, 3, 4] * unit_registry.s)}, + ), + id="DataArray", + ), + pytest.param( + Dataset( + data_vars={ + "a": ("x", [3, 2, 5] * unit_registry.Pa), + "b": ("x", [0, 2, -1] * unit_registry.kg), + }, + coords={"u": ("x", [2, 3, 4] * unit_registry.s)}, + ), + id="Dataset", + ), + ), + ) + def test_strip_units(self, obj): + if isinstance(obj, Variable): + expected_units = {None: None} + elif isinstance(obj, DataArray): + expected_units = {None: None} + expected_units.update({name: None for name in obj.coords.keys()}) + elif isinstance(obj, Dataset): + expected_units = {name: None for name in obj.variables.keys()} + + actual = conversion.strip_units(obj) + assert conversion.extract_units(actual) == expected_units diff --git a/pint_xarray/tests/utils.py b/pint_xarray/tests/utils.py index c5c50e71..f625d0ce 100644 --- a/pint_xarray/tests/utils.py +++ b/pint_xarray/tests/utils.py @@ -1,9 +1,11 @@ import re from contextlib import contextmanager +import numpy as np import pytest import xarray as xr from pint.quantity import Quantity +from xarray.testing import assert_equal # noqa: F401 @contextmanager @@ -59,40 +61,56 @@ def extract_units(obj): return units -def assert_units_equal(a, b): +def attach_units(obj, units): + if isinstance(obj, xr.DataArray): + ds = obj._to_temp_dataset() + new_name = list(ds.data_vars.keys())[0] + units[new_name] = units.get(obj.name) + new_ds = attach_units(ds, units) + new_obj = obj._from_temp_dataset(new_ds) + elif isinstance(obj, xr.Dataset): + data_vars = { + name: attach_units(array.variable, {None: units.get(name)}) + for name, array in obj.data_vars.items() + } + + coords = { + name: attach_units(array.variable, {None: units.get(name)}) + for name, array in obj.coords.items() + } + + new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) + elif isinstance(obj, xr.Variable): + new_data = attach_units(obj.data, units) + new_obj = obj.copy(data=new_data) + elif isinstance(obj, Quantity): + raise ValueError( + f"cannot attach {units.get(None)} to {obj}: already a quantity" + ) + else: + new_obj = Quantity(obj, units.get(None)) + + return new_obj + + +def assert_array_units_equal(a, b): + __tracebackhide__ = True + + units_a = getattr(a, "units", None) + units_b = getattr(b, "units", None) + + assert units_a == units_b + + +def assert_array_equal(a, b): __tracebackhide__ = True - assert extract_units(a) == extract_units(b) + a_ = getattr(a, "magnitude", a) + b_ = getattr(b, "magnitude", b) -# def assert_equal_with_units(a, b): -# # works like xr.testing.assert_equal, but also explicitly checks units -# # so, it is more like assert_identical -# __tracebackhide__ = True -# -# if isinstance(a, xr.Dataset) or isinstance(b, xr.Dataset): -# a_units = extract_units(a) -# b_units = extract_units(b) -# -# a_without_units = strip_units(a) -# b_without_units = strip_units(b) -# -# assert a_without_units.equals(b_without_units), formatting.diff_dataset_repr( -# a, b, "equals" -# ) -# assert a_units == b_units -# else: -# a = a if not isinstance(a, (xr.DataArray, xr.Variable)) else a.data -# b = b if not isinstance(b, (xr.DataArray, xr.Variable)) else b.data -# -# assert type(a) == type(b) or ( -# isinstance(a, Quantity) and isinstance(b, Quantity) -# ) -# -# # workaround until pint implements allclose in __array_function__ -# if isinstance(a, Quantity) or isinstance(b, Quantity): -# assert ( -# hasattr(a, "magnitude") and hasattr(b, "magnitude") -# ) and np.allclose(a.magnitude, b.magnitude, equal_nan=True) -# assert (hasattr(a, "units") and hasattr(b, "units")) and a.units == b.units -# else: -# assert np.allclose(a, b, equal_nan=True) + np.testing.assert_array_equal(a_, b_) + + +def assert_units_equal(a, b): + __tracebackhide__ = True + assert extract_units(a) == extract_units(b)