From 43ca0b74b819eb6bf523ee9cb11194e22f58f3c7 Mon Sep 17 00:00:00 2001 From: PPPoint <1024879159@qq.com> Date: Tue, 27 Jan 2026 18:32:42 +0800 Subject: [PATCH 1/5] Finish T1-1-11: gcd --- src/ntops/kernels/gcd.py | 49 ++++++++++++++++++++++++++++++++++++++++ src/ntops/torch/gcd.py | 12 ++++++++++ 2 files changed, 61 insertions(+) create mode 100644 src/ntops/kernels/gcd.py create mode 100644 src/ntops/torch/gcd.py diff --git a/src/ntops/kernels/gcd.py b/src/ntops/kernels/gcd.py new file mode 100644 index 0000000..a4f733d --- /dev/null +++ b/src/ntops/kernels/gcd.py @@ -0,0 +1,49 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + a = ntl.abs(ntl.cast(input, ntl.int64)) + b = ntl.abs(ntl.cast(other, ntl.int64)) + + while ntl.max(ntl.cast(b != 0, ntl.int32)) == 1: + mask = b != 0 + safe_b = ntl.where(mask, b, 1) + r = a % safe_b + a = ntl.where(mask, b, a) + b = ntl.where(mask, r, b) + + mask = b != 0 + safe_b = ntl.where(mask, b, 1) + r = a % safe_b + a = ntl.where(mask, b, a) + b = ntl.where(mask, r, b) + + mask = b != 0 + safe_b = ntl.where(mask, b, 1) + r = a % safe_b + a = ntl.where(mask, b, a) + b = ntl.where(mask, r, b) + + mask = b != 0 + safe_b = ntl.where(mask, b, 1) + r = a % safe_b + a = ntl.where(mask, b, a) + b = ntl.where(mask, r, b) + + output = ntl.cast(a, output.dtype) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, other=0), + Tensor(ndim, dtype=dtype, other=0), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/gcd.py b/src/ntops/torch/gcd.py new file mode 100644 index 0000000..bd4e267 --- /dev/null +++ b/src/ntops/torch/gcd.py @@ -0,0 +1,12 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def gcd(input, other, out=None): + if out is None: + out = torch.empty_like(input) + + block_size = 1024 + kernel = _cached_make(ntops.kernels.gcd.premake, input.ndim, input.dtype, block_size) + kernel(input, other, out) + return out \ No newline at end of file From cd229fb47a375f5f3d0275e77927012a641a9d30 Mon Sep 17 00:00:00 2001 From: PPPoint <1024879159@qq.com> Date: Tue, 27 Jan 2026 20:26:18 +0800 Subject: [PATCH 2/5] Finish T1-1-11: glu --- src/ntops/kernels/glu.py | 46 ++++++++++++++++++++++++++++++++++++++++ src/ntops/torch/glu.py | 25 ++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 src/ntops/kernels/glu.py create mode 100644 src/ntops/torch/glu.py diff --git a/src/ntops/kernels/glu.py b/src/ntops/kernels/glu.py new file mode 100644 index 0000000..2925735 --- /dev/null +++ b/src/ntops/kernels/glu.py @@ -0,0 +1,46 @@ +import functools +import ninetoothed.language as ntl +from ninetoothed import Tensor + +def arrangement(input, output, dim_size, dim, block_size): + ndim = input.ndim + if dim < 0: dim = ndim + dim + + tile_shape = [1] * ndim + tile_shape[dim] = block_size + + in_t = input.tile(tuple(tile_shape)) + out_t = output.tile(tuple(tile_shape)) + + for _ in range(ndim - 1): + + in_t.dtype = in_t.dtype.squeeze(0 if dim != 0 else 1) + out_t.dtype = out_t.dtype.squeeze(0 if dim != 0 else 1) + + if dim > 0: + dim -= 1 + + return in_t, out_t, dim_size + +def application(input, output, dim_size): + half = dim_size // 2 + + for i in range(half): + a = ntl.cast(input[i], ntl.float32) + b = ntl.cast(input[i + half], ntl.float32) + + res = a * ntl.sigmoid(b) + + output[i] = ntl.cast(res, output.dtype) + +def premake(ndim, dim, dim_size, dtype=None, block_size=None): + + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(0, constexpr=True, value=dim_size), + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/glu.py b/src/ntops/torch/glu.py new file mode 100644 index 0000000..7c1bab5 --- /dev/null +++ b/src/ntops/torch/glu.py @@ -0,0 +1,25 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def glu(input, dim=-1): + ndim = input.ndim + if dim < 0: dim = ndim + dim + + dim_size = input.size(dim) + out_shape = list(input.shape) + out_shape[dim] //= 2 + output = torch.empty(out_shape, dtype=input.dtype, device=input.device) + block_size = 1024 + + kernel = _cached_make( + ntops.kernels.glu.premake, + ndim, + dim, + dim_size, + input.dtype, + block_size + ) + + kernel(input, output, dim_size) + return output \ No newline at end of file From e6142ee94c120143ba42d8123eb0372e86098673 Mon Sep 17 00:00:00 2001 From: PPPoint <1024879159@qq.com> Date: Wed, 28 Jan 2026 18:59:28 +0800 Subject: [PATCH 3/5] Finish T1-1-11: select_scatter --- src/ntops/kernels/select_scatter.py | 46 +++++++++++++++++++++++++++++ src/ntops/torch/select_scatter.py | 23 +++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 src/ntops/kernels/select_scatter.py create mode 100644 src/ntops/torch/select_scatter.py diff --git a/src/ntops/kernels/select_scatter.py b/src/ntops/kernels/select_scatter.py new file mode 100644 index 0000000..78ad155 --- /dev/null +++ b/src/ntops/kernels/select_scatter.py @@ -0,0 +1,46 @@ +import functools +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, src, output, index, dim_size_pow2, dim, block_size): + ndim = input.ndim + if dim < 0: dim += ndim + non_target_dims = tuple(i for i in range(ndim) if i != dim) + + def _arrangement(t): + return t.permute(non_target_dims + (dim,)).flatten(end_dim=-1) + + # (Remaining, Dim_Size) + input_arranged = _arrangement(input).tile((block_size, -1)).squeeze(1) + src_arranged = _arrangement(src).tile((block_size, -1)).squeeze(1) + output_arranged = _arrangement(output).tile((block_size, -1)).squeeze(1) + + return input_arranged, src_arranged, output_arranged, index, dim_size_pow2 + +def application(input, src, output, target_index, dim_size_pow2): + col_indices = ntl.arange(0, dim_size_pow2) + + col_indices = ntl.expand_dims(col_indices, 0) + col_indices = ntl.broadcast_to(col_indices, (input.shape[0], dim_size_pow2)) + + actual_dim_size = input.shape[1] + + match_mask = (col_indices == ntl.cast(target_index, ntl.int32)) + valid_mask = col_indices < ntl.cast(actual_dim_size, ntl.int32) + + final_mask = match_mask & valid_mask + + output = ntl.where(final_mask, ntl.cast(src, output.dtype), ntl.cast(input, output.dtype)) + +def premake(ndim, dim, index, dim_size_pow2, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + Tensor(0, constexpr=True, value=index), + Tensor(0, constexpr=True, value=dim_size_pow2), + ) + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/select_scatter.py b/src/ntops/torch/select_scatter.py new file mode 100644 index 0000000..d738266 --- /dev/null +++ b/src/ntops/torch/select_scatter.py @@ -0,0 +1,23 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def select_scatter(input, src, dim, index): + ndim = input.ndim + if dim < 0: dim += ndim + + dim_size = input.shape[dim] + dim_size_pow2 = 1 << (dim_size - 1).bit_length() + + src_expanded = src.unsqueeze(dim) + output = torch.empty_like(input) + block_size = 1024 + + kernel = _cached_make( + ntops.kernels.select_scatter.premake, + ndim, dim, int(index), int(dim_size_pow2), + input.dtype, block_size + ) + + kernel(input, src_expanded, output, int(index), int(dim_size_pow2)) + return output \ No newline at end of file From 1533db6d9c37d8ebc5706802151af29d3846ceb5 Mon Sep 17 00:00:00 2001 From: PPPoint <1024879159@qq.com> Date: Fri, 30 Jan 2026 17:24:35 +0800 Subject: [PATCH 4/5] Finish T1-1-11: nll_loss --- src/ntops/kernels/nll_loss.py | 95 +++++++++++++++++++++++++++++++++++ src/ntops/torch/nll_loss.py | 82 ++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 src/ntops/kernels/nll_loss.py create mode 100644 src/ntops/torch/nll_loss.py diff --git a/src/ntops/kernels/nll_loss.py b/src/ntops/kernels/nll_loss.py new file mode 100644 index 0000000..1b80b41 --- /dev/null +++ b/src/ntops/kernels/nll_loss.py @@ -0,0 +1,95 @@ +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +def arrangement_loss(input, target, out_loss, out_weight, weight, ignore_index, C_val_Tensor, C_pow2_Tensor, has_weight_Tensor, C_pow2): + input_t = input.tile((1, C_pow2)) + target_t = target.tile((1,)) + out_l_t = out_loss.tile((1,)) + out_w_t = out_weight.tile((1,)) + + weight_t = weight.tile((C_pow2,)).expand((input_t.shape[0],)) + + return input_t, target_t, out_l_t, out_w_t, weight_t, ignore_index, C_val_Tensor, C_pow2_Tensor, has_weight_Tensor + +def application_loss(input, target, out_loss, out_weight, weight, ignore_index, C_val, C_pow2, has_weight): + ii = ntl.cast(ignore_index, ntl.int32) + C_num = ntl.cast(C_val, ntl.int32) + + col_indices = ntl.arange(0, C_pow2) + t_val = ntl.cast(target, ntl.int32) + + match_mask = (t_val == col_indices) & (col_indices < C_num) + is_ignore = (t_val == ii) | (t_val < 0) | (t_val >= C_num) + + input_f32 = ntl.cast(input, ntl.float32) + prob_vec = ntl.where(match_mask, input_f32, 0.0) + selected_prob = ntl.sum(prob_vec) + + if has_weight == True: + weight_f32 = ntl.cast(weight, ntl.float32) + weight_vec = ntl.where(match_mask, weight_f32, 0.0) + selected_weight = ntl.sum(weight_vec) + else: + selected_weight = 1.0 + + loss_val = 0.0 - (selected_prob * selected_weight) + + out_loss = ntl.where(is_ignore, 0.0, loss_val) + out_weight = ntl.where(is_ignore, 0.0, selected_weight) + + +def premake_loss(ignore_index, C_val, C_pow2, has_weight, dtype=None): + arrangement_ = functools.partial(arrangement_loss, C_pow2=C_pow2) + + tensors = ( + Tensor(2, dtype=dtype, other=0.0), # input + Tensor(1, dtype=ninetoothed.int64), # target + Tensor(1, dtype=dtype), # out_loss + Tensor(1, dtype=dtype), # out_weight + Tensor(1, dtype=dtype, other=0.0), # weight + Tensor(0, constexpr=True, value=ignore_index), + Tensor(0, constexpr=True, value=C_val), + Tensor(0, constexpr=True, value=C_pow2), + Tensor(0, constexpr=True, value=has_weight), + ) + return arrangement_, application_loss, tensors + +def arrangement_reduce(input, output, block_size): + input_t = input.tile((block_size,)) + output_t = output.tile((1,)) + return input_t, output_t + +def application_reduce(input, output): + accumulator = 0.0 + for i in range(input.shape[0]): + accumulator += ntl.cast(input[i], ntl.float32) + output[0] = ntl.cast(accumulator, output.dtype) + +def premake_reduce(dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement_reduce, block_size=block_size) + tensors = ( + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # input + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output + ) + + return arrangement_, application_reduce, tensors + +def arrangement_div(loss_sum, weight_sum, output): + return loss_sum.tile((1,)), weight_sum.tile((1,)), output.tile((1,)) + +def application_div(loss_sum, weight_sum, output): + l = ntl.cast(loss_sum[0], ntl.float32) + w = ntl.cast(weight_sum[0], ntl.float32) + res = ntl.where(w > 0.0, l / w, 0.0) + output[0] = ntl.cast(res, output.dtype) + +def premake_div(dtype=None): + arrangement_ = functools.partial(arrangement_div) + tensors = ( + Tensor(1, dtype=dtype), # loss_sum + Tensor(1, dtype=dtype), # weight_sum + Tensor(1, dtype=dtype), # output + ) + return arrangement_, application_div, tensors diff --git a/src/ntops/torch/nll_loss.py b/src/ntops/torch/nll_loss.py new file mode 100644 index 0000000..df1892a --- /dev/null +++ b/src/ntops/torch/nll_loss.py @@ -0,0 +1,82 @@ + +import math +import torch +import ntops +from ntops.torch.utils import _cached_make + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + if target_size > 1024: + target_size = 1024 + if target_size < 32: + target_size = 32 + return target_size + +def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'): + N, C = input.shape + C_pow2 = 1 << (C - 1).bit_length() + device = input.device + dtype = input.dtype + + def iterative_reduce(current, name=""): + step = 0 + while current.numel() > 1: + block_size = get_optimal_block_size(current.numel()) + output_len = math.ceil(current.numel() / block_size) + output = torch.empty((output_len,), dtype=dtype, device=device) + + kernel_reduce = _cached_make( + ntops.kernels.nll_loss.premake_reduce, + dtype, + block_size + ) + kernel_reduce(current, output) + current = output + step += 1 + return current + + tmp_loss_sum = torch.zeros((N,), dtype=dtype, device=device) + tmp_weight_sum = torch.zeros((N,), dtype=dtype, device=device) + + if weight is None: + dummy_weight = torch.empty_like(target) + has_weight = False + else: + dummy_weight = weight.contiguous() + has_weight = True + + kernel_loss = _cached_make( + ntops.kernels.nll_loss.premake_loss, + int(ignore_index), + C, + C_pow2, + has_weight, + dtype + ) + kernel_loss(input, target, tmp_loss_sum, tmp_weight_sum, dummy_weight, int(ignore_index), C, C_pow2, has_weight) + + if reduction == 'none': + return tmp_loss_sum + + final_loss_tensor = iterative_reduce(tmp_loss_sum, "Loss") + final_weight_tensor = iterative_reduce(tmp_weight_sum, "Weight") + + loss_val = final_loss_tensor.view(()) + + if reduction == 'sum': + return loss_val + + elif reduction == 'mean': + final_output = torch.empty((1,), dtype=dtype, device=device) + kernel_div = _cached_make( + ntops.kernels.nll_loss.premake_div, + dtype + ) + kernel_div(final_loss_tensor, final_weight_tensor, final_output) + + return final_output From 80aac9c52ede12c83a3c14bd77eb525a3044b10e Mon Sep 17 00:00:00 2001 From: PPPoint <1024879159@qq.com> Date: Fri, 30 Jan 2026 19:41:18 +0800 Subject: [PATCH 5/5] Finish T1-1-11: gt && Format --- src/ntops/kernels/__init__.py | 8 ++++ src/ntops/kernels/gcd.py | 5 ++- src/ntops/kernels/glu.py | 21 +++++---- src/ntops/kernels/gt.py | 7 ++- src/ntops/kernels/nll_loss.py | 66 +++++++++++++++++++++-------- src/ntops/kernels/select_scatter.py | 28 +++++++----- src/ntops/torch/__init__.py | 8 ++++ src/ntops/torch/gcd.py | 10 +++-- src/ntops/torch/glu.py | 18 ++++---- src/ntops/torch/gt.py | 7 ++- src/ntops/torch/nll_loss.py | 51 +++++++++++++--------- src/ntops/torch/select_scatter.py | 21 ++++++--- 12 files changed, 166 insertions(+), 84 deletions(-) diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..cb325b2 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -12,8 +12,10 @@ dropout, eq, exp, + gcd, ge, gelu, + glu, gt, isinf, isnan, @@ -24,12 +26,14 @@ mul, ne, neg, + nll_loss, pow, relu, rms_norm, rotary_position_embedding, rsqrt, scaled_dot_product_attention, + select_scatter, sigmoid, silu, sin, @@ -76,4 +80,8 @@ "softmax", "sub", "tanh", + "gcd", + "select_scatter", + "nll_loss", + "glu", ] diff --git a/src/ntops/kernels/gcd.py b/src/ntops/kernels/gcd.py index a4f733d..c1dd34e 100644 --- a/src/ntops/kernels/gcd.py +++ b/src/ntops/kernels/gcd.py @@ -2,6 +2,7 @@ import ninetoothed.language as ntl from ninetoothed import Tensor + from ntops.kernels.element_wise import arrangement @@ -21,7 +22,7 @@ def application(input, other, output): r = a % safe_b a = ntl.where(mask, b, a) b = ntl.where(mask, r, b) - + mask = b != 0 safe_b = ntl.where(mask, b, 1) r = a % safe_b @@ -46,4 +47,4 @@ def premake(ndim, dtype=None, block_size=None): Tensor(ndim, dtype=dtype), ) - return arrangement_, application, tensors \ No newline at end of file + return arrangement_, application, tensors diff --git a/src/ntops/kernels/glu.py b/src/ntops/kernels/glu.py index 2925735..b61644c 100644 --- a/src/ntops/kernels/glu.py +++ b/src/ntops/kernels/glu.py @@ -1,19 +1,21 @@ import functools + import ninetoothed.language as ntl from ninetoothed import Tensor + def arrangement(input, output, dim_size, dim, block_size): ndim = input.ndim - if dim < 0: dim = ndim + dim - + if dim < 0: + dim = ndim + dim + tile_shape = [1] * ndim tile_shape[dim] = block_size - + in_t = input.tile(tuple(tile_shape)) out_t = output.tile(tuple(tile_shape)) - - for _ in range(ndim - 1): + for _ in range(ndim - 1): in_t.dtype = in_t.dtype.squeeze(0 if dim != 0 else 1) out_t.dtype = out_t.dtype.squeeze(0 if dim != 0 else 1) @@ -22,19 +24,20 @@ def arrangement(input, output, dim_size, dim, block_size): return in_t, out_t, dim_size + def application(input, output, dim_size): half = dim_size // 2 for i in range(half): a = ntl.cast(input[i], ntl.float32) b = ntl.cast(input[i + half], ntl.float32) - + res = a * ntl.sigmoid(b) - + output[i] = ntl.cast(res, output.dtype) -def premake(ndim, dim, dim_size, dtype=None, block_size=None): +def premake(ndim, dim, dim_size, dtype=None, block_size=None): arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) tensors = ( @@ -43,4 +46,4 @@ def premake(ndim, dim, dim_size, dtype=None, block_size=None): Tensor(0, constexpr=True, value=dim_size), ) - return arrangement_, application, tensors \ No newline at end of file + return arrangement_, application, tensors diff --git a/src/ntops/kernels/gt.py b/src/ntops/kernels/gt.py index 2a67cc5..5e39916 100644 --- a/src/ntops/kernels/gt.py +++ b/src/ntops/kernels/gt.py @@ -1,12 +1,15 @@ import functools +import ninetoothed.language as ntl from ninetoothed import Tensor from ntops.kernels.element_wise import arrangement def application(input, other, output): - output = input > other # noqa: F841 + tmp = input > other + result = ntl.cast(tmp, output.dtype) + output = result def premake(ndim, dtype=None, block_size=None): @@ -15,7 +18,7 @@ def premake(ndim, dtype=None, block_size=None): tensors = ( Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype), - Tensor(ndim, dtype=dtype), + Tensor(ndim), ) return arrangement_, application, tensors diff --git a/src/ntops/kernels/nll_loss.py b/src/ntops/kernels/nll_loss.py index 1b80b41..afa80c9 100644 --- a/src/ntops/kernels/nll_loss.py +++ b/src/ntops/kernels/nll_loss.py @@ -1,22 +1,48 @@ import functools + import ninetoothed import ninetoothed.language as ntl from ninetoothed import Tensor -def arrangement_loss(input, target, out_loss, out_weight, weight, ignore_index, C_val_Tensor, C_pow2_Tensor, has_weight_Tensor, C_pow2): - input_t = input.tile((1, C_pow2)) + +def arrangement_loss( + input, + target, + out_loss, + out_weight, + weight, + ignore_index, + C_val_Tensor, + C_pow2_Tensor, + has_weight_Tensor, + C_pow2, +): + input_t = input.tile((1, C_pow2)) target_t = target.tile((1,)) out_l_t = out_loss.tile((1,)) out_w_t = out_weight.tile((1,)) - + weight_t = weight.tile((C_pow2,)).expand((input_t.shape[0],)) - return input_t, target_t, out_l_t, out_w_t, weight_t, ignore_index, C_val_Tensor, C_pow2_Tensor, has_weight_Tensor + return ( + input_t, + target_t, + out_l_t, + out_w_t, + weight_t, + ignore_index, + C_val_Tensor, + C_pow2_Tensor, + has_weight_Tensor, + ) -def application_loss(input, target, out_loss, out_weight, weight, ignore_index, C_val, C_pow2, has_weight): + +def application_loss( + input, target, out_loss, out_weight, weight, ignore_index, C_val, C_pow2, has_weight +): ii = ntl.cast(ignore_index, ntl.int32) C_num = ntl.cast(C_val, ntl.int32) - + col_indices = ntl.arange(0, C_pow2) t_val = ntl.cast(target, ntl.int32) @@ -44,11 +70,11 @@ def premake_loss(ignore_index, C_val, C_pow2, has_weight, dtype=None): arrangement_ = functools.partial(arrangement_loss, C_pow2=C_pow2) tensors = ( - Tensor(2, dtype=dtype, other=0.0), # input - Tensor(1, dtype=ninetoothed.int64), # target - Tensor(1, dtype=dtype), # out_loss - Tensor(1, dtype=dtype), # out_weight - Tensor(1, dtype=dtype, other=0.0), # weight + Tensor(2, dtype=dtype, other=0.0), # input + Tensor(1, dtype=ninetoothed.int64), # target + Tensor(1, dtype=dtype), # out_loss + Tensor(1, dtype=dtype), # out_weight + Tensor(1, dtype=dtype, other=0.0), # weight Tensor(0, constexpr=True, value=ignore_index), Tensor(0, constexpr=True, value=C_val), Tensor(0, constexpr=True, value=C_pow2), @@ -56,40 +82,46 @@ def premake_loss(ignore_index, C_val, C_pow2, has_weight, dtype=None): ) return arrangement_, application_loss, tensors + def arrangement_reduce(input, output, block_size): input_t = input.tile((block_size,)) output_t = output.tile((1,)) return input_t, output_t + def application_reduce(input, output): accumulator = 0.0 for i in range(input.shape[0]): accumulator += ntl.cast(input[i], ntl.float32) output[0] = ntl.cast(accumulator, output.dtype) + def premake_reduce(dtype=None, block_size=None): arrangement_ = functools.partial(arrangement_reduce, block_size=block_size) tensors = ( - Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # input - Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # input + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output ) - + return arrangement_, application_reduce, tensors + def arrangement_div(loss_sum, weight_sum, output): return loss_sum.tile((1,)), weight_sum.tile((1,)), output.tile((1,)) + def application_div(loss_sum, weight_sum, output): l = ntl.cast(loss_sum[0], ntl.float32) w = ntl.cast(weight_sum[0], ntl.float32) res = ntl.where(w > 0.0, l / w, 0.0) output[0] = ntl.cast(res, output.dtype) + def premake_div(dtype=None): arrangement_ = functools.partial(arrangement_div) tensors = ( - Tensor(1, dtype=dtype), # loss_sum - Tensor(1, dtype=dtype), # weight_sum - Tensor(1, dtype=dtype), # output + Tensor(1, dtype=dtype), # loss_sum + Tensor(1, dtype=dtype), # weight_sum + Tensor(1, dtype=dtype), # output ) return arrangement_, application_div, tensors diff --git a/src/ntops/kernels/select_scatter.py b/src/ntops/kernels/select_scatter.py index 78ad155..9f2d28b 100644 --- a/src/ntops/kernels/select_scatter.py +++ b/src/ntops/kernels/select_scatter.py @@ -1,13 +1,15 @@ import functools + import ninetoothed.language as ntl from ninetoothed import Tensor def arrangement(input, src, output, index, dim_size_pow2, dim, block_size): ndim = input.ndim - if dim < 0: dim += ndim + if dim < 0: + dim += ndim non_target_dims = tuple(i for i in range(ndim) if i != dim) - + def _arrangement(t): return t.permute(non_target_dims + (dim,)).flatten(end_dim=-1) @@ -18,24 +20,28 @@ def _arrangement(t): return input_arranged, src_arranged, output_arranged, index, dim_size_pow2 + def application(input, src, output, target_index, dim_size_pow2): col_indices = ntl.arange(0, dim_size_pow2) - + col_indices = ntl.expand_dims(col_indices, 0) col_indices = ntl.broadcast_to(col_indices, (input.shape[0], dim_size_pow2)) - + actual_dim_size = input.shape[1] - - match_mask = (col_indices == ntl.cast(target_index, ntl.int32)) + + match_mask = col_indices == ntl.cast(target_index, ntl.int32) valid_mask = col_indices < ntl.cast(actual_dim_size, ntl.int32) - + final_mask = match_mask & valid_mask - - output = ntl.where(final_mask, ntl.cast(src, output.dtype), ntl.cast(input, output.dtype)) + + output = ntl.where( + final_mask, ntl.cast(src, output.dtype), ntl.cast(input, output.dtype) + ) + def premake(ndim, dim, index, dim_size_pow2, dtype=None, block_size=None): arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) - + tensors = ( Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), @@ -43,4 +49,4 @@ def premake(ndim, dim, index, dim_size_pow2, dtype=None, block_size=None): Tensor(0, constexpr=True, value=index), Tensor(0, constexpr=True, value=dim_size_pow2), ) - return arrangement_, application, tensors \ No newline at end of file + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..4f82031 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -11,8 +11,10 @@ from ntops.torch.dropout import dropout from ntops.torch.eq import eq from ntops.torch.exp import exp +from ntops.torch.gcd import gcd from ntops.torch.ge import ge from ntops.torch.gelu import gelu +from ntops.torch.glu import glu from ntops.torch.gt import gt from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan @@ -24,12 +26,14 @@ from ntops.torch.mul import mul from ntops.torch.ne import ne from ntops.torch.neg import neg +from ntops.torch.nll_loss import nll_loss from ntops.torch.pow import pow from ntops.torch.relu import relu from ntops.torch.rms_norm import rms_norm from ntops.torch.rotary_position_embedding import rotary_position_embedding from ntops.torch.rsqrt import rsqrt from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention +from ntops.torch.select_scatter import select_scatter from ntops.torch.sigmoid import sigmoid from ntops.torch.silu import silu from ntops.torch.sin import sin @@ -76,4 +80,8 @@ "softmax", "sub", "tanh", + "gcd", + "select_scatter", + "nll_loss", + "glu", ] diff --git a/src/ntops/torch/gcd.py b/src/ntops/torch/gcd.py index bd4e267..4614ae3 100644 --- a/src/ntops/torch/gcd.py +++ b/src/ntops/torch/gcd.py @@ -1,12 +1,16 @@ import torch + import ntops from ntops.torch.utils import _cached_make + def gcd(input, other, out=None): if out is None: out = torch.empty_like(input) - + block_size = 1024 - kernel = _cached_make(ntops.kernels.gcd.premake, input.ndim, input.dtype, block_size) + kernel = _cached_make( + ntops.kernels.gcd.premake, input.ndim, input.dtype, block_size + ) kernel(input, other, out) - return out \ No newline at end of file + return out diff --git a/src/ntops/torch/glu.py b/src/ntops/torch/glu.py index 7c1bab5..31f822e 100644 --- a/src/ntops/torch/glu.py +++ b/src/ntops/torch/glu.py @@ -1,25 +1,23 @@ import torch + import ntops from ntops.torch.utils import _cached_make + def glu(input, dim=-1): ndim = input.ndim - if dim < 0: dim = ndim + dim - + if dim < 0: + dim = ndim + dim + dim_size = input.size(dim) out_shape = list(input.shape) out_shape[dim] //= 2 output = torch.empty(out_shape, dtype=input.dtype, device=input.device) block_size = 1024 - + kernel = _cached_make( - ntops.kernels.glu.premake, - ndim, - dim, - dim_size, - input.dtype, - block_size + ntops.kernels.glu.premake, ndim, dim, dim_size, input.dtype, block_size ) kernel(input, output, dim_size) - return output \ No newline at end of file + return output diff --git a/src/ntops/torch/gt.py b/src/ntops/torch/gt.py index 5404ffd..6fe2b46 100644 --- a/src/ntops/torch/gt.py +++ b/src/ntops/torch/gt.py @@ -6,9 +6,12 @@ def gt(input, other, *, out=None): if out is None: - out = torch.empty_like(input) + out = torch.empty(input.shape, dtype=torch.bool, device=input.device) - kernel = _cached_make(ntops.kernels.gt.premake, input.ndim) + block_size = 1024 + kernel = _cached_make( + ntops.kernels.gt.premake, input.ndim, dtype=input.dtype, block_size=block_size + ) kernel(input, other, out) diff --git a/src/ntops/torch/nll_loss.py b/src/ntops/torch/nll_loss.py index df1892a..7a67b17 100644 --- a/src/ntops/torch/nll_loss.py +++ b/src/ntops/torch/nll_loss.py @@ -1,14 +1,17 @@ - import math + import torch + import ntops from ntops.torch.utils import _cached_make + def next_power_of_2(n): if n == 0: return 1 return 1 << (n - 1).bit_length() + def get_optimal_block_size(dim_size): target_size = next_power_of_2(dim_size) if target_size > 1024: @@ -17,23 +20,22 @@ def get_optimal_block_size(dim_size): target_size = 32 return target_size -def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'): + +def nll_loss(input, target, weight=None, ignore_index=-100, reduction="mean"): N, C = input.shape C_pow2 = 1 << (C - 1).bit_length() device = input.device dtype = input.dtype - + def iterative_reduce(current, name=""): step = 0 while current.numel() > 1: block_size = get_optimal_block_size(current.numel()) output_len = math.ceil(current.numel() / block_size) output = torch.empty((output_len,), dtype=dtype, device=device) - + kernel_reduce = _cached_make( - ntops.kernels.nll_loss.premake_reduce, - dtype, - block_size + ntops.kernels.nll_loss.premake_reduce, dtype, block_size ) kernel_reduce(current, output) current = output @@ -49,34 +51,41 @@ def iterative_reduce(current, name=""): else: dummy_weight = weight.contiguous() has_weight = True - + kernel_loss = _cached_make( ntops.kernels.nll_loss.premake_loss, int(ignore_index), C, C_pow2, has_weight, - dtype + dtype, ) - kernel_loss(input, target, tmp_loss_sum, tmp_weight_sum, dummy_weight, int(ignore_index), C, C_pow2, has_weight) - - if reduction == 'none': + kernel_loss( + input, + target, + tmp_loss_sum, + tmp_weight_sum, + dummy_weight, + int(ignore_index), + C, + C_pow2, + has_weight, + ) + + if reduction == "none": return tmp_loss_sum final_loss_tensor = iterative_reduce(tmp_loss_sum, "Loss") final_weight_tensor = iterative_reduce(tmp_weight_sum, "Weight") - + loss_val = final_loss_tensor.view(()) - - if reduction == 'sum': + + if reduction == "sum": return loss_val - - elif reduction == 'mean': + + elif reduction == "mean": final_output = torch.empty((1,), dtype=dtype, device=device) - kernel_div = _cached_make( - ntops.kernels.nll_loss.premake_div, - dtype - ) + kernel_div = _cached_make(ntops.kernels.nll_loss.premake_div, dtype) kernel_div(final_loss_tensor, final_weight_tensor, final_output) return final_output diff --git a/src/ntops/torch/select_scatter.py b/src/ntops/torch/select_scatter.py index d738266..fb07b0f 100644 --- a/src/ntops/torch/select_scatter.py +++ b/src/ntops/torch/select_scatter.py @@ -1,23 +1,30 @@ import torch + import ntops from ntops.torch.utils import _cached_make + def select_scatter(input, src, dim, index): ndim = input.ndim - if dim < 0: dim += ndim - + if dim < 0: + dim += ndim + dim_size = input.shape[dim] dim_size_pow2 = 1 << (dim_size - 1).bit_length() - + src_expanded = src.unsqueeze(dim) output = torch.empty_like(input) block_size = 1024 kernel = _cached_make( - ntops.kernels.select_scatter.premake, - ndim, dim, int(index), int(dim_size_pow2), - input.dtype, block_size + ntops.kernels.select_scatter.premake, + ndim, + dim, + int(index), + int(dim_size_pow2), + input.dtype, + block_size, ) kernel(input, src_expanded, output, int(index), int(dim_size_pow2)) - return output \ No newline at end of file + return output