Skip to content

Commit b5f0a40

Browse files
committed
fix
1 parent b9571ae commit b5f0a40

3 files changed

Lines changed: 8 additions & 43 deletions

File tree

include/flexflow/ops/decoding.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ class Decoding : public Op {
9595
ffStream_t stream);
9696
static void inference_kernel_wrapper(DecodingMeta *m,
9797
BatchConfig const *bc,
98-
bool is_last_op,
9998
GenericTensorAccessorR const &input,
10099
GenericTensorAccessorW const &softmax_output,
101100
GenericTensorAccessorW const &argmax_output);
@@ -119,7 +118,6 @@ class DecodingMeta : public OpMeta {
119118
DecodingMeta(FFHandler handler,
120119
Decoding const *decoding,
121120
Legion::Domain const &input_domain,
122-
bool is_last_op,
123121
MemoryAllocator &gpu_mem_allocator);
124122
~DecodingMeta(void);
125123
bool beam_search;

src/ops/decoding.cc

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ Decoding::Decoding(FFModel &model,
204204

205205
struct DecodingInitMeta {
206206
Decoding *decoding;
207-
bool is_last_op;
208207
};
209208

210209
void Decoding::init_inference(FFModel const &ff,
@@ -220,26 +219,8 @@ void Decoding::init_inference(FFModel const &ff,
220219
size_t machine_view_hash = view->hash();
221220
set_argumentmap_for_init_inference(ff, argmap, batch_outputs[0]);
222221

223-
int last_op = ff.operators.size() - 1;
224-
assert(ff.operators[last_op]->op_type == OP_ARGMAX ||
225-
ff.operators[last_op]->op_type == OP_ARG_TOPK ||
226-
ff.operators[last_op]->op_type == OP_SAMPLING ||
227-
ff.operators[last_op]->op_type == OP_DECODING);
228-
last_op -= 1;
229-
while (ff.operators[last_op]->op_type == OP_WEIGHT && last_op > 0) {
230-
last_op -= 1;
231-
}
232-
bool is_last_op = false;
233-
if (ff.operators[last_op] == this) {
234-
is_last_op = true;
235-
} else if (ff.operators[last_op]->op_type == OP_FUSED) {
236-
FusedOp *fused_op = static_cast<FusedOp *>(ff.operators[last_op]);
237-
is_last_op = fused_op->operators[fused_op->numOperators - 1] == this;
238-
}
239-
240222
DecodingInitMeta meta;
241223
meta.decoding = this;
242-
meta.is_last_op = is_last_op;
243224

244225
IndexLauncher launcher(DECODING_INIT_TASK_ID,
245226
parallel_is,
@@ -317,7 +298,7 @@ OpMeta *Decoding::init_task(Task const *task,
317298
assert(task->regions.size() == regions.size());
318299
DecodingInitMeta const *meta = (DecodingInitMeta *)task->args;
319300
Decoding const *decoding = meta->decoding;
320-
bool is_last_op = meta->is_last_op;
301+
321302
FFHandler handle = *((FFHandler const *)task->local_args);
322303
Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc);
323304
MemoryAllocator gpu_mem_allocator(gpu_mem);
@@ -352,7 +333,7 @@ OpMeta *Decoding::init_task(Task const *task,
352333
}
353334

354335
DecodingMeta *m =
355-
new DecodingMeta(handle, decoding, domain, is_last_op, gpu_mem_allocator);
336+
new DecodingMeta(handle, decoding, domain, gpu_mem_allocator);
356337
std::strcpy(m->op_name, decoding->name);
357338
m->layer_guid = decoding->layer_guid;
358339
m->beam_search = decoding->beam_search;
@@ -378,21 +359,11 @@ FutureMap Decoding::inference(FFModel const &ff,
378359
size_t machine_view_hash = view->hash();
379360

380361
assert(ff.config.computationMode == COMP_MODE_INFERENCE);
381-
int last_op = ff.operators.size() - 1;
382-
assert(ff.operators[last_op]->op_type == OP_ARGMAX ||
383-
ff.operators[last_op]->op_type == OP_ARG_TOPK ||
384-
ff.operators[last_op]->op_type == OP_SAMPLING ||
385-
ff.operators[last_op]->op_type == OP_DECODING);
386-
last_op -= 1;
387-
while (ff.operators[last_op]->op_type == OP_WEIGHT && last_op > 0) {
388-
last_op -= 1;
389-
}
390-
bool is_last_op = (ff.operators[last_op] == this);
391362

392363
if (beam_search) {
393364
IndexLauncher launcher(DECODING_BEAM_INF_TASK_ID,
394365
parallel_is,
395-
TaskArgument(&is_last_op, sizeof(bool)),
366+
TaskArgument(nullptr, 0),
396367
argmap,
397368
Predicate::TRUE_PRED,
398369
false /*must*/,
@@ -421,7 +392,7 @@ FutureMap Decoding::inference(FFModel const &ff,
421392
} else {
422393
IndexLauncher launcher(DECODING_NORM_INF_TASK_ID,
423394
parallel_is,
424-
TaskArgument(&is_last_op, sizeof(bool)),
395+
TaskArgument(nullptr, 0),
425396
argmap,
426397
Predicate::TRUE_PRED,
427398
false /*must*/,
@@ -457,7 +428,6 @@ BeamInferenceResult
457428
Runtime *runtime) {
458429
assert(regions.size() == 3);
459430
assert(task->regions.size() == 3);
460-
bool is_last_op = *(bool *)task->args;
461431
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
462432
if (bc->num_tokens == 0) {
463433
// Directly return for empty batch config
@@ -475,7 +445,7 @@ BeamInferenceResult
475445
int batch_size = bc->num_active_tokens();
476446
float loss = 0.0f;
477447

478-
inference_kernel_wrapper(m, bc, is_last_op, input, softmax_output, argmax_output);
448+
inference_kernel_wrapper(m, bc, input, softmax_output, argmax_output);
479449

480450
BeamInferenceResult ir;
481451
// Copy argmax results from output region
@@ -504,7 +474,6 @@ InferenceResult
504474
Runtime *runtime) {
505475
assert(regions.size() == 3);
506476
assert(task->regions.size() == 3);
507-
bool is_last_op = *(bool *)task->args;
508477
DecodingMeta *m = *((DecodingMeta **)task->local_args);
509478
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
510479
if (bc->num_tokens == 0) {
@@ -522,7 +491,7 @@ InferenceResult
522491
int batch_size = bc->num_active_tokens();
523492
float loss = 0.0f;
524493

525-
inference_kernel_wrapper(m, bc, is_last_op, input, softmax_output, argmax_output);
494+
inference_kernel_wrapper(m, bc, input, softmax_output, argmax_output);
526495

527496
if (task->index_point.point_data[0] == 0) {
528497
int in_dim0 = input.domain.hi()[0] - input.domain.lo()[0] + 1;

src/ops/decoding.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ void store_peft_activations(DecodingMeta *m,
261261
/*static*/
262262
void Decoding::inference_kernel_wrapper(DecodingMeta *m,
263263
BatchConfig const *bc,
264-
bool is_last_op,
265264
GenericTensorAccessorR const &input,
266265
GenericTensorAccessorW const &softmax_output,
267266
GenericTensorAccessorW const &argmax_output) {
@@ -303,7 +302,7 @@ void Decoding::inference_kernel_wrapper(DecodingMeta *m,
303302
assert(false && "Unsupported data type");
304303
}
305304

306-
if (is_last_op && bc->num_finetuning_fwd_requests() > 0) {
305+
if (bc->num_finetuning_fwd_requests() > 0) {
307306
store_peft_token_ids(m, bc);
308307
// Store softmax activations for PEFT backward pass
309308
if (input.data_type == DT_HALF) {
@@ -418,12 +417,11 @@ void Decoding::peft_bwd_kernel_wrapper(DecodingMeta *m,
418417
DecodingMeta::DecodingMeta(FFHandler handler,
419418
Decoding const *decoding,
420419
Legion::Domain const &input_domain,
421-
bool is_last_op,
422420
MemoryAllocator &gpu_mem_allocator)
423421
: OpMeta(handler, decoding) {
424422
beam_search = decoding->beam_search;
425423
426-
if (peft_finetuning_enabled(peft_support_mode) && is_last_op) {
424+
if (peft_finetuning_enabled(peft_support_mode)) {
427425
allocated_peft_buffer_size =
428426
input_domain.get_volume() * data_type_size(decoding->data_type);
429427
gpu_mem_allocator.create_legion_instance(

0 commit comments

Comments
 (0)