Skip to content

Switch to use CUDA driver APIs in Device constructor #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions cuda_core/cuda/core/experimental/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,34 +957,42 @@ def __new__(cls, device_id=None):

# important: creating a Device instance does not initialize the GPU!
if device_id is None:
device_id = handle_return(runtime.cudaGetDevice())
assert_type(device_id, int)
err, dev = driver.cuCtxGetDevice()
if err == 0:
device_id = int(dev)
else:
ctx = handle_return(driver.cuCtxGetCurrent())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's going on here? Is there some requirement from cudart which requires CtxGetCurrent() to be called before the device can be queried?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be helpful to raise a more specific error here?

It might be helpful to add a comment right after the else:

            else:
                # Emulate cudart behavior
                err, ctx = driver.cuCtxGetCurrent()
                if err != 0:
                    raise <Informative Error, what we really want is the current device (not primarily current context)>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles the case that no context is set to current. The logic is as follows:

  • if a context is set to current, we can easily get the device ID associated with the current context
  • if no context is set to current (which can happen right after cuInit(0) and before anything else is called), we confirm it is the case by checking ctx pointer is zero (err will always succeed), and then pick device 0

assert int(ctx) == 0
device_id = 0 # cudart behavior
assert isinstance(device_id, int), f"{device_id=}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentionally not using the

             assert_type(device_id, int)

helper?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR predates the helper, I'll add it

else:
total = handle_return(runtime.cudaGetDeviceCount())
assert_type(device_id, int)
if not (0 <= device_id < total):
total = handle_return(driver.cuDeviceGetCount())
if not isinstance(device_id, int) or not (0 <= device_id < total):
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")

# ensure Device is singleton
if not hasattr(_tls, "devices"):
total = handle_return(runtime.cudaGetDeviceCount())
total = handle_return(driver.cuDeviceGetCount())
_tls.devices = []
for dev_id in range(total):
dev = super().__new__(cls)

dev._id = dev_id
# If the device is in TCC mode, or does not support memory pools for some other reason,
# use the SynchronousMemoryResource which does not use memory pools.
if (
handle_return(
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id
)
)
) == 1:
dev._mr = _DefaultAsyncMempool(dev_id)
else:
dev._mr = _SynchronousMemoryResource(dev_id)

dev._has_inited = False
dev._properties = None

_tls.devices.append(dev)

return _tls.devices[device_id]
Expand Down
Loading