Skip to content

Commit f0b9a8d

Browse files
committed
Add dask support for timedelta encoding and more tests
1 parent e5150c9 commit f0b9a8d

File tree

3 files changed

+214
-72
lines changed

3 files changed

+214
-72
lines changed

xarray/coding/times.py

Lines changed: 89 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -674,31 +674,54 @@ def encode_cf_datetime(
674674
calendar: str | None = None,
675675
dtype: np.dtype | None = None,
676676
):
677+
"""Given an array of datetime objects, returns the tuple `(num, units,
678+
calendar)` suitable for a CF compliant time variable.
679+
680+
Unlike `date2num`, this function can handle datetime64 arrays.
681+
682+
See Also
683+
--------
684+
cftime.date2num
685+
"""
677686
dates = asarray(dates)
678687
if isinstance(dates, np.ndarray):
679688
return _eagerly_encode_cf_datetime(dates, units, calendar, dtype)
680689
elif is_duck_dask_array(dates):
681690
return _lazily_encode_cf_datetime(dates, units, calendar, dtype)
682691

683692

684-
def _cast_to_dtype_safe(num, dtype) -> np.ndarray:
685-
cast_num = np.asarray(num, dtype=dtype)
693+
def _cast_to_dtype_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
694+
with warnings.catch_warnings():
695+
warnings.filterwarnings("ignore", message="overflow")
696+
cast_num = np.asarray(num, dtype=dtype)
686697

687698
if np.issubdtype(dtype, np.integer):
688699
if not (num == cast_num).all():
689-
raise ValueError(
690-
f"Not possible to cast all encoded times from dtype {num.dtype!r} "
691-
f"to dtype {dtype!r} without changing any of their values. "
692-
f"Consider removing the dtype encoding or explicitly switching to "
693-
f"a dtype encoding with a higher precision."
694-
)
700+
if np.issubdtype(num.dtype, np.floating):
701+
raise ValueError(
702+
f"Not possible to cast all encoded times from dtype "
703+
f"{num.dtype!r} to integer dtype {dtype!r} without losing "
704+
f"precision. Consider modifying the units such that "
705+
f"integer values can be used, or removing the units and "
706+
f"dtype encoding, at which point xarray will make an "
707+
f"appropriate choice."
708+
)
709+
else:
710+
raise OverflowError(
711+
f"Not possible to cast encoded times from dtype "
712+
f"{num.dtype!r} to dtype {dtype!r} without overflow. "
713+
f"Consider removing the dtype encoding, at which point "
714+
f"xarray will make an appropriate choice, or explicitly "
715+
f"switching to a larger integer dtype."
716+
)
695717
else:
696718
if np.isinf(cast_num).any():
697719
raise OverflowError(
698720
f"Not possible to cast encoded times from dtype {num.dtype!r} "
699721
f"to dtype {dtype!r} without overflow. Consider removing the "
700-
f"dtype encoding or explicitly switching to a dtype encoding "
701-
f"with a higher precision."
722+
f"dtype encoding, at which point xarray will make an "
723+
f"appropriate choice, or explicitly switching to a larger "
724+
f"floating point dtype."
702725
)
703726

704727
return cast_num
@@ -711,15 +734,6 @@ def _eagerly_encode_cf_datetime(
711734
dtype: np.dtype | None = None,
712735
called_via_map_blocks: bool = False,
713736
) -> tuple[np.ndarray, str, str]:
714-
"""Given an array of datetime objects, returns the tuple `(num, units,
715-
calendar)` suitable for a CF compliant time variable.
716-
717-
Unlike `date2num`, this function can handle datetime64 arrays.
718-
719-
See Also
720-
--------
721-
cftime.date2num
722-
"""
723737
dates = np.asarray(dates)
724738

