Skip to content

Commit

Permalink
test: add (disabled) python-lang-id test
Browse files Browse the repository at this point in the history
Will need #314 to enable this test.

Signed-off-by: Nick Mitchell <[email protected]>
  • Loading branch information
starpit committed Oct 3, 2024
1 parent 4524d0e commit ecb5252
Show file tree
Hide file tree
Showing 18 changed files with 272 additions and 0 deletions.
Empty file.
3 changes: 3 additions & 0 deletions tests/tests/python-lang-id/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Source: https://github.com/IBM/data-prep-kit/tree/dev/transforms/language/lang_id/python

See [this README](../python-pii-redactor/README) for details. In addition to what is documented there, in this test, we also had to make some minor changes to [nlp.py](./pail/src/nlp.py) to separate from the TransformUtils class.
2 changes: 2 additions & 0 deletions tests/tests/python-lang-id/pail/.helmignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test-data/
requirements.txt
17 changes: 17 additions & 0 deletions tests/tests/python-lang-id/pail/app.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
apiVersion: lunchpail.io/v1alpha1
kind: Application
metadata:
name: lang_id
spec:
role: worker
command: python3 ./main.py
code:
- name: main.py
source: |
{{ .Files.Get "src/main.py" | indent 8 }}
- name: lang_models.py
source: |
{{ .Files.Get "src/lang_models.py" | indent 8 }}
- name: nlp.py
source: |
{{ .Files.Get "src/nlp.py" | indent 8 }}
11 changes: 11 additions & 0 deletions tests/tests/python-lang-id/pail/pool1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{{- range until (.Values.pools | default 1 | int) }}
---
apiVersion: lunchpail.io/v1alpha1
kind: WorkerPool
metadata:
name: {{ print "pool" (add 1 .) }}
spec:
workers:
count: {{ $.Values.workers | default 1 }}
size: {{ $.Values.size | default "xxs" }}
{{- end }}
5 changes: 5 additions & 0 deletions tests/tests/python-lang-id/pail/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fasttext==0.9.2
langcodes==3.3.0
huggingface-hub >= 0.21.4, <1.0.0
numpy==1.26.4
pyarrow
52 changes: 52 additions & 0 deletions tests/tests/python-lang-id/pail/src/lang_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# (C) Copyright IBM Corp. 2024.
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

import math
from abc import ABCMeta, abstractmethod

import fasttext
from huggingface_hub import hf_hub_download
from langcodes import standardize_tag


KIND_FASTTEXT = "fasttext"


class LangModel(metaclass=ABCMeta):
@abstractmethod
def detect_lang(self, text: str) -> tuple[str, float]:
pass


class NoopModel(metaclass=ABCMeta):
def detect_lang(self, text: str) -> tuple[str, float]:
return "en", 0.0


class FastTextModel(LangModel):
def __init__(self, url, credential):
model_path = hf_hub_download(repo_id=url, filename="model.bin", token=credential)
self.nlp = fasttext.load_model(model_path)

def detect_lang(self, text: str) -> tuple[str, float]:
label, score = self.nlp.predict(
text.replace("\n", " "), 1
) # replace newline to avoid ERROR: predict processes one line at a time (remove '\n') skipping the file
return standardize_tag(label[0].replace("__label__", "")), math.floor(score[0] * 1000) / 1000


class LangModelFactory:
def create_model(kind: str, url: str, credential: str) -> LangModel:
if kind == KIND_FASTTEXT:
return FastTextModel(url, credential)
else:
return NoopModel()
95 changes: 95 additions & 0 deletions tests/tests/python-lang-id/pail/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# (C) Copyright IBM Corp. 2024.
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

import sys
import pyarrow.parquet as pq

from os import getenv
import logging

import pyarrow as pa
from lang_models import LangModelFactory
from nlp import get_lang_ds_pa

from lang_models import KIND_FASTTEXT

short_name = "lang_id"
cli_prefix = f"{short_name}_"
model_credential_key = "model_credential"
model_kind_key = "model_kind"
model_url_key = "model_url"
content_column_name_key = "content_column_name"
output_lang_column_name_key = "output_lang_column_name"
output_score_column_name_key = "output_score_column_name"
model_credential_cli_param = f"{cli_prefix}{model_credential_key}"
model_kind_cli_param = f"{cli_prefix}{model_kind_key}"
model_url_cli_param = f"{cli_prefix}{model_url_key}"
content_column_name_cli_param = f"{cli_prefix}{content_column_name_key}"
output_lang_column_name_cli_param = f"{cli_prefix}{output_lang_column_name_key}"
output_score_column_name_cli_param = f"{cli_prefix}{output_score_column_name_key}"

default_content_column_name = "text"
default_output_lang_column_name = "lang"
default_output_score_column_name = "score"

