Skip to content

Commit 513ff50

Browse files
Fix(tridiagonal_matrices): spmv y vector overwrite and missing error check in impure init functions (#1054)
* Fix: spmv y-reset bug and add error guard in tridiagonal init * minor change * test(tridiagonal): added test cases for all alpha and beta combinations under spmv * use allocate for matrix elements, move call spmv to separate line
1 parent 03736fa commit 513ff50

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

src/stdlib_specialmatrices_tridiagonal.fypp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
6969
call linalg_error_handling(err0)
7070
endif
7171
! Matrix elements.
72-
A%dl = [(dl, i = 1, n-1)]
73-
A%dv = [(dv, i = 1, n)]
74-
A%du = [(du, i = 1, n-1)]
72+
allocate( A%dl(n-1), source = dl )
73+
allocate( A%dv(n), source= dv )
74+
allocate( A%du(n-1), source = du )
7575
end function
7676

7777
module function initialize_tridiagonal_impure_${s1}$(dl, dv, du, err) result(A)
@@ -103,10 +103,12 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
103103
call linalg_error_handling(err0, err)
104104
endif
105105

106-
! Description of the matrix.
107-
A%n = n
108-
! Matrix elements.
109-
A%dl = dl ; A%dv = dv ; A%du = du
106+
if(err0%ok()) then
107+
! Description of the matrix.
108+
A%n = n
109+
! Matrix elements.
110+
A%dl = dl ; A%dv = dv ; A%du = du
111+
endif
110112
end function
111113

112114
module function initialize_constant_tridiagonal_impure_${s1}$(dl, dv, du, n, err) result(A)
@@ -124,16 +126,19 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
124126
integer(ilp) :: i
125127
type(linalg_state_type) :: err0
126128

127-
! Description of the matrix.
128-
A%n = n
129129
if (n <= 0) then
130130
err0 = linalg_state_type(this, LINALG_VALUE_ERROR, "Matrix size needs to be positive, n = ", n, ".")
131131
call linalg_error_handling(err0, err)
132132
endif
133-
! Matrix elements.
134-
A%dl = [(dl, i = 1, n-1)]
135-
A%dv = [(dv, i = 1, n)]
136-
A%du = [(du, i = 1, n-1)]
133+
134+
if(err0%ok()) then
135+
! Description of the matrix.
136+
A%n = n
137+
! Matrix elements.
138+
allocate( A%dl(n-1), source = dl )
139+
allocate( A%dv(n), source= dv )
140+
allocate( A%du(n-1), source = du )
141+
endif
137142
end function
138143
#:endfor
139144

@@ -168,7 +173,7 @@ submodule (stdlib_specialmatrices) tridiagonal_matrices
168173
op_ = "N" ; if (present(op)) op_ = op
169174

170175
! Prepare Lapack arguments.
171-
n = A%n ; ldx = n ; ldy = n ; y = 0.0_${k1}$
176+
n = A%n ; ldx = n ; ldy = n ;
172177
nrhs = #{if rank==1}# 1 #{else}# size(x, dim=2, kind=ilp) #{endif}#
173178

174179
#:if rank == 1

test/linalg/test_linalg_specialmatrices.fypp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ contains
3535
${t1}$, allocatable :: Amat(:,:), dl(:), dv(:), du(:)
3636
${t1}$, allocatable :: x(:)
3737
${t1}$, allocatable :: y1(:), y2(:)
38+
${t1}$ :: alpha, beta
39+
40+
integer :: i, j
41+
${t1}$, parameter :: coeffs(3) = [-1.0_wp, 0.0_wp, 1.0_wp]
3842

3943
! Initialize matrix.
4044
allocate(dl(n-1), dv(n), du(n-1))
@@ -56,13 +60,28 @@ contains
5660
call check(error, all_close(y1, y2), .true.)
5761
if (allocated(error)) return
5862

59-
#:if t1.startswith('complex')
63+
#:if t1.startswith('complex')
6064
! Test y = A.H @ x
6165
y1 = 0.0_wp ; y2 = 0.0_wp
6266
y1 = matmul(hermitian(Amat), x) ; call spmv(A, x, y2, op="H")
6367
call check(error, all_close(y1, y2), .true.)
6468
if (allocated(error)) return
6569
#:endif
70+
71+
! Test y = alpha * A @ x + beta * y for alpha,beta in {-1,0,1}
72+
do i = 1, 3
73+
do j = 1,3
74+
alpha = coeffs(i)
75+
beta = coeffs(j)
76+
77+
y1 = 0.0_wp
78+
call random_number(y2)
79+
y1 = alpha * matmul(Amat, x) + beta * y2
80+
call spmv(A, x, y2, alpha=alpha, beta=beta)
81+
call check(error, all_close(y1, y2), .true.)
82+
if (allocated(error)) return
83+
end do
84+
end do
6685
end block
6786
#:endfor
6887
end subroutine
@@ -91,7 +110,7 @@ contains
91110
call check(error, state%ok(), .false.)
92111
if (allocated(error)) return
93112
end block
94-
#:endfor
113+
#:endfor
95114
end subroutine
96115

97116
end module

0 commit comments

Comments
 (0)