-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathQwenPipeline.cs
More file actions
237 lines (211 loc) · 9.62 KB
/
QwenPipeline.cs
File metadata and controls
237 lines (211 loc) · 9.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
// Copyright (c) TensorStack. All rights reserved.
// Licensed under the Apache 2.0 License.
using System;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using TensorStack.Common;
using TensorStack.Common.Pipeline;
using TensorStack.Common.Tensor;
using TensorStack.TextGeneration.Cache;
using TensorStack.TextGeneration.Common;
using TensorStack.TextGeneration.Processing;
using TensorStack.TextGeneration.Tokenizers;
namespace TensorStack.TextGeneration.Pipelines.Qwen
{
public class QwenPipeline : DecoderPipeline<GenerateOptions>,
IPipeline<GenerateResult, GenerateOptions, GenerateProgress>,
IPipeline<GenerateResult[], SearchOptions, GenerateProgress>
{
/// <summary>
/// Initializes a new instance of the <see cref="QwenPipeline"/> class.
/// </summary>
/// <param name="tokenizerConfig">The tokenizer configuration.</param>
/// <param name="decoderConfig">The decoder configuration.</param>
public QwenPipeline(QwenConfig configuration)
: base(configuration.Tokenizer, configuration.DecoderConfig)
{
Configuration = configuration;
}
public QwenConfig Configuration { get; }
/// <summary>
/// Runs the GreedySearch inference
/// </summary>
/// <param name="options">The options.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
public virtual async Task<GenerateResult> RunAsync(GenerateOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
await TokenizePromptAsync(options);
var sequence = await GreedySearchAsync(options, progressCallback, cancellationToken);
using (sequence)
{
return new GenerateResult
{
Score = sequence.Score,
Result = Tokenizer.Decode(sequence.Tokens),
Tokens = sequence.Tokens,
LastHiddenState = sequence.LastHiddenState
};
}
}
/// <summary>
/// Runs the BeamSearch inference
/// </summary>
/// <param name="options">The options.</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
await TokenizePromptAsync(options);
var sequences = await BeamSearchAsync(options, progressCallback, cancellationToken);
var results = new GenerateResult[sequences.Length];
for (int beam = 0; beam < sequences.Length; beam++)
{
var sequence = sequences[beam];
using (sequence)
{
results[beam] = new GenerateResult
{
Beam = beam,
Score = sequence.Score,
PenaltyScore = sequence.PenaltyScore,
Result = Tokenizer.Decode(sequence.Tokens),
Tokens = sequence.Tokens,
LastHiddenState = sequence.LastHiddenState
};
}
}
return results;
}
/// <summary>
/// Tokenize the prompt
/// </summary>
/// <param name="options">The options.</param>
/// <returns>A Task representing the asynchronous operation.</returns>
protected override async Task TokenizePromptAsync(GenerateOptions options)
{
var tokenizerResult = await Tokenizer.EncodeAsync(options.Prompt);
var inputIds = tokenizerResult.InputIds.Span.Pad(Tokenizer.EOS, options.MinLength);
var mask = tokenizerResult.Mask.Span.Pad(0, options.MinLength);
TokenizerOutput = new TokenizerResult(inputIds, mask);
}
/// <summary>
/// Gets the token processors.
/// </summary>
/// <param name="options">The options.</param>
/// <returns>ITokenProcessor[].</returns>
protected override ITokenProcessor[] GetTokenProcessors(GenerateOptions options)
{
return
[
new EOSTokenProcessor(options.MinLength, Tokenizer.EOS),
new MaxLengthTokenProcessor(options.MaxLength)
];
}
/// <summary>
/// Initialize the Decoder cache
/// </summary>
/// <param name="options">The options.</param>
/// <returns>A Task<Sequence> representing the asynchronous operation.</returns>
protected override async Task<Sequence> InitializeAsync(GenerateOptions options)
{
var modelMetadata = await Decoder.LoadAsync();
var kvCache = new KVCacheDecoder(modelMetadata, DecoderConfig.NumHeads, DecoderConfig.NumLayers, DecoderConfig.HiddenSize, DecoderConfig.NumKVHeads, options.MaxLength);
var sequence = new Sequence(kvCache, Tokenizer.BOS);
sequence.Initialize(0);
var position = TokenizerOutput.Length;
var inputIds = TokenizerOutput.InputIds;
var positionIds = GetPositionIds(modelMetadata, 0, position);
var attentionMask = new Tensor<long>([1, position], 1);
RunDecoderInternal(modelMetadata, sequence, inputIds, positionIds, attentionMask, false);
return sequence;
}
/// <summary>
/// Run decoder model
/// </summary>
/// <param name="sequence">The sequence.</param>
/// <returns>A Task<Tensor`1> representing the asynchronous operation.</returns>
protected override async Task<Tensor<float>> RunDecoderAsync(Sequence sequence)
{
var modelMetadata = await Decoder.LoadAsync();
var position = TokenizerOutput.Length + sequence.Tokens.Count;
var inputIds = new Tensor<long>([1, 1], sequence.Tokens[^1]);
var positionIds = GetPositionIds(modelMetadata, position);
var attentionMask = new Tensor<long>([1, position], 1);
return RunDecoderInternal(modelMetadata, sequence, inputIds, positionIds, attentionMask, true);
}
/// <summary>
/// Runs the decoder
/// </summary>
/// <param name="modelMetadata">The model metadata.</param>
/// <param name="sequence">The sequence.</param>
/// <param name="inputIds">The input ids.</param>
/// <param name="positionIds">The position ids.</param>
/// <param name="attentionMask">The attention mask.</param>
/// <param name="useBranchCache">if set to <c>true</c> [use branch cache].</param>
private Tensor<float> RunDecoderInternal(ModelMetadata modelMetadata, Sequence sequence, Tensor<long> inputIds, Tensor<long> positionIds, Tensor<long> attentionMask, bool useBranchCache)
{
using (var parameters = new ModelParameters(modelMetadata))
{
// Inputs
parameters.AddInput(inputIds);
parameters.AddInput(attentionMask);
if (positionIds != null)
parameters.AddInput(positionIds);
foreach (var pastKeyValue in sequence.Cache)
parameters.AddInput(pastKeyValue, false);
// Outputs
foreach (var output in modelMetadata.Outputs)
parameters.AddOutput();
// Result
var modelResult = Decoder.RunInference(parameters);
using (var logitsResult = modelResult[0])
{
var dimension = logitsResult.GetDimensions();
var logits = logitsResult.ToTensor(dimension[1..]);
var presentKeyValues = modelResult.ToArray()[1..];
sequence.UpdateCache(presentKeyValues, useBranchCache);
return logits;
}
}
}
/// <summary>
/// Creates the QwenPipeline
/// </summary>
/// <param name="provider">The provider.</param>
/// <param name="modelPath">The model path.</param>
/// <param name="model">The decoder model.</param>
/// <returns>QwenPipeline.</returns>
public static QwenPipeline Create(ExecutionProvider provider, string modelPath, string model = "model.onnx")
{
// Qwen-2.5 - https://huggingface.co/onnx-community/Qwen2.5-0.5B
var numHeads = 14;
var numLayers = 24;
var hiddenSize = 896;
var numKVHeads = 2;
var vocabSize = 151936;
var config = new QwenConfig
{
Tokenizer = new BPETokenizer(new TokenizerConfig
{
BOS = 151643,
EOS = 151643,
Path = modelPath
}),
DecoderConfig = new DecoderConfig
{
Path = Path.Combine(modelPath, model),
VocabSize = vocabSize,
NumHeads = numHeads,
NumLayers = numLayers,
HiddenSize = hiddenSize,
NumKVHeads = numKVHeads
}
};
config.DecoderConfig.SetProvider(provider);
return new QwenPipeline(config);
}
}
}