Skip to content

Commit

Permalink
Convert loopy's expression nodes to expr_dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 3, 2024
1 parent 251aec7 commit 63c9bc0
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 137 deletions.
18 changes: 4 additions & 14 deletions loopy/library/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
"""


from typing import ClassVar, Tuple

import numpy as np

from pymbolic import var
from pymbolic.primitives import expr_dataclass

from loopy.diagnostic import LoopyError
from loopy.kernel.function_interface import ScalarCallable
Expand Down Expand Up @@ -276,14 +275,9 @@ def __call__(self, dtype, operand1, operand2, callables_table, target):

# {{{ base class for symbolic reduction ops

@expr_dataclass()
class ReductionOpFunction(FunctionIdentifier):
init_arg_names: ClassVar[Tuple[str, ...]] = ("reduction_op",)

def __init__(self, reduction_op):
self.reduction_op = reduction_op

def __getinitargs__(self):
return (self.reduction_op,)
reduction_op: ReductionOperation

@property
def name(self):
Expand All @@ -295,11 +289,6 @@ def copy(self, reduction_op=None):

return type(self)(reduction_op)

hash_fields = (
"reduction_op",)

update_persistent_hash = update_persistent_hash

# }}}


Expand Down Expand Up @@ -413,6 +402,7 @@ class SegmentedProductReductionOperation(_SegmentedScalarReductionOperation):

# {{{ argmin/argmax

@expr_dataclass()
class ArgExtOp(ReductionOpFunction):
pass

Expand Down
Loading

0 comments on commit 63c9bc0

Please sign in to comment.