Skip to content

Commit

Permalink
add low-level apply method
Browse files Browse the repository at this point in the history
  • Loading branch information
sumiya11 committed Jan 29, 2024
1 parent f1180b8 commit 9c0e087
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 8 deletions.
22 changes: 15 additions & 7 deletions src/arithmetic/Zp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ less_than_half(p, ::Type{T}) where {T} = p < (typemax(T) >> ((8 >> 1) * sizeof(T

# Modular arithmetic based on builtin classes in `Base.MultiplicativeInverses`.
struct ArithmeticZp{AccumType, CoeffType} <: AbstractArithmeticZp{AccumType, CoeffType}
# magic contains the precomputed multiplicative inverse of the divisor
magic::UnsignedMultiplicativeInverse{AccumType}
multiplier::AccumType
shift::UInt8
divisor::AccumType
add::Bool

ArithmeticZp(::Type{A}, ::Type{C}, p) where {A, C} = ArithmeticZp(A, C, C(p))

Expand All @@ -40,15 +42,21 @@ struct ArithmeticZp{AccumType, CoeffType} <: AbstractArithmeticZp{AccumType, Coe
) where {AccumType <: CoeffZp, CoeffType <: CoeffZp}
@invariant less_than_half(p, AccumType)
@invariant Primes.isprime(p)
new{AccumType, CoeffType}(
UnsignedMultiplicativeInverse{AccumType}(convert(AccumType, p))
)
uinv = UnsignedMultiplicativeInverse{AccumType}(convert(AccumType, p))
# Further in the code we need the guarantee that the shift is < 64
@invariant uinv.shift < 8 * sizeof(AccumType)
new{AccumType, CoeffType}(uinv.multiplier, uinv.shift, uinv.divisor, uinv.add)
end
end

divisor(arithm::ArithmeticZp) = arithm.magic.divisor
divisor(arithm::ArithmeticZp) = arithm.divisor

@inline mod_p(a::T, arithm::ArithmeticZp{T}) where {T} = a % arithm.magic
@inline function mod_p(a::T, mod::ArithmeticZp{T}) where {T}
x = _mul_high(a, mod.multiplier)
x = ifelse(mod.add, convert(T, convert(T, (convert(T, a - x) >>> UInt8(1))) + x), x)
unsafe_assume(mod.shift < 8 * sizeof(T))
a - (x >>> mod.shift) * mod.divisor
end

inv_mod_p(a::T, arithm::ArithmeticZp{T}) where {T} = invmod(a, divisor(arithm))

Expand Down
41 changes: 41 additions & 0 deletions src/groebner/learn-apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,44 @@ function _groebner_apply1!(ring, trace, params)

flag, gb_monoms, gb_coeffs
end

#=
Several assumptions are in place:
- input contains no zero polynomials,
- input coefficients are non-negative,
- input coefficients are smaller than modulo
=#
function groebner_applyX!(
wrapped_trace::WrappedTraceF4,
coeffs_zp::Vector{Vector{UInt32}},
modulo::UInt32;
options...
)
kws = KeywordsHandler(:groebner_apply!, options)

logging_setup(kws)
statistics_setup(kws)

trace = get_default_trace(wrapped_trace)
@log level = -5 "Selected trace" trace.representation.coefftype

ring = extract_coeffs_raw_X!(trace, trace.representation, coeffs_zp, modulo, kws)

# TODO: this is a bit hacky
params = AlgorithmParameters(
ring,
trace.representation,
kws,
orderings=(trace.params.original_ord, trace.params.target_ord)
)
ring = PolyRing(trace.ring.nvars, trace.ring.ord, ring.ch)

flag, gb_monoms, gb_coeffs = _groebner_apply1!(ring, trace, params)

if trace.params.homogenize
ring, gb_monoms, gb_coeffs =
dehomogenize_generators!(ring, gb_monoms, gb_coeffs, params)
end

flag, gb_coeffs::Vector{Vector{UInt32}}
end
87 changes: 87 additions & 0 deletions src/input-output/AbstractAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,93 @@ function extract_coeffs_raw!(
ring
end

function extract_coeffs_raw_X!(
trace,
representation::PolynomialRepresentation,
coeffs_zp,
modulo,
kws::KeywordsHandler
)
ring = PolyRing(trace.ring.nvars, trace.ring.ord, UInt64(modulo))

basis = trace.buf_basis
input_polys_perm = trace.input_permutation
term_perms = trace.term_sorting_permutations
homog_term_perm = trace.term_homogenizing_permutations
CoeffType = representation.coefftype

_extract_coeffs_raw_X!(
basis,
input_polys_perm,
term_perms,
homog_term_perm,
coeffs_zp,
CoeffType
)

# a hack for homogenized inputs
if trace.homogenize
@assert length(basis.monoms[length(polys) + 1]) ==
length(basis.coeffs[length(polys) + 1]) ==
2
# TODO: !! incorrect if there are zeros in the input
@invariant !iszero(ring.ch)
C = eltype(basis.coeffs[length(polys) + 1][1])
basis.coeffs[length(polys) + 1][1] = one(C)
basis.coeffs[length(polys) + 1][2] =
iszero(ring.ch) ? -one(C) : (ring.ch - one(ring.ch))
end

@log level = -6 "Extracted coefficients from $(length(polys)) polynomials." basis
@log level = -8 "Extracted coefficients" basis.coeffs
ring
end

function _extract_coeffs_raw_X!(
basis,
input_polys_perm::Vector{Int},
term_perms::Vector{Vector{Int}},
homog_term_perms::Vector{Vector{Int}},
coeffs_zp,
::Type{CoeffsType}
) where {CoeffsType}
# write new coefficients directly to trace.buf_basis
permute_input_terms = !isempty(term_perms)
permute_homogenizing_terms = !isempty(homog_term_perms)

@log level = -2 """
Permuting input terms: $permute_input_terms
Permuting for homogenization: $permute_homogenizing_terms"""
@log level = -7 """Permutations:
Of polynomials: $input_polys_perm
Of terms (change of ordering): $term_perms
Of terms (homogenization): $homog_term_perms"""
@inbounds for i in 1:length(coeffs_zp)
basis_cfs = basis.coeffs[i]
poly_index = input_polys_perm[i]
poly = coeffs_zp[poly_index]
if !(length(poly) == length(basis_cfs))
__throw_input_not_supported(
"Potential coefficient cancellation in input polynomial at index $i on apply stage.",
poly
)
end
for j in 1:length(poly)
coeff_index = j
if permute_input_terms
coeff_index = term_perms[poly_index][coeff_index]
end
if permute_homogenizing_terms
coeff_index = homog_term_perms[poly_index][coeff_index]
end
coeff = poly[coeff_index]
basis_cfs[j] = convert(CoeffsType, coeff)
end
end

nothing
end

function io_extract_coeffs_raw_batched!(
trace,
representation::PolynomialRepresentation,
Expand Down
5 changes: 4 additions & 1 deletion src/utils/simd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ typemax_saturated(_::Type{T}, N) where {T <: BitInteger} = typemax(T) ⊻ (N - 1
# If this is not possible, returns N = 1.
function cutoff8_pick_vector_width(::Type{T}) where {T}
N = pick_vector_width(T)
if N in (8, 16, 32, 64)
if N in (8, 16, 32)
return Int(N)
end
if N == 64
return 32
end
1
end

Expand Down

0 comments on commit 9c0e087

Please sign in to comment.