Skip to content

Commit 6bf6338

Browse files
authored
Use __pyx_capi__ for CUDA driver function pointers (#1450) (#1466)
Replace the _CUDA_DRIVER_API_V1 capsule with direct extraction of function pointers from cuda.bindings.cydriver.__pyx_capi__ at module import time. This simplifies the architecture by eliminating the custom capsule struct and its associated loading machinery (load_driver_api, ensure_driver_loaded, cuGetProcAddress resolution). The driver function pointers are now populated directly from Cython's built-in cross-module API mechanism. Closes #1450
1 parent 0f2949a commit 6bf6338

4 files changed

Lines changed: 203 additions & 367 deletions

File tree

cuda_core/cuda/core/_cpp/DESIGN.md

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ link against this code directly—they access it through a capsule mechanism
117117

118118
## Capsule Architecture
119119

120-
The implementation uses **two separate capsule mechanisms** for different purposes:
120+
The implementation uses a capsule mechanism for cross-module C++ function sharing,
121+
and Cython's `__pyx_capi__` for CUDA driver function resolution:
121122

122123
### Capsule 1: C++ API Table (`_CXX_API`)
123124

@@ -160,38 +161,45 @@ cdef inline StreamHandle create_stream_handle(...) except * nogil:
160161
Importing modules are expected to call `_init_handles_table()` prior to calling
161162
any wrapper functions.
162163

163-
### Capsule 2: CUDA Driver API (`_CUDA_DRIVER_API_V1`)
164+
### CUDA Driver Function Pointers via `__pyx_capi__`
164165

165166
**Problem**: cuda.core cannot directly call CUDA driver functions because:
166167

167168
1. We don't want to link against `libcuda.so` at build time.
168169
2. The driver symbols must be resolved dynamically through cuda-bindings.
169170

170-
**Solution**: `_resource_handles.pyx` creates a capsule containing CUDA driver
171-
function pointers obtained from cuda-bindings:
171+
**Solution**: The C++ code declares extern function pointer variables:
172172

173173
```cpp
174-
struct CudaDriverApiV1 {
175-
uint32_t abi_version;
176-
uint32_t struct_size;
177-
178-
uintptr_t cuDevicePrimaryCtxRetain;
179-
uintptr_t cuDevicePrimaryCtxRelease;
180-
uintptr_t cuStreamCreateWithPriority;
181-
uintptr_t cuStreamDestroy;
182-
// ... etc
183-
};
174+
// resource_handles.hpp
175+
extern decltype(&cuStreamCreateWithPriority) p_cuStreamCreateWithPriority;
176+
extern decltype(&cuMemPoolCreate) p_cuMemPoolCreate;
177+
// ... etc
184178
```
185179
186-
The C++ code retrieves this capsule once (via `load_driver_api()`) and caches the
187-
function pointers for subsequent use.
180+
At module import time, `_resource_handles.pyx` populates these pointers by
181+
extracting them from `cuda.bindings.cydriver.__pyx_capi__`:
182+
183+
```cython
184+
import cuda.bindings.cydriver as cydriver
185+
186+
cdef void* _get_driver_fn(str name):
187+
capsule = cydriver.__pyx_capi__[name]
188+
return PyCapsule_GetPointer(capsule, PyCapsule_GetName(capsule))
189+
190+
p_cuStreamCreateWithPriority = _get_driver_fn("cuStreamCreateWithPriority")
191+
```
188192

189-
### Why Two Capsules?
193+
The `__pyx_capi__` dictionary contains PyCapsules that Cython automatically
194+
generates for each `cdef` function declared in a `.pxd` file. Each capsule's
195+
name is the function's C signature; we query it with `PyCapsule_GetName()`
196+
rather than hardcoding signatures.
190197

191-
| Capsule | Direction | Purpose |
192-
|---------|-----------|---------|
193-
| `_CXX_API` | C++ → Cython | Share handle functions across modules |
194-
| `_CUDA_DRIVER_API_V1` | Cython → C++ | Provide resolved driver symbols |
198+
This approach:
199+
- Avoids linking against `libcuda.so` at build time
200+
- Works on CPU-only machines (capsule extraction succeeds; actual driver calls
201+
will return errors like `CUDA_ERROR_NO_DEVICE`)
202+
- Requires no custom capsule infrastructure—uses Cython's built-in mechanism
195203

196204
## Key Implementation Details
197205

0 commit comments

Comments
 (0)