Skip to content

Commit

Permalink
Merge pull request #75 from rnovacek/default-for-optional
Browse files Browse the repository at this point in the history
Set default values for optional field in input types
  • Loading branch information
jhnnsrs authored May 14, 2024
2 parents b956225 + 4c8b689 commit 1df9492
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
47 changes: 47 additions & 0 deletions tests/test_optional_input_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import ast
from graphql import parse, build_ast_schema
from turms.config import GeneratorConfig
from turms.run import generate_ast
from turms.stylers.snake_case import SnakeCaseStyler
from turms.plugins.inputs import InputsPlugin
from turms.plugins.objects import ObjectsPlugin
from .utils import unit_test_with

inputs = '''
input X {
mandatory: String!
otherMandatory: String!
optional: String
otherOptional: String
}
'''


expected = '''class X(BaseModel):
mandatory: str
other_mandatory: str = Field(alias='otherMandatory')
optional: Optional[str] = None
other_optional: Optional[str] = Field(alias='otherOptional', default=None)'''


def test_generates_schema(snapshot):
config = GeneratorConfig()

schema = build_ast_schema(parse(inputs))

generated_ast = generate_ast(
config,
schema,
stylers=[SnakeCaseStyler()],
plugins=[
InputsPlugin(),
ObjectsPlugin(),
],
)

unit_test_with(generated_ast, '')

without_imports = [node for node in generated_ast if not isinstance(node, ast.ImportFrom)]
md = ast.Module(body=without_imports, type_ignores=[])
generated = ast.unparse(ast.fix_missing_locations(md))
assert generated == expected
18 changes: 13 additions & 5 deletions turms/plugins/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ def generate_inputs(

if field_name != value_key:
registry.register_import("pydantic.Field")
keywords = [
ast.keyword(
arg="alias", value=ast.Constant(value=value_key)
)
]
if not isinstance(value.type, GraphQLNonNull):
keywords.append(
ast.keyword(arg="default", value=ast.Constant(None))
)

assign = ast.AnnAssign(
target=ast.Name(field_name, ctx=ast.Store()),
annotation=generate_input_annotation(
Expand All @@ -188,14 +198,11 @@ def generate_inputs(
value=ast.Call(
func=ast.Name(id="Field", ctx=ast.Load()),
args=[],
keywords=[
ast.keyword(
arg="alias", value=ast.Constant(value=value_key)
)
],
keywords=keywords,
),
simple=1,
)

else:
assign = ast.AnnAssign(
target=ast.Name(value_key, ctx=ast.Store()),
Expand All @@ -208,6 +215,7 @@ def generate_inputs(
is_optional=True,
),
simple=1,
value=ast.Constant(None) if not isinstance(value.type, GraphQLNonNull) else None,
)

potential_comment = (
Expand Down

0 comments on commit 1df9492

Please sign in to comment.