Skip to content
51 changes: 5 additions & 46 deletions src/clusterfuzz/_internal/swarming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
import json
import uuid

from google.auth.transport import requests
from google.protobuf import json_format

from clusterfuzz._internal.base import utils
from clusterfuzz._internal.base.errors import BadConfigError
from clusterfuzz._internal.base.feature_flags import FeatureFlags
Expand All @@ -31,11 +28,6 @@
from clusterfuzz._internal.protos import swarming_pb2
from clusterfuzz._internal.system import environment

_SWARMING_SCOPES = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email'
]


def is_swarming_task(job_name: str, job: data_types.Job | None = None) -> bool:
"""Returns True if the task is supposed to run on swarming."""
Expand All @@ -54,7 +46,7 @@ def is_swarming_task(job_name: str, job: data_types.Job | None = None) -> bool:
logs.info('[Swarming DEBUG] No swarming env var', job_name=job_name)
return False

swarming_config = _get_swarming_config()
swarming_config = get_swarming_config()
if swarming_config is None:
logs.warning(
"""[Swarming DEBUG] current task is not suitable for swarming.
Expand All @@ -74,7 +66,7 @@ def _get_task_name(job_name: str):
return f't-{str(uuid.uuid4()).lower()}-{job_name}'


def _get_swarming_config() -> local_config.SwarmingConfig | None:
def get_swarming_config() -> local_config.SwarmingConfig | None:
"""Returns the swarming config."""
try:
return local_config.SwarmingConfig()
Expand All @@ -87,15 +79,15 @@ def _get_task_dimensions(job: data_types.Job, platform_specific_dimensions: list
) -> list[swarming_pb2.StringPair]: # pylint: disable=no-member
""" Gets all swarming dimensions for a task.
Job dimensions have more precedence than static dimensions"""
swarming_config = _get_swarming_config()
swarming_config = get_swarming_config()
if not swarming_config:
logs.error(
'[Swarming] No dimensions set. Reason: failed to retrieve config')
return []

unique_dimensions = {}
unique_dimensions['os'] = str(job.platform).capitalize()
unique_dimensions['pool'] = _get_swarming_config().get('swarming_pool')
unique_dimensions['pool'] = get_swarming_config().get('swarming_pool')

for dimension in platform_specific_dimensions:
unique_dimensions[dimension['key'].lower()] = dimension['value']
Expand Down Expand Up @@ -202,7 +194,7 @@ def create_new_task_request(command: str, job_name: str, download_url: str
if job is None:
return None

swarming_config = _get_swarming_config()
swarming_config = get_swarming_config()
if not swarming_config:
return None

Expand Down Expand Up @@ -255,36 +247,3 @@ def create_new_task_request(command: str, job_name: str, download_url: str
])

return new_task_request


def push_swarming_task(task_request: swarming_pb2.NewTaskRequest): # pylint: disable=no-member
"""Schedules a task on swarming."""
swarming_config = _get_swarming_config()
if not swarming_config:
logs.error(
'[Swarming] Failed to push task into swarming. Reason: No config.')
return
creds = credentials.get_scoped_service_account_credentials(_SWARMING_SCOPES)
if not creds:
logs.error(
'[Swarming] Failed to push task into swarming. Reason: No credentials.')
return

if not creds.token:
creds.refresh(requests.Request())

headers = {
'Accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': f'Bearer {creds.token}'
}
swarming_server = _get_swarming_config().get('swarming_server')
url = f'https://{swarming_server}/prpc/swarming.v2.Tasks/NewTask'
message_body = json_format.MessageToJson(task_request)
logs.info(
f"""[Swarming] Pushing task {task_request.name}
as {creds.service_account_email}""",
url=url,
body=message_body)
response = utils.post_url(url=url, data=message_body, headers=headers)
logs.info(f'[Swarming] Response from {task_request.name}', response=response)
135 changes: 135 additions & 0 deletions src/clusterfuzz/_internal/swarming/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Swarming pRPC API client."""

from google.auth.transport import requests
from google.protobuf import json_format

from clusterfuzz._internal.base import utils
from clusterfuzz._internal.config.local_config import SwarmingConfig
from clusterfuzz._internal.google_cloud_utils import credentials
from clusterfuzz._internal.metrics import logs
from clusterfuzz._internal.protos import swarming_pb2
from clusterfuzz._internal.swarming import get_swarming_config

_SWARMING_SCOPES = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email'
]

_COUNT_TASKS_ENDPOINT = 'swarming.v2.Tasks/CountTasks'
_NEW_TASK_ENDPOINT = 'swarming.v2.Tasks/NewTask'


class SwarmingAPI:
"""Client for Swarming pRPC API."""

_config: SwarmingConfig = None
_base_url: str = ""

def __init__(self):
self._config = get_swarming_config()
if self._config:
self._base_url = f"https://{self._config.get('swarming_server')}/prpc/"

def _get_headers(self) -> dict[str, str]:
"""Checks config and returns headers for pRPC request.

Returns:
A dict containing headers, or empty dict if config is missing or
auth fails.
"""
if not self._config:
logs.error('[Swarming] No config available.')
return {}

creds = credentials.get_scoped_service_account_credentials(_SWARMING_SCOPES)
if not creds:
logs.error('[Swarming] Failed to get credentials.')
return {}

if not creds.token:
creds.refresh(requests.Request())

return {
'Accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': f'Bearer {creds.token}'
}

