Skip to content

Commit 8632186

Browse files
Added facade functions to_zarr and from_zarr (#2236)
* Added facade functions `to_zarr` and `from_zarr` * black * added to changelog * update PR with review comments * fix rebase issues with changelog * black --------- Co-authored-by: Oriol (ZBook) <[email protected]>
1 parent 7fb2257 commit 8632186

File tree

5 files changed

+91
-1
lines changed

5 files changed

+91
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
- Add InferenceData<->DataTree conversion functions ([2253](https://github.com/arviz-devs/arviz/pull/2253))
77
- Bayes Factor plot: Use arviz's kde instead of the one from scipy ([2237](https://github.com/arviz-devs/arviz/pull/2237))
88
- InferenceData objects can now be appended to existing netCDF4 files and to specific groups within them ([2227](https://github.com/arviz-devs/arviz/pull/2227))
9+
- Added facade functions `az.to_zarr` and `az.from_zarr` ([2236](https://github.com/arviz-devs/arviz/pull/2236))
910

1011
### Maintenance and fixes
1112
- Replace deprecated np.product with np.prod ([2249](https://github.com/arviz-devs/arviz/pull/2249))
1213
- Fix numba deprecation warning ([2246](https://github.com/arviz-devs/arviz/pull/2246))
1314
- Fixes for creating numpy object array ([2233](https://github.com/arviz-devs/arviz/pull/2233) and [2239](https://github.com/arviz-devs/arviz/pull/2239))
1415
- Adapt histograms generated by plot_dist to input dtype ([2247](https://github.com/arviz-devs/arviz/pull/2247))
1516

16-
1717
### Deprecation
1818

1919
### Documentation

arviz/data/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .io_pyjags import from_pyjags
1616
from .io_pyro import from_pyro
1717
from .io_pystan import from_pystan
18+
from .io_zarr import from_zarr, to_zarr
1819
from .utils import extract, extract_dataset
1920

2021
__all__ = [
@@ -44,6 +45,8 @@
4445
"to_datatree",
4546
"to_json",
4647
"to_netcdf",
48+
"from_zarr",
49+
"to_zarr",
4750
"CoordSpec",
4851
"DimSpec",
4952
]

arviz/data/io_zarr.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Input and output support for zarr data."""
2+
3+
from .converters import convert_to_inference_data
4+
from .inference_data import InferenceData
5+
6+
7+
def from_zarr(store):
8+
return InferenceData.from_zarr(store)
9+
10+
11+
from_zarr.__doc__ = InferenceData.from_zarr.__doc__
12+
13+
14+
def to_zarr(data, store=None, **kwargs):
15+
"""
16+
Convert data to zarr, optionally saving to disk if ``store`` is provided.
17+
18+
The zarr storage is using the same group names as the InferenceData.
19+
20+
Parameters
21+
----------
22+
store : zarr.storage, MutableMapping or str, optional
23+
Zarr storage class or path to desired DirectoryStore.
24+
Default (None) a store is created in a temporary directory.
25+
**kwargs : dict, optional
26+
Passed to :py:func:`convert_to_inference_data`.
27+
28+
Returns
29+
-------
30+
zarr.hierarchy.group
31+
A zarr hierarchy group containing the InferenceData.
32+
33+
Raises
34+
------
35+
TypeError
36+
If no valid store is found.
37+
38+
39+
References
40+
----------
41+
https://zarr.readthedocs.io/
42+
43+
"""
44+
inference_data = convert_to_inference_data(data, **kwargs)
45+
zarr_group = inference_data.to_zarr(store=store)
46+
return zarr_group

arviz/tests/base_tests/test_data_zarr.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from ... import InferenceData, from_dict
11+
from ... import to_zarr, from_zarr
1112

1213
from ..helpers import ( # pylint: disable=unused-import
1314
chains,
@@ -103,3 +104,41 @@ def test_io_method(self, data, eight_schools_params, store, fill_attrs):
103104
assert inference_data2.attrs["test"] == 1
104105
else:
105106
assert "test" not in inference_data2.attrs
107+
108+
def test_io_function(self, data, eight_schools_params):
109+
# create InferenceData and check it has been properly created
110+
inference_data = self.get_inference_data( # pylint: disable=W0612
111+
data,
112+
eight_schools_params,
113+
fill_attrs=True,
114+
)
115+
test_dict = {
116+
"posterior": ["eta", "theta", "mu", "tau"],
117+
"posterior_predictive": ["eta", "theta", "mu", "tau"],
118+
"sample_stats": ["eta", "theta", "mu", "tau"],
119+
"prior": ["eta", "theta", "mu", "tau"],
120+
"prior_predictive": ["eta", "theta", "mu", "tau"],
121+
"sample_stats_prior": ["eta", "theta", "mu", "tau"],
122+
"observed_data": ["J", "y", "sigma"],
123+
}
124+
fails = check_multiple_attrs(test_dict, inference_data)
125+
assert not fails
126+
127+
assert inference_data.attrs["test"] == 1
128+
129+
# check filename does not exist and use to_zarr method
130+
with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
131+
filepath = os.path.join(tmp_dir, "zarr")
132+
133+
to_zarr(inference_data, store=filepath)
134+
# assert file has been saved correctly
135+
assert os.path.exists(filepath)
136+
assert os.path.getsize(filepath) > 0
137+
138+
inference_data2 = from_zarr(filepath)
139+
140+
# Everything in dict still available in inference_data2 ?
141+
fails = check_multiple_attrs(test_dict, inference_data2)
142+
assert not fails
143+
144+
assert inference_data2.attrs["test"] == 1

doc/source/api/data.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ IO / General conversion
3737
to_datatree
3838
to_json
3939
to_netcdf
40+
from_zarr
41+
to_zarr
4042

4143

4244
General functions

0 commit comments

Comments
 (0)