@@ -1087,8 +1087,9 @@ class SpectralCentroid(torch.nn.Module):
1087
1087
win_length (int or None, optional): Window size. (Default: ``n_fft``)
1088
1088
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
1089
1089
pad (int, optional): Two sided padding of signal. (Default: ``0``)
1090
- window(Tensor, optional): A window tensor that is applied/multiplied to each frame.
1091
- (Default: ``torch.hann_window(win_length)``)
1090
+ window_fn (Callable[..., Tensor], optional): A function to create a window tensor
1091
+ that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
1092
+ wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
1092
1093
1093
1094
Example
1094
1095
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
@@ -1102,14 +1103,14 @@ def __init__(self,
1102
1103
win_length : Optional [int ] = None ,
1103
1104
hop_length : Optional [int ] = None ,
1104
1105
pad : int = 0 ,
1105
- window : Optional [Tensor ] = None ) -> None :
1106
+ window_fn : Callable [..., Tensor ] = torch .hann_window ,
1107
+ wkwargs : Optional [dict ] = None ) -> None :
1106
1108
super (SpectralCentroid , self ).__init__ ()
1107
1109
self .sample_rate = sample_rate
1108
1110
self .n_fft = n_fft
1109
1111
self .win_length = win_length if win_length is not None else n_fft
1110
1112
self .hop_length = hop_length if hop_length is not None else self .win_length // 2
1111
- if window is None :
1112
- window = torch .hann_window (self .win_length )
1113
+ window = window_fn (self .win_length ) if wkwargs is None else window_fn (self .win_length , ** wkwargs )
1113
1114
self .register_buffer ('window' , window )
1114
1115
self .pad = pad
1115
1116
0 commit comments