Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible gpt-2-gen bug: assertion error in inference.py #10

Open
strumke opened this issue Jul 17, 2019 · 10 comments
Open

Possible gpt-2-gen bug: assertion error in inference.py #10

strumke opened this issue Jul 17, 2019 · 10 comments

Comments

@strumke
Copy link

strumke commented Jul 17, 2019

Thanks for added the generative functionality! Is there a bug or am I doing it wrong?
See command and output below (test dataset after encoding and training as per README)

➜ transformer-lm git:(master) gpt-2-gen tests/shakespeare-test-run "Artificial intelligence"
loading model from tests/shakespeare-test-run
generating text for prefix Artificial intelligence
Traceback (most recent call last):
File "/Users/.../anaconda3/bin/gpt-2-gen", line 11, in
load_entry_point('lm', 'console_scripts', 'gpt-2-gen')()
File "/Users/.../transformer-lm/lm/inference.py", line 120, in fire_gen_main
fire.Fire(only_allow_defined_args(gen_main))
File "/Users/.../anaconda3/lib/python3.7/site-packages/fire/core.py", line 127, in Fire
component_trace = _Fire(component, args, context, name)
File "/Users/.../anaconda3/lib/python3.7/site-packages/fire/core.py", line 366, in _Fire
component, remaining_args)
File "/Users/.../anaconda3/lib/python3.7/site-packages/fire/core.py", line 542, in _CallCallable
result = fn(*varargs, **kwargs)
File "/Users/.../transformer-lm/lm/fire_utils.py", line 30, in _return_wrapped
return function_to_decorate(*args, **kwargs)
File "/Users/.../transformer-lm/lm/inference.py", line 116, in gen_main
tokens_gen = mw.generate_tokens(tokens, tokens_to_generate, top_k)
File "/Users/.../transformer-lm/lm/inference.py", line 86, in generate_tokens
ntk = self.get_next_top_k(tokens, top_k)
File "/Users/.../transformer-lm/lm/inference.py", line 74, in get_next_top_k
next_log_probs = self.get_log_probs(tokens)[-1]
File "/Users/.../transformer-lm/lm/inference.py", line 51, in get_log_probs
assert len(tokens) <= self.model.hparams.n_ctx # TODO
AssertionError

@gooofy
Copy link
Contributor

gooofy commented Jul 17, 2019

humm - interesting! :o)

a bit of debug output could be enlightening here - could you apply this patch and run again?

diff --git a/lm/inference.py b/lm/inference.py
index 8768cb7..5a4b78b 100644
--- a/lm/inference.py
+++ b/lm/inference.py
@@ -78,6 +78,8 @@ class ModelWrapper:
 
     def generate_tokens(self, tokens_prefix: List[str], tokens_to_generate: int, top_k: int) -> List[str]:
 
+        print ("self.model.hparams.n_ctx: %d, tokens_to_generate: %d" % (self.model.hparams.n_ctx, tokens_to_generate))
+
         tokens = list(tokens_prefix)
 
         for i in range(tokens_to_generate):
@@ -92,7 +94,7 @@ class ModelWrapper:
             # pick next token randomly according to probs distribution
             next_token_n = np.random.choice(top_k, p=probs)
             next_token = ntk[next_token_n][1]
-            # print (next_token)
+            print ("Token # %d: %s" % (i, next_token))
             
             tokens.append(next_token)

output should look similar to this:

(torch) [bofh@hal transformer-lm]$ gpt-2-gen gpt2-german "Ursula von der Leyen"
loading model from gpt2-german
generating text for prefix Ursula von der Leyen
self.model.hparams.n_ctx: 1024, tokens_to_generate: 42
Token # 0: ▁im
Token # 1: ▁Bundestag

@strumke
Copy link
Author

strumke commented Jul 17, 2019 via email

@gooofy
Copy link
Contributor

gooofy commented Jul 17, 2019

not sure about the TODO (my best guess would be that a nicer error message could be an improvement there), but the real issue seems to be the model you're using which has a pretty small context length?

Try to generate fewer tokens (--tokens-to-generate 38) to see if that works.

@strumke
Copy link
Author

strumke commented Jul 17, 2019 via email

