Skip to content

Commit

Permalink
Update script
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Apr 5, 2024
1 parent 6b240f7 commit e2d185d
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions tools/create_examples_from_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def transform_file_content(

def prepare_speech_script(file_content: str, seq2seq_or_ctc: str):
assert seq2seq_or_ctc in ["seq2seq", "ctc"]
is_seq2seq = seq2seq_or_ctc == "seq2seq"
max_label_length_data_argument = """
max_label_length: int = field(
default=128,
Expand Down Expand Up @@ -356,13 +357,10 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
"""
file_content = transform_file_content(
file_content,
lambda n: isinstance(n, ast.ClassDef) and n.name == "DataCollatorSpeechSeq2SeqWithPadding",
lambda n: isinstance(n, ast.ClassDef)
and n.name in ["DataCollatorSpeechSeq2SeqWithPadding", "DataCollatorCTCWithPadding"],
InsertPosition.BETWEEN,
(
xla_compatible_data_collator_for_seq2seq
if seq2seq_or_ctc == "seq2seq"
else xla_compatible_data_collator_for_ctc
),
(xla_compatible_data_collator_for_seq2seq if is_seq2seq else xla_compatible_data_collator_for_ctc),
)

import_partial_from_functools = "from functools import partial"
Expand Down Expand Up @@ -419,7 +417,7 @@ def is_labels_in_length_range(labels):
InsertPosition.BETWEEN,
(
data_collator_with_padding_and_max_length_for_seq2seq
if seq2seq_or_ctc == "seq2seq"
if is_seq2seq
else data_collator_with_padding_and_max_length_for_ctc
),
)
Expand Down

0 comments on commit e2d185d

Please sign in to comment.