Skip to content

Commit 95a1edd

Browse files
committed
Add band keyword argument to read_raster
It simply delegates to `read_raster_band`.
1 parent b3037d4 commit 95a1edd

File tree

3 files changed

+64
-12
lines changed

3 files changed

+64
-12
lines changed

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ parallel using [Rasterio](https://github.com/mapbox/rasterio) and
1010

1111
## Usage
1212

13-
Read a multiband raster with `read_raster`:
13+
#### Read a multiband raster
1414

1515
```python
1616
>>> from dask_rasterio import read_raster
@@ -25,11 +25,21 @@ dask.array<mean_agg-aggregate, shape=(), dtype=float64, chunksize=()>
2525
40.858976977533935
2626
```
2727

28-
Write a singleband or multiband raster with `write_raster`:
28+
#### Read a single band from a raster
2929

3030
```python
3131
>>> from dask_rasterio import read_raster
3232

33+
>>> array = read_raster('tests/data/RGB.byte.tif', band=3)
34+
>>> array
35+
dask.array<raster, shape=(718, 791), dtype=uint8, chunksize=(3, 791)>
36+
```
37+
38+
#### Write a singleband or multiband raster
39+
40+
```python
41+
>>> from dask_rasterio import read_raster, write_raster
42+
3343
>>> array = read_raster('tests/data/RGB.byte.tif')
3444

3545
>>> new_array = array & (array > 100)

dask_rasterio/read.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,43 @@
55
from rasterio.windows import Window
66

77

8-
def read_raster(path, block_size=1):
9-
"""Read all bands from raster"""
10-
bands = range(1, get_band_count(path) + 1)
11-
return da.stack([
12-
read_raster_band(path, band=band, block_size=block_size)
13-
for band in bands
14-
])
8+
def read_raster(path, band=None, block_size=1):
9+
"""Read all or some bands from raster
10+
11+
Arguments:
12+
path {string} -- path to raster file
13+
14+
Keyword Arguments:
15+
band {int, iterable(int)} -- band number or iterable of bands.
16+
When passing None, it reads all bands (default: {None})
17+
block_size {int} -- block size multiplier (default: {1})
18+
19+
Returns:
20+
dask.array.Array -- a Dask array
21+
"""
22+
23+
if isinstance(band, int):
24+
return read_raster_band(path, band=band, block_size=block_size)
25+
else:
26+
if band is None:
27+
bands = range(1, get_band_count(path) + 1)
28+
else:
29+
bands = list(band)
30+
return da.stack([
31+
read_raster_band(path, band=band, block_size=block_size)
32+
for band in bands
33+
])
1534

1635

1736
def read_raster_band(path, band=1, block_size=1):
1837
"""Read a raster band and return a Dask array
1938
2039
Arguments:
21-
path {string} -- Path to the raster file
40+
path {string} -- path to the raster file
2241
2342
Keyword Arguments:
24-
band {int} -- Number of band to read (default: {1})
25-
block_size {int} -- Multiplier for block size (default: {1})
43+
band {int} -- number of band to read (default: {1})
44+
block_size {int} -- block size multiplier (default: {1})
2645
2746
"""
2847

tests/test_dask_rasterio.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,29 @@ def test_read_raster_band(some_raster_path):
5656
assert_array_equal(g_array.compute(), expected_array)
5757

5858

59+
def test_read_raster_single_band(some_raster_path):
60+
array = read_raster(some_raster_path, band=3)
61+
assert isinstance(array, da.Array)
62+
63+
expected_array = read_raster_band(some_raster_path, band=3)
64+
assert array.shape == expected_array.shape
65+
assert array.dtype == expected_array.dtype
66+
assert_array_equal(array.compute(), expected_array.compute())
67+
68+
69+
def test_read_raster_multi_band(some_raster_path):
70+
array = read_raster(some_raster_path, band=(1, 3))
71+
assert isinstance(array, da.Array)
72+
73+
expected_array = da.stack([
74+
read_raster_band(some_raster_path, band=1),
75+
read_raster_band(some_raster_path, band=3)
76+
])
77+
assert array.shape == expected_array.shape
78+
assert array.dtype == expected_array.dtype
79+
assert_array_equal(array.compute(), expected_array.compute())
80+
81+
5982
def test_do_calcs_on_array(some_raster_path):
6083
r_array = read_raster_band(some_raster_path, 1)
6184
mean = np.mean(r_array)

0 commit comments

Comments
 (0)