@@ -4137,6 +4137,7 @@ def create_tmp_geotiff(
4137
4137
crs = default_value ,
4138
4138
open_kwargs = None ,
4139
4139
additional_attrs = None ,
4140
+ specific_dtype = None
4140
4141
):
4141
4142
if transform_args is default_value :
4142
4143
transform_args = [5000 , 80000 , 1000 , 2000.0 ]
@@ -4163,7 +4164,11 @@ def create_tmp_geotiff(
4163
4164
else :
4164
4165
data_shape = nz , ny , nx
4165
4166
write_kwargs = {}
4166
- data = np .arange (nz * ny * nx , dtype = rasterio .float32 ).reshape (* data_shape )
4167
+ if specific_dtype is None :
4168
+ specific_dtype = rasterio .float32
4169
+ data = np .arange (nz * ny * nx , dtype = rasterio .float32 ).reshape (* data_shape )
4170
+ else :
4171
+ data = np .arange (nz * ny * nx ,dtype = rasterio .float32 ).reshape (* data_shape )
4167
4172
if transform is None :
4168
4173
transform = from_origin (* transform_args )
4169
4174
if additional_attrs is None :
@@ -4180,7 +4185,7 @@ def create_tmp_geotiff(
4180
4185
count = nz ,
4181
4186
crs = crs ,
4182
4187
transform = transform ,
4183
- dtype = rasterio . float32 ,
4188
+ dtype = specific_dtype ,
4184
4189
** open_kwargs ,
4185
4190
) as s :
4186
4191
for attr , val in additional_attrs .items ():
@@ -4697,6 +4702,14 @@ def test_rasterio_vrt_network(self):
4697
4702
assert actual_res == expected_res
4698
4703
assert expected_val == actual_val
4699
4704
4705
+ def test_rasterio_complex_dtype ( self ):
4706
+ import rasterio
4707
+ with create_tmp_geotiff (specific_dtype = 'complex_int16' ,
4708
+ ) as (tmp_file , _ ):
4709
+ with rasterio .open (tmp_file ) as riobj :
4710
+ assert riobj .dtypes [0 ]== 'complex_int16'
4711
+ with xr .open_rasterio (tmp_file ) as rioda :
4712
+ assert rioda .dtype == complex
4700
4713
4701
4714
class TestEncodingInvalid :
4702
4715
def test_extract_nc4_variable_encoding (self ):
0 commit comments