Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
dropout,
eq,
exp,
gcd,
ge,
gelu,
glu,
gt,
isinf,
isnan,
Expand All @@ -24,12 +26,14 @@
mul,
ne,
neg,
nll_loss,
pow,
relu,
rms_norm,
rotary_position_embedding,
rsqrt,
scaled_dot_product_attention,
select_scatter,
sigmoid,
silu,
sin,
Expand Down Expand Up @@ -76,4 +80,8 @@
"softmax",
"sub",
"tanh",
"gcd",
"select_scatter",
"nll_loss",
"glu",
]
50 changes: 50 additions & 0 deletions src/ntops/kernels/gcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, other, output):
a = ntl.abs(ntl.cast(input, ntl.int64))
b = ntl.abs(ntl.cast(other, ntl.int64))

while ntl.max(ntl.cast(b != 0, ntl.int32)) == 1:
mask = b != 0
safe_b = ntl.where(mask, b, 1)
r = a % safe_b
a = ntl.where(mask, b, a)
b = ntl.where(mask, r, b)

mask = b != 0
safe_b = ntl.where(mask, b, 1)
r = a % safe_b
a = ntl.where(mask, b, a)
b = ntl.where(mask, r, b)

mask = b != 0
safe_b = ntl.where(mask, b, 1)
r = a % safe_b
a = ntl.where(mask, b, a)
b = ntl.where(mask, r, b)

mask = b != 0
safe_b = ntl.where(mask, b, 1)
r = a % safe_b
a = ntl.where(mask, b, a)
b = ntl.where(mask, r, b)

output = ntl.cast(a, output.dtype)


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype, other=0),
Tensor(ndim, dtype=dtype, other=0),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
49 changes: 49 additions & 0 deletions src/ntops/kernels/glu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement(input, output, dim_size, dim, block_size):
ndim = input.ndim
if dim < 0:
dim = ndim + dim

tile_shape = [1] * ndim
tile_shape[dim] = block_size

in_t = input.tile(tuple(tile_shape))
out_t = output.tile(tuple(tile_shape))

for _ in range(ndim - 1):
in_t.dtype = in_t.dtype.squeeze(0 if dim != 0 else 1)
out_t.dtype = out_t.dtype.squeeze(0 if dim != 0 else 1)

if dim > 0:
dim -= 1

return in_t, out_t, dim_size


def application(input, output, dim_size):
half = dim_size // 2

for i in range(half):
a = ntl.cast(input[i], ntl.float32)
b = ntl.cast(input[i + half], ntl.float32)

res = a * ntl.sigmoid(b)

output[i] = ntl.cast(res, output.dtype)


def premake(ndim, dim, dim_size, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(0, constexpr=True, value=dim_size),
)

return arrangement_, application, tensors
7 changes: 5 additions & 2 deletions src/ntops/kernels/gt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, other, output):
output = input > other # noqa: F841
tmp = input > other
result = ntl.cast(tmp, output.dtype)
output = result


def premake(ndim, dtype=None, block_size=None):
Expand All @@ -15,7 +18,7 @@ def premake(ndim, dtype=None, block_size=None):
tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim),
)

return arrangement_, application, tensors
127 changes: 127 additions & 0 deletions src/ntops/kernels/nll_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement_loss(
input,
target,
out_loss,
out_weight,
weight,
ignore_index,
C_val_Tensor,
C_pow2_Tensor,
has_weight_Tensor,
C_pow2,
):
input_t = input.tile((1, C_pow2))
target_t = target.tile((1,))
out_l_t = out_loss.tile((1,))
out_w_t = out_weight.tile((1,))

weight_t = weight.tile((C_pow2,)).expand((input_t.shape[0],))

return (
input_t,
target_t,
out_l_t,
out_w_t,
weight_t,
ignore_index,
C_val_Tensor,
C_pow2_Tensor,
has_weight_Tensor,
)


def application_loss(
input, target, out_loss, out_weight, weight, ignore_index, C_val, C_pow2, has_weight
):
ii = ntl.cast(ignore_index, ntl.int32)
C_num = ntl.cast(C_val, ntl.int32)

col_indices = ntl.arange(0, C_pow2)
t_val = ntl.cast(target, ntl.int32)

match_mask = (t_val == col_indices) & (col_indices < C_num)
is_ignore = (t_val == ii) | (t_val < 0) | (t_val >= C_num)

input_f32 = ntl.cast(input, ntl.float32)
prob_vec = ntl.where(match_mask, input_f32, 0.0)
selected_prob = ntl.sum(prob_vec)

if has_weight == True:
weight_f32 = ntl.cast(weight, ntl.float32)
weight_vec = ntl.where(match_mask, weight_f32, 0.0)
selected_weight = ntl.sum(weight_vec)
else:
selected_weight = 1.0

loss_val = 0.0 - (selected_prob * selected_weight)

out_loss = ntl.where(is_ignore, 0.0, loss_val)
out_weight = ntl.where(is_ignore, 0.0, selected_weight)


