Skip to content

Commit 31985c7

Browse files
committed
Implement all elementwise operations as maps and zips
1 parent efdd75b commit 31985c7

17 files changed

Lines changed: 359 additions & 410 deletions

File tree

include/tensor/ops.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ template <typename T, typename D>
4040
Tensor<std::remove_const_t<T>, D> pow(const TensorView<T, D>& tensor,
4141
std::remove_const_t<T> scalar);
4242

43+
template <typename T, typename D>
44+
Tensor<std::remove_const_t<T>, D> cos(const TensorView<T, D>& tensor);
45+
46+
template <typename T, typename D>
47+
Tensor<std::remove_const_t<T>, D> sin(const TensorView<T, D>& tensor);
48+
49+
template <typename T, typename D>
50+
Tensor<std::remove_const_t<T>, D> exp(const TensorView<T, D>& tensor);
51+
4352
template <typename T, typename D>
4453
Tensor<std::remove_const_t<T>, D> masked_fill(const TensorView<T, D>& input,
4554
const TensorView<int, D>& mask,

src/llama/rope.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ precompute_rope_values(size_t head_dim, float theta_base, size_t context_length)
6161

6262
angles = cat(angles.view(), angles.view(), 1); // context length, head_Dim
6363

64-
auto sin = angles.view().sin();
65-
auto cos = angles.view().cos();
64+
auto sin_ = sin(angles.view());
65+
auto cos_ = cos(angles.view());
6666

67-
return std::make_tuple(std::move(cos), std::move(sin));
67+
return std::make_tuple(std::move(cos_), std::move(sin_));
6868
}
6969

7070
template <typename T, typename D>

src/tensor/cpu/ops.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,21 @@ Tensor<std::remove_const_t<T>, D> tril(const TensorView<T, D>& tensor, const boo
152152
return out;
153153
}
154154

155+
template <typename T, typename D>
156+
Tensor<std::remove_const_t<T>, D> cos(const TensorView<T, D>& tensor) {
157+
return tensor.template map<std::remove_const_t<T>>([](T val) { return std::cos(val); });
158+
}
159+
160+
template <typename T, typename D>
161+
Tensor<std::remove_const_t<T>, D> sin(const TensorView<T, D>& tensor) {
162+
return tensor.template map<std::remove_const_t<T>>([](T val) { return std::sin(val); });
163+
}
164+
165+
template <typename T, typename D>
166+
Tensor<std::remove_const_t<T>, D> exp(const TensorView<T, D>& tensor) {
167+
return tensor.template map<std::remove_const_t<T>>([](T val) { return std::exp(val); });
168+
}
169+
155170
template <typename T, typename D>
156171
Tensor<std::remove_const_t<T>, D> pow(std::remove_const_t<T> scalar,
157172
const TensorView<T, D>& tensor) {
@@ -578,9 +593,15 @@ template Tensor<float, CPU> max(const TensorView<float, CPU>&, int, bool);
578593
template Tensor<bfloat16, CPU> masked_fill(const TensorView<bfloat16, CPU>&,
579594
const TensorView<int, CPU>&, bfloat16);
580595

596+
// cat
581597
template Tensor<bfloat16, CPU> cat(const TensorView<bfloat16, CPU>&,
582598
const TensorView<bfloat16, CPU>&, int);
583599
template Tensor<float, CPU> cat(const TensorView<float, CPU>&, const TensorView<float, CPU>&, int);
600+
601+
template Tensor<float, CPU> cos(const TensorView<float, CPU>& tensor);
602+
template Tensor<float, CPU> sin(const TensorView<float, CPU>& tensor);
603+
template Tensor<float, CPU> exp(const TensorView<float, CPU>& tensor);
604+
584605
template Tensor<bfloat16, CPU> pow(bfloat16, const TensorView<bfloat16, CPU>&);
585606
template Tensor<float, CPU> pow(const TensorView<float, CPU>&, float);
586607
template Tensor<float, CPU> pow(float, const TensorView<float, CPU>&);

src/tensor/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
file(GLOB HEADER_LIST CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/include/tensor/*.hpp")
22

3-
add_library(tensor_cuda STATIC storage.cu loader.cu ops.cu kernels/fill.cu kernels/argmax.cu kernels/arange.cu kernels/add.cu kernels/sub.cu kernels/div.cu kernels/mul.cu kernels/sum.cu kernels/max.cu kernels/masked_fill.cu kernels/cat.cu)
3+
add_library(tensor_cuda STATIC storage.cu loader.cu ops.cu kernels/fill.cu kernels/argmax.cu kernels/arange.cu kernels/sum.cu kernels/max.cu kernels/masked_fill.cu kernels/cat.cu kernels/map.cu kernels/zip.cu)
44

55
set_target_properties(tensor_cuda PROPERTIES
66
CUDA_SEPARABLE_COMPILATION ON

src/tensor/cuda/kernels/add.cu

Lines changed: 0 additions & 60 deletions
This file was deleted.

src/tensor/cuda/kernels/add.cuh

Lines changed: 0 additions & 16 deletions
This file was deleted.

src/tensor/cuda/kernels/div.cu

Lines changed: 0 additions & 103 deletions
This file was deleted.

src/tensor/cuda/kernels/div.cuh

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)