Skip to content

Commit e95e68f

Browse files
committed
update plot
1 parent 5cb542a commit e95e68f

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ CUDA = "3"
3131
DocStringExtensions = "0.8, 0.9"
3232
FFTW = "1.4"
3333
Graphs = "1.7"
34-
LuxorGraphPlot = "0.1"
34+
LuxorGraphPlot = "0.2"
3535
Mods = "1.3"
3636
OMEinsum = "0.7"
3737
Polynomials = "2.0, 3"

src/arithematics.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ end
138138
return TruncatedPoly(a.coeffs .+ b.coeffs, a.maxorder)
139139
elseif a.maxorder > b.maxorder
140140
offset = a.maxorder - b.maxorder
141-
return TruncatedPoly((@ntuple $K i->i+offset <= $K ? a.coeffs[i] + b.coeffs[i+offset] : a.coeffs[i]), a.maxorder)
141+
return TruncatedPoly((@ntuple $K i->i+offset <= $K ? a.coeffs[convert(Int, i)] + b.coeffs[convert(Int, i+offset)] : a.coeffs[convert(Int, i)]), a.maxorder)
142142
else
143143
offset = b.maxorder - a.maxorder
144-
return TruncatedPoly((@ntuple $K i->i+offset <= $K ? b.coeffs[i] + a.coeffs[i+offset] : b.coeffs[i]), b.maxorder)
144+
return TruncatedPoly((@ntuple $K i->i+offset <= $K ? b.coeffs[convert(Int, i)] + a.coeffs[convert(Int, i+offset)] : b.coeffs[convert(Int, i)]), b.maxorder)
145145
end
146146
end
147147
end
@@ -796,7 +796,7 @@ for F in [:set_type, :sampler_type, :treeset_type]
796796
function $F(::Type{T}, n::Int, nflavor::Int) where {TV, T<:CountingTropical{TV}}
797797
CountingTropical{TV, $F(n,nflavor)}
798798
end
799-
function $F(::Type{Real}, n::Int, nflavor::Int) where {TV}
799+
function $F(::Type{Real}, n::Int, nflavor::Int)
800800
$F(n, nflavor)
801801
end
802802
end

src/interfaces.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -438,14 +438,18 @@ end
438438
for (PROP, ET) in [
439439
(:(PartitionFunction{T}), :(T)),
440440
(:(SizeMax{Single}), :(Tropical{T})), (:(SizeMin{Single}), :(Tropical{T})),
441-
(:(SizeMax{K}), :(ExtendedTropical{K,Tropical{T}})), (:(SizeMin{K}), :(ExtendedTropical{K,Tropical{T}})),
442441
(:(CountingAll), :T), (:(CountingMax{Single}), :(CountingTropical{T,T})), (:(CountingMin{Single}), :(CountingTropical{T,T})),
443-
(:(CountingMax{K}), :(TruncatedPoly{K,T,T})), (:(CountingMin{K}), :(TruncatedPoly{K,T,T})),
444-
(:(GraphPolynomial{:finitefield}), :(Mod{N,Int32} where N)), (:(GraphPolynomial{:fft}), :(Complex{T})),
445442
(:(GraphPolynomial{:polynomial}), :(Polynomial{T, :x})), (:(GraphPolynomial{:fitting}), :T),
446-
(:(GraphPolynomial{:laurent}), :(LaurentPolynomial{T, :x}))
443+
(:(GraphPolynomial{:laurent}), :(LaurentPolynomial{T, :x})), (:(GraphPolynomial{:fft}), :(Complex{T})),
444+
(:(GraphPolynomial{:finitefield}), :(Mod{N,Int32} where N))
445+
]
446+
@eval tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::$PROP) where {T} = $ET
447+
end
448+
for (PROP, ET) in [
449+
(:(SizeMax{K}), :(ExtendedTropical{K,Tropical{T}})), (:(SizeMin{K}), :(ExtendedTropical{K,Tropical{T}})),
450+
(:(CountingMax{K}), :(TruncatedPoly{K,T,T})), (:(CountingMin{K}), :(TruncatedPoly{K,T,T})),
447451
]
448-
@eval tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::$PROP) where {T,K} = $ET
452+
@eval tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::$PROP) where {T, K} = $ET
449453
end
450454

451455
function tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::PROP) where {T, K, BOUNDED, PROP<:Union{SingleConfigMax{K,BOUNDED},SingleConfigMin{K,BOUNDED}}}
@@ -461,10 +465,17 @@ end
461465

462466
for (PROP, ET) in [
463467
(:(ConfigsMax{Single}), :(CountingTropical{T,T})), (:(ConfigsMin{Single}), :(CountingTropical{T,T})),
464-
(:(ConfigsMax{K}), :(TruncatedPoly{K,T,T})), (:(ConfigsMin{K}), :(TruncatedPoly{K,T,T})),
465468
(:(ConfigsAll), :(Real))
466469
]
467-
@eval function tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::$PROP) where {T,K}
470+
@eval function tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::$PROP) where {T}
471+
set_type($ET, n, nflavor)
472+
end
473+
end
474+
475+
for (PROP, ET) in [
476+
(:(ConfigsMax{K}), :(TruncatedPoly{K,T,T})), (:(ConfigsMin{K}), :(TruncatedPoly{K,T,T})),
477+
]
478+
@eval function tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::$PROP) where {T, K}
468479
set_type($ET, n, nflavor)
469480
end
470481
end

src/visualize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function show_einsum(ein::AbstractEinsum;
5454
annotate_tensors=false,
5555
tensor_locs=nothing,
5656
label_locs=nothing, # dict
57-
spring::Bool=true,
57+
layout::Symbol=:auto,
5858
optimal_distance=1.0,
5959
kwargs...
6060
)
@@ -75,7 +75,7 @@ function show_einsum(ein::AbstractEinsum;
7575
end
7676
end
7777
if label_locs === nothing && tensor_locs === nothing
78-
locs = LuxorGraphPlot.autolocs(graph, nothing, spring, optimal_distance, trues(nv(graph)))
78+
locs = LuxorGraphPlot.autolocs(graph, nothing, layout, optimal_distance, trues(nv(graph)))
7979
elseif label_locs === nothing
8080
# infer label locs from tensor locs
8181
label_locs = [(lst = [iloc for (iloc,ix) in zip(tensor_locs, ixs) if l ix]; reduce((x,y)->x .+ y, lst) ./ length(lst)) for l in labels]

0 commit comments

Comments
 (0)