Skip to content

Refactor: Encapsulate the BLAS and LAPACK interfaces #6415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,14 @@ if(ENABLE_ASAN)
target_link_libraries(${ABACUS_BIN_NAME} -fsanitize=address)
endif()

add_library(math_connector OBJECT
${ABACUS_SOURCE_DIR}/source_base/module_external/blas_connector_base.cpp
${ABACUS_SOURCE_DIR}/source_base/module_external/blas_connector_l1.cpp
${ABACUS_SOURCE_DIR}/source_base/module_external/blas_connector_l2.cpp
${ABACUS_SOURCE_DIR}/source_base/module_external/blas_connector_l3.cpp
${ABACUS_SOURCE_DIR}/source_base/module_external/lapack_connector.cpp
)
list(APPEND math_libs math_connector)
if(DEFINED ENV{MKLROOT} AND NOT DEFINED MKLROOT)
set(MKLROOT "$ENV{MKLROOT}")
endif()
Expand All @@ -479,7 +487,7 @@ else()
find_package(FFTW3 REQUIRED)
find_package(Lapack REQUIRED)
include_directories(${FFTW3_INCLUDE_DIRS})
list(APPEND math_libs FFTW3::FFTW3 LAPACK::LAPACK BLAS::BLAS)
list(APPEND math_libs FFTW3::FFTW3 LAPACKE::LAPACKE BLAS::BLAS)
find_package(ScaLAPACK REQUIRED)
list(APPEND math_libs ScaLAPACK::ScaLAPACK)
if(USE_OPENMP)
Expand Down
25 changes: 20 additions & 5 deletions cmake/FindLapack.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,24 @@ endif()
find_package(Blas REQUIRED)
find_package(LAPACK REQUIRED)

if(NOT TARGET LAPACK::LAPACK)
add_library(LAPACK::LAPACK UNKNOWN IMPORTED)
set_target_properties(LAPACK::LAPACK PROPERTIES
IMPORTED_LINK_INTERFACE_LANGUAGES "C"
IMPORTED_LOCATION "${LAPACK_LIBRARIES}")
find_path(LAPACKE_INCLUDE_DIR
NAMES lapacke.h
PATHS ${LAPACK_DIR} ${LAPACKE_DIR} ${CMAKE_PREFIX_PATH}
PATH_SUFFIXES include include/lapacke
DOC "Path to LAPACKE include directory"
)
# find LAPACKE libraries
find_library(LAPACKE_LIBRARY
NAMES lapacke
PATHS ${LAPACK_DIR} ${LAPACK_LIBRARIES} ${CMAKE_PREFIX_PATH}
PATH_SUFFIXES lib lib64
DOC "Path to LAPACKE library"
)

