diff --git a/src/MPoly.jl b/src/MPoly.jl index 44bbf9f8a3..6bd2e47b66 100644 --- a/src/MPoly.jl +++ b/src/MPoly.jl @@ -500,8 +500,9 @@ end Return an iterator for the coefficients of the given polynomial. To retrieve an array of the coefficients, use `collect(coefficients(a))`. """ -function coefficients(a::MPolyRingElem{T}) where T <: RingElement - return Generic.MPolyCoeffs(a) +function coefficients(a::MPolyRingElem{T}; inplace::Bool = false) where T <: RingElement + t = zero(coefficient_ring(parent(a))) + return Generic.MPolyCoeffs(a, inplace, t) end @doc raw""" @@ -511,8 +512,14 @@ Return an iterator for the exponent vectors of the given polynomial. To retrieve an array of the exponent vectors, use `collect(exponent_vectors(a))`. """ -function exponent_vectors(a::MPolyRingElem{T}) where T <: RingElement - return Generic.MPolyExponentVectors(a) +function exponent_vectors(a::MPolyRingElem{T}; inplace::Bool = false) where T <: RingElement + return exponent_vectors(Vector{Int}, a, inplace=inplace) +end + +function exponent_vectors(::Type{Vector{S}}, a::MPolyRingElem{T}; inplace::Bool = false) where {T <: RingElement, S} + # Don't use `zeros`: If S === ZZRingElem, then all the entries would be identical + t = [zero(S) for _ in 1:nvars(parent(a))] + return Generic.MPolyExponentVectors(a, inplace, t) end @doc raw""" @@ -521,8 +528,9 @@ end Return an iterator for the monomials of the given polynomial. To retrieve an array of the monomials, use `collect(monomials(a))`. """ -function monomials(a::MPolyRingElem{T}) where T <: RingElement - return Generic.MPolyMonomials(a) +function monomials(a::MPolyRingElem{T}; inplace::Bool = false) where T <: RingElement + t = zero(parent(a)) + return Generic.MPolyMonomials(a, inplace, t) end @doc raw""" @@ -531,8 +539,9 @@ end Return an iterator for the terms of the given polynomial. To retrieve an array of the terms, use `collect(terms(a))`. """ -function terms(a::MPolyRingElem{T}) where T <: RingElement - return Generic.MPolyTerms(a) +function terms(a::MPolyRingElem{T}; inplace::Bool = false) where T <: RingElement + t = zero(parent(a)) + return Generic.MPolyTerms(a, inplace, t) end ############################################################################### diff --git a/src/exports.jl b/src/exports.jl index d019f0db32..9b03d00b59 100644 --- a/src/exports.jl +++ b/src/exports.jl @@ -154,6 +154,7 @@ export check_composable export check_parent export codomain export coeff +export coeff! export coefficient_ring export coefficient_ring_type export coefficients @@ -213,6 +214,7 @@ export evaluate export exp_gcd export exponent export exponent_vector +export exponent_vector! export exponent_vectors export exponent_word export exponent_words @@ -579,6 +581,7 @@ export symbols export tail export term export terms +export term! export to_univariate export total_degree export total_ring_of_fractions diff --git a/src/generic/GenericTypes.jl b/src/generic/GenericTypes.jl index bdb61fde40..be65203391 100644 --- a/src/generic/GenericTypes.jl +++ b/src/generic/GenericTypes.jl @@ -385,20 +385,46 @@ end # Iterators -struct MPolyCoeffs{T <: AbstractAlgebra.NCRingElem} +struct MPolyCoeffs{T <: AbstractAlgebra.NCRingElem, S <: AbstractAlgebra.RingElement} poly::T + inplace::Bool + temp::S # only used if inplace == true end -struct MPolyExponentVectors{T <: AbstractAlgebra.RingElem} +function MPolyCoeffs(f::AbstractAlgebra.NCRingElem) + return MPolyCoeffs(f, false, zero(coefficient_ring(parent(f)))) +end + +# S may be the type of anything that can store an exponent vector, for example +# Vector{Int}, ZZMatrix, ... +struct MPolyExponentVectors{T <: AbstractAlgebra.RingElem, S} poly::T + inplace::Bool + temp::S # only used if inplace == true +end + +function MPolyExponentVectors(f::AbstractAlgebra.RingElem) + return MPolyExponentVectors(f, false, Vector{Int}()) end struct MPolyTerms{T <: AbstractAlgebra.NCRingElem} poly::T + inplace::Bool + temp::T # only used if inplace == true +end + +function MPolyTerms(f::AbstractAlgebra.NCRingElem) + return MPolyTerms(f, false, zero(parent(f))) end struct MPolyMonomials{T <: AbstractAlgebra.NCRingElem} poly::T + inplace::Bool + temp::T # only used if inplace == true +end + +function MPolyMonomials(f::NCRingElem) + return MPolyMonomials(f, false, zero(parent(f))) end mutable struct MPolyBuildCtx{T, S} diff --git a/src/generic/MPoly.jl b/src/generic/MPoly.jl index 29a5b337da..f8e37b72ae 100644 --- a/src/generic/MPoly.jl +++ b/src/generic/MPoly.jl @@ -125,19 +125,31 @@ are given in the order of the variables for the ring, as supplied when the ring was created. """ function exponent_vector(a::MPoly{T}, i::Int) where T <: RingElement + e = Vector{Int}(undef, nvars(parent(a))) + return exponent_vector!(e, a, i) +end + +function exponent_vector!(e::Vector{S}, a::MPoly{T}, i::Int) where {T <: RingElement, S} + @assert length(e) == nvars(parent(a)) A = a.exps N = size(A, 1) ord = internal_ordering(parent(a)) if ord == :lex - return [Int(A[j, i]) for j in N:-1:1] + range = N:-1:1 elseif ord == :deglex - return [Int(A[j, i]) for j in N - 1:-1:1] + range = N - 1:-1:1 elseif ord == :degrevlex - return [Int(A[j, i]) for j in 1:N - 1] + range = 1:N - 1 else error("invalid ordering") end + k = 1 + for j in range + e[k] = S(A[j, i]) + k += 1 + end + return e end @doc raw""" @@ -635,6 +647,11 @@ function coeff(x::MPoly, i::Int) return x.coeffs[i] end +# Only for compatibility, we can't do anything in place here +function coeff!(c::T, x::MPoly{T}, i::Int) where T <: RingElement + return x.coeffs[i] +end + function trailing_coefficient(p::MPoly{T}) where T <: RingElement @req !iszero(p) "Zero polynomial does not have a leading monomial" return coeff(p, length(p)) @@ -664,7 +681,9 @@ function monomial!(m::MPoly{T}, x::MPoly{T}, i::Int) where T <: RingElement N = size(x.exps, 1) fit!(m, 1) monomial_set!(m.exps, 1, x.exps, i, N) - m.coeffs[1] = one(base_ring(x)) + if !isassigned(m.coeffs, 1) || !is_one(m.coeffs[1]) + m.coeffs[1] = one(base_ring(x)) + end m.length = 1 return m end @@ -675,11 +694,17 @@ end Return the $i$-th nonzero term of the polynomial $x$ (as a polynomial). """ function term(x::MPoly, i::Int) - R = base_ring(x) + y = zero(parent(x)) + return term!(y, x, i) +end + +function term!(y::T, x::T, i::Int) where T <: MPoly N = size(x.exps, 1) - exps = Matrix{UInt}(undef, N, 1) - monomial_set!(exps, 1, x.exps, i, N) - return parent(x)([deepcopy(x.coeffs[i])], exps) + fit!(y, 1) + monomial_set!(y.exps, 1, x.exps, i, N) + y.coeffs[1] = deepcopy(x.coeffs[i]) + y.length = 1 + return y end @doc raw""" @@ -804,69 +829,41 @@ Base.copy(f::Generic.MPoly) = deepcopy(f) # ############################################################################### -function Base.iterate(x::MPolyCoeffs) - if length(x.poly) >= 1 - return coeff(x.poly, 1), 1 - else - return nothing - end -end - -function Base.iterate(x::MPolyCoeffs, state) - state += 1 - if length(x.poly) >= state - return coeff(x.poly, state), state - else - return nothing - end -end - -function Base.iterate(x::MPolyExponentVectors) - if length(x.poly) >= 1 - return exponent_vector(x.poly, 1), 1 - else - return nothing - end -end - -function Base.iterate(x::MPolyExponentVectors, state) - state += 1 - if length(x.poly) >= state - return exponent_vector(x.poly, state), state - else - return nothing - end -end - -function Base.iterate(x::MPolyTerms) - if length(x.poly) >= 1 - return term(x.poly, 1), 1 +function Base.iterate(x::MPolyCoeffs, state::Union{Nothing, Int} = nothing) + s = isnothing(state) ? 1 : state + 1 + if length(x.poly) >= s + c = x.inplace ? coeff!(x.temp, x.poly, s) : coeff(x.poly, s) + return c, s else return nothing end end -function Base.iterate(x::MPolyTerms, state) - state += 1 - if length(x.poly) >= state - return term(x.poly, state), state +function Base.iterate(x::MPolyExponentVectors, state::Union{Nothing, Int} = nothing) + s = isnothing(state) ? 1 : state + 1 + if length(x.poly) >= s + v = x.inplace ? exponent_vector!(x.temp, x.poly, s) : exponent_vector(x.poly, s) + return v, s else return nothing end end -function Base.iterate(x::MPolyMonomials) - if length(x.poly) >= 1 - return monomial(x.poly, 1), 1 +function Base.iterate(x::MPolyTerms, state::Union{Nothing, Int} = nothing) + s = isnothing(state) ? 1 : state + 1 + if length(x.poly) >= s + t = x.inplace ? term!(x.temp, x.poly, s) : term(x.poly, s) + return t, s else return nothing end end -function Base.iterate(x::MPolyMonomials, state) - state += 1 - if length(x.poly) >= state - return monomial(x.poly, state), state +function Base.iterate(x::MPolyMonomials, state::Union{Nothing, Int} = nothing) + s = isnothing(state) ? 1 : state + 1 + if length(x.poly) >= s + m = x.inplace ? monomial!(x.temp, x.poly, s) : monomial(x.poly, s) + return m, s else return nothing end @@ -876,12 +873,12 @@ function Base.length(x::Union{MPolyCoeffs, MPolyExponentVectors, MPolyTerms, MPo return length(x.poly) end -function Base.eltype(::Type{MPolyCoeffs{T}}) where T <: AbstractAlgebra.MPolyRingElem{S} where S <: RingElement +function Base.eltype(::Type{MPolyCoeffs{T, S}}) where {T <: AbstractAlgebra.MPolyRingElem, S <: RingElement} return S end -function Base.eltype(::Type{MPolyExponentVectors{T}}) where T <: AbstractAlgebra.MPolyRingElem{S} where S <: RingElement - return Vector{Int} +function Base.eltype(::Type{MPolyExponentVectors{T, V}}) where {V, T <: AbstractAlgebra.MPolyRingElem{S} where S <: RingElement} + return V end function Base.eltype(::Type{MPolyMonomials{T}}) where T <: AbstractAlgebra.MPolyRingElem{S} where S <: RingElement diff --git a/src/generic/exports.jl b/src/generic/exports.jl index 62b978d326..938c95db7a 100644 --- a/src/generic/exports.jl +++ b/src/generic/exports.jl @@ -12,6 +12,7 @@ export abs_series_type export base_field export basis export character +export coeff! export collength export combine_like_terms! export cycles @@ -25,6 +26,7 @@ export enable_cache! export exp_gcd export exponent export exponent_vector +export exponent_vector! export exponent_word export falling_factorial export finish @@ -122,6 +124,7 @@ export summands export supermodule export term export terms +export term! export to_univariate export total_degree export trailing_coefficient diff --git a/test/generic/MPoly-test.jl b/test/generic/MPoly-test.jl index 8e7d36cd91..4a3a10bd0e 100644 --- a/test/generic/MPoly-test.jl +++ b/test/generic/MPoly-test.jl @@ -1877,3 +1877,20 @@ end R2, (x2, y2) = polynomial_ring(QQ, [:x, :y]) @test_throws ErrorException z1 + y2 end + +@testset "Generic.MPoly.Iterators" begin + R, (x, y, z) = polynomial_ring(QQ, [:x, :y, :z]) + f = x * y + 2 * x - 3 * z + + @test @inferred collect(exponent_vectors(f)) == [[1, 1, 0], [1, 0, 0], [0, 0, 1]] + @test @inferred collect(exponent_vectors(Vector{UInt}, f)) == [UInt[1, 1, 0], UInt[1, 0, 0], UInt[0, 0, 1]] + @test @inferred collect(coefficients(f)) == [QQ(1), QQ(2), QQ(-3)] + @test @inferred collect(terms(f)) == [x * y, 2 * x, -3 * z] + @test @inferred collect(monomials(f)) == [x * y, x, z] + + @test @inferred first(exponent_vectors(f, inplace = true)) == [1, 1, 0] + @test @inferred first(exponent_vectors(Vector{UInt}, f, inplace = true)) == UInt[1, 1, 0] + @test @inferred first(coefficients(f, inplace = true)) == QQ(1) + @test @inferred first(monomials(f, inplace = true)) == x * y + @test @inferred first(terms(f, inplace = true)) == x * y +end