-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change the utils.download_emo_models (#199)
* Change the utils.download_emo_models Change utils.download_emo_models(config.mirror, model_name, REPO_ID) to utils.download_emo_models(config.mirror, REPO_ID, model_name) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
2c528ce
commit badb125
Showing
1 changed file
with
162 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,162 +1,163 @@ | ||
import argparse | ||
import os | ||
from pathlib import Path | ||
|
||
import librosa | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import DataLoader, Dataset | ||
from tqdm import tqdm | ||
from transformers import Wav2Vec2Processor | ||
from transformers.models.wav2vec2.modeling_wav2vec2 import ( | ||
Wav2Vec2Model, | ||
Wav2Vec2PreTrainedModel, | ||
) | ||
|
||
import utils | ||
from config import config | ||
|
||
|
||
class RegressionHead(nn.Module): | ||
r"""Classification head.""" | ||
|
||
def __init__(self, config): | ||
super().__init__() | ||
|
||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | ||
self.dropout = nn.Dropout(config.final_dropout) | ||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | ||
|
||
def forward(self, features, **kwargs): | ||
x = features | ||
x = self.dropout(x) | ||
x = self.dense(x) | ||
x = torch.tanh(x) | ||
x = self.dropout(x) | ||
x = self.out_proj(x) | ||
|
||
return x | ||
|
||
|
||
class EmotionModel(Wav2Vec2PreTrainedModel): | ||
r"""Speech emotion classifier.""" | ||
|
||
def __init__(self, config): | ||
super().__init__(config) | ||
|
||
self.config = config | ||
self.wav2vec2 = Wav2Vec2Model(config) | ||
self.classifier = RegressionHead(config) | ||
self.init_weights() | ||
|
||
def forward( | ||
self, | ||
input_values, | ||
): | ||
outputs = self.wav2vec2(input_values) | ||
hidden_states = outputs[0] | ||
hidden_states = torch.mean(hidden_states, dim=1) | ||
logits = self.classifier(hidden_states) | ||
|
||
return hidden_states, logits | ||
|
||
|
||
class AudioDataset(Dataset): | ||
def __init__(self, list_of_wav_files, sr, processor): | ||
self.list_of_wav_files = list_of_wav_files | ||
self.processor = processor | ||
self.sr = sr | ||
|
||
def __len__(self): | ||
return len(self.list_of_wav_files) | ||
|
||
def __getitem__(self, idx): | ||
wav_file = self.list_of_wav_files[idx] | ||
audio_data, _ = librosa.load(wav_file, sr=self.sr) | ||
processed_data = self.processor(audio_data, sampling_rate=self.sr)[ | ||
"input_values" | ||
][0] | ||
return torch.from_numpy(processed_data) | ||
|
||
|
||
def process_func( | ||
x: np.ndarray, | ||
sampling_rate: int, | ||
model: EmotionModel, | ||
processor: Wav2Vec2Processor, | ||
device: str, | ||
embeddings: bool = False, | ||
) -> np.ndarray: | ||
r"""Predict emotions or extract embeddings from raw audio signal.""" | ||
model = model.to(device) | ||
y = processor(x, sampling_rate=sampling_rate) | ||
y = y["input_values"][0] | ||
y = torch.from_numpy(y).unsqueeze(0).to(device) | ||
|
||
# run through model | ||
with torch.no_grad(): | ||
y = model(y)[0 if embeddings else 1] | ||
|
||
# convert to numpy | ||
y = y.detach().cpu().numpy() | ||
|
||
return y | ||
|
||
|
||
def get_emo(path): | ||
wav, sr = librosa.load(path, 16000) | ||
device = config.bert_gen_config.device | ||
return process_func( | ||
np.expand_dims(wav, 0).astype(np.float64), | ||
sr, | ||
model, | ||
processor, | ||
device, | ||
embeddings=True, | ||
).squeeze(0) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-c", "--config", type=str, default=config.bert_gen_config.config_path | ||
) | ||
parser.add_argument( | ||
"--num_processes", type=int, default=config.bert_gen_config.num_processes | ||
) | ||
args, _ = parser.parse_known_args() | ||
config_path = args.config | ||
hps = utils.get_hparams_from_file(config_path) | ||
|
||
device = config.bert_gen_config.device | ||
|
||
model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim" | ||
REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" | ||
if not Path(model_name).joinpath("pytorch_model.bin").exists(): | ||
utils.download_emo_models(config.mirror, model_name, REPO_ID) | ||
|
||
processor = Wav2Vec2Processor.from_pretrained(model_name) | ||
model = EmotionModel.from_pretrained(model_name).to(device) | ||
|
||
lines = [] | ||
with open(hps.data.training_files, encoding="utf-8") as f: | ||
lines.extend(f.readlines()) | ||
|
||
with open(hps.data.validation_files, encoding="utf-8") as f: | ||
lines.extend(f.readlines()) | ||
|
||
wavnames = [line.split("|")[0] for line in lines] | ||
dataset = AudioDataset(wavnames, 16000, processor) | ||
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=16) | ||
|
||
with torch.no_grad(): | ||
for i, data in tqdm(enumerate(data_loader), total=len(data_loader)): | ||
wavname = wavnames[i] | ||
emo_path = wavname.replace(".wav", ".emo.npy") | ||
if os.path.exists(emo_path): | ||
continue | ||
emb = model(data.to(device))[0].detach().cpu().numpy() | ||
np.save(emo_path, emb) | ||
|
||
import argparse | ||
import os | ||
from pathlib import Path | ||
|
||
import librosa | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import Dataset | ||
from torch.utils.data import DataLoader, Dataset | ||
from tqdm import tqdm | ||
from transformers import Wav2Vec2Processor | ||
from transformers.models.wav2vec2.modeling_wav2vec2 import ( | ||
Wav2Vec2Model, | ||
Wav2Vec2PreTrainedModel, | ||
) | ||
|
||
import utils | ||
from config import config | ||
|
||
|
||
class RegressionHead(nn.Module): | ||
r"""Classification head.""" | ||
|
||
def __init__(self, config): | ||
super().__init__() | ||
|
||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | ||
self.dropout = nn.Dropout(config.final_dropout) | ||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | ||
|
||
def forward(self, features, **kwargs): | ||
x = features | ||
x = self.dropout(x) | ||
x = self.dense(x) | ||
x = torch.tanh(x) | ||
x = self.dropout(x) | ||
x = self.out_proj(x) | ||
|
||
return x | ||
|
||
|
||
class EmotionModel(Wav2Vec2PreTrainedModel): | ||
r"""Speech emotion classifier.""" | ||
|
||
def __init__(self, config): | ||
super().__init__(config) | ||
|
||
self.config = config | ||
self.wav2vec2 = Wav2Vec2Model(config) | ||
self.classifier = RegressionHead(config) | ||
self.init_weights() | ||
|
||
def forward( | ||
self, | ||
input_values, | ||
): | ||
outputs = self.wav2vec2(input_values) | ||
hidden_states = outputs[0] | ||
hidden_states = torch.mean(hidden_states, dim=1) | ||
logits = self.classifier(hidden_states) | ||
|
||
return hidden_states, logits | ||
|
||
|
||
class AudioDataset(Dataset): | ||
def __init__(self, list_of_wav_files, sr, processor): | ||
self.list_of_wav_files = list_of_wav_files | ||
self.processor = processor | ||
self.sr = sr | ||
|
||
def __len__(self): | ||
return len(self.list_of_wav_files) | ||
|
||
def __getitem__(self, idx): | ||
wav_file = self.list_of_wav_files[idx] | ||
audio_data, _ = librosa.load(wav_file, sr=self.sr) | ||
processed_data = self.processor(audio_data, sampling_rate=self.sr)[ | ||
"input_values" | ||
][0] | ||
return torch.from_numpy(processed_data) | ||
|
||
|
||
def process_func( | ||
x: np.ndarray, | ||
sampling_rate: int, | ||
model: EmotionModel, | ||
processor: Wav2Vec2Processor, | ||
device: str, | ||
embeddings: bool = False, | ||
) -> np.ndarray: | ||
r"""Predict emotions or extract embeddings from raw audio signal.""" | ||
model = model.to(device) | ||
y = processor(x, sampling_rate=sampling_rate) | ||
y = y["input_values"][0] | ||
y = torch.from_numpy(y).unsqueeze(0).to(device) | ||
|
||
# run through model | ||
with torch.no_grad(): | ||
y = model(y)[0 if embeddings else 1] | ||
|
||
# convert to numpy | ||
y = y.detach().cpu().numpy() | ||
|
||
return y | ||
|
||
|
||
def get_emo(path): | ||
wav, sr = librosa.load(path, 16000) | ||
device = config.bert_gen_config.device | ||
return process_func( | ||
np.expand_dims(wav, 0).astype(np.float), | ||
sr, | ||
model, | ||
processor, | ||
device, | ||
embeddings=True, | ||
).squeeze(0) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-c", "--config", type=str, default=config.bert_gen_config.config_path | ||
) | ||
parser.add_argument( | ||
"--num_processes", type=int, default=config.bert_gen_config.num_processes | ||
) | ||
args, _ = parser.parse_known_args() | ||
config_path = args.config | ||
hps = utils.get_hparams_from_file(config_path) | ||
|
||
device = config.bert_gen_config.device | ||
|
||
model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim" | ||
REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" | ||
if not Path(model_name).joinpath("pytorch_model.bin").exists(): | ||
utils.download_emo_models(config.mirror, REPO_ID, model_name) | ||
|
||
processor = Wav2Vec2Processor.from_pretrained(model_name) | ||
model = EmotionModel.from_pretrained(model_name).to(device) | ||
|
||
lines = [] | ||
with open(hps.data.training_files, encoding="utf-8") as f: | ||
lines.extend(f.readlines()) | ||
|
||
with open(hps.data.validation_files, encoding="utf-8") as f: | ||
lines.extend(f.readlines()) | ||
|
||
wavnames = [line.split("|")[0] for line in lines] | ||
dataset = AudioDataset(wavnames, 16000, processor) | ||
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=16) | ||
|
||
with torch.no_grad(): | ||
for i, data in tqdm(enumerate(data_loader), total=len(data_loader)): | ||
wavname = wavnames[i] | ||
emo_path = wavname.replace(".wav", ".emo.npy") | ||
if os.path.exists(emo_path): | ||
continue | ||
emb = model(data.to(device))[0].detach().cpu().numpy() | ||
np.save(emo_path, emb) | ||
|
||
print("Emo vec 生成完毕!") |