Skip to content

Commit 66e4b82

Browse files
committed
Sum and max reductions are faaasttt
1 parent c93fb1f commit 66e4b82

12 files changed

Lines changed: 535 additions & 1 deletion

File tree

CMakeLists.txt

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
cmake_minimum_required(VERSION 3.19..4.0)
2+
3+
project(
4+
Forward
5+
VERSION 0.1
6+
DESCRIPTION "A simple transformer inference engine"
7+
LANGUAGES CXX
8+
)
9+
10+
set(FETCHCONTENT_BASE_DIR "${CMAKE_SOURCE_DIR}/.cmake/fetchcontent")
11+
set(FETCHCONTENT_UPDATES_DISCONNECTED ON)
12+
13+
# msgpack in tokenizers_cpp is doing weird stuff
14+
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
15+
16+
set(CMAKE_C_COMPILER_LAUNCHER ccache)
17+
set(CMAKE_CXX_COMPILER_LAUNCHER ccache)
18+
19+
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
20+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
21+
set(CMAKE_CXX_EXTENSIONS ON)
22+
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
23+
include(CTest)
24+
25+
find_package(Doxygen)
26+
if(Doxygen_FOUND)
27+
add_subdirectory(docs)
28+
else()
29+
message(STATUS "Doxygen not found, not building docs")
30+
endif()
31+
endif()
32+
33+
include(FetchContent)
34+
35+
# Declare all external dependencies in one place
36+
FetchContent_Declare(
37+
fmtlib
38+
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
39+
GIT_TAG 12.1.0
40+
)
41+
42+
FetchContent_Declare(
43+
json
44+
GIT_REPOSITORY https://github.com/nlohmann/json
45+
GIT_TAG v3.12.0
46+
)
47+
48+
FetchContent_Declare(
49+
safetensors_cpp
50+
GIT_REPOSITORY https://github.com/syoyo/safetensors-cpp.git
51+
GIT_TAG 10f7d8f
52+
)
53+
set(SAFETENSORS_CPP_CXX_EXCEPTIONS ON CACHE BOOL "" FORCE)
54+
55+
FetchContent_Declare(
56+
tokenizers_cpp
57+
GIT_REPOSITORY https://github.com/mlc-ai/tokenizers-cpp
58+
GIT_TAG 55d53aa
59+
)
60+
61+
# On NixOS with Clang, clangd uses a different resource-dir than clang++
62+
# Set the correct resource-dir for all C++ compilation
63+
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND DEFINED CLANG_RESOURCE_DIR AND NOT CLANG_RESOURCE_DIR STREQUAL "")
64+
add_compile_options(-resource-dir=${CLANG_RESOURCE_DIR})
65+
endif()
66+
67+
68+
# Make dependencies available
69+
FetchContent_MakeAvailable(fmtlib json safetensors_cpp tokenizers_cpp)
70+
71+
# compiled library code
72+
add_subdirectory(src/tensor)
73+
add_subdirectory(src/nn)
74+
add_subdirectory(src/llama)
75+
add_subdirectory(src/forward)
76+
77+
add_subdirectory(benchmarks)
78+
79+
# executable code
80+
add_subdirectory(apps)
81+
82+
# testing only if this is the main app
83+
if((CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME OR MODERN_CMAKE_BUILD_TESTING)
84+
AND BUILD_TESTING)
85+
add_subdirectory(tests)
86+
endif()
87+

benchmarks/tensor/cpu/bm_ops.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,55 @@ BENCHMARK(BM_CPU_AddBf16)
3535
->Args({65536, 2048})
3636
->Unit(kMillisecond)
3737
->UseRealTime();
38+
39+
static void BM_CPU_SumFp32LastDim(State& state) {
40+
Tensor<float, CPU> tensor(
41+
{static_cast<size_t>(state.range(0)), static_cast<size_t>(state.range(1))});
42+
43+
tensor.fill_(float(1.0));
44+
45+
auto view = tensor.view();
46+
47+
for (auto _ : state)
48+
DoNotOptimize(sum(view, -1, true));
49+
50+
int64_t flops = 0;
51+
52+
flops += state.iterations() * state.range(0) * state.range(1);
53+
state.counters["FLOPs"] = Counter(flops, Counter::kIsRate);
54+
auto bytes_per_element = 4;
55+
state.SetBytesProcessed(static_cast<int64_t>(state.iterations()) * state.range(0) *
56+
state.range(1) * bytes_per_element);
57+
}
58+
59+
BENCHMARK(BM_CPU_SumFp32LastDim)
60+
->Args({16384, 2048})
61+
->Args({65536, 2048})
62+
->Unit(kMillisecond)
63+
->UseRealTime();
64+
65+
static void BM_CPU_SumFp32FirstDim(State& state) {
66+
Tensor<float, CPU> tensor(
67+
{static_cast<size_t>(state.range(0)), static_cast<size_t>(state.range(1))});
68+
69+
tensor.fill_(float(1.0));
70+
71+
auto view = tensor.view();
72+
73+
for (auto _ : state)
74+
DoNotOptimize(sum(view, 0, true));
75+
76+
int64_t flops = 0;
77+
78+
flops += state.iterations() * state.range(0) * state.range(1);
79+
state.counters["FLOPs"] = Counter(flops, Counter::kIsRate);
80+
auto bytes_per_element = 4;
81+
state.SetBytesProcessed(static_cast<int64_t>(state.iterations()) * state.range(0) *
82+
state.range(1) * bytes_per_element);
83+
}
84+
85+
BENCHMARK(BM_CPU_SumFp32FirstDim)
86+
->Args({16384, 2048})
87+
->Args({65536, 2048})
88+
->Unit(kMillisecond)
89+
->UseRealTime();

