From 11ccff6c940c5b3048b97772cf6bf885bd67b917 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Mon, 21 Jul 2025 20:59:33 +0800 Subject: [PATCH 1/5] refactor blas and lapack connector --- source/Makefile.Objects | 6 +- source/source_base/CMakeLists.txt | 6 +- source/source_base/blas_connector_l1.cpp | 506 +++++++++++++++ source/source_base/blas_connector_l2.cpp | 142 +++++ source/source_base/blas_connector_l3.cpp | 598 ++++++++++++++++++ source/source_base/cubic_spline.cpp | 14 +- source/source_base/gather_math_lib_info.cpp | 67 +- source/source_base/global_function.h | 13 +- source/source_base/inverse_matrix.cpp | 32 +- source/source_base/inverse_matrix.h | 4 - source/source_base/kernels/math_kernel_op.cpp | 2 +- source/source_base/lapack_connector.cpp | 461 ++++++++++++++ .../module_container/ATen/kernels/blas.cpp | 26 +- .../ATen/kernels/cuda/lapack.cu | 2 +- .../module_container/ATen/kernels/lapack.cpp | 70 +- .../module_container/ATen/kernels/lapack.h | 2 +- .../module_container/base/third_party/blas.h | 343 +--------- .../base/third_party/lapack.h | 402 ++---------- .../module_external/blas_connector.h | 182 ++---- .../module_external/blas_connector_matrix.cpp | 153 +++-- .../module_external/blas_connector_vector.cpp | 129 ++-- .../module_external/lapack_connector.h | 542 +++------------- .../module_external/lapack_wrapper.h | 484 -------------- source/source_base/module_grid/batch.cpp | 9 +- .../module_mixing/broyden_mixing.cpp | 34 +- source/source_base/module_mixing/mixing.cpp | 48 +- .../module_mixing/pulay_mixing.cpp | 14 +- .../source_base/test/blas_connector_test.cpp | 6 +- .../test/lapack_connector_test.cpp | 48 +- source/source_estate/math_tools.h | 42 +- source/source_estate/module_dm/cal_dm_psi.cpp | 42 +- .../source_estate/module_dm/cal_edm_tddft.cpp | 10 +- source/source_hsolver/diago_david.cpp | 2 +- source/source_hsolver/diago_david.h | 2 +- source/source_hsolver/diago_iter_assist.cpp | 12 +- source/source_hsolver/diago_iter_assist.h | 12 +- source/source_hsolver/diago_lapack.cpp | 199 +----- source/source_hsolver/diago_lapack.h | 8 +- .../source_hsolver/kernels/cuda/dngvd_op.cu | 6 +- source/source_hsolver/kernels/dngvd_op.cpp | 223 +------ source/source_hsolver/kernels/dngvd_op.h | 6 +- .../kernels/rocm/dngvd_op.hip.cu | 18 +- .../source_hsolver/test/diago_bpcg_test.cpp | 8 +- .../test/diago_cg_float_test.cpp | 8 +- .../test/diago_cg_real_test.cpp | 13 +- source/source_hsolver/test/diago_cg_test.cpp | 9 +- .../test/diago_david_float_test.cpp | 12 +- .../test/diago_david_real_test.cpp | 13 +- .../source_hsolver/test/diago_david_test.cpp | 14 +- source/source_hsolver/test/diago_elpa_utils.h | 38 +- source/source_io/berryphase.cpp | 2 +- source/source_io/write_vxc.hpp | 6 +- source/source_io/write_vxc_lip.hpp | 10 +- .../module_deepks/deepks_orbpre.cpp | 21 +- .../source_lcao/module_deepks/deepks_pdm.cpp | 21 +- .../source_lcao/module_gint/gint_rho_old.cpp | 2 +- .../source_lcao/module_gint/gint_tau_old.cpp | 6 +- .../source_lcao/module_gint/gint_vl_old.cpp | 8 +- .../source_lcao/module_gint/mult_psi_dmr.cpp | 4 +- .../ao_to_mo_transformer/ao_to_mo_serial.cpp | 24 +- .../module_lr/dm_trans/dm_trans_serial.cpp | 24 +- .../module_lr/ri_benchmark/ri_benchmark.hpp | 26 +- .../source_lcao/module_lr/utils/lr_util.cpp | 40 +- .../module_operator_lcao/deepks_lcao.cpp | 21 +- .../module_ri/ABFs_Construct-PCA.cpp | 9 +- .../source_lcao/module_ri/Inverse_Matrix.hpp | 11 +- source/source_lcao/module_ri/exx_lip.hpp | 2 +- source/source_pw/module_pwdft/VNL_in_pw.cpp | 23 +- source/source_pw/module_pwdft/forces_us.cpp | 21 +- .../module_pwdft/operator_pw/op_exx_pw.cpp | 19 +- .../source_pw/module_pwdft/stress_func_us.cpp | 42 +- source/source_pw/module_stodft/sto_dos.cpp | 2 +- source/source_relax/bfgs.cpp | 6 +- 73 files changed, 2472 insertions(+), 2920 deletions(-) create mode 100644 source/source_base/blas_connector_l1.cpp create mode 100644 source/source_base/blas_connector_l2.cpp create mode 100644 source/source_base/blas_connector_l3.cpp create mode 100644 source/source_base/lapack_connector.cpp delete mode 100644 source/source_base/module_external/lapack_wrapper.h diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 2a3d6fa918..75bf4e6da7 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -128,8 +128,10 @@ OBJS_MAIN=main.o\ OBJS_BASE=abfs-vector3_order.o\ assoc_laguerre.o\ blas_connector_base.o\ - blas_connector_vector.o\ - blas_connector_matrix.o\ + blas_connector_l1.o\ + blas_connector_l2.o\ + blas_connector_l3.o\ + lapack_connector.o\ complexarray.o\ complexmatrix.o\ clebsch_gordan_coeff.o\ diff --git a/source/source_base/CMakeLists.txt b/source/source_base/CMakeLists.txt index 717ef9bf44..9b0cc046f0 100644 --- a/source/source_base/CMakeLists.txt +++ b/source/source_base/CMakeLists.txt @@ -11,8 +11,10 @@ add_library( OBJECT assoc_laguerre.cpp module_external/blas_connector_base.cpp - module_external/blas_connector_vector.cpp - module_external/blas_connector_matrix.cpp + module_external/blas_connector_l1.cpp + module_external/blas_connector_l2.cpp + module_external/blas_connector_l3.cpp + module_external/lapack_connector.cpp clebsch_gordan_coeff.cpp complexarray.cpp complexmatrix.cpp diff --git a/source/source_base/blas_connector_l1.cpp b/source/source_base/blas_connector_l1.cpp new file mode 100644 index 0000000000..f9c5925143 --- /dev/null +++ b/source/source_base/blas_connector_l1.cpp @@ -0,0 +1,506 @@ +/* level 1: std::vector-std::vector operations, O(n) data and O(n) work. + * This file contains the implementation of the BLAS level 1 operations. + * These operations include vector scaling, vector addition, vector dot product, and vector norm calculations. + */ +#include "blas_connector.h" +#include "../macros.h" + +#include +#ifdef __DSP +#include "source_base/kernels/dsp/dsp_connector.h" +#include "source_base/global_variable.h" +#endif + +#ifdef __CUDA +#include +#include +#include "cublas_v2.h" +#include "source_base/kernels/math_kernel_op.h" +#include "source_base/module_device/memory_op.h" +#endif + +extern "C" +{ + // level 1: std::vector-std::vector operations, O(n) data and O(n) work. + // Peize Lin add ?scal 2016-08-04, to compute x=a*x + void sscal_(const int *N, const float *alpha, float *X, const int *incX); + void dscal_(const int *N, const double *alpha, double *X, const int *incX); + void cscal_(const int *N, const std::complex *alpha, std::complex *X, const int *incX); + void zscal_(const int *N, const std::complex *alpha, std::complex *X, const int *incX); + + // Peize Lin add ?axpy 2016-08-04, to compute y=a*x+y + void saxpy_(const int *N, const float *alpha, const float *X, const int *incX, float *Y, const int *incY); + void daxpy_(const int *N, const double *alpha, const double *X, const int *incX, double *Y, const int *incY); + void caxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); + void zaxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); + + void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy); + void zcopy_(long const *n, const std::complex *a, int const *incx, std::complex *b, int const *incy); + + //reason for passing results as argument instead of returning it: + //see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/ + // void zdotc_(std::complex *result, const int *n, const std::complex *zx, + // const int *incx, const std::complex *zy, const int *incy); + // Peize Lin add ?dot 2017-10-27, to compute d=x*y + float sdot_(const int *N, const float *X, const int *incX, const float *Y, const int *incY); + double ddot_(const int *N, const double *X, const int *incX, const double *Y, const int *incY); + + // Peize Lin add ?nrm2 2018-06-12, to compute out = ||x||_2 = \sqrt{ \sum_i x_i**2 } + float snrm2_( const int *n, const float *X, const int *incX ); + double dnrm2_( const int *n, const double *X, const int *incX ); + double dznrm2_( const int *n, const std::complex *X, const int *incX ); +} + +// x=a*x +void BlasConnector::scal( const int n, const float alpha, float *X, const int incX, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + sscal_(&n, &alpha, X, &incX); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + dscal_(&n, &alpha, X, &incX); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + cscal_(&n, &alpha, X, &incX); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zscal_(&n, &alpha, X, &incX); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + saxpy_(&n, &alpha, X, &incX, Y, &incY); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + daxpy_(&n, &alpha, X, &incX, Y, &incY); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + caxpy_(&n, &alpha, X, &incX, Y, &incY); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zaxpy_(&n, &alpha, X, &incX, Y, &incY); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +// copies a into b +void BlasConnector::copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + dcopy_(&n, a, &incx, b, &incy); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zcopy_(&n, a, &incx, b, &incy); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +// d=x*y +float BlasConnector::dot( const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + return sdot_(&n, X, &incX, Y, &incY); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ + float result = 0.0; + cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result)); + return result; + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +double BlasConnector::dot( const int n, const double*const X, const int incX, const double*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + return ddot_(&n, X, &incX, Y, &incY); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ + double result = 0.0; + cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result)); + return result; + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +// d=x*y +float BlasConnector::dotu(const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + return BlasConnector::dot(n, X, incX, Y, incY, device_type); +} + +double BlasConnector::dotu(const int n, const double*const X, const int incX, const double*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + return BlasConnector::dot(n, X, incX, Y, incY, device_type); +} + +std::complex BlasConnector::dotu(const int n, const std::complex*const X, const int incX, const std::complex*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + const int incX2 = 2 * incX; + const int incY2 = 2 * incY; + const float*const x = reinterpret_cast(X); + const float*const y = reinterpret_cast(Y); + //Re(result)=Re(x)*Re(y)-Im(x)*Im(y) + //Im(result)=Re(x)*Im(y)+Im(x)*Re(y) + return std::complex( + BlasConnector::dot(n, x, incX2, y, incY2, device_type) - dot(n, x+1, incX2, y+1, incY2, device_type), + BlasConnector::dot(n, x, incX2, y+1, incY2, device_type) + dot(n, x+1, incX2, y, incY2, device_type)); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +std::complex BlasConnector::dotu(const int n, const std::complex*const X, const int incX, const std::complex*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + const int incX2 = 2 * incX; + const int incY2 = 2 * incY; + const double*const x = reinterpret_cast(X); + const double*const y = reinterpret_cast(Y); + //Re(result)=Re(x)*Re(y)-Im(x)*Im(y) + //Im(result)=Re(x)*Im(y)+Im(x)*Re(y) + return std::complex( + BlasConnector::dot(n, x, incX2, y, incY2, device_type) - dot(n, x+1, incX2, y+1, incY2, device_type), + BlasConnector::dot(n, x, incX2, y+1, incY2, device_type) + dot(n, x+1, incX2, y, incY2, device_type)); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +// d = x.conj() * Vy +float BlasConnector::dotc(const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + return BlasConnector::dot(n, X, incX, Y, incY, device_type); +} + +double BlasConnector::dotc(const int n, const double*const X, const int incX, const double*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + return BlasConnector::dot(n, X, incX, Y, incY, device_type); +} + +std::complex BlasConnector::dotc(const int n, const std::complex*const X, const int incX, const std::complex*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + const int incX2 = 2 * incX; + const int incY2 = 2 * incY; + const float*const x = reinterpret_cast(X); + const float*const y = reinterpret_cast(Y); + // Re(result)=Re(X)*Re(Y)+Im(X)*Im(Y) + // Im(result)=Re(X)*Im(Y)-Im(X)*Re(Y) + return std::complex( + BlasConnector::dot(n, x, incX2, y, incY2, device_type) + dot(n, x+1, incX2, y+1, incY2, device_type), + BlasConnector::dot(n, x, incX2, y+1, incY2, device_type) - dot(n, x+1, incX2, y, incY2, device_type)); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +std::complex BlasConnector::dotc(const int n, const std::complex*const X, const int incX, const std::complex*const Y, const int incY, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + const int incX2 = 2 * incX; + const int incY2 = 2 * incY; + const double*const x = reinterpret_cast(X); + const double*const y = reinterpret_cast(Y); + // Re(result)=Re(X)*Re(Y)+Im(X)*Im(Y) + // Im(result)=Re(X)*Im(Y)-Im(X)*Re(Y) + return std::complex( + BlasConnector::dot(n, x, incX2, y, incY2, device_type) + dot(n, x+1, incX2, y+1, incY2, device_type), + BlasConnector::dot(n, x, incX2, y+1, incY2, device_type) - dot(n, x+1, incX2, y, incY2, device_type)); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +// out = ||x||_2 +float BlasConnector::nrm2( const int n, const float *X, const int incX, base_device::AbacusDevice_t device_type ) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + return snrm2_( &n, X, &incX ); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ + float result = 0.0; + cublasErrcheck(cublasSnrm2(BlasUtils::cublas_handle, n, X, incX, &result)); + return result; + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + + +double BlasConnector::nrm2( const int n, const double *X, const int incX, base_device::AbacusDevice_t device_type ) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + return dnrm2_( &n, X, &incX ); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ + double result = 0.0; + cublasErrcheck(cublasDnrm2(BlasUtils::cublas_handle, n, X, incX, &result)); + return result; + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + + +double BlasConnector::nrm2( const int n, const std::complex *X, const int incX, base_device::AbacusDevice_t device_type ) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + return dznrm2_( &n, X, &incX ); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice){ + double result = 0.0; + cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, &result)); + return result; + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +template +void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type){ + using Real = typename GetTypeReal::type; + if (device_type == base_device::AbacusDevice_t::CpuDevice) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static, 4096 / sizeof(Real)) +#endif + for (int i = 0; i < dim; i++) + { + result[i] = vector1[i] * vector2[i]; + } + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + ModuleBase::vector_mul_vector_op()(dim, result, vector1, vector2); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + + +template +void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type){ + using Real = typename GetTypeReal::type; + if (device_type == base_device::AbacusDevice_t::CpuDevice) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static, 4096 / sizeof(Real)) +#endif + for (int i = 0; i < dim; i++) + { + result[i] = vector1[i] / vector2[i]; + } + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + ModuleBase::vector_div_vector_op()(dim, result, vector1, vector2); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void vector_add_vector(const int& dim, float *result, const float *vector1, const float constant1, const float *vector2, const float constant2, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::CpuDevice){ +#ifdef _OPENMP +#pragma omp parallel for schedule(static, 8192 / sizeof(float)) +#endif + for (int i = 0; i < dim; i++) + { + result[i] = vector1[i] * constant1 + vector2[i] * constant2; + } + } +#ifdef __CUDA + else if (device_type == base_device::GpuDevice) { + ModuleBase::vector_add_vector_op()(dim, result, vector1, constant1, vector2, constant2); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void vector_add_vector(const int& dim, double *result, const double *vector1, const double constant1, const double *vector2, const double constant2, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::CpuDevice){ +#ifdef _OPENMP +#pragma omp parallel for schedule(static, 8192 / sizeof(double)) +#endif + for (int i = 0; i < dim; i++) + { + result[i] = vector1[i] * constant1 + vector2[i] * constant2; + } + } +#ifdef __CUDA + else if (device_type == base_device::GpuDevice) { + ModuleBase::vector_add_vector_op()(dim, result, vector1, constant1, vector2, constant2); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void vector_add_vector(const int& dim, std::complex *result, const std::complex *vector1, const float constant1, const std::complex *vector2, const float constant2, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::CpuDevice){ +#ifdef _OPENMP +#pragma omp parallel for schedule(static, 8192 / sizeof(std::complex)) +#endif + for (int i = 0; i < dim; i++) + { + result[i] = vector1[i] * constant1 + vector2[i] * constant2; + } + } +#ifdef __CUDA + else if (device_type == base_device::GpuDevice) { + ModuleBase::vector_add_vector_op, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void vector_add_vector(const int& dim, std::complex *result, const std::complex *vector1, const double constant1, const std::complex *vector2, const double constant2, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::CpuDevice){ +#ifdef _OPENMP +#pragma omp parallel for schedule(static, 8192 / sizeof(std::complex)) +#endif + for (int i = 0; i < dim; i++) + { + result[i] = vector1[i] * constant1 + vector2[i] * constant2; + } + } +#ifdef __CUDA + else if (device_type == base_device::GpuDevice) { + ModuleBase::vector_add_vector_op, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} \ No newline at end of file diff --git a/source/source_base/blas_connector_l2.cpp b/source/source_base/blas_connector_l2.cpp new file mode 100644 index 0000000000..2966e2118b --- /dev/null +++ b/source/source_base/blas_connector_l2.cpp @@ -0,0 +1,142 @@ +/* level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work. + * This file contains the implementation of the BLAS level 2 operations. + * These operations include matrix-vector multiplication and related operations. + */ +#include "blas_connector.h" +#include "macros.h" +#include + +#ifdef __DSP +#include "source_base/kernels/dsp/dsp_connector.h" +#include "source_base/global_variable.h" +#endif + +#ifdef __CUDA +#include +#include +#include "cublas_v2.h" +#include "source_base/kernels/math_kernel_op.h" +#include "source_base/module_device/memory_op.h" +#endif + +extern "C" +{ + // level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work. + void sgemv_(const char*const transa, const int*const m, const int*const n, + const float*const alpha, const float*const a, const int*const lda, const float*const x, const int*const incx, + const float*const beta, float*const y, const int*const incy); + void dgemv_(const char*const transa, const int*const m, const int*const n, + const double*const alpha, const double*const a, const int*const lda, const double*const x, const int*const incx, + const double*const beta, double*const y, const int*const incy); + + void cgemv_(const char *trans, const int *m, const int *n, const std::complex *alpha, + const std::complex *a, const int *lda, const std::complex *x, const int *incx, + const std::complex *beta, std::complex *y, const int *incy); + + void zgemv_(const char *trans, const int *m, const int *n, const std::complex *alpha, + const std::complex *a, const int *lda, const std::complex *x, const int *incx, + const std::complex *beta, std::complex *y, const int *incy); + + void dsymv_(const char *uplo, const int *n, + const double *alpha, const double *a, const int *lda, + const double *x, const int *incx, + const double *beta, double *y, const int *incy); + + // A := alpha x * y.T + A + void dger_(const int* m, + const int* n, + const double* alpha, + const double* x, + const int* incx, + const double* y, + const int* incy, + double* a, + const int* lda); +} + +void BlasConnector::gemv_cm(const char trans, const int m, const int n, + const float alpha, const float* A, const int lda, const float* X, const int incx, + const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); + cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::gemv_cm(const char trans, const int m, const int n, + const double alpha, const double* A, const int lda, const double* X, const int incx, + const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); + cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::gemv_cm(const char trans, const int m, const int n, + const std::complex alpha, const std::complex *A, const int lda, const std::complex *X, const int incx, + const std::complex beta, std::complex *Y, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cuFloatComplex alpha_cu = make_cuFloatComplex(alpha.real(), alpha.imag()); + cuFloatComplex beta_cu = make_cuFloatComplex(beta.real(), beta.imag()); + cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op"); + cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::gemv_cm(const char trans, const int m, const int n, + const std::complex alpha, const std::complex *A, const int lda, const std::complex *X, const int incx, + const std::complex beta, std::complex *Y, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cuDoubleComplex alpha_cu = make_cuDoubleComplex(alpha.real(), alpha.imag()); + cuDoubleComplex beta_cu = make_cuDoubleComplex(beta.real(), beta.imag()); + cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op"); + cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::ger_cm(int m, int n, double alpha, const double* x, + int incx, const double* y, const int incy, double a, int lda, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + dger_(&m, &n, &alpha, x, &incx, y, &incy, &a, &lda); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} \ No newline at end of file diff --git a/source/source_base/blas_connector_l3.cpp b/source/source_base/blas_connector_l3.cpp new file mode 100644 index 0000000000..2c749552b1 --- /dev/null +++ b/source/source_base/blas_connector_l3.cpp @@ -0,0 +1,598 @@ +/* level 3: matrix-matrix operations, O(n^2) data and O(n^3) work. + * This file contains the implementation of the BLAS level 3 operations. + * These operations include matrix-matrix multiplication and related operations. + */ +#include "blas_connector.h" +#include "../macros.h" + +#ifdef __DSP +#include "source_base/kernels/dsp/dsp_connector.h" +#include "source_base/global_variable.h" +#endif + +#ifdef __CUDA +#include +#include +#include "cublas_v2.h" +#include "source_base/kernels/math_kernel_op.h" +#include "source_base/module_device/memory_op.h" +#endif + +extern "C" +{ + // level 3: matrix-matrix operations, O(n^2) data and O(n^3) work. + + // Peize Lin add ?gemm 2017-10-27, to compute C = a * A.? * B.? + b * C + // A is general + void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, + const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, + const float *beta, float *c, const int *ldc); + void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, + const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, + const double *beta, double *c, const int *ldc); + void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + + // A is symmetric. C = a * A.? * B.? + b * C + void ssymm_(const char *side, const char *uplo, const int *m, const int *n, + const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, + const float *beta, float *c, const int *ldc); + void dsymm_(const char *side, const char *uplo, const int *m, const int *n, + const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, + const double *beta, double *c, const int *ldc); + void csymm_(const char *side, const char *uplo, const int *m, const int *n, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + void zsymm_(const char *side, const char *uplo, const int *m, const int *n, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + + // A is hermitian. C = a * A.? * B.? + b * C + void chemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, + std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); + void zhemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, + std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); + + // symmetric rank-k update + void dsyrk_( + const char* uplo, + const char* trans, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* beta, + double* c, + const int* ldc + ); +} + +// C = a * A.? * B.? + b * C +// Row-Major part +void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, + const float alpha, const float *a, const int lda, const float *b, const int ldb, + const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + sgemm_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc); + } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice){ + mtfunc::sgemm_mth_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::gemm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const double alpha, + const double* a, + const int lda, + const double* b, + const int ldb, + const double beta, + double* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + dgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); + } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck( + cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); +#endif + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::gemm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); + } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + n, + m, + k, + (float2*)&alpha, + (float2*)b, + ldb, + (float2*)a, + lda, + (float2*)&beta, + (float2*)c, + ldc)); +#endif + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::gemm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + zgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); + } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + n, + m, + k, + (double2*)&alpha, + (double2*)b, + ldb, + (double2*)a, + lda, + (double2*)&beta, + (double2*)c, + ldc)); +#endif + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +// Col-Major part +void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, + const float alpha, const float *a, const int lda, const float *b, const int ldb, + const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + sgemm_(&transa, &transb, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice){ + mtfunc::sgemm_mth_(&transb, &transa, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::gemm_cm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const double alpha, + const double* a, + const int lda, + const double* b, + const int ldb, + const double beta, + double* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::dgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck( + cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::gemm_cm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + m, + n, + k, + (float2*)&alpha, + (float2*)a, + lda, + (float2*)b, + ldb, + (float2*)&beta, + (float2*)c, + ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::gemm_cm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + zgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + m, + n, + k, + (double2*)&alpha, + (double2*)a, + lda, + (double2*)b, + ldb, + (double2*)&beta, + (double2*)c, + ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +// Symm and Hemm part. Only col-major is supported. + +void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, + const float alpha, const float *a, const int lda, const float *b, const int ldb, + const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + ssymm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, + const double alpha, const double *a, const int lda, const double *b, const int ldb, + const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + dsymm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + csymm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zsymm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::hemm_cm(const char side, const char uplo, const int m, const int n, + const float alpha, const float *a, const int lda, const float *b, const int ldb, + const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + symm_cm(side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc, device_type); +} + +void BlasConnector::hemm_cm(const char side, const char uplo, const int m, const int n, + const double alpha, const double *a, const int lda, const double *b, const int ldb, + const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + symm_cm(side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc, device_type); +} + +void BlasConnector::hemm_cm(char side, char uplo, int m, int n, + std::complex alpha, std::complex *a, int lda, std::complex *b, int ldb, + std::complex beta, std::complex *c, int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + chemm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::hemm_cm(char side, char uplo, int m, int n, + std::complex alpha, std::complex *a, int lda, std::complex *b, int ldb, + std::complex beta, std::complex *c, int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zhemm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +#ifdef __CUDA + else if (device_type == base_device::AbacusDevice_t::GpuDevice) { + cublasSideMode_t sideMode = BlasUtils::judge_side(side); + cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); + cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); + } +#endif + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::syrk(char uplo, char trans, int n, int k, + double alpha, const double* a, int lda, double beta, double* c, int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + dsyrk_(&uplo, &trans, &n, &k, &alpha, a, &lda, &beta, c, &ldc); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::herk(char uplo, char trans, int n, int k, float alpha, + const std::complex *A, int lda, float beta, std::complex *C, int ldc, base_device::AbacusDevice_t device_type) +{ + auto cblas_uplo = BlasUtils::toCblasUplo(uplo); + auto cblas_trans = BlasUtils::toCblasTrans(trans); + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cblas_cherk(CblasRowMajor, cblas_uplo, cblas_trans, n, k, alpha, A, lda, beta, C, ldc); + } + else + { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::herk(char uplo, char trans, int n, int k, double alpha, + const std::complex *A, int lda, double beta, std::complex *C, int ldc, base_device::AbacusDevice_t device_type) +{ + auto cblas_uplo = BlasUtils::toCblasUplo(uplo); + auto cblas_trans = BlasUtils::toCblasTrans(trans); + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cblas_zherk(CblasRowMajor, cblas_uplo, cblas_trans, n, k, alpha, A, lda, beta, C, ldc); + } + else + { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} \ No newline at end of file diff --git a/source/source_base/cubic_spline.cpp b/source/source_base/cubic_spline.cpp index cfd9f57943..8d28bbf1c0 100644 --- a/source/source_base/cubic_spline.cpp +++ b/source/source_base/cubic_spline.cpp @@ -1,4 +1,5 @@ #include "cubic_spline.h" +#include "source_base/lapack_connector.h" #include #include @@ -8,13 +9,6 @@ using ModuleBase::CubicSpline; -extern "C" -{ - // solve a tridiagonal linear system - void dgtsv_(int* N, int* NRHS, double* DL, double* D, double* DU, double* B, int* LDB, int* INFO); -}; - - CubicSpline::BoundaryCondition::BoundaryCondition(BoundaryType type) : type(type) { @@ -477,8 +471,7 @@ void CubicSpline::_build( int nrhs = 1; int ldb = n; - int info = 0; - dgtsv_(&n, &nrhs, l, d, u, dy, &ldb, &info); + LapackConnector::gtsv(LapackConnector::ColMajor, n, nrhs, l, d, u, dy, ldb); } } } @@ -552,9 +545,8 @@ void CubicSpline::_solve_cyctri(int n, double* d, double* u, double* l, double* d[n - 1] -= l[n - 1] * alpha / beta; int nrhs = 2; - int info = 0; int ldb = n; - dgtsv_(&n, &nrhs, l, d, u, bp.data(), &ldb, &info); + LapackConnector::gtsv(LapackConnector::ColMajor, n, nrhs, l, d, u, bp.data(), ldb); double fac = (beta * u[n - 1] * bp[0] + alpha * l[n - 1] * bp[n - 1]) / (1. + beta * u[n - 1] * bp[n] + alpha * l[n - 1] * bp[2 * n - 1]); diff --git a/source/source_base/gather_math_lib_info.cpp b/source/source_base/gather_math_lib_info.cpp index 825aaa3163..c8a7b23dde 100644 --- a/source/source_base/gather_math_lib_info.cpp +++ b/source/source_base/gather_math_lib_info.cpp @@ -30,7 +30,7 @@ void zgemm_i(const char *transa, GlobalV::ofs_info.unsetf(std::ios_base::floatfield); GlobalV::ofs_info << "zgemm " << *transa << " " << *transb << " " << *m << " " << *n << " " << *k << " " << *alpha << " " << *lda << " " << *ldb << " " << *beta << " " << *ldc << std::endl; - zgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + BlasConnector::gemm_cm(*transa, *transb, *m, *n, *k, *alpha, a, *lda, b, *ldb, *beta, c, *ldc); } void zaxpy_i(const int *N, @@ -43,38 +43,37 @@ void zaxpy_i(const int *N, // std::cout << "zaxpy " << *N << std::endl; // alpha is a coefficient // incX, incY is always 1 - zaxpy_(N, alpha, X, incX, Y, incY); + BlasConnector::axpy(*N, *alpha, X, *incX, Y, *incY); } -void zhegvx_i(const int *itype, - const char *jobz, - const char *range, - const char *uplo, - const int *n, - std::complex *a, - const int *lda, - std::complex *b, - const int *ldb, - const double *vl, - const double *vu, - const int *il, - const int *iu, - const double *abstol, - const int *m, - double *w, - std::complex *z, - const int *ldz, - std::complex *work, - const int *lwork, - double *rwork, - int *iwork, - int *ifail, - int *info) -{ - GlobalV::ofs_info.unsetf(std::ios_base::floatfield); - GlobalV::ofs_info << "zhegvx " << *itype << " " << *jobz << " " << *range << " " << *uplo - << " " << *n << " " << *lda << " " << *ldb << " " << *vl << " " << *vu << " " << *il << " " << *iu - << " " << *abstol << " " << *m << " " << *lwork << " " << *info << std::endl; - zhegvx_(itype, jobz, range, uplo, n, a, lda, b, ldb, vl, vu, il, iu, abstol, m, w, z, ldz, work, lwork, rwork, - iwork, ifail, info); -} +// void zhegvx_i(const int *itype, +// const char *jobz, +// const char *range, +// const char *uplo, +// const int *n, +// std::complex *a, +// const int *lda, +// std::complex *b, +// const int *ldb, +// const double *vl, +// const double *vu, +// const int *il, +// const int *iu, +// const double *abstol, +// int *m, +// double *w, +// std::complex *z, +// const int *ldz, +// std::complex *work, +// const int *lwork, +// double *rwork, +// int *iwork, +// int *ifail, +// int *info) +// { +// GlobalV::ofs_info.unsetf(std::ios_base::floatfield); +// GlobalV::ofs_info << "zhegvx " << *itype << " " << *jobz << " " << *range << " " << *uplo +// << " " << *n << " " << *lda << " " << *ldb << " " << *vl << " " << *vu << " " << *il << " " << *iu +// << " " << *abstol << " " << *m << " " << *lwork << " " << *info << std::endl; +// LapackConnector::hegvx(LapackConnector::ColMajor, *itype, *jobz, *range, *uplo, *n, a, *lda, b, *ldb, *vl, *vu, *il, *iu, *abstol, m, w, z, *ldz, ifail); +// } diff --git a/source/source_base/global_function.h b/source/source_base/global_function.h index 7981fb79bd..318f37ae0b 100644 --- a/source/source_base/global_function.h +++ b/source/source_base/global_function.h @@ -182,21 +182,12 @@ inline void DCOPY(const T* a, T* b, const int& dim) { } template -inline void COPYARRAY(const T* a, T* b, const long dim); - -template <> -inline void COPYARRAY(const std::complex* a, std::complex* b, const long dim) +inline void COPYARRAY(const T* a, T* b, const long dim) { const int one = 1; - zcopy_(&dim, a, &one, b, &one); + BlasConnector::copy(dim, a, one, b, one); } -template <> -inline void COPYARRAY(const double* a, double* b, const long dim) -{ - const int one = 1; - dcopy_(&dim, a, &one, b, &one); -} void BLOCK_HERE(const std::string& description); diff --git a/source/source_base/inverse_matrix.cpp b/source/source_base/inverse_matrix.cpp index 2fb130759c..55ad954104 100644 --- a/source/source_base/inverse_matrix.cpp +++ b/source/source_base/inverse_matrix.cpp @@ -16,8 +16,6 @@ Inverse_Matrix_Complex::~Inverse_Matrix_Complex() if(allocate) { delete[] e; //mohan fix bug 2012-04-02 - delete[] work2; - delete[] rwork; allocate=false; } } @@ -28,8 +26,6 @@ void Inverse_Matrix_Complex::init(const int &dim_in) if(allocate) { delete[] e; //mohan fix bug 2012-04-02 - delete[] work2; - delete[] rwork; allocate=false; } @@ -37,14 +33,10 @@ void Inverse_Matrix_Complex::init(const int &dim_in) assert(dim>0); this->e = new double[dim]; - this->lwork = 2*dim; assert(lwork>0); - this->work2 = new std::complex[lwork]; assert(3*dim-2>0); - this->rwork = new double[3*dim-2]; - this->info = 0; this->A.create(dim, dim); this->EA.create(dim, dim); @@ -59,7 +51,7 @@ void Inverse_Matrix_Complex::using_zheev( const ModuleBase::ComplexMatrix &Sin, ModuleBase::timer::tick("Inverse","using_zheev"); this->A = Sin; - LapackConnector::zheev('V', 'U', dim, this->A, dim, e, work2, lwork, rwork, &info); + LapackConnector::heev(LapackConnector::RowMajor, 'V', 'U', dim, this->A.c, dim, e); for(int i=0; i ipiv(dim); for (int i = 0; i < dim; i++) { @@ -90,20 +79,7 @@ void Inverse_Matrix_Real(const int dim, const double* in, double* out) } } - dgetrf_(&dim, &dim, out, &lda, ipiv, &info); - if (info != 0) - { - std::cout << "ERROR: LAPACK dgetrf error, info = " << info << std::endl; - exit(1); - } - dgetri_(&dim, out, &lda, ipiv, work, &lwork, &info); - if (info != 0) - { - std::cout << "ERROR: LAPACK dgetri error, info = " << info << std::endl; - exit(1); - } - - delete[] ipiv; - delete[] work; + LapackConnector::getrf(LapackConnector::ColMajor, dim, dim, out, lda, ipiv.data()); + LapackConnector::getri(LapackConnector::ColMajor, dim, out, lda, ipiv.data()); } } \ No newline at end of file diff --git a/source/source_base/inverse_matrix.h b/source/source_base/inverse_matrix.h index d49e109e15..d3380e6bba 100644 --- a/source/source_base/inverse_matrix.h +++ b/source/source_base/inverse_matrix.h @@ -22,10 +22,6 @@ class Inverse_Matrix_Complex private: int dim=0; double *e=nullptr; - int lwork=0; - std::complex *work2=nullptr; - double* rwork=nullptr; - int info=0; bool allocate=false; //mohan add 2012-04-02 ModuleBase::ComplexMatrix EA; diff --git a/source/source_base/kernels/math_kernel_op.cpp b/source/source_base/kernels/math_kernel_op.cpp index dad790e463..e07594404d 100644 --- a/source/source_base/kernels/math_kernel_op.cpp +++ b/source/source_base/kernels/math_kernel_op.cpp @@ -22,7 +22,7 @@ struct gemv_op T* Y, const int& incy) { - BlasConnector::gemv(trans, m, n, *alpha, A, lda, X, incx, *beta, Y, incy); + BlasConnector::gemv_cm(trans, m, n, *alpha, A, lda, X, incx, *beta, Y, incy); } }; diff --git a/source/source_base/lapack_connector.cpp b/source/source_base/lapack_connector.cpp new file mode 100644 index 0000000000..41d3b6c016 --- /dev/null +++ b/source/source_base/lapack_connector.cpp @@ -0,0 +1,461 @@ +#include +#include "lapack_connector.h" +#include "source_base/tool_quit.h" + +namespace LapackConnector +{ +int toLapackLayout(MatrixLayout layout) +{ + return (layout == MatrixLayout::RowMajor) ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR; +} + +void hegv(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::complex* a, int lda, std::complex* b, int ldb, float* w) +{ + + int info = LAPACKE_chegv(toLapackLayout(layout), itype, jobz, uplo, n, reinterpret_cast(a), lda, reinterpret_cast(b), ldb, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK chegv failed with info = " + std::to_string(info)); + } +} + +void hegv(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::complex* a, int lda, std::complex* b, int ldb, double* w) +{ + int info = LAPACKE_zhegv(toLapackLayout(layout), itype, jobz, uplo, n, reinterpret_cast(a), lda, reinterpret_cast(b), ldb, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zhegv failed with info = " + std::to_string(info)); + } +} + +void hegv(MatrixLayout layout, int itype, char jobz, char uplo, int n, double* a, int lda, double* b, int ldb, double* w) +{ + int info = LAPACKE_dsygv(toLapackLayout(layout), itype, jobz, uplo, n, a, lda, b, ldb, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsygv failed with info = " + std::to_string(info)); + } +} + +void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, float* a, int lda, float* b, int ldb, float* w) +{ + int info = LAPACKE_ssygvd(toLapackLayout(layout), itype, jobz, uplo, n, a, lda, b, ldb, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ssygvd failed with info = " + std::to_string(info)); + } +} + +void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, double* a, int lda, double* b, int ldb, double* w) +{ + int info = LAPACKE_dsygvd(toLapackLayout(layout), itype, jobz, uplo, n, a, lda, b, ldb, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsygvd failed with info = " + std::to_string(info)); + } +} + +void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::complex* a, int lda, std::complex* b, int ldb, float* w) +{ + int info = LAPACKE_chegvd(toLapackLayout(layout), itype, jobz, uplo, n, reinterpret_cast(a), lda, reinterpret_cast(b), ldb, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK chegvd failed with info = " + std::to_string(info)); + } +} + +void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::complex* a, int lda, std::complex* b, int ldb, double* w) +{ + int info = LAPACKE_zhegvd(toLapackLayout(layout), itype, jobz, uplo, n, reinterpret_cast(a), lda, reinterpret_cast(b), ldb, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zhegvd failed with info = " + std::to_string(info)); + } +} + +void hegvx(MatrixLayout layout, int itype, char jobz, char range, char uplo, int n, + std::complex* a, int lda, std::complex* b, int ldb, + float vl, float vu, int il, int iu, float abstol, int* m, + float* w, std::complex* z, int ldz, int* ifail) +{ + int info = LAPACKE_chegvx(toLapackLayout(layout), itype, jobz, range, uplo, n, + reinterpret_cast(a), lda, + reinterpret_cast(b), ldb, + vl, vu, il, iu, abstol, m, w, + reinterpret_cast(z), ldz, ifail); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK chegvx failed with info = " + std::to_string(info)); + } +} + +void hegvx(MatrixLayout layout, int itype, char jobz, char range, char uplo, int n, + std::complex* a, int lda, std::complex* b, int ldb, + double vl, double vu, int il, int iu, double abstol, int* m, + double* w, std::complex* z, int ldz, int* ifail) +{ + int info = LAPACKE_zhegvx(toLapackLayout(layout), itype, jobz, range, uplo, n, + reinterpret_cast(a), lda, + reinterpret_cast(b), ldb, + vl, vu, il, iu, abstol, m, w, + reinterpret_cast(z), ldz, ifail); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zhegvx failed with info = " + std::to_string(info)); + } +} + +void hegvx(MatrixLayout layout, int itype, char jobz, char range, char uplo, int n, + double* a, int lda, double* b, int ldb, + double vl, double vu, int il, int iu, double abstol, int* m, + double* w, double* z, int ldz, int* ifail) +{ + int info = LAPACKE_dsygvx(toLapackLayout(layout), itype, jobz, range, uplo, n, + a, lda, b, ldb, + vl, vu, il, iu, abstol, m, w, + z, ldz, ifail); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsygvx failed with info = " + std::to_string(info)); + } +} + +void potrf(MatrixLayout layout, char uplo, int n, float* a, int lda) +{ + int info = LAPACKE_spotrf(toLapackLayout(layout), uplo, n, a, lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK spotrf failed with info = " + std::to_string(info)); + } +} + +void potrf(MatrixLayout layout, char uplo, int n, double* a, int lda) +{ + int info = LAPACKE_dpotrf(toLapackLayout(layout), uplo, n, a, lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dpotrf failed with info = " + std::to_string(info)); + } +} + +void potrf(MatrixLayout layout, char uplo, int n, std::complex* a, int lda) +{ + int info = LAPACKE_cpotrf(toLapackLayout(layout), uplo, n, reinterpret_cast(a), lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cpotrf failed with info = " + std::to_string(info)); + } +} + +void potrf(MatrixLayout layout, char uplo, int n, std::complex* a, int lda) +{ + int info = LAPACKE_zpotrf(toLapackLayout(layout), uplo, n, reinterpret_cast(a), lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zpotrf failed with info = " + std::to_string(info)); + } +} + +void potri(MatrixLayout layout, char uplo, int n, float* a, int lda) +{ + int info = LAPACKE_spotri(toLapackLayout(layout), uplo, n, a, lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK spotri failed with info = " + std::to_string(info)); + } +} + +void potri(MatrixLayout layout, char uplo, int n, double* a, int lda) +{ + int info = LAPACKE_dpotri(toLapackLayout(layout), uplo, n, a, lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dpotri failed with info = " + std::to_string(info)); + } +} + +void potri(MatrixLayout layout, char uplo, int n, std::complex* a, int lda) +{ + int info = LAPACKE_cpotri(toLapackLayout(layout), uplo, n, reinterpret_cast(a), lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cpotri failed with info = " + std::to_string(info)); + } +} + +void potri(MatrixLayout layout, char uplo, int n, std::complex* a, int lda) +{ + int info = LAPACKE_zpotri(toLapackLayout(layout), uplo, n, reinterpret_cast(a), lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zpotri failed with info = " + std::to_string(info)); + } +} + +void heev(MatrixLayout layout, char jobz, char uplo, int n, std::complex* a, int lda, float* w) +{ + int info = LAPACKE_cheev(toLapackLayout(layout), jobz, uplo, n, reinterpret_cast(a), lda, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cheev failed with info = " + std::to_string(info)); + } +} + +void heev(MatrixLayout layout, char jobz, char uplo, int n, std::complex* a, int lda, double* w) +{ + int info = LAPACKE_zheev(toLapackLayout(layout), jobz, uplo, n, reinterpret_cast(a), lda, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zheev failed with info = " + std::to_string(info)); + } +} + +void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, float* a, int lda, float vl, + float vu, int il, int iu, float abstol, int* m, float* w, float* z, int ldz, int* ifail) +{ + int info = LAPACKE_ssyevx(toLapackLayout(layout), jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, ifail); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ssyevx failed with info = " + std::to_string(info)); + } +} + +void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, double* a, int lda, double vl, + double vu, int il, int iu, double abstol, int* m, double* w, double* z, int ldz, int* ifail) +{ + int info = LAPACKE_dsyevx(toLapackLayout(layout), jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, ifail); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsyevx failed with info = " + std::to_string(info)); + } +} + +void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, std::complex* a, int lda, + float vl, float vu, int il, int iu, float abstol, int* m, float* w, + std::complex* z, int ldz, int* ifail) +{ + int info = LAPACKE_cheevx(toLapackLayout(layout), jobz, range, uplo, n, + reinterpret_cast(a), lda, + vl, vu, il, iu, abstol, m, w, + reinterpret_cast(z), ldz, ifail); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cheevx failed with info = " + std::to_string(info)); + } +} + +void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, std::complex* a, int lda, + double vl, double vu, int il, int iu, double abstol, int* m, double* w, + std::complex* z, int ldz, int* ifail) +{ + int info = LAPACKE_zheevx(toLapackLayout(layout), jobz, range, uplo, n, + reinterpret_cast(a), lda, + vl, vu, il, iu, abstol, m, w, + reinterpret_cast(z), ldz, ifail); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zheevx failed with info = " + std::to_string(info)); + } +} + +void heevd(MatrixLayout layout, char jobz, char uplo, int n, + float* a, int lda, float* w) +{ + int info = LAPACKE_ssyevd(toLapackLayout(layout), jobz, uplo, n, a, lda, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ssyevd failed with info = " + std::to_string(info)); + } +} + +void heevd(MatrixLayout layout, char jobz, char uplo, int n, + double* a, int lda, double* w) +{ + int info = LAPACKE_dsyevd(toLapackLayout(layout), jobz, uplo, n, a, lda, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsyevd failed with info = " + std::to_string(info)); + } +} + +void heevd(MatrixLayout layout, char jobz, char uplo, int n, + std::complex* a, int lda, float* w) +{ + int info = LAPACKE_cheevd(toLapackLayout(layout), jobz, uplo, n, reinterpret_cast(a), lda, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cheevd failed with info = " + std::to_string(info)); + } +} + +void heevd(MatrixLayout layout, char jobz, char uplo, int n, + std::complex* a, int lda, double* w) +{ + int info = LAPACKE_zheevd(toLapackLayout(layout), jobz, uplo, n, reinterpret_cast(a), lda, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zheevd failed with info = " + std::to_string(info)); + } +} + +void syev(MatrixLayout layout, char jobz, char uplo, int n, double* a, int lda, double* w) +{ + int info = LAPACKE_dsyev(toLapackLayout(layout), jobz, uplo, n, a, lda, w); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsyev failed with info = " + std::to_string(info)); + } +} + +void geev(MatrixLayout layout, char jobvl, char jobvr, int n, double* a, int lda, + double* wr, double* wi, double* vl, int ldvl, double* vr, int ldvr) +{ + int info = LAPACKE_dgeev(toLapackLayout(layout), jobvl, jobvr, n, a, lda, wr, wi, vl, ldvl, vr, ldvr); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgeev failed with info = " + std::to_string(info)); + } +} + +void geev(MatrixLayout layout, char jobvl, char jobvr, int n, std::complex* a, int lda, + std::complex* w, std::complex* vl, int ldvl, std::complex* vr, int ldvr) +{ + int info = LAPACKE_zgeev(toLapackLayout(layout), jobvl, jobvr, n, reinterpret_cast(a), lda, + reinterpret_cast(w), reinterpret_cast(vl), ldvl, + reinterpret_cast(vr), ldvr); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zgeev failed with info = " + std::to_string(info)); + } +} + +void getrf(MatrixLayout layout, int m, int n, float* a, int lda, int* ipiv) +{ + int info = LAPACKE_sgetrf(toLapackLayout(layout), m, n, a, lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK sgetrf failed with info = " + std::to_string(info)); + } +} + +void getrf(MatrixLayout layout, int m, int n, double* a, int lda, int* ipiv) +{ + int info = LAPACKE_dgetrf(toLapackLayout(layout), m, n, a, lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgetrf failed with info = " + std::to_string(info)); + } +} + +void getrf(MatrixLayout layout, int m, int n, std::complex* a, int lda, int* ipiv) +{ + int info = LAPACKE_cgetrf(toLapackLayout(layout), m, n, reinterpret_cast(a), lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cgetrf failed with info = " + std::to_string(info)); + } +} + +void getrf(MatrixLayout layout, int m, int n, std::complex* a, int lda, int* ipiv) +{ + int info = LAPACKE_zgetrf(toLapackLayout(layout), m, n, reinterpret_cast(a), lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zgetrf failed with info = " + std::to_string(info)); + } +} + +void getri(MatrixLayout layout, int n, float* a, int lda, const int* ipiv) +{ + int info = LAPACKE_sgetri(toLapackLayout(layout), n, a, lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK sgetri failed with info = " + std::to_string(info)); + } +} + +void getri(MatrixLayout layout, int n, double* a, int lda, const int* ipiv) +{ + int info = LAPACKE_dgetri(toLapackLayout(layout), n, a, lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgetri failed with info = " + std::to_string(info)); + } +} + +void getri(MatrixLayout layout, int n, std::complex* a, int lda, const int* ipiv) +{ + int info = LAPACKE_cgetri(toLapackLayout(layout), n, reinterpret_cast(a), lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cgetri failed with info = " + std::to_string(info)); + } +} + +void getri(MatrixLayout layout, int n, std::complex* a, int lda, const int* ipiv) +{ + int info = LAPACKE_zgetri(toLapackLayout(layout), n, reinterpret_cast(a), lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zgetri failed with info = " + std::to_string(info)); + } +} + +void getrs(MatrixLayout layout, char trans, int n, int nrhs, const float* a, int lda, const int* ipiv, float* b, int ldb) +{ + int info = LAPACKE_sgetrs(toLapackLayout(layout), trans, n, nrhs, a, lda, ipiv, b, ldb); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK sgetrs failed with info = " + std::to_string(info)); + } +} + +void getrs(MatrixLayout layout, char trans, int n, int nrhs, const double* a, int lda, const int* ipiv, double* b, int ldb) +{ + int info = LAPACKE_dgetrs(toLapackLayout(layout), trans, n, nrhs, a, lda, ipiv, b, ldb); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgetrs failed with info = " + std::to_string(info)); + } +} + +void getrs(MatrixLayout layout, char trans, int n, int nrhs, const std::complex* a, int lda, const int* ipiv, std::complex* b, int ldb) +{ + int info = LAPACKE_cgetrs(toLapackLayout(layout), trans, n, nrhs, reinterpret_cast(a), lda, ipiv, reinterpret_cast(b), ldb); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cgetrs failed with info = " + std::to_string(info)); + } +} + +void getrs(MatrixLayout layout, char trans, int n, int nrhs, const std::complex* a, int lda, const int* ipiv, std::complex* b, int ldb) +{ + int info = LAPACKE_zgetrs(toLapackLayout(layout), trans, n, nrhs, reinterpret_cast(a), lda, ipiv, reinterpret_cast(b), ldb); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zgetrs failed with info = " + std::to_string(info)); + } +} + +void sytrf(MatrixLayout layout, char uplo, int n, double* a, int lda, int* ipiv) +{ + int info = LAPACKE_dsytrf(toLapackLayout(layout), uplo, n, a, lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsytrf failed with info = " + std::to_string(info)); + } +} + +void sytri(MatrixLayout layout, char uplo, int n, double* a, int lda, const int* ipiv) +{ + int info = LAPACKE_dsytri(toLapackLayout(layout), uplo, n, a, lda, ipiv); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsytri failed with info = " + std::to_string(info)); + } +} + +void trtri(MatrixLayout layout, char uplo, char diag, int n, float* a, int lda) +{ + int info = LAPACKE_strtri(toLapackLayout(layout), uplo, diag, n, a, lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK strtri failed with info = " + std::to_string(info)); + } +} + +void trtri(MatrixLayout layout, char uplo, char diag, int n, double* a, int lda) +{ + int info = LAPACKE_dtrtri(toLapackLayout(layout), uplo, diag, n, a, lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dtrtri failed with info = " + std::to_string(info)); + } +} + +void trtri(MatrixLayout layout, char uplo, char diag, int n, std::complex* a, int lda) +{ + int info = LAPACKE_ztrtri(toLapackLayout(layout), uplo, diag, n, reinterpret_cast(a), lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ztrtri failed with info = " + std::to_string(info)); + } +} + +void trtri(MatrixLayout layout, char uplo, char diag, int n, std::complex* a, int lda) +{ + int info = LAPACKE_ctrtri(toLapackLayout(layout), uplo, diag, n, reinterpret_cast(a), lda); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ctrtri failed with info = " + std::to_string(info)); + } +} + +void gtsv(MatrixLayout layout, int n, int nrhs, double* dl, double* d, double* du, double* b, int ldb) +{ + int info = LAPACKE_dgtsv(toLapackLayout(layout), n, nrhs, dl, d, du, b, ldb); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgtsv failed with info = " + std::to_string(info)); + } +} + +void sysv(MatrixLayout layout, char uplo, int n, int nrhs, double* a, int lda, int* ipiv, double* b, int ldb) +{ + int info = LAPACKE_dsysv(toLapackLayout(layout), uplo, n, nrhs, a, lda, ipiv, b, ldb); + if (info != 0) { + ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsysv failed with info = " + std::to_string(info)); + } +} +} // namespace LapackConnector + diff --git a/source/source_base/module_container/ATen/kernels/blas.cpp b/source/source_base/module_container/ATen/kernels/blas.cpp index eb192a7c9e..0b6f349a0e 100644 --- a/source/source_base/module_container/ATen/kernels/blas.cpp +++ b/source/source_base/module_container/ATen/kernels/blas.cpp @@ -13,7 +13,7 @@ struct blas_dot { const int& incy, T* result) { - *result = BlasConnector::dot(n, x, incx, y, incy); + *result = BlasConnector::dotc(n, x, incx, y, incy); } }; @@ -58,7 +58,7 @@ struct blas_gemv { T* y, const int& incy) { - BlasConnector::gemv(trans, m, n, *alpha, A, lda, x, incx, *beta, y, incy); + BlasConnector::gemv_cm(trans, m, n, *alpha, A, lda, x, incx, *beta, y, incy); } }; @@ -78,7 +78,10 @@ struct blas_gemv_batched { const int& incy, const int& batch_size) { - BlasConnector::gemv_batched(trans, m, n, *alpha, A, lda, x, incx, *beta, y, incy, batch_size); + for (int ii = 0; ii < batch_size; ++ii) { + // Call the single GEMV for each pair of matrix A[ii] and vector x[ii] + BlasConnector::gemv_cm(trans, m, n, *alpha, A[ii], lda, x[ii], incy, *beta, y[ii], incy); + } } }; @@ -102,7 +105,10 @@ struct blas_gemv_batched_strided { const int64_t& stride_y, const int& batch_size) { - BlasConnector::gemv_batched_strided(trans, m, n, *alpha, A, lda, stride_a, x, incx, stride_x, *beta, y, incy, stride_y, batch_size); + for (int ii = 0; ii < batch_size; ii++) { + // Call the single GEMV for each pair of matrix A[ii] and vector x[ii] + BlasConnector::gemv_cm(trans, m, n, *alpha, A + ii * stride_a, lda, x + ii * stride_x, incx, *beta, y + ii * stride_y, incy); + } } }; @@ -123,7 +129,7 @@ struct blas_gemm { T* C, const int& ldc) { - BlasConnector::gemm(transa, transb, m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc); + BlasConnector::gemm_cm(transa, transb, m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc); } }; @@ -145,7 +151,10 @@ struct blas_gemm_batched { const int& ldc, const int& batch_size) { - BlasConnector::gemm_batched(transa, transb, m, n, k, *alpha, A, lda, B, ldb, *beta, C, ldc, batch_size); + for (int ii = 0; ii < batch_size; ++ii) { + // Call the single GEMV for each pair of matrix A[ii] and vector x[ii] + BlasConnector::gemm_cm(transa, transb, m, n, k, *alpha, A[ii], lda, B[ii], ldb, *beta, C[ii], ldc); + } } }; @@ -170,7 +179,10 @@ struct blas_gemm_batched_strided { const int& stride_c, const int& batch_size) { - BlasConnector::gemm_batched_strided(transa, transb, m, n, k, *alpha, A, lda, stride_a, B, ldb, stride_b, *beta, C, ldc, stride_c, batch_size); + for (int ii = 0; ii < batch_size; ii++) { + // Call the single GEMV for each pair of matrix A[ii] and vector x[ii] + BlasConnector::gemm_cm(transa, transb, m, n, k, *alpha, A + ii * stride_a, lda, B + ii * stride_b, ldb, *beta, C + ii * stride_c, ldc); + } } }; diff --git a/source/source_base/module_container/ATen/kernels/cuda/lapack.cu b/source/source_base/module_container/ATen/kernels/cuda/lapack.cu index 6300df953f..58ce40fab9 100644 --- a/source/source_base/module_container/ATen/kernels/cuda/lapack.cu +++ b/source/source_base/module_container/ATen/kernels/cuda/lapack.cu @@ -152,7 +152,7 @@ struct lapack_getrs { const int& nrhs, T* A, const int& lda, - const int* ipiv, + int* ipiv, T* B, const int& ldb) { diff --git a/source/source_base/module_container/ATen/kernels/lapack.cpp b/source/source_base/module_container/ATen/kernels/lapack.cpp index 2369306309..f1345804a9 100644 --- a/source/source_base/module_container/ATen/kernels/lapack.cpp +++ b/source/source_base/module_container/ATen/kernels/lapack.cpp @@ -38,11 +38,7 @@ struct lapack_trtri { T* Mat, const int& lda) { - int info = 0; - lapackConnector::trtri(uplo, diag, dim, Mat, lda, info); - if (info != 0) { - throw std::runtime_error("potrf failed with info = " + std::to_string(info)); - } + lapackConnector::trtri(uplo, diag, dim, Mat, lda); } }; @@ -54,11 +50,7 @@ struct lapack_potrf { T* Mat, const int& lda) { - int info = 0; - lapackConnector::potrf(uplo, dim, Mat, dim, info); - if (info != 0) { - throw std::runtime_error("potrf failed with info = " + std::to_string(info)); - } + lapackConnector::potrf(uplo, dim, Mat, dim); } }; @@ -72,23 +64,7 @@ struct lapack_dnevd { const int& dim, Real* eigen_val) { - int info = 0; - int lwork = std::max(2 * dim + dim * dim, 1 + 6 * dim + 2 * dim * dim); - Tensor work(DataTypeToEnum::value, DeviceType::CpuDevice, {lwork}); - work.zero(); - - int lrwork = 1 + 5 * dim + 2 * dim * dim; - Tensor rwork(DataTypeToEnum::value, DeviceType::CpuDevice, {lrwork}); - rwork.zero(); - - int liwork = 3 + 5 * dim; - Tensor iwork(DataTypeToEnum::value, DeviceType::CpuDevice, {liwork}); - iwork.zero(); - - lapackConnector::dnevd(jobz, uplo, dim, Mat, dim, eigen_val, work.data(), lwork, rwork.data(), lrwork, iwork.data(), liwork, info); - if (info != 0) { - throw std::runtime_error("dnevd failed with info = " + std::to_string(info)); - } + lapackConnector::dnevd(jobz, uplo, dim, Mat, dim, eigen_val); } }; @@ -104,23 +80,7 @@ struct lapack_dngvd { const int& dim, Real* eigen_val) { - int info = 0; - int lwork = std::max(2 * dim + dim * dim, 1 + 6 * dim + 2 * dim * dim); - Tensor work(DataTypeToEnum::value, DeviceType::CpuDevice, {lwork}); - work.zero(); - - int lrwork = 1 + 5 * dim + 2 * dim * dim; - Tensor rwork(DataTypeToEnum::value, DeviceType::CpuDevice, {lrwork}); - rwork.zero(); - - int liwork = 3 + 5 * dim; - Tensor iwork(DataType::DT_INT, DeviceType::CpuDevice, {liwork}); - iwork.zero(); - - lapackConnector::dngvd(itype, jobz, uplo, dim, Mat_A, dim, Mat_B, dim, eigen_val, work.data(), lwork, rwork.data(), lrwork, iwork.data(), liwork, info); - if (info != 0) { - throw std::runtime_error("dngvd failed with info = " + std::to_string(info)); - } + lapackConnector::dngvd(itype, jobz, uplo, dim, Mat_A, dim, Mat_B, dim, eigen_val); } }; @@ -133,11 +93,7 @@ struct lapack_getrf { const int& lda, int* ipiv) { - int info = 0; - lapackConnector::getrf(m, n, Mat, lda, ipiv, info); - if (info != 0) { - throw std::runtime_error("getrf failed with info = " + std::to_string(info)); - } + lapackConnector::getrf(m, n, Mat, lda, ipiv); } }; @@ -147,15 +103,11 @@ struct lapack_getri { const int& n, T* Mat, const int& lda, - const int* ipiv, + int* ipiv, T* work, const int& lwork) { - int info = 0; - lapackConnector::getri(n, Mat, lda, ipiv, work, lwork, info); - if (info != 0) { - throw std::runtime_error("getri failed with info = " + std::to_string(info)); - } + lapackConnector::getri(n, Mat, lda, ipiv); } }; @@ -167,15 +119,11 @@ struct lapack_getrs { const int& nrhs, T* A, const int& lda, - const int* ipiv, + int* ipiv, T* B, const int& ldb) { - int info = 0; - lapackConnector::getrs(trans, n, nrhs, A, lda, ipiv, B, ldb, info); - if (info != 0) { - throw std::runtime_error("getrs failed with info = " + std::to_string(info)); - } + lapackConnector::getrs(trans, n, nrhs, A, lda, ipiv, B, ldb); } }; diff --git a/source/source_base/module_container/ATen/kernels/lapack.h b/source/source_base/module_container/ATen/kernels/lapack.h index cf164dec10..c5e252918d 100644 --- a/source/source_base/module_container/ATen/kernels/lapack.h +++ b/source/source_base/module_container/ATen/kernels/lapack.h @@ -96,7 +96,7 @@ struct lapack_getrs { const int& nrhs, T* A, const int& lda, - const int* ipiv, + int* ipiv, T* B, const int& ldb); }; diff --git a/source/source_base/module_container/base/third_party/blas.h b/source/source_base/module_container/base/third_party/blas.h index 5c73032e05..216608aecb 100644 --- a/source/source_base/module_container/base/third_party/blas.h +++ b/source/source_base/module_container/base/third_party/blas.h @@ -2,6 +2,7 @@ #define BASE_THIRD_PARTY_BLAS_H_ #include +#include "source_base/blas_connector.h" #if defined(__CUDA) #include @@ -9,347 +10,5 @@ #include #endif -extern "C" -{ -// level 1: std::vector-std::vector operations, O(n) data and O(n) work. - -// Peize Lin add ?scal 2016-08-04, to compute x=a*x -void sscal_(const int *N, const float *alpha, float *x, const int *incx); -void dscal_(const int *N, const double *alpha, double *x, const int *incx); -void cscal_(const int *N, const std::complex *alpha, std::complex *x, const int *incx); -void zscal_(const int *N, const std::complex *alpha, std::complex *x, const int *incx); - -// Peize Lin add ?axpy 2016-08-04, to compute y=a*x+y -void saxpy_(const int *N, const float *alpha, const float *x, const int *incx, float *y, const int *incy); -void daxpy_(const int *N, const double *alpha, const double *x, const int *incx, double *y, const int *incy); -void caxpy_(const int *N, const std::complex *alpha, const std::complex *x, const int *incx, std::complex *y, const int *incy); -void zaxpy_(const int *N, const std::complex *alpha, const std::complex *x, const int *incx, std::complex *y, const int *incy); - -void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy); -void zcopy_(long const *n, const std::complex *a, int const *incx, std::complex *b, int const *incy); - -//reason for passing results as argument instead of returning it: -//see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/ -void cdotc_(const int *n, const std::complex *zx, const int *incx, - const std::complex *zy, const int *incy, std::complex *result); -void zdotc_(const int *n, const std::complex *zx, const int *incx, - const std::complex *zy, const int *incy, std::complex *result); -// Peize Lin add ?dot 2017-10-27, to compute d=x*y -float sdot_(const int *N, const float *x, const int *incx, const float *y, const int *incy); -double ddot_(const int *N, const double *x, const int *incx, const double *y, const int *incy); - -// Peize Lin add ?nrm2 2018-06-12, to compute out = ||x||_2 = \sqrt{ \sum_i x_i**2 } -float snrm2_( const int *n, const float *x, const int *incx ); -double dnrm2_( const int *n, const double *x, const int *incx ); -double dznrm2_( const int *n, const std::complex *x, const int *incx ); - -// level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work. -void sgemv_(const char*const transa, const int*const m, const int*const n, - const float*const alpha, const float*const a, const int*const lda, const float*const x, const int*const incx, - const float*const eta, float*const y, const int*const incy); -void dgemv_(const char*const transa, const int*const m, const int*const n, - const double*const alpha, const double*const a, const int*const lda, const double*const x, const int*const incx, - const double*const beta, double*const y, const int*const incy); - -void cgemv_(const char *trans, const int *m, const int *n, const std::complex *alpha, - const std::complex *a, const int *lda, const std::complex *x, const int *incx, - const std::complex *beta, std::complex *y, const int *incy); - -void zgemv_(const char *trans, const int *m, const int *n, const std::complex *alpha, - const std::complex *a, const int *lda, const std::complex *x, const int *incx, - const std::complex *beta, std::complex *y, const int *incy); - -void dsymv_(const char *uplo, const int *n, - const double *alpha, const double *a, const int *lda, - const double *x, const int *incx, - const double *beta, double *y, const int *incy); - -// A := alpha x * y.T + A -void dger_(const int* m, - const int* n, - const double* alpha, - const double* x, - const int* incx, - const double* y, - const int* incy, - double* a, - const int* lda); -void zgerc_(const int* m, - const int* n, - const std::complex* alpha, - const std::complex* x, - const int* incx, - const std::complex* y, - const int* incy, - std::complex* a, - const int* lda); - -// level 3: matrix-matrix operations, O(n^2) data and O(n^3) work. - -// Peize Lin add ?gemm 2017-10-27, to compute C = a * A.? * B.? + b * C -// A is general -void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, - const float *beta, float *c, const int *ldc); -void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, - const double *beta, double *c, const int *ldc); -void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); -void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); - - -//a is symmetric -void dsymm_(const char *side, const char *uplo, const int *m, const int *n, - const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, - const double *beta, double *c, const int *ldc); -//a is hermitian -void zhemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, - std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); - -//solving triangular matrix with multiple right hand sides -void dtrsm_(char *side, char* uplo, char *transa, char *diag, int *m, int *n, - double* alpha, double* a, int *lda, double*b, int *ldb); -void ztrsm_(char *side, char* uplo, char *transa, char *diag, int *m, int *n, - std::complex* alpha, std::complex* a, int *lda, std::complex*b, int *ldb); - -} - -namespace container { - -// Class BlasConnector provide the connector to fortran lapack routine. -// The entire function in this class are static and inline function. -// Usage example: BlasConnector::functionname(parameter list). -namespace BlasConnector { - -static inline -void axpy( const int& n, const float& alpha, const float *x, const int& incx, float *y, const int& incy) -{ - saxpy_(&n, &alpha, x, &incx, y, &incy); -} -static inline -void axpy( const int& n, const double& alpha, const double *x, const int& incx, double *y, const int& incy) -{ - daxpy_(&n, &alpha, x, &incx, y, &incy); -} -static inline -void axpy( const int& n, const std::complex& alpha, const std::complex *x, const int& incx, std::complex *y, const int& incy) -{ - caxpy_(&n, &alpha, x, &incx, y, &incy); -} -static inline -void axpy( const int& n, const std::complex& alpha, const std::complex *x, const int& incx, std::complex *y, const int& incy) -{ - zaxpy_(&n, &alpha, x, &incx, y, &incy); -} - -// Peize Lin add 2016-08-04 -// x=a*x -static inline -void scal( const int& n, const float& alpha, float *x, const int& incx) -{ - sscal_(&n, &alpha, x, &incx); -} -static inline -void scal( const int& n, const double& alpha, double *x, const int& incx) -{ - dscal_(&n, &alpha, x, &incx); -} -static inline -void scal( const int& n, const std::complex& alpha, std::complex *x, const int& incx) -{ - cscal_(&n, &alpha, x, &incx); -} -static inline -void scal( const int& n, const std::complex& alpha, std::complex *x, const int& incx) -{ - zscal_(&n, &alpha, x, &incx); -} - -// Peize Lin add 2017-10-27 -// d=x*y -static inline -float dot( const int& n, const float *x, const int& incx, const float *y, const int& incy) -{ - return sdot_(&n, x, &incx, y, &incy); -} -static inline -double dot( const int& n, const double *x, const int& incx, const double *y, const int& incy) -{ - return ddot_(&n, x, &incx, y, &incy); -} -// Denghui Lu add 2023-8-01 -static inline -std::complex dot(const int& n, const std::complex *x, const int& incx, const std::complex *y, const int& incy) -{ - std::complex result = {0, 0}; - // cdotc_(&n, x, &incx, y, &incy, &result); - for (int ii = 0; ii < n; ii++) { - result += std::conj(x[ii * incx]) * y[ii * incy]; - } - return result; -} -static inline -std::complex dot(const int& n, const std::complex *x, const int& incx, const std::complex *y, const int& incy) -{ - std::complex result = {0, 0}; - // zdotc_(&n, x, &incx, y, &incy, &result); - for (int ii = 0; ii < n; ii++) { - result += std::conj(x[ii * incx]) * y[ii * incy]; - } - return result; -} - -// Peize Lin add 2017-10-27, fix bug trans 2019-01-17 -// C = a * A.? * B.? + b * C -static inline -void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k, - const float& alpha, const float* A, const int& lda, const float* B, const int& ldb, - const float& beta, float* C, const int& ldc) -{ - sgemm_(&transa, &transb, &m, &n, &k, - &alpha, A, &lda, B, &ldb, - &beta, C, &ldc); -} -static inline -void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k, - const double& alpha, const double* A, const int& lda, const double* B, const int& ldb, - const double& beta, double* C, const int& ldc) -{ - dgemm_(&transa, &transb, &m, &n, &k, - &alpha, A, &lda, B, &ldb, - &beta, C, &ldc); -} -static inline -void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k, - const std::complex& alpha, const std::complex* A, const int& lda, const std::complex* B, const int& ldb, - const std::complex& beta, std::complex* C, const int& ldc) -{ - cgemm_(&transa, &transb, &m, &n, &k, - &alpha, A, &lda, B, &ldb, - &beta, C, &ldc); -} -static inline -void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k, - const std::complex& alpha, const std::complex* A, const int& lda, const std::complex* B, const int& ldb, - const std::complex& beta, std::complex* C, const int& ldc) -{ - zgemm_(&transa, &transb, &m, &n, &k, - &alpha, A, &lda, B, &ldb, - &beta, C, &ldc); -} - -template -static inline -void gemm_batched(const char& transa, const char& transb, const int& m, const int& n, const int& k, - const T& alpha, T** A, const int& lda, T** B, const int& ldb, - const T& beta, T** C, const int& ldc, const int& batch_size) -{ - for (int ii = 0; ii < batch_size; ++ii) { - // Call the single GEMV for each pair of matrix A[ii] and vector x[ii] - BlasConnector::gemm(transa, transb, m, n, k, alpha, A[ii], lda, B[ii], ldb, beta, C[ii], ldc); - } -} - -template -static inline -void gemm_batched_strided(const char& transa, const char& transb, const int& m, const int& n, const int& k, - const T& alpha, const T* A, const int& lda, const int& stride_a, const T* B, const int& ldb, const int& stride_b, - const T& beta, T* C, const int& ldc, const int& stride_c, const int& batch_size) -{ - for (int ii = 0; ii < batch_size; ii++) { - // Call the single GEMV for each pair of matrix A[ii] and vector x[ii] - BlasConnector::gemm(transa, transb, m, n, k, alpha, A + ii * stride_a, lda, B + ii * stride_b, ldb, beta, C + ii * stride_c, ldc); - } -} - -static inline -void gemv(const char& trans, const int& m, const int& n, - const float& alpha, const float *A, const int& lda, const float *x, const int& incx, - const float& beta, float *y, const int& incy) -{ - sgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy); -} -static inline -void gemv(const char& trans, const int& m, const int& n, - const double& alpha, const double *A, const int& lda, const double *x, const int& incx, - const double& beta, double *y, const int& incy) -{ - dgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy); -} -static inline -void gemv(const char& trans, const int& m, const int& n, - const std::complex& alpha, const std::complex *A, const int& lda, const std::complex *x, const int& incx, - const std::complex& beta, std::complex *y, const int& incy) -{ - cgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy); -} -static inline -void gemv(const char& trans, const int& m, const int& n, - const std::complex& alpha, const std::complex *A, const int& lda, const std::complex *x, const int& incx, - const std::complex& beta, std::complex *y, const int& incy) -{ - zgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy); -} - -template -static inline -void gemv_batched(const char& trans, const int& m, const int& n, - const T& alpha, T** A, const int& lda, T** x, const int& incx, - const T& beta, T** y, const int& incy, const int& batch_size) -{ - for (int ii = 0; ii < batch_size; ++ii) { - // Call the single GEMV for each pair of matrix A[ii] and vector x[ii] - BlasConnector::gemv(trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy); - } -} - -template -static inline -void gemv_batched_strided(const char& transa, const int& m, const int& n, - const T& alpha, const T* A, const int& lda, const int& stride_a, const T* x, const int& incx, const int& stride_x, - const T& beta, T* y, const int& incy, const int& stride_y, const int& batch_size) -{ - for (int ii = 0; ii < batch_size; ii++) { - // Call the single GEMV for each pair of matrix A[ii] and vector x[ii] - BlasConnector::gemv(transa, m, n, alpha, A + ii * stride_a, lda, x + ii * stride_x, incx, beta, y + ii * stride_y, incy); - } -} - -// Peize Lin add 2018-06-12 -// out = ||x||_2 -static inline -float nrm2( const int n, const float *x, const int incx ) -{ - return snrm2_( &n, x, &incx ); -} -static inline -double nrm2( const int n, const double *x, const int incx ) -{ - return dnrm2_( &n, x, &incx ); -} -static inline -double nrm2( const int n, const std::complex *x, const int incx ) -{ - return dznrm2_( &n, x, &incx ); -} - -// copies a into b -static inline -void copy(const long n, const double *a, const int incx, double *b, const int incy) -{ - dcopy_(&n, a, &incx, b, &incy); -} -static inline -void copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy) -{ - zcopy_(&n, a, &incx, b, &incy); -} - -} // namespace BlasConnector -} // namespace container #endif // BASE_THIRD_PARTY_BLAS_H_ diff --git a/source/source_base/module_container/base/third_party/lapack.h b/source/source_base/module_container/base/third_party/lapack.h index 7452dc1835..8ae1ecd350 100644 --- a/source/source_base/module_container/base/third_party/lapack.h +++ b/source/source_base/module_container/base/third_party/lapack.h @@ -2,7 +2,8 @@ #define BASE_THIRD_PARTY_LAPACK_H_ #include - +#include "source_base/macros.h" +#include "source_base/lapack_connector.h" #if defined(__CUDA) #include @@ -10,392 +11,79 @@ #include #endif -//Naming convention of lapack subroutines : ammxxx, where -//"a" specifies the data type: -// - d stands for double -// - z stands for complex double -//"mm" specifies the type of matrix, for example: -// - he stands for hermitian -// - sy stands for symmetric -//"xxx" specifies the type of problem, for example: -// - gv stands for generalized eigenvalue - -extern "C" -{ -int ilaenv_(int* ispec,const char* name,const char* opts, - const int* n1,const int* n2,const int* n3,const int* n4); - - -// solve the generalized eigenproblem Ax=eBx, where A is Hermitian and complex couble -// zhegv_ & zhegvd_ returns all eigenvalues while zhegvx_ returns selected ones -void ssygvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - float* a, const int* lda, - const float* b, const int* ldb, float* w, - float* work, int* lwork, - int* iwork, int* liwork, int* info); - -void dsygvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - double* a, const int* lda, - const double* b, const int* ldb, double* w, - double* work, int* lwork, - int* iwork, int* liwork, int* info); - -void chegvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - std::complex* a, const int* lda, - const std::complex* b, const int* ldb, float* w, - std::complex* work, int* lwork, float* rwork, int* lrwork, - int* iwork, int* liwork, int* info); - -void zhegvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - std::complex* a, const int* lda, - const std::complex* b, const int* ldb, double* w, - std::complex* work, int* lwork, double* rwork, int* lrwork, - int* iwork, int* liwork, int* info); - -void ssyevx_(const char* jobz, const char* range, const char* uplo, const int* n, - float *a, const int* lda, - const float* vl, const float* vu, const int* il, const int* iu, const float* abstol, - const int* m, float* w, float *z, const int *ldz, - float *work, const int* lwork, float* rwork, int* iwork, int* ifail, int* info); -void dsyevx_(const char* jobz, const char* range, const char* uplo, const int* n, - double *a, const int* lda, - const double* vl, const double* vu, const int* il, const int* iu, const double* abstol, - const int* m, double* w, double *z, const int *ldz, - double *work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info); -void cheevx_(const char* jobz, const char* range, const char* uplo, const int* n, - std::complex *a, const int* lda, - const float* vl, const float* vu, const int* il, const int* iu, const float* abstol, - const int* m, float* w, std::complex *z, const int *ldz, - std::complex *work, const int* lwork, float* rwork, int* iwork, int* ifail, int* info); -void zheevx_(const char* jobz, const char* range, const char* uplo, const int* n, - std::complex *a, const int* lda, - const double* vl, const double* vu, const int* il, const int* iu, const double* abstol, - const int* m, double* w, std::complex *z, const int *ldz, - std::complex *work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info); - -void ssyevd_(const char *jobz, const char *uplo, const int *n, - float *a, const int *lda, float *w, - float *work, int *lwork, - int *iwork, int *liwork, int *info); -void dsyevd_(const char *jobz, const char *uplo, const int *n, - double *a, const int *lda, double *w, - double *work, int *lwork, - int *iwork, int *liwork, int *info); -void cheevd_(const char *jobz, const char *uplo, const int *n, - std::complex *a, const int *lda, float *w, - std::complex *work, int *lwork, float *rwork, int *lrwork, - int *iwork, int *liwork, int *info); -void zheevd_(const char *jobz, const char *uplo, const int *n, - std::complex *a, const int *lda, double *w, - std::complex *work, int *lwork, double *rwork, int *lrwork, - int *iwork, int *liwork, int *info); - -void spotrf_(const char*const uplo, const int*const n, float*const A, const int*const lda, int*const info); -void dpotrf_(const char*const uplo, const int*const n, double*const A, const int*const lda, int*const info); -void cpotrf_(const char*const uplo, const int*const n, std::complex*const A, const int*const lda, int*const info); -void zpotrf_(const char*const uplo, const int*const n, std::complex*const A, const int*const lda, int*const info); - -void spotri_(const char*const uplo, const int*const n, float*const A, const int*const lda, int*const info); -void dpotri_(const char*const uplo, const int*const n, double*const A, const int*const lda, int*const info); -void cpotri_(const char*const uplo, const int*const n, std::complex*const A, const int*const lda, int*const info); -void zpotri_(const char*const uplo, const int*const n, std::complex*const A, const int*const lda, int*const info); - -void strtri_(const char* uplo, const char* diag, const int* n, float* a, const int* lda, int* info); -void dtrtri_(const char* uplo, const char* diag, const int* n, double* a, const int* lda, int* info); -void ctrtri_(const char* uplo, const char* diag, const int* n, std::complex* a, const int* lda, int* info); -void ztrtri_(const char* uplo, const char* diag, const int* n, std::complex* a, const int* lda, int* info); - -void sgetrf_(const int* m, const int* n, float* a, const int* lda, int* ipiv, int* info); -void dgetrf_(const int* m, const int* n, double* a, const int* lda, int* ipiv, int* info); -void cgetrf_(const int* m, const int* n, std::complex* a, const int* lda, int* ipiv, int* info); -void zgetrf_(const int* m, const int* n, std::complex* a, const int* lda, int* ipiv, int* info); - -void sgetri_(const int* n, float* A, const int* lda, const int* ipiv, float* work, const int* lwork, int* info); -void dgetri_(const int* n, double* A, const int* lda, const int* ipiv, double* work, const int* lwork, int* info); -void cgetri_(const int* n, std::complex* A, const int* lda, const int* ipiv, std::complex* work, const int* lwork, int* info); -void zgetri_(const int* n, std::complex* A, const int* lda, const int* ipiv, std::complex* work, const int* lwork, int* info); - -void sgetrs_(const char* trans, const int* n, const int* nrhs, const float* A, const int* lda, const int* ipiv, float* B, const int* ldb, int* info); -void dgetrs_(const char* trans, const int* n, const int* nrhs, const double* A, const int* lda, const int* ipiv, double* B, const int* ldb, int* info); -void cgetrs_(const char* trans, const int* n, const int* nrhs, const std::complex* A, const int* lda, const int* ipiv, std::complex* B, const int* ldb, int* info); -void zgetrs_(const char* trans, const int* n, const int* nrhs, const std::complex* A, const int* lda, const int* ipiv, std::complex* B, const int* ldb, int* info); -} - // Class LapackConnector provide the connector to fortran lapack routine. // The entire function in this class are static and inline function. // Usage example: LapackConnector::functionname(parameter list). namespace container { namespace lapackConnector { -static inline -int ilaenv( int ispec, const char *name,const char *opts,const int n1,const int n2, - const int n3,const int n4) -{ - const int nb = ilaenv_(&ispec, name, opts, &n1, &n2, &n3, &n4); - return nb; -} + +template +using Real = typename GetTypeReal::type; // wrap function of fortran lapack routine zhegvd. (pointer version) -static inline -void dngvd(const int itype, const char jobz, const char uplo, const int n, - float* a, const int lda, - const float* b, const int ldb, float* w, - float* work, int lwork, float* rwork, int lrwork, - int* iwork, int liwork, int info) +template +void dngvd(int itype, char jobz, char uplo, int n, + T* a, int lda, + T* b, int ldb, Real* w) { - // call the fortran routine - ssygvd_(&itype, &jobz, &uplo, &n, - a, &lda, b, &ldb, w, - work, &lwork, - iwork, &liwork, &info); -} -// wrap function of fortran lapack routine zhegvd. -static inline -void dngvd(const int itype, const char jobz, const char uplo, const int n, - double* a, const int lda, - const double* b, const int ldb, double* w, - double* work, int lwork, double* rwork, int lrwork, - int* iwork, int liwork, int info) -{ - // call the fortran routine - dsygvd_(&itype, &jobz, &uplo, &n, - a, &lda, b, &ldb, w, - work, &lwork, - iwork, &liwork, &info); -} -static inline -void dngvd(const int itype, const char jobz, const char uplo, const int n, - std::complex* a, const int lda, - const std::complex* b, const int ldb, float* w, - std::complex* work, int lwork, float* rwork, int lrwork, - int* iwork, int liwork, int info) -{ - // call the fortran routine - chegvd_(&itype, &jobz, &uplo, &n, - a, &lda, b, &ldb, w, - work, &lwork, rwork, &lrwork, - iwork, &liwork, &info); -} -// wrap function of fortran lapack routine zhegvd. -static inline -void dngvd(const int itype, const char jobz, const char uplo, const int n, - std::complex* a, const int lda, - const std::complex* b, const int ldb, double* w, - std::complex* work, int lwork, double* rwork, int lrwork, - int* iwork, int liwork, int info) -{ - // call the fortran routine - zhegvd_(&itype, &jobz, &uplo, &n, - a, &lda, b, &ldb, w, - work, &lwork, rwork, &lrwork, - iwork, &liwork, &info); + LapackConnector::hegvd( + LapackConnector::ColMajor, itype, jobz, uplo, n, + a, lda, b, ldb, w); } // wrap function of fortran lapack routine zheevx. -static inline -void dnevx( const int itype, const char jobz, const char range, const char uplo, const int n, - float* a, const int lda, - const float vl, const float vu, const int il, const int iu, const float abstol, - const int m, float* w, float* z, const int ldz, - float* work, const int lwork, float* rwork, int* iwork, int* ifail, int info) -{ - ssyevx_(&jobz, &range, &uplo, &n, - a, &lda, &vl, &vu, &il, &iu, - &abstol, &m, w, z, &ldz, - work, &lwork, rwork, iwork, ifail, &info); -} -// wrap function of fortran lapack routine zheevx. -static inline -void dnevx( const int itype, const char jobz, const char range, const char uplo, const int n, - double* a, const int lda, - const double vl, const double vu, const int il, const int iu, const double abstol, - const int m, double* w, double* z, const int ldz, - double* work, const int lwork, double* rwork, int* iwork, int* ifail, int info) +template +void dnevx( int itype, char jobz, char range, char uplo, int n, + T* a, int lda, + Real vl, Real vu, int il, int iu, Real abstol, + int m, Real* w, T* z, int ldz, int* ifail) { - dsyevx_(&jobz, &range, &uplo, &n, - a, &lda, &vl, &vu, &il, &iu, - &abstol, &m, w, z, &ldz, - work, &lwork, rwork, iwork, ifail, &info); -} -static inline -void dnevx( const int itype, const char jobz, const char range, const char uplo, const int n, - std::complex* a, const int lda, - const float vl, const float vu, const int il, const int iu, const float abstol, - const int m, float* w, std::complex* z, const int ldz, - std::complex* work, const int lwork, float* rwork, int* iwork, int* ifail, int info) -{ - cheevx_(&jobz, &range, &uplo, &n, - a, &lda, &vl, &vu, &il, &iu, - &abstol, &m, w, z, &ldz, - work, &lwork, rwork, iwork, ifail, &info); -} -// wrap function of fortran lapack routine zheevx. -static inline -void dnevx( const int itype, const char jobz, const char range, const char uplo, const int n, - std::complex* a, const int lda, - const double vl, const double vu, const int il, const int iu, const double abstol, - const int m, double* w, std::complex* z, const int ldz, - std::complex* work, const int lwork, double* rwork, int* iwork, int* ifail, int info) -{ - zheevx_(&jobz, &range, &uplo, &n, - a, &lda, &vl, &vu, &il, &iu, - &abstol, &m, w, z, &ldz, - work, &lwork, rwork, iwork, ifail, &info); + LapackConnector::heevx( + LapackConnector::ColMajor, jobz, range, uplo, n, + a, lda, vl, vu, il, iu, + abstol, m, w, z, ldz, ifail); } -static inline -void dnevd(const char jobz, const char uplo, const int n, - float* a, const int lda, float* w, - float* work, int lwork, float* rwork, int lrwork, - int* iwork, int liwork, int& info) -{ - // call the fortran routine - ssyevd_( &jobz, &uplo, &n, - a, &lda, w, - work, &lwork, - iwork, &liwork, &info); -} -// wrap function of fortran lapack routine zhegvd. -static inline -void dnevd(const char jobz, const char uplo, const int n, - double* a, const int lda, double* w, - double* work, int lwork, double* rwork, int lrwork, - int* iwork, int liwork, int& info) +template +void dnevd(char jobz, char uplo, int n, + T* a, int lda, Real* w) { - // call the fortran routine - dsyevd_( &jobz, &uplo, &n, - a, &lda, w, - work, &lwork, - iwork, &liwork, &info); -} -static inline -void dnevd(const char jobz, const char uplo, const int n, - std::complex* a, const int lda, float* w, - std::complex* work, int lwork, float* rwork, int lrwork, - int* iwork, int liwork, int& info) -{ - // call the fortran routine - cheevd_( &jobz, &uplo, &n, - a, &lda, w, - work, &lwork, rwork, &lrwork, - iwork, &liwork, &info); -} -// wrap function of fortran lapack routine zhegvd. -static inline -void dnevd(const char jobz, const char uplo, const int n, - std::complex* a, const int lda, double* w, - std::complex* work, int lwork, double* rwork, int lrwork, - int* iwork, int liwork, int& info) -{ - // call the fortran routine - zheevd_( &jobz, &uplo, &n, - a, &lda, w, - work, &lwork, rwork, &lrwork, - iwork, &liwork, &info); + LapackConnector::heevd( + LapackConnector::ColMajor, + jobz, uplo, n, + a, lda, w); } -static inline -void potrf( const char &uplo, const int &n, float* A, const int &lda, int &info ) +template +void potrf(char uplo, int n, T* A, int lda) { - spotrf_(&uplo, &n, A, &lda, &info ); + LapackConnector::potrf(LapackConnector::ColMajor, uplo, n, A, lda); } -static inline -void potrf( const char &uplo, const int &n, double* A, const int &lda, int &info ) -{ - dpotrf_(&uplo, &n, A, &lda, &info ); -} -static inline -void potrf( const char &uplo, const int &n, std::complex* A, const int &lda, int &info ) -{ - cpotrf_(&uplo, &n, A, &lda, &info ); -} -static inline -void potrf( const char &uplo, const int &n, std::complex* A, const int &lda, int &info ) -{ - zpotrf_( &uplo, &n, A, &lda, &info ); -} -static inline -void trtri( const char &uplo, const char &diag, const int &n, float* A, const int &lda, int &info ) -{ - strtri_( &uplo, &diag, &n, A, &lda, &info); -} -static inline -void trtri( const char &uplo, const char &diag, const int &n, double* A, const int &lda, int &info) -{ - dtrtri_( &uplo, &diag, &n, A, &lda, &info); -} -static inline -void trtri( const char &uplo, const char &diag, const int &n, std::complex* A, const int &lda, int &info ) -{ - ctrtri_( &uplo, &diag, &n, A, &lda, &info); -} -static inline -void trtri( const char &uplo, const char &diag, const int &n, std::complex* A, const int &lda, int &info) -{ - ztrtri_( &uplo, &diag, &n, A, &lda, &info); -} -static inline -void getrf(const int m, const int n, float* A, const int lda, int* ipiv, int &info) +template +void trtri(char uplo, char diag, int n, T* A, int lda) { - sgetrf_(&m, &n, A, &lda, ipiv, &info); -} -static inline -void getrf(const int m, const int n, double* A, const int lda, int* ipiv, int &info) -{ - dgetrf_(&m, &n, A, &lda, ipiv, &info); -} -static inline -void getrf(const int m, const int n, std::complex* A, const int lda, int* ipiv, int &info) -{ - cgetrf_(&m, &n, A, &lda, ipiv, &info); -} -static inline -void getrf(const int m, const int n, std::complex* A, const int lda, int* ipiv, int &info) -{ - zgetrf_(&m, &n, A, &lda, ipiv, &info); + LapackConnector::trtri(LapackConnector::ColMajor, uplo, diag, n, A, lda); } -static inline -void getri(const int n, float* A, const int lda, const int* ipiv, float* work, const int lwork, int& info) -{ - sgetri_(&n, A, &lda, ipiv, work, &lwork, &info); -} -static inline -void getri(const int n, double* A, const int lda, const int* ipiv, double* work, const int lwork, int& info) -{ - dgetri_(&n, A, &lda, ipiv, work, &lwork, &info); -} -static inline -void getri(const int n, std::complex* A, const int lda, const int* ipiv, std::complex* work, const int lwork, int& info) -{ - cgetri_(&n, A, &lda, ipiv, work, &lwork, &info); -} -static inline -void getri(const int n, std::complex* A, const int lda, const int* ipiv, std::complex* work, const int lwork, int& info) -{ - zgetri_(&n, A, &lda, ipiv, work, &lwork, &info); -} -static inline -void getrs(const char& trans, const int n, const int nrhs, float* A, const int lda, const int* ipiv, float* B, const int ldb, int& info) +template +void getrf(int m, int n, T* A, int lda, int* ipiv) { - sgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info); + LapackConnector::getrf(LapackConnector::ColMajor, m, n, A, lda, ipiv); } -static inline -void getrs(const char& trans, const int n, const int nrhs, double* A, const int lda, const int* ipiv, double* B, const int ldb, int& info) -{ - dgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info); -} -static inline -void getrs(const char& trans, const int n, const int nrhs, std::complex* A, const int lda, const int* ipiv, std::complex* B, const int ldb, int& info) + +template +void getri(int n, T* A, int lda, int* ipiv) { - cgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info); + LapackConnector::getri(LapackConnector::ColMajor, n, A, lda, ipiv); } -static inline -void getrs(const char& trans, const int n, const int nrhs, std::complex* A, const int lda, const int* ipiv, std::complex* B, const int ldb, int& info) + +template +void getrs(char trans, int n, int nrhs, T* A, int lda, int* ipiv, T* B, int ldb) { - zgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info); + LapackConnector::getrs(LapackConnector::ColMajor, trans, n, nrhs, A, lda, ipiv, B, ldb); } } // namespace lapackConnector diff --git a/source/source_base/module_external/blas_connector.h b/source/source_base/module_external/blas_connector.h index 921f94ddb9..fee0995bd2 100644 --- a/source/source_base/module_external/blas_connector.h +++ b/source/source_base/module_external/blas_connector.h @@ -2,144 +2,10 @@ #define BLAS_CONNECTOR_H #include +#include #include "source_base/module_device/types.h" #include "../macros.h" -// These still need to be linked in the header file -// Because quite a lot of code will directly use the original cblas kernels. - -extern "C" -{ - // level 1: std::vector-std::vector operations, O(n) data and O(n) work. - - // Peize Lin add ?scal 2016-08-04, to compute x=a*x - void sscal_(const int *N, const float *alpha, float *X, const int *incX); - void dscal_(const int *N, const double *alpha, double *X, const int *incX); - void cscal_(const int *N, const std::complex *alpha, std::complex *X, const int *incX); - void zscal_(const int *N, const std::complex *alpha, std::complex *X, const int *incX); - - // Peize Lin add ?axpy 2016-08-04, to compute y=a*x+y - void saxpy_(const int *N, const float *alpha, const float *X, const int *incX, float *Y, const int *incY); - void daxpy_(const int *N, const double *alpha, const double *X, const int *incX, double *Y, const int *incY); - void caxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); - void zaxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); - - void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy); - void zcopy_(long const *n, const std::complex *a, int const *incx, std::complex *b, int const *incy); - - //reason for passing results as argument instead of returning it: - //see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/ - // void zdotc_(std::complex *result, const int *n, const std::complex *zx, - // const int *incx, const std::complex *zy, const int *incy); - // Peize Lin add ?dot 2017-10-27, to compute d=x*y - float sdot_(const int *N, const float *X, const int *incX, const float *Y, const int *incY); - double ddot_(const int *N, const double *X, const int *incX, const double *Y, const int *incY); - - // Peize Lin add ?nrm2 2018-06-12, to compute out = ||x||_2 = \sqrt{ \sum_i x_i**2 } - float snrm2_( const int *n, const float *X, const int *incX ); - double dnrm2_( const int *n, const double *X, const int *incX ); - double dznrm2_( const int *n, const std::complex *X, const int *incX ); - - // symmetric rank-k update - void dsyrk_( - const char* uplo, - const char* trans, - const int* n, - const int* k, - const double* alpha, - const double* a, - const int* lda, - const double* beta, - double* c, - const int* ldc - ); - - // level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work. - void sgemv_(const char*const transa, const int*const m, const int*const n, - const float*const alpha, const float*const a, const int*const lda, const float*const x, const int*const incx, - const float*const beta, float*const y, const int*const incy); - void dgemv_(const char*const transa, const int*const m, const int*const n, - const double*const alpha, const double*const a, const int*const lda, const double*const x, const int*const incx, - const double*const beta, double*const y, const int*const incy); - - void cgemv_(const char *trans, const int *m, const int *n, const std::complex *alpha, - const std::complex *a, const int *lda, const std::complex *x, const int *incx, - const std::complex *beta, std::complex *y, const int *incy); - - void zgemv_(const char *trans, const int *m, const int *n, const std::complex *alpha, - const std::complex *a, const int *lda, const std::complex *x, const int *incx, - const std::complex *beta, std::complex *y, const int *incy); - - void dsymv_(const char *uplo, const int *n, - const double *alpha, const double *a, const int *lda, - const double *x, const int *incx, - const double *beta, double *y, const int *incy); - - // A := alpha x * y.T + A - void dger_(const int* m, - const int* n, - const double* alpha, - const double* x, - const int* incx, - const double* y, - const int* incy, - double* a, - const int* lda); - void zgerc_(const int* m, - const int* n, - const std::complex* alpha, - const std::complex* x, - const int* incx, - const std::complex* y, - const int* incy, - std::complex* a, - const int* lda); - - // level 3: matrix-matrix operations, O(n^2) data and O(n^3) work. - - // Peize Lin add ?gemm 2017-10-27, to compute C = a * A.? * B.? + b * C - // A is general - void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, - const float *beta, float *c, const int *ldc); - void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, - const double *beta, double *c, const int *ldc); - void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); - void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); - - // A is symmetric. C = a * A.? * B.? + b * C - void ssymm_(const char *side, const char *uplo, const int *m, const int *n, - const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, - const float *beta, float *c, const int *ldc); - void dsymm_(const char *side, const char *uplo, const int *m, const int *n, - const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, - const double *beta, double *c, const int *ldc); - void csymm_(const char *side, const char *uplo, const int *m, const int *n, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); - void zsymm_(const char *side, const char *uplo, const int *m, const int *n, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); - - // A is hermitian. C = a * A.? * B.? + b * C - void chemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, - std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); - void zhemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, - std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); - - //solving triangular matrix with multiple right hand sides - void dtrsm_(char *side, char* uplo, char *transa, char *diag, int *m, int *n, - double* alpha, double* a, int *lda, double*b, int *ldb); - void ztrsm_(char *side, char* uplo, char *transa, char *diag, int *m, int *n, - std::complex* alpha, std::complex* a, int *lda, std::complex*b, int *ldb); - -} - // Class BlasConnector provide the connector to fortran lapack routine. // The entire function in this class are static and inline function. // Usage example: BlasConnector::functionname(parameter list). @@ -211,6 +77,10 @@ class BlasConnector static std::complex dotc( const int n, const std::complex*const X, const int incX, const std::complex*const Y, const int incY, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + static + void ger_cm(int m, int n, double alpha, const double* x, + int incx, const double* y, const int incy, double a, int lda, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + // Peize Lin add 2017-10-27, fix bug trans 2019-01-17 // C = a * A.? * B.? + b * C // Row Major by default @@ -302,25 +172,34 @@ class BlasConnector void hemm_cm(char side, char uplo, int m, int n, std::complex alpha, std::complex *a, int lda, std::complex *b, int ldb, std::complex beta, std::complex *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void herk(char uplo, char trans, int n, int k, float alpha, const std::complex *A, int lda, float beta, std::complex *C, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + static + void herk(char uplo, char trans, int n, int k, double alpha, const std::complex *A, int lda, double beta, std::complex *C, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + static + void syrk(char uplo, char trans, int n, int k, + double alpha, const double* a, int lda, double beta, double* c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + // y = A*x + beta*y static - void gemv(const char trans, const int m, const int n, + void gemv_cm(const char trans, const int m, const int n, const float alpha, const float* A, const int lda, const float* X, const int incx, const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); static - void gemv(const char trans, const int m, const int n, + void gemv_cm(const char trans, const int m, const int n, const double alpha, const double* A, const int lda, const double* X, const int incx, const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); static - void gemv(const char trans, const int m, const int n, + void gemv_cm(const char trans, const int m, const int n, const std::complex alpha, const std::complex *A, const int lda, const std::complex *X, const int incx, const std::complex beta, std::complex *Y, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); static - void gemv(const char trans, const int m, const int n, + void gemv_cm(const char trans, const int m, const int n, const std::complex alpha, const std::complex *A, const int lda, const std::complex *X, const int incx, const std::complex beta, std::complex *Y, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); @@ -366,6 +245,30 @@ class BlasConnector void vector_add_vector(const int& dim, std::complex *result, const std::complex *vector1, const double constant1, const std::complex *vector2, const double constant2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); }; +namespace BlasUtils { + static inline + CBLAS_UPLO toCblasUplo(const char uplo) + { + if (uplo == 'U' || uplo == 'u') + return CblasUpper; + else if (uplo == 'L' || uplo == 'l') + return CblasLower; + else + throw std::invalid_argument("Invalid uplo argument"); + } + + static inline + CBLAS_TRANSPOSE toCblasTrans(const char trans) + { + if (trans == 'N' || trans == 'n') + return CblasNoTrans; + else if (trans == 'T' || trans == 't') + return CblasTrans; + else + throw std::invalid_argument("Invalid trans augument"); + } +} + #ifdef __CUDA #include @@ -387,7 +290,6 @@ namespace BlasUtils{ cublasSideMode_t judge_side(const char& trans); // Translate a normal side parameter to a cublas/hipblas type. cublasFillMode_t judge_fill(const char& trans); // Translate a normal fill parameter to a cublas/hipblas type. - } #endif diff --git a/source/source_base/module_external/blas_connector_matrix.cpp b/source/source_base/module_external/blas_connector_matrix.cpp index de1e839ad0..2c749552b1 100644 --- a/source/source_base/module_external/blas_connector_matrix.cpp +++ b/source/source_base/module_external/blas_connector_matrix.cpp @@ -1,3 +1,7 @@ +/* level 3: matrix-matrix operations, O(n^2) data and O(n^3) work. + * This file contains the implementation of the BLAS level 3 operations. + * These operations include matrix-matrix multiplication and related operations. + */ #include "blas_connector.h" #include "../macros.h" @@ -14,6 +18,59 @@ #include "source_base/module_device/memory_op.h" #endif +extern "C" +{ + // level 3: matrix-matrix operations, O(n^2) data and O(n^3) work. + + // Peize Lin add ?gemm 2017-10-27, to compute C = a * A.? * B.? + b * C + // A is general + void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, + const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, + const float *beta, float *c, const int *ldc); + void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, + const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, + const double *beta, double *c, const int *ldc); + void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + + // A is symmetric. C = a * A.? * B.? + b * C + void ssymm_(const char *side, const char *uplo, const int *m, const int *n, + const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, + const float *beta, float *c, const int *ldc); + void dsymm_(const char *side, const char *uplo, const int *m, const int *n, + const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, + const double *beta, double *c, const int *ldc); + void csymm_(const char *side, const char *uplo, const int *m, const int *n, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + void zsymm_(const char *side, const char *uplo, const int *m, const int *n, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + + // A is hermitian. C = a * A.? * B.? + b * C + void chemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, + std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); + void zhemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, + std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); + + // symmetric rank-k update + void dsyrk_( + const char* uplo, + const char* trans, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* beta, + double* c, + const int* ldc + ); +} // C = a * A.? * B.? + b * C // Row-Major part @@ -498,78 +555,44 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n, } } -void BlasConnector::gemv(const char trans, const int m, const int n, - const float alpha, const float* A, const int lda, const float* X, const int incx, - const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); - cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::gemv(const char trans, const int m, const int n, - const double alpha, const double* A, const int lda, const double* X, const int incx, - const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type) +void BlasConnector::syrk(char uplo, char trans, int n, int k, + double alpha, const double* a, int lda, double beta, double* c, int ldc, base_device::AbacusDevice_t device_type) { - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); - cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy)); - } -#endif - else { + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + dsyrk_(&uplo, &trans, &n, &k, &alpha, a, &lda, &beta, c, &ldc); + } + else { throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); } } -void BlasConnector::gemv(const char trans, const int m, const int n, - const std::complex alpha, const std::complex *A, const int lda, const std::complex *X, const int incx, - const std::complex beta, std::complex *Y, const int incy, base_device::AbacusDevice_t device_type) +void BlasConnector::herk(char uplo, char trans, int n, int k, float alpha, + const std::complex *A, int lda, float beta, std::complex *C, int ldc, base_device::AbacusDevice_t device_type) { - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cuFloatComplex alpha_cu = make_cuFloatComplex(alpha.real(), alpha.imag()); - cuFloatComplex beta_cu = make_cuFloatComplex(beta.real(), beta.imag()); - cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op"); - cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy)); - } -#endif - else { + auto cblas_uplo = BlasUtils::toCblasUplo(uplo); + auto cblas_trans = BlasUtils::toCblasTrans(trans); + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cblas_cherk(CblasRowMajor, cblas_uplo, cblas_trans, n, k, alpha, A, lda, beta, C, ldc); + } + else + { throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } + } } -void BlasConnector::gemv(const char trans, const int m, const int n, - const std::complex alpha, const std::complex *A, const int lda, const std::complex *X, const int incx, - const std::complex beta, std::complex *Y, const int incy, base_device::AbacusDevice_t device_type) +void BlasConnector::herk(char uplo, char trans, int n, int k, double alpha, + const std::complex *A, int lda, double beta, std::complex *C, int ldc, base_device::AbacusDevice_t device_type) { - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cuDoubleComplex alpha_cu = make_cuDoubleComplex(alpha.real(), alpha.imag()); - cuDoubleComplex beta_cu = make_cuDoubleComplex(beta.real(), beta.imag()); - cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op"); - cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} \ No newline at end of file + auto cblas_uplo = BlasUtils::toCblasUplo(uplo); + auto cblas_trans = BlasUtils::toCblasTrans(trans); + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cblas_zherk(CblasRowMajor, cblas_uplo, cblas_trans, n, k, alpha, A, lda, beta, C, ldc); + } + else + { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} \ No newline at end of file diff --git a/source/source_base/module_external/blas_connector_vector.cpp b/source/source_base/module_external/blas_connector_vector.cpp index b5e0972946..f9c5925143 100644 --- a/source/source_base/module_external/blas_connector_vector.cpp +++ b/source/source_base/module_external/blas_connector_vector.cpp @@ -1,6 +1,11 @@ +/* level 1: std::vector-std::vector operations, O(n) data and O(n) work. + * This file contains the implementation of the BLAS level 1 operations. + * These operations include vector scaling, vector addition, vector dot product, and vector norm calculations. + */ #include "blas_connector.h" #include "../macros.h" +#include #ifdef __DSP #include "source_base/kernels/dsp/dsp_connector.h" #include "source_base/global_variable.h" @@ -14,15 +19,47 @@ #include "source_base/module_device/memory_op.h" #endif +extern "C" +{ + // level 1: std::vector-std::vector operations, O(n) data and O(n) work. + // Peize Lin add ?scal 2016-08-04, to compute x=a*x + void sscal_(const int *N, const float *alpha, float *X, const int *incX); + void dscal_(const int *N, const double *alpha, double *X, const int *incX); + void cscal_(const int *N, const std::complex *alpha, std::complex *X, const int *incX); + void zscal_(const int *N, const std::complex *alpha, std::complex *X, const int *incX); + + // Peize Lin add ?axpy 2016-08-04, to compute y=a*x+y + void saxpy_(const int *N, const float *alpha, const float *X, const int *incX, float *Y, const int *incY); + void daxpy_(const int *N, const double *alpha, const double *X, const int *incX, double *Y, const int *incY); + void caxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); + void zaxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); + + void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy); + void zcopy_(long const *n, const std::complex *a, int const *incx, std::complex *b, int const *incy); + + //reason for passing results as argument instead of returning it: + //see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/ + // void zdotc_(std::complex *result, const int *n, const std::complex *zx, + // const int *incx, const std::complex *zy, const int *incy); + // Peize Lin add ?dot 2017-10-27, to compute d=x*y + float sdot_(const int *N, const float *X, const int *incX, const float *Y, const int *incY); + double ddot_(const int *N, const double *X, const int *incX, const double *Y, const int *incY); + + // Peize Lin add ?nrm2 2018-06-12, to compute out = ||x||_2 = \sqrt{ \sum_i x_i**2 } + float snrm2_( const int *n, const float *X, const int *incX ); + double dnrm2_( const int *n, const double *X, const int *incX ); + double dznrm2_( const int *n, const std::complex *X, const int *incX ); +} -void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type) +// x=a*x +void BlasConnector::scal( const int n, const float alpha, float *X, const int incX, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { - saxpy_(&n, &alpha, X, &incX, Y, &incY); + sscal_(&n, &alpha, X, &incX); } #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); + cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); } #endif else { @@ -30,14 +67,14 @@ void BlasConnector::axpy( const int n, const float alpha, const float *X, const } } -void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type) +void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { - daxpy_(&n, &alpha, X, &incX, Y, &incY); + dscal_(&n, &alpha, X, &incX); } #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); + cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); } #endif else { @@ -45,14 +82,14 @@ void BlasConnector::axpy( const int n, const double alpha, const double *X, cons } } -void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) +void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { - caxpy_(&n, &alpha, X, &incX, Y, &incY); + cscal_(&n, &alpha, X, &incX); } #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY)); + cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX)); } #endif else { @@ -60,14 +97,14 @@ void BlasConnector::axpy( const int n, const std::complex alpha, const st } } -void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) +void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zaxpy_(&n, &alpha, X, &incX, Y, &incY); + zscal_(&n, &alpha, X, &incX); } #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY)); + cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX)); } #endif else { @@ -75,16 +112,14 @@ void BlasConnector::axpy( const int n, const std::complex alpha, const s } } - -// x=a*x -void BlasConnector::scal( const int n, const float alpha, float *X, const int incX, base_device::AbacusDevice_t device_type) +void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { - sscal_(&n, &alpha, X, &incX); + saxpy_(&n, &alpha, X, &incX, Y, &incY); } #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); + cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); } #endif else { @@ -92,14 +127,14 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i } } -void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type) +void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dscal_(&n, &alpha, X, &incX); + daxpy_(&n, &alpha, X, &incX, Y, &incY); } #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); + cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); } #endif else { @@ -107,14 +142,14 @@ void BlasConnector::scal( const int n, const double alpha, double *X, const int } } -void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) +void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { - cscal_(&n, &alpha, X, &incX); + caxpy_(&n, &alpha, X, &incX, Y, &incY); } #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX)); + cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY)); } #endif else { @@ -122,14 +157,14 @@ void BlasConnector::scal( const int n, const std::complex alpha, std::com } } -void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) +void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zscal_(&n, &alpha, X, &incX); + zaxpy_(&n, &alpha, X, &incX, Y, &incY); } #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX)); + cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY)); } #endif else { @@ -137,6 +172,26 @@ void BlasConnector::scal( const int n, const std::complex alpha, std::co } } +// copies a into b +void BlasConnector::copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + dcopy_(&n, a, &incx, b, &incy); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} + +void BlasConnector::copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zcopy_(&n, a, &incx, b, &incy); + } + else { + throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); + } +} // d=x*y float BlasConnector::dot( const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type) @@ -321,28 +376,6 @@ double BlasConnector::nrm2( const int n, const std::complex *X, const in } } -// copies a into b -void BlasConnector::copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dcopy_(&n, a, &incx, b, &incy); - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zcopy_(&n, a, &incx, b, &incy); - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - - template void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type){ using Real = typename GetTypeReal::type; diff --git a/source/source_base/module_external/lapack_connector.h b/source/source_base/module_external/lapack_connector.h index 1f691fe3c2..e1b2844f2f 100644 --- a/source/source_base/module_external/lapack_connector.h +++ b/source/source_base/module_external/lapack_connector.h @@ -9,463 +9,97 @@ #include "../complexmatrix.h" #include "../global_function.h" -//Naming convention of lapack subroutines : ammxxx, where -//"a" specifies the data type: -// - d stands for double -// - z stands for complex double -//"mm" specifies the type of matrix, for example: -// - he stands for hermitian -// - sy stands for symmetric -//"xxx" specifies the type of problem, for example: -// - gv stands for generalized eigenvalue - -extern "C" -{ -// solve the generalized eigenproblem Ax=eBx, where A is Hermitian and complex couble - // zhegv_ & zhegvd_ returns all eigenvalues while zhegvx_ returns selected ones - void dsygvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - double* a, const int* lda, - const double* b, const int* ldb, double* w, - double* work, int* lwork, - int* iwork, int* liwork, int* info); - - void chegvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - std::complex* a, const int* lda, - const std::complex* b, const int* ldb, float* w, - std::complex* work, int* lwork, float* rwork, int* lrwork, - int* iwork, int* liwork, int* info); - - void zhegvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - std::complex* a, const int* lda, - const std::complex* b, const int* ldb, double* w, - std::complex* work, int* lwork, double* rwork, int* lrwork, - int* iwork, int* liwork, int* info); - - void dsyevx_(const char* jobz, const char* range, const char* uplo, const int* n, - double* a, const int* lda, - const double* vl, const double* vu, const int* il, const int* iu, const double* abstol, - const int* m, double* w, double* z, const int* ldz, - double* work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info); - - void cheevx_(const char* jobz, const char* range, const char* uplo, const int* n, - std::complex *a, const int* lda, - const float* vl, const float* vu, const int* il, const int* iu, const float* abstol, - const int* m, float* w, std::complex *z, const int *ldz, - std::complex *work, const int* lwork, float* rwork, int* iwork, int* ifail, int* info); - - void zheevx_(const char* jobz, const char* range, const char* uplo, const int* n, - std::complex *a, const int* lda, - const double* vl, const double* vu, const int* il, const int* iu, const double* abstol, - const int* m, double* w, std::complex *z, const int *ldz, - std::complex *work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info); - - - void dsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo, - const int* n, double* A, const int* lda, double* B, const int* ldb, - const double* vl, const double* vu, const int* il, const int* iu, - const double* abstol, const int* m, double* w, double* Z, const int* ldz, - double* work, const int* lwork, int* iwork, int* ifail, int* info); - - void chegvx_(const int* itype,const char* jobz,const char* range,const char* uplo, - const int* n,std::complex *a,const int* lda,std::complex *b, - const int* ldb,const float* vl,const float* vu,const int* il, - const int* iu,const float* abstol,const int* m,float* w, - std::complex *z,const int *ldz,std::complex *work,const int* lwork, - float* rwork,int* iwork,int* ifail,int* info); - - void zhegvx_(const int* itype,const char* jobz,const char* range,const char* uplo, - const int* n,std::complex *a,const int* lda,std::complex *b, - const int* ldb,const double* vl,const double* vu,const int* il, - const int* iu,const double* abstol,const int* m,double* w, - std::complex *z,const int *ldz,std::complex *work,const int* lwork, - double* rwork,int* iwork,int* ifail,int* info); - - void zhegv_(const int* itype,const char* jobz,const char* uplo,const int* n, - std::complex* a,const int* lda,std::complex* b,const int* ldb, - double* w,std::complex* work,int* lwork,double* rwork,int* info); - void chegv_(const int* itype,const char* jobz,const char* uplo,const int* n, - std::complex* a,const int* lda,std::complex* b,const int* ldb, - float* w,std::complex* work,int* lwork,float* rwork,int* info); - void dsygv_(const int* itype, const char* jobz,const char* uplo, const int* n, - double* a,const int* lda,double* b,const int* ldb, - double* w,double* work,int* lwork,int* info); - - // solve the eigenproblem Ax=ex, where A is Hermitian and complex couble - // zheev_ returns all eigenvalues while zheevx_ returns selected ones - void zheev_(const char* jobz,const char* uplo,const int* n,std::complex *a, - const int* lda,double* w,std::complex* work,const int* lwork, - double* rwork,int* info); - void cheev_(const char* jobz,const char* uplo,const int* n,std::complex *a, - const int* lda,float* w,std::complex* work,const int* lwork, - float* rwork,int* info); - void dsyev_(const char* jobz,const char* uplo,const int* n,double *a, - const int* lda,double* w,double* work,const int* lwork, int* info); - - // solve the eigenproblem Ax=ex, where A is a general matrix - void dgeev_(const char* jobvl, const char* jobvr, const int* n, double* a, const int* lda, - double* wr, double* wi, double* vl, const int* ldvl, double* vr, const int* ldvr, - double* work, const int* lwork, int* info); - void zgeev_(const char* jobvl, const char* jobvr, const int* n, std::complex* a, const int* lda, - std::complex* w, std::complex* vl, const int* ldvl, std::complex* vr, const int* ldvr, - std::complex* work, const int* lwork, double* rwork, int* info); - // liuyu add 2023-10-03 - // dgetri and dgetrf computes the inverse of a n*n real matrix - void dgetri_(const int* n, double* a, const int* lda, const int* ipiv, double* work, const int* lwork, int* info); - void dgetrf_(const int* m, const int* n, double* a, const int* lda, int* ipiv, int* info); - - // dsytrf_ computes the Bunch-Kaufman factorization of a double precision - // symmetric matrix, while dsytri takes its output to perform martrix inversion - void dsytrf_(const char* uplo, const int* n, double * a, const int* lda, - int *ipiv,double *work, const int* lwork ,int *info); - void dsytri_(const char* uplo,const int* n,double *a, const int *lda, - int *ipiv, double * work,int *info); - // Peize Lin add dsptrf and dsptri 2016-06-21, to compute inverse real symmetry indefinit matrix. - // dpotrf computes the Cholesky factorization of a real symmetric positive definite matrix - // while dpotri taks its output to perform matrix inversion - void spotrf_(const char*const uplo, const int*const n, float*const A, const int*const lda, int*const info); - void dpotrf_(const char*const uplo, const int*const n, double*const A, const int*const lda, int*const info); - void cpotrf_(const char*const uplo, const int*const n, std::complex*const A, const int*const lda, int*const info); - void zpotrf_(const char*const uplo, const int*const n, std::complex*const A, const int*const lda, int*const info); - void spotri_(const char*const uplo, const int*const n, float*const A, const int*const lda, int*const info); - void dpotri_(const char*const uplo, const int*const n, double*const A, const int*const lda, int*const info); - void cpotri_(const char*const uplo, const int*const n, std::complex*const A, const int*const lda, int*const info); - void zpotri_(const char*const uplo, const int*const n, std::complex*const A, const int*const lda, int*const info); - - // zgetrf computes the LU factorization of a general matrix - // while zgetri takes its output to perform matrix inversion - void zgetrf_(const int* m, const int *n, std::complex *A, const int *lda, int *ipiv, int* info); - void zgetri_(const int* n, std::complex* A, const int* lda, const int* ipiv, std::complex* work, const int* lwork, int* info); - - // if trans=='N': C = alpha * A * A.H + beta * C - // if trans=='C': C = alpha * A.H * A + beta * C - void zherk_(const char *uplo, const char *trans, const int *n, const int *k, - const double *alpha, const std::complex *A, const int *lda, - const double *beta, std::complex *C, const int *ldc); - void cherk_(const char* uplo, const char* trans, const int* n, const int* k, - const float* alpha, const std::complex* A, const int* lda, - const float* beta, std::complex* C, const int* ldc); - - // computes all eigenvalues of a symmetric tridiagonal matrix - // using the Pal-Walker-Kahan variant of the QL or QR algorithm. - void dsterf_(int *n, double *d, double *e, int *info); - // computes the eigenvectors of a real symmetric tridiagonal - // matrix T corresponding to specified eigenvalues - void dstein_(int *n, double* d, double *e, int *m, double *w, - int* block, int* isplit, double* z, int *lda, double *work, - int* iwork, int* ifail, int *info); - // computes the eigenvectors of a complex symmetric tridiagonal - // matrix T corresponding to specified eigenvalues - void zstein_(int *n, double* d, double *e, int *m, double *w, - int* block, int* isplit, std::complex* z, int *lda, double *work, - int* iwork, int* ifail, int *info); - - // computes the Cholesky factorization of a symmetric - // positive definite matrix A. - void dpotf2_(char *uplo, int *n, double *a, int *lda, int *info); - void zpotf2_(char *uplo,int *n,std::complex *a, int *lda, int *info); - - // reduces a symmetric definite generalized eigenproblem to standard form - // using the factorization results obtained from spotrf - void dsygs2_(int *itype, char *uplo, int *n, double *a, int *lda, double *b, int *ldb, int *info); - void zhegs2_(int *itype, char *uplo, int *n, std::complex *a, int *lda, std::complex *b, int *ldb, int *info); - - // copies a into b - void dlacpy_(char *uplo, int *m, int *n, double* a, int *lda, double *b, int *ldb); - void zlacpy_(char *uplo, int *m, int *n, std::complex* a, int *lda, std::complex *b, int *ldb); - - // generates a real elementary reflector H of order n, such that - // H * ( alpha ) = ( beta ), H is unitary. - // ( x ) ( 0 ) - void dlarfg_(int *n, double *alpha, double *x, int *incx, double *tau); - void zlarfg_(int *n, std::complex *alpha, std::complex *x, int *incx, std::complex *tau); - - // solve a tridiagonal linear system - void dgtsv_(int* N, int* NRHS, double* DL, double* D, double* DU, double* B, int* LDB, int* INFO); - - // solve Ax = b - void dsysv_(const char* uplo, const int* n, const int* m, double * a, const int* lda, - int *ipiv, double * b, const int* ldb, double *work, const int* lwork ,int *info); -} - -#ifdef GATHER_INFO -#define zhegvx_ zhegvx_i -void zhegvx_i(const int* itype, - const char* jobz, - const char* range, - const char* uplo, - const int* n, - std::complex* a, - const int* lda, - std::complex* b, - const int* ldb, - const double* vl, - const double* vu, - const int* il, - const int* iu, - const double* abstol, - const int* m, - double* w, - std::complex* z, - const int* ldz, - std::complex* work, - const int* lwork, - double* rwork, - int* iwork, - int* ifail, - int* info); -#endif // GATHER_INFO - // Class LapackConnector provide the connector to fortran lapack routine. // The entire function in this class are static and inline function. // Usage example: LapackConnector::functionname(parameter list). -class LapackConnector +namespace LapackConnector { -private: - // Transpose the std::complex matrix to the fortran-form real-std::complex array. - static inline - std::complex* transpose(const ModuleBase::ComplexMatrix& a, const int n, const int lda) - { - std::complex* aux = new std::complex[lda*n]; - for (int i = 0; i < n; ++i) - { - for (int j = 0; j < lda; ++j) - { - aux[i*lda+j] = a(j,i); // aux[i*lda+j] means aux[i][j] in semantic, not in syntax! - } - } - return aux; - } - - static inline - std::complex* transpose(const std::complex* a, const int n, const int lda, const int nbase_x) - { - std::complex* aux = new std::complex[lda*n]; - for (int i = 0; i < n; ++i) - { - for (int j = 0; j < lda; ++j) - { - aux[j * n + i] = a[i * nbase_x + j]; - } - } - return aux; - } - - static inline - std::complex* transpose(const std::complex* a, const int n, const int lda, const int nbase_x) - { - std::complex* aux = new std::complex[lda*n]; - for (int i = 0; i < n; ++i) - { - for (int j = 0; j < lda; ++j) - { - aux[j * n + i] = a[i * nbase_x + j]; - } - } - return aux; - } - - // Transpose the fortran-form real-std::complex array to the std::complex matrix. - static inline - void transpose(const std::complex* aux, ModuleBase::ComplexMatrix& a, const int n, const int lda) - { - for (int i = 0; i < n; ++i) - { - for (int j = 0; j < lda; ++j) - { - a(j, i) = aux[i*lda+j]; // aux[i*lda+j] means aux[i][j] in semantic, not in syntax! - } - } - } - - // Transpose the fortran-form real-std::complex array to the std::complex matrix. - static inline - void transpose(const std::complex* aux, std::complex* a, const int n, const int lda, const int nbase_x) - { - for (int i = 0; i < n; ++i) - { - for (int j = 0; j < lda; ++j) - { - a[j * nbase_x + i] = aux[i * lda + j]; // aux[i*lda+j] means aux[i][j] in semantic, not in syntax! - } - } - } - - // Transpose the fortran-form real-std::complex array to the std::complex matrix. - static inline - void transpose(const std::complex* aux, std::complex* a, const int n, const int lda, const int nbase_x) - { - for (int i = 0; i < n; ++i) - { - for (int j = 0; j < lda; ++j) - { - a[j * nbase_x + i] = aux[i * lda + j]; // aux[i*lda+j] means aux[i][j] in semantic, not in syntax! - } - } - } - - // Peize Lin add 2015-12-27 - static inline - char change_uplo(const char &uplo) - { - switch(uplo) - { - case 'U': return 'L'; - case 'L': return 'U'; - default: throw std::invalid_argument("uplo must be 'U' or 'L'"); - } - } - - // Peize Lin add 2019-04-14 - static inline - char change_trans_NC(const char &trans) - { - switch(trans) - { - case 'N': return 'C'; - case 'C': return 'N'; - default: throw std::invalid_argument("trans must be 'N' or 'C'"); - } - } - -public: - // wrap function of fortran lapack routine zheev. - static inline - void zheev( const char jobz, - const char uplo, - const int n, - ModuleBase::ComplexMatrix& a, - const int lda, - double* w, - std::complex< double >* work, - const int lwork, - double* rwork, - int *info ) - { // Transpose the std::complex matrix to the fortran-form real-std::complex array. - std::complex *aux = LapackConnector::transpose(a, n, lda); - // call the fortran routine - zheev_(&jobz, &uplo, &n, aux, &lda, w, work, &lwork, rwork, info); - // Transpose the fortran-form real-std::complex array to the std::complex matrix. - LapackConnector::transpose(aux, a, n, lda); - // free the memory. - delete[] aux; - } - - static inline - void zgetrf(int m, int n, ModuleBase::ComplexMatrix &a, const int lda, int *ipiv, int *info) - { - std::complex *aux = LapackConnector::transpose(a, n, lda); - zgetrf_( &m, &n, aux, &lda, ipiv, info); - LapackConnector::transpose(aux, a, n, lda); - delete[] aux; - return; - } - static inline - void zgetri(int n, ModuleBase::ComplexMatrix &a, int lda, int *ipiv, std::complex * work, int lwork, int *info) - { - std::complex *aux = LapackConnector::transpose(a, n, lda); - zgetri_( &n, aux, &lda, ipiv, work, &lwork, info); - LapackConnector::transpose(aux, a, n, lda); - delete[] aux; - return; - } - - // Peize Lin add 2016-07-09 - static inline - void potrf( const char &uplo, const int &n, float*const A, const int &lda, int &info ) - { - const char uplo_changed = change_uplo(uplo); - spotrf_( &uplo_changed, &n, A, &lda, &info ); - } - static inline - void potrf( const char &uplo, const int &n, double*const A, const int &lda, int &info ) - { - const char uplo_changed = change_uplo(uplo); - dpotrf_( &uplo_changed, &n, A, &lda, &info ); - } - static inline - void potrf( const char &uplo, const int &n, std::complex*const A, const int &lda, int &info ) - { - const char uplo_changed = change_uplo(uplo); - cpotrf_( &uplo_changed, &n, A, &lda, &info ); - } - static inline - void potrf( const char &uplo, const int &n, std::complex*const A, const int &lda, int &info ) - { - const char uplo_changed = change_uplo(uplo); - zpotrf_( &uplo_changed, &n, A, &lda, &info ); - } - - - // Peize Lin add 2016-07-09 - static inline - void potri( const char &uplo, const int &n, float*const A, const int &lda, int &info ) - { - const char uplo_changed = change_uplo(uplo); - spotri_( &uplo_changed, &n, A, &lda, &info); - } - static inline - void potri( const char &uplo, const int &n, double*const A, const int &lda, int &info ) - { - const char uplo_changed = change_uplo(uplo); - dpotri_( &uplo_changed, &n, A, &lda, &info); - } - static inline - void potri( const char &uplo, const int &n, std::complex*const A, const int &lda, int &info ) - { - const char uplo_changed = change_uplo(uplo); - cpotri_( &uplo_changed, &n, A, &lda, &info); - } - static inline - void potri( const char &uplo, const int &n, std::complex*const A, const int &lda, int &info ) - { - const char uplo_changed = change_uplo(uplo); - zpotri_( &uplo_changed, &n, A, &lda, &info); - } - - // Peize Lin add 2016-07-09 - static inline - void potrf( const char &uplo, const int &n, ModuleBase::matrix &A, const int &lda, int &info ) - { - potrf( uplo, n, A.c, lda, info ); - } - static inline - void potrf( const char &uplo, const int &n, ModuleBase::ComplexMatrix &A, const int &lda, int &info ) - { - potrf( uplo, n, A.c, lda, info ); - } - - // Peize Lin add 2016-07-09 - static inline - void potri( const char &uplo, const int &n, ModuleBase::matrix &A, const int &lda, int &info ) - { - potri( uplo, n, A.c, lda, info); - } - static inline - void potri( const char &uplo, const int &n, ModuleBase::ComplexMatrix &A, const int &lda, int &info ) - { - potri( uplo, n, A.c, lda, info); - } - - // Peize Lin add 2019-04-14 - // if trans=='N': C = a * A * A.H + b * C - // if trans=='C': C = a * A.H * A + b * C - static inline - void herk(const char uplo, const char trans, const int n, const int k, - const double alpha, const std::complex *A, const int lda, - const double beta, std::complex *C, const int ldc) - { - const char uplo_changed = change_uplo(uplo); - const char trans_changed = change_trans_NC(trans); - zherk_(&uplo_changed, &trans_changed, &n, &k, &alpha, A, &lda, &beta, C, &ldc); - } - static inline - void herk(const char uplo, const char trans, const int n, const int k, - const float alpha, const std::complex* A, const int lda, - const float beta, std::complex* C, const int ldc) - { - const char uplo_changed = change_uplo(uplo); - const char trans_changed = change_trans_NC(trans); - cherk_(&uplo_changed, &trans_changed, &n, &k, &alpha, A, &lda, &beta, C, &ldc); - } +enum MatrixLayout +{ + RowMajor, + ColMajor }; + +void hegv(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::complex* a, int lda, std::complex* b, int ldb, float* w); +void hegv(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::complex* a, int lda, std::complex* b, int ldb, double* w); +void hegv(MatrixLayout layout, int itype, char jobz, char uplo, int n, double* a, int lda, double* b, int ldb, double* w); +void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, float* a, int lda, float* b, int ldb, float* w); +void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, double* a, int lda, double* b, int ldb, double* w); +void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::complex* a, int lda, std::complex* b, int ldb, float* w); +void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::complex* a, int lda, std::complex* b, int ldb, double* w); +void hegvx(MatrixLayout layout, int itype, char jobz, char range, char uplo, int n, + std::complex* a, int lda, std::complex* b, int ldb, + float vl, float vu, int il, int iu, float abstol, int* m, + float* w, std::complex* z, int ldz, int* ifail); +void hegvx(MatrixLayout layout, int itype, char jobz, char range, char uplo, int n, + std::complex* a, int lda, std::complex* b, int ldb, + double vl, double vu, int il, int iu, double abstol, int* m, + double* w, std::complex* z, int ldz, int* ifail); +void hegvx(MatrixLayout layout, int itype, char jobz, char range, char uplo, int n, + double* a, int lda, double* b, int ldb, + double vl, double vu, int il, int iu, double abstol, int* m, + double* w, double* z, int ldz, int* ifail); +void getrf(MatrixLayout layout, int m, int n, std::complex* a, int lda, int* ipiv); +void getri(MatrixLayout layout, std::complex* a, int lda, const int* ipiv); + +void potrf(MatrixLayout layout, char uplo, int n, float* a, int lda); +void potrf(MatrixLayout layout, char uplo, int n, double* a, int lda); +void potrf(MatrixLayout layout, char uplo, int n, std::complex* a, int lda); +void potrf(MatrixLayout layout, char uplo, int n, std::complex* a, int lda); + +void potri(MatrixLayout layout, char uplo, int n, float* a, int lda); +void potri(MatrixLayout layout, char uplo, int n, double* a, int lda); +void potri(MatrixLayout layout, char uplo, int n, std::complex* a, int lda); +void potri(MatrixLayout layout, char uplo, int n, std::complex* a, int lda); + +void heev(MatrixLayout layout, char jobz, char uplo, int n, std::complex* a, int lda, float* w); +void heev(MatrixLayout layout, char jobz, char uplo, int n, std::complex* a, int lda, double* w); +void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, float* a, int lda, float vl, + float vu, int il, int iu, float abstol, int* m, float* w, float* z, int ldz, int* ifail); +void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, double* a, int lda, double vl, + double vu, int il, int iu, double abstol, int* m, double* w, double* z, int ldz, int* ifail); +void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, std::complex* a, int lda, + float vl, float vu, int il, int iu, float abstol, int* m, float* w, + std::complex* z, int ldz, int* ifail); +void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, std::complex* a, int lda, + double vl, double vu, int il, int iu, double abstol, int* m, double* w, + std::complex* z, int ldz, int* ifail); +void heevd(MatrixLayout layout, char jobz, char uplo, int n, + float* a, int lda, float* w); +void heevd(MatrixLayout layout, char jobz, char uplo, int n, + double* a, int lda, double* w); +void heevd(MatrixLayout layout, char jobz, char uplo, int n, + std::complex* a, int lda, float* w); +void heevd(MatrixLayout layout, char jobz, char uplo, int n, + std::complex* a, int lda, double* w); +void syev(MatrixLayout layout, char jobz, char uplo, int n, double* a, int lda, double* w); + +void geev(MatrixLayout layout, char jobvl, char jobvr, int n, double* a, int lda, + double* wr, double* wi, double* vl, int ldvl, double* vr, int ldvr); +void geev(MatrixLayout layout, char jobvl, char jobvr, int n, std::complex* a, int lda, + std::complex* w, std::complex* vl, int ldvl, std::complex* vr, int ldvr); + +void getrf(MatrixLayout layout, int m, int n, float* a, int lda, int* ipiv); +void getrf(MatrixLayout layout, int m, int n, double* a, int lda, int* ipiv); +void getrf(MatrixLayout layout, int m, int n, std::complex* a, int lda, int* ipiv); +void getrf(MatrixLayout layout, int m, int n, std::complex* a, int lda, int* ipiv); +void getri(MatrixLayout layout, int n, float* a, int lda, const int* ipiv); +void getri(MatrixLayout layout, int n, double* a, int lda, const int* ipiv); +void getri(MatrixLayout layout, int n, std::complex* a, int lda, const int* ipiv); +void getri(MatrixLayout layout, int n, std::complex* a, int lda, const int* ipiv); +void getrs(MatrixLayout layout, char trans, int n, int nrhs, const float* a, int lda, const int* ipiv, float* b, int ldb); +void getrs(MatrixLayout layout, char trans, int n, int nrhs, const double* a, int lda, const int* ipiv, double* b, int ldb); +void getrs(MatrixLayout layout, char trans, int n, int nrhs, const std::complex* a, int lda, const int* ipiv, std::complex* b, int ldb); +void getrs(MatrixLayout layout, char trans, int n, int nrhs, const std::complex* a, int lda, const int* ipiv, std::complex* b, int ldb); +void sytrf(MatrixLayout layout, char uplo, int n, double* a, int lda, int* ipiv); +void sytri(MatrixLayout layout, char uplo, int n, double* a, int lda, const int* ipiv); + +void gtsv(MatrixLayout layout, int n, int nrhs, double* dl, double* d, double* du, double* b, int ldb); +void sysv(MatrixLayout layout, char uplo, int n, int nrhs, double* a, int lda, int* ipiv, double* b, int ldb); + +void trtri(MatrixLayout layout, char uplo, char diag, int n, float* a, int lda); +void trtri(MatrixLayout layout, char uplo, char diag, int n, double* a, int lda); +void trtri(MatrixLayout layout, char uplo, char diag, int n, std::complex* a, int lda); +void trtri(MatrixLayout layout, char uplo, char diag, int n, std::complex* a, int lda); +} #endif // LAPACKCONNECTOR_HPP diff --git a/source/source_base/module_external/lapack_wrapper.h b/source/source_base/module_external/lapack_wrapper.h deleted file mode 100644 index acccdc0454..0000000000 --- a/source/source_base/module_external/lapack_wrapper.h +++ /dev/null @@ -1,484 +0,0 @@ -#ifndef LAPACK_HPP -#define LAPACK_HPP -#include -extern "C" -{ - // ================================================================================= - // gvd: - void dsygvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - double* a, const int* lda, - const double* b, const int* ldb, double* w, - double* work, int* lwork, - int* iwork, int* liwork, int* info); - - void chegvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - std::complex* a, const int* lda, - const std::complex* b, const int* ldb, float* w, - std::complex* work, int* lwork, float* rwork, int* lrwork, - int* iwork, int* liwork, int* info); - - void zhegvd_(const int* itype, const char* jobz, const char* uplo, const int* n, - std::complex* a, const int* lda, - const std::complex* b, const int* ldb, double* w, - std::complex* work, int* lwork, double* rwork, int* lrwork, - int* iwork, int* liwork, int* info); - // ================================================================================= - - // ================================================================================= - // evx - void dsyevx_(const char* jobz, const char* range, const char* uplo, const int* n, - double* a, const int* lda, - const double* vl, const double* vu, const int* il, const int* iu, const double* abstol, - const int* m, double* w, double* z, const int* ldz, - double* work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info); - - void cheevx_(const char* jobz, const char* range, const char* uplo, const int* n, - std::complex *a, const int* lda, - const float* vl, const float* vu, const int* il, const int* iu, const float* abstol, - const int* m, float* w, std::complex *z, const int *ldz, - std::complex *work, const int* lwork, float* rwork, int* iwork, int* ifail, int* info); - - void zheevx_(const char* jobz, const char* range, const char* uplo, const int* n, - std::complex *a, const int* lda, - const double* vl, const double* vu, const int* il, const int* iu, const double* abstol, - const int* m, double* w, std::complex *z, const int *ldz, - std::complex *work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info); - // ================================================================================= - - - // ================================================================================= - // gvx - void dsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo, - const int* n, double* A, const int* lda, double* B, const int* ldb, - const double* vl, const double* vu, const int* il, const int* iu, - const double* abstol, const int* m, double* w, double* Z, const int* ldz, - double* work, const int* lwork, int* iwork, int* ifail, int* info); - - void chegvx_(const int* itype,const char* jobz,const char* range,const char* uplo, - const int* n,std::complex *a,const int* lda,std::complex *b, - const int* ldb,const float* vl,const float* vu,const int* il, - const int* iu,const float* abstol,const int* m,float* w, - std::complex *z,const int *ldz,std::complex *work,const int* lwork, - float* rwork,int* iwork,int* ifail,int* info); - - void zhegvx_(const int* itype,const char* jobz,const char* range,const char* uplo, - const int* n,std::complex *a,const int* lda,std::complex *b, - const int* ldb,const double* vl,const double* vu,const int* il, - const int* iu,const double* abstol,const int* m,double* w, - std::complex *z,const int *ldz,std::complex *work,const int* lwork, - double* rwork,int* iwork,int* ifail,int* info); - // ================================================================================= - - // ================================================================================= - // gv - void zhegv_(const int* itype,const char* jobz,const char* uplo,const int* n, - std::complex* a,const int* lda,std::complex* b,const int* ldb, - double* w,std::complex* work,int* lwork,double* rwork,int* info); - void chegv_(const int* itype,const char* jobz,const char* uplo,const int* n, - std::complex* a,const int* lda,std::complex* b,const int* ldb, - float* w,std::complex* work,int* lwork,float* rwork,int* info); - void dsygv_(const int* itype, const char* jobz,const char* uplo, const int* n, - double* a,const int* lda,double* b,const int* ldb, - double* w,double* work,int* lwork,int* info); - // ================================================================================= - -} - -class LapackWrapper -{ - private: - public: - // wrap function of fortran lapack routine zhegvd. (pointer version) - static inline void xhegvd(const int itype, - const char jobz, - const char uplo, - const int n, - double* a, - const int lda, - const double* b, - const int ldb, - double* w, - double* work, - int lwork, - double* rwork, - int lrwork, - int* iwork, - int liwork, - int& info) - { - // call the fortran routine - dsygvd_(&itype, &jobz, &uplo, &n, a, &lda, b, &ldb, w, work, &lwork, iwork, &liwork, &info); - } - - // wrap function of fortran lapack routine zhegvd. (pointer version) - static inline void xhegvd(const int itype, - const char jobz, - const char uplo, - const int n, - std::complex* a, - const int lda, - const std::complex* b, - const int ldb, - float* w, - std::complex* work, - int lwork, - float* rwork, - int lrwork, - int* iwork, - int liwork, - int& info) - { - // call the fortran routine - chegvd_(&itype, &jobz, &uplo, &n, a, &lda, b, &ldb, w, work, &lwork, rwork, &lrwork, iwork, &liwork, &info); - } - - // wrap function of fortran lapack routine zhegvd. - static inline void xhegvd(const int itype, - const char jobz, - const char uplo, - const int n, - std::complex* a, - const int lda, - const std::complex* b, - const int ldb, - double* w, - std::complex* work, - int lwork, - double* rwork, - int lrwork, - int* iwork, - int liwork, - int& info) - { - // call the fortran routine - zhegvd_(&itype, &jobz, &uplo, &n, a, &lda, b, &ldb, w, work, &lwork, rwork, &lrwork, iwork, &liwork, &info); - } - - // wrap function of fortran lapack routine dsyevx. - static inline void xheevx(const int itype, - const char jobz, - const char range, - const char uplo, - const int n, - double* a, - const int lda, - const double vl, - const double vu, - const int il, - const int iu, - const double abstol, - const int m, - double* w, - double* z, - const int ldz, - double* work, - const int lwork, - double* rwork, - int* iwork, - int* ifail, - int& info) - { - dsyevx_(&jobz, - &range, - &uplo, - &n, - a, - &lda, - &vl, - &vu, - &il, - &iu, - &abstol, - &m, - w, - z, - &ldz, - work, - &lwork, - rwork, - iwork, - ifail, - &info); - } - - // wrap function of fortran lapack routine cheevx. - static inline void xheevx(const int itype, - const char jobz, - const char range, - const char uplo, - const int n, - std::complex* a, - const int lda, - const float vl, - const float vu, - const int il, - const int iu, - const float abstol, - const int m, - float* w, - std::complex* z, - const int ldz, - std::complex* work, - const int lwork, - float* rwork, - int* iwork, - int* ifail, - int& info) - { - cheevx_(&jobz, - &range, - &uplo, - &n, - a, - &lda, - &vl, - &vu, - &il, - &iu, - &abstol, - &m, - w, - z, - &ldz, - work, - &lwork, - rwork, - iwork, - ifail, - &info); - } - - // wrap function of fortran lapack routine zheevx. - static inline void xheevx(const int itype, - const char jobz, - const char range, - const char uplo, - const int n, - std::complex* a, - const int lda, - const double vl, - const double vu, - const int il, - const int iu, - const double abstol, - const int m, - double* w, - std::complex* z, - const int ldz, - std::complex* work, - const int lwork, - double* rwork, - int* iwork, - int* ifail, - int& info) - { - zheevx_(&jobz, - &range, - &uplo, - &n, - a, - &lda, - &vl, - &vu, - &il, - &iu, - &abstol, - &m, - w, - z, - &ldz, - work, - &lwork, - rwork, - iwork, - ifail, - &info); - } - - // wrap function of fortran lapack routine xhegvx ( pointer version ). - static inline void xhegvx(const int itype, - const char jobz, - const char range, - const char uplo, - const int n, - std::complex* a, - const int lda, - std::complex* b, - const int ldb, - const float vl, - const float vu, - const int il, - const int iu, - const float abstol, - const int m, - float* w, - std::complex* z, - const int ldz, - std::complex* work, - const int lwork, - float* rwork, - int* iwork, - int* ifail, - int& info) - { - chegvx_(&itype, - &jobz, - &range, - &uplo, - &n, - a, - &lda, - b, - &ldb, - &vl, - &vu, - &il, - &iu, - &abstol, - &m, - w, - z, - &ldz, - work, - &lwork, - rwork, - iwork, - ifail, - &info); - } - - // wrap function of fortran lapack routine xhegvx ( pointer version ). - static inline void xhegvx(const int itype, - const char jobz, - const char range, - const char uplo, - const int n, - std::complex* a, - const int lda, - std::complex* b, - const int ldb, - const double vl, - const double vu, - const int il, - const int iu, - const double abstol, - const int m, - double* w, - std::complex* z, - const int ldz, - std::complex* work, - const int lwork, - double* rwork, - int* iwork, - int* ifail, - int& info) - { - zhegvx_(&itype, - &jobz, - &range, - &uplo, - &n, - a, - &lda, - b, - &ldb, - &vl, - &vu, - &il, - &iu, - &abstol, - &m, - w, - z, - &ldz, - work, - &lwork, - rwork, - iwork, - ifail, - &info); - } - // wrap function of fortran lapack routine xhegvx ( pointer version ). - static inline void xhegvx(const int itype, - const char jobz, - const char range, - const char uplo, - const int n, - double* a, - const int lda, - double* b, - const int ldb, - const double vl, - const double vu, - const int il, - const int iu, - const double abstol, - const int m, - double* w, - double* z, - const int ldz, - double* work, - const int lwork, - double* rwork, - int* iwork, - int* ifail, - int& info) - { - dsygvx_(&itype, &jobz, &range, &uplo, &n, a, &lda, b, &ldb, &vl, - &vu, &il, &iu, &abstol, &m, w, z, &ldz, work, &lwork, iwork, ifail, &info); - } - - // wrap function of fortran lapack routine xhegvx ( pointer version ). - static inline void xhegv(const int itype, - const char jobz, - const char uplo, - const int n, - double* a, - const int lda, - double* b, - const int ldb, - double* w, - double* work, - int lwork, - double* rwork, - int& info) - { - // TODO - } - - // wrap function of fortran lapack routine xhegvx ( pointer version ). - static inline void xhegv(const int itype, - const char jobz, - const char uplo, - const int n, - std::complex* a, - const int lda, - std::complex* b, - const int ldb, - float* w, - std::complex* work, - int lwork, - float* rwork, - int& info) - { - // TODO - } - // wrap function of fortran lapack routine xhegvx ( pointer version ). - static inline void xhegv(const int itype, - const char jobz, - const char uplo, - const int n, - std::complex* a, - const int lda, - std::complex* b, - const int ldb, - double* w, - std::complex* work, - int lwork, - double* rwork, - int& info) - { - zhegv_(&itype, &jobz, &uplo, &n, a, &lda, b, &ldb, w, work, &lwork, rwork, &info); - } -}; -#endif // LAPACK_HPP \ No newline at end of file diff --git a/source/source_base/module_grid/batch.cpp b/source/source_base/module_grid/batch.cpp index 718d9277d3..36f8ea6ace 100644 --- a/source/source_base/module_grid/batch.cpp +++ b/source/source_base/module_grid/batch.cpp @@ -66,13 +66,10 @@ int _maxmin_divide(const double* grid, int* idx, int m) { // The normal vector of the cut plane is taken to be the eigenvector // corresponding to the largest eigenvalue of the 3x3 matrix A = R*R^T. std::vector A(9, 0.0); - int i3 = 3, i1 = 1; - double d0 = 0.0, d1 = 1.0; - dsyrk_("U", "N", &i3, &m, &d1, R.data(), &i3, &d0, A.data(), &i3); + BlasConnector::syrk('U', 'N', 3, m, 1.0, R.data(), 3, 0.0, A.data(), 3); - int info = 0, lwork = 102 /* determined by a work space query */; - std::vector e(3), work(lwork); - dsyev_("V", "U", &i3, A.data(), &i3, e.data(), work.data(), &lwork, &info); + std::vector e(3); + LapackConnector::syev(LapackConnector::ColMajor, 'V', 'U', 3, A.data(), 3, e.data()); double* n = A.data() + 6; // normal vector of the cut plane // Rearrange the indices to put points in each subset together by diff --git a/source/source_base/module_mixing/broyden_mixing.cpp b/source/source_base/module_mixing/broyden_mixing.cpp index c5f8e5e025..16fad64ddf 100644 --- a/source/source_base/module_mixing/broyden_mixing.cpp +++ b/source/source_base/module_mixing/broyden_mixing.cpp @@ -2,7 +2,6 @@ #include "source_base/module_external/lapack_connector.h" #include "source_base/memory.h" -#include "source_base/module_container/base/third_party/blas.h" #include "source_base/timer.h" #include "source_base/tool_title.h" namespace Base_Mixing @@ -142,10 +141,8 @@ void Broyden_Mixing::tem_cal_coef(const Mixing_Data& mdata, std::function iwork(ndim_cal_dF); // ipiv char uu = 'U'; - int info = 0; int m = 1; // gamma means the coeficients for mixing // but now gamma store , namely c @@ -157,22 +154,16 @@ void Broyden_Mixing::tem_cal_coef(const Mixing_Data& mdata, std::function* data_mix) { @@ -98,16 +98,16 @@ void Mixing::mix_data(const Mixing_Data& mdata, std::complex* data_mix) std::vector> coef_complex(coef.size()); for (int i = 0; i < coef.size(); ++i) coef_complex[i] = coef[i]; - container::BlasConnector::gemv('N', - mdata.length, - mdata.ndim_use, - 1.0, - FP_data, - mdata.length, - coef_complex.data(), - 1, - 0.0, - data_mix, - 1); + BlasConnector::gemv_cm('N', + mdata.length, + mdata.ndim_use, + 1.0, + FP_data, + mdata.length, + coef_complex.data(), + 1, + 0.0, + data_mix, + 1); } -} // namespace Base_Mixing \ No newline at end of file +}// namespace Base_Mixing \ No newline at end of file diff --git a/source/source_base/module_mixing/pulay_mixing.cpp b/source/source_base/module_mixing/pulay_mixing.cpp index c283a5c2e7..4fc2f8c229 100644 --- a/source/source_base/module_mixing/pulay_mixing.cpp +++ b/source/source_base/module_mixing/pulay_mixing.cpp @@ -132,16 +132,10 @@ void Pulay_Mixing::tem_cal_coef(const Mixing_Data& mdata, std::function iwork(ndim_use); char uu = 'U'; - int info; - dsytrf_(&uu, &ndim_use, beta_tmp.c, &ndim_use, iwork, work, &ndim_use, &info); - if (info != 0) - ModuleBase::WARNING_QUIT("Charge_Mixing", "Error when factorizing beta."); - dsytri_(&uu, &ndim_use, beta_tmp.c, &ndim_use, iwork, work, &info); - if (info != 0) - ModuleBase::WARNING_QUIT("Charge_Mixing", "Error when DSYTRI beta."); + LapackConnector::sytrf(LapackConnector::ColMajor, uu, ndim_use, beta_tmp.c, ndim_use, iwork.data()); + LapackConnector::sytri(LapackConnector::ColMajor, uu, ndim_use, beta_tmp.c, ndim_use, iwork.data()); for (int i = 0; i < ndim_use; ++i) { for (int j = i + 1; j < ndim_use; ++j) @@ -168,8 +162,6 @@ void Pulay_Mixing::tem_cal_coef(const Mixing_Data& mdata, std::function w; }; -// Test the zhegv_ function +// Test the zhegv function TEST_F(LapackConnectorTest, ZHEGV) { - // First, query the optimal size of the work array - std::complex work_query; - double rwork_query; - zhegv_(&itype, - &jobz, - &uplo, - &n, + LapackConnector::hegv(LapackConnector::ColMajor, + itype, + jobz, + uplo, + n, A.data(), - &lda, + lda, B.data(), - &ldb, - w.data(), - &work_query, - &lwork, - &rwork_query, - &info); - lwork = static_cast(work_query.real()); - std::vector> work(lwork); - // std::vector rwork(static_cast(rwork_query)); - // the above line is not working as rwork_query will return -nan - // std::vector rwork(7 * lwork); - std::vector rwork(7 * n); - - // Now, call zhegv_ with the optimal work array size - zhegv_(&itype, - &jobz, - &uplo, - &n, - A.data(), - &lda, - B.data(), - &ldb, - w.data(), - work.data(), - &lwork, - rwork.data(), - &info); - - // Check that the function completed successfully - ASSERT_EQ(info, 0); + ldb, + w.data()); // Check the computed eigenvalues and eigenvectors // (Use appropriate values for your test case) diff --git a/source/source_estate/math_tools.h b/source/source_estate/math_tools.h index 66b8468f97..a46f3ad909 100644 --- a/source/source_estate/math_tools.h +++ b/source/source_estate/math_tools.h @@ -82,19 +82,20 @@ inline void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2 const char N_char = 'N', T_char = 'T'; const int nlocal = psi1.get_nbasis(); const int nbands = psi1.get_nbands(); - dgemm_(&N_char, - &T_char, - &nlocal, - &nlocal, - &nbands, - &one_float, + BlasConnector::gemm_cm( + N_char, + T_char, + nlocal, + nlocal, + nbands, + one_float, psi1.get_pointer(), - &nlocal, + nlocal, psi2.get_pointer(), - &nlocal, - &zero_float, + nlocal, + zero_float, dm_out.c, - &nlocal); + nlocal); } inline void psiMulPsi(const psi::Psi>& psi1, @@ -106,19 +107,20 @@ inline void psiMulPsi(const psi::Psi>& psi1, const int nlocal = psi1.get_nbasis(); const int nbands = psi1.get_nbands(); const std::complex one_complex = {1.0, 0.0}, zero_complex = {0.0, 0.0}; - zgemm_(&N_char, - &T_char, - &nlocal, - &nlocal, - &nbands, - &one_complex, + BlasConnector::gemm_cm( + N_char, + T_char, + nlocal, + nlocal, + nbands, + one_complex, psi1.get_pointer(), - &nlocal, + nlocal, psi2.get_pointer(), - &nlocal, - &zero_complex, + nlocal, + zero_complex, dm_out.c, - &nlocal); + nlocal); } #endif \ No newline at end of file diff --git a/source/source_estate/module_dm/cal_dm_psi.cpp b/source/source_estate/module_dm/cal_dm_psi.cpp index 7f68838c94..6c3e858812 100644 --- a/source/source_estate/module_dm/cal_dm_psi.cpp +++ b/source/source_estate/module_dm/cal_dm_psi.cpp @@ -231,19 +231,20 @@ void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2, doubl const char N_char = 'N', T_char = 'T'; const int nlocal = psi1.get_nbasis(); const int nbands = psi1.get_nbands(); - dgemm_(&N_char, - &T_char, - &nlocal, - &nlocal, - &nbands, - &one_float, + BlasConnector::gemm_cm( + N_char, + T_char, + nlocal, + nlocal, + nbands, + one_float, psi1.get_pointer(), - &nlocal, + nlocal, psi2.get_pointer(), - &nlocal, - &zero_float, + nlocal, + zero_float, dm_out, - &nlocal); + nlocal); } void psiMulPsi(const psi::Psi>& psi1, @@ -256,19 +257,20 @@ void psiMulPsi(const psi::Psi>& psi1, const int nbands = psi1.get_nbands(); const std::complex one_complex = {1.0, 0.0}; const std::complex zero_complex = {0.0, 0.0}; - zgemm_(&N_char, - &T_char, - &nlocal, - &nlocal, - &nbands, - &one_complex, + BlasConnector::gemm_cm( + N_char, + T_char, + nlocal, + nlocal, + nbands, + one_complex, psi1.get_pointer(), - &nlocal, + nlocal, psi2.get_pointer(), - &nlocal, - &zero_complex, + nlocal, + zero_complex, dm_out, - &nlocal); + nlocal); } } // namespace elecstate diff --git a/source/source_estate/module_dm/cal_edm_tddft.cpp b/source/source_estate/module_dm/cal_edm_tddft.cpp index dd934d8a00..dfd85bece4 100644 --- a/source/source_estate/module_dm/cal_edm_tddft.cpp +++ b/source/source_estate/module_dm/cal_edm_tddft.cpp @@ -54,8 +54,8 @@ void cal_edm_tddft(Parallel_Orbitals& pv, hamilt::MatrixBlock> s_mat; p_hamilt->matrix(h_mat, s_mat); - zcopy_(&nloc, h_mat.p, &inc, Htmp, &inc); - zcopy_(&nloc, s_mat.p, &inc, Sinv, &inc); + BlasConnector::copy(nloc, h_mat.p, inc, Htmp, inc); + BlasConnector::copy(nloc, s_mat.p, inc, Sinv, inc); vector ipiv(nloc, 0); int info = 0; @@ -201,7 +201,7 @@ void cal_edm_tddft(Parallel_Orbitals& pv, &one_int, pv.desc); - zcopy_(&nloc, tmp4, &inc, tmp_edmk.c, &inc); + BlasConnector::copy(nloc, tmp4, inc, tmp_edmk.c, inc); delete[] Htmp; delete[] Sinv; @@ -236,8 +236,8 @@ void cal_edm_tddft(Parallel_Orbitals& pv, int IPIV[nlocal]; - LapackConnector::zgetrf(nlocal, nlocal, Sinv, nlocal, IPIV, &INFO); - LapackConnector::zgetri(nlocal, Sinv, nlocal, IPIV, work, lwork, &INFO); + LapackConnector::getrf(LapackConnector::RowMajor, nlocal, nlocal, Sinv, nlocal, IPIV); + LapackConnector::getri(LapackConnector::RowMajor, nlocal, Sinv, nlocal, IPIV, work, lwork); // I just use ModuleBase::ComplexMatrix temporarily, and will change it // to std::complex* ModuleBase::ComplexMatrix tmp_dmk_base(nlocal, nlocal); diff --git a/source/source_hsolver/diago_david.cpp b/source/source_hsolver/diago_david.cpp index ef7dd07423..d19ece31e0 100644 --- a/source/source_hsolver/diago_david.cpp +++ b/source/source_hsolver/diago_david.cpp @@ -605,7 +605,7 @@ void DiagoDavid::cal_elem(const int& dim, template void DiagoDavid::diag_zhegvx(const int& nbase, const int& nband, - const T* hcc, + T* hcc, const int& nbase_x, Real* eigenvalue, // in CPU T* vcc) diff --git a/source/source_hsolver/diago_david.h b/source/source_hsolver/diago_david.h index 5b65f22cec..249da57229 100644 --- a/source/source_hsolver/diago_david.h +++ b/source/source_hsolver/diago_david.h @@ -297,7 +297,7 @@ class DiagoDavid void diag_zhegvx(const int& nbase, const int& nband, - const T* hcc, + T* hcc, const int& nbase_x, Real* eigenvalue, T* vcc); diff --git a/source/source_hsolver/diago_iter_assist.cpp b/source/source_hsolver/diago_iter_assist.cpp index 916ea0d3fc..2831b39324 100644 --- a/source/source_hsolver/diago_iter_assist.cpp +++ b/source/source_hsolver/diago_iter_assist.cpp @@ -359,8 +359,8 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* template void DiagoIterAssist::diagH_LAPACK(const int nstart, const int nbands, - const T* hcc, - const T* scc, + T* hcc, + T* scc, const int ldh, // nstart Real* e, // always in CPU T* vcc) @@ -477,8 +477,8 @@ void DiagoIterAssist::cal_hs_subspace(const hamilt::Hamilt } template -void DiagoIterAssist::diag_responce( const T* hcc, - const T* scc, +void DiagoIterAssist::diag_responce( T* hcc, + T* scc, const int nbands, const T* mat_in, // [out] target matrix to be multiplied T* mat_out, @@ -520,8 +520,8 @@ void DiagoIterAssist::diag_responce( const T* hcc, } template -void DiagoIterAssist::diag_subspace_psi(const T* hcc, - const T* scc, +void DiagoIterAssist::diag_subspace_psi(T* hcc, + T* scc, const int dim_subspace, psi::Psi& evc, Real* en diff --git a/source/source_hsolver/diago_iter_assist.h b/source/source_hsolver/diago_iter_assist.h index 696c3c2862..b4e7c95521 100644 --- a/source/source_hsolver/diago_iter_assist.h +++ b/source/source_hsolver/diago_iter_assist.h @@ -56,8 +56,8 @@ class DiagoIterAssist static void diagH_LAPACK(const int nstart, const int nbands, - const T* hcc, - const T* sc, + T* hcc, + T* sc, const int ldh, // nstart Real* e, T* vcc); @@ -80,8 +80,8 @@ class DiagoIterAssist /// @param mat_out : output matrix to be rotated /// @param mat_col : number of columns of target matrix /// @param en : eigenvalues - static void diag_responce(const T* hcc, - const T* scc, + static void diag_responce(T* hcc, + T* scc, const int nbands, const T* mat_in, T* mat_out, @@ -89,8 +89,8 @@ class DiagoIterAssist Real* en); /// @brief calculate the response wavefunction psi from rotation matrix solved by diagonalization of H and S matrix - static void diag_subspace_psi(const T* hcc, - const T* scc, + static void diag_subspace_psi(T* hcc, + T* scc, const int dim_subspace, psi::Psi& evc, Real* en); diff --git a/source/source_hsolver/diago_lapack.cpp b/source/source_hsolver/diago_lapack.cpp index 90018c288d..8c418320fe 100644 --- a/source/source_hsolver/diago_lapack.cpp +++ b/source/source_hsolver/diago_lapack.cpp @@ -56,7 +56,7 @@ void DiagoLapack>::diag(hamilt::Hamilt } template -int DiagoLapack::dsygvx_once(const int ncol, +void DiagoLapack::dsygvx_once(const int ncol, const int nrow, const double* const h_mat, const double* const s_mat, @@ -72,85 +72,16 @@ int DiagoLapack::dsygvx_once(const int ncol, memcpy(s_tmp.c, s_mat, sizeof(double) * ncol * nrow); // Prepare caculate parameters - const char jobz = 'V', range = 'I', uplo = 'U'; - const int itype = 1, il = 1, iu = PARAM.inp.nbands, one = 1; - int M = 0, info = 0; - double vl = 0, vu = 0; - const double abstol = 0; + const char jobz = 'V', uplo = 'U'; + const int itype = 1; - int lwork = (ncol + 2) * ncol; + std::vector ev(ncol * ncol); - std::vector work(3, 0); - std::vector iwork(1, 0); - std::vector ifail(PARAM.globalv.nlocal, 0); - - // Original Lapack caculate, obelsete - /*dsygvx_(&itype, - &jobz, - &range, - &uplo, - &PARAM.globalv.nlocal, - h_tmp.c, - &ncol, - s_tmp.c, - &ncol, - &vl, - &vu, - &il, - &iu, - &abstol, - &M, - ekb, - wfc_2d.get_pointer(), - &ncol, - work.data(), - &lwork, - iwork.data(), - ifail.data(), - &info); - - // Throw error if it returns info - if (info) - throw std::runtime_error("info = " + ModuleBase::GlobalFunc::TO_STRING(info) + ".\n" - + std::string(__FILE__) + " line " - + std::to_string(__LINE__)); - //lwork = work[0]; - //work.resize(std::max(lwork, 3), 0); - //iwork.resize(iwork[0], 0); - - dsygvx_(&itype, - &jobz, - &range, - &uplo, - &PARAM.globalv.nlocal, - h_tmp.c, - &PARAM.globalv.nlocal, - s_tmp.c, - &PARAM.globalv.nlocal, - &vl, - &vu, - &il, - &iu, - &abstol, - &M, - ekb, - wfc_2d.get_pointer(), - &ncol, - work.data(), - &lwork, - iwork.data(), - ifail.data(), - &info);*/ - - double *ev = new double[ncol * ncol]; - - dsygv_(&itype, &jobz, &uplo, &PARAM.globalv.nlocal, h_tmp.c, &ncol, s_tmp.c, &ncol, ekb, ev, &lwork, &info); - - return info; + LapackConnector::sygv(itype, jobz, uplo, PARAM.globalv.nlocal, h_tmp.c, ncol, s_tmp.c, ncol, ekb, ev.data()); } template -int DiagoLapack::zhegvx_once(const int ncol, +void DiagoLapack::zhegvx_once(const int ncol, const int nrow, const std::complex* const h_mat, const std::complex* const s_mat, @@ -163,91 +94,12 @@ int DiagoLapack::zhegvx_once(const int ncol, ModuleBase::ComplexMatrix s_tmp(ncol, nrow, false); memcpy(s_tmp.c, s_mat, sizeof(std::complex) * ncol * nrow); - const char jobz = 'V', range = 'I', uplo = 'U'; - const int itype = 1, il = 1, iu = PARAM.inp.nbands, one = 1; - int M = 0, lrwork = -1, info = 0; - const double abstol = 0; - - int lwork = (ncol + 2) * ncol; - - const double vl = 0, vu = 0; - std::vector> work(1, 0); - double *rwork = new double[3 * ncol - 2]; - std::vector iwork(1, 0); - std::vector ifail(PARAM.globalv.nlocal, 0); - - // Original Lapack caculate, obelsete - /* - zhegvx_(&itype, - &jobz, - &range, - &uplo, - &PARAM.globalv.nlocal, - h_tmp.c, - &PARAM.globalv.nlocal, - s_tmp.c, - &PARAM.globalv.nlocal, - &vl, - &vu, - &il, - &iu, - &abstol, - &M, - ekb, - wfc_2d.get_pointer(), - &ncol, - work.data(), - &lwork, - rwork.data(), - iwork.data(), - ifail.data(), - &info); - - if (info) - throw std::runtime_error("info=" + ModuleBase::GlobalFunc::TO_STRING(info) + ". " - + std::string(__FILE__) + " line " - + std::to_string(__LINE__)); - - // GlobalV::ofs_running<<"lwork="< *ev = new std::complex[ncol * ncol]; + std::vector> ev(ncol * ncol); - zhegv_(&itype, &jobz, &uplo, &PARAM.globalv.nlocal, h_tmp.c, &ncol, s_tmp.c, &ncol, ekb, ev, &lwork, rwork, &info); - - return info; + LapackConnector::hegv(LapackConnector::ColMajor, itype, jobz, uplo, PARAM.globalv.nlocal, h_tmp.c, ncol, s_tmp.c, ncol, ekb, ev.data()); } template @@ -258,14 +110,7 @@ void DiagoLapack::dsygvx_diag(const int ncol, double* const ekb, psi::Psi& wfc_2d) { - while (true) - { - - int info_result = dsygvx_once(ncol, nrow, h_mat, s_mat, ekb, wfc_2d); - if (info_result == 0) { - break; - } - } + dsygvx_once(ncol, nrow, h_mat, s_mat, ekb, wfc_2d); } template @@ -276,26 +121,6 @@ void DiagoLapack::zhegvx_diag(const int ncol, double* const ekb, psi::Psi>& wfc_2d) { - while (true) - { - int info_result = zhegvx_once(ncol, nrow, h_mat, s_mat, ekb, wfc_2d); - if (info_result == 0) { - break; - } - } -} - -template -void DiagoLapack::post_processing(const int info, const std::vector& vec) -{ - const std::string str_info = "info = " + ModuleBase::GlobalFunc::TO_STRING(info) + ".\n"; - const std::string str_FILE - = std::string(__FILE__) + " line " + std::to_string(__LINE__) + ".\n"; - const std::string str_info_FILE = str_info + str_FILE; - - if (info == 0) - { - return; - } + zhegvx_once(ncol, nrow, h_mat, s_mat, ekb, wfc_2d); } } // namespace hsolver \ No newline at end of file diff --git a/source/source_hsolver/diago_lapack.h b/source/source_hsolver/diago_lapack.h index 53b710ae63..ec59880475 100644 --- a/source/source_hsolver/diago_lapack.h +++ b/source/source_hsolver/diago_lapack.h @@ -41,22 +41,18 @@ class DiagoLapack double* const ekb, psi::Psi>& wfc_2d); - int dsygvx_once(const int ncol, + void dsygvx_once(const int ncol, const int nrow, const double* const h_mat, const double* const s_mat, double* const ekb, psi::Psi& wfc_2d) const; - int zhegvx_once(const int ncol, + void zhegvx_once(const int ncol, const int nrow, const std::complex* const h_mat, const std::complex* const s_mat, double* const ekb, psi::Psi>& wfc_2d) const; - - int degeneracy_max = 12; // For reorthogonalized memory. 12 followes siesta. - - void post_processing(const int info, const std::vector& vec); }; } // namespace hsolver diff --git a/source/source_hsolver/kernels/cuda/dngvd_op.cu b/source/source_hsolver/kernels/cuda/dngvd_op.cu index 4ce3d9a1d0..db8ff7cbbf 100644 --- a/source/source_hsolver/kernels/cuda/dngvd_op.cu +++ b/source/source_hsolver/kernels/cuda/dngvd_op.cu @@ -211,8 +211,8 @@ struct dngvd_op void operator()(const base_device::DEVICE_GPU* d, const int nstart, const int ldh, - const T* A, // hcc - const T* B, // scc + T* A, // hcc + T* B, // scc Real* W, // eigenvalue T* V) { @@ -231,7 +231,7 @@ struct dnevx_op void operator()(const base_device::DEVICE_GPU* d, const int nstart, const int ldh, - const T* A, // hcc + T* A, // hcc const int m, Real* W, // eigenvalue T* V) diff --git a/source/source_hsolver/kernels/dngvd_op.cpp b/source/source_hsolver/kernels/dngvd_op.cpp index 66cb3c1233..ca98411363 100644 --- a/source/source_hsolver/kernels/dngvd_op.cpp +++ b/source/source_hsolver/kernels/dngvd_op.cpp @@ -14,8 +14,8 @@ struct dngvd_op void operator()(const base_device::DEVICE_CPU* d, const int nstart, const int ldh, - const T* hcc, - const T* scc, + T* hcc, + T* scc, Real* eigenvalue, T* vcc) { @@ -23,62 +23,21 @@ struct dngvd_op { vcc[i] = hcc[i]; } - int info = 0; - int lwork = 2 * nstart + nstart * nstart; - T* work = new T[lwork]; - Parallel_Reduce::ZEROS(work, lwork); - - int lrwork = 1 + 5 * nstart + 2 * nstart * nstart; - Real* rwork = new Real[lrwork]; - Parallel_Reduce::ZEROS(rwork, lrwork); - - int liwork = 3 + 5 * nstart; - int* iwork = new int[liwork]; - Parallel_Reduce::ZEROS(iwork, liwork); //=========================== // calculate all eigenvalues //=========================== - LapackWrapper::xhegvd(1, - 'V', - 'U', - nstart, - vcc, - ldh, - scc, - ldh, - eigenvalue, - work, - lwork, - rwork, - lrwork, - iwork, - liwork, - info); - - if (info != 0) - { - std::cout << "Error: xhegvd failed, linear dependent basis functions\n" - << ", wrong initialization of wavefunction, or wavefunction information loss\n" - << ", output overlap matrix scc.txt to check\n" - << std::endl; - // print scc to file scc.txt - std::ofstream ofs("scc.txt"); - for (int i = 0; i < nstart; i++) - { - for (int j = 0; j < nstart; j++) - { - ofs << scc[i * ldh + j] << " "; - } - ofs << std::endl; - } - ofs.close(); - } - assert(0 == info); - - delete[] work; - delete[] rwork; - delete[] iwork; + LapackConnector::hegvd( + LapackConnector::ColMajor, + 1, + 'V', + 'U', + nstart, + vcc, + ldh, + scc, + ldh, + eigenvalue); } }; @@ -89,7 +48,7 @@ struct dngv_op void operator()(const base_device::DEVICE_CPU* d, const int nbase, const int ldh, - const T* hcc, + T* hcc, T* scc, Real* eigenvalue, T* vcc) @@ -99,43 +58,10 @@ struct dngv_op vcc[i] = hcc[i]; } - int info = 0; - - int lwork = 2 * nbase - 1; - T* work = new T[lwork]; - Parallel_Reduce::ZEROS(work, lwork); - - int lrwork = 3 * nbase - 2; - Real* rwork = new Real[lrwork]; - Parallel_Reduce::ZEROS(rwork, lrwork); - //=========================== // calculate all eigenvalues //=========================== - LapackWrapper::xhegv(1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue, work, lwork, rwork, info); - - if (info != 0) - { - std::cout << "Error: xhegv failed, linear dependent basis functions\n" - << ", wrong initialization of wavefunction, or wavefunction information loss\n" - << ", output overlap matrix scc.txt to check\n" - << std::endl; - // print scc to file scc.txt - std::ofstream ofs("scc.txt"); - for (int i = 0; i < nbase; i++) - { - for (int j = 0; j < nbase; j++) - { - ofs << scc[i * ldh + j] << " "; - } - ofs << std::endl; - } - ofs.close(); - } - assert(0 == info); - - delete[] work; - delete[] rwork; + LapackConnector::hegv(LapackConnector::ColMajor, 1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue); } }; @@ -146,7 +72,7 @@ struct dnevx_op void operator()(const base_device::DEVICE_CPU* /*ctx*/, const int nstart, const int ldh, - const T* hcc, // hcc + T* hcc, // hcc const int nbands, // nbands Real* eigenvalue, // eigenvalue T* vcc) // vcc @@ -156,52 +82,16 @@ struct dnevx_op { aux[ii] = hcc[ii]; } - - int info = 0; - int lwork = -1; - T* work = new T[1]; - Real* rwork = new Real[7 * nstart]; - int* iwork = new int[5 * nstart]; - int* ifail = new int[nstart]; - - // When lwork = -1, the demension of work will be assumed - // Assume the denmension of work by output work[0] - LapackWrapper::xheevx( - 1, // ITYPE = 1: A*x = (lambda)*B*x - 'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors. - 'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found. - 'L', // UPLO = 'L': Lower triangles of A and B are stored. - nstart, // N = base - aux, // A is COMPLEX*16 array dimension (LDA, N) - ldh, // LDA = base - 0.0, // Not referenced if RANGE = 'A' or 'I'. - 0.0, // Not referenced if RANGE = 'A' or 'I'. - 1, // IL: If RANGE='I', the index of the smallest eigenvalue to be returned. 1 <= IL <= IU <= N, - nbands, // IU: If RANGE='I', the index of the largest eigenvalue to be returned. 1 <= IL <= IU <= N, - 0.0, // ABSTOL - nbands, // M: The total number of eigenvalues found. 0 <= M <= N. if RANGE = 'I', M = IU-IL+1. - eigenvalue, // W store eigenvalues - vcc, // store eigenvector - ldh, // LDZ: The leading dimension of the array Z. - work, - lwork, - rwork, - iwork, - ifail, - info); - - lwork = int(get_real(work[0])); - delete[] work; - work = new T[lwork]; - + std::vector ifail(nstart); + int m = 0; // The A and B storage space is (nstart * ldh), and the data that really participates in the zhegvx // operation is (nstart * nstart). In this function, the data that A and B participate in the operation will // be extracted into the new local variables aux and bux (the internal of the function). // V is the output of the function, the storage space is also (nstart * ldh), and the data size of valid V // obtained by the zhegvx operation is (nstart * nstart) and stored in zux (internal to the function). When // the function is output, the data of zux will be mapped to the corresponding position of V. - LapackWrapper::xheevx( - 1, // ITYPE = 1: A*x = (lambda)*B*x + LapackConnector::heevx( + LapackConnector::ColMajor, // matrix_layout 'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors. 'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found. 'L', // UPLO = 'L': Lower triangles of A and B are stored. @@ -213,24 +103,11 @@ struct dnevx_op 1, // IL: If RANGE='I', the index of the smallest eigenvalue to be returned. 1 <= IL <= IU <= N, nbands, // IU: If RANGE='I', the index of the largest eigenvalue to be returned. 1 <= IL <= IU <= N, 0.0, // ABSTOL - nbands, // M: The total number of eigenvalues found. 0 <= M <= N. if RANGE = 'I', M = IU-IL+1. + &m, // M: The total number of eigenvalues found. 0 <= M <= N. if RANGE = 'I', M = IU-IL+1. eigenvalue, // W store eigenvalues vcc, // store eigenvector ldh, // LDZ: The leading dimension of the array Z. - work, - lwork, - rwork, - iwork, - ifail, - info); - - delete[] aux; - delete[] work; - delete[] rwork; - delete[] iwork; - delete[] ifail; - - assert(0 == info); + ifail.data()); } }; @@ -247,19 +124,12 @@ struct dngvx_op Real* eigenvalue, T* vcc) { - - int info = 0; - int mm = m; - int lwork = -1; - - T* work = new T[1]; - Real* rwork = new Real[7 * nbase]; - int* iwork = new int[5 * nbase]; - int* ifail = new int[nbase]; + std::vector ifail(nbase); - LapackWrapper::xhegvx( + LapackConnector::hegvx( + LapackConnector::ColMajor, // matrix_layout 1, // ITYPE = 1: A*x = (lambda)*B*x 'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors. 'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found. @@ -274,50 +144,11 @@ struct dngvx_op 1, // IL: If RANGE='I', the index of the smallest eigenvalue to be returned. 1 <= IL <= IU <= N, m, // IU: If RANGE='I', the index of the largest eigenvalue to be returned. 1 <= IL <= IU <= N, 0.0, // ABSTOL - mm, // M: The total number of eigenvalues found. 0 <= M <= N. if RANGE = 'I', M = IU-IL+1. + &mm, // M: The total number of eigenvalues found. 0 <= M <= N. if RANGE = 'I', M = IU-IL+1. eigenvalue, // W store eigenvalues vcc, // store eigenvector ldh, // LDZ: The leading dimension of the array Z. - work, - lwork, - rwork, - iwork, - ifail, - info); - - lwork = int(get_real(work[0])); - delete[] work; - work = new T[lwork]; - - LapackWrapper::xhegvx(1, - 'V', - 'I', - 'U', - nbase, - hcc, - ldh, - scc, - ldh, - 0.0, - 0.0, - 1, - m, - 0.0, - mm, - eigenvalue, - vcc, - ldh, - work, - lwork, - rwork, - iwork, - ifail, - info); - - delete[] work; - delete[] rwork; - delete[] iwork; - delete[] ifail; + ifail.data()); } }; diff --git a/source/source_hsolver/kernels/dngvd_op.h b/source/source_hsolver/kernels/dngvd_op.h index c48cd576b5..6a75db1802 100644 --- a/source/source_hsolver/kernels/dngvd_op.h +++ b/source/source_hsolver/kernels/dngvd_op.h @@ -4,7 +4,7 @@ #define MODULE_HSOLVER_DNGVD_H #include "source_base/macros.h" -#include "source_base/module_external/lapack_wrapper.h" +#include "source_base/module_external/lapack_connector.h" #include "source_base/parallel_reduce.h" #include "source_base/module_device/types.h" @@ -43,7 +43,7 @@ struct dngvd_op /// Output Parameter /// @param W : calculated eigenvalues /// @param V : calculated eigenvectors (col major) - void operator()(const Device* d, const int nstart, const int ldh, const T* A, const T* B, Real* W, T* V); + void operator()(const Device* d, const int nstart, const int ldh, T* A, T* B, Real* W, T* V); }; template @@ -103,7 +103,7 @@ struct dnevx_op /// Output Parameter /// @param W : calculated eigenvalues /// @param V : calculated eigenvectors (row major) - void operator()(const Device* d, const int nstart, const int ldh, const T* A, const int m, Real* W, T* V); + void operator()(const Device* d, const int nstart, const int ldh, T* A, const int m, Real* W, T* V); }; #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM diff --git a/source/source_hsolver/kernels/rocm/dngvd_op.hip.cu b/source/source_hsolver/kernels/rocm/dngvd_op.hip.cu index a359ccda87..ee3aecd323 100644 --- a/source/source_hsolver/kernels/rocm/dngvd_op.hip.cu +++ b/source/source_hsolver/kernels/rocm/dngvd_op.hip.cu @@ -31,8 +31,8 @@ template <> void dngvd_op::operator()(const base_device::DEVICE_GPU* ctx, const int nstart, const int ldh, - const double* _hcc, - const double* _scc, + double* _hcc, + double* _scc, double* _eigenvalue, double* _vcc) { @@ -105,8 +105,8 @@ template <> void dngvd_op, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* ctx, const int nstart, const int ldh, - const std::complex* _hcc, - const std::complex* _scc, + std::complex* _hcc, + std::complex* _scc, float* _eigenvalue, std::complex* _vcc) { @@ -177,8 +177,8 @@ template <> void dngvd_op, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* ctx, const int nstart, const int ldh, - const std::complex* _hcc, - const std::complex* _scc, + std::complex* _hcc, + std::complex* _scc, double* _eigenvalue, std::complex* _vcc ) @@ -261,7 +261,7 @@ template <> void dnevx_op::operator()(const base_device::DEVICE_GPU* ctx, const int nstart, const int ldh, - const double* _hcc, + double* _hcc, const int m, double* _eigenvalue, double* _vcc) @@ -281,7 +281,7 @@ template <> void dnevx_op, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* ctx, const int nstart, const int ldh, - const std::complex* _hcc, + std::complex* _hcc, const int m, float* _eigenvalue, std::complex* _vcc) @@ -306,7 +306,7 @@ template <> void dnevx_op, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* ctx, const int nstart, const int ldh, - const std::complex* _hcc, + std::complex* _hcc, const int m, double* _eigenvalue, std::complex* _vcc) diff --git a/source/source_hsolver/test/diago_bpcg_test.cpp b/source/source_hsolver/test/diago_bpcg_test.cpp index 93e1147ccf..cfbb5b032a 100644 --- a/source/source_hsolver/test/diago_bpcg_test.cpp +++ b/source/source_hsolver/test/diago_bpcg_test.cpp @@ -39,18 +39,12 @@ void lapackEigen(int &npw, std::vector> &hm, double *e, boo { clock_t start, end; start = clock(); - int lwork = 2 * npw; - std::complex *work2 = new std::complex[lwork]; - double *rwork = new double[3 * npw - 2]; - int info = 0; char tmp_c1 = 'V', tmp_c2 = 'U'; - zheev_(&tmp_c1, &tmp_c2, &npw, hm.data(), &npw, e, work2, &lwork, rwork, &info); + LapackConnector::heev(LapackConnector::ColMajor, tmp_c1, tmp_c2, npw, hm.data(), npw, e); end = clock(); if (outtime) { std::cout << "Lapack Run time: " << (double)(end - start) / CLOCKS_PER_SEC << " S" << std::endl; } - delete[] rwork; - delete[] work2; } class DiagoBPCGPrepare diff --git a/source/source_hsolver/test/diago_cg_float_test.cpp b/source/source_hsolver/test/diago_cg_float_test.cpp index 60d9a34313..663eca490d 100644 --- a/source/source_hsolver/test/diago_cg_float_test.cpp +++ b/source/source_hsolver/test/diago_cg_float_test.cpp @@ -44,18 +44,12 @@ void lapackEigen(int &npw, std::vector> &hm, float *e, bool { clock_t start, end; start = clock(); - int lwork = 2 * npw; - std::complex *work2 = new std::complex[lwork]; - float *rwork = new float[3 * npw - 2]; - int info = 0; char tmp_c1 = 'V', tmp_c2 = 'U'; - cheev_(&tmp_c1, &tmp_c2, &npw, hm.data(), &npw, e, work2, &lwork, rwork, &info); + LapackConnector::heev(LapackConnector::ColMajor, tmp_c1, tmp_c2, npw, hm.data(), npw, e); end = clock(); if (outtime) { std::cout << "Lapack Run time: " << (float)(end - start) / CLOCKS_PER_SEC << " S" << std::endl; } - delete[] rwork; - delete[] work2; } class DiagoCGPrepare diff --git a/source/source_hsolver/test/diago_cg_real_test.cpp b/source/source_hsolver/test/diago_cg_real_test.cpp index 48e0793fa2..1eaeb4454a 100644 --- a/source/source_hsolver/test/diago_cg_real_test.cpp +++ b/source/source_hsolver/test/diago_cg_real_test.cpp @@ -41,25 +41,14 @@ // call lapack in order to compare to cg void lapackEigen(int& npw, std::vector& hm, double* e, bool outtime = false) { - int info = 0; auto tmp = hm; clock_t start, end; start = clock(); char tmp_c1 = 'V', tmp_c2 = 'U'; - - double work_tmp; - constexpr int minus_one = -1; - dsyev_(&tmp_c1, &tmp_c2, &npw, tmp.data(), &npw, e, &work_tmp, &minus_one, &info); // get best lwork - - const int lwork = work_tmp; - double* work2 = new double[lwork]; - dsyev_(&tmp_c1, &tmp_c2, &npw, tmp.data(), &npw, e, work2, &lwork, &info); + LapackConnector::syev(LapackConnector::ColMajor, tmp_c1, tmp_c2, npw, tmp.data(), npw, e); end = clock(); - if (info) { std::cout << "ERROR: Lapack solver, info=" << info << std::endl; -} if (outtime) { std::cout << "Lapack Run time: " << (double)(end - start) / CLOCKS_PER_SEC << " S" << std::endl; } - delete[] work2; } class DiagoCGPrepare diff --git a/source/source_hsolver/test/diago_cg_test.cpp b/source/source_hsolver/test/diago_cg_test.cpp index 20b115d058..d65b53cfbe 100644 --- a/source/source_hsolver/test/diago_cg_test.cpp +++ b/source/source_hsolver/test/diago_cg_test.cpp @@ -44,18 +44,11 @@ void lapackEigen(int &npw, std::vector> &hm, double *e, boo { clock_t start, end; start = clock(); - int lwork = 2 * npw; - std::complex *work2 = new std::complex[lwork]; - double *rwork = new double[3 * npw - 2]; - int info = 0; - char tmp_c1 = 'V', tmp_c2 = 'U'; - zheev_(&tmp_c1, &tmp_c2, &npw, hm.data(), &npw, e, work2, &lwork, rwork, &info); + LapackConnector::heev(LapackConnector::ColMajor, tmp_c1, tmp_c2, npw, hm.data(), npw, e); end = clock(); if (outtime) { std::cout << "Lapack Run time: " << (double)(end - start) / CLOCKS_PER_SEC << " S" << std::endl; } - delete[] rwork; - delete[] work2; } class DiagoCGPrepare diff --git a/source/source_hsolver/test/diago_david_float_test.cpp b/source/source_hsolver/test/diago_david_float_test.cpp index f907f939e4..405d07bb85 100644 --- a/source/source_hsolver/test/diago_david_float_test.cpp +++ b/source/source_hsolver/test/diago_david_float_test.cpp @@ -34,25 +34,15 @@ //NOTE: after finish this function, hm stores the eigen vectors. void lapackEigen(int &npw, std::vector> &hm, float * e, bool outtime=false) { - int lwork = 2 * npw; - std::complex *work2= new std::complex[lwork]; - float* rwork = new float[3*npw-2]; - int info = 0; - auto tmp = hm; clock_t start,end; start = clock(); char tmp_c1 = 'V', tmp_c2 = 'U'; - cheev_(&tmp_c1, &tmp_c2, &npw, tmp.data(), &npw, e, work2, &lwork, rwork, &info); + LapackConnector::heev(LapackConnector::ColMajor, tmp_c1, tmp_c2, npw, tmp.data(), npw, e); end = clock(); - if(info) { std::cout << "ERROR: Lapack solver, info=" << info <& hm, double* e, bool outtime = false) { - int info = 0; auto tmp = hm; clock_t start, end; start = clock(); char tmp_c1 = 'V', tmp_c2 = 'U'; - - double work_tmp; - constexpr int minus_one = -1; - dsyev_(&tmp_c1, &tmp_c2, &npw, tmp.data(), &npw, e, &work_tmp, &minus_one, &info); // get best lwork - - const int lwork = work_tmp; - double* work2 = new double[lwork]; - dsyev_(&tmp_c1, &tmp_c2, &npw, tmp.data(), &npw, e, work2, &lwork, &info); + LapackConnector::syev(LapackConnector::ColMajor, tmp_c1, tmp_c2, npw, tmp.data(), npw, e); end = clock(); - if (info) { std::cout << "ERROR: Lapack solver, info=" << info << std::endl; -} if (outtime) { std::cout << "Lapack Run time: " << (double)(end - start) / CLOCKS_PER_SEC << " S" << std::endl; } - delete[] work2; } class DiagoDavPrepare { diff --git a/source/source_hsolver/test/diago_david_test.cpp b/source/source_hsolver/test/diago_david_test.cpp index 643eeed4bf..a7189cda8d 100644 --- a/source/source_hsolver/test/diago_david_test.cpp +++ b/source/source_hsolver/test/diago_david_test.cpp @@ -36,25 +36,13 @@ void lapackEigen(int& npw, std::vector>& hm, double* e, bool outtime = false) { - int lwork = 2 * npw; - std::complex *work2= new std::complex[lwork]; - double* rwork = new double[3*npw-2]; - int info = 0; - - auto tmp = hm; - clock_t start,end; start = clock(); char tmp_c1 = 'V', tmp_c2 = 'U'; - zheev_(&tmp_c1, &tmp_c2, &npw, tmp.data(), &npw, e, work2, &lwork, rwork, &info); + LapackConnector::heev(LapackConnector::ColMajor, tmp_c1, tmp_c2, npw, tmp.data(), npw, e); end = clock(); - if(info) { std::cout << "ERROR: Lapack solver, info=" << info < ev(nFull * nFull); - double *a = new double[nFull * nFull]; - double *b = new double[nFull * nFull]; + std::vector a(nFull * nFull); + std::vector b(nFull * nFull); for (int i = 0; i < nFull * nFull; i++) { a[i] = hmatrix[i]; b[i] = smatrix[i]; } - dsygv_(&itype, &jobz, &uplo, &nFull, a, &nFull, b, &nFull, e, ev, &lwork, &info); - if (info != 0) - { - std::cout << "ERROR: solvered by LAPACK error, info=" << info << std::endl; - exit(1); - } - - delete[] a; - delete[] b; - delete[] ev; + LapackConnector::sygv(LapackConnector::ColMajor, itypr, jobz, uplo, nFull, a.data(), nFull, b.data(), nFull, e, ev.data()); } void lapack_diago(std::complex *hmatrix, std::complex *smatrix, double *e, int &nFull) @@ -194,29 +184,17 @@ void lapack_diago(std::complex *hmatrix, std::complex *smatrix, const int itype = 1; // solve A*X=(lambda)*B*X const char jobz = 'V'; // 'N':only calc eigenvalue, 'V': eigenvalues and eigenvectors const char uplo = 'U'; // Upper triangles - int lwork = (nFull + 1) * nFull, info = 0; - double *rwork = new double[3 * nFull - 2]; - std::complex *ev = new std::complex[nFull * nFull]; + std::vector> ev(nFull * nFull); - std::complex *a = new std::complex[nFull * nFull]; - std::complex *b = new std::complex[nFull * nFull]; + std::vector> a(nFull * nFull); + std::vector> b(nFull * nFull); for (int i = 0; i < nFull * nFull; i++) { a[i] = hmatrix[i]; b[i] = smatrix[i]; } - zhegv_(&itype, &jobz, &uplo, &nFull, a, &nFull, b, &nFull, e, ev, &lwork, rwork, &info); - if (info != 0) - { - std::cout << "ERROR: solvered by LAPACK error, info=" << info << std::endl; - exit(1); - } - - delete[] a; - delete[] b; - delete[] ev; - delete[] rwork; + LapackConnector::hegv(LapackConnector::ColMajor, itype, jobz, uplo, nFull, a.data(), nFull, b.data(), nFull, e, ev.data()); } } // namespace LCAO_DIAGO_TEST diff --git a/source/source_io/berryphase.cpp b/source/source_io/berryphase.cpp index 8d31edec91..f48ac6deef 100644 --- a/source/source_io/berryphase.cpp +++ b/source/source_io/berryphase.cpp @@ -313,7 +313,7 @@ double berryphase::stringPhase(const UnitCell& ucell, std::complex det(1.0, 0.0); int info = 0; std::vector ipiv(nbands); - LapackConnector::zgetrf(nbands, nbands, mat, nbands, ipiv.data(), &info); + LapackConnector::getrf(LapackConnector::RowMajor, nbands, nbands, mat.c, nbands, ipiv.data()); for (int ib = 0; ib < nbands; ib++) { if (ipiv[ib] != (ib + 1)) { diff --git a/source/source_io/write_vxc.hpp b/source/source_io/write_vxc.hpp index 43fd803bb7..cfc9c3643e 100644 --- a/source/source_io/write_vxc.hpp +++ b/source/source_io/write_vxc.hpp @@ -2,7 +2,7 @@ #define __WRITE_VXC_H_ #include "source_io/module_parameter/parameter.h" #include "source_base/parallel_reduce.h" -#include "source_base/module_container/base/third_party/blas.h" +#include "source_base/module_external/blas_connector.h" #include "source_base/module_external/scalapack_connector.h" #include "source_lcao/module_operator_lcao/op_dftu_lcao.h" #include "source_lcao/module_operator_lcao/veff_lcao.h" @@ -62,7 +62,7 @@ inline std::vector cVc(T* V, c, i1, i1, pv.desc_wfc, beta, Vc.data(), i1, i1, pv.desc_wfc); #else - container::BlasConnector::gemm(transa, transb, nbasis, nbands, nbasis, alpha, V, nbasis, c, nbasis, beta, Vc.data(), nbasis); + BlasConnector::gemm_cm(transa, transb, nbasis, nbands, nbasis, alpha, V, nbasis, c, nbasis, beta, Vc.data(), nbasis); #endif std::vector cVc(p2d.nloc, 0.0); transa = (std::is_same::value ? 'T' : 'C'); @@ -73,7 +73,7 @@ inline std::vector cVc(T* V, Vc.data(), i1, i1, pv.desc_wfc, beta, cVc.data(), i1, i1, p2d.desc); #else - container::BlasConnector::gemm(transa, transb, nbands, nbands, nbasis, alpha, c, nbasis, Vc.data(), nbasis, beta, cVc.data(), nbasis); + BlasConnector::gemm_cm(transa, transb, nbands, nbands, nbasis, alpha, c, nbasis, Vc.data(), nbasis, beta, cVc.data(), nbasis); #endif return cVc; } diff --git a/source/source_io/write_vxc_lip.hpp b/source/source_io/write_vxc_lip.hpp index 3705993022..6e4916b68e 100644 --- a/source/source_io/write_vxc_lip.hpp +++ b/source/source_io/write_vxc_lip.hpp @@ -2,7 +2,7 @@ #define __WRITE_VXC_LIP_H_ #include "source_io/module_parameter/parameter.h" #include "source_base/parallel_reduce.h" -#include "source_base/module_container/base/third_party/blas.h" +#include "source_base/blas_connector.h" #include "source_pw/module_pwdft/operator_pw/veff_pw.h" #include "source_psi/psi.h" #include "source_cell/unitcell.h" @@ -36,12 +36,12 @@ namespace ModuleIO char transb = 'N'; const T alpha(1.0, 0.0); const T beta(0.0, 0.0); - container::BlasConnector::gemm(transa, transb, nbasis, nbands, nbasis, + BlasConnector::gemm_cm(transa, transb, nbasis, nbands, nbasis, alpha, V, nbasis, c, nbasis, beta, Vc.data(), nbasis); std::vector cVc(nbands * nbands, 0.0); transa = ((std::is_same::value || std::is_same::value) ? 'T' : 'C'); - container::BlasConnector::gemm(transa, transb, nbands, nbands, nbasis, + BlasConnector::gemm_cm(transa, transb, nbands, nbands, nbasis, alpha, c, nbasis, Vc.data(), nbasis, beta, cVc.data(), nbands); return cVc; } @@ -54,7 +54,7 @@ namespace ModuleIO std::vector cVc(nbands * nbands, (T)0.0); const T alpha(1.0, 0.0); const T beta(0.0, 0.0); - container::BlasConnector::gemm('C', 'N', nbands, nbands, nbasis, alpha, + BlasConnector::gemm_cm('C', 'N', nbands, nbands, nbasis, alpha, psi, nbasis, hpsi, nbasis, beta, cVc.data(), nbands); return cVc; } @@ -204,7 +204,7 @@ namespace ModuleIO // std::cout << "exx_energy from orbitals: " << all_band_energy(ik, e_orb_exx.at(ik), wg) << std::endl; // std::cout << "exx_energy from exx_lip: " << GlobalC::exx_info.info_global.hybrid_alpha * exx_lip.get_exx_energy() << std::endl; // ======test======= - container::BlasConnector::axpy(nbands * nbands, 1.0, vexx_k_mo.data(), 1, vxc_tot_k_mo.data(), 1); + BlasConnector::axpy(nbands * nbands, 1.0, vexx_k_mo.data(), 1, vxc_tot_k_mo.data(), 1); } #endif diff --git a/source/source_lcao/module_deepks/deepks_orbpre.cpp b/source/source_lcao/module_deepks/deepks_orbpre.cpp index 90cb40ea95..462565ed18 100644 --- a/source/source_lcao/module_deepks/deepks_orbpre.cpp +++ b/source/source_lcao/module_deepks/deepks_orbpre.cpp @@ -215,19 +215,20 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector& dm_hl, gemm_alpha = 2.0; } - dgemm_(&transa, - &transb, - &row_size_nks, - &trace_alpha_size, - &col_size, - &gemm_alpha, + BlasConnector::gemm_cm( + transa, + transb, + row_size_nks, + trace_alpha_size, + col_size, + gemm_alpha, dm_array.data(), - &col_size, + col_size, s_2t.data(), - &col_size, - &gemm_beta, + col_size, + gemm_beta, g_1dmt.data(), - &row_size_nks); + row_size_nks); } // ad2 for (int ik = 0; ik < nks; ik++) diff --git a/source/source_lcao/module_deepks/deepks_pdm.cpp b/source/source_lcao/module_deepks/deepks_pdm.cpp index ffc0efc8fb..9fdc602ef5 100644 --- a/source/source_lcao/module_deepks/deepks_pdm.cpp +++ b/source/source_lcao/module_deepks/deepks_pdm.cpp @@ -378,19 +378,20 @@ void DeePKS_domain::cal_pdm(bool& init_pdm, // all the input should be data pointer constexpr char transa = 'T', transb = 'N'; const double gemm_alpha = 1.0, gemm_beta = 1.0; - dgemm_(&transa, - &transb, - &row_size, - &trace_alpha_size, - &col_size, - &gemm_alpha, + BlasConnector::gemm_cm( + transa, + transb, + row_size, + trace_alpha_size, + col_size, + gemm_alpha, dm_current, - &col_size, + col_size, s_2t.data(), - &col_size, - &gemm_beta, + col_size, + gemm_beta, g_1dmt.data(), - &row_size); + row_size); } // ad2 if (!PARAM.inp.deepks_equiv) { diff --git a/source/source_lcao/module_gint/gint_rho_old.cpp b/source/source_lcao/module_gint/gint_rho_old.cpp index b3027d6b12..c05d681841 100644 --- a/source/source_lcao/module_gint/gint_rho_old.cpp +++ b/source/source_lcao/module_gint/gint_rho_old.cpp @@ -20,7 +20,7 @@ void Gint::cal_meshball_rho(const int na_grid, // sum over mu to get density on grid for (int ib = 0; ib < this->bxyz; ++ib) { - const double r = ddot_(&block_index[na_grid], psir_ylm[ib], &inc, psir_DMR[ib], &inc); + const double r = BlasConnector::dot(block_index[na_grid], psir_ylm[ib], inc, psir_DMR[ib], inc); const int grid = vindex[ib]; rho[grid] += r; } diff --git a/source/source_lcao/module_gint/gint_tau_old.cpp b/source/source_lcao/module_gint/gint_tau_old.cpp index adf20d45b5..69a221c9e6 100644 --- a/source/source_lcao/module_gint/gint_tau_old.cpp +++ b/source/source_lcao/module_gint/gint_tau_old.cpp @@ -29,9 +29,9 @@ void Gint::cal_meshball_tau( // sum over mu to get density on grid for(int ib=0; ibbxyz; ++ib) { - double rx=ddot_(&block_index[na_grid], dpsix[ib], &inc, dpsix_dm[ib], &inc); - double ry=ddot_(&block_index[na_grid], dpsiy[ib], &inc, dpsiy_dm[ib], &inc); - double rz=ddot_(&block_index[na_grid], dpsiz[ib], &inc, dpsiz_dm[ib], &inc); + double rx=BlasConnector::dot(block_index[na_grid], dpsix[ib], inc, dpsix_dm[ib], inc); + double ry=BlasConnector::dot(block_index[na_grid], dpsiy[ib], inc, dpsiy_dm[ib], inc); + double rz=BlasConnector::dot(block_index[na_grid], dpsiz[ib], inc, dpsiz_dm[ib], inc); const int grid = vindex[ib]; rho[ grid ] += rx + ry + rz; } diff --git a/source/source_lcao/module_gint/gint_vl_old.cpp b/source/source_lcao/module_gint/gint_vl_old.cpp index 9ebc341d7f..25824b1866 100644 --- a/source/source_lcao/module_gint/gint_vl_old.cpp +++ b/source/source_lcao/module_gint/gint_vl_old.cpp @@ -83,10 +83,10 @@ void Gint::cal_meshball_vlocal( hr_tmp.resize(m * n); ModuleBase::GlobalFunc::ZEROS(hr_tmp.data(), m*n); - dgemm_(&transa, &transb, &n, &m, &ib_length, &alpha, - &psir_vlbr3[first_ib][block_index[ia2]], &LD_pool, - &psir_ylm[first_ib][block_index[ia1]], &LD_pool, - &beta, hr_tmp.data(), &n); + BlasConnector::gemm('T', 'N', m, n, ib_length, 1, + &psir_ylm[first_ib][block_index[ia1]], LD_pool, + &psir_vlbr3[first_ib][block_index[ia2]], LD_pool, + 1, hr_tmp.data(), n); tmp_matrix->add_array_ts(hr_tmp.data()); } } diff --git a/source/source_lcao/module_gint/mult_psi_dmr.cpp b/source/source_lcao/module_gint/mult_psi_dmr.cpp index fab47c1aee..8144d9ffc6 100644 --- a/source/source_lcao/module_gint/mult_psi_dmr.cpp +++ b/source/source_lcao/module_gint/mult_psi_dmr.cpp @@ -95,8 +95,8 @@ void mult_psi_DMR( const int idx1 = block_index[ia1]; const int idx2 = block_index[ia2]; - dgemm_(&trans, &trans, &block_size[ia2], &ib_len, &block_size[ia1], &alpha1, tmp_matrix_ptr, &block_size[ia2], - &psi[ib_start][idx1], &LD_pool, &beta, &psi_DMR[ib_start][idx2], &LD_pool); + BlasConnector::gemm_cm(trans, trans, block_size[ia2], ib_len, block_size[ia1], alpha1, tmp_matrix_ptr, block_size[ia2], + &psi[ib_start][idx1], LD_pool, beta, &psi_DMR[ib_start][idx2], LD_pool); } // ia2 } // ia1 diff --git a/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_serial.cpp b/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_serial.cpp index 1ddec7f8da..cadaa29893 100644 --- a/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_serial.cpp +++ b/source/source_lcao/module_lr/ao_to_mo_transformer/ao_to_mo_serial.cpp @@ -110,15 +110,15 @@ namespace LR char transb = 'N'; //coeff is col major const double alpha = 1.0; const double beta = add_on ? 1.0 : 0.0; - dgemm_(&transa, &transb, &naos, &nmo1, &naos, &alpha, - mat_ao[isk].data(), &naos, coeff.get_pointer(imo1), &naos, &beta, - Vc.data(), &naos); + BlasConnector::gemm_cm(transa, transb, naos, nmo1, naos, alpha, + mat_ao[isk].data(), naos, coeff.get_pointer(imo1), naos, beta, + Vc.data(), naos); transa = 'T'; //mat_mo=coeff^TVc (nvirt major) - dgemm_(&transa, &transb, &nmo2, &nmo1, &naos, &alpha, - coeff.get_pointer(imo2), &naos, Vc.data(), &naos, &beta, - mat_mo + start, &nmo2); + BlasConnector::gemm_cm(transa, transb, nmo2, nmo1, naos, alpha, + coeff.get_pointer(imo2), naos, Vc.data(), naos, beta, + mat_mo + start, nmo2); } } template<> @@ -151,15 +151,15 @@ namespace LR char transb = 'N'; //coeff is col major const std::complex alpha(1.0, 0.0); const std::complex beta = add_on ? std::complex(1.0, 0.0) : std::complex(0.0, 0.0); - zgemm_(&transa, &transb, &naos, &nmo1, &naos, &alpha, - mat_ao[isk].data>(), &naos, coeff.get_pointer(imo1), &naos, &beta, - Vc.data>(), &naos); + BlasConnector::gemm_cm(transa, transb, naos, nmo1, naos, alpha, + mat_ao[isk].data>(), naos, coeff.get_pointer(imo1), naos, beta, + Vc.data>(), naos); transa = 'C'; //mat_mo=coeff^\dagger Vc (nvirt major) - zgemm_(&transa, &transb, &nmo2, &nmo1, &naos, &alpha, - coeff.get_pointer(imo2), &naos, Vc.data>(), &naos, &beta, - mat_mo + start, &nmo2); + BlasConnector::gemm_cm(transa, transb, nmo2, nmo1, naos, alpha, + coeff.get_pointer(imo2), naos, Vc.data>(), naos, beta, + mat_mo + start, nmo2); } } } \ No newline at end of file diff --git a/source/source_lcao/module_lr/dm_trans/dm_trans_serial.cpp b/source/source_lcao/module_lr/dm_trans/dm_trans_serial.cpp index 7509760345..d24c8a617b 100644 --- a/source/source_lcao/module_lr/dm_trans/dm_trans_serial.cpp +++ b/source/source_lcao/module_lr/dm_trans/dm_trans_serial.cpp @@ -112,13 +112,13 @@ namespace LR const double alpha = 1.0; const double beta = 0.0; container::Tensor Xc(DAT::DT_DOUBLE, DEV::CpuDevice, { nmo2, naos }); - dgemm_(&transa, &transb, &naos, &nmo2, &nmo1, &alpha, - c.get_pointer(imo1), &naos, X_istate + x_start, &nmo2, - &beta, Xc.data(), &naos); + BlasConnector::gemm_cm(transa, transb, naos, nmo2, nmo1, alpha, + c.get_pointer(imo1), naos, X_istate + x_start, nmo2, + beta, Xc.data(), naos); // 2. C_virt*[X*C_occ^T] - dgemm_(&transa, &transb, &naos, &naos, &nmo2, &factor, - c.get_pointer(imo2), &naos, Xc.data(), &naos, &beta, - dm_trans[isk].data(), &naos); + BlasConnector::gemm_cm(transa, transb, naos, naos, nmo2, factor, + c.get_pointer(imo2), naos, Xc.data(), naos, beta, + dm_trans[isk].data(), naos); } return dm_trans; } @@ -166,14 +166,14 @@ namespace LR // ============== = [C_occ^* * X^T * C_virt^T]^T============= // 1. X*C_occ^\dagger container::Tensor Xc(DAT::DT_COMPLEX_DOUBLE, DEV::CpuDevice, { naos, nmo2 }); - zgemm_(&transa, &transb, &nmo2, &naos, &nmo1, &alpha, - X_istate + x_start, &nmo2, c.get_pointer(imo1), &naos, - &beta, Xc.data>(), &nmo2); + BlasConnector::gemm_cm(transa, transb, nmo2, naos, nmo1, alpha, + X_istate + x_start, nmo2, c.get_pointer(imo1), naos, + beta, Xc.data>(), nmo2); // 2. [X*C_occ^\dagger]^TC_virt^T transa = transb = 'T'; - zgemm_(&transa, &transb, &naos, &naos, &nmo2, &factor, - Xc.data>(), &nmo2, c.get_pointer(imo2), &naos, &beta, - dm_trans[isk].data>(), &naos); + BlasConnector::gemm_cm(transa, transb, naos, naos, nmo2, factor, + Xc.data>(), nmo2, c.get_pointer(imo2), naos, beta, + dm_trans[isk].data>(), naos); } return dm_trans; } diff --git a/source/source_lcao/module_lr/ri_benchmark/ri_benchmark.hpp b/source/source_lcao/module_lr/ri_benchmark/ri_benchmark.hpp index 501e3cf5f3..c1b2a7c347 100644 --- a/source/source_lcao/module_lr/ri_benchmark/ri_benchmark.hpp +++ b/source/source_lcao/module_lr/ri_benchmark/ri_benchmark.hpp @@ -1,6 +1,6 @@ #pragma once #include "ri_benchmark.h" -#include "source_base/module_container/base/third_party/blas.h" +#include "source_base/blas_connector.h" namespace RI_Benchmark { // std::cout << "the size of Cs:" << std::endl; @@ -91,13 +91,13 @@ namespace RI_Benchmark // caution: Cs are row-major (ia2 contiguous) if (occ_first) { - container::BlasConnector::gemm('T', 'N', nvirt, nw1, nw2, 1.0, psi_a2.data(), nw2, ptr, nw2, 0.0, tmp.data(), nvirt); - container::BlasConnector::gemm('N', 'N', nvirt, nocc, nw1, 1.0, tmp.data(), nvirt, psi_a1.data(), nw1, 0.0, &Cs_mo[c1.first][c2.first](iabf, 0, 0), nvirt); + BlasConnector::gemm_cm('T', 'N', nvirt, nw1, nw2, 1.0, psi_a2.data(), nw2, ptr, nw2, 0.0, tmp.data(), nvirt); + BlasConnector::gemm_cm('N', 'N', nvirt, nocc, nw1, 1.0, tmp.data(), nvirt, psi_a1.data(), nw1, 0.0, &Cs_mo[c1.first][c2.first](iabf, 0, 0), nvirt); } else { - container::BlasConnector::gemm('T', 'N', nw1, nocc, nw2, 1.0, ptr, nw2, psi_a2.data(), nw2, 0.0, tmp.data(), nw1); - container::BlasConnector::gemm('T', 'N', nvirt, nocc, nw1, 1.0, psi_a1.data(), nw1, tmp.data(), nw1, 0.0, &Cs_mo[c1.first][c2.first](iabf, 0, 0), nvirt); + BlasConnector::gemm_cm('T', 'N', nw1, nocc, nw2, 1.0, ptr, nw2, psi_a2.data(), nw2, 0.0, tmp.data(), nw1); + BlasConnector::gemm_cm('T', 'N', nvirt, nocc, nw1, 1.0, psi_a1.data(), nw1, tmp.data(), nw1, 0.0, &Cs_mo[c1.first][c2.first](iabf, 0, 0), nvirt); } } iw2 += nw2; @@ -144,8 +144,8 @@ namespace RI_Benchmark assert(tensor_ca.shape[0] == nabf1); //abf1 assert(tensor_cb.shape[0] == nabf2); //abf2 std::vector tmp(npairs * nabf1); - container::BlasConnector::gemm('T', 'T', nabf1, npairs, nabf2, 1.0, tensor_v.ptr(), nabf2, tensor_cb.ptr(), npairs, 0.0, tmp.data(), nabf1); - container::BlasConnector::gemm('N', 'N', npairs, npairs, nabf1, 2.0/*Hartree to Ry*/, tensor_ca.ptr(), npairs, tmp.data(), nabf1, 1.0, Amat_full.data(), npairs); + BlasConnector::gemm_cm('T', 'T', nabf1, npairs, nabf2, 1.0, tensor_v.ptr(), nabf2, tensor_cb.ptr(), npairs, 0.0, tmp.data(), nabf1); + BlasConnector::gemm_cm('N', 'N', npairs, npairs, nabf1, 2.0/*Hartree to Ry*/, tensor_ca.ptr(), npairs, tmp.data(), nabf1, 1.0, Amat_full.data(), npairs); } } } @@ -167,7 +167,7 @@ namespace RI_Benchmark const int& npairs = tensor_c.shape[1] * tensor_c.shape[2]; std::vector CX(nabf); for (int iabf = 0;iabf < nabf;++iabf) - CX[iabf] = container::BlasConnector::dot(npairs, &tensor_c(iabf, 0, 0), 1, X, 1); + CX[iabf] = BlasConnector::dot(npairs, &tensor_c(iabf, 0, 0), 1, X, 1); CsX[iat1][it2.first] = CX; } } @@ -201,12 +201,12 @@ namespace RI_Benchmark if (CV.count(iat2) && CV.at(iat2).count({ iat3, {0, 0, 0} })) // add-up, sum over iat1 { auto& tensor_cv = CV.at(iat2).at({ iat3, {0, 0, 0} }); - container::BlasConnector::gemm('N', 'T', npairs, nabf2, nabf1, 1.0, tensor_c.ptr(), npairs, tensor_v.ptr(), nabf2, 1.0, tensor_cv.ptr(), npairs); + BlasConnector::gemm_cm('N', 'T', npairs, nabf2, nabf1, 1.0, tensor_c.ptr(), npairs, tensor_v.ptr(), nabf2, 1.0, tensor_cv.ptr(), npairs); } else { RI::Tensor tmp({ nabf2, tensor_c.shape[1], tensor_c.shape[2] }); // (nabf2, nocc, nvirt) - container::BlasConnector::gemm('N', 'T', npairs, nabf2, nabf1, 1.0, tensor_c.ptr(), npairs, tensor_v.ptr(), nabf2, 0.0, tmp.ptr(), npairs); + BlasConnector::gemm_cm('N', 'T', npairs, nabf2, nabf1, 1.0, tensor_c.ptr(), npairs, tensor_v.ptr(), nabf2, 0.0, tmp.ptr(), npairs); CV[iat2][{iat3, { 0, 0, 0 }}] = tmp; } } @@ -242,8 +242,8 @@ namespace RI_Benchmark assert(tensor_ca.shape[0] == nabf1); //abf1 assert(vector_cb.size() == nabf2); //abf2 std::vector tmp(nabf1); - container::BlasConnector::gemv('T', nabf1, nabf2, 1.0, tensor_v.ptr(), nabf2, vector_cb.data(), 1, 0.0, tmp.data(), 1); - container::BlasConnector::gemv('N', npairs, nabf1, scale/*Hartree to Ry; singlet*/, tensor_ca.ptr(), npairs, tmp.data(), 1, 1.0, AX, 1); + BlasConnector::gemv_cm('T', nabf1, nabf2, 1.0, tensor_v.ptr(), nabf2, vector_cb.data(), 1, 0.0, tmp.data(), 1); + BlasConnector::gemv_cm('N', npairs, nabf1, scale/*Hartree to Ry; singlet*/, tensor_ca.ptr(), npairs, tmp.data(), 1, 1.0, AX, 1); } } } @@ -270,7 +270,7 @@ namespace RI_Benchmark const auto& vector_cx = it3.second; // (nabf) const int& nabf = tensor_cv.shape[0]; assert(vector_cx.size() == nabf); //abf on at2 - container::BlasConnector::gemv('N', npairs, nabf, scale/*Hartree to Ry; singlet*/, tensor_cv.ptr(), npairs, vector_cx.data(), 1, 1.0, AX, 1); + BlasConnector::gemv_cm('N', npairs, nabf, scale/*Hartree to Ry; singlet*/, tensor_cv.ptr(), npairs, vector_cx.data(), 1, 1.0, AX, 1); } } } diff --git a/source/source_lcao/module_lr/utils/lr_util.cpp b/source/source_lcao/module_lr/utils/lr_util.cpp index ae2f00ffa3..9c15f2057a 100644 --- a/source/source_lcao/module_lr/utils/lr_util.cpp +++ b/source/source_lcao/module_lr/utils/lr_util.cpp @@ -118,66 +118,38 @@ namespace LR_Util void diag_lapack(const int& n, double* mat, double* eig) { ModuleBase::TITLE("LR_Util", "diag_lapack"); - int info = 0; char jobz = 'V', uplo = 'U'; - double work_tmp; - const int minus_one = -1; - dsyev_(&jobz, &uplo, &n, mat, &n, eig, &work_tmp, &minus_one, &info); // get best lwork - const int lwork = work_tmp; - double* work2 = new double[lwork]; - dsyev_(&jobz, &uplo, &n, mat, &n, eig, work2, &lwork, &info); - if (info) { std::cout << "ERROR: Lapack solver, info=" << info << std::endl; } - delete[] work2; + LapackConnector::syev(LapackConnector::ColMajor, jobz, uplo, n, mat, n, eig); } void diag_lapack(const int& n, std::complex* mat, double* eig) { ModuleBase::TITLE("LR_Util", "diag_lapack >"); - int lwork = 2 * n; - std::complex* work2 = new std::complex[lwork]; - double* rwork = new double[3 * n - 2]; - int info = 0; char jobz = 'V', uplo = 'U'; - zheev_(&jobz, &uplo, &n, mat, &n, eig, work2, &lwork, rwork, &info); - if (info) { std::cout << "ERROR: Lapack solver, info=" << info << std::endl; } - delete[] rwork; - delete[] work2; + LapackConnector::heev(LapackConnector::ColMajor, jobz, uplo, n, mat, n, eig); } void diag_lapack_nh(const int& n, double* mat, std::complex* eig) { ModuleBase::TITLE("LR_Util", "diag_lapack_nh"); - int info = 0; char jobvl = 'N', jobvr = 'V'; //calculate right eigenvectors - double work_tmp; - const int minus_one = -1; std::vector eig_real(n); std::vector eig_imag(n); const int ldvl = 1, ldvr = n; std::vector vl(ldvl * n), vr(ldvr * n); - dgeev_(&jobvl, &jobvr, &n, mat, &n, eig_real.data(), eig_imag.data(), - vl.data(), &ldvl, vr.data(), &ldvr, &work_tmp, &minus_one /*lwork*/, &info); // get best lwork - const int lwork = work_tmp; - std::vector work2(lwork); - dgeev_(&jobvl, &jobvr, &n, mat, &n, eig_real.data(), eig_imag.data(), - vl.data(), &ldvl, vr.data(), &ldvr, work2.data(), &lwork, &info); - if (info) { std::cout << "ERROR: Lapack solver dgeev, info=" << info << std::endl; } + LapackConnector::geev(LapackConnector::ColMajor, jobvl, jobvr, n, mat, n, eig_real.data(), eig_imag.data(), + vl.data(), ldvl, vr.data(), ldvr); for (int i = 0;i < n;++i) { eig[i] = std::complex(eig_real[i], eig_imag[i]); } } void diag_lapack_nh(const int& n, std::complex* mat, std::complex* eig) { ModuleBase::TITLE("LR_Util", "diag_lapack_nh >"); - int lwork = 2 * n; - std::vector> work2(lwork); - std::vector rwork(3 * n - 2); - int info = 0; char jobvl = 'N', jobvr = 'V'; const int ldvl = 1, ldvr = n; std::vector> vl(ldvl * n), vr(ldvr * n); - zgeev_(&jobvl, &jobvr, &n, mat, &n, eig, - vl.data(), &ldvl, vr.data(), &ldvr, work2.data(), &lwork, rwork.data(), &info); - if (info) { std::cout << "ERROR: Lapack solver zgeev, info=" << info << std::endl; } + LapackConnector::geev(LapackConnector::ColMajor, jobvl, jobvr, n, mat, n, eig, + vl.data(), ldvl, vr.data(), ldvr); } std::string tolower(const std::string& str) diff --git a/source/source_lcao/module_operator_lcao/deepks_lcao.cpp b/source/source_lcao/module_operator_lcao/deepks_lcao.cpp index 5e75cc9b40..17c7043c4e 100644 --- a/source/source_lcao/module_operator_lcao/deepks_lcao.cpp +++ b/source/source_lcao/module_operator_lcao/deepks_lcao.cpp @@ -364,19 +364,20 @@ void hamilt::DeePKS>::calculate_HR() constexpr char transa = 'T', transb = 'N'; const double gemm_alpha = 1.0, gemm_beta = 1.0; - dgemm_(&transa, - &transb, - &col_size, - &row_size, - &trace_alpha_size, - &gemm_alpha, + BlasConnector::gemm_cm( + transa, + transb, + col_size, + row_size, + trace_alpha_size, + gemm_alpha, s_2t.data(), - &trace_alpha_size, + trace_alpha_size, s_1t.data(), - &trace_alpha_size, - &gemm_beta, + trace_alpha_size, + gemm_beta, hr_current.data(), - &col_size); + col_size); // add data of HR to target BaseMatrix #pragma omp critical diff --git a/source/source_lcao/module_ri/ABFs_Construct-PCA.cpp b/source/source_lcao/module_ri/ABFs_Construct-PCA.cpp index 9c944f77db..f75e30821d 100644 --- a/source/source_lcao/module_ri/ABFs_Construct-PCA.cpp +++ b/source/source_lcao/module_ri/ABFs_Construct-PCA.cpp @@ -22,14 +22,7 @@ namespace PCA assert(a.shape[0] == a.shape[1]); const int nr = a.shape[0]; const int nc = a.shape[1]; - - double work_tmp=0.0; - constexpr int minus_one = -1; - dsyev_(&jobz, &uplo, &nr, a.ptr(), &nc, w, &work_tmp, &minus_one, &info); // get best lwork - - const int lwork = work_tmp; - std::vector work(std::max(1, lwork)); - dsyev_(&jobz, &uplo, &nr, a.ptr(), &nc, w, work.data(), &lwork, &info); + LapackConnector::syev(LapackConnector::ColMajor, jobz, uplo, nr, a.ptr(), nc); } RI::Tensor get_sub_matrix( diff --git a/source/source_lcao/module_ri/Inverse_Matrix.hpp b/source/source_lcao/module_ri/Inverse_Matrix.hpp index 20ef034239..e4f729f30b 100644 --- a/source/source_lcao/module_ri/Inverse_Matrix.hpp +++ b/source/source_lcao/module_ri/Inverse_Matrix.hpp @@ -24,15 +24,8 @@ void Inverse_Matrix::cal_inverse( const Method &method ) template void Inverse_Matrix::using_potrf() { - int info; - LapackConnector::potrf('U', A.shape[0], A.ptr(), A.shape[0], info); - if(info) - throw std::range_error("info="+std::to_string(info)+"\n"+std::string(__FILE__)+" line "+std::to_string(__LINE__)); - - LapackConnector::potri('U', A.shape[0], A.ptr(), A.shape[0], info); - if(info) - throw std::range_error("info="+std::to_string(info)+"\n"+std::string(__FILE__)+" line "+std::to_string(__LINE__)); - + LapackConnector::potrf(LapackConnector::RowMajor, 'U', A.shape[0], A.ptr(), A.shape[0]); + LapackConnector::potri(LapackConnector::RowMajor, 'U', A.shape[0], A.ptr(), A.shape[0]); copy_down_triangle(); } diff --git a/source/source_lcao/module_ri/exx_lip.hpp b/source/source_lcao/module_ri/exx_lip.hpp index aad7b3494c..c1e1343023 100644 --- a/source/source_lcao/module_ri/exx_lip.hpp +++ b/source/source_lcao/module_ri/exx_lip.hpp @@ -408,7 +408,7 @@ void Exx_Lip::b_sum(const int iq, const int ib) // Peize Lin change { ModuleBase::timer::tick("Exx_Lip", "b_sum"); // this->sum1[iw_l,iw_r] += \sum_{ig} this->b[iw_l,ig] * conj(this->b[iw_r,ig]) * this->q_pack->wf_wg(iq,ib) - LapackConnector::herk( + BlasConnector::herk( 'U','N', PARAM.globalv.nlocal, this->rho_basis->npw, (Treal)this->q_pack->wf_wg(iq, ib), this->b.data(), this->rho_basis->npw, diff --git a/source/source_pw/module_pwdft/VNL_in_pw.cpp b/source/source_pw/module_pwdft/VNL_in_pw.cpp index e26e09fa3e..3547bb889b 100644 --- a/source/source_pw/module_pwdft/VNL_in_pw.cpp +++ b/source/source_pw/module_pwdft/VNL_in_pw.cpp @@ -1468,24 +1468,25 @@ void pseudopot_cell_vnl::newq(const ModuleBase::matrix& veff, const ModulePW::PW double* qg_ptr = reinterpret_cast(qg.c); double* aux_ptr = reinterpret_cast(aux.c); - dgemm_(&transa, - &transb, - &nij, - &natom, - &complex_npw, - &fact, + BlasConnector::gemm_cm( + transa, + transb, + nij, + natom, + complex_npw, + fact, qg_ptr, - &complex_npw, + complex_npw, aux_ptr, - &complex_npw, - &zero, + complex_npw, + zero, deeaux.c, - &nij); + nij); // I'm not sure if this is correct for gamma_only if (rho_basis->gamma_only && rho_basis->ig_gge0 >= 0) { const double neg = -1.0; - dger_(&nij, &natom, &neg, qg_ptr, &complex_npw, aux_ptr, &complex_npw, deeaux.c, &nij); + BlasConnector::ger_cm(nij, natom, neg, qg_ptr, complex_npw, aux_ptr, complex_npw, *(deeaux.c), nij); } for (int ia = 0; ia < natom; ia++) diff --git a/source/source_pw/module_pwdft/forces_us.cpp b/source/source_pw/module_pwdft/forces_us.cpp index dd2fbc5f8d..767d29dddc 100644 --- a/source/source_pw/module_pwdft/forces_us.cpp +++ b/source/source_pw/module_pwdft/forces_us.cpp @@ -98,19 +98,20 @@ void Forces::cal_force_us(ModuleBase::matrix& forcenl, const double zero = 0; for (int ipol = 0; ipol < 3; ipol++) { - dgemm_(&transa, - &transb, - &nij, - &atom->na, - &dim, - &(ucell.omega), + BlasConnector::gemm_cm( + transa, + transb, + nij, + atom->na, + dim, + (ucell.omega), qgm_data, - &dim, + dim, &aux1_data[ipol * dim * atom->na], - &dim, - &zero, + dim, + zero, &ddeeq(is, ipol, 0, 0), - &nij); + nij); } } diff --git a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp index 7e294754b8..9a78e5085d 100644 --- a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp +++ b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp @@ -8,6 +8,8 @@ #include "source_cell/klist.h" #include "source_hamilt/operator.h" #include "source_psi/psi.h" +#include "source_base/tool_quit.h" +#include "source_base/lapack_connector.h" #include #include @@ -15,15 +17,6 @@ #include #include -extern "C" -{ - void ztrtri_(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info); - void ctrtri_(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info); -} - -//extern "C" void zpotrf_(char* uplo, const int* n, std::complex* A, const int* lda, int* info); -//extern "C" void cpotrf_(char* uplo, const int* n, std::complex* A, const int* lda, int* info); - #include "op_exx_pw.h" #include "source_pw/module_pwdft/global.h" @@ -1062,25 +1055,25 @@ double OperatorEXXPW::cal_exx_energy_op(psi::Psi *ppsi_) c template <> void trtri_op, base_device::DEVICE_CPU>::operator()(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info) { - ctrtri_(uplo, diag, n, a, lda, info); + LapackConnector::trtri(LapackConnector::ColMajor, *uplo, *diag, *n, a, *lda); } template <> void trtri_op, base_device::DEVICE_CPU>::operator()(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info) { - ztrtri_(uplo, diag, n, a, lda, info); + LapackConnector::trtri(LapackConnector::ColMajor, *uplo, *diag, *n, a, *lda); } template <> void potrf_op, base_device::DEVICE_CPU>::operator()(char *uplo, int *n, std::complex *a, int *lda, int *info) { - cpotrf_(uplo, n, a, lda, info); + LapackConnector::potrf(LapackConnector::ColMajor, *uplo, *n, a, *lda); } template <> void potrf_op, base_device::DEVICE_CPU>::operator()(char *uplo, int *n, std::complex *a, int *lda, int *info) { - zpotrf_(uplo, n, a, lda, info); + LapackConnector::potrf(LapackConnector::ColMajor, *uplo, *n, a, *lda); } template class OperatorEXXPW, base_device::DEVICE_CPU>; diff --git a/source/source_pw/module_pwdft/stress_func_us.cpp b/source/source_pw/module_pwdft/stress_func_us.cpp index adb8745901..a5552593f8 100644 --- a/source/source_pw/module_pwdft/stress_func_us.cpp +++ b/source/source_pw/module_pwdft/stress_func_us.cpp @@ -111,19 +111,20 @@ void Stress_PW::stress_us(ModuleBase::matrix& sigma, const int dim = 2 * npw; const double one = 1; const double zero = 0; - dgemm_(&transa, - &transb, - &dim, - &PARAM.inp.nspin, - &nij, - &one, + BlasConnector::gemm_cm( + transa, + transb, + dim, + PARAM.inp.nspin, + nij, + one, qgm_data, - &dim, + dim, tbecsum.c, - &nij, - &zero, + nij, + zero, aux2_data, - &dim); + dim); for (int is = 0; is < PARAM.inp.nspin; is++) { @@ -148,19 +149,20 @@ void Stress_PW::stress_us(ModuleBase::matrix& sigma, ModuleBase::matrix fac(PARAM.inp.nspin, 3); const char transc = 'T'; const int three = 3; - dgemm_(&transc, - &transb, - &three, - &PARAM.inp.nspin, - &dim, - &one, + BlasConnector::gemm_cm( + transc, + transb, + three, + PARAM.inp.nspin, + dim, + one, aux1_data, - &dim, + dim, aux2_data, - &dim, - &zero, + dim, + zero, fac.c, - &three); + three); for (int is = 0; is < PARAM.inp.nspin; is++) { diff --git a/source/source_pw/module_stodft/sto_dos.cpp b/source/source_pw/module_stodft/sto_dos.cpp index ffcc24203f..d61242be32 100644 --- a/source/source_pw/module_stodft/sto_dos.cpp +++ b/source/source_pw/module_stodft/sto_dos.cpp @@ -157,7 +157,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, cons double* vec_all = (double*)allorderchi.data(); int LDA = npwx * nchipk_new * 2; int M = npwx * nchipk_new * 2; - dgemm_(&trans, &normal, &N, &N, &M, &kweight, vec_all, &LDA, vec_all, &LDA, &one, spolyv.data(), &N); + BlasConnector::gemm_cm(trans, normal, N, N, M, kweight, vec_all, LDA, vec_all, LDA, one, spolyv.data(), N); } } } diff --git a/source/source_relax/bfgs.cpp b/source/source_relax/bfgs.cpp index 829af3fa1c..a2667e5bb4 100644 --- a/source/source_relax/bfgs.cpp +++ b/source/source_relax/bfgs.cpp @@ -120,9 +120,6 @@ void BFGS::PrepareStep(std::vector>& force, //! call dysev std::vector omega(3*size); - std::vector work(3*size*3*size); - int lwork=3*size*3*size; - int info=0; std::vector H_flat; for(const auto& row : H) @@ -131,8 +128,7 @@ void BFGS::PrepareStep(std::vector>& force, } int value=3*size; - int* ptr=&value; - dsyev_("V","U",ptr,H_flat.data(),ptr,omega.data(),work.data(),&lwork,&info); + LapackConnector::syev(LapackConnector::ColMajor, 'V','U', value, H_flat.data(), value, omega.data()); std::vector> V(3*size, std::vector(3*size, 0.0)); for(int i = 0; i < 3*size; i++) { From 78f6485b355de93d92367dd586c83c566e44ca21 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Tue, 22 Jul 2025 14:37:05 +0800 Subject: [PATCH 2/5] fix some compilation error --- source/source_base/cubic_spline.cpp | 2 +- .../module_container/base/third_party/blas.h | 2 +- .../base/third_party/lapack.h | 2 +- .../blas_connector_l1.cpp | 0 .../blas_connector_l2.cpp | 2 +- .../blas_connector_l3.cpp | 0 .../module_external/blas_connector_matrix.cpp | 598 ------------------ .../module_external/blas_connector_vector.cpp | 506 --------------- .../lapack_connector.cpp | 0 source/source_base/module_mixing/mixing.cpp | 2 +- source/source_io/write_vxc_lip.hpp | 2 +- .../module_lr/ri_benchmark/ri_benchmark.hpp | 2 +- .../module_pwdft/operator_pw/op_exx_pw.cpp | 2 +- 13 files changed, 8 insertions(+), 1112 deletions(-) rename source/source_base/{ => module_external}/blas_connector_l1.cpp (100%) rename source/source_base/{ => module_external}/blas_connector_l2.cpp (99%) rename source/source_base/{ => module_external}/blas_connector_l3.cpp (100%) delete mode 100644 source/source_base/module_external/blas_connector_matrix.cpp delete mode 100644 source/source_base/module_external/blas_connector_vector.cpp rename source/source_base/{ => module_external}/lapack_connector.cpp (100%) diff --git a/source/source_base/cubic_spline.cpp b/source/source_base/cubic_spline.cpp index 8d28bbf1c0..9c03f0b861 100644 --- a/source/source_base/cubic_spline.cpp +++ b/source/source_base/cubic_spline.cpp @@ -1,5 +1,5 @@ #include "cubic_spline.h" -#include "source_base/lapack_connector.h" +#include "source_base/module_external/lapack_connector.h" #include #include diff --git a/source/source_base/module_container/base/third_party/blas.h b/source/source_base/module_container/base/third_party/blas.h index 216608aecb..e1b73ade90 100644 --- a/source/source_base/module_container/base/third_party/blas.h +++ b/source/source_base/module_container/base/third_party/blas.h @@ -2,7 +2,7 @@ #define BASE_THIRD_PARTY_BLAS_H_ #include -#include "source_base/blas_connector.h" +#include "source_base/module_external/blas_connector.h" #if defined(__CUDA) #include diff --git a/source/source_base/module_container/base/third_party/lapack.h b/source/source_base/module_container/base/third_party/lapack.h index 8ae1ecd350..e2b60805c2 100644 --- a/source/source_base/module_container/base/third_party/lapack.h +++ b/source/source_base/module_container/base/third_party/lapack.h @@ -3,7 +3,7 @@ #include #include "source_base/macros.h" -#include "source_base/lapack_connector.h" +#include "source_base/module_external/lapack_connector.h" #if defined(__CUDA) #include diff --git a/source/source_base/blas_connector_l1.cpp b/source/source_base/module_external/blas_connector_l1.cpp similarity index 100% rename from source/source_base/blas_connector_l1.cpp rename to source/source_base/module_external/blas_connector_l1.cpp diff --git a/source/source_base/blas_connector_l2.cpp b/source/source_base/module_external/blas_connector_l2.cpp similarity index 99% rename from source/source_base/blas_connector_l2.cpp rename to source/source_base/module_external/blas_connector_l2.cpp index 2966e2118b..d94e641e03 100644 --- a/source/source_base/blas_connector_l2.cpp +++ b/source/source_base/module_external/blas_connector_l2.cpp @@ -3,7 +3,7 @@ * These operations include matrix-vector multiplication and related operations. */ #include "blas_connector.h" -#include "macros.h" +#include "source_base/macros.h" #include #ifdef __DSP diff --git a/source/source_base/blas_connector_l3.cpp b/source/source_base/module_external/blas_connector_l3.cpp similarity index 100% rename from source/source_base/blas_connector_l3.cpp rename to source/source_base/module_external/blas_connector_l3.cpp diff --git a/source/source_base/module_external/blas_connector_matrix.cpp b/source/source_base/module_external/blas_connector_matrix.cpp deleted file mode 100644 index 2c749552b1..0000000000 --- a/source/source_base/module_external/blas_connector_matrix.cpp +++ /dev/null @@ -1,598 +0,0 @@ -/* level 3: matrix-matrix operations, O(n^2) data and O(n^3) work. - * This file contains the implementation of the BLAS level 3 operations. - * These operations include matrix-matrix multiplication and related operations. - */ -#include "blas_connector.h" -#include "../macros.h" - -#ifdef __DSP -#include "source_base/kernels/dsp/dsp_connector.h" -#include "source_base/global_variable.h" -#endif - -#ifdef __CUDA -#include -#include -#include "cublas_v2.h" -#include "source_base/kernels/math_kernel_op.h" -#include "source_base/module_device/memory_op.h" -#endif - -extern "C" -{ - // level 3: matrix-matrix operations, O(n^2) data and O(n^3) work. - - // Peize Lin add ?gemm 2017-10-27, to compute C = a * A.? * B.? + b * C - // A is general - void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, - const float *beta, float *c, const int *ldc); - void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, - const double *beta, double *c, const int *ldc); - void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); - void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); - - // A is symmetric. C = a * A.? * B.? + b * C - void ssymm_(const char *side, const char *uplo, const int *m, const int *n, - const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, - const float *beta, float *c, const int *ldc); - void dsymm_(const char *side, const char *uplo, const int *m, const int *n, - const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, - const double *beta, double *c, const int *ldc); - void csymm_(const char *side, const char *uplo, const int *m, const int *n, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); - void zsymm_(const char *side, const char *uplo, const int *m, const int *n, - const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, - const std::complex *beta, std::complex *c, const int *ldc); - - // A is hermitian. C = a * A.? * B.? + b * C - void chemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, - std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); - void zhemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, - std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); - - // symmetric rank-k update - void dsyrk_( - const char* uplo, - const char* trans, - const int* n, - const int* k, - const double* alpha, - const double* a, - const int* lda, - const double* beta, - double* c, - const int* ldc - ); -} - -// C = a * A.? * B.? + b * C -// Row-Major part -void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, - const float alpha, const float *a, const int lda, const float *b, const int ldb, - const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - sgemm_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); - } -#ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice){ - mtfunc::sgemm_mth_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::gemm(const char transa, - const char transb, - const int m, - const int n, - const int k, - const double alpha, - const double* a, - const int lda, - const double* b, - const int ldb, - const double beta, - double* c, - const int ldc, - base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) - { - dgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); - } -#ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) - { - mtfunc::dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice) - { -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck( - cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); -#endif - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::gemm(const char transa, - const char transb, - const int m, - const int n, - const int k, - const std::complex alpha, - const std::complex* a, - const int lda, - const std::complex* b, - const int ldb, - const std::complex beta, - std::complex* c, - const int ldc, - base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) - { - cgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); - } -#ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) - { - mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice) - { -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, - cutransA, - cutransB, - n, - m, - k, - (float2*)&alpha, - (float2*)b, - ldb, - (float2*)a, - lda, - (float2*)&beta, - (float2*)c, - ldc)); -#endif - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::gemm(const char transa, - const char transb, - const int m, - const int n, - const int k, - const std::complex alpha, - const std::complex* a, - const int lda, - const std::complex* b, - const int ldb, - const std::complex beta, - std::complex* c, - const int ldc, - base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) - { - zgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); - } -#ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) - { - mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice) - { -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, - cutransA, - cutransB, - n, - m, - k, - (double2*)&alpha, - (double2*)b, - ldb, - (double2*)a, - lda, - (double2*)&beta, - (double2*)c, - ldc)); -#endif - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -// Col-Major part -void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, - const float alpha, const float *a, const int lda, const float *b, const int ldb, - const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - sgemm_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } -#ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice){ - mtfunc::sgemm_mth_(&transb, &transa, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::gemm_cm(const char transa, - const char transb, - const int m, - const int n, - const int k, - const double alpha, - const double* a, - const int lda, - const double* b, - const int ldb, - const double beta, - double* c, - const int ldc, - base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) - { - dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); - } -#ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) - { - mtfunc::dgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) - { - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck( - cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::gemm_cm(const char transa, - const char transb, - const int m, - const int n, - const int k, - const std::complex alpha, - const std::complex* a, - const int lda, - const std::complex* b, - const int ldb, - const std::complex beta, - std::complex* c, - const int ldc, - base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) - { - cgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); - } -#ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) - { - mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) - { - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, - cutransA, - cutransB, - m, - n, - k, - (float2*)&alpha, - (float2*)a, - lda, - (float2*)b, - ldb, - (float2*)&beta, - (float2*)c, - ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::gemm_cm(const char transa, - const char transb, - const int m, - const int n, - const int k, - const std::complex alpha, - const std::complex* a, - const int lda, - const std::complex* b, - const int ldb, - const std::complex beta, - std::complex* c, - const int ldc, - base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) - { - zgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); - } -#ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) - { - mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) - { - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, - cutransA, - cutransB, - m, - n, - k, - (double2*)&alpha, - (double2*)a, - lda, - (double2*)b, - ldb, - (double2*)&beta, - (double2*)c, - ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -// Symm and Hemm part. Only col-major is supported. - -void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, - const float alpha, const float *a, const int lda, const float *b, const int ldb, - const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - ssymm_(&side, &uplo, &m, &n, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasSideMode_t sideMode = BlasUtils::judge_side(side); - cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); - cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, - const double alpha, const double *a, const int lda, const double *b, const int ldb, - const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dsymm_(&side, &uplo, &m, &n, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasSideMode_t sideMode = BlasUtils::judge_side(side); - cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); - cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - csymm_(&side, &uplo, &m, &n, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasSideMode_t sideMode = BlasUtils::judge_side(side); - cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); - cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zsymm_(&side, &uplo, &m, &n, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasSideMode_t sideMode = BlasUtils::judge_side(side); - cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); - cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::hemm_cm(const char side, const char uplo, const int m, const int n, - const float alpha, const float *a, const int lda, const float *b, const int ldb, - const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - symm_cm(side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc, device_type); -} - -void BlasConnector::hemm_cm(const char side, const char uplo, const int m, const int n, - const double alpha, const double *a, const int lda, const double *b, const int ldb, - const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - symm_cm(side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc, device_type); -} - -void BlasConnector::hemm_cm(char side, char uplo, int m, int n, - std::complex alpha, std::complex *a, int lda, std::complex *b, int ldb, - std::complex beta, std::complex *c, int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - chemm_(&side, &uplo, &m, &n, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasSideMode_t sideMode = BlasUtils::judge_side(side); - cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); - cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::hemm_cm(char side, char uplo, int m, int n, - std::complex alpha, std::complex *a, int lda, std::complex *b, int ldb, - std::complex beta, std::complex *c, int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zhemm_(&side, &uplo, &m, &n, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasSideMode_t sideMode = BlasUtils::judge_side(side); - cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo); - cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::syrk(char uplo, char trans, int n, int k, - double alpha, const double* a, int lda, double beta, double* c, int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) - { - dsyrk_(&uplo, &trans, &n, &k, &alpha, a, &lda, &beta, c, &ldc); - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::herk(char uplo, char trans, int n, int k, float alpha, - const std::complex *A, int lda, float beta, std::complex *C, int ldc, base_device::AbacusDevice_t device_type) -{ - auto cblas_uplo = BlasUtils::toCblasUplo(uplo); - auto cblas_trans = BlasUtils::toCblasTrans(trans); - if (device_type == base_device::AbacusDevice_t::CpuDevice) - { - cblas_cherk(CblasRowMajor, cblas_uplo, cblas_trans, n, k, alpha, A, lda, beta, C, ldc); - } - else - { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::herk(char uplo, char trans, int n, int k, double alpha, - const std::complex *A, int lda, double beta, std::complex *C, int ldc, base_device::AbacusDevice_t device_type) -{ - auto cblas_uplo = BlasUtils::toCblasUplo(uplo); - auto cblas_trans = BlasUtils::toCblasTrans(trans); - if (device_type == base_device::AbacusDevice_t::CpuDevice) - { - cblas_zherk(CblasRowMajor, cblas_uplo, cblas_trans, n, k, alpha, A, lda, beta, C, ldc); - } - else - { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} \ No newline at end of file diff --git a/source/source_base/module_external/blas_connector_vector.cpp b/source/source_base/module_external/blas_connector_vector.cpp deleted file mode 100644 index f9c5925143..0000000000 --- a/source/source_base/module_external/blas_connector_vector.cpp +++ /dev/null @@ -1,506 +0,0 @@ -/* level 1: std::vector-std::vector operations, O(n) data and O(n) work. - * This file contains the implementation of the BLAS level 1 operations. - * These operations include vector scaling, vector addition, vector dot product, and vector norm calculations. - */ -#include "blas_connector.h" -#include "../macros.h" - -#include -#ifdef __DSP -#include "source_base/kernels/dsp/dsp_connector.h" -#include "source_base/global_variable.h" -#endif - -#ifdef __CUDA -#include -#include -#include "cublas_v2.h" -#include "source_base/kernels/math_kernel_op.h" -#include "source_base/module_device/memory_op.h" -#endif - -extern "C" -{ - // level 1: std::vector-std::vector operations, O(n) data and O(n) work. - // Peize Lin add ?scal 2016-08-04, to compute x=a*x - void sscal_(const int *N, const float *alpha, float *X, const int *incX); - void dscal_(const int *N, const double *alpha, double *X, const int *incX); - void cscal_(const int *N, const std::complex *alpha, std::complex *X, const int *incX); - void zscal_(const int *N, const std::complex *alpha, std::complex *X, const int *incX); - - // Peize Lin add ?axpy 2016-08-04, to compute y=a*x+y - void saxpy_(const int *N, const float *alpha, const float *X, const int *incX, float *Y, const int *incY); - void daxpy_(const int *N, const double *alpha, const double *X, const int *incX, double *Y, const int *incY); - void caxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); - void zaxpy_(const int *N, const std::complex *alpha, const std::complex *X, const int *incX, std::complex *Y, const int *incY); - - void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy); - void zcopy_(long const *n, const std::complex *a, int const *incx, std::complex *b, int const *incy); - - //reason for passing results as argument instead of returning it: - //see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/ - // void zdotc_(std::complex *result, const int *n, const std::complex *zx, - // const int *incx, const std::complex *zy, const int *incy); - // Peize Lin add ?dot 2017-10-27, to compute d=x*y - float sdot_(const int *N, const float *X, const int *incX, const float *Y, const int *incY); - double ddot_(const int *N, const double *X, const int *incX, const double *Y, const int *incY); - - // Peize Lin add ?nrm2 2018-06-12, to compute out = ||x||_2 = \sqrt{ \sum_i x_i**2 } - float snrm2_( const int *n, const float *X, const int *incX ); - double dnrm2_( const int *n, const double *X, const int *incX ); - double dznrm2_( const int *n, const std::complex *X, const int *incX ); -} - -// x=a*x -void BlasConnector::scal( const int n, const float alpha, float *X, const int incX, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - sscal_(&n, &alpha, X, &incX); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dscal_(&n, &alpha, X, &incX); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - cscal_(&n, &alpha, X, &incX); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::scal( const int n, const std::complex alpha, std::complex *X, const int incX, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zscal_(&n, &alpha, X, &incX); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - saxpy_(&n, &alpha, X, &incX, Y, &incY); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - daxpy_(&n, &alpha, X, &incX, Y, &incY); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - caxpy_(&n, &alpha, X, &incX, Y, &incY); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::axpy( const int n, const std::complex alpha, const std::complex *X, const int incX, std::complex *Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zaxpy_(&n, &alpha, X, &incX, Y, &incY); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY)); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -// copies a into b -void BlasConnector::copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dcopy_(&n, a, &incx, b, &incy); - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void BlasConnector::copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zcopy_(&n, a, &incx, b, &incy); - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -// d=x*y -float BlasConnector::dot( const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - return sdot_(&n, X, &incX, Y, &incY); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ - float result = 0.0; - cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result)); - return result; - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -double BlasConnector::dot( const int n, const double*const X, const int incX, const double*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - return ddot_(&n, X, &incX, Y, &incY); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ - double result = 0.0; - cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result)); - return result; - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -// d=x*y -float BlasConnector::dotu(const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - return BlasConnector::dot(n, X, incX, Y, incY, device_type); -} - -double BlasConnector::dotu(const int n, const double*const X, const int incX, const double*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - return BlasConnector::dot(n, X, incX, Y, incY, device_type); -} - -std::complex BlasConnector::dotu(const int n, const std::complex*const X, const int incX, const std::complex*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - const int incX2 = 2 * incX; - const int incY2 = 2 * incY; - const float*const x = reinterpret_cast(X); - const float*const y = reinterpret_cast(Y); - //Re(result)=Re(x)*Re(y)-Im(x)*Im(y) - //Im(result)=Re(x)*Im(y)+Im(x)*Re(y) - return std::complex( - BlasConnector::dot(n, x, incX2, y, incY2, device_type) - dot(n, x+1, incX2, y+1, incY2, device_type), - BlasConnector::dot(n, x, incX2, y+1, incY2, device_type) + dot(n, x+1, incX2, y, incY2, device_type)); - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -std::complex BlasConnector::dotu(const int n, const std::complex*const X, const int incX, const std::complex*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - const int incX2 = 2 * incX; - const int incY2 = 2 * incY; - const double*const x = reinterpret_cast(X); - const double*const y = reinterpret_cast(Y); - //Re(result)=Re(x)*Re(y)-Im(x)*Im(y) - //Im(result)=Re(x)*Im(y)+Im(x)*Re(y) - return std::complex( - BlasConnector::dot(n, x, incX2, y, incY2, device_type) - dot(n, x+1, incX2, y+1, incY2, device_type), - BlasConnector::dot(n, x, incX2, y+1, incY2, device_type) + dot(n, x+1, incX2, y, incY2, device_type)); - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -// d = x.conj() * Vy -float BlasConnector::dotc(const int n, const float*const X, const int incX, const float*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - return BlasConnector::dot(n, X, incX, Y, incY, device_type); -} - -double BlasConnector::dotc(const int n, const double*const X, const int incX, const double*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - return BlasConnector::dot(n, X, incX, Y, incY, device_type); -} - -std::complex BlasConnector::dotc(const int n, const std::complex*const X, const int incX, const std::complex*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - const int incX2 = 2 * incX; - const int incY2 = 2 * incY; - const float*const x = reinterpret_cast(X); - const float*const y = reinterpret_cast(Y); - // Re(result)=Re(X)*Re(Y)+Im(X)*Im(Y) - // Im(result)=Re(X)*Im(Y)-Im(X)*Re(Y) - return std::complex( - BlasConnector::dot(n, x, incX2, y, incY2, device_type) + dot(n, x+1, incX2, y+1, incY2, device_type), - BlasConnector::dot(n, x, incX2, y+1, incY2, device_type) - dot(n, x+1, incX2, y, incY2, device_type)); - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -std::complex BlasConnector::dotc(const int n, const std::complex*const X, const int incX, const std::complex*const Y, const int incY, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - const int incX2 = 2 * incX; - const int incY2 = 2 * incY; - const double*const x = reinterpret_cast(X); - const double*const y = reinterpret_cast(Y); - // Re(result)=Re(X)*Re(Y)+Im(X)*Im(Y) - // Im(result)=Re(X)*Im(Y)-Im(X)*Re(Y) - return std::complex( - BlasConnector::dot(n, x, incX2, y, incY2, device_type) + dot(n, x+1, incX2, y+1, incY2, device_type), - BlasConnector::dot(n, x, incX2, y+1, incY2, device_type) - dot(n, x+1, incX2, y, incY2, device_type)); - } - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -// out = ||x||_2 -float BlasConnector::nrm2( const int n, const float *X, const int incX, base_device::AbacusDevice_t device_type ) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - return snrm2_( &n, X, &incX ); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ - float result = 0.0; - cublasErrcheck(cublasSnrm2(BlasUtils::cublas_handle, n, X, incX, &result)); - return result; - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - - -double BlasConnector::nrm2( const int n, const double *X, const int incX, base_device::AbacusDevice_t device_type ) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - return dnrm2_( &n, X, &incX ); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ - double result = 0.0; - cublasErrcheck(cublasDnrm2(BlasUtils::cublas_handle, n, X, incX, &result)); - return result; - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - - -double BlasConnector::nrm2( const int n, const std::complex *X, const int incX, base_device::AbacusDevice_t device_type ) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - return dznrm2_( &n, X, &incX ); - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ - double result = 0.0; - cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, &result)); - return result; - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -template -void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type){ - using Real = typename GetTypeReal::type; - if (device_type == base_device::AbacusDevice_t::CpuDevice) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096 / sizeof(Real)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * vector2[i]; - } - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - ModuleBase::vector_mul_vector_op()(dim, result, vector1, vector2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - - -template -void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type){ - using Real = typename GetTypeReal::type; - if (device_type == base_device::AbacusDevice_t::CpuDevice) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096 / sizeof(Real)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] / vector2[i]; - } - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - ModuleBase::vector_div_vector_op()(dim, result, vector1, vector2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void vector_add_vector(const int& dim, float *result, const float *vector1, const float constant1, const float *vector2, const float constant2, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::CpuDevice){ -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 8192 / sizeof(float)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * constant1 + vector2[i] * constant2; - } - } -#ifdef __CUDA - else if (device_type == base_device::GpuDevice) { - ModuleBase::vector_add_vector_op()(dim, result, vector1, constant1, vector2, constant2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void vector_add_vector(const int& dim, double *result, const double *vector1, const double constant1, const double *vector2, const double constant2, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::CpuDevice){ -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 8192 / sizeof(double)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * constant1 + vector2[i] * constant2; - } - } -#ifdef __CUDA - else if (device_type == base_device::GpuDevice) { - ModuleBase::vector_add_vector_op()(dim, result, vector1, constant1, vector2, constant2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void vector_add_vector(const int& dim, std::complex *result, const std::complex *vector1, const float constant1, const std::complex *vector2, const float constant2, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::CpuDevice){ -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 8192 / sizeof(std::complex)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * constant1 + vector2[i] * constant2; - } - } -#ifdef __CUDA - else if (device_type == base_device::GpuDevice) { - ModuleBase::vector_add_vector_op, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void vector_add_vector(const int& dim, std::complex *result, const std::complex *vector1, const double constant1, const std::complex *vector2, const double constant2, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::CpuDevice){ -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 8192 / sizeof(std::complex)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * constant1 + vector2[i] * constant2; - } - } -#ifdef __CUDA - else if (device_type == base_device::GpuDevice) { - ModuleBase::vector_add_vector_op, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} \ No newline at end of file diff --git a/source/source_base/lapack_connector.cpp b/source/source_base/module_external/lapack_connector.cpp similarity index 100% rename from source/source_base/lapack_connector.cpp rename to source/source_base/module_external/lapack_connector.cpp diff --git a/source/source_base/module_mixing/mixing.cpp b/source/source_base/module_mixing/mixing.cpp index e6c816f0b3..4d0e349de7 100644 --- a/source/source_base/module_mixing/mixing.cpp +++ b/source/source_base/module_mixing/mixing.cpp @@ -1,6 +1,6 @@ #include "mixing.h" -#include "source_base/blas_connector.h" +#include "source_base/module_external/blas_connector.h" namespace Base_Mixing { diff --git a/source/source_io/write_vxc_lip.hpp b/source/source_io/write_vxc_lip.hpp index 6e4916b68e..f5cb9681cd 100644 --- a/source/source_io/write_vxc_lip.hpp +++ b/source/source_io/write_vxc_lip.hpp @@ -2,7 +2,7 @@ #define __WRITE_VXC_LIP_H_ #include "source_io/module_parameter/parameter.h" #include "source_base/parallel_reduce.h" -#include "source_base/blas_connector.h" +#include "source_base/module_external/blas_connector.h" #include "source_pw/module_pwdft/operator_pw/veff_pw.h" #include "source_psi/psi.h" #include "source_cell/unitcell.h" diff --git a/source/source_lcao/module_lr/ri_benchmark/ri_benchmark.hpp b/source/source_lcao/module_lr/ri_benchmark/ri_benchmark.hpp index c1b2a7c347..ea74a6c270 100644 --- a/source/source_lcao/module_lr/ri_benchmark/ri_benchmark.hpp +++ b/source/source_lcao/module_lr/ri_benchmark/ri_benchmark.hpp @@ -1,6 +1,6 @@ #pragma once #include "ri_benchmark.h" -#include "source_base/blas_connector.h" +#include "source_base/module_external/blas_connector.h" namespace RI_Benchmark { // std::cout << "the size of Cs:" << std::endl; diff --git a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp index 9a78e5085d..a6d541f197 100644 --- a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp +++ b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp @@ -9,7 +9,7 @@ #include "source_hamilt/operator.h" #include "source_psi/psi.h" #include "source_base/tool_quit.h" -#include "source_base/lapack_connector.h" +#include "source_base/module_external/lapack_connector.h" #include #include From 25c752fd0fed6aa821015501f08c204325f93628 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Tue, 22 Jul 2025 02:09:53 +0800 Subject: [PATCH 3/5] modify cmakelist --- CMakeLists.txt | 2 +- cmake/FindLapack.cmake | 19 +++++++++++++++++++ source/source_base/module_grid/batch.cpp | 2 +- .../module_grid/test/CMakeLists.txt | 6 +++++- .../module_ao/test/CMakeLists.txt | 6 ++++-- .../module_pw/kernels/test/CMakeLists.txt | 6 +++++- .../module_pw/test/CMakeLists.txt | 6 +++++- .../module_xc/test/CMakeLists.txt | 12 ++++++++++-- source/source_md/test/CMakeLists.txt | 6 ++++-- .../module_pwdft/test/CMakeLists.txt | 6 +++++- source/source_relax/test/CMakeLists.txt | 6 +++++- 11 files changed, 64 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 277f1924ec..c99dfc5c21 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -479,7 +479,7 @@ else() find_package(FFTW3 REQUIRED) find_package(Lapack REQUIRED) include_directories(${FFTW3_INCLUDE_DIRS}) - list(APPEND math_libs FFTW3::FFTW3 LAPACK::LAPACK BLAS::BLAS) + list(APPEND math_libs FFTW3::FFTW3 LAPACKE::LAPACKE LAPACK::LAPACK BLAS::BLAS) find_package(ScaLAPACK REQUIRED) list(APPEND math_libs ScaLAPACK::ScaLAPACK) if(USE_OPENMP) diff --git a/cmake/FindLapack.cmake b/cmake/FindLapack.cmake index 15c3976d64..b6f5dba604 100644 --- a/cmake/FindLapack.cmake +++ b/cmake/FindLapack.cmake @@ -9,9 +9,28 @@ endif() find_package(Blas REQUIRED) find_package(LAPACK REQUIRED) +# find LAPACKE libraries +find_library(LAPACKE_LIBRARY + NAMES lapacke + PATHS ${LAPACK_DIR} ${LAPACK_LIBRARIES} ${CMAKE_PREFIX_PATH} + PATH_SUFFIXES lib lib64 + DOC "Path to LAPACKE library" +) + if(NOT TARGET LAPACK::LAPACK) add_library(LAPACK::LAPACK UNKNOWN IMPORTED) set_target_properties(LAPACK::LAPACK PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${LAPACK_LIBRARIES}") endif() + +if(NOT TARGET LAPACKE::LAPACKE) + add_library(LAPACKE::LAPACKE UNKNOWN IMPORTED) + set_target_properties(LAPACKE::LAPACKE PROPERTIES + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${LAPACKE_LIBRARY}") + + set_target_properties(LAPACKE::LAPACKE PROPERTIES + INTERFACE_LINK_LIBRARIES "LAPACK::LAPACK" + ) +endif() diff --git a/source/source_base/module_grid/batch.cpp b/source/source_base/module_grid/batch.cpp index 36f8ea6ace..5291f8fe6d 100644 --- a/source/source_base/module_grid/batch.cpp +++ b/source/source_base/module_grid/batch.cpp @@ -75,7 +75,7 @@ int _maxmin_divide(const double* grid, int* idx, int m) { // Rearrange the indices to put points in each subset together by // examining the signed distances of points to the cut plane (R^T*n). std::vector dist(m); - dgemv_("T", &i3, &m, &d1, R.data(), &i3, n, &i1, &d0, dist.data(), &i1); + BlasConnector::gemv_cm('T', 3, m, 1.0, R.data(), 3, n, 1, 0.0, dist.data(), 1); int *head = idx; std::reverse_iterator tail(idx + m), rend(idx); diff --git a/source/source_base/module_grid/test/CMakeLists.txt b/source/source_base/module_grid/test/CMakeLists.txt index 068feb9634..6e08a90308 100644 --- a/source/source_base/module_grid/test/CMakeLists.txt +++ b/source/source_base/module_grid/test/CMakeLists.txt @@ -25,5 +25,9 @@ AddTest( TARGET MODULE_BASE_GRID_test_batch SOURCES test_batch.cpp ../batch.cpp - LIBS ${math_libs} + ../../blas_connector_base.cpp + ../../blas_connector_l1.cpp + ../../blas_connector_l2.cpp + ../../blas_connector_l3.cpp + ../../lapack_connector.cpp ) diff --git a/source/source_basis/module_ao/test/CMakeLists.txt b/source/source_basis/module_ao/test/CMakeLists.txt index b7e1e31191..ef769ec459 100644 --- a/source/source_basis/module_ao/test/CMakeLists.txt +++ b/source/source_basis/module_ao/test/CMakeLists.txt @@ -8,8 +8,10 @@ list(APPEND depend_files ../../../source_base/ylm.cpp ../../../source_base/memory.cpp ../../../source_base/module_external/blas_connector_base.cpp - ../../../source_base/module_external/blas_connector_vector.cpp - ../../../source_base/module_external/blas_connector_matrix.cpp + ../../../source_base/module_external/blas_connector_l1.cpp + ../../../source_base/module_external/blas_connector_l2.cpp + ../../../source_base/module_external/blas_connector_l3.cpp + ../../../source_base/module_external/lapack_connector.cpp ../../../source_base/complexarray.cpp ../../../source_base/complexmatrix.cpp ../../../source_base/matrix.cpp diff --git a/source/source_basis/module_pw/kernels/test/CMakeLists.txt b/source/source_basis/module_pw/kernels/test/CMakeLists.txt index 151010241b..f90d15e9c5 100644 --- a/source/source_basis/module_pw/kernels/test/CMakeLists.txt +++ b/source/source_basis/module_pw/kernels/test/CMakeLists.txt @@ -9,5 +9,9 @@ AddTest( ../../../../source_base/parallel_comm.cpp ../../../../source_base/complexmatrix.cpp ../../../../source_base/matrix.cpp ../../../../source_base/memory.cpp ../../../../source_base/libm/branred.cpp ../../../../source_base/libm/sincos.cpp - ../../../../source_base/module_external/blas_connector_base.cpp ../../../../source_base/module_external/blas_connector_vector.cpp ../../../../source_base/module_external/blas_connector_matrix.cpp + ../../../../source_base/module_external/blas_connector_base.cpp + ../../../../source_base/module_external/blas_connector_l1.cpp + ../../../../source_base/module_external/blas_connector_l2.cpp + ../../../../source_base/module_external/blas_connector_l3.cpp + ../../../../source_base/module_external/lapack_connector.cpp ) \ No newline at end of file diff --git a/source/source_basis/module_pw/test/CMakeLists.txt b/source/source_basis/module_pw/test/CMakeLists.txt index c73f5549fd..7359383905 100644 --- a/source/source_basis/module_pw/test/CMakeLists.txt +++ b/source/source_basis/module_pw/test/CMakeLists.txt @@ -4,7 +4,11 @@ AddTest( LIBS parameter ${math_libs} planewave device SOURCES ../../../source_base/matrix.cpp ../../../source_base/complexmatrix.cpp ../../../source_base/matrix3.cpp ../../../source_base/tool_quit.cpp ../../../source_base/mymath.cpp ../../../source_base/timer.cpp ../../../source_base/memory.cpp - ../../../source_base/module_external/blas_connector_base.cpp ../../../source_base/module_external/blas_connector_vector.cpp ../../../source_base/module_external/blas_connector_matrix.cpp + ../../../source_base/module_external/blas_connector_base.cpp + ../../../source_base/module_external/blas_connector_l1.cpp + ../../../source_base/module_external/blas_connector_l2.cpp + ../../../source_base/module_external/blas_connector_l3.cpp + ../../../source_base/module_external/lapack_connector.cpp ../../../source_base/libm/branred.cpp ../../../source_base/libm/sincos.cpp ../../../source_base/module_device/memory_op.cpp depend_mock.cpp pw_test.cpp test1-1-1.cpp test1-1-2.cpp test1-2.cpp test1-3.cpp test1-4.cpp test1-5.cpp diff --git a/source/source_hamilt/module_xc/test/CMakeLists.txt b/source/source_hamilt/module_xc/test/CMakeLists.txt index eb24bfec27..857bd7516e 100644 --- a/source/source_hamilt/module_xc/test/CMakeLists.txt +++ b/source/source_hamilt/module_xc/test/CMakeLists.txt @@ -40,7 +40,11 @@ AddTest( ../../../source_base/memory.cpp ../../../source_base/libm/branred.cpp ../../../source_base/libm/sincos.cpp - ../../../source_base/module_external/blas_connector_base.cpp ../../../source_base/module_external/blas_connector_vector.cpp ../../../source_base/module_external/blas_connector_matrix.cpp + ../../../source_base/module_external/blas_connector_base.cpp + ../../../source_base/module_external/blas_connector_l1.cpp + ../../../source_base/module_external/blas_connector_l2.cpp + ../../../source_base/module_external/blas_connector_l3.cpp + ../../../source_base/module_external/lapack_connector.cpp ../../../source_basis/module_pw/module_fft/fft_bundle.cpp ../../../source_basis/module_pw/module_fft/fft_cpu.cpp ${FFT_SRC} @@ -73,7 +77,11 @@ AddTest( ../xc_functional_vxc.cpp ../xc_functional_libxc_vxc.cpp ../xc_functional_libxc_tools.cpp - ../../../source_base/module_external/blas_connector_base.cpp ../../../source_base/module_external/blas_connector_vector.cpp ../../../source_base/module_external/blas_connector_matrix.cpp + ../../../source_base/module_external/blas_connector_base.cpp + ../../../source_base/module_external/blas_connector_l1.cpp + ../../../source_base/module_external/blas_connector_l2.cpp + ../../../source_base/module_external/blas_connector_l3.cpp + ../../../source_base/module_external/lapack_connector.cpp ../../../source_base/matrix.cpp ../../../source_base/memory.cpp ../../../source_base/timer.cpp diff --git a/source/source_md/test/CMakeLists.txt b/source/source_md/test/CMakeLists.txt index dee30a6ab7..c91c449f79 100644 --- a/source/source_md/test/CMakeLists.txt +++ b/source/source_md/test/CMakeLists.txt @@ -23,8 +23,10 @@ list(APPEND depend_files ../../source_base/matrix.cpp ../../source_base/timer.cpp ../../source_base/module_external/blas_connector_base.cpp - ../../source_base/module_external/blas_connector_matrix.cpp - ../../source_base/module_external/blas_connector_vector.cpp + ../../source_base/module_external/blas_connector_l1.cpp + ../../source_base/module_external/blas_connector_l2.cpp + ../../source_base/module_external/blas_connector_l3.cpp + ../../source_base/module_external/lapack_connector.cpp ../../source_base/memory.cpp ../../source_base/global_variable.cpp ../../source_base/global_function.cpp diff --git a/source/source_pw/module_pwdft/test/CMakeLists.txt b/source/source_pw/module_pwdft/test/CMakeLists.txt index 32009ba9e7..6477b3c480 100644 --- a/source/source_pw/module_pwdft/test/CMakeLists.txt +++ b/source/source_pw/module_pwdft/test/CMakeLists.txt @@ -15,7 +15,11 @@ AddTest( ../../../source_base/global_file.cpp ../../../source_base/memory.cpp ../../../source_base/timer.cpp - ../../../source_base/module_external/blas_connector_base.cpp ../../../source_base/module_external/blas_connector_vector.cpp ../../../source_base/module_external/blas_connector_matrix.cpp + ../../../source_base/module_external/blas_connector_base.cpp + ../../../source_base/module_external/blas_connector_l1.cpp + ../../../source_base/module_external/blas_connector_l2.cpp + ../../../source_base/module_external/blas_connector_l3.cpp + ../../../source_base/module_external/lapack_connector.cpp ../../../source_base/parallel_global.cpp ../../../source_base/parallel_comm.cpp ../../../source_base/parallel_common.cpp diff --git a/source/source_relax/test/CMakeLists.txt b/source/source_relax/test/CMakeLists.txt index 5f8a600cc6..5cab0093f6 100644 --- a/source/source_relax/test/CMakeLists.txt +++ b/source/source_relax/test/CMakeLists.txt @@ -18,7 +18,11 @@ AddTest( ../../source_base/matrix3.cpp ../../source_base/intarray.cpp ../../source_base/tool_title.cpp ../../source_base/global_function.cpp ../../source_base/complexmatrix.cpp ../../source_base/matrix.cpp ../../source_base/complexarray.cpp ../../source_base/tool_quit.cpp ../../source_base/realarray.cpp - ../../source_base/module_external/blas_connector_base.cpp ../../source_base/module_external/blas_connector_vector.cpp ../../source_base/module_external/blas_connector_matrix.cpp + ../../source_base/module_external/blas_connector_base.cpp + ../../source_base/module_external/blas_connector_l1.cpp + ../../source_base/module_external/blas_connector_l2.cpp + ../../source_base/module_external/blas_connector_l3.cpp + ../../source_base/module_external/lapack_connector.cpp ../../source_cell/update_cell.cpp ../../source_cell/print_cell.cpp ../../source_cell/bcast_cell.cpp ../../source_io/output.cpp LIBS parameter ${math_libs} ) From 933fc9537c407067e17da54c8e03b9220b3437d1 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Tue, 22 Jul 2025 16:03:43 +0800 Subject: [PATCH 4/5] fix a bug --- source/source_base/inverse_matrix.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/source/source_base/inverse_matrix.cpp b/source/source_base/inverse_matrix.cpp index 55ad954104..6f45521d0f 100644 --- a/source/source_base/inverse_matrix.cpp +++ b/source/source_base/inverse_matrix.cpp @@ -33,9 +33,6 @@ void Inverse_Matrix_Complex::init(const int &dim_in) assert(dim>0); this->e = new double[dim]; - - assert(lwork>0); - assert(3*dim-2>0); this->A.create(dim, dim); this->EA.create(dim, dim); From 4942cb32681d88ab7556681aacfda917a84da3e7 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Fri, 25 Jul 2025 16:30:57 +0800 Subject: [PATCH 5/5] update cmakelist --- CMakeLists.txt | 10 +- cmake/FindLapack.cmake | 18 +- source/source_base/CMakeLists.txt | 5 - .../kernels/test/math_kernel_test.cpp | 34 +-- .../module_external/blas_connector.h | 22 -- .../module_external/blas_connector_l1.cpp | 129 ---------- .../module_external/lapack_connector.cpp | 103 ++++---- .../module_grid/test/CMakeLists.txt | 9 +- source/source_base/test/CMakeLists.txt | 2 +- .../source_base/test/blas_connector_test.cpp | 243 ++---------------- .../module_ao/test/CMakeLists.txt | 5 - .../module_pw/kernels/test/CMakeLists.txt | 5 - .../module_pw/test/CMakeLists.txt | 5 - .../module_xc/test/CMakeLists.txt | 10 - source/source_hsolver/test/diago_cg_test.cpp | 2 +- .../source_hsolver/test/diago_david_test.cpp | 1 + source/source_hsolver/test/diago_elpa_utils.h | 4 +- source/source_md/test/CMakeLists.txt | 5 - .../module_pwdft/test/CMakeLists.txt | 5 - source/source_relax/test/CMakeLists.txt | 5 - 20 files changed, 119 insertions(+), 503 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c99dfc5c21..8b8c76ecaa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -459,6 +459,14 @@ if(ENABLE_ASAN) target_link_libraries(${ABACUS_BIN_NAME} -fsanitize=address) endif() +add_library(math_connector OBJECT +${ABACUS_SOURCE_DIR}/source_base/module_external/blas_connector_base.cpp +${ABACUS_SOURCE_DIR}/source_base/module_external/blas_connector_l1.cpp +${ABACUS_SOURCE_DIR}/source_base/module_external/blas_connector_l2.cpp +${ABACUS_SOURCE_DIR}/source_base/module_external/blas_connector_l3.cpp +${ABACUS_SOURCE_DIR}/source_base/module_external/lapack_connector.cpp +) +list(APPEND math_libs math_connector) if(DEFINED ENV{MKLROOT} AND NOT DEFINED MKLROOT) set(MKLROOT "$ENV{MKLROOT}") endif() @@ -479,7 +487,7 @@ else() find_package(FFTW3 REQUIRED) find_package(Lapack REQUIRED) include_directories(${FFTW3_INCLUDE_DIRS}) - list(APPEND math_libs FFTW3::FFTW3 LAPACKE::LAPACKE LAPACK::LAPACK BLAS::BLAS) + list(APPEND math_libs FFTW3::FFTW3 LAPACKE::LAPACKE BLAS::BLAS) find_package(ScaLAPACK REQUIRED) list(APPEND math_libs ScaLAPACK::ScaLAPACK) if(USE_OPENMP) diff --git a/cmake/FindLapack.cmake b/cmake/FindLapack.cmake index b6f5dba604..8a67fa53ae 100644 --- a/cmake/FindLapack.cmake +++ b/cmake/FindLapack.cmake @@ -9,6 +9,12 @@ endif() find_package(Blas REQUIRED) find_package(LAPACK REQUIRED) +find_path(LAPACKE_INCLUDE_DIR + NAMES lapacke.h + PATHS ${LAPACK_DIR} ${LAPACKE_DIR} ${CMAKE_PREFIX_PATH} + PATH_SUFFIXES include include/lapacke + DOC "Path to LAPACKE include directory" +) # find LAPACKE libraries find_library(LAPACKE_LIBRARY NAMES lapacke @@ -17,20 +23,10 @@ find_library(LAPACKE_LIBRARY DOC "Path to LAPACKE library" ) -if(NOT TARGET LAPACK::LAPACK) - add_library(LAPACK::LAPACK UNKNOWN IMPORTED) - set_target_properties(LAPACK::LAPACK PROPERTIES - IMPORTED_LINK_INTERFACE_LANGUAGES "C" - IMPORTED_LOCATION "${LAPACK_LIBRARIES}") -endif() - if(NOT TARGET LAPACKE::LAPACKE) add_library(LAPACKE::LAPACKE UNKNOWN IMPORTED) set_target_properties(LAPACKE::LAPACKE PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${LAPACKE_INCLUDE_DIR}" IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${LAPACKE_LIBRARY}") - - set_target_properties(LAPACKE::LAPACKE PROPERTIES - INTERFACE_LINK_LIBRARIES "LAPACK::LAPACK" - ) endif() diff --git a/source/source_base/CMakeLists.txt b/source/source_base/CMakeLists.txt index 9b0cc046f0..e2f08b7ccf 100644 --- a/source/source_base/CMakeLists.txt +++ b/source/source_base/CMakeLists.txt @@ -10,11 +10,6 @@ add_library( base OBJECT assoc_laguerre.cpp - module_external/blas_connector_base.cpp - module_external/blas_connector_l1.cpp - module_external/blas_connector_l2.cpp - module_external/blas_connector_l3.cpp - module_external/lapack_connector.cpp clebsch_gordan_coeff.cpp complexarray.cpp complexmatrix.cpp diff --git a/source/source_base/kernels/test/math_kernel_test.cpp b/source/source_base/kernels/test/math_kernel_test.cpp index 69bcce784e..2590a1e08b 100644 --- a/source/source_base/kernels/test/math_kernel_test.cpp +++ b/source/source_base/kernels/test/math_kernel_test.cpp @@ -347,17 +347,18 @@ TEST_F(TestModuleHsolverMathKernel, gemv_op_cpu) int inc = 1; int row = 2; int col = 3; - zgemv_(&trans, - &row, - &col, - &ModuleBase::ONE, + BlasConnector::gemv_cm( + trans, + row, + col, + ModuleBase::ONE, A_gemv.data(), - &row, + row, X_gemv.data(), - &inc, - &ModuleBase::ONE, + inc, + ModuleBase::ONE, Y_test_gemv.data(), - &inc); + inc); for (int i = 0; i < Y_gemv.size(); i++) { EXPECT_LT(fabs(Y_gemv[i].imag() - Y_test_gemv[i].imag()), 1e-12); @@ -607,17 +608,18 @@ TEST_F(TestModuleHsolverMathKernel, gemv_op_gpu) int inc = 1; int row = 2; int col = 3; - zgemv_(&trans, - &row, - &col, - &ModuleBase::ONE, + BlasConnector::gemv( + trans, + row, + col, + ModuleBase::ONE, A_gemv.data(), - &row, + row, X_gemv.data(), - &inc, - &ModuleBase::ONE, + inc, + ModuleBase::ONE, Y_test_gemv.data(), - &inc); + inc); for (int i = 0; i < Y_gemv.size(); i++) { diff --git a/source/source_base/module_external/blas_connector.h b/source/source_base/module_external/blas_connector.h index fee0995bd2..3221bb93d3 100644 --- a/source/source_base/module_external/blas_connector.h +++ b/source/source_base/module_external/blas_connector.h @@ -221,28 +221,6 @@ class BlasConnector static void copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); - - // There is some other operators needed, so implemented manually here - template - static - void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); - - template - static - void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); - - // y = alpha * x + beta * y - static - void vector_add_vector(const int& dim, float *result, const float *vector1, const float constant1, const float *vector2, const float constant2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); - - static - void vector_add_vector(const int& dim, double *result, const double *vector1, const double constant1, const double *vector2, const double constant2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); - - static - void vector_add_vector(const int& dim, std::complex *result, const std::complex *vector1, const float constant1, const std::complex *vector2, const float constant2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); - - static - void vector_add_vector(const int& dim, std::complex *result, const std::complex *vector1, const double constant1, const std::complex *vector2, const double constant2, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); }; namespace BlasUtils { diff --git a/source/source_base/module_external/blas_connector_l1.cpp b/source/source_base/module_external/blas_connector_l1.cpp index f9c5925143..9090f71539 100644 --- a/source/source_base/module_external/blas_connector_l1.cpp +++ b/source/source_base/module_external/blas_connector_l1.cpp @@ -374,133 +374,4 @@ double BlasConnector::nrm2( const int n, const std::complex *X, const in else { throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); } -} - -template -void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type){ - using Real = typename GetTypeReal::type; - if (device_type == base_device::AbacusDevice_t::CpuDevice) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096 / sizeof(Real)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * vector2[i]; - } - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - ModuleBase::vector_mul_vector_op()(dim, result, vector1, vector2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - - -template -void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vector2, base_device::AbacusDevice_t device_type){ - using Real = typename GetTypeReal::type; - if (device_type == base_device::AbacusDevice_t::CpuDevice) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096 / sizeof(Real)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] / vector2[i]; - } - } -#ifdef __CUDA - else if (device_type == base_device::AbacusDevice_t::GpuDevice) { - ModuleBase::vector_div_vector_op()(dim, result, vector1, vector2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void vector_add_vector(const int& dim, float *result, const float *vector1, const float constant1, const float *vector2, const float constant2, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::CpuDevice){ -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 8192 / sizeof(float)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * constant1 + vector2[i] * constant2; - } - } -#ifdef __CUDA - else if (device_type == base_device::GpuDevice) { - ModuleBase::vector_add_vector_op()(dim, result, vector1, constant1, vector2, constant2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void vector_add_vector(const int& dim, double *result, const double *vector1, const double constant1, const double *vector2, const double constant2, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::CpuDevice){ -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 8192 / sizeof(double)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * constant1 + vector2[i] * constant2; - } - } -#ifdef __CUDA - else if (device_type == base_device::GpuDevice) { - ModuleBase::vector_add_vector_op()(dim, result, vector1, constant1, vector2, constant2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void vector_add_vector(const int& dim, std::complex *result, const std::complex *vector1, const float constant1, const std::complex *vector2, const float constant2, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::CpuDevice){ -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 8192 / sizeof(std::complex)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * constant1 + vector2[i] * constant2; - } - } -#ifdef __CUDA - else if (device_type == base_device::GpuDevice) { - ModuleBase::vector_add_vector_op, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } -} - -void vector_add_vector(const int& dim, std::complex *result, const std::complex *vector1, const double constant1, const std::complex *vector2, const double constant2, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::CpuDevice){ -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 8192 / sizeof(std::complex)) -#endif - for (int i = 0; i < dim; i++) - { - result[i] = vector1[i] * constant1 + vector2[i] * constant2; - } - } -#ifdef __CUDA - else if (device_type == base_device::GpuDevice) { - ModuleBase::vector_add_vector_op, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2); - } -#endif - else { - throw std::invalid_argument("device_type = " + std::to_string(device_type) + " in " + std::string(__FILE__) + " line " + std::to_string(__LINE__)); - } } \ No newline at end of file diff --git a/source/source_base/module_external/lapack_connector.cpp b/source/source_base/module_external/lapack_connector.cpp index 41d3b6c016..772eb87f93 100644 --- a/source/source_base/module_external/lapack_connector.cpp +++ b/source/source_base/module_external/lapack_connector.cpp @@ -1,6 +1,5 @@ #include #include "lapack_connector.h" -#include "source_base/tool_quit.h" namespace LapackConnector { @@ -14,7 +13,7 @@ void hegv(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::comp int info = LAPACKE_chegv(toLapackLayout(layout), itype, jobz, uplo, n, reinterpret_cast(a), lda, reinterpret_cast(b), ldb, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK chegv failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK chegv failed with info = " + std::to_string(info)); } } @@ -22,7 +21,7 @@ void hegv(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::comp { int info = LAPACKE_zhegv(toLapackLayout(layout), itype, jobz, uplo, n, reinterpret_cast(a), lda, reinterpret_cast(b), ldb, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zhegv failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zhegv failed with info = " + std::to_string(info)); } } @@ -30,7 +29,7 @@ void hegv(MatrixLayout layout, int itype, char jobz, char uplo, int n, double* a { int info = LAPACKE_dsygv(toLapackLayout(layout), itype, jobz, uplo, n, a, lda, b, ldb, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsygv failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dsygv failed with info = " + std::to_string(info)); } } @@ -38,7 +37,7 @@ void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, float* a { int info = LAPACKE_ssygvd(toLapackLayout(layout), itype, jobz, uplo, n, a, lda, b, ldb, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ssygvd failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK ssygvd failed with info = " + std::to_string(info)); } } @@ -46,7 +45,7 @@ void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, double* { int info = LAPACKE_dsygvd(toLapackLayout(layout), itype, jobz, uplo, n, a, lda, b, ldb, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsygvd failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dsygvd failed with info = " + std::to_string(info)); } } @@ -54,7 +53,7 @@ void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::com { int info = LAPACKE_chegvd(toLapackLayout(layout), itype, jobz, uplo, n, reinterpret_cast(a), lda, reinterpret_cast(b), ldb, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK chegvd failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK chegvd failed with info = " + std::to_string(info)); } } @@ -62,7 +61,7 @@ void hegvd(MatrixLayout layout, int itype, char jobz, char uplo, int n, std::com { int info = LAPACKE_zhegvd(toLapackLayout(layout), itype, jobz, uplo, n, reinterpret_cast(a), lda, reinterpret_cast(b), ldb, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zhegvd failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zhegvd failed with info = " + std::to_string(info)); } } @@ -77,7 +76,7 @@ void hegvx(MatrixLayout layout, int itype, char jobz, char range, char uplo, int vl, vu, il, iu, abstol, m, w, reinterpret_cast(z), ldz, ifail); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK chegvx failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK chegvx failed with info = " + std::to_string(info)); } } @@ -92,7 +91,7 @@ void hegvx(MatrixLayout layout, int itype, char jobz, char range, char uplo, int vl, vu, il, iu, abstol, m, w, reinterpret_cast(z), ldz, ifail); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zhegvx failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zhegvx failed with info = " + std::to_string(info)); } } @@ -106,7 +105,7 @@ void hegvx(MatrixLayout layout, int itype, char jobz, char range, char uplo, int vl, vu, il, iu, abstol, m, w, z, ldz, ifail); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsygvx failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dsygvx failed with info = " + std::to_string(info)); } } @@ -114,7 +113,7 @@ void potrf(MatrixLayout layout, char uplo, int n, float* a, int lda) { int info = LAPACKE_spotrf(toLapackLayout(layout), uplo, n, a, lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK spotrf failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK spotrf failed with info = " + std::to_string(info)); } } @@ -122,7 +121,7 @@ void potrf(MatrixLayout layout, char uplo, int n, double* a, int lda) { int info = LAPACKE_dpotrf(toLapackLayout(layout), uplo, n, a, lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dpotrf failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dpotrf failed with info = " + std::to_string(info)); } } @@ -130,7 +129,7 @@ void potrf(MatrixLayout layout, char uplo, int n, std::complex* a, int ld { int info = LAPACKE_cpotrf(toLapackLayout(layout), uplo, n, reinterpret_cast(a), lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cpotrf failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK cpotrf failed with info = " + std::to_string(info)); } } @@ -138,7 +137,7 @@ void potrf(MatrixLayout layout, char uplo, int n, std::complex* a, int l { int info = LAPACKE_zpotrf(toLapackLayout(layout), uplo, n, reinterpret_cast(a), lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zpotrf failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zpotrf failed with info = " + std::to_string(info)); } } @@ -146,7 +145,7 @@ void potri(MatrixLayout layout, char uplo, int n, float* a, int lda) { int info = LAPACKE_spotri(toLapackLayout(layout), uplo, n, a, lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK spotri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK spotri failed with info = " + std::to_string(info)); } } @@ -154,7 +153,7 @@ void potri(MatrixLayout layout, char uplo, int n, double* a, int lda) { int info = LAPACKE_dpotri(toLapackLayout(layout), uplo, n, a, lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dpotri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dpotri failed with info = " + std::to_string(info)); } } @@ -162,7 +161,7 @@ void potri(MatrixLayout layout, char uplo, int n, std::complex* a, int ld { int info = LAPACKE_cpotri(toLapackLayout(layout), uplo, n, reinterpret_cast(a), lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cpotri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK cpotri failed with info = " + std::to_string(info)); } } @@ -170,7 +169,7 @@ void potri(MatrixLayout layout, char uplo, int n, std::complex* a, int l { int info = LAPACKE_zpotri(toLapackLayout(layout), uplo, n, reinterpret_cast(a), lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zpotri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zpotri failed with info = " + std::to_string(info)); } } @@ -178,7 +177,7 @@ void heev(MatrixLayout layout, char jobz, char uplo, int n, std::complex* { int info = LAPACKE_cheev(toLapackLayout(layout), jobz, uplo, n, reinterpret_cast(a), lda, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cheev failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK cheev failed with info = " + std::to_string(info)); } } @@ -186,7 +185,7 @@ void heev(MatrixLayout layout, char jobz, char uplo, int n, std::complex { int info = LAPACKE_zheev(toLapackLayout(layout), jobz, uplo, n, reinterpret_cast(a), lda, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zheev failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zheev failed with info = " + std::to_string(info)); } } @@ -195,7 +194,7 @@ void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, float* { int info = LAPACKE_ssyevx(toLapackLayout(layout), jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, ifail); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ssyevx failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK ssyevx failed with info = " + std::to_string(info)); } } @@ -204,7 +203,7 @@ void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, double* { int info = LAPACKE_dsyevx(toLapackLayout(layout), jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, ifail); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsyevx failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dsyevx failed with info = " + std::to_string(info)); } } @@ -217,7 +216,7 @@ void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, std::co vl, vu, il, iu, abstol, m, w, reinterpret_cast(z), ldz, ifail); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cheevx failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK cheevx failed with info = " + std::to_string(info)); } } @@ -230,7 +229,7 @@ void heevx(MatrixLayout layout, char jobz, char range, char uplo, int n, std::co vl, vu, il, iu, abstol, m, w, reinterpret_cast(z), ldz, ifail); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zheevx failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zheevx failed with info = " + std::to_string(info)); } } @@ -239,7 +238,7 @@ void heevd(MatrixLayout layout, char jobz, char uplo, int n, { int info = LAPACKE_ssyevd(toLapackLayout(layout), jobz, uplo, n, a, lda, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ssyevd failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK ssyevd failed with info = " + std::to_string(info)); } } @@ -248,7 +247,7 @@ void heevd(MatrixLayout layout, char jobz, char uplo, int n, { int info = LAPACKE_dsyevd(toLapackLayout(layout), jobz, uplo, n, a, lda, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsyevd failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dsyevd failed with info = " + std::to_string(info)); } } @@ -257,7 +256,7 @@ void heevd(MatrixLayout layout, char jobz, char uplo, int n, { int info = LAPACKE_cheevd(toLapackLayout(layout), jobz, uplo, n, reinterpret_cast(a), lda, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cheevd failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK cheevd failed with info = " + std::to_string(info)); } } @@ -266,7 +265,7 @@ void heevd(MatrixLayout layout, char jobz, char uplo, int n, { int info = LAPACKE_zheevd(toLapackLayout(layout), jobz, uplo, n, reinterpret_cast(a), lda, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zheevd failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zheevd failed with info = " + std::to_string(info)); } } @@ -274,7 +273,7 @@ void syev(MatrixLayout layout, char jobz, char uplo, int n, double* a, int lda, { int info = LAPACKE_dsyev(toLapackLayout(layout), jobz, uplo, n, a, lda, w); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsyev failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dsyev failed with info = " + std::to_string(info)); } } @@ -283,7 +282,7 @@ void geev(MatrixLayout layout, char jobvl, char jobvr, int n, double* a, int lda { int info = LAPACKE_dgeev(toLapackLayout(layout), jobvl, jobvr, n, a, lda, wr, wi, vl, ldvl, vr, ldvr); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgeev failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dgeev failed with info = " + std::to_string(info)); } } @@ -294,7 +293,7 @@ void geev(MatrixLayout layout, char jobvl, char jobvr, int n, std::complex(w), reinterpret_cast(vl), ldvl, reinterpret_cast(vr), ldvr); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zgeev failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zgeev failed with info = " + std::to_string(info)); } } @@ -302,7 +301,7 @@ void getrf(MatrixLayout layout, int m, int n, float* a, int lda, int* ipiv) { int info = LAPACKE_sgetrf(toLapackLayout(layout), m, n, a, lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK sgetrf failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK sgetrf failed with info = " + std::to_string(info)); } } @@ -310,7 +309,7 @@ void getrf(MatrixLayout layout, int m, int n, double* a, int lda, int* ipiv) { int info = LAPACKE_dgetrf(toLapackLayout(layout), m, n, a, lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgetrf failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dgetrf failed with info = " + std::to_string(info)); } } @@ -318,7 +317,7 @@ void getrf(MatrixLayout layout, int m, int n, std::complex* a, int lda, i { int info = LAPACKE_cgetrf(toLapackLayout(layout), m, n, reinterpret_cast(a), lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cgetrf failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK cgetrf failed with info = " + std::to_string(info)); } } @@ -326,7 +325,7 @@ void getrf(MatrixLayout layout, int m, int n, std::complex* a, int lda, { int info = LAPACKE_zgetrf(toLapackLayout(layout), m, n, reinterpret_cast(a), lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zgetrf failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zgetrf failed with info = " + std::to_string(info)); } } @@ -334,7 +333,7 @@ void getri(MatrixLayout layout, int n, float* a, int lda, const int* ipiv) { int info = LAPACKE_sgetri(toLapackLayout(layout), n, a, lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK sgetri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK sgetri failed with info = " + std::to_string(info)); } } @@ -342,7 +341,7 @@ void getri(MatrixLayout layout, int n, double* a, int lda, const int* ipiv) { int info = LAPACKE_dgetri(toLapackLayout(layout), n, a, lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgetri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dgetri failed with info = " + std::to_string(info)); } } @@ -350,7 +349,7 @@ void getri(MatrixLayout layout, int n, std::complex* a, int lda, const in { int info = LAPACKE_cgetri(toLapackLayout(layout), n, reinterpret_cast(a), lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cgetri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK cgetri failed with info = " + std::to_string(info)); } } @@ -358,7 +357,7 @@ void getri(MatrixLayout layout, int n, std::complex* a, int lda, const i { int info = LAPACKE_zgetri(toLapackLayout(layout), n, reinterpret_cast(a), lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zgetri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zgetri failed with info = " + std::to_string(info)); } } @@ -366,7 +365,7 @@ void getrs(MatrixLayout layout, char trans, int n, int nrhs, const float* a, int { int info = LAPACKE_sgetrs(toLapackLayout(layout), trans, n, nrhs, a, lda, ipiv, b, ldb); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK sgetrs failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK sgetrs failed with info = " + std::to_string(info)); } } @@ -374,7 +373,7 @@ void getrs(MatrixLayout layout, char trans, int n, int nrhs, const double* a, in { int info = LAPACKE_dgetrs(toLapackLayout(layout), trans, n, nrhs, a, lda, ipiv, b, ldb); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgetrs failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dgetrs failed with info = " + std::to_string(info)); } } @@ -382,7 +381,7 @@ void getrs(MatrixLayout layout, char trans, int n, int nrhs, const std::complex< { int info = LAPACKE_cgetrs(toLapackLayout(layout), trans, n, nrhs, reinterpret_cast(a), lda, ipiv, reinterpret_cast(b), ldb); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK cgetrs failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK cgetrs failed with info = " + std::to_string(info)); } } @@ -390,7 +389,7 @@ void getrs(MatrixLayout layout, char trans, int n, int nrhs, const std::complex< { int info = LAPACKE_zgetrs(toLapackLayout(layout), trans, n, nrhs, reinterpret_cast(a), lda, ipiv, reinterpret_cast(b), ldb); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK zgetrs failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK zgetrs failed with info = " + std::to_string(info)); } } @@ -398,7 +397,7 @@ void sytrf(MatrixLayout layout, char uplo, int n, double* a, int lda, int* ipiv) { int info = LAPACKE_dsytrf(toLapackLayout(layout), uplo, n, a, lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsytrf failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dsytrf failed with info = " + std::to_string(info)); } } @@ -406,7 +405,7 @@ void sytri(MatrixLayout layout, char uplo, int n, double* a, int lda, const int* { int info = LAPACKE_dsytri(toLapackLayout(layout), uplo, n, a, lda, ipiv); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsytri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dsytri failed with info = " + std::to_string(info)); } } @@ -414,7 +413,7 @@ void trtri(MatrixLayout layout, char uplo, char diag, int n, float* a, int lda) { int info = LAPACKE_strtri(toLapackLayout(layout), uplo, diag, n, a, lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK strtri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK strtri failed with info = " + std::to_string(info)); } } @@ -422,7 +421,7 @@ void trtri(MatrixLayout layout, char uplo, char diag, int n, double* a, int lda) { int info = LAPACKE_dtrtri(toLapackLayout(layout), uplo, diag, n, a, lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dtrtri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dtrtri failed with info = " + std::to_string(info)); } } @@ -430,7 +429,7 @@ void trtri(MatrixLayout layout, char uplo, char diag, int n, std::complex(a), lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ztrtri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK ztrtri failed with info = " + std::to_string(info)); } } @@ -438,7 +437,7 @@ void trtri(MatrixLayout layout, char uplo, char diag, int n, std::complex { int info = LAPACKE_ctrtri(toLapackLayout(layout), uplo, diag, n, reinterpret_cast(a), lda); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK ctrtri failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK ctrtri failed with info = " + std::to_string(info)); } } @@ -446,7 +445,7 @@ void gtsv(MatrixLayout layout, int n, int nrhs, double* dl, double* d, double* d { int info = LAPACKE_dgtsv(toLapackLayout(layout), n, nrhs, dl, d, du, b, ldb); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dgtsv failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dgtsv failed with info = " + std::to_string(info)); } } @@ -454,7 +453,7 @@ void sysv(MatrixLayout layout, char uplo, int n, int nrhs, double* a, int lda, i { int info = LAPACKE_dsysv(toLapackLayout(layout), uplo, n, nrhs, a, lda, ipiv, b, ldb); if (info != 0) { - ModuleBase::WARNING_QUIT("lapackConnector", "LAPACK dsysv failed with info = " + std::to_string(info)); + throw std::invalid_argument("LAPACK dsysv failed with info = " + std::to_string(info)); } } } // namespace LapackConnector diff --git a/source/source_base/module_grid/test/CMakeLists.txt b/source/source_base/module_grid/test/CMakeLists.txt index 6e08a90308..7cc3b2a5a4 100644 --- a/source/source_base/module_grid/test/CMakeLists.txt +++ b/source/source_base/module_grid/test/CMakeLists.txt @@ -23,11 +23,6 @@ AddTest( AddTest( TARGET MODULE_BASE_GRID_test_batch - SOURCES test_batch.cpp - ../batch.cpp - ../../blas_connector_base.cpp - ../../blas_connector_l1.cpp - ../../blas_connector_l2.cpp - ../../blas_connector_l3.cpp - ../../lapack_connector.cpp + SOURCES test_batch.cpp ../batch.cpp + LIBS parameter ${math_libs} ) diff --git a/source/source_base/test/CMakeLists.txt b/source/source_base/test/CMakeLists.txt index a39ba8436a..915c837752 100644 --- a/source/source_base/test/CMakeLists.txt +++ b/source/source_base/test/CMakeLists.txt @@ -86,7 +86,7 @@ AddTest( ) AddTest( TARGET MODULE_BASE_math_sphbes - LIBS parameter + LIBS parameter SOURCES math_sphbes_test.cpp ../math_sphbes.cpp ../timer.cpp ) AddTest( diff --git a/source/source_base/test/blas_connector_test.cpp b/source/source_base/test/blas_connector_test.cpp index 43a6728cc3..19dc29bfc0 100644 --- a/source/source_base/test/blas_connector_test.cpp +++ b/source/source_base/test/blas_connector_test.cpp @@ -6,7 +6,7 @@ #include #include #include -TEST(blas_connector, sscal_) { +TEST(blas_connector, sscal) { typedef float T; const int size = 8; const T scale = 2; @@ -16,12 +16,12 @@ TEST(blas_connector, sscal_) { []() { return std::rand() / T(RAND_MAX); }); for (int i = 0; i < size; i++) answer[i] = result[i] * scale; - sscal_(&size, &scale, result.data(), &incx); + BlasConnector::scal(size, scale, result.data(), incx); for (int i = 0; i < size; i++) EXPECT_FLOAT_EQ(answer[i], result[i]); } -TEST(blas_connector, dscal_) { +TEST(blas_connector, dscal) { typedef double T; const int size = 8; const T scale = 2; @@ -31,12 +31,12 @@ TEST(blas_connector, dscal_) { []() { return std::rand() / T(RAND_MAX); }); for (int i = 0; i < size; i++) answer[i] = result[i] * scale; - dscal_(&size, &scale, result.data(), &incx); + BlasConnector::scal(size, scale, result.data(), incx); for (int i = 0; i < size; i++) EXPECT_DOUBLE_EQ(answer[i], result[i]); } -TEST(blas_connector, cscal_) { +TEST(blas_connector, cscal) { typedef std::complex T; const int size = 8; const T scale = {2, 3}; @@ -48,14 +48,14 @@ TEST(blas_connector, cscal_) { }); for (int i = 0; i < size; i++) answer[i] = result[i] * scale; - cscal_(&size, &scale, result.data(), &incx); + BlasConnector::scal(size, scale, result.data(), incx); for (int i = 0; i < size; i++) { EXPECT_FLOAT_EQ(answer[i].real(), result[i].real()); EXPECT_FLOAT_EQ(answer[i].imag(), result[i].imag()); } } -TEST(blas_connector, zscal_) { +TEST(blas_connector, zscal) { typedef std::complex T; const int size = 8; const T scale = {2, 3}; @@ -67,26 +67,7 @@ TEST(blas_connector, zscal_) { }); for (int i = 0; i < size; i++) answer[i] = result[i] * scale; - zscal_(&size, &scale, result.data(), &incx); - for (int i = 0; i < size; i++) { - EXPECT_DOUBLE_EQ(answer[i].real(), result[i].real()); - EXPECT_DOUBLE_EQ(answer[i].imag(), result[i].imag()); - } -} - -TEST(blas_connector, Scal) { - const int size = 8; - const std::complex scale = {2, 3}; - const int incx = 1; - std::complex result[8], answer[8]; - for (int i=0; i< size; i++) { - result[i] = std::complex{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }; - for (int i = 0; i < size; i++) - answer[i] = result[i] * scale; - BlasConnector::scal(size,scale,result,incx); - // incx is the spacing between elements if result + BlasConnector::scal(size, scale, result.data(), incx); for (int i = 0; i < size; i++) { EXPECT_DOUBLE_EQ(answer[i].real(), result[i].real()); EXPECT_DOUBLE_EQ(answer[i].imag(), result[i].imag()); @@ -121,7 +102,7 @@ TEST(blas_connector, ScalGpu) { #endif -TEST(blas_connector, daxpy_) { +TEST(blas_connector, daxpy) { typedef double T; const int size = 8; const T scale = 2; @@ -134,36 +115,12 @@ TEST(blas_connector, daxpy_) { []() { return std::rand() / double(RAND_MAX); }); for (int i = 0; i < size; i++) answer[i] = x_const[i] * scale + result[i]; - daxpy_(&size, &scale, x_const.data(), &incx, result.data(), &incy); + BlasConnector::axpy(size, scale, x_const.data(), incx, result.data(), incy); for (int i = 0; i < size; i++) EXPECT_DOUBLE_EQ(answer[i], result[i]); } -TEST(blas_connector, zaxpy_) { - typedef std::complex T; - const int size = 8; - const T scale = {2, 3}; - const int incx = 1; - const int incy = 1; - std::array x_const, result, answer; - std::generate(x_const.begin(), x_const.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - std::generate(result.begin(), result.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - for (int i = 0; i < size; i++) - answer[i] = x_const[i] * scale + result[i]; - zaxpy_(&size, &scale, x_const.data(), &incx, result.data(), &incy); - for (int i = 0; i < size; i++) { - EXPECT_DOUBLE_EQ(answer[i].real(), result[i].real()); - EXPECT_DOUBLE_EQ(answer[i].imag(), result[i].imag()); - } -} - -TEST(blas_connector, Axpy) { +TEST(blas_connector, zaxpy) { typedef std::complex T; const int size = 8; const T scale = {2, 3}; @@ -224,7 +181,7 @@ TEST(blas_connector, AxpyGpu) { #endif -TEST(blas_connector, dcopy_) { +TEST(blas_connector, dcopy) { typedef double T; long const size = 8; int const incx = 1; @@ -234,12 +191,12 @@ TEST(blas_connector, dcopy_) { []() { return std::rand() / double(RAND_MAX); }); for (int i = 0; i < size; i++) answer[i] = x_const[i]; - dcopy_(&size, x_const.data(), &incx, result.data(), &incy); + BlasConnector::copy(size, x_const.data(), incx, result.data(), incy); for (int i = 0; i < size; i++) EXPECT_DOUBLE_EQ(answer[i], result[i]); } -TEST(blas_connector, zcopy_) { +TEST(blas_connector, zcopy) { typedef std::complex T; long const size = 8; int const incx = 1; @@ -251,32 +208,14 @@ TEST(blas_connector, zcopy_) { }); for (int i = 0; i < size; i++) answer[i] = x_const[i]; - zcopy_(&size, x_const.data(), &incx, result.data(), &incy); - for (int i = 0; i < size; i++) { - EXPECT_DOUBLE_EQ(answer[i].real(), result[i].real()); - EXPECT_DOUBLE_EQ(answer[i].imag(), result[i].imag()); - } -} - -TEST(blas_connector, copy) { - long const size = 8; - int const incx = 1; - int const incy = 1; - std::complex result[8], answer[8]; - for (int i = 0; i < size; i++) - { - answer[i] = std::complex{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - } - BlasConnector bs; - bs.copy(size, answer, incx, result, incy); + BlasConnector::copy(size, x_const.data(), incx, result.data(), incy); for (int i = 0; i < size; i++) { EXPECT_DOUBLE_EQ(answer[i].real(), result[i].real()); EXPECT_DOUBLE_EQ(answer[i].imag(), result[i].imag()); } } -TEST(blas_connector, dgemv_) { +TEST(blas_connector, dgemv) { typedef double T; const char transa_m = 'N'; const char transa_n = 'T'; @@ -306,8 +245,8 @@ TEST(blas_connector, dgemv_) { } answer_m[i] = alpha_const * c_dot_m[i] + beta_const * result_m[i]; } - dgemv_(&transa_m, &size_m, &size_n, &alpha_const, a_const.data(), &lda, - x_const_n.data(), &incx, &beta_const, result_m.data(), &incy); + BlasConnector::gemv_cm(transa_m, size_m, size_n, alpha_const, a_const.data(), lda, + x_const_n.data(), incx, beta_const, result_m.data(), incy); for (int j = 0; j < size_n; j++) { for (int i = 0; i < size_m; i++) { @@ -315,8 +254,8 @@ TEST(blas_connector, dgemv_) { } answer_n[j] = alpha_const * c_dot_n[j] + beta_const * result_n[j]; } - dgemv_(&transa_n, &size_m, &size_n, &alpha_const, a_const.data(), &lda, - x_const_m.data(), &incx, &beta_const, result_n.data(), &incy); + BlasConnector::gemv_cm(transa_n, size_m, size_n, alpha_const, a_const.data(), lda, + x_const_m.data(), incx, beta_const, result_n.data(), incy); for (int i = 0; i < size_m; i++) EXPECT_DOUBLE_EQ(answer_m[i], result_m[i]); @@ -324,84 +263,7 @@ TEST(blas_connector, dgemv_) { EXPECT_DOUBLE_EQ(answer_n[j], result_n[j]); } -TEST(blas_connector, zgemv_) { - typedef std::complex T; - const char transa_m = 'N'; - const char transa_n = 'T'; - const char transa_h = 'C'; - const int size_m = 3; - const int size_n = 4; - const T alpha_const = {2, 3}; - const T beta_const = {3, 4}; - const int lda = 5; - const int incx = 1; - const int incy = 1; - std::array x_const_m, x_const_c, result_m, answer_m, c_dot_m{}; - std::array x_const_n, result_n, result_c, answer_n, answer_c, - c_dot_n{}, c_dot_c{}; - std::generate(x_const_n.begin(), x_const_n.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - std::generate(result_n.begin(), result_n.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - std::generate(x_const_m.begin(), x_const_m.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - std::generate(result_m.begin(), result_m.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - std::array a_const; - std::generate(a_const.begin(), a_const.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - for (int i = 0; i < size_m; i++) { - for (int j = 0; j < size_n; j++) { - c_dot_m[i] += a_const[i + j * lda] * x_const_n[j]; - } - answer_m[i] = alpha_const * c_dot_m[i] + beta_const * result_m[i]; - } - zgemv_(&transa_m, &size_m, &size_n, &alpha_const, a_const.data(), &lda, - x_const_n.data(), &incx, &beta_const, result_m.data(), &incy); - - for (int j = 0; j < size_n; j++) { - for (int i = 0; i < size_m; i++) { - c_dot_n[j] += a_const[i + j * lda] * x_const_m[i]; - } - answer_n[j] = alpha_const * c_dot_n[j] + beta_const * result_n[j]; - } - zgemv_(&transa_n, &size_m, &size_n, &alpha_const, a_const.data(), &lda, - x_const_m.data(), &incx, &beta_const, result_n.data(), &incy); - - for (int j = 0; j < size_n; j++) { - for (int i = 0; i < size_m; i++) { - c_dot_c[j] += conj(a_const[i + j * lda]) * x_const_c[i]; - } - answer_c[j] = alpha_const * c_dot_c[j] + beta_const * result_c[j]; - } - zgemv_(&transa_h, &size_m, &size_n, &alpha_const, a_const.data(), &lda, - x_const_c.data(), &incx, &beta_const, result_c.data(), &incy); - - for (int i = 0; i < size_m; i++) { - EXPECT_DOUBLE_EQ(answer_m[i].real(), result_m[i].real()); - EXPECT_DOUBLE_EQ(answer_m[i].imag(), result_m[i].imag()); - } - for (int j = 0; j < size_n; j++) { - EXPECT_DOUBLE_EQ(answer_n[j].real(), result_n[j].real()); - EXPECT_DOUBLE_EQ(answer_n[j].imag(), result_n[j].imag()); - } - for (int j = 0; j < size_n; j++) { - EXPECT_DOUBLE_EQ(answer_c[j].real(), result_c[j].real()); - EXPECT_DOUBLE_EQ(answer_c[j].imag(), result_c[j].imag()); - } -} - -TEST(blas_connector, Gemv) { +TEST(blas_connector, gemv) { typedef std::complex T; const char transa_m = 'N'; const char transa_n = 'T'; @@ -478,8 +340,7 @@ TEST(blas_connector, Gemv) { } } - -TEST(blas_connector, dgemm_) { +TEST(blas_connector, dgemm) { typedef double T; const char transa_m = 'N'; const char transb_m = 'N'; @@ -510,9 +371,9 @@ TEST(blas_connector, dgemm_) { beta_const * result[i + j * ldc]; } } - dgemm_(&transa_m, &transb_m, &size_m, &size_n, &size_k, &alpha_const, - a_const.data(), &lda, b_const.data(), &ldb, &beta_const, - result.data(), &ldc); + BlasConnector::gemm(transa_m, transb_m, size_m, size_n, size_k, alpha_const, + a_const.data(), lda, b_const.data(), ldb, beta_const, + result.data(), ldc); for (int i = 0; i < size_m; i++) for (int j = 0; j < size_n; j++) { @@ -520,57 +381,7 @@ TEST(blas_connector, dgemm_) { } } -TEST(blas_connector, zgemm_) { - typedef std::complex T; - const char transa_m = 'N'; - const char transb_m = 'N'; - const int size_m = 3; - const int size_n = 4; - const int size_k = 5; - const T alpha_const = {2, 3}; - const T beta_const = {3, 4}; - const int lda = 6; - const int ldb = 5; - const int ldc = 4; - std::array a_const; - std::array b_const; - std::array c_dot{}, answer, result; - std::generate(a_const.begin(), a_const.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - std::generate(b_const.begin(), b_const.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - std::generate(result.begin(), result.end(), []() { - return T{static_cast(std::rand() / double(RAND_MAX)), - static_cast(std::rand() / double(RAND_MAX))}; - }); - for (int i = 0; i < size_m; i++) { - for (int j = 0; j < size_n; j++) { - for (int k = 0; k < size_k; k++) { - c_dot[i + j * ldc] += - a_const[i + k * lda] * b_const[k + j * ldb]; - } - answer[i + j * ldc] = alpha_const * c_dot[i + j * ldc] + - beta_const * result[i + j * ldc]; - } - } - zgemm_(&transa_m, &transb_m, &size_m, &size_n, &size_k, &alpha_const, - a_const.data(), &lda, b_const.data(), &ldb, &beta_const, - result.data(), &ldc); - - for (int i = 0; i < size_m; i++) - for (int j = 0; j < size_n; j++) { - EXPECT_DOUBLE_EQ(answer[i + j * ldc].real(), - result[i + j * ldc].real()); - EXPECT_DOUBLE_EQ(answer[i + j * ldc].imag(), - result[i + j * ldc].imag()); - } -} - -TEST(blas_connector, Gemm) { +TEST(blas_connector, zgemm) { typedef std::complex T; const char transa_m = 'N'; const char transb_m = 'N'; @@ -607,7 +418,7 @@ TEST(blas_connector, Gemm) { beta_const * result[i + j * ldc]; } } - BlasConnector::gemm_cm(transa_m, transb_m, size_m, size_n, size_k, alpha_const, + BlasConnector::gemm(transa_m, transb_m, size_m, size_n, size_k, alpha_const, a_const.data(), lda, b_const.data(), ldb, beta_const, result.data(), ldc); diff --git a/source/source_basis/module_ao/test/CMakeLists.txt b/source/source_basis/module_ao/test/CMakeLists.txt index ef769ec459..bd4393f5a9 100644 --- a/source/source_basis/module_ao/test/CMakeLists.txt +++ b/source/source_basis/module_ao/test/CMakeLists.txt @@ -7,11 +7,6 @@ list(APPEND depend_files ../../../source_base/math_ylmreal.cpp ../../../source_base/ylm.cpp ../../../source_base/memory.cpp - ../../../source_base/module_external/blas_connector_base.cpp - ../../../source_base/module_external/blas_connector_l1.cpp - ../../../source_base/module_external/blas_connector_l2.cpp - ../../../source_base/module_external/blas_connector_l3.cpp - ../../../source_base/module_external/lapack_connector.cpp ../../../source_base/complexarray.cpp ../../../source_base/complexmatrix.cpp ../../../source_base/matrix.cpp diff --git a/source/source_basis/module_pw/kernels/test/CMakeLists.txt b/source/source_basis/module_pw/kernels/test/CMakeLists.txt index f90d15e9c5..7e5555bf37 100644 --- a/source/source_basis/module_pw/kernels/test/CMakeLists.txt +++ b/source/source_basis/module_pw/kernels/test/CMakeLists.txt @@ -9,9 +9,4 @@ AddTest( ../../../../source_base/parallel_comm.cpp ../../../../source_base/complexmatrix.cpp ../../../../source_base/matrix.cpp ../../../../source_base/memory.cpp ../../../../source_base/libm/branred.cpp ../../../../source_base/libm/sincos.cpp - ../../../../source_base/module_external/blas_connector_base.cpp - ../../../../source_base/module_external/blas_connector_l1.cpp - ../../../../source_base/module_external/blas_connector_l2.cpp - ../../../../source_base/module_external/blas_connector_l3.cpp - ../../../../source_base/module_external/lapack_connector.cpp ) \ No newline at end of file diff --git a/source/source_basis/module_pw/test/CMakeLists.txt b/source/source_basis/module_pw/test/CMakeLists.txt index 7359383905..af887e604b 100644 --- a/source/source_basis/module_pw/test/CMakeLists.txt +++ b/source/source_basis/module_pw/test/CMakeLists.txt @@ -4,11 +4,6 @@ AddTest( LIBS parameter ${math_libs} planewave device SOURCES ../../../source_base/matrix.cpp ../../../source_base/complexmatrix.cpp ../../../source_base/matrix3.cpp ../../../source_base/tool_quit.cpp ../../../source_base/mymath.cpp ../../../source_base/timer.cpp ../../../source_base/memory.cpp - ../../../source_base/module_external/blas_connector_base.cpp - ../../../source_base/module_external/blas_connector_l1.cpp - ../../../source_base/module_external/blas_connector_l2.cpp - ../../../source_base/module_external/blas_connector_l3.cpp - ../../../source_base/module_external/lapack_connector.cpp ../../../source_base/libm/branred.cpp ../../../source_base/libm/sincos.cpp ../../../source_base/module_device/memory_op.cpp depend_mock.cpp pw_test.cpp test1-1-1.cpp test1-1-2.cpp test1-2.cpp test1-3.cpp test1-4.cpp test1-5.cpp diff --git a/source/source_hamilt/module_xc/test/CMakeLists.txt b/source/source_hamilt/module_xc/test/CMakeLists.txt index 857bd7516e..57e25e2950 100644 --- a/source/source_hamilt/module_xc/test/CMakeLists.txt +++ b/source/source_hamilt/module_xc/test/CMakeLists.txt @@ -40,11 +40,6 @@ AddTest( ../../../source_base/memory.cpp ../../../source_base/libm/branred.cpp ../../../source_base/libm/sincos.cpp - ../../../source_base/module_external/blas_connector_base.cpp - ../../../source_base/module_external/blas_connector_l1.cpp - ../../../source_base/module_external/blas_connector_l2.cpp - ../../../source_base/module_external/blas_connector_l3.cpp - ../../../source_base/module_external/lapack_connector.cpp ../../../source_basis/module_pw/module_fft/fft_bundle.cpp ../../../source_basis/module_pw/module_fft/fft_cpu.cpp ${FFT_SRC} @@ -77,11 +72,6 @@ AddTest( ../xc_functional_vxc.cpp ../xc_functional_libxc_vxc.cpp ../xc_functional_libxc_tools.cpp - ../../../source_base/module_external/blas_connector_base.cpp - ../../../source_base/module_external/blas_connector_l1.cpp - ../../../source_base/module_external/blas_connector_l2.cpp - ../../../source_base/module_external/blas_connector_l3.cpp - ../../../source_base/module_external/lapack_connector.cpp ../../../source_base/matrix.cpp ../../../source_base/memory.cpp ../../../source_base/timer.cpp diff --git a/source/source_hsolver/test/diago_cg_test.cpp b/source/source_hsolver/test/diago_cg_test.cpp index d65b53cfbe..e5285cdf58 100644 --- a/source/source_hsolver/test/diago_cg_test.cpp +++ b/source/source_hsolver/test/diago_cg_test.cpp @@ -44,7 +44,7 @@ void lapackEigen(int &npw, std::vector> &hm, double *e, boo { clock_t start, end; start = clock(); - LapackConnector::heev(LapackConnector::ColMajor, tmp_c1, tmp_c2, npw, hm.data(), npw, e); + LapackConnector::heev(LapackConnector::ColMajor, 'V', 'U', npw, hm.data(), npw, e); end = clock(); if (outtime) { std::cout << "Lapack Run time: " << (double)(end - start) / CLOCKS_PER_SEC << " S" << std::endl; diff --git a/source/source_hsolver/test/diago_david_test.cpp b/source/source_hsolver/test/diago_david_test.cpp index a7189cda8d..6b1a5ab41a 100644 --- a/source/source_hsolver/test/diago_david_test.cpp +++ b/source/source_hsolver/test/diago_david_test.cpp @@ -39,6 +39,7 @@ void lapackEigen(int& npw, std::vector>& hm, double* e, boo clock_t start,end; start = clock(); char tmp_c1 = 'V', tmp_c2 = 'U'; + auto tmp = hm; LapackConnector::heev(LapackConnector::ColMajor, tmp_c1, tmp_c2, npw, tmp.data(), npw, e); end = clock(); if (outtime) { std::cout<<"Lapack Run time: "<<(double)(end - start) / CLOCKS_PER_SEC<<" S"< *hmatrix, std::complex *smatrix, double *e, int &nFull) @@ -194,7 +194,7 @@ void lapack_diago(std::complex *hmatrix, std::complex *smatrix, b[i] = smatrix[i]; } - LapackConnector::hegv(LapackConnector::ColMajor, itype, jobz, uplo, nFull, a.data(), nFull, b.data(), nFull, e, ev.data()); + LapackConnector::hegv(LapackConnector::ColMajor, itype, jobz, uplo, nFull, a.data(), nFull, b.data(), nFull, e); } } // namespace LCAO_DIAGO_TEST diff --git a/source/source_md/test/CMakeLists.txt b/source/source_md/test/CMakeLists.txt index c91c449f79..0a4c0de37a 100644 --- a/source/source_md/test/CMakeLists.txt +++ b/source/source_md/test/CMakeLists.txt @@ -22,11 +22,6 @@ list(APPEND depend_files ../../source_base/matrix3.cpp ../../source_base/matrix.cpp ../../source_base/timer.cpp - ../../source_base/module_external/blas_connector_base.cpp - ../../source_base/module_external/blas_connector_l1.cpp - ../../source_base/module_external/blas_connector_l2.cpp - ../../source_base/module_external/blas_connector_l3.cpp - ../../source_base/module_external/lapack_connector.cpp ../../source_base/memory.cpp ../../source_base/global_variable.cpp ../../source_base/global_function.cpp diff --git a/source/source_pw/module_pwdft/test/CMakeLists.txt b/source/source_pw/module_pwdft/test/CMakeLists.txt index 6477b3c480..6c8d794ed1 100644 --- a/source/source_pw/module_pwdft/test/CMakeLists.txt +++ b/source/source_pw/module_pwdft/test/CMakeLists.txt @@ -15,11 +15,6 @@ AddTest( ../../../source_base/global_file.cpp ../../../source_base/memory.cpp ../../../source_base/timer.cpp - ../../../source_base/module_external/blas_connector_base.cpp - ../../../source_base/module_external/blas_connector_l1.cpp - ../../../source_base/module_external/blas_connector_l2.cpp - ../../../source_base/module_external/blas_connector_l3.cpp - ../../../source_base/module_external/lapack_connector.cpp ../../../source_base/parallel_global.cpp ../../../source_base/parallel_comm.cpp ../../../source_base/parallel_common.cpp diff --git a/source/source_relax/test/CMakeLists.txt b/source/source_relax/test/CMakeLists.txt index 5cab0093f6..a8b7ceb4ee 100644 --- a/source/source_relax/test/CMakeLists.txt +++ b/source/source_relax/test/CMakeLists.txt @@ -18,11 +18,6 @@ AddTest( ../../source_base/matrix3.cpp ../../source_base/intarray.cpp ../../source_base/tool_title.cpp ../../source_base/global_function.cpp ../../source_base/complexmatrix.cpp ../../source_base/matrix.cpp ../../source_base/complexarray.cpp ../../source_base/tool_quit.cpp ../../source_base/realarray.cpp - ../../source_base/module_external/blas_connector_base.cpp - ../../source_base/module_external/blas_connector_l1.cpp - ../../source_base/module_external/blas_connector_l2.cpp - ../../source_base/module_external/blas_connector_l3.cpp - ../../source_base/module_external/lapack_connector.cpp ../../source_cell/update_cell.cpp ../../source_cell/print_cell.cpp ../../source_cell/bcast_cell.cpp ../../source_io/output.cpp LIBS parameter ${math_libs} )