This repository has been archived by the owner on Sep 6, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 79
/
streaming_transcribe.py
97 lines (63 loc) · 2.1 KB
/
streaming_transcribe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from argparse import ArgumentParser
import os
import time
import pyaudio
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
tf.autograph.set_verbosity(0)
from utils import preprocessing, encoding, decoding
from utils import model as model_utils
from model import build_keras_model
from hparams import *
SAMPLE_RATE = 16000
NUM_CHANNELS = 1
CHUNK_SIZE = 1024
LAST_OUTPUT = ''
def main(args):
model_dir = os.path.dirname(os.path.realpath(args.checkpoint))
hparams = model_utils.load_hparams(model_dir)
_, tok_to_text, vocab_size = encoding.get_encoder(
encoder_dir=model_dir,
hparams=hparams)
hparams[HP_VOCAB_SIZE.name] = vocab_size
model = build_keras_model(hparams, stateful=True)
model.load_weights(args.checkpoint)
decoder_fn = decoding.greedy_decode_fn(model, hparams)
p = pyaudio.PyAudio()
def listen_callback(in_data, frame_count, time_info, status):
global LAST_OUTPUT
audio = tf.io.decode_raw(in_data, out_type=tf.float32)
log_melspec = preprocessing.preprocess_audio(
audio=audio,
sample_rate=SAMPLE_RATE,
hparams=hparams)
log_melspec = tf.expand_dims(log_melspec, axis=0)
decoded = decoder_fn(log_melspec)[0]
transcription = LAST_OUTPUT + tok_to_text(decoded)\
.numpy().decode('utf8')
if transcription != LAST_OUTPUT:
LAST_OUTPUT = transcription
print(transcription)
return in_data, pyaudio.paContinue
stream = p.open(
format=pyaudio.paFloat32,
channels=NUM_CHANNELS,
rate=SAMPLE_RATE,
input=True,
frames_per_buffer=CHUNK_SIZE,
stream_callback=listen_callback)
print('Listening...')
stream.start_stream()
while stream.is_active():
time.sleep(0.1)
stream.stop_stream()
stream.close()
p.terminate()
def parse_args():
ap = ArgumentParser()
ap.add_argument('--checkpoint', type=str, required=True,
help='Checkpoint to load.')
return ap.parse_args()
if __name__ == '__main__':
args = parse_args()
main(args)