Skip to content

Commit

Permalink
Merge pull request #291 from iver56/ij/increase-threshold-for-wrong-s…
Browse files Browse the repository at this point in the history
…hape-exception

Increase the threshold for raising WrongMultichannelAudioShape
  • Loading branch information
iver56 authored Jul 31, 2023
2 parents 4d0cd0c + 8e652aa commit 5c8180b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
3 changes: 2 additions & 1 deletion audiomentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __call__(self, samples: np.ndarray, sample_rate: int) -> np.ndarray:
self.randomize_parameters(samples, sample_rate)
if self.parameters["should_apply"] and len(samples) > 0:
if self.is_multichannel(samples):
if samples.shape[0] > samples.shape[1]:
# Note: We multiply by 8 here to allow big batches of very short audio
if samples.shape[0] > samples.shape[1] * 8:
raise WrongMultichannelAudioShape(
"Multichannel audio must have channels first, not channels last. In"
" other words, the shape must be (channels, samples), not"
Expand Down
11 changes: 4 additions & 7 deletions tests/test_gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,11 @@ def test_gain_multichannel(self):
assert processed_samples.dtype == np.float32

def test_gain_multichannel_with_wrong_dimension_ordering(self):
samples = np.array(
[[1.0, 0.5, -0.25, -0.125, 0.0], [1.0, 0.5, -0.25, -0.125, 0.0]],
dtype=np.float32,
).T
print(samples.shape)
sample_rate = 16000
samples = np.random.uniform(low=-0.5, high=0.5, size=(2000, 2)).astype(
np.float32
)

augment = Gain(min_gain_db=-6, max_gain_db=-6, p=1.0)

with pytest.raises(WrongMultichannelAudioShape):
processed_samples = augment(samples=samples, sample_rate=sample_rate)
augment(samples=samples, sample_rate=16000)

0 comments on commit 5c8180b

Please sign in to comment.