@@ -87,8 +87,8 @@ def _get_not_self_contained_reps(self, modality_type):
8787 )
8888
8989 @lru_cache (maxsize = 32 )
90- def _get_context_operators (self ):
91- return self .operator_registry .get_context_operators ()
90+ def _get_context_operators (self , modality_type ):
91+ return self .operator_registry .get_context_operators (modality_type )
9292
9393 def store_results (self , file_name = None ):
9494 if file_name is None :
@@ -302,6 +302,39 @@ def _build_modality_dag(
302302 current_node_id = rep_node_id
303303 dags .append (builder .build (current_node_id ))
304304
305+ if operator .needs_context :
306+ context_operators = self ._get_context_operators (modality .modality_type )
307+ for context_op in context_operators :
308+ if operator .initial_context_length is not None :
309+ context_length = operator .initial_context_length
310+
311+ context_node_id = builder .create_operation_node (
312+ context_op ,
313+ [leaf_id ],
314+ context_op (context_length ).get_current_parameters (),
315+ )
316+ else :
317+ context_node_id = builder .create_operation_node (
318+ context_op ,
319+ [leaf_id ],
320+ context_op ().get_current_parameters (),
321+ )
322+
323+ context_rep_node_id = builder .create_operation_node (
324+ operator .__class__ ,
325+ [context_node_id ],
326+ operator .get_current_parameters (),
327+ )
328+
329+ agg_operator = AggregatedRepresentation ()
330+ context_agg_node_id = builder .create_operation_node (
331+ agg_operator .__class__ ,
332+ [context_rep_node_id ],
333+ agg_operator .get_current_parameters (),
334+ )
335+
336+ dags .append (builder .build (context_agg_node_id ))
337+
305338 if not operator .self_contained :
306339 not_self_contained_reps = self ._get_not_self_contained_reps (
307340 modality .modality_type
@@ -344,7 +377,7 @@ def _build_modality_dag(
344377
345378 def default_context_operators (self , modality , builder , leaf_id , current_node_id ):
346379 dags = []
347- context_operators = self ._get_context_operators ()
380+ context_operators = self ._get_context_operators (modality . modality_type )
348381 for context_op in context_operators :
349382 if (
350383 modality .modality_type != ModalityType .TEXT
@@ -368,7 +401,7 @@ def default_context_operators(self, modality, builder, leaf_id, current_node_id)
368401
369402 def temporal_context_operators (self , modality , builder , leaf_id , current_node_id ):
370403 aggregators = self .operator_registry .get_representations (modality .modality_type )
371- context_operators = self ._get_context_operators ()
404+ context_operators = self ._get_context_operators (modality . modality_type )
372405
373406 dags = []
374407 for agg in aggregators :
0 commit comments