22
33import json
44from collections .abc import Sequence
5- from pathlib import Path
6- from tempfile import NamedTemporaryFile
5+ from dataclasses import dataclass
6+ from http import HTTPStatus
7+ from time import sleep
78
8- from cmem .cmempy .workflow .workflow import execute_workflow_io , get_workflows_io
9+ from cmem .cmempy .api import config , get_json
10+ from cmem .cmempy .workflow .workflow import get_workflows_io
911from cmem_plugin_base .dataintegration .context import ExecutionContext , ExecutionReport
1012from cmem_plugin_base .dataintegration .description import Icon , Plugin , PluginParameter
1113from cmem_plugin_base .dataintegration .entity import (
1416 EntityPath ,
1517 EntitySchema ,
1618)
17- from cmem_plugin_base .dataintegration .plugins import WorkflowPlugin
19+ from cmem_plugin_base .dataintegration .plugins import PluginLogger , WorkflowPlugin
1820from cmem_plugin_base .dataintegration .ports import FixedNumberOfInputs , FlexibleSchemaPort
21+ from cmem_plugin_base .dataintegration .types import BoolParameterType , IntParameterType
1922from cmem_plugin_base .dataintegration .utils import setup_cmempy_user_access
23+ from requests import HTTPError
2024
2125from cmem_plugin_loopwf import exceptions
2226from cmem_plugin_loopwf .workflow_type import SuitableWorkflowParameterType
3741"""
3842
3943
44+ @dataclass
45+ class WorkflowExecution :
46+ """Represents the status of a concrete workflow execution"""
47+
48+ task_id : str
49+ project_id : str
50+ entity : Entity
51+ schema : EntitySchema
52+ instance_id : str | None = None
53+ activity_id : str | None = None
54+ status : str = "QUEUED"
55+ is_running : bool = False
56+ raw : dict [str , str ] | None = None
57+ execution_context : ExecutionContext | None = None
58+ logger : PluginLogger | None = None
59+
60+ @property
61+ def is_finished (self ) -> bool :
62+ """True if workflow is finished"""
63+ return self .status .upper () == "FINISHED"
64+
65+ @property
66+ def is_queued (self ) -> bool :
67+ """True if workflow is queued"""
68+ return self .status .upper () == "QUEUED"
69+
70+ def entity_as_json_str (self ) -> str :
71+ """Return the entity as a JSON string"""
72+ entity_as_dict = StartWorkflow .entity_to_dict (entity = self .entity , schema = self .schema )
73+ return json .dumps (entity_as_dict )
74+
75+ def start (self ) -> bool :
76+ """Start the workflow"""
77+ if self .logger :
78+ self .logger .debug (f"Starting workflow execution: { self .entity_as_json_str ()} " )
79+ try :
80+ response = get_json (
81+ f"{ config .get_di_api_endpoint ()} /api/workflow/executeAsync/{ self .project_id } /{ self .task_id } " ,
82+ headers = {"Content-Type" : "application/json" },
83+ method = "POST" ,
84+ data = self .entity_as_json_str (),
85+ )
86+ except HTTPError as error :
87+ if error .response .status_code == HTTPStatus .SERVICE_UNAVAILABLE :
88+ # 503 - no more execution capacity > no status change
89+ return False
90+ raise HTTPError from error
91+ self .instance_id = response ["instanceId" ]
92+ self .activity_id = response ["activityId" ]
93+ self .update ()
94+ return True
95+
96+ def wait_until_finished (self ) -> None :
97+ """Wait until the workflow is finished"""
98+ while self .is_running :
99+ self .update ()
100+ sleep (1 )
101+
102+ def update (self ) -> None :
103+ """Update the execution status"""
104+ response = get_json (
105+ f"{ config .get_di_api_endpoint ()} /workspace/activities/status" ,
106+ params = {
107+ "project" : self .project_id ,
108+ "task" : self .task_id ,
109+ "activity" : self .activity_id ,
110+ "instance" : self .instance_id ,
111+ },
112+ )
113+ self .status = response ["statusName" ]
114+ self .is_running = response ["isRunning" ]
115+ self .raw = response
116+ if self .logger :
117+ self .logger .debug (f"Updated Status: { self !s} " )
118+
119+
120+ @dataclass
121+ class WorkflowExecutionList :
122+ """Workflow execution status list / registry"""
123+
124+ statuses : list [WorkflowExecution ]
125+ context : ExecutionContext
126+ logger : PluginLogger
127+
128+ def __init__ (self ):
129+ self .statuses = []
130+
131+ def execute (self , parallel_execution : int ) -> None :
132+ """Execute all workflow executions"""
133+ while self .queued > 0 :
134+ while self .running < parallel_execution and self .queued > 0 :
135+ self .start_next ()
136+ self .report ()
137+ self .wait_until_finished ()
138+ self .report ()
139+
140+ def start_next (self ) -> bool :
141+ """Start next workflow execution in queue"""
142+ all_queued = [_ for _ in self .statuses if _ .is_queued ]
143+ if not all_queued :
144+ return False
145+ next_in_queue : WorkflowExecution = all_queued [0 ]
146+ return next_in_queue .start ()
147+
148+ def wait_until_finished (self , polling_time : int = 1 ) -> None :
149+ """Wait until all running workflows are finished"""
150+ while self .running > 0 :
151+ sleep (polling_time )
152+ self .update_running_status ()
153+
154+ def update_running_status (self ) -> None :
155+ """Update status of running workflows"""
156+ for _ in self .statuses :
157+ if _ .is_running :
158+ _ .update ()
159+
160+ def append (self , status : WorkflowExecution ) -> None :
161+ """Append a workflow execution to the list"""
162+ self .statuses .append (status )
163+
164+ def report (self ) -> None :
165+ """Report workflow statuses to logger and/or execution report from context"""
166+ line = f"finished ({ self .running } running, { self .queued } queued)"
167+ self .context .report .update (
168+ ExecutionReport (
169+ entity_count = self .finished ,
170+ operation = "start" ,
171+ operation_desc = line ,
172+ )
173+ )
174+ self .logger .info (f"{ self .finished } { line } " )
175+
176+ @property
177+ def running (self ) -> int :
178+ """Returns the number of running workflows"""
179+ return len ([_ for _ in self .statuses if _ .is_running ])
180+
181+ @property
182+ def finished (self ) -> int :
183+ """Returns the number of finished workflows"""
184+ return len ([_ for _ in self .statuses if _ .is_finished ])
185+
186+ @property
187+ def queued (self ) -> int :
188+ """Returns the number of queued workflows"""
189+ return len ([_ for _ in self .statuses if _ .is_queued ])
190+
191+
40192@Plugin (
41193 label = "Start Workflow per Entity" ,
42194 description = "Loop over the output of a task and start a sub-workflow for each entity." ,
50202 param_type = SuitableWorkflowParameterType (),
51203 description = "Which workflow do you want to start per entity." ,
52204 ),
205+ PluginParameter (
206+ name = "parallel_execution" ,
207+ label = "How many workflow jobs should run in parallel?" ,
208+ param_type = IntParameterType (),
209+ default_value = 1 ,
210+ ),
211+ PluginParameter (
212+ name = "forward_entities" ,
213+ label = "Forward incoming entities to the output port?" ,
214+ param_type = BoolParameterType (),
215+ default_value = False ,
216+ ),
53217 ],
54218)
55219class StartWorkflow (WorkflowPlugin ):
56220 """Start Workflow per Entity"""
57221
58222 context : ExecutionContext
59- schema : EntitySchema
223+ executions : WorkflowExecutionList
60224
61- def __init__ (self , workflow : str ) -> None :
225+ def __init__ (
226+ self , workflow : str , parallel_execution : int = 1 , forward_entities : bool = False
227+ ) -> None :
62228 self .workflow = workflow
229+ if parallel_execution < 1 :
230+ raise ValueError ("parallel_execution must be >= 1" )
231+ self .parallel_execution = parallel_execution
232+ self .forward_entities = forward_entities
63233 self .input_ports = FixedNumberOfInputs ([FlexibleSchemaPort ()])
64- self .output_port = None
234+ self .output_port = FlexibleSchemaPort () if forward_entities else None
65235 self .workflows_started = 0
236+ self .executions = WorkflowExecutionList ()
237+
238+ def start_workflows (self , inputs : Sequence [Entities ]) -> Entities :
239+ """Start the workflows and return output entities"""
240+ input_entities = inputs [0 ].entities
241+ schema = inputs [0 ].schema
242+ self .executions .context = self .context
243+ self .executions .logger = self .log
244+ self .executions .report ()
245+ for entity in input_entities :
246+ new_execution = WorkflowExecution (
247+ task_id = self .workflow ,
248+ project_id = self .context .task .project_id (),
249+ entity = entity ,
250+ schema = schema ,
251+ execution_context = self .context ,
252+ logger = self .log ,
253+ )
254+ self .log .info (f"Got new entity: { new_execution .entity_as_json_str ()} " )
255+ self .executions .append (new_execution )
256+ self .executions .report ()
257+ self .executions .execute (parallel_execution = self .parallel_execution )
258+ # remove execution via /workflow/workflows/{project}/{task}/execution/{executionId}
259+
260+ return Entities (
261+ schema = schema ,
262+ entities = iter ([_ .entity for _ in self .executions .statuses ]),
263+ )
66264
67265 def execute (
68266 self ,
69267 inputs : Sequence [Entities ],
70268 context : ExecutionContext ,
71- ) -> None :
269+ ) -> Entities | None :
72270 """Run the workflow operator."""
73271 self .log .info ("Start execute" )
74272 self .context = context
75273 self .validate_inputs (inputs = inputs )
76- self .schema = inputs [0 ].schema
77274 self .validate_workflow (workflow = self .workflow )
78-
79- for entity in inputs [0 ].entities :
80- self .start_workflow (entity = entity )
81-
82- self .log .info ("Stop execute" )
275+ output_entities = self .start_workflows (inputs = inputs )
276+ if self .forward_entities :
277+ self .log .info ("All done ... forward entities" )
278+ return output_entities
279+ self .log .info ("All done ..." )
280+ return None
83281
84282 @staticmethod
85283 def validate_inputs (inputs : Sequence [Entities ]) -> None :
@@ -106,35 +304,6 @@ def validate_workflow(self, workflow: str) -> None:
106304 )
107305 self .log .info (str (suitable_workflows ))
108306
109- def start_workflow (self , entity : Entity ) -> None :
110- """Start a single workflow."""
111- entity_as_dict : dict = self .entity_to_dict (entity = entity , schema = self .schema )
112- entity_as_json : str = json .dumps (entity_as_dict )
113- self .log .info (f"Processing new entity: { entity_as_json } " )
114- # start workflow here
115- with NamedTemporaryFile (mode = "w+" ) as temp_file :
116- self .log .info (f"temp file for entity: { temp_file .name } " )
117- temp_file .write (entity_as_json )
118- temp_file .flush ()
119- self .log .info (f"temp file content: { Path (temp_file .name ).read_text ()} " )
120- setup_cmempy_user_access (context = self .context .user )
121- execute_workflow_io (
122- project_name = self .context .task .project_id (),
123- task_name = self .workflow ,
124- input_file = temp_file .name ,
125- input_mime_type = "application/x-plugin-json" ,
126- output_mime_type = "guess" ,
127- auto_config = False ,
128- )
129- self .workflows_started += 1
130- self .context .report .update (
131- ExecutionReport (
132- entity_count = self .workflows_started ,
133- operation = "start" ,
134- operation_desc = "workflows started" ,
135- )
136- )
137-
138307 @staticmethod
139308 def entity_to_dict (entity : Entity , schema : EntitySchema ) -> dict :
140309 """Convert an entity to a dictionary, using the schema"""
0 commit comments