@@ -204,7 +204,6 @@ Decoding::Decoding(FFModel &model,
204204
205205struct DecodingInitMeta {
206206 Decoding *decoding;
207- bool is_last_op;
208207};
209208
210209void Decoding::init_inference (FFModel const &ff,
@@ -220,26 +219,8 @@ void Decoding::init_inference(FFModel const &ff,
220219 size_t machine_view_hash = view->hash ();
221220 set_argumentmap_for_init_inference (ff, argmap, batch_outputs[0 ]);
222221
223- int last_op = ff.operators .size () - 1 ;
224- assert (ff.operators [last_op]->op_type == OP_ARGMAX ||
225- ff.operators [last_op]->op_type == OP_ARG_TOPK ||
226- ff.operators [last_op]->op_type == OP_SAMPLING ||
227- ff.operators [last_op]->op_type == OP_DECODING);
228- last_op -= 1 ;
229- while (ff.operators [last_op]->op_type == OP_WEIGHT && last_op > 0 ) {
230- last_op -= 1 ;
231- }
232- bool is_last_op = false ;
233- if (ff.operators [last_op] == this ) {
234- is_last_op = true ;
235- } else if (ff.operators [last_op]->op_type == OP_FUSED) {
236- FusedOp *fused_op = static_cast <FusedOp *>(ff.operators [last_op]);
237- is_last_op = fused_op->operators [fused_op->numOperators - 1 ] == this ;
238- }
239-
240222 DecodingInitMeta meta;
241223 meta.decoding = this ;
242- meta.is_last_op = is_last_op;
243224
244225 IndexLauncher launcher (DECODING_INIT_TASK_ID,
245226 parallel_is,
@@ -317,7 +298,7 @@ OpMeta *Decoding::init_task(Task const *task,
317298 assert (task->regions .size () == regions.size ());
318299 DecodingInitMeta const *meta = (DecodingInitMeta *)task->args ;
319300 Decoding const *decoding = meta->decoding ;
320- bool is_last_op = meta-> is_last_op ;
301+
321302 FFHandler handle = *((FFHandler const *)task->local_args );
322303 Memory gpu_mem = get_proc_mem (Machine::get_machine (), task->target_proc );
323304 MemoryAllocator gpu_mem_allocator (gpu_mem);
@@ -352,7 +333,7 @@ OpMeta *Decoding::init_task(Task const *task,
352333 }
353334
354335 DecodingMeta *m =
355- new DecodingMeta (handle, decoding, domain, is_last_op, gpu_mem_allocator);
336+ new DecodingMeta (handle, decoding, domain, gpu_mem_allocator);
356337 std::strcpy (m->op_name , decoding->name );
357338 m->layer_guid = decoding->layer_guid ;
358339 m->beam_search = decoding->beam_search ;
@@ -378,21 +359,11 @@ FutureMap Decoding::inference(FFModel const &ff,
378359 size_t machine_view_hash = view->hash ();
379360
380361 assert (ff.config .computationMode == COMP_MODE_INFERENCE);
381- int last_op = ff.operators .size () - 1 ;
382- assert (ff.operators [last_op]->op_type == OP_ARGMAX ||
383- ff.operators [last_op]->op_type == OP_ARG_TOPK ||
384- ff.operators [last_op]->op_type == OP_SAMPLING ||
385- ff.operators [last_op]->op_type == OP_DECODING);
386- last_op -= 1 ;
387- while (ff.operators [last_op]->op_type == OP_WEIGHT && last_op > 0 ) {
388- last_op -= 1 ;
389- }
390- bool is_last_op = (ff.operators [last_op] == this );
391362
392363 if (beam_search) {
393364 IndexLauncher launcher (DECODING_BEAM_INF_TASK_ID,
394365 parallel_is,
395- TaskArgument (&is_last_op, sizeof ( bool ) ),
366+ TaskArgument (nullptr , 0 ),
396367 argmap,
397368 Predicate::TRUE_PRED,
398369 false /* must*/ ,
@@ -421,7 +392,7 @@ FutureMap Decoding::inference(FFModel const &ff,
421392 } else {
422393 IndexLauncher launcher (DECODING_NORM_INF_TASK_ID,
423394 parallel_is,
424- TaskArgument (&is_last_op, sizeof ( bool ) ),
395+ TaskArgument (nullptr , 0 ),
425396 argmap,
426397 Predicate::TRUE_PRED,
427398 false /* must*/ ,
@@ -457,7 +428,6 @@ BeamInferenceResult
457428 Runtime *runtime) {
458429 assert (regions.size () == 3 );
459430 assert (task->regions .size () == 3 );
460- bool is_last_op = *(bool *)task->args ;
461431 BatchConfig const *bc = BatchConfig::from_future (task->futures [0 ]);
462432 if (bc->num_tokens == 0 ) {
463433 // Directly return for empty batch config
@@ -475,7 +445,7 @@ BeamInferenceResult
475445 int batch_size = bc->num_active_tokens ();
476446 float loss = 0 .0f ;
477447
478- inference_kernel_wrapper (m, bc, is_last_op, input, softmax_output, argmax_output);
448+ inference_kernel_wrapper (m, bc, input, softmax_output, argmax_output);
479449
480450 BeamInferenceResult ir;
481451 // Copy argmax results from output region
@@ -504,7 +474,6 @@ InferenceResult
504474 Runtime *runtime) {
505475 assert (regions.size () == 3 );
506476 assert (task->regions .size () == 3 );
507- bool is_last_op = *(bool *)task->args ;
508477 DecodingMeta *m = *((DecodingMeta **)task->local_args );
509478 BatchConfig const *bc = BatchConfig::from_future (task->futures [0 ]);
510479 if (bc->num_tokens == 0 ) {
@@ -522,7 +491,7 @@ InferenceResult
522491 int batch_size = bc->num_active_tokens ();
523492 float loss = 0 .0f ;
524493
525- inference_kernel_wrapper (m, bc, is_last_op, input, softmax_output, argmax_output);
494+ inference_kernel_wrapper (m, bc, input, softmax_output, argmax_output);
526495
527496 if (task->index_point .point_data [0 ] == 0 ) {
528497 int in_dim0 = input.domain .hi ()[0 ] - input.domain .lo ()[0 ] + 1 ;
0 commit comments