diff --git a/README.rst b/README.rst index 020cf4f..c27ebf7 100644 --- a/README.rst +++ b/README.rst @@ -89,6 +89,22 @@ Notes on training parameters: or ``--n-ctx``: loss is already scaled appropriately. +Inference ++++++++++ + +Example command:: + + gpt-2-gen run-root "Artificial intelligence" + +``run-root`` would contain model checkpoints +``"Artificial intelligence"`` is the text prefix used as a starting point for generating tokens + +Notes on inference parameters: + +- ``--tokens-to-generate``: number of tokens to generate, default is 42 +- ``--top-k``: number of token candidates to generate for each position (beam width), default is 8. + + License & credits ----------------- diff --git a/lm/inference.py b/lm/inference.py index 94f040b..8768cb7 100644 --- a/lm/inference.py +++ b/lm/inference.py @@ -4,6 +4,9 @@ import sentencepiece as spm import torch +import numpy as np +import fire +from .fire_utils import only_allow_defined_args from .model import Model, HParams from .common import END_OF_LINE, END_OF_TEXT @@ -73,9 +76,45 @@ def get_next_top_k( for i in next_log_probs.argsort()[-top_k:]], reverse=True) + def generate_tokens(self, tokens_prefix: List[str], tokens_to_generate: int, top_k: int) -> List[str]: + + tokens = list(tokens_prefix) + + for i in range(tokens_to_generate): + + # generate TOP_K potential next tokens + ntk = self.get_next_top_k(tokens, top_k) + + # convert log probs to real probs + logprobs = np.array(list(map(lambda a: a[0], ntk))) + probs = np.exp(logprobs) / np.exp(logprobs).sum() + + # pick next token randomly according to probs distribution + next_token_n = np.random.choice(top_k, p=probs) + next_token = ntk[next_token_n][1] + # print (next_token) + + tokens.append(next_token) + + return tokens + def fixed_state_dict(state_dict): if all(k.startswith('module.') for k in state_dict): # legacy multi-GPU format state_dict = {k[len('module.'):]: v for k, v in state_dict.items()} return state_dict + +def gen_main(model_path, prefix, tokens_to_generate=42, top_k=8): + + print("loading model from %s" % model_path) + mw = ModelWrapper.load(Path(model_path)) + + print("generating text for prefix %s" % prefix) + tokens = mw.tokenize(prefix) + + tokens_gen = mw.generate_tokens(tokens, tokens_to_generate, top_k) + print(mw.sp_model.DecodePieces(tokens_gen)) + +def fire_gen_main(): + fire.Fire(only_allow_defined_args(gen_main)) diff --git a/setup.py b/setup.py index 514fa61..f65ef2c 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ 'sp-encode = lm.data:sp_encode', 'gpt-2-tf-train = lm.gpt_2_tf.train:main', 'gpt-2 = lm.main:fire_main', + 'gpt-2-gen = lm.inference:fire_gen_main', 'lm-web-ui = lm_web_ui.main:main', ], }