Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #529 from microsoft/bleik/optim-patch
Browse files Browse the repository at this point in the history
bleik/common transformers utils update
  • Loading branch information
saidbleik authored Jan 25, 2020
2 parents abeb88a + 6b35c49 commit 7dcdc32
Show file tree
Hide file tree
Showing 28 changed files with 1,099 additions and 1,213 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@
"source": [
"with Timer() as t:\n",
" preds = model.predict(\n",
" eval_dataloader=test_dataloader,\n",
" test_dataloader=test_dataloader,\n",
" num_gpus=None,\n",
" verbose=True\n",
" )\n",
Expand Down
205 changes: 146 additions & 59 deletions examples/text_classification/tc_mnli_transformers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"from sklearn.preprocessing import LabelEncoder\n",
"from tqdm import tqdm\n",
"from utils_nlp.common.timer import Timer\n",
"from utils_nlp.common.pytorch_utils import dataloader_from_dataset\n",
"from utils_nlp.dataset.multinli import load_pandas_df\n",
"from utils_nlp.models.transformers.sequence_classification import (\n",
" Processor, SequenceClassifier)"
Expand Down Expand Up @@ -93,7 +94,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 222k/222k [01:25<00:00, 2.60kKB/s] \n"
"100%|██████████| 222k/222k [01:20<00:00, 2.74kKB/s] \n"
]
}
],
Expand Down Expand Up @@ -196,7 +197,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/media/bleik2/miniconda3/envs/nlp_gpu/lib/python3.6/site-packages/sklearn/model_selection/_split.py:2179: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n",
"/media/bleik2/backup/.conda/envs/nlp_gpu/lib/python3.6/site-packages/sklearn/model_selection/_split.py:2179: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n",
" FutureWarning)\n"
]
}
Expand Down Expand Up @@ -232,11 +233,11 @@
{
"data": {
"text/plain": [
"telephone 1055\n",
"slate 1003\n",
"travel 961\n",
"fiction 952\n",
"government 938\n",
"telephone 1043\n",
"slate 989\n",
"fiction 968\n",
"travel 964\n",
"government 945\n",
"Name: genre, dtype: int64"
]
},
Expand Down Expand Up @@ -385,32 +386,108 @@
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>roberta-base</td>\n",
" <td>bert-base-japanese</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>roberta-large</td>\n",
" <td>bert-base-japanese-whole-word-masking</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>roberta-large-mnli</td>\n",
" <td>bert-base-japanese-char</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>xlnet-base-cased</td>\n",
" <td>bert-base-japanese-char-whole-word-masking</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>xlnet-large-cased</td>\n",
" <td>bert-base-finnish-cased-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>distilbert-base-uncased</td>\n",
" <td>bert-base-finnish-uncased-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>roberta-base</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>roberta-large</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>roberta-large-mnli</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>distilroberta-base</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>roberta-base-openai-detector</td>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>roberta-large-openai-detector</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>xlnet-base-cased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>xlnet-large-cased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <td>distilbert-base-uncased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <td>distilbert-base-uncased-distilled-squad</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31</th>\n",
" <td>distilbert-base-german-cased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>distilbert-base-multilingual-cased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33</th>\n",
" <td>albert-base-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>34</th>\n",
" <td>albert-large-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35</th>\n",
" <td>albert-xlarge-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36</th>\n",
" <td>albert-xxlarge-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37</th>\n",
" <td>albert-base-v2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38</th>\n",
" <td>albert-large-v2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>39</th>\n",
" <td>albert-xlarge-v2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40</th>\n",
" <td>albert-xxlarge-v2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
Expand All @@ -432,13 +509,32 @@
"12 bert-base-cased-finetuned-mrpc\n",
"13 bert-base-german-dbmdz-cased\n",
"14 bert-base-german-dbmdz-uncased\n",
"15 roberta-base\n",
"16 roberta-large\n",
"17 roberta-large-mnli\n",
"18 xlnet-base-cased\n",
"19 xlnet-large-cased\n",
"20 distilbert-base-uncased\n",
"21 distilbert-base-uncased-distilled-squad"
"15 bert-base-japanese\n",
"16 bert-base-japanese-whole-word-masking\n",
"17 bert-base-japanese-char\n",
"18 bert-base-japanese-char-whole-word-masking\n",
"19 bert-base-finnish-cased-v1\n",
"20 bert-base-finnish-uncased-v1\n",
"21 roberta-base\n",
"22 roberta-large\n",
"23 roberta-large-mnli\n",
"24 distilroberta-base\n",
"25 roberta-base-openai-detector\n",
"26 roberta-large-openai-detector\n",
"27 xlnet-base-cased\n",
"28 xlnet-large-cased\n",
"29 distilbert-base-uncased\n",
"30 distilbert-base-uncased-distilled-squad\n",
"31 distilbert-base-german-cased\n",
"32 distilbert-base-multilingual-cased\n",
"33 albert-base-v1\n",
"34 albert-large-v1\n",
"35 albert-xlarge-v1\n",
"36 albert-xxlarge-v1\n",
"37 albert-base-v2\n",
"38 albert-large-v2\n",
"39 albert-xlarge-v2\n",
"40 albert-xxlarge-v2"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -492,18 +588,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 231508/231508 [00:00<00:00, 15545441.79B/s]\n",
"100%|██████████| 492/492 [00:00<00:00, 560455.61B/s]\n",
"100%|██████████| 267967963/267967963 [00:04<00:00, 61255588.46B/s]\n",
"/media/bleik2/miniconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"100%|██████████| 898823/898823 [00:00<00:00, 23932308.55B/s]\n",
"100%|██████████| 456318/456318 [00:00<00:00, 23321916.66B/s]\n",
"100%|██████████| 473/473 [00:00<00:00, 477015.10B/s]\n",
"100%|██████████| 501200538/501200538 [00:07<00:00, 64332558.45B/s]\n",
"100%|██████████| 798011/798011 [00:00<00:00, 25002433.16B/s]\n",
"100%|██████████| 641/641 [00:00<00:00, 695974.34B/s]\n",
"100%|██████████| 467042463/467042463 [00:08<00:00, 55154509.21B/s]\n"
"/media/bleik2/backup/.conda/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
]
}
],
Expand All @@ -518,11 +604,17 @@
" to_lower=model_name.endswith(\"uncased\"),\n",
" cache_dir=CACHE_DIR,\n",
" )\n",
" train_dataloader = processor.create_dataloader_from_df(\n",
" df_train, TEXT_COL, LABEL_COL, max_len=MAX_LEN, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=True\n",
" train_dataset = processor.dataset_from_dataframe(\n",
" df_train, TEXT_COL, LABEL_COL, max_len=MAX_LEN\n",
" )\n",
" test_dataloader = processor.create_dataloader_from_df(\n",
" df_test, TEXT_COL, LABEL_COL, max_len=MAX_LEN, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=False\n",
" train_dataloader = dataloader_from_dataset(\n",
" train_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=True\n",
" )\n",
" test_dataset = processor.dataset_from_dataframe(\n",
" df_test, TEXT_COL, LABEL_COL, max_len=MAX_LEN\n",
" )\n",
" test_dataloader = dataloader_from_dataset(\n",
" test_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=False\n",
" )\n",
"\n",
" # fine-tune\n",
Expand All @@ -531,17 +623,12 @@
" )\n",
" with Timer() as t:\n",
" classifier.fit(\n",
" train_dataloader,\n",
" num_epochs=NUM_EPOCHS,\n",
" num_gpus=NUM_GPUS,\n",
" verbose=False,\n",
" train_dataloader, num_epochs=NUM_EPOCHS, num_gpus=NUM_GPUS, verbose=False,\n",
" )\n",
" train_time = t.interval / 3600\n",
"\n",
" # predict\n",
" preds = classifier.predict(\n",
" test_dataloader, num_gpus=NUM_GPUS, verbose=False\n",
" )\n",
" preds = classifier.predict(test_dataloader, num_gpus=NUM_GPUS, verbose=False)\n",
"\n",
" # eval\n",
" accuracy = accuracy_score(df_test[LABEL_COL], preds)\n",
Expand Down Expand Up @@ -600,31 +687,31 @@
" <tbody>\n",
" <tr>\n",
" <th>accuracy</th>\n",
" <td>0.895477</td>\n",
" <td>0.879584</td>\n",
" <td>0.894866</td>\n",
" <td>0.889364</td>\n",
" <td>0.885697</td>\n",
" <td>0.886308</td>\n",
" </tr>\n",
" <tr>\n",
" <th>f1-score</th>\n",
" <td>0.896656</td>\n",
" <td>0.881218</td>\n",
" <td>0.896108</td>\n",
" <td>0.885225</td>\n",
" <td>0.880926</td>\n",
" <td>0.881819</td>\n",
" </tr>\n",
" <tr>\n",
" <th>time(hrs)</th>\n",
" <td>0.021865</td>\n",
" <td>0.035351</td>\n",
" <td>0.046295</td>\n",
" <td>0.023326</td>\n",
" <td>0.044209</td>\n",
" <td>0.052801</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" distilbert-base-uncased roberta-base xlnet-base-cased\n",
"accuracy 0.895477 0.879584 0.894866\n",
"f1-score 0.896656 0.881218 0.896108\n",
"time(hrs) 0.021865 0.035351 0.046295"
"accuracy 0.889364 0.885697 0.886308\n",
"f1-score 0.885225 0.880926 0.881819\n",
"time(hrs) 0.023326 0.044209 0.052801"
]
},
"execution_count": 13,
Expand All @@ -645,7 +732,7 @@
{
"data": {
"application/scrapbook.scrap.json+json": {
"data": 0.8899755501222494,
"data": 0.887123064384678,
"encoder": "json",
"name": "accuracy",
"version": 1
Expand All @@ -663,7 +750,7 @@
{
"data": {
"application/scrapbook.scrap.json+json": {
"data": 0.8913273009038569,
"data": 0.8826569624491233,
"encoder": "json",
"name": "f1",
"version": 1
Expand All @@ -688,9 +775,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "nlp_gpu",
"display_name": "Python 3.6.8 64-bit ('nlp_gpu': conda)",
"language": "python",
"name": "nlp_gpu"
"name": "python36864bitnlpgpucondaa579511bcea84c65877ff3dca4205921"
},
"language_info": {
"codemirror_mode": {
Expand Down
Loading

0 comments on commit 7dcdc32

Please sign in to comment.