Skip to content

Latest commit

 

History

History
310 lines (248 loc) · 8.84 KB

File metadata and controls

310 lines (248 loc) · 8.84 KB

TileGym Transformers Inference

End-to-end inference examples for transformer language models accelerated with TileGym kernels. Optimized for NVIDIA Blackwell Architecture.

Supported Models

Model Model ID Features
LLaMA-3.1-8B meta-llama/Meta-Llama-3.1-8B RoPE, SwiGLU, RMSNorm, Attention*, Flash Decoding*
DeepSeek-V2-Lite-Chat deepseek-ai/DeepSeek-V2-Lite-Chat RoPE, SwiGLU, RMSNorm, MoE, MLADecoding*, Attention*
Qwen2-7B Qwen/Qwen2-7B RoPE, SwiGLU, RMSNorm, Attention*
Gemma-3-4B-IT google/gemma-3-4b-it RoPE, GEGLU, RMSNorm, Attention*
GPT-OSS openai/gpt-oss-20b RoPE, RMSNorm, Attention Sink*
Mistral-7B-Instruct-v0.3 mistralai/Mistral-7B-Instruct-v0.3 RoPE, SwiGLU, RMSNorm, Attention*
Phi-3-mini-4k-instruct microsoft/Phi-3-mini-4k-instruct RoPE, SwiGLU, RMSNorm, Attention*

*Optional: Enable with --use_attn, we can use attention provided in TileGym

B200 can support both models. Due to memory constraints, RTX 5090 GPUs only support LLaMA-3.1-8B models. DeepSeek-V2-Lite-Chat requires higher memory capacity.

Docker Support

# Option 1: Use the build script
cd modeling/transformers
./build_docker.sh

# Option 2: Build manually (must run from tilegym repository root)
cd /path/to/tilegym
docker build -t tilegym-transformers -f modeling/transformers/Dockerfile .

# Enter interactive mode
docker run --gpus all -it tilegym-transformers bash

# Or run inference directly
docker run --gpus all -it tilegym-transformers \
    python infer.py --model_id deepseek-ai/DeepSeek-V2-Lite-Chat --use_tilegym --use_cutile --use_attn --show_outputs

Quick Start

Basic Inference

# Transformer baseline
python infer.py --model_id meta-llama/Meta-Llama-3.1-8B --show_outputs

# With CUTILE backend
python infer.py --model_id meta-llama/Meta-Llama-3.1-8B --use_tilegym --use_cutile --use_attn --show_outputs

Using Custom Inputs

# From file
python infer.py \
    --model_id meta-llama/Meta-Llama-3.1-8B \
    --use_tilegym \
    --use_attn \
    --sentence_file sample_inputs/input_prompt_32K.txt \
    --output_length 100

# From command line
python infer.py \
    --model_id meta-llama/Meta-Llama-3.1-8B \
    --use_tilegym \
    --use_cutile \
    --use_attn \
    --input_text "Explain machine learning" \
    --show_outputs

Performance Profiling

Will provide results using Torch Profiler.

python infer.py \
    --model_id meta-llama/Meta-Llama-3.1-8B \
    --sentence_file sample_inputs/input_prompt_32K.txt \
    --use_tilegym \
    --use_cutile \
    --use_attn \
    --profile \
    --num_runs 5

Kernel Coverage Report

Report the fraction of GPU time and kernel launches covered by TileGym cuTile kernels. Runs the model under NSight Systems (nsys profile) and analyzes the trace automatically.

python infer.py \
    --model_id meta-llama/Meta-Llama-3.1-8B \
    --use_tilegym \
    --use_cutile \
    --use_attn \
    --report_kernel_coverage \
    --sentence_file sample_inputs/input_prompt_32K.txt \
    --output_length 100

Example output:

===== NSYS KERNEL GPU TIME ANALYSIS =====

Kernel Name                                                   # Calls   GPU Time (ms)   % of Total
------------------------------------------------------------  --------  -------------   ----------
fmha_kernel                                                       ...          54.507        10.5%
rms_norm_kernel_gather                                            ...           9.788         1.9%
...
------------------------------------------------------------  --------  -------------   ----------
TileGym Total                                                    9676          95.147        18.3%
All Kernels Total                                              104858         520.725       100.0%

