1- import asyncio
21import importlib .metadata
32import os
4- from email . utils import parsedate_to_datetime
3+ from collections . abc import Awaitable , Callable
54from typing import (
65 Any ,
76 AsyncIterator ,
2524 RunTaskStreamChunk ,
2625 TaskRunResponse ,
2726)
27+ from workflowai .core .client .utils import build_retryable_wait
2828from workflowai .core .domain .cache_usage import CacheUsage
2929from workflowai .core .domain .errors import BaseError , WorkflowAIError
3030from workflowai .core .domain .task import Task , TaskInput , TaskOutput
@@ -77,9 +77,8 @@ async def run(
7777 use_cache : CacheUsage = "when_available" ,
7878 labels : Optional [set [str ]] = None ,
7979 metadata : Optional [dict [str , Any ]] = None ,
80- retry_delay : int = 5000 ,
81- max_retry_delay : int = 60000 ,
82- max_retry_count : int = 1 ,
80+ max_retry_delay : float = 60 ,
81+ max_retry_count : float = 1 ,
8382 ) -> TaskRun [TaskInput , TaskOutput ]: ...
8483
8584 @overload
@@ -94,12 +93,11 @@ async def run(
9493 use_cache : CacheUsage = "when_available" ,
9594 labels : Optional [set [str ]] = None ,
9695 metadata : Optional [dict [str , Any ]] = None ,
97- retry_delay : int = 5000 ,
98- max_retry_delay : int = 60000 ,
99- max_retry_count : int = 1 ,
96+ max_retry_delay : float = 60 ,
97+ max_retry_count : float = 1 ,
10098 ) -> AsyncIterator [TaskOutput ]: ...
10199
102- async def run ( # noqa: C901
100+ async def run (
103101 self ,
104102 task : Task [TaskInput , TaskOutput ],
105103 task_input : TaskInput ,
@@ -110,9 +108,8 @@ async def run( # noqa: C901
110108 use_cache : CacheUsage = "when_available" ,
111109 labels : Optional [set [str ]] = None ,
112110 metadata : Optional [dict [str , Any ]] = None ,
113- retry_delay : int = 5000 ,
114- max_retry_delay : int = 60000 ,
115- max_retry_count : int = 1 ,
111+ max_retry_delay : float = 60 ,
112+ max_retry_count : float = 1 ,
116113 ) -> Union [TaskRun [TaskInput , TaskOutput ], AsyncIterator [TaskOutput ]]:
117114 await self ._auto_register (task )
118115
@@ -135,76 +132,62 @@ async def run( # noqa: C901
135132 )
136133
137134 route = f"/tasks/{ task .id } /schemas/{ task .schema_id } /run"
135+ should_retry , wait_for_exception = build_retryable_wait (max_retry_delay , max_retry_count )
138136
139137 if not stream :
140- res = None
141- delay = retry_delay / 1000
142- retry_count = 0
143- while retry_count < max_retry_count :
144- try :
145- res = await self .api .post (route , request , returns = TaskRunResponse )
146- return res .to_domain (task )
147- except HTTPStatusError as e :
148- if e .response .status_code == 404 :
149- raise WorkflowAIError (
150- error = BaseError (
151- status_code = 404 ,
152- code = "not_found" ,
153- message = "Task not found" ,
154- ),
155- ) from e
156- retry_after = e .response .headers .get ("Retry-After" )
157- if retry_after :
158- try :
159- # for 429 errors this is non-negative decimal
160- delay = float (retry_after )
161- except ValueError :
162- try :
163- retry_after_date = parsedate_to_datetime (retry_after )
164- current_time = asyncio .get_event_loop ().time ()
165- delay = retry_after_date .timestamp () - current_time
166- except (TypeError , ValueError , OverflowError ):
167- delay = min (delay * 2 , max_retry_delay / 1000 )
168- await asyncio .sleep (delay )
169- elif e .response .status_code == 429 :
170- if delay < max_retry_delay / 1000 :
171- delay = min (delay * 2 , max_retry_delay / 1000 )
172- await asyncio .sleep (delay )
173- retry_count += 1
174-
175- async def _stream ():
176- delay = retry_delay / 1000
177- retry_count = 0
178- while retry_count < max_retry_count :
179- try :
180- async for chunk in self .api .stream (
181- method = "POST" ,
182- path = route ,
183- data = request ,
184- returns = RunTaskStreamChunk ,
185- ):
186- yield task .output_class .model_construct (None , ** chunk .task_output )
187- except HTTPStatusError as e :
188- if e .response .status_code == 404 :
189- raise WorkflowAIError (error = BaseError (message = "Task not found" )) from e
190- retry_after = e .response .headers .get ("Retry-After" )
191-
192- if retry_after :
193- try :
194- delay = float (retry_after )
195- except ValueError :
196- try :
197- retry_after_date = parsedate_to_datetime (retry_after )
198- current_time = asyncio .get_event_loop ().time ()
199- delay = retry_after_date .timestamp () - current_time
200- except (TypeError , ValueError , OverflowError ):
201- delay = min (delay * 2 , max_retry_delay / 1000 )
202- elif e .response .status_code == 429 and delay < max_retry_delay / 1000 :
203- delay = min (delay * 2 , max_retry_delay / 1000 )
204- await asyncio .sleep (delay )
205- retry_count += 1
206-
207- return _stream ()
138+ return await self ._retriable_run (
139+ route ,
140+ request ,
141+ task ,
142+ should_retry = should_retry ,
143+ wait_for_exception = wait_for_exception ,
144+ )
145+
146+ return self ._retriable_stream (
147+ route ,
148+ request ,
149+ task ,
150+ should_retry = should_retry ,
151+ wait_for_exception = wait_for_exception ,
152+ )
153+
154+ async def _retriable_run (
155+ self ,
156+ route : str ,
157+ request : RunRequest ,
158+ task : Task [TaskInput , TaskOutput ],
159+ should_retry : Callable [[], bool ],
160+ wait_for_exception : Callable [[HTTPStatusError ], Awaitable [None ]],
161+ ):
162+ while should_retry ():
163+ try :
164+ res = await self .api .post (route , request , returns = TaskRunResponse )
165+ return res .to_domain (task )
166+ except HTTPStatusError as e : # noqa: PERF203
167+ await wait_for_exception (e )
168+
169+ raise WorkflowAIError (error = BaseError (message = "max retries reached" ))
170+
171+ async def _retriable_stream (
172+ self ,
173+ route : str ,
174+ request : RunRequest ,
175+ task : Task [TaskInput , TaskOutput ],
176+ should_retry : Callable [[], bool ],
177+ wait_for_exception : Callable [[HTTPStatusError ], Awaitable [None ]],
178+ ):
179+ while should_retry ():
180+ try :
181+ async for chunk in self .api .stream (
182+ method = "POST" ,
183+ path = route ,
184+ data = request ,
185+ returns = RunTaskStreamChunk ,
186+ ):
187+ yield task .output_class .model_construct (None , ** chunk .task_output )
188+ return
189+ except HTTPStatusError as e : # noqa: PERF203
190+ await wait_for_exception (e )
208191
209192 async def import_run (
210193 self ,
0 commit comments