Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions benchmarks/xla_flags_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@


# Enable SparseCore All Gather (1D), Reduce Scatter (1D) and All Reduce (ND)
# On Ironwood, by default:
# xla_tpu_enable_sparse_core_collective_offload_all_gather as True
# xla_tpu_enable_sparse_core_collective_offload_reduce_scatter as True
# xla_tpu_enable_sparse_core_collective_offload_all_reduce as True
ENABLE_SPARSECORE_OFFLOADING_FOR_RS_AG_AR = (
" --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false"
" --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false"
Expand All @@ -91,6 +95,8 @@

# Enable SparseCore Reduce Scatter (SC RS)
# Either one of CF or SC can be enabled at a time.
# On Ironwood, by default:
# xla_tpu_enable_sparse_core_collective_offload_reduce_scatter as True
ENABLE_SPARSECORE_OFFLOADING_FOR_REDUCE_SCATTER = (
" --xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter=false"
" --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true"
Expand All @@ -99,6 +105,8 @@

# Enable SparseCore All Gather (SC AG).
# Either one of CF or SC can be enabled at a time.
# On Ironwood, by default:
# xla_tpu_enable_sparse_core_collective_offload_all_gather as True
ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER = (
" --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false"
" --xla_tpu_enable_sparse_core_collective_offload_all_gather=true"
Expand All @@ -109,6 +117,8 @@
# Either one of CF or SC can be enabled at a time.
# This is useful for reducing the gradient reduction all-reduce time with
# overlapping with compute during that time.
# On Ironwood, by default:
# xla_tpu_enable_sparse_core_collective_offload_all_reduce as True
ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE = (
" --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false"
" --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true"
Expand Down
Loading
Loading