>>> cuTile Kernel Coverage (GPU Time):    18.3% <<<
>>> cuTile Kernel Coverage (# Launches):  9.2% <<<

Performance Benchmark

Benchmark TileGym's CUTILE-optimized kernels against standard PyTorch implementation. The --profile flag enables detailed performance metrics including throughput (tokens/sec) and generation latency.

Quick Start

Run benchmark scripts for automated comparison:

# LLaMA-3.1-8B benchmark
./bench_llama.sh

# DeepSeek-V2-Lite benchmark
./bench_deepseek.sh

# Qwen2-7B benchmark
./bench_qwen.sh

# Gemma-3-4B-IT benchmark
./bench_gemma3.sh

# GPT-OSS benchmark
./bench_gpt_oss.sh

# Mistral-7B benchmark
./bench_mistral.sh

# Phi-3-mini-4k-instruct benchmark
./bench_phi3.sh

Manual Benchmark

LLaMA-3.1-8B Benchmark

# PyTorch baseline
python infer.py \
    --model_id meta-llama/Meta-Llama-3.1-8B \
    --profile \
    --sentence_file sample_inputs/input_prompt_32K.txt \
    --output_length 100

# TileGym CUTILE backend
python infer.py \
    --model_id meta-llama/Meta-Llama-3.1-8B \
    --use_tilegym \
    --use_cutile \
    --use_attn \
    --profile \
    --sentence_file sample_inputs/input_prompt_32K.txt \
    --output_length 100

DeepSeek-V2-Lite Benchmark

# PyTorch baseline
python infer.py \
    --model_id deepseek-ai/DeepSeek-V2-Lite-Chat \
    --profile \
    --sentence_file sample_inputs/input_prompt_small.txt \
    --output_length 100

# TileGym CUTILE backend
python infer.py \
    --model_id deepseek-ai/DeepSeek-V2-Lite-Chat \
    --use_tilegym \
    --use_cutile \
    --use_attn \
    --profile \
    --sentence_file sample_inputs/input_prompt_small.txt \
    --output_length 100

Qwen2-7B Benchmark

# PyTorch baseline
python infer.py \
    --model_id Qwen/Qwen2-7B \
    --profile \
    --sentence_file sample_inputs/input_prompt_small.txt \
    --batch_size 16 \
    --output_length 100

# TileGym CUTILE backend
python infer.py \
    --model_id Qwen/Qwen2-7B \
    --use_tilegym \
    --use_cutile \
    --use_attn \
    --profile \
    --sentence_file sample_inputs/input_prompt_small.txt \
    --batch_size 16 \
    --output_length 100

Gemma-3-4B-IT Benchmark

# PyTorch baseline
python infer.py \
    --model_id google/gemma-3-4b-it \
    --profile \
    --sentence_file sample_inputs/input_prompt_small.txt \
    --output_length 100

# TileGym CUTILE backend
python infer.py \
    --model_id google/gemma-3-4b-it \
    --use_tilegym \
    --use_cutile \
    --use_attn \
    --profile \
    --sentence_file sample_inputs/input_prompt_small.txt \
    --output_length 100

Mistral-7B Benchmark

# PyTorch baseline
python infer.py \
    --model_id mistralai/Mistral-7B-Instruct-v0.3 \
    --profile \
    --sentence_file sample_inputs/input_prompt_32K.txt \
    --output_length 100

# TileGym CUTILE backend
python infer.py \
    --model_id mistralai/Mistral-7B-Instruct-v0.3 \
    --use_tilegym \
    --use_cutile \
    --use_attn \
    --profile \
    --sentence_file sample_inputs/input_prompt_32K.txt \
    --output_length 100

Command Line Options

Option Description Default
--model_id HuggingFace model ID or local path meta-llama/Meta-Llama-3.1-8B
--use_tilegym Enable TileGym optimization False
--use_cutile Use CUTILE backend False
--use_attn Enable attention optimization False
--input_text Input prompt text -
--sentence_file Input file path -
--output_length Number of tokens to generate 100
--batch_size Batch size 1
--precision bfloat16 or float32 bfloat16
--num_runs Benchmark iterations 5
--warmup_runs Warmup iterations 2
--profile Enable profiling False
--report_kernel_coverage Report cuTile kernel GPU time and launch count coverage via nsys False
--show_outputs Print generated text False

Using Local Models

You can use locally cached models by specifying the path directly:

# Use local model path
python infer.py \
    --model_id /path/to/local/model \
    --use_tilegym \
    --use_attn \
    --use_cutile \
    --show_outputs

Troubleshooting

CUDA Out of Memory

  • Reduce --batch_size or use smaller inputs
  • Use --precision bfloat16

Model Download Issues

Slow Performance

  • Disable --use_tilegym flag to see whether naive version has output
  • Sometimes, it may take more than one minute to get the output when your output sentence is too long. Try to use shorter input and reduce --output_length

Import Errors

  • Install TileGym: pip install -e . (from repo root)