-
Notifications
You must be signed in to change notification settings - Fork 631
Expand file tree
/
Copy pathtypes.h
More file actions
371 lines (325 loc) · 12.6 KB
/
types.h
File metadata and controls
371 lines (325 loc) · 12.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Types shared between tensor definitions and `compress-inl.h`.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_
#include <stddef.h>
#include <stdint.h>
// IWYU pragma: begin_exports
#include "util/basics.h" // BF16
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // HWY_INLINE
// IWYU pragma: end_exports
namespace gcpp {
// EMU128 must not be disabled because we disable SCALAR.
#define HWY_BROKEN_EMU128 0
// Allow user override of disabled targets.
#ifndef GEMMA_DISABLED_TARGETS
// All platforms: exclude SCALAR because we use ReorderWidenMulAccumulate.
#if HWY_ARCH_ARM_V7
// No NEON because we require double-precision support.
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_ALL_NEON)
#elif HWY_ARCH_ARM_A64
// We do not yet use AES (e.g. for random generation), hence NEON is the same
// as NEON_WITHOUT_AES. Also skip SVE because SVE2_128 and SVE_256 cover most.
#define GEMMA_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON | HWY_SVE)
#elif HWY_ARCH_X86
// Skip anything older than Haswell (2013); use Zen4/SPR for recent CPUs.
// Although we do not use SPR's F16, Zen4 is only enabled for AMD. We do not
// yet use any AVX 10.2 features.
#define GEMMA_DISABLED_TARGETS \
(HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX10_2)
#elif HWY_ARCH_WASM
#define GEMMA_DISABLED_TARGETS HWY_SCALAR
#endif // HWY_ARCH_*
#endif // GEMMA_DISABLED_TARGETS
// Only used in experiments, hence disable in default builds.
#ifndef GEMMA_ENABLE_NUQ
#define GEMMA_ENABLE_NUQ 0
#endif
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
//
// Characteristics:
// - 24-bit dynamic range, with max exponent 2^0.
// - 3 bit mantissa for values >= 2^-7, otherwise 2.
//
// A pointer to this is the *start* of an SFP stream. Aligning the allocation
// (see aligned_allocator.h) may speed up decoding but is not required.
//
// Layout: Values are stored in-order to enable vector-length agnostic seeking,
// because streams may be written to disk for loading on other CPUs.
//
// This is faster to decode than a straightforward implementation of eXmY, in
// part because SFP does not require subnormals. Unlike OCP MX, it also does not
// require side information (shared exponents).
//
// Although the representation could probably be shrunk to 6-7 bits, more
// savings can be had by non-uniform clustering - see NuqStream.
#pragma pack(push, 1)
struct SfpStream {
// Largest possible input magnitude: 1.111 * 2^0. This could be increased by
// shifting the value range (exponent bias).
static constexpr float kMax = 1.875f;
uint8_t byte;
};
#pragma pack(pop)
#pragma pack(push, 1)
struct I8Stream {
static constexpr size_t kGroupSize = 128;
using ScaleT = hwy::bfloat16_t;
// Returns number of I8Stream to allocate for the stream, which matches its
// size in bytes.
// TODO: should support other types beyond hwy::float32_t for scale and
// zero-point.
static constexpr size_t PackedEnd(size_t capacity) {
const size_t num_groups = hwy::DivCeil(capacity, kGroupSize);
return (sizeof(ScaleT) * num_groups) + // scale
(sizeof(ScaleT) * num_groups) + // zero-point
capacity; // 1 value per byte
}
int8_t i;
};
#pragma pack(pop)
// Non-uniform quantization: a compressed representation of f32 inputs that
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
// two vectors (for `Decompress2`), and decoding to bf16/f32.
//
// A pointer to this is the *start* of a NUQ stream. Aligning the allocation
// (see aligned_allocator.h) may be speed up decoding but is not required.
//
// Layout: first one table of kClusters entries per group, in ascending order
// of group index, then two packed indices per byte. Indices are stored
// in-order to enable vector-length agnostic decode, because streams may be
// persisted to disk and used by other CPUs.
//
// To enable parallel encoding and decoding, Enc/Dec have `offset` parameters
// which refer to the stream, NOT the raw from/to pointers, which point directly
// to the source/destination. Offsets are in units of values, NOT compressed
// bytes within the stream.
#pragma pack(push, 1)
struct NuqStream {
// 4-bit indices are a sweet spot in terms of quality per size.
static constexpr size_t kClusters = 16;
// Number of weights that share a table. Larger = slower encode, higher error,
// smaller size (table amortized over more weights).
static constexpr size_t kGroupSize = 256;
// Storage for dynamic programming. There are two matrices; we use separate
// allocations to avoid type punning.
template <class T>
class AlignedMatrix {
public:
AlignedMatrix() : mem_(hwy::AllocateAligned<T>(kClusters * kGroupSize)) {}
HWY_INLINE const T& operator()(size_t row, size_t col) const {
return mem_[row * kGroupSize + col];
}
HWY_INLINE T& operator()(size_t row, size_t col) {
return mem_[row * kGroupSize + col];
}
private:
hwy::AlignedFreeUniquePtr<T[]> mem_;
};
// Reuse memory across calls to Enc to avoid per-call allocations.
struct ClusterBuf {
// Move-only (stored inside vector in CompressWorkingSet).
ClusterBuf() = default;
ClusterBuf(const ClusterBuf&) = delete;
ClusterBuf& operator=(const ClusterBuf&) = delete;
ClusterBuf(ClusterBuf&&) = default;
ClusterBuf& operator=(ClusterBuf&&) = default;
// Independent of num_groups.
AlignedMatrix<float> costs;
AlignedMatrix<int32_t> argmin;
};
// Returns offset of packed indices from the start of the stream. This matches
// the (padded) total table size because table entries are bytes.
static constexpr size_t PackedStart(size_t capacity) {
// Round up to avoid cache-line splits when loading indices. No effect on
// size as long as capacity / kGroupSize is a multiple of 4.
return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64);
}
// Returns number of NuqStream to allocate for the stream, which matches its
// size in bytes.
static constexpr size_t PackedEnd(size_t capacity) {
const size_t num_groups = hwy::DivCeil(capacity, kGroupSize);
return (kClusters * num_groups) +
hwy::DivCeil(capacity, 2); // 2x 4-bit/byte
}
uint8_t byte;
};
#pragma pack(pop)
template <typename Packed>
constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
}
template <typename Packed>
constexpr bool IsBF16() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();
}
template <typename Packed>
constexpr bool IsSfpStream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, SfpStream>();
}
template <typename Packed>
constexpr bool IsNuqStream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
}
template <typename Packed>
constexpr bool IsI8Stream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>();
}
template <typename Packed>
constexpr bool SupportsPointerArithmetic() {
return !IsNuqStream<Packed>() && !IsI8Stream<Packed>();
}
// Tensor types for loading weights. Not all of these are supported weight
// types, some are only used for `Activations`. Append-only.
enum class Type {
kUnknown,
kF32,
kBF16,
kSFP,
kNUQ,
kF64,
kU32,
kU64,
kI8,
kU16,
kU8,
};
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "u32", "u64",
"i8", "u16", "u8"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {
0,
8 * sizeof(float),
8 * sizeof(BF16),
8 * sizeof(SfpStream),
4 /* NuqStream, actually 4.5 */,
8 * sizeof(double),
8 * sizeof(uint32_t),
8 * sizeof(uint64_t),
8 * sizeof(I8Stream),
8 * sizeof(uint16_t),
8 * sizeof(uint8_t),
};
static inline bool EnumValid(Type type) {
return static_cast<size_t>(type) < kNumTypes;
}
// Returns a Type enum for the type of the template parameter.
template <typename PackedT>
constexpr Type TypeEnum() {
using Packed = hwy::RemoveCvRef<PackedT>;
if constexpr (hwy::IsSame<Packed, float>()) {
return Type::kF32;
} else if constexpr (hwy::IsSame<Packed, BF16>()) {
return Type::kBF16;
} else if constexpr (hwy::IsSame<Packed, SfpStream>()) {
return Type::kSFP;
} else if constexpr (hwy::IsSame<Packed, NuqStream>()) {
return Type::kNUQ;
} else if constexpr (hwy::IsSame<Packed, double>()) {
return Type::kF64;
} else if constexpr (hwy::IsSame<Packed, uint32_t>()) {
return Type::kU32;
} else if constexpr (hwy::IsSame<Packed, uint64_t>()) {
return Type::kU64;
} else if constexpr (hwy::IsSame<Packed, I8Stream>()) {
return Type::kI8;
} else if constexpr (hwy::IsSame<Packed, uint16_t>()) {
return Type::kU16;
} else if constexpr (hwy::IsSame<Packed, uint8_t>()) {
return Type::kU8;
} else {
return Type::kUnknown;
}
}
static inline size_t TypeBits(Type type) {
return kTypeBits[static_cast<int>(type)];
}
static inline const char* TypeName(Type type) {
return kTypeStrings[static_cast<int>(type)];
}
template <typename PackedT>
const char* TypeName() {
return TypeName(TypeEnum<PackedT>());
}
template <typename Packed>
constexpr bool IsCompressed() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, SfpStream>() ||
hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>() ||
hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>();
}
// Returns the number of `MatT` elements required to store `capacity` values,
// which must not be zero. This is only intended to support the extra tables
// required for NUQ. `capacity` includes any padding and is `rows * stride`.
// Deprecated, replaced by fixup within `MatPtr`. Only used by tests.
template <typename Packed>
constexpr size_t CompressedArrayElements(size_t capacity) {
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
return NuqStream::PackedEnd(capacity);
} else if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, I8Stream>()) {
return I8Stream::PackedEnd(capacity);
} else {
return capacity;
}
}
// Non-owning view of packed elements. Shortens argument lists.
//
// Callers typically also pass an `ofs` starting offset. This is not folded
// into `ptr` because NUQ consists of two separate streams. To discourage direct
// use of `ptr` without that offset, we define a separate class instead of
// reusing `hwy::Span`.
template <typename Packed>
struct PackedSpan {
// Ensures callers can read or write `num_accessible` elements starting at
// `packed_ofs`.
void BoundsCheck(size_t packed_ofs, size_t num_accessible) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
// For NUQ, there can be fewer Packed than the number of elements, hence
// check the compressed count and ensure we have that many.
const size_t required =
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
if (num < required) {
HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed",
packed_ofs, num_accessible, required, num);
}
}
}
Packed* HWY_RESTRICT ptr;
size_t num; // for BoundsCheck, also required by nuq-inl.h.
};
// Avoids spelling out the template parameter in every call.
template <typename Packed>
HWY_INLINE PackedSpan<Packed> MakeSpan(Packed* ptr, size_t size) {
return {ptr, size};
}
template <typename Packed>
HWY_INLINE PackedSpan<const Packed> MakeConstSpan(Packed* ptr, size_t size) {
return {ptr, size};
}
// "Implicit" conversion from a PackedSpan<T> to PackedSpan<const T>, used in
// `RMSNormInplace` and compression tests.
template <typename Packed>
HWY_INLINE PackedSpan<const Packed> MakeConst(PackedSpan<Packed> packed) {
return {packed.ptr, packed.num};
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_