Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds ability to add list_fields to directives #85

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "turms"
version = "0.7.0"
version = "0.8.0"
description = "graphql-codegen powered by pydantic"
authors = ["jhnnsrs <[email protected]>"]
license = "MIT"
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def nested_input_schema():
def union_schema():
return build_schema_from_schema_type(build_relative_glob("/schemas/union.graphql"))

@pytest.fixture(scope="session")
def directive_schema():
return build_schema_from_schema_type(build_relative_glob("/schemas/list_field_directive.graphql"))


@pytest.fixture(scope="session")
def schema_directive_schema():
Expand Down
3 changes: 3 additions & 0 deletions tests/documents/directives/list_field_directive.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
query X {
x
}
Empty file added tests/plugins/__init__.py
Empty file.
Empty file.
36 changes: 36 additions & 0 deletions tests/plugins/strawberry/test_list_directive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

from ...utils import build_relative_glob, unit_test_with
from turms.config import GeneratorConfig
from turms.run import generate_ast
from turms.plugins.enums import EnumsPlugin
from turms.plugins.inputs import InputsPlugin
from turms.plugins.fragments import FragmentsPlugin
from turms.plugins.operations import OperationsPlugin
from turms.plugins.funcs import (
FunctionDefinition,
FuncsPlugin,
FuncsPluginConfig,
)
from turms.plugins.strawberry import StrawberryPlugin
from turms.stylers.snake_case import SnakeCaseStyler
from turms.stylers.capitalize import CapitalizeStyler
from turms.run import generate_ast


def test_list_directive_funcs(directive_schema):
config = GeneratorConfig(
documents=build_relative_glob("/documents/directives/*.graphql"),
)
generated_ast = generate_ast(
config,
directive_schema,
stylers=[CapitalizeStyler(), SnakeCaseStyler()],
plugins=[
StrawberryPlugin(),
],
)

