diff --git a/src/arithmetic/Zp.jl b/src/arithmetic/Zp.jl index cb3f90c6..68aa50ef 100644 --- a/src/arithmetic/Zp.jl +++ b/src/arithmetic/Zp.jl @@ -28,8 +28,10 @@ less_than_half(p, ::Type{T}) where {T} = p < (typemax(T) >> ((8 >> 1) * sizeof(T # Modular arithmetic based on builtin classes in `Base.MultiplicativeInverses`. struct ArithmeticZp{AccumType, CoeffType} <: AbstractArithmeticZp{AccumType, CoeffType} - # magic contains the precomputed multiplicative inverse of the divisor - magic::UnsignedMultiplicativeInverse{AccumType} + multiplier::AccumType + shift::UInt8 + divisor::AccumType + add::Bool ArithmeticZp(::Type{A}, ::Type{C}, p) where {A, C} = ArithmeticZp(A, C, C(p)) @@ -40,15 +42,21 @@ struct ArithmeticZp{AccumType, CoeffType} <: AbstractArithmeticZp{AccumType, Coe ) where {AccumType <: CoeffZp, CoeffType <: CoeffZp} @invariant less_than_half(p, AccumType) @invariant Primes.isprime(p) - new{AccumType, CoeffType}( - UnsignedMultiplicativeInverse{AccumType}(convert(AccumType, p)) - ) + uinv = UnsignedMultiplicativeInverse{AccumType}(convert(AccumType, p)) + # Further in the code we need the guarantee that the shift is < 64 + @invariant uinv.shift < 8 * sizeof(AccumType) + new{AccumType, CoeffType}(uinv.multiplier, uinv.shift, uinv.divisor, uinv.add) end end -divisor(arithm::ArithmeticZp) = arithm.magic.divisor +divisor(arithm::ArithmeticZp) = arithm.divisor -@inline mod_p(a::T, arithm::ArithmeticZp{T}) where {T} = a % arithm.magic +@inline function mod_p(a::T, mod::ArithmeticZp{T}) where {T} + x = _mul_high(a, mod.multiplier) + x = ifelse(mod.add, convert(T, convert(T, (convert(T, a - x) >>> UInt8(1))) + x), x) + unsafe_assume(mod.shift < 8 * sizeof(T)) + a - (x >>> mod.shift) * mod.divisor +end inv_mod_p(a::T, arithm::ArithmeticZp{T}) where {T} = invmod(a, divisor(arithm)) diff --git a/src/groebner/learn-apply.jl b/src/groebner/learn-apply.jl index 3974093d..9ebf38fa 100644 --- a/src/groebner/learn-apply.jl +++ b/src/groebner/learn-apply.jl @@ -167,3 +167,44 @@ function _groebner_apply1!(ring, trace, params) flag, gb_monoms, gb_coeffs end + +#= +Several assumptions are in place: +- input contains no zero polynomials, +- input coefficients are non-negative, +- input coefficients are smaller than modulo +=# +function groebner_applyX!( + wrapped_trace::WrappedTraceF4, + coeffs_zp::Vector{Vector{UInt32}}, + modulo::UInt32; + options... +) + kws = KeywordsHandler(:groebner_apply!, options) + + logging_setup(kws) + statistics_setup(kws) + + trace = get_default_trace(wrapped_trace) + @log level = -5 "Selected trace" trace.representation.coefftype + + ring = extract_coeffs_raw_X!(trace, trace.representation, coeffs_zp, modulo, kws) + + # TODO: this is a bit hacky + params = AlgorithmParameters( + ring, + trace.representation, + kws, + orderings=(trace.params.original_ord, trace.params.target_ord) + ) + ring = PolyRing(trace.ring.nvars, trace.ring.ord, ring.ch) + + flag, gb_monoms, gb_coeffs = _groebner_apply1!(ring, trace, params) + + if trace.params.homogenize + ring, gb_monoms, gb_coeffs = + dehomogenize_generators!(ring, gb_monoms, gb_coeffs, params) + end + + flag, gb_coeffs::Vector{Vector{UInt32}} +end diff --git a/src/input-output/AbstractAlgebra.jl b/src/input-output/AbstractAlgebra.jl index bcca50d0..6c7ced1b 100644 --- a/src/input-output/AbstractAlgebra.jl +++ b/src/input-output/AbstractAlgebra.jl @@ -282,6 +282,93 @@ function extract_coeffs_raw!( ring end +function extract_coeffs_raw_X!( + trace, + representation::PolynomialRepresentation, + coeffs_zp, + modulo, + kws::KeywordsHandler +) + ring = PolyRing(trace.ring.nvars, trace.ring.ord, UInt64(modulo)) + + basis = trace.buf_basis + input_polys_perm = trace.input_permutation + term_perms = trace.term_sorting_permutations + homog_term_perm = trace.term_homogenizing_permutations + CoeffType = representation.coefftype + + _extract_coeffs_raw_X!( + basis, + input_polys_perm, + term_perms, + homog_term_perm, + coeffs_zp, + CoeffType + ) + + # a hack for homogenized inputs + if trace.homogenize + @assert length(basis.monoms[length(polys) + 1]) == + length(basis.coeffs[length(polys) + 1]) == + 2 + # TODO: !! incorrect if there are zeros in the input + @invariant !iszero(ring.ch) + C = eltype(basis.coeffs[length(polys) + 1][1]) + basis.coeffs[length(polys) + 1][1] = one(C) + basis.coeffs[length(polys) + 1][2] = + iszero(ring.ch) ? -one(C) : (ring.ch - one(ring.ch)) + end + + @log level = -6 "Extracted coefficients from $(length(polys)) polynomials." basis + @log level = -8 "Extracted coefficients" basis.coeffs + ring +end + +function _extract_coeffs_raw_X!( + basis, + input_polys_perm::Vector{Int}, + term_perms::Vector{Vector{Int}}, + homog_term_perms::Vector{Vector{Int}}, + coeffs_zp, + ::Type{CoeffsType} +) where {CoeffsType} + # write new coefficients directly to trace.buf_basis + permute_input_terms = !isempty(term_perms) + permute_homogenizing_terms = !isempty(homog_term_perms) + + @log level = -2 """ + Permuting input terms: $permute_input_terms + Permuting for homogenization: $permute_homogenizing_terms""" + @log level = -7 """Permutations: + Of polynomials: $input_polys_perm + Of terms (change of ordering): $term_perms + Of terms (homogenization): $homog_term_perms""" + @inbounds for i in 1:length(coeffs_zp) + basis_cfs = basis.coeffs[i] + poly_index = input_polys_perm[i] + poly = coeffs_zp[poly_index] + if !(length(poly) == length(basis_cfs)) + __throw_input_not_supported( + "Potential coefficient cancellation in input polynomial at index $i on apply stage.", + poly + ) + end + for j in 1:length(poly) + coeff_index = j + if permute_input_terms + coeff_index = term_perms[poly_index][coeff_index] + end + if permute_homogenizing_terms + coeff_index = homog_term_perms[poly_index][coeff_index] + end + coeff = poly[coeff_index] + basis_cfs[j] = convert(CoeffsType, coeff) + end + end + + nothing +end + function io_extract_coeffs_raw_batched!( trace, representation::PolynomialRepresentation, diff --git a/src/utils/simd.jl b/src/utils/simd.jl index 42cf5b92..dbb0d57b 100644 --- a/src/utils/simd.jl +++ b/src/utils/simd.jl @@ -35,9 +35,12 @@ typemax_saturated(_::Type{T}, N) where {T <: BitInteger} = typemax(T) ⊻ (N - 1 # If this is not possible, returns N = 1. function cutoff8_pick_vector_width(::Type{T}) where {T} N = pick_vector_width(T) - if N in (8, 16, 32, 64) + if N in (8, 16, 32) return Int(N) end + if N == 64 + return 32 + end 1 end