Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@

FLASH_ATTN_VERSION = None

This comment was marked as outdated.

from fastdeploy.model_executor.utils import try_import


def init_flash_attn_version():
"""
Expand All @@ -85,12 +87,33 @@ def init_flash_attn_version():
if sm_version >= 100:
try:
paddle.enable_compat(scope={"cutlass"})
from flash_mask.cute.interface import flashmask_attention as fa4

global flashmask_attention_v4
flashmask_attention_v4 = fa4
FLASH_ATTN_VERSION = 4
logger.info("The current platform supports Flash Attention V4.")
try:
old_api = try_import(["paddlefleet.ops"])
if old_api is not None:
from paddlefleet.ops import is_flash_mask_available

if is_flash_mask_available():
from paddlefleet.ops.flash_mask.cute.interface import (
flashmask_attention as fa4,
)
else:
raise ModuleNotFoundError("flash_mask not available.")
else:
from paddlefleet_ops import is_flash_mask_available

if is_flash_mask_available():
from paddlefleet_ops.flash_mask.cute.interface import (
flashmask_attention as fa4,
)
else:
raise ModuleNotFoundError("flash_mask not available.")

global flashmask_attention_v4
flashmask_attention_v4 = fa4
FLASH_ATTN_VERSION = 4
logger.info("The current platform supports Flash Attention V4.")
except (ImportError, ModuleNotFoundError):
logger.info(f"The current platform[sm{get_sm_version()}] can't import fa V4.")
except ImportError:
logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.")

Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ aistudio_sdk
p2pstore
py-cpuinfo
flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.3-py3-none-any.whl
flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl
transformers>=4.55.1,<5.0.0
181 changes: 181 additions & 0 deletions tests/layers/test_flash_attn_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

from __future__ import annotations

import sys
import types
import unittest
from unittest import mock

import paddle

from fastdeploy.model_executor.layers.attention import flash_attn_backend
from fastdeploy.model_executor.layers.attention.flash_attn_backend import (
flash_attn_func,
)
Expand Down Expand Up @@ -205,5 +209,182 @@ def test_fa4(self):
)


class TestInitFlashAttnVersion(unittest.TestCase):
"""Tests for the init_flash_attn_version FA4 import branch (sm>=100)."""

_MODULE_NAMES = (
"paddlefleet",
"paddlefleet.ops",
"paddlefleet.ops.flash_mask",
"paddlefleet.ops.flash_mask.cute",
"paddlefleet.ops.flash_mask.cute.interface",
"paddlefleet_ops",
"paddlefleet_ops.flash_mask",
"paddlefleet_ops.flash_mask.cute",
"paddlefleet_ops.flash_mask.cute.interface",
)

def setUp(self):
# Save state to restore after each test.
self._saved_version = flash_attn_backend.FLASH_ATTN_VERSION
self._saved_v4 = flash_attn_backend.flashmask_attention_v4
self._saved_modules = {name: sys.modules.get(name) for name in self._MODULE_NAMES}
# Make sure each test starts with a clean module state.
for name in self._MODULE_NAMES:
sys.modules.pop(name, None)

def _block_old_api(self):
"""Force `paddlefleet.ops` import to fail regardless of what is installed."""
# Setting sys.modules[name] = None makes importlib.import_module raise ImportError.
sys.modules["paddlefleet"] = None
sys.modules["paddlefleet.ops"] = None

def _block_new_api(self):
"""Force `paddlefleet_ops` import to fail regardless of what is installed."""
sys.modules["paddlefleet_ops"] = None

def tearDown(self):
flash_attn_backend.FLASH_ATTN_VERSION = self._saved_version
flash_attn_backend.flashmask_attention_v4 = self._saved_v4
for name, mod in self._saved_modules.items():
if mod is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = mod

def _install_fake_paddlefleet_old_api(self, is_available: bool):
"""Inject fake `paddlefleet.ops` (old API) modules."""
pkg = types.ModuleType("paddlefleet")
pkg.__path__ = []
ops = types.ModuleType("paddlefleet.ops")
ops.__path__ = []
ops.is_flash_mask_available = lambda: is_available
pkg.ops = ops
flash_mask = types.ModuleType("paddlefleet.ops.flash_mask")
flash_mask.__path__ = []
cute = types.ModuleType("paddlefleet.ops.flash_mask.cute")
cute.__path__ = []
interface = types.ModuleType("paddlefleet.ops.flash_mask.cute.interface")
interface.flashmask_attention = mock.MagicMock(name="fa4_old")

sys.modules["paddlefleet"] = pkg
sys.modules["paddlefleet.ops"] = ops
sys.modules["paddlefleet.ops.flash_mask"] = flash_mask
sys.modules["paddlefleet.ops.flash_mask.cute"] = cute
sys.modules["paddlefleet.ops.flash_mask.cute.interface"] = interface
return interface.flashmask_attention

def _install_fake_paddlefleet_new_api(self, is_available: bool):
"""Inject fake `paddlefleet_ops` (new API) modules."""
ops = types.ModuleType("paddlefleet_ops")
ops.__path__ = []
ops.is_flash_mask_available = lambda: is_available
flash_mask = types.ModuleType("paddlefleet_ops.flash_mask")
flash_mask.__path__ = []
cute = types.ModuleType("paddlefleet_ops.flash_mask.cute")
cute.__path__ = []
interface = types.ModuleType("paddlefleet_ops.flash_mask.cute.interface")
interface.flashmask_attention = mock.MagicMock(name="fa4_new")

sys.modules["paddlefleet_ops"] = ops
sys.modules["paddlefleet_ops.flash_mask"] = flash_mask
sys.modules["paddlefleet_ops.flash_mask.cute"] = cute
sys.modules["paddlefleet_ops.flash_mask.cute.interface"] = interface
return interface.flashmask_attention

def test_fa4_old_api_import_success(self):
"""Old API (`paddlefleet.ops`) is preferred when available."""
fake_fa4 = self._install_fake_paddlefleet_old_api(is_available=True)
# Also install new API to verify the old API takes precedence.
new_fa4 = self._install_fake_paddlefleet_new_api(is_available=True)
flash_attn_backend.FLASH_ATTN_VERSION = None
flash_attn_backend.flashmask_attention_v4 = None

with (
mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True),
mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100),
mock.patch.object(paddle, "enable_compat", create=True, return_value=None),
):
flash_attn_backend.init_flash_attn_version()

self.assertEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4)
self.assertIs(flash_attn_backend.flashmask_attention_v4, fake_fa4)
self.assertIsNot(flash_attn_backend.flashmask_attention_v4, new_fa4)

def test_fa4_old_api_flash_mask_unavailable(self):
"""Old API present but `is_flash_mask_available` is False."""
self._install_fake_paddlefleet_old_api(is_available=False)
self._block_new_api()
flash_attn_backend.FLASH_ATTN_VERSION = None
flash_attn_backend.flashmask_attention_v4 = None

with (
mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True),
mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100),
mock.patch.object(paddle, "enable_compat", create=True, return_value=None),
):
try:
flash_attn_backend.init_flash_attn_version()
except NameError:
pass

self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4)

def test_fa4_new_api_import_success(self):
"""Falls back to new API (`paddlefleet_ops`) when old API is missing."""
fake_fa4 = self._install_fake_paddlefleet_new_api(is_available=True)
self._block_old_api()
flash_attn_backend.FLASH_ATTN_VERSION = None
flash_attn_backend.flashmask_attention_v4 = None

with (
mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True),
mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100),
mock.patch.object(paddle, "enable_compat", create=True, return_value=None),
):
flash_attn_backend.init_flash_attn_version()

self.assertEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4)
self.assertIs(flash_attn_backend.flashmask_attention_v4, fake_fa4)

def test_fa4_new_api_flash_mask_unavailable(self):
"""New API present but `is_flash_mask_available` is False."""
self._install_fake_paddlefleet_new_api(is_available=False)
self._block_old_api()
flash_attn_backend.FLASH_ATTN_VERSION = None
flash_attn_backend.flashmask_attention_v4 = None

with (
mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True),
mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100),
mock.patch.object(paddle, "enable_compat", create=True, return_value=None),
):
try:
flash_attn_backend.init_flash_attn_version()
except NameError:
pass

self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4)

def test_fa4_paddlefleet_import_error(self):
"""Neither old nor new API is importable."""
self._block_old_api()
self._block_new_api()
flash_attn_backend.FLASH_ATTN_VERSION = None
flash_attn_backend.flashmask_attention_v4 = None

with (
mock.patch.object(flash_attn_backend.current_platform, "is_cuda", return_value=True),
mock.patch.object(flash_attn_backend, "get_sm_version", return_value=100),
mock.patch.object(paddle, "enable_compat", create=True, return_value=None),
):
try:
flash_attn_backend.init_flash_attn_version()
except NameError:
pass

self.assertNotEqual(flash_attn_backend.FLASH_ATTN_VERSION, 4)


if __name__ == "__main__":
unittest.main()
Loading