@@ -4,8 +4,8 @@ use std::os::raw::c_char;
44use  std:: ptr; 
55
66use  cust:: stream:: Stream ; 
7- use  cust_raw:: cublas_sys ; 
8- use  cust_raw:: driver_sys ; 
7+ use  cust_raw:: cublas ; 
8+ use  cust_raw:: driver ; 
99
1010use  super :: error:: DropResult ; 
1111use  super :: error:: ToResult  as  _; 
@@ -73,7 +73,7 @@ bitflags::bitflags! {
7373/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm) 
7474#[ derive( Debug ) ]  
7575pub  struct  CublasContext  { 
76-     pub ( crate )  raw :  cublas_sys :: cublasHandle_t , 
76+     pub ( crate )  raw :  cublas :: cublasHandle_t , 
7777} 
7878
7979impl  CublasContext  { 
@@ -92,10 +92,10 @@ impl CublasContext {
9292pub  fn  new ( )  -> Result < Self >  { 
9393        let  mut  raw = MaybeUninit :: uninit ( ) ; 
9494        unsafe  { 
95-             cublas_sys :: cublasCreate ( raw. as_mut_ptr ( ) ) . to_result ( ) ?; 
96-             cublas_sys :: cublasSetPointerMode ( 
95+             cublas :: cublasCreate ( raw. as_mut_ptr ( ) ) . to_result ( ) ?; 
96+             cublas :: cublasSetPointerMode ( 
9797                raw. assume_init ( ) , 
98-                 cublas_sys :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE , 
98+                 cublas :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE , 
9999            ) 
100100            . to_result ( ) ?; 
101101            Ok ( Self  { 
@@ -112,7 +112,7 @@ impl CublasContext {
112112
113113        unsafe  { 
114114            let  inner = mem:: replace ( & mut  ctx. raw ,  ptr:: null_mut ( ) ) ; 
115-             match  cublas_sys :: cublasDestroy ( inner) . to_result ( )  { 
115+             match  cublas :: cublasDestroy ( inner) . to_result ( )  { 
116116                Ok ( ( ) )  => { 
117117                    mem:: forget ( ctx) ; 
118118                    Ok ( ( ) ) 
@@ -127,7 +127,7 @@ impl CublasContext {
127127        let  mut  raw = MaybeUninit :: < u32 > :: uninit ( ) ; 
128128        unsafe  { 
129129            // getVersion can't fail 
130-             cublas_sys :: cublasGetVersion ( self . raw ,  raw. as_mut_ptr ( ) . cast ( ) ) 
130+             cublas :: cublasGetVersion ( self . raw ,  raw. as_mut_ptr ( ) . cast ( ) ) 
131131                . to_result ( ) 
132132                . unwrap ( ) ; 
133133
@@ -145,17 +145,15 @@ impl CublasContext {
145145    )  -> Result < T >  { 
146146        unsafe  { 
147147            // cudaStream_t is the same as CUstream 
148-             cublas_sys :: cublasSetStream ( 
148+             cublas :: cublasSetStream ( 
149149                self . raw , 
150-                 mem:: transmute :: < * mut  driver_sys:: CUstream_st ,  * mut  cublas_sys:: CUstream_st > ( 
151-                     stream. as_inner ( ) , 
152-                 ) , 
150+                 mem:: transmute :: < driver:: CUstream ,  cublas:: cudaStream_t > ( stream. as_inner ( ) ) , 
153151            ) 
154152            . to_result ( ) ?; 
155153            let  res = func ( self ) ?; 
156154            // reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to 
157155            // execute a raw sys function with the context's handle. 
158-             cublas_sys :: cublasSetStream ( self . raw ,  ptr:: null_mut ( ) ) . to_result ( ) ?; 
156+             cublas :: cublasSetStream ( self . raw ,  ptr:: null_mut ( ) ) . to_result ( ) ?; 
159157            Ok ( res) 
160158        } 
161159    } 
@@ -185,12 +183,12 @@ impl CublasContext {
185183/// ``` 
186184pub  fn  set_atomics_mode ( & self ,  allowed :  bool )  -> Result < ( ) >  { 
187185        unsafe  { 
188-             Ok ( cublas_sys :: cublasSetAtomicsMode ( 
186+             Ok ( cublas :: cublasSetAtomicsMode ( 
189187                self . raw , 
190188                if  allowed { 
191-                     cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED 
189+                     cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED 
192190                }  else  { 
193-                     cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED 
191+                     cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED 
194192                } , 
195193            ) 
196194            . to_result ( ) ?) 
@@ -215,10 +213,10 @@ impl CublasContext {
215213pub  fn  get_atomics_mode ( & self )  -> Result < bool >  { 
216214        let  mut  mode = MaybeUninit :: uninit ( ) ; 
217215        unsafe  { 
218-             cublas_sys :: cublasGetAtomicsMode ( self . raw ,  mode. as_mut_ptr ( ) ) . to_result ( ) ?; 
216+             cublas :: cublasGetAtomicsMode ( self . raw ,  mode. as_mut_ptr ( ) ) . to_result ( ) ?; 
219217            Ok ( match  mode. assume_init ( )  { 
220-                 cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED  => true , 
221-                 cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED  => false , 
218+                 cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED  => true , 
219+                 cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED  => false , 
222220            } ) 
223221        } 
224222    } 
@@ -238,9 +236,9 @@ impl CublasContext {
238236/// ``` 
239237pub  fn  set_math_mode ( & self ,  math_mode :  MathMode )  -> Result < ( ) >  { 
240238        unsafe  { 
241-             Ok ( cublas_sys :: cublasSetMathMode ( 
239+             Ok ( cublas :: cublasSetMathMode ( 
242240                self . raw , 
243-                 mem:: transmute :: < u32 ,  cublas_sys :: cublasMath_t > ( math_mode. bits ( ) ) , 
241+                 mem:: transmute :: < u32 ,  cublas :: cublasMath_t > ( math_mode. bits ( ) ) , 
244242            ) 
245243            . to_result ( ) ?) 
246244        } 
@@ -263,7 +261,7 @@ impl CublasContext {
263261pub  fn  get_math_mode ( & self )  -> Result < MathMode >  { 
264262        let  mut  mode = MaybeUninit :: uninit ( ) ; 
265263        unsafe  { 
266-             cublas_sys :: cublasGetMathMode ( self . raw ,  mode. as_mut_ptr ( ) ) . to_result ( ) ?; 
264+             cublas :: cublasGetMathMode ( self . raw ,  mode. as_mut_ptr ( ) ) . to_result ( ) ?; 
267265            Ok ( MathMode :: from_bits ( mode. assume_init ( )  as  u32 ) 
268266                . expect ( "Invalid MathMode from cuBLAS" ) ) 
269267        } 
@@ -303,7 +301,7 @@ impl CublasContext {
303301            let  path = log_file_name. map ( |p| CString :: new ( p) . expect ( "nul in log_file_name" ) ) ; 
304302            let  path_ptr = path. map_or ( ptr:: null ( ) ,  |s| s. as_ptr ( ) ) ; 
305303
306-             cublas_sys :: cublasLoggerConfigure ( 
304+             cublas :: cublasLoggerConfigure ( 
307305                enable as  i32 , 
308306                log_to_stdout as  i32 , 
309307                log_to_stderr as  i32 , 
@@ -320,7 +318,7 @@ impl CublasContext {
320318/// 
321319/// The callback must not panic and unwind. 
322320pub  unsafe  fn  set_logger_callback ( callback :  Option < unsafe  extern  "C"  fn ( * const  c_char ) > )  { 
323-         cublas_sys :: cublasSetLoggerCallback ( callback) 
321+         cublas :: cublasSetLoggerCallback ( callback) 
324322            . to_result ( ) 
325323            . unwrap ( ) ; 
326324    } 
@@ -329,7 +327,7 @@ impl CublasContext {
329327pub  fn  get_logger_callback ( )  -> Option < unsafe  extern  "C"  fn ( * const  c_char ) >  { 
330328        let  mut  cb = MaybeUninit :: uninit ( ) ; 
331329        unsafe  { 
332-             cublas_sys :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) ) 
330+             cublas :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) ) 
333331                . to_result ( ) 
334332                . unwrap ( ) ; 
335333            cb. assume_init ( ) 
@@ -340,7 +338,7 @@ impl CublasContext {
340338impl  Drop  for  CublasContext  { 
341339    fn  drop ( & mut  self )  { 
342340        unsafe  { 
343-             let  _ = cublas_sys :: cublasDestroy ( self . raw ) ; 
341+             let  _ = cublas :: cublasDestroy ( self . raw ) ; 
344342        } 
345343    } 
346344} 
0 commit comments