Skip to content

Commit

Permalink
Merge pull request #3 from grazder/torchDF_changes_for_graph_optimiza…
Browse files Browse the repository at this point in the history
…tion

Torch df changes for graph optimization
  • Loading branch information
grazder authored Mar 19, 2024
2 parents 1097015 + 6106c5a commit 3453346
Show file tree
Hide file tree
Showing 8 changed files with 2,037 additions and 330 deletions.
95 changes: 73 additions & 22 deletions DeepFilterNet/df/deepfilternet3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,32 @@ def __init__(self):
self.conv_lookahead: int = config(
"CONV_LOOKAHEAD", cast=int, default=0, section=self.section
)
self.conv_ch: int = config("CONV_CH", cast=int, default=16, section=self.section)
self.conv_ch: int = config(
"CONV_CH", cast=int, default=16, section=self.section
)
self.conv_depthwise: bool = config(
"CONV_DEPTHWISE", cast=bool, default=True, section=self.section
)
self.convt_depthwise: bool = config(
"CONVT_DEPTHWISE", cast=bool, default=True, section=self.section
)
self.conv_kernel: List[int] = config(
"CONV_KERNEL", cast=Csv(int), default=(1, 3), section=self.section # type: ignore
"CONV_KERNEL",
cast=Csv(int),
default=(1, 3),
section=self.section, # type: ignore
)
self.convt_kernel: List[int] = config(
"CONVT_KERNEL", cast=Csv(int), default=(1, 3), section=self.section # type: ignore
"CONVT_KERNEL",
cast=Csv(int),
default=(1, 3),
section=self.section, # type: ignore
)
self.conv_kernel_inp: List[int] = config(
"CONV_KERNEL_INP", cast=Csv(int), default=(3, 3), section=self.section # type: ignore
"CONV_KERNEL_INP",
cast=Csv(int),
default=(3, 3),
section=self.section, # type: ignore
)
self.emb_hidden_dim: int = config(
"EMB_HIDDEN_DIM", cast=int, default=256, section=self.section
Expand All @@ -53,31 +64,49 @@ def __init__(self):
self.emb_gru_skip_enc: str = config(
"EMB_GRU_SKIP_ENC", default="none", section=self.section
)
self.emb_gru_skip: str = config("EMB_GRU_SKIP", default="none", section=self.section)
self.emb_gru_skip: str = config(
"EMB_GRU_SKIP", default="none", section=self.section
)
self.df_hidden_dim: int = config(
"DF_HIDDEN_DIM", cast=int, default=256, section=self.section
)
self.df_gru_skip: str = config("DF_GRU_SKIP", default="none", section=self.section)
self.df_gru_skip: str = config(
"DF_GRU_SKIP", default="none", section=self.section
)
self.df_pathway_kernel_size_t: int = config(
"DF_PATHWAY_KERNEL_SIZE_T", cast=int, default=1, section=self.section
)
self.enc_concat: bool = config("ENC_CONCAT", cast=bool, default=False, section=self.section)
self.df_num_layers: int = config("DF_NUM_LAYERS", cast=int, default=3, section=self.section)
self.df_n_iter: int = config("DF_N_ITER", cast=int, default=1, section=self.section)
self.lin_groups: int = config("LINEAR_GROUPS", cast=int, default=1, section=self.section)
self.enc_concat: bool = config(
"ENC_CONCAT", cast=bool, default=False, section=self.section
)
self.df_num_layers: int = config(
"DF_NUM_LAYERS", cast=int, default=3, section=self.section
)
self.df_n_iter: int = config(
"DF_N_ITER", cast=int, default=1, section=self.section
)
self.lin_groups: int = config(
"LINEAR_GROUPS", cast=int, default=1, section=self.section
)
self.enc_lin_groups: int = config(
"ENC_LINEAR_GROUPS", cast=int, default=16, section=self.section
)
self.mask_pf: bool = config("MASK_PF", cast=bool, default=False, section=self.section)
self.mask_pf: bool = config(
"MASK_PF", cast=bool, default=False, section=self.section
)
self.lsnr_dropout: bool = config(
"LSNR_DROPOUT", cast=bool, default=False, section=self.section
)


def init_model(df_state: Optional[DF] = None, run_df: bool = True, train_mask: bool = True):
def init_model(
df_state: Optional[DF] = None, run_df: bool = True, train_mask: bool = True
):
p = ModelParams()
if df_state is None:
df_state = DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb)
df_state = DF(
sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb
)
erb = erb_fb(df_state.erb_widths(), p.sr, inverse=False)
erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True)
model = DfNet(erb, erb_inverse, run_df, train_mask)
Expand Down Expand Up @@ -119,7 +148,11 @@ def __init__(self):
self.erb_conv3 = conv_layer(fstride=1)
self.df_conv0_ch = p.conv_ch
self.df_conv0 = Conv2dNormAct(
2, self.df_conv0_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True
2,
self.df_conv0_ch,
kernel_size=p.conv_kernel_inp,
bias=False,
separable=True,
)
self.df_conv1 = conv_layer(fstride=2)
self.erb_bins = p.nb_erb
Expand Down Expand Up @@ -243,7 +276,15 @@ def __init__(self):
p.conv_ch, 1, kernel_size=p.conv_kernel, activation_layer=nn.Sigmoid
)

