Skip to content

Commit b9571ae

Browse files
committed
update
1 parent 2a5c018 commit b9571ae

19 files changed

Lines changed: 1546 additions & 10 deletions

include/flexflow/ffconst.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ enum OperatorType {
198198
OP_RESIDUAL_RMS_NORM,
199199
OP_BEAM_TOPK,
200200
OP_ARGMAX,
201+
OP_DECODING,
201202
OP_INC_MULTIHEAD_SELF_ATTENTION,
202203
OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION,
203204
OP_TREE_INC_MULTIHEAD_SELF_ATTENTION,

include/flexflow/model.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ enum TaskIDs {
168168
ARGMAX_INIT_TASK_ID,
169169
ARGMAX_BEAM_INF_TASK_ID,
170170
ARGMAX_NORM_INF_TASK_ID,
171+
DECODING_INIT_TASK_ID,
172+
DECODING_BEAM_INF_TASK_ID,
173+
DECODING_NORM_INF_TASK_ID,
174+
DECODING_PEFT_BWD_TASK_ID,
171175
TRANSPOSE_INIT_TASK_ID,
172176
TRANSPOSE_FWD_TASK_ID,
173177
TRANSPOSE_BWD_TASK_ID,
@@ -375,6 +379,7 @@ class BeamTopK;
375379
class SpecIncMultiHeadSelfAttention;
376380
class Sampling;
377381
class ArgMax;
382+
class Decoding;
378383
class Combine;
379384
class Repartition;
380385
class Reduction;
@@ -720,6 +725,7 @@ class FFModel {
720725
bool speculative_decoding,
721726
char const *name = NULL);
722727
Tensor argmax(const Tensor input, bool beam_search, char const *name = NULL);
728+
Tensor decoding(const Tensor input, bool beam_search, char const *name = NULL);
723729
Tensor sampling(const Tensor input, float top_p, char const *name = NULL);
724730
Tensor multihead_attention(const Tensor query,
725731
const Tensor key,
@@ -1221,6 +1227,8 @@ class FFModel {
12211227
Sampling *>,
12221228
std::unordered_map<std::pair<ParallelTensorShape, ArgMaxParams>,
12231229
ArgMax *>,
1230+
std::unordered_map<std::pair<ParallelTensorShape, DecodingParams>,
1231+
Decoding *>,
12241232
std::unordered_map<
12251233
std::pair<ParallelTensorShape, SpecIncMultiHeadSelfAttentionParams>,
12261234
SpecIncMultiHeadSelfAttention *>,

include/flexflow/operator_params.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "flexflow/ops/aggregate_spec_params.h"
77
#include "flexflow/ops/arg_topk_params.h"
88
#include "flexflow/ops/argmax_params.h"
9+
#include "flexflow/ops/decoding_params.h"
910
#include "flexflow/ops/attention_params.h"
1011
#include "flexflow/ops/batch_matmul_params.h"
1112
#include "flexflow/ops/beam_topk_params.h"
@@ -85,6 +86,7 @@ using OperatorParameters = mp::variant<AggregateParams,
8586
ArgTopKParams,
8687
SamplingParams,
8788
ArgMaxParams,
89+
DecodingParams,
8890
SoftmaxParams,
8991
TransposeParams,
9092
RepartitionParams,

include/flexflow/ops/decoding.h

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#ifndef _FLEXFLOW_DECODING_H
2+
#define _FLEXFLOW_DECODING_H
3+
4+
#include "flexflow/inference.h"
5+
#include "flexflow/model.h"
6+
#include "flexflow/layer.h"
7+
#include "flexflow/node.h"
8+
#include "flexflow/operator.h"
9+
#include "flexflow/ops/decoding_params.h"
10+
#include "flexflow/utils/memory_allocator.h"
11+
#include "flexflow/fftype.h"
12+
#include "flexflow/device.h"
13+
14+
namespace FlexFlow {
15+
16+
// forward declaration
17+
class DecodingMeta;
18+
19+
class Decoding : public Op {
20+
public:
21+
using Params = DecodingParams;
22+
using Input = ParallelTensor;
23+
Decoding(FFModel &model,
24+
LayerID const &_layer_guid,
25+
const ParallelTensor input,
26+
bool beam_search,
27+
char const *name);
28+
Decoding(FFModel &model,
29+
Params const &params,
30+
const Input input,
31+
char const *name = nullptr);
32+
void init(FFModel const &) override;
33+
void init_inference(FFModel const &,
34+
std::vector<ParallelTensor> const &,
35+
std::vector<ParallelTensor> const &,
36+
MachineView const *mv = nullptr) override;
37+
void forward(FFModel const &) override;
38+
Legion::FutureMap inference(FFModel const &,
39+
BatchConfigFuture const &,
40+
std::vector<ParallelTensor> const &,
41+
std::vector<ParallelTensor> const &,
42+
MachineView const *mv = nullptr) override;
43+
Legion::FutureMap peft_bwd(FFModel const &,
44+
BatchConfigFuture const &,
45+
std::vector<ParallelTensor> const &,
46+
std::vector<ParallelTensor> const &,
47+
MachineView const *mv = nullptr) override;
48+
void backward(FFModel const &) override;
49+
void print_layer(FFModel const &model) override {
50+
assert(0);
51+
}
52+
static Op *
53+
create_operator_from_layer(FFModel &model,
54+
Layer const *layer,
55+
std::vector<ParallelTensor> const &inputs);
56+
static OpMeta *init_task(Legion::Task const *task,
57+
std::vector<Legion::PhysicalRegion> const &regions,
58+
Legion::Context ctx,
59+
Legion::Runtime *runtime);
60+
static BeamInferenceResult
61+
inference_task_beam(Legion::Task const *task,
62+
std::vector<Legion::PhysicalRegion> const &regions,
63+
Legion::Context ctx,
64+
Legion::Runtime *runtime);
65+
static InferenceResult
66+
inference_task_norm(Legion::Task const *task,
67+
std::vector<Legion::PhysicalRegion> const &regions,
68+
Legion::Context ctx,
69+
Legion::Runtime *runtime);
70+
static bool peft_bwd_task(Legion::Task const *task,
71+
std::vector<Legion::PhysicalRegion> const &regions,
72+
Legion::Context ctx,
73+
Legion::Runtime *runtime);
74+
bool measure_operator_cost(Simulator *sim,
75+
MachineView const &pc,
76+
CostMetrics &cost_metrics) const override;
77+
void serialize(Legion::Serializer &) const override;
78+
static PCG::Node deserialize(FFModel &ff,
79+
Legion::Deserializer &d,
80+
ParallelTensor inputs[],
81+
int num_inputs);
82+
Op *materialize(FFModel &ff,
83+
ParallelTensor inputs[],
84+
int num_inputs) const override;
85+
Params get_params() const;
86+
87+
template <typename DT>
88+
static void inference_kernel(DecodingMeta const *m,
89+
BatchConfig const *bc,
90+
DT const *input_ptr,
91+
DT *softmax_output_ptr,
92+
int *argmax_output_ptr,
93+
int num_classes,
94+
float *loss,
95+
ffStream_t stream);
96+
static void inference_kernel_wrapper(DecodingMeta *m,
97+
BatchConfig const *bc,
98+
bool is_last_op,
99+
GenericTensorAccessorR const &input,
100+
GenericTensorAccessorW const &softmax_output,
101+
GenericTensorAccessorW const &argmax_output);
102+
template <typename DT>
103+
static void peft_bwd_kernel(DecodingMeta const *m,
104+
BatchConfig const *bc,
105+
DT *input_grad_ptr,
106+
int num_classes,
107+
ffStream_t stream);
108+
static void peft_bwd_kernel_wrapper(DecodingMeta *m,
109+
BatchConfig const *bc,
110+
GenericTensorAccessorW const &input_grad);
111+
112+
public:
113+
LayerID layer_guid;
114+
bool beam_search;
115+
};
116+
117+
class DecodingMeta : public OpMeta {
118+
public:
119+
DecodingMeta(FFHandler handler,
120+
Decoding const *decoding,
121+
Legion::Domain const &input_domain,
122+
bool is_last_op,
123+
MemoryAllocator &gpu_mem_allocator);
124+
~DecodingMeta(void);
125+
bool beam_search;
126+
float *probs;
127+
float *d_loss;
128+
// Temporary buffers
129+
int *parent_output_buffer;
130+
// PEFT related fields
131+
void *output_grad_ptr = nullptr;
132+
size_t allocated_peft_buffer_size = 0;
133+
Realm::RegionInstance reserveInst;
134+
BatchConfig::TokenId peft_token_ids[BatchConfig::MAX_NUM_TOKENS];
135+
};
136+
137+
}; // namespace FlexFlow
138+
139+
#endif // _FLEXFLOW_DECODING_H
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef _FLEXFLOW_DECODING_PARAMS_H
2+
#define _FLEXFLOW_DECODING_PARAMS_H
3+
4+
#include "flexflow/ffconst.h"
5+
#include "flexflow/parallel_tensor.h"
6+
7+
namespace FlexFlow {
8+
9+
struct DecodingParams {
10+
LayerID layer_guid;
11+
bool beam_search;
12+
bool is_valid(ParallelTensorShape const &) const;
13+
char name[MAX_OPNAME];
14+
};
15+
bool operator==(DecodingParams const &, DecodingParams const &);
16+
17+
} // namespace FlexFlow
18+
19+
namespace std {
20+
template <>
21+
struct hash<FlexFlow::DecodingParams> {
22+
size_t operator()(FlexFlow::DecodingParams const &) const;
23+
};
24+
} // namespace std
25+
26+
#endif // _FLEXFLOW_DECODING_PARAMS_H

inference/models/llama.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,9 @@ void LLAMA::create_llama_model(FFModel &ff,
291291
output = ff.sampling(softmax, generation_config.topp);
292292
} else {
293293
// output = ff.arg_top_k(dense, /*k=*/1, false);
294-
Tensor softmax = ff.softmax(dense, -1);
295-
output = ff.argmax(softmax, /*beam_Search*/ false);
294+
// Tensor softmax = ff.softmax(dense, -1);
295+
// output = ff.argmax(softmax, /*beam_Search*/ false);
296+
output = ff.decoding(dense, /*beam_search*/ false, "decoding");
296297
}
297298
}
298299

src/ops/argmax.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,22 @@ bool operator==(ArgMaxParams const &lhs, ArgMaxParams const &rhs) {
107107
return lhs.beam_search == rhs.beam_search;
108108
}
109109

110+
static std::string remove_uid(char const *op_name) {
111+
std::string op_name_without_uid = std::string(op_name);
112+
size_t last_underscore = op_name_without_uid.length();
113+
for (int i = op_name_without_uid.length() - 1; i > 0; i--) {
114+
if (!(std::isdigit(op_name[i]) || op_name[i] == '_')) {
115+
break;
116+
} else if (op_name[i] == '_') {
117+
last_underscore = i;
118+
}
119+
}
120+
if (last_underscore < op_name_without_uid.length()) {
121+
op_name_without_uid.erase(last_underscore);
122+
}
123+
return op_name_without_uid;
124+
}
125+
110126
ArgMax::ArgMax(FFModel &model,
111127
const ParallelTensor _input,
112128
bool _beam_search,
@@ -136,6 +152,10 @@ ArgMax::ArgMax(FFModel &model,
136152
outputs[1] = model.create_parallel_tensor_legion_ordering(
137153
numdim, dims, DT_INT32, this, 1 /*owner_idx*/);
138154
}
155+
std::string const &input_label = std::string("Argmax input tensor");
156+
_input->print(input_label);
157+
std::string const &label = std::string("Argmax output tensor");
158+
outputs[0]->print(label);
139159
}
140160

141161
ArgMax::ArgMax(FFModel &model, ArgMax const &other, const ParallelTensor input)
@@ -397,6 +417,18 @@ InferenceResult
397417

398418
ArgMax::inference_kernel_wrapper(m, bc, input, indices, parent, &loss);
399419

420+
if (task->index_point.point_data[0] == 0) {
421+
int in_dim0 = input.domain.hi()[0] - input.domain.lo()[0] + 1;
422+
int in_dim1 = input.domain.hi()[1] - input.domain.lo()[1] + 1;
423+
int out_dim0 = indices.domain.hi()[0] - indices.domain.lo()[0] + 1;
424+
int out_dim1 = indices.domain.hi()[1] - indices.domain.lo()[1] + 1;
425+
std::string op_name_without_uid = remove_uid(m->op_name);
426+
printf("Argmax(%s): in=[%i, bz=%i/%i] -> out=[%i,bz=%i/%i]\n",
427+
op_name_without_uid.c_str(),
428+
in_dim0, bc->num_tokens, in_dim1,
429+
out_dim0, bc->num_tokens, out_dim1);
430+
}
431+
400432
InferenceResult ir;
401433
ir.finetuning_loss = loss;
402434

0 commit comments

Comments
 (0)