Skip to content

Commit 0129cb1

Browse files
committed
slice is the opposite of cat
1 parent 94cbe8d commit 0129cb1

7 files changed

Lines changed: 175 additions & 4 deletions

File tree

src/tensor/cpu/ops.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,12 +607,16 @@ template Tensor<bfloat16, CPU> pow(bfloat16, const TensorView<bfloat16, CPU>&);
607607
template Tensor<float, CPU> pow(const TensorView<float, CPU>&, float);
608608
template Tensor<float, CPU> pow(float, const TensorView<float, CPU>&);
609609

610+
// tril
610611
template Tensor<bfloat16, CPU> tril(const TensorView<bfloat16, CPU>&, bool);
611612
template Tensor<int, CPU> tril(const TensorView<int, CPU>&, bool);
613+
614+
// slice
612615
template Tensor<bfloat16, CPU> slice(const TensorView<bfloat16, CPU>&, int, size_t, size_t);
613616
template Tensor<float, CPU> slice(const TensorView<float, CPU>&, int, size_t, size_t);
614617
template Tensor<float, CPU> slice(const TensorView<const float, CPU>&, int, size_t, size_t);
615-
template Tensor<int, CPU> slice(const TensorView<int, CPU>&, int, size_t, size_t);
618+
619+
// matmul
616620
template Tensor<bfloat16, CPU> matmul(const TensorView<bfloat16, CPU>&,
617621
const TensorView<bfloat16, CPU>&);
618622
template Tensor<bfloat16, CPU> matmul(const TensorView<bfloat16, CPU>&,

src/tensor/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
file(GLOB HEADER_LIST CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/include/tensor/*.hpp")
22

3-
add_library(tensor_cuda STATIC storage.cu loader.cu ops.cu kernels/fill.cu kernels/argmax.cu kernels/arange.cu kernels/sum.cu kernels/max.cu kernels/masked_fill.cu kernels/cat.cu kernels/map.cu kernels/zip.cu kernels/tril.cu)
3+
add_library(tensor_cuda STATIC storage.cu loader.cu ops.cu kernels/fill.cu kernels/argmax.cu kernels/arange.cu kernels/sum.cu kernels/max.cu kernels/masked_fill.cu kernels/cat.cu kernels/map.cu kernels/zip.cu kernels/tril.cu kernels/slice.cu)
44

55
set_target_properties(tensor_cuda PROPERTIES
66
CUDA_SEPARABLE_COMPILATION ON

src/tensor/cuda/kernels/slice.cu

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#include "slice.cuh"
2+
#include "utils.cuh"
3+
#include <cstddef>
4+
#include <tensor/device_type.hpp>
5+
#include <cuda_runtime.h>
6+
7+
namespace tensor::kernels {
8+
9+
using namespace dtype;
10+
11+
template <typename T>
12+
__global__ void slice_kernel(Cuda<T>* out, const Cuda<T>* input, size_t start_offset, size_t chunk_size, size_t source_stride) {
13+
size_t operation_idx = blockIdx.x;
14+
15+
auto in_base = (operation_idx * source_stride) + start_offset;
16+
auto out_base = operation_idx * chunk_size;
17+
18+
for (size_t element = threadIdx.x; element < chunk_size; element += blockDim.x) {
19+
out[out_base + element] = input[in_base + element];
20+
}
21+
}
22+
23+
template <typename T>
24+
Tensor<T, CUDA> slice(const TensorView<T, CUDA>& view, int dim, size_t start, size_t end) {
25+
assert(view.is_contiguous() && "tensor should be contiguous");
26+
27+
auto shape = view.shape;
28+
29+
if (dim < 0) {
30+
dim = static_cast<int>(shape.size()) + dim;
31+
}
32+
33+
Shape new_shape{shape};
34+
new_shape[dim] = end - start;
35+
36+
// product of all dimensions after dim
37+
size_t inner_stride = 1;
38+
for (size_t idx = dim + 1; idx < shape.size(); ++idx) {
39+
inner_stride *= shape[idx];
40+
}
41+
42+
// product of all dimensions before dim
43+
size_t outer_iterations = 1;
44+
for (size_t idx = 0; idx < static_cast<size_t>(dim); ++idx) {
45+
outer_iterations *= shape[idx];
46+
}
47+
48+
size_t source_stride = shape[dim] * inner_stride;
49+
size_t chunk_size = (end - start) * inner_stride;
50+
size_t start_offset = start * inner_stride;
51+
52+
size_t n_elements = outer_iterations * chunk_size;
53+
TensorStorage<T, CUDA> storage(n_elements);
54+
Tensor<T, CUDA> out{new_shape, std::move(storage)};
55+
56+
// fast path: if slicing on first dimension, just use cudaMemcpy
57+
if (dim == 0) {
58+
size_t bytes = n_elements * sizeof(T);
59+
CUDA_CHECK(cudaMemcpy(out.data(), view.data + start_offset, bytes, cudaMemcpyDeviceToDevice)); // NOLINT
60+
return out;
61+
}
62+
63+
size_t block_size = cuda::get_block_size(chunk_size);
64+
65+
auto* out_d = reinterpret_cast<Cuda<T>*>(out.data()); // NOLINT
66+
auto* in_d = reinterpret_cast<const Cuda<T>*>(view.data); // NOLINT
67+
68+
slice_kernel<T><<<outer_iterations, block_size>>>(out_d, in_d, start_offset, chunk_size, source_stride);
69+
70+
return out;
71+
}
72+
73+
template Tensor<bfloat16, CUDA> slice(const TensorView<bfloat16, CUDA>& view, int dim, size_t start, size_t end);
74+
template Tensor<float, CUDA> slice(const TensorView<float, CUDA>& view, int dim, size_t start, size_t end);
75+
76+
} // namespace tensor::kernels

src/tensor/cuda/kernels/slice.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#pragma once
2+
3+
#include <tensor/tensor.hpp>
4+
5+
namespace tensor::kernels {
6+
7+
template<typename T>
8+
Tensor<T, CUDA> slice(const TensorView<T, CUDA>& view, int dim, size_t start, size_t end);
9+
10+
} // namespace tensor::kernels

src/tensor/cuda/kernels/tril.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ template <typename T> Tensor<T, CUDA> tril(const TensorView<T, CUDA>& tensor, bo
5151
auto* in_d = reinterpret_cast<Cuda<T>*>(tensor.data); // NOLINT
5252
Cuda<T> diagonal_d = to_device_type(diagonal, CUDA{});
5353

54-
fmt::println("Grid: {} by {}", grid_size, block_size);
55-
5654
tril_kernel<Cuda<T>><<<grid_size, block_size>>>(device_data, in_d, diagonal_d, cols, rows);
5755

5856
return out;

src/tensor/cuda/ops.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "kernels/map.cuh"
1515
#include "kernels/zip.cuh"
1616
#include "kernels/tril.cuh"
17+
#include "kernels/slice.cuh"
1718
#include "kernels/utils.cuh"
1819

1920
namespace tensor {
@@ -135,4 +136,14 @@ Tensor<int, CUDA> tril(const TensorView<int, CUDA>& tensor, bool diagonal) {
135136
return kernels::tril(tensor, diagonal);
136137
}
137138

139+
template <>
140+
Tensor<bfloat16, CUDA> slice(const TensorView<bfloat16, CUDA>& view, int dim, size_t start, size_t end) {
141+
return kernels::slice(view, dim, start, end);
142+
}
143+
144+
template <>
145+
Tensor<float, CUDA> slice(const TensorView<float, CUDA>& view, int dim, size_t start, size_t end) {
146+
return kernels::slice(view, dim, start, end);
147+
}
148+
138149
} // namespace tensor

tests/tensor/cuda/test_ops.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,75 @@ TEST(TensorCUDATest, TrilBf16) {
309309
exp = {1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1};
310310
tensor_is_close<bfloat16>(diag.span(), std::span(exp));
311311
}
312+
313+
TEST(TensorCUDATest, SliceBf16FirstDim) {
314+
SKIP_IF_NO_GPU();
315+
// Tensor shape {4, 3}: 4 rows, 3 cols
316+
// Data: row0=[1,2,3], row1=[4,5,6], row2=[7,8,9], row3=[10,11,12]
317+
Tensor<bfloat16, CPU> tensor({4, 3});
318+
for (int i = 0; i < 12; ++i) {
319+
tensor.set_(i, bfloat16(i + 1));
320+
}
321+
322+
auto gpu_tensor = tensor.cuda();
323+
324+
// Slice rows 1 to 3 (exclusive), so rows 1 and 2
325+
Tensor<bfloat16, CUDA> result = slice(gpu_tensor.view(), 0, 1, 3);
326+
327+
auto result_cpu = result.cpu();
328+
329+
Shape expected_shape = {2, 3};
330+
EXPECT_EQ(result_cpu.shape(), expected_shape);
331+
332+
// Expected: row1=[4,5,6], row2=[7,8,9]
333+
std::vector<bfloat16> exp = {4, 5, 6, 7, 8, 9};
334+
tensor_is_close<bfloat16>(result_cpu.span(), std::span(exp));
335+
}
336+
337+
TEST(TensorCUDATest, SliceBf16LastDim) {
338+
SKIP_IF_NO_GPU();
339+
// Tensor shape {2, 6}
340+
// Data: row0=[1,2,3,4,5,6], row1=[7,8,9,10,11,12]
341+
Tensor<bfloat16, CPU> tensor({2, 6});
342+
for (int i = 0; i < 12; ++i) {
343+
tensor.set_(i, bfloat16(i + 1));
344+
}
345+
346+
auto gpu_tensor = tensor.cuda();
347+
348+
// Slice cols 2 to 5 (exclusive), so cols 2, 3, 4
349+
Tensor<bfloat16, CUDA> result = slice(gpu_tensor.view(), 1, 2, 5);
350+
351+
auto result_cpu = result.cpu();
352+
353+
Shape expected_shape = {2, 3};
354+
EXPECT_EQ(result_cpu.shape(), expected_shape);
355+
356+
// Expected: row0=[3,4,5], row1=[9,10,11]
357+
std::vector<bfloat16> exp = {3, 4, 5, 9, 10, 11};
358+
tensor_is_close<bfloat16>(result_cpu.span(), std::span(exp));
359+
}
360+
361+
TEST(TensorCUDATest, SliceBf16MiddleDim) {
362+
SKIP_IF_NO_GPU();
363+
// Tensor shape {2, 4, 3}: 2 batches, 4 rows, 3 cols
364+
Tensor<bfloat16, CPU> tensor({2, 4, 3});
365+
for (int i = 0; i < 24; ++i) {
366+
tensor.set_(i, bfloat16(i + 1));
367+
}
368+
369+
auto gpu_tensor = tensor.cuda();
370+
371+
// Slice dim 1 (rows) from 1 to 3, keeping 2 rows
372+
Tensor<bfloat16, CUDA> result = slice(gpu_tensor.view(), 1, 1, 3);
373+
374+
auto result_cpu = result.cpu();
375+
376+
Shape expected_shape = {2, 2, 3};
377+
EXPECT_EQ(result_cpu.shape(), expected_shape);
378+
379+
// Batch 0: rows 1-2 = [4,5,6, 7,8,9]
380+
// Batch 1: rows 1-2 = [16,17,18, 19,20,21]
381+
std::vector<bfloat16> exp = {4, 5, 6, 7, 8, 9, 16, 17, 18, 19, 20, 21};
382+
tensor_is_close<bfloat16>(result_cpu.span(), std::span(exp));
383+
}

0 commit comments

Comments
 (0)