|
| 1 | +const ROCTensorMap{T,S,N₁,N₂} = TensorMap{T,S,N₁,N₂, ROCVector{T,AMDGPU.DeviceMemory}} |
| 2 | +const ROCTensor{T, S, N} = ROCTensorMap{T, S, N, 0} |
| 3 | + |
| 4 | +const AdjointROCTensorMap{T,S,N₁,N₂} = AdjointTensorMap{T,S,N₁,N₂,ROCTensorMap{T,S,N₁,N₂}} |
| 5 | + |
| 6 | +function TensorKit.tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type{<:StridedROCArray}) |
| 7 | + if TorA <: ROCArray |
| 8 | + return TensorMap{eltype(TorA),S,N₁,N₂,ROCVector{eltype(TorA), AMDGPU.DeviceMemory}} |
| 9 | + else |
| 10 | + throw(ArgumentError("argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:ROCVector{<:Number}`")) |
| 11 | + end |
| 12 | +end |
| 13 | + |
| 14 | +function ROCTensorMap{T}(::UndefInitializer, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂} |
| 15 | + return ROCTensorMap{T,S,N₁,N₂}(undef, V) |
| 16 | +end |
| 17 | + |
| 18 | +function ROCTensorMap{T}(::UndefInitializer, codomain::TensorSpace{S}, |
| 19 | + domain::TensorSpace{S}) where {T,S} |
| 20 | + return ROCTensorMap{T}(undef, codomain ← domain) |
| 21 | +end |
| 22 | +function ROCTensor{T}(::UndefInitializer, V::TensorSpace{S}) where {T,S} |
| 23 | + return ROCTensorMap{T}(undef, V ← one(V)) |
| 24 | +end |
| 25 | +# constructor starting from block data |
| 26 | +""" |
| 27 | + ROCTensorMap(data::AbstractDict{<:Sector,<:ROCMatrix}, codomain::ProductSpace{S,N₁}, |
| 28 | + domain::ProductSpace{S,N₂}) where {S<:ElementarySpace,N₁,N₂} |
| 29 | + ROCTensorMap(data, codomain ← domain) |
| 30 | + ROCTensorMap(data, domain → codomain) |
| 31 | +
|
| 32 | +Construct a `ROCTensorMap` by explicitly specifying its block data. |
| 33 | +
|
| 34 | +## Arguments |
| 35 | +- `data::AbstractDict{<:Sector,<:ROCMatrix}`: dictionary containing the block data for |
| 36 | + each coupled sector `c` as a matrix of size `(blockdim(codomain, c), blockdim(domain, c))`. |
| 37 | +- `codomain::ProductSpace{S,N₁}`: the codomain as a `ProductSpace` of `N₁` spaces of type |
| 38 | + `S<:ElementarySpace`. |
| 39 | +- `domain::ProductSpace{S,N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type |
| 40 | + `S<:ElementarySpace`. |
| 41 | +
|
| 42 | +Alternatively, the domain and codomain can be specified by passing a [`HomSpace`](@ref) |
| 43 | +using the syntax `codomain ← domain` or `domain → codomain`. |
| 44 | +""" |
| 45 | +function ROCTensorMap(data::AbstractDict{<:Sector,<:ROCArray}, |
| 46 | + V::TensorMapSpace{S,N₁,N₂}) where {S,N₁,N₂} |
| 47 | + T = eltype(valtype(data)) |
| 48 | + t = ROCTensorMap{T}(undef, V) |
| 49 | + for (c, b) in blocks(t) |
| 50 | + haskey(data, c) || throw(SectorMismatch("no data for block sector $c")) |
| 51 | + datac = data[c] |
| 52 | + size(datac) == size(b) || |
| 53 | + throw(DimensionMismatch("wrong size of block for sector $c")) |
| 54 | + copy!(b, datac) |
| 55 | + end |
| 56 | + for (c, b) in data |
| 57 | + c ∈ blocksectors(t) || isempty(b) || |
| 58 | + throw(SectorMismatch("data for block sector $c not expected")) |
| 59 | + end |
| 60 | + return t |
| 61 | +end |
| 62 | +function ROCTensorMap{T}(data::DenseVector{T}, codomain::TensorSpace{S}, |
| 63 | + domain::TensorSpace{S}) where {T,S} |
| 64 | + return ROCTensorMap(data, codomain ← domain) |
| 65 | +end |
| 66 | +function ROCTensorMap(data::AbstractDict{<:Sector,<:ROCMatrix}, codom::TensorSpace{S}, |
| 67 | + dom::TensorSpace{S}) where {S} |
| 68 | + return ROCTensorMap(data, codom ← dom) |
| 69 | +end |
| 70 | + |
| 71 | +for (fname, felt) in ((:zeros, :zero), (:ones, :one)) |
| 72 | + @eval begin |
| 73 | + function AMDGPU.$fname(codomain::TensorSpace{S}, |
| 74 | + domain::TensorSpace{S}=one(codomain)) where {S<:IndexSpace} |
| 75 | + return AMDGPU.$fname(codomain ← domain) |
| 76 | + end |
| 77 | + function AMDGPU.$fname(::Type{T}, codomain::TensorSpace{S}, |
| 78 | + domain::TensorSpace{S}=one(codomain)) where {T,S<:IndexSpace} |
| 79 | + return AMDGPU.$fname(T, codomain ← domain) |
| 80 | + end |
| 81 | + AMDGPU.$fname(V::TensorMapSpace) = AMDGPU.$fname(Float64, V) |
| 82 | + function AMDGPU.$fname(::Type{T}, V::TensorMapSpace) where {T} |
| 83 | + t = ROCTensorMap{T}(undef, V) |
| 84 | + fill!(t, $felt(T)) |
| 85 | + return t |
| 86 | + end |
| 87 | + end |
| 88 | +end |
| 89 | + |
| 90 | +for randfun in (:curand, :curandn) |
| 91 | + randfun! = Symbol(randfun, :!) |
| 92 | + @eval begin |
| 93 | + # converting `codomain` and `domain` into `HomSpace` |
| 94 | + function $randfun(codomain::TensorSpace{S}, |
| 95 | + domain::TensorSpace{S}) where {S<:IndexSpace} |
| 96 | + return $randfun(codomain ← domain) |
| 97 | + end |
| 98 | + function $randfun(::Type{T}, codomain::TensorSpace{S}, |
| 99 | + domain::TensorSpace{S}) where {T,S<:IndexSpace} |
| 100 | + return $randfun(T, codomain ← domain) |
| 101 | + end |
| 102 | + function $randfun(rng::Random.AbstractRNG, ::Type{T}, |
| 103 | + codomain::TensorSpace{S}, |
| 104 | + domain::TensorSpace{S}) where {T,S<:IndexSpace} |
| 105 | + return $randfun(rng, T, codomain ← domain) |
| 106 | + end |
| 107 | + |
| 108 | + # accepting single `TensorSpace` |
| 109 | + $randfun(codomain::TensorSpace) = $randfun(codomain ← one(codomain)) |
| 110 | + function $randfun(::Type{T}, codomain::TensorSpace) where {T} |
| 111 | + return $randfun(T, codomain ← one(codomain)) |
| 112 | + end |
| 113 | + function $randfun(rng::Random.AbstractRNG, ::Type{T}, |
| 114 | + codomain::TensorSpace) where {T} |
| 115 | + return $randfun(rng, T, codomain ← one(domain)) |
| 116 | + end |
| 117 | + |
| 118 | + # filling in default eltype |
| 119 | + $randfun(V::TensorMapSpace) = $randfun(Float64, V) |
| 120 | + function $randfun(rng::Random.AbstractRNG, V::TensorMapSpace) |
| 121 | + return $randfun(rng, Float64, V) |
| 122 | + end |
| 123 | + |
| 124 | + # filling in default rng |
| 125 | + function $randfun(::Type{T}, V::TensorMapSpace) where {T} |
| 126 | + return $randfun(Random.default_rng(), T, V) |
| 127 | + end |
| 128 | + |
| 129 | + # implementation |
| 130 | + function $randfun(rng::Random.AbstractRNG, ::Type{T}, |
| 131 | + V::TensorMapSpace) where {T} |
| 132 | + t = ROCTensorMap{T}(undef, V) |
| 133 | + $randfun!(rng, t) |
| 134 | + return t |
| 135 | + end |
| 136 | + end |
| 137 | +end |
| 138 | + |
| 139 | +# converters |
| 140 | +# ---------- |
| 141 | +function Base.convert(::Type{ROCTensorMap}, d::Dict{Symbol,Any}) |
| 142 | + try |
| 143 | + codomain = eval(Meta.parse(d[:codomain])) |
| 144 | + domain = eval(Meta.parse(d[:domain])) |
| 145 | + data = SectorDict(eval(Meta.parse(c)) => ROCArray(b) for (c, b) in d[:data]) |
| 146 | + return ROCTensorMap(data, codomain, domain) |
| 147 | + catch e # sector unknown in TensorKit.jl; user-defined, hopefully accessible in Main |
| 148 | + codomain = Base.eval(Main, Meta.parse(d[:codomain])) |
| 149 | + domain = Base.eval(Main, Meta.parse(d[:domain])) |
| 150 | + data = SectorDict(Base.eval(Main, Meta.parse(c)) => ROCArray(b) |
| 151 | + for (c, b) in d[:data]) |
| 152 | + return ROCTensorMap(data, codomain, domain) |
| 153 | + end |
| 154 | +end |
| 155 | + |
| 156 | +function Base.convert(::Type{ROCTensorMap}, t::AbstractTensorMap) |
| 157 | + return copy!(ROCTensorMap{scalartype(t)}(undef, space(t)), t) |
| 158 | +end |
| 159 | + |
| 160 | +# Scalar implementation |
| 161 | +#----------------------- |
| 162 | +function TensorKit.scalar(t::ROCTensorMap) |
| 163 | + |
| 164 | + # TODO: should scalar only work if N₁ == N₂ == 0? |
| 165 | + return @allowscalar dim(codomain(t)) == dim(domain(t)) == 1 ? |
| 166 | + first(blocks(t))[2][1, 1] : throw(DimensionMismatch()) |
| 167 | +end |
| 168 | + |
| 169 | +TensorKit.scalartype(A::StridedROCArray{T}) where {T} = T |
| 170 | +vi_scalartype(::Type{<:ROCTensorMap{T}}) where {T} = T |
| 171 | +vi_scalartype(::Type{<:ROCArray{T}}) where {T} = T |
| 172 | + |
| 173 | +function TensorKit.similarstoragetype(TT::Type{<:ROCTensorMap{TTT,S,N₁,N₂}}, ::Type{T}) where {TTT,T,S,N₁,N₂} |
| 174 | + return ROCVector{T, AMDGPU.DeviceMemory} |
| 175 | +end |
| 176 | + |
| 177 | +function Base.convert(TT::Type{ROCTensorMap{T,S,N₁,N₂}}, |
| 178 | + t::AbstractTensorMap{<:Any,S,N₁,N₂}) where {T,S,N₁,N₂} |
| 179 | + if typeof(t) === TT |
| 180 | + return t |
| 181 | + else |
| 182 | + tnew = TT(undef, space(t)) |
| 183 | + return copy!(tnew, t) |
| 184 | + end |
| 185 | +end |
| 186 | + |
| 187 | +function Base.copy!(tdst::ROCTensorMap{T, S, N₁, N₂}, tsrc::ROCTensorMap{T, S, N₁, N₂}) where {T, S, N₁, N₂} |
| 188 | + space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst)) ≠ $(space(tsrc))")) |
| 189 | + for ((c, bdst), (_, bsrc)) in zip(blocks(tdst), blocks(tsrc)) |
| 190 | + copy!(bdst, bsrc) |
| 191 | + end |
| 192 | + return tdst |
| 193 | +end |
| 194 | + |
| 195 | +function Base.copy!(tdst::ROCTensorMap, tsrc::TensorKit.AdjointTensorMap) |
| 196 | + space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst)) ≠ $(space(tsrc))")) |
| 197 | + for ((c, bdst), (_, bsrc)) in zip(blocks(tdst), blocks(tsrc)) |
| 198 | + copy!(bdst, bsrc) |
| 199 | + end |
| 200 | + return tdst |
| 201 | +end |
| 202 | + |
| 203 | +function Base.promote_rule(::Type{<:TT₁}, |
| 204 | + ::Type{<:TT₂}) where {S,N₁,N₂, TTT₁, TTT₂, |
| 205 | + TT₁<:ROCTensorMap{TTT₁,S,N₁,N₂}, |
| 206 | + TT₂<:ROCTensorMap{TTT₂,S,N₁,N₂}} |
| 207 | + T = TensorKit.VectorInterface.promote_add(TTT₁, TTT₂) |
| 208 | + return ROCTensorMap{T,S,N₁,N₂} |
| 209 | +end |
| 210 | + |
| 211 | +function LinearAlgebra.isposdef(t::ROCTensorMap) |
| 212 | + domain(t) == codomain(t) || |
| 213 | + throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) |
| 214 | + InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false |
| 215 | + for (c, b) in blocks(t) |
| 216 | + # do our own hermitian check |
| 217 | + isherm = TensorKit.MatrixAlgebraKit.ishermitian(b; atol=eps(real(eltype(b))), rtol=eps(real(eltype(b)))) |
| 218 | + isherm || return false |
| 219 | + isposdef(Hermitian(b)) || return false |
| 220 | + end |
| 221 | + return true |
| 222 | +end |
0 commit comments