Skip to content

Commit 53a1969

Browse files
voltjiawooway777
authored andcommitted
issue/919 - Add a NineToothed implementation of scaled_dot_product_attention
1 parent 5ab3363 commit 53a1969

8 files changed

Lines changed: 249 additions & 10 deletions

File tree

python/infinicore/nn/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from .random_sample import random_sample
55
from .rms_norm import rms_norm
66
from .rope import RopeAlgo, rope
7+
from .scaled_dot_product_attention import scaled_dot_product_attention
78
from .silu import silu
89
from .swiglu import swiglu
910

1011
__all__ = [
1112
"causal_softmax",
1213
"random_sample",
1314
"rms_norm",
15+
"scaled_dot_product_attention",
1416
"silu",
1517
"swiglu",
1618
"linear",
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import math
2+
3+
from infinicore.lib import _infinicore
4+
from infinicore.tensor import Tensor
5+
6+
7+
def scaled_dot_product_attention(
8+
query,
9+
key,
10+
value,
11+
attn_mask=None,
12+
dropout_p=0,
13+
is_causal=False,
14+
scale=None,
15+
enable_gqa=False,
16+
):
17+
assert attn_mask is None and dropout_p == 0 and not enable_gqa
18+
19+
emb_dim = query.shape[-1]
20+
21+
if scale is None:
22+
scale = 1 / math.sqrt(emb_dim)
23+
24+
return Tensor(
25+
_infinicore.flash_attention(
26+
query._underlying, key._underlying, value._underlying, scale, is_causal
27+
)
28+
)

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ops/attention.hpp"
88
#include "ops/causal_softmax.hpp"
99
#include "ops/embedding.hpp"
10+
#include "ops/flash_attention.hpp"
1011
#include "ops/linear.hpp"
1112
#include "ops/matmul.hpp"
1213
#include "ops/mul.hpp"
@@ -29,6 +30,7 @@ inline void bind(py::module &m) {
2930
bind_add_rms_norm(m);
3031
bind_attention(m);
3132
bind_causal_softmax(m);
33+
bind_flash_attention(m);
3234
bind_random_sample(m);
3335
bind_linear(m);
3436
bind_matmul(m);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/flash_attention.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_flash_attention(py::module &m) {
12+
m.def("flash_attention",
13+
&op::flash_attention,
14+
py::arg("q"),
15+
py::arg("k"),
16+
py::arg("v"),
17+
py::arg("scale"),
18+
py::arg("is_causal"));
19+
}
20+
21+
} // namespace infinicore::ops
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import ninetoothed
2+
from ntops.kernels import scaled_dot_product_attention
3+
from ntops.kernels.scaled_dot_product_attention import CausalVariant
4+
5+
import infiniop.ninetoothed.build
6+
7+
8+
def build():
9+
with_kv_cache_values = (0,)
10+
emb_dim_values = (16, 32, 64, 128, 256)
11+
is_causal_values = (0, 1)
12+
with_attn_mask_values = (0,)
13+
causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT)
14+
dtype_values = (ninetoothed.float16, ninetoothed.float32)
15+
block_size_m_values = (64,)
16+
block_size_n_values = (64,)
17+
18+
constexpr_param_grid = {
19+
"with_kv_cache": with_kv_cache_values,
20+
"emb_dim": emb_dim_values,
21+
"is_causal": is_causal_values,
22+
"with_attn_mask": with_attn_mask_values,
23+
"causal_variant": causal_variant_values,
24+
"dtype": dtype_values,
25+
"block_size_m": block_size_m_values,
26+
"block_size_n": block_size_n_values,
27+
}
28+
29+
infiniop.ninetoothed.build.build(
30+
scaled_dot_product_attention.premake,
31+
constexpr_param_grid,
32+
caller="cuda",
33+
op_name="flash_attention",
34+
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
35+
)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__
2+
#define __FLASH_ATTENTION_DESCRIPTOR_H__
3+
4+
#include "../../../handle.h"
5+
#include "../../../operator.h"
6+
#include "../../../tensor.h"
7+
8+
#include "../../../../../build/ninetoothed/flash_attention.h"
9+
#include "../../../ninetoothed/utils.h"
10+
11+
namespace op::flash_attention::ninetoothed {
12+
13+
class Descriptor final : public InfiniopDescriptor {
14+
public:
15+
Descriptor(infiniopHandle_t handle,
16+
infiniopTensorDescriptor_t out_desc,
17+
infiniopTensorDescriptor_t q_desc,
18+
infiniopTensorDescriptor_t k_desc,
19+
infiniopTensorDescriptor_t v_desc,
20+
double scale,
21+
char is_causal) : InfiniopDescriptor{handle->device, handle->device_id},
22+
_query_shape{q_desc->shape()},
23+
_query_strides{q_desc->strides()},
24+
_key_shape{k_desc->shape()},
25+
_key_strides{k_desc->strides()},
26+
_value_shape{v_desc->shape()},
27+
_value_strides{v_desc->strides()},
28+
_output_strides{out_desc->strides()},
29+
_dtype{q_desc->dtype()},
30+
_scale{scale},
31+
_is_causal{is_causal} {}
32+
33+
~Descriptor() = default;
34+
35+
size_t get_workspace_size() const {
36+
return 0;
37+
}
38+
39+
infiniStatus_t calculate(void *workspace,
40+
size_t workspace_size,
41+
void *out,
42+
const void *q,
43+
const void *k,
44+
const void *v,
45+
void *stream) const {
46+
uint64_t empty_shape[4];
47+
int64_t empty_strides[4];
48+
49+
auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}};
50+
auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}};
51+
auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}};
52+
53+
NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides};
54+
NineToothedTensor is_causal;
55+
NineToothedTensor scale{const_cast<double *>(&_scale), nullptr, nullptr};
56+
auto output{::ninetoothed::Tensor{out, _query_shape, _output_strides}};
57+
NineToothedTensor with_attn_mask;
58+
NineToothedTensor causal_variant;
59+
60+
const auto with_kv_cache_{0};
61+
const auto emb_dim_{_query_shape[3]};
62+
const auto is_causal_{_is_causal};
63+
const auto with_attn_mask_{0};
64+
const auto causal_variant_{1};
65+
const auto dtype_{_dtype};
66+
67+
constexpr auto block_size_m_{64};
68+
constexpr auto block_size_n_{64};
69+
70+
launch_flash_attention(stream,
71+
query,
72+
key,
73+
value,
74+
attn_mask,
75+
is_causal,
76+
scale,
77+
output,
78+
with_attn_mask,
79+
causal_variant,
80+
with_kv_cache_,
81+
emb_dim_,
82+
is_causal_,
83+
with_attn_mask_,
84+
causal_variant_,
85+
dtype_,
86+
block_size_m_,
87+
block_size_n_);
88+
89+
return INFINI_STATUS_SUCCESS;
90+
}
91+
92+
static infiniStatus_t create(infiniopHandle_t handle,
93+
Descriptor **desc,
94+
infiniopTensorDescriptor_t out_desc,
95+
infiniopTensorDescriptor_t q_desc,
96+
infiniopTensorDescriptor_t k_desc,
97+
infiniopTensorDescriptor_t v_desc,
98+
double scale,
99+
char is_causal) {
100+
*desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, scale, is_causal};
101+
102+
return INFINI_STATUS_SUCCESS;
103+
}
104+
105+
private:
106+
using Size = ::ninetoothed::Tensor<>::Size;
107+
108+
using Stride = ::ninetoothed::Tensor<>::Stride;
109+
110+
std::vector<Size> _query_shape;
111+
112+
std::vector<Stride> _query_strides;
113+
114+
std::vector<Size> _key_shape;
115+
116+
std::vector<Stride> _key_strides;
117+
118+
std::vector<Size> _value_shape;
119+
120+
std::vector<Stride> _value_strides;
121+
122+
std::vector<Stride> _output_strides;
123+
124+
infiniDtype_t _dtype;
125+
126+
double _scale;
127+
128+
char _is_causal;
129+
};
130+
131+
} // namespace op::flash_attention::ninetoothed
132+
133+
#endif // __FLASH_ATTENTION_DESCRIPTOR_H__