benchmarks/tensor/cuda/bm_ops.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,55 @@ BENCHMARK(BM_CUDA_AddBf16)
3737
->Args({262144, 2048})
3838
->Unit(kMillisecond)
3939
->UseRealTime();
40+
41+
static void BM_CUDA_SumFp32LastDim(State& state) {
42+
Tensor<float, CUDA> tensor(
43+
{static_cast<size_t>(state.range(0)), static_cast<size_t>(state.range(1))});
44+
45+
tensor.fill_(float(1.0));
46+
47+
auto view = tensor.view();
48+
49+
for (auto _ : state)
50+
DoNotOptimize(sum(view, -1, true));
51+
52+
int64_t flops = 0;
53+
54+
flops += state.iterations() * state.range(0) * state.range(1);
55+
state.counters["FLOPs"] = Counter(flops, Counter::kIsRate);
56+
auto bytes_per_element = 4;
57+
state.SetBytesProcessed(static_cast<int64_t>(state.iterations()) * state.range(0) *
58+
state.range(1) * bytes_per_element);
59+
}
60+
61+
BENCHMARK(BM_CUDA_SumFp32LastDim)
62+
->Args({16384, 2048})
63+
->Args({65536, 2048})
64+
->Unit(kMillisecond)
65+
->UseRealTime();
66+
67+
static void BM_CUDA_SumFp32FirstDim(State& state) {
68+
Tensor<float, CUDA> tensor(
69+
{static_cast<size_t>(state.range(0)), static_cast<size_t>(state.range(1))});
70+
71+
tensor.fill_(float(1.0));
72+
73+
auto view = tensor.view();
74+
75+
for (auto _ : state)
76+
DoNotOptimize(sum(view, 0, true));
77+
78+
int64_t flops = 0;
79+
80+
flops += state.iterations() * state.range(0) * state.range(1);
81+
state.counters["FLOPs"] = Counter(flops, Counter::kIsRate);
82+
auto bytes_per_element = 4;
83+
state.SetBytesProcessed(static_cast<int64_t>(state.iterations()) * state.range(0) *
84+
state.range(1) * bytes_per_element);
85+
}
86+
87+
BENCHMARK(BM_CUDA_SumFp32FirstDim)
88+
->Args({16384, 2048})
89+
->Args({65536, 2048})
90+
->Unit(kMillisecond)
91+
->UseRealTime();

src/tensor/cpu/ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,10 @@ template Tensor<bfloat16, CPU> mul(const TensorView<bfloat16, CPU>&, bfloat16);
567567
template Tensor<bfloat16, CPU> mul(const TensorView<bfloat16, CPU>&,
568568
const TensorView<bfloat16, CPU>&);
569569

570+
// sum
570571
template Tensor<float, CPU> sum(const TensorView<float, CPU>&, int, bool);
572+
573+
// max
571574
template Tensor<float, CPU> max(const TensorView<float, CPU>&, int, bool);
572575
template Tensor<bfloat16, CPU> masked_fill(const TensorView<bfloat16, CPU>&,
573576
const TensorView<int, CPU>&, bfloat16);

src/tensor/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
file(GLOB HEADER_LIST CONFIGURE_DEPENDS "${PROJECT_SOURCE_DIR}/include/tensor/*.hpp")
22

3-
add_library(tensor_cuda STATIC storage.cu loader.cu ops.cu kernels/fill.cu kernels/arange.cu kernels/add.cu kernels/sub.cu kernels/div.cu kernels/mul.cu)
3+
add_library(tensor_cuda STATIC storage.cu loader.cu ops.cu kernels/fill.cu kernels/arange.cu kernels/add.cu kernels/sub.cu kernels/div.cu kernels/mul.cu kernels/sum.cu kernels/max.cu)
44

