@@ -5,7 +5,7 @@ using LinearAlgebra: Transpose, Adjoint,
55 Hermitian, Symmetric,
66 LowerTriangular, UnitLowerTriangular,
77 UpperTriangular, UnitUpperTriangular,
8- MulAddMul, wrap
8+ UpperOrLowerTriangular, MulAddMul, wrap
99
1010#
1111# BLAS 1
@@ -163,12 +163,50 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
163163 GPUArrays. generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
164164end
165165
166+ const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<: T ,<: oneStridedMatrix }}
167+
168+ function LinearAlgebra. generic_trimatmul! (
169+ C:: oneStridedMatrix{T} , uplocA, isunitcA,
170+ tfunA:: Function , A:: oneStridedMatrix{T} ,
171+ triB:: UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}} ,
172+ ) where {T<: onemklFloat }
173+ uplocB = LinearAlgebra. uplo_char (triB)
174+ isunitcB = LinearAlgebra. isunit_char (triB)
175+ B = parent (triB)
176+ tfunB = LinearAlgebra. wrapperop (B)
177+ transa = tfunA === identity ? ' N' : tfunA === transpose ? ' T' : ' C'
178+ transb = tfunB === identity ? ' N' : tfunB === transpose ? ' T' : ' C'
179+ if uplocA == ' L' && tfunA === identity && tfunB === identity && uplocB == ' U' && isunitcB == ' N' # lower * upper
180+ triu! (B)
181+ trmm! (' L' , uplocA, transa, isunitcA, one (T), A, B, C)
182+ elseif uplocA == ' U' && tfunA === identity && tfunB === identity && uplocB == ' L' && isunitcB == ' N' # upper * lower
183+ tril! (B)
184+ trmm! (' L' , uplocA, transa, isunitcA, one (T), A, B, C)
185+ elseif uplocA == ' U' && tfunA === identity && tfunB != = identity && uplocB == ' U' && isunitcA == ' N'
186+ # operation is reversed to avoid executing the tranpose
187+ triu! (A)
188+ trmm! (' R' , uplocB, transb, isunitcB, one (T), parent (B), A, C)
189+ elseif uplocA == ' L' && tfunA != = identity && tfunB === identity && uplocB == ' L' && isunitcB == ' N'
190+ tril! (B)
191+ trmm! (' L' , uplocA, transa, isunitcA, one (T), A, B, C)
192+ elseif uplocA == ' U' && tfunA != = identity && tfunB === identity && uplocB == ' U' && isunitcB == ' N'
193+ triu! (B)
194+ trmm! (' L' , uplocA, transa, isunitcA, one (T), A, B, C)
195+ elseif uplocA == ' L' && tfunA === identity && tfunB != = identity && uplocB == ' L' && isunitcA == ' N'
196+ tril! (A)
197+ trmm! (' R' , uplocB, transb, isunitcB, one (T), parent (B), A, C)
198+ else
199+ throw (" mixed triangular-triangular multiplication" ) # TODO : rethink
200+ end
201+ return C
202+ end
203+
166204# triangular
167205LinearAlgebra. generic_trimatmul! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix{T} , B:: oneStridedMatrix{T} ) where {T<: onemklFloat } =
168- trmm! (' L' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), A, C === B ? C : copyto! (C, B) )
206+ trmm! (' L' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), A, B, C )
169207LinearAlgebra. generic_mattrimul! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix{T} , B:: oneStridedMatrix{T} ) where {T<: onemklFloat } =
170- trmm! (' R' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), B, C === A ? C : copyto! (C, A) )
171- LinearAlgebra. generic_trimatdiv! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix{T} , B:: oneStridedMatrix {T} ) where {T<: onemklFloat } =
172- trsm! (' L' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), A, C === B ? C : copyto! (C, B) )
173- LinearAlgebra. generic_mattridiv! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix {T} , B:: oneStridedMatrix{T} ) where {T<: onemklFloat } =
174- trsm! (' R' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), B, C === A ? C : copyto! (C, A) )
208+ trmm! (' R' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), B, A, C )
209+ LinearAlgebra. generic_trimatdiv! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix{T} , B:: AbstractMatrix {T} ) where {T<: onemklFloat } =
210+ trsm! (' L' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), A, B, C )
211+ LinearAlgebra. generic_mattridiv! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: AbstractMatrix {T} , B:: oneStridedMatrix{T} ) where {T<: onemklFloat } =
212+ trsm! (' R' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), B, A, C )
0 commit comments