From e8ebbe5486f53dbdf911b50ee8574c7084877971 Mon Sep 17 00:00:00 2001 From: Junghwan Park Date: Wed, 12 Jun 2024 23:40:17 +0900 Subject: [PATCH] fix typo --- .build/requirements-minimal.txt | 2 +- beginner_source/chatbot_tutorial.py | 2 +- intermediate_source/seq2seq_translation_tutorial.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.build/requirements-minimal.txt b/.build/requirements-minimal.txt index 6de5b47d..7a77da83 100644 --- a/.build/requirements-minimal.txt +++ b/.build/requirements-minimal.txt @@ -12,7 +12,7 @@ sphinx-sitemap sphinxext-opengraph sphinxcontrib-katex plotly==5.14.0 -torch==2.3 +torch torchvision torchtext torchaudio diff --git a/beginner_source/chatbot_tutorial.py b/beginner_source/chatbot_tutorial.py index 5d63c527..b1c5bf03 100644 --- a/beginner_source/chatbot_tutorial.py +++ b/beginner_source/chatbot_tutorial.py @@ -523,7 +523,7 @@ def outputVar(l, voc): max_target_len = max([len(indexes) for indexes in indexes_batch]) padList = zeroPadding(indexes_batch) mask = binaryMatrix(padList) - mask = torch.ByteTensor(mask) + mask = torch.BoolTensor(mask) padVar = torch.LongTensor(padList) return padVar, mask, max_target_len diff --git a/intermediate_source/seq2seq_translation_tutorial.py b/intermediate_source/seq2seq_translation_tutorial.py index 31143483..bcd1a5b1 100644 --- a/intermediate_source/seq2seq_translation_tutorial.py +++ b/intermediate_source/seq2seq_translation_tutorial.py @@ -750,7 +750,7 @@ def evaluateRandomly(encoder, decoder, n=10): pair = random.choice(pairs) print('>', pair[0]) print('=', pair[1]) - output_words, attentions = evaluate(encoder, decoder, pair[0]) + output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang) output_sentence = ' '.join(output_words) print('<', output_sentence) print('')