Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@
Q_LENGTH,
Q_LENGTH_NO_EXP,
)
from MaxText.layers import nnx_wrappers
from MaxText.layers.initializers import variable_to_logically_partitioned
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.sharding import logical_to_mesh_axes, maybe_shard_with_name
from maxtext.inference import page_manager
from maxtext.inference.kvcache import KVQuant, KVTensor
from maxtext.kernels.attention import jax_flash_attention
from maxtext.kernels.attention.ragged_attention import ragged_gqa
from maxtext.kernels.attention.ragged_attention import ragged_mha
from MaxText.layers import nnx_wrappers
from MaxText.layers.initializers import variable_to_logically_partitioned
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.sharding import logical_to_mesh_axes, maybe_shard_with_name
from maxtext.utils import max_utils
import numpy as np
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from MaxText import common_types as ctypes
from MaxText.common_types import ShardMode
from MaxText.sharding import maybe_shard_with_logical, create_sharding
from maxtext.kernels import megablox as mblx
from MaxText.sharding import logical_to_mesh_axes
from MaxText.layers import attentions, linears, nnx_wrappers, quantizations
from MaxText.layers.initializers import NdInitializer, default_bias_init, nd_dense_init, variable_to_logically_partitioned
from maxtext.kernels import megablox as mblx
from maxtext.utils import max_logging
from maxtext.utils import max_utils
import numpy as np
Expand Down
Loading