diff --git a/mkl_fft/_numpy_fft.py b/mkl_fft/_numpy_fft.py index 6f1de34..a694097 100644 --- a/mkl_fft/_numpy_fft.py +++ b/mkl_fft/_numpy_fft.py @@ -1039,10 +1039,10 @@ def rfftn(a, s=None, axes=None, norm=None): if unitary: a = asarray(a) s, axes = _cook_nd_args(a, s, axes) - n_tot = numpy.prod([ s[ai] for ai in axes]) output = mkl_fft.rfftn_numpy(a, s, axes) if unitary: + n_tot = prod(asarray(s, dtype=output.dtype)) output *= 1 / sqrt(n_tot) return output