System Info
- Hardware: Jetson Orin Nano Super 8 GB (sm_87, Ampere, 1024 CUDA cores)
- OS: JetPack 6.2 (L4T R36.4) / Ubuntu 22.04
- CUDA: 12.6
- Python: 3.10.12
- PyTorch: 2.5.1 (source-rebuilt for Jetson with USE_DISTRIBUTED=1 USE_GLOO=1; reports as 2.5.0a0+gita8d6afb)
- bitsandbytes: 0.46.1, source-built at SHA 4bca844 with
cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=87 -S .
- Power mode: MAXN_SUPER (
nvpmodel -m 2)
Reproduction
Summary
On Jetson Orin Nano Super (sm_87), running a sequence of three bnb.nn.Linear4bit construct + forward cycles at large output dimensions in a single Python process reboots the entire system — kernel-level reboot, not Python crash, not Linux OOM-killer, not cudaMalloc ENOMEM. It fires reliably; I've reproduced it twice on separate fresh boots, and the same shapes pass cleanly when each is isolated in its own process.
The fault appears to be accumulated state across multiple bnb-NF4 cycles in the same process — most likely allocator fragmentation, bnb scratch-buffer accumulation, or driver bookkeeping at sm_87.
Reproducer
Standalone gist: https://gist.github.com/neil-the-nowledgeable/a5ade63cbd2143d0628c96c8f840c87e
# WARNING: reboots the Jetson on every attempt I've tried.
# Run via ssh from a separate machine; close any unsaved work first.
import torch
import bitsandbytes as bnb
shapes = [
(4096, 32768), # Linear A — large intermediate-class
(4096, 128256), # Linear B — Llama-3.1-8B-class lm_head
(3584, 152064), # Linear C — Qwen2.5-7B-class lm_head
]
for in_f, out_f in shapes:
print(f"Constructing Linear4bit({in_f}, {out_f})...", flush=True)
layer = bnb.nn.Linear4bit(
in_f, out_f, bias=False,
quant_type="nf4",
compute_dtype=torch.bfloat16,
quant_storage=torch.bfloat16,
).to("cuda")
x = torch.randn(1, in_f, dtype=torch.bfloat16, device="cuda")
with torch.no_grad():
y = layer(x)
torch.cuda.synchronize()
print(f" forward OK; peak_mb={torch.cuda.max_memory_allocated()/1e6:.1f}",
flush=True)
Observed
- Shapes A and B complete forward cleanly (peak memory printed, sync returns).
- During or shortly after Linear C's forward, the Jetson reboots. dmesg / journalctl from the prior boot show no kernel oops, no OOM-killer event, no NVIDIA driver fault
- the system simply restarts. ssh sessions die; the Jetson is back up ~60 s later.
- Reproduces every time so far (twice on fresh boots, identical shape sequence).
Bracket — what works vs what doesn't
| Test |
Result |
Linear C alone in fresh process: Linear4bit(3584, 152064) construct + forward |
PASS (fwd ~30 ms, peak ~2530 MB) |
| Linears A, B, C in separate processes, sequentially |
PASS each |
| Linears A → B → C in same Python process (above) |
REBOOT |
The single-shape PASS for (3584, 152064) is what made me revise an earlier framing — I'd initially thought it was a single-shape geometry fault. It isn't; the bug requires multi-shape sequencing.
Workload context
This is the limiting factor for running 7B-class QLoRA training on Jetson Orin Nano Super on the bnb path. The multi-shape sequence reboot above is one of three*
distinct bnb-NF4 sm_87 fault regimes I've observed at the broader 7B-class workload boundary. The 1B–4B class is solid on this build; Mistral-7B-v0.3 specifically (32 K vocab, intermediate=14336, 32 layers) also dodges all three regimes at single-Jetson + no-FSDP and trains cleanly to 15,610 stable steps at 5044 ± 1 MB peak.
* The other two regimes — both reproducible, separate bugs, scope-creep for this issue but flagged for awareness:
- Stacked-context fault: Llama-3.1-8B real-model forward (no FSDP at all) hangs Python at the bnb forward call. Each individual Linear in the model passes synthetically; the fault requires the full 32-decoder-layer + lm_head graph context to fire.
- FSDP2-multilayer-aggravated: per-decoder
fully_shard() wrap of any 7B-class model with intermediate_size ≥ 14336 hangs forward.
Cross-reference
This is the kernel-fault caveat referenced in #1218 (prebuilt aarch64+CUDA wheel request).
Offer
I'm happy to run instrumented variants of the repro on my Jetson — per-shape cudaDeviceSynchronize, allocator stats between shapes, compute-sanitizer if it fits in memory, etc. — and share JSON outputs. Let me know what would be most helpful.
Expected behavior
The shape sequence should either:
- Complete cleanly and print "Sequence complete" — same outcome as running each shape in its own process.
- Raise a Python-level error (e.g., CUDA OOM) that surfaces through normal exception handling.
Rebooting the Jetson is unexpected and not recoverable in-process. There's no Python traceback, no kernel oops, no driver fault log — the system just restarts, which makes it both hard to diagnose and dangerous for any long-running training process that hits the sequence.
System Info
cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=87 -S .nvpmodel -m 2)Reproduction
Summary
On Jetson Orin Nano Super (sm_87), running a sequence of three
bnb.nn.Linear4bitconstruct + forward cycles at large output dimensions in a single Python process reboots the entire system — kernel-level reboot, not Python crash, not Linux OOM-killer, notcudaMallocENOMEM. It fires reliably; I've reproduced it twice on separate fresh boots, and the same shapes pass cleanly when each is isolated in its own process.The fault appears to be accumulated state across multiple bnb-NF4 cycles in the same process — most likely allocator fragmentation, bnb scratch-buffer accumulation, or driver bookkeeping at sm_87.
Reproducer
Standalone gist: https://gist.github.com/neil-the-nowledgeable/a5ade63cbd2143d0628c96c8f840c87e
Observed
Bracket — what works vs what doesn't
Linear4bit(3584, 152064)construct + forwardThe single-shape PASS for
(3584, 152064)is what made me revise an earlier framing — I'd initially thought it was a single-shape geometry fault. It isn't; the bug requires multi-shape sequencing.Workload context
This is the limiting factor for running 7B-class QLoRA training on Jetson Orin Nano Super on the bnb path. The multi-shape sequence reboot above is one of three*
distinct bnb-NF4 sm_87 fault regimes I've observed at the broader 7B-class workload boundary. The 1B–4B class is solid on this build; Mistral-7B-v0.3 specifically (32 K vocab, intermediate=14336, 32 layers) also dodges all three regimes at single-Jetson + no-FSDP and trains cleanly to 15,610 stable steps at 5044 ± 1 MB peak.
* The other two regimes — both reproducible, separate bugs, scope-creep for this issue but flagged for awareness:
fully_shard()wrap of any 7B-class model withintermediate_size ≥ 14336hangs forward.Cross-reference
This is the kernel-fault caveat referenced in #1218 (prebuilt aarch64+CUDA wheel request).
Offer
I'm happy to run instrumented variants of the repro on my Jetson — per-shape
cudaDeviceSynchronize, allocator stats between shapes,compute-sanitizerif it fits in memory, etc. — and share JSON outputs. Let me know what would be most helpful.Expected behavior
The shape sequence should either:
Rebooting the Jetson is unexpected and not recoverable in-process. There's no Python traceback, no kernel oops, no driver fault log — the system just restarts, which makes it both hard to diagnose and dangerous for any long-running training process that hits the sequence.