-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathmock_engine.py
More file actions
307 lines (272 loc) · 10.5 KB
/
mock_engine.py
File metadata and controls
307 lines (272 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple test engine for the JetStream API described.
Contains simple functions that we can hand calculate the desired outcome of.
Prefill: Doubles the sequence by multiplying it with an integer weight.
Insert: Writes this sequence into a cache row.
Generate step: Return sum(prefill_cache) + sum(generate_cache)/weight.
I.e. if we prefill [2, 65, 66] (i.e. <BOS>, 'A', 'B') using an ACII vocab,
we should get [4, 130, 132].
If we then insert that and run three generation steps, we should see
266+0 / 2 = 266
266 + [266] /2 = 399
266 + [266, 399] /2 = 598
I.e. ['Ċ', 'Ə', 'ɖ'] when converted back with chr()
"""
import functools
from typing import Any, Optional, Tuple
from flax import struct
import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jetstream.engine import engine_api
from jetstream.engine import tokenizer_pb2
Params = jax.Array # [1,].
Prefix = jax.Array # [batch,] of strings with different lengths.
@struct.dataclass
class DecodeState:
"""The inputs into a generation step."""
prefill_cache: jax.Array
generate_cache: jax.Array
generate_cache_index: int
generate_lengths: jax.Array
generate_tokens: jax.Array
class TestEngine(engine_api.Engine):
"""The computational core of the generative model server.
Engine defines an API that models must adhere to as they plug into the
JetStream efficient serving infrastructure.
"""
def __init__(self, batch_size: int, cache_length: int, weight: float):
self.prefill_cache_batch = batch_size
self.generate_cache_batch = batch_size
self.cache_length = cache_length
self.weight = weight
self._mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((1, 1, 1), jax.devices()), ("x", "y", "z")
)
def load_params(self) -> Params:
"""Loads model weights."""
# An integer, used to multiply inputs.
return jnp.array([self.weight], dtype=jnp.float32)
@functools.partial(jax.jit, static_argnums=(0,))
def prefill(
self,
*,
params: Params,
existing_prefix: Optional[jax.Array] = None,
padded_tokens: jax.Array,
true_length: int,
) -> Tuple[Prefix, engine_api.ResultTokens]:
"""Computes a kv-cache for a new generate request.
Args:
params: Scalar multiplier.
existing_prefix: If provided, represents a prefix that has already been
processed by the underlying model.
padded_tokens: Logically appended tokens to any existing prefix, this is
what we compute prefill on.
true_length: The real length of the tokens, pre-pad.
Returns:
kv_cache: For the resulting text.
"""
if existing_prefix is not None:
raise NotImplementedError
del true_length
assert padded_tokens.ndim == 1
# Wait to simulate model step time.
fake_size = 4096
fake_work = jnp.ones((fake_size, fake_size)) @ jnp.ones(
(fake_size, fake_size)
)
# Do some fake work that isn't eliminated by dead code elimination (DCE).
params = params + fake_work.mean() - fake_work.mean()
prefill_cache = padded_tokens[None, :] * params
# get dummy first token
first_step = (prefill_cache.sum(axis=-1))[:, jnp.newaxis]
first_token_data = jnp.concatenate(
[first_step, jnp.ones_like(first_step), jnp.ones_like(first_step)],
axis=-1,
)
speculations = first_step.shape[1]
first_token = engine_api.ResultTokens(
data=first_token_data.astype(jnp.int32),
tokens_idx=(0, speculations),
# Validity occupies the same amount of space, but next in line.
valid_idx=(speculations, 2 * speculations),
# And lengths is rank 1.
length_idx=(2 * speculations, 2 * speculations + 1),
samples_per_slot=self.generate_cache_batch // self.prefill_cache_batch,
)
return (prefill_cache.astype(jnp.float32), first_step), first_token
@functools.partial(jax.jit, static_argnums=(0,))
def generate(
self, params: Params, decode_state: DecodeState
) -> Tuple[DecodeState, engine_api.ResultTokens]:
"""Generates tokens for each sequence being decoded in parallel."""
(
prefill_cache,
generate_cache,
generate_cache_index,
generate_lengths,
previous_timestep,
) = (
decode_state.prefill_cache,
decode_state.generate_cache,
decode_state.generate_cache_index,
decode_state.generate_lengths,
decode_state.generate_tokens,
)
# Update generate cache
generate_cache = jax.lax.dynamic_update_slice_in_dim(
generate_cache.astype(jnp.float32),
previous_timestep,
start_index=generate_cache_index,
axis=1,
)
generate_cache_index = (generate_cache_index + 1) % self.cache_length
# Sum each row of prefill cache and generate cache to produce new timestep,
# multiply by params.
l_iota = jax.lax.broadcasted_iota(
jnp.int32,
(self.generate_cache_batch, self.cache_length),
dimension=1,
)
# The generate cache should be circular and right aligned.
# TODO: Do we need a left aligned one to test spec sampling?
# Don't need the + 1 you normally would, because we don't provide a
# token from prefill in the dummy.
# This iota and masking is to allow for a cicular cache.
length_mask = (
-(l_iota - generate_cache_index) % self.cache_length
) <= generate_lengths[:, None]
length_masked_gen_cache = generate_cache * length_mask
new_timestep = (
prefill_cache.sum(axis=-1)
+ (length_masked_gen_cache.sum(axis=-1) / params)
)[:, jnp.newaxis]
# Wait to simulate model step time.
fake_size = 4096
fake_work = jnp.ones((fake_size, fake_size)) @ jnp.ones(
(fake_size, fake_size)
)
# Do some fake work that isn't eliminated by dead code elimination (DCE).
generate_cache = generate_cache + fake_work.mean() - fake_work.mean()
new_lengths = generate_lengths + 1
speculations = new_timestep.shape[1]
# Concatenates the tokens, their validity and the lengths of each sequence
# into one tensor so that copy operations are faster on Cloud TPU
# infrastructure.
token_data = jnp.concatenate(
[new_timestep, jnp.ones_like(new_timestep), new_lengths[:, None]],
axis=-1,
)
return DecodeState(
prefill_cache=prefill_cache,
generate_cache=generate_cache.astype(jnp.float32),
generate_cache_index=generate_cache_index,
generate_lengths=new_lengths,
generate_tokens=new_timestep,
), engine_api.ResultTokens(
data=token_data.astype(jnp.int32),
# Tokens are shape [batch, speculations], so when we concatenate
# tokens, validity and length along their index 1 dimension then they
# occupy 0:speculations.
tokens_idx=(0, speculations),
# Validity occupies the same amount of space, but next in line.
valid_idx=(speculations, 2 * speculations),
# And lengths is rank 1.
length_idx=(2 * speculations, 2 * speculations + 1),
samples_per_slot=self.generate_cache_batch // self.prefill_cache_batch,
)
@functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(2,))
def insert(
self,
prefix: Prefix,
decode_state: DecodeState,
slot: int,
) -> DecodeState:
"""Adds `prefix` into `decode_state` at `slot`."""
# [B, T], [T,] -> [B, T]
prefill_cache, previous_timestep = prefix
prefill_cache = jax.lax.dynamic_update_slice_in_dim(
decode_state.prefill_cache, prefill_cache, slot, axis=0
)
generate_cache = jax.lax.dynamic_update_slice_in_dim(
decode_state.generate_cache,
jnp.zeros((1, self.cache_length), dtype=jnp.float32),
slot,
axis=0,
)
samples_per_slot = self.generate_cache_batch // self.prefill_cache_batch
generate_lengths = jax.lax.dynamic_update_slice_in_dim(
decode_state.generate_lengths,
jnp.ones((samples_per_slot), dtype=jnp.int32),
slot * samples_per_slot,
axis=0,
)
generate_tokens = jax.lax.dynamic_update_slice_in_dim(
decode_state.generate_tokens,
previous_timestep.astype(jnp.float32),
slot * samples_per_slot,
axis=0,
)
return decode_state.replace(
prefill_cache=prefill_cache,
generate_cache=generate_cache,
generate_lengths=generate_lengths,
generate_tokens=generate_tokens,
)
def get_prefix_destination_sharding(self) -> Any:
return jax.sharding.NamedSharding(
mesh=self.mesh, spec=jax.sharding.PartitionSpec()
)
def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
"""Return a protobuf of tokenizer info, callable from Py or C++."""
return tokenizer_pb2.TokenizerParameters(path="test", extra_ids=0)
def init_decode_state(self) -> DecodeState:
"""Initialises any state which a generation step transforms."""
return DecodeState(
prefill_cache=jnp.zeros(
(self.prefill_cache_batch, self.cache_length), dtype=jnp.float32
),
generate_cache=jnp.zeros(
(self.generate_cache_batch, self.cache_length), dtype=jnp.float32
),
generate_cache_index=0,
generate_lengths=jnp.zeros(
(self.generate_cache_batch), dtype=jnp.int32
),
generate_tokens=jnp.zeros(
(self.generate_cache_batch, 1), dtype=jnp.float32
),
)
@property
def max_concurrent_decodes(self) -> int:
"""Free slots."""
return self.prefill_cache_batch
@property
def max_prefill_length(self) -> int:
"""Maximum prefill length."""
return self.cache_length
@property
def samples_per_slot(self) -> int:
"""Number of samples per slot."""
return self.generate_cache_batch // self.max_concurrent_decodes
@property
def mesh(self) -> jax.sharding.Mesh:
"""Mesh which the engine is running on."""
return self._mesh
@property
def colocated_cpus(self) -> None:
"""CPU devices colocated with the engine's accelerators."""
raise NotImplementedError