From 0be690fc3ed84f62deb4712082f9bc006ea8c1fc Mon Sep 17 00:00:00 2001 From: Eric Gouriou Date: Thu, 2 Mar 2023 13:45:20 -0800 Subject: [PATCH] Zvk: Implement Zvksh, vector SM3 Hash Function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement the Zvksh sub-extension, "ShangMi Suite: SM3 Hash Function Instructions": - vsm3me.vv, message expansion, - vsm3c.vi, compression rounds. This also introduces a SM3 specific header for common logic. Co-authored-by: Raghav Gupta Co-authored-by: Albert Jakieła Co-authored-by: Kornel Dulęba Signed-off-by: Eric Gouriou --- riscv/insns/vsm3c_vi.h | 89 ++++++++++++++++++++++++++++++++++++++++ riscv/insns/vsm3me_vv.h | 58 ++++++++++++++++++++++++++ riscv/riscv.mk.in | 5 +++ riscv/zvksh_ext_macros.h | 47 +++++++++++++++++++++ 4 files changed, 199 insertions(+) create mode 100644 riscv/insns/vsm3c_vi.h create mode 100644 riscv/insns/vsm3me_vv.h create mode 100644 riscv/zvksh_ext_macros.h diff --git a/riscv/insns/vsm3c_vi.h b/riscv/insns/vsm3c_vi.h new file mode 100644 index 0000000000..50f65c8b15 --- /dev/null +++ b/riscv/insns/vsm3c_vi.h @@ -0,0 +1,89 @@ +// vsm3c.vi vd, vs2, rnd + +#include "zvksh_ext_macros.h" + +// Uncomment to enable debug logging of invocations of this instruction. +//#define DLOG_INVOCATION + +#if defined(DLOG_INVOCATION) +#define DLOG(...) ZVK_DBG_LOG(__VA_ARGS__) +// Print format/value for "v()" +#define PRI_uR_xEG PRI_uREG_xEGU32x8 +#define PRV_R_EG(reg_num, reg) PRV_REG_EGU32x8_LE(reg_num, reg) +#else +#define DLOG(...) (void)(0) +#endif + +require_vsm3_constraints; + +// Rotates left a uint32_t value by N bits. +// uint32_t SM3_ROL32(uint32_t X, unsigned int N); +// This is a "safer" version of zvk_ext_macros.h's ROL32 that accepts +// a run-time shift-value between 0 and 31. ROL32 has Undefine Behavior +// when invoked with value 0. +#define SM3_ROL32(X, N) \ + ((N) == 0 ? (X) : (((X) << (N)) | ((X) >> (32 - (N))))) + +VI_ZVK_VD_VS2_ZIMM5_EGU32x8_NOVM_LOOP( + { + DLOG("-- vsm3c_vi " ZVK_PRI_REGNUMS_VD_VS2_ZIMM5, + ZVK_PRV_REGNUMS_VD_VS2_ZIMM5); + }, + // No need to validate or normalize 'zimm5' here as this is a 5 bits value + // and all values in 0-31 are valid. + const reg_t round = zimm5;, + { + DLOG("vsm3c " PRI_uR_xEG " " PRI_uR_xEG, + PRV_R_EG(vd_num, vd), PRV_R_EG(vs2_num, vs2)); + + // {H, G, F, E, D, C, B, A} <- vd + EXTRACT_EGU32x8_WORDS_BE_BSWAP(vd, H, G, F, E, D, C, B, A); + // {_, _, w5, w4, _, _, w1, w0} <- vs2 + EXTRACT_EGU32x8_WORDS_BE_BSWAP(vs2, + _unused_w7, _unused_w6, w5, w4, + _unused_w3, _unused_w2, w1, w0); + const uint32_t x0 = w0 ^ w4; // W'[0] in spec documentation. + const uint32_t x1 = w1 ^ w5; // W'[1] + + // Two rounds of compression. + uint32_t ss1; + uint32_t ss2; + uint32_t tt1; + uint32_t tt2; + uint32_t j; + + j = 2 * round; + ss1 = SM3_ROL32(SM3_ROL32(A, 12) + E + SM3_ROL32(ZVKSH_T(j), j % 32), 7); + ss2 = ss1 ^ SM3_ROL32(A, 12); + tt1 = ZVKSH_FF(A, B, C, j) + D + ss2 + x0; + tt2 = ZVKSH_GG(E, F, G, j) + H + ss1 + w0; + D = C; + const uint32_t C1 = SM3_ROL32(B, 9); + B = A; + const uint32_t A1 = tt1; + H = G; + const uint32_t G1 = SM3_ROL32(F, 19); + F = E; + const uint32_t E1 = ZVKSH_P0(tt2); + + j = 2 * round + 1; + ss1 = SM3_ROL32(SM3_ROL32(A1, 12) + E1 + SM3_ROL32(ZVKSH_T(j), j % 32), 7); + ss2 = ss1 ^ SM3_ROL32(A1, 12); + tt1 = ZVKSH_FF(A1, B, C1, j) + D + ss2 + x1; + tt2 = ZVKSH_GG(E1, F, G1, j) + H + ss1 + w1; + D = C1; + const uint32_t C2 = SM3_ROL32(B, 9); + B = A1; + const uint32_t A2 = tt1; + H = G1; + const uint32_t G2 = SM3_ROL32(F, 19); + F = E1; + const uint32_t E2 = ZVKSH_P0(tt2); + + // Update the destination register. + SET_EGU32x8_WORDS_BE_BSWAP(vd, G1, G2, E1, E2, C1, C2, A1, A2); + DLOG("= vsm3c " PRI_uR_xEG, PRV_R_EG(vd_num, vd)); + } +); + +#undef SM3_ROL32 diff --git a/riscv/insns/vsm3me_vv.h b/riscv/insns/vsm3me_vv.h new file mode 100644 index 0000000000..893f81a79e --- /dev/null +++ b/riscv/insns/vsm3me_vv.h @@ -0,0 +1,58 @@ +// vsm3me.vv vd, vs2, vs1 + +#include "zvk_ext_macros.h" +#include "zvksh_ext_macros.h" + +// Uncomment to enable debug logging of invocations of this instruction. +//#define DLOG_INVOCATION + +#if defined(DLOG_INVOCATION) +#define DLOG(...) ZVK_DBG_LOG(__VA_ARGS__) +// Print format/value for "v()" +#define PRI_uR_xEG PRI_uREG_xEGU32x8 +#define PRV_R_EG(reg_num, reg) PRV_REG_EGU32x8_LE(reg_num, reg) +#else +#define DLOG(...) (void)(0) +#endif + +// Per the SM3 spec, the message expansion computes new words Wi as: +// W[i] = ( P_1( W[i-16] xor W[i-9] xor ( W[i-3] <<< 15 ) ) +// xor ( W[i-13] <<< 7 ) +// xor W[i-6])) +// Using arguments M16 = W[i-16], M9 = W[i-9], etc., +// where Mk stands for "W[i Minus k]", we define the "W function": +#define ZVKSH_W(M16, M9, M3, M13, M6) \ + (ZVKSH_P1( (M16) ^ (M9) ^ ROL32((M3), 15) ) ^ ROL32((M13), 7) ^ (M6)) + +require_vsm3_constraints; + +VI_ZVK_VD_VS1_VS2_EGU32x8_NOVM_LOOP( + { + DLOG("-- vsm3me_vv " ZVK_PRI_REGNUMS_VD_VS2_VS1, + ZVK_PRV_REGNUMS_VD_VS2_VS1); + }, + { + DLOG("vsm3me " PRI_uR_xEG " " PRI_uR_xEG " " PRI_uR_xEG, + PRV_R_EG(vd_num, vd), PRV_R_EG(vs2_num, vs2), PRV_R_EG(vs1_num, vs1)); + + // {w7, w6, w5, w4, w3, w2, w1, w0} <- vs1 + EXTRACT_EGU32x8_WORDS_BE_BSWAP(vs1, w7, w6, w5, w4, w3, w2, w1, w0); + // {w15, w14, w13, w12, w11, w10, w9, w8} <- vs2 + EXTRACT_EGU32x8_WORDS_BE_BSWAP(vs2, w15, w14, w13, w12, w11, w10, w9, w8); + + // Arguments are W[i-16], W[i-9], W[i-13], W[i-6]. + // Note that some of the newly computed words are used in later invocations. + const uint32_t w16 = ZVKSH_W(w0, w7, w13, w3, w10); + const uint32_t w17 = ZVKSH_W(w1, w8, w14, w4, w11); + const uint32_t w18 = ZVKSH_W(w2, w9, w15, w5, w12); + const uint32_t w19 = ZVKSH_W(w3, w10, w16, w6, w13); + const uint32_t w20 = ZVKSH_W(w4, w11, w17, w7, w14); + const uint32_t w21 = ZVKSH_W(w5, w12, w18, w8, w15); + const uint32_t w22 = ZVKSH_W(w6, w13, w19, w9, w16); + const uint32_t w23 = ZVKSH_W(w7, w14, w20, w10, w17); + + // Update the destination register. + SET_EGU32x8_WORDS_BE_BSWAP(vd, w23, w22, w21, w20, w19, w18, w17, w16); + DLOG("= vsm3me " PRI_uR_xEG, PRV_R_EG(vd_num, vd)); + } +); diff --git a/riscv/riscv.mk.in b/riscv/riscv.mk.in index f6bc8be71c..d7ede03af3 100644 --- a/riscv/riscv.mk.in +++ b/riscv/riscv.mk.in @@ -1367,6 +1367,10 @@ riscv_insn_ext_zvksed = \ vsm4r_vs \ vsm4r_vv \ +riscv_insn_ext_zvksh = \ + vsm3c_vi \ + vsm3me_vv \ + riscv_insn_ext_zvk = \ $(riscv_insn_ext_zvbb) \ $(riscv_insn_ext_zvbc) \ @@ -1374,6 +1378,7 @@ riscv_insn_ext_zvk = \ $(riscv_insn_ext_zvkned) \ $(riscv_insn_ext_zvknh) \ $(riscv_insn_ext_zvksed) \ + $(riscv_insn_ext_zvksh) \ # Note that riscv_insn_ext_p and riscv_insn_ext_zvk contain instructions # that have conflicting encodings. They cannot be both included concurrently. diff --git a/riscv/zvksh_ext_macros.h b/riscv/zvksh_ext_macros.h new file mode 100644 index 0000000000..8260817965 --- /dev/null +++ b/riscv/zvksh_ext_macros.h @@ -0,0 +1,47 @@ +// Helper macros and functions to help implement instructions defined as part of +// the RISC-V Zvksh extension (vectorized SM3). + +#include "zvk_ext_macros.h" + +#ifndef RISCV_INSNS_ZVKSH_COMMON_H_ +#define RISCV_INSNS_ZVKSH_COMMON_H_ + +// Constraints common to all vsm3* instructions: +// - Zvksh is enabled +// - VSEW == 32 +// - EGW (256) <= LMUL * VLEN +// - No overlap of vd and vs2. +// +// The constraint that vstart and vl are both EGS (8) aligned +// is checked in the VI_ZVK_..._EGU32x8_..._LOOP macros. +#define require_vsm3_constraints \ + do { \ + require_zvksh; \ + require(P.VU.vsew == 32); \ + require_egw_fits(256); \ + require(insn.rd() != insn.rs2()); \ + } while (false) + +#define FF1(X, Y, Z) ((X) ^ (Y) ^ (Z)) +#define FF2(X, Y, Z) (((X) & (Y)) | ((X) & (Z)) | ((Y) & (Z))) + +// Boolean function FF_j - section 4.3. of the IETF draft. +#define ZVKSH_FF(X, Y, Z, J) (((J) <= 15) ? FF1(X, Y, Z) : FF2(X, Y, Z)) + +#define GG1(X, Y, Z) ((X) ^ (Y) ^ (Z)) +#define GG2(X, Y, Z) (((X) & (Y)) | ((~(X)) & (Z))) + +// Boolean function GG_j - section 4.3. of the IETF draft. +#define ZVKSH_GG(X, Y, Z, J) (((J) <= 15) ? GG1(X, Y, Z) : GG2(X, Y, Z)) + +#define T1 0x79CC4519 +#define T2 0x7A879D8A + +// T_j constant - section 4.2. of the IETF draft. +#define ZVKSH_T(J) (((J) <= 15) ? (T1) : (T2)) + +// Permutation functions P_0 and P_1 - section 4.4 of the IETF draft. +#define ZVKSH_P0(X) ((X) ^ ROL32((X), 9) ^ ROL32((X), 17)) +#define ZVKSH_P1(X) ((X) ^ ROL32((X), 15) ^ ROL32((X), 23)) + +#endif // RISCV_INSNS_ZVKSH_COMMON_H