Skip to content

Commit

Permalink
chore: patch for dbt 1.5+ compat
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed Apr 13, 2024
1 parent 56ac1de commit d60e2bb
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
2 changes: 1 addition & 1 deletion dbt_feature_flags/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.1.0"
__version__ = "0.5.2"
20 changes: 10 additions & 10 deletions dbt_feature_flags/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import abc
import logging
import typing as t
from functools import wraps
from typing import Any, Union, final


class BaseFeatureFlagsClient(abc.ABC):
Expand All @@ -28,49 +28,49 @@ class BaseFeatureFlagsClient(abc.ABC):
def __init__(self) -> None:
self._add_validators()

@final
@t.final
def _add_validators(self):
self.bool_variation = validate(types=(bool,))(self.bool_variation)
self.string_variation = validate(types=(str,))(self.string_variation)
self.number_variation = validate(types=(float, int))(self.number_variation)
self.json_variation = validate(types=(dict, list, None))(self.json_variation)
self.json_variation = validate(types=(dict, list, None))(self.json_variation) # type: ignore

@abc.abstractmethod
def bool_variation(self, flag: str, default: Any) -> bool:
def bool_variation(self, flag: str, default: t.Any) -> bool:
raise NotImplementedError(
"Boolean feature flags are not implemented for this driver"
)

@abc.abstractmethod
def string_variation(self, flag: str, default: Any) -> str:
def string_variation(self, flag: str, default: t.Any) -> str:
raise NotImplementedError(
"String feature flags are not implemented for this driver"
)

@abc.abstractmethod
def number_variation(self, flag: str, default: Any) -> Union[float, int]:
def number_variation(self, flag: str, default: t.Any) -> t.Union[float, int]:
raise NotImplementedError(
"Number feature flags are not implemented for this driver"
)

@abc.abstractmethod
def json_variation(self, flag: str, default: Any) -> Union[dict, list]:
def json_variation(self, flag: str, default: t.Any) -> t.Union[dict, list]:
raise NotImplementedError(
"JSON feature flags are not implemented for this driver"
)


def validate(types: Union[list, tuple]):
def validate(types: t.Tuple[t.Type[t.Any], ...]):
def _validate(v, flag_name, func_name):
if not isinstance(v, types):
if not isinstance(v, tuple(types)):
raise ValueError(
f"Invalid return value for {func_name}({flag_name}...) feature flag call. Found type {type(v).__name__}."
)
return v

def _main(func):
@wraps(func)
def _injected_validator(flag: str, default: Any = func.__defaults__[0]):
def _injected_validator(flag: str, default: t.Any = func.__defaults__[0]):
if not isinstance(default, types):
raise ValueError(
f"Invalid default value: {default} for {func.__name__}({flag}...) feature flag call. Found type {type(default).__name__}."
Expand Down
18 changes: 13 additions & 5 deletions dbt_feature_flags/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import typing as t
from enum import Enum
from functools import wraps
from types import SimpleNamespace

from dbt_feature_flags import base, harness, launchdarkly

_MOCK_CLIENT = object()
MockClient = t.NewType("MockClient", type(object()))

_MOCK_CLIENT = t.cast(MockClient, object())


class SupportedProviders(str, Enum):
Expand All @@ -34,7 +37,7 @@ def _is_truthy(value: str) -> bool:
return value.lower() in ("1", "true", "yes")


def _get_client() -> base.BaseFeatureFlagsClient | _MOCK_CLIENT:
def _get_client() -> base.BaseFeatureFlagsClient | MockClient | None:
"""Return the user specified client.
Valid implementations MUST inherit from BaseFeatureFlagsClient.
Expand Down Expand Up @@ -90,16 +93,21 @@ def _wrapped(
ctx["feature_flag_json"] = client.json_variation
return fn(string, ctx, node, capture_macros, native)

_wrapped.status = "patched"
_wrapped.status = "patched" # type: ignore
return _wrapped


def patch_dbt_environment() -> None:
"""Patch dbt's jinja environment to include feature flag functions."""
import dbt.flags
from dbt.clients import jinja

jinja._get_rendered = jinja.get_rendered
jinja.get_rendered = get_rendered(jinja._get_rendered, _get_client())
# small patch to make compatible with dbt 1.5+
g_flags = getattr(dbt.flags, "get_flags", lambda: SimpleNamespace())
g_flags().MACRO_DEBUGGING = False

jinja._get_rendered = jinja.get_rendered # type: ignore
jinja.get_rendered = get_rendered(jinja._get_rendered, _get_client()) # type: ignore


if __name__ == "__main__":
Expand Down

0 comments on commit d60e2bb

Please sign in to comment.