diff --git a/DeepFilterNet/df/deepfilternet3.py b/DeepFilterNet/df/deepfilternet3.py index 94c155ac1..6985a0bbf 100644 --- a/DeepFilterNet/df/deepfilternet3.py +++ b/DeepFilterNet/df/deepfilternet3.py @@ -28,7 +28,9 @@ def __init__(self): self.conv_lookahead: int = config( "CONV_LOOKAHEAD", cast=int, default=0, section=self.section ) - self.conv_ch: int = config("CONV_CH", cast=int, default=16, section=self.section) + self.conv_ch: int = config( + "CONV_CH", cast=int, default=16, section=self.section + ) self.conv_depthwise: bool = config( "CONV_DEPTHWISE", cast=bool, default=True, section=self.section ) @@ -36,13 +38,22 @@ def __init__(self): "CONVT_DEPTHWISE", cast=bool, default=True, section=self.section ) self.conv_kernel: List[int] = config( - "CONV_KERNEL", cast=Csv(int), default=(1, 3), section=self.section # type: ignore + "CONV_KERNEL", + cast=Csv(int), + default=(1, 3), + section=self.section, # type: ignore ) self.convt_kernel: List[int] = config( - "CONVT_KERNEL", cast=Csv(int), default=(1, 3), section=self.section # type: ignore + "CONVT_KERNEL", + cast=Csv(int), + default=(1, 3), + section=self.section, # type: ignore ) self.conv_kernel_inp: List[int] = config( - "CONV_KERNEL_INP", cast=Csv(int), default=(3, 3), section=self.section # type: ignore + "CONV_KERNEL_INP", + cast=Csv(int), + default=(3, 3), + section=self.section, # type: ignore ) self.emb_hidden_dim: int = config( "EMB_HIDDEN_DIM", cast=int, default=256, section=self.section @@ -53,31 +64,49 @@ def __init__(self): self.emb_gru_skip_enc: str = config( "EMB_GRU_SKIP_ENC", default="none", section=self.section ) - self.emb_gru_skip: str = config("EMB_GRU_SKIP", default="none", section=self.section) + self.emb_gru_skip: str = config( + "EMB_GRU_SKIP", default="none", section=self.section + ) self.df_hidden_dim: int = config( "DF_HIDDEN_DIM", cast=int, default=256, section=self.section ) - self.df_gru_skip: str = config("DF_GRU_SKIP", default="none", section=self.section) + self.df_gru_skip: str = config( + "DF_GRU_SKIP", default="none", section=self.section + ) self.df_pathway_kernel_size_t: int = config( "DF_PATHWAY_KERNEL_SIZE_T", cast=int, default=1, section=self.section ) - self.enc_concat: bool = config("ENC_CONCAT", cast=bool, default=False, section=self.section) - self.df_num_layers: int = config("DF_NUM_LAYERS", cast=int, default=3, section=self.section) - self.df_n_iter: int = config("DF_N_ITER", cast=int, default=1, section=self.section) - self.lin_groups: int = config("LINEAR_GROUPS", cast=int, default=1, section=self.section) + self.enc_concat: bool = config( + "ENC_CONCAT", cast=bool, default=False, section=self.section + ) + self.df_num_layers: int = config( + "DF_NUM_LAYERS", cast=int, default=3, section=self.section + ) + self.df_n_iter: int = config( + "DF_N_ITER", cast=int, default=1, section=self.section + ) + self.lin_groups: int = config( + "LINEAR_GROUPS", cast=int, default=1, section=self.section + ) self.enc_lin_groups: int = config( "ENC_LINEAR_GROUPS", cast=int, default=16, section=self.section ) - self.mask_pf: bool = config("MASK_PF", cast=bool, default=False, section=self.section) + self.mask_pf: bool = config( + "MASK_PF", cast=bool, default=False, section=self.section + ) self.lsnr_dropout: bool = config( "LSNR_DROPOUT", cast=bool, default=False, section=self.section ) -def init_model(df_state: Optional[DF] = None, run_df: bool = True, train_mask: bool = True): +def init_model( + df_state: Optional[DF] = None, run_df: bool = True, train_mask: bool = True +): p = ModelParams() if df_state is None: - df_state = DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb) + df_state = DF( + sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb + ) erb = erb_fb(df_state.erb_widths(), p.sr, inverse=False) erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True) model = DfNet(erb, erb_inverse, run_df, train_mask) @@ -119,7 +148,11 @@ def __init__(self): self.erb_conv3 = conv_layer(fstride=1) self.df_conv0_ch = p.conv_ch self.df_conv0 = Conv2dNormAct( - 2, self.df_conv0_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True + 2, + self.df_conv0_ch, + kernel_size=p.conv_kernel_inp, + bias=False, + separable=True, ) self.df_conv1 = conv_layer(fstride=2) self.erb_bins = p.nb_erb @@ -243,7 +276,15 @@ def __init__(self): p.conv_ch, 1, kernel_size=p.conv_kernel, activation_layer=nn.Sigmoid ) - def forward(self, emb: Tensor, e3: Tensor, e2: Tensor, e1: Tensor, e0: Tensor, hidden: Tensor) -> Tuple[Tensor, Tensor]: + def forward( + self, + emb: Tensor, + e3: Tensor, + e2: Tensor, + e1: Tensor, + e0: Tensor, + hidden: Tensor, + ) -> Tuple[Tensor, Tensor]: # Estimates erb mask b, _, t, f8 = e3.shape emb, hidden = self.emb_gru(emb, hidden) @@ -294,7 +335,9 @@ def __init__(self): conv_layer = partial(Conv2dNormAct, separable=True, bias=False) kt = p.df_pathway_kernel_size_t self.conv_buffer_size = kt - 1 - self.df_convp = conv_layer(layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1)) + self.df_convp = conv_layer( + layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1) + ) self.df_gru = SqueezedGRU_S( self.emb_in_dim, @@ -313,7 +356,9 @@ def __init__(self): assert p.emb_hidden_dim == p.df_hidden_dim, "Dimensions do not match" self.df_skip = nn.Identity() elif p.df_gru_skip == "groupedlinear": - self.df_skip = GroupedLinearEinsum(self.emb_in_dim, self.emb_dim, groups=p.lin_groups) + self.df_skip = GroupedLinearEinsum( + self.emb_in_dim, self.emb_dim, groups=p.lin_groups + ) else: raise NotImplementedError() self.df_out: nn.Module @@ -356,11 +401,15 @@ def __init__( self.erb_bins: int = p.nb_erb if p.conv_lookahead > 0: assert p.conv_lookahead >= p.df_lookahead - self.pad_feat = nn.ConstantPad2d((0, 0, -p.conv_lookahead, p.conv_lookahead), 0.0) + self.pad_feat = nn.ConstantPad2d( + (0, 0, -p.conv_lookahead, p.conv_lookahead), 0.0 + ) else: self.pad_feat = nn.Identity() if p.df_lookahead > 0: - self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, p.df_lookahead - 1, -p.df_lookahead + 1), 0.0) + self.pad_spec = nn.ConstantPad3d( + (0, 0, 0, 0, p.df_lookahead - 1, -p.df_lookahead + 1), 0.0 + ) else: self.pad_spec = nn.Identity() self.register_buffer("erb_fb", erb_fb) @@ -369,7 +418,9 @@ def __init__( self.mask = Mask(erb_inv_fb, post_filter=p.mask_pf) self.df_order = p.df_order - self.df_op = MF.DF(num_freqs=p.nb_df, frame_size=p.df_order, lookahead=self.df_lookahead) + self.df_op = MF.DF( + num_freqs=p.nb_df, frame_size=p.df_order, lookahead=self.df_lookahead + ) self.df_dec = DfDecoder() self.df_out_transform = DfOutputReshapeMF(self.df_order, p.nb_df) @@ -406,7 +457,7 @@ def forward( # feat_erb = self.pad_feat(feat_erb) # feat_spec = self.pad_feat(feat_spec) spec = self.pad_spec(spec) - + e0, e1, e2, e3, emb, c0, lsnr, _ = self.enc(feat_erb, feat_spec, hidden=None) if self.lsnr_droput: @@ -427,7 +478,7 @@ def forward( m[:, :, idcs], _ = self.erb_dec(emb, e3, e2, e1, e0, hidden=None) else: m, _ = self.erb_dec(emb, e3, e2, e1, e0, hidden=None) - + pad_spec = F.pad(spec, (0, 0, 0, 0, 1, -1, 0, 0), value=0) spec_m = self.mask(pad_spec, m) else: diff --git a/DeepFilterNet/df/modules.py b/DeepFilterNet/df/modules.py index 0b4612068..9e4c09614 100644 --- a/DeepFilterNet/df/modules.py +++ b/DeepFilterNet/df/modules.py @@ -1006,4 +1006,4 @@ def test_dfop(): dfop.freq_bins = F dfop.set_forward("real_hidden_state_loop") out6 = dfop(spec, coefs, alpha) - torch.testing.assert_allclose(out1, out6) + torch.testing.assert_allclose(out1, out6) \ No newline at end of file diff --git a/torchDF/README.md b/torchDF/README.md index a50fcf601..0c416f9b5 100644 --- a/torchDF/README.md +++ b/torchDF/README.md @@ -34,7 +34,7 @@ poetry run python torch_df_streaming_minimal.py --audio-path examples/A1CIM28ZUC To convert model to onnx and run tests: ``` -poetry run python model_onnx_export.py --test --performance --inference-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --ort +poetry run python model_onnx_export.py --test --performance --inference-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --ort --simplify --profiling --minimal ``` TODO: diff --git a/torchDF/model_onnx_export.py b/torchDF/model_onnx_export.py index 4c4ca142c..b08d72559 100644 --- a/torchDF/model_onnx_export.py +++ b/torchDF/model_onnx_export.py @@ -1,4 +1,3 @@ -import os import copy import onnx import argparse @@ -10,20 +9,15 @@ import onnxruntime as ort import torch.utils.benchmark as benchmark +from torch_df_streaming_minimal import TorchDFMinimalPipeline from torch_df_streaming import TorchDFPipeline from typing import Dict, Iterable +from torch.onnx._internal import jit_utils +from loguru import logger torch.manual_seed(0) -FRAME_SIZE = 480 -INPUT_NAMES = [ - 'input_frame', - 'states', - 'atten_lim_db' -] -OUTPUT_NAMES = [ - 'enhanced_audio_frame', 'out_states', 'lsnr' -] +OPSET_VERSION = 17 def onnx_simplify( @@ -31,7 +25,7 @@ def onnx_simplify( ) -> str: """ Simplify ONNX model using onnxsim and checking it - + Parameters: path: str - Path to ONNX model input_data: Dict[str, np.ndarray] - Input data for ONNX model @@ -53,7 +47,10 @@ def onnx_simplify( onnx.save_model(model_simp, path) return path -def test_onnx_model(torch_model, ort_session, states, atten_lim_db): + +def test_onnx_model( + torch_model, ort_session, states, frame_size, input_names, output_names +): """ Simple test that everything converted correctly @@ -66,30 +63,29 @@ def test_onnx_model(torch_model, ort_session, states, atten_lim_db): states_onnx = copy.deepcopy(states) for i in range(30): - input_frame = torch.randn(FRAME_SIZE) + input_frame = torch.randn(frame_size) # torch - output_torch = torch_model(input_frame, states_torch, atten_lim_db) + output_torch = torch_model(input_frame, *states_torch) # onnx output_onnx = ort_session.run( - OUTPUT_NAMES, - generate_onnx_features([input_frame, states_onnx, atten_lim_db]), + output_names, + generate_onnx_features([input_frame, *states_onnx], input_names), ) - for (x, y, name) in zip(output_torch, output_onnx, OUTPUT_NAMES): + for x, y, name in zip(output_torch, output_onnx, output_names): y_tensor = torch.from_numpy(y) - assert torch.allclose(x, y_tensor, atol=1e-3), f"out {name} - {i}, {x.flatten()[-5:]}, {y_tensor.flatten()[-5:]}" + assert torch.allclose( + x, y_tensor, atol=1e-2 + ), f"out {name} - {i}, {x.flatten()[-5:]}, {y_tensor.flatten()[-5:]}" -def generate_onnx_features(input_features): - return { - x: y.detach().cpu().numpy() - for x, y in zip(INPUT_NAMES, input_features) - } -def perform_benchmark( - ort_session, input_features: Dict[str, np.ndarray], -): +def generate_onnx_features(input_features, input_names): + return {x: y.detach().cpu().numpy() for x, y in zip(input_names, input_features)} + + +def perform_benchmark(ort_session, input_features: Dict[str, np.ndarray], output_names): """ Benchmark ONNX model performance @@ -97,138 +93,225 @@ def perform_benchmark( ort_session: onnxruntime.InferenceSession - Inference Session for converted ONNX model input_features: Dict[str, np.ndarray] - Input features """ + def run_onnx(): output = ort_session.run( - OUTPUT_NAMES, + output_names, input_features, ) - + t0 = benchmark.Timer( - stmt='run_onnx()', + stmt="run_onnx()", num_threads=1, - globals={'run_onnx': run_onnx}, + globals={"run_onnx": run_onnx}, + ) + logger.info( + f"Median iteration time: {t0.blocked_autorange(min_run_time=10).median * 1e3:6.2f} ms / {480 / 48000 * 1000} ms" ) - print(f"Median iteration time: {t0.blocked_autorange(min_run_time=10).median * 1e3:6.2f} ms / {480 / 48000 * 1000} ms") -def infer_onnx_model(streaming_pipeline, ort_session, inference_path): + +def infer_onnx_model( + streaming_pipeline, ort_session, inference_path, input_names, output_names +): """ Inference ONNX model with TorchDFPipeline """ del streaming_pipeline.torch_streaming_model streaming_pipeline.torch_streaming_model = lambda *features: ( - torch.from_numpy(x) for x in ort_session.run( - OUTPUT_NAMES, - generate_onnx_features(list(features)), + torch.from_numpy(x) + for x in ort_session.run( + output_names, + generate_onnx_features(list(features), input_names), ) ) noisy_audio, sr = torchaudio.load(inference_path, channels_first=True) - noisy_audio = noisy_audio.mean(dim=0).unsqueeze(0) # stereo to mono + noisy_audio = noisy_audio.mean(dim=0).unsqueeze(0) # stereo to mono enhanced_audio = streaming_pipeline(noisy_audio, sr) torchaudio.save( - inference_path.replace('.wav', '_onnx_infer.wav'), enhanced_audio, sr, - encoding="PCM_S", bits_per_sample=16 - ) + inference_path.replace(".wav", "_onnx_infer.wav"), + enhanced_audio, + sr, + encoding="PCM_S", + bits_per_sample=16, + ) + + +# setType API provides shape/type to ONNX shape/type inference +def custom_rfft(g: jit_utils.GraphContext, X, n, dim, norm): + x = g.op( + "Unsqueeze", + X, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), + ) + x = g.op( + "Unsqueeze", + x, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + x = g.op("DFT", x, axis_i=1, inverse_i=0, onesided_i=1) + x = g.op( + "Squeeze", + x, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + return x + + +# setType API provides shape/type to ONNX shape/type inference +def custom_identity(g: jit_utils.GraphContext, X): + return X + def main(args): - streaming_pipeline = TorchDFPipeline(always_apply_all_stages=args.always_apply_all_stages, device='cpu') + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + if args.minimal: + streaming_pipeline = TorchDFMinimalPipeline(device="cpu") + else: + streaming_pipeline = TorchDFPipeline(device="cpu", always_apply_all_stages=True) + + frame_size = streaming_pipeline.hop_size + input_names = streaming_pipeline.input_names + output_names = streaming_pipeline.output_names + torch_df = streaming_pipeline.torch_streaming_model states = streaming_pipeline.states - atten_lim_db = streaming_pipeline.atten_lim_db - input_frame = torch.rand(FRAME_SIZE) - input_features = ( - input_frame, states, atten_lim_db + input_frame = torch.rand(frame_size) + input_features = (input_frame, *states) + torch_df(*input_features) # check model + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::fft_rfft", + symbolic_fn=custom_rfft, + opset_version=OPSET_VERSION, + ) + # Only used with aten::fft_rfft, so it's useless in ONNX + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::view_as_real", + symbolic_fn=custom_identity, + opset_version=OPSET_VERSION, ) - torch_df(*input_features) # check model torch_df_script = torch.jit.script(torch_df) + torch.onnx.export( torch_df_script, input_features, args.output_path, verbose=False, - input_names=INPUT_NAMES, - output_names=OUTPUT_NAMES, - opset_version=14 + input_names=input_names, + output_names=output_names, + opset_version=OPSET_VERSION, ) - print(f'Model exported to {args.output_path}!') + logger.info(f"Model exported to {args.output_path}!") - input_features_onnx = generate_onnx_features(input_features) - input_shapes_dict = { - x: y.shape - for x, y in input_features_onnx.items() - } + input_features_onnx = generate_onnx_features(input_features, input_names) + input_shapes_dict = {x: y.shape for x, y in input_features_onnx.items()} - # Simplify not working! + # Simplify not working for not minimal! if args.simplify: - raise NotImplementedError("Simplify not working for flatten states!") + # raise NotImplementedError("Simplify not working for flatten states!") onnx_simplify(args.output_path, input_features_onnx, input_shapes_dict) - print(f'Model simplified! {args.output_path}') + logger.info(f"Model simplified! {args.output_path}") if args.ort: - if subprocess.run([ - 'python', '-m', 'onnxruntime.tools.convert_onnx_models_to_ort', - args.output_path, - '--optimization_style', 'Fixed', - ]).returncode != 0: + if ( + subprocess.run( + [ + "python", + "-m", + "onnxruntime.tools.convert_onnx_models_to_ort", + args.output_path, + "--optimization_style", + "Fixed", + ] + ).returncode + != 0 + ): raise RuntimeError("ONNX to ORT conversion failed!") - print('Model converted to ORT format!') + logger.info("Model converted to ORT format!") - print('Checking model...') + logger.info("Checking model...") sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED + sess_options.graph_optimization_level = ( + ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED + ) sess_options.optimized_model_filepath = args.output_path sess_options.intra_op_num_threads = 1 + sess_options.inter_op_num_threads = 1 sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + sess_options.enable_profiling = args.profiling - ort_session = ort.InferenceSession(args.output_path, sess_options, providers=['CPUExecutionProvider']) - - onnx_outputs = ort_session.run( - OUTPUT_NAMES, - input_features_onnx, + ort_session = ort.InferenceSession( + args.output_path, + sess_options, + providers=["CPUExecutionProvider"], ) - print(f'InferenceSession successful! Output shapes: {[x.shape for x in onnx_outputs]}') + for _ in range(3): + onnx_outputs = ort_session.run( + output_names, + input_features_onnx, + ) + + if args.profiling: + logger.info("Profiling enabled...") + ort_session.end_profiling() + + logger.info( + f"InferenceSession successful! Output shapes: {[x.shape for x in onnx_outputs]}" + ) if args.test: - test_onnx_model(torch_df, ort_session, input_features[1], input_features[2]) - print('Tests passed!') + logger.info("Testing...") + test_onnx_model( + torch_df, + ort_session, + input_features[1:], + frame_size, + input_names, + output_names, + ) + logger.info("Tests passed!") if args.performance: - print('Performanse check...') - perform_benchmark(ort_session, input_features_onnx) + logger.info("Performanse check...") + perform_benchmark(ort_session, input_features_onnx, output_names) if args.inference_path: - infer_onnx_model(streaming_pipeline, ort_session, args.inference_path) - print(f'Audio from {args.inference_path} enhanced!') + infer_onnx_model( + streaming_pipeline, + ort_session, + args.inference_path, + input_names, + output_names, + ) + logger.info(f"Audio from {args.inference_path} enhanced!") -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description='Exporting torchDF model to ONNX' - ) - parser.add_argument( - '--output-path', type=str, default='denoiser_model.onnx', help='Path to output onnx file' - ) - parser.add_argument( - '--simplify', action='store_true', help='Simplify the model' - ) - parser.add_argument( - '--test', action='store_true', help='Test the onnx model' - ) - parser.add_argument( - '--performance', action='store_true', help='Mesure median iteration time for onnx model' - ) - parser.add_argument( - '--inference-path', type=str, help='Run inference on example' - ) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Exporting torchDF model to ONNX") parser.add_argument( - '--ort', action='store_true', help='Save to ort format' + "--output-path", + type=str, + default="denoiser_model.onnx", + help="Path to output onnx file", ) + parser.add_argument("--simplify", action="store_true", help="Simplify the model") + parser.add_argument("--test", action="store_true", help="Test the onnx model") parser.add_argument( - '--always-apply-all-stages', action='store_true', help='Always apply stages' + "--performance", + action="store_true", + help="Mesure median iteration time for onnx model", ) + parser.add_argument("--inference-path", type=str, help="Run inference on example") + parser.add_argument("--ort", action="store_true", help="Save to ort format") + parser.add_argument("--profiling", action="store_true", help="Run ONNX profiler") + parser.add_argument("--minimal", action="store_true", help="Export minimal version") main(parser.parse_args()) diff --git a/torchDF/pyproject.toml b/torchDF/pyproject.toml index 8b99539ba..74ace5ecc 100644 --- a/torchDF/pyproject.toml +++ b/torchDF/pyproject.toml @@ -37,11 +37,13 @@ icecream = { version = ">=2,<3", optional = true } pystoi = { version = "^0.3", optional = true } pesq = { version = ">=0.0.3,<0.0.5", optional = true } scipy = { version = "^1", optional = true } -onnxruntime = { version = "^1.15"} +onnxruntime = "1.17" pytest = "^7.4.0" tqdm = "^4.65.0" onnx = "^1.14.0" onnxsim = "^0.4.33" +onnxscript = "^0.1.0.dev20240131" +onnxruntime-extensions = "^0.10.0" [tool.poetry.extras] train = ["deepfilterdataloader", "icecream"] @@ -57,8 +59,8 @@ deep-filter-py = "df.enhance:run" poethepoet = "^0.21" [tool.poe.tasks] -install-torch-cuda11 = "python -m pip install torch==2.1.0 torchaudio==2.1.0 --extra-index-url https://download.pytorch.org/whl/cu118/" -install-torch-cpu = "python -m pip install torch==1.13.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu/" +install-torch-cuda11 = "python -m pip install torch torchaudio --extra-index-url https://download.pytorch.org/whl/cu118/" +install-torch-cpu = "python -m pip install torch==2.1.0 torchaudio==2.1.0 --extra-index-url https://download.pytorch.org/whl/cpu/" install-eval-utils = "python -m pip install -r requirements_eval.txt" install-dnsmos-utils = "python -m pip install -r requirements_dnsmos.txt" diff --git a/torchDF/test_torchdf.py b/torchDF/test_torchdf.py index 14e34798a..bff821c28 100644 --- a/torchDF/test_torchdf.py +++ b/torchDF/test_torchdf.py @@ -6,33 +6,59 @@ from torch_df_offline import TorchDF from libdf import DFTractPy from torch_df_streaming import TorchDFPipeline +from torch_df_streaming_minimal import TorchDFMinimalPipeline torch.set_num_threads(1) torch.set_num_interop_threads(1) DEVICE = torch.device("cpu") +EXAMPLES_PATH = "" +AUDIO_EXAMPLE_PATH = ( + "examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav" +) -class TestTorchStreaming(): + +class TestTorchStreaming: def __reset(self): torch.manual_seed(23) - - self.model, self.rust_state, _ = init_df(config_allow_defaults=True, model_base_dir='DeepFilterNet3') - self.model.eval() - self.torch_streaming_like_offline = TorchDFPipeline(always_apply_all_stages=True, device=DEVICE) + self.model, self.rust_state, _ = init_df( + config_allow_defaults=True, model_base_dir="DeepFilterNet3" + ) + self.model.eval() - pipeline_for_streaming = TorchDFPipeline(always_apply_all_stages=False, device=DEVICE) + self.torch_streaming_like_offline = TorchDFPipeline( + always_apply_all_stages=True, device=DEVICE + ) + self.torch_streaming_no_stages = ( + self.torch_streaming_like_offline.torch_streaming_model + ) + self.streaming_state_no_stages = self.torch_streaming_like_offline.states[0] + self.atten_lim_db_no_stages = self.torch_streaming_like_offline.atten_lim_db + + pipeline_for_streaming = TorchDFPipeline( + always_apply_all_stages=False, device=DEVICE + ) self.torch_streaming = pipeline_for_streaming.torch_streaming_model - self.streaming_state = pipeline_for_streaming.states + self.streaming_state = pipeline_for_streaming.states[0] self.atten_lim_db = pipeline_for_streaming.atten_lim_db + pipeline_for_minmal_streaming = TorchDFMinimalPipeline(device=DEVICE) + self.torch_streaming_minimal = ( + pipeline_for_minmal_streaming.torch_streaming_model + ) + self.streaming_state_minimal = pipeline_for_minmal_streaming.states + self.torch_offline = TorchDF(copy.deepcopy(self.model)) self.torch_offline = self.torch_offline.to(DEVICE) self.df_tract = DFTractPy() - self.noisy_audio, self.audio_sr = torchaudio.load('examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav', channels_first=True) + self.noisy_audio, self.audio_sr = torchaudio.load( + AUDIO_EXAMPLE_PATH, + channels_first=True, + ) self.noisy_audio = self.noisy_audio.mean(dim=0).unsqueeze(0).to(DEVICE) def test_offline_with_enhance(self): @@ -40,8 +66,12 @@ def test_offline_with_enhance(self): Compare torchDF offline implementation to enhance method """ self.__reset() - enhanced_audio_torch = self.torch_streaming_like_offline(self.noisy_audio, self.audio_sr).to(DEVICE) - enhanced_audio_offline = enhance(self.model, self.rust_state, self.noisy_audio.cpu()).to(DEVICE) + enhanced_audio_torch = self.torch_streaming_like_offline( + self.noisy_audio, self.audio_sr + ).to(DEVICE) + enhanced_audio_offline = enhance( + self.model, self.rust_state, self.noisy_audio.cpu() + ).to(DEVICE) assert torch.allclose(enhanced_audio_torch, enhanced_audio_offline, atol=1e-3) @@ -51,15 +81,19 @@ def test_offline_with_streaming(self): """ self.__reset() - enhanced_audio_torch = self.torch_streaming_like_offline(self.noisy_audio, self.audio_sr).to(DEVICE) - enhanced_audio_offline = self.torch_offline(self.noisy_audio.squeeze(0)).to(DEVICE) + enhanced_audio_torch = self.torch_streaming_like_offline( + self.noisy_audio, self.audio_sr + ).to(DEVICE) + enhanced_audio_offline = self.torch_offline(self.noisy_audio.squeeze(0)).to( + DEVICE + ) assert torch.allclose(enhanced_audio_torch, enhanced_audio_offline, atol=1e-3) def test_streaming_torch_with_tract(self): """ Compare torchDF streaming implementation to tract streaming implementation - + always_apply_all_stages = False """ self.__reset() @@ -67,7 +101,41 @@ def test_streaming_torch_with_tract(self): chunked_audio = torch.split(self.noisy_audio.squeeze(0), 480) for i, chunk in enumerate(chunked_audio): - onnx_output, self.streaming_state, _ = self.torch_streaming(chunk, self.streaming_state, self.atten_lim_db) - rust_output = torch.from_numpy(self.df_tract.process(chunk.unsqueeze(0).cpu().numpy())) + onnx_output, self.streaming_state, _ = self.torch_streaming( + chunk, self.streaming_state, self.atten_lim_db + ) + rust_output = torch.from_numpy( + self.df_tract.process(chunk.unsqueeze(0).cpu().numpy()) + ) + + assert torch.allclose( + onnx_output.to(DEVICE), rust_output.to(DEVICE), atol=1e-3 + ), f"process failed - {i} iteration" + + def test_streaming_torch_with_streaming_torch_minimal(self): + """ + Compare base torchDF streaming implementation with graph optimized version + """ + self.__reset() - assert torch.allclose(onnx_output.to(DEVICE), rust_output.to(DEVICE), atol=1e-3), f'process failed - {i} iteration' + chunked_audio = torch.split(self.noisy_audio.squeeze(0), 480) + + for i, chunk in enumerate(chunked_audio): + ( + base_output, + self.streaming_state_no_stages, + _, + ) = self.torch_streaming_no_stages( + chunk, self.streaming_state_no_stages, self.atten_lim_db_no_stages + ) + ( + minimal_output, + *self.streaming_state_minimal, + ) = self.torch_streaming_minimal(chunk, *self.streaming_state_minimal) + + minimal_output = minimal_output.to(DEVICE) + base_output = base_output.to(DEVICE) + + assert torch.allclose( + minimal_output, base_output, atol=1e-2 + ), f"process failed - {i} iteration, {torch.max(torch.abs(minimal_output - base_output))}" diff --git a/torchDF/torch_df_streaming.py b/torchDF/torch_df_streaming.py index 02a8dc7ae..491d3c130 100644 --- a/torchDF/torch_df_streaming.py +++ b/torchDF/torch_df_streaming.py @@ -1,6 +1,7 @@ """ ONNX exportable classes """ + import math import torch import argparse @@ -16,26 +17,36 @@ class ExportableStreamingTorchDF(nn.Module): - def __init__(self, fft_size, hop_size, nb_bands, - enc, df_dec, erb_dec, df_order=5, lookahead=2, - conv_lookahead=2, nb_df=96, alpha=0.99, - min_db_thresh=-10.0, - max_db_erb_thresh=30.0, - max_db_df_thresh=20.0, - normalize_atten_lim=20.0, - silence_thresh=1e-7, - sr=48000, - always_apply_all_stages=False, - ): + def __init__( + self, + fft_size, + hop_size, + nb_bands, + enc, + df_dec, + erb_dec, + df_order=5, + lookahead=2, + conv_lookahead=2, + nb_df=96, + alpha=0.99, + min_db_thresh=-10.0, + max_db_erb_thresh=30.0, + max_db_df_thresh=20.0, + normalize_atten_lim=20.0, + silence_thresh=1e-7, + sr=48000, + always_apply_all_stages=False, + ): # All complex numbers are stored as floats for ONNX compatibility super().__init__() - + self.fft_size = fft_size - self.frame_size = hop_size # dimension "f" in Float[f] - self.window_size = fft_size + self.frame_size = hop_size # dimension "f" in Float[f] + self.window_size = fft_size self.window_size_h = fft_size // 2 - self.freq_size = fft_size // 2 + 1 # dimension "F" in Float[F] - self.wnorm = 1. / (self.window_size ** 2 / (2 * self.frame_size)) + self.freq_size = fft_size // 2 + 1 # dimension "F" in Float[F] + self.wnorm = 1.0 / (self.window_size**2 / (2 * self.frame_size)) self.df_order = df_order self.lookahead = lookahead self.always_apply_all_stages = torch.tensor(always_apply_all_stages) @@ -45,20 +56,58 @@ def __init__(self, fft_size, hop_size, nb_bands, window = torch.sin( 0.5 * torch.pi * (torch.arange(self.fft_size) + 0.5) / self.window_size_h ) - window = torch.sin(0.5 * torch.pi * window ** 2) - self.register_buffer('window', window) - + window = torch.sin(0.5 * torch.pi * window**2) + self.register_buffer("window", window) + self.nb_df = nb_df # Initializing erb features - self.erb_indices = torch.tensor([ - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 5, 7, 7, 8, - 10, 12, 13, 15, 18, 20, 24, 28, 31, 37, 42, 50, 56, 67 - ]) + self.erb_indices = torch.tensor( + [ + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 5, + 5, + 7, + 7, + 8, + 10, + 12, + 13, + 15, + 18, + 20, + 24, + 28, + 31, + 37, + 42, + 50, + 56, + 67, + ] + ) self.nb_bands = nb_bands - self.register_buffer('forward_erb_matrix', self.erb_fb(self.erb_indices, normalized=True, inverse=False)) - self.register_buffer('inverse_erb_matrix', self.erb_fb(self.erb_indices, normalized=True, inverse=True)) + self.register_buffer( + "forward_erb_matrix", + self.erb_fb(self.erb_indices, normalized=True, inverse=False), + ) + self.register_buffer( + "inverse_erb_matrix", + self.erb_fb(self.erb_indices, normalized=True, inverse=True), + ) # Model self.enc = enc @@ -76,30 +125,74 @@ def __init__(self, fft_size, hop_size, nb_bands, # RFFT # FFT operations are performed as matmuls for ONNX compatability - self.register_buffer('rfft_matrix', torch.view_as_real(torch.fft.rfft(torch.eye(self.window_size))).transpose(0, 1)) - self.register_buffer('irfft_matrix', torch.linalg.pinv(self.rfft_matrix)) + self.register_buffer( + "rfft_matrix", + torch.view_as_real(torch.fft.rfft(torch.eye(self.window_size))).transpose( + 0, 1 + ), + ) + self.register_buffer("irfft_matrix", torch.linalg.pinv(self.rfft_matrix)) # Thresholds - self.register_buffer('min_db_thresh', torch.tensor([min_db_thresh])) - self.register_buffer('max_db_erb_thresh', torch.tensor([max_db_erb_thresh])) - self.register_buffer('max_db_df_thresh', torch.tensor([max_db_df_thresh])) + self.register_buffer("min_db_thresh", torch.tensor([min_db_thresh])) + self.register_buffer("max_db_erb_thresh", torch.tensor([max_db_erb_thresh])) + self.register_buffer("max_db_df_thresh", torch.tensor([max_db_df_thresh])) self.normalize_atten_lim = torch.tensor(normalize_atten_lim) self.silence_thresh = torch.tensor(silence_thresh) - self.linspace_erb = [-60., -90.] + self.linspace_erb = [-60.0, -90.0] self.linspace_df = [0.001, 0.0001] - self.erb_norm_state_shape = (self.nb_bands, ) - self.band_unit_norm_state_shape = (1, self.nb_df, 1) # [bs=1, nb_df, mean of complex value = 1] - self.analysis_mem_shape = (self.frame_size, ) - self.synthesis_mem_shape = (self.frame_size, ) - self.rolling_erb_buf_shape = (1, 1, conv_lookahead + 1, self.nb_bands) # [B, 1, conv kernel size, nb_bands] - self.rolling_feat_spec_buf_shape = (1, 2, conv_lookahead + 1, self.nb_df) # [B, 2 - complex, conv kernel size, nb_df] - self.rolling_c0_buf_shape = (1, self.enc.df_conv0_ch, self.df_order, self.nb_df) # [B, conv hidden, df_order, nb_df] - self.rolling_spec_buf_x_shape = (max(self.df_order, conv_lookahead), self.freq_size, 2) # [number of specs to save, ...] - self.rolling_spec_buf_y_shape = (self.df_order + conv_lookahead, self.freq_size, 2) # [number of specs to save, ...] - self.enc_hidden_shape = (1, 1, self.enc.emb_dim) # [n_layers=1, batch_size=1, emb_dim] - self.erb_dec_hidden_shape = (2, 1, self.erb_dec.emb_dim) # [n_layers=2, batch_size=1, emb_dim] - self.df_dec_hidden_shape = (2, 1, self.df_dec.emb_dim) # [n_layers=2, batch_size=1, emb_dim] + self.erb_norm_state_shape = (self.nb_bands,) + self.band_unit_norm_state_shape = ( + 1, + self.nb_df, + 1, + ) # [bs=1, nb_df, mean of complex value = 1] + self.analysis_mem_shape = (self.frame_size,) + self.synthesis_mem_shape = (self.frame_size,) + self.rolling_erb_buf_shape = ( + 1, + 1, + conv_lookahead + 1, + self.nb_bands, + ) # [B, 1, conv kernel size, nb_bands] + self.rolling_feat_spec_buf_shape = ( + 1, + 2, + conv_lookahead + 1, + self.nb_df, + ) # [B, 2 - complex, conv kernel size, nb_df] + self.rolling_c0_buf_shape = ( + 1, + self.enc.df_conv0_ch, + self.df_order, + self.nb_df, + ) # [B, conv hidden, df_order, nb_df] + self.rolling_spec_buf_x_shape = ( + max(self.df_order, conv_lookahead), + self.freq_size, + 2, + ) # [number of specs to save, ...] + self.rolling_spec_buf_y_shape = ( + self.df_order + conv_lookahead, + self.freq_size, + 2, + ) # [number of specs to save, ...] + self.enc_hidden_shape = ( + 1, + 1, + self.enc.emb_dim, + ) # [n_layers=1, batch_size=1, emb_dim] + self.erb_dec_hidden_shape = ( + 2, + 1, + self.erb_dec.emb_dim, + ) # [n_layers=2, batch_size=1, emb_dim] + self.df_dec_hidden_shape = ( + 2, + 1, + self.df_dec.emb_dim, + ) # [n_layers=2, batch_size=1, emb_dim] # States state_shapes = [ @@ -114,16 +207,16 @@ def __init__(self, fft_size, hop_size, nb_bands, self.rolling_spec_buf_y_shape, self.enc_hidden_shape, self.erb_dec_hidden_shape, - self.df_dec_hidden_shape - ] - self.state_lens = [ - math.prod(x) for x in state_shapes + self.df_dec_hidden_shape, ] + self.state_lens = [math.prod(x) for x in state_shapes] self.states_full_len = sum(self.state_lens) # Zero buffers - self.register_buffer('zero_gains', torch.zeros(self.nb_bands)) - self.register_buffer('zero_coefs', torch.zeros(self.rolling_c0_buf_shape[2], self.nb_df, 2)) + self.register_buffer("zero_gains", torch.zeros(self.nb_bands)) + self.register_buffer( + "zero_coefs", torch.zeros(self.rolling_c0_buf_shape[2], self.nb_df, 2) + ) @staticmethod def remove_conv_block_padding(original_conv: nn.Module) -> nn.Module: @@ -132,7 +225,7 @@ def remove_conv_block_padding(original_conv: nn.Module) -> nn.Module: Parameters: original_conv: nn.Module - original convolution module - + Returns: output: nn.Module - new convolution module without paddings """ @@ -141,10 +234,12 @@ def remove_conv_block_padding(original_conv: nn.Module) -> nn.Module: for module in original_conv: if not isinstance(module, nn.ConstantPad2d): new_modules.append(module) - + return nn.Sequential(*new_modules) - - def erb_fb(self, widths: Tensor, normalized: bool = True, inverse: bool = False) -> Tensor: + + def erb_fb( + self, widths: Tensor, normalized: bool = True, inverse: bool = False + ) -> Tensor: """ Generate the erb filterbank Taken from https://github.com/Rikorose/DeepFilterNet/blob/fa926662facea33657c255fd1f3a083ddc696220/DeepFilterNet/df/modules.py#L206 @@ -161,7 +256,9 @@ def erb_fb(self, widths: Tensor, normalized: bool = True, inverse: bool = False) n_freqs = int(torch.sum(widths)) all_freqs = torch.linspace(0, self.sr // 2, n_freqs + 1)[:-1] - b_pts = torch.cumsum(torch.cat([torch.tensor([0]), widths]), dtype=torch.int32, dim=0)[:-1] + b_pts = torch.cumsum( + torch.cat([torch.tensor([0]), widths]), dtype=torch.int32, dim=0 + )[:-1] fb = torch.zeros((all_freqs.shape[0], b_pts.shape[0])) for i, (b, w) in enumerate(zip(b_pts.tolist(), widths.tolist())): @@ -186,36 +283,44 @@ def mul_complex(t1, t2): Parameters: t1: Float[F, 2] - First number t2: Float[F, 2] - Second number - + Returns: output: Float[F, 2] - final multiplication of two complex numbers """ - t1_real = t1[..., 0] + t1_real = t1[..., 0] t1_imag = t1[..., 1] t2_real = t2[..., 0] t2_imag = t2[..., 1] - return torch.stack((t1_real * t2_real - t1_imag * t2_imag, t1_real * t2_imag + t1_imag * t2_real), dim=-1) - + return torch.stack( + ( + t1_real * t2_real - t1_imag * t2_imag, + t1_real * t2_imag + t1_imag * t2_real, + ), + dim=-1, + ) + def erb(self, input_data: Tensor, erb_eps: float = 1e-10) -> Tensor: """ Original code - pyDF/src/lib.rs - erb() Calculating ERB features for each frame. Parameters: - input_data: Float[T, F] or Float[F] - audio spectrogram + input_data: Float[T, F] or Float[F] - audio spectrogram Returns: erb_features: Float[T, ERB] or Float[ERB] - erb features for given spectrogram """ - magnitude_squared = torch.sum(input_data ** 2, dim=-1) + magnitude_squared = torch.sum(input_data**2, dim=-1) erb_features = magnitude_squared.matmul(self.forward_erb_matrix) erb_features_db = 10.0 * torch.log10(erb_features + erb_eps) return erb_features_db - + @staticmethod - def band_mean_norm_erb(xs: Tensor, erb_norm_state: Tensor, alpha: float, denominator: float = 40.0) -> Tuple[Tensor, Tensor]: + def band_mean_norm_erb( + xs: Tensor, erb_norm_state: Tensor, alpha: float, denominator: float = 40.0 + ) -> Tuple[Tensor, Tensor]: """ Original code - libDF/src/lib.rs - band_mean_norm() Normalizing ERB features. And updates the normalization state. @@ -232,11 +337,13 @@ def band_mean_norm_erb(xs: Tensor, erb_norm_state: Tensor, alpha: float, denomin """ new_erb_norm_state = torch.lerp(xs, erb_norm_state, alpha) output = (xs - new_erb_norm_state) / denominator - + return output, new_erb_norm_state - @staticmethod - def band_unit_norm(xs: Tensor, band_unit_norm_state, alpha: float) -> Tuple[Tensor, Tensor]: + @staticmethod + def band_unit_norm( + xs: Tensor, band_unit_norm_state, alpha: float + ) -> Tuple[Tensor, Tensor]: """ Original code - libDF/src/lib.rs - band_unit_norm() Normalizing Deep Filtering features. And updates the normalization state. @@ -250,13 +357,15 @@ def band_unit_norm(xs: Tensor, band_unit_norm_state, alpha: float) -> Tuple[Tens output: Float[1, DF] - normalized deep filtering features band_unit_norm_state: Float[1, DF, 1] - updated normalization state """ - xs_abs = torch.linalg.norm(xs, dim=-1, keepdim=True) # xs.abs() from complex + xs_abs = torch.linalg.norm(xs, dim=-1, keepdim=True) # xs.abs() from complex new_band_unit_norm_state = torch.lerp(xs_abs, band_unit_norm_state, alpha) output = xs / new_band_unit_norm_state.sqrt() - + return output, new_band_unit_norm_state - def frame_analysis(self, input_frame: Tensor, analysis_mem: Tensor) -> Tuple[Tensor, Tensor]: + def frame_analysis( + self, input_frame: Tensor, analysis_mem: Tensor + ) -> Tuple[Tensor, Tensor]: """ Original code - libDF/src/lib.rs - frame_analysis() Calculating spectrograme for one frame. Every frame is concated with buffer from previous frame. @@ -264,7 +373,7 @@ def frame_analysis(self, input_frame: Tensor, analysis_mem: Tensor) -> Tuple[Ten Parameters: input_frame: Float[f] - Input raw audio frame analysis_mem: Float[f] - Previous frame - + Returns: output: Float[F, 2] - Spectrogram analysis_mem: Float[f] - Saving current frame for next iteration @@ -274,10 +383,12 @@ def frame_analysis(self, input_frame: Tensor, analysis_mem: Tensor) -> Tuple[Ten buf = torch.cat([analysis_mem, input_frame]) * self.window rfft_buf = torch.matmul(buf, self.rfft_matrix) * self.wnorm - # Copy input to analysis_mem for next iteration + # Copy input to analysis_mem for next iteration return rfft_buf, input_frame - - def frame_synthesis(self, x: Tensor, synthesis_mem: Tensor) -> Tuple[Tensor, Tensor]: + + def frame_synthesis( + self, x: Tensor, synthesis_mem: Tensor + ) -> Tuple[Tensor, Tensor]: """ Original code - libDF/src/lib.rs - frame_synthesis() Inverse rfft for one frame. Every frame is summarized with buffer from previous frame. @@ -294,10 +405,16 @@ def frame_synthesis(self, x: Tensor, synthesis_mem: Tensor) -> Tuple[Tensor, Ten # x - [F=481, 2] # self.irfft_matrix - [fft_size=481, 2, f=960] # [f=960] - x = torch.einsum('fi,fij->j', x, self.irfft_matrix) * self.fft_size * self.window + x = ( + torch.einsum("fi,fij->j", x, self.irfft_matrix) + * self.fft_size + * self.window + ) - x_first, x_second = torch.split(x, [self.frame_size, self.window_size - self.frame_size]) - output = x_first + synthesis_mem + x_first, x_second = torch.split( + x, [self.frame_size, self.window_size - self.frame_size] + ) + output = x_first + synthesis_mem return output, x_second @@ -314,9 +431,11 @@ def is_apply_gains(self, lsnr: Tensor) -> Tensor: """ if self.always_apply_all_stages: return torch.ones_like(lsnr, dtype=torch.bool) - - return torch.le(lsnr, self.max_db_erb_thresh) * torch.ge(lsnr, self.min_db_thresh) - + + return torch.le(lsnr, self.max_db_erb_thresh) * torch.ge( + lsnr, self.min_db_thresh + ) + def is_apply_gain_zeros(self, lsnr: Tensor) -> Tensor: """ Original code - libDF/src/tract.rs - is_apply_stages() @@ -330,10 +449,10 @@ def is_apply_gain_zeros(self, lsnr: Tensor) -> Tensor: """ if self.always_apply_all_stages: return torch.zeros_like(lsnr, dtype=torch.bool) - + # Only noise detected, just apply a zero mask return torch.ge(self.min_db_thresh, lsnr) - + def is_apply_df(self, lsnr: Tensor) -> Tensor: """ Original code - libDF/src/tract.rs - is_apply_stages() @@ -347,8 +466,10 @@ def is_apply_df(self, lsnr: Tensor) -> Tensor: """ if self.always_apply_all_stages: return torch.ones_like(lsnr, dtype=torch.bool) - - return torch.le(lsnr, self.max_db_df_thresh) * torch.ge(lsnr, self.min_db_thresh) + + return torch.le(lsnr, self.max_db_df_thresh) * torch.ge( + lsnr, self.min_db_thresh + ) def apply_mask(self, spec: Tensor, gains: Tensor) -> Tensor: """ @@ -365,10 +486,12 @@ def apply_mask(self, spec: Tensor, gains: Tensor) -> Tensor: """ gains = gains.matmul(self.inverse_erb_matrix) spec = spec * gains.unsqueeze(-1) - + return spec - - def deep_filter(self, gain_spec: Tensor, coefs: Tensor, rolling_spec_buf_x: Tensor) -> Tensor: + + def deep_filter( + self, gain_spec: Tensor, coefs: Tensor, rolling_spec_buf_x: Tensor + ) -> Tensor: """ Original code - libDF/src/tract.rs - df() @@ -379,16 +502,31 @@ def deep_filter(self, gain_spec: Tensor, coefs: Tensor, rolling_spec_buf_x: Tens gain_spec: Float[F, 2] - spectrogram after ERB gains applied coefs: Float[DF, BUF, 2] - coefficients for deep filtering from df decoder rolling_spec_buf_x: Float[buffer_size, F, 2] - spectrograms from past / future - + Returns: gain_spec: Float[F, 2] - spectrogram after deep filtering """ - stacked_input_specs = rolling_spec_buf_x[:, :self.nb_df] + stacked_input_specs = rolling_spec_buf_x[:, : self.nb_df] mult = self.mul_complex(stacked_input_specs, coefs) - gain_spec[:self.nb_df] = torch.sum(mult, dim=0) + gain_spec[: self.nb_df] = torch.sum(mult, dim=0) return gain_spec - - def unpack_states(self, states: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + + def unpack_states( + self, states: Tensor + ) -> Tuple[ + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + ]: splitted_states = torch.split(states, self.state_lens) erb_norm_state = splitted_states[0].view(self.erb_norm_state_shape) @@ -396,7 +534,9 @@ def unpack_states(self, states: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, analysis_mem = splitted_states[2].view(self.analysis_mem_shape) synthesis_mem = splitted_states[3].view(self.synthesis_mem_shape) rolling_erb_buf = splitted_states[4].view(self.rolling_erb_buf_shape) - rolling_feat_spec_buf = splitted_states[5].view(self.rolling_feat_spec_buf_shape) + rolling_feat_spec_buf = splitted_states[5].view( + self.rolling_feat_spec_buf_shape + ) rolling_c0_buf = splitted_states[6].view(self.rolling_c0_buf_shape) rolling_spec_buf_x = splitted_states[7].view(self.rolling_spec_buf_x_shape) rolling_spec_buf_y = splitted_states[8].view(self.rolling_spec_buf_y_shape) @@ -404,40 +544,57 @@ def unpack_states(self, states: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, erb_dec_hidden = splitted_states[10].view(self.erb_dec_hidden_shape) df_dec_hidden = splitted_states[11].view(self.df_dec_hidden_shape) - new_erb_norm_state = torch.linspace( - self.linspace_erb[0], self.linspace_erb[1], self.nb_bands, device=erb_norm_state.device - ).view(self.erb_norm_state_shape).to(torch.float32) # float() to fix export issue - new_band_unit_norm_state = torch.linspace( - self.linspace_df[0], self.linspace_df[1], self.nb_df, device=band_unit_norm_state.device - ).view(self.band_unit_norm_state_shape).to(torch.float32) # float() to fix export issue + new_erb_norm_state = ( + torch.linspace( + self.linspace_erb[0], + self.linspace_erb[1], + self.nb_bands, + device=erb_norm_state.device, + ) + .view(self.erb_norm_state_shape) + .to(torch.float32) + ) # float() to fix export issue + new_band_unit_norm_state = ( + torch.linspace( + self.linspace_df[0], + self.linspace_df[1], + self.nb_df, + device=band_unit_norm_state.device, + ) + .view(self.band_unit_norm_state_shape) + .to(torch.float32) + ) # float() to fix export issue erb_norm_state = torch.where( torch.tensor(torch.nonzero(erb_norm_state).shape[0] == 0), new_erb_norm_state, - erb_norm_state + erb_norm_state, ) - + band_unit_norm_state = torch.where( torch.tensor(torch.nonzero(band_unit_norm_state).shape[0] == 0), new_band_unit_norm_state, - band_unit_norm_state + band_unit_norm_state, ) return ( - erb_norm_state, band_unit_norm_state, - analysis_mem, synthesis_mem, - rolling_erb_buf, rolling_feat_spec_buf, rolling_c0_buf, - rolling_spec_buf_x, rolling_spec_buf_y, - enc_hidden, erb_dec_hidden, df_dec_hidden + erb_norm_state, + band_unit_norm_state, + analysis_mem, + synthesis_mem, + rolling_erb_buf, + rolling_feat_spec_buf, + rolling_c0_buf, + rolling_spec_buf_x, + rolling_spec_buf_y, + enc_hidden, + erb_dec_hidden, + df_dec_hidden, ) - - def forward(self, - input_frame: Tensor, - states: Tensor, - atten_lim_db: Tensor - ) -> Tuple[ - Tensor, Tensor, Tensor - ]: + + def forward( + self, input_frame: Tensor, states: Tensor, atten_lim_db: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: """ Enhancing input audio frame @@ -452,49 +609,70 @@ def forward(self, lsnr: Float[1] - Estimated lsnr of input frame """ - assert input_frame.ndim == 1, 'only bs=1 and t=frame_size supported' - assert input_frame.shape[0] == self.frame_size, 'input_frame must be bs=1 and t=frame_size' + assert input_frame.ndim == 1, "only bs=1 and t=frame_size supported" + assert ( + input_frame.shape[0] == self.frame_size + ), "input_frame must be bs=1 and t=frame_size" ( - erb_norm_state, band_unit_norm_state, - analysis_mem, synthesis_mem, - rolling_erb_buf, rolling_feat_spec_buf, rolling_c0_buf, - rolling_spec_buf_x, rolling_spec_buf_y, - enc_hidden, erb_dec_hidden, df_dec_hidden + erb_norm_state, + band_unit_norm_state, + analysis_mem, + synthesis_mem, + rolling_erb_buf, + rolling_feat_spec_buf, + rolling_c0_buf, + rolling_spec_buf_x, + rolling_spec_buf_y, + enc_hidden, + erb_dec_hidden, + df_dec_hidden, ) = self.unpack_states(states) # If input_frame is silent, then do nothing and return zeros - rms_non_silence_condition = (input_frame ** 2).sum() / self.frame_size >= self.silence_thresh - rms_non_silence_condition = torch.logical_or(rms_non_silence_condition, self.always_apply_all_stages) + rms_non_silence_condition = ( + input_frame**2 + ).sum() / self.frame_size >= self.silence_thresh + rms_non_silence_condition = torch.logical_or( + rms_non_silence_condition, self.always_apply_all_stages + ) spectrogram, new_analysis_mem = self.frame_analysis(input_frame, analysis_mem) - spectrogram = spectrogram.unsqueeze(0) # [1, freq_size, 2] reshape needed for easier stacking buffers - new_rolling_spec_buf_x = torch.cat([rolling_spec_buf_x[1:, ...], spectrogram]) # [n_frames=5, 481, 2] + spectrogram = spectrogram.unsqueeze( + 0 + ) # [1, freq_size, 2] reshape needed for easier stacking buffers + new_rolling_spec_buf_x = torch.cat( + [rolling_spec_buf_x[1:, ...], spectrogram] + ) # [n_frames=5, 481, 2] # rolling_spec_buf_y - [n_frames=7, 481, 2] n_frames=7 for compatability with original code, but in code we use only one frame new_rolling_spec_buf_y = torch.cat([rolling_spec_buf_y[1:, ...], spectrogram]) erb_feat, new_erb_norm_state = self.band_mean_norm_erb( self.erb(spectrogram).squeeze(0), erb_norm_state, alpha=self.alpha - ) # [ERB] + ) # [ERB] spec_feat, new_band_unit_norm_state = self.band_unit_norm( - spectrogram[:, :self.nb_df], band_unit_norm_state, alpha=self.alpha - ) # [1, DF, 2] + spectrogram[:, : self.nb_df], band_unit_norm_state, alpha=self.alpha + ) # [1, DF, 2] - erb_feat = erb_feat[None, None, None, ...] # [b=1, conv_input_dim=1, t=1, n_erb=32] - spec_feat = spec_feat[None, ...].permute(0, 3, 1, 2) # [bs=1, conv_input_dim=2, t=1, df_order=96] + erb_feat = erb_feat[ + None, None, None, ... + ] # [b=1, conv_input_dim=1, t=1, n_erb=32] + spec_feat = spec_feat[None, ...].permute( + 0, 3, 1, 2 + ) # [bs=1, conv_input_dim=2, t=1, df_order=96] # (1, 1, T, self.nb_bands) new_rolling_erb_buf = torch.cat([rolling_erb_buf[:, :, 1:, :], erb_feat], dim=2) # (1, 2, T, self.nb_df) - new_rolling_feat_spec_buf = torch.cat([rolling_feat_spec_buf[:, :, 1:, :], spec_feat], dim=2) + new_rolling_feat_spec_buf = torch.cat( + [rolling_feat_spec_buf[:, :, 1:, :], spec_feat], dim=2 + ) e0, e1, e2, e3, emb, c0, lsnr, new_enc_hidden = self.enc( - new_rolling_erb_buf, - new_rolling_feat_spec_buf, - enc_hidden + new_rolling_erb_buf, new_rolling_feat_spec_buf, enc_hidden ) - lsnr = lsnr.flatten() # [b=1, t=1, 1] -> 1 + lsnr = lsnr.flatten() # [b=1, t=1, 1] -> 1 apply_gains = self.is_apply_gains(lsnr) apply_df = self.is_apply_df(lsnr) @@ -502,24 +680,26 @@ def forward(self, # erb_dec # [BS=1, 1, T=1, ERB] - new_gains, new_erb_dec_hidden = self.erb_dec(emb, e3, e2, e1, e0, erb_dec_hidden) + new_gains, new_erb_dec_hidden = self.erb_dec( + emb, e3, e2, e1, e0, erb_dec_hidden + ) gains = torch.where(apply_gains, new_gains.view(self.nb_bands), self.zero_gains) - new_erb_dec_hidden = torch.where(apply_gains, new_erb_dec_hidden, erb_dec_hidden) + new_erb_dec_hidden = torch.where( + apply_gains, new_erb_dec_hidden, erb_dec_hidden + ) # df_dec new_rolling_c0_buf = torch.cat([rolling_c0_buf[:, :, 1:, :], c0], dim=2) # new_coefs - [BS=1, T=1, F, O*2] new_coefs, new_df_dec_hidden = self.df_dec( - emb, - new_rolling_c0_buf, - df_dec_hidden + emb, new_rolling_c0_buf, df_dec_hidden ) new_rolling_c0_buf = torch.where(apply_df, new_rolling_c0_buf, rolling_c0_buf) new_df_dec_hidden = torch.where(apply_df, new_df_dec_hidden, df_dec_hidden) coefs = torch.where( - apply_df, - new_coefs.view(self.nb_df, -1, 2).permute(1, 0, 2), - self.zero_coefs + apply_df, + new_coefs.view(self.nb_df, -1, 2).permute(1, 0, 2), + self.zero_coefs, ) # Applying features @@ -527,62 +707,100 @@ def forward(self, current_spec = torch.where( torch.logical_or(apply_gains, apply_gain_zeros), self.apply_mask(current_spec.clone(), gains), - current_spec + current_spec, ) current_spec = torch.where( - apply_df, + apply_df, self.deep_filter(current_spec.clone(), coefs, new_rolling_spec_buf_x), - current_spec + current_spec, ) # Mixing some noisy channel # taken from https://github.com/Rikorose/DeepFilterNet/blob/59789e135cb5ed0eb86bb50e8f1be09f60859d5c/DeepFilterNet/df/enhance.py#L237 if torch.abs(atten_lim_db) > 0: - spec_noisy = rolling_spec_buf_x[max(self.lookahead, self.df_order) - self.lookahead - 1] + spec_noisy = rolling_spec_buf_x[ + max(self.lookahead, self.df_order) - self.lookahead - 1 + ] lim = 10 ** (-torch.abs(atten_lim_db) / self.normalize_atten_lim) current_spec = torch.lerp(current_spec, spec_noisy, lim) - enhanced_audio_frame, new_synthesis_mem = self.frame_synthesis(current_spec, synthesis_mem) + enhanced_audio_frame, new_synthesis_mem = self.frame_synthesis( + current_spec, synthesis_mem + ) new_states = [ - new_erb_norm_state, new_band_unit_norm_state, - new_analysis_mem, new_synthesis_mem, - new_rolling_erb_buf, new_rolling_feat_spec_buf, new_rolling_c0_buf, - new_rolling_spec_buf_x, new_rolling_spec_buf_y, - new_enc_hidden, new_erb_dec_hidden, new_df_dec_hidden + new_erb_norm_state, + new_band_unit_norm_state, + new_analysis_mem, + new_synthesis_mem, + new_rolling_erb_buf, + new_rolling_feat_spec_buf, + new_rolling_c0_buf, + new_rolling_spec_buf_x, + new_rolling_spec_buf_y, + new_enc_hidden, + new_erb_dec_hidden, + new_df_dec_hidden, ] new_states = torch.cat([x.flatten() for x in new_states]) # RMS conditioning for better ONNX graph - enhanced_audio_frame = torch.where(rms_non_silence_condition, enhanced_audio_frame, torch.zeros_like(enhanced_audio_frame)) + enhanced_audio_frame = torch.where( + rms_non_silence_condition, + enhanced_audio_frame, + torch.zeros_like(enhanced_audio_frame), + ) new_states = torch.where(rms_non_silence_condition, new_states, states) return enhanced_audio_frame, new_states, lsnr + class TorchDFPipeline(nn.Module): def __init__( - self, nb_bands=32, hop_size=480, fft_size=960, - df_order=5, conv_lookahead=2, nb_df=96, model_base_dir='DeepFilterNet3', - atten_lim_db=0.0, always_apply_all_stages=False, device='cpu' - ): + self, + nb_bands=32, + hop_size=480, + fft_size=960, + df_order=5, + conv_lookahead=2, + nb_df=96, + model_base_dir="DeepFilterNet3", + atten_lim_db=0.0, + always_apply_all_stages=False, + device="cpu", + ): super().__init__() self.hop_size = hop_size self.fft_size = fft_size - model, state, _ = init_df(config_allow_defaults=True, model_base_dir=model_base_dir) + model, state, _ = init_df( + config_allow_defaults=True, model_base_dir=model_base_dir + ) model.eval() self.sample_rate = state.sr() self.torch_streaming_model = ExportableStreamingTorchDF( - nb_bands=nb_bands, hop_size=hop_size, fft_size=fft_size, - enc=model.enc, df_dec=model.df_dec, erb_dec=model.erb_dec, df_order=df_order, + nb_bands=nb_bands, + hop_size=hop_size, + fft_size=fft_size, + enc=model.enc, + df_dec=model.df_dec, + erb_dec=model.erb_dec, + df_order=df_order, always_apply_all_stages=always_apply_all_stages, - conv_lookahead=conv_lookahead, nb_df=nb_df, sr=self.sample_rate + conv_lookahead=conv_lookahead, + nb_df=nb_df, + sr=self.sample_rate, ) self.torch_streaming_model = self.torch_streaming_model.to(device) - self.states = torch.zeros(self.torch_streaming_model.states_full_len, device=device) - self.atten_lim_db = torch.tensor(atten_lim_db, device=device) + self.states = [ + torch.zeros(self.torch_streaming_model.states_full_len, device=device), + self.atten_lim_db, + ] + + self.input_names = ["input_frame", "states", "atten_lim_db"] + self.output_names = ["enhanced_audio_frame", "new_states", "lsnr"] def forward(self, input_audio: Tensor, sample_rate: int) -> Tensor: """ @@ -595,36 +813,43 @@ def forward(self, input_audio: Tensor, sample_rate: int) -> Tensor: Returns: enhanced_audio: Float[1, t] - Enhanced input audio """ - assert input_audio.shape[0] == 1, f'Only mono supported! Got wrong shape! {input_audio.shape}' - assert sample_rate == self.sample_rate, f'Only {self.sample_rate} supported! Got wrong sample rate! {sample_rate}' + assert ( + input_audio.shape[0] == 1 + ), f"Only mono supported! Got wrong shape! {input_audio.shape}" + assert ( + sample_rate == self.sample_rate + ), f"Only {self.sample_rate} supported! Got wrong sample rate! {sample_rate}" input_audio = input_audio.squeeze(0) orig_len = input_audio.shape[0] # padding taken from # https://github.com/Rikorose/DeepFilterNet/blob/fa926662facea33657c255fd1f3a083ddc696220/DeepFilterNet/df/enhance.py#L229 - hop_size_divisible_padding_size = (self.hop_size - orig_len % self.hop_size) % self.hop_size + hop_size_divisible_padding_size = ( + self.hop_size - orig_len % self.hop_size + ) % self.hop_size orig_len += hop_size_divisible_padding_size - input_audio = F.pad(input_audio, (0, self.fft_size + hop_size_divisible_padding_size)) - + input_audio = F.pad( + input_audio, (0, self.fft_size + hop_size_divisible_padding_size) + ) + chunked_audio = torch.split(input_audio, self.hop_size) output_frames = [] for input_frame in chunked_audio: - ( - enhanced_audio_frame, self.states, lsnr - ) = self.torch_streaming_model( - input_frame, - self.states, - self.atten_lim_db + enhanced_audio_frame, states, _ = self.torch_streaming_model( + input_frame, *self.states ) - + self.states = [states, self.atten_lim_db] + output_frames.append(enhanced_audio_frame) - enhanced_audio = torch.cat(output_frames).unsqueeze(0) # [t] -> [1, t] typical mono format + enhanced_audio = torch.cat(output_frames).unsqueeze( + 0 + ) # [t] -> [1, t] typical mono format - # taken from + # taken from # https://github.com/Rikorose/DeepFilterNet/blob/fa926662facea33657c255fd1f3a083ddc696220/DeepFilterNet/df/enhance.py#L248 d = self.fft_size - self.hop_size enhanced_audio = enhanced_audio[:, d : orig_len + d] @@ -636,36 +861,46 @@ def main(args): torch.set_num_threads(1) torch.set_num_interop_threads(1) - torch_df = TorchDFPipeline(device=args.device, always_apply_all_stages=args.always_apply_all_stages) + torch_df = TorchDFPipeline( + device=args.device, always_apply_all_stages=args.always_apply_all_stages + ) # torchaudio normalize=True, fp32 return noisy_audio, sr = torchaudio.load(args.audio_path, channels_first=True) - noisy_audio = noisy_audio.mean(dim=0).unsqueeze(0).to(args.device) # stereo to mono + if sr != torch_df.sample_rate: + resample = torchaudio.transforms.Resample( + orig_freq=sr, new_freq=torch_df.sample_rate + ) + noisy_audio = resample(noisy_audio) + sr = torch_df.sample_rate + noisy_audio = noisy_audio.mean(dim=0).unsqueeze(0).to(args.device) # stereo to mono enhanced_audio = torch_df(noisy_audio, sr).detach().cpu() torchaudio.save( - args.output_path, enhanced_audio, sr, - encoding="PCM_S", bits_per_sample=16 + args.output_path, enhanced_audio, sr, encoding="PCM_S", bits_per_sample=16 ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser( - description='Denoising one audio with DF3 model using torch only' + description="Denoising one audio with DF3 model using torch only" ) parser.add_argument( - '--audio-path', type=str, required=True, help='Path to audio file' + "--audio-path", type=str, required=True, help="Path to audio file" ) parser.add_argument( - '--output-path', type=str, required=True, help='Path to output file' + "--output-path", type=str, required=True, help="Path to output file" ) parser.add_argument( - '--device', type=str, default='cpu', choices=['cuda', 'cpu'], help='Device to run on' + "--device", + type=str, + default="cpu", + choices=["cuda", "cpu"], + help="Device to run on", ) parser.add_argument( - '--always-apply-all-stages', action='store_true' + "--always-apply-all-stages", action="store_true", help="Apply all stages or not" ) - main(parser.parse_args()) - \ No newline at end of file + main(parser.parse_args()) diff --git a/torchDF/torch_df_streaming_minimal.py b/torchDF/torch_df_streaming_minimal.py new file mode 100644 index 000000000..219184b4a --- /dev/null +++ b/torchDF/torch_df_streaming_minimal.py @@ -0,0 +1,1268 @@ +""" +ONNX exportable classes +""" + +import math +import torch +import argparse +import torchaudio + +from torch.nn import functional as F + +from torch import nn +from torch import Tensor +from typing import Tuple, List + +from df import init_df + +from functools import partial +from typing import Callable, List, Optional, Tuple + +import torch +from torch import Tensor, nn +import torch.nn.functional as F + +from df.config import Csv, DfParams, config +from df.modules import Conv2dNormAct, ConvTranspose2dNormAct + +from typing_extensions import Final +from torch.nn.parameter import Parameter +from torch.nn import init + + +from torch.autograd import Function + + +class OnnxComplexMul(Function): + """Auto-grad function to mimic irfft for ONNX exporting""" + + @staticmethod + def forward(ctx, input_0: torch.Tensor, input_1: torch.Tensor) -> torch.Tensor: + return torch.view_as_real( + torch.view_as_complex(input_0) * torch.view_as_complex(input_1) + ) + + @staticmethod + def symbolic( + g: torch.Graph, input_0: torch.Value, input_1: torch.Value + ) -> torch.Value: + """Symbolic representation for onnx graph""" + return g.op("ai.onnx.contrib::ComplexMul", input_0, input_1) + + +class SqueezedGRU_S(nn.Module): + input_size: Final[int] + hidden_size: Final[int] + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: Optional[int] = None, + num_layers: int = 1, + linear_groups: int = 8, + batch_first: bool = True, + gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None, + linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity, + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.linear_in = nn.Sequential( + GroupedLinearEinsum( + input_size, hidden_size, linear_groups, linear_act_layer() + ), + ) + self.gru = nn.GRU( + hidden_size, hidden_size, num_layers=num_layers, batch_first=batch_first + ) + self.gru_skip = gru_skip_op() if gru_skip_op is not None else None + if output_size is not None: + self.linear_out = nn.Sequential( + GroupedLinearEinsum( + hidden_size, output_size, linear_groups, linear_act_layer() + ), + ) + else: + self.linear_out = nn.Identity() + + def forward(self, input: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: + x = self.linear_in(input) + x, h = self.gru(x, h) + x = self.linear_out(x) + if self.gru_skip is not None: + x = x + self.gru_skip(input) + return x, h + + +class GroupedLinearEinsum(nn.Module): + input_size: Final[int] + hidden_size: Final[int] + groups: Final[int] + + def __init__( + self, + input_size: int, + hidden_size: int, + groups: int = 1, + activation=nn.Identity(), + ): + super().__init__() + # self.weight: Tensor + self.input_size = input_size + self.hidden_size = hidden_size + self.groups = groups + assert ( + input_size % groups == 0 + ), f"Input size {input_size} not divisible by {groups}" + assert ( + hidden_size % groups == 0 + ), f"Hidden size {hidden_size} not divisible by {groups}" + self.ws = input_size // groups + self.register_parameter( + "weight", + Parameter( + torch.zeros(groups, input_size // groups, hidden_size // groups), + requires_grad=True, + ), + ) + self.reset_parameters() + self.activation = activation + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore + + def forward(self, x: Tensor) -> Tensor: + x = x.reshape(self.groups, 1, self.ws) + x = torch.matmul(x, self.weight) + x = self.activation(x) + return x.view(1, 1, -1) + + def __repr__(self): + cls = self.__class__.__name__ + return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" + + +class ModelParams(DfParams): + section = "deepfilternet" + + def __init__(self): + super().__init__() + self.conv_lookahead: int = config( + "CONV_LOOKAHEAD", cast=int, default=0, section=self.section + ) + self.conv_ch: int = config( + "CONV_CH", cast=int, default=16, section=self.section + ) + self.conv_depthwise: bool = config( + "CONV_DEPTHWISE", cast=bool, default=True, section=self.section + ) + self.convt_depthwise: bool = config( + "CONVT_DEPTHWISE", cast=bool, default=True, section=self.section + ) + self.conv_kernel: List[int] = config( + "CONV_KERNEL", + cast=Csv(int), + default=(1, 3), + section=self.section, # type: ignore + ) + self.convt_kernel: List[int] = config( + "CONVT_KERNEL", + cast=Csv(int), + default=(1, 3), + section=self.section, # type: ignore + ) + self.conv_kernel_inp: List[int] = config( + "CONV_KERNEL_INP", + cast=Csv(int), + default=(3, 3), + section=self.section, # type: ignore + ) + self.emb_hidden_dim: int = config( + "EMB_HIDDEN_DIM", cast=int, default=256, section=self.section + ) + self.emb_num_layers: int = config( + "EMB_NUM_LAYERS", cast=int, default=2, section=self.section + ) + self.emb_gru_skip_enc: str = config( + "EMB_GRU_SKIP_ENC", default="none", section=self.section + ) + self.emb_gru_skip: str = config( + "EMB_GRU_SKIP", default="none", section=self.section + ) + self.df_hidden_dim: int = config( + "DF_HIDDEN_DIM", cast=int, default=256, section=self.section + ) + self.df_gru_skip: str = config( + "DF_GRU_SKIP", default="none", section=self.section + ) + self.df_pathway_kernel_size_t: int = config( + "DF_PATHWAY_KERNEL_SIZE_T", cast=int, default=1, section=self.section + ) + self.enc_concat: bool = config( + "ENC_CONCAT", cast=bool, default=False, section=self.section + ) + self.df_num_layers: int = config( + "DF_NUM_LAYERS", cast=int, default=3, section=self.section + ) + self.df_n_iter: int = config( + "DF_N_ITER", cast=int, default=1, section=self.section + ) + self.lin_groups: int = config( + "LINEAR_GROUPS", cast=int, default=1, section=self.section + ) + self.enc_lin_groups: int = config( + "ENC_LINEAR_GROUPS", cast=int, default=16, section=self.section + ) + self.mask_pf: bool = config( + "MASK_PF", cast=bool, default=False, section=self.section + ) + self.lsnr_dropout: bool = config( + "LSNR_DROPOUT", cast=bool, default=False, section=self.section + ) + + +class Add(nn.Module): + def forward(self, a, b): + return a + b + + +class Concat(nn.Module): + def forward(self, a, b): + return torch.cat((a, b), dim=-1) + + +class Encoder(nn.Module): + def __init__(self): + super().__init__() + p = ModelParams() + assert p.nb_erb % 4 == 0, "erb_bins should be divisible by 4" + + self.erb_conv0 = Conv2dNormAct( + 1, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True + ) + self.conv_buffer_size = p.conv_kernel_inp[0] - 1 + self.conv_ch = p.conv_ch + + conv_layer = partial( + Conv2dNormAct, + in_ch=p.conv_ch, + out_ch=p.conv_ch, + kernel_size=p.conv_kernel, + bias=False, + separable=True, + ) + self.erb_conv1 = conv_layer(fstride=2) + self.erb_conv2 = conv_layer(fstride=2) + self.erb_conv3 = conv_layer(fstride=1) + self.df_conv0_ch = p.conv_ch + self.df_conv0 = Conv2dNormAct( + 2, + self.df_conv0_ch, + kernel_size=p.conv_kernel_inp, + bias=False, + separable=True, + ) + self.df_conv1 = conv_layer(fstride=2) + self.erb_bins = p.nb_erb + self.emb_in_dim = p.conv_ch * p.nb_erb // 4 + self.emb_dim = p.emb_hidden_dim + self.emb_out_dim = p.conv_ch * p.nb_erb // 4 + df_fc_emb = GroupedLinearEinsum( + p.conv_ch * p.nb_df // 2, + self.emb_in_dim, + groups=p.enc_lin_groups, + activation=nn.ReLU(inplace=True), + ) + self.df_fc_emb = nn.Sequential(df_fc_emb) + if p.enc_concat: + self.emb_in_dim *= 2 + self.combine = Concat() + else: + self.combine = Add() + self.emb_n_layers = p.emb_num_layers + if p.emb_gru_skip_enc == "none": + skip_op = None + elif p.emb_gru_skip_enc == "identity": + assert self.emb_in_dim == self.emb_out_dim, "Dimensions do not match" + skip_op = partial(nn.Identity) + elif p.emb_gru_skip_enc == "groupedlinear": + skip_op = partial( + GroupedLinearEinsum, + input_size=self.emb_out_dim, + hidden_size=self.emb_out_dim, + groups=p.lin_groups, + ) + else: + raise NotImplementedError() + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_dim, + output_size=self.emb_out_dim, + num_layers=1, + batch_first=False, + gru_skip_op=skip_op, + linear_groups=p.lin_groups, + linear_act_layer=partial(nn.ReLU, inplace=True), + ) + self.lsnr_fc = nn.Sequential(nn.Linear(self.emb_out_dim, 1), nn.Sigmoid()) + self.lsnr_scale = p.lsnr_max - p.lsnr_min + self.lsnr_offset = p.lsnr_min + + def forward( + self, feat_erb: Tensor, feat_spec: Tensor, hidden: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + # Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands. + # erb: [B, 1, T, Fe] + # spec: [B, 2, T, Fc] + # b, _, t, _ = feat_erb.shape + + # feat erb branch + e0 = self.erb_conv0(feat_erb) # [B, C, T, F] + e1 = self.erb_conv1(e0) # [B, C*2, T, F/2] + e2 = self.erb_conv2(e1) # [B, C*4, T, F/4] + e3 = self.erb_conv3(e2) # [B, C*4, T, F/4] + emb = e3.permute(0, 2, 3, 1).flatten(2, 3) # [B, T, C * F] + + # feat spec branch + c0 = self.df_conv0(feat_spec) # [B, C, T, Fc] + c1 = self.df_conv1(c0) # [B, C*2, T, Fc/2] + cemb = c1.permute(0, 2, 3, 1) # [B, T, -1] + cemb = self.df_fc_emb(cemb) # [T, B, C * F/4] + + # combine + emb = self.combine(emb, cemb) + emb, hidden = self.emb_gru(emb, hidden) # [B, T, -1] + return e0, e1, e2, e3, emb, c0, hidden + + +class ErbDecoder(nn.Module): + def __init__(self): + super().__init__() + p = ModelParams() + assert p.nb_erb % 8 == 0, "erb_bins should be divisible by 8" + + self.emb_in_dim = p.conv_ch * p.nb_erb // 4 + self.emb_dim = p.emb_hidden_dim + self.emb_out_dim = p.conv_ch * p.nb_erb // 4 + + if p.emb_gru_skip == "none": + skip_op = None + elif p.emb_gru_skip == "identity": + assert self.emb_in_dim == self.emb_out_dim, "Dimensions do not match" + skip_op = partial(nn.Identity) + elif p.emb_gru_skip == "groupedlinear": + skip_op = partial( + GroupedLinearEinsum, + input_size=self.emb_in_dim, + hidden_size=self.emb_out_dim, + groups=p.lin_groups, + ) + else: + raise NotImplementedError() + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_dim, + output_size=self.emb_out_dim, + num_layers=p.emb_num_layers - 1, + batch_first=False, + gru_skip_op=skip_op, + linear_groups=p.lin_groups, + linear_act_layer=partial(nn.ReLU, inplace=True), + ) + tconv_layer = partial( + ConvTranspose2dNormAct, + kernel_size=p.convt_kernel, + bias=False, + separable=True, + ) + conv_layer = partial( + Conv2dNormAct, + bias=False, + separable=True, + ) + # convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions + self.conv3p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) + self.convt3 = conv_layer(p.conv_ch, p.conv_ch, kernel_size=p.conv_kernel) + self.conv2p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) + self.convt2 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2) + self.conv1p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) + self.convt1 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2) + self.conv0p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) + self.conv0_out = conv_layer( + p.conv_ch, 1, kernel_size=p.conv_kernel, activation_layer=nn.Sigmoid + ) + + def forward( + self, + emb: Tensor, + e3: Tensor, + e2: Tensor, + e1: Tensor, + e0: Tensor, + hidden: Tensor, + ) -> Tuple[Tensor, Tensor]: + # Estimates erb mask + b, _, t, f8 = e3.shape + emb, hidden = self.emb_gru(emb, hidden) + emb = emb.view(1, 1, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8] + e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4] + e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2] + e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F] + m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F] + return m, hidden + + +class DfDecoder(nn.Module): + def __init__(self): + super().__init__() + p = ModelParams() + layer_width = p.conv_ch + + self.emb_in_dim = p.conv_ch * p.nb_erb // 4 + self.emb_dim = p.df_hidden_dim + + self.df_n_hidden = p.df_hidden_dim + self.df_n_layers = p.df_num_layers + self.df_order = p.df_order + self.df_bins = p.nb_df + self.df_out_ch = p.df_order * 2 + + conv_layer = partial(Conv2dNormAct, separable=True, bias=False) + kt = p.df_pathway_kernel_size_t + self.conv_buffer_size = kt - 1 + self.df_convp = conv_layer( + layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1) + ) + + self.df_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_dim, + num_layers=self.df_n_layers, + batch_first=True, + gru_skip_op=None, + linear_act_layer=partial(nn.ReLU, inplace=True), + ) + p.df_gru_skip = p.df_gru_skip.lower() + assert p.df_gru_skip in ("none", "identity", "groupedlinear") + self.df_skip: Optional[nn.Module] + if p.df_gru_skip == "none": + self.df_skip = None + elif p.df_gru_skip == "identity": + assert p.emb_hidden_dim == p.df_hidden_dim, "Dimensions do not match" + self.df_skip = nn.Identity() + elif p.df_gru_skip == "groupedlinear": + self.df_skip = GroupedLinearEinsum( + self.emb_in_dim, self.emb_dim, groups=p.lin_groups + ) + else: + raise NotImplementedError() + self.df_out: nn.Module + out_dim = self.df_bins * self.df_out_ch + df_out = GroupedLinearEinsum( + self.df_n_hidden, out_dim, groups=p.lin_groups, activation=nn.Tanh() + ) + self.df_out = nn.Sequential(df_out) + self.df_fc_a = nn.Sequential(nn.Linear(self.df_n_hidden, 1), nn.Sigmoid()) + + def forward(self, emb: Tensor, c0: Tensor, hidden: Tensor) -> Tuple[Tensor, Tensor]: + b, t, _ = emb.shape + c, hidden = self.df_gru(emb, hidden) # [B, T, H], H: df_n_hidden + if self.df_skip is not None: + c = c + self.df_skip(emb) + c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last + c = self.df_out(c) # [B, T, F*O*2], O: df_order + c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2] + return c, hidden + + +class ExportableStreamingMinimalTorchDF(nn.Module): + def __init__( + self, + fft_size, + hop_size, + nb_bands, + enc, + df_dec, + erb_dec, + df_order=5, + lookahead=2, + conv_lookahead=2, + nb_df=96, + alpha=0.99, + min_db_thresh=-10.0, + max_db_erb_thresh=30.0, + max_db_df_thresh=20.0, + normalize_atten_lim=20.0, + silence_thresh=1e-7, + sr=48000, + ): + # All complex numbers are stored as floats for ONNX compatibility + super().__init__() + + self.fft_size = torch.tensor(fft_size, dtype=torch.float32) + self.frame_size = hop_size # dimension "f" in Float[f] + self.window_size = fft_size + self.window_size_h = fft_size // 2 + self.freq_size = fft_size // 2 + 1 # dimension "F" in Float[F] + self.wnorm = 1.0 / (self.window_size**2 / (2 * self.frame_size)) + self.df_order = df_order + self.lookahead = lookahead + self.sr = sr + + # Initialize the vorbis window: sin(pi/2*sin^2(pi*n/N)) + window = torch.sin( + 0.5 * torch.pi * (torch.arange(self.fft_size) + 0.5) / self.window_size_h + ) + window = torch.sin( + 0.5 * torch.pi * window**2, + ) + self.register_buffer("window", window) + + self.nb_df = nb_df + + # Initializing erb features + self.erb_indices = torch.tensor( + [ + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 5, + 5, + 7, + 7, + 8, + 10, + 12, + 13, + 15, + 18, + 20, + 24, + 28, + 31, + 37, + 42, + 50, + 56, + 67, + ] + ) + self.nb_bands = nb_bands + + self.register_buffer( + "forward_erb_matrix", + self.erb_fb(self.erb_indices, normalized=True, inverse=False), + ) + self.register_buffer( + "inverse_erb_matrix", + self.erb_fb(self.erb_indices, normalized=True, inverse=True), + ) + + ### Model + self.enc = Encoder() + self.enc.load_state_dict(enc.state_dict()) + self.enc.eval() + + # Instead of padding we put tensor with buffers into encoder + # I didn't checked receptived fields of convolution, but equallity tests are working + self.enc.erb_conv0 = self.remove_conv_block_padding(self.enc.erb_conv0) + self.enc.df_conv0 = self.remove_conv_block_padding(self.enc.df_conv0) + + # Instead of padding we put tensor with buffers into df_decoder + self.df_dec = DfDecoder() + self.df_dec.load_state_dict(df_dec.state_dict()) + self.df_dec.eval() + self.df_dec.df_convp = self.remove_conv_block_padding(self.df_dec.df_convp) + + self.erb_dec = ErbDecoder() + self.erb_dec.load_state_dict(erb_dec.state_dict()) + self.erb_dec.eval() + ### End Model + + self.alpha = alpha + + # RFFT + # FFT operations are performed as matmuls for ONNX compatability + self.register_buffer( + "rfft_matrix", + torch.view_as_real(torch.fft.rfft(torch.eye(self.window_size))).transpose( + 0, 1 + ), + ) + self.register_buffer("irfft_matrix", torch.linalg.pinv(self.rfft_matrix)) + + # Thresholds + self.register_buffer("min_db_thresh", torch.tensor([min_db_thresh])) + self.register_buffer("max_db_erb_thresh", torch.tensor([max_db_erb_thresh])) + self.register_buffer("max_db_df_thresh", torch.tensor([max_db_df_thresh])) + self.normalize_atten_lim = torch.tensor(normalize_atten_lim) + self.silence_thresh = torch.tensor(silence_thresh) + self.linspace_erb = [-60.0, -90.0] + self.linspace_df = [0.001, 0.0001] + + self.erb_norm_state_shape = (self.nb_bands,) + self.band_unit_norm_state_shape = ( + 1, + self.nb_df, + 1, + ) # [bs=1, nb_df, mean of complex value = 1] + self.analysis_mem_shape = (self.frame_size,) + self.synthesis_mem_shape = (self.frame_size,) + self.rolling_erb_buf_shape = ( + 1, + 1, + conv_lookahead + 1, + self.nb_bands, + ) # [B, 1, conv kernel size, nb_bands] + self.rolling_feat_spec_buf_shape = ( + 1, + 2, + conv_lookahead + 1, + self.nb_df, + ) # [B, 2 - complex, conv kernel size, nb_df] + self.rolling_c0_buf_shape = ( + 1, + self.enc.df_conv0_ch, + self.df_order, + self.nb_df, + ) # [B, conv hidden, df_order, nb_df] + self.rolling_spec_buf_x_shape = ( + max(self.df_order, conv_lookahead), + self.freq_size, + 2, + ) # [number of specs to save, ...] + self.rolling_spec_buf_y_shape = ( + self.df_order + conv_lookahead, + self.freq_size, + 2, + ) # [number of specs to save, ...] + self.enc_hidden_shape = ( + 1, + 1, + self.enc.emb_dim, + ) # [n_layers=1, batch_size=1, emb_dim] + self.erb_dec_hidden_shape = ( + 2, + 1, + self.erb_dec.emb_dim, + ) # [n_layers=2, batch_size=1, emb_dim] + self.df_dec_hidden_shape = ( + 2, + 1, + self.df_dec.emb_dim, + ) # [n_layers=2, batch_size=1, emb_dim] + + # States + state_shapes = [ + self.erb_norm_state_shape, + self.band_unit_norm_state_shape, + self.analysis_mem_shape, + self.synthesis_mem_shape, + self.rolling_erb_buf_shape, + self.rolling_feat_spec_buf_shape, + self.rolling_c0_buf_shape, + self.rolling_spec_buf_x_shape, + self.rolling_spec_buf_y_shape, + self.enc_hidden_shape, + self.erb_dec_hidden_shape, + self.df_dec_hidden_shape, + ] + self.state_lens = [math.prod(x) for x in state_shapes] + self.states_full_len = sum(self.state_lens) + + # Zero buffers + self.register_buffer("zero_gains", torch.zeros(self.nb_bands)) + self.register_buffer( + "zero_coefs", torch.zeros(self.rolling_c0_buf_shape[2], self.nb_df, 2) + ) + + @staticmethod + def remove_conv_block_padding(original_conv: nn.Module) -> nn.Module: + """ + Remove paddings for convolutions in the original model + + Parameters: + original_conv: nn.Module - original convolution module + + Returns: + output: nn.Module - new convolution module without paddings + """ + new_modules = [] + + for module in original_conv: + if not isinstance(module, nn.ConstantPad2d): + new_modules.append(module) + + return nn.Sequential(*new_modules) + + def erb_fb( + self, widths: Tensor, normalized: bool = True, inverse: bool = False + ) -> Tensor: + """ + Generate the erb filterbank + Taken from https://github.com/Rikorose/DeepFilterNet/blob/fa926662facea33657c255fd1f3a083ddc696220/DeepFilterNet/df/modules.py#L206 + Numpy removed from original code + + Parameters: + widths: Tensor - widths of the erb bands + normalized: bool - normalize to constant energy per band + inverse: bool - inverse erb filterbank + + Returns: + fb: Tensor - erb filterbank + """ + n_freqs = int(torch.sum(widths)) + all_freqs = torch.linspace(0, self.sr // 2, n_freqs + 1)[:-1] + + b_pts = torch.cumsum( + torch.cat([torch.tensor([0]), widths]), dtype=torch.int32, dim=0 + )[:-1] + + fb = torch.zeros((all_freqs.shape[0], b_pts.shape[0])) + for i, (b, w) in enumerate(zip(b_pts.tolist(), widths.tolist())): + fb[b : b + w, i] = 1 + + # Normalize to constant energy per resulting band + if inverse: + fb = fb.t() + if not normalized: + fb /= fb.sum(dim=1, keepdim=True) + else: + if normalized: + fb /= fb.sum(dim=0) + + return fb + + @staticmethod + def mul_complex(t1, t2): + """ + Compute multiplication of two complex numbers in view_as_real format. + + Parameters: + t1: Float[F, 2] - First number + t2: Float[F, 2] - Second number + + Returns: + output: Float[F, 2] - final multiplication of two complex numbers + """ + # if not torch.onnx.is_in_onnx_export(): + t1_real = t1[..., 0] + t1_imag = t1[..., 1] + t2_real = t2[..., 0] + t2_imag = t2[..., 1] + return torch.stack( + ( + t1_real * t2_real - t1_imag * t2_imag, + t1_real * t2_imag + t1_imag * t2_real, + ), + dim=-1, + ) + # return t1 * t2 + # return OnnxComplexMul.apply(t1, t2) + + def erb(self, input_data: Tensor, erb_eps: float = 1e-10) -> Tensor: + """ + Original code - pyDF/src/lib.rs - erb() + Calculating ERB features for each frame. + + Parameters: + input_data: Float[T, F] or Float[F] - audio spectrogram + + Returns: + erb_features: Float[T, ERB] or Float[ERB] - erb features for given spectrogram + """ + + magnitude_squared = torch.sum(input_data**2, dim=-1) + erb_features = magnitude_squared.matmul(self.forward_erb_matrix) + erb_features_db = 10.0 * torch.log10(erb_features + erb_eps) + + return erb_features_db + + @staticmethod + def band_mean_norm_erb( + xs: Tensor, erb_norm_state: Tensor, alpha: float, denominator: float = 40.0 + ) -> Tuple[Tensor, Tensor]: + """ + Original code - libDF/src/lib.rs - band_mean_norm() + Normalizing ERB features. And updates the normalization state. + + Parameters: + xs: Float[ERB] - erb features + erb_norm_state: Float[ERB] - normalization state from previous step + alpha: float - alpha value which is needed for adaptation of the normalization state for given scale. + denominator: float - denominator for normalization + + Returns: + output: Float[ERB] - normalized erb features + erb_norm_state: Float[ERB] - updated normalization state + """ + new_erb_norm_state = xs * (1 - alpha) + erb_norm_state * alpha + output = (xs - new_erb_norm_state) / denominator + + return output, new_erb_norm_state + + @staticmethod + def band_unit_norm( + xs: Tensor, band_unit_norm_state, alpha: float + ) -> Tuple[Tensor, Tensor]: + """ + Original code - libDF/src/lib.rs - band_unit_norm() + Normalizing Deep Filtering features. And updates the normalization state. + + Parameters: + xs: Float[1, DF, 2] - deep filtering features + band_unit_norm_state: Float[1, DF, 1] - normalization state from previous step + alpha: float - alpha value which is needed for adaptation of the normalization state for given scale. + + Returns: + output: Float[1, DF] - normalized deep filtering features + band_unit_norm_state: Float[1, DF, 1] - updated normalization state + """ + xs_abs = torch.linalg.norm(xs, dim=-1, keepdim=True) # xs.abs() from complexxs + new_band_unit_norm_state = xs_abs * (1 - alpha) + band_unit_norm_state * alpha + output = xs / new_band_unit_norm_state.sqrt() + + return output, new_band_unit_norm_state + + def frame_analysis( + self, input_frame: Tensor, analysis_mem: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Original code - libDF/src/lib.rs - frame_analysis() + Calculating spectrograme for one frame. Every frame is concated with buffer from previous frame. + + Parameters: + input_frame: Float[f] - Input raw audio frame + analysis_mem: Float[f] - Previous frame + + Returns: + output: Float[F, 2] - Spectrogram + analysis_mem: Float[f] - Saving current frame for next iteration + """ + # First part of the window on the previous frame + # Second part of the window on the new input frame + buf = torch.cat([analysis_mem, input_frame]) * self.window + # rfft_buf = torch.matmul(buf, self.rfft_matrix) * self.wnorm + rfft_buf = torch.view_as_real(torch.fft.rfft(buf)) * self.wnorm + + # Copy input to analysis_mem for next iteration + return rfft_buf, input_frame + + def frame_synthesis( + self, x: Tensor, synthesis_mem: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Original code - libDF/src/lib.rs - frame_synthesis() + Inverse rfft for one frame. Every frame is summarized with buffer from previous frame. + And saving buffer for next frame. + + Parameters: + x: Float[F, 2] - Enhanced audio spectrogram + synthesis_mem: Float[f] - Previous synthesis frame + + Returns: + output: Float[f] - Enhanced audio + synthesis_mem: Float[f] - Saving current frame + """ + # x - [F=481, 2] + # self.irfft_matrix - [fft_size=481, 2, f=960] + # [f=960] + x = ( + torch.einsum("fi,fij->j", x, self.irfft_matrix) + * self.fft_size + * self.window + ) + # x = torch.cat([x[:, 0], torch.zeros(479)]) + # x = torch.fft.irfft(torch.view_as_complex(x)) * self.fft_size * self.window + + x_first, x_second = torch.split( + x, [self.frame_size, self.window_size - self.frame_size] + ) + + output = x_first + synthesis_mem + + return output, x_second.view(self.window_size - self.frame_size) + + def apply_mask(self, spec: Tensor, gains: Tensor) -> Tensor: + """ + Original code - libDF/src/lib.rs - apply_interp_band_gain() + + Applying ERB Gains for input spectrogram + + Parameters: + spec: Float[F, 2] - Input frame spectrogram + gains: Float[ERB] - ERB gains from erb decoder + + Returns: + spec: Float[F] - Spectrogram with applyed ERB gains + """ + gains = gains.matmul(self.inverse_erb_matrix) + spec = spec * gains.unsqueeze(-1) + + return spec + + def deep_filter( + self, gain_spec: Tensor, coefs: Tensor, rolling_spec_buf_x: Tensor + ) -> Tensor: + """ + Original code - libDF/src/tract.rs - df() + + Applying Deep Filtering to gained spectrogram by multiplying coefs to rolling_buffer_x (spectrograms from past / future). + Deep Filtering replacing lower self.nb_df spec bands. + + Parameters: + gain_spec: Float[F, 2] - spectrogram after ERB gains applied + coefs: Float[DF, BUF, 2] - coefficients for deep filtering from df decoder + rolling_spec_buf_x: Float[buffer_size, F, 2] - spectrograms from past / future + + Returns: + gain_spec: Float[F, 2] - spectrogram after deep filtering + """ + stacked_input_specs = rolling_spec_buf_x[:, : self.nb_df] + mult = self.mul_complex(stacked_input_specs, coefs) + gain_spec[: self.nb_df] = torch.sum(mult, dim=0) + return gain_spec + + def forward( + self, + input_frame: Tensor, + erb_norm_state: Tensor, + band_unit_norm_state: Tensor, + analysis_mem: Tensor, + synthesis_mem: Tensor, + rolling_erb_buf: Tensor, + rolling_feat_spec_buf: Tensor, + rolling_c0_buf: Tensor, + rolling_spec_buf_x: Tensor, + rolling_spec_buf_y: Tensor, + enc_hidden: Tensor, + erb_dec_hidden: Tensor, + df_dec_hidden: Tensor, + ) -> Tuple[ + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + ]: + """ + Enhancing input audio frame + + Parameters: + input_frame: Float[t] - Input raw audio frame + + Returns: + enhanced_frame: Float[t] - Enhanced audio frame + """ + assert input_frame.ndim == 1, "only bs=1 and t=frame_size supported" + assert ( + input_frame.shape[0] == self.frame_size + ), "input_frame must be bs=1 and t=frame_size" + + spectrogram, new_analysis_mem = self.frame_analysis(input_frame, analysis_mem) + spectrogram = spectrogram.unsqueeze( + 0 + ) # [1, freq_size, 2] reshape needed for easier stacking buffers + new_rolling_spec_buf_x = torch.cat( + [rolling_spec_buf_x[1:, ...], spectrogram] + ) # [n_frames=5, 481, 2] + # rolling_spec_buf_y - [n_frames=7, 481, 2] n_frames=7 for compatability with original code, but in code we use only one frame + new_rolling_spec_buf_y = torch.cat([rolling_spec_buf_y[1:, ...], spectrogram]) + + erb_feat, new_erb_norm_state = self.band_mean_norm_erb( + self.erb(spectrogram).squeeze(0), erb_norm_state, alpha=self.alpha + ) # [ERB] + spec_feat, new_band_unit_norm_state = self.band_unit_norm( + spectrogram[:, : self.nb_df], band_unit_norm_state, alpha=self.alpha + ) # [1, DF, 2] + + erb_feat = erb_feat[ + None, None, None, ... + ] # [b=1, conv_input_dim=1, t=1, n_erb=32] + spec_feat = spec_feat[None, ...].permute( + 0, 3, 1, 2 + ) # [bs=1, conv_input_dim=2, t=1, df_order=96] + + # (1, 1, T, self.nb_bands) + new_rolling_erb_buf = torch.cat([rolling_erb_buf[:, :, 1:, :], erb_feat], dim=2) + + # (1, 2, T, self.nb_df) + new_rolling_feat_spec_buf = torch.cat( + [rolling_feat_spec_buf[:, :, 1:, :], spec_feat], dim=2 + ) + + e0, e1, e2, e3, emb, c0, new_enc_hidden = self.enc( + new_rolling_erb_buf, new_rolling_feat_spec_buf, enc_hidden + ) + + # erb_dec + # [BS=1, 1, T=1, ERB] + new_gains, new_erb_dec_hidden = self.erb_dec( + emb, e3, e2, e1, e0, erb_dec_hidden + ) + gains = new_gains.view(self.nb_bands) + + # df_dec + new_rolling_c0_buf = torch.cat([rolling_c0_buf[:, :, 1:, :], c0], dim=2) + # new_coefs - [BS=1, T=1, F, O*2] + new_coefs, new_df_dec_hidden = self.df_dec( + emb, new_rolling_c0_buf, df_dec_hidden + ) + coefs = new_coefs.view(self.nb_df, -1, 2).permute(1, 0, 2) + + # Applying features + current_spec = new_rolling_spec_buf_y[self.df_order - 1] + current_spec = self.apply_mask(current_spec.clone(), gains) + current_spec = self.deep_filter( + current_spec.clone(), coefs, new_rolling_spec_buf_x + ) + + enhanced_audio_frame, new_synthesis_mem = self.frame_synthesis( + current_spec, synthesis_mem + ) + + return ( + enhanced_audio_frame, + new_erb_norm_state, + new_band_unit_norm_state, + new_analysis_mem, + new_synthesis_mem, + new_rolling_erb_buf, + new_rolling_feat_spec_buf, + new_rolling_c0_buf, + new_rolling_spec_buf_x, + new_rolling_spec_buf_y, + new_enc_hidden, + new_erb_dec_hidden, + new_df_dec_hidden, + ) + + +class TorchDFMinimalPipeline(nn.Module): + def __init__( + self, + nb_bands=32, + hop_size=480, + fft_size=960, + df_order=5, + conv_lookahead=2, + nb_df=96, + model_base_dir="DeepFilterNet3", + device="cpu", + ): + super().__init__() + self.hop_size = hop_size + self.fft_size = fft_size + + model, state, _ = init_df( + config_allow_defaults=True, + model_base_dir=model_base_dir, + ) + model.eval() + self.sample_rate = state.sr() + + self.torch_streaming_model = ExportableStreamingMinimalTorchDF( + nb_bands=nb_bands, + hop_size=hop_size, + fft_size=fft_size, + enc=model.enc, + df_dec=model.df_dec, + erb_dec=model.erb_dec, + df_order=df_order, + conv_lookahead=conv_lookahead, + nb_df=nb_df, + sr=self.sample_rate, + ) + self.torch_streaming_model = self.torch_streaming_model.to(device) + + analysis_mem = torch.zeros(self.torch_streaming_model.analysis_mem_shape) + synthesis_mem = torch.zeros(self.torch_streaming_model.synthesis_mem_shape) + rolling_erb_buf = torch.zeros(self.torch_streaming_model.rolling_erb_buf_shape) + rolling_feat_spec_buf = torch.zeros( + self.torch_streaming_model.rolling_feat_spec_buf_shape + ) + rolling_c0_buf = torch.zeros(self.torch_streaming_model.rolling_c0_buf_shape) + rolling_spec_buf_x = torch.zeros( + self.torch_streaming_model.rolling_spec_buf_x_shape + ) + rolling_spec_buf_y = torch.zeros( + self.torch_streaming_model.rolling_spec_buf_y_shape + ) + enc_hidden = torch.zeros(self.torch_streaming_model.enc_hidden_shape) + erb_dec_hidden = torch.zeros(self.torch_streaming_model.erb_dec_hidden_shape) + df_dec_hidden = torch.zeros(self.torch_streaming_model.df_dec_hidden_shape) + + erb_norm_state = ( + torch.linspace( + self.torch_streaming_model.linspace_erb[0], + self.torch_streaming_model.linspace_erb[1], + self.torch_streaming_model.nb_bands, + ) + .view(self.torch_streaming_model.erb_norm_state_shape) + .to(torch.float32) + ) # float() to fix export issue + + band_unit_norm_state = ( + torch.linspace( + self.torch_streaming_model.linspace_df[0], + self.torch_streaming_model.linspace_df[1], + self.torch_streaming_model.nb_df, + ) + .view(self.torch_streaming_model.band_unit_norm_state_shape) + .to(torch.float32) + ) # float() to fix export issue + + self.states = [ + erb_norm_state, + band_unit_norm_state, + analysis_mem, + synthesis_mem, + rolling_erb_buf, + rolling_feat_spec_buf, + rolling_c0_buf, + rolling_spec_buf_x, + rolling_spec_buf_y, + enc_hidden, + erb_dec_hidden, + df_dec_hidden, + ] + self.input_names = [ + "input_frame", + "erb_norm_state", + "band_unit_norm_state", + "analysis_mem", + "synthesis_mem", + "rolling_erb_buf", + "rolling_feat_spec_buf", + "rolling_c0_buf", + "rolling_spec_buf_x", + "rolling_spec_buf_y", + "enc_hidden", + "erb_dec_hidden", + "df_dec_hidden", + ] + self.output_names = [ + "enhanced_audio_frame", + "new_erb_norm_state", + "new_band_unit_norm_state", + "new_analysis_mem", + "new_synthesis_mem", + "new_rolling_erb_buf", + "new_rolling_feat_spec_buf", + "new_rolling_c0_buf", + "new_rolling_spec_buf_x", + "new_rolling_spec_buf_y", + "new_enc_hidden", + "new_erb_dec_hidden", + "new_df_dec_hidden", + ] + + def forward(self, input_audio: Tensor, sample_rate: int) -> Tensor: + """ + Denoising audio frame using exportable fully torch model. + + Parameters: + input_audio: Float[1, t] - Input audio + sample_rate: Int - Sample rate + + Returns: + enhanced_audio: Float[1, t] - Enhanced input audio + """ + assert ( + input_audio.shape[0] == 1 + ), f"Only mono supported! Got wrong shape! {input_audio.shape}" + assert ( + sample_rate == self.sample_rate + ), f"Only {self.sample_rate} supported! Got wrong sample rate! {sample_rate}" + + input_audio = input_audio.squeeze(0) + orig_len = input_audio.shape[0] + + # padding taken from + # https://github.com/Rikorose/DeepFilterNet/blob/fa926662facea33657c255fd1f3a083ddc696220/DeepFilterNet/df/enhance.py#L229 + hop_size_divisible_padding_size = ( + self.hop_size - orig_len % self.hop_size + ) % self.hop_size + orig_len += hop_size_divisible_padding_size + input_audio = F.pad( + input_audio, (0, self.fft_size + hop_size_divisible_padding_size) + ) + + chunked_audio = torch.split(input_audio, self.hop_size) + + output_frames = [] + + for input_frame in chunked_audio: + enhanced_audio_frame, *self.states = self.torch_streaming_model( + input_frame, *self.states + ) + + output_frames.append(enhanced_audio_frame) + + enhanced_audio = torch.cat(output_frames).unsqueeze( + 0 + ) # [t] -> [1, t] typical mono format + + # taken from + # https://github.com/Rikorose/DeepFilterNet/blob/fa926662facea33657c255fd1f3a083ddc696220/DeepFilterNet/df/enhance.py#L248 + d = self.fft_size - self.hop_size + enhanced_audio = enhanced_audio[:, d : orig_len + d] + + return enhanced_audio + + +def main(args): + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + torch_df = TorchDFMinimalPipeline(device=args.device) + + # torchaudio normalize=True, fp32 return + noisy_audio, sr = torchaudio.load(args.audio_path, channels_first=True) + noisy_audio = noisy_audio.mean(dim=0).unsqueeze(0).to(args.device) # stereo to mono + + enhanced_audio = torch_df(noisy_audio, sr).detach().cpu() + + torchaudio.save( + args.output_path, enhanced_audio, sr, encoding="PCM_S", bits_per_sample=16 + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Denoising one audio with DF3 model using torch only" + ) + parser.add_argument( + "--audio-path", type=str, required=True, help="Path to audio file" + ) + parser.add_argument( + "--output-path", type=str, required=True, help="Path to output file" + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cuda", "cpu"], + help="Device to run on", + ) + + main(parser.parse_args())