-
Notifications
You must be signed in to change notification settings - Fork 110
Expand file tree
/
Copy pathelu.py
More file actions
32 lines (23 loc) · 1.03 KB
/
elu.py
File metadata and controls
32 lines (23 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False, *, out=None) -> Tensor:
r"""Apply the Exponential Linear Unit (ELU) function, element-wise.
ELU(x) = x if x >= 0 else alpha * (exp(x) - 1)
Args:
input: Input tensor
alpha: ELU parameter (default: 1.0)
inplace: If True, performs the operation in-place (default: False)
out: Optional output tensor for in-place operation
Returns:
Output tensor with ELU applied element-wise.
"""
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.elu(input, alpha=alpha, inplace=inplace)
if inplace:
_infinicore.elu_(input._underlying, input._underlying, alpha)
return input
if out is None:
return Tensor(_infinicore.elu(input._underlying, alpha))
_infinicore.elu_(out._underlying, input._underlying, alpha)
return out