Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/clusterfuzz/_internal/protos/swarming_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2026 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self):
self._service_map = {
adapter.id: adapter.service()
for adapter in remote_task_adapters.RemoteTaskAdapters
if adapter.feature_flag.enabled
}
self._adapters = remote_task_adapters.RemoteTaskAdapters

Expand Down
51 changes: 5 additions & 46 deletions src/clusterfuzz/_internal/swarming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
import json
import uuid

from google.auth.transport import requests
from google.protobuf import json_format

from clusterfuzz._internal.base import utils
from clusterfuzz._internal.base.errors import BadConfigError
from clusterfuzz._internal.base.feature_flags import FeatureFlags
Expand All @@ -31,11 +28,6 @@
from clusterfuzz._internal.protos import swarming_pb2
from clusterfuzz._internal.system import environment

_SWARMING_SCOPES = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email'
]


def is_swarming_task(job_name: str, job: data_types.Job | None = None) -> bool:
"""Returns True if the task is supposed to run on swarming."""
Expand All @@ -54,7 +46,7 @@ def is_swarming_task(job_name: str, job: data_types.Job | None = None) -> bool:
logs.info('[Swarming DEBUG] No swarming env var', job_name=job_name)
return False

swarming_config = _get_swarming_config()
swarming_config = get_swarming_config()
if swarming_config is None:
logs.warning(
"""[Swarming DEBUG] current task is not suitable for swarming.
Expand All @@ -74,7 +66,7 @@ def _get_task_name(job_name: str):
return f't-{str(uuid.uuid4()).lower()}-{job_name}'


def _get_swarming_config() -> local_config.SwarmingConfig | None:
def get_swarming_config() -> local_config.SwarmingConfig | None:
"""Returns the swarming config."""
try:
return local_config.SwarmingConfig()
Expand All @@ -87,15 +79,15 @@ def _get_task_dimensions(job: data_types.Job, platform_specific_dimensions: list
) -> list[swarming_pb2.StringPair]: # pylint: disable=no-member
""" Gets all swarming dimensions for a task.
Job dimensions have more precedence than static dimensions"""
swarming_config = _get_swarming_config()
swarming_config = get_swarming_config()
if not swarming_config:
logs.error(
'[Swarming] No dimensions set. Reason: failed to retrieve config')
return []

unique_dimensions = {}
unique_dimensions['os'] = str(job.platform).capitalize()
unique_dimensions['pool'] = _get_swarming_config().get('swarming_pool')
unique_dimensions['pool'] = swarming_config.get('swarming_pool')

for dimension in platform_specific_dimensions:
unique_dimensions[dimension['key'].lower()] = dimension['value']
Expand Down Expand Up @@ -202,7 +194,7 @@ def create_new_task_request(command: str, job_name: str, download_url: str
if job is None:
return None

swarming_config = _get_swarming_config()
swarming_config = get_swarming_config()
if not swarming_config:
return None

Expand Down Expand Up @@ -255,36 +247,3 @@ def create_new_task_request(command: str, job_name: str, download_url: str
])

return new_task_request


def push_swarming_task(task_request: swarming_pb2.NewTaskRequest): # pylint: disable=no-member
"""Schedules a task on swarming."""
swarming_config = _get_swarming_config()
if not swarming_config:
logs.error(
'[Swarming] Failed to push task into swarming. Reason: No config.')
return
creds = credentials.get_scoped_service_account_credentials(_SWARMING_SCOPES)
if not creds:
logs.error(
'[Swarming] Failed to push task into swarming. Reason: No credentials.')
return

if not creds.token:
creds.refresh(requests.Request())

headers = {
'Accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': f'Bearer {creds.token}'
}
swarming_server = _get_swarming_config().get('swarming_server')
url = f'https://{swarming_server}/prpc/swarming.v2.Tasks/NewTask'
message_body = json_format.MessageToJson(task_request)
logs.info(
f"""[Swarming] Pushing task {task_request.name}
as {creds.service_account_email}""",
url=url,
body=message_body)
response = utils.post_url(url=url, data=message_body, headers=headers)
logs.info(f'[Swarming] Response from {task_request.name}', response=response)
163 changes: 163 additions & 0 deletions src/clusterfuzz/_internal/swarming/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Swarming pRPC API client."""

from typing import Optional

from google.auth import exceptions as auth_exceptions
from google.auth.transport import requests
from google.protobuf import json_format

from clusterfuzz._internal.base import utils
from clusterfuzz._internal.config.local_config import SwarmingConfig
from clusterfuzz._internal.google_cloud_utils import credentials
from clusterfuzz._internal.metrics import logs
from clusterfuzz._internal.protos import swarming_pb2
from clusterfuzz._internal.swarming import get_swarming_config

_SWARMING_SCOPES = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email'
]

_COUNT_TASKS_ENDPOINT = 'swarming.v2.Tasks/CountTasks'
_NEW_TASK_ENDPOINT = 'swarming.v2.Tasks/NewTask'


class SwarmingApi:
"""Client for Swarming pRPC API."""

_config: SwarmingConfig
_base_url: str = ""

def __init__(self, config: SwarmingConfig):
self._config = config
self._base_url = f"https://{self._config.get('swarming_server')}/prpc/"

@staticmethod
def create() -> Optional['SwarmingApi']:
"""Creates a SwarmingApi instance if config is available.

Returns:
A SwarmingApi instance if config is available, None otherwise.
"""
config = get_swarming_config()
if config is None:
return None

return SwarmingApi(config)

def _get_token(self) -> str:
"""Gets a valid token for the Swarming API. Returns "" if it fails."""
try:
creds = credentials.get_scoped_service_account_credentials(
_SWARMING_SCOPES)
if not creds:
logs.error('[Swarming] Failed to get credentials. None found.')
return ""

