Skip to content

Commit

Permalink
alphatensor bril benchmark, tested against numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
Bennett Wineholt committed Sep 8, 2023
1 parent 785e829 commit 7c2ad91
Show file tree
Hide file tree
Showing 9 changed files with 32,051 additions and 0 deletions.
106 changes: 106 additions & 0 deletions benchmarks/float/atmulti4x4/at_np_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# test alpha tensor 4x4 against numpy matmul for correctness

import numpy as np

def main():
# [A,B]
matrices = [
[[8, 3, 2, 4],[2, 7, 4, 5],[0, 1, 2, 3],[0, 1, 2, 3]],
[[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12],[13, 14, 15, 16]],
]
A = np.array(matrices[0])
B = np.array(matrices[1])
res_np = np.matmul(A,B)
res_at = at4x4(A,B)
if np.array_equal(res_np, res_at):
print("Arrays equal")
print(res_np)
else:
print("Arrays are not equal")
print("np")
print(res_np)
print("at")
print(res_at)



def at4x4(A,B):
'''Alphatensor Multiply 4-by-4 matrix A with 4-by-4 matrix B, return result C'''

a11, a12, a13, a14, a21, a22, a23, a24, a31, a32, a33, a34, a41, a42, a43, a44 = A.ravel()
b11, b12, b13, b14, b21, b22, b23, b24, b31, b32, b33, b34, b41, b42, b43, b44 = B.ravel()
h1 = (a11 + a31) * (b11 + b31)
h2 = (a11 - a13 + a31) * (b11 - b13 + b31)
h3 = - a13 * (b11 - b13 + b31 - b33)
h4 = - a33 * - b33
h5 = - a31 * - b13
h6 = (a11 - a13 + a31 - a33) * - b31
h7 = (- a21 + a22 - a23 - a24) * (- b21 + b22 - b23 - b24)
h8 = (- a21 + a22 - a23 - a24 - a41 + a42) * (- b21 + b22 - b23 - b24 - b41 + b42)
h9 = (a11 - a13) * (b11 - b13)
h10 = (- a21 + a22 - a41 + a42) * (- b21 + b22 - b41 + b42)
h11 = (a41 - a42) * (- b23 - b24)
h12 = (- a21 + a22 - a23 - a24 - a41 + a42 - a43 - a44) * (b41 - b42)
h13 = (- a23 - a24) * (- b21 + b22 - b23 - b24 - b41 + b42 - b43 - b44)
h14 = (a11 - a12 + a21 - a22) * (- b12 - b14)
h15 = (- a12 - a14) * - b21
h16 = (a12 + a14 - a21 + a22 + a23 + a24) * (b12 + b14 - b21 + b22 + b23 + b24)
h17 = (a12 + a14 - a21 + a22 + a23 + a24 + a32 + a41 - a42) * (b12 + b14 - b21 + b22 + b23 + b24 + b32 + b41 - b42)
h18 = (a12 - a21 + a22 + a32 + a41 - a42) * (b12 - b21 + b22 + b32 + b41 - b42)
h19 = (a14 + a23 + a24) * (b12 + b14 - b21 + b22 + b23 + b24 + b32 + b34 + b41 - b42 - b43 - b44)
h20 = (a12 + a14 - a21 + a22 + a23 + a24 + a32 + a34 + a41 - a42 - a43 - a44) * (b32 + b41 - b42)
h21 = (a32 + a41 - a42) * (b14 + b23 + b24)
h22 = (a12 + a14 + a22 + a24) * (b12 + b14 + b22 + b24)
h23 = (a12 + a14 + a22 + a24 + a32 - a42) * (b12 + b14 + b22 + b24 + b32 - b42)
h24 = (a14 + a24) * (b12 + b14 + b22 + b24 + b32 + b34 - b42 - b44)
h25 = (a12 + a14 + a22 + a24 + a32 + a34 - a42 - a44) * (b32 - b42)
h26 = (a32 - a42) * (b14 + b24)
h27 = (a34 - a44) * (b34 - b44)
h28 = (a34 - a43 - a44) * (b34 - b43 - b44)
h29 = (a14 + a34) * - b43
h30 = (a13 + a14 + a23 + a24 + a33 + a34 - a43 - a44) * (b14 + b34)
h31 = (a11 - a12 - a13 - a14 + a21 - a22 - a23 - a24 + a31 - a32 - a33 - a34 - a41 + a42 + a43 + a44) * b14
h32 = - a43 * (b13 + b14 + b23 + b24 + b33 + b34 - b43 - b44)
h33 = a14 * (- b21 + b41)
h34 = (a14 - a32) * (- b21 + b41 - b43)
h35 = (a13 + a14 + a23 + a24 - a31 + a32 + a33 + a34 + a41 - a42 - a43 - a44) * (b14 - b32)
h36 = (- a31 + a32 + a33 + a34 + a41 - a42 - a43 - a44) * b32
h37 = (- a12 - a32) * - b23
h38 = (a32 + a34) * (b41 - b43)
h39 = (- a13 - a14 - a23 - a24) * (b32 + b34)
h40 = a32 * (- b21 + b23 + b41 - b43)
h41 = - a21 * (b11 - b12 + b21 - b22)
h42 = (- a21 + a41) * (b11 - b12 - b13 - b14 + b21 - b22 - b23 - b24 + b31 - b32 - b33 - b34 - b41 + b42 + b43 + b44)
h43 = (- a21 + a41 - a43) * (b13 + b14 + b23 + b24 - b31 + b32 + b33 + b34 + b41 - b42 - b43 - b44)
h44 = (a12 + a22 + a32 - a42) * (b12 + b22 + b32 - b42)
h45 = (- a21 + a23 + a41 - a43) * (- b31 + b32 + b33 + b34 + b41 - b42 - b43 - b44)
h46 = (- a31 + a32 + a41 - a42) * (- b12 - b32)
h47 = (a41 - a43) * (- b13 - b14 - b23 - b24)
h48 = (- a43 - a44) * (- b43 - b44)
h49 = - a23 * (- b31 + b32 + b41 - b42)
c11 = h1 - h2 - h5 + h9 + h15 + h33
c12 = - h15 - h16 + h17 - h18 - h21 + h22 - h23 + h26 - h33 - h41 + h44 + h49
c13 = h2 + h5 + h6 - h9 - h29 - h33 + h34 + h38
c14 = - h16 + h17 - h20 - h21 + h22 - h23 + h25 + h26 - h29 - h32 - h33 + h34 + h38 - h41 + h42 + h43
c21 = - h7 + h8 - h10 + h11 - h14 + h15 + h16 - h17 + h18 + h21 - h31 + h33 - h35 - h36
c22 = h7 - h8 + h10 - h11 - h15 - h16 + h17 - h18 - h21 + h22 - h23 + h26 - h33 + h44
c23 = - h7 + h8 + h11 + h12 - h16 + h17 - h20 - h21 - h29 - h33 + h34 + h36 + h38 + h46
c24 = - h7 + h8 + h11 + h12 - h16 + h17 - h20 - h21 + h22 - h23 + h25 + h26 - h29 - h33 + h34 + h38
c31 = h1 - h2 + h3 - h5 + h33 - h34 + h37 - h40
c32 = h17 - h18 - h19 - h21 - h23 + h24 + h26 - h33 + h34 - h37 + h40 - h43 + h44 + h45 - h47 + h49
c33 = h4 + h5 - h29 - h33 + h34 + h40
c34 = - h21 + h26 - h27 + h28 - h29 - h32 - h33 + h34 + h40 - h47
c41 = h8 - h10 + h11 - h13 + h17 - h18 - h19 - h21 + h31 - h33 + h34 + h35 + h36 - h37 - h39 + h40
c42 = - h8 + h10 - h11 + h13 - h17 + h18 + h19 + h21 + h23 - h24 - h26 + h33 - h34 + h37 - h40 - h44
c43 = h11 + h21 - h28 + h29 + h30 + h33 - h34 - h35 - h36 + h39 - h40 + h48
c44 = h11 + h21 - h26 + h27 - h28 + h29 + h33 - h34 - h40 + h48
C = np.array([c11, c12, c13, c14, c21, c22, c23, c24, c31, c32, c33, c34, c41, c42, c43, c44]).reshape(4, 4).T
return C




