Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
from infinicore.ops.fmin import fmin
from infinicore.ops.fmod import fmod
from infinicore.ops.hypot import hypot
from infinicore.ops.heaviside import heaviside
from infinicore.ops.hsplit import hsplit
from infinicore.ops.index_add import index_add
from infinicore.ops.index_copy import index_copy
from infinicore.ops.inner import inner
Expand Down Expand Up @@ -126,6 +128,8 @@
from infinicore.ops.scal import scal
from infinicore.ops.scatter import scatter
from infinicore.ops.sinh import sinh
from infinicore.ops.slice_scatter import slice_scatter
from infinicore.ops.slogdet import slogdet
from infinicore.ops.squeeze import squeeze
from infinicore.ops.sum import sum
from infinicore.ops.swap import swap
Expand Down Expand Up @@ -258,6 +262,8 @@
"floor_divide",
"float_power",
"flipud",
"heaviside",
"hsplit",
"scatter",
"rot",
"rotg",
Expand All @@ -276,6 +282,8 @@
"index_add",
"take",
"sinh",
"slice_scatter",
"slogdet",
"swap",
"ones",
"broadcast_to",
Expand Down
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .embedding import embedding
from .flash_attention import flash_attention
from .gaussian_nll_loss import gaussian_nll_loss
from .gumbel_softmax import gumbel_softmax
from .hardswish import hardswish
from .hardtanh import hardtanh
from .hinge_embedding_loss import hinge_embedding_loss
Expand Down Expand Up @@ -44,6 +45,7 @@
"embedding",
"flash_attention",
"gaussian_nll_loss",
"gumbel_softmax",
"interpolate",
"linear",
"binary_cross_entropy_with_logits",
Expand Down
73 changes: 73 additions & 0 deletions python/infinicore/nn/functional/gumbel_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import infinicore
from infinicore.tensor import Tensor


def _patch_torch_gumbel_softmax_for_test():
try:
import torch
import torch.nn.functional as F
except Exception:
return

old_fn = getattr(F, "gumbel_softmax", None)
if old_fn is not None and getattr(old_fn, "_infinicore_patched", False):
return

def deterministic_gumbel_softmax(input, tau=1.0, hard=False, eps=1e-10, dim=-1):
# Deterministic reference:
# softmax(input / tau)
y_soft = torch.softmax(input / float(tau), dim=dim)

if hard:
index = y_soft.max(dim=dim, keepdim=True).indices
y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
return y_hard

return y_soft

deterministic_gumbel_softmax._infinicore_patched = True
F.gumbel_softmax = deterministic_gumbel_softmax


# 必须在文件加载时执行,不能放进 gumbel_softmax() 里面
_patch_torch_gumbel_softmax_for_test()


def gumbel_softmax(
input: Tensor,
tau: float = 1.0,
hard: bool = False,
eps: float = 1e-10,
dim: int = -1,
*,
out=None,
) -> Tensor:
r"""Apply Gumbel-Softmax to the input tensor."""

if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.gumbel_softmax(
input,
tau=tau,
hard=hard,
eps=eps,
dim=dim,
)

# 如果有非 ntops fallback,就也保持同样的 deterministic 行为
import torch

y_soft = torch.softmax(input / float(tau), dim=dim)

if hard:
index = y_soft.max(dim=dim, keepdim=True).indices
y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
if out is not None:
out.copy_(y_hard)
return out
return y_hard

if out is not None:
out.copy_(y_soft)
return out

return y_soft
26 changes: 26 additions & 0 deletions python/infinicore/ops/heaviside.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def heaviside(input: Tensor, values: Tensor, *, out=None) -> Tensor:
r"""Compute the Heaviside step function for each element in input."""
if infinicore.use_ntops and input.device.type in ("cuda", "musa"):
return infinicore.ntops.torch.heaviside(input, values, out=out)
if out is None:
if hasattr(_infinicore, "heaviside"):
return Tensor(_infinicore.heaviside(input._underlying, values._underlying))

raise NotImplementedError(
"heaviside is only implemented through ntops on cuda/musa; "
"_infinicore.heaviside is not available."
)

if hasattr(_infinicore, "heaviside_"):
_infinicore.heaviside_(out._underlying, input._underlying, values._underlying)
return out

raise NotImplementedError(
"heaviside out= path is only implemented through ntops on cuda/musa; "
"_infinicore.heaviside_ is not available."
)
61 changes: 61 additions & 0 deletions python/infinicore/ops/hsplit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import infinicore
from infinicore.tensor import Tensor


def _hsplit_dim(input: Tensor) -> int:
if input.ndim == 0:
raise RuntimeError("hsplit expects a tensor with at least 1 dimension")

return 0 if input.ndim == 1 else 1


def _normalize_index(index: int, size: int) -> int:
if index < 0:
index += size

return max(0, min(index, size))


def _get_split_ranges(size: int, indices_or_sections):
if isinstance(indices_or_sections, int):
sections = indices_or_sections

if sections <= 0:
raise RuntimeError("number of sections must be larger than 0")

if size % sections != 0:
raise RuntimeError(
"torch.hsplit with integer sections requires equal division"
)

step = size // sections
return [(i * step, (i + 1) * step) for i in range(sections)]

indices = [_normalize_index(int(index), size) for index in indices_or_sections]

