Skip to content
Open

Bf16 #566

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,15 @@ fi
OUTPUT_NAME=${OUTPUT_NAME:-$ENV}

# Standalone environment build
# -mavx2 enables AVX2 intrinsics (__m256, _mm256_*) which drive.h and
# src/bf16.h use directly. x86_64 only — strip if porting to ARM/Apple Silicon.
SIMD_FLAGS=(-mavx2 -mfma)
if [ -n "$DEBUG" ] || [ "$MODE" = "local" ]; then
CLANG_OPT=(-g -O0 "${CLANG_WARN[@]}" "${SANITIZE_FLAGS[@]}")
CLANG_OPT=(-g -O0 "${CLANG_WARN[@]}" "${SANITIZE_FLAGS[@]}" "${SIMD_FLAGS[@]}")
NVCC_OPT="-O0 -g"
LINK_OPT="-g"
else
CLANG_OPT=(-O2 -DNDEBUG "${CLANG_WARN[@]}")
CLANG_OPT=(-O2 -DNDEBUG "${CLANG_WARN[@]}" "${SIMD_FLAGS[@]}")
NVCC_OPT="-O2 --threads 0"
LINK_OPT="-O2"
fi
Expand Down
44 changes: 44 additions & 0 deletions src/bf16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Usage:
// #include "bf16.h"
//
// Set env_name/binding.c obs to PrecisionTensor
//
// bf16* observations;
// observations[0] = f32_to_bf16(some_float); // scalar
//
// // SIMD fast-path for inner loops with 8 floats already in an __m256:
// __m256 v = _mm256_mul_ps(x, scale);
// store_f32x8_as_bf16(&observations[i], v); // 1 store, 8 vals
//
// // Reverse if you ever need to read back as float:
// float f = bf16_to_f32(observations[0]);
//
// x86_64 only — uses AVX intrinsics. To port to ARM/Apple Silicon, replace
// store_f32x8_as_bf16 with a NEON equivalent (or remove and use scalar
// f32_to_bf16 in a loop — the compiler auto-vectorizes well).

#include <stdint.h>
#include <string.h>
#include <immintrin.h>

typedef uint16_t bf16;

static inline bf16 f32_to_bf16(float f) {
uint32_t bits;
memcpy(&bits, &f, 4);
return (uint16_t)(bits >> 16);
}

static inline float bf16_to_f32(bf16 b) {
uint32_t bits = (uint32_t)b << 16;
float f;
memcpy(&f, &bits, 4);
return f;
}

static inline void store_f32x8_as_bf16(bf16* dst, __m256 v) {
__m256i vi = _mm256_srli_epi32(_mm256_castps_si256(v), 16);
__m128i lo = _mm256_castsi256_si128(vi);
__m128i hi = _mm256_extracti128_si256(vi, 1);
_mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(lo, hi));
}
7 changes: 7 additions & 0 deletions src/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ __global__ void cast(precision_t* __restrict__ dst,
}
}

inline void cast_dispatch(precision_t* dst, const precision_t* src, int n, cudaStream_t stream) {
cudaMemcpyAsync(dst, src, n * sizeof(precision_t), cudaMemcpyDeviceToDevice, stream);
}
inline void cast_dispatch(precision_t* dst, const float* src, int n, cudaStream_t stream) {
cast<<<grid_size(n), BLOCK_SIZE, 0, stream>>>(dst, src, n);
}

#ifndef PRECISION_FLOAT
__global__ void cast(float* __restrict__ dst,
const precision_t* __restrict__ src, int n) {
Expand Down
3 changes: 1 addition & 2 deletions src/pufferlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
OBS_TENSOR_T& obs_env = env.obs;
int n = block_size * obs_env.shape[1];
PrecisionTensor obs_dst = puf_slice(rollouts.observations, t, start, block_size);
cast<<<grid_size(n), BLOCK_SIZE, 0, stream>>>(
obs_dst.data, obs_env.data + (long)start*obs_env.shape[1], n);
cast_dispatch(obs_dst.data, obs_env.data + (long)start*obs_env.shape[1], n, stream);

PrecisionTensor rew_dst = puf_slice(rollouts.rewards, t, start, block_size);
n = block_size;
Expand Down
14 changes: 14 additions & 0 deletions src/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ typedef struct {
precision_t* data;
int64_t shape[PUF_MAX_DIMS];
} PrecisionTensor;
#else
// C-compatible definition: precision_t is bf16 (uint16_t) or float depending on build mode.
// Only the element size matters here (used by obs_element_size() in vecenv.h).
#ifdef PRECISION_FLOAT
typedef struct {
float* data;
int64_t shape[PUF_MAX_DIMS];
} PrecisionTensor;
#else
typedef struct {
uint16_t* data;
int64_t shape[PUF_MAX_DIMS];
} PrecisionTensor;
#endif
#endif

#endif // PUFFERLIB_TENSOR_H
Loading