File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -96,16 +96,13 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
9696 if out_sharding is not None :
9797 if nnx .PARTITION_NAME in metadata :
9898 partition_name = metadata [nnx .PARTITION_NAME ]
99- # Only nnx.Param variables are typically scanned across the param_scan_axis
10099 scan_axis = metadata .get ("param_scan_axis" , 0 ) if variable .type == nnx .Param else 0
101100
102- if isinstance (out_sharding , str ):
103- out_sharding = [out_sharding ]
104- else :
105- out_sharding = list (out_sharding )
101+ sharding_list = [out_sharding ] if isinstance (out_sharding , str ) else list (out_sharding )
102+ if partition_name not in sharding_list :
103+ sharding_list .insert (scan_axis , partition_name )
106104
107- out_sharding .insert (scan_axis , partition_name )
108- out_sharding = tuple (out_sharding )
105+ out_sharding = tuple (sharding_list )
109106
110107 return nn .LogicallyPartitioned ( # type: ignore[wrong-keyword-args]
111108 variable .value ,
You can’t perform that action at this time.
0 commit comments