Skip to content

Commit

Permalink
Merge pull request #9 from gooofy/master
Browse files Browse the repository at this point in the history
add text generation (inference) methods and command line tool
  • Loading branch information
lopuhin authored Jul 16, 2019
2 parents be11c00 + 561d6aa commit 0638ca5
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
16 changes: 16 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,22 @@ Notes on training parameters:
or ``--n-ctx``: loss is already scaled appropriately.


Inference
+++++++++

Example command::

gpt-2-gen run-root "Artificial intelligence"

``run-root`` would contain model checkpoints
``"Artificial intelligence"`` is the text prefix used as a starting point for generating tokens

Notes on inference parameters:

- ``--tokens-to-generate``: number of tokens to generate, default is 42
- ``--top-k``: number of token candidates to generate for each position (beam width), default is 8.


License & credits
-----------------

Expand Down
39 changes: 39 additions & 0 deletions lm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import sentencepiece as spm
import torch
import numpy as np
import fire
from .fire_utils import only_allow_defined_args

from .model import Model, HParams
from .common import END_OF_LINE, END_OF_TEXT
Expand Down Expand Up @@ -73,9 +76,45 @@ def get_next_top_k(
for i in next_log_probs.argsort()[-top_k:]],
reverse=True)

def generate_tokens(self, tokens_prefix: List[str], tokens_to_generate: int, top_k: int) -> List[str]:

tokens = list(tokens_prefix)

for i in range(tokens_to_generate):

# generate TOP_K potential next tokens
ntk = self.get_next_top_k(tokens, top_k)

# convert log probs to real probs
logprobs = np.array(list(map(lambda a: a[0], ntk)))
probs = np.exp(logprobs) / np.exp(logprobs).sum()

# pick next token randomly according to probs distribution
next_token_n = np.random.choice(top_k, p=probs)
next_token = ntk[next_token_n][1]
# print (next_token)

tokens.append(next_token)

return tokens


def fixed_state_dict(state_dict):
if all(k.startswith('module.') for k in state_dict):
# legacy multi-GPU format
state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
return state_dict

def gen_main(model_path, prefix, tokens_to_generate=42, top_k=8):

print("loading model from %s" % model_path)
mw = ModelWrapper.load(Path(model_path))

print("generating text for prefix %s" % prefix)
tokens = mw.tokenize(prefix)

tokens_gen = mw.generate_tokens(tokens, tokens_to_generate, top_k)
print(mw.sp_model.DecodePieces(tokens_gen))

def fire_gen_main():
fire.Fire(only_allow_defined_args(gen_main))
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
'sp-encode = lm.data:sp_encode',
'gpt-2-tf-train = lm.gpt_2_tf.train:main',
'gpt-2 = lm.main:fire_main',
'gpt-2-gen = lm.inference:fire_gen_main',
'lm-web-ui = lm_web_ui.main:main',
],
}
Expand Down

0 comments on commit 0638ca5

Please sign in to comment.