Skip to content

Commit 04b56d5

Browse files
Update
1 parent eff50f9 commit 04b56d5

1 file changed

Lines changed: 4 additions & 7 deletions

File tree

src/maxtext/layers/initializers.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,13 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
9696
if out_sharding is not None:
9797
if nnx.PARTITION_NAME in metadata:
9898
partition_name = metadata[nnx.PARTITION_NAME]
99-
# Only nnx.Param variables are typically scanned across the param_scan_axis
10099
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0
101100

102-
if isinstance(out_sharding, str):
103-
out_sharding = [out_sharding]
104-
else:
105-
out_sharding = list(out_sharding)
101+
sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
102+
if partition_name not in sharding_list:
103+
sharding_list.insert(scan_axis, partition_name)
106104

107-
out_sharding.insert(scan_axis, partition_name)
108-
out_sharding = tuple(out_sharding)
105+
out_sharding = tuple(sharding_list)
109106

110107
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
111108
variable.value,

0 commit comments

Comments
 (0)