diff --git a/cuda_core/cuda/core/_program.pyx b/cuda_core/cuda/core/_program.pyx index 0b1fa93279..3188c73ecf 100644 --- a/cuda_core/cuda/core/_program.pyx +++ b/cuda_core/cuda/core/_program.pyx @@ -402,6 +402,31 @@ class ProgramOptions: # Set arch to default if not provided if self.arch is None: self.arch = f"sm_{Device().arch}" + if self.extra_sources is not None: + if not is_sequence(self.extra_sources): + raise TypeError( + "extra_sources must be a sequence of 2-tuples: ((name1, source1), (name2, source2), ...)" + ) + for i, module in enumerate(self.extra_sources): + if not isinstance(module, tuple) or len(module) != 2: + raise TypeError( + f"Each extra module must be a 2-tuple (name, source)" + f", got {type(module).__name__} at index {i}" + ) + + module_name, module_source = module + + if not isinstance(module_name, str): + raise TypeError(f"Module name at index {i} must be a string, got {type(module_name).__name__}") + + if not isinstance(module_source, (str, bytes, bytearray)): + raise TypeError( + f"Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), " + f"or bytearray, got {type(module_source).__name__}" + ) + + if len(module_source) == 0: + raise ValueError(f"Module source for '{module_name}' (index {i}) cannot be empty") def _prepare_nvrtc_options(self) -> list[bytes]: return _prepare_nvrtc_options_impl(self) @@ -455,6 +480,23 @@ class ProgramOptions: def __repr__(self): return f"ProgramOptions(name={self.name!r}, arch={self.arch!r})" + def _prepare_extra_sources_bytes(self) -> list[tuple[bytes, bytes]] | None: + """Convert extra_sources to bytes format for NVVM.""" + if self.extra_sources is None: + return None + + result = [] + for module_name, module_source in self.extra_sources: + name_bytes = module_name.encode("utf-8") + if isinstance(module_source, str): + source_bytes = module_source.encode("utf-8") + elif isinstance(module_source, bytearray): + source_bytes = bytes(module_source) + else: + source_bytes = module_source + result.append((name_bytes, source_bytes)) + return result + # ============================================================================= # Private Classes and Helper Functions @@ -624,41 +666,11 @@ cdef inline int Program_init(Program self, object code, str code_type, object op # Add extra modules if provided if options.extra_sources is not None: - if not is_sequence(options.extra_sources): - raise TypeError( - "extra_sources must be a sequence of 2-tuples: ((name1, source1), (name2, source2), ...)" - ) - for i, module in enumerate(options.extra_sources): - if not isinstance(module, tuple) or len(module) != 2: - raise TypeError( - f"Each extra module must be a 2-tuple (name, source)" - f", got {type(module).__name__} at index {i}" - ) - - module_name, module_source = module - - if not isinstance(module_name, str): - raise TypeError(f"Module name at index {i} must be a string, got {type(module_name).__name__}") - - if isinstance(module_source, str): - # Textual LLVM IR - encode to UTF-8 bytes - module_source = module_source.encode("utf-8") - elif not isinstance(module_source, (bytes, bytearray)): - raise TypeError( - f"Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), " - f"or bytearray, got {type(module_source).__name__}" - ) - - if len(module_source) == 0: - raise ValueError(f"Module source for '{module_name}' (index {i}) cannot be empty") - - # Add the module using NVVM API - module_bytes = module_source if isinstance(module_source, bytes) else bytes(module_source) + extra_sources_bytes = options._prepare_extra_sources_bytes() + for module_name_bytes, module_bytes in extra_sources_bytes: module_ptr = module_bytes module_len = len(module_bytes) - module_name_bytes = module_name.encode() module_name_ptr = module_name_bytes - with nogil: HANDLE_RETURN_NVVM(nvvm_prog, cynvvm.nvvmAddModuleToProgram( nvvm_prog, module_ptr, module_len, module_name_ptr)) diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 565305a1e3..d31bf607e0 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -697,7 +697,7 @@ def test_cpp_program_with_extra_sources(): # negative test with NVRTC with multiple sources code = 'extern "C" __global__ void my_kernel(){}' helper = 'extern "C" __global__ void helper(){}' - options = ProgramOptions(extra_sources=helper) + options = ProgramOptions(extra_sources=[("helper", helper)]) with pytest.raises(ValueError, match="extra_sources is not supported by the NVRTC backend"): Program(code, "c++", options)