@@ -64,11 +64,11 @@ function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
64
64
end
65
65
function cached_einsum (code:: NestedEinsum , @nospecialize (xs), size_dict)
66
66
if OMEinsum. isleaf (code)
67
- y = xs[code . tensorindex]
67
+ y = xs[OMEinsum . tensorindex (code) ]
68
68
return CacheTree (y, CacheTree{eltype (y)}[])
69
69
else
70
- caches = [cached_einsum (arg, xs, size_dict) for arg in code . args ]
71
- y = einsum (code . eins , ntuple (i-> caches[i]. content, length (caches)), size_dict)
70
+ caches = [cached_einsum (arg, xs, size_dict) for arg in OMEinsum . siblings (code) ]
71
+ y = einsum (OMEinsum . rootcode (code) , ntuple (i-> caches[i]. content, length (caches)), size_dict)
72
72
return CacheTree (y, caches)
73
73
end
74
74
end
@@ -84,8 +84,9 @@ function generate_masktree(mode, code::NestedEinsum, cache, mask, size_dict)
84
84
if OMEinsum. isleaf (code)
85
85
return CacheTree (mask, CacheTree{Bool}[])
86
86
else
87
- submasks = backward_tropical (mode, getixs (code. eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (code. eins), cache. content, mask, size_dict)
88
- return CacheTree (mask, generate_masktree .(Ref (mode), code. args, cache. siblings, submasks, Ref (size_dict)))
87
+ eins = OMEinsum. rootcode (code)
88
+ submasks = backward_tropical (mode, getixs (eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (eins), cache. content, mask, size_dict)
89
+ return CacheTree (mask, generate_masktree .(Ref (mode), OMEinsum. siblings (code), cache. siblings, submasks, Ref (size_dict)))
89
90
end
90
91
end
91
92
@@ -98,12 +99,12 @@ function masked_einsum(se::SlicedEinsum, @nospecialize(xs), masks, size_dict)
98
99
end
99
100
function masked_einsum (code:: NestedEinsum , @nospecialize (xs), masks, size_dict)
100
101
if OMEinsum. isleaf (code)
101
- y = copy (xs[code . tensorindex])
102
+ y = copy (xs[OMEinsum . tensorindex (code) ])
102
103
y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y)))
103
104
return y
104
105
else
105
- xs = [masked_einsum (arg, xs, mask, size_dict) for (arg, mask) in zip (code . args , masks. siblings)]
106
- y = einsum (code . eins , (xs... ,), size_dict)
106
+ xs = [masked_einsum (arg, xs, mask, size_dict) for (arg, mask) in zip (OMEinsum . siblings (code) , masks. siblings)]
107
+ y = einsum (OMEinsum . rootcode (code) , (xs... ,), size_dict)
107
108
y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y)))
108
109
return y
109
110
end
@@ -121,10 +122,10 @@ Contraction method with bounding.
121
122
"""
122
123
function bounding_contract (mode:: AllConfigs , code:: EinCode , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
123
124
LT = OMEinsum. labeltype (code)
124
- bounding_contract (mode, NestedEinsum ( NestedEinsum {DynamicEinCode{LT} } .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
125
+ bounding_contract (mode, DynamicNestedEinsum ( DynamicNestedEinsum {LT } .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
125
126
end
126
127
function bounding_contract (mode:: AllConfigs , code:: Union{NestedEinsum,SlicedEinsum} , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
127
- size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code.eins ),Int} () : copy (size_info)
128
+ size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code),Int} () : copy (size_info)
128
129
OMEinsum. get_size_dict! (code, xsa, size_dict)
129
130
# compute intermediate tensors
130
131
@debug " caching einsum..."
@@ -139,11 +140,11 @@ end
139
140
# get the optimal solution with automatic differentiation.
140
141
function solution_ad (code:: EinCode , @nospecialize (xsa), ymask; size_info= nothing )
141
142
LT = OMEinsum. labeltype (code)
142
- solution_ad (NestedEinsum ( NestedEinsum {DynamicEinCode{LT} } .(1 : length (xsa)), code), xsa, ymask; size_info= size_info)
143
+ solution_ad (DynamicNestedEinsum ( DynamicNestedEinsum {LT } .(1 : length (xsa)), code), xsa, ymask; size_info= size_info)
143
144
end
144
145
145
146
function solution_ad (code:: Union{NestedEinsum,SlicedEinsum} , @nospecialize (xsa), ymask; size_info= nothing )
146
- size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code.eins ),Int} () : copy (size_info)
147
+ size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code),Int} () : copy (size_info)
147
148
OMEinsum. get_size_dict! (code, xsa, size_dict)
148
149
# compute intermediate tensors
149
150
@debug " caching einsum..."
@@ -165,7 +166,7 @@ function read_config!(code::SlicedEinsum, mt, out)
165
166
end
166
167
167
168
function read_config! (code:: NestedEinsum , mt, out)
168
- for (arg, ix, sibling) in zip (code . args , getixs (code . eins ), mt. siblings)
169
+ for (arg, ix, sibling) in zip (OMEinsum . siblings (code) , getixs (OMEinsum . rootcode (code) ), mt. siblings)
169
170
if OMEinsum. isleaf (arg)
170
171
mask = convert (Array, sibling. content) # note: the content can be CuArray
171
172
for ci in CartesianIndices (mask)
0 commit comments