Skip to content

Commit

Permalink
Support Whisper-PMFA
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurora1818 committed Aug 30, 2024
1 parent fa0179b commit c6fd891
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
19 changes: 11 additions & 8 deletions wespeaker/frontend/whisper_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,15 @@ def forward(
q = self.query(x)

if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or
# hooks, if installed (i.e. kv_cache is not None),
# will prepend the cached kv tensors; otherwise,
# perform key/value projections for self- or
# cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in
# subsequent calls.
# for cross-attention, calculate keys and values once
# and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]

Expand Down Expand Up @@ -192,9 +193,9 @@ def forward(self, x: Tensor):
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)

# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
# ----------------------Change:Tailor the positional_embedding----------
assert x.shape[2:] == self.positional_embedding.shape[1:], "incorrect audio shape"
# ------------Change:Tailor the positional_embedding----------
assert x.shape[2:] == self.positional_embedding.shape[1:], \
"incorrect audio shape"
if self.positional_embedding.shape[0] > x.shape[1]:
temp_positional_embedding = self.positional_embedding[:x.shape[1], :]
elif self.positional_embedding.shape[0] < x.shape[1]:
Expand Down Expand Up @@ -266,7 +267,9 @@ def _download_whisper_model(self, model_path='whisper_hub/large-v2.pt'):
os.makedirs(download_dir)
if not os.path.isfile(model_path):
print("Downloading large-v2.pt ...")
url = 'https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt'
url = 'https://openaipublic.azureedge.net/main/whisper/models/' \
'81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/' \
'large-v2.pt'

urllib.request.urlretrieve(url, model_path)

Expand Down
3 changes: 0 additions & 3 deletions wespeaker/models/whisper_PMFA.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn

import wespeaker.models.pooling_layers as pooling_layers
Expand Down

0 comments on commit c6fd891

Please sign in to comment.