File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -21,14 +21,14 @@ jobs:
2121 fail-fast : false
2222 matrix :
2323 include :
24- - name : " python3.11-pytorch2.5.1 -gpus1"
24+ - name : " python3.11-pytorch2.6.0 -gpus1"
2525 gpu_num : 1
2626 python_version : 3.11
27- container : mosaicml/pytorch:2.5.1_cu124 -python3.11-ubuntu20 .04
28- - name : " python3.11-pytorch2.5.1 -gpus2"
27+ container : mosaicml/pytorch:2.6.0_cu124 -python3.11-ubuntu22 .04
28+ - name : " python3.11-pytorch2.6.0 -gpus2"
2929 gpu_num : 2
3030 python_version : 3.11
31- container : mosaicml/pytorch:2.5.1_cu124 -python3.11-ubuntu20 .04
31+ container : mosaicml/pytorch:2.6.0_cu124 -python3.11-ubuntu22 .04
3232 steps :
3333 - name : Run PR GPU tests
3434 uses : mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2
Original file line number Diff line number Diff line change 3232 additional_dependencies :
3333 - toml
3434- repo : https://github.com/hadialqattan/pycln
35- rev : v2.1.2
35+ rev : v2.5.0
3636 hooks :
3737 - id : pycln
3838 args : [. --all]
Original file line number Diff line number Diff line change @@ -73,6 +73,18 @@ class Arguments:
7373 moe_zloss_in_fp32 : bool = False
7474
7575 def __post_init__ (self ):
76+ # Sparse MLP is not supported with triton >=3.2.0
77+ # TODO: Remove this once sparse is supported with triton >=3.2.0
78+ if self .__getattribute__ ('mlp_impl' ) == 'sparse' :
79+ try :
80+ import triton
81+ if triton .__version__ >= '3.2.0' :
82+ raise ValueError (
83+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.' ,
84+ )
85+ except ImportError :
86+ raise ImportError ('Triton is required for sparse MLP implementation' )
87+
7688 if self .__getattribute__ ('mlp_impl' ) == 'grouped' :
7789 grouped_gemm .assert_grouped_gemm_is_available ()
7890
Original file line number Diff line number Diff line change 33
44# build requirements
55[build-system ]
6- requires = [" setuptools < 70.0.0" , " torch >= 2.5.1 , < 2.5.2 " ]
6+ requires = [" setuptools < 70.0.0" , " torch >= 2.6.0 , < 2.6.1 " ]
77build-backend = " setuptools.build_meta"
88
99# Pytest
Original file line number Diff line number Diff line change 6262install_requires = [
6363 'numpy>=1.21.5,<2.1.0' ,
6464 'packaging>=21.3.0,<24.2' ,
65- 'torch>=2.5.1 ,<2.5.2 ' ,
66- 'triton>=2.1 .0' ,
65+ 'torch>=2.6.0 ,<2.6.1 ' ,
66+ 'triton>=3.2.0,<3.3 .0' ,
6767 'stanford-stk==0.7.1' ,
6868]
6969
Original file line number Diff line number Diff line change @@ -53,6 +53,16 @@ def construct_moes(
5353 mlp_impl : str = 'sparse' ,
5454 moe_zloss_weight : float = 0 ,
5555):
56+ # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
57+ # TODO: Remove this once sparse is supported with triton >=3.2.0
58+ if mlp_impl == 'sparse' :
59+ try :
60+ import triton
61+ if triton .__version__ >= '3.2.0' :
62+ pytest .skip ('Sparse MLP is not supported with triton >=3.2.0' )
63+ except ImportError :
64+ pass
65+
5666 init_method = partial (torch .nn .init .normal_ , mean = 0.0 , std = 0.1 )
5767 args = Arguments (
5868 hidden_size = hidden_size ,
Original file line number Diff line number Diff line change @@ -23,6 +23,16 @@ def construct_dmoe_glu(
2323 mlp_impl : str = 'sparse' ,
2424 memory_optimized_mlp : bool = False ,
2525):
26+ # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
27+ # TODO: Remove this once sparse is supported with triton >=3.2.0
28+ if mlp_impl == 'sparse' :
29+ try :
30+ import triton
31+ if triton .__version__ >= '3.2.0' :
32+ pytest .skip ('Sparse MLP is not supported with triton >=3.2.0' )
33+ except ImportError :
34+ pass
35+
2636 init_method = partial (torch .nn .init .normal_ , mean = 0.0 , std = 0.1 )
2737 args = Arguments (
2838 hidden_size = hidden_size ,
Original file line number Diff line number Diff line change @@ -41,6 +41,15 @@ def construct_moe(
4141 moe_top_k : int = 1 ,
4242 moe_zloss_weight : float = 0 ,
4343):
44+ # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
45+ # TODO: Remove this once sparse is supported with triton >=3.2.0
46+ try :
47+ import triton
48+ if triton .__version__ >= '3.2.0' :
49+ pytest .skip ('Sparse MLP is not supported with triton >=3.2.0' )
50+ except ImportError :
51+ pass
52+
4453 init_method = partial (torch .nn .init .normal_ , mean = 0.0 , std = 0.1 )
4554 args = Arguments (
4655 hidden_size = hidden_size ,
You can’t perform that action at this time.
0 commit comments