Skip to content

Commit 674d9b6

Browse files
committed
add backend test
1 parent 6ca31cb commit 674d9b6

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

xarray/backends/rasterio_.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def __init__(self, manager, lock, vrt_params=None):
3939
dtypes = riods.dtypes
4040
if not np.all(np.asarray(dtypes) == dtypes[0]):
4141
raise ValueError("All bands should have the same dtype")
42-
if dtypes[0] == "complex_int16":
43-
self._dtype = np.complex
44-
else:
42+
if dtypes[0] != 'complex_int16' :
4543
self._dtype = np.dtype(dtypes[0])
44+
else :
45+
self._dtype = complex
4646

4747
@property
4848
def dtype(self):

xarray/tests/test_backends.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4137,6 +4137,7 @@ def create_tmp_geotiff(
41374137
crs=default_value,
41384138
open_kwargs=None,
41394139
additional_attrs=None,
4140+
specific_dtype=None
41404141
):
41414142
if transform_args is default_value:
41424143
transform_args = [5000, 80000, 1000, 2000.0]
@@ -4163,7 +4164,11 @@ def create_tmp_geotiff(
41634164
else:
41644165
data_shape = nz, ny, nx
41654166
write_kwargs = {}
4166-
data = np.arange(nz * ny * nx, dtype=rasterio.float32).reshape(*data_shape)
4167+
if specific_dtype is None:
4168+
specific_dtype = rasterio.float32
4169+
data = np.arange(nz * ny * nx, dtype=rasterio.float32).reshape(*data_shape)
4170+
else:
4171+
data = np.arange(nz * ny * nx,dtype=rasterio.float32).reshape(*data_shape)
41674172
if transform is None:
41684173
transform = from_origin(*transform_args)
41694174
if additional_attrs is None:
@@ -4180,7 +4185,7 @@ def create_tmp_geotiff(
41804185
count=nz,
41814186
crs=crs,
41824187
transform=transform,
4183-
dtype=rasterio.float32,
4188+
dtype=specific_dtype,
41844189
**open_kwargs,
41854190
) as s:
41864191
for attr, val in additional_attrs.items():
@@ -4697,6 +4702,14 @@ def test_rasterio_vrt_network(self):
46974702
assert actual_res == expected_res
46984703
assert expected_val == actual_val
46994704

4705+
def test_rasterio_complex_dtype( self ):
4706+
import rasterio
4707+
with create_tmp_geotiff(specific_dtype='complex_int16',
4708+
) as (tmp_file, _):
4709+
with rasterio.open(tmp_file) as riobj:
4710+
assert riobj.dtypes[0]=='complex_int16'
4711+
with xr.open_rasterio(tmp_file) as rioda:
4712+
assert rioda.dtype==complex
47004713

47014714
class TestEncodingInvalid:
47024715
def test_extract_nc4_variable_encoding(self):

0 commit comments

Comments
 (0)