diff --git a/lm/main.py b/lm/main.py index 06cb248..5ca11ac 100644 --- a/lm/main.py +++ b/lm/main.py @@ -227,6 +227,8 @@ def get_valid_loss(): with torch.no_grad(): for ctx in _valid_batch_iter( valid_dataset, batch_size=batch_size, n_ctx=n_ctx): + if not ctx: + continue ctx = torch.LongTensor(ctx).to(device) logits = model(ctx)['logits'] loss = loss_fn(logits, ctx)