Skip to content

added 'storage_transformers' to valid_encodings #7540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from xarray import backends, conventions
from xarray.backends import plugins
from xarray.backends.common import AbstractDataStore, ArrayWriter, _normalize_path
from xarray.backends.locks import _get_scheduler
from xarray.backends.locks import _get_scheduler, get_write_lock
from xarray.core import indexing
from xarray.core.combine import (
_infer_concat_order_from_positions,
Expand Down Expand Up @@ -1650,7 +1650,11 @@ def to_zarr(
"mode='r+'. To allow writing new variables, set mode='a'."
)

writer = ArrayWriter()
if any(["storage_transformers" in encoding[var] for var in encoding]):
writer = ArrayWriter(lock=get_write_lock("ZARR_SHARDING_LOCK"))
else:
writer = ArrayWriter()

# TODO: figure out how to properly handle unlimited_dims
dump_to_store(dataset, zstore, writer, encoding=encoding)
writes = writer.sync(compute=compute)
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Neither HDF5 nor the netCDF-C library are thread-safe.
HDF5_LOCK = SerializableLock()
NETCDFC_LOCK = SerializableLock()

ZARR_SHARDING_LOCK = SerializableLock()

_FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary()

Expand Down
1 change: 1 addition & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def extract_zarr_variable_encoding(
"filters",
"cache_metadata",
"write_empty_chunks",
"storage_transformers",
}

for k in safe_to_drop:
Expand Down
50 changes: 49 additions & 1 deletion xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,27 @@
KVStore = None

have_zarr_v3 = False
have_sharding_v3 = False
try:
# as of Zarr v2.13 these imports require environment variable
# as of Zarr v2.14 these imports require environment variable
# ZARR_V3_EXPERIMENTAL_API=1
from zarr import DirectoryStoreV3, KVStoreV3

have_zarr_v3 = True

from zarr._storage.v3_storage_transformers import v3_sharding_available

# TODO: change to try except ImportError when available at the top-level zarr namespace
if v3_sharding_available:
# as of Zarr v2.14 these imports require environment variable
# ZARR_V3_SHARDING=1
# TODO: change import to
# from zarr import ShardingStorageTransformer
# when ShardingStorageTransformer becomes available at the top-level zarr namespace
from zarr._storage.v3_storage_transformers import ShardingStorageTransformer

have_sharding_v3 = True

except ImportError:
KVStoreV3 = None

Expand Down Expand Up @@ -2660,6 +2675,39 @@ def create_zarr_target(self):
yield tmp


@pytest.mark.skipif(not have_zarr_v3, reason="requires zarr version 3")
class TestZarrStorageTransformersV3(TestZarrDirectoryStoreV3):
@pytest.mark.skipif(not have_sharding_v3, reason="requires sharding")
def test_sharding_storage_transformer(self):
original = create_test_data().chunk({"dim1": 2, "dim2": 3, "dim3": 2})

encoding = {
"var1": {
"storage_transformers": [
ShardingStorageTransformer("indexed", chunks_per_shard=(2, 1))
],
},
"var2": {
"storage_transformers": [
ShardingStorageTransformer("indexed", chunks_per_shard=(2, 2))
],
},
"var3": {
"storage_transformers": [
ShardingStorageTransformer("indexed", chunks_per_shard=(1, 1))
],
},
}

with self.roundtrip(
original, save_kwargs={"encoding": encoding}, open_kwargs={"chunks": {}}
) as ds1:
assert_identical(ds1, original)

with self.roundtrip_append(original, open_kwargs={"chunks": {}}) as ds2:
assert_identical(ds2, original)


@requires_zarr
@requires_fsspec
def test_zarr_storage_options() -> None:
Expand Down