@@ -2,12 +2,13 @@ using oneAPI_Support_Headers_jll
22
33include (" generate_helpers.jl" )
44
5- blas = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " blas" , " buffer_decls.hpp" )]
6- lapack = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " lapack" , " lapack.hpp" ),
7- joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " lapack" , " scratchpad.hpp" )]
8- sparse = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " spblas" , " sparse_structures.hpp" ),
9- joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " spblas" , " sparse_auxiliary.hpp" ),
10- joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " spblas" , " sparse_operations.hpp" )]
5+ include_dir = joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" )
6+ blas = [joinpath (include_dir, " oneapi" , " mkl" , " blas" , " buffer_decls.hpp" )]
7+ lapack = [joinpath (include_dir, " oneapi" , " mkl" , " lapack" , " lapack.hpp" ),
8+ joinpath (include_dir, " oneapi" , " mkl" , " lapack" , " scratchpad.hpp" )]
9+ sparse = [joinpath (include_dir, " oneapi" , " mkl" , " spblas" , " sparse_structures.hpp" ),
10+ joinpath (include_dir, " oneapi" , " mkl" , " spblas" , " sparse_auxiliary.hpp" ),
11+ joinpath (include_dir, " oneapi" , " mkl" , " spblas" , " sparse_operations.hpp" )]
1112
1213dict_version = Dict {Int, Char} (1 => ' S' , 2 => ' D' , 3 => ' C' , 4 => ' Z' )
1314
@@ -23,7 +24,8 @@ version_types_header = Dict{Char, String}('S' => "float",
2324
2425comments = [" namespace" , " #" , " }" , " /*" , " *" , " //" , " [[" , " ONEMKL_DECLARE_" , " ONEMKL_INLINE_DECLARE" ]
2526
26- void_output = [" init_matrix_handle" , " init_matmat_descr" , " release_matmat_descr" , " set_matmat_data" , " get_matmat_data" ]
27+ void_output = [" init_matrix_handle" , " init_matmat_descr" , " release_matmat_descr" , " set_matmat_data" ,
28+ " get_matmat_data" , " init_omatadd_descr" , " init_omatconvert_desc" ]
2729
2830function generate_headers (library:: String , filename:: Vector{String} , output:: String ; pattern:: String = " " )
2931 routines = Dict {String,Int} ()
@@ -189,6 +191,8 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
189191 header = replace (header, " ,)" => " )" )
190192 header = replace (header, " void" => " void" )
191193 header = replace (header, " sycl::event" => " sycl::event" )
194+ header = replace (header, " * const* " => " **" )
195+ header = replace (header, " int64_t**" => " int64_t **" )
192196
193197 ind1 = findfirst (' ' , header)
194198 ind2 = findfirst (' (' , header)
@@ -245,6 +249,7 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
245249 (version == ' C' ) && (header = replace (header, " std::complex " => " float _Complex " ))
246250 (version == ' Z' ) && (header = replace (header, " std::complex " => " double _Complex " ))
247251 end
252+ header = replace (header, " omatconvert (" => " omatconvert(" )
248253 header = replace (header, " transpose " => " onemklTranspose " )
249254 header = replace (header, " uplo " => " onemklUplo " )
250255 header = replace (header, " diag " => " onemklDiag " )
@@ -255,6 +260,8 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
255260 header = replace (header, " sparse::matrix_view_descr " => " onemklMatrixView " )
256261 header = replace (header, " matrix_view_descr " => " onemklMatrixView " )
257262 header = replace (header, " sparse::matmat_request " => " onemklMatmatRequest " )
263+ header = replace (header, " omatconvert_alg " => " onemklOmatconvertAlg " )
264+ header = replace (header, " omatadd_alg " => " onemklOmataddAlg " )
258265 header = replace (header, name_routine => " sparse_" * name_routine)
259266 end
260267 push! (signatures, (header, name_routine, version, type_routine, template))
@@ -381,6 +388,10 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
381388 parameters = replace (parameters, " matrix_handle_t " => " (oneapi::mkl::sparse::matrix_handle_t) " )
382389 parameters = replace (parameters, " matmat_descr_t *" => " (oneapi::mkl::sparse::matmat_descr_t*) " )
383390 parameters = replace (parameters, " matmat_descr_t " => " (oneapi::mkl::sparse::matmat_descr_t) " )
391+ parameters = replace (parameters, " omatadd_descr_t *" => " (oneapi::mkl::sparse::omatadd_descr_t*) " )
392+ parameters = replace (parameters, " omatadd_descr_t " => " (oneapi::mkl::sparse::omatadd_descr_t) " )
393+ parameters = replace (parameters, " omatconvert_descr_t *" => " (oneapi::mkl::sparse::omatconvert_descr_t*) " )
394+ parameters = replace (parameters, " omatconvert_descr_t " => " (oneapi::mkl::sparse::omatconvert_descr_t) " )
384395 parameters = replace (parameters, " short **" => " reinterpret_cast<sycl::half **>" )
385396 parameters = replace (parameters, " float _Complex **" => " reinterpret_cast<std::complex<float> **>" )
386397 parameters = replace (parameters, " double _Complex **" => " reinterpret_cast<std::complex<double> **>" )
@@ -407,7 +418,8 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
407418
408419 for type in (" onemklTranspose" , " onemklSide" , " onemklUplo" , " onemklDiag" , " onemklGenerate" ,
409420 " onemklLayout" , " onemklJob" , " onemklJobsvd" , " onemklCompz" , " onemklRangev" ,
410- " onemklIndex" , " onemklProperty" , " onemklMatrixView" , " onemklMatmatRequest" )
421+ " onemklIndex" , " onemklProperty" , " onemklMatrixView" , " onemklMatmatRequest" ,
422+ " onemklOmatconvertAlg" , " onemklOmataddAlg" )
411423 parameters = replace (parameters, Regex (" $type ([A-Za-z0-9_]+)," ) => SubstitutionString (" convert(\\ 1)," ))
412424 parameters = replace (parameters, Regex (" , $type ([A-Za-z0-9_]+)" ) => SubstitutionString (" , convert(\\ 1)" ))
413425 end
0 commit comments