From dfa1227668c9dfea42fe410b80aea88ebd746db1 Mon Sep 17 00:00:00 2001 From: "a.korepanov" Date: Wed, 13 Mar 2024 16:14:03 +0300 Subject: [PATCH] small updates --- torchDF/test_torchdf.py | 11 ++++++++--- torchDF/torch_df_streaming.py | 4 +++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/torchDF/test_torchdf.py b/torchDF/test_torchdf.py index c8046446b..bff821c28 100644 --- a/torchDF/test_torchdf.py +++ b/torchDF/test_torchdf.py @@ -13,6 +13,11 @@ torch.set_num_interop_threads(1) DEVICE = torch.device("cpu") +EXAMPLES_PATH = "" +AUDIO_EXAMPLE_PATH = ( + "examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav" +) + class TestTorchStreaming: def __reset(self): @@ -29,14 +34,14 @@ def __reset(self): self.torch_streaming_no_stages = ( self.torch_streaming_like_offline.torch_streaming_model ) - self.streaming_state_no_stages = self.torch_streaming_like_offline.states + self.streaming_state_no_stages = self.torch_streaming_like_offline.states[0] self.atten_lim_db_no_stages = self.torch_streaming_like_offline.atten_lim_db pipeline_for_streaming = TorchDFPipeline( always_apply_all_stages=False, device=DEVICE ) self.torch_streaming = pipeline_for_streaming.torch_streaming_model - self.streaming_state = pipeline_for_streaming.states + self.streaming_state = pipeline_for_streaming.states[0] self.atten_lim_db = pipeline_for_streaming.atten_lim_db pipeline_for_minmal_streaming = TorchDFMinimalPipeline(device=DEVICE) @@ -51,7 +56,7 @@ def __reset(self): self.df_tract = DFTractPy() self.noisy_audio, self.audio_sr = torchaudio.load( - "examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav", + AUDIO_EXAMPLE_PATH, channels_first=True, ) self.noisy_audio = self.noisy_audio.mean(dim=0).unsqueeze(0).to(DEVICE) diff --git a/torchDF/torch_df_streaming.py b/torchDF/torch_df_streaming.py index 3d0c6661a..491d3c130 100644 --- a/torchDF/torch_df_streaming.py +++ b/torchDF/torch_df_streaming.py @@ -899,6 +899,8 @@ def main(args): choices=["cuda", "cpu"], help="Device to run on", ) - parser.add_argument("--always-apply-all-stages", action="store_true") + parser.add_argument( + "--always-apply-all-stages", action="store_true", help="Apply all stages or not" + ) main(parser.parse_args())