Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions cuda_core/cuda/core/_program.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,31 @@ class ProgramOptions:
# Set arch to default if not provided
if self.arch is None:
self.arch = f"sm_{Device().arch}"
if self.extra_sources is not None:
if not is_sequence(self.extra_sources):
raise TypeError(
"extra_sources must be a sequence of 2-tuples: ((name1, source1), (name2, source2), ...)"
)
for i, module in enumerate(self.extra_sources):
if not isinstance(module, tuple) or len(module) != 2:
raise TypeError(
f"Each extra module must be a 2-tuple (name, source)"
f", got {type(module).__name__} at index {i}"
)

module_name, module_source = module

if not isinstance(module_name, str):
raise TypeError(f"Module name at index {i} must be a string, got {type(module_name).__name__}")

if not isinstance(module_source, (str, bytes, bytearray)):
raise TypeError(
f"Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), "
f"or bytearray, got {type(module_source).__name__}"
)

if len(module_source) == 0:
raise ValueError(f"Module source for '{module_name}' (index {i}) cannot be empty")

def _prepare_nvrtc_options(self) -> list[bytes]:
return _prepare_nvrtc_options_impl(self)
Expand Down Expand Up @@ -624,33 +649,9 @@ cdef inline int Program_init(Program self, object code, str code_type, object op

# Add extra modules if provided
if options.extra_sources is not None:
if not is_sequence(options.extra_sources):
raise TypeError(
"extra_sources must be a sequence of 2-tuples: ((name1, source1), (name2, source2), ...)"
)
for i, module in enumerate(options.extra_sources):
if not isinstance(module, tuple) or len(module) != 2:
raise TypeError(
f"Each extra module must be a 2-tuple (name, source)"
f", got {type(module).__name__} at index {i}"
)

module_name, module_source = module

if not isinstance(module_name, str):
raise TypeError(f"Module name at index {i} must be a string, got {type(module_name).__name__}")

for module_name, module_source in options.extra_sources:
if isinstance(module_source, str):
# Textual LLVM IR - encode to UTF-8 bytes
module_source = module_source.encode("utf-8")
Comment on lines 653 to 654
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: I assume we could move the str->bytes conversion to earlier (perhaps do an in-place change in _prepare_nvvm_options_impl). Not a blocker!

elif not isinstance(module_source, (bytes, bytearray)):
raise TypeError(
f"Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), "
f"or bytearray, got {type(module_source).__name__}"
)

if len(module_source) == 0:
raise ValueError(f"Module source for '{module_name}' (index {i}) cannot be empty")

# Add the module using NVVM API
module_bytes = module_source if isinstance(module_source, bytes) else bytes(module_source)
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def test_cpp_program_with_extra_sources():
# negative test with NVRTC with multiple sources
code = 'extern "C" __global__ void my_kernel(){}'
helper = 'extern "C" __global__ void helper(){}'
options = ProgramOptions(extra_sources=helper)
options = ProgramOptions(extra_sources=[("helper", helper)])
with pytest.raises(ValueError, match="extra_sources is not supported by the NVRTC backend"):
Program(code, "c++", options)

Expand Down
Loading