diff --git a/build.sh b/build.sh index 492c033686..fee36c4084 100755 --- a/build.sh +++ b/build.sh @@ -125,12 +125,15 @@ fi OUTPUT_NAME=${OUTPUT_NAME:-$ENV} # Standalone environment build +# -mavx2 enables AVX2 intrinsics (__m256, _mm256_*) which drive.h and +# src/bf16.h use directly. x86_64 only — strip if porting to ARM/Apple Silicon. +SIMD_FLAGS=(-mavx2 -mfma) if [ -n "$DEBUG" ] || [ "$MODE" = "local" ]; then - CLANG_OPT=(-g -O0 "${CLANG_WARN[@]}" "${SANITIZE_FLAGS[@]}") + CLANG_OPT=(-g -O0 "${CLANG_WARN[@]}" "${SANITIZE_FLAGS[@]}" "${SIMD_FLAGS[@]}") NVCC_OPT="-O0 -g" LINK_OPT="-g" else - CLANG_OPT=(-O2 -DNDEBUG "${CLANG_WARN[@]}") + CLANG_OPT=(-O2 -DNDEBUG "${CLANG_WARN[@]}" "${SIMD_FLAGS[@]}") NVCC_OPT="-O2 --threads 0" LINK_OPT="-O2" fi diff --git a/src/bf16.h b/src/bf16.h new file mode 100644 index 0000000000..8f95360d22 --- /dev/null +++ b/src/bf16.h @@ -0,0 +1,44 @@ +// Usage: +// #include "bf16.h" +// +// Set env_name/binding.c obs to PrecisionTensor +// +// bf16* observations; +// observations[0] = f32_to_bf16(some_float); // scalar +// +// // SIMD fast-path for inner loops with 8 floats already in an __m256: +// __m256 v = _mm256_mul_ps(x, scale); +// store_f32x8_as_bf16(&observations[i], v); // 1 store, 8 vals +// +// // Reverse if you ever need to read back as float: +// float f = bf16_to_f32(observations[0]); +// +// x86_64 only — uses AVX intrinsics. To port to ARM/Apple Silicon, replace +// store_f32x8_as_bf16 with a NEON equivalent (or remove and use scalar +// f32_to_bf16 in a loop — the compiler auto-vectorizes well). + +#include +#include +#include + +typedef uint16_t bf16; + +static inline bf16 f32_to_bf16(float f) { + uint32_t bits; + memcpy(&bits, &f, 4); + return (uint16_t)(bits >> 16); +} + +static inline float bf16_to_f32(bf16 b) { + uint32_t bits = (uint32_t)b << 16; + float f; + memcpy(&f, &bits, 4); + return f; +} + +static inline void store_f32x8_as_bf16(bf16* dst, __m256 v) { + __m256i vi = _mm256_srli_epi32(_mm256_castps_si256(v), 16); + __m128i lo = _mm256_castsi256_si128(vi); + __m128i hi = _mm256_extracti128_si256(vi, 1); + _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(lo, hi)); +} diff --git a/src/kernels.cu b/src/kernels.cu index 0a6df9f45c..d9b25d8579 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -317,6 +317,13 @@ __global__ void cast(precision_t* __restrict__ dst, } } +inline void cast_dispatch(precision_t* dst, const precision_t* src, int n, cudaStream_t stream) { + cudaMemcpyAsync(dst, src, n * sizeof(precision_t), cudaMemcpyDeviceToDevice, stream); +} +inline void cast_dispatch(precision_t* dst, const float* src, int n, cudaStream_t stream) { + cast<<>>(dst, src, n); +} + #ifndef PRECISION_FLOAT __global__ void cast(float* __restrict__ dst, const precision_t* __restrict__ src, int n) { diff --git a/src/pufferlib.cu b/src/pufferlib.cu index 0014999e64..583d9a115f 100644 --- a/src/pufferlib.cu +++ b/src/pufferlib.cu @@ -600,8 +600,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) { OBS_TENSOR_T& obs_env = env.obs; int n = block_size * obs_env.shape[1]; PrecisionTensor obs_dst = puf_slice(rollouts.observations, t, start, block_size); - cast<<>>( - obs_dst.data, obs_env.data + (long)start*obs_env.shape[1], n); + cast_dispatch(obs_dst.data, obs_env.data + (long)start*obs_env.shape[1], n, stream); PrecisionTensor rew_dst = puf_slice(rollouts.rewards, t, start, block_size); n = block_size; diff --git a/src/tensor.h b/src/tensor.h index 8bf00d1c5f..7904bb9b70 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -30,6 +30,20 @@ typedef struct { precision_t* data; int64_t shape[PUF_MAX_DIMS]; } PrecisionTensor; +#else +// C-compatible definition: precision_t is bf16 (uint16_t) or float depending on build mode. +// Only the element size matters here (used by obs_element_size() in vecenv.h). +#ifdef PRECISION_FLOAT +typedef struct { + float* data; + int64_t shape[PUF_MAX_DIMS]; +} PrecisionTensor; +#else +typedef struct { + uint16_t* data; + int64_t shape[PUF_MAX_DIMS]; +} PrecisionTensor; +#endif #endif #endif // PUFFERLIB_TENSOR_H