-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: add (disabled) python-lang-id test
Will need #314 to enable this test. Signed-off-by: Nick Mitchell <[email protected]>
- Loading branch information
Showing
18 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Source: https://github.com/IBM/data-prep-kit/tree/dev/transforms/language/lang_id/python | ||
|
||
See [this README](../python-pii-redactor/README) for details. In addition to what is documented there, in this test, we also had to make some minor changes to [nlp.py](./pail/src/nlp.py) to separate from the TransformUtils class. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
test-data/ | ||
requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
apiVersion: lunchpail.io/v1alpha1 | ||
kind: Application | ||
metadata: | ||
name: lang_id | ||
spec: | ||
role: worker | ||
command: python3 ./main.py | ||
code: | ||
- name: main.py | ||
source: | | ||
{{ .Files.Get "src/main.py" | indent 8 }} | ||
- name: lang_models.py | ||
source: | | ||
{{ .Files.Get "src/lang_models.py" | indent 8 }} | ||
- name: nlp.py | ||
source: | | ||
{{ .Files.Get "src/nlp.py" | indent 8 }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{{- range until (.Values.pools | default 1 | int) }} | ||
--- | ||
apiVersion: lunchpail.io/v1alpha1 | ||
kind: WorkerPool | ||
metadata: | ||
name: {{ print "pool" (add 1 .) }} | ||
spec: | ||
workers: | ||
count: {{ $.Values.workers | default 1 }} | ||
size: {{ $.Values.size | default "xxs" }} | ||
{{- end }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
fasttext==0.9.2 | ||
langcodes==3.3.0 | ||
huggingface-hub >= 0.21.4, <1.0.0 | ||
numpy==1.26.4 | ||
pyarrow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# (C) Copyright IBM Corp. 2024. | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
################################################################################ | ||
|
||
import math | ||
from abc import ABCMeta, abstractmethod | ||
|
||
import fasttext | ||
from huggingface_hub import hf_hub_download | ||
from langcodes import standardize_tag | ||
|
||
|
||
KIND_FASTTEXT = "fasttext" | ||
|
||
|
||
class LangModel(metaclass=ABCMeta): | ||
@abstractmethod | ||
def detect_lang(self, text: str) -> tuple[str, float]: | ||
pass | ||
|
||
|
||
class NoopModel(metaclass=ABCMeta): | ||
def detect_lang(self, text: str) -> tuple[str, float]: | ||
return "en", 0.0 | ||
|
||
|
||
class FastTextModel(LangModel): | ||
def __init__(self, url, credential): | ||
model_path = hf_hub_download(repo_id=url, filename="model.bin", token=credential) | ||
self.nlp = fasttext.load_model(model_path) | ||
|
||
def detect_lang(self, text: str) -> tuple[str, float]: | ||
label, score = self.nlp.predict( | ||
text.replace("\n", " "), 1 | ||
) # replace newline to avoid ERROR: predict processes one line at a time (remove '\n') skipping the file | ||
return standardize_tag(label[0].replace("__label__", "")), math.floor(score[0] * 1000) / 1000 | ||
|
||
|
||
class LangModelFactory: | ||
def create_model(kind: str, url: str, credential: str) -> LangModel: | ||
if kind == KIND_FASTTEXT: | ||
return FastTextModel(url, credential) | ||
else: | ||
return NoopModel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# (C) Copyright IBM Corp. 2024. | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
################################################################################ | ||
|
||
import sys | ||
import pyarrow.parquet as pq | ||
|
||
from os import getenv | ||
import logging | ||
|
||
import pyarrow as pa | ||
from lang_models import LangModelFactory | ||
from nlp import get_lang_ds_pa | ||
|
||
from lang_models import KIND_FASTTEXT | ||
|
||
short_name = "lang_id" | ||
cli_prefix = f"{short_name}_" | ||
model_credential_key = "model_credential" | ||
model_kind_key = "model_kind" | ||
model_url_key = "model_url" | ||
content_column_name_key = "content_column_name" | ||
output_lang_column_name_key = "output_lang_column_name" | ||
output_score_column_name_key = "output_score_column_name" | ||
model_credential_cli_param = f"{cli_prefix}{model_credential_key}" | ||
model_kind_cli_param = f"{cli_prefix}{model_kind_key}" | ||
model_url_cli_param = f"{cli_prefix}{model_url_key}" | ||
content_column_name_cli_param = f"{cli_prefix}{content_column_name_key}" | ||
output_lang_column_name_cli_param = f"{cli_prefix}{output_lang_column_name_key}" | ||
output_score_column_name_cli_param = f"{cli_prefix}{output_score_column_name_key}" | ||
|
||
default_content_column_name = "text" | ||
default_output_lang_column_name = "lang" | ||
default_output_score_column_name = "score" | ||
|
||
model_kind = getenv(model_kind_key, KIND_FASTTEXT) | ||
model_url = getenv(model_url_key, "facebook/fasttext-language-identification") | ||
model_credential = getenv(model_credential_key, "PUT YOUR OWN HUGGINGFACE CREDENTIAL") | ||
|
||
def validate_columns(table: pa.Table, required: list[str]) -> None: | ||
""" | ||
Check if required columns exist in the table | ||
:param table: table | ||
:param required: list of required columns | ||
:return: None | ||
""" | ||
columns = table.schema.names | ||
result = True | ||
for r in required: | ||
if r not in columns: | ||
result = False | ||
break | ||
if not result: | ||
raise Exception( | ||
f"Not all required columns are present in the table - " f"required {required}, present {columns}" | ||
) | ||
|
||
logger=logging.getLogger(__name__) | ||
nlp_langid = LangModelFactory.create_model( | ||
model_kind, model_url, model_credential | ||
) | ||
content_column_name = getenv(content_column_name_key, default_content_column_name) | ||
output_lang_column_name = getenv(output_lang_column_name_key, default_output_lang_column_name) | ||
output_score_column_name = getenv(output_score_column_name_key, default_output_score_column_name) | ||
|
||
try: | ||
print(f"Reading in parquet file {sys.argv[1]}") | ||
table = pq.read_table(sys.argv[1]) | ||
except Exception as e: | ||
print(f"Error reading table from {path}: {e}", file=sys.stderr) | ||
exit(1) | ||
print(f"Done Reading in parquet file {sys.argv[1]}") | ||
|
||
validate_columns(table, [content_column_name]) | ||
if output_lang_column_name in table.schema.names: | ||
raise Exception(f"column to store identified language ({output_lang_column_name}) already exist") | ||
if output_score_column_name in table.schema.names: | ||
raise Exception( | ||
f"column to store score of language identification ({output_score_column_name}) already exist" | ||
) | ||
print(f"Transforming one table with {len(table)} rows") | ||
table, stats = get_lang_ds_pa( | ||
table, nlp_langid, content_column_name, output_lang_column_name, output_score_column_name) | ||
print(f"Transformed one table with {len(table)} rows") | ||
|
||
print(f"Done. Writing output to {sys.argv[2]}") | ||
pq.write_table(table, sys.argv[2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# (C) Copyright IBM Corp. 2024. | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
################################################################################ | ||
|
||
from typing import Any | ||
|
||
import logging | ||
import pyarrow as pa | ||
from lang_models import LangModel | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def add_column(table: pa.Table, name: str, content: list[Any]) -> pa.Table: | ||
""" | ||
Add column to the table | ||
:param table: original table | ||
:param name: column name | ||
:param content: content of the column | ||
:return: updated table, containing new column | ||
""" | ||
# check if column already exist and drop it | ||
if name in table.schema.names: | ||
table = table.drop(columns=[name]) | ||
# append column | ||
return table.append_column(field_=name, column=[content]) | ||
|
||
def get_lang_ds_pa( | ||
table: pa.table, | ||
nlp: LangModel, | ||
content_column_name: str, | ||
output_lang_column_name: str, | ||
output_score_column_name: str, | ||
) -> tuple[pa.table, dict[str, Any]]: | ||
detected_language = pa.Table.from_pylist( | ||
list( | ||
map( | ||
lambda r: {"lang": r[0], "score": r[1]}, | ||
map(lambda x: nlp.detect_lang(x), table[content_column_name].to_pylist()), | ||
) | ||
) | ||
) | ||
stats = pa.table([detected_language["lang"]], names=["lang"]).group_by("lang").aggregate([("lang", "count")]) | ||
stats_dict = {} | ||
for batch in stats.to_batches(): | ||
d = batch.to_pydict() | ||
for lang, count in zip(d["lang"], d["lang_count"]): | ||
stats_dict[lang] = count | ||
result = add_column(table=table, name=output_lang_column_name, content=detected_language["lang"]) | ||
result = add_column(table=result, name=output_score_column_name, content=detected_language["score"]) | ||
return result, stats_dict |
Binary file added
BIN
+773 Bytes
tests/tests/python-lang-id/pail/test-data/sm/expected/metadata.json.gz
Binary file not shown.
Binary file added
BIN
+32.5 KB
tests/tests/python-lang-id/pail/test-data/sm/expected/test_01.parquet.gz
Binary file not shown.
Binary file added
BIN
+87.5 KB
tests/tests/python-lang-id/pail/test-data/sm/expected/test_02.parquet.gz
Binary file not shown.
Binary file added
BIN
+257 KB
tests/tests/python-lang-id/pail/test-data/sm/expected/test_03.parquet.gz
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+86.1 KB
tests/tests/python-lang-id/pail/test-data/sm/input/test_02.parquet.gz
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#!/bin/sh | ||
|
||
DATA="$TEST_PATH"/pail/test-data/sm | ||
|
||
for i in $(seq 1 3) | ||
do | ||
actual="$DATA"/input/test_0$i.output.parquet | ||
expected="$DATA"/expected/test_0$i.parquet.gz | ||
|
||
if [ -f $actual ] | ||
then echo "✅ PASS found local task output file=$actual test=$TEST_NAME" && rm -f $actual | ||
else echo "❌ FAIL cannot find local task output file=$actual test=$TEST_NAME" && exit 1 | ||
fi | ||
|
||
actual_sha256=$(cat "$actual" | sha256) | ||
expected_sha256=$(gzcat "$expected" | sha256 ) | ||
|
||
if [ "$actual_sha256" = "$expected_sha256" ] | ||
then echo "✅ PASS found local task output file=$f test=$TEST_NAME" && rm -f $f | ||
else echo "❌ FAIL cannot find local task output file=$f test=$TEST_NAME" && exit 1 | ||
fi | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
api=workqueue | ||
|
||
expected=("Transforming one table") | ||
NUM_DESIRED_OUTPUTS=1 | ||
|
||
up_args='<(gzcat "$TEST_PATH"/pail/test-data/sm/input/test_01.parquet.gz) <(gzcat "$TEST_PATH"/pail/test-data/sm/input/test_02.parquet.gz) <(gzcat "$TEST_PATH"/pail/test-data/sm/input/test_03.parquet.gz)' |