Actual source code: cupmcontext.hpp
1: #ifndef PETSCDEVICECONTEXTCUPM_HPP
2: #define PETSCDEVICECONTEXTCUPM_HPP
4: #include <petsc/private/deviceimpl.h>
5: #include <petsc/private/cupmblasinterface.hpp>
7: #include <array>
9: namespace Petsc
10: {
12: namespace Device
13: {
15: namespace CUPM
16: {
18: namespace Impl
19: {
21: namespace detail
22: {
24: // for tag-based dispatch of handle retrieval
25: template <typename T> struct HandleTag { using type = T; };
27: } // namespace detail
29: // Forward declare
30: template <DeviceType T> class PETSC_VISIBILITY_INTERNAL DeviceContext;
32: template <DeviceType T>
33: class DeviceContext : Impl::BlasInterface<T>
34: {
35: public:
36: PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t,T);
38: private:
39: template <typename H> using HandleTag = typename detail::HandleTag<H>;
40: using stream_tag = HandleTag<cupmStream_t>;
41: using blas_tag = HandleTag<cupmBlasHandle_t>;
42: using solver_tag = HandleTag<cupmSolverHandle_t>;
44: public:
45: // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46: // header, but since we are using the power of templates it must be declared part of
47: // this class to have easy access the same typedefs. Technically one can make a
48: // templated struct outside the class but it's more code for the same result.
49: struct PetscDeviceContext_IMPLS
50: {
51: cupmStream_t stream;
52: cupmEvent_t event;
53: cupmEvent_t begin; // timer-only
54: cupmEvent_t end; // timer-only
55: #if PetscDefined(USE_DEBUG)
56: PetscBool timerInUse;
57: #endif
58: cupmBlasHandle_t blas;
59: cupmSolverHandle_t solver;
61: PETSC_NODISCARD auto get(stream_tag) const -> decltype(this->stream) { return this->stream; }
62: PETSC_NODISCARD auto get(blas_tag) const -> decltype(this->blas) { return this->blas; }
63: PETSC_NODISCARD auto get(solver_tag) const -> decltype(this->solver) { return this->solver; }
64: };
66: private:
67: static bool initialized_;
68: static std::array<cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> blashandles_;
69: static std::array<cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> solverhandles_;
71: PETSC_CXX_COMPAT_DECL(constexpr PetscDeviceContext_IMPLS* impls_cast_(PetscDeviceContext ptr))
72: {
73: return static_cast<PetscDeviceContext_IMPLS*>(ptr->data);
74: }
76: PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(cupmBlasHandle_t &handle))
77: {
78: if (handle) return 0;
79: for (auto i = 0; i < 3; ++i) {
80: auto cberr = cupmBlasCreate(&handle);
81: if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
82: if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) cberr;
83: if (i != 2) {
84: PetscSleep(3);
85: continue;
86: }
88: }
89: return 0;
90: }
92: PETSC_CXX_COMPAT_DECL(PetscErrorCode set_handle_stream_(cupmBlasHandle_t &handle, cupmStream_t &stream))
93: {
94: cupmStream_t cupmStream;
96: cupmBlasGetStream(handle,&cupmStream);
97: if (cupmStream != stream) cupmBlasSetStream(handle,stream);
98: return 0;
99: }
101: PETSC_CXX_COMPAT_DECL(PetscErrorCode finalize_())
102: {
103: for (auto&& handle : blashandles_) {
104: if (handle) {
105: cupmBlasDestroy(handle);
106: handle = nullptr;
107: }
108: }
109: for (auto&& handle : solverhandles_) {
110: if (handle) {
111: cupmBlasInterface_t::DestroyHandle(handle);
112: handle = nullptr;
113: }
114: }
115: initialized_ = false;
116: return 0;
117: }
119: PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_(PetscInt id, PetscDeviceContext_IMPLS *dci))
120: {
121: PetscDeviceCheckDeviceCount_Internal(id);
122: if (!initialized_) {
123: initialized_ = true;
124: PetscRegisterFinalize(finalize_);
125: }
126: // use the blashandle as a canary
127: if (!blashandles_[id]) {
128: initialize_handle_(blashandles_[id]);
129: cupmBlasInterface_t::InitializeHandle(solverhandles_[id]);
130: }
131: set_handle_stream_(blashandles_[id],dci->stream);
132: cupmBlasInterface_t::SetHandleStream(solverhandles_[id],dci->stream);
133: dci->blas = blashandles_[id];
134: dci->solver = solverhandles_[id];
135: return 0;
136: }
138: public:
139: const struct _DeviceContextOps ops = {
140: destroy,
141: changeStreamType,
142: setUp,
143: query,
144: waitForContext,
145: synchronize,
146: getHandle<blas_tag>,
147: getHandle<solver_tag>,
148: getHandle<stream_tag>,
149: beginTimer,
150: endTimer,
151: };
153: // All of these functions MUST be static in order to be callable from C, otherwise they
154: // get the implicit 'this' pointer tacked on
155: PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
156: PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType));
157: PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
158: PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext,PetscBool*));
159: PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext));
160: PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
161: template <typename Handle_t>
162: PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext,void*));
163: PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
164: PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext,PetscLogDouble*));
165: };
167: template <DeviceType T>
168: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx))
169: {
170: auto dci = impls_cast_(dctx);
172: if (dci->stream) cupmStreamDestroy(dci->stream);
173: if (dci->event) cupmEventDestroy(dci->event);
174: if (dci->begin) cupmEventDestroy(dci->begin);
175: if (dci->end) cupmEventDestroy(dci->end);
176: PetscFree(dctx->data);
177: return 0;
178: }
180: template <DeviceType T>
181: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype))
182: {
183: auto dci = impls_cast_(dctx);
185: if (dci->stream) {
186: cupmStreamDestroy(dci->stream);
187: dci->stream = nullptr;
188: }
189: // set these to null so they aren't usable until setup is called again
190: dci->blas = nullptr;
191: dci->solver = nullptr;
192: return 0;
193: }
195: template <DeviceType T>
196: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx))
197: {
198: auto dci = impls_cast_(dctx);
200: if (dci->stream) {
201: cupmStreamDestroy(dci->stream);
202: dci->stream = nullptr;
203: }
204: switch (dctx->streamType) {
205: case PETSC_STREAM_GLOBAL_BLOCKING:
206: // don't create a stream for global blocking
207: break;
208: case PETSC_STREAM_DEFAULT_BLOCKING:
209: cupmStreamCreate(&dci->stream);
210: break;
211: case PETSC_STREAM_GLOBAL_NONBLOCKING:
212: cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);
213: break;
214: default:
215: SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %s",PetscStreamTypes[util::integral_value(dctx->streamType)]);
216: break;
217: }
218: if (!dci->event) cupmEventCreate(&dci->event);
219: #if PetscDefined(USE_DEBUG)
220: dci->timerInUse = PETSC_FALSE;
221: #endif
222: initialize_(dctx->device->deviceId,dci);
223: return 0;
224: }
226: template <DeviceType T>
227: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle))
228: {
229: cupmError_t cerr;
231: cerr = cupmStreamQuery(impls_cast_(dctx)->stream);
232: if (cerr == cupmSuccess) *idle = PETSC_TRUE;
233: else {
234: // somethings gone wrong
235: if (PetscUnlikely(cerr != cupmErrorNotReady)) cerr;
236: *idle = PETSC_FALSE;
237: }
238: return 0;
239: }
241: template <DeviceType T>
242: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb))
243: {
244: auto dcib = impls_cast_(dctxb);
246: cupmEventRecord(dcib->event,dcib->stream);
247: cupmStreamWaitEvent(impls_cast_(dctxa)->stream,dcib->event,0);
248: return 0;
249: }
251: template <DeviceType T>
252: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx))
253: {
254: auto dci = impls_cast_(dctx);
256: // in case anything was queued on the event
257: cupmStreamWaitEvent(dci->stream,dci->event,0);
258: cupmStreamSynchronize(dci->stream);
259: return 0;
260: }
262: template <DeviceType T>
263: template <typename handle_t>
264: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle))
265: {
266: *static_cast<typename handle_t::type*>(handle) = impls_cast_(dctx)->get(handle_t());
267: return 0;
268: }
270: template <DeviceType T>
271: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx))
272: {
273: auto dci = impls_cast_(dctx);
275: #if PetscDefined(USE_DEBUG)
277: dci->timerInUse = PETSC_TRUE;
278: #endif
279: if (!dci->begin) {
280: cupmEventCreate(&dci->begin);
281: cupmEventCreate(&dci->end);
282: }
283: cupmEventRecord(dci->begin,dci->stream);
284: return 0;
285: }
287: template <DeviceType T>
288: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed))
289: {
290: float gtime;
291: auto dci = impls_cast_(dctx);
293: #if PetscDefined(USE_DEBUG)
295: dci->timerInUse = PETSC_FALSE;
296: #endif
297: cupmEventRecord(dci->end,dci->stream);
298: cupmEventSynchronize(dci->end);
299: cupmEventElapsedTime(>ime,dci->begin,dci->end);
300: *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
301: return 0;
302: }
304: // initialize the static member variables
305: template <DeviceType T> bool DeviceContext<T>::initialized_ = false;
307: template <DeviceType T>
308: std::array<typename DeviceContext<T>::cupmBlasHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
310: template <DeviceType T>
311: std::array<typename DeviceContext<T>::cupmSolverHandle_t,PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
313: } // namespace Impl
315: // shorten this one up a bit (and instantiate the templates)
316: using CUPMContextCuda = Impl::DeviceContext<DeviceType::CUDA>;
317: using CUPMContextHip = Impl::DeviceContext<DeviceType::HIP>;
319: // shorthand for what is an EXTREMELY long name
320: #define PetscDeviceContext_(IMPLS) Petsc::Device::CUPM::Impl::DeviceContext<Petsc::Device::CUPM::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
322: } // namespace CUPM
324: } // namespace Device
326: } // namespace Petsc
328: #endif // PETSCDEVICECONTEXTCUDA_HPP