File tree Expand file tree Collapse file tree
transformer_engine/pytorch Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -575,14 +575,12 @@ def set_device_mesh(
575575 weight_mesh : Optional[DeviceMesh]
576576 Not used for DotProductAttention as there are no quantized weights.
577577 """
578+ warnings .warn (f"weight_mesh not necessary for { self .__class__ .__name__ } : { weight_mesh } " )
578579 if tp_mesh is not None :
579580 # Validate TP DeviceMesh / Group. Must be consistent with tp_size.
580- assert (
581- tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (),
582- (
583- f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
584- f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
585- ),
581+ assert tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (), (
582+ f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
583+ f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
586584 )
587585 # Set the tensor parallel group from the mesh.
588586 self .set_tensor_parallel_group (tp_mesh .get_group ())
Original file line number Diff line number Diff line change @@ -641,12 +641,9 @@ def set_device_mesh(
641641 """
642642 if tp_mesh is not None :
643643 # Validate TP DeviceMesh / Group. Must be consistent with tp_size.
644- assert (
645- tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (),
646- (
647- f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
648- f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
649- ),
644+ assert tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (), (
645+ f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
646+ f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
650647 )
651648 # Set the tensor parallel group from the mesh.
652649 self .set_tensor_parallel_group (tp_mesh .get_group ())
@@ -655,7 +652,7 @@ def set_device_mesh(
655652 # Iterate through child sub-modules without deep recursion.
656653 # Automatically detects TransformerEngine TP modules and
657654 # the capability to call this method at any level.
658- for name , child in self .named_children ():
655+ for child in self .children ():
659656 if hasattr (child , "set_device_mesh" ):
660657 child .set_device_mesh (tp_mesh , weight_mesh )
661658
Original file line number Diff line number Diff line change @@ -858,12 +858,9 @@ def set_device_mesh(
858858 """
859859 if tp_mesh is not None :
860860 # Validate TP DeviceMesh / Group. Must be consistent with tp_size.
861- assert (
862- tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (),
863- (
864- f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
865- f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
866- ),
861+ assert tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (), (
862+ f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
863+ f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
867864 )
868865 # Set the tensor parallel group from the mesh.
869866 self .set_tensor_parallel_group (tp_mesh .get_group ())
Original file line number Diff line number Diff line change @@ -168,6 +168,7 @@ def set_device_mesh(
168168 Quantized DTensor parameters are currently not supported for FusibleOperation(s),
169169 and this mesh is not used.
170170 """
171+ warnings .warn (f"weight_mesh not necessary for { self .__class__ .__name__ } : { weight_mesh } " )
171172 if tp_mesh is not None :
172173 # Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility
173174 # with DTensor parameters in TP layers to support DTensor operations.
Original file line number Diff line number Diff line change @@ -1492,12 +1492,9 @@ def set_device_mesh(
14921492 """
14931493 if tp_mesh is not None :
14941494 # Validate TP DeviceMesh / Group. Must be consistent with tp_size.
1495- assert (
1496- tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (),
1497- (
1498- f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
1499- f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
1500- ),
1495+ assert tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (), (
1496+ f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
1497+ f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
15011498 )
15021499 # Set the tensor parallel group from the mesh.
15031500 self .set_tensor_parallel_group (tp_mesh .get_group ())
Original file line number Diff line number Diff line change @@ -2058,12 +2058,9 @@ def set_device_mesh(
20582058 """
20592059 if tp_mesh is not None :
20602060 # Validate TP DeviceMesh / Group. Must be consistent with tp_size.
2061- assert (
2062- tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (),
2063- (
2064- f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
2065- f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
2066- ),
2061+ assert tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (), (
2062+ f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
2063+ f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
20672064 )
20682065 # Set the tensor parallel group from the mesh.
20692066 self .set_tensor_parallel_group (tp_mesh .get_group ())
Original file line number Diff line number Diff line change @@ -1383,12 +1383,9 @@ def set_device_mesh(
13831383 """
13841384 if tp_mesh is not None :
13851385 # Validate TP DeviceMesh / Group. Must be consistent with tp_size.
1386- assert (
1387- tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (),
1388- (
1389- f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
1390- f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
1391- ),
1386+ assert tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (), (
1387+ f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
1388+ f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
13921389 )
13931390 # Set the tensor parallel group from the mesh.
13941391 self .set_tensor_parallel_group (tp_mesh .get_group ())
Original file line number Diff line number Diff line change @@ -171,6 +171,7 @@ def set_device_mesh(
171171 Quantized DTensor parameters are currently not supported for FusibleOperation(s),
172172 and this mesh is not used.
173173 """
174+ warnings .warn (f"weight_mesh not necessary for { self .__class__ .__name__ } : { weight_mesh } " )
174175 if tp_mesh is not None :
175176 # Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility
176177 # with DTensor parameters in TP layers to support DTensor operations.
Original file line number Diff line number Diff line change @@ -629,12 +629,9 @@ def set_device_mesh(
629629 """
630630 if tp_mesh is not None :
631631 # Validate TP DeviceMesh / Group. Must be consistent with tp_size.
632- assert (
633- tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (),
634- (
635- f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
636- f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
637- ),
632+ assert tp_mesh .ndim == 1 and self .tp_size == tp_mesh .size (), (
633+ f"TransformerEngine { self .__class__ .__name__ } TP init size ({ self .tp_size } ) "
634+ f"does not match the size of the provided TP DeviceMesh ({ tp_mesh .size ()} )."
638635 )
639636 # Set the tensor parallel group from the mesh.
640637 self .set_tensor_parallel_group (tp_mesh .get_group ())
@@ -643,7 +640,7 @@ def set_device_mesh(
643640 # Iterate through child sub-modules without deep recursion.
644641 # Automatically detects TransformerEngine TP modules and
645642 # the capability to call this method at any level.
646- for name , child in self .named_children ():
643+ for child in self .children ():
647644 if hasattr (child , "set_device_mesh" ):
648645 child .set_device_mesh (tp_mesh , weight_mesh )
649646
You can’t perform that action at this time.
0 commit comments