From 50f8a2000714754359e54051d31493cd9351c9ca Mon Sep 17 00:00:00 2001 From: paulruelle Date: Thu, 19 Sep 2024 11:23:16 +0200 Subject: [PATCH] feat(lab-3088): add unit tests for `project_model` methods --- tests/unit/llm/test_project_model.py | 86 ++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 tests/unit/llm/test_project_model.py diff --git a/tests/unit/llm/test_project_model.py b/tests/unit/llm/test_project_model.py new file mode 100644 index 000000000..66d7a6941 --- /dev/null +++ b/tests/unit/llm/test_project_model.py @@ -0,0 +1,86 @@ +from kili.llm.presentation.client.llm import LlmClientMethods + +mock_list_project_models = [ + { + "id": "project_model_id_1", + "model": { + "credentials": { + "apiKey": "***", + "endpoint": "https://ai21-jamba-1-5-large-ykxca.eastus.models.ai.azure.com", + }, + "name": "Jamba (created by SDK)", + "type": "OPEN_AI_SDK", + }, + "configuration": {"model": "AI21-Jamba-1-5-Large-ykxca", "temperature": 0.5}, + }, + { + "id": "project_model_id_2", + "model": { + "credentials": { + "apiKey": "***", + "endpoint": "https://ai21-jamba-1-5-large-ykxca.eastus.models.ai.azure.com", + }, + "name": "Jamba (created by SDK)", + "type": "OPEN_AI_SDK", + }, + "configuration": {"model": "AI21-Jamba-1-5-Large-ykxca", "temperature": 0.7}, + }, +] + + +def test_list_project_models(mocker): + kili_api_gateway = mocker.MagicMock() + kili_api_gateway.list_project_models.return_value = mock_list_project_models + + kili_llm = LlmClientMethods(kili_api_gateway) + result = kili_llm.list_project_models(project_id="project_id") + + assert result == mock_list_project_models + + +def test_create_project_model(mocker): + mock_create_project_model = {"id": "new_project_model_id"} + + kili_api_gateway = mocker.MagicMock() + kili_api_gateway.create_project_model.return_value = mock_create_project_model + + kili_llm = LlmClientMethods(kili_api_gateway) + result = kili_llm.create_project_model( + project_id="project_id", + model_id="model_id", + configuration={ + "model": "AI21-Jamba-1-5-Large-ykxca", + "temperature": {"min": 0.2, "max": 0.8}, + }, + ) + + assert result == mock_create_project_model + + +def test_update_project_model(mocker): + mock_update_project_model = { + "id": "project_model_id", + "configuration": {"model": "AI21-Jamba-1-5-Large-ykxca", "temperature": 0.7}, + } + + kili_api_gateway = mocker.MagicMock() + kili_api_gateway.update_project_model.return_value = mock_update_project_model + + kili_llm = LlmClientMethods(kili_api_gateway) + result = kili_llm.update_project_model( + project_model_id="project_model_id", configuration={"temperature": 0.7} + ) + + assert result == mock_update_project_model + + +def test_delete_project_model(mocker): + delete_project_model_return_val = {"id": "project_model_id"} + + kili_api_gateway = mocker.MagicMock() + kili_api_gateway.delete_project_model.return_value = delete_project_model_return_val + + kili_llm = LlmClientMethods(kili_api_gateway) + result = kili_llm.delete_project_model(project_model_id="project_model_id") + + assert result == delete_project_model_return_val