def premake_loss(ignore_index, C_val, C_pow2, has_weight, dtype=None):
arrangement_ = functools.partial(arrangement_loss, C_pow2=C_pow2)

tensors = (
Tensor(2, dtype=dtype, other=0.0), # input
Tensor(1, dtype=ninetoothed.int64), # target
Tensor(1, dtype=dtype), # out_loss
Tensor(1, dtype=dtype), # out_weight
Tensor(1, dtype=dtype, other=0.0), # weight
Tensor(0, constexpr=True, value=ignore_index),
Tensor(0, constexpr=True, value=C_val),
Tensor(0, constexpr=True, value=C_pow2),
Tensor(0, constexpr=True, value=has_weight),
)
return arrangement_, application_loss, tensors


def arrangement_reduce(input, output, block_size):
input_t = input.tile((block_size,))
output_t = output.tile((1,))
return input_t, output_t


def application_reduce(input, output):
accumulator = 0.0
for i in range(input.shape[0]):
accumulator += ntl.cast(input[i], ntl.float32)
output[0] = ntl.cast(accumulator, output.dtype)


def premake_reduce(dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement_reduce, block_size=block_size)
tensors = (
Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # input
Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output
)

return arrangement_, application_reduce, tensors


def arrangement_div(loss_sum, weight_sum, output):
return loss_sum.tile((1,)), weight_sum.tile((1,)), output.tile((1,))


def application_div(loss_sum, weight_sum, output):
l = ntl.cast(loss_sum[0], ntl.float32)
w = ntl.cast(weight_sum[0], ntl.float32)
res = ntl.where(w > 0.0, l / w, 0.0)
output[0] = ntl.cast(res, output.dtype)


def premake_div(dtype=None):
arrangement_ = functools.partial(arrangement_div)
tensors = (
Tensor(1, dtype=dtype), # loss_sum
Tensor(1, dtype=dtype), # weight_sum
Tensor(1, dtype=dtype), # output
)
return arrangement_, application_div, tensors
52 changes: 52 additions & 0 deletions src/ntops/kernels/select_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement(input, src, output, index, dim_size_pow2, dim, block_size):
ndim = input.ndim
if dim < 0:
dim += ndim
non_target_dims = tuple(i for i in range(ndim) if i != dim)

def _arrangement(t):
return t.permute(non_target_dims + (dim,)).flatten(end_dim=-1)

# (Remaining, Dim_Size)
input_arranged = _arrangement(input).tile((block_size, -1)).squeeze(1)
src_arranged = _arrangement(src).tile((block_size, -1)).squeeze(1)
output_arranged = _arrangement(output).tile((block_size, -1)).squeeze(1)

return input_arranged, src_arranged, output_arranged, index, dim_size_pow2


def application(input, src, output, target_index, dim_size_pow2):
col_indices = ntl.arange(0, dim_size_pow2)

col_indices = ntl.expand_dims(col_indices, 0)
col_indices = ntl.broadcast_to(col_indices, (input.shape[0], dim_size_pow2))

actual_dim_size = input.shape[1]

match_mask = col_indices == ntl.cast(target_index, ntl.int32)
valid_mask = col_indices < ntl.cast(actual_dim_size, ntl.int32)

final_mask = match_mask & valid_mask

output = ntl.where(
final_mask, ntl.cast(src, output.dtype), ntl.cast(input, output.dtype)
)


def premake(ndim, dim, index, dim_size_pow2, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}),
Tensor(0, constexpr=True, value=index),
Tensor(0, constexpr=True, value=dim_size_pow2),
)
return arrangement_, application, tensors
8 changes: 8 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from ntops.torch.dropout import dropout
from ntops.torch.eq import eq
from ntops.torch.exp import exp
from ntops.torch.gcd import gcd
from ntops.torch.ge import ge
from ntops.torch.gelu import gelu
from ntops.torch.glu import glu
from ntops.torch.gt import gt
from ntops.torch.isinf import isinf
from ntops.torch.isnan import isnan
Expand All @@ -24,12 +26,14 @@
from ntops.torch.mul import mul
from ntops.torch.ne import ne
from ntops.torch.neg import neg
from ntops.torch.nll_loss import nll_loss
from ntops.torch.pow import pow
from ntops.torch.relu import relu
from ntops.torch.rms_norm import rms_norm
from ntops.torch.rotary_position_embedding import rotary_position_embedding
from ntops.torch.rsqrt import rsqrt
from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention
from ntops.torch.select_scatter import select_scatter
from ntops.torch.sigmoid import sigmoid
from ntops.torch.silu import silu
from ntops.torch.sin import sin
Expand Down Expand Up @@ -76,4 +80,8 @@
"softmax",
"sub",
"tanh",
"gcd",
"select_scatter",
"nll_loss",
"glu",
]
16 changes: 16 additions & 0 deletions src/ntops/torch/gcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def gcd(input, other, out=None):
if out is None:
out = torch.empty_like(input)

block_size = 1024
kernel = _cached_make(
ntops.kernels.gcd.premake, input.ndim, input.dtype, block_size
)
kernel(input, other, out)
return out
Loading