Skip to content

Commit 4915524

Browse files
authored
Add MVDR module to example (#1708)
- Support three solutions for MVDR beamforming ("ref_channel", "stv_evd", "stv_power"). - Support single-channel and multi-channel time-frequency masks - Add unit tests
1 parent 38528cf commit 4915524

12 files changed

+779
-0
lines changed

examples/beamforming/mvdr.py

Lines changed: 445 additions & 0 deletions
Large diffs are not rendered by default.

test/torchaudio_unittest/example/beamforming/__init__.py

Whitespace-only changes.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from torchaudio_unittest.common_utils import PytorchTestCase
2+
from .autograd_test_impl import AutogradTestMixin
3+
4+
5+
class AutogradCPUTest(AutogradTestMixin, PytorchTestCase):
6+
device = 'cpu'
7+
8+
9+
class AutogradRNNTCPUTest(PytorchTestCase):
10+
device = 'cpu'
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from torchaudio_unittest.common_utils import (
2+
PytorchTestCase,
3+
skipIfNoCuda,
4+
)
5+
from .autograd_test_impl import AutogradTestMixin
6+
7+
8+
@skipIfNoCuda
9+
class AutogradCUDATest(AutogradTestMixin, PytorchTestCase):
10+
device = 'cuda'
11+
12+
13+
@skipIfNoCuda
14+
class AutogradRNNTCUDATest(PytorchTestCase):
15+
device = 'cuda'
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import List
2+
3+
import torch
4+
from beamforming.mvdr import PSD, MVDR
5+
from parameterized import parameterized, param
6+
from torch.autograd import gradcheck, gradgradcheck
7+
8+
from torchaudio_unittest.common_utils import (
9+
TestBaseMixin,
10+
get_whitenoise,
11+
get_spectrogram,
12+
)
13+
14+
15+
class AutogradTestMixin(TestBaseMixin):
16+
def assert_grad(
17+
self,
18+
transform: torch.nn.Module,
19+
inputs: List[torch.Tensor],
20+
*,
21+
nondet_tol: float = 0.0,
22+
):
23+
transform = transform.to(dtype=torch.float64, device=self.device)
24+
25+
# gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or
26+
# `torch.cdouble`, when the default eps and tolerance values are used.
27+
inputs_ = []
28+
for i in inputs:
29+
if torch.is_tensor(i):
30+
i = i.to(
31+
dtype=torch.cdouble if i.is_complex() else torch.double,
32+
device=self.device)
33+
i.requires_grad = True
34+
inputs_.append(i)
35+
assert gradcheck(transform, inputs_)
36+
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
37+
38+
def test_psd(self):
39+
transform = PSD()
40+
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
41+
spectrogram = get_spectrogram(waveform, n_fft=400)
42+
self.assert_grad(transform, [spectrogram])
43+
44+
@parameterized.expand([
45+
[True],
46+
[False],
47+
])
48+
def test_psd_with_mask(self, multi_mask):
49+
transform = PSD(multi_mask=multi_mask)
50+
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
51+
spectrogram = get_spectrogram(waveform, n_fft=400)
52+
if multi_mask:
53+
mask = torch.rand(spectrogram.shape[-3:])
54+
else:
55+
mask = torch.rand(spectrogram.shape[-2:])
56+
57+
self.assert_grad(transform, [spectrogram, mask])
58+
59+
@parameterized.expand([
60+
param(solution="ref_channel"),
61+
param(solution="stv_power"),
62+
# evd will fail since the eigenvalues are not distinct
63+
# param(solution="stv_evd"),
64+
])
65+
def test_mvdr(self, solution):
66+
transform = MVDR(solution=solution)
67+
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
68+
spectrogram = get_spectrogram(waveform, n_fft=400)
69+
mask = torch.rand(spectrogram.shape[-2:])
70+
self.assert_grad(transform, [spectrogram, mask])
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Test numerical consistency among single input and batched input."""
2+
import torch
3+
from beamforming.mvdr import PSD, MVDR
4+
from parameterized import parameterized
5+
6+
from torchaudio_unittest import common_utils
7+
8+
9+
class TestTransforms(common_utils.TorchaudioTestCase):
10+
def test_batch_PSD(self):
11+
spec = torch.rand((4, 6, 201, 100), dtype=torch.cdouble)
12+
13+
# Single then transform then batch
14+
expected = []
15+
for i in range(4):
16+
expected.append(PSD()(spec[i]))
17+
expected = torch.stack(expected)
18+
19+
# Batch then transform
20+
computed = PSD()(spec)
21+
22+
self.assertEqual(computed, expected)
23+
24+
def test_batch_PSD_with_mask(self):
25+
spec = torch.rand((4, 6, 201, 100), dtype=torch.cdouble)
26+
mask = torch.rand((4, 201, 100))
27+
28+
# Single then transform then batch
29+
expected = []
30+
for i in range(4):
31+
expected.append(PSD()(spec[i], mask[i]))
32+
expected = torch.stack(expected)
33+
34+
# Batch then transform
35+
computed = PSD()(spec, mask)
36+
37+
self.assertEqual(computed, expected)
38+
39+
@parameterized.expand([
40+
[True],
41+
[False],
42+
])
43+
def test_MVDR(self, multi_mask):
44+
spec = torch.rand((4, 6, 201, 100), dtype=torch.cdouble)
45+
if multi_mask:
46+
mask = torch.rand((4, 6, 201, 100))
47+
else:
48+
mask = torch.rand((4, 201, 100))
49+
50+
# Single then transform then batch
51+
expected = []
52+
for i in range(4):
53+
expected.append(MVDR(multi_mask=multi_mask)(spec[i], mask[i]))
54+
expected = torch.stack(expected)
55+
56+
# Batch then transform
57+
computed = MVDR(multi_mask=multi_mask)(spec, mask)
58+
59+
self.assertEqual(computed, expected)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
from torchaudio_unittest.common_utils import PytorchTestCase
4+
from .torchscript_consistency_impl import Transforms, TransformsFloat64Only
5+
6+
7+
class TestTransformsFloat32(Transforms, PytorchTestCase):
8+
dtype = torch.float32
9+
device = torch.device('cpu')
10+
11+
12+
class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase):
13+
dtype = torch.float64
14+
device = torch.device('cpu')
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
4+
from .torchscript_consistency_impl import Transforms, TransformsFloat64Only
5+
6+
7+
@skipIfNoCuda
8+
class TestTransformsFloat32(Transforms, PytorchTestCase):
9+
dtype = torch.float32
10+
device = torch.device('cuda')
11+
12+
13+
@skipIfNoCuda
14+
class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase):
15+
dtype = torch.float64
16+
device = torch.device('cuda')
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Test suites for jit-ability and its numerical compatibility"""
2+
3+
import torch
4+
from beamforming.mvdr import PSD, MVDR
5+
from parameterized import parameterized, param
6+
7+
from torchaudio_unittest import common_utils
8+
from torchaudio_unittest.common_utils import (
9+
TempDirMixin,
10+
TestBaseMixin,
11+
)
12+
13+
14+
class Transforms(TempDirMixin, TestBaseMixin):
15+
"""Implements test for Transforms that are performed for different devices"""
16+
def _assert_consistency_complex(self, transform, tensors):
17+
assert tensors[0].is_complex()
18+
tensors = [tensor.to(device=self.device, dtype=self.complex_dtype) for tensor in tensors]
19+
transform = transform.to(device=self.device, dtype=self.dtype)
20+
21+
path = self.get_temp_path('func.zip')
22+
torch.jit.script(transform).save(path)
23+
ts_transform = torch.jit.load(path)
24+
25+
output = transform(*tensors)
26+
ts_output = ts_transform(*tensors)
27+
self.assertEqual(ts_output, output)
28+
29+
def test_PSD(self):
30+
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
31+
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
32+
self._assert_consistency_complex(PSD(), (spectrogram,))
33+
34+
def test_PSD_with_mask(self):
35+
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
36+
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
37+
mask = torch.rand(spectrogram.shape[-2:])
38+
self._assert_consistency_complex(PSD(), (spectrogram, mask))
39+
40+
41+
class TransformsFloat64Only(TestBaseMixin):
42+
@parameterized.expand([
43+
param(solution="ref_channel", online=True),
44+
param(solution="stv_evd", online=True),
45+
param(solution="stv_power", online=True),
46+
param(solution="ref_channel", online=False),
47+
param(solution="stv_evd", online=False),
48+
param(solution="stv_power", online=False),
49+
])
50+
def test_MVDR(self, solution, online):
51+
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
52+
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
53+
mask = torch.rand(spectrogram.shape[-2:])
54+
self._assert_consistency_complex(
55+
MVDR(solution=solution, online=online),
56+
(spectrogram, mask)
57+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
from torchaudio_unittest.common_utils import PytorchTestCase
4+
from . transforms_test_impl import TransformsTestBase
5+
6+
7+
class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase):
8+
device = 'cpu'
9+
dtype = torch.float32
10+
11+
12+
class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase):
13+
device = 'cpu'
14+
dtype = torch.float64

0 commit comments

Comments
 (0)