diff --git a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py index d837ed902..363be3f7b 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py +++ b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py @@ -29,6 +29,8 @@ LOGGER = logging.getLogger("airbyte") +_MAX_DOWNLOAD_RETRIES = 1 + @dataclass class AsyncHttpJobRepository(AsyncJobRepository): @@ -188,7 +190,9 @@ def update_jobs_status(self, jobs: Iterable[AsyncJob]) -> None: lazy_log( LOGGER, logging.DEBUG, - lambda: f"Status of job {job.api_job_id()} changed from {job.status()} to {job_status}", + lambda: ( + f"Status of job {job.api_job_id()} changed from {job.status()} to {job_status}" + ), ) else: lazy_log( @@ -205,6 +209,10 @@ def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]: """ Fetches records from the given job. + If a download fails due to an expired download URL (e.g., HTTP 403 from an expired + SAS token), this method re-polls the polling endpoint to obtain fresh download URLs + and retries the download. + Args: job (AsyncJob): The job to fetch records from. @@ -212,7 +220,24 @@ def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]: Iterable[Mapping[str, Any]]: A generator that yields records as dictionaries. """ + for attempt in range(_MAX_DOWNLOAD_RETRIES + 1): + try: + yield from self._download_records_for_job(job) + return + except AirbyteTracedException as error: + is_last_attempt = attempt >= _MAX_DOWNLOAD_RETRIES + if is_last_attempt or not self._is_retriable_download_error(error): + raise + LOGGER.info( + f"Download failed (attempt {attempt + 1}/{_MAX_DOWNLOAD_RETRIES + 1}), " + f"re-polling for fresh download URLs. Error: {error.internal_message}" + ) + self._refresh_polling_response(job) + + yield from [] + def _download_records_for_job(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]: + """Downloads and yields records from all download targets for a job.""" for download_target in self._get_download_targets(job): job_slice = job.job_parameters() stream_slice = StreamSlice( @@ -238,6 +263,20 @@ def fetch_records(self, job: AsyncJob) -> Iterable[Mapping[str, Any]]: yield from [] + def _is_retriable_download_error(self, error: AirbyteTracedException) -> bool: + """Check if a download error is likely caused by an expired download URL. + + HTTP 403 during the download phase typically indicates an expired pre-signed URL + or SAS token, not a genuine permissions issue. Re-polling for a fresh URL may resolve it. + """ + return error.internal_message is not None and "status code '403'" in error.internal_message + + def _refresh_polling_response(self, job: AsyncJob) -> None: + """Re-poll the polling endpoint to obtain a fresh response with updated download URLs.""" + stream_slice = self._get_create_job_stream_slice(job) + polling_response = self._get_validated_polling_response(stream_slice) + self._polling_job_response_by_id[job.api_job_id()] = polling_response + def abort(self, job: AsyncJob) -> None: if not self.abort_requester: return @@ -340,7 +379,9 @@ def _get_download_targets(self, job: AsyncJob) -> Iterable[str]: lazy_log( LOGGER, logging.DEBUG, - lambda: "No download_target_extractor or download_target_requester provided. Will attempt a single download request without a `download_target`.", + lambda: ( + "No download_target_extractor or download_target_requester provided. Will attempt a single download request without a `download_target`." + ), ) yield "" return