Skip to content

Commit

Permalink
chore: add unit tests for dbt cloud feat (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
datnguye authored Jan 13, 2024
1 parent 4eda823 commit d0da492
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 17 deletions.
3 changes: 1 addition & 2 deletions dbterd/adapters/dbt_cloud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import json

import click
import requests

from dbterd.helpers import file
Expand Down Expand Up @@ -78,7 +77,7 @@ def download_artifact(self, artifact: str, artifacts_dir: str) -> bool:
data=json.dumps(r.json(), indent=2),
path=f"{artifacts_dir}/{artifact}.json",
)
except click.BadParameter as e:
except Exception as e:
logger.error(f"Error occurred while downloading [error: {str(e)}]")
return False

Expand Down
21 changes: 14 additions & 7 deletions tests/unit/adapters/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@


class TestBase:
# .run(
# target="dbml",
# algo="test_relationship",
# artifacts_dir="/",
# output="/"
# )

def test_worker(self):
worker = Executor(ctx=click.Context(command=click.BaseCommand("dummy")))
assert worker.filename_manifest == "manifest.json"
Expand Down Expand Up @@ -101,8 +94,19 @@ def test__get_selection__error(self, mock_dbt_invocation):
exclude=[],
),
),
(
dict(select=[], exclude=[], dbt_cloud=True),
dict(
dbt_cloud=True,
artifacts_dir="/path/dpd/target",
dbt_project_dir="/path/dpd",
select=[],
exclude=[],
),
),
],
)
@mock.patch("dbterd.adapters.dbt_cloud.DbtCloudArtifact.get")
@mock.patch("dbterd.adapters.base.Executor._Executor__get_dir")
@mock.patch("dbterd.adapters.base.Executor._Executor__get_selection")
@mock.patch("dbterd.adapters.base.DbtInvocation.get_artifacts_for_erd")
Expand All @@ -111,6 +115,7 @@ def test_evaluate_kwargs(
mock_get_artifacts_for_erd,
mock_get_selection,
mock_get_dir,
mock_dbt_cloud_get,
kwargs,
expected,
):
Expand All @@ -121,6 +126,8 @@ def test_evaluate_kwargs(
mock_get_dir.assert_called_once()
if kwargs.get("dbt_auto_artifacts"):
mock_get_artifacts_for_erd.assert_called_once()
if kwargs.get("dbt_cloud"):
mock_dbt_cloud_get.assert_called_once()

@pytest.mark.parametrize(
"kwargs, mock_isfile_se, expected",
Expand Down
148 changes: 140 additions & 8 deletions tests/unit/adapters/test_dbt_cloud.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,146 @@
# from unittest import mock
import json
from unittest import mock

# import click
# import pytest
import pytest
import requests

# from dbterd.adapters.dbt_cloud import DbtCloudArtifact
from dbterd.adapters.dbt_cloud import DbtCloudArtifact


class MockResponse:
def __init__(self, status_code, data=None) -> None:
self.status_code = status_code
self.data = data

def json(self):
return self.data


class TestDbtCloudArtifact:
def test_download_artifact(self):
pass
@pytest.fixture
def dbtCloudArtifact(self) -> DbtCloudArtifact:
return DbtCloudArtifact(
dbt_cloud_host_url="irrelevant_url",
dbt_cloud_service_token="irrelevant_st",
dbt_cloud_account_id="irrelevant_acc_id",
dbt_cloud_run_id="irrelevant_run_id",
dbt_cloud_api_version="irrelevant_v",
)

@pytest.mark.parametrize(
"kwargs, expected",
[
(
dict(),
dict(
host_url=None,
service_token=None,
account_id=None,
run_id=None,
api_version=None,
),
),
(
dict(
dbt_cloud_host_url="host_url",
dbt_cloud_service_token="service_token",
dbt_cloud_account_id="account_id",
dbt_cloud_run_id="run_id",
dbt_cloud_api_version="api_version",
),
dict(
host_url="host_url",
service_token="service_token",
account_id="account_id",
run_id="run_id",
api_version="api_version",
),
),
],
)
def test_init(self, kwargs, expected):
dbt_cloud = DbtCloudArtifact(**kwargs)
assert vars(dbt_cloud) == expected
assert dbt_cloud.request_headers == {
"Authorization": f"Token {kwargs.get('dbt_cloud_service_token')}"
}
assert dbt_cloud.api_endpoint == (
"https://{host_url}/api/{api_version}/"
"accounts/{account_id}/"
"runs/{run_id}/"
"artifacts/{{path}}"
).format(**expected)
assert dbt_cloud.manifest_api_endpoint == (
"https://{host_url}/api/{api_version}/"
"accounts/{account_id}/"
"runs/{run_id}/"
"artifacts/manifest.json"
).format(**expected)
assert dbt_cloud.catalog_api_endpoint == (
"https://{host_url}/api/{api_version}/"
"accounts/{account_id}/"
"runs/{run_id}/"
"artifacts/catalog.json"
).format(**expected)

@mock.patch("dbterd.adapters.dbt_cloud.file.write_json")
@mock.patch("dbterd.adapters.dbt_cloud.requests.get")
def test_download_artifact_ok(
self, mock_requests_get, mock_write_json, dbtCloudArtifact
):
mock_requests_get.return_value = MockResponse(status_code=200, data={})
assert dbtCloudArtifact.download_artifact(
artifact="manifest", artifacts_dir="/irrelevant/path"
)
mock_write_json.assert_called_once_with(
data=json.dumps({}, indent=2),
path="/irrelevant/path/manifest.json",
)

@mock.patch("dbterd.adapters.dbt_cloud.file.write_json")
def test_download_artifact_bad_parameters(self, mock_write_json, dbtCloudArtifact):
with pytest.raises(AttributeError):
dbtCloudArtifact.download_artifact(
artifact="irrelevant", artifacts_dir="/irrelevant/path"
)
assert mock_write_json.call_count == 0

@mock.patch("dbterd.adapters.dbt_cloud.file.write_json")
@mock.patch("dbterd.adapters.dbt_cloud.requests.get")
def test_download_artifact_network_failed(
self, mock_requests_get, mock_write_json, dbtCloudArtifact
):
mock_requests_get.side_effect = requests.exceptions.ConnectionError()
assert not dbtCloudArtifact.download_artifact(
artifact="manifest", artifacts_dir="/irrelevant/path"
)
assert mock_write_json.call_count == 0

@mock.patch("dbterd.adapters.dbt_cloud.file.write_json")
@mock.patch("dbterd.adapters.dbt_cloud.requests.get")
def test_download_artifact_failed_to_save_file(
self, mock_requests_get, mock_write_json, dbtCloudArtifact
):
mock_requests_get.return_value = MockResponse(status_code=200, data={})
mock_write_json.side_effect = Exception("any error")
assert not dbtCloudArtifact.download_artifact(
artifact="manifest", artifacts_dir="/irrelevant/path"
)
assert mock_write_json.call_count == 1

@mock.patch("dbterd.adapters.dbt_cloud.file.write_json")
@mock.patch("dbterd.adapters.dbt_cloud.requests.get")
def test_download_artifact_status_not_ok(
self, mock_requests_get, mock_write_json, dbtCloudArtifact
):
mock_requests_get.return_value = MockResponse(status_code=999)
assert not dbtCloudArtifact.download_artifact(
artifact="manifest", artifacts_dir="/irrelevant/path"
)
assert mock_write_json.call_count == 0

def test_get(self):
pass
@mock.patch("dbterd.adapters.dbt_cloud.DbtCloudArtifact.download_artifact")
def test_get(self, mock_download_artifact, dbtCloudArtifact):
mock_download_artifact.return_value = True
assert dbtCloudArtifact.get(artifacts_dir="/irrelevant/path")
assert mock_download_artifact.call_count == 2

0 comments on commit d0da492

Please sign in to comment.