Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 2, 2024
1 parent 599ab11 commit d79cf06
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 61 deletions.
104 changes: 61 additions & 43 deletions crosstl/src/backend/Vulkan/VulkanAst.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
class ASTNode:
pass


class TernaryOpNode:
def __init__(self, condition, true_expr, false_expr):
self.condition = condition
Expand All @@ -9,7 +10,8 @@ def __init__(self, condition, true_expr, false_expr):

def __repr__(self):
return f"TernaryOpNode(condition={self.condition}, true_expr={self.true_expr}, false_expr={self.false_expr})"



class ShaderNode:
def __init__(
self,
Expand All @@ -18,14 +20,15 @@ def __init__(
shader_stages,
functions,
):
self.spirv_version = spirv_version
self.descriptor_sets = descriptor_sets
self.shader_stages = shader_stages
self.functions = functions
self.spirv_version = spirv_version
self.descriptor_sets = descriptor_sets
self.shader_stages = shader_stages
self.functions = functions

def __repr__(self):
return f"ShaderNode(spirv_version={self.spirv_version}, descriptor_sets={self.descriptor_sets}, shader_stages={self.shader_stages}, functions={self.functions})"



class IfNode(ASTNode):
def __init__(self, condition, if_body, else_body=None):
self.condition = condition
Expand All @@ -34,7 +37,8 @@ def __init__(self, condition, if_body, else_body=None):

def __repr__(self):
return f"IfNode(condition={self.condition}, if_body={self.if_body}, else_body={self.else_body})"



class ForNode(ASTNode):
def __init__(self, init, condition, update, body):
self.init = init
Expand All @@ -44,14 +48,16 @@ def __init__(self, init, condition, update, body):

def __repr__(self):
return f"ForNode(init={self.init}, condition={self.condition}, update={self.update}, body={self.body})"



class ReturnNode(ASTNode):
def __init__(self, value):
self.value = value

def __repr__(self):
return f"ReturnNode(value={self.value})"



class FunctionCallNode(ASTNode):
def __init__(self, name, args):
self.name = name
Expand All @@ -60,6 +66,7 @@ def __init__(self, name, args):
def __repr__(self):
return f"FunctionCallNode(name={self.name}, args={self.args})"


class BinaryOpNode(ASTNode):
def __init__(self, left, op, right):
self.left = left
Expand All @@ -68,7 +75,8 @@ def __init__(self, left, op, right):

def __repr__(self):
return f"BinaryOpNode(left={self.left}, op={self.op}, right={self.right})"



class UnaryOpNode(ASTNode):
def __init__(self, op, operand):
self.op = op
Expand All @@ -80,109 +88,121 @@ def __repr__(self):
def __str__(self):
return f"({self.op}{self.operand})"


class DescriptorSetNode(ASTNode):
def __init__(self, set_number, bindings):
self.set_number = set_number
self.bindings = bindings
self.set_number = set_number
self.bindings = bindings

def __repr__(self):
return f"DescriptorSetNode(set_number={self.set_number}, bindings={self.bindings})"

return (
f"DescriptorSetNode(set_number={self.set_number}, bindings={self.bindings})"
)


class LayoutNode(ASTNode):
def __init__(self, sets, push_constants):
self.sets = sets
self.push_constants = push_constants
self.sets = sets
self.push_constants = push_constants

def __repr__(self):
return f"LayoutNode(sets={self.sets}, push_constants={self.push_constants})"


class ShaderStageNode(ASTNode):
def __init__(self, stage, entry_point):
self.stage = stage
self.entry_point = entry_point
self.stage = stage
self.entry_point = entry_point

def __repr__(self):
return f"ShaderStageNode(stage={self.stage}, entry_point={self.entry_point})"



class PushConstantNode(ASTNode):
def __init__(self, size, values):
self.size = size
self.values = values
self.size = size
self.values = values

def __repr__(self):
return f"PushConstantNode(size={self.size}, values={self.values})"



class StructNode(ASTNode):
def __init__(self, name, members):
self.name = name
self.members = members
self.name = name
self.members = members

def __repr__(self):
return f"StructNode(name={self.name}, members={self.members})"


class FunctionNode(ASTNode):
def __init__(self, name, return_type, parameters, body):
self.name = name
self.return_type = return_type
self.parameters = parameters
self.body = body
self.name = name
self.return_type = return_type
self.parameters = parameters
self.body = body

def __repr__(self):
return f"FunctionNode(name={self.name}, return_type={self.return_type}, parameters={self.parameters}, body={self.body})"


class VariableNode(ASTNode):
def __init__(self, name, var_type, initializer=None):
self.name = name
self.var_type = var_type
self.initializer = initializer
self.name = name
self.var_type = var_type
self.initializer = initializer

def __repr__(self):
return f"VariableNode(name={self.name}, var_type={self.var_type}, initializer={self.initializer})"



class VariableDeclarationNode(ASTNode):
def __init__(self, name, var_type, initializer=None):
self.name = name
self.var_type = var_type
self.initializer = initializer
self.name = name
self.var_type = var_type
self.initializer = initializer

def __repr__(self):
return f"VariableDeclarationNode(name={self.name}, var_type={self.var_type}, initializer={self.initializer})"



class SwitchNode(ASTNode):
def __init__(self, expression, cases):
self.expression = expression
self.cases = cases

def __repr__(self):
return f"SwitchNode(expression={self.expression}, cases={self.cases})"



class CaseNode(ASTNode):
def __init__(self, value, body):
self.value = value
self.body = body

def __repr__(self):
return f"CaseNode(value={self.value}, body={self.body})"



class WhileNode(ASTNode):
def __init__(self, condition, body):
self.condition = condition
self.body = body

def __repr__(self):
return f"WhileNode(condition={self.condition}, body={self.body})"



class DoWhileNode(ASTNode):
def __init__(self, body, condition):
self.body = body
self.condition = condition

def __repr__(self):
return f"DoWhileNode(body={self.body}, condition={self.condition})"



class AssignmentNode(ASTNode):
def __init__(self, left, right, operator="="):
self.left = left
Expand All @@ -191,5 +211,3 @@ def __init__(self, left, right, operator="="):

def __repr__(self):
return f"AssignmentNode(left={self.left}, operator='{self.operator}', right={self.right})"


11 changes: 6 additions & 5 deletions crosstl/src/backend/Vulkan/VulkanLexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
("COMMENT_SINGLE", r"//.*"),
("COMMENT_MULTI", r"/\*[\s\S]*?\*/"),
("WHITESPACE", r"\s+"),
("SEMANTIC", r":\w+"),
("IDENTIFIER", r"[a-zA-Z_][a-zA-Z0-9_]*"),
("SEMANTIC", r":\w+"),
("IDENTIFIER", r"[a-zA-Z_][a-zA-Z0-9_]*"),
("NUMBER", r"\d+(\.\d*)?|\.\d+"),
("SEMICOLON", r";"),
("LBRACE", r"\{"),
Expand All @@ -14,7 +14,7 @@
("RPAREN", r"\)"),
("COMMA", r","),
("DOT", r"\."),
("PLUS_EQUALS", r"\+="),
("PLUS_EQUALS", r"\+="),
("MINUS_EQUALS", r"-="),
("MULTIPLY_EQUALS", r"\*="),
("DIVIDE_EQUALS", r"/="),
Expand All @@ -23,7 +23,7 @@
("MINUS", r"-"),
("MULTIPLY", r"\*"),
("DIVIDE", r"/"),
("LESS_EQUAL", r"<="),
("LESS_EQUAL", r"<="),
("GREATER_EQUAL", r">="),
("NOT_EQUAL", r"!="),
("LESS_THAN", r"<"),
Expand Down Expand Up @@ -106,6 +106,7 @@
"atomic_uint": "ATOMICUINT",
}


