Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

WIP: add compute-post. #210

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions snowfall/bin/modes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .ali import ali
from .cli_base import cli
from .net import net
8 changes: 1 addition & 7 deletions snowfall/bin/modes/ali.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from pathlib import Path
from typing import Optional

import sys

import click
import k2
import torch
Expand All @@ -18,7 +16,7 @@
@cli.group()
def ali():
'''
Alignment tools in snowfall
Alignment tools in snowfall.
'''
pass

Expand Down Expand Up @@ -79,8 +77,6 @@ def edit_distance(ref: str,
type=type,
output_file=output_file)

print(f'Saved to {output_file}', file=sys.stderr)


@ali.command()
@click.option('-i',
Expand Down Expand Up @@ -157,5 +153,3 @@ def visualize(input: str,
width=width,
height=height,
font_size=font_size)

print(f'Saved to {output_file}', file=sys.stderr)
126 changes: 126 additions & 0 deletions snowfall/bin/modes/net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang)

from pathlib import Path

import click
import torch

from .cli_base import cli
from snowfall.tools.net import compute_post as compute_post_impl
from snowfall.tools.net import decode as decode_impl


@cli.group()
def net():
'''
Neural network tools in snowfall.
'''
pass


@net.command()
@click.option('-m',
'--model',
type=click.Path(exists=True, dir_okay=False),
required=True,
help='Path to Torch Scripted module')
@click.option('-f',
'--feats',
type=click.Path(exists=True, dir_okay=False),
required=True,
help='Path to Featureset manifest')
@click.option('-o',
'--output-dir',
type=click.Path(dir_okay=True),
required=True,
help='Output directory')
@click.option('-l',
'--max-duration',
default=200,
type=int,
show_default=True,
help='max duration in seconds in a batch')
@click.option('-i',
'--device-id',
default=0,
type=int,
show_default=True,
help='-1 to use CPU. Otherwise, it is the GPU device ID')
def compute_post(model: str,
feats: str,
output_dir: str,
max_duration: int = 200,
device_id: int = 0):
'''Compute posteriors given a model and a FeatureSet.
'''
if device_id < 0:
print('Use CPU')
device = torch.device('cpu')
else:
print(f'Use GPU {device_id}')
device = torch.device('cuda', device_id)

compute_post_impl(model=model,
feats=feats,
output_dir=output_dir,
max_duration=max_duration,
device=device)


@net.command()
@click.option('-l',
'--lang-dir',
type=click.Path(exists=True, dir_okay=True, file_okay=False),
required=True,
help='The language dir. It is expected to '
'contain the following files:\n'
' - words.txt\n'
' - phones.txt\n'
' - HLG.pt (or L_disambig.fst.txt, G.fst.txt\n')
@click.option('-p',
'--posts',
type=click.Path(exists=True, dir_okay=False),
required=True,
help='Path to Posteriors manifest')
@click.option('-o',
'--output-dir',
type=click.Path(dir_okay=True),
required=True,
help='Output directory')
@click.option('-i',
'--device-id',
default=0,
type=int,
show_default=True,
help='-1 to use CPU. Otherwise, it is the GPU device ID')
@click.option('-m',
'--max-duration',
default=200,
type=int,
show_default=True,
help='max duration in seconds in a batch')
@click.option('-b',
'--output-beam-size',
default=8.0,
type=float,
show_default=True,
help='max duration in seconds in a batch')
def decode(lang_dir: str,
posts: str,
output_dir: str,
device_id: int = 0,
max_duration: int = 200,
output_beam_size: float = 8.0):
if device_id < 0:
print('Use CPU')
device = torch.device('cpu')
else:
print(f'Use GPU {device_id}')
device = torch.device('cuda', device_id)

decode_impl(lang_dir=lang_dir,
posts=posts,
output_dir=output_dir,
device=device,
max_duration=max_duration,
output_beam_size=output_beam_size)
74 changes: 69 additions & 5 deletions snowfall/decoding/graph.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from pathlib import Path

import logging

import k2
import torch
from k2 import Fsa

from snowfall.common import find_first_disambig_symbol
from snowfall.training.ctc_graph import build_ctc_topo
from snowfall.training.mmi_graph import get_phone_symbols