src/infiniop/ops/flash_attention/operator.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
// #include "cpu/flash_attention_cpu.h"
77
#endif
88
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
9+
#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API)
10+
#include "ninetoothed/descriptor.h"
11+
#else
912
// #include "nvidia/flash_attention_nvidia.cuh"
1013
#endif
14+
#endif
1115

1216
__C infiniStatus_t infiniopCreateFlashAttentionDescriptor(
1317
infiniopHandle_t handle,
@@ -37,7 +41,11 @@ __C infiniStatus_t infiniopCreateFlashAttentionDescriptor(
3741
// CREATE(INFINI_DEVICE_CPU, cpu);
3842
#endif
3943
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
44+
#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API)
45+
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
46+
#else
4047
// CREATE(INFINI_DEVICE_NVIDIA, nvidia);
48+
#endif
4149
#endif
4250
default:
4351
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -60,7 +68,11 @@ __C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
6068
// GET_SIZE(INFINI_DEVICE_CPU, cpu);
6169
#endif
6270
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
71+
#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API)
72+
GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed);
73+
#else
6374
// GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia);
75+
#endif
6476
#endif
6577
default:
6678
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -90,7 +102,11 @@ __C infiniStatus_t infiniopFlashAttention(
90102
// CALCULATE(INFINI_DEVICE_CPU, cpu);
91103
#endif
92104
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
105+
#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API)
106+
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
107+
#else
93108
// CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
109+
#endif
94110
#endif
95111
default:
96112
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -112,7 +128,11 @@ __C infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
112128
// DESTROY(INFINI_DEVICE_CPU, cpu);
113129
#endif
114130
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
131+
#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API)
132+
DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed);
133+
#else
115134
// DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
135+
#endif
116136
#endif
117137
default:
118138
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;

test/infinicore/ops/scaled_dot_product_attention.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,16 @@
1111
# q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim)
1212

1313
_TEST_CASES_DATA = [
14-
((2, 8, 16), (2, 8, 16), (2, 8, 16), None, 0.0, False),
15-
((1, 4, 32), (1, 4, 32), (1, 4, 32), None, 0.0, False),
16-
((2, 6, 12), (2, 6, 12), (2, 6, 12), None, 0.0, True),
17-
((3, 8, 8), (3, 8, 8), (3, 8, 8), None, 0.0, False),
18-
((2, 4, 16), (2, 4, 16), (2, 4, 16), None, 0.0, True),
19-
((1, 2, 64), (1, 2, 64), (1, 2, 64), None, 0.0, False),
14+
((1, 1, 2, 16), (1, 1, 2, 16), (1, 1, 2, 16), None, 0.0, False),
15+
((1, 2, 8, 16), (1, 2, 8, 16), (1, 2, 8, 16), None, 0.0, False),
16+
((1, 1, 4, 32), (1, 1, 4, 32), (1, 1, 4, 32), None, 0.0, False),
17+
((1, 2, 4, 16), (1, 2, 4, 16), (1, 2, 4, 16), None, 0.0, True),
18+
((1, 1, 2, 64), (1, 1, 2, 64), (1, 1, 2, 64), None, 0.0, False),
2019
]
2120

2221
_TOLERANCE_MAP = {
2322
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
24-
infinicore.float32: {"atol": 1e-4, "rtol": 1e-4},
23+
infinicore.float32: {"atol": 1e-3, "rtol": 1e-3},
2524
}
2625
_TENSOR_DTYPES = [infinicore.float16, infinicore.float32]
2726

@@ -68,9 +67,8 @@ def get_test_cases(self):
6867
def torch_operator(self, *args, **kwargs):
6968
return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs)
7069

71-
# def infinicore_operator(self, *args, **kwargs):
72-
# """InfiniCore implementation (operator not yet available)."""
73-
# return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs)
70+
def infinicore_operator(self, *args, **kwargs):
71+
return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs)
7472

7573

7674
def main():

0 commit comments

Comments
 (0)