Skip to content

Commit dcea681

Browse files
committed
issue/949 - feat: add silu_and_mul for moore gpu with test pass
1 parent 3c8fb3c commit dcea681

18 files changed

Lines changed: 824 additions & 0 deletions

File tree

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
#include "ops/rms_norm.hpp"
1515
#include "ops/rope.hpp"
1616
#include "ops/silu.hpp"
17+
#include "ops/silu_and_mul.hpp"
1718
#include "ops/swiglu.hpp"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
5+
#include "common/op.hpp"
6+
7+
namespace infinicore::op {
8+
9+
INFINICORE_GRAPH_OP_CLASS(SiluAndMul, Tensor, Tensor);
10+
11+
Tensor silu_and_mul(Tensor x);
12+
void silu_and_mul_(Tensor out, Tensor x);
13+
14+
} // namespace infinicore::op

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "infiniop/ops/rope.h"
2727
#include "infiniop/ops/sigmoid.h"
2828
#include "infiniop/ops/silu.h"
29+
#include "infiniop/ops/silu_and_mul.h"
2930
#include "infiniop/ops/softmax.h"
3031
#include "infiniop/ops/softplus.h"
3132
#include "infiniop/ops/sub.h"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef __INFINIOP_SILU_AND_MUL_API_H__
2+
#define __INFINIOP_SILU_AND_MUL_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopSiluAndMulDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateSiluAndMulDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopSiluAndMulDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t output,
12+
infiniopTensorDescriptor_t input);
13+
14+
__C __export infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(
15+
infiniopSiluAndMulDescriptor_t desc,
16+
size_t *size);
17+
18+
__C __export infiniStatus_t infiniopSiluAndMul(
19+
infiniopSiluAndMulDescriptor_t desc,
20+
void *workspace,
21+
size_t workspace_size,
22+
void *output,
23+
const void *input,
24+
void *stream);
25+
26+
__C __export infiniStatus_t infiniopDestroySiluAndMulDescriptor(
27+
infiniopSiluAndMulDescriptor_t desc);
28+
29+
#endif // __INFINIOP_SILU_AND_MUL_API_H__

python/infinicore/nn/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .rms_norm import rms_norm
66
from .rope import RopeAlgo, rope
77
from .silu import silu
8+
from .silu_and_mul import silu_and_mul
89
from .swiglu import swiglu
910

1011
__all__ = [
@@ -17,4 +18,5 @@
1718
"embedding",
1819
"rope",
1920
"RopeAlgo",
21+
"silu_and_mul",
2022
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def silu_and_mul(input: Tensor, out=None) -> Tensor:
6+
r"""Apply the SiLU and Mul (SwiGLU) function.
7+
8+
Formula: output = SiLU(input_gate) * input_up
9+
Input shape: [..., 2*d], Output shape: [..., d]
10+
"""
11+
12+
if out is None:
13+
return Tensor(_infinicore.silu_and_mul(input._underlying))
14+
15+
_infinicore.silu_and_mul_(out._underlying, input._underlying)
16+
17+
return out
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include "infinicore/ops/silu_and_mul.hpp"
2+
#include "../../utils.hpp"
3+
4+
namespace infinicore::op {
5+
6+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SiluAndMul);
7+
8+
SiluAndMul::SiluAndMul(Tensor out, Tensor x) {
9+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, x);
10+
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, x);
11+
}
12+
13+
void SiluAndMul::execute(Tensor out, Tensor x) {
14+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(SiluAndMul, out, x);
15+
}
16+
17+
Tensor silu_and_mul(Tensor x) {
18+
Shape shape = x->shape();
19+
size_t ndim = x->ndim();
20+
21+
if (shape[ndim - 1] % 2 != 0) {
22+
throw std::runtime_error("SiluAndMul input last dim must be even.");
23+
}
24+
shape[ndim - 1] /= 2;
25+
26+
auto out = Tensor::empty(shape, x->dtype(), x->device());
27+
silu_and_mul_(out, x);
28+
return out;
29+
}
30+
31+
void silu_and_mul_(Tensor out, Tensor x) {
32+
SiluAndMul::execute(out, x);
33+
}
34+
35+
} // namespace infinicore::op
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "../infiniop_impl.hpp"
2+
#include "infinicore/ops/silu_and_mul.hpp"
3+
4+
namespace infinicore::op::silu_and_mul_impl::infiniop {
5+
6+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SiluAndMul, 100);
7+
8+
struct PlannedMeta {
9+
std::shared_ptr<Descriptor> descriptor;
10+
graph::GraphTensor workspace, output, input;
11+
};
12+
13+
void *plan(Tensor output, Tensor input) {
14+
size_t seed = hash_combine(output, input);
15+
16+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
17+
Descriptor, descriptor, SiluAndMul,
18+
seed, output->desc(), input->desc());
19+
20+
INFINIOP_WORKSPACE_TENSOR(workspace, SiluAndMul, descriptor);
21+
22+
auto planned = new PlannedMeta{
23+
descriptor,
24+
graph::GraphTensor(workspace),
25+
graph::GraphTensor(output),
26+
graph::GraphTensor(input)};
27+
28+
return planned;
29+
}
30+
31+
void run(void *planned_meta) {
32+
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
33+
34+
INFINICORE_CHECK_ERROR(infiniopSiluAndMul(
35+
planned->descriptor->desc,
36+
planned->workspace->data(),
37+
planned->workspace->numel(),
38+
planned->output->data(),
39+
planned->input->data(),
40+
context::getStream()));
41+
}
42+
43+
void cleanup(void **planned_meta_ptr) {
44+
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
45+
*planned_meta_ptr = nullptr;
46+
}
47+
48+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SiluAndMul, &plan, &run, &cleanup);
49+
50+
} // namespace infinicore::op::silu_and_mul_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "ops/rms_norm.hpp"
1919
#include "ops/rope.hpp"
2020
#include "ops/silu.hpp"
21+
#include "ops/silu_and_mul.hpp"
2122
#include "ops/swiglu.hpp"
2223

2324
namespace py = pybind11;
@@ -42,6 +43,7 @@ inline void bind(py::module &m) {
4243
bind_swiglu(m);
4344
bind_rope(m);
4445
bind_embedding(m);
46+
bind_silu_and_mul(m);
4547
}
4648

4749
} // namespace infinicore::ops
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/silu_and_mul.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_silu_and_mul(py::module &m) {
12+
m.def("silu_and_mul",
13+
&op::silu_and_mul,
14+
py::arg("input"),
15+
R"doc(
16+
SiLU and Mul (SwiGLU) activation function.
17+
Input should be [..., 2*d], output will be [..., d].
18+
)doc");
19+
20+
m.def("silu_and_mul_",
21+
&op::silu_and_mul_,
22+
py::arg("output"),
23+
py::arg("input"),
24+
R"doc(
25+
In-place or destination-specified SiLU and Mul (SwiGLU) activation function.
26+
)doc");
27+
}
28+
29+
} // namespace infinicore::ops

0 commit comments

Comments
 (0)