if(NOT TARGET LAPACKE::LAPACKE)
add_library(LAPACKE::LAPACKE UNKNOWN IMPORTED)
set_target_properties(LAPACKE::LAPACKE PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${LAPACKE_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "C"
IMPORTED_LOCATION "${LAPACKE_LIBRARY}")
endif()
6 changes: 4 additions & 2 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ OBJS_MAIN=main.o\
OBJS_BASE=abfs-vector3_order.o\
assoc_laguerre.o\
blas_connector_base.o\
blas_connector_vector.o\
blas_connector_matrix.o\
blas_connector_l1.o\
blas_connector_l2.o\
blas_connector_l3.o\
lapack_connector.o\
complexarray.o\
complexmatrix.o\
clebsch_gordan_coeff.o\
Expand Down
3 changes: 0 additions & 3 deletions source/source_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ add_library(
base
OBJECT
assoc_laguerre.cpp
module_external/blas_connector_base.cpp
module_external/blas_connector_vector.cpp
module_external/blas_connector_matrix.cpp
clebsch_gordan_coeff.cpp
complexarray.cpp
complexmatrix.cpp
Expand Down
14 changes: 3 additions & 11 deletions source/source_base/cubic_spline.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "cubic_spline.h"
#include "source_base/module_external/lapack_connector.h"

#include <cassert>
#include <algorithm>
Expand All @@ -8,13 +9,6 @@

using ModuleBase::CubicSpline;

extern "C"
{
// solve a tridiagonal linear system
void dgtsv_(int* N, int* NRHS, double* DL, double* D, double* DU, double* B, int* LDB, int* INFO);
};


CubicSpline::BoundaryCondition::BoundaryCondition(BoundaryType type)
: type(type)
{
Expand Down Expand Up @@ -477,8 +471,7 @@ void CubicSpline::_build(

int nrhs = 1;
int ldb = n;
int info = 0;
dgtsv_(&n, &nrhs, l, d, u, dy, &ldb, &info);
LapackConnector::gtsv(LapackConnector::ColMajor, n, nrhs, l, d, u, dy, ldb);
}
}
}
Expand Down Expand Up @@ -552,9 +545,8 @@ void CubicSpline::_solve_cyctri(int n, double* d, double* u, double* l, double*
d[n - 1] -= l[n - 1] * alpha / beta;

int nrhs = 2;
int info = 0;
int ldb = n;
dgtsv_(&n, &nrhs, l, d, u, bp.data(), &ldb, &info);
LapackConnector::gtsv(LapackConnector::ColMajor, n, nrhs, l, d, u, bp.data(), ldb);

double fac = (beta * u[n - 1] * bp[0] + alpha * l[n - 1] * bp[n - 1])
/ (1. + beta * u[n - 1] * bp[n] + alpha * l[n - 1] * bp[2 * n - 1]);
Expand Down
67 changes: 33 additions & 34 deletions source/source_base/gather_math_lib_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void zgemm_i(const char *transa,
GlobalV::ofs_info.unsetf(std::ios_base::floatfield);
GlobalV::ofs_info << "zgemm " << *transa << " " << *transb << " " << *m << " " << *n << " "
<< *k << " " << *alpha << " " << *lda << " " << *ldb << " " << *beta << " " << *ldc << std::endl;
zgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
BlasConnector::gemm_cm(*transa, *transb, *m, *n, *k, *alpha, a, *lda, b, *ldb, *beta, c, *ldc);
}

void zaxpy_i(const int *N,
Expand All @@ -43,38 +43,37 @@ void zaxpy_i(const int *N,
// std::cout << "zaxpy " << *N << std::endl;
// alpha is a coefficient
// incX, incY is always 1
zaxpy_(N, alpha, X, incX, Y, incY);
BlasConnector::axpy(*N, *alpha, X, *incX, Y, *incY);
}

void zhegvx_i(const int *itype,
const char *jobz,
const char *range,
const char *uplo,
const int *n,
std::complex<double> *a,
const int *lda,
std::complex<double> *b,
const int *ldb,
const double *vl,
const double *vu,
const int *il,
const int *iu,
const double *abstol,
const int *m,
double *w,
std::complex<double> *z,
const int *ldz,
std::complex<double> *work,
const int *lwork,
double *rwork,
int *iwork,
int *ifail,
int *info)
{
GlobalV::ofs_info.unsetf(std::ios_base::floatfield);
GlobalV::ofs_info << "zhegvx " << *itype << " " << *jobz << " " << *range << " " << *uplo
<< " " << *n << " " << *lda << " " << *ldb << " " << *vl << " " << *vu << " " << *il << " " << *iu
<< " " << *abstol << " " << *m << " " << *lwork << " " << *info << std::endl;
zhegvx_(itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, work, lwork, rwork,
iwork, ifail, info);
}
// void zhegvx_i(const int *itype,
// const char *jobz,
// const char *range,
// const char *uplo,
// const int *n,
// std::complex<double> *a,
// const int *lda,
// std::complex<double> *b,
// const int *ldb,
// const double *vl,
// const double *vu,
// const int *il,
// const int *iu,
// const double *abstol,
// int *m,
// double *w,
// std::complex<double> *z,
// const int *ldz,
// std::complex<double> *work,
// const int *lwork,
// double *rwork,
// int *iwork,
// int *ifail,
// int *info)
// {
// GlobalV::ofs_info.unsetf(std::ios_base::floatfield);
// GlobalV::ofs_info << "zhegvx " << *itype << " " << *jobz << " " << *range << " " << *uplo
// << " " << *n << " " << *lda << " " << *ldb << " " << *vl << " " << *vu << " " << *il << " " << *iu
// << " " << *abstol << " " << *m << " " << *lwork << " " << *info << std::endl;
// LapackConnector::hegvx(LapackConnector::ColMajor, *itype, *jobz, *range, *uplo, *n, a, *lda, b, *ldb, *vl, *vu, *il, *iu, *abstol, m, w, z, *ldz, ifail);
// }
13 changes: 2 additions & 11 deletions source/source_base/global_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,12 @@ inline void DCOPY(const T* a, T* b, const int& dim) {
}

template <typename T>
inline void COPYARRAY(const T* a, T* b, const long dim);

template <>
inline void COPYARRAY(const std::complex<double>* a, std::complex<double>* b, const long dim)
inline void COPYARRAY(const T* a, T* b, const long dim)
{
const int one = 1;
zcopy_(&dim, a, &one, b, &one);
BlasConnector::copy(dim, a, one, b, one);
}

template <>
inline void COPYARRAY(const double* a, double* b, const long dim)
{
const int one = 1;
dcopy_(&dim, a, &one, b, &one);
}

void BLOCK_HERE(const std::string& description);

Expand Down
35 changes: 4 additions & 31 deletions source/source_base/inverse_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ Inverse_Matrix_Complex::~Inverse_Matrix_Complex()
if(allocate)
{
delete[] e; //mohan fix bug 2012-04-02
delete[] work2;
delete[] rwork;
allocate=false;
}
}
Expand All @@ -28,23 +26,14 @@ void Inverse_Matrix_Complex::init(const int &dim_in)
if(allocate)
{
delete[] e; //mohan fix bug 2012-04-02
delete[] work2;
delete[] rwork;
allocate=false;
}

this->dim = dim_in;

assert(dim>0);
this->e = new double[dim];
this->lwork = 2*dim;

assert(lwork>0);
this->work2 = new std::complex<double>[lwork];

assert(3*dim-2>0);
this->rwork = new double[3*dim-2];
this->info = 0;
this->A.create(dim, dim);
this->EA.create(dim, dim);

Expand All @@ -59,7 +48,7 @@ void Inverse_Matrix_Complex::using_zheev( const ModuleBase::ComplexMatrix &Sin,
ModuleBase::timer::tick("Inverse","using_zheev");
this->A = Sin;

LapackConnector::zheev('V', 'U', dim, this->A, dim, e, work2, lwork, rwork, &info);
LapackConnector::heev(LapackConnector::RowMajor, 'V', 'U', dim, this->A.c, dim, e);

for(int i=0; i<dim; i++)
{
Expand All @@ -76,11 +65,8 @@ void Inverse_Matrix_Complex::using_zheev( const ModuleBase::ComplexMatrix &Sin,

void Inverse_Matrix_Real(const int dim, const double* in, double* out)
{
int info = 0;
int lda = dim;
int lwork = 64 * dim;
int* ipiv = new int[dim];
double* work = new double[lwork];
std::vector<int> ipiv(dim);

for (int i = 0; i < dim; i++)
{
Expand All @@ -90,20 +76,7 @@ void Inverse_Matrix_Real(const int dim, const double* in, double* out)
}
}

dgetrf_(&dim, &dim, out, &lda, ipiv, &info);
if (info != 0)
{
std::cout << "ERROR: LAPACK dgetrf error, info = " << info << std::endl;
exit(1);
}
dgetri_(&dim, out, &lda, ipiv, work, &lwork, &info);
if (info != 0)
{
std::cout << "ERROR: LAPACK dgetri error, info = " << info << std::endl;
exit(1);
}

delete[] ipiv;
delete[] work;
LapackConnector::getrf(LapackConnector::ColMajor, dim, dim, out, lda, ipiv.data());
LapackConnector::getri(LapackConnector::ColMajor, dim, out, lda, ipiv.data());
}
}
4 changes: 0 additions & 4 deletions source/source_base/inverse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ class Inverse_Matrix_Complex
private:
int dim=0;
double *e=nullptr;
int lwork=0;
std::complex<double> *work2=nullptr;
double* rwork=nullptr;
int info=0;
bool allocate=false; //mohan add 2012-04-02

ModuleBase::ComplexMatrix EA;
Expand Down
2 changes: 1 addition & 1 deletion source/source_base/kernels/math_kernel_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct gemv_op<T, base_device::DEVICE_CPU>
T* Y,
const int& incy)
{
BlasConnector::gemv(trans, m, n, *alpha, A, lda, X, incx, *beta, Y, incy);
BlasConnector::gemv_cm(trans, m, n, *alpha, A, lda, X, incx, *beta, Y, incy);
}
};

Expand Down
34 changes: 18 additions & 16 deletions source/source_base/kernels/test/math_kernel_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,17 +347,18 @@ TEST_F(TestModuleHsolverMathKernel, gemv_op_cpu)
int inc = 1;
int row = 2;
int col = 3;
zgemv_(&trans,
&row,
&col,
&ModuleBase::ONE,
BlasConnector::gemv_cm(
trans,
row,
col,
ModuleBase::ONE,
A_gemv.data(),
&row,
row,
X_gemv.data(),
&inc,
&ModuleBase::ONE,
inc,
ModuleBase::ONE,
Y_test_gemv.data(),
&inc);
inc);
for (int i = 0; i < Y_gemv.size(); i++)
{
EXPECT_LT(fabs(Y_gemv[i].imag() - Y_test_gemv[i].imag()), 1e-12);
Expand Down Expand Up @@ -607,17 +608,18 @@ TEST_F(TestModuleHsolverMathKernel, gemv_op_gpu)
int inc = 1;
int row = 2;
int col = 3;
zgemv_(&trans,
&row,
&col,
&ModuleBase::ONE,
BlasConnector::gemv(
trans,
row,
col,
ModuleBase::ONE,
A_gemv.data(),
&row,
row,
X_gemv.data(),
&inc,
&ModuleBase::ONE,
inc,
ModuleBase::ONE,
Y_test_gemv.data(),
&inc);
inc);

for (int i = 0; i < Y_gemv.size(); i++)
{
Expand Down
Loading
Loading