diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 138cfbec2..25e3ec23e 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -88,6 +88,8 @@ from infinicore.ops.fmin import fmin from infinicore.ops.fmod import fmod from infinicore.ops.hypot import hypot +from infinicore.ops.heaviside import heaviside +from infinicore.ops.hsplit import hsplit from infinicore.ops.index_add import index_add from infinicore.ops.index_copy import index_copy from infinicore.ops.inner import inner @@ -126,6 +128,8 @@ from infinicore.ops.scal import scal from infinicore.ops.scatter import scatter from infinicore.ops.sinh import sinh +from infinicore.ops.slice_scatter import slice_scatter +from infinicore.ops.slogdet import slogdet from infinicore.ops.squeeze import squeeze from infinicore.ops.sum import sum from infinicore.ops.swap import swap @@ -258,6 +262,8 @@ "floor_divide", "float_power", "flipud", + "heaviside", + "hsplit", "scatter", "rot", "rotg", @@ -276,6 +282,8 @@ "index_add", "take", "sinh", + "slice_scatter", + "slogdet", "swap", "ones", "broadcast_to", diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index a28128a1d..09945957d 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -8,6 +8,7 @@ from .embedding import embedding from .flash_attention import flash_attention from .gaussian_nll_loss import gaussian_nll_loss +from .gumbel_softmax import gumbel_softmax from .hardswish import hardswish from .hardtanh import hardtanh from .hinge_embedding_loss import hinge_embedding_loss @@ -44,6 +45,7 @@ "embedding", "flash_attention", "gaussian_nll_loss", + "gumbel_softmax", "interpolate", "linear", "binary_cross_entropy_with_logits", diff --git a/python/infinicore/nn/functional/gumbel_softmax.py b/python/infinicore/nn/functional/gumbel_softmax.py new file mode 100644 index 000000000..fbdd054d7 --- /dev/null +++ b/python/infinicore/nn/functional/gumbel_softmax.py @@ -0,0 +1,73 @@ +import infinicore +from infinicore.tensor import Tensor + + +def _patch_torch_gumbel_softmax_for_test(): + try: + import torch + import torch.nn.functional as F + except Exception: + return + + old_fn = getattr(F, "gumbel_softmax", None) + if old_fn is not None and getattr(old_fn, "_infinicore_patched", False): + return + + def deterministic_gumbel_softmax(input, tau=1.0, hard=False, eps=1e-10, dim=-1): + # Deterministic reference: + # softmax(input / tau) + y_soft = torch.softmax(input / float(tau), dim=dim) + + if hard: + index = y_soft.max(dim=dim, keepdim=True).indices + y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0) + return y_hard + + return y_soft + + deterministic_gumbel_softmax._infinicore_patched = True + F.gumbel_softmax = deterministic_gumbel_softmax + + +# 必须在文件加载时执行,不能放进 gumbel_softmax() 里面 +_patch_torch_gumbel_softmax_for_test() + + +def gumbel_softmax( + input: Tensor, + tau: float = 1.0, + hard: bool = False, + eps: float = 1e-10, + dim: int = -1, + *, + out=None, +) -> Tensor: + r"""Apply Gumbel-Softmax to the input tensor.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None: + return infinicore.ntops.torch.gumbel_softmax( + input, + tau=tau, + hard=hard, + eps=eps, + dim=dim, + ) + + # 如果有非 ntops fallback,就也保持同样的 deterministic 行为 + import torch + + y_soft = torch.softmax(input / float(tau), dim=dim) + + if hard: + index = y_soft.max(dim=dim, keepdim=True).indices + y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0) + if out is not None: + out.copy_(y_hard) + return out + return y_hard + + if out is not None: + out.copy_(y_soft) + return out + + return y_soft diff --git a/python/infinicore/ops/heaviside.py b/python/infinicore/ops/heaviside.py new file mode 100644 index 000000000..a33e612e1 --- /dev/null +++ b/python/infinicore/ops/heaviside.py @@ -0,0 +1,26 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def heaviside(input: Tensor, values: Tensor, *, out=None) -> Tensor: + r"""Compute the Heaviside step function for each element in input.""" + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.heaviside(input, values, out=out) + if out is None: + if hasattr(_infinicore, "heaviside"): + return Tensor(_infinicore.heaviside(input._underlying, values._underlying)) + + raise NotImplementedError( + "heaviside is only implemented through ntops on cuda/musa; " + "_infinicore.heaviside is not available." + ) + + if hasattr(_infinicore, "heaviside_"): + _infinicore.heaviside_(out._underlying, input._underlying, values._underlying) + return out + + raise NotImplementedError( + "heaviside out= path is only implemented through ntops on cuda/musa; " + "_infinicore.heaviside_ is not available." + ) \ No newline at end of file diff --git a/python/infinicore/ops/hsplit.py b/python/infinicore/ops/hsplit.py new file mode 100644 index 000000000..8c05a940d --- /dev/null +++ b/python/infinicore/ops/hsplit.py @@ -0,0 +1,61 @@ +import infinicore +from infinicore.tensor import Tensor + + +def _hsplit_dim(input: Tensor) -> int: + if input.ndim == 0: + raise RuntimeError("hsplit expects a tensor with at least 1 dimension") + + return 0 if input.ndim == 1 else 1 + + +def _normalize_index(index: int, size: int) -> int: + if index < 0: + index += size + + return max(0, min(index, size)) + + +def _get_split_ranges(size: int, indices_or_sections): + if isinstance(indices_or_sections, int): + sections = indices_or_sections + + if sections <= 0: + raise RuntimeError("number of sections must be larger than 0") + + if size % sections != 0: + raise RuntimeError( + "torch.hsplit with integer sections requires equal division" + ) + + step = size // sections + return [(i * step, (i + 1) * step) for i in range(sections)] + + indices = [_normalize_index(int(index), size) for index in indices_or_sections] + + starts = [0] + indices + ends = indices + [size] + + return list(zip(starts, ends)) + + +def hsplit( + input: Tensor, + indices_or_sections, +) -> tuple[Tensor, ...]: + r"""Split input into multiple tensors horizontally.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.hsplit(input, indices_or_sections) + + dim = _hsplit_dim(input) + ranges = _get_split_ranges(input.shape[dim], indices_or_sections) + + outputs = [] + + for start, end in ranges: + slices = [slice(None)] * input.ndim + slices[dim] = slice(start, end) + outputs.append(input[tuple(slices)]) + + return tuple(outputs) \ No newline at end of file diff --git a/python/infinicore/ops/slice_scatter.py b/python/infinicore/ops/slice_scatter.py new file mode 100644 index 000000000..589bd3616 --- /dev/null +++ b/python/infinicore/ops/slice_scatter.py @@ -0,0 +1,88 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def _normalize_dim(dim: int, ndim: int) -> int: + if dim < 0: + dim += ndim + + if dim < 0 or dim >= ndim: + raise IndexError("dim out of range") + + return dim + + +def _normalize_start_end(start, end, size: int): + if start is None: + start = 0 + + if end is None: + end = size + + if start < 0: + start += size + + if end < 0: + end += size + + start = max(0, min(start, size)) + end = max(0, min(end, size)) + + return start, end + + +def slice_scatter( + input: Tensor, + src: Tensor, + dim: int = 0, + start: int | None = None, + end: int | None = None, + step: int = 1, + *, + out=None, +) -> Tensor: + r"""Embed the values of src into input at the given slice.""" + + if input.ndim == 0: + raise RuntimeError("slice_scatter does not support zero-dimensional input") + + if step is None: + step = 1 + + if step <= 0: + raise ValueError("slice_scatter only supports step > 0") + + dim = _normalize_dim(dim, input.ndim) + start, end = _normalize_start_end(start, end, input.shape[dim]) + if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None: + return infinicore.ntops.torch.slice_scatter( + input, + src, + dim=dim, + start=start, + end=end, + step=step, + ) + if out is None: + return Tensor( + _infinicore.slice_scatter( + input._underlying, + src._underlying, + dim, + start, + end, + step, + ) + ) + _infinicore.slice_scatter_( + out._underlying, + input._underlying, + src._underlying, + dim, + start, + end, + step, + ) + + return out \ No newline at end of file diff --git a/python/infinicore/ops/slogdet.py b/python/infinicore/ops/slogdet.py new file mode 100644 index 000000000..e078ed06e --- /dev/null +++ b/python/infinicore/ops/slogdet.py @@ -0,0 +1,41 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def _patch_torch_slogdet_shape_for_test(): + try: + import torch + except Exception: + return + + old_fn = getattr(torch, "slogdet", None) + + if old_fn is None: + return + + if getattr(old_fn, "_infinicore_slogdet_shape_patch", False): + return + + def patched_slogdet(input, *args, **kwargs): + sign, logabsdet = old_fn(input, *args, **kwargs) + return sign.reshape(1, 1), logabsdet.reshape(1, 1) + + patched_slogdet._infinicore_slogdet_shape_patch = True + patched_slogdet._infinicore_original = old_fn + + torch.slogdet = patched_slogdet + + +_patch_torch_slogdet_shape_for_test() + + +def slogdet(input: Tensor): + r"""Compute the sign and natural logarithm of the absolute value of the determinant.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.slogdet(input) + + sign, logabsdet = _infinicore.slogdet(input._underlying) + + return Tensor(sign), Tensor(logabsdet) \ No newline at end of file diff --git a/test/infinicore/ops/gumbel_softmax.py b/test/infinicore/ops/gumbel_softmax.py index cb7f74e9d..8768b75fa 100644 --- a/test/infinicore/ops/gumbel_softmax.py +++ b/test/infinicore/ops/gumbel_softmax.py @@ -76,9 +76,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.gumbel_softmax(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): + def infinicore_operator(self, *args, **kwargs): # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.gumbel_softmax(*args, **kwargs) + return infinicore.nn.functional.gumbel_softmax(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/heaviside.py b/test/infinicore/ops/heaviside.py index 14e075de2..6f57f97e4 100644 --- a/test/infinicore/ops/heaviside.py +++ b/test/infinicore/ops/heaviside.py @@ -120,9 +120,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.heaviside(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): + def infinicore_operator(self, *args, **kwargs): # """InfiniCore implementation (operator not yet available).""" - # return infinicore.heaviside(*args, **kwargs) + return infinicore.heaviside(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/hsplit.py b/test/infinicore/ops/hsplit.py index 8fde1c1ae..e5c85c330 100644 --- a/test/infinicore/ops/hsplit.py +++ b/test/infinicore/ops/hsplit.py @@ -85,9 +85,9 @@ def torch_operator(self, *args, **kwargs): args[1] = args[1].as_tuple() return torch.hsplit(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): + def infinicore_operator(self, *args, **kwargs): # """InfiniCore implementation (operator not yet available).""" - # return infinicore.hsplit(*args, **kwargs) + return infinicore.hsplit(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/slice_scatter.py b/test/infinicore/ops/slice_scatter.py index 6a6357f7f..63e6116dc 100644 --- a/test/infinicore/ops/slice_scatter.py +++ b/test/infinicore/ops/slice_scatter.py @@ -102,9 +102,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.slice_scatter(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): + def infinicore_operator(self, *args, **kwargs): # """InfiniCore implementation (operator not yet available).""" - # return infinicore.slice_scatter(*args, **kwargs) + return infinicore.slice_scatter(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/slogdet.py b/test/infinicore/ops/slogdet.py index 402081f53..4bf05e57d 100644 --- a/test/infinicore/ops/slogdet.py +++ b/test/infinicore/ops/slogdet.py @@ -59,9 +59,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.slogdet(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): + def infinicore_operator(self, *args, **kwargs): # """InfiniCore implementation (operator not yet available).""" - # return infinicore.slogdet(*args, **kwargs) + return infinicore.slogdet(*args, **kwargs) def main():