@gooofy
Copy link
Contributor

gooofy commented Jul 17, 2019

yes, the context length is a model hyperparameter - see /params.json

@strumke
Copy link
Author

strumke commented Jul 17, 2019 via email

@hbajohr
Copy link

hbajohr commented Dec 6, 2019

So does that mean that unless I retrain the model with a higher context length, I cannot generate any output longer than 48 characters? Or is there a way to do that without retraining?

@lopuhin
Copy link
Owner

lopuhin commented Dec 6, 2019

@ceprun for reference, 48 is number of tokens, not characters, and default values for gpt-2 are much larger than that - 48 is just for the integration test. Still there is some context length, but you can still generate contexts which are larger in length in theory, say context length is 4, so you generate:

input | output
-----------------
s     | a
s a   | b
s a b | c
a b c | d
b c d | e

I hope idea is clear - you truncate previously generated text on the right when feeding as context.

@hafsahabib-educator
Copy link

I have n-ctx = 1024 still facing the assertion error

`
AssertionError Traceback (most recent call last)
in ()
25
26 #f = open(output_csv, "a+", encoding='utf-8')
---> 27 tokens=mw.generate_tokens(txt,30,3)
28
29 st =mw.sp_model.DecodePieces(tokens)

10 frames
in generate_tokens(self, tokens_prefix, tokens_to_generate, top_k, top_p, temperature)
145 if top_p <= 0.0:
146 # generate TOP_K potential next tokens
--> 147 ntk, presents = self._get_next_top_k(tokens, top_k, past=past)
148
149 # convert log probs to real probs

in _get_next_top_k(self, tokens, top_k, past)
98 past: Optional[torch.Tensor],
99 ) -> Tuple[List[Tuple[float, str]], torch.Tensor]:
--> 100 next_log_probs, presents = self._get_log_probs(tokens, past=past)
101 next_log_probs = next_log_probs[-1]
102 result = sorted(

in _get_log_probs(self, tokens, past)
71 ctx = torch.LongTensor(ids).unsqueeze(0)
72 with torch.no_grad():
---> 73 output = self.model(ctx, past=past)
74 logits = output['logits'].squeeze(0)
75 return torch.log_softmax(logits, dim=1), output['presents']

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)

/content/drive/My Drive/Colab Notebooks/transformer-lm/lm/model.py in forward(self, x, past)
51 for i, block in enumerate(self.blocks):
52 if self.hparams.gradient_checkpointing:
---> 53 h, present = torch.utils.checkpoint.checkpoint(block, h, past[:, i] if past is not None else None)
54 else:
55 h, present = block(h, past=past[:, i] if past is not None else None)

/usr/local/lib/python3.6/dist-packages/torch/utils/checkpoint.py in checkpoint(function, *args, **kwargs)
153 raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
154
--> 155 return CheckpointFunction.apply(function, preserve, *args)
156
157

/usr/local/lib/python3.6/dist-packages/torch/utils/checkpoint.py in forward(ctx, run_function, preserve_rng_state, *args)
72 ctx.save_for_backward(*args)
73 with torch.no_grad():
---> 74 outputs = run_function(*args)
75 return outputs
76

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)

/content/drive/My Drive/Colab Notebooks/transformer-lm/lm/model.py in forward(self, x, past)
77
78 def forward(self, x, past):
---> 79 a, present = self.attn(self.ln_1(x), past=past)
80 x = x + a
81 m = self.mlp(self.ln_2(x))

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)

/content/drive/My Drive/Colab Notebooks/transformer-lm/lm/model.py in forward(self, x, past)
129 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]
130 assert len(past.shape) == 5
--> 131 assert past.shape[-1] == self.hparams.n_hidden
132 c = self.c_attn(x)
133 q, k, v = map(self.split_heads, torch.split(c, x.shape[-1], dim=2))

AssertionError:
`

@lopuhin
Copy link
Owner

lopuhin commented Jul 30, 2020

@hafsabukhary aha this is a different error, assert is not correct indeed, it was fixed recently in 4c18649 - so it should work once you update to recent master. Also check out updated generate_tokens method - now it implements more efficient generation using "past" (which is what you are doing here it seems).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants