Skip to content

Commit

Permalink
Fix rnn example. (PaddlePaddle#971)
Browse files Browse the repository at this point in the history
1. typo fix: stm_hidden_size -> lstm_hidden_size.
2. import fix: add missing imports in export_model.py
3. nit: remove unused imports.
  • Loading branch information
songzy12 authored Sep 3, 2021
1 parent c7fb2b9 commit 98e3564
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
1 change: 0 additions & 1 deletion examples/text_classification/rnn/deploy/python/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import numpy as np
import paddle
from paddle import inference
from paddlenlp.data import JiebaTokenizer, Stack, Tuple, Pad, Vocab
from scipy.special import softmax

Expand Down
5 changes: 3 additions & 2 deletions examples/text_classification/rnn/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import argparse

import paddle
import paddlenlp as ppnlp
from paddlenlp.data import Vocab

from model import BoWModel, BiLSTMAttentionModel, CNNModel, LSTMModel, GRUModel, RNNModel, SelfInteractiveAttention

# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--vocab_path", type=str, default="./senta_word_dict.txt", help="The path to vocabulary.")
Expand Down Expand Up @@ -56,7 +57,7 @@ def main():
padding_idx=pad_token_id)
elif network == 'bilstm_attn':
lstm_hidden_size = 196
attention = SelfInteractiveAttention(hidden_size=2 * stm_hidden_size)
attention = SelfInteractiveAttention(hidden_size=2 * lstm_hidden_size)
model = BiLSTMAttentionModel(
attention_layer=attention,
vocab_size=vocab_size,
Expand Down
3 changes: 1 addition & 2 deletions examples/text_classification/rnn/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import paddle
import paddle.nn.functional as F
import paddlenlp as ppnlp
from paddlenlp.data import JiebaTokenizer, Stack, Tuple, Pad, Vocab

from model import BoWModel, BiLSTMAttentionModel, CNNModel, LSTMModel, GRUModel, RNNModel, SelfInteractiveAttention
Expand Down Expand Up @@ -102,7 +101,7 @@ def predict(model, data, label_map, batch_size=1, pad_token_id=0):
padding_idx=pad_token_id)
elif network == 'bilstm_attn':
lstm_hidden_size = 196
attention = SelfInteractiveAttention(hidden_size=2 * stm_hidden_size)
attention = SelfInteractiveAttention(hidden_size=2 * lstm_hidden_size)
model = BiLSTMAttentionModel(
attention_layer=attention,
vocab_size=vocab_size,
Expand Down

0 comments on commit 98e3564

Please sign in to comment.