From e25abaf1fd1488a2deef962fa1c996679378dac0 Mon Sep 17 00:00:00 2001 From: eshau Date: Thu, 23 Apr 2026 20:18:53 -0400 Subject: [PATCH] Add vectorized mingru scan forward and backwards + profiling and kernel selector depending on BTH --- src/models.cu | 1030 ++++++++++++++++++++--- tests/profile_kernels.cu | 1683 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 2545 insertions(+), 168 deletions(-) diff --git a/src/models.cu b/src/models.cu index 2ba14caa43..ee0d8e23cc 100644 --- a/src/models.cu +++ b/src/models.cu @@ -150,7 +150,11 @@ struct PrefixScan { // Checkpointing trades off partial recomputation for memory bandwidth. #define CHECKPOINT_INTERVAL 4 -__global__ void mingru_scan_forward(PrefixScan scan) { +constexpr int MINGRU_SCAN_VEC128_WIDTH = 16 / sizeof(precision_t); +constexpr int MINGRU_SCAN_VEC64_WIDTH = 8 / sizeof(precision_t); + +template +__device__ __forceinline__ void mingru_scan_forward_ckpt_tuned_body(PrefixScan scan) { int T_seq = scan.T, H = scan.H, B = scan.B; precision_t* __restrict__ out = scan.out.data; precision_t* __restrict__ next_state = scan.next_state.data; @@ -179,7 +183,6 @@ __global__ void mingru_scan_forward(PrefixScan scan) { float a_star = 0.0f; float log_value = 0.0f; - // Handle t=0 outside the loop: use log(state), coeff = 0 float s = __logf(to_float(state[bH + h])); log_value = s; @@ -194,21 +197,20 @@ __global__ void mingru_scan_forward(PrefixScan scan) { const precision_t* combined_g_base = &combined[cbase + H + h]; const precision_t* combined_p_base = &combined[cbase + H2 + h]; - // Loop t=1..T_seq with sparse checkpointing float scan_result = 0.0f; int out_curr = out_base; int t_offset = 0; - for (int t = 1; t < T_seq + 1; t++) { + for (int t = 1; t <= T_seq; t++) { float hidden_val = to_float(combined_h_base[t_offset]); float gate_val = to_float(combined_g_base[t_offset]); float proj_val = to_float(combined_p_base[t_offset]); - float x_val = to_float(input[out_base + (t - 1) * H]); + int input_idx = out_base + (t - 1) * H; + float x_val = to_float(input[input_idx]); float log_coeff_val; log_coeffs_and_values_fwd(gate_val, hidden_val, &log_coeff_val, &log_value); - // a_star[t] = sum_{i=0}^t log_coeffs[i] a_star += log_coeff_val; float z = log_value - a_star; @@ -217,26 +219,29 @@ __global__ void mingru_scan_forward(PrefixScan scan) { scan_result = __expf(a_star + s); float proj_sigmoid = sigmoid(proj_val); - // out = sigmoid(proj) * scan_result + (1 - sigmoid(proj)) * x (highway gate) out[out_curr] = from_float(proj_sigmoid * scan_result + (1.0f - proj_sigmoid) * x_val); buf_curr += H; out_curr += H; t_offset += H3; - if (t % CHECKPOINT_INTERVAL == 0) { + if (t % CKPT_INTERVAL == 0) { a_star_buf[buf_curr] = a_star; s_buf[buf_curr] = s; log_values_buf[buf_curr] = log_value; } } - // Write timestep T to next_state (raw scan_result, no proj, for recurrence) next_state[bH + h] = from_float(scan_result); } +__global__ void mingru_scan_forward(PrefixScan scan) { + mingru_scan_forward_ckpt_tuned_body(scan); +} + // Reads sparse checkpoints from forward pass, recomputes intermediate values in chunks -__global__ void mingru_scan_backward(PrefixScan scan, +template +__device__ __forceinline__ void mingru_scan_backward_ckpt_tuned_body(PrefixScan scan, const precision_t* __restrict__ grad_out, const precision_t* __restrict__ grad_next_state) { int T_seq = scan.T, H = scan.H, B = scan.B; @@ -280,16 +285,18 @@ __global__ void mingru_scan_backward(PrefixScan scan, float s_val_next = 0.0; float carry_grad_a = 0.0; - for (int chunk_end = T_seq; chunk_end > 0; chunk_end -= CHECKPOINT_INTERVAL) { - int chunk_start = (chunk_end > CHECKPOINT_INTERVAL) ? (chunk_end - CHECKPOINT_INTERVAL) : 0; + // Backward recompute must start each chunk at a stored checkpoint index. + // Floor-to-multiple handles ragged tails (e.g. with CKPT_INTERVAL=4: T=10 -> starts 8,4,0). + for (int chunk_end = T_seq; chunk_end > 0;) { + int chunk_start = ((chunk_end - 1) / CKPT_INTERVAL) * CKPT_INTERVAL; int chunk_len = chunk_end - chunk_start; // Chunk storage in registers - float chunk_a_star[CHECKPOINT_INTERVAL]; - float chunk_s[CHECKPOINT_INTERVAL]; - float chunk_log_values[CHECKPOINT_INTERVAL]; - float chunk_hidden[CHECKPOINT_INTERVAL]; - float chunk_gate[CHECKPOINT_INTERVAL]; + float chunk_a_star[CKPT_INTERVAL]; + float chunk_s[CKPT_INTERVAL]; + float chunk_log_values[CKPT_INTERVAL]; + float chunk_hidden[CKPT_INTERVAL]; + float chunk_gate[CKPT_INTERVAL]; // Load checkpoint from global memory int ckpt_buf_idx = buf_base + chunk_start * H; @@ -297,36 +304,38 @@ __global__ void mingru_scan_backward(PrefixScan scan, float recomp_s = s_buf[ckpt_buf_idx]; float recomp_log_value = log_values_buf[ckpt_buf_idx]; - // Recompute and store from chunk_start to chunk_end - for (int i = 0; i < chunk_len; ++i) { - int t = chunk_start + 1 + i; + // Phase 1: recompute per-timestep values from checkpoint start to chunk end. + for (int chunk_i = 0; chunk_i < chunk_len; ++chunk_i) { + int t = chunk_start + 1 + chunk_i; int t_offset = (t - 1) * H3; - float hv = to_float(combined_h_base[t_offset]); - float gv = to_float(combined_g_base[t_offset]); + float hidden_val = to_float(combined_h_base[t_offset]); + float gate_val = to_float(combined_g_base[t_offset]); - float lc; - log_coeffs_and_values_fwd(gv, hv, &lc, &recomp_log_value); - recomp_a_star += lc; + float log_coeff_val; + log_coeffs_and_values_fwd(gate_val, hidden_val, &log_coeff_val, &recomp_log_value); + recomp_a_star += log_coeff_val; float z = recomp_log_value - recomp_a_star; recomp_s = logaddexp(recomp_s, z); - chunk_a_star[i] = recomp_a_star; - chunk_s[i] = recomp_s; - chunk_log_values[i] = recomp_log_value; - chunk_hidden[i] = hv; - chunk_gate[i] = gv; + chunk_a_star[chunk_i] = recomp_a_star; + chunk_s[chunk_i] = recomp_s; + chunk_log_values[chunk_i] = recomp_log_value; + chunk_hidden[chunk_i] = hidden_val; + chunk_gate[chunk_i] = gate_val; } - for (int i = chunk_len - 1; i >= 0; --i) { - int t = chunk_start + 1 + i; + // Phase 2: backprop through the chunk in reverse time order. + for (int chunk_i = chunk_len - 1; chunk_i >= 0; --chunk_i) { + int t = chunk_start + 1 + chunk_i; int t_offset = (t - 1) * H3; + const bool is_last_t = (t == T_seq); - float a_star_t = chunk_a_star[i]; - float s_t = chunk_s[i]; - float log_value_t = chunk_log_values[i]; - float hidden_val = chunk_hidden[i]; - float gate_val = chunk_gate[i]; + float a_star_t = chunk_a_star[chunk_i]; + float s_t = chunk_s[chunk_i]; + float log_value_t = chunk_log_values[chunk_i]; + float hidden_val = chunk_hidden[chunk_i]; + float gate_val = chunk_gate[chunk_i]; float proj_val = to_float(combined_p_base[t_offset]); int input_idx = out_base + (t - 1) * H; @@ -336,7 +345,7 @@ __global__ void mingru_scan_backward(PrefixScan scan, float z = log_value_t - a_star_t; float grad_out_val = to_float(grad_out[input_idx]); - float grad_scan_from_next = (t == T_seq) ? to_float(grad_next_state[state_idx]) : 0.0f; + float grad_scan_from_next = is_last_t ? to_float(grad_next_state[state_idx]) : 0.0f; float proj_sigmoid = sigmoid(proj_val); // Highway gate gradients: out = sigmoid(proj) * scan_result + (1 - sigmoid(proj)) * x @@ -347,7 +356,7 @@ __global__ void mingru_scan_backward(PrefixScan scan, float grad_log_h = grad_scan_result * scan_result; float grad_s = grad_log_h; - if (t == T_seq) { + if (is_last_t) { acc = grad_s; } else { acc = grad_s + acc * __expf(s_t - s_val_next); @@ -365,26 +374,852 @@ __global__ void mingru_scan_backward(PrefixScan scan, grad_combined_g_base[t_offset] = from_float(grad_g); grad_combined_p_base[t_offset] = from_float(grad_proj); } + + chunk_end = chunk_start; } int ckpt_0_idx = buf_base; float a_star_0 = a_star_buf[ckpt_0_idx]; float s_0 = s_buf[ckpt_0_idx]; float log_value_0 = log_values_buf[ckpt_0_idx]; - - float scan_result_0 = __expf(a_star_0 + s_0); float z_0 = log_value_0 - a_star_0; - - float grad_scan_result_0 = 0.0f; - float grad_log_h_0 = grad_scan_result_0 * scan_result_0; - float grad_s_0 = grad_log_h_0; - - acc = grad_s_0 + acc * __expf(s_0 - s_val_next); + acc = acc * __expf(s_0 - s_val_next); float grad_z_0 = acc * __expf(z_0 - s_0); grad_state[state_idx] = from_float(grad_z_0 / to_float(state[state_idx])); } +__global__ void mingru_scan_backward(PrefixScan scan, + const precision_t* __restrict__ grad_out, + const precision_t* __restrict__ grad_next_state) { + mingru_scan_backward_ckpt_tuned_body(scan, grad_out, grad_next_state); +} + +// Shared packed/vector helper layer for log vec32/vec64/vec128 kernels. + +__device__ __forceinline__ float2 scan_load_pair(const float* ptr) { + return make_float2(ptr[0], ptr[1]); +} + +__device__ __forceinline__ void scan_store_pair(float* ptr, float2 v) { + ptr[0] = v.x; + ptr[1] = v.y; +} + +__device__ __forceinline__ float2 scan_logaddexp_pair(float2 a, float2 b) { + return make_float2(logaddexp(a.x, b.x), logaddexp(a.y, b.y)); +} + +__device__ __forceinline__ void scan_log_coeffs_and_values_fwd_pair( + float2 gate, float2 hidden, float2* log_coeff_out, float2* log_value_io) { + log_coeffs_and_values_fwd(gate.x, hidden.x, &log_coeff_out->x, &log_value_io->x); + log_coeffs_and_values_fwd(gate.y, hidden.y, &log_coeff_out->y, &log_value_io->y); +} + +__device__ __forceinline__ void scan_log_coeffs_and_values_bwd_pair( + float2 grad_log_coeffs, float2 grad_log_values, float2 gate, float2 hidden, + float2* grad_gate_out, float2* grad_hidden_out) { + log_coeffs_and_values_bwd( + grad_log_coeffs.x, grad_log_values.x, gate.x, hidden.x, + &grad_gate_out->x, &grad_hidden_out->x); + log_coeffs_and_values_bwd( + grad_log_coeffs.y, grad_log_values.y, gate.y, hidden.y, + &grad_gate_out->y, &grad_hidden_out->y); +} + +template +__device__ __forceinline__ void scan_load_precision_vec(const precision_t* ptr, float* out) { +#ifdef PRECISION_FLOAT + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == 4, + "float build supports vec64/vec128 widths"); + if constexpr (VEC_WIDTH == 2) { + float2 v = *reinterpret_cast(ptr); + out[0] = v.x; + out[1] = v.y; + } else { + float4 v = *reinterpret_cast(ptr); + out[0] = v.x; + out[1] = v.y; + out[2] = v.z; + out[3] = v.w; + } +#else + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == 4 || VEC_WIDTH == 8, + "bf16 build supports vec32/vec64/vec128 widths"); + if constexpr (VEC_WIDTH == 2) { + uint32_t raw = *reinterpret_cast(ptr); + const precision_t* bf = reinterpret_cast(&raw); + out[0] = to_float(bf[0]); + out[1] = to_float(bf[1]); + } else if constexpr (VEC_WIDTH == 4) { + uint2 raw = *reinterpret_cast(ptr); + const precision_t* bf = reinterpret_cast(&raw); + #pragma unroll + for (int lane = 0; lane < 4; lane++) { + out[lane] = to_float(bf[lane]); + } + } else { + uint4 raw = *reinterpret_cast(ptr); + const precision_t* bf = reinterpret_cast(&raw); + #pragma unroll + for (int lane = 0; lane < 8; lane++) { + out[lane] = to_float(bf[lane]); + } + } +#endif +} + +template +__device__ __forceinline__ void scan_store_precision_vec(precision_t* ptr, const float* in) { +#ifdef PRECISION_FLOAT + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == 4, + "float build supports vec64/vec128 widths"); + if constexpr (VEC_WIDTH == 2) { + *reinterpret_cast(ptr) = make_float2(in[0], in[1]); + } else { + *reinterpret_cast(ptr) = make_float4(in[0], in[1], in[2], in[3]); + } +#else + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == 4 || VEC_WIDTH == 8, + "bf16 build supports vec32/vec64/vec128 widths"); + if constexpr (VEC_WIDTH == 2) { + alignas(4) precision_t tmp[2]; + tmp[0] = from_float(in[0]); + tmp[1] = from_float(in[1]); + *reinterpret_cast(ptr) = *reinterpret_cast(tmp); + } else if constexpr (VEC_WIDTH == 4) { + alignas(8) precision_t tmp[4]; + #pragma unroll + for (int lane = 0; lane < 4; lane++) { + tmp[lane] = from_float(in[lane]); + } + *reinterpret_cast(ptr) = *reinterpret_cast(tmp); + } else { + alignas(16) precision_t tmp[8]; + #pragma unroll + for (int lane = 0; lane < 8; lane++) { + tmp[lane] = from_float(in[lane]); + } + *reinterpret_cast(ptr) = *reinterpret_cast(tmp); + } +#endif +} + +template +__device__ __forceinline__ void scan_load_float_vec(const float* ptr, float* out) { +#ifdef PRECISION_FLOAT + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == 4, + "float build supports vec64/vec128 widths"); +#else + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == 4 || VEC_WIDTH == 8, + "bf16 build supports vec32/vec64/vec128 widths"); +#endif + if constexpr (VEC_WIDTH == 2) { + float2 v = *reinterpret_cast(ptr); + out[0] = v.x; + out[1] = v.y; + } else if constexpr (VEC_WIDTH == 4) { + float4 v = *reinterpret_cast(ptr); + out[0] = v.x; + out[1] = v.y; + out[2] = v.z; + out[3] = v.w; + } else { + float4 lo = *reinterpret_cast(ptr); + float4 hi = *reinterpret_cast(ptr + 4); + out[0] = lo.x; + out[1] = lo.y; + out[2] = lo.z; + out[3] = lo.w; + out[4] = hi.x; + out[5] = hi.y; + out[6] = hi.z; + out[7] = hi.w; + } +} + +template +__device__ __forceinline__ void scan_store_float_vec(float* ptr, const float* in) { +#ifdef PRECISION_FLOAT + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == 4, + "float build supports vec64/vec128 widths"); +#else + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == 4 || VEC_WIDTH == 8, + "bf16 build supports vec32/vec64/vec128 widths"); +#endif + if constexpr (VEC_WIDTH == 2) { + *reinterpret_cast(ptr) = make_float2(in[0], in[1]); + } else if constexpr (VEC_WIDTH == 4) { + *reinterpret_cast(ptr) = make_float4(in[0], in[1], in[2], in[3]); + } else { + *reinterpret_cast(ptr) = make_float4(in[0], in[1], in[2], in[3]); + *reinterpret_cast(ptr + 4) = make_float4(in[4], in[5], in[6], in[7]); + } +} + +template +__global__ void mingru_scan_forward_ckpt_tuned(PrefixScan scan) { + mingru_scan_forward_ckpt_tuned_body(scan); +} + +template +__global__ void mingru_scan_backward_ckpt_tuned(PrefixScan scan, + const precision_t* __restrict__ grad_out, + const precision_t* __restrict__ grad_next_state) { + mingru_scan_backward_ckpt_tuned_body(scan, grad_out, grad_next_state); +} + +template +__device__ __forceinline__ void mingru_scan_forward_ckpt_tuned_vec_body(PrefixScan scan) { + static_assert((VEC_WIDTH % 2) == 0, "vectorized kernels require even width"); +#ifdef PRECISION_FLOAT + static_assert(VEC_WIDTH == MINGRU_SCAN_VEC64_WIDTH || VEC_WIDTH == MINGRU_SCAN_VEC128_WIDTH, + "float build supports vec64/vec128 widths"); +#else + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == MINGRU_SCAN_VEC64_WIDTH || VEC_WIDTH == MINGRU_SCAN_VEC128_WIDTH, + "bf16 build supports vec32/vec64/vec128 widths"); +#endif + + int T_seq = scan.T, H = scan.H, B = scan.B; + int HW = H / VEC_WIDTH; + if (HW == 0) { + return; + } + + precision_t* __restrict__ out = scan.out.data; + precision_t* __restrict__ next_state = scan.next_state.data; + float* __restrict__ a_star_buf = scan.a_star.data; + float* __restrict__ s_buf = scan.s_vals.data; + float* __restrict__ log_values_buf = scan.log_values_buf.data; + const precision_t* __restrict__ combined = scan.combined_ptr; + const precision_t* __restrict__ state = scan.state_ptr; + const precision_t* __restrict__ input = scan.input_ptr; + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * HW) { + return; + } + + int b = idx / HW; + int hw = idx % HW; + int h = hw * VEC_WIDTH; + int bH = b * H; + int bTH = b * T_seq * H; + int cbase = 3 * bTH; + int H3 = 3 * H; + int H2 = 2 * H; + int out_base = bTH + h; + + const precision_t* combined_h_base = &combined[cbase + h]; + const precision_t* combined_g_base = &combined[cbase + H + h]; + const precision_t* combined_p_base = &combined[cbase + H2 + h]; + + float h_state[VEC_WIDTH]; + float a_star[VEC_WIDTH]; + float s[VEC_WIDTH]; + float log_value[VEC_WIDTH]; + float scan_result[VEC_WIDTH]; + float hidden[VEC_WIDTH]; + float gate[VEC_WIDTH]; + float proj[VEC_WIDTH]; + float x[VEC_WIDTH]; + float out_chunk[VEC_WIDTH]; + + scan_load_precision_vec(&state[bH + h], h_state); + #pragma unroll + for (int lane = 0; lane < VEC_WIDTH; lane += 2) { + float2 state_v = scan_load_pair(&h_state[lane]); + float2 s_v = make_float2(__logf(state_v.x), __logf(state_v.y)); + scan_store_pair(&a_star[lane], make_float2(0.0f, 0.0f)); + scan_store_pair(&s[lane], s_v); + scan_store_pair(&log_value[lane], s_v); + scan_store_pair(&scan_result[lane], state_v); + } + + int T_out = T_seq + 1; + int buf_base = b * T_out * H + h; + scan_store_float_vec(&a_star_buf[buf_base], a_star); + scan_store_float_vec(&s_buf[buf_base], s); + scan_store_float_vec(&log_values_buf[buf_base], log_value); + + int out_curr = out_base; + int t_offset = 0; + for (int t = 1; t <= T_seq; t++) { + scan_load_precision_vec(&combined_h_base[t_offset], hidden); + scan_load_precision_vec(&combined_g_base[t_offset], gate); + scan_load_precision_vec(&combined_p_base[t_offset], proj); + scan_load_precision_vec(&input[out_curr], x); + + #pragma unroll + for (int lane = 0; lane < VEC_WIDTH; lane += 2) { + float2 gate_v = scan_load_pair(&gate[lane]); + float2 hidden_v = scan_load_pair(&hidden[lane]); + float2 proj_v = scan_load_pair(&proj[lane]); + float2 x_v = scan_load_pair(&x[lane]); + float2 a_star_v = scan_load_pair(&a_star[lane]); + float2 s_v = scan_load_pair(&s[lane]); + float2 log_value_v = scan_load_pair(&log_value[lane]); + float2 log_coeff_v; + scan_log_coeffs_and_values_fwd_pair(gate_v, hidden_v, &log_coeff_v, &log_value_v); + a_star_v.x += log_coeff_v.x; + a_star_v.y += log_coeff_v.y; + float2 z_v = make_float2(log_value_v.x - a_star_v.x, log_value_v.y - a_star_v.y); + s_v = scan_logaddexp_pair(s_v, z_v); + float2 scan_result_v = make_float2(__expf(a_star_v.x + s_v.x), __expf(a_star_v.y + s_v.y)); + float2 proj_sigmoid_v = make_float2(sigmoid(proj_v.x), sigmoid(proj_v.y)); + float2 out_v = make_float2( + proj_sigmoid_v.x * scan_result_v.x + (1.0f - proj_sigmoid_v.x) * x_v.x, + proj_sigmoid_v.y * scan_result_v.y + (1.0f - proj_sigmoid_v.y) * x_v.y); + scan_store_pair(&a_star[lane], a_star_v); + scan_store_pair(&s[lane], s_v); + scan_store_pair(&log_value[lane], log_value_v); + scan_store_pair(&scan_result[lane], scan_result_v); + scan_store_pair(&out_chunk[lane], out_v); + } + scan_store_precision_vec(&out[out_curr], out_chunk); + + if ((t % CKPT_INTERVAL) == 0) { + int buf_idx = buf_base + t * H; + scan_store_float_vec(&a_star_buf[buf_idx], a_star); + scan_store_float_vec(&s_buf[buf_idx], s); + scan_store_float_vec(&log_values_buf[buf_idx], log_value); + } + + out_curr += H; + t_offset += H3; + } + + scan_store_precision_vec(&next_state[bH + h], scan_result); +} + +template +__device__ __forceinline__ void mingru_scan_backward_ckpt_tuned_vec_body( + PrefixScan scan, const precision_t* __restrict__ grad_out, + const precision_t* __restrict__ grad_next_state) { + static_assert((VEC_WIDTH % 2) == 0, "vectorized kernels require even width"); +#ifdef PRECISION_FLOAT + static_assert(VEC_WIDTH == MINGRU_SCAN_VEC64_WIDTH || VEC_WIDTH == MINGRU_SCAN_VEC128_WIDTH, + "float build supports vec64/vec128 widths"); +#else + static_assert(VEC_WIDTH == 2 || VEC_WIDTH == MINGRU_SCAN_VEC64_WIDTH || VEC_WIDTH == MINGRU_SCAN_VEC128_WIDTH, + "bf16 build supports vec32/vec64/vec128 widths"); +#endif + + int T_seq = scan.T, H = scan.H, B = scan.B; + int HW = H / VEC_WIDTH; + if (HW == 0) { + return; + } + + precision_t* __restrict__ grad_combined = scan.grad_combined.data; + precision_t* __restrict__ grad_state = scan.grad_state.data; + precision_t* __restrict__ grad_input = scan.grad_input.data; + const precision_t* __restrict__ combined = scan.combined_ptr; + const precision_t* __restrict__ state = scan.state_ptr; + const precision_t* __restrict__ input = scan.input_ptr; + const float* __restrict__ a_star_buf = scan.a_star.data; + const float* __restrict__ s_buf = scan.s_vals.data; + const float* __restrict__ log_values_buf = scan.log_values_buf.data; + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * HW) { + return; + } + + int b = idx / HW; + int hw = idx % HW; + int h = hw * VEC_WIDTH; + + int bH = b * H; + int bTH = b * T_seq * H; + int cbase = 3 * bTH; + int H3 = 3 * H; + int H2 = 2 * H; + int out_base = bTH + h; + + const precision_t* combined_h_base = &combined[cbase + h]; + const precision_t* combined_g_base = &combined[cbase + H + h]; + const precision_t* combined_p_base = &combined[cbase + H2 + h]; + + precision_t* grad_combined_h_base = &grad_combined[cbase + h]; + precision_t* grad_combined_g_base = &grad_combined[cbase + H + h]; + precision_t* grad_combined_p_base = &grad_combined[cbase + H2 + h]; + + int T_out = T_seq + 1; + int buf_base = b * T_out * H + h; + + float acc[VEC_WIDTH]; + float s_val_next[VEC_WIDTH]; + float carry_grad_a[VEC_WIDTH]; + float grad_next[VEC_WIDTH]; + #pragma unroll + for (int lane = 0; lane < VEC_WIDTH; lane++) { + acc[lane] = 0.0f; + s_val_next[lane] = 0.0f; + carry_grad_a[lane] = 0.0f; + } + scan_load_precision_vec(&grad_next_state[bH + h], grad_next); + + for (int chunk_end = T_seq; chunk_end > 0;) { + int chunk_start = ((chunk_end - 1) / CKPT_INTERVAL) * CKPT_INTERVAL; + int chunk_len = chunk_end - chunk_start; + + float chunk_a_star[CKPT_INTERVAL][VEC_WIDTH]; + float chunk_s[CKPT_INTERVAL][VEC_WIDTH]; + float chunk_log_values[CKPT_INTERVAL][VEC_WIDTH]; + + float recomp_a_star[VEC_WIDTH]; + float recomp_s[VEC_WIDTH]; + float recomp_log_value[VEC_WIDTH]; + int ckpt_buf_idx = buf_base + chunk_start * H; + scan_load_float_vec(&a_star_buf[ckpt_buf_idx], recomp_a_star); + scan_load_float_vec(&s_buf[ckpt_buf_idx], recomp_s); + scan_load_float_vec(&log_values_buf[ckpt_buf_idx], recomp_log_value); + + // Phase 1: recompute per-timestep values from checkpoint start to chunk end. + for (int chunk_i = 0; chunk_i < chunk_len; ++chunk_i) { + int t = chunk_start + 1 + chunk_i; + int t_offset = (t - 1) * H3; + float hidden[VEC_WIDTH]; + float gate[VEC_WIDTH]; + scan_load_precision_vec(&combined_h_base[t_offset], hidden); + scan_load_precision_vec(&combined_g_base[t_offset], gate); + #pragma unroll + for (int lane = 0; lane < VEC_WIDTH; lane += 2) { + float2 gate_v = scan_load_pair(&gate[lane]); + float2 hidden_v = scan_load_pair(&hidden[lane]); + float2 recomp_a_star_v = scan_load_pair(&recomp_a_star[lane]); + float2 recomp_s_v = scan_load_pair(&recomp_s[lane]); + float2 recomp_log_value_v = scan_load_pair(&recomp_log_value[lane]); + float2 log_coeff_v; + scan_log_coeffs_and_values_fwd_pair(gate_v, hidden_v, &log_coeff_v, &recomp_log_value_v); + recomp_a_star_v.x += log_coeff_v.x; + recomp_a_star_v.y += log_coeff_v.y; + float2 z_v = make_float2( + recomp_log_value_v.x - recomp_a_star_v.x, + recomp_log_value_v.y - recomp_a_star_v.y); + recomp_s_v = scan_logaddexp_pair(recomp_s_v, z_v); + scan_store_pair(&recomp_a_star[lane], recomp_a_star_v); + scan_store_pair(&recomp_s[lane], recomp_s_v); + scan_store_pair(&recomp_log_value[lane], recomp_log_value_v); + scan_store_pair(&chunk_a_star[chunk_i][lane], recomp_a_star_v); + scan_store_pair(&chunk_s[chunk_i][lane], recomp_s_v); + scan_store_pair(&chunk_log_values[chunk_i][lane], recomp_log_value_v); + } + } + + // Phase 2: backprop through the chunk in reverse time order. + for (int chunk_i = chunk_len - 1; chunk_i >= 0; --chunk_i) { + int t = chunk_start + 1 + chunk_i; + int t_offset = (t - 1) * H3; + int input_idx = out_base + (t - 1) * H; + const bool is_last_t = (t == T_seq); + + float proj[VEC_WIDTH]; + float x[VEC_WIDTH]; + float grad_out_chunk[VEC_WIDTH]; + float grad_hidden[VEC_WIDTH]; + float grad_gate[VEC_WIDTH]; + float grad_proj[VEC_WIDTH]; + float grad_in[VEC_WIDTH]; + float hidden[VEC_WIDTH]; + float gate[VEC_WIDTH]; + scan_load_precision_vec(&combined_p_base[t_offset], proj); + scan_load_precision_vec(&combined_h_base[t_offset], hidden); + scan_load_precision_vec(&combined_g_base[t_offset], gate); + scan_load_precision_vec(&input[input_idx], x); + scan_load_precision_vec(&grad_out[input_idx], grad_out_chunk); + + #pragma unroll + for (int lane = 0; lane < VEC_WIDTH; lane += 2) { + float2 a_star_t_v = scan_load_pair(&chunk_a_star[chunk_i][lane]); + float2 s_t_v = scan_load_pair(&chunk_s[chunk_i][lane]); + float2 log_value_t_v = scan_load_pair(&chunk_log_values[chunk_i][lane]); + float2 scan_result_v = make_float2( + __expf(a_star_t_v.x + s_t_v.x), + __expf(a_star_t_v.y + s_t_v.y)); + float2 z_v = make_float2( + log_value_t_v.x - a_star_t_v.x, + log_value_t_v.y - a_star_t_v.y); + + float2 proj_v = scan_load_pair(&proj[lane]); + float2 x_v = scan_load_pair(&x[lane]); + float2 grad_out_v = scan_load_pair(&grad_out_chunk[lane]); + float2 proj_sigmoid_v = make_float2(sigmoid(proj_v.x), sigmoid(proj_v.y)); + float2 grad_in_v = make_float2( + grad_out_v.x * (1.0f - proj_sigmoid_v.x), + grad_out_v.y * (1.0f - proj_sigmoid_v.y)); + float2 grad_proj_v = make_float2( + grad_out_v.x * (scan_result_v.x - x_v.x) * proj_sigmoid_v.x * (1.0f - proj_sigmoid_v.x), + grad_out_v.y * (scan_result_v.y - x_v.y) * proj_sigmoid_v.y * (1.0f - proj_sigmoid_v.y)); + scan_store_pair(&grad_in[lane], grad_in_v); + scan_store_pair(&grad_proj[lane], grad_proj_v); + + float2 grad_scan_from_next_v = is_last_t + ? scan_load_pair(&grad_next[lane]) + : make_float2(0.0f, 0.0f); + float2 grad_scan_result_v = make_float2( + grad_scan_from_next_v.x + grad_out_v.x * proj_sigmoid_v.x, + grad_scan_from_next_v.y + grad_out_v.y * proj_sigmoid_v.y); + float2 grad_log_h_v = make_float2( + grad_scan_result_v.x * scan_result_v.x, + grad_scan_result_v.y * scan_result_v.y); + + float2 acc_v; + if (is_last_t) { + acc_v = grad_log_h_v; + } else { + float2 acc_prev_v = scan_load_pair(&acc[lane]); + float2 s_val_next_v = scan_load_pair(&s_val_next[lane]); + acc_v = make_float2( + grad_log_h_v.x + acc_prev_v.x * __expf(s_t_v.x - s_val_next_v.x), + grad_log_h_v.y + acc_prev_v.y * __expf(s_t_v.y - s_val_next_v.y)); + } + scan_store_pair(&acc[lane], acc_v); + + float2 grad_z_v = make_float2( + acc_v.x * __expf(z_v.x - s_t_v.x), + acc_v.y * __expf(z_v.y - s_t_v.y)); + scan_store_pair(&s_val_next[lane], s_t_v); + + float2 carry_grad_a_v = scan_load_pair(&carry_grad_a[lane]); + float2 grad_a_v = make_float2( + grad_log_h_v.x + carry_grad_a_v.x - grad_z_v.x, + grad_log_h_v.y + carry_grad_a_v.y - grad_z_v.y); + scan_store_pair(&carry_grad_a[lane], grad_a_v); + + float2 gate_v = scan_load_pair(&gate[lane]); + float2 hidden_v = scan_load_pair(&hidden[lane]); + float2 grad_gate_v; + float2 grad_hidden_v; + scan_log_coeffs_and_values_bwd_pair( + grad_a_v, grad_z_v, gate_v, hidden_v, &grad_gate_v, &grad_hidden_v); + scan_store_pair(&grad_gate[lane], grad_gate_v); + scan_store_pair(&grad_hidden[lane], grad_hidden_v); + } + + scan_store_precision_vec(&grad_combined_h_base[t_offset], grad_hidden); + scan_store_precision_vec(&grad_combined_g_base[t_offset], grad_gate); + scan_store_precision_vec(&grad_combined_p_base[t_offset], grad_proj); + scan_store_precision_vec(&grad_input[input_idx], grad_in); + } + + chunk_end = chunk_start; + } + + float a_star_0[VEC_WIDTH]; + float s_0[VEC_WIDTH]; + float log_value_0[VEC_WIDTH]; + float grad_state_chunk[VEC_WIDTH]; + float state_chunk[VEC_WIDTH]; + int ckpt_0_idx = buf_base; + scan_load_float_vec(&a_star_buf[ckpt_0_idx], a_star_0); + scan_load_float_vec(&s_buf[ckpt_0_idx], s_0); + scan_load_float_vec(&log_values_buf[ckpt_0_idx], log_value_0); + scan_load_precision_vec(&state[bH + h], state_chunk); + #pragma unroll + for (int lane = 0; lane < VEC_WIDTH; lane += 2) { + float2 a_star_0_v = scan_load_pair(&a_star_0[lane]); + float2 s_0_v = scan_load_pair(&s_0[lane]); + float2 log_value_0_v = scan_load_pair(&log_value_0[lane]); + float2 acc_v = scan_load_pair(&acc[lane]); + float2 s_val_next_v = scan_load_pair(&s_val_next[lane]); + float2 state_chunk_v = scan_load_pair(&state_chunk[lane]); + + float2 z_0_v = make_float2( + log_value_0_v.x - a_star_0_v.x, + log_value_0_v.y - a_star_0_v.y); + acc_v.x = acc_v.x * __expf(s_0_v.x - s_val_next_v.x); + acc_v.y = acc_v.y * __expf(s_0_v.y - s_val_next_v.y); + float2 grad_z_0_v = make_float2( + acc_v.x * __expf(z_0_v.x - s_0_v.x), + acc_v.y * __expf(z_0_v.y - s_0_v.y)); + float2 grad_state_v = make_float2( + grad_z_0_v.x / state_chunk_v.x, + grad_z_0_v.y / state_chunk_v.y); + scan_store_pair(&grad_state_chunk[lane], grad_state_v); + } + scan_store_precision_vec(&grad_state[bH + h], grad_state_chunk); +} + +#ifndef PRECISION_FLOAT +template +__global__ void mingru_scan_forward_ckpt_tuned_vec32(PrefixScan scan) { + mingru_scan_forward_ckpt_tuned_vec_body(scan); +} + +template +__global__ void mingru_scan_backward_ckpt_tuned_vec32(PrefixScan scan, + const precision_t* __restrict__ grad_out, + const precision_t* __restrict__ grad_next_state) { + mingru_scan_backward_ckpt_tuned_vec_body( + scan, grad_out, grad_next_state); +} +#endif + +template +__global__ void mingru_scan_forward_ckpt_tuned_vec64(PrefixScan scan) { + mingru_scan_forward_ckpt_tuned_vec_body(scan); +} + +template +__global__ void mingru_scan_backward_ckpt_tuned_vec64(PrefixScan scan, + const precision_t* __restrict__ grad_out, + const precision_t* __restrict__ grad_next_state) { + mingru_scan_backward_ckpt_tuned_vec_body( + scan, grad_out, grad_next_state); +} + +template +__global__ void mingru_scan_forward_ckpt_tuned_vec128(PrefixScan scan) { + mingru_scan_forward_ckpt_tuned_vec_body(scan); +} + +template +__global__ void mingru_scan_backward_ckpt_tuned_vec128(PrefixScan scan, + const precision_t* __restrict__ grad_out, + const precision_t* __restrict__ grad_next_state) { + mingru_scan_backward_ckpt_tuned_vec_body( + scan, grad_out, grad_next_state); +} + +enum class MingruScanVariant : int { + kScalar = 0, + kVec32 = 1, + kVec64 = 2, + kVec128 = 3, +}; + +struct MingruScanKernelSelection { + int ckpt_interval; + MingruScanVariant fwd_variant; + MingruScanVariant bwd_variant; +}; + +static inline const char* mingru_scan_variant_name(MingruScanVariant variant) { + switch (variant) { + case MingruScanVariant::kScalar: return "log_scalar"; + case MingruScanVariant::kVec32: return "log_vec32"; + case MingruScanVariant::kVec64: return "log_vec64"; + case MingruScanVariant::kVec128: return "log_vec128"; + default: return "log_scalar"; + } +} + +static inline bool mingru_scan_variant_supported_for_hidden(MingruScanVariant variant, int H) { + switch (variant) { + case MingruScanVariant::kScalar: + return true; + case MingruScanVariant::kVec32: +#ifdef PRECISION_FLOAT + return false; +#else + return (H % 2) == 0; +#endif + case MingruScanVariant::kVec64: + return (H % MINGRU_SCAN_VEC64_WIDTH) == 0; + case MingruScanVariant::kVec128: + return (H % MINGRU_SCAN_VEC128_WIDTH) == 0; + default: + return false; + } +} + +static inline MingruScanVariant mingru_scan_resolve_variant(MingruScanVariant variant, int H) { + if (mingru_scan_variant_supported_for_hidden(variant, H)) { + return variant; + } + switch (variant) { + case MingruScanVariant::kVec128: + if (mingru_scan_variant_supported_for_hidden(MingruScanVariant::kVec64, H)) { + return MingruScanVariant::kVec64; + } +#ifndef PRECISION_FLOAT + if (mingru_scan_variant_supported_for_hidden(MingruScanVariant::kVec32, H)) { + return MingruScanVariant::kVec32; + } +#endif + return MingruScanVariant::kScalar; + case MingruScanVariant::kVec64: +#ifndef PRECISION_FLOAT + if (mingru_scan_variant_supported_for_hidden(MingruScanVariant::kVec32, H)) { + return MingruScanVariant::kVec32; + } +#endif + return MingruScanVariant::kScalar; + case MingruScanVariant::kVec32: + case MingruScanVariant::kScalar: + default: + return MingruScanVariant::kScalar; + } +} + +static inline MingruScanKernelSelection mingru_scan_select_baseline_policy(int /*B*/, int /*T*/, int H) { + MingruScanKernelSelection selection = { + CHECKPOINT_INTERVAL, + MingruScanVariant::kScalar, + MingruScanVariant::kScalar, + }; + selection.fwd_variant = mingru_scan_resolve_variant(selection.fwd_variant, H); + selection.bwd_variant = mingru_scan_resolve_variant(selection.bwd_variant, H); + return selection; +} + +// Depth-2 trees fitted on fused_scan sweep results (B, T, H grid). +static inline MingruScanKernelSelection mingru_scan_select_depth2_policy(int B, int T, int H) { + int64_t HB = (int64_t)B * (int64_t)H; + int64_t HT = (int64_t)H * (int64_t)T; + int64_t HBT = HB * (int64_t)T; + + int ckpt_interval = 4; + if (HBT < 4194304LL) { + ckpt_interval = (HBT < 1048576LL) ? 16 : 1; + } else { + ckpt_interval = (HB < 262144LL) ? 8 : 4; + } + + MingruScanVariant fwd_variant; + if (HB < 262144LL) { + fwd_variant = (HBT < 67108864LL) ? MingruScanVariant::kScalar : MingruScanVariant::kVec32; + } else { + fwd_variant = (HT < 262144LL) ? MingruScanVariant::kVec64 : MingruScanVariant::kVec128; + } + + MingruScanVariant bwd_variant; + if (HB < 262144LL) { + bwd_variant = (HB < 131072LL) ? MingruScanVariant::kScalar : MingruScanVariant::kVec32; + } else { + bwd_variant = (HB < 524288LL) ? MingruScanVariant::kVec64 : MingruScanVariant::kVec32; + } + + MingruScanKernelSelection selection = {ckpt_interval, fwd_variant, bwd_variant}; + selection.fwd_variant = mingru_scan_resolve_variant(selection.fwd_variant, H); + selection.bwd_variant = mingru_scan_resolve_variant(selection.bwd_variant, H); + return selection; +} + +template +static inline void mingru_scan_launch_forward_ckpt( + PrefixScan scan, MingruScanVariant variant, cudaStream_t stream) { + switch (variant) { + case MingruScanVariant::kVec128: + mingru_scan_forward_ckpt_tuned_vec128<<< + grid_size(scan.B * (scan.H / MINGRU_SCAN_VEC128_WIDTH)), + BLOCK_SIZE, 0, stream>>>(scan); + return; + case MingruScanVariant::kVec64: + mingru_scan_forward_ckpt_tuned_vec64<<< + grid_size(scan.B * (scan.H / MINGRU_SCAN_VEC64_WIDTH)), + BLOCK_SIZE, 0, stream>>>(scan); + return; + case MingruScanVariant::kVec32: +#ifndef PRECISION_FLOAT + mingru_scan_forward_ckpt_tuned_vec32<<< + grid_size(scan.B * (scan.H / 2)), + BLOCK_SIZE, 0, stream>>>(scan); + return; +#else + break; +#endif + case MingruScanVariant::kScalar: + default: + mingru_scan_forward_ckpt_tuned<<< + grid_size(scan.B * scan.H), + BLOCK_SIZE, 0, stream>>>(scan); + return; + } +} + +template +static inline void mingru_scan_launch_backward_ckpt( + PrefixScan scan, const precision_t* grad_out, const precision_t* grad_next_state, + MingruScanVariant variant, cudaStream_t stream) { + switch (variant) { + case MingruScanVariant::kVec128: + mingru_scan_backward_ckpt_tuned_vec128<<< + grid_size(scan.B * (scan.H / MINGRU_SCAN_VEC128_WIDTH)), + BLOCK_SIZE, 0, stream>>>( + scan, grad_out, grad_next_state); + return; + case MingruScanVariant::kVec64: + mingru_scan_backward_ckpt_tuned_vec64<<< + grid_size(scan.B * (scan.H / MINGRU_SCAN_VEC64_WIDTH)), + BLOCK_SIZE, 0, stream>>>( + scan, grad_out, grad_next_state); + return; + case MingruScanVariant::kVec32: +#ifndef PRECISION_FLOAT + mingru_scan_backward_ckpt_tuned_vec32<<< + grid_size(scan.B * (scan.H / 2)), + BLOCK_SIZE, 0, stream>>>( + scan, grad_out, grad_next_state); + return; +#else + break; +#endif + case MingruScanVariant::kScalar: + default: + mingru_scan_backward_ckpt_tuned<<< + grid_size(scan.B * scan.H), + BLOCK_SIZE, 0, stream>>>( + scan, grad_out, grad_next_state); + return; + } +} + +static inline void mingru_scan_launch_forward_selected( + PrefixScan scan, int ckpt_interval, MingruScanVariant variant, cudaStream_t stream) { + variant = mingru_scan_resolve_variant(variant, scan.H); + switch (ckpt_interval) { + case 1: mingru_scan_launch_forward_ckpt<1>(scan, variant, stream); return; + case 2: mingru_scan_launch_forward_ckpt<2>(scan, variant, stream); return; + case 4: mingru_scan_launch_forward_ckpt<4>(scan, variant, stream); return; + case 8: mingru_scan_launch_forward_ckpt<8>(scan, variant, stream); return; + case 16: mingru_scan_launch_forward_ckpt<16>(scan, variant, stream); return; + case 32: mingru_scan_launch_forward_ckpt<32>(scan, variant, stream); return; + default: + mingru_scan_launch_forward_ckpt(scan, variant, stream); + return; + } +} + +static inline void mingru_scan_launch_backward_selected( + PrefixScan scan, const precision_t* grad_out, const precision_t* grad_next_state, + int ckpt_interval, MingruScanVariant variant, cudaStream_t stream) { + variant = mingru_scan_resolve_variant(variant, scan.H); + switch (ckpt_interval) { + case 1: + mingru_scan_launch_backward_ckpt<1>( + scan, grad_out, grad_next_state, variant, stream); + return; + case 2: + mingru_scan_launch_backward_ckpt<2>( + scan, grad_out, grad_next_state, variant, stream); + return; + case 4: + mingru_scan_launch_backward_ckpt<4>( + scan, grad_out, grad_next_state, variant, stream); + return; + case 8: + mingru_scan_launch_backward_ckpt<8>( + scan, grad_out, grad_next_state, variant, stream); + return; + case 16: + mingru_scan_launch_backward_ckpt<16>( + scan, grad_out, grad_next_state, variant, stream); + return; + case 32: + mingru_scan_launch_backward_ckpt<32>( + scan, grad_out, grad_next_state, variant, stream); + return; + default: + mingru_scan_launch_backward_ckpt( + scan, grad_out, grad_next_state, variant, stream); + return; + } +} + + __global__ void sum_rows_to_precision_kernel(precision_t* __restrict__ dst, const float* __restrict__ src, int R, int C) { int col = blockIdx.x * blockDim.x + threadIdx.x; @@ -598,16 +1433,16 @@ struct MinGRUWeights { PrecisionTensor* weights; // [num_layers] }; -static PrecisionTensor mingru_state_layer(MinGRUWeights* m, PrecisionTensor& state, int i) { +static PrecisionTensor mingru_state_layer(MinGRUWeights* m, PrecisionTensor& state, int layer_i) { long B = state.shape[1], H = state.shape[2]; - return {.data = state.data + i * B * H, .shape = {B, H}}; + return {.data = state.data + layer_i * B * H, .shape = {B, H}}; } static void mingru_init_weights(void* w, ulong* seed, cudaStream_t stream) { MinGRUWeights* m = (MinGRUWeights*)w; - for (int i = 0; i < m->num_layers; i++) { + for (int layer_i = 0; layer_i < m->num_layers; layer_i++) { PrecisionTensor w2d = { - .data = m->weights[i].data, + .data = m->weights[layer_i].data, .shape = {3 * m->hidden, m->hidden}, }; puf_kaiming_init(&w2d, 1.0f, (*seed)++, stream); @@ -616,9 +1451,9 @@ static void mingru_init_weights(void* w, ulong* seed, cudaStream_t stream) { static void mingru_reg_params(void* w, Allocator* alloc) { MinGRUWeights* m = (MinGRUWeights*)w; - for (int i = 0; i < m->num_layers; i++) { - m->weights[i] = {.shape = {3 * m->hidden, m->hidden}}; - alloc_register(alloc,&m->weights[i]); + for (int layer_i = 0; layer_i < m->num_layers; layer_i++) { + m->weights[layer_i] = {.shape = {3 * m->hidden, m->hidden}}; + alloc_register(alloc,&m->weights[layer_i]); } } @@ -635,8 +1470,8 @@ static void mingru_reg_train(void* w, void* activations, Allocator* acts, Alloca a->grad_next_state = {.shape = {B, 1, H}}; alloc_register(acts,&a->grad_input_buf); alloc_register(acts,&a->grad_next_state); - for (int i = 0; i < m->num_layers; i++) { - a->scan_bufs[i] = { + for (int layer_i = 0; layer_i < m->num_layers; layer_i++) { + a->scan_bufs[layer_i] = { .B = B, .T = TT, .H = H, .a_star = {.shape = {B, TT + 1, H}}, .s_vals = {.shape = {B, TT + 1, H}}, @@ -647,20 +1482,20 @@ static void mingru_reg_train(void* w, void* activations, Allocator* acts, Alloca .grad_state = {.shape = {B, 1, H}}, .grad_input = {.shape = {B, TT, H}}, }; - a->saved_inputs[i] = {.shape = {B, TT, H}}; - a->combined_bufs[i] = {.shape = {B_TT, 3 * H}}; - a->wgrad_scratch[i] = {.shape = {3 * H, H}}; - alloc_register(acts,&a->saved_inputs[i]); - alloc_register(acts,&a->combined_bufs[i]); - alloc_register(acts,&a->scan_bufs[i].out); - alloc_register(acts,&a->scan_bufs[i].next_state); - alloc_register(acts,&a->scan_bufs[i].a_star); - alloc_register(acts,&a->scan_bufs[i].s_vals); - alloc_register(acts,&a->scan_bufs[i].log_values_buf); - alloc_register(acts,&a->scan_bufs[i].grad_combined); - alloc_register(acts,&a->scan_bufs[i].grad_state); - alloc_register(acts,&a->scan_bufs[i].grad_input); - alloc_register(grads,&a->wgrad_scratch[i]); + a->saved_inputs[layer_i] = {.shape = {B, TT, H}}; + a->combined_bufs[layer_i] = {.shape = {B_TT, 3 * H}}; + a->wgrad_scratch[layer_i] = {.shape = {3 * H, H}}; + alloc_register(acts,&a->saved_inputs[layer_i]); + alloc_register(acts,&a->combined_bufs[layer_i]); + alloc_register(acts,&a->scan_bufs[layer_i].out); + alloc_register(acts,&a->scan_bufs[layer_i].next_state); + alloc_register(acts,&a->scan_bufs[layer_i].a_star); + alloc_register(acts,&a->scan_bufs[layer_i].s_vals); + alloc_register(acts,&a->scan_bufs[layer_i].log_values_buf); + alloc_register(acts,&a->scan_bufs[layer_i].grad_combined); + alloc_register(acts,&a->scan_bufs[layer_i].grad_state); + alloc_register(acts,&a->scan_bufs[layer_i].grad_input); + alloc_register(grads,&a->wgrad_scratch[layer_i]); } } @@ -670,9 +1505,9 @@ static void mingru_reg_rollout(void* weights, void* activations, Allocator* allo int H = w->hidden; a->num_layers = w->num_layers; a->combined = (PrecisionTensor*)calloc(w->num_layers, sizeof(PrecisionTensor)); - for (int i = 0; i < w->num_layers; i++) { - a->combined[i] = {.shape = {B_inf, 3 * H}}; - alloc_register(alloc,&a->combined[i]); + for (int layer_i = 0; layer_i < w->num_layers; layer_i++) { + a->combined[layer_i] = {.shape = {B_inf, 3 * H}}; + alloc_register(alloc,&a->combined[layer_i]); } a->out = {.shape = {B_inf, H}}; a->next_state = {.shape = {B_inf, H}}; @@ -683,7 +1518,9 @@ static void mingru_reg_rollout(void* weights, void* activations, Allocator* allo static void* mingru_create_weights(void* self) { Network* n = (Network*)self; MinGRUWeights* mw = (MinGRUWeights*)calloc(1, sizeof(MinGRUWeights)); - mw->hidden = n->hidden; mw->num_layers = n->num_layers; mw->horizon = n->horizon; + mw->hidden = n->hidden; + mw->num_layers = n->num_layers; + mw->horizon = n->horizon; mw->weights = (PrecisionTensor*)calloc(n->num_layers, sizeof(PrecisionTensor)); return mw; } @@ -706,12 +1543,12 @@ static PrecisionTensor mingru_forward(void* w, PrecisionTensor x, PrecisionTenso MinGRUActivations* a = (MinGRUActivations*)activations; int B = state.shape[1]; int H = state.shape[2]; - for (int i = 0; i < m->num_layers; i++) { - PrecisionTensor state_i = mingru_state_layer(m, state, i); - puf_mm(&x, &m->weights[i], &a->combined[i], stream); + for (int layer_i = 0; layer_i < m->num_layers; layer_i++) { + PrecisionTensor state_i = mingru_state_layer(m, state, layer_i); + puf_mm(&x, &m->weights[layer_i], &a->combined[layer_i], stream); mingru_gate<<>>( a->out.data, a->next_state.data, - a->combined[i].data, state_i.data, x.data, H, B); + a->combined[layer_i].data, state_i.data, x.data, H, B); puf_copy(&state_i, &a->next_state, stream); x = a->out; } @@ -722,16 +1559,18 @@ static PrecisionTensor mingru_forward_train(void* w, PrecisionTensor x, Precisio void* activations, cudaStream_t stream) { MinGRUWeights* m = (MinGRUWeights*)w; MinGRUActivations* a = (MinGRUActivations*)activations; - int B = x.shape[0]; - for (int i = 0; i < m->num_layers; i++) { - puf_copy(&a->saved_inputs[i], &x, stream); - PrecisionTensor state_i = mingru_state_layer(m, state, i); - puf_mm(&x, &m->weights[i], &a->combined_bufs[i], stream); - a->scan_bufs[i].combined_ptr = a->combined_bufs[i].data; - a->scan_bufs[i].state_ptr = state_i.data; - a->scan_bufs[i].input_ptr = a->saved_inputs[i].data; - mingru_scan_forward<<hidden), BLOCK_SIZE, 0, stream>>>(a->scan_bufs[i]); - x = a->scan_bufs[i].out; + MingruScanKernelSelection scan_selection = mingru_scan_select_depth2_policy( + state.shape[1], m->horizon, m->hidden); + for (int layer_i = 0; layer_i < m->num_layers; layer_i++) { + puf_copy(&a->saved_inputs[layer_i], &x, stream); + PrecisionTensor state_i = mingru_state_layer(m, state, layer_i); + puf_mm(&x, &m->weights[layer_i], &a->combined_bufs[layer_i], stream); + a->scan_bufs[layer_i].combined_ptr = a->combined_bufs[layer_i].data; + a->scan_bufs[layer_i].state_ptr = state_i.data; + a->scan_bufs[layer_i].input_ptr = a->saved_inputs[layer_i].data; + mingru_scan_launch_forward_selected( + a->scan_bufs[layer_i], scan_selection.ckpt_interval, scan_selection.fwd_variant, stream); + x = a->scan_bufs[layer_i].out; } return x; } @@ -739,12 +1578,15 @@ static PrecisionTensor mingru_forward_train(void* w, PrecisionTensor x, Precisio static PrecisionTensor mingru_backward(void* w, PrecisionTensor grad, void* activations, cudaStream_t stream) { MinGRUWeights* m = (MinGRUWeights*)w; MinGRUActivations* a = (MinGRUActivations*)activations; - for (int i = m->num_layers - 1; i >= 0; i--) { - PrefixScan& scan = a->scan_bufs[i]; - mingru_scan_backward<<>>( - scan, grad.data, a->grad_next_state.data); - puf_mm_tn(&scan.grad_combined, &a->saved_inputs[i], &a->wgrad_scratch[i], stream); - puf_mm_nn(&scan.grad_combined, &m->weights[i], &a->grad_input_buf, stream); + MingruScanKernelSelection scan_selection = mingru_scan_select_depth2_policy( + a->scan_bufs[0].B, a->scan_bufs[0].T, a->scan_bufs[0].H); + for (int layer_i = m->num_layers - 1; layer_i >= 0; layer_i--) { + PrefixScan& scan = a->scan_bufs[layer_i]; + mingru_scan_launch_backward_selected( + scan, grad.data, a->grad_next_state.data, + scan_selection.ckpt_interval, scan_selection.bwd_variant, stream); + puf_mm_tn(&scan.grad_combined, &a->saved_inputs[layer_i], &a->wgrad_scratch[layer_i], stream); + puf_mm_nn(&scan.grad_combined, &m->weights[layer_i], &a->grad_input_buf, stream); int n = numel(scan.grad_input.shape); add_kernel<<>>( a->grad_input_buf.data, scan.grad_input.data, n); diff --git a/tests/profile_kernels.cu b/tests/profile_kernels.cu index 880eeab2ee..532ab3b99d 100644 --- a/tests/profile_kernels.cu +++ b/tests/profile_kernels.cu @@ -4,12 +4,15 @@ #include #include #include +#include +#include #include "pufferlib.cu" #include "ini.h" const int WARMUP_ITERS = 100; const int TIMING_ITERS = 1000; +constexpr int kFusedscanBlockSweepBlockSize = 256; const int BUF = 2; const int BR = 4096; // Rollout batch (no T dim) @@ -17,7 +20,6 @@ const int BT = 512; // Train batch (with T dim) const int T_ = 64; // T_ to avoid collision with PrefixScan::T const int H_ = 128; const int A_ = 4; -const int INPUT_SIZE = 96; #ifndef ENV_NAME #error "ENV_NAME must be defined at compile time (e.g. -DENV_NAME=breakout)" @@ -26,22 +28,30 @@ const int INPUT_SIZE = 96; #define TOSTRING(x) STRINGIFY(x) typedef void (*kernel_fn)(void*); +void print_selected_sweep_sizes(); void print_usage(const char* prog) { printf("Usage: %s \n", prog); printf("\nProfiles:\n"); - printf(" kernels - All individual kernel microbenchmarks\n"); - printf(" mingrugate - MinGRU gate kernel only\n"); - printf(" logcoeffsvals - log_coeffs_and_values fwd+bwd\n"); - printf(" fusedscan - Fused scan (checkpointed) kernel only\n"); - printf(" samplelogits - Sample logits kernel only\n"); - printf(" ppoloss - PPO loss fused fwd+bwd kernel\n"); - printf(" im2col - im2col + col2im (nmmo3 conv sizes, B=1024)\n"); - printf(" envspeed - Environment step throughput\n"); - printf(" --buffers N - Number of buffers (default: %d)\n", BUF); - printf(" --threads N - Number of threads (default: 16)\n"); - printf(" --horizon N - Horizon length (default: %d)\n", T_); - printf(" all - Run all available profiles\n"); + printf(" kernels - All individual kernel microbenchmarks\n"); + printf(" mingrugate - MinGRU gate kernel only\n"); + printf(" logcoeffsvals - log_coeffs_and_values fwd+bwd\n"); + printf(" fusedscan - Fused scan (checkpointed) kernel only\n"); + printf(" fusedscan_correctness - Correctness-only check: log vec variants vs log_scalar\n"); + printf(" fusedscan_sweep - Fixed-block (256) sweep for log scalar/vec32/vec64/vec128 kernels\n"); + printf(" fusedscan_selector_bench - Baseline selector vs depth2 selector benchmark\n"); + printf(" samplelogits - Sample logits kernel only\n"); + printf(" ppoloss - PPO loss fused fwd+bwd kernel\n"); + printf(" im2col - im2col + col2im (nmmo3 conv sizes, B=1024)\n"); + printf(" envspeed - Environment step throughput\n"); + printf(" --ckpt-intervals CSV - Requested checkpoint intervals for fusedscan sweeps (e.g. 1,4,8)\n"); + printf(" --b-sizes CSV - Sweep B sizes for fusedscan *_sweep profiles (e.g. 64,128,256)\n"); + printf(" --t-sizes CSV - Sweep T sizes for fusedscan *_sweep profiles (e.g. 64,128,256)\n"); + printf(" --h-sizes CSV - Sweep H sizes for fusedscan *_sweep profiles (e.g. 128,256,512)\n"); + printf(" --buffers N - Number of buffers (default: %d)\n", BUF); + printf(" --threads N - Number of threads (default: 16)\n"); + printf(" --horizon N - Horizon length (default: %d)\n", T_); + printf(" all - Run all available profiles\n"); } inline void print_timing(const char* name, float ms, int N) { @@ -67,6 +77,38 @@ inline void float_to_device(precision_t* dst, const float* src, int count) { free(tmp); } +__device__ __forceinline__ uint32_t scan_hash_u32(uint32_t x) { + x ^= x >> 16; + x *= 0x7feb352dU; + x ^= x >> 15; + x *= 0x846ca68bU; + x ^= x >> 16; + return x; +} + +__global__ void fill_precision_pseudorand_signed_kernel( + precision_t* __restrict__ dst, int n, uint32_t seed, float scale) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) { + return; + } + uint32_t h = scan_hash_u32((uint32_t)idx ^ seed); + float u = (float)(h & 0x00ffffffU) * (1.0f / 16777215.0f); // [0, 1] + float centered = 2.0f * u - 1.0f; // [-1, 1] + dst[idx] = from_float(centered * scale); +} + +__global__ void fill_precision_pseudorand_positive_kernel( + precision_t* __restrict__ dst, int n, uint32_t seed, float base, float span) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) { + return; + } + uint32_t h = scan_hash_u32((uint32_t)idx ^ seed); + float u = (float)(h & 0x00ffffffU) * (1.0f / 16777215.0f); // [0, 1] + dst[idx] = from_float(base + span * u); +} + inline float profile_kernel(kernel_fn fn, void* args) { for (int i = 0; i < WARMUP_ITERS; ++i) fn(args); cudaDeviceSynchronize(); @@ -140,25 +182,6 @@ void profile_mingrugate(int B, int H) { free(p); } -__global__ void log_coeffs_and_values_fwd_kernel( - float* log_coeff_out, float* log_value_out, - const float* gate, const float* hidden, int N) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) return; - log_coeffs_and_values_fwd(gate[idx], hidden[idx], - &log_coeff_out[idx], &log_value_out[idx]); -} - -__global__ void log_coeffs_and_values_bwd_kernel( - float* grad_gate_out, float* grad_hidden_out, - const float* grad_log_coeffs, const float* grad_log_values, - const float* gate, const float* hidden, int N) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) return; - log_coeffs_and_values_bwd(grad_log_coeffs[idx], grad_log_values[idx], - gate[idx], hidden[idx], &grad_gate_out[idx], &grad_hidden_out[idx]); -} - struct LogCoeffsProfile { FloatTensor gate, hidden, log_coeff, log_value; FloatTensor grad_log_coeffs, grad_log_values, grad_gate, grad_hidden; @@ -166,8 +189,12 @@ struct LogCoeffsProfile { int N; }; -LogCoeffsProfile* create_logcoeffs(int N) { +LogCoeffsProfile* create_logcoeffs(int N, const char* tag = "logcoeffs") { auto* p = (LogCoeffsProfile*)calloc(1, sizeof(LogCoeffsProfile)); + if (!p) { + printf("%s N=%d SKIP: host allocation failed\n", tag, N); + return nullptr; + } p->N = N; p->gate = {.shape = {N}}; p->hidden = {.shape = {N}}; @@ -186,9 +213,24 @@ LogCoeffsProfile* create_logcoeffs(int N) { alloc_register(&p->alloc, &p->grad_log_values); alloc_register(&p->alloc, &p->grad_gate); alloc_register(&p->alloc, &p->grad_hidden); - alloc_create(&p->alloc); + cudaError_t alloc_err = alloc_create(&p->alloc); + if (alloc_err != cudaSuccess) { + double gib = (double)p->alloc.total_bytes / (1024.0 * 1024.0 * 1024.0); + printf("%s N=%d SKIP: alloc_create %.2f GiB failed (%s)\n", + tag, N, gib, cudaGetErrorString(alloc_err)); + cudaGetLastError(); // clear sticky runtime error so subsequent cases can proceed + alloc_free(&p->alloc); + free(p); + return nullptr; + } float* buf = (float*)malloc(N * sizeof(float)); + if (!buf) { + printf("%s N=%d SKIP: host input buffer allocation failed\n", tag, N); + alloc_free(&p->alloc); + free(p); + return nullptr; + } for (int i = 0; i < N; ++i) buf[i] = rand1() * 5.0f; cudaMemcpy(p->gate.data, buf, N * sizeof(float), cudaMemcpyHostToDevice); for (int i = 0; i < N; ++i) buf[i] = rand1() * 5.0f; @@ -201,6 +243,25 @@ LogCoeffsProfile* create_logcoeffs(int N) { return p; } +__global__ void log_coeffs_and_values_fwd_kernel( + float* log_coeff_out, float* log_value_out, + const float* gate, const float* hidden, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + log_coeffs_and_values_fwd(gate[idx], hidden[idx], + &log_coeff_out[idx], &log_value_out[idx]); +} + +__global__ void log_coeffs_and_values_bwd_kernel( + float* grad_gate_out, float* grad_hidden_out, + const float* grad_log_coeffs, const float* grad_log_values, + const float* gate, const float* hidden, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + log_coeffs_and_values_bwd(grad_log_coeffs[idx], grad_log_values[idx], + gate[idx], hidden[idx], &grad_gate_out[idx], &grad_hidden_out[idx]); +} + void run_logcoeffs_fwd(LogCoeffsProfile* p) { log_coeffs_and_values_fwd_kernel<<N), BLOCK_SIZE>>>( p->log_coeff.data, p->log_value.data, @@ -217,7 +278,11 @@ void run_logcoeffs_bwd(LogCoeffsProfile* p) { void profile_logcoeffs(int B, int T, int H) { int N = B * T * H; printf("log_coeffs_and_values (N=%d, %dx%dx%d)\n", N, B, T, H); - auto* p = create_logcoeffs(N); + auto* p = create_logcoeffs(N, "logcoeffsvals"); + if (!p) { + printf("\n"); + return; + } float fwd = profile_kernel((kernel_fn)run_logcoeffs_fwd, p); print_timing("forward", fwd, N); float bwd = profile_kernel((kernel_fn)run_logcoeffs_bwd, p); @@ -234,8 +299,25 @@ struct FusedScanProfile { int B, T, H; }; -FusedScanProfile* create_fusedscan(int B, int T, int H) { +static const int kSweepBValsDefault[] = {64, 128, 256, 512}; +static const int kSweepTValsDefault[] = {64, 128, 256, 512}; +static const int kSweepHValsDefault[] = {128, 256, 512, 1024}; +static constexpr int kSweepDefaultBCount = sizeof(kSweepBValsDefault) / sizeof(kSweepBValsDefault[0]); +static constexpr int kSweepDefaultTCount = sizeof(kSweepTValsDefault) / sizeof(kSweepTValsDefault[0]); +static constexpr int kSweepDefaultHCount = sizeof(kSweepHValsDefault) / sizeof(kSweepHValsDefault[0]); +static constexpr int kMaxSweepSizes = 64; +static int gSelectedBVals[kMaxSweepSizes] = {64, 128, 256, 512}; +static int gSelectedTVals[kMaxSweepSizes] = {64, 128, 256, 512}; +static int gSelectedHVals[kMaxSweepSizes] = {128, 256, 512, 1024}; +static int gSelectedBCount = kSweepDefaultBCount; +static int gSelectedTCount = kSweepDefaultTCount; +static int gSelectedHCount = kSweepDefaultHCount; +FusedScanProfile* create_fusedscan(int B, int T, int H, const char* tag = "fused_scan") { auto* p = (FusedScanProfile*)calloc(1, sizeof(FusedScanProfile)); + if (!p) { + printf("%s %dx%dx%d SKIP: host allocation failed\n", tag, B, T, H); + return nullptr; + } p->B = B; p->T = T; p->H = H; PrefixScan& s = p->scan; @@ -273,26 +355,104 @@ FusedScanProfile* create_fusedscan(int B, int T, int H) { alloc_register(&p->alloc, &s.grad_input); alloc_register(&p->alloc, &p->grad_out); alloc_register(&p->alloc, &p->grad_next_state); - alloc_create(&p->alloc); + cudaError_t alloc_err = alloc_create(&p->alloc); + if (alloc_err != cudaSuccess) { + double gib = (double)p->alloc.total_bytes / (1024.0 * 1024.0 * 1024.0); + printf("%s %dx%dx%d SKIP: alloc_create %.2f GiB failed (%s)\n", + tag, B, T, H, gib, cudaGetErrorString(alloc_err)); + cudaGetLastError(); // clear sticky runtime error so subsequent cases can proceed + alloc_free(&p->alloc); + free(p); + return nullptr; + } s.combined_ptr = combined_t.data; s.state_ptr = state_t.data; s.input_ptr = input_t.data; + // Use non-trivial deterministic values for correctness checks. + // Keep state strictly positive to avoid log(0) in log-space kernels. int N_combined = B * T * 3 * H; int N_state = B * H; - int N_out = B * T * H; - float* buf = (float*)malloc(N_combined * sizeof(float)); - for (int i = 0; i < N_combined; ++i) buf[i] = rand1() * 5.0f; - float_to_device(s.combined_ptr, buf, N_combined); - for (int i = 0; i < N_state; ++i) buf[i] = fabsf(rand1()) + 0.1f; - float_to_device(s.state_ptr, buf, N_state); - for (int i = 0; i < N_out; ++i) buf[i] = rand1(); - float_to_device(s.input_ptr, buf, N_out); - float_to_device(p->grad_out.data, buf, N_out); - for (int i = 0; i < N_state; ++i) buf[i] = rand1(); - float_to_device(p->grad_next_state.data, buf, N_state); - free(buf); + int N_input = B * T * H; + int N_grad_out = N_input; + int N_grad_next_state = N_state; + uint32_t seed_base = 0x9e3779b9U + ^ ((uint32_t)B * 73856093U) + ^ ((uint32_t)T * 19349663U) + ^ ((uint32_t)H * 83492791U); + + fill_precision_pseudorand_signed_kernel<<>>( + s.combined_ptr, N_combined, seed_base ^ 0x11111111U, 5.0f); + cudaError_t init_err = cudaGetLastError(); + if (init_err != cudaSuccess) { + printf("%s %dx%dx%d SKIP: combined init launch failed (%s)\n", + tag, B, T, H, cudaGetErrorString(init_err)); + cudaGetLastError(); + alloc_free(&p->alloc); + free(p); + return nullptr; + } + + fill_precision_pseudorand_positive_kernel<<>>( + s.state_ptr, N_state, seed_base ^ 0x22222222U, 0.25f, 1.25f); + init_err = cudaGetLastError(); + if (init_err != cudaSuccess) { + printf("%s %dx%dx%d SKIP: state init launch failed (%s)\n", + tag, B, T, H, cudaGetErrorString(init_err)); + cudaGetLastError(); + alloc_free(&p->alloc); + free(p); + return nullptr; + } + + fill_precision_pseudorand_signed_kernel<<>>( + s.input_ptr, N_input, seed_base ^ 0x33333333U, 1.0f); + init_err = cudaGetLastError(); + if (init_err != cudaSuccess) { + printf("%s %dx%dx%d SKIP: input init launch failed (%s)\n", + tag, B, T, H, cudaGetErrorString(init_err)); + cudaGetLastError(); + alloc_free(&p->alloc); + free(p); + return nullptr; + } + + fill_precision_pseudorand_signed_kernel<<>>( + p->grad_out.data, N_grad_out, seed_base ^ 0x44444444U, 1.0f); + init_err = cudaGetLastError(); + if (init_err != cudaSuccess) { + printf("%s %dx%dx%d SKIP: grad_out init launch failed (%s)\n", + tag, B, T, H, cudaGetErrorString(init_err)); + cudaGetLastError(); + alloc_free(&p->alloc); + free(p); + return nullptr; + } + + fill_precision_pseudorand_signed_kernel<<>>( + p->grad_next_state.data, N_grad_next_state, seed_base ^ 0x55555555U, 1.0f); + init_err = cudaGetLastError(); + if (init_err != cudaSuccess) { + printf("%s %dx%dx%d SKIP: grad_next_state init launch failed (%s)\n", + tag, B, T, H, cudaGetErrorString(init_err)); + cudaGetLastError(); + alloc_free(&p->alloc); + free(p); + return nullptr; + } + + // Ensure all init kernels are complete before first use. + cudaError_t sync_err = cudaDeviceSynchronize(); + if (sync_err != cudaSuccess) { + printf("%s %dx%dx%d SKIP: init sync failed (%s)\n", + tag, B, T, H, cudaGetErrorString(sync_err)); + cudaGetLastError(); + alloc_free(&p->alloc); + free(p); + return nullptr; + } + return p; } @@ -305,9 +465,255 @@ void run_fusedscan_bwd(FusedScanProfile* p) { p->scan, p->grad_out.data, p->grad_next_state.data); } +template +void run_fusedscan_fwd_ckpt_tuned(FusedScanProfile* p); +template +void run_fusedscan_bwd_ckpt_tuned(FusedScanProfile* p); + +template +void run_fusedscan_fwd_vec32_ckpt_tuned(FusedScanProfile* p) { +#ifdef PRECISION_FLOAT + run_fusedscan_fwd_ckpt_tuned(p); +#else + if ((p->H & 1) == 0) { + int H2 = p->H >> 1; + mingru_scan_forward_ckpt_tuned_vec32<<B * H2), BLOCK_SIZE>>>(p->scan); + } else { + run_fusedscan_fwd_ckpt_tuned(p); + } +#endif +} + +template +void run_fusedscan_bwd_vec32_ckpt_tuned(FusedScanProfile* p) { +#ifdef PRECISION_FLOAT + run_fusedscan_bwd_ckpt_tuned(p); +#else + if ((p->H & 1) == 0) { + int H2 = p->H >> 1; + mingru_scan_backward_ckpt_tuned_vec32<<B * H2), BLOCK_SIZE>>>( + p->scan, p->grad_out.data, p->grad_next_state.data); + } else { + run_fusedscan_bwd_ckpt_tuned(p); + } +#endif +} + +template +void run_fusedscan_fwd_vec64_ckpt_tuned(FusedScanProfile* p) { + if ((p->H % MINGRU_SCAN_VEC64_WIDTH) == 0) { + int HW = p->H / MINGRU_SCAN_VEC64_WIDTH; + mingru_scan_forward_ckpt_tuned_vec64<<B * HW), BLOCK_SIZE>>>(p->scan); + } else { + run_fusedscan_fwd_ckpt_tuned(p); + } +} + +template +void run_fusedscan_bwd_vec64_ckpt_tuned(FusedScanProfile* p) { + if ((p->H % MINGRU_SCAN_VEC64_WIDTH) == 0) { + int HW = p->H / MINGRU_SCAN_VEC64_WIDTH; + mingru_scan_backward_ckpt_tuned_vec64<<B * HW), BLOCK_SIZE>>>( + p->scan, p->grad_out.data, p->grad_next_state.data); + } else { + run_fusedscan_bwd_ckpt_tuned(p); + } +} + +template +void run_fusedscan_fwd_vec128_ckpt_tuned(FusedScanProfile* p) { + if ((p->H % MINGRU_SCAN_VEC128_WIDTH) == 0) { + int HW = p->H / MINGRU_SCAN_VEC128_WIDTH; + mingru_scan_forward_ckpt_tuned_vec128<<B * HW), BLOCK_SIZE>>>(p->scan); + } else { + run_fusedscan_fwd_ckpt_tuned(p); + } +} + +template +void run_fusedscan_bwd_vec128_ckpt_tuned(FusedScanProfile* p) { + if ((p->H % MINGRU_SCAN_VEC128_WIDTH) == 0) { + int HW = p->H / MINGRU_SCAN_VEC128_WIDTH; + mingru_scan_backward_ckpt_tuned_vec128<<B * HW), BLOCK_SIZE>>>( + p->scan, p->grad_out.data, p->grad_next_state.data); + } else { + run_fusedscan_bwd_ckpt_tuned(p); + } +} + +void run_fusedscan_fwd_vec32_ckpt4(FusedScanProfile* p) { + run_fusedscan_fwd_vec32_ckpt_tuned<4>(p); +} + +void run_fusedscan_bwd_vec32_ckpt4(FusedScanProfile* p) { + run_fusedscan_bwd_vec32_ckpt_tuned<4>(p); +} + +void run_fusedscan_fwd_vec64_ckpt4(FusedScanProfile* p) { + run_fusedscan_fwd_vec64_ckpt_tuned<4>(p); +} + +void run_fusedscan_bwd_vec64_ckpt4(FusedScanProfile* p) { + run_fusedscan_bwd_vec64_ckpt_tuned<4>(p); +} + +void run_fusedscan_fwd_vec128_ckpt4(FusedScanProfile* p) { + run_fusedscan_fwd_vec128_ckpt_tuned<4>(p); +} + +void run_fusedscan_bwd_vec128_ckpt4(FusedScanProfile* p) { + run_fusedscan_bwd_vec128_ckpt_tuned<4>(p); +} + +template +void run_fusedscan_fwd_ckpt_tuned(FusedScanProfile* p) { + mingru_scan_forward_ckpt_tuned<<B * p->H), BLOCK_SIZE>>>(p->scan); +} + +template +void run_fusedscan_bwd_ckpt_tuned(FusedScanProfile* p) { + mingru_scan_backward_ckpt_tuned<<B * p->H), BLOCK_SIZE>>>( + p->scan, p->grad_out.data, p->grad_next_state.data); +} + +template +void print_kernel_attrs(const char* name, KernelT kernel) { + cudaFuncAttributes attrs = {}; + cudaError_t err = cudaFuncGetAttributes(&attrs, kernel); + if (err != cudaSuccess) { + printf(" %-34s unavailable (%s)\n", name, cudaGetErrorString(err)); + cudaGetLastError(); // clear sticky error so profiling can continue + return; + } + printf(" %-34s regs=%3d local=%5zuB smem=%5zuB maxT=%4d bin=%2d ptx=%2d\n", + name, + attrs.numRegs, + (size_t)attrs.localSizeBytes, + (size_t)attrs.sharedSizeBytes, + attrs.maxThreadsPerBlock, + attrs.binaryVersion, + attrs.ptxVersion); +} + +void print_fusedscan_kernel_diagnostics_once() { + static bool printed = false; + if (printed) { + return; + } + printed = true; + + int dev = 0; + cudaGetDevice(&dev); + cudaDeviceProp prop = {}; + cudaError_t prop_err = cudaGetDeviceProperties(&prop, dev); + + printf("fused_scan diagnostics\n"); + if (prop_err == cudaSuccess) { + printf(" device: %s (sm_%d%d)\n", prop.name, prop.major, prop.minor); + } else { + printf(" device: (%s)\n", cudaGetErrorString(prop_err)); + cudaGetLastError(); + } +#ifdef PRECISION_FLOAT + printf(" precision: float32 (PRECISION_FLOAT)\n"); +#else + printf(" precision: bfloat16 (default)\n"); +#endif + printf(" kernel attributes:\n"); + print_kernel_attrs("log_ckpt4_fwd", mingru_scan_forward); + print_kernel_attrs("log_ckpt4_bwd", mingru_scan_backward); +#ifdef PRECISION_FLOAT + print_kernel_attrs("log_vec32_ckpt4_fwd", mingru_scan_forward_ckpt_tuned<4>); + print_kernel_attrs("log_vec32_ckpt4_bwd", mingru_scan_backward_ckpt_tuned<4>); +#else + print_kernel_attrs("log_vec32_ckpt4_fwd", mingru_scan_forward_ckpt_tuned_vec32<4>); + print_kernel_attrs("log_vec32_ckpt4_bwd", mingru_scan_backward_ckpt_tuned_vec32<4>); +#endif + print_kernel_attrs("log_vec64_ckpt4_fwd", mingru_scan_forward_ckpt_tuned_vec64<4>); + print_kernel_attrs("log_vec64_ckpt4_bwd", mingru_scan_backward_ckpt_tuned_vec64<4>); + print_kernel_attrs("log_vec128_ckpt4_fwd", mingru_scan_forward_ckpt_tuned_vec128<4>); + print_kernel_attrs("log_vec128_ckpt4_bwd", mingru_scan_backward_ckpt_tuned_vec128<4>); + printf("\n"); +} + +void copy_fusedscan_inputs(FusedScanProfile* dst, FusedScanProfile* src) { + int N_combined = src->B * src->T * 3 * src->H; + int N_state = src->B * src->H; + int N_out = src->B * src->T * src->H; + cudaMemcpy(dst->scan.combined_ptr, src->scan.combined_ptr, + N_combined * sizeof(precision_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(dst->scan.state_ptr, src->scan.state_ptr, + N_state * sizeof(precision_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(dst->scan.input_ptr, src->scan.input_ptr, + N_out * sizeof(precision_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(dst->grad_out.data, src->grad_out.data, + N_out * sizeof(precision_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(dst->grad_next_state.data, src->grad_next_state.data, + N_state * sizeof(precision_t), cudaMemcpyDeviceToDevice); +} + +struct CompareStats { + float max_abs = 0.0f; + float max_rel = 0.0f; + float mean_abs = 0.0f; + float mean_rel = 0.0f; + long long mismatch_count = 0; + long long total_count = 0; +}; + +CompareStats compare_precision_tensors( + const PrecisionTensor& a, const PrecisionTensor& b, + float abs_tol, float rel_tol) { + long n = numel(a.shape); + precision_t* ah = (precision_t*)malloc(n * sizeof(precision_t)); + precision_t* bh = (precision_t*)malloc(n * sizeof(precision_t)); + cudaMemcpy(ah, a.data, n * sizeof(precision_t), cudaMemcpyDeviceToHost); + cudaMemcpy(bh, b.data, n * sizeof(precision_t), cudaMemcpyDeviceToHost); + + CompareStats stats; + stats.total_count = (long long)n; + double sum_abs = 0.0; + double sum_rel = 0.0; + for (long i = 0; i < n; i++) { + float av = to_float(ah[i]); + float bv = to_float(bh[i]); + float abs_err = fabsf(av - bv); + float denom = fmaxf(fabsf(av), fabsf(bv)); + float rel_err = abs_err / fmaxf(denom, 1e-6f); + sum_abs += (double)abs_err; + sum_rel += (double)rel_err; + stats.max_abs = fmaxf(stats.max_abs, abs_err); + stats.max_rel = fmaxf(stats.max_rel, rel_err); + if (abs_err > abs_tol && rel_err > rel_tol) { + stats.mismatch_count++; + } + } + if (n > 0) { + stats.mean_abs = (float)(sum_abs / (double)n); + stats.mean_rel = (float)(sum_rel / (double)n); + } + + free(ah); + free(bh); + return stats; +} + +bool print_compare_result(const char* name, CompareStats stats, float abs_tol, float rel_tol) { + bool ok = stats.max_abs <= abs_tol || stats.max_rel <= rel_tol; + printf(" %-16s %s max_abs=%9.4g max_rel=%9.4g avg_abs=%9.4g avg_rel=%9.4g mismatches=%lld/%lld\n", + name, ok ? "PASS" : "FAIL", + stats.max_abs, stats.max_rel, stats.mean_abs, stats.mean_rel, + stats.mismatch_count, stats.total_count); + return ok; +} + void profile_fusedscan(int B, int T, int H) { printf("fused_scan (N=%d, %dx%dx%d)\n", B*T*H, B, T, H); - auto* p = create_fusedscan(B, T, H); + auto* p = create_fusedscan(B, T, H, "fused_scan"); + if (!p) { + printf("\n"); + return; + } float fwd = profile_kernel((kernel_fn)run_fusedscan_fwd, p); print_timing("forward", fwd, B*T); float bwd = profile_kernel((kernel_fn)run_fusedscan_bwd, p); @@ -317,6 +723,1083 @@ void profile_fusedscan(int B, int T, int H) { free(p); } +typedef void (*fusedscan_run_fn)(FusedScanProfile*); + +struct FusedScanVariant { + const char* name; + fusedscan_run_fn fwd; + fusedscan_run_fn bwd; +}; + +static const FusedScanVariant kLogBaseVariant = { + "base_log_ckpt4", + run_fusedscan_fwd, + run_fusedscan_bwd, +}; + +#ifdef PRECISION_FLOAT +static const FusedScanVariant kLogVectorVariants[] = { + {"log_vec32", run_fusedscan_fwd_vec32_ckpt4, run_fusedscan_bwd_vec32_ckpt4}, + {"log_vec64", run_fusedscan_fwd_vec64_ckpt4, run_fusedscan_bwd_vec64_ckpt4}, + {"log_vec128", run_fusedscan_fwd_vec128_ckpt4, run_fusedscan_bwd_vec128_ckpt4}, +}; +#else +static const FusedScanVariant kLogVectorVariants[] = { + {"log_vec32", run_fusedscan_fwd_vec32_ckpt4, run_fusedscan_bwd_vec32_ckpt4}, + {"log_vec64", run_fusedscan_fwd_vec64_ckpt4, run_fusedscan_bwd_vec64_ckpt4}, + {"log_vec128", run_fusedscan_fwd_vec128_ckpt4, run_fusedscan_bwd_vec128_ckpt4}, +}; +#endif +static const int kLogVectorVariantCount = + sizeof(kLogVectorVariants) / sizeof(kLogVectorVariants[0]); + +bool check_variant_correctness_case(int B, int T, int H, const FusedScanVariant& variant) { + printf(" correctness-log %-20s %dx%dx%d\n", variant.name, B, T, H); + auto* ref = create_fusedscan(B, T, H, "fused_scan_log_base_ref"); + if (!ref) { + return true; + } + auto* test = create_fusedscan(B, T, H, "fused_scan_log_variant"); + if (!test) { + alloc_free(&ref->alloc); + free(ref); + return true; + } + + copy_fusedscan_inputs(test, ref); + kLogBaseVariant.fwd(ref); + variant.fwd(test); + cudaDeviceSynchronize(); + + kLogBaseVariant.bwd(ref); + variant.bwd(test); + cudaDeviceSynchronize(); + +#ifdef PRECISION_FLOAT + float abs_tol = 2e-4f, rel_tol = 2e-4f; +#else + float abs_tol = 5e-2f, rel_tol = 5e-2f; +#endif + bool ok = true; + ok &= print_compare_result("out", + compare_precision_tensors(ref->scan.out, test->scan.out, abs_tol, rel_tol), abs_tol, rel_tol); + ok &= print_compare_result("next_state", + compare_precision_tensors(ref->scan.next_state, test->scan.next_state, abs_tol, rel_tol), abs_tol, rel_tol); + ok &= print_compare_result("grad_combined", + compare_precision_tensors(ref->scan.grad_combined, test->scan.grad_combined, abs_tol, rel_tol), abs_tol, rel_tol); + ok &= print_compare_result("grad_input", + compare_precision_tensors(ref->scan.grad_input, test->scan.grad_input, abs_tol, rel_tol), abs_tol, rel_tol); + ok &= print_compare_result("grad_state", + compare_precision_tensors(ref->scan.grad_state, test->scan.grad_state, abs_tol, rel_tol), abs_tol, rel_tol); + + alloc_free(&ref->alloc); + alloc_free(&test->alloc); + free(ref); + free(test); + return ok; +} + +bool profile_fusedscan_correctness_suite() { + printf("fused_scan_correctness suite (log variants vs log_scalar, correctness-only)\n"); + print_fusedscan_kernel_diagnostics_once(); + print_selected_sweep_sizes(); + bool ok = true; + int nb = gSelectedBCount; + int nt = gSelectedTCount; + int nh = gSelectedHCount; + for (int ib = 0; ib < nb; ib++) { + for (int it = 0; it < nt; it++) { + for (int ih = 0; ih < nh; ih++) { + int B = gSelectedBVals[ib]; + int T = gSelectedTVals[it]; + int H = gSelectedHVals[ih]; + printf(" case %dx%dx%d\n", B, T, H); + for (int variant_i = 0; variant_i < kLogVectorVariantCount; variant_i++) { + ok &= check_variant_correctness_case(B, T, H, kLogVectorVariants[variant_i]); + } + } + } + } + printf("fused_scan_correctness suite: %s\n\n", ok ? "PASS" : "FAIL"); + return ok; +} + +static const int kCheckpointIntervalsAll[] = {1, 2, 4, 8, 16, 32}; + +static constexpr int kCheckpointIntervalOptions = + sizeof(kCheckpointIntervalsAll) / sizeof(kCheckpointIntervalsAll[0]); +static int gSelectedCheckpointVals[kMaxSweepSizes] = {1, 2, 4, 8, 16, 32}; +static int gSelectedCheckpointCount = kCheckpointIntervalOptions; + +inline void copy_selected_values(int* dst_vals, const int* src_vals, int count) { + for (int idx = 0; idx < count; idx++) { + dst_vals[idx] = src_vals[idx]; + } +} + +inline void print_selected_values(const char* label, const int* values, int count) { + printf(" %s:", label); + for (int idx = 0; idx < count; idx++) { + printf(" %d", values[idx]); + } + printf("\n"); +} + +void reset_b_size_selection() { + gSelectedBCount = kSweepDefaultBCount; + copy_selected_values(gSelectedBVals, kSweepBValsDefault, kSweepDefaultBCount); +} + +void reset_t_size_selection() { + gSelectedTCount = kSweepDefaultTCount; + copy_selected_values(gSelectedTVals, kSweepTValsDefault, kSweepDefaultTCount); +} + +void reset_h_size_selection() { + gSelectedHCount = kSweepDefaultHCount; + copy_selected_values(gSelectedHVals, kSweepHValsDefault, kSweepDefaultHCount); +} + +void reset_sweep_size_selection() { + reset_b_size_selection(); + reset_t_size_selection(); + reset_h_size_selection(); +} + +void reset_checkpoint_interval_selection() { + gSelectedCheckpointCount = kCheckpointIntervalOptions; + copy_selected_values(gSelectedCheckpointVals, kCheckpointIntervalsAll, kCheckpointIntervalOptions); +} + +void print_selected_checkpoint_intervals() { + print_selected_values("checkpoint intervals", gSelectedCheckpointVals, gSelectedCheckpointCount); +} + +void print_selected_sweep_sizes() { + print_selected_values("B sizes", gSelectedBVals, gSelectedBCount); + print_selected_values("T sizes", gSelectedTVals, gSelectedTCount); + print_selected_values("H sizes", gSelectedHVals, gSelectedHCount); +} + +#if defined(__GNUC__) && !defined(__clang__) +__attribute__((optimize("O0"))) +#endif +int parse_positive_int_csv(const char* csv, int* out_vals, int out_cap, const char* what) { + if (!csv || csv[0] == '\0') { + return 0; + } + + int out_count = 0; + const char* p = csv; + while (*p != '\0') { + while (*p == ' ' || *p == '\t' || *p == ',') { + p++; + } + if (*p == '\0') { + break; + } + + char* parse_end = nullptr; + long val = strtol(p, &parse_end, 10); + if (parse_end == p) { + printf("warning: invalid %s token near '%.16s' (ignored)\n", what, p); + while (*p != '\0' && *p != ',') { + p++; + } + continue; + } + + const char* tail = parse_end; + while (*tail == ' ' || *tail == '\t') { + tail++; + } + if (*tail != '\0' && *tail != ',') { + printf("warning: invalid %s token near '%.16s' (ignored)\n", what, p); + p = tail; + while (*p != '\0' && *p != ',') { + p++; + } + continue; + } + + if (val <= 0) { + printf("warning: invalid %s value %ld (ignored)\n", what, val); + } else { + bool exists = false; + for (int out_i = 0; out_i < out_count; out_i++) { + if (out_vals[out_i] == (int)val) { + exists = true; + break; + } + } + if (!exists) { + if (out_count < out_cap) { + out_vals[out_count++] = (int)val; + } else { + printf("warning: too many %s values; max=%d (remaining ignored)\n", what, out_cap); + break; + } + } + } + + p = tail; + if (*p == ',') { + p++; + } + } + + return out_count; +} + +typedef void (*reset_selection_fn)(); + +bool set_selection_from_csv(const char* csv, int* selected_vals, int* selected_count, + const char* what, reset_selection_fn reset_selection) { + if (!csv || csv[0] == '\0') { + reset_selection(); + return true; + } + + int parsed[kMaxSweepSizes]; + int parsed_count = parse_positive_int_csv(csv, parsed, kMaxSweepSizes, what); + if (parsed_count == 0) { + printf("warning: no valid %s parsed from '%s'; using defaults\n", what, csv); + reset_selection(); + return false; + } + + *selected_count = parsed_count; + copy_selected_values(selected_vals, parsed, parsed_count); + return true; +} + +#if defined(__GNUC__) && !defined(__clang__) +__attribute__((optimize("O0"))) +#endif +bool set_checkpoint_intervals_from_csv(const char* csv) { + return set_selection_from_csv(csv, gSelectedCheckpointVals, &gSelectedCheckpointCount, + "checkpoint intervals", reset_checkpoint_interval_selection); +} + +#if defined(__GNUC__) && !defined(__clang__) +__attribute__((optimize("O0"))) +#endif +bool set_b_sizes_from_csv(const char* csv) { + return set_selection_from_csv(csv, gSelectedBVals, &gSelectedBCount, + "B sizes", reset_b_size_selection); +} + +#if defined(__GNUC__) && !defined(__clang__) +__attribute__((optimize("O0"))) +#endif +bool set_t_sizes_from_csv(const char* csv) { + return set_selection_from_csv(csv, gSelectedTVals, &gSelectedTCount, + "T sizes", reset_t_size_selection); +} + +#if defined(__GNUC__) && !defined(__clang__) +__attribute__((optimize("O0"))) +#endif +bool set_h_sizes_from_csv(const char* csv) { + return set_selection_from_csv(csv, gSelectedHVals, &gSelectedHCount, + "H sizes", reset_h_size_selection); +} + +inline int grid_size_for_block(int N, int block_size) { + return (N + block_size - 1) / block_size; +} + +struct FusedScanBlockLaunch { + FusedScanProfile* p; + int block_size; +}; + +template +void run_fusedscan_fwd_ckpt_tuned_block(FusedScanBlockLaunch* args) { + FusedScanProfile* p = args->p; + int bs = args->block_size; + mingru_scan_forward_ckpt_tuned<<B * p->H, bs), bs>>>(p->scan); +} + +template +void run_fusedscan_bwd_ckpt_tuned_block(FusedScanBlockLaunch* args) { + FusedScanProfile* p = args->p; + int bs = args->block_size; + mingru_scan_backward_ckpt_tuned<<B * p->H, bs), bs>>>( + p->scan, p->grad_out.data, p->grad_next_state.data); +} + +template +void run_fusedscan_fwd_vec32_ckpt_tuned_block(FusedScanBlockLaunch* args) { +#ifdef PRECISION_FLOAT + run_fusedscan_fwd_ckpt_tuned_block(args); +#else + FusedScanProfile* p = args->p; + int bs = args->block_size; + if ((p->H & 1) == 0) { + int H2 = p->H >> 1; + mingru_scan_forward_ckpt_tuned_vec32<<B * H2, bs), bs>>>(p->scan); + } else { + run_fusedscan_fwd_ckpt_tuned_block(args); + } +#endif +} + +template +void run_fusedscan_bwd_vec32_ckpt_tuned_block(FusedScanBlockLaunch* args) { +#ifdef PRECISION_FLOAT + run_fusedscan_bwd_ckpt_tuned_block(args); +#else + FusedScanProfile* p = args->p; + int bs = args->block_size; + if ((p->H & 1) == 0) { + int H2 = p->H >> 1; + mingru_scan_backward_ckpt_tuned_vec32<<B * H2, bs), bs>>>( + p->scan, p->grad_out.data, p->grad_next_state.data); + } else { + run_fusedscan_bwd_ckpt_tuned_block(args); + } +#endif +} + +template +void run_fusedscan_fwd_vec64_ckpt_tuned_block(FusedScanBlockLaunch* args) { + FusedScanProfile* p = args->p; + int bs = args->block_size; + if ((p->H % MINGRU_SCAN_VEC64_WIDTH) == 0) { + int HW = p->H / MINGRU_SCAN_VEC64_WIDTH; + mingru_scan_forward_ckpt_tuned_vec64<<B * HW, bs), bs>>>(p->scan); + } else { + run_fusedscan_fwd_ckpt_tuned_block(args); + } +} + +template +void run_fusedscan_bwd_vec64_ckpt_tuned_block(FusedScanBlockLaunch* args) { + FusedScanProfile* p = args->p; + int bs = args->block_size; + if ((p->H % MINGRU_SCAN_VEC64_WIDTH) == 0) { + int HW = p->H / MINGRU_SCAN_VEC64_WIDTH; + mingru_scan_backward_ckpt_tuned_vec64<<B * HW, bs), bs>>>( + p->scan, p->grad_out.data, p->grad_next_state.data); + } else { + run_fusedscan_bwd_ckpt_tuned_block(args); + } +} + +template +void run_fusedscan_fwd_vec128_ckpt_tuned_block(FusedScanBlockLaunch* args) { + FusedScanProfile* p = args->p; + int bs = args->block_size; + if ((p->H % MINGRU_SCAN_VEC128_WIDTH) == 0) { + int HW = p->H / MINGRU_SCAN_VEC128_WIDTH; + mingru_scan_forward_ckpt_tuned_vec128<<B * HW, bs), bs>>>(p->scan); + } else { + run_fusedscan_fwd_ckpt_tuned_block(args); + } +} + +template +void run_fusedscan_bwd_vec128_ckpt_tuned_block(FusedScanBlockLaunch* args) { + FusedScanProfile* p = args->p; + int bs = args->block_size; + if ((p->H % MINGRU_SCAN_VEC128_WIDTH) == 0) { + int HW = p->H / MINGRU_SCAN_VEC128_WIDTH; + mingru_scan_backward_ckpt_tuned_vec128<<B * HW, bs), bs>>>( + p->scan, p->grad_out.data, p->grad_next_state.data); + } else { + run_fusedscan_bwd_ckpt_tuned_block(args); + } +} + +template +float profile_fusedscan_ckpt_tuned_speed_block(FusedScanProfile* p, int N, int block_size, + float* fwd_out = nullptr, float* bwd_out = nullptr) { + (void)N; + FusedScanBlockLaunch args = {p, block_size}; + float fwd = profile_kernel((kernel_fn)run_fusedscan_fwd_ckpt_tuned_block, &args); + float bwd = profile_kernel((kernel_fn)run_fusedscan_bwd_ckpt_tuned_block, &args); + if (fwd_out) { + *fwd_out = fwd; + } + if (bwd_out) { + *bwd_out = bwd; + } + return fwd + bwd; +} + +template +float profile_fusedscan_vec32_ckpt_speed_block(FusedScanProfile* p, int N, int block_size, + float* fwd_out = nullptr, float* bwd_out = nullptr) { + (void)N; + FusedScanBlockLaunch args = {p, block_size}; + float fwd = profile_kernel((kernel_fn)run_fusedscan_fwd_vec32_ckpt_tuned_block, &args); + float bwd = profile_kernel((kernel_fn)run_fusedscan_bwd_vec32_ckpt_tuned_block, &args); + if (fwd_out) { + *fwd_out = fwd; + } + if (bwd_out) { + *bwd_out = bwd; + } + return fwd + bwd; +} + +template +float profile_fusedscan_vec64_ckpt_speed_block(FusedScanProfile* p, int N, int block_size, + float* fwd_out = nullptr, float* bwd_out = nullptr) { + (void)N; + FusedScanBlockLaunch args = {p, block_size}; + float fwd = profile_kernel((kernel_fn)run_fusedscan_fwd_vec64_ckpt_tuned_block, &args); + float bwd = profile_kernel((kernel_fn)run_fusedscan_bwd_vec64_ckpt_tuned_block, &args); + if (fwd_out) { + *fwd_out = fwd; + } + if (bwd_out) { + *bwd_out = bwd; + } + return fwd + bwd; +} + +template +float profile_fusedscan_vec128_ckpt_speed_block(FusedScanProfile* p, int N, int block_size, + float* fwd_out = nullptr, float* bwd_out = nullptr) { + (void)N; + FusedScanBlockLaunch args = {p, block_size}; + float fwd = profile_kernel((kernel_fn)run_fusedscan_fwd_vec128_ckpt_tuned_block, &args); + float bwd = profile_kernel((kernel_fn)run_fusedscan_bwd_vec128_ckpt_tuned_block, &args); + if (fwd_out) { + *fwd_out = fwd; + } + if (bwd_out) { + *bwd_out = bwd; + } + return fwd + bwd; +} + +inline bool is_supported_checkpoint_interval(int ckpt_interval) { + switch (ckpt_interval) { + case 1: + case 2: + case 4: + case 8: + case 16: + case 32: + return true; + default: + return false; + } +} + +float dispatch_scalar_ckpt_speed(int ckpt_interval, FusedScanProfile* p, int N, int block_size, + float* fwd_out = nullptr, float* bwd_out = nullptr) { + switch (ckpt_interval) { + case 1: return profile_fusedscan_ckpt_tuned_speed_block<1>(p, N, block_size, fwd_out, bwd_out); + case 2: return profile_fusedscan_ckpt_tuned_speed_block<2>(p, N, block_size, fwd_out, bwd_out); + case 4: return profile_fusedscan_ckpt_tuned_speed_block<4>(p, N, block_size, fwd_out, bwd_out); + case 8: return profile_fusedscan_ckpt_tuned_speed_block<8>(p, N, block_size, fwd_out, bwd_out); + case 16: return profile_fusedscan_ckpt_tuned_speed_block<16>(p, N, block_size, fwd_out, bwd_out); + case 32: return profile_fusedscan_ckpt_tuned_speed_block<32>(p, N, block_size, fwd_out, bwd_out); + default: return -1.0f; + } +} + +float dispatch_vec32_ckpt_speed(int ckpt_interval, FusedScanProfile* p, int N, int block_size, + float* fwd_out = nullptr, float* bwd_out = nullptr) { + switch (ckpt_interval) { + case 1: return profile_fusedscan_vec32_ckpt_speed_block<1>(p, N, block_size, fwd_out, bwd_out); + case 2: return profile_fusedscan_vec32_ckpt_speed_block<2>(p, N, block_size, fwd_out, bwd_out); + case 4: return profile_fusedscan_vec32_ckpt_speed_block<4>(p, N, block_size, fwd_out, bwd_out); + case 8: return profile_fusedscan_vec32_ckpt_speed_block<8>(p, N, block_size, fwd_out, bwd_out); + case 16: return profile_fusedscan_vec32_ckpt_speed_block<16>(p, N, block_size, fwd_out, bwd_out); + case 32: return profile_fusedscan_vec32_ckpt_speed_block<32>(p, N, block_size, fwd_out, bwd_out); + default: return -1.0f; + } +} + +float dispatch_vec64_ckpt_speed(int ckpt_interval, FusedScanProfile* p, int N, int block_size, + float* fwd_out = nullptr, float* bwd_out = nullptr) { + switch (ckpt_interval) { + case 1: return profile_fusedscan_vec64_ckpt_speed_block<1>(p, N, block_size, fwd_out, bwd_out); + case 2: return profile_fusedscan_vec64_ckpt_speed_block<2>(p, N, block_size, fwd_out, bwd_out); + case 4: return profile_fusedscan_vec64_ckpt_speed_block<4>(p, N, block_size, fwd_out, bwd_out); + case 8: return profile_fusedscan_vec64_ckpt_speed_block<8>(p, N, block_size, fwd_out, bwd_out); + case 16: return profile_fusedscan_vec64_ckpt_speed_block<16>(p, N, block_size, fwd_out, bwd_out); + case 32: return profile_fusedscan_vec64_ckpt_speed_block<32>(p, N, block_size, fwd_out, bwd_out); + default: return -1.0f; + } +} + +float dispatch_vec128_ckpt_speed(int ckpt_interval, FusedScanProfile* p, int N, int block_size, + float* fwd_out = nullptr, float* bwd_out = nullptr) { + switch (ckpt_interval) { + case 1: return profile_fusedscan_vec128_ckpt_speed_block<1>(p, N, block_size, fwd_out, bwd_out); + case 2: return profile_fusedscan_vec128_ckpt_speed_block<2>(p, N, block_size, fwd_out, bwd_out); + case 4: return profile_fusedscan_vec128_ckpt_speed_block<4>(p, N, block_size, fwd_out, bwd_out); + case 8: return profile_fusedscan_vec128_ckpt_speed_block<8>(p, N, block_size, fwd_out, bwd_out); + case 16: return profile_fusedscan_vec128_ckpt_speed_block<16>(p, N, block_size, fwd_out, bwd_out); + case 32: return profile_fusedscan_vec128_ckpt_speed_block<32>(p, N, block_size, fwd_out, bwd_out); + default: return -1.0f; + } +} + +template +int kernel_max_threads_per_block(KernelT kernel) { + cudaFuncAttributes attrs = {}; + cudaError_t err = cudaFuncGetAttributes(&attrs, kernel); + if (err != cudaSuccess) { + cudaGetLastError(); // clear sticky runtime error so profiling can continue + return 0; + } + return attrs.maxThreadsPerBlock; +} + +template +int max_block_scalar_ckpt(const FusedScanProfile* /*p*/) { + int fwd = kernel_max_threads_per_block(mingru_scan_forward_ckpt_tuned); + int bwd = kernel_max_threads_per_block(mingru_scan_backward_ckpt_tuned); + if (fwd <= 0 || bwd <= 0) { + return 0; + } + return std::min(fwd, bwd); +} + +template +int max_block_vec32_ckpt(const FusedScanProfile* p) { +#ifdef PRECISION_FLOAT + return max_block_scalar_ckpt(p); +#else + if ((p->H & 1) == 0) { + int fwd = kernel_max_threads_per_block(mingru_scan_forward_ckpt_tuned_vec32); + int bwd = kernel_max_threads_per_block(mingru_scan_backward_ckpt_tuned_vec32); + if (fwd <= 0 || bwd <= 0) { + return 0; + } + return std::min(fwd, bwd); + } + return max_block_scalar_ckpt(p); +#endif +} + +template +int max_block_vec64_ckpt(const FusedScanProfile* p) { + if ((p->H % MINGRU_SCAN_VEC64_WIDTH) == 0) { + int fwd = kernel_max_threads_per_block(mingru_scan_forward_ckpt_tuned_vec64); + int bwd = kernel_max_threads_per_block(mingru_scan_backward_ckpt_tuned_vec64); + if (fwd <= 0 || bwd <= 0) { + return 0; + } + return std::min(fwd, bwd); + } + return max_block_scalar_ckpt(p); +} + +template +int max_block_vec128_ckpt(const FusedScanProfile* p) { + if ((p->H % MINGRU_SCAN_VEC128_WIDTH) == 0) { + int fwd = kernel_max_threads_per_block(mingru_scan_forward_ckpt_tuned_vec128); + int bwd = kernel_max_threads_per_block(mingru_scan_backward_ckpt_tuned_vec128); + if (fwd <= 0 || bwd <= 0) { + return 0; + } + return std::min(fwd, bwd); + } + return max_block_scalar_ckpt(p); +} + +int dispatch_scalar_ckpt_max_block(int ckpt_interval, const FusedScanProfile* p) { + switch (ckpt_interval) { + case 1: return max_block_scalar_ckpt<1>(p); + case 2: return max_block_scalar_ckpt<2>(p); + case 4: return max_block_scalar_ckpt<4>(p); + case 8: return max_block_scalar_ckpt<8>(p); + case 16: return max_block_scalar_ckpt<16>(p); + case 32: return max_block_scalar_ckpt<32>(p); + default: return 0; + } +} + +int dispatch_vec32_ckpt_max_block(int ckpt_interval, const FusedScanProfile* p) { + switch (ckpt_interval) { + case 1: return max_block_vec32_ckpt<1>(p); + case 2: return max_block_vec32_ckpt<2>(p); + case 4: return max_block_vec32_ckpt<4>(p); + case 8: return max_block_vec32_ckpt<8>(p); + case 16: return max_block_vec32_ckpt<16>(p); + case 32: return max_block_vec32_ckpt<32>(p); + default: return 0; + } +} + +int dispatch_vec64_ckpt_max_block(int ckpt_interval, const FusedScanProfile* p) { + switch (ckpt_interval) { + case 1: return max_block_vec64_ckpt<1>(p); + case 2: return max_block_vec64_ckpt<2>(p); + case 4: return max_block_vec64_ckpt<4>(p); + case 8: return max_block_vec64_ckpt<8>(p); + case 16: return max_block_vec64_ckpt<16>(p); + case 32: return max_block_vec64_ckpt<32>(p); + default: return 0; + } +} + +int dispatch_vec128_ckpt_max_block(int ckpt_interval, const FusedScanProfile* p) { + switch (ckpt_interval) { + case 1: return max_block_vec128_ckpt<1>(p); + case 2: return max_block_vec128_ckpt<2>(p); + case 4: return max_block_vec128_ckpt<4>(p); + case 8: return max_block_vec128_ckpt<8>(p); + case 16: return max_block_vec128_ckpt<16>(p); + case 32: return max_block_vec128_ckpt<32>(p); + default: return 0; + } +} + +template +float profile_fusedscan_fwd_variant_speed_block( + FusedScanProfile* p, int block_size, MingruScanVariant variant) { + FusedScanBlockLaunch args = {p, block_size}; + switch (variant) { + case MingruScanVariant::kVec128: + return profile_kernel((kernel_fn)run_fusedscan_fwd_vec128_ckpt_tuned_block, &args); + case MingruScanVariant::kVec64: + return profile_kernel((kernel_fn)run_fusedscan_fwd_vec64_ckpt_tuned_block, &args); + case MingruScanVariant::kVec32: + return profile_kernel((kernel_fn)run_fusedscan_fwd_vec32_ckpt_tuned_block, &args); + case MingruScanVariant::kScalar: + default: + return profile_kernel((kernel_fn)run_fusedscan_fwd_ckpt_tuned_block, &args); + } +} + +template +float profile_fusedscan_bwd_variant_speed_block( + FusedScanProfile* p, int block_size, MingruScanVariant variant) { + FusedScanBlockLaunch args = {p, block_size}; + switch (variant) { + case MingruScanVariant::kVec128: + return profile_kernel((kernel_fn)run_fusedscan_bwd_vec128_ckpt_tuned_block, &args); + case MingruScanVariant::kVec64: + return profile_kernel((kernel_fn)run_fusedscan_bwd_vec64_ckpt_tuned_block, &args); + case MingruScanVariant::kVec32: + return profile_kernel((kernel_fn)run_fusedscan_bwd_vec32_ckpt_tuned_block, &args); + case MingruScanVariant::kScalar: + default: + return profile_kernel((kernel_fn)run_fusedscan_bwd_ckpt_tuned_block, &args); + } +} + +float dispatch_variant_ckpt_fwd_speed(MingruScanVariant variant, int ckpt_interval, + FusedScanProfile* p, int block_size) { + switch (ckpt_interval) { + case 1: return profile_fusedscan_fwd_variant_speed_block<1>(p, block_size, variant); + case 2: return profile_fusedscan_fwd_variant_speed_block<2>(p, block_size, variant); + case 4: return profile_fusedscan_fwd_variant_speed_block<4>(p, block_size, variant); + case 8: return profile_fusedscan_fwd_variant_speed_block<8>(p, block_size, variant); + case 16: return profile_fusedscan_fwd_variant_speed_block<16>(p, block_size, variant); + case 32: return profile_fusedscan_fwd_variant_speed_block<32>(p, block_size, variant); + default: return -1.0f; + } +} + +float dispatch_variant_ckpt_bwd_speed(MingruScanVariant variant, int ckpt_interval, + FusedScanProfile* p, int block_size) { + switch (ckpt_interval) { + case 1: return profile_fusedscan_bwd_variant_speed_block<1>(p, block_size, variant); + case 2: return profile_fusedscan_bwd_variant_speed_block<2>(p, block_size, variant); + case 4: return profile_fusedscan_bwd_variant_speed_block<4>(p, block_size, variant); + case 8: return profile_fusedscan_bwd_variant_speed_block<8>(p, block_size, variant); + case 16: return profile_fusedscan_bwd_variant_speed_block<16>(p, block_size, variant); + case 32: return profile_fusedscan_bwd_variant_speed_block<32>(p, block_size, variant); + default: return -1.0f; + } +} + +int dispatch_variant_ckpt_max_block(MingruScanVariant variant, int ckpt_interval, const FusedScanProfile* p) { + switch (variant) { + case MingruScanVariant::kVec128: + return dispatch_vec128_ckpt_max_block(ckpt_interval, p); + case MingruScanVariant::kVec64: + return dispatch_vec64_ckpt_max_block(ckpt_interval, p); + case MingruScanVariant::kVec32: + return dispatch_vec32_ckpt_max_block(ckpt_interval, p); + case MingruScanVariant::kScalar: + default: + return dispatch_scalar_ckpt_max_block(ckpt_interval, p); + } +} + +bool profile_fusedscan_selector_policy(FusedScanProfile* p, int H, int block_size, + MingruScanKernelSelection selection, MingruScanKernelSelection* resolved_out, + float* fwd_ms_out, float* bwd_ms_out, float* total_ms_out) { + selection.fwd_variant = mingru_scan_resolve_variant(selection.fwd_variant, H); + selection.bwd_variant = mingru_scan_resolve_variant(selection.bwd_variant, H); + if (resolved_out) { + *resolved_out = selection; + } + if (!is_supported_checkpoint_interval(selection.ckpt_interval)) { + return false; + } + int max_fwd_block = dispatch_variant_ckpt_max_block(selection.fwd_variant, selection.ckpt_interval, p); + int max_bwd_block = dispatch_variant_ckpt_max_block(selection.bwd_variant, selection.ckpt_interval, p); + if (max_fwd_block <= 0 || max_bwd_block <= 0) { + return false; + } + if (block_size > max_fwd_block || block_size > max_bwd_block) { + return false; + } + + float fwd_ms = dispatch_variant_ckpt_fwd_speed( + selection.fwd_variant, selection.ckpt_interval, p, block_size); + if (fwd_ms < 0.0f) { + return false; + } + + float bwd_ms = dispatch_variant_ckpt_bwd_speed( + selection.bwd_variant, selection.ckpt_interval, p, block_size); + if (bwd_ms < 0.0f) { + return false; + } + + if (fwd_ms_out) { + *fwd_ms_out = fwd_ms; + } + if (bwd_ms_out) { + *bwd_ms_out = bwd_ms; + } + if (total_ms_out) { + *total_ms_out = fwd_ms + bwd_ms; + } + return true; +} + +bool profile_fusedscan_selector_bench_case(int B, int T, int H, float* baseline_total_ms_out, + float* depth2_total_ms_out) { + printf("speed fused_scan selector bench (B=%d, T=%d, H=%d, N=%d)\n", + B, T, H, B * T * H); + auto* p = create_fusedscan(B, T, H, "fused_scan_selector_bench"); + if (!p) { + printf("\n"); + return false; + } + + const int block = kFusedscanBlockSweepBlockSize; + MingruScanKernelSelection baseline = mingru_scan_select_baseline_policy(B, T, H); + MingruScanKernelSelection depth2 = mingru_scan_select_depth2_policy(B, T, H); + MingruScanKernelSelection baseline_resolved = baseline; + MingruScanKernelSelection depth2_resolved = depth2; + float baseline_fwd = 0.0f; + float baseline_bwd = 0.0f; + float baseline_total = 0.0f; + float depth2_fwd = 0.0f; + float depth2_bwd = 0.0f; + float depth2_total = 0.0f; + + bool baseline_ok = profile_fusedscan_selector_policy( + p, H, block, baseline, &baseline_resolved, &baseline_fwd, &baseline_bwd, &baseline_total); + bool depth2_ok = profile_fusedscan_selector_policy( + p, H, block, depth2, &depth2_resolved, &depth2_fwd, &depth2_bwd, &depth2_total); + + printf(" %-12s %6s %-12s %-12s %10s %10s %10s %12s\n", + "policy", "ckpt", "fwd", "bwd", "fwd_us", "bwd_us", "total_us", "speedup_x"); + if (baseline_ok) { + printf(" %-12s %6d %-12s %-12s %10.1f %10.1f %10.1f %12s\n", + "baseline", baseline_resolved.ckpt_interval, + mingru_scan_variant_name(baseline_resolved.fwd_variant), + mingru_scan_variant_name(baseline_resolved.bwd_variant), + baseline_fwd * 1000.0f, baseline_bwd * 1000.0f, baseline_total * 1000.0f, "-"); + } else { + printf(" %-12s %6s %-12s %-12s %10s %10s %10s %12s\n", + "baseline", "-", "-", "-", "-", "-", "-", "-"); + } + if (depth2_ok && baseline_ok && depth2_total > 0.0f) { + printf(" %-12s %6d %-12s %-12s %10.1f %10.1f %10.1f %12.3f\n", + "depth2", depth2_resolved.ckpt_interval, + mingru_scan_variant_name(depth2_resolved.fwd_variant), + mingru_scan_variant_name(depth2_resolved.bwd_variant), + depth2_fwd * 1000.0f, depth2_bwd * 1000.0f, depth2_total * 1000.0f, + baseline_total / depth2_total); + printf(" %-12s %6s %-12s %-12s %10s %10s %10.1f %12s\n", + "delta", "-", "-", "-", "-", "-", + (depth2_total - baseline_total) * 1000.0f, "depth2-baseline"); + } else if (depth2_ok) { + printf(" %-12s %6d %-12s %-12s %10.1f %10.1f %10.1f %12s\n", + "depth2", depth2_resolved.ckpt_interval, + mingru_scan_variant_name(depth2_resolved.fwd_variant), + mingru_scan_variant_name(depth2_resolved.bwd_variant), + depth2_fwd * 1000.0f, depth2_bwd * 1000.0f, depth2_total * 1000.0f, "-"); + } else { + printf(" %-12s %6s %-12s %-12s %10s %10s %10s %12s\n", + "depth2", "-", "-", "-", "-", "-", "-", "-"); + } + printf("\n"); + + if (baseline_total_ms_out) { + *baseline_total_ms_out = baseline_ok ? baseline_total : 0.0f; + } + if (depth2_total_ms_out) { + *depth2_total_ms_out = depth2_ok ? depth2_total : 0.0f; + } + bool ok = baseline_ok && depth2_ok; + alloc_free(&p->alloc); + free(p); + return ok; +} + +bool profile_fusedscan_selector_bench() { + printf("fused_scan selector benchmark (baseline vs depth2 policy)\n"); + print_fusedscan_kernel_diagnostics_once(); + printf(" launch block size: %d\n", kFusedscanBlockSweepBlockSize); + print_selected_sweep_sizes(); + + int wins = 0; + int losses = 0; + int ties = 0; + int measured = 0; + double baseline_sum_ms = 0.0; + double depth2_sum_ms = 0.0; + double sum_log_speedup = 0.0; + const double tie_eps = 1e-4; + + int nb = gSelectedBCount; + int nt = gSelectedTCount; + int nh = gSelectedHCount; + for (int ib = 0; ib < nb; ib++) { + for (int it = 0; it < nt; it++) { + for (int ih = 0; ih < nh; ih++) { + float baseline_total_ms = 0.0f; + float depth2_total_ms = 0.0f; + bool ok = profile_fusedscan_selector_bench_case( + gSelectedBVals[ib], gSelectedTVals[it], gSelectedHVals[ih], + &baseline_total_ms, &depth2_total_ms); + if (!ok || baseline_total_ms <= 0.0f || depth2_total_ms <= 0.0f) { + continue; + } + + measured++; + baseline_sum_ms += baseline_total_ms; + depth2_sum_ms += depth2_total_ms; + if (depth2_total_ms < baseline_total_ms * (1.0 - tie_eps)) { + wins++; + } else if (depth2_total_ms > baseline_total_ms * (1.0 + tie_eps)) { + losses++; + } else { + ties++; + } + sum_log_speedup += log((double)baseline_total_ms / (double)depth2_total_ms); + } + } + } + + printf("selector benchmark summary\n"); + printf(" measured cases: %d\n", measured); + printf(" depth2 wins/losses/ties: %d/%d/%d\n", wins, losses, ties); + if (measured > 0 && depth2_sum_ms > 0.0) { + double overall_speedup = baseline_sum_ms / depth2_sum_ms; + double geomean_speedup = exp(sum_log_speedup / measured); + printf(" total baseline_us: %.1f\n", baseline_sum_ms * 1000.0); + printf(" total depth2_us: %.1f\n", depth2_sum_ms * 1000.0); + printf(" overall speedup: %.4fx\n", overall_speedup); + printf(" geomean speedup: %.4fx\n", geomean_speedup); + } + printf("\n"); + return measured > 0; +} + +void profile_fusedscan_sweep_speed_case(int B, int T, int H) { + printf("speed fused_scan sweep (B=%d, T=%d, H=%d, N=%d)\n", + B, T, H, B*T*H); + int N = B * T; + auto* p = create_fusedscan(B, T, H, "fused_scan_sweep"); + if (!p) { + printf("\n"); + return; + } + + const int intervals = gSelectedCheckpointCount; + float overall_best_total = std::numeric_limits::infinity(); + float overall_best_fwd = 0.0f; + float overall_best_bwd = 0.0f; + int overall_best_ckpt = -1; + const char* overall_best_variant = "n/a"; + float overall_best_combo_total = std::numeric_limits::infinity(); + float overall_best_combo_fwd = 0.0f; + float overall_best_combo_bwd = 0.0f; + int overall_best_combo_ckpt = -1; + const char* overall_best_combo_fwd_variant = "n/a"; + const char* overall_best_combo_bwd_variant = "n/a"; + bool overall_any_combo = false; + bool overall_any_legal = false; + const int block = kFusedscanBlockSweepBlockSize; + + for (int i = 0; i < intervals; i++) { + int ckpt_interval = gSelectedCheckpointVals[i]; + + float best_total = std::numeric_limits::infinity(); + float best_fwd = 0.0f; + float best_bwd = 0.0f; + const char* best_variant = "n/a"; + float best_combo_fwd = std::numeric_limits::infinity(); + float best_combo_bwd = std::numeric_limits::infinity(); + const char* best_combo_fwd_variant = "n/a"; + const char* best_combo_bwd_variant = "n/a"; + bool any_legal = false; + printf(" --------------------------------------------------------------------------------------------------------\n"); + printf(" ckpt=%d results:\n", ckpt_interval); + printf(" %-24s %10s %10s %10s\n", + "variant", "fwd_us", "bwd_us", "total_us"); + + if (!is_supported_checkpoint_interval(ckpt_interval)) { + printf(" %-24s %10s %10s %10s\n", "log_scalar", "-", "-", "-"); + printf(" %-24s %10s %10s %10s\n", "log_vec32", "-", "-", "-"); + printf(" %-24s %10s %10s %10s\n", "log_vec64", "-", "-", "-"); + printf(" %-24s %10s %10s %10s\n", "log_vec128", "-", "-", "-"); + printf(" %-24s %10s %10s %10s\n", "winner: (none)", "-", "-", "-"); + printf(" %-24s %10s %10s %10s\n", "best_combo: (none)", "-", "-", "-"); + continue; + } + + const char* variant_names[] = {"log_scalar", "log_vec32", "log_vec64", "log_vec128"}; + float (*variant_speed_fns[])(int, FusedScanProfile*, int, int, float*, float*) = { + dispatch_scalar_ckpt_speed, + dispatch_vec32_ckpt_speed, + dispatch_vec64_ckpt_speed, + dispatch_vec128_ckpt_speed, + }; + int (*variant_max_block_fns[])(int, const FusedScanProfile*) = { + dispatch_scalar_ckpt_max_block, + dispatch_vec32_ckpt_max_block, + dispatch_vec64_ckpt_max_block, + dispatch_vec128_ckpt_max_block, + }; + int variant_max_blocks[] = { + variant_max_block_fns[0](ckpt_interval, p), + variant_max_block_fns[1](ckpt_interval, p), + variant_max_block_fns[2](ckpt_interval, p), + variant_max_block_fns[3](ckpt_interval, p), + }; + static constexpr int kVariantCount = sizeof(variant_names) / sizeof(variant_names[0]); + + for (int variant_i = 0; variant_i < kVariantCount; variant_i++) { + if (variant_max_blocks[variant_i] <= 0 || block > variant_max_blocks[variant_i]) { + printf(" %-24s %10s %10s %10s\n", variant_names[variant_i], "-", "-", "-"); + continue; + } + + float fwd = 0.0f; + float bwd = 0.0f; + float total = variant_speed_fns[variant_i](ckpt_interval, p, N, block, &fwd, &bwd); + printf(" %-24s %10.1f %10.1f %10.1f\n", + variant_names[variant_i], fwd * 1000.0f, bwd * 1000.0f, total * 1000.0f); + any_legal = true; + + if (fwd < best_combo_fwd) { + best_combo_fwd = fwd; + best_combo_fwd_variant = variant_names[variant_i]; + } + if (bwd < best_combo_bwd) { + best_combo_bwd = bwd; + best_combo_bwd_variant = variant_names[variant_i]; + } + if (total < best_total) { + best_total = total; + best_fwd = fwd; + best_bwd = bwd; + best_variant = variant_names[variant_i]; + } + if (total < overall_best_total) { + overall_best_total = total; + overall_best_fwd = fwd; + overall_best_bwd = bwd; + overall_best_ckpt = ckpt_interval; + overall_best_variant = variant_names[variant_i]; + overall_any_legal = true; + } + } + + if (any_legal) { + char winner_label[64]; + snprintf(winner_label, sizeof(winner_label), "winner: %s", best_variant); + printf(" %-24s %10.1f %10.1f %10.1f\n", + winner_label, best_fwd * 1000.0f, best_bwd * 1000.0f, best_total * 1000.0f); + float best_combo_total = best_combo_fwd + best_combo_bwd; + printf(" %-24s %10.1f %10.1f %10.1f\n", + "best_combo", best_combo_fwd * 1000.0f, best_combo_bwd * 1000.0f, best_combo_total * 1000.0f); + printf(" (fwd from %s, bwd from %s)\n", + best_combo_fwd_variant, best_combo_bwd_variant); + if (best_combo_total < overall_best_combo_total) { + overall_best_combo_total = best_combo_total; + overall_best_combo_fwd = best_combo_fwd; + overall_best_combo_bwd = best_combo_bwd; + overall_best_combo_ckpt = ckpt_interval; + overall_best_combo_fwd_variant = best_combo_fwd_variant; + overall_best_combo_bwd_variant = best_combo_bwd_variant; + overall_any_combo = true; + } + } else { + printf(" %-24s %10s %10s %10s\n", "winner: (none)", "-", "-", "-"); + printf(" %-24s %10s %10s %10s\n", "best_combo: (none)", "-", "-", "-"); + } + } + + printf(" --------------------------------------------------------------------------------------------------------\n"); + if (overall_any_legal) { + printf(" overall winner:\n"); + printf(" %6s %-24s %10s %10s %10s\n", + "ckpt", "variant", "fwd_us", "bwd_us", "total_us"); + printf(" %6d %-24s %10.1f %10.1f %10.1f\n", + overall_best_ckpt, overall_best_variant, + overall_best_fwd * 1000.0f, overall_best_bwd * 1000.0f, overall_best_total * 1000.0f); + } else { + printf(" overall winner:\n"); + printf(" %6s %-24s %10s %10s %10s\n", + "ckpt", "variant", "fwd_us", "bwd_us", "total_us"); + printf(" %6s %-24s %10s %10s %10s\n", + "-", "(none)", "-", "-", "-"); + } + + if (overall_any_combo) { + printf(" overall best_combo (single ckpt):\n"); + printf(" %6s %-24s %10s %10s %10s\n", + "ckpt", "combo", "fwd_us", "bwd_us", "total_us"); + printf(" %6d %-24s %10.1f %10.1f %10.1f\n", + overall_best_combo_ckpt, "best_combo", + overall_best_combo_fwd * 1000.0f, + overall_best_combo_bwd * 1000.0f, + overall_best_combo_total * 1000.0f); + printf(" (fwd from %s, bwd from %s)\n", + overall_best_combo_fwd_variant, overall_best_combo_bwd_variant); + } else { + printf(" overall best_combo (single ckpt):\n"); + printf(" %6s %-24s %10s %10s %10s\n", + "ckpt", "combo", "fwd_us", "bwd_us", "total_us"); + printf(" %6s %-24s %10s %10s %10s\n", + "-", "(none)", "-", "-", "-"); + } + + printf("\n"); + alloc_free(&p->alloc); + free(p); +} + +bool profile_fusedscan_sweep() { + printf("fused_scan sweep speed (no correctness checks)\n"); + print_fusedscan_kernel_diagnostics_once(); + print_selected_checkpoint_intervals(); + printf(" launch block size: %d\n", kFusedscanBlockSweepBlockSize); + print_selected_sweep_sizes(); + int nb = gSelectedBCount; + int nt = gSelectedTCount; + int nh = gSelectedHCount; + for (int ib = 0; ib < nb; ib++) { + for (int it = 0; it < nt; it++) { + for (int ih = 0; ih < nh; ih++) { + profile_fusedscan_sweep_speed_case(gSelectedBVals[ib], gSelectedTVals[it], gSelectedHVals[ih]); + } + } + } + return true; +} + struct PPOProfile { PPOKernelArgs ka; PPOGraphArgs ga; @@ -693,54 +2176,106 @@ void profile_envspeed(int total_agents, int num_buffers, int num_threads, int ho printf("\n"); } +inline bool profile_is(const char* profile, const char* name) { + return strcmp(profile, name) == 0; +} + +inline bool should_run_kernel_profile(const char* profile, bool run_all, const char* name) { + return run_all || profile_is(profile, "kernels") || profile_is(profile, name); +} + +bool is_known_profile_name(const char* profile) { + static const char* kKnownProfiles[] = { + "all", + "kernels", + "mingrugate", + "logcoeffsvals", + "fusedscan", + "fusedscan_correctness", + "fusedscan_sweep", + "fusedscan_selector_bench", + "samplelogits", + "ppoloss", + "im2col", + "envspeed", + }; + static constexpr int kKnownProfileCount = sizeof(kKnownProfiles) / sizeof(kKnownProfiles[0]); + for (int profile_i = 0; profile_i < kKnownProfileCount; profile_i++) { + if (profile_is(profile, kKnownProfiles[profile_i])) { + return true; + } + } + return false; +} + int main(int argc, char** argv) { if (argc < 2) { print_usage(argv[0]); return 1; } const char* profile = argv[1]; + const char* ckpt_intervals_csv = nullptr; + const char* b_sizes_csv = nullptr; + const char* t_sizes_csv = nullptr; + const char* h_sizes_csv = nullptr; int buffers = BUF, threads = 16, horizon = T_; int total_agents = BR * buffers; - for (int i = 2; i < argc - 1; i++) { - if (strcmp(argv[i], "--buffers") == 0) buffers = atoi(argv[++i]); - else if (strcmp(argv[i], "--threads") == 0) threads = atoi(argv[++i]); - else if (strcmp(argv[i], "--horizon") == 0) horizon = atoi(argv[++i]); - else if (strcmp(argv[i], "--total-agents") == 0) total_agents = atoi(argv[++i]); + for (int arg_i = 2; arg_i < argc - 1; arg_i++) { + if (strcmp(argv[arg_i], "--buffers") == 0) buffers = atoi(argv[++arg_i]); + else if (strcmp(argv[arg_i], "--threads") == 0) threads = atoi(argv[++arg_i]); + else if (strcmp(argv[arg_i], "--horizon") == 0) horizon = atoi(argv[++arg_i]); + else if (strcmp(argv[arg_i], "--total-agents") == 0) total_agents = atoi(argv[++arg_i]); + else if (strcmp(argv[arg_i], "--ckpt-intervals") == 0) ckpt_intervals_csv = argv[++arg_i]; + else if (strcmp(argv[arg_i], "--b-sizes") == 0) b_sizes_csv = argv[++arg_i]; + else if (strcmp(argv[arg_i], "--t-sizes") == 0) t_sizes_csv = argv[++arg_i]; + else if (strcmp(argv[arg_i], "--h-sizes") == 0) h_sizes_csv = argv[++arg_i]; + } + reset_checkpoint_interval_selection(); + reset_sweep_size_selection(); + if (ckpt_intervals_csv) { + set_checkpoint_intervals_from_csv(ckpt_intervals_csv); + } + if (b_sizes_csv) { + set_b_sizes_from_csv(b_sizes_csv); + } + if (t_sizes_csv) { + set_t_sizes_from_csv(t_sizes_csv); + } + if (h_sizes_csv) { + set_h_sizes_from_csv(h_sizes_csv); } warmup_gpu(); - bool run_all = strcmp(profile, "all") == 0; + bool run_all = profile_is(profile, "all"); + bool ok = true; - if (strcmp(profile, "kernels") == 0 || strcmp(profile, "mingrugate") == 0 || run_all) + if (should_run_kernel_profile(profile, run_all, "mingrugate")) profile_mingrugate(BR, H_); - if (strcmp(profile, "kernels") == 0 || strcmp(profile, "logcoeffsvals") == 0 || run_all) + if (should_run_kernel_profile(profile, run_all, "logcoeffsvals")) profile_logcoeffs(BT, T_, H_); - if (strcmp(profile, "kernels") == 0 || strcmp(profile, "fusedscan") == 0 || run_all) + if (should_run_kernel_profile(profile, run_all, "fusedscan")) profile_fusedscan(BT, T_, H_); - if (strcmp(profile, "kernels") == 0 || strcmp(profile, "samplelogits") == 0 || run_all) + if (profile_is(profile, "fusedscan_correctness")) + ok &= profile_fusedscan_correctness_suite(); + if (profile_is(profile, "fusedscan_sweep")) + ok &= profile_fusedscan_sweep(); + if (profile_is(profile, "fusedscan_selector_bench")) + ok &= profile_fusedscan_selector_bench(); + if (should_run_kernel_profile(profile, run_all, "samplelogits")) profile_samplelogits(BR, A_); - if (strcmp(profile, "kernels") == 0 || strcmp(profile, "ppoloss") == 0 || run_all) + if (should_run_kernel_profile(profile, run_all, "ppoloss")) profile_ppoloss(BT, T_, A_); - if (strcmp(profile, "kernels") == 0 || strcmp(profile, "im2col") == 0 || run_all) { + if (should_run_kernel_profile(profile, run_all, "im2col")) { profile_im2col(1024, N3_C1_IC, N3_MAP_H, N3_MAP_W, N3_C1_K, N3_C1_S, N3_C1_OH, N3_C1_OW); profile_im2col(1024, N3_C2_IC, N3_C1_OH, N3_C1_OW, N3_C2_K, N3_C2_S, N3_C2_OH, N3_C2_OW); } - if (strcmp(profile, "envspeed") == 0 || run_all) + if (profile_is(profile, "envspeed") || run_all) profile_envspeed(total_agents, buffers, threads, horizon); - if (!run_all - && strcmp(profile, "kernels") != 0 - && strcmp(profile, "mingrugate") != 0 - && strcmp(profile, "logcoeffsvals") != 0 - && strcmp(profile, "fusedscan") != 0 - && strcmp(profile, "samplelogits") != 0 - && strcmp(profile, "ppoloss") != 0 - && strcmp(profile, "im2col") != 0 - && strcmp(profile, "envspeed") != 0 - ) { + if (!is_known_profile_name(profile)) { printf("Unknown profile: %s\n\n", profile); print_usage(argv[0]); return 1; } - return 0; + return ok ? 0 : 1; }