diff --git a/examples/bfloat16.jl b/examples/bfloat16.jl new file mode 100644 index 00000000..4b359682 --- /dev/null +++ b/examples/bfloat16.jl @@ -0,0 +1,49 @@ +using oneAPI, Test + +@static if !isdefined(Core, :BFloat16) + @info "BFloat16 requires Julia 1.12+, skipping." + exit() +end + +bfloat16_supported = oneAPI._device_supports_bfloat16() + +@info "BFloat16 support: $bfloat16_supported" + +if !bfloat16_supported + @info "Device does not support BFloat16, skipping." + exit() +end + +# Conversions: Core.BFloat16 in Julia 1.12 may not have Float32 constructors yet +float32_to_bf16(x::Float32) = reinterpret(Core.BFloat16, (reinterpret(UInt32, x) >> 16) % UInt16) +bf16_to_float32(x::Core.BFloat16) = reinterpret(Float32, UInt32(reinterpret(UInt16, x)) << 16) + +# Simple kernel: scale BFloat16 values by 2 via Float32 round-trip +# (BFloat16 arithmetic is done by promoting to Float32 on device) +function scale_bf16(input, output) + i = get_global_id() + @inbounds begin + val = reinterpret(UInt16, input[i]) + # BFloat16 -> Float32: shift left 16 bits + f = reinterpret(Float32, UInt32(val) << 16) + f *= 2.0f0 + # Float32 -> BFloat16: take upper 16 bits + output[i] = reinterpret(Core.BFloat16, (reinterpret(UInt32, f) >> 16) % UInt16) + end + return +end + +n = 1024 +a = float32_to_bf16.(rand(Float32, n)) + +d_a = oneArray(a) +d_out = oneArray{Core.BFloat16}(undef, n) + +@oneapi items=n scale_bf16(d_a, d_out) +result = Array(d_out) + +# Verify: each output should be 2x the input (in Float32 space) +result_f32 = bf16_to_float32.(result) +expected_f32 = bf16_to_float32.(a) .* 2.0f0 +@test result_f32 ≈ expected_f32 +@info "BFloat16 scale-by-2 kernel passed!" diff --git a/src/array.jl b/src/array.jl index f219ad5f..c23a46f5 100644 --- a/src/array.jl +++ b/src/array.jl @@ -28,6 +28,22 @@ function contains_eltype(T, X) return false end +function _device_supports_bfloat16() + # check the driver extension first + if haskey(oneL0.extension_properties(driver()), + oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME) + return true + end + # some drivers (e.g. older versions on PVC/Max) don't advertise the extension, + # but the hardware supports BFloat16 natively. fall back to checking device ID. + dev_id = oneL0.properties(device()).deviceId + # Intel Data Center GPU Max (Ponte Vecchio): device IDs 0x0BD0-0x0BDB + if 0x0BD0 <= dev_id <= 0x0BDB + return true + end + return false +end + function check_eltype(T) Base.allocatedinline(T) || error("oneArray only supports element types that are stored inline") Base.isbitsunion(T) && error("oneArray does not yet support isbits-union arrays") @@ -39,6 +55,11 @@ function check_eltype(T) oneL0.ZE_DEVICE_MODULE_FLAG_FP64 contains_eltype(T, Float64) && error("Float64 is not supported on this device") end + @static if isdefined(Core, :BFloat16) + if !_device_supports_bfloat16() + contains_eltype(T, Core.BFloat16) && error("BFloat16 is not supported on this device") + end + end end """ diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index 4ca04b0d..8015248f 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -50,6 +50,14 @@ function GPUCompiler.finish_ir!(job::oneAPICompilerJob, mod::LLVM.Module, # indices (e.g., "1 0") corrupts adjacent struct fields. flatten_nested_insertvalue!(mod) + # When the device supports BFloat16 but the SPIR-V runtime doesn't accept + # SPV_KHR_bfloat16, lower all bfloat types to i16 so the translator can + # handle the module without the extension. + if @static(isdefined(Core, :BFloat16) && isdefined(LLVM, :BFloatType)) && + _device_supports_bfloat16() && !_driver_supports_bfloat16_spirv() + lower_bfloat_to_i16!(mod) + end + return entry end @@ -158,6 +166,105 @@ function flatten_insert!(inst::LLVM.Instruction) end +# Lower bfloat types to i16 in the LLVM IR. +# This is needed when the device supports BFloat16 but the SPIR-V runtime/translator +# doesn't support SPV_KHR_bfloat16. Since sizeof(bfloat)==sizeof(i16)==2, the memory +# layout is identical. +# +# TODO: Julia 1.12's Core.BFloat16 is a bare primitive (no Float32 conversion, no +# arithmetic), so fptrunc/fpext instructions never appear in practice. If Julia adds +# BFloat16 conversion methods in the future, this pass should be extended to handle +# fptrunc float→bfloat and fpext bfloat→float, either via inline RNE bit manipulation +# or calls to __devicelib_ConvertFToBF16INTEL / __devicelib_ConvertBF16ToFINTEL. +function lower_bfloat_to_i16!(mod::LLVM.Module) + T_bf16 = LLVM.BFloatType() + T_i16 = LLVM.Int16Type() + + # Phase 1: Eliminate all bitcasts between i16 and bfloat (same bit width). + eliminate_bf16_bitcasts!(mod, T_bf16, T_i16) + + # Phase 2: Replace remaining bfloat GEPs, loads, and stores with i16 equivalents. + for f in functions(mod) + isempty(blocks(f)) && continue + for bb in blocks(f) + to_replace = LLVM.Instruction[] + for inst in instructions(bb) + opcode = LLVM.API.LLVMGetInstructionOpcode(inst) + if opcode == LLVM.API.LLVMGetElementPtr + src_ty = LLVMType(LLVM.API.LLVMGetGEPSourceElementType(inst)) + src_ty == T_bf16 && push!(to_replace, inst) + elseif opcode == LLVM.API.LLVMLoad + value_type(inst) == T_bf16 && push!(to_replace, inst) + elseif opcode == LLVM.API.LLVMStore + value_type(LLVM.operands(inst)[1]) == T_bf16 && push!(to_replace, inst) + end + end + + for inst in to_replace + opcode = LLVM.API.LLVMGetInstructionOpcode(inst) + builder = LLVM.IRBuilder() + LLVM.position!(builder, inst) + + if opcode == LLVM.API.LLVMGetElementPtr + ptr = LLVM.operands(inst)[1] + indices = LLVM.Value[LLVM.operands(inst)[i] for i in 2:length(LLVM.operands(inst))] + new_gep = if LLVM.API.LLVMIsInBounds(inst) != 0 + LLVM.inbounds_gep!(builder, T_i16, ptr, indices) + else + LLVM.gep!(builder, T_i16, ptr, indices) + end + LLVM.replace_uses!(inst, new_gep) + elseif opcode == LLVM.API.LLVMLoad + ptr = LLVM.operands(inst)[1] + new_load = LLVM.load!(builder, T_i16, ptr) + LLVM.replace_uses!(inst, new_load) + elseif opcode == LLVM.API.LLVMStore + val = LLVM.operands(inst)[1] + ptr = LLVM.operands(inst)[2] + LLVM.store!(builder, val, ptr) + end + + LLVM.API.LLVMInstructionEraseFromParent(inst) + LLVM.dispose(builder) + end + end + end + + return true +end + +# Iteratively eliminate bitcasts between i16 and bfloat (same bit representation). +function eliminate_bf16_bitcasts!(mod::LLVM.Module, T_bf16::LLVMType, T_i16::LLVMType) + changed = true + while changed + changed = false + for f in functions(mod) + isempty(blocks(f)) && continue + for bb in blocks(f) + to_delete = LLVM.Instruction[] + for inst in instructions(bb) + if LLVM.API.LLVMGetInstructionOpcode(inst) == LLVM.API.LLVMBitCast + src = LLVM.operands(inst)[1] + src_ty = value_type(src) + dst_ty = value_type(inst) + if (src_ty == T_i16 && dst_ty == T_bf16) || + (src_ty == T_bf16 && dst_ty == T_i16) || + (src_ty == dst_ty) + LLVM.replace_uses!(inst, src) + push!(to_delete, inst) + changed = true + end + end + end + for inst in to_delete + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end + end + end +end + + ## compiler implementation (cache, configure, compile, and link) # cache of compilation caches, per device @@ -183,18 +290,35 @@ function compiler_config(dev; kwargs...) end return config end +# Whether the driver's SPIR-V runtime accepts the SPV_KHR_bfloat16 extension. +function _driver_supports_bfloat16_spirv() + @static if isdefined(Core, :BFloat16) + haskey(oneL0.extension_properties(driver()), + oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME) + else + false + end +end + @noinline function _compiler_config(dev; kernel=true, name=nothing, always_inline=false, kwargs...) supports_fp16 = oneL0.module_properties(device()).fp16flags & oneL0.ZE_DEVICE_MODULE_FLAG_FP16 == oneL0.ZE_DEVICE_MODULE_FLAG_FP16 supports_fp64 = oneL0.module_properties(device()).fp64flags & oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64 + # Allow BFloat16 in IR if the device supports it (even if the SPIR-V runtime doesn't + # advertise the extension). We lower bfloat→i16 in finish_ir! when needed. + supports_bfloat16 = _device_supports_bfloat16() # TODO: emit printf format strings in constant memory extensions = String[ "SPV_EXT_relaxed_printf_string_address_space", "SPV_EXT_shader_atomic_float_add" ] + # Only add the SPIR-V extension if the runtime actually supports it + if _driver_supports_bfloat16_spirv() + push!(extensions, "SPV_KHR_bfloat16") + end # create GPUCompiler objects - target = SPIRVCompilerTarget(; extensions, supports_fp16, supports_fp64, kwargs...) + target = SPIRVCompilerTarget(; extensions, supports_fp16, supports_fp64, supports_bfloat16, kwargs...) params = oneAPICompilerParams() CompilerConfig(target, params; kernel, name, always_inline) end diff --git a/test/setup.jl b/test/setup.jl index 269d5b9c..2c35aa2a 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -22,6 +22,14 @@ const float64_supported = oneL0.module_properties(device()).fp64flags & oneL0.ZE if float64_supported append!(eltypes, [Float64, ComplexF64]) end +@static if isdefined(Core, :BFloat16) + const bfloat16_supported = oneAPI._device_supports_bfloat16() + if bfloat16_supported + push!(eltypes, Core.BFloat16) + end +else + const bfloat16_supported = false +end TestSuite.supported_eltypes(::Type{<:oneArray}) = eltypes const validation_layer = parse(Bool, get(ENV, "ZE_ENABLE_VALIDATION_LAYER", "false"))