-
Notifications
You must be signed in to change notification settings - Fork 237
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bril matmul 4x4 with more matrices, tested
- Loading branch information
Bennett Wineholt
committed
Sep 1, 2023
1 parent
fd91376
commit 785e829
Showing
8 changed files
with
13,317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,357 @@ | ||
from jinja2 import Template | ||
import numpy as np | ||
|
||
matrices = [ | ||
[[8, 3, 2, 4],[2, 7, 4, 5],[0, 1, 2, 3],[0, 1, 2, 3]], | ||
[[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12],[13, 14, 15, 16]], | ||
[[17, 18, 19, 20],[21, 22, 23, 24],[25, 26, 27, 28],[29, 30, 31, 32]], | ||
[[33, 34, 35, 36],[37, 38, 39, 40],[41, 42, 43, 44],[45, 46, 47, 48]], | ||
[[49, 50, 51, 52],[53, 54, 55, 56],[57, 58, 59, 60],[61, 62, 63, 64]], | ||
[[65, 66, 67, 68],[69, 70, 71, 72],[73, 74, 75, 76],[77, 78, 79, 80]], | ||
[[0.12, 0.45, 0.67, 0.89], [1.23, 4.56, 7.89, 0.12], [3.45, 6.78, 9.01, 2.34], [5.67, 8.90, 1.23, 4.56]], | ||
[[7.12, 8.34, 9.56, 0.78], [1.91, 2.34, 5.67, 8.90], [1.12, 3.45, 6.78, 9.01], [2.34, 5.67, 8.90, 1.23]], | ||
[[4.56, 7.89, 0.12, 3.45], [6.78, 9.01, 2.34, 5.67], [8.90, 1.23, 4.56, 7.89], [0.12, 3.45, 6.78, 9.01]], | ||
[[2.34, 5.67, 8.90, 1.23], [4.56, 7.89, 0.12, 3.45], [6.78, 9.01, 2.34, 5.67], [8.90, 1.23, 4.56, 7.89]], | ||
[[0.12, 3.45, 6.78, 9.01], [2.34, 5.67, 8.90, 1.23], [4.56, 7.89, 0.12, 3.45], [6.78, 9.01, 2.34, 5.67]], | ||
[[8.90, 1.23, 4.56, 7.89], [0.12, 3.45, 6.78, 9.01], [2.34, 5.67, 8.90, 1.23], [4.56, 7.89, 0.12, 3.45]], | ||
[[6.78, 9.01, 2.34, 5.67], [8.90, 1.23, 4.56, 7.89], [0.12, 3.45, 6.78, 9.01], [2.34, 5.67, 8.90, 1.23]], | ||
[[4.56, 7.89, 0.12, 3.45], [6.78, 9.01, 2.34, 5.67], [8.90, 1.23, 4.56, 7.89], [0.12, 3.45, 6.78, 9.01]], | ||
[[2.34, 5.67, 8.90, 1.23], [4.56, 7.89, 0.12, 3.45], [6.78, 9.01, 2.34, 5.67], [8.90, 1.23, 4.56, 7.89]], | ||
[[0.12, 3.45, 6.78, 9.01], [2.34, 5.67, 8.90, 1.23], [4.56, 7.89, 0.12, 3.45], [6.78, 9.01, 2.34, 5.67]], | ||
[[9.87, 4.32, 1.09, 7.65], [2.98, 5.43, 8.76, 1.23], [4.56, 7.89, 0.11, 3.44], [6.77, 9.00, 2.33, 5.66]], | ||
[[3.21, 6.54, 9.87, 1.20], [4.53, 7.86, 0.19, 3.52], [6.85, 9.18, 2.51, 5.84], [8.17, 0.50, 4.83, 7.16]], | ||
[[1.49, 4.82, 8.15, 0.48], [3.81, 7.14, 0.47, 3.80], [6.13, 9.46, 2.79, 6.12], [8.45, 1.78, 5.11, 8.44]], | ||
[[0.77, 4.10, 7.43, 0.76], [3.09, 6.42, 9.75, 3.08], [5.41, 8.74, 2.07, 5.40], [7.73, 1.06, 4.39, 7.72]], | ||
[[2.05, 5.38, 8.71, 1.04], [4.37, 7.70, 0.03, 4.36], [6.69, 9.02, 2.35, 6.68], [8.01, 1.34, 4.67, 8.00]], | ||
[[3.33, 6.66, 9.99, 1.32], [4.65, 7.98, 1.31, 4.64], [6.97, 9.30, 2.63, 6.96], [9.29, 1.62, 4.95, 9.28]], | ||
[[1.61, 4.94, 8.27, 0.60], [3.93, 7.26, 0.59, 3.92], [6.25, 9.58, 2.91, 6.24], [8.57, 1.90, 5.23, 8.56]], | ||
[[0.89, 4.22, 7.55, 0.88], [3.21, 6.54, 9.87, 3.20], [5.53, 8.86, 2.19, 5.52], [7.85, 1.18, 4.51, 7.84]], | ||
[[2.17, 5.50, 8.83, 1.16], [4.49, 7.82, 1.15, 4.48], [6.81, 9.14, 2.47, 6.80], [9.13, 1.46, 4.79, 9.12]], | ||
[[3.45, 6.78, 0.11, 3.44], [5.77, 9.10, 2.43, 5.76], [8.09, 0.42, 4.75, 8.08], [0.41, 3.74, 7.07, 0.40]], | ||
[ | ||
[0.28457546132827105101, 0.88627381330549304117, 0.40320221937108335908, 0.30618816153541617009], | ||
[0.24348777186925174565, 0.50072784147668758514, 0.48842330783646209502, 0.19709424381285589600], | ||
[0.33713629192588301375, 0.91693290018374307149, 0.08621385671306092124, 0.78770454526290512032], | ||
[0.99145116252088472120, 0.30059001519913997047, 0.42804717767327410405, 0.64424647755355002321], | ||
], | ||
[ | ||
[0.48030881806754244234, 0.90259649449860590575, 0.67413645718059722611, 0.20232363120920468513], | ||
[0.36622239101634967984, 0.89406270127513465251, 0.95325893823456808729, 0.06421372832690142030], | ||
[0.24466327506236373868, 0.79081601532482548311, 0.75953013298928062635, 0.88266074171091057909], | ||
[0.41575699163745793996, 0.23821703956939355162, 0.58730964538244068152, 0.93577202842527273940], | ||
], | ||
[ | ||
[0.43989503735275387042, 0.86580910933685562014, 0.92210690549369112023, 0.82728956316327195708], | ||
[0.33625499552846527251, 0.09078152680887840997, 0.26865781289550894062, 0.85192843333728907051], | ||
[0.46296983660445939490, 0.05022771089161170988, 0.38981476933462932966, 0.67712015876507769541], | ||
[0.11654404309526489314, 0.35954340499047654500, 0.93342708656642436882, 0.46231808668802909512], | ||
], | ||
[ | ||
[0.36831544802535831629, 0.53198852149162323411, 0.01728710620287032818, 0.21467245598830841935], | ||
[0.16196980816140538195, 0.95281720171188366564, 0.33964868821846072588, 0.47546474995776200068], | ||
[0.09678521232130721241, 0.77746869576090238407, 0.63808430124154613683, 0.53161913932088511459], | ||
[0.59829189622507294999, 0.02930590131640231286, 0.80741837880772915348, 0.92115142723527643209], | ||
], | ||
[ | ||
[0.13573288328419841342, 0.24368095411816059759, 0.32119334495243034855, 0.26450568384973138780], | ||
[0.48174270084699799543, 0.04291930261229626176, 0.47157097095133587716, 0.35838537690657223944], | ||
[0.77553803969512746797, 0.92408995983327890666, 0.63204407523033578897, 0.68803668226751812931], | ||
[0.17434141787231605125, 0.86691470925322711150, 0.32403871366539904741, 0.54071477366476827786], | ||
], | ||
[ | ||
[0.14753592794468339822, 0.37806407892282123395, 0.25802388694408040504, 0.15776691010362897671], | ||
[0.75505246838259265640, 0.25212892065216507831, 0.62750681288657339518, 0.16400453516603197279], | ||
[0.66443125157348081888, 0.40209692978091393645, 0.30590717432995662151, 0.04215241442327544164], | ||
[0.63819395843850812433, 0.40249830989915769131, 0.11770185928804044462, 0.64643350018143419522], | ||
], | ||
[ | ||
[0.35923563112553247301, 0.68424507553533164828, 0.75062910406187488555, 0.70337448879700925630], | ||
[0.98750268901617577200, 0.10553122798450743913, 0.51443500172643952251, 0.68988281040887666773], | ||
[0.12312974174209968814, 0.12200042713941117167, 0.88343121075045694113, 0.48611329760158766833], | ||
[0.15648740315474041207, 0.09424294560749713057, 0.37418994981271902489, 0.35703367974888411407], | ||
], | ||
[ | ||
[0.40681021366070896361, 0.78620457480967287367, 0.68495447465686920552, 0.83102289006824336948], | ||
[0.93123388778495164164, 0.44595250556622301197, 0.15935476842946916243, 0.17258688305065117419], | ||
[0.75033536303401315859, 0.32787262043919784826, 0.39514666005546550398, 0.54373254260484937816], | ||
[0.78765577262333230646, 0.88051572263529320761, 0.10630017474477526651, 0.73095618076293988885], | ||
], | ||
[ | ||
[0.50792803247678042222, 0.56827175626472137271, 0.12678695892539348922, 0.55943568334025495226], | ||
[0.78925031930876643482, 0.52617888485160624334, 0.79965982298913795834, 0.35414076210837253100], | ||
[0.26044776325265950323, 0.23478698236881204164, 0.76698617153691495130, 0.39335203319833256241], | ||
[0.74837189726470898510, 0.03186591535360372429, 0.34182671149417553913, 0.61461244418348870422], | ||
], | ||
[ | ||
[0.74258329256046085032, 0.41690374497873239346, 0.01987092206255436366, 0.27300830543424725594], | ||
[0.83326637938900682823, 0.25703433984650458921, 0.07093626648498130294, 0.91022274932270630377], | ||
[0.59582503579951373585, 0.66916322732387678585, 0.21309321101146705413, 0.41487854039487981339], | ||
[0.43132289448645944052, 0.26974375677529716100, 0.22308283356038438594, 0.59828569565852096623], | ||
], | ||
[ | ||
[8.0, 3.0, 2.0, 4.0], | ||
[2.0, 7.0, 4.0, 5.0], | ||
[1.0, 1.0, 2.0, 3.0], | ||
[1.0, 1.0, 2.0, 3.0], | ||
], | ||
[ | ||
[1.0, 2.0, 3.0, 4.0], | ||
[5.0, 6.0, 7.0, 8.0], | ||
[9.0, 10.0, 11.0, 12.0], | ||
[13.0, 14.0, 15.0, 16.0], | ||
], | ||
[ | ||
[17.0, 18.0, 19.0, 20.0], | ||
[21.0, 22.0, 23.0, 24.0], | ||
[25.0, 26.0, 27.0, 28.0], | ||
[29.0, 30.0, 31.0, 32.0], | ||
], | ||
[ | ||
[33.0, 34.0, 35.0, 36.0], | ||
[37.0, 38.0, 39.0, 40.0], | ||
[41.0, 42.0, 43.0, 44.0], | ||
[45.0, 46.0, 47.0, 48.0], | ||
], | ||
[ | ||
[49.0, 50.0, 51.0, 52.0], | ||
[53.0, 54.0, 55.0, 56.0], | ||
[57.0, 58.0, 59.0, 60.0], | ||
[61.0, 62.0, 63.0, 64.0], | ||
], | ||
[ | ||
[65.0, 66.0, 67.0, 68.0], | ||
[69.0, 70.0, 71.0, 72.0], | ||
[73.0, 74.0, 75.0, 76.0], | ||
[77.0, 78.0, 79.0, 80.0], | ||
], | ||
[ | ||
[1.0, 1.0, 1.0, 1.0], | ||
[1.2, 4.6, 7.9, 1.0], | ||
[3.5, 6.8, 9.0, 2.3], | ||
[5.7, 8.9, 1.2, 4.6], | ||
], | ||
[ | ||
[7.1, 8.3, 9.6, 1.0], | ||
[1.9, 2.3, 5.7, 8.9], | ||
[1.1, 3.5, 6.8, 9.0], | ||
[2.3, 5.7, 8.9, 1.2], | ||
], | ||
[ | ||
[4.6, 7.9, 1.0, 3.5], | ||
[6.8, 9.0, 2.3, 5.7], | ||
[8.9, 1.2, 4.6, 7.9], | ||
[1.0, 3.5, 6.8, 9.0], | ||
], | ||
[ | ||
[2.3, 5.7, 8.9, 1.2], | ||
[4.6, 7.9, 1.0, 3.5], | ||
[6.8, 9.0, 2.3, 5.7], | ||
[8.9, 1.2, 4.6, 7.9], | ||
], | ||
[ | ||
[1.0, 3.5, 6.8, 9.0], | ||
[2.3, 5.7, 8.9, 1.2], | ||
[4.6, 7.9, 1.0, 3.5], | ||
[6.8, 9.0, 2.3, 5.7], | ||
], | ||
[ | ||
[8.9, 1.2, 4.6, 7.9], | ||
[1.0, 3.5, 6.8, 9.0], | ||
[2.3, 5.7, 8.9, 1.2], | ||
[4.6, 7.9, 1.0, 3.5], | ||
], | ||
[ | ||
[6.8, 9.0, 2.3, 5.7], | ||
[8.9, 1.2, 4.6, 7.9], | ||
[1.0, 3.5, 6.8, 9.0], | ||
[2.3, 5.7, 8.9, 1.2], | ||
], | ||
[ | ||
[4.6, 7.9, 1.0, 3.5], | ||
[6.8, 9.0, 2.3, 5.7], | ||
[8.9, 1.2, 4.6, 7.9], | ||
[1.0, 3.5, 6.8, 9.0], | ||
], | ||
[ | ||
[2.3, 5.7, 8.9, 1.2], | ||
[4.6, 7.9, 1.0, 3.5], | ||
[6.8, 9.0, 2.3, 5.7], | ||
[8.9, 1.2, 4.6, 7.9], | ||
], | ||
[ | ||
[1.0, 3.5, 6.8, 9.0], | ||
[2.3, 5.7, 8.9, 1.2], | ||
[4.6, 7.9, 1.0, 3.5], | ||
[6.8, 9.0, 2.3, 5.7], | ||
], | ||
[ | ||
[9.9, 4.3, 1.1, 7.7], | ||
[3.0, 5.4, 8.8, 1.2], | ||
[4.6, 7.9, 1.0, 3.4], | ||
[6.8, 9.0, 2.3, 5.7], | ||
], | ||
[ | ||
[3.2, 6.5, 9.9, 1.2], | ||
[4.5, 7.9, 1.0, 3.5], | ||
[6.8, 9.2, 2.5, 5.8], | ||
[8.2, 1.0, 4.8, 7.2], | ||
], | ||
[ | ||
[1.5, 4.8, 8.2, 1.0], | ||
[3.8, 7.1, 1.0, 3.8], | ||
[6.1, 9.5, 2.8, 6.1], | ||
[8.4, 1.8, 5.1, 8.4], | ||
], | ||
[ | ||
[1.0, 4.1, 7.4, 1.0], | ||
[3.1, 6.4, 9.8, 3.1], | ||
[5.4, 8.7, 2.1, 5.4], | ||
[7.7, 1.1, 4.4, 7.7], | ||
], | ||
[ | ||
[2.0, 5.4, 8.7, 1.0], | ||
[4.4, 7.7, 1.0, 4.4], | ||
[6.7, 9.0, 2.4, 6.7], | ||
[8.0, 1.3, 4.7, 8.0], | ||
], | ||
[ | ||
[3.3, 6.7, 10.0, 1.3], | ||
[4.7, 8.0, 1.3, 4.6], | ||
[7.0, 9.3, 2.6, 7.0], | ||
[9.3, 1.6, 5.0, 9.3], | ||
], | ||
[ | ||
[1.6, 4.9, 8.3, 1.0], | ||
[3.9, 7.3, 1.0, 3.9], | ||
[6.2, 9.6, 2.9, 6.2], | ||
[8.6, 1.9, 5.2, 8.6], | ||
], | ||
[ | ||
[1.0, 4.2, 7.5, 1.0], | ||
[3.2, 6.5, 9.9, 3.2], | ||
[5.5, 8.9, 2.2, 5.5], | ||
[7.8, 1.2, 4.5, 7.8], | ||
], | ||
[ | ||
[2.2, 5.5, 8.8, 1.2], | ||
[4.5, 7.8, 1.1, 4.5], | ||
[6.8, 9.1, 2.5, 6.8], | ||
[9.1, 1.5, 4.8, 9.1], | ||
], | ||
[ | ||
[3.5, 6.8, 1.0, 3.4], | ||
[5.8, 9.1, 2.4, 5.8], | ||
[8.1, 1.0, 4.8, 8.1], | ||
[1.0, 3.7, 7.1, 1.0], | ||
], | ||
[ | ||
[28.5, 1.0, 1.0, 1.0], | ||
[24.3, 1.0, 100.0, 19.7], | ||
[33.7, 1.0, 86.2, 78.8], | ||
[99.1, 1.0, 1.0, 64.4], | ||
], | ||
[ | ||
[1.0, 1.0, 1.0, 1.0], | ||
[1.0, 1.0, 1.0, 64.2], | ||
[1.0, 1.0, 1.0, 88.3], | ||
[1.0, 1.0, 5.9, 9.4], | ||
], | ||
[ | ||
[1.0, 86.6, 1.0, 1.0], | ||
[1.0, 9.1, 26.9, 1.0], | ||
[46.3, 5.0, 39.0, 1.0], | ||
[11.7, 100.0, 93.3, 1.0], | ||
], | ||
[ | ||
[1.0, 53.2, 1.0, 21.5], | ||
[16.2, 9.5, 34.0, 47.5], | ||
[96.8, 77.7, 6.4, 5.3], | ||
[59.8, 29.3, 1.0, 1.0], | ||
], | ||
[ | ||
[13.6, 1.0, 1.0, 1.0], | ||
[48.2, 4.3, 1.0, 35.8], | ||
[77.6, 100.0, 1.0, 100.0], | ||
[17.4, 86.7, 1.0, 4.1], | ||
], | ||
[ | ||
[1.5, 1.0, 25.8, 1.0], | ||
[75.5, 100.0, 62.8, 16.4], | ||
[6.6, 40.2, 30.6, 100.0], | ||
[100.0, 4.0, 1.0, 1.0], | ||
], | ||
[ | ||
[1.0, 1.0, 1.0, 1.0], | ||
[98.8, 1.0, 1.0, 1.0], | ||
[12.3, 12.2, 1.0, 1.0], | ||
[1.0, 9.4, 1.0, 1.0], | ||
], | ||
[ | ||
[40.7, 1.0, 1.0, 1.0], | ||
[93.1, 44.6, 15.9, 17.3], | ||
[75.0, 3.3, 39.5, 54.4], | ||
[78.8, 88.1, 1.0, 7.3], | ||
], | ||
[ | ||
[1.0, 1.0, 12.7, 55.9], | ||
[7.9, 52.6, 8.0, 35.4], | ||
[26.0, 23.5, 7.7, 39.3], | ||
[7.5, 3.2, 3.4, 61.5], | ||
], | ||
[ | ||
[7.4, 41.7, 100.0, 2.7], | ||
[8.3, 25.7, 7.1, 9.1], | ||
[59.6, 66.9, 100.0, 41.5], | ||
[43.1, 2.7, 2.2, 59.8], | ||
] | ||
] | ||
|
||
num_examples = int(len(matrices) / 2) # Number of 4x4 matrix multiplication examples | ||
|
||
# Define the template for Bril code | ||
template_str = ''' | ||
@main { | ||
{% for i in range(num_examples) %} | ||
{% for row in range(4) %} | ||
{% for col in range(4) %} | ||
a{{row+1}}{{col+1}}_{{i}}: float = const {{matrices[2*i][row][col]}}; | ||
b{{row+1}}{{col+1}}_{{i}}: float = const {{matrices[2*i+1][row][col]}}; | ||
{% endfor %} | ||
{% endfor %} | ||
{% for row in range(4) %} | ||
{% for col in range(4) %} | ||
c{{row+1}}{{col+1}}_{{i}}: float = const 0; | ||
{% for k in range(4) %} | ||
temp_{{i}}_{{row}}_{{col}}_{{k}}: float = fmul a{{row+1}}{{k+1}}_{{i}} b{{k+1}}{{col+1}}_{{i}}; | ||
c{{row+1}}{{col+1}}_{{i}} = fadd c{{row+1}}{{col+1}}_{{i}} temp_{{i}}_{{row}}_{{col}}_{{k}}; | ||
{% endfor %} | ||
print c{{row+1}}{{col+1}}_{{i}}; | ||
{% endfor %} | ||
{% endfor %} | ||
{% endfor %} | ||
ret; | ||
} | ||
''' | ||
|
||
|
||
# Generate reference results using NumPy and write to matmulti4x4.ref | ||
with open("matmulti4x4.ref", "w") as ref_file: | ||
for i in range(num_examples): | ||
mat1 = np.array(matrices[2*i]) | ||
mat2 = np.array(matrices[2*i + 1]) | ||
result = np.matmul(mat1, mat2) | ||
for row in result: | ||
for element in row: | ||
ref_file.write(f"{element:.17f}\n") | ||
|
||
# Create a Jinja2 template and render it | ||
template = Template(template_str) | ||
rendered_str = template.render(matrices=matrices, num_examples=num_examples) | ||
|
||
# Write the rendered Bril code to matmulti4x4.bril | ||
with open("matmulti4x4.bril", "w") as bril_file: | ||
bril_file.write(rendered_str) | ||
|
||
|
||
|
Oops, something went wrong.