11#pragma once
22
3- #include " debug.hpp"
43#include < cuda_runtime.h>
4+ #include < nvfunctional>
55#include < version>
66#include < cstddef>
77#include < cstdio>
88#include < cstdlib>
99#include < cstdarg>
10+ #include < cstdarg>
1011#include < memory>
1112#include < new>
1213#include < string>
@@ -187,6 +188,8 @@ private:
187188
188189public:
189190 CudaMemPool (std::nullptr_t ) noexcept {}
191+ CudaMemPool (CudaMemPool &&) = default ;
192+ CudaMemPool &operator =(CudaMemPool &&) = default ;
190193
191194 struct Builder {
192195 private:
@@ -259,12 +262,17 @@ private:
259262
260263public:
261264 CudaEvent (std::nullptr_t ) noexcept {}
265+ CudaEvent (CudaEvent &&) = default ;
266+ CudaEvent &operator =(CudaEvent &&) = default ;
262267
263268 struct Builder {
264269 private:
265270 int flags = cudaEventDefault;
266271
267272 public:
273+ Builder () = default ;
274+ explicit Builder (int flags) noexcept : flags(flags) {}
275+
268276 Builder &withBlockingSync (bool blockingSync = true ) noexcept {
269277 if (blockingSync) {
270278 flags |= cudaEventBlockingSync;
@@ -303,24 +311,28 @@ public:
303311 CHECK_CUDA (cudaEventSynchronize (*this ));
304312 }
305313
306- bool joinReady () const {
314+ bool poll () const {
307315 cudaError_t res = cudaEventQuery (*this );
308316 if (res == cudaSuccess) {
309317 return true ;
310318 }
311319 if (res == cudaErrorNotReady) {
312320 return false ;
313321 }
314- CHECK_CUDA (res);
322+ CHECK_CUDA (res /* cudaEventQuery */ );
315323 return false ;
316324 }
317325
318326 float elapsedMillis (CudaEvent const &event) const {
319327 float result;
320- CHECK_CUDA (cudaEventElapsedTime (&result, * this , event ));
328+ CHECK_CUDA (cudaEventElapsedTime (&result, event, * this ));
321329 return result;
322330 }
323331
332+ float operator -(CudaEvent const &event) const {
333+ return elapsedMillis (event);
334+ }
335+
324336 ~CudaEvent () {
325337 if (*this ) {
326338 CHECK_CUDA (cudaEventDestroy (*this ));
@@ -335,12 +347,17 @@ private:
335347
336348public:
337349 CudaStream (std::nullptr_t ) noexcept {}
350+ CudaStream (CudaStream &&) = default ;
351+ CudaStream &operator =(CudaStream &&) = default ;
338352
339353 struct Builder {
340354 private:
341355 int flags = cudaStreamDefault;
342356
343357 public:
358+ Builder () = default ;
359+ explicit Builder (int flags) noexcept : flags(flags) {}
360+
344361 Builder &withNonBlocking (bool nonBlocking = true ) noexcept {
345362 if (nonBlocking) {
346363 flags |= cudaStreamNonBlocking;
@@ -357,10 +374,14 @@ public:
357374 }
358375 };
359376
360- static CudaStream nullStream () noexcept {
377+ static CudaStream defaultStream () noexcept {
361378 return CudaStream (nullptr );
362379 }
363380
381+ static CudaStream perThreadStream () noexcept {
382+ return CudaStream (cudaStreamPerThread);
383+ }
384+
364385 void copy (void *dst, void *src, size_t size, cudaMemcpyKind kind) const {
365386 CHECK_CUDA (cudaMemcpyAsync (dst, src, size, kind, *this ));
366387 }
@@ -381,11 +402,17 @@ public:
381402 copy (dst, src, size, cudaMemcpyHostToHost);
382403 }
383404
384- void record (CudaEvent const &event) const {
405+ void recordEvent (CudaEvent const &event) const {
385406 CHECK_CUDA (cudaEventRecord (event, *this ));
386407 }
387408
388- void wait (CudaEvent const &event,
409+ CudaEvent recordEvent () const {
410+ CudaEvent event = CudaEvent::Builder ().build ();
411+ recordEvent (event);
412+ return event;
413+ }
414+
415+ void waitEvent (CudaEvent const &event,
389416 unsigned int flags = cudaEventWaitDefault) const {
390417 CHECK_CUDA (cudaStreamWaitEvent (*this , event, flags));
391418 }
@@ -403,22 +430,23 @@ public:
403430 auto userData = std::make_unique<Func>();
404431 cudaStreamCallback_t callback = [](cudaStream_t stream,
405432 cudaError_t status, void *userData) {
433+ CHECK_CUDA (status /* joinAsync cudaStreamCallback */ );
406434 std::unique_ptr<Func> func (static_cast <Func *>(userData));
407- (*func)(stream, status );
435+ (*func)();
408436 };
409437 joinAsync (callback, userData.get ());
410438 userData.release ();
411439 }
412440
413- bool joinReady () const {
441+ bool poll () const {
414442 cudaError_t res = cudaStreamQuery (*this );
415443 if (res == cudaSuccess) {
416444 return true ;
417445 }
418446 if (res == cudaErrorNotReady) {
419447 return false ;
420448 }
421- CHECK_CUDA (res);
449+ CHECK_CUDA (res /* cudaStreamQuery */ );
422450 return false ;
423451 }
424452
@@ -428,7 +456,7 @@ public:
428456 }
429457
430458 ~CudaStream () {
431- if (*this ) {
459+ if (*this && * this != cudaStreamPerThread ) {
432460 CHECK_CUDA (cudaStreamDestroy (*this ));
433461 }
434462 }
@@ -522,8 +550,8 @@ struct CudaAllocator : private Arena {
522550 };
523551};
524552
525- template <class T >
526- using CudaVector = std::vector<T, CudaAllocator<T>>;
553+ template <class T , class Arena = CudaManagedArena >
554+ using CudaVector = std::vector<T, CudaAllocator<T, Arena >>;
527555
528556#if defined(__clang__) && defined(__CUDACC__) && defined(__GLIBCXX__)
529557__host__ __device__ static void printf (const char *fmt, ...) {
0 commit comments