Skip to content

Commit

Permalink
Added dedicated function for rendering a jinja string value with a pr…
Browse files Browse the repository at this point in the history
…ovider child. New functions for aggregating detect messages.
  • Loading branch information
Will-NOQ committed Jul 14, 2023
1 parent 740e331 commit 8d84d0c
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 27 deletions.
40 changes: 40 additions & 0 deletions iambic/core/detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from collections import defaultdict
from typing import Type

from iambic.core.models import ProviderChild, BaseTemplate
from iambic.core.utils import evaluate_on_provider


def group_detect_messages(group_by: str, messages: list) -> dict:
"""Group messages by a key in the message dict.
Args:
group_by (str): The key to group by.
messages (list): The messages to group.
Returns:
dict: The grouped messages.
"""
grouped_messages = defaultdict(list)
for message in messages:
grouped_messages[getattr(message, group_by)].append(message)

return grouped_messages


def generate_template_output(
excluded_provider_ids: list[str],
provider_child_map: dict[str, ProviderChild],
template: Type[BaseTemplate]
) -> dict[str, dict]:
provider_children_value_map = dict()
for provider_child_id, provider_child in provider_child_map.items():
if provider_child_id in excluded_provider_ids:
continue
elif not evaluate_on_provider(template, provider_child, exclude_import_only=False):
continue

if provider_child_value := template.apply_resource_dict(provider_child):
provider_children_value_map[provider_child_id] = provider_child_value

return provider_children_value_map
38 changes: 12 additions & 26 deletions iambic/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import dateparser
from deepdiff.model import PrettyOrderedSet
from git import Repo
from jinja2 import BaseLoader, Environment
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Extra, Field, root_validator, schema, validate_model, validator
from pydantic.fields import ModelField
Expand All @@ -41,17 +40,15 @@
apply_to_provider,
create_commented_map,
get_writable_directory,
sanitize_string,
simplify_dt,
snake_to_camelcap,
sort_dict,
transform_comments,
yaml,
yaml, get_rendered_template_str_value,
)

if TYPE_CHECKING:
from iambic.config.dynamic_config import Config
from iambic.plugins.v0_1_0.aws.models import AWSAccount

MappingIntStrAny = typing.Mapping[int | str, Any]
AbstractSetIntStr = typing.AbstractSet[int | str]
Expand Down Expand Up @@ -159,14 +156,14 @@ def get_field_type(field: Any) -> Any:

def get_attribute_val_for_account(
self,
aws_account: AWSAccount,
provider_child: Type[ProviderChild],
attr: str,
as_boto_dict: bool = True,
):
"""
Retrieve the value of an attribute for a specific AWS account.
:param aws_account: The AWSAccount object for which the attribute value should be retrieved.
:param provider_child: The ProviderChild object for which the attribute value should be retrieved.
:param attr: The attribute name (supports nested attributes via dot notation, e.g., properties.tags).
:param as_boto_dict: If True, the value will be transformed to a boto dictionary if applicable.
:return: The attribute value for the specified AWS account.
Expand All @@ -177,12 +174,12 @@ def get_attribute_val_for_account(
attr_val = getattr(attr_val, attr_key)

if as_boto_dict and hasattr(attr_val, "_apply_resource_dict"):
return attr_val._apply_resource_dict(aws_account)
return attr_val._apply_resource_dict(provider_child)
elif not isinstance(attr_val, list):
return attr_val

matching_definitions = [
val for val in attr_val if apply_to_provider(val, aws_account)
val for val in attr_val if apply_to_provider(val, provider_child)
]
if len(matching_definitions) == 0:
# Fallback to the default definition
Expand All @@ -194,15 +191,15 @@ def get_attribute_val_for_account(
return field.__fields__[split_key[-1]].default
elif as_boto_dict:
return [
match._apply_resource_dict(aws_account)
match._apply_resource_dict(provider_child)
if hasattr(match, "_apply_resource_dict")
else match
for match in matching_definitions
]
else:
return matching_definitions

def _apply_resource_dict(self, aws_account: AWSAccount = None) -> dict:
def _apply_resource_dict(self, provider_child: Type[ProviderChild] = None) -> dict:
exclude_keys = {
"deleted",
"expires_at",
Expand All @@ -220,10 +217,10 @@ def _apply_resource_dict(self, aws_account: AWSAccount = None) -> dict:
exclude_keys.update(self.exclude_keys)
has_properties = hasattr(self, "properties")
properties = getattr(self, "properties", self)
if aws_account:
if provider_child:
resource_dict = {
k: self.get_attribute_val_for_account(
aws_account,
provider_child,
f"properties.{k}" if has_properties else k,
)
for k in properties.__dict__.keys()
Expand All @@ -239,20 +236,9 @@ def _apply_resource_dict(self, aws_account: AWSAccount = None) -> dict:

return {self.case_convention(k): v for k, v in resource_dict.items()}

def apply_resource_dict(self, aws_account: AWSAccount) -> dict:
response = self._apply_resource_dict(aws_account)
variables = {var.key: var.value for var in aws_account.variables}
variables["account_id"] = aws_account.account_id
variables["account_name"] = aws_account.account_name
if hasattr(self, "owner") and (owner := getattr(self, "owner", None)):
variables["owner"] = owner

rtemplate = Environment(loader=BaseLoader()).from_string(json.dumps(response))
valid_characters_re = r"[\w_+=,.@-]"
variables = {
k: sanitize_string(v, valid_characters_re) for k, v in variables.items()
}
data = rtemplate.render(var=variables)
def apply_resource_dict(self, provider_child: Type[ProviderChild]) -> dict:
response = self._apply_resource_dict(provider_child)
data = get_rendered_template_str_value(json.dumps(response), provider_child)
return json.loads(data)

async def remove_expired_resources(self):
Expand Down
28 changes: 27 additions & 1 deletion iambic/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import aiofiles
import jwt
from asgiref.sync import sync_to_async
from jinja2 import BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
from ruamel.yaml import YAML, scalarstring

from iambic.core import noq_json as json
Expand All @@ -26,7 +28,7 @@
from iambic.core.logger import log

if TYPE_CHECKING:
from iambic.core.models import ProposedChange
from iambic.core.models import ProposedChange, ProviderChild


NOQ_TEMPLATE_REGEX = r".*template_type:\n?.*NOQ::"
Expand Down Expand Up @@ -892,3 +894,27 @@ def decode_with_reference_time(encoded_jwt, public_key, algorithms, reference_ti
)

return payload


def get_rendered_template_str_value(
template_value: str, provider_child: typing.Type[ProviderChild]
) -> str:
"""
Render a template string with the variables from the provider child.
"""
valid_characters_re = r"[\w_+=,.@-]"
variables = {var.key: var.value for var in getattr(provider_child, "variables", [])}
for extra_attr in {"account_id", "account_name", "owner"}:
if attr_val := getattr(provider_child, extra_attr, None):
variables[extra_attr] = attr_val

if not variables:
return template_value

variables = {
k: sanitize_string(v, valid_characters_re) for k, v in variables.items()
}
rtemplate = ImmutableSandboxedEnvironment(loader=BaseLoader()).from_string(
template_value
)
return rtemplate.render(var=variables)

0 comments on commit 8d84d0c

Please sign in to comment.