diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 8891ac2986b..6162ae829a1 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -13,7 +13,7 @@ from xarray import backends, conventions from xarray.backends import plugins from xarray.backends.common import AbstractDataStore, ArrayWriter, _normalize_path -from xarray.backends.locks import _get_scheduler +from xarray.backends.locks import _get_scheduler, get_write_lock from xarray.core import indexing from xarray.core.combine import ( _infer_concat_order_from_positions, @@ -1650,7 +1650,11 @@ def to_zarr( "mode='r+'. To allow writing new variables, set mode='a'." ) - writer = ArrayWriter() + if any(["storage_transformers" in encoding[var] for var in encoding]): + writer = ArrayWriter(lock=get_write_lock("ZARR_SHARDING_LOCK")) + else: + writer = ArrayWriter() + # TODO: figure out how to properly handle unlimited_dims dump_to_store(dataset, zstore, writer, encoding=encoding) writes = writer.sync(compute=compute) diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index bba12a29609..ddf0bdf742e 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -17,7 +17,7 @@ # Neither HDF5 nor the netCDF-C library are thread-safe. HDF5_LOCK = SerializableLock() NETCDFC_LOCK = SerializableLock() - +ZARR_SHARDING_LOCK = SerializableLock() _FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 3b0335aa5a6..42d32892db7 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -237,6 +237,7 @@ def extract_zarr_variable_encoding( "filters", "cache_metadata", "write_empty_chunks", + "storage_transformers", } for k in safe_to_drop: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 91daabd12d5..705a5ef8757 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -109,12 +109,27 @@ KVStore = None have_zarr_v3 = False +have_sharding_v3 = False try: - # as of Zarr v2.13 these imports require environment variable + # as of Zarr v2.14 these imports require environment variable # ZARR_V3_EXPERIMENTAL_API=1 from zarr import DirectoryStoreV3, KVStoreV3 have_zarr_v3 = True + + from zarr._storage.v3_storage_transformers import v3_sharding_available + + # TODO: change to try except ImportError when available at the top-level zarr namespace + if v3_sharding_available: + # as of Zarr v2.14 these imports require environment variable + # ZARR_V3_SHARDING=1 + # TODO: change import to + # from zarr import ShardingStorageTransformer + # when ShardingStorageTransformer becomes available at the top-level zarr namespace + from zarr._storage.v3_storage_transformers import ShardingStorageTransformer + + have_sharding_v3 = True + except ImportError: KVStoreV3 = None @@ -2660,6 +2675,39 @@ def create_zarr_target(self): yield tmp +@pytest.mark.skipif(not have_zarr_v3, reason="requires zarr version 3") +class TestZarrStorageTransformersV3(TestZarrDirectoryStoreV3): + @pytest.mark.skipif(not have_sharding_v3, reason="requires sharding") + def test_sharding_storage_transformer(self): + original = create_test_data().chunk({"dim1": 2, "dim2": 3, "dim3": 2}) + + encoding = { + "var1": { + "storage_transformers": [ + ShardingStorageTransformer("indexed", chunks_per_shard=(2, 1)) + ], + }, + "var2": { + "storage_transformers": [ + ShardingStorageTransformer("indexed", chunks_per_shard=(2, 2)) + ], + }, + "var3": { + "storage_transformers": [ + ShardingStorageTransformer("indexed", chunks_per_shard=(1, 1)) + ], + }, + } + + with self.roundtrip( + original, save_kwargs={"encoding": encoding}, open_kwargs={"chunks": {}} + ) as ds1: + assert_identical(ds1, original) + + with self.roundtrip_append(original, open_kwargs={"chunks": {}}) as ds2: + assert_identical(ds2, original) + + @requires_zarr @requires_fsspec def test_zarr_storage_options() -> None: