diff --git a/cxxheaderparser/parser.py b/cxxheaderparser/parser.py index cb74124..a8cdf95 100644 --- a/cxxheaderparser/parser.py +++ b/cxxheaderparser/parser.py @@ -1912,12 +1912,10 @@ def _parse_fn_end(self, fn: Function) -> None: else: rtok = self.lex.token_if("requires") if rtok: - fn_template = fn.template - if fn_template is None: + # requires on a function must always be accompanied by a template + if fn.template is None: raise self._parse_error(rtok) - elif isinstance(fn_template, list): - fn_template = fn_template[0] - fn_template.raw_requires_post = self._parse_requires(rtok) + fn.raw_requires = self._parse_requires(rtok) if self.lex.token_if("ARROW"): self._parse_trailing_return_type(fn) @@ -1983,12 +1981,7 @@ def _parse_method_end(self, method: Method) -> None: toks = self._consume_balanced_tokens(otok)[1:-1] method.noexcept = self._create_value(toks) elif tok_value == "requires": - method_template = method.template - if method_template is None: - raise self._parse_error(tok) - elif isinstance(method_template, list): - method_template = method_template[0] - method_template.raw_requires_post = self._parse_requires(tok) + method.raw_requires = self._parse_requires(tok) else: self.lex.return_token(tok) break diff --git a/cxxheaderparser/types.py b/cxxheaderparser/types.py index 5ad2554..d15b76f 100644 --- a/cxxheaderparser/types.py +++ b/cxxheaderparser/types.py @@ -526,9 +526,6 @@ class Foo {}; #: template requires ... raw_requires_pre: typing.Optional[Value] = None - #: template int main() requires ... - raw_requires_post: typing.Optional[Value] = None - #: If no template, this is None. This is a TemplateDecl if this there is a single #: declaration: @@ -730,6 +727,13 @@ class Function: #: is the string "conversion" and the full Type is found in return_type operator: typing.Optional[str] = None + #: A requires constraint following the function declaration. If you need the + #: prior, look at TemplateDecl.raw_requires_pre. At the moment this is just + #: a raw value, if we interpret it in the future this will change. + #: + #: template int main() requires ... + raw_requires: typing.Optional[Value] = None + @dataclass class Method(Function): diff --git a/tests/test_concepts.py b/tests/test_concepts.py index 2be5a8c..85ceec8 100644 --- a/tests/test_concepts.py +++ b/tests/test_concepts.py @@ -6,6 +6,7 @@ Concept, Function, FundamentalSpecifier, + Method, MoveReference, NameSpecifier, PQName, @@ -495,15 +496,15 @@ def test_requires_last_elem() -> None: ) ], template=TemplateDecl( - params=[TemplateTypeParam(typekey="typename", name="T")], - raw_requires_post=Value( - tokens=[ - Token(value="Eq"), - Token(value="<"), - Token(value="T"), - Token(value=">"), - ] - ), + params=[TemplateTypeParam(typekey="typename", name="T")] + ), + raw_requires=Value( + tokens=[ + Token(value="Eq"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] ), ) ] @@ -752,14 +753,14 @@ def test_requires_both() -> None: Token(value=">"), ] ), - raw_requires_post=Value( - tokens=[ - Token(value="Subtractable"), - Token(value="<"), - Token(value="T"), - Token(value=">"), - ] - ), + ), + raw_requires=Value( + tokens=[ + Token(value="Subtractable"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + ] ), ) ] @@ -791,20 +792,86 @@ def test_requires_paren() -> None: ) ], template=TemplateDecl( - params=[TemplateTypeParam(typekey="class", name="T")], - raw_requires_post=Value( - tokens=[ - Token(value="("), - Token(value="is_purrable"), - Token(value="<"), - Token(value="T"), - Token(value=">"), - Token(value="("), - Token(value=")"), - Token(value=")"), - ] + params=[TemplateTypeParam(typekey="class", name="T")] + ), + raw_requires=Value( + tokens=[ + Token(value="("), + Token(value="is_purrable"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + Token(value="("), + Token(value=")"), + Token(value=")"), + ] + ), + ) + ] + ) + ) + + +def test_non_template_requires() -> None: + content = """ + // clang-format off + + template + struct Payload + { + constexpr Payload(T v) + requires(std::is_pod_v) + : Value(v) + { + } + }; + """ + data = parse_string(content, cleandoc=True) + + assert data == ParsedData( + namespace=NamespaceScope( + classes=[ + ClassScope( + class_decl=ClassDecl( + typename=PQName( + segments=[NameSpecifier(name="Payload")], classkey="struct" + ), + template=TemplateDecl( + params=[TemplateTypeParam(typekey="class", name="T")] ), ), + methods=[ + Method( + return_type=None, + name=PQName(segments=[NameSpecifier(name="Payload")]), + parameters=[ + Parameter( + type=Type( + typename=PQName( + segments=[NameSpecifier(name="T")] + ) + ), + name="v", + ) + ], + constexpr=True, + has_body=True, + raw_requires=Value( + tokens=[ + Token(value="("), + Token(value="std"), + Token(value="::"), + Token(value="is_pod_v"), + Token(value="<"), + Token(value="T"), + Token(value=">"), + Token(value=")"), + ] + ), + access="public", + constructor=True, + ) + ], ) ] )