725739
data_units = infer_datetime_units(dates)
@@ -796,7 +810,7 @@ def _eagerly_encode_cf_datetime(
796810
if called_via_map_blocks:
797811
return num
798812
else:
799-
return (num, units, calendar)
813+
return num, units, calendar
800814

801815

802816
def _lazily_encode_cf_datetime(
@@ -821,10 +835,10 @@ def _lazily_encode_cf_datetime(
821835

822836
if units is None or dtype is None:
823837
raise ValueError(
824-
f"When encoding chunked arrays of datetime values, both the units and "
825-
f"dtype must be prescribed or both must be unprescribed. Prescribing "
826-
f"only one or the other is not currently supported. Got a units "
827-
f"encoding of {units} and a dtype encoding of {dtype}."
838+
f"When encoding chunked arrays of datetime values, both the units "
839+
f"and dtype must be prescribed or both must be unprescribed. "
840+
f"Prescribing only one or the other is not currently supported. "
841+
f"Got a units encoding of {units} and a dtype encoding of {dtype}."
828842
)
829843

830844
num = dask.array.map_blocks(
@@ -841,6 +855,19 @@ def _lazily_encode_cf_datetime(
841855

842856
def encode_cf_timedelta(
843857
timedeltas, units: str | None = None, dtype: np.dtype | None = None
858+
):
859+
timedeltas = asarray(timedeltas)
860+
if is_duck_dask_array(timedeltas):
861+
return _lazily_encode_cf_timedelta(timedeltas, units, dtype)
862+
else:
863+
return _eagerly_encode_cf_timedelta(timedeltas, units, dtype)
864+
865+
866+
def _eagerly_encode_cf_timedelta(
867+
timedeltas,
868+
units: str | None = None,
869+
dtype: np.dtype | None = None,
870+
called_via_map_blocks: bool = False,
844871
) -> tuple[np.ndarray, str]:
845872
data_units = infer_timedelta_units(timedeltas)
846873

@@ -868,7 +895,7 @@ def encode_cf_timedelta(
868895
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
869896
f"Set encoding['dtype'] to floating point dtype to silence this warning."
870897
)
871-
elif np.issubdtype(dtype, np.integer):
898+
elif np.issubdtype(dtype, np.integer) and not called_via_map_blocks:
872899
emit_user_level_warning(
873900
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
874901
f"Serializing with units {needed_units!r} instead. "
@@ -881,7 +908,43 @@ def encode_cf_timedelta(
881908

882909
num = _division(time_deltas, time_delta, floor_division)
883910
num = num.values.reshape(timedeltas.shape)
884-
return (num, units)
911+
912+
if dtype is not None:
913+
num = _cast_to_dtype_safe(num, dtype)
914+
915+
if called_via_map_blocks:
916+
return num
917+
else:
918+
return num, units
919+
920+
921+
def _lazily_encode_cf_timedelta(
922+
timedeltas, units: str | None = None, dtype: np.dtype | None = None
923+
):
924+
import dask.array
925+
926+
if units is None and dtype is None:
927+
units = "nanoseconds"
928+
dtype = np.dtype("int64")
929+
930+
if units is None or dtype is None:
931+
raise ValueError(
932+
f"When encoding chunked arrays of timedelta values, both the "
933+
f"units and dtype must be prescribed or both must be "
934+
f"unprescribed. Prescribing only one or the other is not "
935+
f"currently supported. Got a units encoding of {units} and a "
936+
f"dtype encoding of {dtype}."
937+
)
938+
939+
num = dask.array.map_blocks(
940+
_eagerly_encode_cf_timedelta,
941+
timedeltas,
942+
units,
943+
dtype,
944+
called_via_map_blocks=True,
945+
dtype=dtype,
946+
)
947+
return num, units
885948

886949

887950
class CFDatetimeCoder(VariableCoder):

xarray/tests/test_backends.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2818,6 +2818,15 @@ def test_chunked_datetime64(self) -> None:
28182818
assert original[name].chunks == actual_var.chunks
28192819
assert original.chunks == actual.chunks
28202820

2821+
@requires_dask
2822+
def test_chunked_timedelta64(self) -> None:
2823+
# Based @malmans2's datetime64[ns] test in PR #8253
2824+
original = create_test_data().astype("timedelta64[ns]").chunk(1)
2825+
with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
2826+
for name, actual_var in actual.variables.items():
2827+
assert original[name].chunks == actual_var.chunks
2828+
assert original.chunks == actual.chunks
2829+
28212830
def test_vectorized_indexing_negative_step(self) -> None:
28222831
if not has_dask:
28232832
pytest.xfail(

0 commit comments

Comments
 (0)