class VulkanLexer:
def __init__(self, code):
self.code = code
Expand Down Expand Up @@ -143,4 +144,4 @@ def tokenize(self):
f"Illegal character '{unmatched_char}' at position {pos}\n{highlighted_code}"
)

self.tokens.append(("EOF", None))
self.tokens.append(("EOF", None))
24 changes: 11 additions & 13 deletions crosstl/src/backend/Vulkan/VulkanParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def parse_layout(self):
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("RPAREN")
self.eat("IDENTIFIER")
self.eat("IDENTIFIER")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
Expand Down Expand Up @@ -138,27 +138,27 @@ def parse_statement(self):
return self.parse_for_statement()
else:
return self.parse_expression_statement()

def parse_if_statement(self):
self.eat("IF")
self.eat("LPAREN")
condition = self.parse_expression()
condition = self.parse_expression()
self.eat("RPAREN")
if_body = self.parse_block()
if_body = self.parse_block()
else_body = None
if self.current_token[0] == "ELSE":
self.eat("ELSE")
else_body = self.parse_block()
else_body = self.parse_block()
return IfNode(condition, if_body, else_body)

def parse_for_statement(self):
self.eat("FOR")
self.eat("LPAREN")
initialization = self.parse_expression_statement()
initialization = self.parse_expression_statement()
condition = self.parse_expression()
self.eat("SEMICOLON")
self.eat("SEMICOLON")
increment = self.parse_expression()
self.eat("RPAREN")
self.eat("RPAREN")
body = self.parse_block()
return ForNode(initialization, condition, increment, body)

Expand Down Expand Up @@ -194,7 +194,7 @@ def parse_primary(self):
return value
else:
raise SyntaxError(f"Unexpected token: {self.current_token[0]}")

def parse_while_statement(self):
self.eat("WHILE")
self.eat("LPAREN")
Expand All @@ -212,7 +212,7 @@ def parse_do_while_statement(self):
self.eat("RPAREN")
self.eat("SEMICOLON")
return DoWhileNode(condition, body)

def parse_switch_statement(self):
self.eat("SWITCH")
self.eat("LPAREN")
Expand All @@ -238,5 +238,3 @@ def parse_case_statement(self):
while self.current_token[0] not in ["CASE", "DEFAULT", "RBRACE"]:
statements.append(self.parse_statement())
return CaseNode(value, statements)


0 comments on commit d79cf06

Please sign in to comment.