Skip to content

Commit 66873e8

Browse files
authored
backwards compat for torchaudio<=0.11
1 parent 1e534ec commit 66873e8

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ __pycache__
33
shifted*.wav
44
*.egg-info
55
build
6-
dist
6+
dist
7+
.venv

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
torch>=1.7.0
22
torchaudio>=0.7.0
3-
primePy>=1.3
3+
primePy>=1.3
4+
packaging>=21.3

torch_pitch_shift/main.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
import warnings
12
from collections import Counter
23
from fractions import Fraction
34
from functools import reduce
45
from itertools import chain, count, islice, repeat
5-
from typing import Union, Callable, List, Optional
6-
from torch.nn.functional import pad
6+
from math import log2
7+
from typing import Callable, List, Optional, Union
8+
79
import torch
10+
import torchaudio
811
import torchaudio.transforms as T
12+
from packaging import version
913
from primePy import primes
10-
from math import log2
11-
import warnings
14+
from torch.nn.functional import pad
1215

1316
warnings.simplefilter("ignore")
1417

@@ -149,7 +152,8 @@ def pitch_shift(
149152
resampler = T.Resample(sample_rate, int(sample_rate / shift)).to(input.device)
150153
output = input
151154
output = output.reshape(batch_size * channels, samples)
152-
output = torch.stft(output, n_fft, hop_length, return_complex=True)[None, ...]
155+
v011 = version.parse(torchaudio.__version__) >= version.parse("0.11.0")
156+
output = torch.stft(output, n_fft, hop_length, return_complex=v011)[None, ...]
153157
stretcher = T.TimeStretch(
154158
fixed_rate=float(1 / shift), n_freq=output.shape[2], hop_length=hop_length
155159
).to(input.device)

0 commit comments

Comments
 (0)