if not creds.token:
creds.refresh(requests.Request())

return creds.token or ""
except (auth_exceptions.DefaultCredentialsError,
auth_exceptions.RefreshError, auth_exceptions.TransportError) as e:
logs.error(f'[Swarming] Failed to get token with: {e}.')
return ""

def _get_headers(self) -> dict[str, str]:
"""Checks config and returns headers for pRPC request.

Returns:
A dict containing headers.
"""
token = self._get_token()

return {
'Accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': f'Bearer {token}'
}

def _make_request(self, endpoint: str, body: str) -> str | None:
"""Makes a pRPC request to the Swarming API.

Args:
endpoint: The pRPC endpoint (e.g. "swarming.v2.Tasks/NewTask").
body: The JSON body of the request.

Returns:
The raw JSON response string from the server, or None if the response is
empty.

Raises:
requests.exceptions.HTTPError: If the request fails with a 4xx or 5xx
status code.
"""
headers = self._get_headers()

url = f'{self._base_url}{endpoint}'
logs.info(
f"[Swarming] Making request to {url}",
url=self._base_url,
endpoint=endpoint,
body=body,
headers=headers)
response = utils.post_url(url=url, data=body, headers=headers)
if not response:
logs.error(f"[Swarming] Failed to make request to {url}. Empty response")
return None
return response

def push_task(self, task_request: swarming_pb2.NewTaskRequest) -> str | None: # pylint: disable=no-member
"""Schedules a task on swarming.

Args:
task_request: The NewTaskRequest proto message.

Returns:
The raw JSON response string from the server, or None if the response is
empty.

Raises:
requests.exceptions.HTTPError: If the request fails with a 4xx or 5xx
status code.
"""
message_body = json_format.MessageToJson(task_request)

response = self._make_request(_NEW_TASK_ENDPOINT, message_body)
logs.info(
f'[Swarming] Response from {task_request.name}', response=response)
return response

def count_tasks(self,
count_request: swarming_pb2.TasksCountRequest) -> str | None: # pylint: disable=no-member
"""Counts tasks on swarming.

Args:
count_request: The TasksCountRequest proto message.

Returns:
The raw JSON response string from the server, or None if the response is
empty.

Raises:
requests.exceptions.HTTPError: If the request fails with a 4xx or 5xx
status code.
"""
message_body = json_format.MessageToJson(count_request)

response = self._make_request(_COUNT_TASKS_ENDPOINT, message_body)
logs.info('[Swarming] Response from CountTasks', response=response)
return response
26 changes: 22 additions & 4 deletions src/clusterfuzz/_internal/swarming/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Swarming service."""
from requests.exceptions import HTTPError

from clusterfuzz._internal import swarming
from clusterfuzz._internal.base.tasks import task_utils
from clusterfuzz._internal.metrics import logs
from clusterfuzz._internal.remote_task import remote_task_types
from clusterfuzz._internal.swarming.api import SwarmingApi


class SwarmingService(remote_task_types.RemoteTaskInterface):
"""Remote task service implementation for Swarming."""

_api: SwarmingApi

def __init__(self):
api = SwarmingApi.create()
if api is None:
raise ValueError(
'Failed to instantiate SwarmingApi. Swarming config not available.')
self._api = api

def create_utask_main_job(self, module: str, job_type: str,
input_download_url: str):
"""Creates a single swarming task for a uworker main task."""
Expand Down Expand Up @@ -51,10 +62,17 @@ def create_utask_main_jobs(self,
continue
if request := swarming.create_new_task_request(
task.command, task.job_type, task.argument):
swarming.push_swarming_task(request)
except Exception: # pylint: disable=broad-except
self._api.push_task(request)
except HTTPError as api_failure:
logs.error(
f'''Failed to push task to Swarming: {task.command}, {task.job_type}
. Reason: {api_failure}.
''')
unscheduled_tasks.append(task)
except Exception as e: # pylint: disable=broad-except
logs.error(
f'Failed to push task to Swarming: {task.command}, {task.job_type}.'
)
f'''Failed to push task to Swarming: {task.command}, {task.job_type}
. Unexpected exception: {e}.
''')
unscheduled_tasks.append(task)
return unscheduled_tasks
Comment thread
IvanBM18 marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ def setUp(self):
mock.Mock(
id='kubernetes',
service=mock.Mock(return_value=self.mock_k8s_service),
feature_flag=None,
feature_flag=mock.Mock(enabled=True),
default_weight=0.0),
'GCP_BATCH':
mock.Mock(
id='gcp_batch',
service=mock.Mock(return_value=self.mock_gcp_batch_service),
feature_flag=None,
feature_flag=mock.Mock(enabled=True),
default_weight=1.0),
'SWARMING':
mock.Mock(
id='swarming',
service=mock.Mock(return_value=self.mock_swarming_service),
feature_flag=None,
feature_flag=mock.Mock(enabled=True),
default_weight=0.0),
})
self.patcher.start()
Expand Down Expand Up @@ -488,19 +488,19 @@ def setUp(self):
mock.Mock(
id='kubernetes',
service=mock.Mock(),
feature_flag=None,
feature_flag=mock.Mock(enabled=True),
default_weight=0.0),
'GCP_BATCH':
mock.Mock(
id='gcp_batch',
service=mock.Mock(),
feature_flag=None,
feature_flag=mock.Mock(enabled=True),
default_weight=1.0),
'SWARMING':
mock.Mock(
id='swarming',
service=mock.Mock(return_value=self.mock_swarming_service),
feature_flag=None,
feature_flag=mock.Mock(enabled=True),
default_weight=0.0),
})
self.patcher.start()
Expand Down
Loading
Loading