if __name__ == "__main__":
main()


236 changes: 236 additions & 0 deletions benchmarks/float/atmulti4x4/at_partial_print.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# test alpha tensor 4x4 against numpy matmul for correctness

import numpy as np

def main():
# [A,B]
matrices = [
[[8, 3, 2, 4],[2, 7, 4, 5],[0, 1, 2, 3],[0, 1, 2, 3]],
[[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12],[13, 14, 15, 16]],
]
A = np.array(matrices[0])
B = np.array(matrices[1])
#res_np = np.matmul(A,B)
res_at = at4x4(A,B)
#if np.array_equal(res_np, res_at):
# print("Arrays equal")
# print(res_np)
#else:
# print("Arrays are not equal")
# print("np")
# print(res_np)
# print("at")
# print(res_at)



def at4x4(A,B):
'''Alphatensor Multiply 4-by-4 matrix A with 4-by-4 matrix B, return result C'''

a11, a12, a13, a14, a21, a22, a23, a24, a31, a32, a33, a34, a41, a42, a43, a44 = A.ravel()
b11, b12, b13, b14, b21, b22, b23, b24, b31, b32, b33, b34, b41, b42, b43, b44 = B.ravel()
h1 = (a11 + a31) * (b11 + b31)
print("h1")
print(h1)
h2 = (a11 - a13 + a31) * (b11 - b13 + b31)
print("h2")
print(h2)
h3 = - a13 * (b11 - b13 + b31 - b33)
print("h3")
print(h3)
h4 = - a33 * - b33
print("h4")
print(h4)
h5 = - a31 * - b13
print("h5")
print(h5)
h6 = (a11 - a13 + a31 - a33) * - b31
print("h6")
print(h6)
h7 = (- a21 + a22 - a23 - a24) * (- b21 + b22 - b23 - b24)
print("h7")
print(h7)
h8 = (- a21 + a22 - a23 - a24 - a41 + a42) * (- b21 + b22 - b23 - b24 - b41 + b42)
print("h8")
print(h8)
h9 = (a11 - a13) * (b11 - b13)
print("h9")
print(h9)
h10 = (- a21 + a22 - a41 + a42) * (- b21 + b22 - b41 + b42)
print("h10")
print(h10)
h11 = (a41 - a42) * (- b23 - b24)
print("h11")
print(h11)
h12 = (- a21 + a22 - a23 - a24 - a41 + a42 - a43 - a44) * (b41 - b42)
print("h12")
print(h12)
h13 = (- a23 - a24) * (- b21 + b22 - b23 - b24 - b41 + b42 - b43 - b44)
print("h13")
print(h13)
h14 = (a11 - a12 + a21 - a22) * (- b12 - b14)
print("h14")
print(h14)
h15 = (- a12 - a14) * - b21
print("h15")
print(h15)
h16 = (a12 + a14 - a21 + a22 + a23 + a24) * (b12 + b14 - b21 + b22 + b23 + b24)
print("h16")
print(h16)
h17 = (a12 + a14 - a21 + a22 + a23 + a24 + a32 + a41 - a42) * (b12 + b14 - b21 + b22 + b23 + b24 + b32 + b41 - b42)
print("h17")
print(h17)
h18 = (a12 - a21 + a22 + a32 + a41 - a42) * (b12 - b21 + b22 + b32 + b41 - b42)
print("h18")
print(h18)
h19 = (a14 + a23 + a24) * (b12 + b14 - b21 + b22 + b23 + b24 + b32 + b34 + b41 - b42 - b43 - b44)
print("h19")
print(h19)
h20 = (a12 + a14 - a21 + a22 + a23 + a24 + a32 + a34 + a41 - a42 - a43 - a44) * (b32 + b41 - b42)
print("h20")
print(h20)
h21 = (a32 + a41 - a42) * (b14 + b23 + b24)
print("h21")
print(h21)
h22 = (a12 + a14 + a22 + a24) * (b12 + b14 + b22 + b24)
print("h22")
print(h22)
h23 = (a12 + a14 + a22 + a24 + a32 - a42) * (b12 + b14 + b22 + b24 + b32 - b42)
print("h23")
print(h23)
h24 = (a14 + a24) * (b12 + b14 + b22 + b24 + b32 + b34 - b42 - b44)
print("h24")
print(h24)
h25 = (a12 + a14 + a22 + a24 + a32 + a34 - a42 - a44) * (b32 - b42)
print("h25")
print(h25)
h26 = (a32 - a42) * (b14 + b24)
print("h26")
print(h26)
h27 = (a34 - a44) * (b34 - b44)
print("h27")
print(h27)
h28 = (a34 - a43 - a44) * (b34 - b43 - b44)
print("h28")
print(h28)
h29 = (a14 + a34) * - b43
print("h29")
print(h29)
h30 = (a13 + a14 + a23 + a24 + a33 + a34 - a43 - a44) * (b14 + b34)
print("h30")
print(h30)
h31 = (a11 - a12 - a13 - a14 + a21 - a22 - a23 - a24 + a31 - a32 - a33 - a34 - a41 + a42 + a43 + a44) * b14
print("h31")
print(h31)
h32 = - a43 * (b13 + b14 + b23 + b24 + b33 + b34 - b43 - b44)
print("h32")
print(h32)
h33 = a14 * (- b21 + b41)
print("h33")
print(h33)
h34 = (a14 - a32) * (- b21 + b41 - b43)
print("h34")
print(h34)
h35 = (a13 + a14 + a23 + a24 - a31 + a32 + a33 + a34 + a41 - a42 - a43 - a44) * (b14 - b32)
print("h35")
print(h35)
h36 = (- a31 + a32 + a33 + a34 + a41 - a42 - a43 - a44) * b32
print("h36")
print(h36)
h37 = (- a12 - a32) * - b23
print("h37")
print(h37)
h38 = (a32 + a34) * (b41 - b43)
print("h38")
print(h38)
h39 = (- a13 - a14 - a23 - a24) * (b32 + b34)
print("h39")
print(h39)
h40 = a32 * (- b21 + b23 + b41 - b43)
print("h40")
print(h40)
h41 = - a21 * (b11 - b12 + b21 - b22)
print("h41")
print(h41)
h42 = (- a21 + a41) * (b11 - b12 - b13 - b14 + b21 - b22 - b23 - b24 + b31 - b32 - b33 - b34 - b41 + b42 + b43 + b44)
print("h42")
print(h42)
h43 = (- a21 + a41 - a43) * (b13 + b14 + b23 + b24 - b31 + b32 + b33 + b34 + b41 - b42 - b43 - b44)
print("h43")
print(h43)
h44 = (a12 + a22 + a32 - a42) * (b12 + b22 + b32 - b42)
print("h44")
print(h44)
h45 = (- a21 + a23 + a41 - a43) * (- b31 + b32 + b33 + b34 + b41 - b42 - b43 - b44)
print("h45")
print(h45)
h46 = (- a31 + a32 + a41 - a42) * (- b12 - b32)
print("h46")
print(h46)
h47 = (a41 - a43) * (- b13 - b14 - b23 - b24)
print("h47")
print(h47)
h48 = (- a43 - a44) * (- b43 - b44)
print("h48")
print(h48)
h49 = - a23 * (- b31 + b32 + b41 - b42)
print("h49")
print(h49)
c11 = h1 - h2 - h5 + h9 + h15 + h33
print("c11")
print(c11)
c12 = - h15 - h16 + h17 - h18 - h21 + h22 - h23 + h26 - h33 - h41 + h44 + h49
print("c12")
print(c12)
c13 = h2 + h5 + h6 - h9 - h29 - h33 + h34 + h38
print("c13")
print(c13)
c14 = - h16 + h17 - h20 - h21 + h22 - h23 + h25 + h26 - h29 - h32 - h33 + h34 + h38 - h41 + h42 + h43
print("c14")
print(c14)
c21 = - h7 + h8 - h10 + h11 - h14 + h15 + h16 - h17 + h18 + h21 - h31 + h33 - h35 - h36
print("c21")
print(c21)
c22 = h7 - h8 + h10 - h11 - h15 - h16 + h17 - h18 - h21 + h22 - h23 + h26 - h33 + h44
print("c22")
print(c22)
c23 = - h7 + h8 + h11 + h12 - h16 + h17 - h20 - h21 - h29 - h33 + h34 + h36 + h38 + h46
print("c23")
print(c23)
c24 = - h7 + h8 + h11 + h12 - h16 + h17 - h20 - h21 + h22 - h23 + h25 + h26 - h29 - h33 + h34 + h38
print("c24")
print(c24)
c31 = h1 - h2 + h3 - h5 + h33 - h34 + h37 - h40
print("c31")
print(c31)
c32 = h17 - h18 - h19 - h21 - h23 + h24 + h26 - h33 + h34 - h37 + h40 - h43 + h44 + h45 - h47 + h49
print("c32")
print(c32)
c33 = h4 + h5 - h29 - h33 + h34 + h40
print("c33")
print(c33)
c34 = - h21 + h26 - h27 + h28 - h29 - h32 - h33 + h34 + h40 - h47
print("c34")
print(c34)
c41 = h8 - h10 + h11 - h13 + h17 - h18 - h19 - h21 + h31 - h33 + h34 + h35 + h36 - h37 - h39 + h40
print("c41")
print(c41)
c42 = - h8 + h10 - h11 + h13 - h17 + h18 + h19 + h21 + h23 - h24 - h26 + h33 - h34 + h37 - h40 - h44
print("c42")
print(c42)
c43 = h11 + h21 - h28 + h29 + h30 + h33 - h34 - h35 - h36 + h39 - h40 + h48
print("c43")
print(c43)
c44 = h11 + h21 - h26 + h27 - h28 + h29 + h33 - h34 - h40 + h48
print("c44")
print(c44)
C = np.array([c11, c12, c13, c14, c21, c22, c23, c24, c31, c32, c33, c34, c41, c42, c43, c44]).reshape(4, 4).T
return C




if __name__ == "__main__":
main()


Loading

0 comments on commit 7c2ad91

Please sign in to comment.