Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Aug 7, 2023
1 parent 502efdf commit 7db996f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
11 changes: 8 additions & 3 deletions k2/torch/bin/hlg_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def get_parser():
parser.add_argument(
"--output-file",
type=str,
help="The file to write out results to, only used when giving --wav-scp",
help="""
The file to write out results to, only used when giving --wav-scp
""",
)

parser.add_argument(
Expand Down Expand Up @@ -239,7 +241,10 @@ def main():
if args.method == "ctc-decoding":
logging.info("Use CTC decoding")
max_token_id = args.num_classes - 1
decoding_graph = k2.ctc_topo(max_token=max_token_id, device=device,)
decoding_graph = k2.ctc_topo(
max_token=max_token_id,
device=device,
)
token_sym_table = k2.SymbolTable.from_file(args.tokens)
else:
assert args.method == "1best", args.method
Expand All @@ -260,7 +265,7 @@ def main():

res = decode_one_batch(
params=args,
batch=wave_list[start : start + args.batch_size],
batch=wave_list[start: start + args.batch_size],
model=model,
feature_extractor=fbank,
decoding_graph=decoding_graph,
Expand Down
15 changes: 11 additions & 4 deletions k2/torch/bin/online_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def get_parser():
parser.add_argument(
"--output-file",
type=str,
help="The file to write out results to, only used when giving --wav-scp",
help="""
The file to write out results to, only used when giving --wav-scp
""",
)

parser.add_argument(
Expand Down Expand Up @@ -158,7 +160,8 @@ def decode_one_chunk(
current_num_frames.append(0)
current_nnet_outputs.append(
torch.zeros(
(params.chunk_size, params.num_classes), device=params.device,
(params.chunk_size, params.num_classes),
device=params.device,
)
)
current_state_infos.append(DecodeStateInfo())
Expand Down Expand Up @@ -211,7 +214,8 @@ def decode_dataset(
data, sample_rate = torchaudio.load(waves[wave_index][1])
assert (
sample_rate == params.sample_rate
), f"expected sample rate: {params.sample_rate}. Given: {sample_rate}"
), f"expected sample rate: {params.sample_rate}. "
f"Given: {sample_rate}"
data = data[0].to(params.device)
feature = feature_extractor(data)
nnet_output, _, _ = model(feature.unsqueeze(0))
Expand Down Expand Up @@ -318,7 +322,10 @@ def main():
if args.method == "ctc-decoding":
logging.info("Use CTC decoding")
max_token_id = args.num_classes - 1
decoding_graph = k2.ctc_topo(max_token=max_token_id, device=device,)
decoding_graph = k2.ctc_topo(
max_token=max_token_id,
device=device,
)
token_sym_table = k2.SymbolTable.from_file(args.tokens)
else:
assert args.method == "1best", args.method
Expand Down

0 comments on commit 7db996f

Please sign in to comment.