diff --git a/src/clusterfuzz/_internal/swarming/__init__.py b/src/clusterfuzz/_internal/swarming/__init__.py index 0369abcb62..165c3e50e0 100644 --- a/src/clusterfuzz/_internal/swarming/__init__.py +++ b/src/clusterfuzz/_internal/swarming/__init__.py @@ -17,9 +17,6 @@ import json import uuid -from google.auth.transport import requests -from google.protobuf import json_format - from clusterfuzz._internal.base import utils from clusterfuzz._internal.base.errors import BadConfigError from clusterfuzz._internal.base.feature_flags import FeatureFlags @@ -31,11 +28,6 @@ from clusterfuzz._internal.protos import swarming_pb2 from clusterfuzz._internal.system import environment -_SWARMING_SCOPES = [ - 'https://www.googleapis.com/auth/cloud-platform', - 'https://www.googleapis.com/auth/userinfo.email' -] - def is_swarming_task(job_name: str, job: data_types.Job | None = None) -> bool: """Returns True if the task is supposed to run on swarming.""" @@ -54,7 +46,7 @@ def is_swarming_task(job_name: str, job: data_types.Job | None = None) -> bool: logs.info('[Swarming DEBUG] No swarming env var', job_name=job_name) return False - swarming_config = _get_swarming_config() + swarming_config = get_swarming_config() if swarming_config is None: logs.warning( """[Swarming DEBUG] current task is not suitable for swarming. @@ -74,7 +66,7 @@ def _get_task_name(job_name: str): return f't-{str(uuid.uuid4()).lower()}-{job_name}' -def _get_swarming_config() -> local_config.SwarmingConfig | None: +def get_swarming_config() -> local_config.SwarmingConfig | None: """Returns the swarming config.""" try: return local_config.SwarmingConfig() @@ -87,7 +79,7 @@ def _get_task_dimensions(job: data_types.Job, platform_specific_dimensions: list ) -> list[swarming_pb2.StringPair]: # pylint: disable=no-member """ Gets all swarming dimensions for a task. Job dimensions have more precedence than static dimensions""" - swarming_config = _get_swarming_config() + swarming_config = get_swarming_config() if not swarming_config: logs.error( '[Swarming] No dimensions set. Reason: failed to retrieve config') @@ -95,7 +87,7 @@ def _get_task_dimensions(job: data_types.Job, platform_specific_dimensions: list unique_dimensions = {} unique_dimensions['os'] = str(job.platform).capitalize() - unique_dimensions['pool'] = _get_swarming_config().get('swarming_pool') + unique_dimensions['pool'] = get_swarming_config().get('swarming_pool') for dimension in platform_specific_dimensions: unique_dimensions[dimension['key'].lower()] = dimension['value'] @@ -202,7 +194,7 @@ def create_new_task_request(command: str, job_name: str, download_url: str if job is None: return None - swarming_config = _get_swarming_config() + swarming_config = get_swarming_config() if not swarming_config: return None @@ -255,36 +247,3 @@ def create_new_task_request(command: str, job_name: str, download_url: str ]) return new_task_request - - -def push_swarming_task(task_request: swarming_pb2.NewTaskRequest): # pylint: disable=no-member - """Schedules a task on swarming.""" - swarming_config = _get_swarming_config() - if not swarming_config: - logs.error( - '[Swarming] Failed to push task into swarming. Reason: No config.') - return - creds = credentials.get_scoped_service_account_credentials(_SWARMING_SCOPES) - if not creds: - logs.error( - '[Swarming] Failed to push task into swarming. Reason: No credentials.') - return - - if not creds.token: - creds.refresh(requests.Request()) - - headers = { - 'Accept': 'application/json', - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {creds.token}' - } - swarming_server = _get_swarming_config().get('swarming_server') - url = f'https://{swarming_server}/prpc/swarming.v2.Tasks/NewTask' - message_body = json_format.MessageToJson(task_request) - logs.info( - f"""[Swarming] Pushing task {task_request.name} - as {creds.service_account_email}""", - url=url, - body=message_body) - response = utils.post_url(url=url, data=message_body, headers=headers) - logs.info(f'[Swarming] Response from {task_request.name}', response=response) diff --git a/src/clusterfuzz/_internal/swarming/api.py b/src/clusterfuzz/_internal/swarming/api.py new file mode 100644 index 0000000000..c64f3285fc --- /dev/null +++ b/src/clusterfuzz/_internal/swarming/api.py @@ -0,0 +1,135 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swarming pRPC API client.""" + +from google.auth.transport import requests +from google.protobuf import json_format + +from clusterfuzz._internal.base import utils +from clusterfuzz._internal.config.local_config import SwarmingConfig +from clusterfuzz._internal.google_cloud_utils import credentials +from clusterfuzz._internal.metrics import logs +from clusterfuzz._internal.protos import swarming_pb2 +from clusterfuzz._internal.swarming import get_swarming_config + +_SWARMING_SCOPES = [ + 'https://www.googleapis.com/auth/cloud-platform', + 'https://www.googleapis.com/auth/userinfo.email' +] + +_COUNT_TASKS_ENDPOINT = 'swarming.v2.Tasks/CountTasks' +_NEW_TASK_ENDPOINT = 'swarming.v2.Tasks/NewTask' + + +class SwarmingAPI: + """Client for Swarming pRPC API.""" + + _config: SwarmingConfig = None + _base_url: str = "" + + def __init__(self): + self._config = get_swarming_config() + if self._config: + self._base_url = f"https://{self._config.get('swarming_server')}/prpc/" + + def _get_headers(self) -> dict[str, str]: + """Checks config and returns headers for pRPC request. + + Returns: + A dict containing headers, or empty dict if config is missing or + auth fails. + """ + if not self._config: + logs.error('[Swarming] No config available.') + return {} + + creds = credentials.get_scoped_service_account_credentials(_SWARMING_SCOPES) + if not creds: + logs.error('[Swarming] Failed to get credentials.') + return {} + + if not creds.token: + creds.refresh(requests.Request()) + + return { + 'Accept': 'application/json', + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {creds.token}' + } + + def _make_request(self, endpoint: str, body: str) -> str | None: + """Makes a pRPC request to the Swarming API. + + Args: + endpoint: The pRPC endpoint (e.g. "swarming.v2.Tasks/NewTask"). + body: The JSON body of the request. + + Returns: + The raw JSON response string from the server, or None if the request + could not be made (e.g. missing config, auth failure) or failed. + """ + headers = self._get_headers() + if not headers: + return None + + url = f'{self._base_url}{endpoint}' + response = utils.post_url(url=url, data=body, headers=headers) + if not response: + logs.error(f"[Swarming] Failed to make request to {url}") + return None + return response + + def push_task(self, task_request: swarming_pb2.NewTaskRequest) -> str | None: # pylint: disable=no-member + """Schedules a task on swarming. + + Args: + task_request: The NewTaskRequest proto message. + + Returns: + The raw JSON response string from the server, or None if the request + could not be made (e.g. missing config, auth failure) or failed. + """ + message_body = json_format.MessageToJson(task_request) + logs.info( + f"[Swarming] Pushing task {task_request.name}", + url=self._base_url, + endpoint=_NEW_TASK_ENDPOINT, + body=message_body) + + response = self._make_request(_NEW_TASK_ENDPOINT, message_body) + logs.info( + f'[Swarming] Response from {task_request.name}', response=response) + return response + + def count_tasks(self, + count_request: swarming_pb2.TasksCountRequest) -> str | None: # pylint: disable=no-member + """Counts tasks on swarming. + + Args: + count_request: The TasksCountRequest proto message. + + Returns: + The raw JSON response string from the server, or None if the request + could not be made (e.g. missing config, auth failure) or failed. + """ + message_body = json_format.MessageToJson(count_request) + logs.info( + "[Swarming] Counting tasks in queue", + url=self._base_url, + endpoint=_COUNT_TASKS_ENDPOINT, + body=message_body) + + response = self._make_request(_COUNT_TASKS_ENDPOINT, message_body) + logs.info('[Swarming] Response from CountTasks', response=response) + return response diff --git a/src/clusterfuzz/_internal/swarming/service.py b/src/clusterfuzz/_internal/swarming/service.py index 30c1bad677..4c09306feb 100644 --- a/src/clusterfuzz/_internal/swarming/service.py +++ b/src/clusterfuzz/_internal/swarming/service.py @@ -13,15 +13,59 @@ # limitations under the License. """Swarming service.""" +import json + from clusterfuzz._internal import swarming from clusterfuzz._internal.base.tasks import task_utils from clusterfuzz._internal.metrics import logs +from clusterfuzz._internal.protos import swarming_pb2 from clusterfuzz._internal.remote_task import remote_task_types +from clusterfuzz._internal.swarming.api import SwarmingAPI class SwarmingService(remote_task_types.RemoteTaskInterface): """Remote task service implementation for Swarming.""" + _api: SwarmingAPI = None + MAX_PENDING_TASKS = 50 + + def _get_api(self) -> SwarmingAPI: + """Returns the Swarming API instance.""" + if not self._api: + self._api = SwarmingAPI() + return self._api + + def _get_os_dimension(self, + request: swarming_pb2.NewTaskRequest) -> str | None: # pylint: disable=no-member + """Extracts the OS dimension from the task request.""" + for dimension in request.task_slices[0].properties.dimensions: + if dimension.key == 'os': + return dimension.value + return None + + def _is_backpressure_applied( + self, count_request: swarming_pb2.TasksCountRequest) -> bool: # pylint: disable=no-member + """Checks if backpressure should be applied based on pending tasks count. + + Returns True if backpressure is applied or if the check fails (Fail Closed). + """ + try: + response_str = self._get_api().count_tasks(count_request) + if not response_str: + raise RuntimeError("Empty response from CountTasks") + + response_json = json.loads(response_str) + count = int(response_json.get('count', 0)) + + if count >= self.MAX_PENDING_TASKS: + logs.info(f'[Swarming] Backpressure applied. Queue size: {count}. ' + 'Stopping scheduling.') + return True + return False + except Exception as e: + logs.error(f'[Swarming] Failed to check backpressure (Fail Closed): {e}') + return True # Fail closed, always fails if swarming request fails + def create_utask_main_job(self, module: str, job_type: str, input_download_url: str): """Creates a single swarming task for a uworker main task.""" @@ -44,14 +88,36 @@ def create_utask_main_jobs(self, """ unscheduled_tasks = [] logs.info(f'[Swarming] Pushing {len(remote_tasks)} tasks trough service.') - for task in remote_tasks: + for i, task in enumerate(remote_tasks): try: if not swarming.is_swarming_task(task.job_type): unscheduled_tasks.append(task) continue - if request := swarming.create_new_task_request( - task.command, task.job_type, task.argument): - swarming.push_swarming_task(request) + + request = swarming.create_new_task_request(task.command, task.job_type, + task.argument) + if not request: + unscheduled_tasks.append(task) + continue + + os_val = self._get_os_dimension(request) + if not os_val: + logs.error( + f'[Swarming] Failed to find OS dimension for job {task.job_type}.' + ) + unscheduled_tasks.append(task) + continue + + # Check backpressure + count_request = swarming_pb2.TasksCountRequest( # pylint: disable=no-member + tags=['pool:chrome-sec-clusterfuzz', f'os:{os_val}'], + state=swarming_pb2.QUERY_PENDING) # pylint: disable=no-member + + if self._is_backpressure_applied(count_request): + unscheduled_tasks.extend(remote_tasks[i:]) + break + + self._get_api().push_task(request) except Exception: # pylint: disable=broad-except logs.error( f'Failed to push task to Swarming: {task.command}, {task.job_type}.' diff --git a/src/clusterfuzz/_internal/tests/core/swarming/api_test.py b/src/clusterfuzz/_internal/tests/core/swarming/api_test.py new file mode 100644 index 0000000000..7be899ae57 --- /dev/null +++ b/src/clusterfuzz/_internal/tests/core/swarming/api_test.py @@ -0,0 +1,98 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for api.py.""" +import unittest +from unittest import mock + +from google.protobuf import json_format + +from clusterfuzz._internal.protos import swarming_pb2 +from clusterfuzz._internal.swarming.api import SwarmingAPI +from clusterfuzz._internal.tests.test_libs import helpers + + +class SwarmingAPITest(unittest.TestCase): + """Tests for SwarmingAPI.""" + + def setUp(self): + helpers.patch(self, [ + 'clusterfuzz._internal.base.utils.post_url', + 'clusterfuzz._internal.google_cloud_utils.credentials.get_scoped_service_account_credentials', + 'google.auth.transport.requests.Request', + ]) + + self.mock_creds = mock.MagicMock() + self.mock_creds.token = 'fake_token' + self.mock.get_scoped_service_account_credentials.return_value = self.mock_creds + + self.api = SwarmingAPI() + + def test_push_task(self): + """Tests that push_task works as expected.""" + task_request = swarming_pb2.NewTaskRequest(name='test_task') + self.api.push_task(task_request) + + expected_headers = { + 'Accept': 'application/json', + 'Content-Type': 'application/json', + 'Authorization': 'Bearer fake_token' + } + expected_url = 'https://server-name/prpc/swarming.v2.Tasks/NewTask' + self.mock.post_url.assert_called_with( + url=expected_url, + data=json_format.MessageToJson(task_request), + headers=expected_headers) + + def test_count_tasks(self): + """Tests that count_tasks works as expected.""" + count_request = swarming_pb2.TasksCountRequest(tags=['tag1']) + + # Mock response from post_url + self.mock.post_url.return_value = '{"count": 42}' + + response = self.api.count_tasks(count_request) + + expected_headers = { + 'Accept': 'application/json', + 'Content-Type': 'application/json', + 'Authorization': 'Bearer fake_token' + } + expected_url = 'https://server-name/prpc/swarming.v2.Tasks/CountTasks' + self.mock.post_url.assert_called_with( + url=expected_url, + data=json_format.MessageToJson(count_request), + headers=expected_headers) + + self.assertEqual(response, '{"count": 42}') + + def test_push_task_no_config(self): + """Tests that push_task fails when config is missing.""" + with mock.patch('clusterfuzz._internal.config.local_config.SwarmingConfig' + ) as mock_config: + mock_config.side_effect = ValueError('Failed to load') + api = SwarmingAPI() + response = api.push_task(swarming_pb2.NewTaskRequest()) + self.assertIsNone(response) + + def test_push_task_no_credentials(self): + """Tests that push_task fails when credentials are missing.""" + self.mock.get_scoped_service_account_credentials.return_value = None + response = self.api.push_task(swarming_pb2.NewTaskRequest()) + self.assertIsNone(response) + + def test_count_tasks_no_credentials(self): + """Tests that count_tasks fails when credentials are missing.""" + self.mock.get_scoped_service_account_credentials.return_value = None + response = self.api.count_tasks(swarming_pb2.TasksCountRequest()) + self.assertIsNone(response) diff --git a/src/clusterfuzz/_internal/tests/core/swarming/service_test.py b/src/clusterfuzz/_internal/tests/core/swarming/service_test.py index 234150cf5f..48a0769d82 100644 --- a/src/clusterfuzz/_internal/tests/core/swarming/service_test.py +++ b/src/clusterfuzz/_internal/tests/core/swarming/service_test.py @@ -27,15 +27,25 @@ class SwarmingServiceTest(unittest.TestCase): def setUp(self): helpers.patch(self, [ 'clusterfuzz._internal.swarming.is_swarming_task', - 'clusterfuzz._internal.swarming.push_swarming_task', + 'clusterfuzz._internal.swarming.service.SwarmingService._get_api', 'clusterfuzz._internal.swarming.create_new_task_request', 'clusterfuzz._internal.base.tasks.task_utils.get_command_from_module', 'clusterfuzz._internal.metrics.logs.error', 'clusterfuzz._internal.google_cloud_utils.compute_metadata.get', ]) self.service = service.SwarmingService() - self.mock.create_new_task_request.return_value = 'fake_request' + + self.mock_request = mock.MagicMock() + mock_dimension = mock.MagicMock() + mock_dimension.key = 'os' + mock_dimension.value = 'Linux' + self.mock_request.task_slices[0].properties.dimensions = [mock_dimension] + self.mock.create_new_task_request.return_value = self.mock_request + self.mock.get.return_value = None + self.mock_api = mock.MagicMock() + self.mock._get_api.return_value = self.mock_api # pylint: disable=protected-access + self.mock_api.count_tasks.return_value = '{"count": 0}' def test_create_utask_main_job_success(self): """Test creating a single task successfully.""" @@ -48,7 +58,7 @@ def test_create_utask_main_job_success(self): # Success returns None in this interface (consistent with GcpBatchService) self.assertIsNone(result) - self.mock.push_swarming_task.assert_called_once_with('fake_request') + self.mock_api.push_task.assert_called_once_with(self.mock_request) def test_create_utask_main_job_failure(self): """Test creating a single task that is not a swarming task.""" @@ -61,7 +71,7 @@ def test_create_utask_main_job_failure(self): # Failure returns the task itself self.assertIsInstance(result, remote_task_types.RemoteTask) self.assertEqual(result.command, 'fuzz') - self.mock.push_swarming_task.assert_not_called() + self.mock_api.push_task.assert_not_called() def test_create_utask_main_jobs_mixed_results(self): """Test creating multiple tasks with mixed success/failure.""" @@ -79,10 +89,10 @@ def test_create_utask_main_jobs_mixed_results(self): self.assertEqual(len(unscheduled), 1) self.assertEqual(unscheduled[0].job_type, 'job2') - self.assertEqual(self.mock.push_swarming_task.call_count, 2) - self.mock.push_swarming_task.assert_has_calls([ - mock.call('fake_request'), - mock.call('fake_request'), + self.assertEqual(self.mock_api.push_task.call_count, 2) + self.mock_api.push_task.assert_has_calls([ + mock.call(self.mock_request), + mock.call(self.mock_request), ]) def test_create_utask_main_jobs_all_success(self): @@ -96,7 +106,7 @@ def test_create_utask_main_jobs_all_success(self): unscheduled = self.service.create_utask_main_jobs(tasks) self.assertEqual(unscheduled, []) - self.assertEqual(self.mock.push_swarming_task.call_count, 2) + self.assertEqual(self.mock_api.push_task.call_count, 2) def test_create_utask_main_jobs_all_fail(self): """Test creating multiple tasks where all fail.""" @@ -109,13 +119,13 @@ def test_create_utask_main_jobs_all_fail(self): unscheduled = self.service.create_utask_main_jobs(tasks) self.assertEqual(unscheduled, tasks) - self.mock.push_swarming_task.assert_not_called() + self.mock_api.push_task.assert_not_called() def test_create_utask_main_jobs_empty(self): """Test creating tasks with an empty list.""" unscheduled = self.service.create_utask_main_jobs([]) self.assertEqual(unscheduled, []) - self.mock.push_swarming_task.assert_not_called() + self.mock_api.push_task.assert_not_called() def test_create_utask_main_jobs_exception(self): """Test creating tasks when push_swarming_task raises an exception.""" @@ -124,7 +134,7 @@ def test_create_utask_main_jobs_exception(self): ] self.mock.is_swarming_task.return_value = True - self.mock.push_swarming_task.side_effect = Exception('error') + self.mock_api.push_task.side_effect = Exception('error') unscheduled = self.service.create_utask_main_jobs(tasks) @@ -132,3 +142,38 @@ def test_create_utask_main_jobs_exception(self): self.assertEqual(unscheduled[0].job_type, 'job1') self.mock.error.assert_called_once_with( 'Failed to push task to Swarming: fuzz, job1.') + + def test_create_utask_main_jobs_backpressure(self): + """Test that backpressure stops scheduling.""" + tasks = [ + remote_task_types.RemoteTask('fuzz', 'job1', 'url1'), + remote_task_types.RemoteTask('fuzz', 'job2', 'url2'), + ] + self.mock.is_swarming_task.return_value = True + + self.mock_api.count_tasks.side_effect = ['{"count": 10}', '{"count": 50}'] + + unscheduled = self.service.create_utask_main_jobs(tasks) + + self.assertEqual(len(unscheduled), 1) + self.assertEqual(unscheduled[0].job_type, 'job2') + + self.assertEqual(self.mock_api.push_task.call_count, 1) + self.mock_api.push_task.assert_called_once_with(self.mock_request) + + def test_create_utask_main_jobs_count_tasks_failure(self): + """Test that count_tasks failure fails closed.""" + tasks = [ + remote_task_types.RemoteTask('fuzz', 'job1', 'url1'), + remote_task_types.RemoteTask('fuzz', 'job2', 'url2'), + ] + self.mock.is_swarming_task.return_value = True + self.mock_api.count_tasks.side_effect = Exception('api error') + + unscheduled = self.service.create_utask_main_jobs(tasks) + + self.assertEqual(len(unscheduled), 2) + self.assertEqual(unscheduled[0].job_type, 'job1') + self.assertEqual(unscheduled[1].job_type, 'job2') + + self.mock_api.push_task.assert_not_called() diff --git a/src/clusterfuzz/_internal/tests/core/swarming/swarming_test.py b/src/clusterfuzz/_internal/tests/core/swarming/swarming_test.py index c48d154058..fb8b99f67c 100644 --- a/src/clusterfuzz/_internal/tests/core/swarming/swarming_test.py +++ b/src/clusterfuzz/_internal/tests/core/swarming/swarming_test.py @@ -16,8 +16,6 @@ import unittest from unittest import mock -from google.protobuf import json_format - from clusterfuzz._internal import swarming from clusterfuzz._internal.datastore import data_types from clusterfuzz._internal.protos import swarming_pb2 @@ -252,109 +250,6 @@ def test_get_spec_from_config_for_fuzz_task(self): ]) self.assertEqual(spec, expected_spec) - def test_push_swarming_task(self): - """Tests that push_swarming_task works as expected.""" - mock_creds = mock.MagicMock() - mock_creds.token = 'fake_token' - self.mock.get_scoped_service_account_credentials.return_value = mock_creds - - job = data_types.Job(name='libfuzzer_chrome_asan', platform='LINUX') - job.put() - task_request = swarming.create_new_task_request('fuzz', job.name, - 'https://download_url') - swarming.push_swarming_task(task_request) - - expected_new_task_request = swarming_pb2.NewTaskRequest( - name='task_name', - priority=1, - realm='realm-name', - service_account='test-clusterfuzz-service-account-email', - task_slices=[ - swarming_pb2.TaskSlice( - expiration_secs=86400, - properties=swarming_pb2.TaskProperties( - command=[ - 'luci-auth', 'context', '--', './linux_entry_point.sh' - ], - dimensions=[ - swarming_pb2.StringPair( - key='os', value=str(job.platform).capitalize()), - swarming_pb2.StringPair(key='pool', value='pool-name') - ], - cipd_input=swarming_pb2.CipdInput(), # pylint: disable=no-member - cas_input_root=swarming_pb2.CASReference( - cas_instance= - 'projects/server-name/instances/instance_name', - digest=swarming_pb2.Digest( - hash='linux_entry_point_archive_hash', - size_bytes=1234)), - execution_timeout_secs=12345, - env=[ - swarming_pb2.StringPair( - key='DOCKER_IMAGE', - value= - 'gcr.io/clusterfuzz-images/base:a2f4dd6-202202070654' - ), - swarming_pb2.StringPair(key='UWORKER', value='True'), - swarming_pb2.StringPair( - key='SWARMING_BOT', value='True'), - swarming_pb2.StringPair(key='LOG_TO_GCP', value='True'), - swarming_pb2.StringPair(key='IS_K8S_ENV', value='True'), - swarming_pb2.StringPair( - key='DISABLE_MOUNTS', value='True'), - swarming_pb2.StringPair( - key='LOGGING_CLOUD_PROJECT_ID', value='project_id'), - swarming_pb2.StringPair( - key='DOCKER_ENV_VARS', - value= - ('{"DOCKER_IMAGE": "gcr.io/clusterfuzz-images/' - 'base:a2f4dd6-202202070654", "UWORKER": "True", ' - '"SWARMING_BOT": "True", "LOG_TO_GCP": "True", ' - '"IS_K8S_ENV": "True", "DISABLE_MOUNTS": "True", ' - '"LOGGING_CLOUD_PROJECT_ID": "project_id"}')), - ], - secret_bytes='https://download_url'.encode('utf-8'))) - ]) - - self.mock.get_scoped_service_account_credentials.assert_called_with( - swarming._SWARMING_SCOPES) # pylint: disable=protected-access - expected_headers = { - 'Accept': 'application/json', - 'Content-Type': 'application/json', - 'Authorization': 'Bearer fake_token' - } - expected_url = 'https://server-name/prpc/swarming.v2.Tasks/NewTask' - self.mock.post_url.assert_called_with( - url=expected_url, - data=json_format.MessageToJson(expected_new_task_request), - headers=expected_headers) - - def test_push_swarming_task_with_refresh(self): - """Tests that push_swarming_task refreshes credentials if token is missing.""" - mock_creds = mock.MagicMock() - mock_creds.token = None - self.mock.get_scoped_service_account_credentials.return_value = mock_creds - - def refresh_side_effect(_): - mock_creds.token = 'refreshed_token' - - mock_creds.refresh.side_effect = refresh_side_effect - - job = data_types.Job(name='libfuzzer_chrome_asan', platform='LINUX') - job.put() - request = swarming.create_new_task_request('fuzz', job.name, - 'https://download_url') - swarming.push_swarming_task(request) - - mock_creds.refresh.assert_called_with(self.mock.Request.return_value) - expected_headers = { - 'Accept': 'application/json', - 'Content-Type': 'application/json', - 'Authorization': 'Bearer refreshed_token' - } - self.assertEqual(self.mock.post_url.call_args[1]['headers'], - expected_headers) - def test_is_swarming_task(self): """Tests that is_swarming_task works as expected.""" job = data_types.Job(