unit_test_with(
generated_ast,
""
)
22 changes: 22 additions & 0 deletions tests/schemas/list_field_directive.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
The directive is responsible for authorization check.
"""
directive @auth(
"""
Permissions which are required for field access.
"""
permissions: [String!]

"""
The list of roles that an authorized user should have to get the access.
"""
roles: [String!] = []
) on FIELD_DEFINITION

type X {
name: String! @auth(permissions: ["read"])
}

type Query {
x: [X!]! @auth(permissions: ["read"])
}
2 changes: 1 addition & 1 deletion turms/plugins/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class FuncsPluginConfig(PluginConfig):
definitions: List[FunctionDefinition] = []
extract_documentation: bool = True
argument_key_is_styled: bool = False
expand_input_types: List[str] = ["input"]
expand_input_types: List[str] = []


def camel_to_snake(name):
Expand Down
192 changes: 180 additions & 12 deletions turms/plugins/strawberry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from graphql import (
BooleanValueNode,
ConstListValueNode,
ConstObjectValueNode,
ConstValueNode,
EnumValueNode,
FloatValueNode,
GraphQLField,
GraphQLInputObjectType,
GraphQLInterfaceType,
Expand All @@ -8,6 +14,11 @@
GraphQLScalarType,
GraphQLType,
GraphQLUnionType,
IntValueNode,
ListValueNode,
NullValueNode,
ObjectValueNode,
StringValueNode,
Undefined,
GraphQLArgument,
ObjectTypeDefinitionNode,
Expand Down Expand Up @@ -43,6 +54,133 @@ def __call__(
) -> List[ast.AST]: ... # pragma: no cover


def build_directive_type_annotation(value: GraphQLType, registry: ClassRegistry, is_optional=True):

if isinstance(value, GraphQLScalarType):
if is_optional:
registry.register_import("typing.Optional")
return ast.Subscript(
value=ast.Name("Optional", ctx=ast.Load()),
slice=registry.reference_scalar(value.name),
ctx=ast.Load(),
)

return registry.reference_scalar(value.name)
if isinstance(value, GraphQLObjectType):
raise NotImplementedError("Object types cannot be used as arguments")
if isinstance(value, GraphQLInterfaceType):
raise NotImplementedError("Interface types cannot be used as arguments")
if isinstance(value, GraphQLUnionType):
raise NotImplementedError("Union types cannot be used as arguments")
if isinstance(value, GraphQLEnumType):
if is_optional:
registry.register_import("typing.Optional")
return ast.Subscript(
value=ast.Name("Optional", ctx=ast.Load()),
slice=registry.reference_enum(value.name),
ctx=ast.Load(),
)

return registry.reference_enum(value.name)
if isinstance(value, GraphQLNonNull):
return build_directive_type_annotation(value.of_type, registry, is_optional=False)
if isinstance(value, GraphQLList):
registry.register_import("typing.List")

if is_optional:
registry.register_import("typing.Optional")

return ast.Subscript(
value=ast.Name("Optional", ctx=ast.Load()),
slice=ast.Subscript(
value=ast.Name("List", ctx=ast.Load()),
slice=build_directive_type_annotation(value.of_type, registry, is_optional=True),
ctx=ast.Load(),
),
ctx=ast.Load(),
)

return ast.Subscript(
value=ast.Name("List", ctx=ast.Load()),
slice=build_directive_type_annotation(value.of_type, registry, is_optional=True),
ctx=ast.Load(),
)
if isinstance(value, GraphQLInputObjectType):
raise NotImplementedError("Input types cannot be used as arguments")

raise NotImplementedError(f"Unknown type {repr(value)}")



def convert_valuenode_to_ast(value: ConstValueNode):
if isinstance(value, NullValueNode):
return ast.Constant(value=None)
if isinstance(value, StringValueNode):
return ast.Constant(value=value.value)
if isinstance(value, IntValueNode):
return ast.Constant(value=value.value)
if isinstance(value, FloatValueNode):
return ast.Constant(value=value.value)
if isinstance(value, BooleanValueNode):
return ast.Constant(value=value.value)

if isinstance(value, EnumValueNode):
return ast.Constant(value=value)
if isinstance(value, ListValueNode):
return ast.List(elts=[convert_valuenode_to_ast(x) for x in value.values], ctx=ast.Load())
if isinstance(value, ObjectValueNode):

keys = []
values = []

for field in value.fields:
keys.append(field.name.value)
values.append(convert_valuenode_to_ast(field.value))

return ast.Dict(
keys=keys,
values=values,
)

raise NotImplementedError(f"Unknown default value {repr(value)}")



def convert_default_value_to_ast(value):
if value is Undefined:
return None
if value is None:
return ast.Constant(value=None)
if isinstance(value, str):
return ast.Constant(value=value)
if isinstance(value, int):
return ast.Constant(value=value)
if isinstance(value, float):
return ast.Constant(value=value)
if isinstance(value, bool):
return ast.Constant(value=value)
if isinstance(value, list):
return ast.List(elts=[convert_default_value_to_ast(x) for x in value], ctx=ast.Load())
if isinstance(value, dict):
keys = []
values = []

for key, value in value.items():
keys.append(key)
values.append(convert_default_value_to_ast(value))

return ast.Dict(
keys=keys,
values=values,
)
raise NotImplementedError(f"Unknown default value {repr(value)}")







def default_generate_directives(
client_schema: GraphQLSchema,
config: GeneratorConfig,
Expand Down Expand Up @@ -96,23 +234,53 @@ def default_generate_directives(

type = value.type

if isinstance(value.type, GraphQLNonNull):
type = value.type.of_type

assert isinstance(
type, GraphQLScalarType
), "Only scalar (or nonnull version of this) are supported"

if value.default_value:
default = ast.Constant(value=value.default_value)
if value.default_value is not None:
default = convert_default_value_to_ast(value.default_value)
else:
default = None

needs_factory = False
if isinstance(default, ast.List):
needs_factory = True
if isinstance(default, ast.Dict):
needs_factory = True


field_value = None

if default:
if needs_factory:
field_value = ast.Call(
func=ast.Name(id="strawberry.field", ctx=ast.Load()),
keywords=[
ast.keyword(
arg="default_factory",
value=ast.Lambda(
args=[], body=default
),
),
],
args=[],
)
else:
field_value = ast.Call(
func=ast.Name(id="strawberry.field", ctx=ast.Load()),
keywords=[
ast.keyword(
arg="default",
value=default,
),
],
args=[],
)


assign = ast.AnnAssign(
target=ast.Name(
id=registry.generate_node_name(value_key), ctx=ast.Store()
),
annotation=registry.reference_scalar(type.name),
value=default,
annotation=build_directive_type_annotation(type, registry),
value=field_value,
simple=1,
)

Expand Down Expand Up @@ -560,7 +728,7 @@ def generate_directive_keywords(
ctx=ast.Load(),
),
keywords=[
ast.keyword(arg=arg.name.value, value=ast.Constant(arg.value.value))
ast.keyword(arg=arg.name.value, value=convert_valuenode_to_ast(arg.value))
for arg in directive.arguments
],
args=[],
Expand Down
Loading