diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 49a5a9ec7ae..3da40688a47 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -39,7 +39,10 @@ def __init__(self, manager, lock, vrt_params=None): dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError("All bands should have the same dtype") - self._dtype = np.dtype(dtypes[0]) + if dtypes[0] != 'complex_int16' : + self._dtype = np.dtype(dtypes[0]) + else : + self._dtype = complex @property def dtype(self): diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 5079cd390f1..675026eff75 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4137,6 +4137,7 @@ def create_tmp_geotiff( crs=default_value, open_kwargs=None, additional_attrs=None, + specific_dtype=None ): if transform_args is default_value: transform_args = [5000, 80000, 1000, 2000.0] @@ -4163,7 +4164,11 @@ def create_tmp_geotiff( else: data_shape = nz, ny, nx write_kwargs = {} - data = np.arange(nz * ny * nx, dtype=rasterio.float32).reshape(*data_shape) + if specific_dtype is None: + specific_dtype = rasterio.float32 + data = np.arange(nz * ny * nx, dtype=rasterio.float32).reshape(*data_shape) + else: + data = np.arange(nz * ny * nx,dtype=rasterio.float32).reshape(*data_shape) if transform is None: transform = from_origin(*transform_args) if additional_attrs is None: @@ -4180,7 +4185,7 @@ def create_tmp_geotiff( count=nz, crs=crs, transform=transform, - dtype=rasterio.float32, + dtype=specific_dtype, **open_kwargs, ) as s: for attr, val in additional_attrs.items(): @@ -4697,6 +4702,14 @@ def test_rasterio_vrt_network(self): assert actual_res == expected_res assert expected_val == actual_val + def test_rasterio_complex_dtype( self ): + import rasterio + with create_tmp_geotiff(specific_dtype='complex_int16', + ) as (tmp_file, _): + with rasterio.open(tmp_file) as riobj: + assert riobj.dtypes[0]=='complex_int16' + with xr.open_rasterio(tmp_file) as rioda: + assert rioda.dtype==complex class TestEncodingInvalid: def test_extract_nc4_variable_encoding(self):