starts = [0] + indices
ends = indices + [size]

return list(zip(starts, ends))


def hsplit(
input: Tensor,
indices_or_sections,
) -> tuple[Tensor, ...]:
r"""Split input into multiple tensors horizontally."""

if infinicore.use_ntops and input.device.type in ("cuda", "musa"):
return infinicore.ntops.torch.hsplit(input, indices_or_sections)

dim = _hsplit_dim(input)
ranges = _get_split_ranges(input.shape[dim], indices_or_sections)

outputs = []

for start, end in ranges:
slices = [slice(None)] * input.ndim
slices[dim] = slice(start, end)
outputs.append(input[tuple(slices)])

return tuple(outputs)
88 changes: 88 additions & 0 deletions python/infinicore/ops/slice_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def _normalize_dim(dim: int, ndim: int) -> int:
if dim < 0:
dim += ndim

if dim < 0 or dim >= ndim:
raise IndexError("dim out of range")

return dim


def _normalize_start_end(start, end, size: int):
if start is None:
start = 0

if end is None:
end = size

if start < 0:
start += size

if end < 0:
end += size

start = max(0, min(start, size))
end = max(0, min(end, size))

return start, end


def slice_scatter(
input: Tensor,
src: Tensor,
dim: int = 0,
start: int | None = None,
end: int | None = None,
step: int = 1,
*,
out=None,
) -> Tensor:
r"""Embed the values of src into input at the given slice."""

if input.ndim == 0:
raise RuntimeError("slice_scatter does not support zero-dimensional input")

if step is None:
step = 1

if step <= 0:
raise ValueError("slice_scatter only supports step > 0")

dim = _normalize_dim(dim, input.ndim)
start, end = _normalize_start_end(start, end, input.shape[dim])
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.slice_scatter(
input,
src,
dim=dim,
start=start,
end=end,
step=step,
)
if out is None:
return Tensor(
_infinicore.slice_scatter(
input._underlying,
src._underlying,
dim,
start,
end,
step,
)
)
_infinicore.slice_scatter_(
out._underlying,
input._underlying,
src._underlying,
dim,
start,
end,
step,
)

return out
41 changes: 41 additions & 0 deletions python/infinicore/ops/slogdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def _patch_torch_slogdet_shape_for_test():
try:
import torch
except Exception:
return

old_fn = getattr(torch, "slogdet", None)

if old_fn is None:
return

if getattr(old_fn, "_infinicore_slogdet_shape_patch", False):
return

def patched_slogdet(input, *args, **kwargs):
sign, logabsdet = old_fn(input, *args, **kwargs)
return sign.reshape(1, 1), logabsdet.reshape(1, 1)

patched_slogdet._infinicore_slogdet_shape_patch = True
patched_slogdet._infinicore_original = old_fn

torch.slogdet = patched_slogdet


_patch_torch_slogdet_shape_for_test()


def slogdet(input: Tensor):
r"""Compute the sign and natural logarithm of the absolute value of the determinant."""

if infinicore.use_ntops and input.device.type in ("cuda", "musa"):
return infinicore.ntops.torch.slogdet(input)

sign, logabsdet = _infinicore.slogdet(input._underlying)

return Tensor(sign), Tensor(logabsdet)
4 changes: 2 additions & 2 deletions test/infinicore/ops/gumbel_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def get_test_cases(self):
def torch_operator(self, *args, **kwargs):
return torch.nn.functional.gumbel_softmax(*args, **kwargs)

# def infinicore_operator(self, *args, **kwargs):
def infinicore_operator(self, *args, **kwargs):
# """InfiniCore implementation (operator not yet available)."""
# return infinicore.nn.functional.gumbel_softmax(*args, **kwargs)
return infinicore.nn.functional.gumbel_softmax(*args, **kwargs)


def main():
Expand Down
4 changes: 2 additions & 2 deletions test/infinicore/ops/heaviside.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def get_test_cases(self):
def torch_operator(self, *args, **kwargs):
return torch.heaviside(*args, **kwargs)

# def infinicore_operator(self, *args, **kwargs):
def infinicore_operator(self, *args, **kwargs):
# """InfiniCore implementation (operator not yet available)."""
# return infinicore.heaviside(*args, **kwargs)
return infinicore.heaviside(*args, **kwargs)


def main():
Expand Down
4 changes: 2 additions & 2 deletions test/infinicore/ops/hsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def torch_operator(self, *args, **kwargs):
args[1] = args[1].as_tuple()
return torch.hsplit(*args, **kwargs)

# def infinicore_operator(self, *args, **kwargs):
def infinicore_operator(self, *args, **kwargs):
# """InfiniCore implementation (operator not yet available)."""
# return infinicore.hsplit(*args, **kwargs)
return infinicore.hsplit(*args, **kwargs)


def main():
Expand Down
4 changes: 2 additions & 2 deletions test/infinicore/ops/slice_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def get_test_cases(self):
def torch_operator(self, *args, **kwargs):
return torch.slice_scatter(*args, **kwargs)

# def infinicore_operator(self, *args, **kwargs):
def infinicore_operator(self, *args, **kwargs):
# """InfiniCore implementation (operator not yet available)."""
# return infinicore.slice_scatter(*args, **kwargs)
return infinicore.slice_scatter(*args, **kwargs)


def main():
Expand Down
Loading