Skip to content

Commit 50da1dc

Browse files
committed
Lint.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 60b68f7 commit 50da1dc

9 files changed

Lines changed: 26 additions & 44 deletions

File tree

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -575,14 +575,12 @@ def set_device_mesh(
575575
weight_mesh : Optional[DeviceMesh]
576576
Not used for DotProductAttention as there are no quantized weights.
577577
"""
578+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
578579
if tp_mesh is not None:
579580
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
580-
assert (
581-
tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(),
582-
(
583-
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
584-
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
585-
),
581+
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (
582+
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
583+
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
586584
)
587585
# Set the tensor parallel group from the mesh.
588586
self.set_tensor_parallel_group(tp_mesh.get_group())

transformer_engine/pytorch/attention/multi_head_attention.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -641,12 +641,9 @@ def set_device_mesh(
641641
"""
642642
if tp_mesh is not None:
643643
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
644-
assert (
645-
tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(),
646-
(
647-
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
648-
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
649-
),
644+
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (
645+
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
646+
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
650647
)
651648
# Set the tensor parallel group from the mesh.
652649
self.set_tensor_parallel_group(tp_mesh.get_group())
@@ -655,7 +652,7 @@ def set_device_mesh(
655652
# Iterate through child sub-modules without deep recursion.
656653
# Automatically detects TransformerEngine TP modules and
657654
# the capability to call this method at any level.
658-
for name, child in self.named_children():
655+
for child in self.children():
659656
if hasattr(child, "set_device_mesh"):
660657
child.set_device_mesh(tp_mesh, weight_mesh)
661658

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -858,12 +858,9 @@ def set_device_mesh(
858858
"""
859859
if tp_mesh is not None:
860860
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
861-
assert (
862-
tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(),
863-
(
864-
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
865-
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
866-
),
861+
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (
862+
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
863+
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
867864
)
868865
# Set the tensor parallel group from the mesh.
869866
self.set_tensor_parallel_group(tp_mesh.get_group())

transformer_engine/pytorch/module/layernorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def set_device_mesh(
168168
Quantized DTensor parameters are currently not supported for FusibleOperation(s),
169169
and this mesh is not used.
170170
"""
171+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
171172
if tp_mesh is not None:
172173
# Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility
173174
# with DTensor parameters in TP layers to support DTensor operations.

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,12 +1492,9 @@ def set_device_mesh(
14921492
"""
14931493
if tp_mesh is not None:
14941494
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
1495-
assert (
1496-
tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(),
1497-
(
1498-
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
1499-
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
1500-
),
1495+
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (
1496+
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
1497+
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
15011498
)
15021499
# Set the tensor parallel group from the mesh.
15031500
self.set_tensor_parallel_group(tp_mesh.get_group())

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,12 +2058,9 @@ def set_device_mesh(
20582058
"""
20592059
if tp_mesh is not None:
20602060
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
2061-
assert (
2062-
tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(),
2063-
(
2064-
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
2065-
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
2066-
),
2061+
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (
2062+
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
2063+
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
20672064
)
20682065
# Set the tensor parallel group from the mesh.
20692066
self.set_tensor_parallel_group(tp_mesh.get_group())

transformer_engine/pytorch/module/linear.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,12 +1383,9 @@ def set_device_mesh(
13831383
"""
13841384
if tp_mesh is not None:
13851385
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
1386-
assert (
1387-
tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(),
1388-
(
1389-
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
1390-
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
1391-
),
1386+
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (
1387+
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
1388+
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
13921389
)
13931390
# Set the tensor parallel group from the mesh.
13941391
self.set_tensor_parallel_group(tp_mesh.get_group())

transformer_engine/pytorch/module/rmsnorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def set_device_mesh(
171171
Quantized DTensor parameters are currently not supported for FusibleOperation(s),
172172
and this mesh is not used.
173173
"""
174+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
174175
if tp_mesh is not None:
175176
# Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility
176177
# with DTensor parameters in TP layers to support DTensor operations.

transformer_engine/pytorch/transformer.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,9 @@ def set_device_mesh(
629629
"""
630630
if tp_mesh is not None:
631631
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
632-
assert (
633-
tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(),
634-
(
635-
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
636-
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
637-
),
632+
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (
633+
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
634+
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
638635
)
639636
# Set the tensor parallel group from the mesh.
640637
self.set_tensor_parallel_group(tp_mesh.get_group())
@@ -643,7 +640,7 @@ def set_device_mesh(
643640
# Iterate through child sub-modules without deep recursion.
644641
# Automatically detects TransformerEngine TP modules and
645642
# the capability to call this method at any level.
646-
for name, child in self.named_children():
643+
for child in self.children():
647644
if hasattr(child, "set_device_mesh"):
648645
child.set_device_mesh(tp_mesh, weight_mesh)
649646

0 commit comments

Comments
 (0)