From 1e64b6ee1ab73315171b414f2f73ba512430425b Mon Sep 17 00:00:00 2001 From: Valtteri Valo Date: Tue, 19 May 2026 22:16:57 +0300 Subject: [PATCH] add cuda aurora optimizer flag --- config/default.ini | 2 + src/bindings.cu | 21 ++++ src/muon.cu | 234 ++++++++++++++++++++++++++++++++++++++------- src/pufferlib.cu | 5 +- 4 files changed, 229 insertions(+), 33 deletions(-) diff --git a/config/default.ini b/config/default.ini index 29bc1808b7..122a677578 100644 --- a/config/default.ini +++ b/config/default.ini @@ -81,6 +81,8 @@ min_ent_coef_ratio = 0.1 beta1 = 0.95 beta2 = 0.999 eps = 1e-12 +aurora = 0 +aurora_weight_decay = 0.025 minibatch_size = 8192 horizon = 64 vtrace_rho_clip = 1.0 diff --git a/src/bindings.cu b/src/bindings.cu index 64be61194d..1e096b8d31 100644 --- a/src/bindings.cu +++ b/src/bindings.cu @@ -3,6 +3,7 @@ #include #include #include +#include #include "pufferlib.cu" #define _PUFFER_STRINGIFY(x) #x @@ -286,6 +287,22 @@ double get_config(py::dict& kwargs, const char* key) { } } +bool get_config_flag(py::dict& kwargs, const char* key) { + double value = get_config(kwargs, key); + if (value != 0.0 && value != 1.0) { + throw std::runtime_error(std::string(key) + " must be 0 or 1"); + } + return value == 1.0; +} + +double get_config_nonnegative(py::dict& kwargs, const char* key) { + double value = get_config(kwargs, key); + if (!std::isfinite(value) || value < 0.0) { + throw std::runtime_error(std::string(key) + " must be finite and nonnegative"); + } + return value; +} + Dict* py_dict_to_c_dict(py::dict py_dict) { Dict* c_dict = create_dict(py_dict.size()); for (auto item : py_dict) { @@ -409,6 +426,8 @@ std::unique_ptr create_pufferl(py::dict args) { hypers.beta1 = get_config(train_kwargs, "beta1"); hypers.beta2 = get_config(train_kwargs, "beta2"); hypers.eps = get_config(train_kwargs, "eps"); + hypers.aurora = get_config_flag(train_kwargs, "aurora"); + hypers.aurora_weight_decay = get_config_nonnegative(train_kwargs, "aurora_weight_decay"); // Training hypers.minibatch_size = get_config(train_kwargs, "minibatch_size"); hypers.replay_ratio = get_config(train_kwargs, "replay_ratio"); @@ -547,6 +566,8 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("beta1", &HypersT::beta1) .def_readwrite("beta2", &HypersT::beta2) .def_readwrite("eps", &HypersT::eps) + .def_readwrite("aurora", &HypersT::aurora) + .def_readwrite("aurora_weight_decay", &HypersT::aurora_weight_decay) .def_readwrite("total_timesteps", &HypersT::total_timesteps) .def_readwrite("max_grad_norm", &HypersT::max_grad_norm) .def_readwrite("clip_coef", &HypersT::clip_coef) diff --git a/src/muon.cu b/src/muon.cu index 1665dcf1ee..f043161387 100644 --- a/src/muon.cu +++ b/src/muon.cu @@ -55,6 +55,17 @@ __global__ void muon_nesterov(float* __restrict__ mb, precision_t* __restrict__ } } +__global__ void aurora_nesterov(float* __restrict__ mb, precision_t* __restrict__ gc, + float mu, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float g = to_float(gc[idx]); + float m = mu * mb[idx] + (1.0f - mu) * g; + mb[idx] = m; + gc[idx] = from_float((1.0f - mu) * g + mu * m); + } +} + // Fused weight update: wb = wb * (1 - lr*wd) - lr * scale * update __global__ void muon_weight_update(float* __restrict__ wb, const precision_t* __restrict__ update, const float* __restrict__ lr_ptr, float wd, float scale, int n) { @@ -75,6 +86,67 @@ __global__ void muon_clip_norm(precision_t* __restrict__ dst, } } +__global__ void aurora_init_row_scales(float* __restrict__ row_scale, + const precision_t* __restrict__ src, int rows, int cols, + bool transposed, float eps) { + int row = blockIdx.x; + if (row >= rows) return; + float sum = 0.0f; + for (int col = threadIdx.x; col < cols; col += blockDim.x) { + int idx = transposed ? col * rows + row : row * cols + col; + float v = to_float(src[idx]); + sum += v * v; + } + __shared__ float sdata[256]; + int tid = threadIdx.x; + sdata[tid] = sum; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) sdata[tid] += sdata[tid + s]; + __syncthreads(); + } + if (tid == 0) { + row_scale[row] = 1.0f / fmaxf(sqrtf(sdata[0]), eps); + } +} + +__global__ void aurora_apply_row_scales(precision_t* __restrict__ dst, + const precision_t* __restrict__ src, const float* __restrict__ row_scale, + int R, int C, bool wide) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int n = R * C; + if (idx >= n) return; + int r = idx / C; + int c = idx - r * C; + float scale = wide ? row_scale[c] : row_scale[r]; + dst[idx] = from_float(to_float(src[idx]) * scale); +} + +__global__ void aurora_update_row_scales(float* __restrict__ row_scale, + const precision_t* __restrict__ update, int rows, int cols, + bool transposed, float target_row_sq, float beta, float eps_sq) { + int row = blockIdx.x; + if (row >= rows) return; + float sum = 0.0f; + for (int col = threadIdx.x; col < cols; col += blockDim.x) { + int idx = transposed ? col * rows + row : row * cols + col; + float v = to_float(update[idx]); + sum += v * v; + } + __shared__ float sdata[256]; + int tid = threadIdx.x; + sdata[tid] = sum; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) sdata[tid] += sdata[tid + s]; + __syncthreads(); + } + if (tid == 0) { + float row_sq = fmaxf(sdata[0], eps_sq); + row_scale[row] *= powf(target_row_sq / row_sq, beta); + } +} + static constexpr double ns_coeffs[5][3] = { {4.0848, -6.8946, 2.9270}, {3.9505, -6.3029, 2.6377}, @@ -83,16 +155,27 @@ static constexpr double ns_coeffs[5][3] = { {2.8366, -3.0525, 1.2012}, }; +static constexpr int muon_polar_iters = 5; +static constexpr int aurora_polar_iters = 12; +static constexpr double aurora_simple_quintic[3] = {2.0, -1.5, 0.5}; + +static bool aurora_matrix_eligible(const AllocEntry& e) { + if (ndim(e.shape) < 2) return false; + long R = e.shape[0], C = numel(e.shape) / R; + return min(R, C) >= 2; +} + struct Muon { - double momentum, weight_decay, eps; + double momentum, weight_decay, aurora_weight_decay, eps; + bool aurora; float lr_val_init; float* lr_ptr; float* lr_derived_ptr; float* norm_ptr; float* grad_norm_ptr; FloatTensor lr_puf, lr_derived_puf, ns_norm_puf, grad_norm_puf; - FloatTensor mb_puf; - PrecisionTensor gram, gram_buf, x_buf; + FloatTensor mb_puf, aurora_row_scale; + PrecisionTensor gram, gram_buf, x_buf, orig_buf; FloatTensor norm_partials; long max_M, max_N; Allocator* param_alloc; // params allocator — shapes used by muon_step @@ -101,11 +184,13 @@ struct Muon { }; void muon_init(Muon* m, Allocator* param_alloc, double lr_val, - double momentum, double eps, double weight_decay, - Allocator* alloc) { + double momentum, double eps, double weight_decay, double aurora_weight_decay, + Allocator* alloc, bool aurora = false) { m->momentum = momentum; m->weight_decay = weight_decay; + m->aurora_weight_decay = aurora_weight_decay; m->eps = eps; + m->aurora = aurora; m->lr_val_init = (float)lr_val; m->lr_ptr = nullptr; m->lr_derived_ptr = nullptr; @@ -124,13 +209,16 @@ void muon_init(Muon* m, Allocator* param_alloc, double lr_val, alloc_register(alloc, &m->mb_puf); alloc_register(alloc, &m->norm_partials); alloc_register(alloc, &m->grad_norm_puf); - long max_M = 0, max_N = 0; + long max_M = 0, max_N = 0, max_aurora_N = 0; for (int _i = 0; _i < param_alloc->num_regs; _i++) { AllocEntry& e = param_alloc->regs[_i]; if (ndim(e.shape) >= 2) { long R = e.shape[0], C = numel(e.shape) / R; max_M = max(max_M, min(R, C)); max_N = max(max_N, max(R, C)); + if (aurora && aurora_matrix_eligible(e) && R != C) { + max_aurora_N = max(max_aurora_N, max(R, C)); + } } } if (max_M > 0) { @@ -142,6 +230,12 @@ void muon_init(Muon* m, Allocator* param_alloc, double lr_val, alloc_register(alloc, &m->gram); alloc_register(alloc, &m->gram_buf); alloc_register(alloc, &m->x_buf); + if (max_aurora_N > 0) { + m->orig_buf = {.shape = {max_M, max_N}}; + m->aurora_row_scale = {.shape = {max_aurora_N}}; + alloc_register(alloc, &m->orig_buf); + alloc_register(alloc, &m->aurora_row_scale); + } alloc_register(alloc, &m->ns_norm_puf); } } @@ -156,6 +250,63 @@ void muon_post_create(Muon* m) { cudaMemset(m->mb_puf.data, 0, numel(m->mb_puf.shape) * sizeof(float)); } +PrecisionTensor muon_polar_project(Muon* m, PrecisionTensor x, PrecisionTensor x_buf, + PrecisionTensor gram, PrecisionTensor gram_buf, bool tall, + long R, long C, long M, long N, cudaStream_t stream) { + int nblk = min((int)grid_size(numel(x.shape)), 256); + muon_norm_partials<<>>( + m->norm_partials.data, x.data, numel(x.shape)); + muon_norm_reduce<<<1, 256, 0, stream>>>(m->norm_ptr, m->norm_partials.data, nblk); + muon_norm_apply<<>>( + x.data, m->norm_ptr, 1e-7f, numel(x.shape)); + + cublasOperation_t gram_op_a = tall ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t gram_op_b = tall ? CUBLAS_OP_N : CUBLAS_OP_T; + for (int i = 0; i < muon_polar_iters; ++i) { + PrecisionTensor& src = (i % 2 == 0) ? x : x_buf; + PrecisionTensor& dst = (i % 2 == 0) ? x_buf : x; + cublasGemmExDense(gram_op_a, gram_op_b, (int)M, (int)M, (int)N, + src.data, src.data, gram.data, stream); + puf_copy(&gram_buf, &gram, stream); + puf_addmm_nn(&gram, &gram, &gram_buf, ns_coeffs[i][2], ns_coeffs[i][1], stream); + puf_copy(&dst, &src, stream); + cublasGemmExDense(CUBLAS_OP_N, CUBLAS_OP_N, (int)R, (int)C, (int)M, + tall ? src.data : gram_buf.data, tall ? gram_buf.data : src.data, dst.data, + stream, 1.0f, ns_coeffs[i][0]); + } + + return (muon_polar_iters % 2 == 0) ? x : x_buf; +} + +PrecisionTensor aurora_polar_project(Muon* m, PrecisionTensor x, PrecisionTensor x_buf, + PrecisionTensor gram, PrecisionTensor gram_buf, bool tall, + long R, long C, long M, long N, cudaStream_t stream) { + int nblk = min((int)grid_size(numel(x.shape)), 256); + muon_norm_partials<<>>( + m->norm_partials.data, x.data, numel(x.shape)); + muon_norm_reduce<<<1, 256, 0, stream>>>(m->norm_ptr, m->norm_partials.data, nblk); + muon_norm_apply<<>>( + x.data, m->norm_ptr, 1e-7f, numel(x.shape)); + + cublasOperation_t gram_op_a = tall ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t gram_op_b = tall ? CUBLAS_OP_N : CUBLAS_OP_T; + for (int i = 0; i < aurora_polar_iters; ++i) { + PrecisionTensor& src = (i % 2 == 0) ? x : x_buf; + PrecisionTensor& dst = (i % 2 == 0) ? x_buf : x; + cublasGemmExDense(gram_op_a, gram_op_b, (int)M, (int)M, (int)N, + src.data, src.data, gram.data, stream); + puf_copy(&gram_buf, &gram, stream); + puf_addmm_nn(&gram, &gram, &gram_buf, aurora_simple_quintic[2], + aurora_simple_quintic[1], stream); + puf_copy(&dst, &src, stream); + cublasGemmExDense(CUBLAS_OP_N, CUBLAS_OP_N, (int)R, (int)C, (int)M, + tall ? src.data : gram_buf.data, tall ? gram_buf.data : src.data, dst.data, + stream, 1.0f, aurora_simple_quintic[0]); + } + + return (aurora_polar_iters % 2 == 0) ? x : x_buf; +} + void muon_step(Muon* m, FloatTensor weights, PrecisionTensor grads, float max_grad_norm, cudaStream_t stream = 0) { // Multi-GPU support: simple all-reduce over a contiguous grad buffer if (m->nccl_comm != nullptr && m->world_size > 1) { @@ -171,9 +322,10 @@ void muon_step(Muon* m, FloatTensor weights, PrecisionTensor grads, float max_gr muon_clip_norm<<>>( grads.data, m->grad_norm_ptr, max_grad_norm, 1e-6f, numel(grads.shape)); - // Nesterov momentum - muon_nesterov<<mb_puf.shape)), BLOCK_SIZE, 0, stream>>>( - m->mb_puf.data, grads.data, (float)m->momentum, numel(m->mb_puf.shape)); + if (!m->aurora) { + muon_nesterov<<mb_puf.shape)), BLOCK_SIZE, 0, stream>>>( + m->mb_puf.data, grads.data, (float)m->momentum, numel(m->mb_puf.shape)); + } long offset = 0; for (int _i = 0; _i < m->param_alloc->num_regs; _i++) { @@ -182,46 +334,64 @@ void muon_step(Muon* m, FloatTensor weights, PrecisionTensor grads, float max_gr float* wb_ptr = weights.data + offset; long ne = numel(e.shape); const precision_t* update_ptr = gc_ptr; + bool use_aurora_matrix_update = m->aurora && aurora_matrix_eligible(e); + float param_weight_decay = use_aurora_matrix_update + ? (float)m->aurora_weight_decay + : (float)m->weight_decay; float scale = 1.0f; - // Orthogonalize the update + if (m->aurora) { + if (use_aurora_matrix_update) { + aurora_nesterov<<>>( + m->mb_puf.data + offset, gc_ptr, (float)m->momentum, ne); + } else { + muon_nesterov<<>>( + m->mb_puf.data + offset, gc_ptr, (float)m->momentum, ne); + } + } + if (ndim(e.shape) >= 2) { long R = e.shape[0], C = ne / R; long M = min(R, C), N = max(R, C); bool tall = R > C; + bool wide = R < C; + bool use_aurora_preconditioner = use_aurora_matrix_update && R != C; PrecisionTensor x = {.data = gc_ptr, .shape = {R, C}}; PrecisionTensor x_buf = {.data = m->x_buf.data, .shape = {R, C}}; PrecisionTensor gram = {.data = m->gram.data, .shape = {M, M}}; PrecisionTensor gram_buf = {.data = m->gram_buf.data, .shape = {M, M}}; - int nblk = min((int)grid_size(numel(x.shape)), 256); - muon_norm_partials<<>>( - m->norm_partials.data, x.data, numel(x.shape)); - muon_norm_reduce<<<1, 256, 0, stream>>>(m->norm_ptr, m->norm_partials.data, nblk); - muon_norm_apply<<>>( - x.data, m->norm_ptr, 1e-7f, numel(x.shape)); - - cublasOperation_t gram_op_a = tall ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t gram_op_b = tall ? CUBLAS_OP_N : CUBLAS_OP_T; - for (int i = 0; i < 5; ++i) { - PrecisionTensor& src = (i % 2 == 0) ? x : x_buf; - PrecisionTensor& dst = (i % 2 == 0) ? x_buf : x; - cublasGemmExDense(gram_op_a, gram_op_b, (int)M, (int)M, (int)N, - src.data, src.data, gram.data, stream); - puf_copy(&gram_buf, &gram, stream); - puf_addmm_nn(&gram, &gram, &gram_buf, ns_coeffs[i][2], ns_coeffs[i][1], stream); - puf_copy(&dst, &src, stream); - cublasGemmExDense(CUBLAS_OP_N, CUBLAS_OP_N, (int)R, (int)C, (int)M, - tall ? src.data : gram_buf.data, tall ? gram_buf.data : src.data, dst.data, - stream, 1.0f, ns_coeffs[i][0]); + if (use_aurora_preconditioner) { + PrecisionTensor orig_buf = {.data = m->orig_buf.data, .shape = {R, C}}; + puf_copy(&orig_buf, &x, stream); + aurora_init_row_scales<<<(int)N, 256, 0, stream>>>( + m->aurora_row_scale.data, orig_buf.data, (int)N, (int)M, wide, 1e-7f); + aurora_apply_row_scales<<>>( + x.data, orig_buf.data, m->aurora_row_scale.data, (int)R, (int)C, wide); + PrecisionTensor aurora_update = aurora_polar_project( + m, x, x_buf, gram, gram_buf, tall, R, C, M, N, stream); + float target_row_sq = (float)M / (float)N; + aurora_update_row_scales<<<(int)N, 256, 0, stream>>>( + m->aurora_row_scale.data, aurora_update.data, (int)N, (int)M, wide, + target_row_sq, 0.5f, 1e-14f); + aurora_apply_row_scales<<>>( + x.data, orig_buf.data, m->aurora_row_scale.data, (int)R, (int)C, wide); } - update_ptr = x_buf.data; + if (use_aurora_preconditioner) { + PrecisionTensor aurora_update = aurora_polar_project( + m, x, x_buf, gram, gram_buf, tall, R, C, M, N, stream); + update_ptr = aurora_update.data; + } else { + PrecisionTensor muon_update = muon_polar_project( + m, x, x_buf, gram, gram_buf, tall, R, C, M, N, stream); + update_ptr = muon_update.data; + } scale = sqrtf(fmaxf(1.0f, (float)R / (float)C)); } muon_weight_update<<>>( - wb_ptr, update_ptr, m->lr_ptr, (float)m->weight_decay, scale, (int)ne); + wb_ptr, update_ptr, m->lr_ptr, param_weight_decay, scale, (int)ne); offset += ne; } } diff --git a/src/pufferlib.cu b/src/pufferlib.cu index 0014999e64..537abf7bf7 100644 --- a/src/pufferlib.cu +++ b/src/pufferlib.cu @@ -274,6 +274,8 @@ typedef struct { float beta1; float beta2; float eps; + float aurora_weight_decay; + bool aurora; // Training int minibatch_size; float replay_ratio; @@ -2026,7 +2028,8 @@ std::unique_ptr create_pufferl_impl(HypersT& hypers, pufferl->advantages_puf = {.shape = {total_agents, horizon}}; alloc_register(acts, &pufferl->advantages_puf); - muon_init(&pufferl->muon, params, hypers.lr, hypers.beta1, hypers.eps, 0.0, acts); + muon_init(&pufferl->muon, params, hypers.lr, hypers.beta1, hypers.eps, + 0.0, hypers.aurora_weight_decay, acts, hypers.aurora); pufferl->muon.nccl_comm = pufferl->nccl_comm; pufferl->muon.world_size = hypers.world_size;