diff --git a/models/DeepFilterNet3_torchDF.zip b/models/DeepFilterNet3_torchDF.zip new file mode 100644 index 000000000..c4ed6825a Binary files /dev/null and b/models/DeepFilterNet3_torchDF.zip differ diff --git a/models/DeepFilterNet3_torchDF_onnx.tar b/models/DeepFilterNet3_torchDF_onnx.tar new file mode 100644 index 000000000..5c25847de Binary files /dev/null and b/models/DeepFilterNet3_torchDF_onnx.tar differ diff --git a/torchDF/README.md b/torchDF/README.md index 0c416f9b5..da9af6b14 100644 --- a/torchDF/README.md +++ b/torchDF/README.md @@ -37,6 +37,26 @@ To convert model to onnx and run tests: poetry run python model_onnx_export.py --test --performance --inference-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --ort --simplify --profiling --minimal ``` +I also changed the hop_size parameter from 480 to 512 to speed up the stft. And then finetuned 3 epoches to adapt the model to size 512. New fast model can be found in `models/` dir. +``` +../models +├── DeepFilterNet3_torchDF +│   ├── checkpoints/model_123.ckpt.best +│   └── config.ini +├── DeepFilterNet3_torchDF_onnx +│   ├── denoiser_model.onnx +│   ├── denoiser_model.ort +│   └── denoiser_model.required_operators.config +``` + +How to convert new model to onnx: +```sh +# cd torchDF +unzip ../models/DeepFilterNet3_torchDF.zip -d ../models/ +python model_onnx_export.py --performance --minimal --simplify --ort --model-base-dir ../models/DeepFilterNet3_torchDF +``` + + TODO: * Issues about split + simplify * Thinkging of offline method exportability + compatability with streaming functions @@ -46,4 +66,4 @@ TODO: * rfft hacks tests * torch.nonzero thinking * rfft nn.module -* more static methods \ No newline at end of file +* more static methods diff --git a/torchDF/model_onnx_export.py b/torchDF/model_onnx_export.py index b08d72559..786e962c8 100644 --- a/torchDF/model_onnx_export.py +++ b/torchDF/model_onnx_export.py @@ -15,6 +15,8 @@ from torch.onnx._internal import jit_utils from loguru import logger +from df.enhance import parse_epoch_type + torch.manual_seed(0) OPSET_VERSION = 17 @@ -171,9 +173,18 @@ def main(args): torch.set_num_interop_threads(1) if args.minimal: - streaming_pipeline = TorchDFMinimalPipeline(device="cpu") + streaming_pipeline = TorchDFMinimalPipeline( + device="cpu", + model_base_dir=args.model_base_dir, + epoch=args.epoch + ) else: - streaming_pipeline = TorchDFPipeline(device="cpu", always_apply_all_stages=True) + streaming_pipeline = TorchDFPipeline( + device="cpu", + always_apply_all_stages=True, + model_base_dir=args.model_base_dir, + epoch=args.epoch + ) frame_size = streaming_pipeline.hop_size input_names = streaming_pipeline.input_names @@ -314,4 +325,6 @@ def main(args): parser.add_argument("--ort", action="store_true", help="Save to ort format") parser.add_argument("--profiling", action="store_true", help="Run ONNX profiler") parser.add_argument("--minimal", action="store_true", help="Export minimal version") + parser.add_argument("--model-base-dir", type=str, default='DeepFilterNet3', help="Path to model base dir with \"checkpoints\" subdir") + parser.add_argument("-e", "--epoch", type=parse_epoch_type, default='best', help="Epoch for checkpoint loading. Can be one of ['best', 'latest', ].") main(parser.parse_args()) diff --git a/torchDF/torch_df_streaming.py b/torchDF/torch_df_streaming.py index 491d3c130..f85160847 100644 --- a/torchDF/torch_df_streaming.py +++ b/torchDF/torch_df_streaming.py @@ -12,8 +12,10 @@ from torch import nn from torch import Tensor from typing import Tuple +import numpy as np from df import init_df +from df.model import ModelParams class ExportableStreamingTorchDF(nn.Module): @@ -25,6 +27,7 @@ def __init__( enc, df_dec, erb_dec, + erb_indices, df_order=5, lookahead=2, conv_lookahead=2, @@ -60,44 +63,7 @@ def __init__( self.register_buffer("window", window) self.nb_df = nb_df - - # Initializing erb features - self.erb_indices = torch.tensor( - [ - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 5, - 5, - 7, - 7, - 8, - 10, - 12, - 13, - 15, - 18, - 20, - 24, - 28, - 31, - 37, - 42, - 50, - 56, - 67, - ] - ) + self.erb_indices = torch.from_numpy(erb_indices.astype(np.int64)) self.nb_bands = nb_bands self.register_buffer( @@ -758,39 +724,38 @@ def forward( class TorchDFPipeline(nn.Module): def __init__( self, - nb_bands=32, - hop_size=480, - fft_size=960, - df_order=5, - conv_lookahead=2, - nb_df=96, model_base_dir="DeepFilterNet3", - atten_lim_db=0.0, + epoch="best", always_apply_all_stages=False, + atten_lim_db=0.0, device="cpu", ): super().__init__() - self.hop_size = hop_size - self.fft_size = fft_size model, state, _ = init_df( - config_allow_defaults=True, model_base_dir=model_base_dir + config_allow_defaults=True, + model_base_dir=model_base_dir, + epoch=epoch, ) model.eval() - self.sample_rate = state.sr() + p = ModelParams() + self.hop_size = p.hop_size + self.fft_size = p.fft_size + self.sample_rate = p.sr + self.torch_streaming_model = ExportableStreamingTorchDF( - nb_bands=nb_bands, - hop_size=hop_size, - fft_size=fft_size, + nb_bands=p.nb_erb, + hop_size=p.hop_size, + fft_size=p.fft_size, enc=model.enc, df_dec=model.df_dec, erb_dec=model.erb_dec, - df_order=df_order, - always_apply_all_stages=always_apply_all_stages, - conv_lookahead=conv_lookahead, - nb_df=nb_df, + df_order=p.df_order, + conv_lookahead=p.conv_lookahead, + nb_df=p.nb_df, sr=self.sample_rate, + erb_indices=state.erb_widths() ) self.torch_streaming_model = self.torch_streaming_model.to(device) self.atten_lim_db = torch.tensor(atten_lim_db, device=device) diff --git a/torchDF/torch_df_streaming_minimal.py b/torchDF/torch_df_streaming_minimal.py index 219184b4a..cdc03c756 100644 --- a/torchDF/torch_df_streaming_minimal.py +++ b/torchDF/torch_df_streaming_minimal.py @@ -28,6 +28,7 @@ from typing_extensions import Final from torch.nn.parameter import Parameter from torch.nn import init +import numpy as np from torch.autograd import Function @@ -485,6 +486,7 @@ def __init__( enc, df_dec, erb_dec, + erb_indices, df_order=5, lookahead=2, conv_lookahead=2, @@ -520,44 +522,7 @@ def __init__( self.register_buffer("window", window) self.nb_df = nb_df - - # Initializing erb features - self.erb_indices = torch.tensor( - [ - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 5, - 5, - 7, - 7, - 8, - 10, - 12, - 13, - 15, - 18, - 20, - 24, - 28, - 31, - 37, - 42, - 50, - 56, - 67, - ] - ) + self.erb_indices = torch.from_numpy(erb_indices.astype(np.int64)) self.nb_bands = nb_bands self.register_buffer( @@ -1060,37 +1025,36 @@ def forward( class TorchDFMinimalPipeline(nn.Module): def __init__( self, - nb_bands=32, - hop_size=480, - fft_size=960, - df_order=5, - conv_lookahead=2, - nb_df=96, model_base_dir="DeepFilterNet3", + epoch="best", device="cpu", ): super().__init__() - self.hop_size = hop_size - self.fft_size = fft_size model, state, _ = init_df( config_allow_defaults=True, model_base_dir=model_base_dir, + epoch=epoch, ) model.eval() - self.sample_rate = state.sr() + p = ModelParams() + self.hop_size = p.hop_size + self.fft_size = p.fft_size + self.sample_rate = p.sr + self.torch_streaming_model = ExportableStreamingMinimalTorchDF( - nb_bands=nb_bands, - hop_size=hop_size, - fft_size=fft_size, + nb_bands=p.nb_erb, + hop_size=p.hop_size, + fft_size=p.fft_size, enc=model.enc, df_dec=model.df_dec, erb_dec=model.erb_dec, - df_order=df_order, - conv_lookahead=conv_lookahead, - nb_df=nb_df, + df_order=p.df_order, + conv_lookahead=p.conv_lookahead, + nb_df=p.nb_df, sr=self.sample_rate, + erb_indices=state.erb_widths() ) self.torch_streaming_model = self.torch_streaming_model.to(device)