Actual source code: cupmcontext.hpp
1: #if !defined(PETSCDEVICECONTEXTCUPM_HPP)
2: #define PETSCDEVICECONTEXTCUPM_HPP
4: #include <petsc/private/deviceimpl.h>
5: #include <petsc/private/cupminterface.hpp>
7: #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11)
8: #error PetscDeviceContext backends for CUDA and HIP requires C++11
9: #endif
11: namespace Petsc {
13: // Forward declare
14: template <CUPMDeviceKind T> class CUPMContext;
16: template <CUPMDeviceKind T>
17: class CUPMContext : CUPMInterface<T>
18: {
19: public:
20: PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T)
22: // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
23: // header, but since we are using the power of templates it must be declared part of
24: // this class to have easy access the same typedefs. Technically one can make a
25: // templated struct outside the class but it's more code for the same result.
26: struct PetscDeviceContext_IMPLS
27: {
28: cupmStream_t stream;
29: cupmEvent_t event;
30: cupmBlasHandle_t blas;
31: cupmSolverHandle_t solver;
32: };
34: private:
35: static cupmBlasHandle_t _blashandle;
36: static cupmSolverHandle_t _solverhandle;
38: PETSC_NODISCARD static PetscErrorCode __finalizeBLASHandle() noexcept
39: {
43: cupmInterface_t::DestroyHandle(_blashandle);
44: return(0);
45: }
47: PETSC_NODISCARD static PetscErrorCode __finalizeSOLVERHandle() noexcept
48: {
52: cupmInterface_t::DestroyHandle(_solverhandle);
53: return(0);
54: }
56: PETSC_NODISCARD static PetscErrorCode __setupHandles(PetscDeviceContext_IMPLS *dci) noexcept
57: {
58: PetscErrorCode ierr;
61: if (!_blashandle) {
62: cupmInterface_t::InitializeHandle(_blashandle);
63: PetscRegisterFinalize(__finalizeBLASHandle);
64: }
65: if (!_solverhandle) {
66: cupmInterface_t::InitializeHandle(_solverhandle);
67: PetscRegisterFinalize(__finalizeSOLVERHandle);
68: }
69: cupmInterface_t::SetHandleStream(_blashandle,dci->stream);
70: cupmInterface_t::SetHandleStream(_solverhandle,dci->stream);
71: dci->blas = _blashandle;
72: dci->solver = _solverhandle;
73: return(0);
74: }
76: public:
77: const struct _DeviceContextOps ops {destroy,changeStreamType,setUp,query,waitForContext,synchronize};
79: // default constructor
80: constexpr CUPMContext() noexcept = default;
82: // All of these functions MUST be static in order to be callable from C, otherwise they
83: // get the implicit 'this' pointer tacked on
84: PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept;
85: PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept;
86: PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept;
87: PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept;
88: PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept;
89: PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
90: };
92: #define IMPLS_RCAST_(obj_) static_cast<PetscDeviceContext_IMPLS*>((obj_)->data)
94: template <CUPMDeviceKind T>
95: inline PetscErrorCode CUPMContext<T>::destroy(PetscDeviceContext dctx) noexcept
96: {
97: PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
98: cupmError_t cerr;
99: PetscErrorCode ierr;
102: if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
103: if (dci->event) {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);}
104: PetscFree(dctx->data);
105: return(0);
106: }
108: template <CUPMDeviceKind T>
109: inline PetscErrorCode CUPMContext<T>::changeStreamType(PetscDeviceContext dctx, PetscStreamType stype) noexcept
110: {
111: PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
114: if (dci->stream) {
115: cupmError_t cerr;
117: cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
118: dci->stream = nullptr;
119: }
120: // set these to null so they aren't usable until setup is called again
121: dci->blas = nullptr;
122: dci->solver = nullptr;
123: return(0);
124: }
126: template <CUPMDeviceKind T>
127: inline PetscErrorCode CUPMContext<T>::setUp(PetscDeviceContext dctx) noexcept
128: {
129: PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
130: PetscErrorCode ierr;
131: cupmError_t cerr;
134: if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
135: switch (dctx->streamType) {
136: case PETSC_STREAM_GLOBAL_BLOCKING:
137: // don't create a stream for global blocking
138: dci->stream = nullptr;
139: break;
140: case PETSC_STREAM_DEFAULT_BLOCKING:
141: cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr);
142: break;
143: case PETSC_STREAM_GLOBAL_NONBLOCKING:
144: cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr);
145: break;
146: default:
147: SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %d",dctx->streamType);
148: break;
149: }
150: if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);}
151: __setupHandles(dci);
152: return(0);
153: }
155: template <CUPMDeviceKind T>
156: inline PetscErrorCode CUPMContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
157: {
158: cupmError_t cerr;
161: cerr = cupmStreamQuery(IMPLS_RCAST_(dctx)->stream);
162: if (cerr == cupmSuccess)
163: *idle = PETSC_TRUE;
164: else if (cerr == cupmErrorNotReady) {
165: *idle = PETSC_FALSE;
166: } else {
167: // somethings gone wrong
168: CHKERRCUPM(cerr);
169: }
170: return(0);
171: }
173: template <CUPMDeviceKind T>
174: inline PetscErrorCode CUPMContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
175: {
176: PetscDeviceContext_IMPLS *dcia = IMPLS_RCAST_(dctxa);
177: PetscDeviceContext_IMPLS *dcib = IMPLS_RCAST_(dctxb);
178: cupmError_t cerr;
181: cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr);
182: cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr);
183: return(0);
184: }
186: template <CUPMDeviceKind T>
187: inline PetscErrorCode CUPMContext<T>::synchronize(PetscDeviceContext dctx) noexcept
188: {
189: PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
190: cupmError_t cerr;
193: // in case anything was queued on the event
194: cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr);
195: cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr);
196: return(0);
197: }
199: // initialize the static member variables
200: template <CUPMDeviceKind T>
201: typename CUPMContext<T>::cupmBlasHandle_t CUPMContext<T>::_blashandle = nullptr;
203: template <CUPMDeviceKind T>
204: typename CUPMContext<T>::cupmSolverHandle_t CUPMContext<T>::_solverhandle = nullptr;
206: // shorten this one up a bit
207: using CUPMContextCuda = CUPMContext<CUPMDeviceKind::CUDA>;
208: using CUPMContextHip = CUPMContext<CUPMDeviceKind::HIP>;
210: // make sure these doesn't leak out
211: #undef CHKERRCUPM
212: #undef IMPLS_RCAST_
214: } // namespace Petsc
216: // shorthand for what is an EXTREMELY long name
217: #define PetscDeviceContext_(impls_) Petsc::CUPMContext<Petsc::CUPMDeviceKind::impls_>::PetscDeviceContext_IMPLS
219: // shorthand for casting dctx->data to the appropriate object to access the handles
220: #define PDC_IMPLS_RCAST(impls_,obj_) reinterpret_cast<PetscDeviceContext_(impls_) *>((obj_)->data)
222: #endif /* PETSCDEVICECONTEXTCUDA_HPP */