Reject CUDA BERT EmbedLayerNorm/SkipLayerNorm shapes exceeding 32-bit output indexing#29264
Open
titaiwangms wants to merge 1 commit into
Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR addresses integer overflow in CUDA BERT LayerNorm-family kernels by widening the global element write offset (row * hidden_size) from 32-bit to 64-bit, preventing wrapped output indexing for very large tensors.
Changes:
- Widen LayerNorm device helper offset/index parameters to
int64_tinlayer_norm.cuh. - Compute per-row offsets/indices in 64-bit in
skip_layer_norm_impl.cukernels. - Compute
output_offsetin 64-bit inembed_layer_norm_impl.cuand pass it through toLayerNorm.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu | Uses int64_t for per-row offset/idx to avoid overflow when indexing large output tensors. |
| onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh | Updates LayerNorm helpers to accept 64-bit offsets and use 64-bit indices for global element access. |
| onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu | Uses 64-bit output_offset for writing/normalizing large outputs in EmbedLayerNorm. |
Contributor
|
Is it needed? Typical max sequence length for BERT model is 512, and int32 offset is enough. |
… output indexing The CUDA EmbedLayerNormalization and SkipLayerNormalization kernels compute output write offsets (row_index * hidden_size) using 32-bit arithmetic. For very large output tensors the element count can exceed INT32_MAX and the offset would no longer be representable in 32 bits. Every output write index in these kernels is a pure function of the launch grid and hidden_size (no data-dependent write indexing), so the maximum index is exactly output_element_count - 1, which the host knows from the input shapes before launch. Add a host-side guard in each ComputeInternal that computes the output element count in 64-bit arithmetic and returns a clear error when it exceeds the supported 32-bit indexing range, instead of silently relying on the int32 kernels for shapes they cannot index. Kernels are unchanged (int32 baseline); no numeric behavior change for supported shapes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1379258 to
0b9d5e2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The CUDA
EmbedLayerNormalizationandSkipLayerNormalizationkernels compute output write offsets (row_index * hidden_size) using 32-bit arithmetic. For very large output tensors the element count can exceedINT32_MAX, at which point the offset is no longer representable in 32 bits.Every output write index in these kernels is a pure function of the launch grid and
hidden_size— there is no data-dependent write indexing — so the maximum index is exactlyoutput_element_count - 1, which the host knows from the input shapes before launch. This PR adds a host-side guard in each op'sComputeInternalthat computes the output element count in 64-bit arithmetic and returns a clear error when it exceeds the supported 32-bit indexing range.Design
EmbedLayerNormalization(embed_layer_norm.cc):output_element_count = (int64)batch_size * sequence_length * hidden_size, guarded withORT_RETURN_IF_NOT(... <= INT32_MAX, ...).SkipLayerNormalization(skip_layer_norm.cc):output_element_count = input->Shape().Size()(output shares the input shape), same guard.Behavior
This rejects (rather than silently attempting) single-op LayerNorm outputs larger than 2³¹ elements — a regime no real BERT-family model produces (it would require a multi-GB single-op activation). For all supported shapes there is no behavior or numeric change.
Co-authored-by: Copilot 223556219+Copilot@users.noreply.github.com