diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index b584ceed..61feb6b4 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,543 +1,83 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') +from __future__ import annotations +from typing import Optional -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) +import torch - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) +_TRITON: Optional[object] +try: + # Triton kernels (CUDA fast-path). + from megablocks.backend import triton_kernels as _TRITON # type: ignore +except Exception: + _TRITON = None - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a +def _fallback(): + # Pure-PyTorch implementation used on Ascend NPU (and when Triton isn't available). + from megablocks.backend import npu_kernels as _NPU # type: ignore - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) + return _NPU - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - offsets += BLOCK_X +def _should_fallback(x: torch.Tensor) -> bool: + # NPU runtime doesn't support Triton kernels. + return x.device.type == "npu" or _TRITON is None def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out + if _should_fallback(x): + return _fallback().padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k) + return _TRITON.padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k) # type: ignore[union-attr] def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out + if _should_fallback(x): + return _fallback().gather(x, indices, bin_ids, weights, bins, top_k) + return _TRITON.gather(x, indices, bin_ids, weights, bins, top_k) # type: ignore[union-attr] def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + if _should_fallback(x): + return _fallback().padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k) + return _TRITON.padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k) # type: ignore[union-attr] def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) + if _should_fallback(x): + return _fallback().scatter(x, indices, bin_ids, weights, bins, top_k) + return _TRITON.scatter(x, indices, bin_ids, weights, bins, top_k) # type: ignore[union-attr] def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out + if _should_fallback(x): + return _fallback().padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k) + return _TRITON.padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k) # type: ignore[union-attr] def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X + if _should_fallback(x): + return _fallback().scatter_wgrad(x, grad, indices, bin_ids, bins, top_k) + return _TRITON.scatter_wgrad(x, grad, indices, bin_ids, bins, top_k) # type: ignore[union-attr] def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out + if _should_fallback(x): + return _fallback().binned_gather(x, indices, weights, bins, expert_capacity, top_k) + return _TRITON.binned_gather(x, indices, weights, bins, expert_capacity, top_k) # type: ignore[union-attr] def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) + if _should_fallback(x): + return _fallback().binned_scatter(x, indices, weights, bins, top_k) + return _TRITON.binned_scatter(x, indices, weights, bins, top_k) # type: ignore[union-attr] def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) + if _should_fallback(x): + return _fallback().binned_scatter_wgrad(x, grad, indices, bins, top_k) + return _TRITON.binned_scatter_wgrad(x, grad, indices, bins, top_k) # type: ignore[union-attr] - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/megablocks/backend/npu_kernels.py b/megablocks/backend/npu_kernels.py new file mode 100644 index 00000000..0138a2f0 --- /dev/null +++ b/megablocks/backend/npu_kernels.py @@ -0,0 +1,365 @@ +import torch + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + +# ------------------------------------------------------------------------- +# Padded Operations +# ------------------------------------------------------------------------- + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + """ + Gathers tokens from 'x' based on 'indices' and organizes them into + a padded layout defined by 'padded_bins'. + """ + # Validate inputs + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + output_rows = padded_bins[-1].item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + + # We iterate over experts (bins). + # Since num_experts is usually small (8-64), this loop is negligible compared to data movement. + # 'bins' acts as a CSR pointer array. + + num_experts = len(bins) + current_idx_start = 0 + current_padded_start = 0 + + for i in range(num_experts): + # Determine the range of indices for this expert + bin_end = bins[i].item() + padded_end = padded_bins[i].item() + + count = bin_end - current_idx_start + + if count > 0: + # 1. Get the source token indices for this expert + # indices are sorted by expert in standard MoE usage + src_indices = indices[current_idx_start:bin_end] + + # 2. Gather data: x[src_indices] + gathered = x[src_indices] + + # 3. Apply weights if present + if weights is not None: + w = weights[current_idx_start:bin_end].unsqueeze(1) + gathered = gathered * w + + # 4. Place into the padded output buffer + # We copy 'count' rows into the allocated slot for this expert + out[current_padded_start : current_padded_start + count] = gathered + + current_idx_start = bin_end + current_padded_start = padded_end + + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + """ + Standard gather without padding gaps. + Equivalent to padded_gather where bins == padded_bins. + """ + # Optimization: If no padding logic is needed, we can do a bulk gather + # provided we just want the data in the order of 'indices'. + + # Validate inputs + assert_is_matrix(x) + assert_is_vector(indices) + + # Bulk operation + out = x[indices] + + if weights is not None: + out = out * weights.unsqueeze(1) + + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + """ + Scatters data from the padded expert output 'x' back to original token locations. + Accumulates results if top_k > 1. + """ + # Validate inputs + assert_is_matrix(x) + assert_is_vector(indices) + + tokens = indices.shape[0] // top_k + hidden_size = x.shape[1] + + # We scatter into a flattened view first: (tokens * top_k, hidden) + # Then we will reduce. + # Note: Using index_add_ on a zero tensor handles the accumulation logic. + scatter_buffer = torch.zeros((tokens * top_k, hidden_size), dtype=x.dtype, device=x.device) + + num_experts = len(bins) + current_idx_start = 0 + current_padded_start = 0 + + # Generate a range tensor to map the slice logic to absolute positions in scatter_buffer + # This avoids creating a massive index tensor for the whole batch. + + for i in range(num_experts): + bin_end = bins[i].item() + padded_end = padded_bins[i].item() + count = bin_end - current_idx_start + + if count > 0: + # Data coming from the expert (removing padding implicitly by slicing) + expert_out = x[current_padded_start : current_padded_start + count] + + # Scale by weights (if we are scattering A_TO_B=False, weights are applied here) + if weights is not None: + w = weights[current_idx_start:bin_end].unsqueeze(1) + expert_out = expert_out * w + + # Target positions in the flattened (tokens*topk) array + # We are writing to the range [current_idx_start : bin_end] + # strictly speaking, in standard MoE, 'indices' maps specific slots. + # However, padded_scatter logic in Triton implies reversing the mapping. + # In Megablocks, 'indices' defines the token index for the sorted buffer. + # So we map: scatter_buffer[range] = expert_out + + # Wait, standard scatter logic: out[indices[i]] += x[i] + # Here 'indices' contains the original token row IDs. + target_rows = indices[current_idx_start:bin_end] + + # Add to output (Atomic add equivalent) + scatter_buffer.index_add_(0, target_rows, expert_out) # Maps to (tokens, hidden) effectively? + + # Logic Correction: + # The Triton kernel does: `index_a = indices[idx]`. `out[index_a] = x[...]`. + # But the output 'out' in Triton is shape (tokens, top_k, hidden). + # The python wrapper calculates `out.sum(dim=1)` at the end. + # To match the "out" shape of (tokens, top_k, hidden), we need to know + # which "k" slot a specific token-expert pair occupies. + # However, standard PyTorch implementation simplifies this: + # We can scatter directly to (tokens, hidden) via index_add_ if we don't strictly need the intermediate (tokens, top_k, hidden). + + # Re-reading the Triton wrapper: + # It constructs `out` as (tokens, top_k, hidden). + # Then `out.sum(dim=1)` or `view`. + # Since `indices` usually points to the Token ID (0..N), it doesn't encode the "top_k slot". + # Standard Megablocks usage: indices is (tokens * top_k). + # If we want exact parity with the Intermediate Tensor shape: + + # Optimized PyTorch Logic: + # 1. We essentially want to perform: out_tensor[indices] += expert_outputs + # 2. But `indices` has repeats (a token appears top_k times). + # 3. Direct scatter_add/index_add to (tokens, hidden) is mathematically equivalent to (tokens, topk, hidden).sum(1). + + final_out = torch.zeros((tokens, hidden_size), dtype=x.dtype, device=x.device) + + for i in range(num_experts): + bin_end = bins[i].item() + padded_end = padded_bins[i].item() + count = bin_end - current_idx_start + + if count > 0: + expert_out = x[current_padded_start : current_padded_start + count] + if weights is not None: + w = weights[current_idx_start:bin_end].unsqueeze(1) + expert_out = expert_out * w + + target_indices = indices[current_idx_start:bin_end] + final_out.index_add_(0, target_indices, expert_out) + + current_idx_start = bin_end + current_padded_start = padded_end + + # Return shape behavior matching the original wrapper + if top_k > 1: + return final_out + else: + return final_out.view(tokens, hidden_size) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# ------------------------------------------------------------------------- +# Gradient Operations (wgrad) +# ------------------------------------------------------------------------- + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + """ + Computes gradients for router weights. + Dot product between expert_output (x) and upstream gradient (grad). + """ + assert_is_matrix(x) + assert_is_matrix(grad) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + + num_experts = len(bins) + current_idx_start = 0 + current_padded_start = 0 + + for i in range(num_experts): + bin_end = bins[i].item() + padded_end = padded_bins[i].item() + count = bin_end - current_idx_start + + if count > 0: + # 1. Get expert outputs (padded source) + x_expert = x[current_padded_start : current_padded_start + count] + + # 2. Get corresponding gradients from tokens + # indices points to the token row. + token_indices = indices[current_idx_start:bin_end] + grad_tokens = grad[token_indices] + + # 3. Compute Dot Product (sum over hidden dim) + # (Batch, Hidden) * (Batch, Hidden) -> (Batch, 1) -> Squeeze + dot_prod = (x_expert * grad_tokens).sum(dim=1) + + # 4. Store in output (which is aligned with 'indices') + out[current_idx_start:bin_end] = dot_prod + + current_idx_start = bin_end + current_padded_start = padded_end + + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# ------------------------------------------------------------------------- +# Binned Operations (3D fixed layout: Experts, Capacity, Hidden) +# ------------------------------------------------------------------------- + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + """ + Gathers tokens from 'x' into a fixed 3D tensor organized by expert and capacity. + """ + assert_is_matrix(x) + assert_is_vector(indices) + + num_experts = bins.shape[0] + hidden_size = x.shape[1] + + out = torch.zeros((num_experts, expert_capacity, hidden_size), dtype=x.dtype, device=x.device) + + current_idx_start = 0 + + for i in range(num_experts): + bin_end = bins[i].item() + count = bin_end - current_idx_start + + # Clamp count to capacity (though typical usage implies count <= capacity) + valid_count = min(count, expert_capacity) + + if valid_count > 0: + src_indices = indices[current_idx_start : current_idx_start + valid_count] + + gathered = x[src_indices] + + if weights is not None: + w = weights[current_idx_start : current_idx_start + valid_count].unsqueeze(1) + gathered = gathered * w + + # Copy into the 3D tensor slice + out[i, :valid_count, :] = gathered + + current_idx_start = bin_end + + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + """ + Scatters from 3D expert buffer back to token space. + """ + assert_is_tensor(x, 3) # (Experts, Capacity, Hidden) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + + # Output accumulator + final_out = torch.zeros((tokens, hidden_size), dtype=x.dtype, device=x.device) + + current_idx_start = 0 + + for i in range(num_experts): + bin_end = bins[i].item() + count = bin_end - current_idx_start + valid_count = min(count, expert_capacity) + + if valid_count > 0: + # Slice from 3D buffer + expert_out = x[i, :valid_count, :] + + if weights is not None: + w = weights[current_idx_start : current_idx_start + valid_count].unsqueeze(1) + expert_out = expert_out * w + + target_indices = indices[current_idx_start : current_idx_start + valid_count] + + # Accumulate + final_out.index_add_(0, target_indices, expert_out) + + current_idx_start = bin_end + + if top_k > 1: + return final_out + else: + return final_out.view(tokens, hidden_size) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + """ + Computes router weight gradients for the binned layout. + """ + assert_is_tensor(x, 3) + assert_is_matrix(grad) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + + num_experts, expert_capacity, _ = x.shape + current_idx_start = 0 + + for i in range(num_experts): + bin_end = bins[i].item() + count = bin_end - current_idx_start + valid_count = min(count, expert_capacity) + + if valid_count > 0: + # x: (Experts, Capacity, Hidden) + x_expert = x[i, :valid_count, :] + + # grad: (Tokens, Hidden) + token_indices = indices[current_idx_start : current_idx_start + valid_count] + grad_tokens = grad[token_indices] + + # Dot product + dot_prod = (x_expert * grad_tokens).sum(dim=1) + + out[current_idx_start : current_idx_start + valid_count] = dot_prod + + current_idx_start = bin_end + + return out diff --git a/megablocks/backend/npu_ops/__init__.py b/megablocks/backend/npu_ops/__init__.py new file mode 100644 index 00000000..8e49c8ce --- /dev/null +++ b/megablocks/backend/npu_ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .histogram import histogram +from .indices import indices +from .replicate import replicate_backward, replicate_forward +from .sort import sort + diff --git a/megablocks/backend/npu_ops/cumsum.py b/megablocks/backend/npu_ops/cumsum.py new file mode 100644 index 00000000..c09817e3 --- /dev/null +++ b/megablocks/backend/npu_ops/cumsum.py @@ -0,0 +1,22 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + assert x.dim() == 2 + assert dim == 1 + return torch.cumsum(x, dim=dim, out=out) + + +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + assert x.dim() == 2 + assert dim == 1 + + if out is not None: + torch.cumsum(x, dim=dim, out=out) + out.sub_(x) + return out + return torch.cumsum(x, dim=dim) - x + diff --git a/megablocks/backend/npu_ops/histogram.py b/megablocks/backend/npu_ops/histogram.py new file mode 100644 index 00000000..a1268bfd --- /dev/null +++ b/megablocks/backend/npu_ops/histogram.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """Pure PyTorch histogram matching `megablocks_ops.histogram`. + + Supports both 1D and 2D inputs. Returns int32 counts. + """ + if x.ndim not in (1, 2): + raise ValueError(f"Expected 1D or 2D tensor but got {x.ndim}D.") + + original_ndim = x.ndim + if original_ndim == 1: + x = x.view(1, -1) + + batch_size = x.shape[0] + if x.numel() == 0: + out = torch.zeros(batch_size, num_bins, device=x.device, dtype=torch.int32) + return out.flatten() if original_ndim == 1 else out + + # Batched bincount by shifting each row into a disjoint bin range. + offsets = torch.arange(batch_size, device=x.device, dtype=torch.int64) * int(num_bins) + x_flat_shifted = (x.to(torch.int64) + offsets.unsqueeze(1)).reshape(-1) + counts = torch.bincount(x_flat_shifted, minlength=batch_size * int(num_bins)) + out = counts.view(batch_size, int(num_bins)).to(torch.int32) + return out.flatten() if original_ndim == 1 else out + diff --git a/megablocks/backend/npu_ops/indices.py b/megablocks/backend/npu_ops/indices.py new file mode 100644 index 00000000..afb92a3d --- /dev/null +++ b/megablocks/backend/npu_ops/indices.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + out: torch.Tensor, +) -> None: + """Pure PyTorch indices matching `megablocks_ops.indices` out-params.""" + zeros = torch.zeros(1, device=padded_bins.device, dtype=padded_bins.dtype) + starts_tokens = torch.cat([zeros, padded_bins[:-1]]) + ends_tokens = padded_bins + + starts_blocks = starts_tokens.div(block_size, rounding_mode="floor") + ends_blocks = ends_tokens.div(block_size, rounding_mode="floor") + rows_per_bin = ends_blocks - starts_blocks + + num_bins = padded_bins.numel() + bin_ids = torch.arange(num_bins, device=padded_bins.device) + bin_base_values = bin_ids * output_block_columns + row_base_vals = torch.repeat_interleave(bin_base_values, rows_per_bin) + + if row_base_vals.numel() == 0: + return + + col_offsets = torch.arange(output_block_columns, device=padded_bins.device).unsqueeze(0) + result_matrix = row_base_vals.unsqueeze(1) + col_offsets + result_flat = result_matrix.flatten().to(dtype=torch.int16) + + num_elements = min(out.numel(), result_flat.numel()) + out[:num_elements] = result_flat[:num_elements] + diff --git a/megablocks/backend/npu_ops/replicate.py b/megablocks/backend/npu_ops/replicate.py new file mode 100644 index 00000000..3cdbab52 --- /dev/null +++ b/megablocks/backend/npu_ops/replicate.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def replicate_forward(x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor) -> None: + """Pure PyTorch replicate_forward matching `megablocks_ops.replicate_forward`.""" + zero = torch.tensor([0], device=bins.device, dtype=bins.dtype) + counts = torch.diff(bins, prepend=zero).to(torch.long) + res = torch.repeat_interleave(x, counts, dim=1) + out.copy_(res) + + +def replicate_backward(grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor) -> None: + """Pure PyTorch replicate_backward matching `megablocks_ops.replicate_backward`.""" + zero = torch.tensor([0], device=bins.device, dtype=bins.dtype) + counts = torch.diff(bins, prepend=zero).to(torch.long) + + num_bins = bins.size(0) + bin_indices = torch.repeat_interleave( + torch.arange(num_bins, device=grad.device, dtype=torch.long), + counts, + ) + batch_size = grad.size(0) + expanded_indices = bin_indices.unsqueeze(0).expand(batch_size, -1) + + out.zero_() + out.scatter_add_(dim=1, index=expanded_indices, src=grad) + diff --git a/megablocks/backend/npu_ops/sort.py b/megablocks/backend/npu_ops/sort.py new file mode 100644 index 00000000..5ad29d1d --- /dev/null +++ b/megablocks/backend/npu_ops/sort.py @@ -0,0 +1,21 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def sort( + x: torch.Tensor, + end_bit: int, + x_out: torch.Tensor, + iota_out: torch.Tensor, +) -> None: + """Pure PyTorch sort matching `megablocks_ops.sort` out-params.""" + del end_bit + if x.ndim != 1: + raise ValueError("Expected a 1D tensor.") + + sorted_values, sorted_indices = torch.sort(x, stable=True) + x_out.copy_(sorted_values) + iota_out.copy_(sorted_indices.to(iota_out.dtype)) + diff --git a/megablocks/backend/triton_kernels.py b/megablocks/backend/triton_kernels.py new file mode 100644 index 00000000..b584ceed --- /dev/null +++ b/megablocks/backend/triton_kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index bf0482ac..3b026536 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -7,12 +7,18 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. try: - import megablocks_ops as ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + import megablocks_ops as _ops # type: ignore +except ModuleNotFoundError: + _ops = None + + +def _inclusive_cumsum_fallback(x: torch.Tensor, dim: int) -> torch.Tensor: + return torch.cumsum(x, dim=dim) + + +def _exclusive_cumsum_fallback(x: torch.Tensor, dim: int) -> torch.Tensor: + return torch.cumsum(x, dim=dim) - x # Autograd wrappers for cumsum kernels. @@ -23,11 +29,19 @@ class ExclusiveCumsumOp(torch.autograd.Function): def forward(ctx: Any, x: torch.Tensor, dim: int): if len(x.size()) == 1: x = x.view([1, -1]) + if x.device.type == 'npu': + return _exclusive_cumsum_fallback(x, 1).squeeze() + if _ops is None: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) + _ops.exclusive_cumsum(x, 1, out) return out.squeeze() + if x.device.type == 'npu': + return _exclusive_cumsum_fallback(x, dim) + if _ops is None: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) + _ops.exclusive_cumsum(x, dim, out) return out @@ -40,11 +54,19 @@ class InclusiveCumsumOp(torch.autograd.Function): def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: if len(x.size()) == 1: x = x.view([1, -1]) + if x.device.type == 'npu': + return _inclusive_cumsum_fallback(x, 1).squeeze() + if _ops is None: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) + _ops.inclusive_cumsum(x, 1, out) return out.squeeze() + if x.device.type == 'npu': + return _inclusive_cumsum_fallback(x, dim) + if _ops is None: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) + _ops.inclusive_cumsum(x, dim, out) return out diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index 78552338..45f2f4eb 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -7,12 +7,28 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. try: - import megablocks_ops as ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + import megablocks_ops as _ops # type: ignore +except ModuleNotFoundError: + _ops = None + + +def _histogram_fallback(x: torch.Tensor, max_val: float) -> torch.Tensor: + # `max_val` is used as number of bins throughout Megablocks. + num_bins = int(max_val) + + # The C++ kernel supports both 1D and 2D inputs. + if x.ndim == 1: + return torch.bincount(x.to(torch.int64), minlength=num_bins).to(torch.int32) + + if x.ndim != 2: + raise ValueError(f'Expected 1D or 2D tensor but got {x.ndim}D.') + + batch_size = x.shape[0] + offsets = torch.arange(batch_size, device=x.device, dtype=torch.int64) * num_bins + x_shifted = x.to(torch.int64) + offsets.unsqueeze(1) + counts = torch.bincount(x_shifted.flatten(), minlength=batch_size * num_bins) + return counts.view(batch_size, num_bins).to(torch.int32) # Autograd wrapper for histogram kernel. @@ -21,7 +37,12 @@ class HistogramOp(torch.autograd.Function): @staticmethod def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) + # Ascend NPU does not support the CUDA extension; fall back to PyTorch. + if x.device.type == 'npu': + return _histogram_fallback(x, max_val) + if _ops is None: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") + return _ops.histogram(x, max_val) histogram = HistogramOp.apply diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index 2dbec35c..1af58e91 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -7,12 +7,16 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. try: - import megablocks_ops as ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + import megablocks_ops as _ops # type: ignore +except ModuleNotFoundError: + _ops = None + + +def _bin_counts_from_inclusive_bins(bins: torch.Tensor) -> torch.Tensor: + zeros = bins.new_zeros((1,)) + counts = torch.diff(bins.to(torch.long), prepend=zeros) + return counts # Autograd wrapper for replicate kernel. @@ -21,15 +25,39 @@ class ReplicateOp(torch.autograd.Function): @staticmethod def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): ctx.save_for_backward(bins) + if x.device.type == 'npu': + counts = _bin_counts_from_inclusive_bins(bins) + out = torch.repeat_interleave(x, counts, dim=1) + # Defensive: match the requested output width. + if out.shape[1] < num_outputs: + out = torch.nn.functional.pad(out, (0, num_outputs - out.shape[1])) + elif out.shape[1] > num_outputs: + out = out[:, :num_outputs] + return out + if _ops is None: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) + _ops.replicate_forward(x, bins, out) return out @staticmethod def backward(ctx: Any, grad: torch.Tensor): bins, = ctx.saved_tensors + if grad.device.type == 'npu': + counts = _bin_counts_from_inclusive_bins(bins) + num_bins = bins.shape[0] + bin_ids = torch.repeat_interleave( + torch.arange(num_bins, device=grad.device, dtype=torch.long), + counts, + ) + expanded = bin_ids.unsqueeze(0).expand(grad.shape[0], -1) + out = torch.zeros((grad.shape[0], num_bins), dtype=grad.dtype, device=grad.device) + out.scatter_add_(dim=1, index=expanded, src=grad) + return out, None, None + if _ops is None: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) + _ops.replicate_backward(grad, bins, out) return out, None, None diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index 4fb0aab4..db350346 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -7,12 +7,10 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. try: - import megablocks_ops as ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + import megablocks_ops as _ops # type: ignore +except ModuleNotFoundError: + _ops = None _BITS_FOR_DTYPE = { torch.int16: 16, @@ -29,9 +27,15 @@ class SortOp(torch.autograd.Function): def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: if end_bit is None: end_bit = _BITS_FOR_DTYPE[x.dtype] + if x.device.type == 'npu': + # `end_bit` is ignored by the PyTorch fallback. + x_out, idx = torch.sort(x) + return (x_out, idx.to(dtype=x.dtype)) + if _ops is None: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") x_out = torch.empty_like(x) iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) + _ops.sort(x, end_bit, x_out, iota_out) return (x_out, iota_out) diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index b41b5fa5..f514346d 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -7,12 +7,54 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. try: - import megablocks_ops as ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + import megablocks_ops as _ops # type: ignore +except ModuleNotFoundError: + _ops = None + + +def _indices_fallback( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + # Mirror `csrc/indices.h` on any device type. + num_bins = padded_bins.numel() + if num_bins == 0 or output_block_rows * output_block_columns == 0: + return torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + + starts_tokens = torch.zeros_like(padded_bins) + starts_tokens[1:] = padded_bins[:-1] + starts_blocks = torch.div(starts_tokens, block_size, rounding_mode='floor') + ends_blocks = torch.div(padded_bins, block_size, rounding_mode='floor') + + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + out.zero_() + col_offsets = torch.arange( + output_block_columns, + device=padded_bins.device, + dtype=torch.int16, + ) + + # Fill each block-row belonging to each bin. + for bin_id in range(num_bins): + start = int(starts_blocks[bin_id].item()) + end = int(ends_blocks[bin_id].item()) + value = torch.as_tensor(bin_id * output_block_columns, device=out.device, dtype=torch.int16) + col_offsets + for row in range(start, min(end, output_block_rows)): + base = row * output_block_columns + out[base: base + output_block_columns] = value + + return out # Autograd wrapper for topology kernel. @@ -27,12 +69,21 @@ def forward( output_block_rows: int, output_block_columns: int, ): + if padded_bins.device.type == 'npu': + return _indices_fallback( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + ) + if _ops is None: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") out = torch.empty( output_block_rows * output_block_columns, dtype=torch.int16, device=padded_bins.device, ) - ops.indices( + _ops.indices( padded_bins, block_size, output_block_rows, diff --git a/tests/backend_npu_kernels_test.py b/tests/backend_npu_kernels_test.py new file mode 100644 index 00000000..8abd763e --- /dev/null +++ b/tests/backend_npu_kernels_test.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.backend import npu_kernels + + +def _reference_gather(x: torch.Tensor, top_k: int) -> torch.Tensor: + # indices = arange(tokens * top_k), single bin + tokens = x.shape[0] + idx = torch.arange(tokens * top_k, device=x.device) + return x[idx // top_k] + + +def test_npu_kernels_gather_simple_repeats_tokens(): + tokens, hidden, top_k = 8, 16, 2 + x = torch.randn(tokens, hidden) + indices = torch.arange(tokens * top_k, dtype=torch.int64) + bin_ids = torch.zeros(tokens * top_k, dtype=torch.int64) + bins = torch.tensor([tokens * top_k], dtype=torch.int64) + out = npu_kernels.gather(x, indices, bin_ids, None, bins, top_k) + assert torch.allclose(out, _reference_gather(x, top_k)) + + +def test_npu_kernels_scatter_inverts_gather_for_unit_weights(): + tokens, hidden, top_k = 8, 16, 2 + x = torch.randn(tokens, hidden) + gathered = _reference_gather(x, top_k) + + indices = torch.arange(tokens * top_k, dtype=torch.int64) + bin_ids = torch.zeros(tokens * top_k, dtype=torch.int64) + bins = torch.tensor([tokens * top_k], dtype=torch.int64) + scattered = npu_kernels.scatter(gathered, indices, bin_ids, None, bins, top_k) + + # Each token appears `top_k` times. + assert torch.allclose(scattered, x * top_k) + + +def test_npu_kernels_padded_gather_respects_padding(): + tokens, hidden, top_k = 6, 8, 1 + x = torch.randn(tokens, hidden) + + # All entries belong to expert 0; expert 1 is empty. + indices = torch.arange(tokens * top_k, dtype=torch.int64) + bin_ids = torch.zeros(tokens * top_k, dtype=torch.int64) + bins = torch.tensor([tokens, tokens], dtype=torch.int64) # inclusive + padded_bins = torch.tensor([tokens + 4, tokens + 4], dtype=torch.int64) # pad expert 0 by 4 rows + + out = npu_kernels.padded_gather(x, indices, bin_ids, None, bins, padded_bins, top_k) + assert out.shape == (tokens + 4, hidden) + assert torch.allclose(out[:tokens], x) + assert torch.allclose(out[tokens:], torch.zeros_like(out[tokens:])) +