Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataclasses refactor and add new to_dict function #221

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 66 additions & 22 deletions testsuite/objects/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,50 @@
"""Module containing base classes for common objects"""
import abc
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass, fields
from copy import deepcopy
from functools import cached_property
from typing import Literal, List
from typing import Literal, Union

from testsuite.objects.sections import Metadata, Identities, Authorizations, Responses

pehala marked this conversation as resolved.
Show resolved Hide resolved
JSONValues = Union[None, str, int, bool, list["JSONValues"], dict[str, "JSONValues"]]


def asdict(obj) -> dict[str, JSONValues]:
"""
This function converts dataclass object to dictionary.
While it works similar to `dataclasses.asdict` a notable change is usage of
overriding `to_dict()` function if dataclass contains it.
This function works recursively in lists, tuples and dicts. All other values are passed to copy.deepcopy function.
"""
if not is_dataclass(obj):
raise TypeError("asdict() should be called on dataclass instances")
return _asdict_recurse(obj)


def _asdict_recurse(obj):
if hasattr(obj, "asdict"):
return obj.asdict()

if not is_dataclass(obj):
return deepcopy(obj)

result = {}
for field in fields(obj):
value = getattr(obj, field.name)
if value is None:
continue # do not include None values

if is_dataclass(value):
result[field.name] = _asdict_recurse(value)
elif isinstance(value, (list, tuple)):
result[field.name] = type(value)(_asdict_recurse(i) for i in value)
elif isinstance(value, dict):
result[field.name] = type(value)((_asdict_recurse(k), _asdict_recurse(v)) for k, v in value.items())
else:
result[field.name] = deepcopy(value)
return result


@dataclass
class MatchExpression:
Expand All @@ -15,7 +54,7 @@ class MatchExpression:
"""

operator: Literal["In", "NotIn", "Exists", "DoesNotExist"]
values: List[str]
values: list[str]
key: str = "group"


Expand All @@ -35,40 +74,45 @@ class Rule:
value: str


class Value:
"""Dataclass for specifying a Value in Authorization, can be either constant or value from AuthJson (jsonPath)"""
@dataclass
class ABCValue(abc.ABC):
"""
Abstract Dataclass for specifying a Value in Authorization,
can be either static or reference to value in AuthJson.
"""


# pylint: disable=invalid-name
def __init__(self, value=None, jsonPath=None) -> None:
super().__init__()
if not (value is None) ^ (jsonPath is None):
raise AttributeError("Exactly one of the `value` and `jsonPath` argument must be specified")
self.value = value
self.jsonPath = jsonPath
@dataclass
class Value(ABCValue):
"""Dataclass for static Value. Can be any value allowed in JSON: None, string, integer, bool, list, dict"""

def to_dict(self):
"""Returns dict representation of itself (shallow copy only)"""
return {"value": self.value} if self.value else {"valueFrom": {"authJson": self.jsonPath}}
value: JSONValues


@dataclass
class ValueFrom(ABCValue):
"""Dataclass for dynamic Value. It contains reference path to existing value in AuthJson."""

authJSON: str # pylint: disable=invalid-name

def asdict(self):
"""Override `asdict` function"""
return {"valueFrom": {"authJSON": self.authJSON}}


@dataclass
class Cache:
"""Dataclass for specifying Cache in Authorization"""

ttl: int
key: Value

def to_dict(self):
"""Returns dict representation of itself (shallow copy only)"""
return {"ttl": self.ttl, "key": self.key.to_dict()}
key: ABCValue


@dataclass
class PatternRef:
"""Dataclass for specifying Pattern reference in Authorization"""

# pylint: disable=invalid-name
patternRef: str
patternRef: str # pylint: disable=invalid-name


class LifecycleObject(abc.ABC):
Expand Down
4 changes: 2 additions & 2 deletions testsuite/objects/sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from testsuite.objects import Rule, Value
from testsuite.objects import Rule, ABCValue


class Authorizations(abc.ABC):
Expand All @@ -27,7 +27,7 @@ def auth_rule(self, name: str, rule: "Rule", **common_features):
"""Adds JSON pattern-matching authorization rule (authorization.json)"""

@abc.abstractmethod
def kubernetes(self, name: str, user: "Value", kube_attrs: dict, **common_features):
def kubernetes(self, name: str, user: "ABCValue", kube_attrs: dict, **common_features):
"""Adds kubernetes authorization rule."""


Expand Down
19 changes: 14 additions & 5 deletions testsuite/openshift/objects/auth_config/sections.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
"""AuthConfig CR object"""
from dataclasses import asdict
from typing import Dict, Literal, Iterable, TYPE_CHECKING

from testsuite.objects import Identities, Metadata, Responses, MatchExpression, Authorizations, Rule, Cache, Value
from testsuite.objects import (
asdict,
Identities,
Metadata,
Responses,
MatchExpression,
Authorizations,
Rule,
Cache,
ABCValue,
)
from testsuite.openshift.objects import modify

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,7 +54,7 @@ def add_item(
if metrics:
item["metrics"] = metrics
if cache:
item["cache"] = cache.to_dict()
item["cache"] = asdict(cache)
if priority:
item["priority"] = priority
self.section.append(item)
Expand Down Expand Up @@ -215,7 +224,7 @@ def external_opa_policy(self, name, endpoint, ttl=0, **common_features):
self.add_item(name, {"opa": {"externalRegistry": {"endpoint": endpoint, "ttl": ttl}}}, **common_features)

@modify
def kubernetes(self, name: str, user: Value, kube_attrs: dict, **common_features):
def kubernetes(self, name: str, user: ABCValue, kube_attrs: dict, **common_features):
"""Adds Kubernetes authorization

:param name: name of kubernetes authorization
Expand All @@ -226,7 +235,7 @@ def kubernetes(self, name: str, user: Value, kube_attrs: dict, **common_features
self.add_item(
name,
{
"kubernetes": {"user": user.to_dict(), "resourceAttributes": kube_attrs},
"kubernetes": {"user": asdict(user), "resourceAttributes": kube_attrs},
},
**common_features
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from testsuite.objects import Cache, Value
from testsuite.objects import Cache, ValueFrom
from testsuite.utils import extract_response


Expand All @@ -16,7 +16,7 @@ def cache_ttl():
@pytest.fixture(scope="module")
def authorization(authorization, module_label, expectation_path, cache_ttl):
"""Adds Cached Metadata to the AuthConfig"""
meta_cache = Cache(cache_ttl, Value(jsonPath="context.request.http.path"))
meta_cache = Cache(cache_ttl, ValueFrom("context.request.http.path"))
authorization.metadata.http_metadata(module_label, expectation_path, "GET", cache=meta_cache)
return authorization

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import openshift as oc
from openshift import OpenShiftPythonException

from testsuite.objects import Authorization, Rule, Value
from testsuite.objects import Authorization, Rule, ValueFrom
from testsuite.certificates import CertInfo
from testsuite.utils import cert_builder
from testsuite.openshift.objects.ingress import Ingress
Expand Down Expand Up @@ -78,7 +78,7 @@ def authorization(authorization, openshift, module_label, authorino_domain) -> A

# add OPA policy to process admission webhook request
authorization.authorization.opa_policy("features", OPA_POLICY)
user_value = Value(jsonPath="auth.identity.username")
user_value = ValueFrom("auth.identity.username")

when = [
Rule("auth.authorization.features.allow", "eq", "true"),
Expand Down