A Scalable SOTA PyTorch Training Framework โ SOTA-level capabilities with 100% YAML-driven configuration.
TrainScale isn't just another training script. It's a comprehensive, modular architecture designed to solve the "last mile" problem in LLM training: Data Engineering.
Most frameworks treat data loading as an afterthought. TrainScale makes it a first-class citizen with SOTA preprocessing features usually found only in proprietary codebases (like flexible packing, token-aware distribution, and thorough dataset introspection).
- Zero Hardcoding: Every aspect of the pipeline is controlled via YAML.
- SOTA Data Pipeline: Smart truncation, content-aware token distribution, and dynamic packing.
- Rust-Inspired Reliability: Uses
Result<T, E>patterns for robust error handling. - Hardware Optimized: Built-in support for Flash Attention 2, Triton kernels, and 8-bit optimizers.
The TrainScale pipeline operates in distinct, modular stages to ensure scalability and reproducibility.
graph LR
A[YAML Configuration] --> B[Dataset Introspector]
B --> C[Dataset Loader]
C --> D[Prompt Engine]
D --> E[Length Manager]
E --> F[Tokenizer Wrapper]
F --> G[DataLoader Builder]
G --> H[SOTATrainer]
style A fill:#f9f,stroke:#333,stroke-width:2px
style D fill:#bbf,stroke:#333,stroke-width:2px
style H fill:#bfb,stroke:#333,stroke-width:2px
- Problem: Hardcoding split names (
train,validation) and columns (text,input) makes code brittle. - Solution: Automatically inspects HuggingFace datasets to discover available splits and columns, mapping them to a standardized schema defined in your YAML.
- Problem: Simple truncation cuts off important context; "max_length" is a blunt instrument.
- Solution:
- Smart Truncation: Respects sentence and word boundaries.
- Content Distribution: Allocates token budgets intelligently (e.g., "Give 60% to context, 40% to history").
- Priority Trimming: Drops least important columns first when context window is exceeded.
- Problem: Training scripts are often monolithic and hard to extend.
- Solution: A modular trainer supporting multiple backends (FSDP, DDP, QLoRA) and advanced features like:
- Optimizers: Adam8bit, Lion, SophiaG, Prodigy.
- Schedulers: Cosine, WSD (Warmup-Stable-Decay), REX.
- Loss Functions: Fused CrossEntropy, DPO, SimPO.
# Clone repository
git clone https://github.com/generalaimodels/TrainScale.git
cd TrainScale
# Optional: Flash Attention 2 (Recommended for Ampere+)
pip install flash-attn --no-build-isolationWe provide a production-ready example in examples/. This script auto-detects your GPU setup (ROCm/CUDA) and launches a DDP training run.
# Single GPU
python examples/rocm_sota_demo_ddp.py --config examples/rocm_sota_config.yaml
# Multi-GPU (e.g., 4 GPUs)
torchrun --nproc_per_node=4 examples/rocm_sota_demo_ddp.py --config examples/rocm_sota_config.yamlTrainScale is optimized for a wide range of hardware, from consumer GPUs to H100 clusters.
| GPU | VRAM | Mode | Max Context | Batch Size | Technique |
|---|---|---|---|---|---|
| RTX 3090 | 24GB | QLoRA | 2048 | 4 | 4-bit NF4 + Gradient Checkpointing |
| RTX 4090 | 24GB | QLoRA | 4096 | 4 | 4-bit NF4 + Flash Attn 2 |
| A100 40GB | 40GB | LoRA | 8192 | 8 | BF16 + Flash Attn 2 |
| A100 80GB | 80GB | Full | 8192 | 16 | BF16 + FSDP |
| H100 | 80GB | Full | 16384 | 32 | FP8 + Transformer Engine |
| Mac M1/M2 | Unified | MPS | 2048 | 1-2 | FP16 (Experimental) |
Let's be honest: training LLMs efficiently is hard. If you don't optimize, you are burning money and time. Here is the technical reality of high-performance training with TrainScale:
Stop using fp32. It consumes 2x memory and 2x bandwidth for zero perceptible gain in SFT.
- Mandatory: Use
bf16(Brain Float 16) on Ampere/MI300+ hardware. It prevents overflow/underflow issues common infp16without the cost offp32. - Next-Gen: If you have H100s or MI300X, use
fp8via Transformer Engine (supported in TrainScale).
Native PyTorch layers have overhead. We stripped them out.
- Flash Attention 2: Non-negotiable for sequences > 2048. We enforce this by default.
- Triton Kernels: We implemented custom fused kernels for RMSNorm, RoPE, and CrossEntropy. If you disable these, your throughput will drop by 30-40%.
- DDP (Distributed Data Parallel): Perfect for LoRA/QLoRA on < 8 GPUs. Fast, simple, robust.
- FSDP (Fully Sharded Data Parallel): The only way to do Full Fine-Tuning on huge models (70B+). If you try DDP for full fine-tuning a 70B model on 24GB cards, you will OOM instantly.
- ZeRO-3: We support it, but it adds communication overhead. Use only if FSDP doesn't fit.
Most training runs are bottlenecked by CPU data processing, not GPU compute.
- TrainScale Solution: We pre-tokenize and "pack" datasets. We don't just truncate; we fill context windows (e.g., 4096) completely with multiple samples. This increases effective throughput by 2x-3x compared to naive padding.
- Standard AdamW: Memory hog. Avoid for models > 7B unless you have 80GB VRAM.
- AdamW-8bit: Recommended. Same convergence as 32-bit but uses 75% less memory for optimizer states.
- Lion: Great for throughput (simpler math than Adam), but requires careful hyper-parameter tuning.
TL;DR: Use BF16 + FlashAttn-2 + Packed Data + 8-bit Optimizer. Anything else is suboptimal.
Control how your data is processed with granular detail:
preprocessing:
length_manager:
enabled: true
max_total_length: 4096
truncation_strategy: "smart" # smart, sentence_boundary, word_boundary
# Precise control over character limits per column
per_column_limits:
instruction: 500
input: 2000
output: 1500
content_distribution:
enabled: true
mode: "proportional" # or 'priority', 'ratio'
column_ratios:
instruction: 0.2
input: 0.3
output: 0.5Switch between training modes and hardware optimizations instantly:
training:
mode: "qlora" # full, lora, qlora
hardware:
precision: "bf16"
compile_model: true # torch.compile
optimizer:
type: "adamw_8bit" # 75% memory saving over standard AdamW
learning_rate: 2e-4
scheduler:
type: "wsd" # Warmup-Stable-Decay (LLaMA-3 style)We welcome contributions! Whether you're fixing a bug, adding a new feature, or improving documentation, here's how you can help:
- New Data Connectors: Support for SQL, S3, or Arrow datasets.
- Additional Kernels: Implement optimized Triton kernels for new attention mechanisms.
- Model Support: Add configs for new architectures (Mistral, Gemma, Phi).
- Benchmarks: Run hardware benchmarks and update the README table.
- Type Hints: All code must be fully type-hinted (
mypycompliant). - Error Handling: Use the
Resulttype fromcore/types.pyinstead of raising raw exceptions where possible. - Config-First: Avoid hardcoding. If a value might change, put it in the YAML schema.
- Tests: Add unit tests for new modules. Run existing tests before pushing.
- Fork the repo.
- Create a branch:
git checkout -b feature/my-cool-feature. - Commit your changes.
- Push to your fork and submit a Pull Request.
-
Phase 1: Foundation (Complete) โ
- End-to-end YAML pipeline
- SOTA preprocessing module
- QLoRA/LoRA support
-
Phase 2: Scale (In Progress) ๐ง
- Multi-node FSDP training
- DeepSpeed integration
- Streaming dataset support for infinite datasets
-
Phase 3: Multimodal (Planned) ๐ฎ
- Image/Video tokenization support
- Audio processing pipeline
MIT License โ see LICENSE for details.
TrainScale โ Train Smarter, Scale Faster ๐