-
Notifications
You must be signed in to change notification settings - Fork 744
Expand file tree
/
Copy pathresource_manager_v1.py
More file actions
1524 lines (1363 loc) · 72.8 KB
/
resource_manager_v1.py
File metadata and controls
1524 lines (1363 loc) · 72.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import threading
import time
import traceback
from collections import deque
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Union
import numpy as np
import paddle
from fastdeploy import envs
from fastdeploy.cache_manager.multimodal_cache_manager import (
EncoderCacheManager,
ProcessorCacheManager,
)
from fastdeploy.config import ErnieArchitectures
from fastdeploy.engine.request import (
ImagePosition,
Request,
RequestOutput,
RequestStatus,
RequestType,
)
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.platforms import current_platform
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import download_from_bos, init_bos_client, llm_logger
@dataclass
class ScheduledDecodeTask:
"""
Task for allocating new blocks to decode.
"""
idx: int
request_id: str
block_tables: list[int]
task_type: RequestType = RequestType.DECODE
@dataclass
class ScheduledPreemptTask:
"""
Task for terminating inference to recycle resource.
"""
idx: int
request_id: str
task_type: RequestType = RequestType.PREEMPTED
@dataclass
class ScheduledExtendBlocksTask:
"""
Task for allocating new blocks to extend.
"""
idx: int
request_id: str
extend_block_tables: list[int]
task_type: RequestType = RequestType.EXTEND
@dataclass
class ScheduledAbortTask:
"""Task for allocating new blocks to skip."""
idx: int
request_id: str
task_type: RequestType = RequestType.ABORT
class SignalConsumer:
"""
A class that consumes a signal value up to a specified limit.
This class maintains an internal signal value and allows controlled consumption
of that signal. The signal can be watched at any time, but can only be consumed
a limited number of times before being reset to zero.
"""
def __init__(self, signal, consume_limit):
"""
Initialize the SignalConsumer with a signal value and consumption limit.
Args:
signal: The initial signal value to be consumed.
consume_limit (int): The maximum number of times the signal can be consumed
before being reset to 0. Must be a positive integer.
Raises:
AssertionError: If consume_limit is not greater than 0.
"""
assert consume_limit > 0
self._signal = signal
self._consume_limit = consume_limit
def watch(self):
"""
Get the current signal value without consuming it.
This method allows reading the signal value any number of times without
affecting the consumption limit or the signal value itself.
Returns:
The current signal value.
"""
return self._signal
def consume(self):
"""
Consume the signal value, decrementing the consumption limit.
This method returns the current signal value and decrements the consumption
counter. When the consumption limit reaches zero, the signal is automatically
reset to 0. The consumption happens in a finally block to ensure the limit is
decremented even if an exception occurs while processing the signal.
Returns:
The current signal value before consumption.
Note:
After the consumption limit is reached, this method will continue to
return 0 on subsequent calls.
"""
try:
return self._signal
finally:
if self._consume_limit > 0:
self._consume_limit -= 1
if self._consume_limit == 0:
self._signal = 0
class ResourceManagerV1(ResourceManager):
"""
Resource manager for scheduler v1.
In scheduler v1, all gpu blocks are managed by PrefixCacheManager.
Tasks sent to worker are divided into 3 types, PREFILL、DECODE and PREEMPTED.
For prefill task, the worker infer with one step and then stopped for this query if not all prompt tokens are computed.
For decode task, the work continues to decode until allocated blocks are exhausted.
For preempted task, the work reset all inputs to terminate the inference.
"""
def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id=0):
super(ResourceManagerV1, self).__init__(
max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id
)
# req_id -> Request
self.config = config
self.requests: dict[str, Request] = {}
# Priority queues for requests.
self.waiting: deque[Request] = deque()
self.running: list[Request] = []
self.preallocated_reqs: dict[str, Request] = {}
self.enable_max_prefill = envs.FD_ENABLE_MAX_PREFILL
self.finish_execution_pool = ThreadPoolExecutor(max_workers=1)
self.lock = threading.Lock()
self.to_be_rescheduled_request_id_set = set()
main_process_metrics.max_batch_size.set(max_num_seqs)
self.using_extend_tables_req_id = set()
self.reuse_block_num_map = dict()
self.abort_req_ids_set = set()
self.waiting_abort_req_id_set = set()
self.to_be_aborted_req_id_set = set()
# need block nums
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
self.need_block_num_signal = IPCSignal(
name="need_block_num_signal",
array=need_block_num_data,
dtype=np.int32,
suffix=self.config.parallel_config.local_engine_worker_queue_port,
create=True,
)
self.need_block_num_map = dict()
self.encoder_cache = None
if config.model_config.enable_mm and config.cache_config.max_encoder_cache > 0:
self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache)
self.processor_cache = None
if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0:
max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024)
self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes)
self.bos_client = None
self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4)
self.init_reserve_output_block_num = (
envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
) # int
self.decay_output_block_num = (
envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
) # float
self.min_reserve_output_block_num = (
envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
) # int
self.current_reserve_output_block_num = self.init_reserve_output_block_num
self.current_reserve_output_block_num_float = self.init_reserve_output_block_num
self.can_relax_prefill_strategy = True
# Scheduler-side requests that have not been moved into resource manager waiting queue yet.
self.scheduler_unhandled_request_num = 0
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size
def get_new_block_nums(self, request: Request, num_new_tokens: int):
block_num = (
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size - len(request.block_tables)
if self.config.speculative_config.method is not None:
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
return block_num
def _prepare_prefill_task(self, request, new_token_num):
request.prefill_start_index = request.num_computed_tokens
request.prefill_end_index = request.num_computed_tokens + new_token_num
request.task_type = RequestType.PREFILL
return request
def _prepare_decode_task(self, request):
return ScheduledDecodeTask(idx=request.idx, request_id=request.request_id, block_tables=request.block_tables)
def _prepare_preempt_task(self, request):
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
def _prepare_abort_task(self, request):
return ScheduledAbortTask(idx=request.idx, request_id=request.request_id)
def reschedule_preempt_task(self, request_id, process_func=None):
with self.lock:
llm_logger.debug(f"reschedule {request_id} into waiting queue")
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
request = self.requests[request_id]
request.has_been_preempted_before = True
request.metrics.preempted_count += 1
if process_func is not None:
process_func(request)
llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}")
self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id)
def recycle_abort_task(self, request_id):
with self.lock:
if request_id in self.to_be_aborted_req_id_set and request_id in self.requests:
request = self.requests[request_id]
self.tasks_list[request.idx] = None # 清空slot
self.stop_flags[request.idx] = True # 设置停止标志
del self.requests[request_id]
del self.req_dict[request_id]
self.to_be_aborted_req_id_set.remove(request_id)
self.update_metrics()
def _trigger_abort(self, request_id, scheduled_reqs):
if request_id in self.requests:
abort_request = self.requests[request_id]
abort_request.status = RequestStatus.PREEMPTED
abort_request.num_computed_tokens = 0
self._free_blocks(abort_request) # 释放KV cache blocks
abort_request.cached_block_num = 0
scheduled_reqs.append(self._prepare_abort_task(abort_request))
self.to_be_aborted_req_id_set.add(request_id)
self.waiting_abort_req_id_set.remove(request_id)
def _info_each_block(self):
"""
print each req block
"""
for req in self.running:
llm_logger.debug(
f"req idx {req.idx} occupy {len(req.block_tables)} block_tables and {len(req.extend_block_tables)} extend_block_tables"
)
def _can_preempt(self):
"""
cannot preempt request which use extend block
"""
for req in self.running:
if not req.use_extend_tables:
return True
return False
def preempted_all(self):
with self.lock:
preempted_reqs = []
for i in range(len(self.running)):
req = self.running.pop()
# txt2image: req.use_extend_tables is True, req can not be preempted. txt2image is not used in RL.
if req.use_extend_tables:
self.running.insert(0, req)
continue
req.status = RequestStatus.PREEMPTED
req.num_computed_tokens = 0
self._free_blocks(req)
req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(req.request_id)
trace_print(LoggingEventName.PREEMPTED, req.request_id, getattr(req, "user", ""))
preempted_reqs.append(self._prepare_preempt_task(req))
return preempted_reqs
def wait_worker_inflight_requests_finish(self, timeout=60):
count = 0
while count < timeout * 1000:
# wait ongoing running and rescheduled requests finished in worker
running_reqs_count = len(self.to_be_rescheduled_request_id_set) + len(self.running)
if running_reqs_count == 0:
break
count += 1
time.sleep(0.001)
if count >= timeout * 1000:
llm_logger.info(
f"wait_inflight_requests_finish timeout after {timeout} seconds, "
f"still {len(self.to_be_rescheduled_request_id_set)} requests running"
)
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
"""
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
"""
can_schedule = False
while self._can_preempt():
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
preempted_req = self.running.pop()
if preempted_req.use_extend_tables:
self.running.insert(0, preempted_req)
continue
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.config.scheduler_config.splitwise_role == "decode":
self.tasks_list[preempted_req.idx] = None
self.stop_flags[preempted_req.idx] = True
if preempted_req.request_id in self.requests:
del self.requests[preempted_req.request_id]
if preempted_req.request_id in self.req_dict:
del self.req_dict[preempted_req.request_id]
self._free_blocks(preempted_req)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
else:
self._free_blocks(preempted_req)
preempted_req.num_cached_blocks = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
trace_print(
LoggingEventName.PREEMPTED, preempted_req.request_id, getattr(preempted_req, "user", "")
)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
llm_logger.debug(
f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}"
)
llm_logger.debug(self.info())
self._info_each_block()
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
can_schedule = True
break
self.current_reserve_output_block_num = self.init_reserve_output_block_num
self.current_reserve_output_block_num_float = self.init_reserve_output_block_num
self.can_relax_prefill_strategy = False
return can_schedule
def _get_can_schedule_prefill_threshold_block(self, request, num_chunk_new_block):
if self.can_relax_prefill_strategy:
can_schedule_block_num_threshold = num_chunk_new_block
else:
can_schedule_block_num_threshold = (
num_chunk_new_block + len(self.running) * self.current_reserve_output_block_num
)
if self.config.speculative_config.method is not None:
can_schedule_block_num_threshold = min(
can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq
)
return can_schedule_block_num_threshold
def _update_mm_hashes(self, request):
if request.multimodal_inputs is None:
return
inputs = request.multimodal_inputs
if (
inputs.get("images", None) is not None
and inputs.get("image_patch_id", None) is not None
and inputs.get("grid_thw", None) is not None
and len(inputs["grid_thw"]) != 0
):
grid_thw = []
new_mm_positions, new_mm_hashes = [], []
image_st = 0
for idx, one in enumerate(inputs["grid_thw"]):
t, h, w = one[0], one[1], one[2]
if t == 1:
grid_thw.append(one)
new_mm_positions.append(inputs["mm_positions"][idx])
new_mm_hashes.append(inputs["mm_hashes"][idx])
image_st += h * w
else:
grid_thw.extend([[2, h, w]] * (t // 2))
token_st = inputs["mm_positions"][idx].offset
for _ in range(t // 2):
mm_num_token = inputs["mm_num_token_func"](grid_thw=[2, h, w])
new_mm_positions.append(ImagePosition(token_st, mm_num_token))
# videos are split into patches every 2 frames, need to rehash
new_mm_hashes.append(
MultimodalHasher.hash_features(inputs["images"][image_st : image_st + 2 * h * w])
)
image_st += 2 * h * w
token_st += mm_num_token
inputs["mm_positions"] = new_mm_positions
inputs["mm_hashes"] = new_mm_hashes
elif inputs.get("mm_positions", None) is None or inputs.get("mm_hashes", None) is None:
inputs["mm_positions"] = []
inputs["mm_hashes"] = []
def _is_mm_request(self, request):
inputs = request.multimodal_inputs
if inputs is None or len(inputs) == 0:
return False
if (
(inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0)
or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0)
or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0)
):
return True
elif (
inputs.get("images", None) is not None
and inputs.get("image_patch_id", None) is not None
and inputs.get("grid_thw", None) is not None
):
return True
return False
def revert_chunked_mm_input(self, mm_inputs, matched_token_num):
"""
revert mm_inputs that is chunked
"""
if mm_inputs is None or "mm_positions" not in mm_inputs or len(mm_inputs["mm_positions"]) == 0:
return matched_token_num
position_idx = len(mm_inputs["mm_positions"]) - 1
while matched_token_num > 0 and position_idx >= 0:
position = mm_inputs["mm_positions"][position_idx]
if position.offset < matched_token_num < position.offset + position.length:
matched_token_num = (
position.offset // self.config.cache_config.block_size
) * self.config.cache_config.block_size
position_idx -= 1
elif matched_token_num <= position.offset:
position_idx -= 1
elif matched_token_num >= position.offset + position.length:
break
else:
llm_logger.error(
f"revert_chunked_mm_input error, matched_token_num:{matched_token_num} position:{position}, {mm_inputs['mm_positions']}"
)
break
return matched_token_num
def _get_num_new_tokens(self, request, token_budget):
# TODO: set condition to new _get_num_new_tokens
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
if (
current_platform.is_intel_hpu()
and request.need_prefill_tokens - request.num_computed_tokens > token_budget
and token_budget > self.config.cache_config.block_size
):
num_new_tokens = token_budget // self.config.cache_config.block_size * self.config.cache_config.block_size
request.with_image = False
if not self.config.model_config.enable_mm:
return num_new_tokens
inputs = request.multimodal_inputs
if inputs.get("patch_idx", None) is not None and inputs.get("patch_map", None) is not None:
pre_end_idx = request.num_computed_tokens
new_end_idx = pre_end_idx + num_new_tokens
prompt_token_ids_len = len(request.prompt_token_ids)
if not inputs.get("tts", False):
assert prompt_token_ids_len == len(inputs["patch_idx"]), (
prompt_token_ids_len,
len(inputs["patch_idx"]),
)
def _compute_audio_prefix_count(end_idx, end_patch_idx):
audio_prefix_count = 0
pre_patch_end_idx = 0
for patch_idx in range(end_patch_idx + 1):
patch_map = inputs["patch_map"][patch_idx]
modal_id = patch_map["modal_id"]
if modal_id == IDS_TYPE_FLAG["audio"]:
if patch_idx != end_patch_idx:
audio_prefix_count += patch_map["end_idx"] - pre_patch_end_idx
else:
audio_prefix_count += end_idx - pre_patch_end_idx
pre_patch_end_idx = patch_map["end_idx"]
return audio_prefix_count
# start
if pre_end_idx >= prompt_token_ids_len:
start_patch_idx = inputs["patch_idx"][-1]
else:
start_patch_idx = inputs["patch_idx"][pre_end_idx]
if (
pre_end_idx > 0
and request.prompt_token_ids[pre_end_idx]
in [
inputs["image_patch_id"],
inputs["video_patch_id"],
inputs["audio_patch_id"],
]
and request.prompt_token_ids[pre_end_idx] != request.prompt_token_ids[pre_end_idx - 1]
):
# It just hit the starting position of the image / video / audio
start_patch_idx -= 1
start_patch_map = inputs["patch_map"][start_patch_idx]
request.image_start = start_patch_map["image_num"]
request.video_start = start_patch_map["video_num"]
request.audio_start = _compute_audio_prefix_count(pre_end_idx, start_patch_idx)
# end
if new_end_idx >= prompt_token_ids_len:
end_patch_idx = inputs["patch_idx"][-1]
else:
end_patch_idx = inputs["patch_idx"][new_end_idx]
if request.prompt_token_ids[new_end_idx] in [
inputs["image_end_id"],
inputs["video_end_id"],
inputs["audio_end_id"],
]:
end_patch_idx -= 1
end_patch_map = inputs["patch_map"][end_patch_idx]
end_modal_id = end_patch_map["modal_id"]
if end_modal_id == IDS_TYPE_FLAG["image"]:
new_end_idx = end_patch_map["end_idx"] # 当前模态结束位置
if end_modal_id == IDS_TYPE_FLAG["video"] and "can_split_idx_list" in inputs:
can_split_idx_list = inputs["can_split_idx_list"]
for i in range(len(can_split_idx_list)):
if can_split_idx_list[i] >= new_end_idx:
new_end_idx = can_split_idx_list[i]
break
num_new_tokens = new_end_idx - pre_end_idx
request.image_end = end_patch_map["image_num"]
request.video_end = end_patch_map["video_num"]
request.audio_end = _compute_audio_prefix_count(new_end_idx, end_patch_idx)
elif (
inputs.get("images", None) is not None
and inputs.get("image_patch_id", None) is not None
and inputs.get("grid_thw", None) is not None
):
input_ids_lst = request.prompt_token_ids + request.output_token_ids
input_ids = paddle.to_tensor(input_ids_lst, dtype="int64")
image_patch_id = inputs["image_patch_id"]
if request.multimodal_img_boundaries is None:
grid_thw = []
for idx, one in enumerate(inputs["grid_thw"]):
t, h, w = one[0], one[1], one[2]
if t == 1:
grid_thw.append(one)
else:
grid_thw.extend([[2, h, w]] * (t // 2))
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import get_img_boundaries
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
get_img_boundaries,
)
else:
from fastdeploy.model_executor.ops.gpu import get_img_boundaries
mm_num_token = inputs["mm_num_token_func"](grid_thw=grid_thw)
mm_num_token = paddle.to_tensor(mm_num_token, dtype="int64")
request.multimodal_img_boundaries = get_img_boundaries(
task_input_ids=input_ids, mm_num_token=mm_num_token, image_patch_id=image_patch_id
).numpy()
grid_thw = np.array(grid_thw).reshape([-1, 3])
inputs["grid_thw"] = grid_thw
grid_thw = inputs["grid_thw"]
img_boundaries_idx = request.multimodal_img_boundaries[0]
img_num_per_boundary = request.multimodal_img_boundaries[1]
ori_prompt_len = img_boundaries_idx[-1].item()
pre_end_idx = request.num_computed_tokens
new_end_idx = pre_end_idx + num_new_tokens
if new_end_idx < ori_prompt_len and input_ids[new_end_idx - 1] == image_patch_id:
boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item()
if boundary_idx == len(img_boundaries_idx):
new_end_idx = ori_prompt_len
else:
new_end_idx = img_boundaries_idx[boundary_idx].item()
elif new_end_idx >= ori_prompt_len and paddle.sum(input_ids[pre_end_idx:new_end_idx] == image_patch_id):
new_end_idx = ori_prompt_len
num_new_tokens = new_end_idx - pre_end_idx
image_mask = input_ids[pre_end_idx:new_end_idx] == image_patch_id
request.with_image = image_mask.any()
if request.with_image:
pre_boundary_idx = np.searchsorted(img_boundaries_idx, pre_end_idx, side="left").item()
if pre_boundary_idx == len(img_boundaries_idx):
request.num_image_start = img_num_per_boundary[-1]
else:
pre_boundary_idx = (
pre_boundary_idx
if pre_end_idx == img_boundaries_idx[pre_boundary_idx]
else pre_boundary_idx - 1
)
request.num_image_start = img_num_per_boundary[pre_boundary_idx]
new_boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item()
if new_boundary_idx == len(img_boundaries_idx):
request.num_image_end = img_num_per_boundary[-1]
else:
new_boundary_idx = (
new_boundary_idx
if new_end_idx == img_boundaries_idx[new_boundary_idx]
else new_boundary_idx - 1
)
request.num_image_end = img_num_per_boundary[new_boundary_idx]
request.image_type_ids_start = np.sum(grid_thw[: request.num_image_start, 0])
request.image_type_ids_end = np.sum(grid_thw[: request.num_image_end, 0])
request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1))
request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1))
if self.encoder_cache:
cur_mm_hashes = inputs["mm_hashes"][request.num_image_start : request.num_image_end]
cur_mm_positions = inputs["mm_positions"][request.num_image_start : request.num_image_end]
request.evict_mm_hashes = self.encoder_cache.apply_cache(cur_mm_hashes, cur_mm_positions)
# Compatible with scenarios without images and videos.
return num_new_tokens
def exist_mm_prefill(self, scheduled_reqs):
for request in scheduled_reqs:
if request.task_type == RequestType.PREFILL and self._is_mm_request(request):
return True
return False
def exist_prefill(self, scheduled_reqs):
for request in scheduled_reqs:
if request.task_type == RequestType.PREFILL:
return True
return False
def add_abort_req_ids(self, req_ids):
with self.lock:
if isinstance(req_ids, list):
self.waiting_abort_req_id_set.update(req_ids)
else:
self.waiting_abort_req_id_set.add(req_ids)
def cache_output_tokens(self, request):
if self.config.cache_config.enable_prefix_caching and self.config.cache_config.enable_output_caching:
with self.lock:
if request.num_computed_tokens >= request.need_prefill_tokens: # request is decoding
self.cache_manager.cache_output_blocks(request, self.config.cache_config.block_size)
def schedule(self):
"""
Try to pull a batch of requests from the waiting queue and schedule them.
"""
def get_enough_request(request, scheduled_reqs):
return (
ErnieArchitectures.is_ernie5_arch(self.config.model_config.architectures)
and self._is_mm_request(request)
and self.exist_mm_prefill(scheduled_reqs)
)
with self.lock:
scheduled_reqs: list[Request] = []
preempted_reqs: list[Request] = []
error_reqs: list[tuple[str, str]] = []
token_budget = self.config.scheduler_config.max_num_batched_tokens
need_abort_requests = [] # users trigger abortion
# First, schedule the RUNNING requests.
req_index = 0
num_decoding_req_nums = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
need_block_num = self.need_block_num_signal.value[request.idx]
if need_block_num != 0:
self.need_block_num_map[request.request_id] = SignalConsumer(need_block_num, 1)
self.need_block_num_signal.value[request.idx] = 0
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
if (
self.config.scheduler_config.splitwise_role == "prefill"
): # do not need to schedule for decoding
req_index += 1
continue
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1
if request.request_id in self.waiting_abort_req_id_set:
self._trigger_abort(request.request_id, scheduled_reqs)
req_index += 1
need_abort_requests.append(request)
continue
if (
self.allocated_slots(request) - request.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
):
# Allocation for next decoding blocks
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
llm_logger.debug(
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
)
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(
self.config.cache_config.enc_dec_block_num, request.request_id
)
)
# Prepare decoding task
scheduled_reqs.append(self._prepare_decode_task(request))
else:
# Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(
request, self.config.cache_config.enc_dec_block_num, preempted_reqs, scheduled_reqs
)
if not can_schedule:
break
# Allocation for next decoding blocks
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(
self.config.cache_config.enc_dec_block_num, request.request_id
)
)
# Prepare decoding task
scheduled_reqs.append(self._prepare_decode_task(request))
num_decoding_req_nums += 1
token_budget -= 1
if (
request.use_extend_tables
and request.request_id not in self.using_extend_tables_req_id
and self.need_block_num_map[request.request_id].watch() > 0
):
def _allocate_decode_and_extend():
allocate_block_num = self.need_block_num_map[request.request_id].consume()
# Prepare decoding task
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id)
)
scheduled_reqs.append(self._prepare_decode_task(request))
# Prepare extend task
reuse_block_num = request.num_total_tokens // self.config.cache_config.block_size
llm_logger.info(
f"req {request.request_id} at batch id {request.idx} with reuse_block_num {reuse_block_num} is going to enable extend tables,"
f"need_block_num {allocate_block_num}"
)
self.using_extend_tables_req_id.add(request.request_id)
self.reuse_block_num_map[request.request_id] = reuse_block_num
request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache
request.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id)
)
scheduled_reqs.append(
ScheduledExtendBlocksTask(
idx=request.idx,
request_id=request.request_id,
extend_block_tables=request.extend_block_tables,
)
)
llm_logger.debug(f"extend blocks is {request.extend_block_tables}")
if self.cache_manager.can_allocate_gpu_blocks(
2 * self.need_block_num_map[request.request_id].watch()
):
_allocate_decode_and_extend()
else:
llm_logger.info(
f"{request.idx} using extend block need {2 * self.need_block_num_map[request.request_id].watch()} blocks but got not enough blocks, ready to preempt"
)
can_schedule = self._trigger_preempt(
request,
2 * self.need_block_num_map[request.request_id].watch(),
preempted_reqs,
scheduled_reqs,
)
if can_schedule:
_allocate_decode_and_extend()
else:
break
else: # need to prefill
llm_logger.debug(
f"scheduler prefill task in running queue: {request.request_id}, "
f"request.need_prefill_tokens {request.need_prefill_tokens},"
f"request.num_computed_tokens {request.num_computed_tokens}"
)
if (
current_platform.is_intel_hpu()
and request.need_prefill_tokens - request.num_computed_tokens
>= self.config.cache_config.block_size
and token_budget < self.config.cache_config.block_size
):
req_index += 1
continue
if get_enough_request(request, scheduled_reqs):
req_index += 1
continue
num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
else: # Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
if not can_schedule:
break
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
req_index += 1
# remove requests to be aborted from running list
for request in need_abort_requests:
self.running.remove(request)
# Second, schedule the WAITING requests.
if not preempted_reqs:
skip_requests: list[Request] = []
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_seqs:
break
request = self.waiting[0]
if get_enough_request(request, scheduled_reqs):
break
if request.status == RequestStatus.WAITING:
result = self.waiting_async_process(request)
if result is None:
error_reqs.append((request.request_id, request.error_message))
self.waiting.popleft()
continue
elif result is True:
# skip current request, try next request
skip_requests.append(request)
self.waiting.popleft()
continue
self._update_mm_hashes(request)
# Enable prefix caching
if self.config.cache_config.enable_prefix_caching:
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
break
if (
current_platform.is_intel_hpu()
and request.need_prefill_tokens - request.num_computed_tokens
>= self.config.cache_config.block_size
and token_budget < self.config.cache_config.block_size
):
continue
# Allocate blocks for the tokens that does not hit cache
num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens)
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(
request, num_new_block
)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
if not request.get("skip_allocate", False):
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
num_new_block, request.request_id
)
request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft()
self.running.append(request)
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING
if self.config.scheduler_config.splitwise_role == "mixed":
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[allocated_position] = request
self.stop_flags[allocated_position] = False
self.req_dict[request.request_id] = allocated_position
llm_logger.debug(f"req_id:{request.request_id} allocate pos end")
else:
if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request)
break
elif request.status == RequestStatus.PREEMPTED:
request.need_prefill_tokens = (
request.num_total_tokens
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
if self.config.cache_config.enable_prefix_caching:
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
break
# Allocate blocks for the tokens that does not hit cache
num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens)
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(
request, num_new_block
)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
if not request.get("skip_allocate", False):
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
num_new_block, request.request_id
)
request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft()
self.running.append(request)
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens