-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy pathexpand.cc
More file actions
70 lines (58 loc) · 2.74 KB
/
expand.cc
File metadata and controls
70 lines (58 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#include "../aie_kernel_utils.h"
#include <aie_api/aie.hpp>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <type_traits>
template <typename T_in, typename T_sf, typename T_out, const int N, const int G> void expand(T_in *in, T_out *out)
{
// Keep vector width constant; group size can vary as a multiple of 32
constexpr int block_size = 32;
constexpr int blocks_per_group = G / block_size;
constexpr int groups_per_tile = N / G;
// Super block size = block_size x blocks_per_group
static_assert((G % block_size) == 0, "GROUP_SIZE must be a multiple of 32");
T_in *__restrict pI = in; // Input pointer
T_in *__restrict pSFb = in + N / 2; // The scale factors are after the inputs
T_sf *__restrict pSF = (T_sf *)pSFb; // But we only advance by the number of bytes not elements
T_out *__restrict pO = out;
const int F = groups_per_tile; // iterate over groups of size GROUP_SIZE
event0();
for (int i = 0; i < F; i++)
chess_prepare_for_pipelining chess_loop_range(F, )
{ // 16 -> F
// Load one scale per group (scalar load)
T_sf sf = *pSF;
pSF += 1;
for (int k = 0; k < blocks_per_group; k++) {
aie::vector<T_in, block_size> I0 = aie::load_v<block_size>(pI); // Load one block of input (32 int4s)
pI += block_size / 2; // Advance by the number of bytes
bfloat16 sf_bf16 = sf;
aie::vector<bfloat16, block_size> sf_broadcast = aie::broadcast(sf_bf16);
// Upsize these to 8 bits -> 16 -> bfloat16
aie::vector<uint8, block_size> asInt8 = aie::unpack(I0); // Unpack the 4 bit values to 8 bits
aie::vector<uint16, block_size> asInt16 = aie::unpack(asInt8); // Unpack the 8 bit values to 16 bits
aie::vector<bfloat16, block_size> as_bf16 = aie::to_float<bfloat16>(asInt16, 0); // Convert to bfloat16
aie::vector<bfloat16, block_size> scaled_bf16 =
aie::mul(as_bf16, sf_broadcast); // Scale the bfloat16 values
aie::store_v(pO,
scaled_bf16); // Write the scaled bfloat16 values to output
pO += block_size; // Advance by the number of bytes
}
}
event1();
}
extern "C" {
#ifndef GROUP_SIZE
#define GROUP_SIZE 32
#endif
#ifndef TILE_SIZE
#define TILE_SIZE 1024
#endif
void expand_int4_to_bfloat16(uint4 *a_in, bfloat16 *c_out)
{
expand<uint4, bfloat16, bfloat16, TILE_SIZE, GROUP_SIZE>(a_in, c_out);
}
} // extern "C"