def forward(self, emb: Tensor, e3: Tensor, e2: Tensor, e1: Tensor, e0: Tensor, hidden: Tensor) -> Tuple[Tensor, Tensor]:
def forward(
self,
emb: Tensor,
e3: Tensor,
e2: Tensor,
e1: Tensor,
e0: Tensor,
hidden: Tensor,
) -> Tuple[Tensor, Tensor]:
# Estimates erb mask
b, _, t, f8 = e3.shape
emb, hidden = self.emb_gru(emb, hidden)
Expand Down Expand Up @@ -294,7 +335,9 @@ def __init__(self):
conv_layer = partial(Conv2dNormAct, separable=True, bias=False)
kt = p.df_pathway_kernel_size_t
self.conv_buffer_size = kt - 1
self.df_convp = conv_layer(layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1))
self.df_convp = conv_layer(
layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1)
)

self.df_gru = SqueezedGRU_S(
self.emb_in_dim,
Expand All @@ -313,7 +356,9 @@ def __init__(self):
assert p.emb_hidden_dim == p.df_hidden_dim, "Dimensions do not match"
self.df_skip = nn.Identity()
elif p.df_gru_skip == "groupedlinear":
self.df_skip = GroupedLinearEinsum(self.emb_in_dim, self.emb_dim, groups=p.lin_groups)
self.df_skip = GroupedLinearEinsum(
self.emb_in_dim, self.emb_dim, groups=p.lin_groups
)
else:
raise NotImplementedError()
self.df_out: nn.Module
Expand Down Expand Up @@ -356,11 +401,15 @@ def __init__(
self.erb_bins: int = p.nb_erb
if p.conv_lookahead > 0:
assert p.conv_lookahead >= p.df_lookahead
self.pad_feat = nn.ConstantPad2d((0, 0, -p.conv_lookahead, p.conv_lookahead), 0.0)
self.pad_feat = nn.ConstantPad2d(
(0, 0, -p.conv_lookahead, p.conv_lookahead), 0.0
)
else:
self.pad_feat = nn.Identity()
if p.df_lookahead > 0:
self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, p.df_lookahead - 1, -p.df_lookahead + 1), 0.0)
self.pad_spec = nn.ConstantPad3d(
(0, 0, 0, 0, p.df_lookahead - 1, -p.df_lookahead + 1), 0.0
)
else:
self.pad_spec = nn.Identity()
self.register_buffer("erb_fb", erb_fb)
Expand All @@ -369,7 +418,9 @@ def __init__(
self.mask = Mask(erb_inv_fb, post_filter=p.mask_pf)

self.df_order = p.df_order
self.df_op = MF.DF(num_freqs=p.nb_df, frame_size=p.df_order, lookahead=self.df_lookahead)
self.df_op = MF.DF(
num_freqs=p.nb_df, frame_size=p.df_order, lookahead=self.df_lookahead
)
self.df_dec = DfDecoder()
self.df_out_transform = DfOutputReshapeMF(self.df_order, p.nb_df)

Expand Down Expand Up @@ -406,7 +457,7 @@ def forward(
# feat_erb = self.pad_feat(feat_erb)
# feat_spec = self.pad_feat(feat_spec)
spec = self.pad_spec(spec)

e0, e1, e2, e3, emb, c0, lsnr, _ = self.enc(feat_erb, feat_spec, hidden=None)

if self.lsnr_droput:
Expand All @@ -427,7 +478,7 @@ def forward(
m[:, :, idcs], _ = self.erb_dec(emb, e3, e2, e1, e0, hidden=None)
else:
m, _ = self.erb_dec(emb, e3, e2, e1, e0, hidden=None)

pad_spec = F.pad(spec, (0, 0, 0, 0, 1, -1, 0, 0), value=0)
spec_m = self.mask(pad_spec, m)
else:
Expand Down
2 changes: 1 addition & 1 deletion DeepFilterNet/df/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,4 +1006,4 @@ def test_dfop():
dfop.freq_bins = F
dfop.set_forward("real_hidden_state_loop")
out6 = dfop(spec, coefs, alpha)
torch.testing.assert_allclose(out1, out6)
torch.testing.assert_allclose(out1, out6)
2 changes: 1 addition & 1 deletion torchDF/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ poetry run python torch_df_streaming_minimal.py --audio-path examples/A1CIM28ZUC

To convert model to onnx and run tests:
```
poetry run python model_onnx_export.py --test --performance --inference-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --ort
poetry run python model_onnx_export.py --test --performance --inference-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --ort --simplify --profiling --minimal
```

TODO:
Expand Down
Loading

0 comments on commit 3453346

Please sign in to comment.