From 73a992f0a63f4f7f5793b791347e7e7dd7b8d9fb Mon Sep 17 00:00:00 2001 From: Nikita Frolov Date: Tue, 16 May 2023 09:34:49 +0200 Subject: [PATCH] refactor(compiler): generalize noise calculation in FHE/FHELinalg Instead of having one `getSQManp` implementation per op with a lot of repetition, the noise calculation is now modular. - Ops that implements`UnaryEint`/`BinaryInt`/`BinaryEint` interfaces share the operand noise presence check. - For many scalar ops no further calculation is needed. If it's not the case, an op can override `sqMANP`. - Integer operand types lookups are abstracted into `BinaryInt::operandIntType()` - Finding largest operand value for a type is abstracted into `BinaryInt::operandMaxConstant` - Noise calculation for matmul ops is simplified and it's now general enough to work for `matmul_eint_int`, `matmul_int_eint` and `dot_eint_int` at once. --- .../concretelang/Dialect/FHE/CMakeLists.txt | 1 + .../Dialect/FHE/IR/CMakeLists.txt | 4 - .../Dialect/FHE/IR/FHEInterfaces.td | 32 - .../concretelang/Dialect/FHE/IR/FHEOps.td | 45 +- .../concretelang/Dialect/FHE/IR/FHETypes.h | 4 +- .../concretelang/Dialect/FHE/IR/FHETypes.td | 10 +- .../Dialect/FHE/Interfaces/CMakeLists.txt | 7 + .../Dialect/FHE/Interfaces/FHEInterfaces.h | 20 + .../Dialect/FHE/Interfaces/FHEInterfaces.td | 189 +++ .../FHE/Interfaces/FHEInterfacesInstances.h | 21 + .../Dialect/FHELinalg/IR/FHELinalgOps.td | 65 +- .../Dialect/Tracing/IR/TracingOps.td | 2 +- .../compiler/lib/Bindings/Rust/build.rs | 1 + .../compiler/lib/CAPI/Dialect/FHE/FHE.cpp | 6 +- .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 18 +- .../lib/Dialect/FHE/Analysis/MANP.cpp | 1475 +++++------------ .../compiler/lib/Dialect/FHE/CMakeLists.txt | 1 + .../lib/Dialect/FHE/IR/CMakeLists.txt | 3 +- .../lib/Dialect/FHE/IR/FHEDialect.cpp | 9 +- .../compiler/lib/Dialect/FHE/IR/FHEOps.cpp | 7 +- .../lib/Dialect/FHE/Interfaces/CMakeLists.txt | 11 + .../Dialect/FHE/Interfaces/FHEInterfaces.cpp | 9 + .../FHE/Interfaces/FHEInterfacesInstances.cpp | 24 + .../lib/Dialect/FHE/Transforms/BigInt.cpp | 35 +- .../lib/Dialect/FHE/Transforms/Boolean.cpp | 4 +- .../lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp | 4 +- .../Conversion/FHEToTFHEScalar/neg_eint.mlir | 8 - .../Dialect/FHE/Analysis/MANP.mlir | 2 +- .../Dialect/FHE/Analysis/MANP_conv2d.mlir | 6 +- .../Dialect/FHE/Analysis/MANP_linalg.mlir | 46 +- .../Dialect/FHE/Analysis/MANP_tensor.mlir | 8 +- 31 files changed, 850 insertions(+), 1227 deletions(-) delete mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfacesInstances.h create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfaces.cpp create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfacesInstances.cpp diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/CMakeLists.txt index 306b439685..dd089989bc 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Interfaces) add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt index 54e5a4799f..9760b5f5da 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt @@ -1,7 +1,3 @@ -set(LLVM_TARGET_DEFINITIONS FHEInterfaces.td) -mlir_tablegen(FHETypesInterfaces.h.inc -gen-type-interface-decls) -mlir_tablegen(FHETypesInterfaces.cpp.inc -gen-type-interface-defs) - set(LLVM_TARGET_DEFINITIONS FHEOps.td) mlir_tablegen(FHEOps.h.inc -gen-op-decls) mlir_tablegen(FHEOps.cpp.inc -gen-op-defs) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td deleted file mode 100644 index cb0afe5648..0000000000 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES -#define CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES - -include "mlir/IR/OpBase.td" - -def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> { - let cppNamespace = "mlir::concretelang::FHE"; - - let description = [{ - Interface for encapsulating the common properties of encrypted integer types. - }]; - - let methods = [ - InterfaceMethod< - /*description=*/"Get bit-width of the integer.", - /*retTy=*/"unsigned", - /*methodName=*/"getWidth" - >, - InterfaceMethod< - /*description=*/"Get whether the integer is signed.", - /*retTy=*/"bool", - /*methodName=*/"isSigned" - >, - InterfaceMethod< - /*description=*/"Get whether the integer is unsigned.", - /*retTy=*/"bool", - /*methodName=*/"isUnsigned" - > - ]; -} - -#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index 0814248140..06d105cfaf 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -14,11 +14,12 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "concretelang/Dialect/FHE/IR/FHEDialect.td" include "concretelang/Dialect/FHE/IR/FHETypes.td" +include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td" class FHE_Op traits = []> : Op; -def FHE_ZeroEintOp : FHE_Op<"zero", [Pure]> { +def FHE_ZeroEintOp : FHE_Op<"zero", [Pure, ConstantNoise]> { let summary = "Returns a trivial encrypted integer of 0"; let description = [{ @@ -33,7 +34,7 @@ def FHE_ZeroEintOp : FHE_Op<"zero", [Pure]> { let results = (outs FHE_AnyEncryptedInteger:$out); } -def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure]> { +def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure, ConstantNoise]> { let summary = "Creates a new tensor with all elements initialized to an encrypted zero."; let description = [{ @@ -51,7 +52,7 @@ def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [Pure]> { let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); } -def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure]> { +def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods]> { let summary = "Adds an encrypted integer and a clear integer"; let description = [{ @@ -84,7 +85,7 @@ def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [Pure]> { let hasFolder = 1; } -def FHE_AddEintOp : FHE_Op<"add_eint", [Pure]> { +def FHE_AddEintOp : FHE_Op<"add_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods]> { let summary = "Adds two encrypted integers"; let description = [{ @@ -116,7 +117,7 @@ def FHE_AddEintOp : FHE_Op<"add_eint", [Pure]> { let hasVerifier = 1; } -def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure]> { +def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure, BinaryIntEint]> { let summary = "Subtract an encrypted integer from a clear integer"; let description = [{ @@ -148,7 +149,7 @@ def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [Pure]> { let hasVerifier = 1; } -def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure]> { +def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods]> { let summary = "Subtract a clear integer from an encrypted integer"; let description = [{ @@ -181,7 +182,7 @@ def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [Pure]> { let hasFolder = 1; } -def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure]> { +def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods]> { let summary = "Subtract an encrypted integer from an encrypted integer"; let description = [{ @@ -213,7 +214,7 @@ def FHE_SubEintOp : FHE_Op<"sub_eint", [Pure]> { let hasVerifier = 1; } -def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure]> { +def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Negates an encrypted integer"; @@ -243,7 +244,7 @@ def FHE_NegEintOp : FHE_Op<"neg_eint", [Pure]> { let hasVerifier = 1; } -def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure]> { +def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods]> { let summary = "Multiply an encrypted integer with a clear integer"; let description = [{ @@ -277,7 +278,7 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [Pure]> { let hasCanonicalizer = 1; } -def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure]> { +def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods]> { let summary = "Multiplies two encrypted integers"; let description = [{ @@ -313,7 +314,7 @@ def FHE_MulEintOp : FHE_Op<"mul_eint", [Pure]> { let hasVerifier = 1; } -def FHE_MaxEintOp : FHE_Op<"max_eint", [Pure]> { +def FHE_MaxEintOp : FHE_Op<"max_eint", [Pure, BinaryEint, DeclareOpInterfaceMethods]> { let summary = "Retrieve the maximum of two encrypted integers."; let description = [{ @@ -348,7 +349,7 @@ def FHE_MaxEintOp : FHE_Op<"max_eint", [Pure]> { let hasVerifier = 1; } -def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure]> { +def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Cast an unsigned integer to a signed one"; let description = [{ @@ -366,14 +367,14 @@ def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure]> { ``` }]; - let arguments = (ins FHE_EncryptedIntegerType:$input); + let arguments = (ins FHE_EncryptedUnsignedIntegerType:$input); let results = (outs FHE_EncryptedSignedIntegerType); let hasVerifier = 1; let hasCanonicalizer = 1; } -def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> { +def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Cast a signed integer to an unsigned one"; let description = [{ @@ -392,13 +393,13 @@ def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> { }]; let arguments = (ins FHE_EncryptedSignedIntegerType:$input); - let results = (outs FHE_EncryptedIntegerType); + let results = (outs FHE_EncryptedUnsignedIntegerType); let hasVerifier = 1; let hasCanonicalizer = 1; } -def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure]> { +def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure, ConstantNoise]> { let summary = "Applies a clear lookup table to an encrypted integer"; @@ -424,7 +425,7 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure]> { let hasVerifier = 1; } -def FHE_RoundEintOp: FHE_Op<"round", [Pure]> { +def FHE_RoundEintOp: FHE_Op<"round", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Rounds a ciphertext to a smaller precision."; @@ -556,7 +557,7 @@ def FHE_BoolXorOp : FHE_Op<"xor", [Pure]> { let results = (outs FHE_EncryptedBooleanType); } -def FHE_BoolNotOp : FHE_Op<"not", [Pure]> { +def FHE_BoolNotOp : FHE_Op<"not", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Applies a NOT gate to an encrypted boolean value"; @@ -571,7 +572,7 @@ def FHE_BoolNotOp : FHE_Op<"not", [Pure]> { let results = (outs FHE_EncryptedBooleanType); } -def FHE_ToBoolOp : FHE_Op<"to_bool", [Pure]> { +def FHE_ToBoolOp : FHE_Op<"to_bool", [Pure, UnaryEint]> { let summary = "Cast an unsigned integer to a boolean"; let description = [{ @@ -589,13 +590,13 @@ def FHE_ToBoolOp : FHE_Op<"to_bool", [Pure]> { ``` }]; - let arguments = (ins FHE_EncryptedIntegerType:$input); + let arguments = (ins FHE_EncryptedUnsignedIntegerType:$input); let results = (outs FHE_EncryptedBooleanType); let hasVerifier = 1; } -def FHE_FromBoolOp : FHE_Op<"from_bool", [Pure]> { +def FHE_FromBoolOp : FHE_Op<"from_bool", [Pure, UnaryEint]> { let summary = "Cast a boolean to an unsigned integer"; let description = [{ @@ -608,7 +609,7 @@ def FHE_FromBoolOp : FHE_Op<"from_bool", [Pure]> { }]; let arguments = (ins FHE_EncryptedBooleanType:$input); - let results = (outs FHE_EncryptedIntegerType); + let results = (outs FHE_EncryptedUnsignedIntegerType); } diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.h index 9db4a81c81..f2df299717 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.h @@ -11,7 +11,9 @@ #include #include -#include "concretelang/Dialect/FHE/IR/FHETypesInterfaces.h.inc" +#include + +#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h" #define GET_TYPEDEF_CLASSES #include "concretelang/Dialect/FHE/IR/FHEOpsTypes.h.inc" diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td index 1201061394..7859a96a26 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td @@ -10,20 +10,20 @@ #define CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES include "concretelang/Dialect/FHE/IR/FHEDialect.td" -include "concretelang/Dialect/FHE/IR/FHEInterfaces.td" +include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td" include "mlir/IR/BuiltinTypes.td" class FHE_Type traits = []> : TypeDef { } -def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger", +def FHE_EncryptedUnsignedIntegerType : FHE_Type<"EncryptedUnsignedInteger", [MemRefElementTypeInterface, FheIntegerInterface]> { let mnemonic = "eint"; - let summary = "An encrypted integer"; + let summary = "An encrypted unsigned integer"; let description = [{ - An encrypted integer with `width` bits to performs FHE Operations. + An encrypted unsigned integer with `width` bits to performs FHE Operations. Examples: ```mlir @@ -73,7 +73,7 @@ def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger", } def FHE_AnyEncryptedInteger : Type>; diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/CMakeLists.txt new file mode 100644 index 0000000000..d0ec665ddd --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/CMakeLists.txt @@ -0,0 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS FHEInterfaces.td) +mlir_tablegen(FHETypesInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(FHETypesInterfaces.cpp.inc -gen-type-interface-defs) +mlir_tablegen(FHEOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(FHEOpsInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRFHEInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRFHEInterfacesIncGen) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h new file mode 100644 index 0000000000..092f431b64 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h @@ -0,0 +1,20 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_FHE_INTERFACES_FHEINTERFACES_H +#define CONCRETELANG_DIALECT_FHE_INTERFACES_FHEINTERFACES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" + +#include "concretelang/Dialect/FHE/Interfaces/FHEOpsInterfaces.h.inc" +#include "concretelang/Dialect/FHE/Interfaces/FHETypesInterfaces.h.inc" + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td new file mode 100644 index 0000000000..dc9f531223 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td @@ -0,0 +1,189 @@ +#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES +#define CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES + +include "mlir/IR/OpBase.td" + +def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> { + let cppNamespace = "mlir::concretelang::FHE"; + + let description = [{ + Interface for encapsulating the common properties of encrypted integer types. + }]; + + let methods = [ + InterfaceMethod< + /*description=*/"Get bit-width of the integer.", + /*retTy=*/"unsigned", + /*methodName=*/"getWidth" + >, + InterfaceMethod< + /*description=*/"Get whether the integer is signed.", + /*retTy=*/"bool", + /*methodName=*/"isSigned" + >, + InterfaceMethod< + /*description=*/"Get whether the integer is unsigned.", + /*retTy=*/"bool", + /*methodName=*/"isUnsigned" + > + ]; +} + +def ConstantNoise : OpInterface<"ConstantNoise"> { + let description = [{ + An operation which always has the same noise. + }]; + + let cppNamespace = "mlir::concretelang::FHE"; +} + +def UnaryEint : OpInterface<"UnaryEint"> { + let description = [{ + A unary operation on scalars, with the operand encrypted. + }]; + + let cppNamespace = "mlir::concretelang::FHE"; + + let methods = [ + InterfaceMethod< + /*description=*/"Calculate squared MANP", + /*retTy=*/"llvm::APInt", + /*methodName=*/"sqMANP", + /*ins=*/(ins "llvm::APInt":$a), + /*methodBody=*/"", + /*defaultImplementation=*/"return a;" + >, + InterfaceMethod< + /*description=*/"Get the underlying integer type of an operand", + /*retTy=*/"mlir::Type", + /*methodName=*/"operandIntType", + /*ins=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (auto operandTy = dyn_cast($_op->getOpOperand(0).get().getType())) { + return operandTy.getElementType(); + } else return $_op->getOpOperand(0).get().getType(); + }]> + ]; +} + +def Binary : OpInterface<"Binary"> { + let description = [{ + An operation with two operands + }]; + + let cppNamespace = "mlir::concretelang::FHE"; + + let methods = [ + InterfaceMethod< + /*description=*/"Calculate squared MANP", + /*retTy=*/"llvm::APInt", + /*methodName=*/"sqMANP", + /*ins=*/(ins "llvm::APInt":$a), + /*methodBody=*/"", + /*defaultImplementation=*/"return a;" + >, + InterfaceMethod< + /*description=*/"Get the underlying integer type of an operand", + /*retTy=*/"mlir::Type", + /*methodName=*/"operandIntType", + /*ins=*/(ins "unsigned":$opNum), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (auto operandTy = dyn_cast($_op->getOpOperand(opNum).get().getType())) { + return operandTy.getElementType(); + } else return $_op->getOpOperand(opNum).get().getType(); + }]>, + InterfaceMethod< + /*description=*/"Get the (largest) scalar value of an operand", + /*retTy=*/"std::optional", + /*methodName=*/"operandMaxConstant", + /*ins=*/(ins "unsigned":$opNum), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (auto cstOp = llvm::dyn_cast_or_null($_op-> + getOpOperand(opNum).get().getDefiningOp())) { + if (auto operandTy = dyn_cast($_op-> + getOpOperand(opNum).get().getType())) { + mlir::DenseIntElementsAttr denseVals = + cstOp->template getAttrOfType("value"); + return *(std::max_element(denseVals.begin(), denseVals.end(), + [](llvm::APInt a, llvm::APInt b) { + return a.ult(b); + })); + } else return cstOp->template getAttrOfType("value").getValue(); + } else return {}; + }]>, + InterfaceMethod< + /*description=*/"Get clear operand tensor value ", + /*retTy=*/"std::optional>", + /*methodName=*/"opTensorConstant", + /*ins=*/(ins "unsigned":$opNum), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (auto cstOp = llvm::dyn_cast_or_null($_op-> + getOpOperand(opNum).get().getDefiningOp())) + return cstOp->template getAttrOfType("value").template getValues(); + else return {}; + }]>, + ]; +} + +def BinaryEintInt : OpInterface<"BinaryEintInt", [Binary]> { + let description = [{ + A binary operation on scalars, with the first operand encrypted and the + second clear. + }]; + + let cppNamespace = "mlir::concretelang::FHE"; + + let methods = [ + InterfaceMethod< + /*description=*/"Get clear operand number", + /*retTy=*/"unsigned", + /*methodName=*/"getClearOperandNumber", + /*ins=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"return 1;" + > + ]; +} + +def BinaryIntEint : OpInterface<"BinaryIntEint", [Binary]> { + let description = [{ + A binary operation on scalars, with the first operand clear and the + second encrypted. + }]; + + let cppNamespace = "mlir::concretelang::FHE"; + + let methods = [ + InterfaceMethod< + /*description=*/"Get clear operand number", + /*retTy=*/"unsigned", + /*methodName=*/"getClearOperandNumber", + /*ins=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"return 0;" + > + ]; +} + +def BinaryEint : OpInterface<"BinaryEint"> { + let description = [{ + A binary operation on scalars, with both operands encrypted. + }]; + + let cppNamespace = "mlir::concretelang::FHE"; + + let methods = [ + InterfaceMethod< + /*description=*/"Calculate squared MANP", + /*retTy=*/"llvm::APInt", + /*methodName=*/"sqMANP", + /*ins=*/(ins "llvm::APInt":$a, "llvm::APInt":$b) + > + ]; +} + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfacesInstances.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfacesInstances.h new file mode 100644 index 0000000000..966304a160 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Interfaces/FHEInterfacesInstances.h @@ -0,0 +1,21 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_CONCRETE_FHEINTERFACESINSTANCES_H +#define CONCRETELANG_DIALECT_CONCRETE_FHEINTERFACESINSTANCES_H + +#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h" + +namespace mlir { +class DialectRegistry; + +namespace concretelang { +namespace FHE { +void registerFheInterfacesExternalModels(DialectRegistry ®istry); +} // namespace FHE +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index aae802de3b..efd858e3d1 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -12,6 +12,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.td" + include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td" include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.td" @@ -26,7 +28,7 @@ def TensorBinaryEint : NativeOpTrait<"TensorBinaryEint">; def TensorUnaryEint : NativeOpTrait<"TensorUnaryEint">; -def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt]> { +def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers."; let description = [{ @@ -81,7 +83,8 @@ def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [Pure, TensorBroadcast let hasFolder = 1; } -def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [Pure, TensorBroadcastingRules, TensorBinaryEint]> { +def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [Pure, + TensorBroadcastingRules, TensorBinaryEint, BinaryEint, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the addition of two tensor of encrypted integers."; let description = [{ @@ -133,7 +136,7 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [Pure, TensorBroadcastingRule ]; } -def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcastingRules, TensorBinaryIntEint]> { +def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcastingRules, TensorBinaryIntEint, BinaryIntEint, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers and a tensor of encrypted integers."; let description = [{ @@ -186,7 +189,7 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcast ]; } -def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt]> { +def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers from a tensor of encrypted integers."; let description = [{ @@ -242,7 +245,7 @@ def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [Pure, TensorBroadcast let hasFolder = 1; } -def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [Pure, TensorBroadcastingRules, TensorBinaryEint]> { +def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [Pure, TensorBroadcastingRules, TensorBinaryEint, BinaryEint, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the subtraction of two tensor of encrypted integers."; let description = [{ @@ -294,7 +297,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [Pure, TensorBroadcastingRule ]; } -def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [Pure, TensorUnaryEint]> { +def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [Pure, TensorUnaryEint, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the negation of a tensor of encrypted integers."; let description = [{ @@ -326,7 +329,7 @@ def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [Pure, TensorUnaryEint]> { ]; } -def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt]> { +def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt, BinaryEintInt, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the multiplication of a tensor of encrypted integers and a tensor of clear integers."; let description = [{ @@ -377,7 +380,7 @@ def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [Pure, TensorBroadcast let hasCanonicalizer = 1; } -def FHELinalg_MulEintOp : FHELinalg_Op<"mul_eint", [Pure, TensorBroadcastingRules, TensorBinaryEint]> { +def FHELinalg_MulEintOp : FHELinalg_Op<"mul_eint", [Pure, TensorBroadcastingRules, TensorBinaryEint, BinaryEint, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the multiplication of two tensor of encrypted integers."; let description = [{ @@ -429,8 +432,8 @@ def FHELinalg_MulEintOp : FHELinalg_Op<"mul_eint", [Pure, TensorBroadcastingRule ]; } -def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", [Pure]> { - let summary = "Returns a tensor that contains the result of a lookup table."; +def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", [Pure, ConstantNoise]> { + let summary = "Returns a tensor that contains the result of the lookup on a table."; let description = [{ For each encrypted index, performs a lookup table of clear integers. @@ -465,8 +468,8 @@ def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", [Pure] let hasVerifier = 1; } -def FHELinalg_ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_table", [Pure]> { - let summary = "Returns a tensor that contains the result of a lookup table, using a different lookup table for each element."; +def FHELinalg_ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_table", [Pure, ConstantNoise]> { + let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element."; let description = [{ Performs for each encrypted index a lookup table of clear integers. Multiple lookup tables are passed, and the application of lookup tables @@ -503,17 +506,17 @@ def FHELinalg_ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_tab }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$t, + Type.predicate, HasStaticShapePred]>>:$t, Type.predicate, HasStaticShapePred]>>:$luts ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let hasVerifier = 1; } -def FHELinalg_ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_table", [Pure]> { - let summary = "Returns a tensor that contains the result of a lookup table, using a different lookup table for each element, specified by a map."; +def FHELinalg_ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_table", [Pure, ConstantNoise]> { + let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element, specified by a map."; let description = [{ Performs for each encrypted index a lookup table of clear integers. Multiple lookup tables are passed, and the application of lookup tables @@ -566,7 +569,7 @@ def FHELinalg_ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_t let hasVerifier = 1; } -def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int", [Pure]> { +def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods]> { let summary = "Returns the encrypted dot product between a vector of encrypted integers and a vector of clean integers."; let description = [{ @@ -589,7 +592,6 @@ def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int", [Pure]> { let hasVerifier = 1; } - def FHELinalg_DotEint : FHELinalg_Op<"dot_eint_eint", [Pure]> { let summary = "Returns the encrypted dot product between two vectors of encrypted integers."; @@ -613,8 +615,7 @@ def FHELinalg_DotEint : FHELinalg_Op<"dot_eint_eint", [Pure]> { let hasVerifier = 1; } - -def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [Pure, TensorBinaryEintInt]> { +def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [Pure, TensorBinaryEintInt, BinaryEintInt, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of encrypted integers and a matrix of clear integers."; let description = [{ @@ -759,7 +760,7 @@ def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [Pure, TensorBin }]; } -def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [Pure, TensorBinaryIntEint]> { +def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [Pure, TensorBinaryIntEint, BinaryIntEint, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of clear integers and a matrix of encrypted integers."; let description = [{ @@ -1167,8 +1168,9 @@ def FHELinalg_ConcatOp : FHELinalg_Op<"concat", [Pure]> { let hasVerifier = 1; } -def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", [Pure]> { - let summary = "Returns the 2D convolution of a tensor in NCHW form with weights in the form FCHW"; +def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", [Pure, BinaryEintInt, DeclareOpInterfaceMethods]> { + let summary = "Returns the 2D convolution of a tensor in the form NCHW with weights in the form FCHW"; + let arguments = (ins Type.predicate, HasStaticShapePred]>>:$input, Type.predicate, HasStaticShapePred]>>:$weight, @@ -1183,8 +1185,9 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", [Pure]> { let hasVerifier = 1; } -def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", [Pure]> { - let summary = "Returns the 2D maxpool of a tensor in NCHW form"; +def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", [UnaryEint, DeclareOpInterfaceMethods]> { + let summary = "Returns the 2D maxpool of a tensor in the form NCHW"; + let arguments = (ins Type.predicate, HasStaticShapePred]>>:$input, I64ElementsAttr:$kernel_shape, @@ -1195,7 +1198,7 @@ def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", [Pure]> { let hasVerifier = 1; } -def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", [Pure]> { +def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Returns a tensor that contains the transposition of the input tensor."; let description = [{ @@ -1251,7 +1254,7 @@ def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", [Pure]> { let hasVerifier = 1; } -def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", [Pure]> { +def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Cast an unsigned integer tensor to a signed one"; let description = [{ @@ -1271,7 +1274,7 @@ def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", [Pure]> { }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$input + Type.predicate, HasStaticShapePred]>>:$input ); let results = (outs @@ -1282,7 +1285,7 @@ def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", [Pure]> { let hasCanonicalizer = 1; } -def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure]> { +def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Cast a signed integer tensor to an unsigned one"; let description = [{ @@ -1306,14 +1309,14 @@ def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure]> { ); let results = (outs - Type.predicate, HasStaticShapePred]>>:$output + Type.predicate, HasStaticShapePred]>>:$output ); let hasVerifier = 1; let hasCanonicalizer = 1; } -def FHELinalg_RoundOp : FHELinalg_Op<"round", [Pure, TensorUnaryEint]> { +def FHELinalg_RoundOp : FHELinalg_Op<"round", [Pure, TensorUnaryEint, UnaryEint, DeclareOpInterfaceMethods]> { let summary = "Rounds a tensor of ciphertexts into a smaller precision."; let description = [{ diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td index ad9ca7ce82..aa073832e1 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td @@ -17,7 +17,7 @@ def Tracing_TraceCiphertextOp : Tracing_Op<"trace_ciphertext"> { let arguments = (ins Type.predicate, diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs b/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs index 7da8bb0e4a..1e1652a9ec 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs @@ -288,6 +288,7 @@ const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[ "CONCRETELANGCAPIFHELINALG", "FHELinalgDialectTransforms", "FHEDialect", + "FHEInterfaces", "FHEDialectTransforms", "TFHEToConcrete", "FHEToTFHECrt", diff --git a/compilers/concrete-compiler/compiler/lib/CAPI/Dialect/FHE/FHE.cpp b/compilers/concrete-compiler/compiler/lib/CAPI/Dialect/FHE/FHE.cpp index d06a45b76a..0bfb2fa9b2 100644 --- a/compilers/concrete-compiler/compiler/lib/CAPI/Dialect/FHE/FHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/CAPI/Dialect/FHE/FHE.cpp @@ -45,12 +45,12 @@ MlirTypeOrError IntegerTypeGetChecked(MlirContext ctx, unsigned width) { } bool fheTypeIsAnEncryptedIntegerType(MlirType type) { - return unwrap(type).isa(); + return unwrap(type).isa(); } MlirTypeOrError fheEncryptedIntegerTypeGetChecked(MlirContext ctx, unsigned width) { - return IntegerTypeGetChecked(ctx, width); + return IntegerTypeGetChecked(ctx, width); } bool fheTypeIsAnEncryptedSignedIntegerType(MlirType type) { @@ -64,7 +64,7 @@ MlirTypeOrError fheEncryptedSignedIntegerTypeGetChecked(MlirContext ctx, unsigned fheTypeIntegerWidthGet(MlirType integerType) { mlir::Type type = unwrap(integerType); - auto eint = type.dyn_cast_or_null(); + auto eint = type.dyn_cast_or_null(); if (eint) { return eint.getWidth(); } diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index d88de41a3d..2e35f1c8a7 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -652,10 +652,11 @@ struct ToBoolOpPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite(FHE::ToBoolOp op, mlir::PatternRewriter &rewriter) const override { - auto width = op.getInput() - .getType() - .dyn_cast() - .getWidth(); + auto width = + op.getInput() + .getType() + .dyn_cast() + .getWidth(); if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) { rewriter.replaceOp(op, op.getInput()); return mlir::success(); @@ -675,10 +676,11 @@ struct FromBoolOpPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite(FHE::FromBoolOp op, mlir::PatternRewriter &rewriter) const override { - auto width = op.getResult() - .getType() - .dyn_cast() - .getWidth(); + auto width = + op.getResult() + .getType() + .dyn_cast() + .getWidth(); if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) { rewriter.replaceOp(op, op.getInput()); return mlir::success(); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 11162672f5..85c8ae9c97 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -3,6 +3,7 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. +#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h" #include #include #include @@ -210,79 +211,6 @@ static std::string APIntToStringValUnsigned(const llvm::APInt &i) { return std::string(s.c_str()); } -/// Calculates the square of the 2-norm of a tensor initialized with a -/// dense matrix of constant, signless integers. Aborts if the value -/// type or initialization of of `cstOp` is incorrect. -static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp, - llvm::APInt eNorm) { - mlir::DenseIntElementsAttr denseVals = - cstOp->getAttrOfType("value"); - - assert(denseVals && cstOp.getType().isa() && - "Constant must be a tensor initialized with `dense`"); - - mlir::TensorType tensorType = cstOp.getType().cast(); - - assert(tensorType.getElementType().isSignlessInteger() && - "Can only handle tensors with signless integer elements"); - - llvm::APInt accu{1, 0, false}; - - for (llvm::APInt val : denseVals.getValues()) { - llvm::APInt valSqNorm = APIntWidthExtendSqForConstant(val); - llvm::APInt mulSqNorm = APIntWidthExtendUMul(valSqNorm, eNorm); - accu = APIntWidthExtendUAdd(accu, mulSqNorm); - } - - return accu; -} - -/// Calculates the square of the 2-norm of a 1D tensor of signless -/// integers by conservatively assuming that the dynamic values are the -/// maximum for the integer width. Aborts if the tensor type `tTy` is -/// incorrect. -static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy, - llvm::APInt eNorm) { - assert(tTy && tTy.getElementType().isSignlessInteger() && - tTy.hasStaticShape() && tTy.getRank() == 1 && - "Plaintext operand must be a statically shaped 1D tensor of integers"); - - // Make sure the log2 of the number of elements fits into an - // unsigned - assert(std::numeric_limits::max() > 8 * sizeof(uint64_t)); - - unsigned elWidth = tTy.getElementTypeBitWidth(); - - llvm::APInt maxVal = APInt::getSignedMaxValue(elWidth); - llvm::APInt maxValSq = APIntWidthExtendUnsignedSq(maxVal); - - llvm::APInt maxMulSqNorm = APIntWidthExtendUMul(maxValSq, eNorm); - - // Calculate number of bits for APInt to store number of elements - uint64_t nElts = (uint64_t)tTy.getNumElements(); - assert(std::numeric_limits::max() - nElts > 1); - unsigned nEltsBits = (unsigned)ceilLog2(nElts + 1); - - llvm::APInt nEltsAP{nEltsBits, nElts, false}; - - return APIntWidthExtendUMul(maxMulSqNorm, nEltsAP); -} - -/// Returns the squared 2-norm of the maximum value of the dense values. -static llvm::APInt maxIntNorm2Sq(mlir::DenseIntElementsAttr denseVals) { - auto denseValsAP = denseVals.getValues(); - - // For a constant operand use actual constant to calculate 2-norm - llvm::APInt maxCst = denseValsAP[0]; - for (int64_t i = 0; i < denseVals.getNumElements(); i++) { - llvm::APInt iCst = denseValsAP[i]; - if (maxCst.ult(iCst)) { - maxCst = iCst; - } - } - return APIntWidthExtendSqForConstant(maxCst); -} - /// Returns the squared 2-norm for a dynamic integer by conservatively /// assuming that the integer's value is the maximum for the integer /// width. @@ -290,7 +218,8 @@ static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) { assert(t.isSignlessInteger() && "Type must be a signless integer type"); assert(std::numeric_limits::max() - t.getIntOrFloatBitWidth() > 1); - llvm::APInt maxVal = APInt::getMaxValue(t.getIntOrFloatBitWidth()); + // we consider the maximum value as a signed integer + llvm::APInt maxVal = APInt::getMaxValue(t.getIntOrFloatBitWidth() - 1); return APIntWidthExtendUnsignedSq(maxVal); } @@ -308,759 +237,256 @@ getNoOpSqMANP(llvm::ArrayRef operandMANPs) { return eNorm; } -/// Calculates the squared Minimal Arithmetic Noise Padding of an -/// `FHELinalg.dot_eint_int` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Dot op, - llvm::ArrayRef operandMANPs) { - assert(operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operands"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(1).get().getDefiningOp()); - - if (cstOp) { - // Dot product between a vector of encrypted integers and a vector - // of plaintext constants -> return 2-norm of constant vector - return denseCstTensorNorm2Sq(cstOp, eNorm); - } else { - // Dot product between a vector of encrypted integers and a vector - // of dynamic plaintext values -> conservatively assume that all - // the values are the maximum possible value for the integer's - // width - mlir::TensorType tTy = op->getOpOperand(1) - .get() - .getType() - .dyn_cast_or_null(); - - return denseDynTensorNorm2Sq(tTy, eNorm); - } -} - /// Calculates the squared Minimal Arithmetic Noise Padding of an /// `FHELinalg.dot_eint_eint` operation. static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::DotEint op, - llvm::ArrayRef operandMANPs) { - assert(operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operands"); - - llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value(); - llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value(); - - auto rhsType = - ((mlir::Type)op.getRhs().getType()).cast(); - - llvm::ArrayRef rhsShape = rhsType.getShape(); - - int64_t rhsDims = (int64_t)rhsShape.size(); - - assert(rhsDims == 1 && "In MANP computation dot product RHS expected to have " - "a single dimension"); - - int64_t N = rhsShape[0]; - - // Compute output MANP: - // Tlu output MANP is 1 - llvm::APInt tlu = {1, 1, false}; - // The element-wise multiplication is given by the - // subtraction of two TLU outputs. The MANP of the multiplication is thus - // the sum of the TLU MANPs - llvm::APInt elemMulNorm = APIntWidthExtendUAdd(tlu, tlu); - - llvm::APInt accNorm = llvm::APInt{1, 0, false}; - - // For the total Dot product MANP, take the manp of the sum of products - for (int64_t i = 0; i < N; i++) { - accNorm = APIntWidthExtendUAdd(elemMulNorm, accNorm); - } - - return accNorm; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of an -/// `FHE.add_eint_int` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::AddEintIntOp op, - llvm::ArrayRef operandMANPs) { - assert( - operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.add_eint` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::AddEintOp op, - llvm::ArrayRef operandMANPs) { - assert(operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - operandMANPs[1]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operands"); - - llvm::APInt a = operandMANPs[0]->getValue().getMANP().value(); - llvm::APInt b = operandMANPs[1]->getValue().getMANP().value(); - - return APIntWidthExtendUAdd(a, b); -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.sub_int_eint` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::SubIntEintOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 2 && - operandMANPs[1]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().value(); - - return eNorm; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.sub_eint_int` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::SubEintIntOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.sub_eint` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::SubEintOp op, - llvm::ArrayRef operandMANPs) { - assert(operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - operandMANPs[1]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operands"); - - llvm::APInt a = operandMANPs[0]->getValue().getMANP().value(); - llvm::APInt b = operandMANPs[1]->getValue().getMANP().value(); - - return APIntWidthExtendUAdd(a, b); -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.neg_eint` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::NegEintOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.not` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::BoolNotOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -static llvm::APInt getSqMANP(mlir::concretelang::FHE::ToSignedOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -static llvm::APInt getSqMANP(mlir::concretelang::FHE::ToUnsignedOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.mul_eint_int` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::MulEintIntOp op, - llvm::ArrayRef operandMANPs) { - mlir::Type iTy = op->getOpOperand(1).get().getType(); - - assert(iTy.isSignlessInteger() && - "Only multiplications with signless integers are currently allowed"); - - assert( - operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(1).get().getDefiningOp()); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - llvm::APInt sqNorm; - - if (cstOp) { - // For a constant operand use actual constant to calculate 2-norm - mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); - sqNorm = APIntWidthExtendSqForConstant(attr.getValue()); - } else { - // For a dynamic operand conservatively assume that the value is - // the maximum for the integer width - sqNorm = conservativeIntNorm2Sq(iTy); - } - - return APIntWidthExtendUMul(sqNorm, eNorm); -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of -/// `FHE.mul_eint` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::MulEintOp op, - llvm::ArrayRef operandMANPs) { - assert(operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - operandMANPs[1]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operands"); - - // x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y) - - const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value(); - const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value(); - - // The MANP of this operation is simply the MANP after the TLUs - // which is equal to the sum of outputs of 2 TLUs - const llvm::APInt tlu = {1, 1, false}; - const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu); - - return result; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.round` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::RoundEintOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - uint64_t inputWidth = - op.getOperand().getType().cast().getWidth(); - uint64_t outputWidth = - op.getResult().getType().cast().getWidth(); - uint64_t clearedBits = inputWidth - outputWidth; - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - eNorm += clearedBits; - - return eNorm; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of -/// `FHE.max_eint` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHE::MaxEintOp op, - llvm::ArrayRef operandMANPs) { - assert(operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - operandMANPs[1]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operands"); - - // max(x, y) = max(x - y, 0) + y - - const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value(); - const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value(); - - const llvm::APInt sub = APIntWidthExtendUAdd(x, y); - const llvm::APInt tlu = {1, 1, false}; - const llvm::APInt add = APIntWidthExtendUAdd(tlu, y); - - // this is not optimal as it can increase the resulting noise unnecessarily - return APIntUMax(add, sub); -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of an -/// `FHELinalg.add_eint_int` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::AddEintIntOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::AddEintOp op, - llvm::ArrayRef operandMANPs) { - assert(operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - operandMANPs[1]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operands"); - - llvm::APInt a = operandMANPs[0]->getValue().getMANP().value(); - llvm::APInt b = operandMANPs[1]->getValue().getMANP().value(); - - return APIntWidthExtendUAdd(a, b); -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHELinalg.sub_int_eint` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SubIntEintOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 2 && - operandMANPs[1]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().value(); - - return eNorm; -} - -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SubEintIntOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SubEintOp op, - llvm::ArrayRef operandMANPs) { - assert(operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - operandMANPs[1]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operands"); - - llvm::APInt a = operandMANPs[0]->getValue().getMANP().value(); - llvm::APInt b = operandMANPs[1]->getValue().getMANP().value(); - - return APIntWidthExtendUAdd(a, b); -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHELinalg.neg_eint` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::NegEintOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::ToSignedOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::ToUnsignedOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - - return eNorm; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.mul_eint_int` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MulEintIntOp op, - llvm::ArrayRef operandMANPs) { - - mlir::RankedTensorType op0Ty = - op->getOpOperand(1).get().getType().cast(); - - mlir::Type iTy = op0Ty.getElementType(); - - assert(iTy.isSignlessInteger() && - "Only multiplications with signless integers are currently allowed"); - - assert( - operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().value(); - llvm::APInt sqNorm; - - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(1).get().getDefiningOp()); - mlir::DenseIntElementsAttr denseVals = - cstOp ? cstOp->getAttrOfType("value") - : nullptr; - - if (denseVals) { - // For a constant operand use actual constant to calculate 2-norm - sqNorm = maxIntNorm2Sq(denseVals); - } else { - // For a dynamic operand conservatively assume that the value is - // the maximum for the integer width - sqNorm = conservativeIntNorm2Sq(iTy); - } - - return APIntWidthExtendUMul(sqNorm, eNorm); -} - -/// Calculates the squared Minimal Arithmetic Noise Padding -/// of `FHE.mul_eint` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MulEintOp op, - llvm::ArrayRef operandMANPs) { - assert(operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - operandMANPs[1]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operands"); - - // x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y) - - const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value(); - const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value(); - - // The MANP of this operation is simply the MANP after the TLUs - // which is equal to the sum of outputs of 2 TLUs - const llvm::APInt tlu = {1, 1, false}; - const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu); - - return result; -} - -static llvm::APInt computeVectorNorm( - llvm::ArrayRef shape, int64_t axis, - mlir::DenseIntElementsAttr denseValues, llvm::APInt encryptedOperandNorm, - llvm::SmallVector &elementSelector) { - - // The accumulator is initialized with 0s in all bits (not the encrypted 0) - // the there is no initial noise in the accumulator, so its - // MANP is initialized to 0 - llvm::APInt accumulationNorm = llvm::APInt{1, 0, false}; - for (int64_t i = 0; i < shape[axis]; i++) { - elementSelector[axis] = i; - - auto denseValuesAP = denseValues.getValues(); - llvm::APInt weight = denseValuesAP[elementSelector]; - llvm::APInt weightNorm = APIntWidthExtendSqForConstant(weight); - - llvm::APInt multiplicationNorm = - APIntWidthExtendUMul(encryptedOperandNorm, weightNorm); - accumulationNorm = - APIntWidthExtendUAdd(multiplicationNorm, accumulationNorm); - } - return accumulationNorm; -} - -static void determineNextVector( - llvm::ArrayRef shape, int64_t destroyedDimension, - llvm::SmallVector &vectorSelector) { - - for (int64_t i = shape.size() - 1; i >= 0; i--) { - if (i == destroyedDimension) { - continue; - } - - if (vectorSelector[i] + 1 < (uint64_t)shape[i]) { - vectorSelector[i]++; - break; - } - - vectorSelector[i] = 0; - } -} - -static llvm::APInt calculateSqManpForMatMulWithDenseValues( - llvm::ArrayRef shape, int64_t destroyedDimension, - mlir::DenseIntElementsAttr denseValues, llvm::APInt encryptedOperandNorm) { - - llvm::APInt maximumNorm = llvm::APInt{1, 1, false}; - - size_t numberOfVectorsToInspect = 1; - for (auto size : shape) { - numberOfVectorsToInspect *= size; - } - numberOfVectorsToInspect /= shape[destroyedDimension]; - - auto vectorSelector = - llvm::SmallVector(shape.size(), 0); - - auto elementSelector = vectorSelector; - for (size_t n = 0; n < numberOfVectorsToInspect; n++) { - elementSelector.assign(vectorSelector); - - llvm::APInt accumulationNorm = - computeVectorNorm(shape, destroyedDimension, denseValues, - encryptedOperandNorm, elementSelector); - maximumNorm = APIntUMax(maximumNorm, accumulationNorm); - - determineNextVector(shape, destroyedDimension, vectorSelector); - } - - return maximumNorm; -} - -/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation -/// that is equivalent to an `FHE.mul_eint_int` operation. -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintIntOp op, - llvm::ArrayRef operandMANPs) { - - auto lhsType = - ((mlir::Type)op.getLhs().getType()).cast(); - auto rhsType = - ((mlir::Type)op.getRhs().getType()).cast(); - - llvm::ArrayRef lhsShape = lhsType.getShape(); - llvm::ArrayRef rhsShape = rhsType.getShape(); - - int64_t lhsDims = (int64_t)lhsShape.size(); - int64_t rhsDims = (int64_t)rhsShape.size(); - - mlir::Type rhsElementType = rhsType.getElementType(); - assert(rhsElementType.isSignlessInteger() && - "Only multiplications with signless integers are currently allowed"); - - assert( - operandMANPs.size() == 2 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value(); - llvm::APInt accNorm = llvm::APInt{1, 0, false}; - - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(1).get().getDefiningOp()); - mlir::DenseIntElementsAttr denseVals = - cstOp ? cstOp->getAttrOfType("value") - : nullptr; - - int64_t N = rhsDims <= 2 ? rhsShape[0] : rhsShape[rhsDims - 2]; - - if (denseVals) { - auto denseValsAP = denseVals.getValues(); - - if (lhsDims == 2 && rhsDims == 2) { - // MxN @ NxP -> MxP - - int64_t M = lhsShape[0]; - int64_t P = rhsShape[1]; - for (int64_t m = 0; m < M; m++) { - for (int64_t p = 0; p < P; p++) { - // The accumulator is initialized with 0s in all bits (not the - // encrypted 0) the there is no initial noise in the accumulator, so - // its MANP is initialized to 0 - llvm::APInt tmpNorm = llvm::APInt{1, 0, false}; - for (int64_t n = 0; n < N; n++) { - llvm::APInt cst = denseValsAP[{(uint64_t)n, (uint64_t)p}]; - llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst); - llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); - tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); - } - accNorm = APIntUMax(accNorm, tmpNorm); - } - } - - } else if (rhsDims == 1) { + llvm::ArrayRef operandMANPs) { + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().has_value() && + operandMANPs[1]->getValue().getMANP().has_value() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); - // MxN @ N -> M - // LxMxN @ N -> LxM - // KxLxMxN @ N -> KxLxM + llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().value(); + llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value(); - for (int64_t i = 0; i < N; i++) { - llvm::APInt cst = denseValsAP[i]; - llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst); - llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); - accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); - } + auto rhsType = + ((mlir::Type)op.getRhs().getType()).cast(); - } else if (rhsDims >= 2) { + llvm::ArrayRef rhsShape = rhsType.getShape(); - // KxLxMxN @ NxP -> KxLxMxP - // KxLxMxN @ LxNxP -> KxLxMxP - // Kx1xMxN @ LxNxP -> KxLxMxP + int64_t rhsDims = (int64_t)rhsShape.size(); - // MxN @ KxLxNxP -> KxLxMxP - // LxMxN @ KxLxNxP -> KxLxMxP - // 1xMxN @ KxLxNxP -> KxLxMxP + assert(rhsDims == 1 && "In MANP computation dot product RHS expected to have " + "a single dimension"); - // N @ NxP -> P - // N @ LxNxP -> LxP - // N @ KxLxNxP -> KxLxP + int64_t N = rhsShape[0]; - accNorm = calculateSqManpForMatMulWithDenseValues(rhsShape, rhsDims - 2, - denseVals, lhsNorm); - } + // Compute output MANP: + // Tlu output MANP is 1 + llvm::APInt tlu = {1, 1, false}; + // The element-wise multiplication is given by the + // subtraction of two TLU outputs. The MANP of the multiplication is thus + // the sum of the TLU MANPs + llvm::APInt elemMulNorm = APIntWidthExtendUAdd(tlu, tlu); - } else { - llvm::APInt rhsNorm = conservativeIntNorm2Sq(rhsElementType); - for (int64_t i = 0; i < N; i++) { - llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); - accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); - } - } + llvm::APInt accNorm = llvm::APInt{1, 0, false}; - return accNorm; + return APIntWidthExtendUMul(APIntWidthExtendUAdd(tlu, tlu), + llvm::APInt(ceilLog2(N + 1), N, false)); } -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulIntEintOp op, - llvm::ArrayRef operandMANPs) { - - auto lhsType = - ((mlir::Type)op.getLhs().getType()).cast(); - auto rhsType = - ((mlir::Type)op.getRhs().getType()).cast(); +/// Calculates the squared Minimal Arithmetic Noise Padding of an unary FHE +/// operation. +static std::optional +getSqMANP(mlir::concretelang::FHE::UnaryEint op, + llvm::ArrayRef operandMANPs) { + // not all unary ops taking an encrypted operand have a type signature + // reflecting that, a check might be required (FHELinalg.TransposeOp is one + // such known op) + if (op.operandIntType().isa()) { + assert(operandMANPs.size() == 1 && + operandMANPs[0]->getValue().getMANP().has_value() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operand"); + return op.sqMANP(operandMANPs[0]->getValue().getMANP().value()); + } else + return {}; +} - llvm::ArrayRef lhsShape = lhsType.getShape(); - llvm::ArrayRef rhsShape = rhsType.getShape(); +/// Calculates the squared Minimal Arithmetic Noise Padding of a binary FHE +/// operation with the first operand encrypted. - int64_t lhsDims = (int64_t)lhsShape.size(); - int64_t rhsDims = (int64_t)rhsShape.size(); +static std::optional +getSqMANP(mlir::concretelang::FHE::BinaryEintInt op, + llvm::ArrayRef operandMANPs) { + assert(operandMANPs.size() >= 2 && // conv2d has an optional 3rd operand + operandMANPs[0]->getValue().getMANP().has_value() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operand"); + return op.sqMANP(operandMANPs[0]->getValue().getMANP().value()); +} - mlir::Type lhsElementType = lhsType.getElementType(); - assert(lhsElementType.isSignlessInteger() && - "Only multiplications with signless integers are currently allowed"); +/// Calculates the squared Minimal Arithmetic Noise Padding of a binary FHE +/// operation with the second operand encrypted. +static std::optional +getSqMANP(mlir::concretelang::FHE::BinaryIntEint op, + llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() == 2 && operandMANPs[1]->getValue().getMANP().has_value() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + return op.sqMANP(operandMANPs[1]->getValue().getMANP().value()); +} - llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().value(); - llvm::APInt accNorm = llvm::APInt{1, 0, false}; - - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(0).get().getDefiningOp()); - mlir::DenseIntElementsAttr denseVals = - cstOp ? cstOp->getAttrOfType("value") - : nullptr; - - int64_t N = rhsDims <= 2 ? rhsShape[0] : rhsShape[rhsDims - 2]; +/// Calculates the squared Minimal Arithmetic Noise Padding of a binary FHE +/// operation with both operands encrypted. - if (denseVals) { - auto denseValsAP = denseVals.getValues(); +static std::optional +getSqMANP(mlir::concretelang::FHE::BinaryEint op, + llvm::ArrayRef operandMANPs) { + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().has_value() && + operandMANPs[1]->getValue().getMANP().has_value() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); - if (lhsDims == 2 && rhsDims == 2) { + return op.sqMANP(operandMANPs[0]->getValue().getMANP().value(), + operandMANPs[1]->getValue().getMANP().value()); +} - // MxN @ NxP -> MxP +static llvm::APInt sqMANP_mul_eint_int(llvm::APInt a, mlir::Type iTy, + std::optional b) { + assert(iTy.isSignlessInteger() && + "Only multiplications with signless integers are currently allowed"); - int64_t M = lhsShape[0]; - int64_t P = rhsShape[1]; - for (int64_t m = 0; m < M; m++) { - for (int64_t p = 0; p < P; p++) { - llvm::APInt tmpNorm = llvm::APInt{1, 1, false}; - for (int64_t n = 0; n < N; n++) { - llvm::APInt cst = denseValsAP[{(uint64_t)m, (uint64_t)n}]; - llvm::APInt lhsNorm = APIntWidthExtendSqForConstant(cst); - llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); - tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); - } - accNorm = APIntUMax(accNorm, tmpNorm); - } - } + llvm::APInt sqNorm; + if (b.has_value()) { + // For a constant operand use actual constant to calculate 2-norm + sqNorm = APIntWidthExtendSqForConstant(b.value()); + } else { + // For a dynamic operand conservatively assume that the value is + // the maximum for the integer width + sqNorm = conservativeIntNorm2Sq(iTy); + } - } else if (lhsDims == 1) { + return APIntWidthExtendUMul(sqNorm, a); +} - // N @ NxP -> P - // N @ LxNxP -> LxP - // N @ KxLxNxP -> KxLxP +static llvm::APInt sqMANP_mul_eint(llvm::APInt a, llvm::APInt b) { + // a * b = ((a + b)^2 / 4) - ((a - b)^2 / 4) == tlu(a + b) - tlu(a - b) + const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(a, b); + const llvm::APInt tlu = {1, 1, false}; + const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu); - for (int64_t i = 0; i < N; i++) { - llvm::APInt cst = denseValsAP[i]; - llvm::APInt lhsNorm = APIntWidthExtendSqForConstant(cst); - llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); - accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); - } + return result; +} - } else if (lhsDims >= 2) { +/// Computes the squared vector norm as the maximum of all dot products over the +/// destroyed dimension. The computation is recursive and can be seen as folding +/// a tree of values where leaves compute the dot products and nodes choose the +/// maximum of the children. There is a branching node at every shape dimension +/// with the fanout equal to that dimension range, except for the destroyed +/// dimension (where the fanout is 1). +/// +/// This function is expected to behave correctly on clear tensor shapes with +/// any dimensionality, for example: +/// +/// MxN @ N -> M +/// LxMxN @ N -> LxM +/// KxLxMxN @ N -> KxLxM +/// +/// N @ NxP -> P +/// N @ LxNxP -> LxP +/// N @ KxLxNxP -> KxLxP +/// +/// KxLxMxN @ NxP -> KxLxMxP +/// KxLxMxN @ LxNxP -> KxLxMxP +/// Kx1xMxN @ LxNxP -> KxLxMxP +/// +/// MxN @ KxLxNxP -> KxLxMxP +/// LxMxN @ KxLxNxP -> KxLxMxP +/// 1xMxN @ KxLxNxP -> KxLxMxP +/// +/// N @ NxP -> P +/// N @ LxNxP -> LxP +/// N @ KxLxNxP -> KxLxP +/// +/// MxN @ N -> M +/// LxMxN @ N -> LxM +/// KxLxMxN @ N -> KxLxM +/// +/// MxN @ NxP -> MxP + +static llvm::APInt sqMANP_matmul_internal( + llvm::ArrayRef shape, size_t destroyedDimension, + llvm::SmallVector iterPoint, + mlir::detail::ElementsAttrRange + clearValues, + llvm::APInt encryptedOperandNorm) { + assert(iterPoint.size() >= shape.size() && + "Tensor shape dimensionality is larger than iteration space " + "dimensionality"); + assert(destroyedDimension < iterPoint.size() && + "Destroyed dimension outside of iteration space dimensionality"); + size_t currentDimension = iterPoint.size() - shape.size(); + + if (currentDimension == destroyedDimension) { + // the dot product over destroyed dimension will sum products counting down + // from the largest index + iterPoint[currentDimension] = shape[0] - 1; + return sqMANP_matmul_internal(shape.drop_front(1), destroyedDimension, + iterPoint, clearValues, encryptedOperandNorm); + } - // KxLxMxN @ NxP -> KxLxMxP - // KxLxMxN @ LxNxP -> KxLxMxP - // Kx1xMxN @ LxNxP -> KxLxMxP + if (shape.size() == 0) { // `iterPoint` is defined in all indices, let's + // compute the dot product + llvm::APInt accumulationNorm = llvm::APInt{1, 0, false}; + for (int64_t i = iterPoint[destroyedDimension]; i >= 0; i--) { + iterPoint[destroyedDimension] = i; + + llvm::APInt weight = clearValues[iterPoint]; + llvm::APInt weightNorm = APIntWidthExtendSqForConstant(weight); + llvm::APInt multiplicationNorm = + APIntWidthExtendUMul(encryptedOperandNorm, weightNorm); + accumulationNorm = + APIntWidthExtendUAdd(multiplicationNorm, accumulationNorm); + } + return accumulationNorm; + } else { // descend into all indices in current dimension + llvm::APInt maximumNorm = llvm::APInt{1, 1, false}; + for (int64_t i = 0; i < shape[0]; i++) { + iterPoint[currentDimension] = i; + llvm::APInt accumulationNorm = + sqMANP_matmul_internal(shape.drop_front(1), destroyedDimension, + iterPoint, clearValues, encryptedOperandNorm); + maximumNorm = APIntUMax(maximumNorm, accumulationNorm); + } + return maximumNorm; + } +} - // MxN @ KxLxNxP -> KxLxMxP - // LxMxN @ KxLxNxP -> KxLxMxP - // 1xMxN @ KxLxNxP -> KxLxMxP +/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +/// that is equivalent to an `FHE.mul_eint_int` operation. +static llvm::APInt +sqMANP_matmul(llvm::APInt encryptedOperandNorm, + mlir::RankedTensorType clearOperandType, + std::optional> + clearVals, + unsigned clearOpNum) { + + assert(clearOperandType.getElementType().isSignlessInteger() && + "Only multiplications with signless integers are currently allowed"); - // MxN @ N -> M - // LxMxN @ N -> LxM - // KxLxMxN @ N -> KxLxM + llvm::ArrayRef clearOperandShape = clearOperandType.getShape(); + uint64_t clearOperandDims = (uint64_t)clearOperandShape.size(); + // if the clear operand is LHS (index 0), then the destroyed dimension is its + // last (dims-1) if the clear operand is RHS (index 1), then the destroyed + // dimension is its second to last (dims-2) + assert(clearOpNum <= 1 && "Cannot determine destroyed dimension: operation " + "has more than 2 operands"); + size_t destroyedDimension = + clearOperandDims == 1 ? 0 : clearOperandDims - 1 - clearOpNum; - accNorm = calculateSqManpForMatMulWithDenseValues(lhsShape, lhsDims - 1, - denseVals, rhsNorm); - } + llvm::APInt accNorm = llvm::APInt{1, 0, false}; - } else { - llvm::APInt lhsNorm = conservativeIntNorm2Sq(lhsElementType); - for (int64_t i = 0; i < N; i++) { - llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); - accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); - } + if (clearVals.has_value()) + accNorm = + sqMANP_matmul_internal(clearOperandShape, destroyedDimension, + llvm::SmallVector( + clearOperandShape.size(), 0), + clearVals.value(), encryptedOperandNorm); + else { + llvm::APInt clearOperandNorm = + conservativeIntNorm2Sq(clearOperandType.getElementType()); + llvm::APInt mulNorm = + APIntWidthExtendUMul(encryptedOperandNorm, clearOperandNorm); + uint64_t N = clearOperandShape[destroyedDimension]; + unsigned int Nbits = ceilLog2(N + 1); + mulNorm = APIntWidthExtendUMul(mulNorm, APInt{Nbits, N, false}); + accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); } return accNorm; @@ -1106,17 +532,6 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MatMulEintEintOp op, return accNorm; } -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::TransposeOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - return operandMANPs[0]->getValue().getMANP().value(); -} - static llvm::APInt getSqMANP(mlir::tensor::ExtractOp op, llvm::ArrayRef operandMANPs) { @@ -1129,19 +544,9 @@ static llvm::APInt getSqMANP(mlir::tensor::ExtractOp op, return eNorm; } -static llvm::APInt getSqMANP(FHELinalg::FromElementOp op, - llvm::ArrayRef operandMANPs) { - - auto manp = operandMANPs[0]->getValue().getMANP(); - if (manp.has_value()) { - return manp.value(); - } - - return llvm::APInt{1, 1, false}; -} - -static llvm::APInt getSqMANP(mlir::tensor::FromElementsOp op, - llvm::ArrayRef operandMANPs) { +static std::optional +getSqMANP(mlir::tensor::FromElementsOp op, + llvm::ArrayRef operandMANPs) { auto max = std::max_element(operandMANPs.begin(), operandMANPs.end(), [](const MANPLattice *a, const MANPLattice *b) { @@ -1152,8 +557,9 @@ static llvm::APInt getSqMANP(mlir::tensor::FromElementsOp op, return (*max)->getValue().getMANP().value(); } -static llvm::APInt getSqMANP(mlir::tensor::ExtractSliceOp op, - llvm::ArrayRef operandMANPs) { +static std::optional +getSqMANP(mlir::tensor::ExtractSliceOp op, + llvm::ArrayRef operandMANPs) { assert( operandMANPs[0]->getValue().getMANP().has_value() && @@ -1162,8 +568,9 @@ static llvm::APInt getSqMANP(mlir::tensor::ExtractSliceOp op, return operandMANPs[0]->getValue().getMANP().value(); } -static llvm::APInt getSqMANP(mlir::tensor::InsertSliceOp op, - llvm::ArrayRef operandMANPs) { +static std::optional +getSqMANP(mlir::tensor::InsertSliceOp op, + llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() >= 2 && @@ -1175,8 +582,9 @@ static llvm::APInt getSqMANP(mlir::tensor::InsertSliceOp op, operandMANPs[1]->getValue().getMANP().value()); } -static llvm::APInt getSqMANP(mlir::tensor::InsertOp op, - llvm::ArrayRef operandMANPs) { +static std::optional +getSqMANP(mlir::tensor::InsertOp op, + llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() >= 2 && @@ -1188,8 +596,9 @@ static llvm::APInt getSqMANP(mlir::tensor::InsertOp op, operandMANPs[1]->getValue().getMANP().value()); } -static llvm::APInt getSqMANP(mlir::tensor::CollapseShapeOp op, - llvm::ArrayRef operandMANPs) { +static std::optional +getSqMANP(mlir::tensor::CollapseShapeOp op, + llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() >= 1 && @@ -1199,8 +608,9 @@ static llvm::APInt getSqMANP(mlir::tensor::CollapseShapeOp op, return operandMANPs[0]->getValue().getMANP().value(); } -static llvm::APInt getSqMANP(mlir::tensor::ExpandShapeOp op, - llvm::ArrayRef operandMANPs) { +static std::optional +getSqMANP(mlir::tensor::ExpandShapeOp op, + llvm::ArrayRef operandMANPs) { assert( operandMANPs.size() >= 1 && @@ -1210,8 +620,9 @@ static llvm::APInt getSqMANP(mlir::tensor::ExpandShapeOp op, return operandMANPs[0]->getValue().getMANP().value(); } -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SumOp op, - llvm::ArrayRef operandMANPs) { +static std::optional +getSqMANP(mlir::concretelang::FHELinalg::SumOp op, + llvm::ArrayRef operandMANPs) { auto inputType = op.getOperand().getType().dyn_cast(); @@ -1252,8 +663,9 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::SumOp op, return APIntWidthExtendUMul(noiseMultiplier, operandMANP); } -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::ConcatOp op, - llvm::ArrayRef operandMANPs) { +static std::optional +getSqMANP(mlir::concretelang::FHELinalg::ConcatOp op, + llvm::ArrayRef operandMANPs) { llvm::APInt result = llvm::APInt{1, 0, false}; for (const MANPLattice *operandMANP : operandMANPs) { @@ -1265,30 +677,11 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::ConcatOp op, return result; } -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Conv2dOp op, - llvm::ArrayRef operandMANPs) { - - mlir::RankedTensorType weightTy = - op.getWeight().getType().cast(); - - mlir::Type weightIntType = weightTy.getElementType(); - - // Bias is optional, so we can have both 2 or 3 operands - assert((operandMANPs.size() == 2 || operandMANPs.size() == 3) && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted " - "operand"); - - llvm::APInt inputNorm = operandMANPs[0]->getValue().getMANP().value(); - - mlir::arith::ConstantOp weightCstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(1).get().getDefiningOp()); - mlir::DenseIntElementsAttr weightDenseVals = - weightCstOp - ? weightCstOp->getAttrOfType("value") - : nullptr; - +static llvm::APInt +sqMANP_conv2d(llvm::APInt inputNorm, mlir::RankedTensorType weightTy, + std::optional> + weightVals) { // Initial value of the accumulator to 0 llvm::APInt accNorm = llvm::APInt{1, 0, false}; @@ -1297,8 +690,7 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Conv2dOp op, uint64_t C = weightTy.getShape()[1]; uint64_t H = weightTy.getShape()[2]; uint64_t W = weightTy.getShape()[3]; - if (weightDenseVals) { - auto weightDenseValsAP = weightDenseVals.getValues(); + if (weightVals.has_value()) { // For a constant weight kernel use actual constant to calculate 2-norm // input windows are being multiplied by a kernel and summed up for (uint64_t f = 0; f < F; f++) { @@ -1307,7 +699,7 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Conv2dOp op, for (uint64_t c = 0; c < C; c++) { for (uint64_t h = 0; h < H; h++) { for (uint64_t w = 0; w < W; w++) { - llvm::APInt cst = weightDenseValsAP[{f, c, h, w}]; + llvm::APInt cst = weightVals.value()[{f, c, h, w}]; llvm::APInt weightNorm = APIntWidthExtendSqForConstant(cst); llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm); tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); @@ -1320,7 +712,7 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Conv2dOp op, } else { // For a dynamic operand conservatively assume that the value is // the maximum for the integer width - llvm::APInt weightNorm = conservativeIntNorm2Sq(weightIntType); + llvm::APInt weightNorm = conservativeIntNorm2Sq(weightTy.getElementType()); // For a weight (kernel) of shape tensor, there is C*H*W // FHE.mul_eint_int and FHE.add_eint operations for each elements of the // result @@ -1333,57 +725,6 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Conv2dOp op, return accNorm; } -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::Maxpool2dOp op, - llvm::ArrayRef operandMANPs) { - - // maximum between two value is calculated using - // - max(x - y, 0) + y - - // max is calculated with a TLU so MANP is {1, 1, false} - // y on the other hand comes from the input or from the previous result - - // in the current implementation, it's the input - // so the resulting MANP is `{1, 1, false} + MANP input` - - const llvm::APInt tlu = {1, 1, false}; - const llvm::APInt input = operandMANPs[0]->getValue().getMANP().value(); - - const llvm::APInt forResult = APIntWidthExtendUAdd(tlu, input); - const llvm::APInt forIntermediate = APIntWidthExtendUAdd(forResult, input); - - return APIntUMax(forIntermediate, forResult); -} - -static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::RoundOp op, - llvm::ArrayRef operandMANPs) { - - assert( - operandMANPs.size() == 1 && - operandMANPs[0]->getValue().getMANP().has_value() && - "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - - const uint64_t inputWidth = op.getOperand() - .getType() - .cast() - .getElementType() - .cast() - .getWidth(); - - const uint64_t outputWidth = op.getResult() - .getType() - .cast() - .getElementType() - .cast() - .getWidth(); - - const uint64_t clearedBits = inputWidth - outputWidth; - - llvm::APInt result = operandMANPs[0]->getValue().getMANP().value(); - result += clearedBits; - - return result; -} - class MANPAnalysis : public mlir::dataflow::SparseDataFlowAnalysis { public: @@ -1406,119 +747,34 @@ class MANPAnalysis ArrayRef results) override { MANPLattice *latticeRes = results[0]; - bool isDummy = false; - llvm::APInt norm2SqEquiv; - - // FHE Operators - if (auto addEintIntOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(addEintIntOp, operands); - } else if (auto addEintOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(addEintOp, operands); - } else if (auto subIntEintOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(subIntEintOp, operands); - } else if (auto subEintIntOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(subEintIntOp, operands); - } else if (auto subEintOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(subEintOp, operands); - } else if (auto negEintOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(negEintOp, operands); - } else if (auto boolNotOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(boolNotOp, operands); - } else if (auto toSignedOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(toSignedOp, operands); - } else if (auto toUnsignedOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(toUnsignedOp, operands); - } else if (auto mulEintIntOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(mulEintIntOp, operands); - } else if (auto mulEintOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(mulEintOp, operands); - } else if (auto roundOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(roundOp, operands); - } else if (auto maxEintOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(maxEintOp, operands); - } else if (llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op)) { + std::optional norm2SqEquiv; + + if (auto cstNoiseOp = + llvm::dyn_cast(op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; } else if (llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = getNoOpSqMANP(operands); } - // FHELinalg Operators - else if (auto dotOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(dotOp, operands); + // FHE and FHELinalg Operators + else if (auto unaryEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(unaryEintOp, operands); + } else if (auto binaryEintIntOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(binaryEintIntOp, operands); + } else if (auto binaryIntEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(binaryIntEintOp, operands); + } else if (auto binaryEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(binaryEintOp, operands); } else if (auto dotEintOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(dotEintOp, operands); - } else if (auto addEintIntOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(addEintIntOp, operands); - } else if (auto addEintOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(addEintOp, operands); - } else if (auto subIntEintOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(subIntEintOp, operands); - } else if (auto subEintIntOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(subEintIntOp, operands); - } else if (auto subEintOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(subEintOp, operands); - } else if (auto negEintOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(negEintOp, operands); - } else if (auto toSignedOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(toSignedOp, operands); - } else if (auto toUnsignedOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(toUnsignedOp, operands); - } else if (auto mulEintIntOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(mulEintIntOp, operands); - } else if (auto mulEintOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(mulEintOp, operands); - } else if (auto matmulEintIntOp = llvm::dyn_cast< - mlir::concretelang::FHELinalg::MatMulEintIntOp>(op)) { - norm2SqEquiv = getSqMANP(matmulEintIntOp, operands); - } else if (auto matmulIntEintOp = llvm::dyn_cast< - mlir::concretelang::FHELinalg::MatMulIntEintOp>(op)) { - norm2SqEquiv = getSqMANP(matmulIntEintOp, operands); } else if (auto matmulEintEintOp = llvm::dyn_cast< mlir::concretelang::FHELinalg::MatMulEintEintOp>(op)) { norm2SqEquiv = getSqMANP(matmulEintEintOp, operands); - } else if (llvm::isa< - mlir::concretelang::FHELinalg::ApplyLookupTableEintOp, - mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp, - mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>( - op)) { - norm2SqEquiv = llvm::APInt{1, 1, false}; } else if (auto sumOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(sumOp, operands); @@ -1526,35 +782,14 @@ class MANPAnalysis llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(concatOp, operands); - } else if (auto conv2dOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(conv2dOp, operands); - } else if (auto maxpool2dOp = - llvm::dyn_cast( - op)) { - norm2SqEquiv = getSqMANP(maxpool2dOp, operands); } else if (auto fromElementOp = llvm::dyn_cast( op)) { - norm2SqEquiv = getSqMANP(fromElementOp, operands); - } else if (auto transposeOp = - llvm::dyn_cast( - op)) { - if (transposeOp.getTensor() - .getType() - .cast() - .getElementType() - .isa()) { - norm2SqEquiv = getSqMANP(transposeOp, operands); - } else { - isDummy = true; - } - } else if (auto roundOp = - llvm::dyn_cast(op)) { - norm2SqEquiv = getSqMANP(roundOp, operands); + if (operands[0]->getValue().getMANP().has_value()) { + norm2SqEquiv = operands[0]->getValue().getMANP().value(); + } else + norm2SqEquiv = llvm::APInt{1, 1, false}; } - // Tensor Operators // ExtractOp else if (auto extractOp = llvm::dyn_cast(op)) { @@ -1563,7 +798,7 @@ class MANPAnalysis .isa()) { norm2SqEquiv = getSqMANP(extractOp, operands); } else { - isDummy = true; + norm2SqEquiv = {}; } } // ExtractSliceOp @@ -1576,7 +811,7 @@ class MANPAnalysis .isa()) { norm2SqEquiv = getSqMANP(extractSliceOp, operands); } else { - isDummy = true; + norm2SqEquiv = {}; } } // InsertOp @@ -1588,7 +823,7 @@ class MANPAnalysis .isa()) { norm2SqEquiv = getSqMANP(insertOp, operands); } else { - isDummy = true; + norm2SqEquiv = {}; } } // InsertSliceOp @@ -1601,7 +836,7 @@ class MANPAnalysis .isa()) { norm2SqEquiv = getSqMANP(insertSliceOp, operands); } else { - isDummy = true; + norm2SqEquiv = {}; } } // FromElementOp @@ -1613,7 +848,7 @@ class MANPAnalysis .isa()) { norm2SqEquiv = getSqMANP(fromOp, operands); } else { - isDummy = true; + norm2SqEquiv = {}; } } // TensorCollapseShapeOp @@ -1626,7 +861,7 @@ class MANPAnalysis .isa()) { norm2SqEquiv = getSqMANP(reshapeOp, operands); } else { - isDummy = true; + norm2SqEquiv = {}; } } // TensorExpandShapeOp @@ -1638,31 +873,31 @@ class MANPAnalysis .isa()) { norm2SqEquiv = getSqMANP(reshapeOp, operands); } else { - isDummy = true; + norm2SqEquiv = {}; } } else if (llvm::isa(op)) { - isDummy = true; + norm2SqEquiv = {}; } else if (llvm::isa( *op->getDialect())) { op->emitError("Unsupported operation"); assert(false && "Unsupported operation"); } else { - isDummy = true; + norm2SqEquiv = {}; } - if (!isDummy) { + if (norm2SqEquiv.has_value()) { latticeRes->join(MANPLatticeValue{norm2SqEquiv}); op->setAttr("SMANP", mlir::IntegerAttr::get( mlir::IntegerType::get( - op->getContext(), norm2SqEquiv.getBitWidth(), + op->getContext(), norm2SqEquiv.value().getBitWidth(), mlir::IntegerType::SignednessSemantics::Unsigned), - norm2SqEquiv)); + norm2SqEquiv.value())); - llvm::APInt norm2Equiv = APIntCeilSqrt(norm2SqEquiv); + llvm::APInt norm2Equiv = APIntCeilSqrt(norm2SqEquiv.value()); op->setAttr("MANP", mlir::IntegerAttr::get( @@ -1673,7 +908,7 @@ class MANPAnalysis if (debug) { op->emitRemark("Squared Minimal Arithmetic Noise Padding: ") - << APIntToStringValUnsigned(norm2SqEquiv) << "\n"; + << APIntToStringValUnsigned(norm2SqEquiv.value()) << "\n"; } } else { latticeRes->join(MANPLatticeValue{}); @@ -1685,6 +920,148 @@ class MANPAnalysis }; } // namespace +namespace FHE { +llvm::APInt AddEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { + return APIntWidthExtendUAdd(a, b); +} + +llvm::APInt SubEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { + return APIntWidthExtendUAdd(a, b); +} + +llvm::APInt MulEintIntOp::sqMANP(llvm::APInt a) { + return sqMANP_mul_eint_int( + a, this->operandIntType(this->getClearOperandNumber()), + this->operandMaxConstant(this->getClearOperandNumber())); +} + +llvm::APInt MulEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { + return sqMANP_mul_eint(a, b); +} + +llvm::APInt MaxEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { + // max(a, b) = max(a - b, 0) + b + const llvm::APInt sub = APIntWidthExtendUAdd(a, b); + const llvm::APInt tlu = {1, 1, false}; + const llvm::APInt add = APIntWidthExtendUAdd(tlu, b); + + // this is not optimal as it can increase the resulting noise unnecessarily + return APIntUMax(add, sub); +} + +llvm::APInt RoundEintOp::sqMANP(llvm::APInt a) { + uint64_t inputWidth = + this->getOperand().getType().cast().getWidth(); + uint64_t outputWidth = + this->getResult().getType().cast().getWidth(); + uint64_t clearedBits = inputWidth - outputWidth; + + return a + clearedBits; +} +} // namespace FHE + +namespace FHELinalg { +llvm::APInt AddEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { + return APIntWidthExtendUAdd(a, b); +} + +llvm::APInt SubEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { + return APIntWidthExtendUAdd(a, b); +} + +llvm::APInt MulEintIntOp::sqMANP(llvm::APInt a) { + return sqMANP_mul_eint_int( + a, this->operandIntType(this->getClearOperandNumber()), + this->operandMaxConstant(this->getClearOperandNumber())); +} + +llvm::APInt MulEintOp::sqMANP(llvm::APInt a, llvm::APInt b) { + return sqMANP_mul_eint(a, b); +} + +llvm::APInt Dot::sqMANP(llvm::APInt a) { + unsigned clearOpNum = this->getClearOperandNumber(); + auto clearOperandType = this->getOperation() + ->getOpOperand(clearOpNum) + .get() + .getType() + .cast(); + return sqMANP_matmul(a, clearOperandType, this->opTensorConstant(clearOpNum), + clearOpNum); +} + +llvm::APInt MatMulEintIntOp::sqMANP(llvm::APInt a) { + unsigned clearOpNum = this->getClearOperandNumber(); + auto clearOpType = this->getOperation() + ->getOpOperand(clearOpNum) + .get() + .getType() + .cast(); + return sqMANP_matmul(a, clearOpType, this->opTensorConstant(clearOpNum), + clearOpNum); +} + +llvm::APInt MatMulIntEintOp::sqMANP(llvm::APInt a) { + unsigned clearOpNum = this->getClearOperandNumber(); + assert(clearOpNum <= 1 && "Operation has more than 2 operands"); + auto clearOpType = this->getOperation() + ->getOpOperand(clearOpNum) + .get() + .getType() + .cast(); + return sqMANP_matmul(a, clearOpType, this->opTensorConstant(clearOpNum), + clearOpNum); +} + +llvm::APInt Conv2dOp::sqMANP(llvm::APInt a) { + unsigned clearOpNum = this->getClearOperandNumber(); + auto clearOpType = this->getOperation() + ->getOpOperand(clearOpNum) + .get() + .getType() + .cast(); + return sqMANP_conv2d(a, clearOpType, this->opTensorConstant(clearOpNum)); +} + +llvm::APInt Maxpool2dOp::sqMANP(llvm::APInt a) { + // maximum between two value is calculated using + // - max(x - y, 0) + y + + // max is calculated with a TLU so MANP is {1, 1, false} + // y on the other hand comes from the input or from the previous result + + // in the current implementation, it's the input + // so the resulting MANP is `{1, 1, false} + MANP input` + + const llvm::APInt tlu = {1, 1, false}; + const llvm::APInt forResult = APIntWidthExtendUAdd(tlu, a); + const llvm::APInt forIntermediate = APIntWidthExtendUAdd(forResult, a); + + return APIntUMax(forIntermediate, forResult); +} + +llvm::APInt RoundOp::sqMANP(llvm::APInt a) { + const uint64_t inputWidth = this->getOperand() + .getType() + .cast() + .getElementType() + .cast() + .getWidth(); + + const uint64_t outputWidth = this->getResult() + .getType() + .cast() + .getElementType() + .cast() + .getWidth(); + + const uint64_t clearedBits = inputWidth - outputWidth; + + return a + clearedBits; +} + +} // namespace FHELinalg + namespace { /// For documentation see MANP.td struct MANPPass : public MANPBase { diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/CMakeLists.txt index 306b439685..de4a2530ae 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Analysis) +add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/CMakeLists.txt index eee3fb3a07..73999d5531 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library( mlir-headers LINK_LIBS PUBLIC - MLIRIR) + MLIRIR + FHEInterfaces) target_link_libraries(FHEDialect PUBLIC MLIRIR) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp index 3bea0d3a95..bfa6e5121e 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp @@ -6,8 +6,7 @@ #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHE/IR/FHETypes.h" - -#include "concretelang/Dialect/FHE/IR/FHETypesInterfaces.cpp.inc" +#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h" #define GET_TYPEDEF_CLASSES #include "concretelang/Dialect/FHE/IR/FHEOpsTypes.cpp.inc" @@ -30,7 +29,7 @@ void FHEDialect::initialize() { >(); } -mlir::LogicalResult EncryptedIntegerType::verify( +mlir::LogicalResult EncryptedUnsignedIntegerType::verify( llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) { if (p == 0) { emitError() << "FHE.eint doesn't support precision of 0"; @@ -39,11 +38,11 @@ mlir::LogicalResult EncryptedIntegerType::verify( return mlir::success(); } -void EncryptedIntegerType::print(mlir::AsmPrinter &p) const { +void EncryptedUnsignedIntegerType::print(mlir::AsmPrinter &p) const { p << "<" << getWidth() << ">"; } -mlir::Type EncryptedIntegerType::parse(mlir::AsmParser &p) { +mlir::Type EncryptedUnsignedIntegerType::parse(mlir::AsmParser &p) { if (p.parseLess()) return mlir::Type(); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index aadb6d0b01..844235c449 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -209,7 +209,7 @@ mlir::LogicalResult MaxEintOp::verify() { } mlir::LogicalResult ToSignedOp::verify() { - auto input = this->getInput().getType().cast(); + auto input = this->getInput().getType().cast(); auto output = this->getResult().getType().cast(); if (input.getWidth() != output.getWidth()) { @@ -223,7 +223,8 @@ mlir::LogicalResult ToSignedOp::verify() { mlir::LogicalResult ToUnsignedOp::verify() { auto input = this->getInput().getType().cast(); - auto output = this->getResult().getType().cast(); + auto output = + this->getResult().getType().cast(); if (input.getWidth() != output.getWidth()) { this->emitOpError( @@ -235,7 +236,7 @@ mlir::LogicalResult ToUnsignedOp::verify() { } mlir::LogicalResult ToBoolOp::verify() { - auto input = this->getInput().getType().cast(); + auto input = this->getInput().getType().cast(); if (input.getWidth() != 1 && input.getWidth() != 2) { this->emitOpError("should have 1 or 2 as the width of encrypted input to " diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/CMakeLists.txt new file mode 100644 index 0000000000..cb1fa7c114 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library( + FHEInterfaces + FHEInterfaces.cpp + FHEInterfacesInstances.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE + DEPENDS + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfaces.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfaces.cpp new file mode 100644 index 0000000000..32d1a5c7d9 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfaces.cpp @@ -0,0 +1,9 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h" + +#include "concretelang/Dialect/FHE/Interfaces/FHEOpsInterfaces.cpp.inc" +#include "concretelang/Dialect/FHE/Interfaces/FHETypesInterfaces.cpp.inc" diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfacesInstances.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfacesInstances.cpp new file mode 100644 index 0000000000..5980246f6e --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Interfaces/FHEInterfacesInstances.cpp @@ -0,0 +1,24 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Dialect/FHE/Interfaces/FHEInterfacesInstances.h" +#include "concretelang/Dialect/FHE/IR/FHEDialect.h" +#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +namespace mlir { +namespace concretelang { +namespace FHE { + +using namespace mlir::tensor; + +void registerFheInterfacesExternalModels(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { + ExtractOp::attachInterface(*ctx); + }); +} +} // namespace FHE +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp index 6d822c4877..7fa7b96353 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp @@ -49,11 +49,11 @@ namespace typing { /// Converts `FHE::ChunkedEncryptedInteger` into a tensor of /// `FHE::EncryptedInteger`. -mlir::RankedTensorType convertChunkedEint(mlir::MLIRContext *context, - FHE::EncryptedIntegerType chunkedEint, - unsigned int chunkSize, - unsigned int chunkWidth) { - auto eint = FHE::EncryptedIntegerType::get(context, chunkSize); +mlir::RankedTensorType +convertChunkedEint(mlir::MLIRContext *context, + FHE::EncryptedUnsignedIntegerType chunkedEint, + unsigned int chunkSize, unsigned int chunkWidth) { + auto eint = FHE::EncryptedUnsignedIntegerType::get(context, chunkSize); auto bigIntWidth = chunkedEint.getWidth(); assert(bigIntWidth % chunkWidth == 0 && "chunkWidth must divide width of the big integer"); @@ -68,14 +68,15 @@ class TypeConverter : public mlir::TypeConverter { public: TypeConverter(unsigned int chunkSize, unsigned int chunkWidth) { addConversion([](mlir::Type type) { return type; }); - addConversion([chunkSize, chunkWidth](FHE::EncryptedIntegerType type) { - if (type.getWidth() > chunkSize) { - return (mlir::Type)convertChunkedEint(type.getContext(), type, - chunkSize, chunkWidth); - } else { - return (mlir::Type)type; - } - }); + addConversion( + [chunkSize, chunkWidth](FHE::EncryptedUnsignedIntegerType type) { + if (type.getWidth() > chunkSize) { + return (mlir::Type)convertChunkedEint(type.getContext(), type, + chunkSize, chunkWidth); + } else { + return (mlir::Type)type; + } + }); } }; @@ -100,7 +101,7 @@ class AddEintPattern "chunked integer should be converted to flat tensors, but tensor " "have more than one dimension"); auto eintChunkWidth = tensorType.getElementType() - .dyn_cast() + .dyn_cast() .getWidth(); assert(eintChunkWidth == chunkSize && "wrong tensor elements width"); auto numberOfChunks = shape[0]; @@ -108,7 +109,7 @@ class AddEintPattern mlir::Value carry = rewriter .create(op.getLoc(), - FHE::EncryptedIntegerType::get( + FHE::EncryptedUnsignedIntegerType::get( rewriter.getContext(), chunkSize)) .getResult(); @@ -142,8 +143,8 @@ class AddEintPattern carry = rewriter.create( op.getLoc(), - FHE::EncryptedIntegerType::get(rewriter.getContext(), - chunkSize), + FHE::EncryptedUnsignedIntegerType::get(rewriter.getContext(), + chunkSize), resultWithCarry, getTruthTableCarryExtract(rewriter, op.getLoc(), chunkSize, chunkWidth)); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp index 4ed18f3e32..a2e07cee96 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp @@ -29,7 +29,7 @@ class GenGatePattern mlir::LogicalResult matchAndRewrite(mlir::concretelang::FHE::GenGateOp op, mlir::PatternRewriter &rewriter) const override { - auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get( + auto eint2 = mlir::concretelang::FHE::EncryptedUnsignedIntegerType::get( rewriter.getContext(), 2); auto left = rewriter .create( @@ -104,7 +104,7 @@ class MuxOpPattern mlir::LogicalResult matchAndRewrite(mlir::concretelang::FHE::MuxOp op, mlir::PatternRewriter &rewriter) const override { - auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get( + auto eint2 = mlir::concretelang::FHE::EncryptedUnsignedIntegerType::get( rewriter.getContext(), 2); auto boolType = mlir::concretelang::FHE::EncryptedBooleanType::get( rewriter.getContext()); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 2c7dc02e52..f5e08169ba 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -1350,7 +1350,7 @@ mlir::LogicalResult ToSignedOp::verify() { } auto inputElementType = - inputType.getElementType().cast(); + inputType.getElementType().cast(); auto outputElementType = outputType.getElementType().cast(); @@ -1381,7 +1381,7 @@ mlir::LogicalResult ToUnsignedOp::verify() { auto inputElementType = inputType.getElementType().cast(); auto outputElementType = - outputType.getElementType().cast(); + outputType.getElementType().cast(); if (inputElementType.getWidth() != outputElementType.getWidth()) { this->emitOpError() diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir index c1fd83034a..0a128806bb 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir @@ -8,11 +8,3 @@ func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { %1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } - -// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe) -> !TFHE.glwe -func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool { - // CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe) -> !TFHE.glwe - // CHECK-NEXT: return %0 : !TFHE.glwe - %1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool - return %1: !FHE.ebool -} diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir index 16ede990e8..eadd19c1d8 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir @@ -199,7 +199,7 @@ func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.mul_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir index ba34a0f48d..b9de103ec6 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir @@ -25,7 +25,7 @@ func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : ten func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { %bias = arith.constant dense<[5]> : tensor<1xi3> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 14 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> @@ -35,7 +35,7 @@ func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tens // ----- func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 14 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> @@ -45,7 +45,7 @@ func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weigh // ----- func.func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 25 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 11 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir index e48c3d9f9d..4cf20197b7 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir @@ -207,7 +207,7 @@ func.func @apply_lookup_table(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.e func.func @apply_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<3>> { %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64> - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> // CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>> %res = "FHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>> @@ -226,7 +226,7 @@ func.func @apply_multi_lookup_table(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor< // ----- func.func @apply_multi_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> { - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> // CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_multi_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> %res = "FHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> @@ -253,7 +253,7 @@ func.func @single_cst_dot(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> func.func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FHE.eint<2> { - // sqrt(1*(2^2-1)^2*4) = 16 + // sqrt(1^2*(2^2-1)^2*4) = 6 // CHECK: %[[V0:.*]] = "FHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> %0 = "FHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> @@ -265,12 +265,12 @@ func.func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FH func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2> { // sqrt((2^2-1)^2*1) = sqrt(9) = 3 - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}} + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> %cst = arith.constant dense<[1, 2, 3, -1]> : tensor<4xi3> // sqrt(1^2*9 + 2^2*9 + 3^2*9 + 1^2*9) = sqrt(135) = 12 - // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 28 : ui{{[[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 12 : ui{{[[0-9]+}}} %1 = "FHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %1 : !FHE.eint<2> @@ -281,12 +281,11 @@ func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) func.func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2> { // sqrt((2^2-1)^2*1) = sqrt(9) = 3 - // FIXME: the dynamic clear value MANP computation is wrong, update the MANP to the correct one when it's fixed - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}} + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> - // sqrt(4*(2^2-1)^2*9) = sqrt(324) = 18 - // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I:.*]]) {MANP = 42 : ui{{[0-9]+}}} + // sqrt(3^2*(2^2-1)^2*4) = sqrt(324) = 18 + // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I:.*]]) {MANP = 18 : ui{{[0-9]+}}} %1 = "FHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %1 : !FHE.eint<2> @@ -304,8 +303,7 @@ func.func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tenso // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 9 // ceil(sqrt(9)) = 3 - // FIXME: the dynamic clear value MANP computation is wrong, update the MANP to the correct one when it's fixed - // CHECK: %[[V0:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 7 : ui{{[0-9]+}}} + // CHECK: %[[V0:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 3 : ui{{[0-9]+}}} %0 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %0 : tensor<3x2x!FHE.eint<2>> } @@ -321,8 +319,7 @@ func.func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tenso // manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 9 + 9 = 18 // ceil(sqrt(18)) = 5 - // FIXME: the dynamic clear value MANP computation is wrong, update the MANP to the correct one when it's fixed - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 10 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -587,9 +584,9 @@ func.func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE // p = 0 // acc = manp(0) = 0 // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 - // manp(add_eint(mul, acc)) = 64 - // ceil(sqrt(64)) = 8 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 7 : ui{{[0-9]+}}} + // manp(add_eint(mul, acc)) = 9 + // ceil(sqrt(9)) = 3 + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 3 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -598,14 +595,13 @@ func.func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE func.func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> { // p = 0 - // acc = manp(0) = 1 + // acc = manp(0) = 0 // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 - // manp(add_eint(mul, acc)) = 64 + 1 = 10 - // p = 1 + // manp(add_eint(mul, acc)) = 0 + 9 = 9 // manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 - // manp(add_eint(mul, acc)) = 10 + 9 = 19 - // ceil(sqrt(129)) = 12 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 10 : ui{{[0-9]+}}} + // manp(add_eint(mul, acc)) = 9 + 9 = 18 + // ceil(sqrt(18)) = 5 + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x2xi3>, tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -619,7 +615,7 @@ func.func @matmul_int_eint_cst_p_1(%arg0: tensor<1x3x!FHE.eint<2>>) -> tensor<2x // mul = manp(mul_eint_int(eint<2>, 3) = 1^2 + 3^2 = 10 // manp(add_eint(mul, acc)) = 10 // ceil(sqrt(10)) = 4 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 3 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x1xi3>, tensor<1x3x!FHE.eint<2>>) -> tensor<2x3x!FHE.eint<2>> return %1 : tensor<2x3x!FHE.eint<2>> } @@ -637,7 +633,7 @@ func.func @matmul_int_eint_cst_p_2_n_0(%arg0: tensor<2x3x!FHE.eint<2>>) -> tenso // mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 17 // manp(add_eint(mul, acc)) = 17 + 9 = 26 // ceil(sqrt(26)) = 6 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 6 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!FHE.eint<2>>) -> tensor<2x3x!FHE.eint<2>> return %1 : tensor<2x3x!FHE.eint<2>> } @@ -685,7 +681,7 @@ func.func @matmul_int_eint_cst(%0: tensor<3x2x!FHE.eint<7>>) -> (tensor<2x!FHE.e ] > : tensor<2x3xi8> - // CHECK: MANP = 8 : ui{{[0-9]+}} + // CHECK: MANP = 7 : ui{{[0-9]+}} %4 = "FHELinalg.matmul_int_eint"(%3, %0) : (tensor<2x3xi8>, tensor<3x2x!FHE.eint<7>>) -> tensor<2x2x!FHE.eint<7>> // =============================== diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir index f3ece12254..e9e137c118 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir @@ -99,9 +99,9 @@ func.func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8 func.func @tensor_collapse_shape_2(%a: tensor<2x2x4x!FHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!FHE.eint<2>> { - // CHECK: "FHELinalg.mul_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 7 : ui{{[0-9]+}}} + // CHECK: "FHELinalg.mul_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 3 : ui{{[0-9]+}}} %0 = "FHELinalg.mul_eint_int"(%a, %b) : (tensor<2x2x4x!FHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!FHE.eint<2>> - // CHECK-NEXT: tensor.collapse_shape %[[A:.*]] [[X:.*]] {MANP = 7 : ui{{[0-9]+}}} + // CHECK-NEXT: tensor.collapse_shape %[[A:.*]] [[X:.*]] {MANP = 3 : ui{{[0-9]+}}} %1 = tensor.collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!FHE.eint<2>> into tensor<2x8x!FHE.eint<2>> return %1 : tensor<2x8x!FHE.eint<2>> } @@ -118,9 +118,9 @@ func.func @tensor_expand_shape_1(%a: tensor<2x8x!FHE.eint<6>>) -> tensor<2x2x4x! func.func @tensor_expand_shape_2(%a: tensor<2x8x!FHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!FHE.eint<2>> { - // CHECK: "FHELinalg.mul_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 7 : ui{{[0-9]+}}} + // CHECK: "FHELinalg.mul_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 3 : ui{{[0-9]+}}} %0 = "FHELinalg.mul_eint_int"(%a, %b) : (tensor<2x8x!FHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!FHE.eint<2>> - // CHECK-NEXT: tensor.expand_shape %[[A:.*]] [[X:.*]] {MANP = 7 : ui{{[0-9]+}}} + // CHECK-NEXT: tensor.expand_shape %[[A:.*]] [[X:.*]] {MANP = 3 : ui{{[0-9]+}}} %1 = tensor.expand_shape %0 [[0],[1,2]] : tensor<2x8x!FHE.eint<2>> into tensor<2x2x4x!FHE.eint<2>> return %1 : tensor<2x2x4x!FHE.eint<2>> }