Skip to content

Commit

Permalink
small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
a.korepanov committed Mar 13, 2024
1 parent 5c015ea commit dfa1227
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
11 changes: 8 additions & 3 deletions torchDF/test_torchdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
torch.set_num_interop_threads(1)
DEVICE = torch.device("cpu")

EXAMPLES_PATH = ""
AUDIO_EXAMPLE_PATH = (
"examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav"
)


class TestTorchStreaming:
def __reset(self):
Expand All @@ -29,14 +34,14 @@ def __reset(self):
self.torch_streaming_no_stages = (
self.torch_streaming_like_offline.torch_streaming_model
)
self.streaming_state_no_stages = self.torch_streaming_like_offline.states
self.streaming_state_no_stages = self.torch_streaming_like_offline.states[0]
self.atten_lim_db_no_stages = self.torch_streaming_like_offline.atten_lim_db

pipeline_for_streaming = TorchDFPipeline(
always_apply_all_stages=False, device=DEVICE
)
self.torch_streaming = pipeline_for_streaming.torch_streaming_model
self.streaming_state = pipeline_for_streaming.states
self.streaming_state = pipeline_for_streaming.states[0]
self.atten_lim_db = pipeline_for_streaming.atten_lim_db

pipeline_for_minmal_streaming = TorchDFMinimalPipeline(device=DEVICE)
Expand All @@ -51,7 +56,7 @@ def __reset(self):
self.df_tract = DFTractPy()

self.noisy_audio, self.audio_sr = torchaudio.load(
"examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav",
AUDIO_EXAMPLE_PATH,
channels_first=True,
)
self.noisy_audio = self.noisy_audio.mean(dim=0).unsqueeze(0).to(DEVICE)
Expand Down
4 changes: 3 additions & 1 deletion torchDF/torch_df_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,8 @@ def main(args):
choices=["cuda", "cpu"],
help="Device to run on",
)
parser.add_argument("--always-apply-all-stages", action="store_true")
parser.add_argument(
"--always-apply-all-stages", action="store_true", help="Apply all stages or not"
)

main(parser.parse_args())

0 comments on commit dfa1227

Please sign in to comment.