model_kind = getenv(model_kind_key, KIND_FASTTEXT)
model_url = getenv(model_url_key, "facebook/fasttext-language-identification")
model_credential = getenv(model_credential_key, "PUT YOUR OWN HUGGINGFACE CREDENTIAL")

def validate_columns(table: pa.Table, required: list[str]) -> None:
"""
Check if required columns exist in the table
:param table: table
:param required: list of required columns
:return: None
"""
columns = table.schema.names
result = True
for r in required:
if r not in columns:
result = False
break
if not result:
raise Exception(
f"Not all required columns are present in the table - " f"required {required}, present {columns}"
)

logger=logging.getLogger(__name__)
nlp_langid = LangModelFactory.create_model(
model_kind, model_url, model_credential
)
content_column_name = getenv(content_column_name_key, default_content_column_name)
output_lang_column_name = getenv(output_lang_column_name_key, default_output_lang_column_name)
output_score_column_name = getenv(output_score_column_name_key, default_output_score_column_name)

try:
print(f"Reading in parquet file {sys.argv[1]}")
table = pq.read_table(sys.argv[1])
except Exception as e:
print(f"Error reading table from {path}: {e}", file=sys.stderr)
exit(1)
print(f"Done Reading in parquet file {sys.argv[1]}")

validate_columns(table, [content_column_name])
if output_lang_column_name in table.schema.names:
raise Exception(f"column to store identified language ({output_lang_column_name}) already exist")
if output_score_column_name in table.schema.names:
raise Exception(
f"column to store score of language identification ({output_score_column_name}) already exist"
)
print(f"Transforming one table with {len(table)} rows")
table, stats = get_lang_ds_pa(
table, nlp_langid, content_column_name, output_lang_column_name, output_score_column_name)
print(f"Transformed one table with {len(table)} rows")

print(f"Done. Writing output to {sys.argv[2]}")
pq.write_table(table, sys.argv[2])
59 changes: 59 additions & 0 deletions tests/tests/python-lang-id/pail/src/nlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# (C) Copyright IBM Corp. 2024.
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

from typing import Any

import logging
import pyarrow as pa
from lang_models import LangModel


logger = logging.getLogger(__name__)

def add_column(table: pa.Table, name: str, content: list[Any]) -> pa.Table:
"""
Add column to the table
:param table: original table
:param name: column name
:param content: content of the column
:return: updated table, containing new column
"""
# check if column already exist and drop it
if name in table.schema.names:
table = table.drop(columns=[name])
# append column
return table.append_column(field_=name, column=[content])

def get_lang_ds_pa(
table: pa.table,
nlp: LangModel,
content_column_name: str,
output_lang_column_name: str,
output_score_column_name: str,
) -> tuple[pa.table, dict[str, Any]]:
detected_language = pa.Table.from_pylist(
list(
map(
lambda r: {"lang": r[0], "score": r[1]},
map(lambda x: nlp.detect_lang(x), table[content_column_name].to_pylist()),
)
)
)
stats = pa.table([detected_language["lang"]], names=["lang"]).group_by("lang").aggregate([("lang", "count")])
stats_dict = {}
for batch in stats.to_batches():
d = batch.to_pydict()
for lang, count in zip(d["lang"], d["lang_count"]):
stats_dict[lang] = count
result = add_column(table=table, name=output_lang_column_name, content=detected_language["lang"])
result = add_column(table=result, name=output_score_column_name, content=detected_language["score"])
return result, stats_dict
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
22 changes: 22 additions & 0 deletions tests/tests/python-lang-id/post.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/sh

DATA="$TEST_PATH"/pail/test-data/sm

for i in $(seq 1 3)
do
actual="$DATA"/input/test_0$i.output.parquet
expected="$DATA"/expected/test_0$i.parquet.gz

if [ -f $actual ]
then echo "✅ PASS found local task output file=$actual test=$TEST_NAME" && rm -f $actual
else echo "❌ FAIL cannot find local task output file=$actual test=$TEST_NAME" && exit 1
fi

actual_sha256=$(cat "$actual" | sha256)
expected_sha256=$(gzcat "$expected" | sha256 )

if [ "$actual_sha256" = "$expected_sha256" ]
then echo "✅ PASS found local task output file=$f test=$TEST_NAME" && rm -f $f
else echo "❌ FAIL cannot find local task output file=$f test=$TEST_NAME" && exit 1
fi
done
6 changes: 6 additions & 0 deletions tests/tests/python-lang-id/settings.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
api=workqueue

expected=("Transforming one table")
NUM_DESIRED_OUTPUTS=1

up_args='<(gzcat "$TEST_PATH"/pail/test-data/sm/input/test_01.parquet.gz) <(gzcat "$TEST_PATH"/pail/test-data/sm/input/test_02.parquet.gz) <(gzcat "$TEST_PATH"/pail/test-data/sm/input/test_03.parquet.gz)'

0 comments on commit ecb5252

Please sign in to comment.