From d0da492f9ee5495b9dcf2559507774b996fbb7b1 Mon Sep 17 00:00:00 2001 From: Dat Nguyen Date: Sat, 13 Jan 2024 10:46:35 +0700 Subject: [PATCH] chore: add unit tests for dbt cloud feat (#69) --- dbterd/adapters/dbt_cloud.py | 3 +- tests/unit/adapters/test_base.py | 21 ++-- tests/unit/adapters/test_dbt_cloud.py | 148 ++++++++++++++++++++++++-- 3 files changed, 155 insertions(+), 17 deletions(-) diff --git a/dbterd/adapters/dbt_cloud.py b/dbterd/adapters/dbt_cloud.py index 99542a1..51512fd 100644 --- a/dbterd/adapters/dbt_cloud.py +++ b/dbterd/adapters/dbt_cloud.py @@ -1,7 +1,6 @@ import os import json -import click import requests from dbterd.helpers import file @@ -78,7 +77,7 @@ def download_artifact(self, artifact: str, artifacts_dir: str) -> bool: data=json.dumps(r.json(), indent=2), path=f"{artifacts_dir}/{artifact}.json", ) - except click.BadParameter as e: + except Exception as e: logger.error(f"Error occurred while downloading [error: {str(e)}]") return False diff --git a/tests/unit/adapters/test_base.py b/tests/unit/adapters/test_base.py index 178095b..2ee343f 100644 --- a/tests/unit/adapters/test_base.py +++ b/tests/unit/adapters/test_base.py @@ -10,13 +10,6 @@ class TestBase: - # .run( - # target="dbml", - # algo="test_relationship", - # artifacts_dir="/", - # output="/" - # ) - def test_worker(self): worker = Executor(ctx=click.Context(command=click.BaseCommand("dummy"))) assert worker.filename_manifest == "manifest.json" @@ -101,8 +94,19 @@ def test__get_selection__error(self, mock_dbt_invocation): exclude=[], ), ), + ( + dict(select=[], exclude=[], dbt_cloud=True), + dict( + dbt_cloud=True, + artifacts_dir="/path/dpd/target", + dbt_project_dir="/path/dpd", + select=[], + exclude=[], + ), + ), ], ) + @mock.patch("dbterd.adapters.dbt_cloud.DbtCloudArtifact.get") @mock.patch("dbterd.adapters.base.Executor._Executor__get_dir") @mock.patch("dbterd.adapters.base.Executor._Executor__get_selection") @mock.patch("dbterd.adapters.base.DbtInvocation.get_artifacts_for_erd") @@ -111,6 +115,7 @@ def test_evaluate_kwargs( mock_get_artifacts_for_erd, mock_get_selection, mock_get_dir, + mock_dbt_cloud_get, kwargs, expected, ): @@ -121,6 +126,8 @@ def test_evaluate_kwargs( mock_get_dir.assert_called_once() if kwargs.get("dbt_auto_artifacts"): mock_get_artifacts_for_erd.assert_called_once() + if kwargs.get("dbt_cloud"): + mock_dbt_cloud_get.assert_called_once() @pytest.mark.parametrize( "kwargs, mock_isfile_se, expected", diff --git a/tests/unit/adapters/test_dbt_cloud.py b/tests/unit/adapters/test_dbt_cloud.py index ee477e0..7b903d0 100644 --- a/tests/unit/adapters/test_dbt_cloud.py +++ b/tests/unit/adapters/test_dbt_cloud.py @@ -1,14 +1,146 @@ -# from unittest import mock +import json +from unittest import mock -# import click -# import pytest +import pytest +import requests -# from dbterd.adapters.dbt_cloud import DbtCloudArtifact +from dbterd.adapters.dbt_cloud import DbtCloudArtifact + + +class MockResponse: + def __init__(self, status_code, data=None) -> None: + self.status_code = status_code + self.data = data + + def json(self): + return self.data class TestDbtCloudArtifact: - def test_download_artifact(self): - pass + @pytest.fixture + def dbtCloudArtifact(self) -> DbtCloudArtifact: + return DbtCloudArtifact( + dbt_cloud_host_url="irrelevant_url", + dbt_cloud_service_token="irrelevant_st", + dbt_cloud_account_id="irrelevant_acc_id", + dbt_cloud_run_id="irrelevant_run_id", + dbt_cloud_api_version="irrelevant_v", + ) + + @pytest.mark.parametrize( + "kwargs, expected", + [ + ( + dict(), + dict( + host_url=None, + service_token=None, + account_id=None, + run_id=None, + api_version=None, + ), + ), + ( + dict( + dbt_cloud_host_url="host_url", + dbt_cloud_service_token="service_token", + dbt_cloud_account_id="account_id", + dbt_cloud_run_id="run_id", + dbt_cloud_api_version="api_version", + ), + dict( + host_url="host_url", + service_token="service_token", + account_id="account_id", + run_id="run_id", + api_version="api_version", + ), + ), + ], + ) + def test_init(self, kwargs, expected): + dbt_cloud = DbtCloudArtifact(**kwargs) + assert vars(dbt_cloud) == expected + assert dbt_cloud.request_headers == { + "Authorization": f"Token {kwargs.get('dbt_cloud_service_token')}" + } + assert dbt_cloud.api_endpoint == ( + "https://{host_url}/api/{api_version}/" + "accounts/{account_id}/" + "runs/{run_id}/" + "artifacts/{{path}}" + ).format(**expected) + assert dbt_cloud.manifest_api_endpoint == ( + "https://{host_url}/api/{api_version}/" + "accounts/{account_id}/" + "runs/{run_id}/" + "artifacts/manifest.json" + ).format(**expected) + assert dbt_cloud.catalog_api_endpoint == ( + "https://{host_url}/api/{api_version}/" + "accounts/{account_id}/" + "runs/{run_id}/" + "artifacts/catalog.json" + ).format(**expected) + + @mock.patch("dbterd.adapters.dbt_cloud.file.write_json") + @mock.patch("dbterd.adapters.dbt_cloud.requests.get") + def test_download_artifact_ok( + self, mock_requests_get, mock_write_json, dbtCloudArtifact + ): + mock_requests_get.return_value = MockResponse(status_code=200, data={}) + assert dbtCloudArtifact.download_artifact( + artifact="manifest", artifacts_dir="/irrelevant/path" + ) + mock_write_json.assert_called_once_with( + data=json.dumps({}, indent=2), + path="/irrelevant/path/manifest.json", + ) + + @mock.patch("dbterd.adapters.dbt_cloud.file.write_json") + def test_download_artifact_bad_parameters(self, mock_write_json, dbtCloudArtifact): + with pytest.raises(AttributeError): + dbtCloudArtifact.download_artifact( + artifact="irrelevant", artifacts_dir="/irrelevant/path" + ) + assert mock_write_json.call_count == 0 + + @mock.patch("dbterd.adapters.dbt_cloud.file.write_json") + @mock.patch("dbterd.adapters.dbt_cloud.requests.get") + def test_download_artifact_network_failed( + self, mock_requests_get, mock_write_json, dbtCloudArtifact + ): + mock_requests_get.side_effect = requests.exceptions.ConnectionError() + assert not dbtCloudArtifact.download_artifact( + artifact="manifest", artifacts_dir="/irrelevant/path" + ) + assert mock_write_json.call_count == 0 + + @mock.patch("dbterd.adapters.dbt_cloud.file.write_json") + @mock.patch("dbterd.adapters.dbt_cloud.requests.get") + def test_download_artifact_failed_to_save_file( + self, mock_requests_get, mock_write_json, dbtCloudArtifact + ): + mock_requests_get.return_value = MockResponse(status_code=200, data={}) + mock_write_json.side_effect = Exception("any error") + assert not dbtCloudArtifact.download_artifact( + artifact="manifest", artifacts_dir="/irrelevant/path" + ) + assert mock_write_json.call_count == 1 + + @mock.patch("dbterd.adapters.dbt_cloud.file.write_json") + @mock.patch("dbterd.adapters.dbt_cloud.requests.get") + def test_download_artifact_status_not_ok( + self, mock_requests_get, mock_write_json, dbtCloudArtifact + ): + mock_requests_get.return_value = MockResponse(status_code=999) + assert not dbtCloudArtifact.download_artifact( + artifact="manifest", artifacts_dir="/irrelevant/path" + ) + assert mock_write_json.call_count == 0 - def test_get(self): - pass + @mock.patch("dbterd.adapters.dbt_cloud.DbtCloudArtifact.download_artifact") + def test_get(self, mock_download_artifact, dbtCloudArtifact): + mock_download_artifact.return_value = True + assert dbtCloudArtifact.get(artifacts_dir="/irrelevant/path") + assert mock_download_artifact.call_count == 2