Skip to content

Commit 4190fef

Browse files
authored
Merge pull request #433 from Modalities/uv_support_for_different_cuda_versions
feat: Added cuda version selection to uv build.
2 parents 0596085 + 7fd495c commit 4190fef

3 files changed

Lines changed: 75 additions & 10 deletions

File tree

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ It is recommended to install Modalities via uv or install PyTorch, psutil and Ni
4444
# Get uv (tested with uv version 0.9.13)
4545
curl -LsSf https://astral.sh/uv/install.sh | sh
4646

47-
uv sync
47+
uv sync --extra [cpu|cu126|cu128|cu130] # Get CUDA version via nvidia-smi
4848
source .venv/bin/activate
4949

5050
# For developers: use [tests,linting] and install pre-commit hooks
51-
uv sync --extra tests --extra linting
51+
uv sync --extra [cpu|cu126|cu128|cu130] --extra tests --extra linting
5252
pre-commit install --install-hooks
5353
```
5454

@@ -60,7 +60,8 @@ conda create -n modalities python=3.13
6060
conda activate modalities
6161

6262
# Install PyTorch, psutil, Ninja and Flash Attention
63-
pip install "torch<2.11.0"
63+
# For PyTorch, select the correct index URL for your CUDA/CPU setup from https://pytorch.org/get-started/locally/ e.g.:
64+
pip install "torch>=2.10,<2.11.0" torchvision --index-url https://download.pytorch.org/whl/cu130
6465
pip install psutil ninja # Ninja lowers compilation time of flash attention significantly
6566
pip install flash-attn==2.8.3 --no-build-isolation
6667
```

pyproject.toml

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ description = "Modalities, a PyTorch-native framework for distributed and reprod
66
readme = "README.md"
77
dependencies = [
88
"numpy",
9-
"torch<2.11.0",
109
"ninja",
1110
"packaging",
1211
"tqdm",
@@ -25,25 +24,86 @@ dependencies = [
2524
"matplotlib",
2625
"wandb",
2726
"einops>=0.7.0",
28-
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'",
2927
"debugpy", # For VSCode debugging support
3028
]
3129

3230
[project.urls]
3331
Homepage = "https://github.com/Modalities/modalities"
3432
Issues = "https://github.com/Modalities/modalities/issues"
3533

36-
[project.optional-dependencies]
37-
linting = ["pre-commit"]
38-
tests = ["pytest", "pytest-cov", "debugpy"]
39-
4034
[project.scripts]
4135
modalities = "modalities.__main__:main"
4236

4337
[build-system]
4438
requires = ["setuptools >= 61.0.0"]
4539
build-backend = "setuptools.build_meta"
4640

41+
[project.optional-dependencies]
42+
linting = ["pre-commit"]
43+
tests = ["pytest", "pytest-cov", "debugpy"]
44+
45+
cpu = ["torch>=2.10,<2.11.0", "torchvision"]
46+
cu126 = [
47+
"torch>=2.10,<2.11.0",
48+
"torchvision",
49+
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'"
50+
]
51+
cu128 = [
52+
"torch>=2.10,<2.11.0",
53+
"torchvision",
54+
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'"
55+
]
56+
cu130 = [
57+
"torch>=2.10,<2.11.0",
58+
"torchvision",
59+
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'"
60+
]
61+
62+
[tool.uv]
63+
conflicts = [
64+
[
65+
{ extra = "cpu" },
66+
{ extra = "cu126" },
67+
{ extra = "cu128" },
68+
{ extra = "cu130" },
69+
],
70+
]
71+
72+
[tool.uv.sources]
73+
torch = [
74+
{ index = "pytorch-cpu", extra = "cpu" },
75+
{ index = "pytorch-cu126", extra = "cu126" },
76+
{ index = "pytorch-cu128", extra = "cu128" },
77+
{ index = "pytorch-cu130", extra = "cu130" },
78+
]
79+
torchvision = [
80+
{ index = "pytorch-cpu", extra = "cpu" },
81+
{ index = "pytorch-cu126", extra = "cu126" },
82+
{ index = "pytorch-cu128", extra = "cu128" },
83+
{ index = "pytorch-cu130", extra = "cu130" },
84+
]
85+
86+
[[tool.uv.index]]
87+
name = "pytorch-cpu"
88+
url = "https://download.pytorch.org/whl/cpu"
89+
explicit = true
90+
91+
[[tool.uv.index]]
92+
name = "pytorch-cu126"
93+
url = "https://download.pytorch.org/whl/cu126"
94+
explicit = true
95+
96+
[[tool.uv.index]]
97+
name = "pytorch-cu128"
98+
url = "https://download.pytorch.org/whl/cu128"
99+
explicit = true
100+
101+
[[tool.uv.index]]
102+
name = "pytorch-cu130"
103+
url = "https://download.pytorch.org/whl/cu130"
104+
explicit = true
105+
106+
47107
[tool.uv.extra-build-dependencies]
48108
flash-attn = [
49109
{ requirement = "torch", match-runtime = true },

src/modalities/utils/mfu.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# https://www.nvidia.com/en-us/data-center/h100/
1515
#
1616
# NOTE: These values are valid for fp16 and bf16 only
17-
PEAK_PERFORMANCE = {"A100": 312e12, "H100": 989e12, "GH200": 989e12}
17+
PEAK_PERFORMANCE = {"A100": 312e12, "H100": 989e12, "GH200": 989e12, "B200": 2.25e15}
1818

1919

2020
class MFUCalculatorABC:
@@ -130,6 +130,10 @@ def _get_theoretical_gpu_peak_performance(model_parts: FSDPX | list[FSDP2], worl
130130
single_gpu_peak_performance = MFUCalculatorABC._get_theoretical_gpu_peak_performance_single(
131131
precision, "GH200"
132132
)
133+
elif device_name.startswith("NVIDIA B200"):
134+
single_gpu_peak_performance = MFUCalculatorABC._get_theoretical_gpu_peak_performance_single(
135+
precision, "B200"
136+
)
133137
else:
134138
warnings.warn(f"Could not get theoretical GPU peak performance for unknown device = {device_name}.")
135139
return None

0 commit comments

Comments
 (0)