def _make_request(self, endpoint: str, body: str) -> str | None:
"""Makes a pRPC request to the Swarming API.

Args:
endpoint: The pRPC endpoint (e.g. "swarming.v2.Tasks/NewTask").
body: The JSON body of the request.

Returns:
The raw JSON response string from the server, or None if the request
could not be made (e.g. missing config, auth failure) or failed.
"""
headers = self._get_headers()
if not headers:
return None

url = f'{self._base_url}{endpoint}'
response = utils.post_url(url=url, data=body, headers=headers)
if not response:
logs.error(f"[Swarming] Failed to make request to {url}")
return None
return response

def push_task(self, task_request: swarming_pb2.NewTaskRequest) -> str | None: # pylint: disable=no-member
"""Schedules a task on swarming.

Args:
task_request: The NewTaskRequest proto message.

Returns:
The raw JSON response string from the server, or None if the request
could not be made (e.g. missing config, auth failure) or failed.
"""
message_body = json_format.MessageToJson(task_request)
logs.info(
f"[Swarming] Pushing task {task_request.name}",
url=self._base_url,
endpoint=_NEW_TASK_ENDPOINT,
body=message_body)

response = self._make_request(_NEW_TASK_ENDPOINT, message_body)
logs.info(
f'[Swarming] Response from {task_request.name}', response=response)
return response

def count_tasks(self,
count_request: swarming_pb2.TasksCountRequest) -> str | None: # pylint: disable=no-member
"""Counts tasks on swarming.

Args:
count_request: The TasksCountRequest proto message.

Returns:
The raw JSON response string from the server, or None if the request
could not be made (e.g. missing config, auth failure) or failed.
"""
message_body = json_format.MessageToJson(count_request)
logs.info(
"[Swarming] Counting tasks in queue",
url=self._base_url,
endpoint=_COUNT_TASKS_ENDPOINT,
body=message_body)

response = self._make_request(_COUNT_TASKS_ENDPOINT, message_body)
logs.info('[Swarming] Response from CountTasks', response=response)
return response
74 changes: 70 additions & 4 deletions src/clusterfuzz/_internal/swarming/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,59 @@
# limitations under the License.
"""Swarming service."""

import json

from clusterfuzz._internal import swarming
from clusterfuzz._internal.base.tasks import task_utils
from clusterfuzz._internal.metrics import logs
from clusterfuzz._internal.protos import swarming_pb2
from clusterfuzz._internal.remote_task import remote_task_types
from clusterfuzz._internal.swarming.api import SwarmingAPI


class SwarmingService(remote_task_types.RemoteTaskInterface):
"""Remote task service implementation for Swarming."""

_api: SwarmingAPI = None
MAX_PENDING_TASKS = 50

def _get_api(self) -> SwarmingAPI:
"""Returns the Swarming API instance."""
if not self._api:
self._api = SwarmingAPI()
return self._api

def _get_os_dimension(self,
request: swarming_pb2.NewTaskRequest) -> str | None: # pylint: disable=no-member
"""Extracts the OS dimension from the task request."""
for dimension in request.task_slices[0].properties.dimensions:
if dimension.key == 'os':
return dimension.value
return None

def _is_backpressure_applied(
self, count_request: swarming_pb2.TasksCountRequest) -> bool: # pylint: disable=no-member
"""Checks if backpressure should be applied based on pending tasks count.

Returns True if backpressure is applied or if the check fails (Fail Closed).
"""
try:
response_str = self._get_api().count_tasks(count_request)
if not response_str:
raise RuntimeError("Empty response from CountTasks")

response_json = json.loads(response_str)
count = int(response_json.get('count', 0))

if count >= self.MAX_PENDING_TASKS:
logs.info(f'[Swarming] Backpressure applied. Queue size: {count}. '
'Stopping scheduling.')
return True
return False
except Exception as e:
logs.error(f'[Swarming] Failed to check backpressure (Fail Closed): {e}')
return True # Fail closed, always fails if swarming request fails

def create_utask_main_job(self, module: str, job_type: str,
input_download_url: str):
"""Creates a single swarming task for a uworker main task."""
Expand All @@ -44,14 +88,36 @@ def create_utask_main_jobs(self,
"""
unscheduled_tasks = []
logs.info(f'[Swarming] Pushing {len(remote_tasks)} tasks trough service.')
for task in remote_tasks:
for i, task in enumerate(remote_tasks):
try:
if not swarming.is_swarming_task(task.job_type):
unscheduled_tasks.append(task)
continue
if request := swarming.create_new_task_request(
task.command, task.job_type, task.argument):
swarming.push_swarming_task(request)

request = swarming.create_new_task_request(task.command, task.job_type,
task.argument)
if not request:
unscheduled_tasks.append(task)
continue

os_val = self._get_os_dimension(request)
if not os_val:
logs.error(
f'[Swarming] Failed to find OS dimension for job {task.job_type}.'
)
unscheduled_tasks.append(task)
continue

# Check backpressure
count_request = swarming_pb2.TasksCountRequest( # pylint: disable=no-member
tags=['pool:chrome-sec-clusterfuzz', f'os:{os_val}'],
state=swarming_pb2.QUERY_PENDING) # pylint: disable=no-member

if self._is_backpressure_applied(count_request):
unscheduled_tasks.extend(remote_tasks[i:])
break

self._get_api().push_task(request)
except Exception: # pylint: disable=broad-except
logs.error(
f'Failed to push task to Swarming: {task.command}, {task.job_type}.'
Expand Down
Loading
Loading