|
| 1 | +import warnings |
1 | 2 | from collections import Counter
|
2 | 3 | from fractions import Fraction
|
3 | 4 | from functools import reduce
|
4 | 5 | from itertools import chain, count, islice, repeat
|
5 |
| -from typing import Union, Callable, List, Optional |
6 |
| -from torch.nn.functional import pad |
| 6 | +from math import log2 |
| 7 | +from typing import Callable, List, Optional, Union |
| 8 | + |
7 | 9 | import torch
|
| 10 | +import torchaudio |
8 | 11 | import torchaudio.transforms as T
|
| 12 | +from packaging import version |
9 | 13 | from primePy import primes
|
10 |
| -from math import log2 |
11 |
| -import warnings |
| 14 | +from torch.nn.functional import pad |
12 | 15 |
|
13 | 16 | warnings.simplefilter("ignore")
|
14 | 17 |
|
@@ -149,7 +152,8 @@ def pitch_shift(
|
149 | 152 | resampler = T.Resample(sample_rate, int(sample_rate / shift)).to(input.device)
|
150 | 153 | output = input
|
151 | 154 | output = output.reshape(batch_size * channels, samples)
|
152 |
| - output = torch.stft(output, n_fft, hop_length, return_complex=True)[None, ...] |
| 155 | + v011 = version.parse(torchaudio.__version__) >= version.parse("0.11.0") |
| 156 | + output = torch.stft(output, n_fft, hop_length, return_complex=v011)[None, ...] |
153 | 157 | stretcher = T.TimeStretch(
|
154 | 158 | fixed_rate=float(1 / shift), n_freq=output.shape[2], hop_length=hop_length
|
155 | 159 | ).to(input.device)
|
|
0 commit comments