|
9 | 9 | #include <stdexcept> |
10 | 10 |
|
11 | 11 | #include <t81/core/limb.hpp> |
| 12 | +#include <t81/linalg/gemm_gpu.hpp> |
12 | 13 |
|
13 | 14 | namespace t81::linalg { |
14 | 15 |
|
@@ -62,87 +63,82 @@ namespace t81::linalg { |
62 | 63 | return low_value + high_value * radix; |
63 | 64 | } |
64 | 65 |
|
65 | | - } // namespace detail |
66 | | - |
67 | | - inline void gemm_ternary(std::span<const core::limb> A, |
68 | | - std::span<const core::limb> B, |
69 | | - std::span<float> C, |
70 | | - int M, |
71 | | - int N, |
72 | | - int K, |
73 | | - float alpha, |
74 | | - float beta) { |
75 | | - if (M < 0 || N < 0 || K < 0) { |
76 | | - throw std::invalid_argument("gemm_ternary dimensions must be non-negative"); |
77 | | - } |
78 | | - if (K % core::limb::TRITS != 0) { |
79 | | - throw std::invalid_argument("gemm_ternary requires K divisible by 48"); |
80 | | - } |
81 | | - const int K_limbs = K / core::limb::TRITS; |
82 | | - if (static_cast<std::size_t>(M) * static_cast<std::size_t>(K_limbs) != A.size()) { |
83 | | - throw std::invalid_argument("A span size does not match (M, K / 48)"); |
84 | | - } |
85 | | - if (static_cast<std::size_t>(K_limbs) * static_cast<std::size_t>(N) != B.size()) { |
86 | | - throw std::invalid_argument("B span size does not match (K / 48, N)"); |
87 | | - } |
88 | | - if (static_cast<std::size_t>(M) * static_cast<std::size_t>(N) != C.size()) { |
89 | | - throw std::invalid_argument("C span size does not match (M, N)"); |
90 | | - } |
91 | | - |
92 | | - if (M == 0 || N == 0) { |
93 | | - return; |
94 | | - } |
| 66 | + inline void gemm_ternary_cpu_impl(std::span<const core::limb> A, |
| 67 | + std::span<const core::limb> B, |
| 68 | + std::span<float> C, |
| 69 | + int M, |
| 70 | + int N, |
| 71 | + int K, |
| 72 | + int K_limbs, |
| 73 | + float alpha, |
| 74 | + float beta) { |
| 75 | + if (M == 0 || N == 0) { |
| 76 | + return; |
| 77 | + } |
95 | 78 |
|
96 | | - constexpr int BlockM = 8; |
97 | | - constexpr int BlockN = 8; |
98 | | - constexpr int BlockK = 4; |
99 | | - const std::size_t N_size = static_cast<std::size_t>(N); |
100 | | - const auto *const a_data = A.data(); |
101 | | - const auto *const b_data = B.data(); |
102 | | - auto *const c_data = C.data(); |
103 | | - |
104 | | - for (int ib = 0; ib < M; ib += BlockM) { |
105 | | - const int i_end = std::min(M, ib + BlockM); |
106 | | - for (int jb = 0; jb < N; jb += BlockN) { |
107 | | - const int j_end = std::min(N, jb + BlockN); |
108 | | - std::array<std::array<double, BlockN>, BlockM> accum{}; |
109 | | - for (int i = ib; i < i_end; ++i) { |
110 | | - const std::size_t row = static_cast<std::size_t>(i) * N_size; |
111 | | - for (int j = jb; j < j_end; ++j) { |
112 | | - const float existing = c_data[row + static_cast<std::size_t>(j)]; |
113 | | - accum[i - ib][j - jb] = static_cast<double>(existing) * beta; |
| 79 | + constexpr int BlockM = 8; |
| 80 | + constexpr int BlockN = 8; |
| 81 | + constexpr int BlockK = 4; |
| 82 | + const std::size_t N_size = static_cast<std::size_t>(N); |
| 83 | + const auto *const a_data = A.data(); |
| 84 | + const auto *const b_data = B.data(); |
| 85 | + auto *const c_data = C.data(); |
| 86 | + |
| 87 | + for (int ib = 0; ib < M; ib += BlockM) { |
| 88 | + const int i_end = std::min(M, ib + BlockM); |
| 89 | + for (int jb = 0; jb < N; jb += BlockN) { |
| 90 | + const int j_end = std::min(N, jb + BlockN); |
| 91 | + std::array<std::array<double, BlockN>, BlockM> accum{}; |
| 92 | + for (int i = ib; i < i_end; ++i) { |
| 93 | + const std::size_t row = static_cast<std::size_t>(i) * N_size; |
| 94 | + for (int j = jb; j < j_end; ++j) { |
| 95 | + const float existing = c_data[row + static_cast<std::size_t>(j)]; |
| 96 | + accum[i - ib][j - jb] = static_cast<double>(existing) * beta; |
| 97 | + } |
114 | 98 | } |
115 | | - } |
116 | 99 |
|
117 | | - for (int kb = 0; kb < K_limbs; kb += BlockK) { |
118 | | - const int k_end = std::min(K_limbs, kb + BlockK); |
119 | | - for (int k = kb; k < k_end; ++k) { |
120 | | - const std::size_t b_row = static_cast<std::size_t>(k) * N_size; |
121 | | - for (int j = jb; j < j_end; ++j) { |
122 | | - const core::limb b_value = b_data[b_row + static_cast<std::size_t>(j)]; |
123 | | - detail::prefetch_read(b_data + b_row + static_cast<std::size_t>(j) + 1); |
124 | | - for (int i = ib; i < i_end; ++i) { |
125 | | - const std::size_t a_index = static_cast<std::size_t>(i) * |
126 | | - static_cast<std::size_t>(K_limbs) + |
127 | | - static_cast<std::size_t>(k); |
128 | | - const core::limb a_value = a_data[a_index]; |
129 | | - const double product = detail::multiply_to_double(a_value, b_value); |
130 | | - accum[i - ib][j - jb] += product * static_cast<double>(alpha); |
131 | | - detail::prefetch_read(a_data + a_index + 1); |
| 100 | + for (int kb = 0; kb < K_limbs; kb += BlockK) { |
| 101 | + const int k_end = std::min(K_limbs, kb + BlockK); |
| 102 | + for (int k = kb; k < k_end; ++k) { |
| 103 | + const std::size_t b_row = static_cast<std::size_t>(k) * N_size; |
| 104 | + for (int j = jb; j < j_end; ++j) { |
| 105 | + const core::limb b_value = b_data[b_row + static_cast<std::size_t>(j)]; |
| 106 | + detail::prefetch_read(b_data + b_row + static_cast<std::size_t>(j) + 1); |
| 107 | + for (int i = ib; i < i_end; ++i) { |
| 108 | + const std::size_t a_index = static_cast<std::size_t>(i) * |
| 109 | + static_cast<std::size_t>(K_limbs) + |
| 110 | + static_cast<std::size_t>(k); |
| 111 | + const core::limb a_value = a_data[a_index]; |
| 112 | + const double product = detail::multiply_to_double(a_value, b_value); |
| 113 | + accum[i - ib][j - jb] += product * static_cast<double>(alpha); |
| 114 | + detail::prefetch_read(a_data + a_index + 1); |
| 115 | + } |
132 | 116 | } |
133 | 117 | } |
134 | 118 | } |
135 | | - } |
136 | 119 |
|
137 | | - for (int i = ib; i < i_end; ++i) { |
138 | | - const std::size_t row = static_cast<std::size_t>(i) * N_size; |
139 | | - for (int j = jb; j < j_end; ++j) { |
140 | | - c_data[row + static_cast<std::size_t>(j)] = |
141 | | - static_cast<float>(accum[i - ib][j - jb]); |
| 120 | + for (int i = ib; i < i_end; ++i) { |
| 121 | + const std::size_t row = static_cast<std::size_t>(i) * N_size; |
| 122 | + for (int j = jb; j < j_end; ++j) { |
| 123 | + c_data[row + static_cast<std::size_t>(j)] = |
| 124 | + static_cast<float>(accum[i - ib][j - jb]); |
| 125 | + } |
142 | 126 | } |
143 | 127 | } |
144 | 128 | } |
145 | 129 | } |
| 130 | + |
| 131 | + } // namespace detail |
| 132 | + |
| 133 | + inline void gemm_ternary(std::span<const core::limb> A, |
| 134 | + std::span<const core::limb> B, |
| 135 | + std::span<float> C, |
| 136 | + int M, |
| 137 | + int N, |
| 138 | + int K, |
| 139 | + float alpha, |
| 140 | + float beta) { |
| 141 | + detail::gemm_ternary_dispatch(A, B, C, M, N, K, alpha, beta); |
146 | 142 | } |
147 | 143 |
|
148 | 144 | } // namespace t81::linalg |
0 commit comments