-
Notifications
You must be signed in to change notification settings - Fork 110
Expand file tree
/
Copy pathsum.py
More file actions
38 lines (32 loc) · 1.06 KB
/
sum.py
File metadata and controls
38 lines (32 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def sum(
input: Tensor,
dim: int | tuple[int] | list[int] | None = None,
keepdim=False,
*,
dtype=None,
out=None,
) -> Tensor:
r"""Apply the sum function."""
if infinicore.use_ntops and input.device.type in ("cuda", "musa"):
return infinicore.ntops.torch.sum(
input, dim, keepdim=keepdim, dtype=dtype, out=out
)
if dim is None:
if out is None:
return Tensor(_infinicore.sum_global(input._underlying))
_infinicore.sum_global_(input._underlying, out._underlying)
return out
else:
target_dim = dim
if isinstance(target_dim, (tuple, list)):
if len(target_dim) == 1:
target_dim = target_dim[0]
if out is None:
return Tensor(
_infinicore.sum_reduce(input._underlying, target_dim, keepdim)
)
_infinicore.sum_reduce_(input._underlying, out._underlying, target_dim, keepdim)
return out