def compile_HLG(
L: Fsa,
G: Fsa,
H: Fsa,
L: k2.Fsa,
G: k2.Fsa,
H: k2.Fsa,
labels_disambig_id_start: int,
aux_labels_disambig_id_start: int
) -> Fsa:
) -> k2.Fsa:
"""
Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
Involves arc sorting, intersection, determinization, removal of disambiguation symbols
Expand Down Expand Up @@ -87,3 +92,62 @@ def compile_HLG(
# NOTE: we assume that both kinds of scores are in log-space.
HLG.lm_scores = HLG.scores.clone()
return HLG


def load_or_compile_HLG(lang_dir: str) -> k2.Fsa:
'''Build an HLG graph from a given directory.

The following files are expected to be available in
the given directory:

- HLG.pt
- If HLG.pt does not exist, then the following must
be available:

- words.txt
- phones.txt
- L_disambig.fst.txt
- G.fst.txt

Args:
lang_dir:
Path to the language directory.
Returns:
Return an HLG graph.
'''
lang_dir = Path(lang_dir)

if (lang_dir / 'HLG.pt').exists():
logging.info('Loading pre-compiled HLG')
d = torch.load(lang_dir / 'HLG.pt')
HLG = k2.Fsa.from_dict(d)
return HLG

word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
phone_ids = get_phone_symbols(phone_symbol_table)

phone_ids_with_blank = [0] + phone_ids
ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

logging.info('Loading L_disambig.fst.txt')
with open(lang_dir / 'L_disambig.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)

logging.info('Loading G.fst.txt')
with open(lang_dir / 'G.fst.txt') as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)

first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)

first_word_disambig_id = find_first_disambig_symbol(word_symbol_table)

HLG = compile_HLG(L=L,
G=G,
H=ctc_topo,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)

torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')

return HLG
55 changes: 52 additions & 3 deletions snowfall/tools/ali.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang)

from dataclasses import dataclass
from pathlib import Path
from typing import Dict
from typing import List
from typing import Union
from pathlib import Path

import os
import shutil
import sys
import tempfile

import k2
import lhotse
import torch

from snowfall.common import write_error_stats

Expand All @@ -27,6 +29,10 @@ class Alignment:
# one used in Lhotse
value: Dict[str, Union[List[int], List[str]]]

@staticmethod
def from_list(type: str, v: Union[List[int], List[str]]) -> 'Alignment':
return Alignment(value={type: v})


# The alignment in a dataset can be represented by
#
Expand Down Expand Up @@ -110,6 +116,8 @@ def compute_edit_distance(ref_ali: Dict[str, Alignment],
with open(output_file, 'w') as f:
write_error_stats(f, 'test', pairs)

print(f'Saved to {output_file}', file=sys.stderr)


def visualize(input: str,
text_grid: str,
Expand All @@ -128,7 +136,7 @@ def visualize(input: str,
The filename of the text grid.
output_file:
Filename of the output file. Currently, it requires that the name ends
with `.pdf`.
with one of the following: `.pdf`, `.png`, or `eps`.
start:
The start time in seconds.
end:
Expand Down Expand Up @@ -195,4 +203,45 @@ def visualize(input: str,
Path(tmp_name).unlink()
if ret != 0:
raise Exception(f'Failed to run\n{cmd}\n'
f'The praat script content is:\n{command}')
f'The Praat script content is:\n{command}')

print(f'Saved to {output_file}', file=sys.stderr)


def get_phone_alignment(best_paths: k2.Fsa) -> List[Alignment]:
'''Get phone alignment from 1-best path.

Args:
best_paths:
An FsaVec. Each single FSA in it is expected to contain only
one path.

Returns:
Return a list of :class:`Alignment`. `len(ans) == num_fsas`.
Each element in the returned list corresponds to the result of
a single FSA in the given FsaVec.
'''
assert len(best_paths.shape) == 3

# labels is a k2.RaggedInt with shape [batch][state][phone]
labels = k2.RaggedInt(best_paths.arcs.shape(), best_paths.labels.clone())

# remove the axis `state`
labels = k2.ragged.remove_axis(labels, 1)

# Now labels has axis [batch][phone]

labels = k2.ragged.to_list(labels)

# Now labels is
# [
# [phone1, phone2, phone3, ...],
# [phone1, phone2, phone3, ...],
# ...
# ]
# len(labels) == num_fsas

ans = []
for v in labels:
ans.append(Alignment.from_list('phone', v))
return ans
Loading