Skip to content

Commit

Permalink
Merge pull request #6 from WoodieDudy/torchDF-speedup-stft
Browse files Browse the repository at this point in the history
created checkpoint with 512 hop_size
  • Loading branch information
grazder authored May 23, 2024
2 parents 3453346 + 1204593 commit 408414f
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 112 deletions.
Binary file added models/DeepFilterNet3_torchDF.zip
Binary file not shown.
Binary file added models/DeepFilterNet3_torchDF_onnx.tar
Binary file not shown.
22 changes: 21 additions & 1 deletion torchDF/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ To convert model to onnx and run tests:
poetry run python model_onnx_export.py --test --performance --inference-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --ort --simplify --profiling --minimal
```

I also changed the hop_size parameter from 480 to 512 to speed up the stft. And then finetuned 3 epoches to adapt the model to size 512. New fast model can be found in `models/` dir.
```
../models
├── DeepFilterNet3_torchDF
│   ├── checkpoints/model_123.ckpt.best
│   └── config.ini
├── DeepFilterNet3_torchDF_onnx
│   ├── denoiser_model.onnx
│   ├── denoiser_model.ort
│   └── denoiser_model.required_operators.config
```

How to convert new model to onnx:
```sh
# cd torchDF
unzip ../models/DeepFilterNet3_torchDF.zip -d ../models/
python model_onnx_export.py --performance --minimal --simplify --ort --model-base-dir ../models/DeepFilterNet3_torchDF
```


TODO:
* Issues about split + simplify
* Thinkging of offline method exportability + compatability with streaming functions
Expand All @@ -46,4 +66,4 @@ TODO:
* rfft hacks tests
* torch.nonzero thinking
* rfft nn.module
* more static methods
* more static methods
17 changes: 15 additions & 2 deletions torchDF/model_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torch.onnx._internal import jit_utils
from loguru import logger

from df.enhance import parse_epoch_type

torch.manual_seed(0)

OPSET_VERSION = 17
Expand Down Expand Up @@ -171,9 +173,18 @@ def main(args):
torch.set_num_interop_threads(1)

if args.minimal:
streaming_pipeline = TorchDFMinimalPipeline(device="cpu")
streaming_pipeline = TorchDFMinimalPipeline(
device="cpu",
model_base_dir=args.model_base_dir,
epoch=args.epoch
)
else:
streaming_pipeline = TorchDFPipeline(device="cpu", always_apply_all_stages=True)
streaming_pipeline = TorchDFPipeline(
device="cpu",
always_apply_all_stages=True,
model_base_dir=args.model_base_dir,
epoch=args.epoch
)

frame_size = streaming_pipeline.hop_size
input_names = streaming_pipeline.input_names
Expand Down Expand Up @@ -314,4 +325,6 @@ def main(args):
parser.add_argument("--ort", action="store_true", help="Save to ort format")
parser.add_argument("--profiling", action="store_true", help="Run ONNX profiler")
parser.add_argument("--minimal", action="store_true", help="Export minimal version")
parser.add_argument("--model-base-dir", type=str, default='DeepFilterNet3', help="Path to model base dir with \"checkpoints\" subdir")
parser.add_argument("-e", "--epoch", type=parse_epoch_type, default='best', help="Epoch for checkpoint loading. Can be one of ['best', 'latest', <int>].")
main(parser.parse_args())
77 changes: 21 additions & 56 deletions torchDF/torch_df_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from torch import nn
from torch import Tensor
from typing import Tuple
import numpy as np

from df import init_df
from df.model import ModelParams


class ExportableStreamingTorchDF(nn.Module):
Expand All @@ -25,6 +27,7 @@ def __init__(
enc,
df_dec,
erb_dec,
erb_indices,
df_order=5,
lookahead=2,
conv_lookahead=2,
Expand Down Expand Up @@ -60,44 +63,7 @@ def __init__(
self.register_buffer("window", window)

self.nb_df = nb_df

# Initializing erb features
self.erb_indices = torch.tensor(
[
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
5,
5,
7,
7,
8,
10,
12,
13,
15,
18,
20,
24,
28,
31,
37,
42,
50,
56,
67,
]
)
self.erb_indices = torch.from_numpy(erb_indices.astype(np.int64))
self.nb_bands = nb_bands

self.register_buffer(
Expand Down Expand Up @@ -758,39 +724,38 @@ def forward(
class TorchDFPipeline(nn.Module):
def __init__(
self,
nb_bands=32,
hop_size=480,
fft_size=960,
df_order=5,
conv_lookahead=2,
nb_df=96,
model_base_dir="DeepFilterNet3",
atten_lim_db=0.0,
epoch="best",
always_apply_all_stages=False,
atten_lim_db=0.0,
device="cpu",
):
super().__init__()
self.hop_size = hop_size
self.fft_size = fft_size

model, state, _ = init_df(
config_allow_defaults=True, model_base_dir=model_base_dir
config_allow_defaults=True,
model_base_dir=model_base_dir,
epoch=epoch,
)
model.eval()
self.sample_rate = state.sr()
p = ModelParams()

self.hop_size = p.hop_size
self.fft_size = p.fft_size
self.sample_rate = p.sr

self.torch_streaming_model = ExportableStreamingTorchDF(
nb_bands=nb_bands,
hop_size=hop_size,
fft_size=fft_size,
nb_bands=p.nb_erb,
hop_size=p.hop_size,
fft_size=p.fft_size,
enc=model.enc,
df_dec=model.df_dec,
erb_dec=model.erb_dec,
df_order=df_order,
always_apply_all_stages=always_apply_all_stages,
conv_lookahead=conv_lookahead,
nb_df=nb_df,
df_order=p.df_order,
conv_lookahead=p.conv_lookahead,
nb_df=p.nb_df,
sr=self.sample_rate,
erb_indices=state.erb_widths()
)
self.torch_streaming_model = self.torch_streaming_model.to(device)
self.atten_lim_db = torch.tensor(atten_lim_db, device=device)
Expand Down
70 changes: 17 additions & 53 deletions torchDF/torch_df_streaming_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing_extensions import Final
from torch.nn.parameter import Parameter
from torch.nn import init
import numpy as np


from torch.autograd import Function
Expand Down Expand Up @@ -485,6 +486,7 @@ def __init__(
enc,
df_dec,
erb_dec,
erb_indices,
df_order=5,
lookahead=2,
conv_lookahead=2,
Expand Down Expand Up @@ -520,44 +522,7 @@ def __init__(
self.register_buffer("window", window)

self.nb_df = nb_df

# Initializing erb features
self.erb_indices = torch.tensor(
[
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
5,
5,
7,
7,
8,
10,
12,
13,
15,
18,
20,
24,
28,
31,
37,
42,
50,
56,
67,
]
)
self.erb_indices = torch.from_numpy(erb_indices.astype(np.int64))
self.nb_bands = nb_bands

self.register_buffer(
Expand Down Expand Up @@ -1060,37 +1025,36 @@ def forward(
class TorchDFMinimalPipeline(nn.Module):
def __init__(
self,
nb_bands=32,
hop_size=480,
fft_size=960,
df_order=5,
conv_lookahead=2,
nb_df=96,
model_base_dir="DeepFilterNet3",
epoch="best",
device="cpu",
):
super().__init__()
self.hop_size = hop_size
self.fft_size = fft_size

model, state, _ = init_df(
config_allow_defaults=True,
model_base_dir=model_base_dir,
epoch=epoch,
)
model.eval()
self.sample_rate = state.sr()
p = ModelParams()

self.hop_size = p.hop_size
self.fft_size = p.fft_size
self.sample_rate = p.sr

self.torch_streaming_model = ExportableStreamingMinimalTorchDF(
nb_bands=nb_bands,
hop_size=hop_size,
fft_size=fft_size,
nb_bands=p.nb_erb,
hop_size=p.hop_size,
fft_size=p.fft_size,
enc=model.enc,
df_dec=model.df_dec,
erb_dec=model.erb_dec,
df_order=df_order,
conv_lookahead=conv_lookahead,
nb_df=nb_df,
df_order=p.df_order,
conv_lookahead=p.conv_lookahead,
nb_df=p.nb_df,
sr=self.sample_rate,
erb_indices=state.erb_widths()
)
self.torch_streaming_model = self.torch_streaming_model.to(device)

Expand Down

0 comments on commit 408414f

Please sign in to comment.