Skip to content

Commit 9dd1255

Browse files
authored
upgrade omeinsum (#51)
* upgrade omeinsum
1 parent b9027a7 commit 9dd1255

File tree

9 files changed

+22
-25
lines changed

9 files changed

+22
-25
lines changed

Project.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GenericTensorNetworks"
22
uuid = "3521c873-ad32-4bb4-b63d-f4f178f42b49"
33
authors = ["GiggleLiu <cacate0129@gmail.com> and contributors"]
4-
version = "1.2.0"
4+
version = "1.2.1"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -14,7 +14,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc"
1515
Mods = "7475f97c-0381-53b1-977b-4c60186c8d62"
1616
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
17-
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
1817
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
1918
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
2019
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -26,14 +25,13 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2625
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
2726

2827
[compat]
29-
AbstractTrees = "0.3, 0.4"
28+
AbstractTrees = "0.4"
3029
CUDA = "3"
3130
FFTW = "1.4"
3231
Graphs = "1.7"
3332
LuxorGraphPlot = "0.1"
3433
Mods = "1.3"
35-
OMEinsum = "0.6.1"
36-
OMEinsumContractionOrders = "0.6"
34+
OMEinsum = "0.7"
3735
Polynomials = "2.0, 3"
3836
Primes = "0.5"
3937
Requires = "1"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Pkg
22
using GenericTensorNetworks
3-
using GenericTensorNetworks: TropicalNumbers, Polynomials, Mods, OMEinsum, OMEinsumContractionOrders, LuxorGraphPlot
3+
using GenericTensorNetworks: TropicalNumbers, Polynomials, Mods, OMEinsum, OMEinsum.OMEinsumContractionOrders, LuxorGraphPlot
44
using Documenter
55
using DocThemeIndigo
66
using PlutoStaticHTML

docs/src/gist.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ They can be installed in a similar way to `GenericTensorNetworks`.
1313
After installing the required packages, one can open a Julia REPL, and copy-paste the following code snippet into it.
1414

1515
```julia
16-
using OMEinsum, OMEinsumContractionOrders
16+
using OMEinsum
1717
using Graphs
1818
using Random
1919

src/GenericTensorNetworks.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module GenericTensorNetworks
22

3-
using OMEinsumContractionOrders
43
using Core: Argument
54
using TropicalNumbers
65
using OMEinsum

src/bounding.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
6464
end
6565
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
6666
if OMEinsum.isleaf(code)
67-
y = xs[code.tensorindex]
67+
y = xs[OMEinsum.tensorindex(code)]
6868
return CacheTree(y, CacheTree{eltype(y)}[])
6969
else
70-
caches = [cached_einsum(arg, xs, size_dict) for arg in code.args]
71-
y = einsum(code.eins, ntuple(i->caches[i].content, length(caches)), size_dict)
70+
caches = [cached_einsum(arg, xs, size_dict) for arg in OMEinsum.siblings(code)]
71+
y = einsum(OMEinsum.rootcode(code), ntuple(i->caches[i].content, length(caches)), size_dict)
7272
return CacheTree(y, caches)
7373
end
7474
end
@@ -84,8 +84,9 @@ function generate_masktree(mode, code::NestedEinsum, cache, mask, size_dict)
8484
if OMEinsum.isleaf(code)
8585
return CacheTree(mask, CacheTree{Bool}[])
8686
else
87-
submasks = backward_tropical(mode, getixs(code.eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(code.eins), cache.content, mask, size_dict)
88-
return CacheTree(mask, generate_masktree.(Ref(mode), code.args, cache.siblings, submasks, Ref(size_dict)))
87+
eins = OMEinsum.rootcode(code)
88+
submasks = backward_tropical(mode, getixs(eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(eins), cache.content, mask, size_dict)
89+
return CacheTree(mask, generate_masktree.(Ref(mode), OMEinsum.siblings(code), cache.siblings, submasks, Ref(size_dict)))
8990
end
9091
end
9192

@@ -98,12 +99,12 @@ function masked_einsum(se::SlicedEinsum, @nospecialize(xs), masks, size_dict)
9899
end
99100
function masked_einsum(code::NestedEinsum, @nospecialize(xs), masks, size_dict)
100101
if OMEinsum.isleaf(code)
101-
y = copy(xs[code.tensorindex])
102+
y = copy(xs[OMEinsum.tensorindex(code)])
102103
y[OMEinsum.asarray(.!masks.content)] .= Ref(zero(eltype(y)))
103104
return y
104105
else
105-
xs = [masked_einsum(arg, xs, mask, size_dict) for (arg, mask) in zip(code.args, masks.siblings)]
106-
y = einsum(code.eins, (xs...,), size_dict)
106+
xs = [masked_einsum(arg, xs, mask, size_dict) for (arg, mask) in zip(OMEinsum.siblings(code), masks.siblings)]
107+
y = einsum(OMEinsum.rootcode(code), (xs...,), size_dict)
107108
y[OMEinsum.asarray(.!masks.content)] .= Ref(zero(eltype(y)))
108109
return y
109110
end
@@ -121,10 +122,10 @@ Contraction method with bounding.
121122
"""
122123
function bounding_contract(mode::AllConfigs, code::EinCode, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
123124
LT = OMEinsum.labeltype(code)
124-
bounding_contract(mode, NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
125+
bounding_contract(mode, DynamicNestedEinsum(DynamicNestedEinsum{LT}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
125126
end
126127
function bounding_contract(mode::AllConfigs, code::Union{NestedEinsum,SlicedEinsum}, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
127-
size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code.eins),Int}() : copy(size_info)
128+
size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code),Int}() : copy(size_info)
128129
OMEinsum.get_size_dict!(code, xsa, size_dict)
129130
# compute intermediate tensors
130131
@debug "caching einsum..."
@@ -139,11 +140,11 @@ end
139140
# get the optimal solution with automatic differentiation.
140141
function solution_ad(code::EinCode, @nospecialize(xsa), ymask; size_info=nothing)
141142
LT = OMEinsum.labeltype(code)
142-
solution_ad(NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask; size_info=size_info)
143+
solution_ad(DynamicNestedEinsum(DynamicNestedEinsum{LT}.(1:length(xsa)), code), xsa, ymask; size_info=size_info)
143144
end
144145

145146
function solution_ad(code::Union{NestedEinsum,SlicedEinsum}, @nospecialize(xsa), ymask; size_info=nothing)
146-
size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code.eins),Int}() : copy(size_info)
147+
size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code),Int}() : copy(size_info)
147148
OMEinsum.get_size_dict!(code, xsa, size_dict)
148149
# compute intermediate tensors
149150
@debug "caching einsum..."
@@ -165,7 +166,7 @@ function read_config!(code::SlicedEinsum, mt, out)
165166
end
166167

167168
function read_config!(code::NestedEinsum, mt, out)
168-
for (arg, ix, sibling) in zip(code.args, getixs(code.eins), mt.siblings)
169+
for (arg, ix, sibling) in zip(OMEinsum.siblings(code), getixs(OMEinsum.rootcode(code)), mt.siblings)
169170
if OMEinsum.isleaf(arg)
170171
mask = convert(Array, sibling.content) # note: the content can be CuArray
171172
for ci in CartesianIndices(mask)

test/arithematics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using GenericTensorNetworks, Test, OMEinsum, OMEinsumContractionOrders
1+
using GenericTensorNetworks, Test, OMEinsum
22
using Mods, Polynomials, TropicalNumbers
33
using Graphs, Random
44
using GenericTensorNetworks: StaticBitVector

test/configurations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using GenericTensorNetworks, Test, Graphs
22
using OMEinsum
33
using TropicalNumbers: CountingTropicalF64
4-
using OMEinsumContractionOrders: uniformsize
54
using GenericTensorNetworks: _onehotv, _x, sampler_type, set_type, best_solutions, best2_solutions, solutions, all_solutions, bestk_solutions, AllConfigs, SingleConfig, max_size, max_size_count
65

76
@testset "Config types" begin

test/graph_polynomials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using GenericTensorNetworks, Test, OMEinsum, OMEinsumContractionOrders
1+
using GenericTensorNetworks, Test, OMEinsum
22
using Mods, Polynomials, TropicalNumbers
33
using Graphs, Random
44
using GenericTensorNetworks: StaticBitVector, graph_polynomial

test/networks/MaximalIS.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ end
3939
@test min2.coeffs == (2, 150, 7510)
4040

4141
for bounded in [false, true]
42-
println("bounded = ", bounded, ", configs max1")
42+
@info("bounded = ", bounded, ", configs max1")
4343
@test length(solve(MaximalIS(g), ConfigsMin(; bounded=bounded))[].c) == 2
4444
println("bounded = ", bounded, ", configs max3")
4545
cmin2 = solve(MaximalIS(g), ConfigsMin(3; bounded=bounded))[]

0 commit comments

Comments
 (0)