Skip to content

Commit 44fbdab

Browse files
authored
Switch from Requires.jl to extensions (#72)
* switch from Requires.jl to extensions * update test * update test
1 parent ca05f0f commit 44fbdab

File tree

10 files changed

+48
-55
lines changed

10 files changed

+48
-55
lines changed

Project.toml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ version = "1.3.6"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
8-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
98
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
109
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1110
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -19,15 +18,20 @@ Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
1918
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
2019
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2120
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
22-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2321
SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4"
2422
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
2523
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2624
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
2725

26+
[weakdeps]
27+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
28+
29+
[extensions]
30+
GenericTensorNetworksCUDAExt = "CUDA"
31+
2832
[compat]
2933
AbstractTrees = "0.3, 0.4"
30-
CUDA = "4"
34+
CUDA = "4, 5"
3135
DelimitedFiles = "1"
3236
DocStringExtensions = "0.8, 0.9"
3337
FFTW = "1.4"
@@ -37,15 +41,14 @@ Mods = "1.3"
3741
OMEinsum = "0.7"
3842
Polynomials = "4"
3943
Primes = "0.5"
40-
Requires = "1"
4144
SIMDTypes = "0.1"
4245
StatsBase = "0.33, 0.34"
4346
TropicalNumbers = "0.4, 0.5, 0.6"
44-
julia = "1"
47+
julia = "1.9"
4548

4649
[extras]
4750
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4851
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4952

5053
[targets]
51-
test = ["Test", "Documenter"]
54+
test = ["Test", "Documenter", "CUDA"]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ makedocs(;
5454
"References" => "ref.md",
5555
],
5656
doctest=false,
57+
warnonly = :missing_docs,
5758
)
5859

5960
deploydocs(;

docs/src/ref.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ get_weights
3333
chweights
3434
nflavor
3535
fixedvertices
36+
37+
extract_result
3638
```
3739

3840
#### Graph Problem Utilities
@@ -121,8 +123,6 @@ optimize_code
121123
getixsv
122124
getiyv
123125
contraction_complexity
124-
timespace_complexity
125-
timespacereadwrite_complexity
126126
estimate_memory
127127
@ein_str
128128
GreedyMethod

ext/GenericTensorNetworksCUDAExt.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module GenericTensorNetworksCUDAExt
2+
3+
using CUDA
4+
using GenericTensorNetworks
5+
import GenericTensorNetworks: onehotmask!, togpu
6+
7+
# patch
8+
# Base.ndims(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{0}}) = 0
9+
10+
togpu(x::AbstractArray) = CuArray(x)
11+
12+
function onehotmask!(A::CuArray{T}, X::CuArray{T}) where T
13+
@assert length(A) == length(X)
14+
mask = X .≈ inv.(A)
15+
ci = argmax(mask)
16+
mask .= false
17+
CUDA.@allowscalar mask[ci] = true
18+
# set some elements in X to zero to help back propagating.
19+
X[(!).(mask)] .= Ref(zero(T))
20+
return mask
21+
end
22+
23+
end

src/GenericTensorNetworks.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,4 @@ include("deprecate.jl")
7979
include("multiprocessing.jl")
8080
include("visualize.jl")
8181

82-
using Requires
83-
function __init__()
84-
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
85-
end
86-
8782
end

src/arithematics.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,22 +204,22 @@ Examples
204204
------------------------------
205205
```jldoctest; setup=(using GenericTensorNetworks)
206206
julia> x = ExtendedTropical{3}(Tropical.([1.0, 2, 3]))
207-
ExtendedTropical{3, TropicalF64}(TropicalF64[1.0ₜ, 2.0ₜ, 3.0ₜ])
207+
ExtendedTropical{3, Tropical{Float64}}(Tropical{Float64}[1.0ₜ, 2.0ₜ, 3.0ₜ])
208208
209209
julia> y = ExtendedTropical{3}(Tropical.([-Inf, 2, 5]))
210-
ExtendedTropical{3, TropicalF64}(TropicalF64[-Infₜ, 2.0ₜ, 5.0ₜ])
210+
ExtendedTropical{3, Tropical{Float64}}(Tropical{Float64}[-Infₜ, 2.0ₜ, 5.0ₜ])
211211
212212
julia> x * y
213-
ExtendedTropical{3, TropicalF64}(TropicalF64[6.0ₜ, 7.0ₜ, 8.0ₜ])
213+
ExtendedTropical{3, Tropical{Float64}}(Tropical{Float64}[6.0ₜ, 7.0ₜ, 8.0ₜ])
214214
215215
julia> x + y
216-
ExtendedTropical{3, TropicalF64}(TropicalF64[2.0ₜ, 3.0ₜ, 5.0ₜ])
216+
ExtendedTropical{3, Tropical{Float64}}(Tropical{Float64}[2.0ₜ, 3.0ₜ, 5.0ₜ])
217217
218218
julia> one(x)
219-
ExtendedTropical{3, TropicalF64}(TropicalF64[-Infₜ, -Infₜ, 0.0ₜ])
219+
ExtendedTropical{3, Tropical{Float64}}(Tropical{Float64}[-Infₜ, -Infₜ, 0.0ₜ])
220220
221221
julia> zero(x)
222-
ExtendedTropical{3, TropicalF64}(TropicalF64[-Infₜ, -Infₜ, -Infₜ])
222+
ExtendedTropical{3, Tropical{Float64}}(Tropical{Float64}[-Infₜ, -Infₜ, -Infₜ])
223223
```
224224
"""
225225
struct ExtendedTropical{K,TO} <: Number

src/configurations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ function best_solutions(gp::GraphProblem; all=false, usecuda=false, invert=false
2727
xst = generate_tensors(_x(Tropical{T}; invert), gp)
2828
ymask = trues(fill(2, length(getiyv(gp.code)))...)
2929
if usecuda
30-
xst = CuArray.(xst)
31-
ymask = CuArray(ymask)
30+
xst = togpu.(xst)
31+
ymask = togpu(ymask)
3232
end
3333
if all
3434
# we use `Float64` as default because we want to support weighted graphs.

src/cuda.jl

Lines changed: 0 additions & 32 deletions
This file was deleted.

src/networks/networks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ julia> getixsv(gp.code)
114114
[8, 10]
115115
116116
julia> gp.code(GenericTensorNetworks.generate_tensors(Tropical(1.0), gp)...)
117-
0-dimensional Array{TropicalF64, 0}:
117+
0-dimensional Array{Tropical{Float64}, 0}:
118118
4.0ₜ
119119
```
120120
"""
@@ -150,7 +150,7 @@ function contractx(gp::GraphProblem, x; usecuda=false)
150150
xs = generate_tensors(x, gp)
151151
@debug "contracting tensors ..."
152152
if usecuda
153-
gp.code([CuArray(x) for x in xs]...)
153+
gp.code([togpu(x) for x in xs]...)
154154
else
155155
gp.code(xs...)
156156
end

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ end
1414
# a unified interface to optimize the contraction code
1515
_optimize_code(code, size_dict, optimizer::Nothing, simplifier) = code
1616
_optimize_code(code, size_dict, optimizer, simplifier) = optimize_code(code, size_dict, optimizer, simplifier)
17+
18+
# upload tensors to GPU
19+
function togpu end

0 commit comments

Comments
 (0)