Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardizing data loader and pulling from split for adding custom dataset dwmw17 #268

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 44 additions & 37 deletions datasets/download_text_classification.sh
Original file line number Diff line number Diff line change
@@ -1,43 +1,50 @@
#!/bin/sh
DIR="./TextClassification"
DIR="./datasets/TextClassification"
mkdir $DIR
cd $DIR

rm -rf mnli
wget --content-disposition https://cloud.tsinghua.edu.cn/f/33182c22cb594e88b49b/?dl=1
tar -zxvf mnli.tar.gz
rm -rf mnli.tar.gz

rm -rf agnews
wget --content-disposition https://cloud.tsinghua.edu.cn/f/0fb6af2a1e6647b79098/?dl=1
tar -zxvf agnews.tar.gz
rm -rf agnews.tar.gz

rm -rf dbpedia
wget --content-disposition https://cloud.tsinghua.edu.cn/f/362d3cdaa63b4692bafb/?dl=1
tar -zxvf dbpedia.tar.gz
rm -rf dbpedia.tar.gz

rm -rf imdb
wget --content-disposition https://cloud.tsinghua.edu.cn/f/37bd6cb978d342db87ed/?dl=1
tar -zxvf imdb.tar.gz
rm -rf imdb.tar.gz

rm -rf SST-2
wget --content-disposition https://cloud.tsinghua.edu.cn/f/bccfdb243eca404f8bf3/?dl=1
tar -zxvf SST-2.tar.gz
rm -rf SST-2.tar.gz

rm -rf amazon
wget --content-disposition https://cloud.tsinghua.edu.cn/f/e00a4c44aaf844cdb6c9/?dl=1
tar -zxvf amazon.tar.gz
mv datasets/amazon/ amazon
rm -rf ./datasets
rm -rf amazon.tar.gz

rm -rf yahoo_answers_topics
wget --content-disposition https://cloud.tsinghua.edu.cn/f/79257038afaa4730a03f/?dl=1
tar -zxvf yahoo_answers_topics.tar.gz
rm -rf yahoo_answers_topics.tar.gz
# rm -rf mnli
# wget --content-disposition https://cloud.tsinghua.edu.cn/f/33182c22cb594e88b49b/?dl=1
# tar -zxvf mnli.tar.gz
# rm -rf mnli.tar.gz

# rm -rf agnews
# wget --content-disposition https://cloud.tsinghua.edu.cn/f/0fb6af2a1e6647b79098/?dl=1
# tar -zxvf agnews.tar.gz
# rm -rf agnews.tar.gz

# rm -rf dbpedia
# wget --content-disposition https://cloud.tsinghua.edu.cn/f/362d3cdaa63b4692bafb/?dl=1
# tar -zxvf dbpedia.tar.gz
# rm -rf dbpedia.tar.gz

# rm -rf imdb
# wget --content-disposition https://cloud.tsinghua.edu.cn/f/37bd6cb978d342db87ed/?dl=1
# tar -zxvf imdb.tar.gz
# rm -rf imdb.tar.gz

# rm -rf SST-2
# wget --content-disposition https://cloud.tsinghua.edu.cn/f/bccfdb243eca404f8bf3/?dl=1
# tar -zxvf SST-2.tar.gz
# rm -rf SST-2.tar.gz

# rm -rf amazon
# wget --content-disposition https://cloud.tsinghua.edu.cn/f/e00a4c44aaf844cdb6c9/?dl=1
# tar -zxvf amazon.tar.gz
# mv datasets/amazon/ amazon
# rm -rf ./datasets
# rm -rf amazon.tar.gz

# rm -rf yahoo_answers_topics
# wget --content-disposition https://cloud.tsinghua.edu.cn/f/79257038afaa4730a03f/?dl=1
# tar -zxvf yahoo_answers_topics.tar.gz
# rm -rf yahoo_answers_topics.tar.gz

