Skip to content

Commit 035938c

Browse files
committed
gemm
1 parent 047df23 commit 035938c

15 files changed

Lines changed: 1015 additions & 117 deletions

File tree

AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ This file helps AI agents discover and understand how to work with this reposito
5959
- Restructured `README.md` into a onboarding-focused front door and added companion docs (`docs/use-cases.md`, `docs/hardware.md`, `docs/api-overview.md`, `docs/python-install.md`, `docs/torch.md`, `docs/gpu.md`, `examples/README.md`) so heavy reference material lives outside the visitor-facing overview.
6060
- Added optional CUDA/ROCm toggles plus a GPU dispatcher sketch (`include/t81/linalg/gemm_gpu.hpp`, `src/linalg/{gemm_cuda.cu,gemm_dispatch.cpp,gemm_rocm.cpp}`) so future teams can wire the new `where`/`clamp`/`lerp`/`addcmul` helpers into GPU kernels, introduced `t81::TensorMetadata` + Python helpers (`python/bindings.cpp`) that extract metadata from NumPy/Torch tensors, and expanded `tests/python/test_gpu_ops.py` to cover the metadata-backed bindings on both CPU and GPU paths.
6161
- Enhanced `tests/python/test_gguf.py` with quant-parameterized round-trip checks, metadata assertions, and a regression case for invalid quant identifiers to spotlight the GGUF helpers before future agents touch them.
62+
- Hardened the SIMD detection helpers in `include/t81/core/detail/simd.hpp` with CPUID/xgetbv fallbacks, documented the `add_trytes_*` overflow semantics, and made NEON runtime checks opt-out via `T81_DISABLE_NEON`.

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ option(T81LIB_ENABLE_TORCH_BINDINGS
1616

1717
option(USE_CUDA "Enable CUDA backend" OFF)
1818
option(USE_ROCM "Enable ROCm/HIP backend" OFF)
19+
option(USE_METAL "Enable Apple Metal backend" OFF)
1920

2021
if(USE_CUDA)
2122
enable_language(CUDA)

docs/gpu.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22

33
CUDA/ROCm kernels can be built when you configure with `-DUSE_CUDA=ON` or `-DUSE_ROCM=ON` (see `python/CMakeLists.txt`). The bindings expose `t81lib.where`, `t81lib.clamp`, `t81lib.lerp`, and `t81lib.addcmul`, which accept either NumPy buffers or PyTorch tensors and dispatch directly to the GPU kernels.
44

5-
Dispatch relies on `t81::TensorMetadata` (`include/t81/tensor_metadata.hpp`): a lightweight struct that carries device tags, dtype codes, shape, strides, and `data_ptr` so the dispatcher can call the right CUDA/HIP kernel without copies. When torch is available, `t81lib` automatically wraps tensors; without torch it gracefully falls back to CPU buffers. Review `python/bindings.cpp` for the extraction helpers and lifetime management and follow the [GPU dispatch diagram](diagrams/gpu-dispatch.mermaid.md) for the metadata flow.
5+
The dispatcher is driven entirely by `t81::TensorMetadata` (`include/t81/tensor_metadata.hpp`): a lightweight struct that carries device tags, dtype codes, shape, strides, and `data_ptr` so the runtime can call the right CUDA/HIP kernel without copies. Torch-aware helpers in `python/bindings.cpp` create metadata for GPU tensors (including a `requires_sync` flag when needed) and fall back to contiguous CPU buffers when torch is unavailable.
6+
7+
`t81lib.gemm_ternary` now shares the same metadata plumbing. The CUDA/HIP kernels view `ScalarType::TernaryLimb` buffers as packed `core::limb` rows (`TRYTES_PER_LIMB` trytes packed into 16 bytes) and expect contiguous layouts (`np.dtype('V16')` rows or `torch.uint8` views with dimensions `(M, K_limbs)` / `(K_limbs, N)`). The accumulator `C` must remain float32 and contiguous. With `Backend::Auto`, the binding dispatches to CUDA/ROCm when available; otherwise it falls back to the CPU path. Review the [GPU dispatch diagram](diagrams/gpu-dispatch.mermaid.md) for how metadata flows from NumPy/Torch -> CUDA/HIP -> back to Python.

docs/python-cookbook.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ t81lib.gemm_ternary(packed, packed, c, 16, 16, 48)
1717

1818
This shows how to drive the low-level binding (`t81lib.pack_dense_matrix`) together with `gemm_ternary` without needing PyTorch.
1919

20+
### Keep packed buffers on the GPU
21+
22+
When CUDA/ROCm support is enabled, `t81lib.gemm_ternary` accepts GPU-backed metadata directly. The dispatcher expects A/B to describe `ScalarType::TernaryLimb` rows with `TRYTES_PER_LIMB` packed trytes (e.g., NumPy `dtype('V16')` or `torch.uint8` views shaped `(M, k_limbs, 16)` after calling `torch.from_numpy(packed).reshape(...)`). The accumulator `C` stays float32 contiguous, and with `Backend::Auto` the binding routes the work to the compiled GPU kernel if the necessary backend is available.
23+
24+
Use `t81.torch.TernaryTensor` to keep limbs on the GPU and let the binding generate the required metadata, or copy a packed NumPy buffer onto CUDA/ROCm if you need to interface with other tooling. Because the binding now shares the same `TensorMetadata` flow as `t81lib.where`/`clamp`/`lerp`, no extra copying or manual span conversions are required when the inputs already live on a compatible device.
25+
2026
## 2. Drop in `t81.torch.TernaryTensor` during training
2127

2228
```python

include/t81/core/detail/simd.hpp

Lines changed: 123 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,57 +5,160 @@
55

66
#include <optional>
77
#include <utility>
8+
#include <cstdint>
9+
10+
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86)
11+
#include <immintrin.h>
12+
#if defined(_MSC_VER)
13+
#include <intrin.h>
14+
#elif defined(__GNUC__) || defined(__clang__)
15+
#include <cpuid.h>
16+
#endif
17+
#endif
818

919
namespace t81::core {
1020
class limb;
1121
} // namespace t81::core
1222

1323
namespace t81::core::detail {
1424

15-
inline bool cpu_supports_avx2() noexcept {
25+
namespace {
26+
1627
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86)
17-
#if defined(__has_builtin)
18-
#if __has_builtin(__builtin_cpu_supports)
19-
return __builtin_cpu_supports("avx2");
28+
struct cpuid_regs {
29+
unsigned int eax{};
30+
unsigned int ebx{};
31+
unsigned int ecx{};
32+
unsigned int edx{};
33+
};
34+
35+
inline cpuid_regs read_cpuid(unsigned int leaf, unsigned int subleaf) {
36+
#if defined(_MSC_VER)
37+
int regs[4];
38+
__cpuidex(regs, leaf, subleaf);
39+
return {static_cast<unsigned int>(regs[0]),
40+
static_cast<unsigned int>(regs[1]),
41+
static_cast<unsigned int>(regs[2]),
42+
static_cast<unsigned int>(regs[3])};
43+
#elif defined(__GNUC__) || defined(__clang__)
44+
unsigned int eax, ebx, ecx, edx;
45+
__get_cpuid_count(leaf, subleaf, &eax, &ebx, &ecx, &edx);
46+
return {eax, ebx, ecx, edx};
2047
#else
21-
return false;
48+
return {};
2249
#endif
50+
}
51+
52+
inline unsigned long long read_xcr0() {
53+
#if defined(_MSC_VER)
54+
return _xgetbv(0);
2355
#else
24-
return false;
56+
unsigned int eax, edx;
57+
__asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0));
58+
return (static_cast<unsigned long long>(edx) << 32) | eax;
2559
#endif
26-
#else
60+
}
61+
62+
inline bool os_supports_xsave() {
63+
const auto leaf1 = read_cpuid(1, 0);
64+
return (leaf1.ecx & (1u << 27)) != 0;
65+
}
66+
67+
inline bool os_supports_xcr_states(unsigned long long mask) {
68+
if (!os_supports_xsave()) {
69+
return false;
70+
}
71+
return (read_xcr0() & mask) == mask;
72+
}
73+
74+
inline bool os_supports_avx_states() {
75+
constexpr unsigned long long kMask = (1ull << 1) | (1ull << 2);
76+
return os_supports_xcr_states(kMask);
77+
}
78+
79+
inline bool os_supports_avx512_states() {
80+
constexpr unsigned long long kMask =
81+
(1ull << 1) | (1ull << 2) | (1ull << 5) | (1ull << 6) | (1ull << 7);
82+
return os_supports_xcr_states(kMask);
83+
}
84+
85+
inline bool cpu_reports_avx() {
86+
const auto leaf1 = read_cpuid(1, 0);
87+
return (leaf1.ecx & (1u << 28)) != 0;
88+
}
89+
90+
inline bool cpu_reports_avx2() {
91+
const auto leaf7 = read_cpuid(7, 0);
92+
return (leaf7.ebx & (1u << 5)) != 0;
93+
}
94+
95+
inline bool cpu_reports_avx512f() {
96+
const auto leaf7 = read_cpuid(7, 0);
97+
return (leaf7.ebx & (1u << 16)) != 0;
98+
}
99+
100+
inline bool has_runtime_avx2() {
101+
if (!os_supports_avx_states() || !cpu_reports_avx()) {
102+
return false;
103+
}
104+
return cpu_reports_avx2();
105+
}
106+
107+
inline bool has_runtime_avx512f() {
108+
if (!os_supports_avx512_states()) {
27109
return false;
110+
}
111+
return cpu_reports_avx512f();
112+
}
113+
#else
114+
inline bool has_runtime_avx2() {
115+
return false;
116+
}
117+
118+
inline bool has_runtime_avx512f() {
119+
return false;
120+
}
121+
#endif
122+
123+
} // namespace
124+
125+
inline bool cpu_supports_avx2() noexcept {
126+
#if defined(__has_builtin)
127+
#if __has_builtin(__builtin_cpu_supports)
128+
#if defined(__AVX2__)
129+
if (__builtin_cpu_supports("avx2")) {
130+
return true;
131+
}
132+
#endif
133+
#endif
28134
#endif
135+
return has_runtime_avx2();
29136
}
30137

31138
inline bool cpu_supports_avx512f() noexcept {
32-
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86)
33139
#if defined(__has_builtin)
34140
#if __has_builtin(__builtin_cpu_supports)
35-
return __builtin_cpu_supports("avx512f");
36-
#else
37-
return false;
141+
#if defined(__AVX512F__)
142+
if (__builtin_cpu_supports("avx512f")) {
143+
return true;
144+
}
38145
#endif
39-
#else
40-
return false;
41146
#endif
42-
#else
43-
return false;
44147
#endif
148+
return has_runtime_avx512f();
45149
}
46150

47151
inline bool cpu_supports_neon() noexcept {
48-
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
49-
#if defined(T81_ENABLE_NEON)
50-
return true;
51-
#else
152+
#if defined(T81_DISABLE_NEON)
52153
return false;
53-
#endif
154+
#elif defined(__ARM_NEON) || defined(__ARM_NEON__) || defined(T81_ENABLE_NEON)
155+
return true;
54156
#else
55157
return false;
56158
#endif
57159
}
58160

161+
// Returns true when SIMD addition completes without an overflow carry.
59162
bool add_trytes_avx2(const limb &, const limb &, limb &);
60163
bool add_trytes_avx512(const limb &, const limb &, limb &);
61164
bool add_trytes_neon(const limb &, const limb &, limb &);

include/t81/linalg/gemm.hpp

Lines changed: 66 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <stdexcept>
1010

1111
#include <t81/core/limb.hpp>
12+
#include <t81/linalg/gemm_gpu.hpp>
1213

1314
namespace t81::linalg {
1415

@@ -62,87 +63,82 @@ namespace t81::linalg {
6263
return low_value + high_value * radix;
6364
}
6465

65-
} // namespace detail
66-
67-
inline void gemm_ternary(std::span<const core::limb> A,
68-
std::span<const core::limb> B,
69-
std::span<float> C,
70-
int M,
71-
int N,
72-
int K,
73-
float alpha,
74-
float beta) {
75-
if (M < 0 || N < 0 || K < 0) {
76-
throw std::invalid_argument("gemm_ternary dimensions must be non-negative");
77-
}
78-
if (K % core::limb::TRITS != 0) {
79-
throw std::invalid_argument("gemm_ternary requires K divisible by 48");
80-
}
81-
const int K_limbs = K / core::limb::TRITS;
82-
if (static_cast<std::size_t>(M) * static_cast<std::size_t>(K_limbs) != A.size()) {
83-
throw std::invalid_argument("A span size does not match (M, K / 48)");
84-
}
85-
if (static_cast<std::size_t>(K_limbs) * static_cast<std::size_t>(N) != B.size()) {
86-
throw std::invalid_argument("B span size does not match (K / 48, N)");
87-
}
88-
if (static_cast<std::size_t>(M) * static_cast<std::size_t>(N) != C.size()) {
89-
throw std::invalid_argument("C span size does not match (M, N)");
90-
}
91-
92-
if (M == 0 || N == 0) {
93-
return;
94-
}
66+
inline void gemm_ternary_cpu_impl(std::span<const core::limb> A,
67+
std::span<const core::limb> B,
68+
std::span<float> C,
69+
int M,
70+
int N,
71+
int K,
72+
int K_limbs,
73+
float alpha,
74+
float beta) {
75+
if (M == 0 || N == 0) {
76+
return;
77+
}
9578

96-
constexpr int BlockM = 8;
97-
constexpr int BlockN = 8;
98-
constexpr int BlockK = 4;
99-
const std::size_t N_size = static_cast<std::size_t>(N);
100-
const auto *const a_data = A.data();
101-
const auto *const b_data = B.data();
102-
auto *const c_data = C.data();
103-
104-
for (int ib = 0; ib < M; ib += BlockM) {
105-
const int i_end = std::min(M, ib + BlockM);
106-
for (int jb = 0; jb < N; jb += BlockN) {
107-
const int j_end = std::min(N, jb + BlockN);
108-
std::array<std::array<double, BlockN>, BlockM> accum{};
109-
for (int i = ib; i < i_end; ++i) {
110-
const std::size_t row = static_cast<std::size_t>(i) * N_size;
111-
for (int j = jb; j < j_end; ++j) {
112-
const float existing = c_data[row + static_cast<std::size_t>(j)];
113-
accum[i - ib][j - jb] = static_cast<double>(existing) * beta;
79+
constexpr int BlockM = 8;
80+
constexpr int BlockN = 8;
81+
constexpr int BlockK = 4;
82+
const std::size_t N_size = static_cast<std::size_t>(N);
83+
const auto *const a_data = A.data();
84+
const auto *const b_data = B.data();
85+
auto *const c_data = C.data();
86+
87+
for (int ib = 0; ib < M; ib += BlockM) {
88+
const int i_end = std::min(M, ib + BlockM);
89+
for (int jb = 0; jb < N; jb += BlockN) {
90+
const int j_end = std::min(N, jb + BlockN);
91+
std::array<std::array<double, BlockN>, BlockM> accum{};
92+
for (int i = ib; i < i_end; ++i) {
93+
const std::size_t row = static_cast<std::size_t>(i) * N_size;
94+
for (int j = jb; j < j_end; ++j) {
95+
const float existing = c_data[row + static_cast<std::size_t>(j)];
96+
accum[i - ib][j - jb] = static_cast<double>(existing) * beta;
97+
}
11498
}
115-
}
11699

117-
for (int kb = 0; kb < K_limbs; kb += BlockK) {
118-
const int k_end = std::min(K_limbs, kb + BlockK);
119-
for (int k = kb; k < k_end; ++k) {
120-
const std::size_t b_row = static_cast<std::size_t>(k) * N_size;
121-
for (int j = jb; j < j_end; ++j) {
122-
const core::limb b_value = b_data[b_row + static_cast<std::size_t>(j)];
123-
detail::prefetch_read(b_data + b_row + static_cast<std::size_t>(j) + 1);
124-
for (int i = ib; i < i_end; ++i) {
125-
const std::size_t a_index = static_cast<std::size_t>(i) *
126-
static_cast<std::size_t>(K_limbs) +
127-
static_cast<std::size_t>(k);
128-
const core::limb a_value = a_data[a_index];
129-
const double product = detail::multiply_to_double(a_value, b_value);
130-
accum[i - ib][j - jb] += product * static_cast<double>(alpha);
131-
detail::prefetch_read(a_data + a_index + 1);
100+
for (int kb = 0; kb < K_limbs; kb += BlockK) {
101+
const int k_end = std::min(K_limbs, kb + BlockK);
102+
for (int k = kb; k < k_end; ++k) {
103+
const std::size_t b_row = static_cast<std::size_t>(k) * N_size;
104+
for (int j = jb; j < j_end; ++j) {
105+
const core::limb b_value = b_data[b_row + static_cast<std::size_t>(j)];
106+
detail::prefetch_read(b_data + b_row + static_cast<std::size_t>(j) + 1);
107+
for (int i = ib; i < i_end; ++i) {
108+
const std::size_t a_index = static_cast<std::size_t>(i) *
109+
static_cast<std::size_t>(K_limbs) +
110+
static_cast<std::size_t>(k);
111+
const core::limb a_value = a_data[a_index];
112+
const double product = detail::multiply_to_double(a_value, b_value);
113+
accum[i - ib][j - jb] += product * static_cast<double>(alpha);
114+
detail::prefetch_read(a_data + a_index + 1);
115+
}
132116
}
133117
}
134118
}
135-
}
136119

137-
for (int i = ib; i < i_end; ++i) {
138-
const std::size_t row = static_cast<std::size_t>(i) * N_size;
139-
for (int j = jb; j < j_end; ++j) {
140-
c_data[row + static_cast<std::size_t>(j)] =
141-
static_cast<float>(accum[i - ib][j - jb]);
120+
for (int i = ib; i < i_end; ++i) {
121+
const std::size_t row = static_cast<std::size_t>(i) * N_size;
122+
for (int j = jb; j < j_end; ++j) {
123+
c_data[row + static_cast<std::size_t>(j)] =
124+
static_cast<float>(accum[i - ib][j - jb]);
125+
}
142126
}
143127
}
144128
}
145129
}
130+
131+
} // namespace detail
132+
133+
inline void gemm_ternary(std::span<const core::limb> A,
134+
std::span<const core::limb> B,
135+
std::span<float> C,
136+
int M,
137+
int N,
138+
int K,
139+
float alpha,
140+
float beta) {
141+
detail::gemm_ternary_dispatch(A, B, C, M, N, K, alpha, beta);
146142
}
147143

148144
} // namespace t81::linalg

0 commit comments

Comments
 (0)