-
Notifications
You must be signed in to change notification settings - Fork 743
Expand file tree
/
Copy pathsetup_ops.py
More file actions
830 lines (767 loc) · 33.1 KB
/
setup_ops.py
File metadata and controls
830 lines (767 loc) · 33.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""setup for FastDeploy custom ops"""
import importlib
import json
import os
import shutil
import subprocess
import sys
import tarfile
from pathlib import Path
import paddle
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
from setuptools import find_namespace_packages, find_packages
def load_module_from_path(module_name, path):
"""
load python module from path
"""
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def update_git_repo():
try:
print("update third party repo...", flush=True)
original_dir = os.getcwd()
submodule_dir = os.path.dirname(os.path.abspath(__file__))
third_party_path = os.path.join(submodule_dir, "third_party")
root_path = Path(third_party_path)
# check if third_party is empty
update_third_party = False
for dirpath in root_path.iterdir():
if dirpath.is_dir():
has_content = any(dirpath.iterdir())
if not has_content:
update_third_party = True
if update_third_party:
os.chdir(submodule_dir)
subprocess.run(
"git submodule sync --recursive && git submodule update --init --recursive",
shell=True,
check=True,
text=True,
)
else:
print(
"\033[33m[===WARNING===]third_party directory already exists, skip clone and update.\033[0m",
flush=True,
)
# apply deep gemm patch
deep_gemm_dir = "third_party/DeepGEMM"
dst_path = os.path.join(submodule_dir, deep_gemm_dir)
patch = "0001-DeepGEMM-95e81b3.patch"
patch_source = os.path.join(submodule_dir, patch)
patch_destination = os.path.join(dst_path, patch)
if not os.path.exists(patch_destination):
shutil.copy(patch_source, patch_destination)
apply_cmd = ["git", "apply", patch]
os.chdir(dst_path)
subprocess.run(apply_cmd, check=True)
os.chdir(original_dir)
except subprocess.CalledProcessError:
raise Exception("Git submodule update and apply patch failed. Maybe network connection is poor.")
ROOT_DIR = Path(__file__).parent.parent
# cannot import envs directly because it depends on fastdeploy,
# which is not installed yet
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
archs = json.loads(envs.FD_BUILDING_ARCS)
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
update_git_repo()
def download_and_extract(url, destination_directory):
"""
Download a .tar.gz file using wget to the destination directory
and extract its contents without renaming the downloaded file.
:param url: The URL of the .tar.gz file to download.
:param destination_directory: The directory where the file should be downloaded and extracted.
"""
os.makedirs(destination_directory, exist_ok=True)
filename = os.path.basename(url)
file_path = os.path.join(destination_directory, filename)
try:
subprocess.run(
["wget", "-O", file_path, url],
check=True,
)
print(f"Downloaded: {file_path}")
with tarfile.open(file_path, "r:gz") as tar:
tar.extractall(path=destination_directory)
print(f"Extracted: {file_path} to {destination_directory}")
os.remove(file_path)
print(f"Deleted downloaded file: {file_path}")
except subprocess.CalledProcessError as e:
print(f"Error downloading file: {e}")
except Exception as e:
print(f"Error extracting file: {e}")
def get_sm_version(archs):
"""
Get sm version of paddle.
"""
arch_set = set(archs)
if len(arch_set) == 0:
try:
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
arch_set.add(cc)
except ValueError:
pass
return list(arch_set)
def get_nvcc_version():
"""
Get cuda version of nvcc.
"""
nvcc_output = subprocess.check_output(["nvcc", "--version"], universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = float(output[release_idx].split(",")[0])
return nvcc_cuda_version
def get_gencode_flags(archs):
"""
Get gencode flags for current device or input.
"""
cc_s = get_sm_version(archs)
flags = []
for cc_val in cc_s:
if cc_val == 90:
arch_code = "90a"
flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
arch_code = "100a"
flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
else:
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
return flags
def get_compile_parallelism():
"""
Decide safe compile parallelism for both build workers and nvcc threads.
"""
cpu_count = os.cpu_count() or 1
max_jobs_env = os.getenv("MAX_JOBS")
if max_jobs_env is not None:
try:
max_jobs = int(max_jobs_env)
if max_jobs < 1:
raise ValueError
except ValueError as exc:
raise ValueError(f"Invalid MAX_JOBS={max_jobs_env!r}, expected a positive integer.") from exc
else:
# Cap default build workers to avoid OOM in high-core CI runners.
max_jobs = min(cpu_count, 32)
os.environ["MAX_JOBS"] = str(max_jobs)
# Limit nvcc internal threads to avoid resource exhaustion when Paddle's
# ThreadPoolExecutor also launches many parallel compilations.
# Total threads ~= (number of parallel compile jobs) * nvcc_threads.
nvcc_threads = min(max_jobs, 4)
return max_jobs, nvcc_threads
def find_end_files(directory, end_str):
"""
Find files with end str in directory.
"""
gen_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(end_str):
gen_files.append(os.path.join(root, file))
return gen_files
if paddle.is_compiled_with_rocm():
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
# so we need to check if paddle compiled with rocm at first.
sources = [
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
"gpu_ops/save_output_msg_with_topk.cc",
"gpu_ops/transfer_output.cc",
"gpu_ops/set_value_by_flags_and_idx.cu",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/step.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/unset_data_ipc.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/step_system_cache.cu",
"gpu_ops/get_output_ep.cc",
"gpu_ops/speculate_decoding/speculate_get_padding_offset.cu",
"gpu_ops/speculate_decoding/speculate_get_output.cc",
"gpu_ops/share_external_data.cu",
"gpu_ops/speculate_decoding/speculate_clear_accept_nums.cu",
"gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu",
"gpu_ops/speculate_decoding/speculate_get_seq_lens_output.cu",
"gpu_ops/speculate_decoding/speculate_save_output.cc",
"gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu",
"gpu_ops/speculate_decoding/speculate_step.cu",
"gpu_ops/speculate_decoding/speculate_step_system_cache.cu",
"gpu_ops/speculate_decoding/speculate_update_v3.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/step_reschedule.cu",
]
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
sources=sources,
extra_compile_args={
"cxx": ["-O3"],
"hipcc": [
"-O3",
"--gpu-max-threads-per-block=1024",
"-U__HIP_NO_HALF_OPERATORS__",
"-U__HIP_NO_HALF_CONVERSIONS__",
"-U__HIP_NO_BFLOAT16_OPERATORS__",
"-U__HIP_NO_BFLOAT16_CONVERSIONS__",
"-U__HIP_NO_BFLOAT162_OPERATORS__",
"-U__HIP_NO_BFLOAT162_CONVERSIONS__",
"-DPADDLE_DEV",
"-Ithird_party/nlohmann_json/include",
"-Igpu_ops",
],
},
),
)
elif paddle.is_compiled_with_cuda():
sources = [
"gpu_ops/helper.cu",
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
"gpu_ops/save_output_msg_with_topk.cc",
"gpu_ops/transfer_output.cc",
"gpu_ops/set_mask_value.cu",
"gpu_ops/set_value_by_flags_and_idx.cu",
"gpu_ops/ngram_mask.cu",
"gpu_ops/gather_idx.cu",
"gpu_ops/get_output_ep.cc",
"gpu_ops/get_mm_split_fuse.cc",
"gpu_ops/get_img_boundaries.cc",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/token_penalty_only_once.cu",
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/set_flags.cu",
"gpu_ops/update_inputs_v1.cu",
"gpu_ops/recover_decode_task.cu",
"gpu_ops/step.cu",
"gpu_ops/step_reschedule.cu",
"gpu_ops/fused_get_rotary_embedding.cu",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/update_inputs_beam.cu",
"gpu_ops/beam_search_softmax.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/unset_data_ipc.cu",
"gpu_ops/read_data_ipc.cu",
"gpu_ops/enforce_generation.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/tune_cublaslt_gemm.cu",
"gpu_ops/swap_cache_batch.cu",
"gpu_ops/swap_cache.cu",
"gpu_ops/swap_cache_layout.cu",
"gpu_ops/swap_cache_optimized.cu", # 新增:优化的 KV cache 换入算子
"gpu_ops/step_system_cache.cu",
"gpu_ops/cpp_extensions.cc",
"gpu_ops/share_external_data.cu",
"gpu_ops/fused_mask_swiglu_fp8_quant_kernel.cu",
"gpu_ops/per_token_quant_fp8.cu",
"gpu_ops/update_split_fuse_input.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/noaux_tc_redundant.cu",
"gpu_ops/custom_all_reduce/all_reduce.cu",
"gpu_ops/merge_prefill_decode_output.cu",
"gpu_ops/limit_thinking_content_length.cu",
"gpu_ops/update_attn_mask_offsets.cu",
"gpu_ops/fused_neox_rope_embedding.cu",
"gpu_ops/gelu_tanh.cu",
"gpu_ops/reasoning_phase_token_constraint.cu",
"gpu_ops/get_attn_mask_q.cu",
]
sm_versions = get_sm_version(archs)
# Some kernels in this file require SM75+ instructions. Exclude them when building SM70 (V100).
disable_gelu_tanh = 70 in sm_versions
if disable_gelu_tanh:
sources = [s for s in sources if s != "gpu_ops/gelu_tanh.cu"]
# pd_disaggregation
sources += [
"gpu_ops/remote_cache_kv_ipc.cc",
"gpu_ops/open_shm_and_get_meta_signal.cc",
"gpu_ops/init_signal_layerwise.cc",
"gpu_ops/get_data_ptr_ipc.cu",
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
]
dg_third_party_include_dirs = (
"third_party/cutlass/include/cute",
"third_party/cutlass/include/cutlass",
)
dg_include_dir = "third_party/DeepGEMM/deep_gemm/include"
os.makedirs(dg_include_dir, exist_ok=True)
for d in dg_third_party_include_dirs:
dirname = d.split("/")[-1]
src_dir = d
dst_dir = os.path.join(dg_include_dir, dirname)
# Remove existing directory if it exists
if os.path.exists(dst_dir):
if os.path.islink(dst_dir):
os.unlink(dst_dir)
else:
shutil.rmtree(dst_dir)
print(f"Copying {src_dir} to {dst_dir}")
# Copy the directory
try:
shutil.copytree(src_dir, dst_dir)
except Exception as e:
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
cc_compile_args = []
nvcc_compile_args = get_gencode_flags(archs)
if disable_gelu_tanh:
cc_compile_args += ["-DDISABLE_GELU_TANH_OP"]
nvcc_compile_args += ["-DDISABLE_GELU_TANH_OP"]
nvcc_compile_args += ["-DPADDLE_DEV"]
nvcc_compile_args += ["-DPADDLE_ON_INFERENCE"]
nvcc_compile_args += ["-DPy_LIMITED_API=0x03090000"]
nvcc_compile_args += [
"-Igpu_ops/cutlass_kernels",
"-Ithird_party/cutlass/include",
"-Ithird_party/cutlass/tools/util/include",
"-Igpu_ops/fp8_gemm_with_cutlass",
"-Igpu_ops",
"-Ithird_party/nlohmann_json/include",
]
max_jobs, nvcc_threads = get_compile_parallelism()
print(f"MAX_JOBS = {max_jobs}, nvcc -t = {nvcc_threads}")
nvcc_compile_args += ["-t", str(nvcc_threads)]
nvcc_version = get_nvcc_version()
print(f"nvcc_version = {nvcc_version}")
# CUDA 13.0+ (CCCL 3.0) changes the default -static-global-template-stub behavior
# Restore old linking behavior to allow kernel symbols to be visible in shared libraries
if nvcc_version >= 13.0:
nvcc_compile_args += ["-static-global-template-stub=false"]
if nvcc_version >= 12.0:
sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
cc = max(sm_versions)
print(f"cc = {cc}")
fp8_auto_gen_directory = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
if os.path.isdir(fp8_auto_gen_directory):
shutil.rmtree(fp8_auto_gen_directory)
if cc >= 75:
cc_compile_args += ["-DENABLE_SM75_EXT_OPS"]
nvcc_compile_args += [
"-DENABLE_SM75_EXT_OPS",
"-DENABLE_SCALED_MM_C2X=1",
"-Igpu_ops/cutlass_kernels/w8a8",
]
sources += [
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu",
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu",
"gpu_ops/quantization/common.cu",
# cpp_extensions.cc always registers these two ops; include their kernels on SM75 as well.
"gpu_ops/moe/moe_deepgemm_permute.cu",
"gpu_ops/moe/moe_deepgemm_depermute.cu",
]
if cc >= 80:
cc_compile_args += ["-DENABLE_SM80_EXT_OPS"]
nvcc_compile_args += ["-DENABLE_SM80_EXT_OPS"]
# append_attention
os.system(
"python utils/auto_gen_template_instantiation.py --config gpu_ops/append_attn/template_config.json --output gpu_ops/append_attn/template_instantiation/autogen"
)
sources += ["gpu_ops/append_attention.cu"]
sources += find_end_files("gpu_ops/append_attn", ".cu")
# sparse indexer
sources += find_end_files("gpu_ops/sparse_indexer", ".cu")
# mla
sources += ["gpu_ops/multi_head_latent_attention.cu"]
# gemm_dequant
sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"]
# speculate_decoding
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
nvcc_compile_args += ["-DENABLE_BF16"]
# moe
os.system("python gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py")
os.system(
"python utils/auto_gen_template_instantiation.py --config gpu_ops/moe/template_config.json --output gpu_ops/moe/template_instantiation/autogen"
)
sources += find_end_files("gpu_ops/cutlass_kernels/moe_gemm/", ".cu")
sources += find_end_files("gpu_ops/cutlass_kernels/w4a8_moe/", ".cu")
sources += find_end_files("gpu_ops/moe/", ".cu")
nvcc_compile_args += ["-Igpu_ops/moe"]
if cc >= 89:
# Running generate fp8 gemm codes.
# Common for SM89, SM90, SM100 (Blackwell)
nvcc_compile_args += ["-DENABLE_FP8"]
nvcc_compile_args += ["-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"]
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
if cc >= 90: # Hopper and newer
# SM90 (Hopper) specific auto-generation and flags
if cc == 90: # Only for SM90
nvcc_compile_args += [
# The gencode for 90a is added in get_gencode_flags now
# "-gencode",
# "arch=compute_90a,code=compute_90a",
"-O3",
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
]
print("SM90: Running SM90-specific FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py")
os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py")
nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1",
]
sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
]
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
print("SM100 (Blackwell): Applying SM100 configurations.")
nvcc_compile_args += [
# The gencode for 100a is added in get_gencode_flags
# "-gencode",
# "arch=compute_100a,code=compute_100a",
"-O3", # Common optimization flag
"-DNDEBUG", # Common debug flag
# Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
]
# Placeholder for SM100-specific kernel auto-generation scripts
# These might be needed if Blackwell has new FP8 hardware features
# not covered by existing generic CUTLASS templates or SM90 scripts.
# print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).")
# os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example
# os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example
# Add SM100 specific sources if any, e.g., for new hardware intrinsics
# sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
pass # No SM100 specific sources identified yet beyond what CUTLASS handles
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else: # For cc == 89 (Ada)
print("SM89: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
# Common FP8 sources for SM89+
sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
"gpu_ops/fp8_gemm_with_cutlass/per_channel_fp8_fp8_half_gemm.cu",
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu",
"gpu_ops/scaled_gemm_f8_i4_f16_gemm.cu",
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
"gpu_ops/fused_hadamard_quant_fp8.cu",
]
sources += find_end_files(fp8_auto_gen_directory, ".cu")
if cc >= 90 and nvcc_version >= 12.0:
# Hopper optimized mla
sources += find_end_files("gpu_ops/mla_attn", ".cu")
sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"]
cc_compile_args += ["-DENABLE_FLASH_MASK_ATTENTION"]
sources += find_end_files("gpu_ops/moba_attn/moba_decoder_attn/", ".cu")
sources += find_end_files("gpu_ops/moba_attn/moba_encoder_attn/", ".cu")
sources += find_end_files("gpu_ops/moba_attn/moba_process/", ".cu")
sources += ["gpu_ops/moba_attn/moba_attn.cu"]
os.system("python utils/auto_gen_w4afp8_gemm_kernel.py")
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
os.system("python utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py")
sources += find_end_files("gpu_ops/wfp8afp8_sparse_gemm", ".cu")
os.system("python gpu_ops/machete/generate.py")
sources += find_end_files("gpu_ops/machete", ".cu")
cc_compile_args += ["-DENABLE_MACHETE"]
# Deduplicate translation units while preserving order. Some files are
# appended explicitly for SM75 and also discovered by later directory globs.
sources = list(dict.fromkeys(sources))
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
sources=sources,
extra_compile_args={"cxx": cc_compile_args, "nvcc": nvcc_compile_args},
libraries=["cublasLt"],
extra_link_args=["-lcuda", "-lnvidia-ml"],
),
packages=find_packages(where="third_party/DeepGEMM"),
package_dir={"": "third_party/DeepGEMM"},
package_data={
"deep_gemm": [
"include/deep_gemm/**/*",
"include/cute/**/*",
"include/cutlass/**/*",
]
},
include_package_data=True,
)
elif paddle.is_compiled_with_xpu():
assert False, "For XPU, please use setup_ops.py in the xpu_ops directory to compile custom ops."
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
_iluvatar_clang_cuda_flags = ["-Wno-non-pod-varargs", "-DPADDLE_DEV", "-DPADDLE_WITH_CUSTOM_DEVICE"]
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
extra_compile_args={
"cxx": _iluvatar_clang_cuda_flags,
"nvcc": _iluvatar_clang_cuda_flags,
},
sources=[
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
"gpu_ops/save_output_msg_with_topk.cc",
"gpu_ops/transfer_output.cc",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/set_value_by_flags_and_idx.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/step.cu",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/limit_thinking_content_length.cu",
"gpu_ops/recover_decode_task.cu",
"gpu_ops/update_inputs_v1.cu",
"gpu_ops/get_img_boundaries.cc",
"gpu_ops/fused_neox_rope_embedding.cu",
"gpu_ops/get_output_ep.cc",
"iluvatar_ops/moe_dispatch.cu",
"iluvatar_ops/moe_reduce.cu",
"iluvatar_ops/flash_attn_unpadded.cu",
"iluvatar_ops/paged_attn.cu",
"iluvatar_ops/prefill_fused_attn.cu",
"iluvatar_ops/mixed_fused_attn.cu",
"iluvatar_ops/w8a16_group_gemm.cu",
"iluvatar_ops/w8a16_group_gemv.cu",
"iluvatar_ops/wi4a16_group_gemm.cu",
"iluvatar_ops/wi4a16_weight_quantize.cu",
"iluvatar_ops/restore_tokens_per_expert.cu",
"iluvatar_ops/runtime/iluvatar_context.cc",
"iluvatar_ops/cpp_extensions.cc",
],
include_dirs=["iluvatar_ops/runtime", "gpu_ops"],
extra_link_args=[
"-lcuinfer",
],
),
)
elif paddle.is_compiled_with_custom_device("gcu"):
setup(
name="fastdeploy_ops",
ext_modules=CppExtension(
sources=[
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
]
),
)
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
maca_path = os.getenv("MACA_PATH", "/opt/maca")
sources = [
"gpu_ops/update_inputs_v1.cu",
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
"gpu_ops/save_output_msg_with_topk.cc",
"gpu_ops/transfer_output.cc",
"gpu_ops/save_with_output.cc",
"gpu_ops/set_mask_value.cu",
"gpu_ops/set_value_by_flags_and_idx.cu",
"gpu_ops/ngram_mask.cu",
"gpu_ops/gather_idx.cu",
"gpu_ops/get_output_ep.cc",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/token_penalty_only_once.cu",
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/set_flags.cu",
"gpu_ops/fused_get_rotary_embedding.cu",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/update_inputs_beam.cu",
"gpu_ops/beam_search_softmax.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/step.cu",
"gpu_ops/step_reschedule.cu",
"gpu_ops/step_system_cache.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/read_data_ipc.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/share_external_data.cu",
"gpu_ops/recover_decode_task.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/noaux_tc_redundant.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/limit_thinking_content_length.cu",
"gpu_ops/update_attn_mask_offsets.cu",
"gpu_ops/append_attn/mla_cache_kernel.cu",
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/moe/moe_topk_select.cu",
"gpu_ops/get_img_boundaries.cc",
"gpu_ops/remote_cache_kv_ipc.cc",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
"gpu_ops/get_data_ptr_ipc.cu",
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
"gpu_ops/unset_data_ipc.cu",
"gpu_ops/swap_cache_batch.cu",
"gpu_ops/gelu_tanh.cu",
"metax_ops/moe_dispatch.cu",
"metax_ops/moe_ffn.cu",
"metax_ops/moe_reduce.cu",
"metax_ops/fused_moe.cu",
"metax_ops/cache_kv_with_rope.cu",
"metax_ops/cpp_extensions.cc",
"metax_ops/split_merge_qkv.cu",
]
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
metax_extra_compile_args = {
"cxx": ["-O3"],
"nvcc": [
"-O3",
"-Ithird_party/nlohmann_json/include",
"-Igpu_ops",
"-DPADDLE_DEV",
"-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU",
],
}
def get_maca_version(version_file: str = "/opt/maca/Version.txt") -> list[int]:
try:
with open(version_file, "r", encoding="utf-8") as f:
version_str = f.readline().strip()
target_version = [int(part) for part in version_str.split(":")[1].split(".")]
except Exception as e:
print(f"Trigger exception: {type(e).__name__} - {e}")
raise
return target_version
maca_version = get_maca_version(f"{maca_path}/Version.txt")
if len(maca_version) == 4:
major_version = maca_version[0]
minor_version = maca_version[1]
patch_version = maca_version[2]
build_version = maca_version[3]
cur_maca_version = (
((major_version & 0xFF) << 24)
| ((minor_version & 0xFF) << 16)
| ((patch_version & 0xFF) << 8)
| ((build_version & 0xFF) << 0)
)
metax_extra_compile_args["nvcc"].append(f"-DMACA_VERSION={cur_maca_version}")
else:
raise ValueError(f"MACA version invalid - {maca_version}")
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
sources=sources,
extra_compile_args=metax_extra_compile_args,
library_dirs=[os.path.join(maca_path, "lib")],
extra_link_args=["-lruntime_cu", "-lmctlassEx"],
include_dirs=[
os.path.join(maca_path, "include"),
os.path.join(maca_path, "include/mcr"),
os.path.join(maca_path, "include/common"),
os.path.join(maca_path, "include/mcfft"),
os.path.join(maca_path, "include/mcrand"),
os.path.join(maca_path, "include/mcsparse"),
os.path.join(maca_path, "include/mcblas"),
os.path.join(maca_path, "include/mcsolver"),
],
),
)
elif paddle.is_compiled_with_custom_device("intel_hpu"):
setup(
name="fastdeploy_ops",
ext_modules=CppExtension(
sources=[
"gpu_ops/get_output.cc",
]
),
)
else:
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
# cc flags
paddle_extra_compile_args = [
"-std=c++17",
"-shared",
"-fPIC",
"-Wno-parentheses",
"-DPADDLE_WITH_CUSTOM_KERNEL",
"-DPADDLE_ON_INFERENCE",
"-Wall",
"-O3",
"-g",
"-lstdc++fs",
"-D_GLIBCXX_USE_CXX11_ABI=1",
"-DPy_LIMITED_API=0x03090000",
]
setup(
name="fastdeploy_cpu_ops",
ext_modules=CppExtension(
sources=[
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
"gpu_ops/save_output_msg_with_topk.cc",
"gpu_ops/transfer_output.cc",
"cpu_ops/rebuild_padding.cc",
"cpu_ops/simd_sort.cc",
"cpu_ops/set_value_by_flags.cc",
"cpu_ops/token_penalty_multi_scores.cc",
"cpu_ops/stop_generation_multi_ends.cc",
"cpu_ops/update_inputs.cc",
"cpu_ops/get_padding_offset.cc",
],
extra_link_args=[
"-Wl,-rpath,$ORIGIN/x86-simd-sort/builddir",
"-Wl,-rpath,$ORIGIN/xFasterTransformer/build",
],
extra_compile_args=paddle_extra_compile_args,
),
packages=find_namespace_packages(where="third_party"),
package_dir={"": "third_party"},
include_package_data=True,
)