55
set_target_properties(tensor_cuda PROPERTIES
66
CUDA_SEPARABLE_COMPILATION ON

src/tensor/cuda/kernels/max.cu

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#include "max.cuh"
2+
#include "utils.cuh"
3+
#include <cstddef>
4+
#include <cuda_bf16.hpp>
5+
#include <limits>
6+
7+
namespace tensor::kernels {
8+
9+
using namespace dtype;
10+
11+
const int blockThreads = 256;
12+
13+
__global__ void max_float_kernel(Cuda<float>* out, Cuda<float>* input, size_t num_reductions, size_t reduce_size, size_t reduce_stride) {
14+
__shared__ Cuda<float> shmem[blockThreads]; // NOLINT
15+
size_t tid = threadIdx.x;
16+
17+
size_t reduction_idx = blockIdx.x; // which reduction are we doing?
18+
19+
// decompose into outer and inner indices
20+
size_t outer_idx = reduction_idx / reduce_stride;
21+
size_t inner_idx = reduction_idx % reduce_stride;
22+
23+
// base pointer for this reduction
24+
size_t base = (outer_idx * reduce_size * reduce_stride) + inner_idx;
25+
26+
// reduce with a grid stride loop to handle reduce_size > blockThreads
27+
float thread_max = -std::numeric_limits<float>::infinity();
28+
for (size_t element = tid; element < reduce_size; element += blockDim.x) {
29+
thread_max = max(thread_max, input[base + (element * reduce_stride)]);
30+
}
31+
// now we only have to reduce 'blockThreads' elements, which is easy within a block
32+
33+
// load partial maxs onto shmem
34+
shmem[tid] = thread_max;
35+
__syncthreads();
36+
37+
// reduce in shared memory
38+
for (int stride = blockDim.x / 2; stride > 32; stride >>= 1) { // NOLINT
39+
if (tid < stride) { shmem[tid] = max(shmem[tid], shmem[tid + stride]); }
40+
__syncthreads();
41+
}
42+
43+
// warp shuffle for the final warp-level reduction
44+
if (tid < 32) {
45+
float val = max(shmem[tid], shmem[tid + 32]);
46+
for (int offset = 16; offset > 0; offset >>= 1) {
47+
val = max(val, __shfl_down_sync(0xffffffff, val, offset));
48+
}
49+
50+
if (tid == 0) {
51+
out[reduction_idx] = val;
52+
}
53+
}
54+
}
55+
56+
Tensor<float, CUDA> max_float(const TensorView<float, CUDA>& input, int dim, bool keepdim) {
57+
assert(input.is_contiguous() && "the tensor should be contiguous");
58+
59+
auto shape = input.shape;
60+
61+
if (dim < 0) {
62+
dim = shape.size() + dim;
63+
}
64+
65+
assert(dim >= 0 && static_cast<size_t>(dim) < shape.size());
66+
67+
size_t outer_size = 1; // how many reductions will we perform? ("batch size")
68+
size_t inner_size = 1; // what's the distance between elements to reduce?
69+
size_t reduce_size = 1; // how many elements each reduction needs to reduce over
70+
71+
bool found_dim = false;
72+
73+
// Output shape
74+
Shape out_shape;
75+
for (size_t i = 0; i < shape.size(); ++i) {
76+
if (i == static_cast<size_t>(dim)) {
77+
if (keepdim) {
78+
out_shape.push_back(1);
79+
}
80+
reduce_size = shape[dim];
81+
found_dim = true;
82+
} else {
83+
if (!found_dim) {
84+
outer_size *= shape[i];
85+
} else {
86+
inner_size *= shape[i];
87+
}
88+
89+
out_shape.push_back(shape[i]);
90+
}
91+
}
92+
93+
if (out_shape.empty()) {
94+
out_shape.push_back(1);
95+
}
96+
97+
auto n_elements = outer_size * inner_size;
98+
99+
auto input_strides = get_all_strides(shape);
100+
101+
TensorStorage<float, CUDA> storage(n_elements);
102+
Tensor<float, CUDA> out{out_shape, std::move(storage)};
103+
104+
int block_size = blockThreads;
105+
106+
// Convert to device-native types for kernel call
107+
auto* out_d = reinterpret_cast<Cuda<float>*>(out.data()); // NOLINT
108+
auto* input_d = reinterpret_cast<Cuda<float>*>(input.data); // NOLINT
109+
110+
max_float_kernel<<<n_elements, block_size>>>(out_d, input_d, n_elements, reduce_size, inner_size);
111+
112+
return out;
113+
}
114+
115+
} // namespace tensor::kernels

src/tensor/cuda/kernels/max.cuh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include <cuda_runtime.h>
4+
#include <tensor/device_type.hpp>
5+
#include <tensor/tensor.hpp>
6+
#include <cstddef>
7+
8+
namespace tensor::kernels {
9+
10+
using namespace dtype;
11+
12+
__global__ void max_float_kernel(Cuda<float>* out, Cuda<float>* input, size_t num_reductions, size_t reduce_size, size_t reduce_stride);
13+
14+
Tensor<float, CUDA> max_float(const TensorView<float, CUDA>& input, int dim, bool keepdim);
15+
16+
} // namespace tensor::kernels

0 commit comments

Comments
 (0)