rm -rf dwmw17
FILEID="1FW_qQX8aubnuFy--y8cY8HW26CFixFei"
FILENAME="dwmw17.zip"
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=${FILEID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O ${FILENAME} && rm -rf /tmp/cookies.txt
unzip dwmw17.zip
rm dwmw17.zip

cd ..
68 changes: 68 additions & 0 deletions experiments/classification_protoverb_dwmw17.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
dataset:
name: dwmw17
path: datasets/TextClassification/dwmw17

plm:
model_name: roberta
model_path: roberta-large
optimize:
freeze_para: False
lr: 0.00003
weight_decay: 0.01
scheduler:
type:
num_warmup_steps: 500

checkpoint:
save_latest: False
save_best: False

train:
batch_size: 2
num_epochs: 5
train_verblizer: post
clean: True

test:
batch_size: 2

template: manual_template
verbalizer: proto_verbalizer

manual_template:
choice: 0
file_path: scripts/TextClassification/dwmw17/manual_template.txt

proto_verbalizer:
parent_config: dwmw17
choice: 0
file_path: scripts/TextClassification/dwmw17/icl_verbalizer.json
lr: 0.01
mid_dim: 128
epochs: 30
multi_verb: multi



environment:
num_gpus: 1
cuda_visible_devices:
- 0
local_rank: 0

learning_setting: few_shot

few_shot:
parent_config: learning_setting
few_shot_sampling: sampling_from_train

sampling_from_train:
parent_config: few_shot_sampling
num_examples_per_label: 1
also_sample_dev: True
num_examples_per_label_dev: 1
seed:
- 123

reproduce: # seed for reproduction
seed: 123 # a seed for all random part
40 changes: 40 additions & 0 deletions openprompt/data_utils/text_classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import os
import json, csv
import pandas as pd
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import List, Dict, Callable
Expand All @@ -27,6 +28,44 @@
from openprompt.data_utils.data_processor import DataProcessor


class Dwmw17Processor(DataProcessor):
"""
from openprompt.data_utils.text_classification_dataset import PROCESSORS
import os
# Get the absolute path of the parent directory of the current file
root_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

# Set the base path to the 'datasets' directory located in the parent directory
base_path = os.path.join(root_dir, 'datasets/TextClassification')


dataset_name = "dwmw17"
dataset_path = os.path.join(base_path, dataset_name)
processor = PROCESSORS[dataset_name.lower()]()
trainvalid_dataset = processor.get_train_examples(dataset_path)
print(trainvalid_dataset)
"""
def __init__(self):
super().__init__()
self.labels = [ "hate speech", "offensive language", "neither" ]

def get_examples(self, data_dir, split):
path = os.path.join(data_dir, "{}.csv".format(split))
examples = []
with open(path, encoding='utf8') as f:
reader = csv.reader(f, delimiter=',')
# Skip first row
next(reader)
for idx, row in enumerate(reader):
idx, _, _, _, _, label, tweet = row
text_a = tweet
example = InputExample(
guid=str(idx), text_a=text_a, label=int(label))
examples.append(example)

return examples


class MnliProcessor(DataProcessor):
# TODO Test needed
def __init__(self):
Expand Down Expand Up @@ -358,4 +397,5 @@ def get_examples(self, data_dir, split):
"sst-2": SST2Processor,
"mnli": MnliProcessor,
"yahoo": YahooProcessor,
"dwmw17": Dwmw17Processor
}
5 changes: 5 additions & 0 deletions scripts/TextClassification/dwmw17/icl_verbalizer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"hate speech": ["Hateful", "Malicious", "Malevolent", "Vicious", "Nefarious", "Sinister", "Discriminatory", "Harmful", "Abusive", "Prejudice"],
"offensive language": ["Offensive", "Insulting", "Rude", "Inappropriate", "Insensitive", "Controversial", "Obscenity", "Profanity"],
"neither": ["Harmless", "Innocent", "Benign", "Nonthreatening", "Inoffensive", "Amicable", "Acceptable", "Respectful", "Neutral"]
}
3 changes: 3 additions & 0 deletions scripts/TextClassification/dwmw17/manual_template.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
This tweet contains {"mask"} . {"placeholder": "text_a"}
This tweet is {"mask"} . {"placeholder": "text_a"}
A {"mask"} tweet : {"placeholder": "text_a"}