Skip to content
3 changes: 3 additions & 0 deletions configs/test/batch/batch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ mapping:
name: east4-network2
weight: 1
project: 'test-clusterfuzz'
queue_check_regions:
- us-central1
- us-east4
subconfigs:
central1-network1:
region: 'us-central1'
Expand Down
39 changes: 36 additions & 3 deletions src/clusterfuzz/_internal/base/external_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

MEMCACHE_TTL_IN_SECONDS = 15 * 60

# OSS-Fuzz issue tracker CC group
OSS_FUZZ_CC_GROUP_SUFFIX = '-ccs@oss-fuzz.com'


def _fuzzers_for_job(job_type, include_parents):
"""Return all fuzzers that have the job associated.
Expand Down Expand Up @@ -198,7 +201,10 @@ def _allowed_users_for_entity(name, entity_kind, auto_cc=None):
return sorted(allowed_users)


def _cc_users_for_entity(name, entity_type, security_flag):
def _cc_users_for_entity(name,
entity_type,
security_flag,
allow_cc_group_for_job=True):
"""Return CC users for entity."""
users = _allowed_users_for_entity(name, entity_type,
data_types.AutoCCType.ALL)
Expand All @@ -208,6 +214,20 @@ def _cc_users_for_entity(name, entity_type, security_flag):
_allowed_users_for_entity(name, entity_type,
data_types.AutoCCType.SECURITY))

if (entity_type != data_types.PermissionEntityKind.JOB or
not allow_cc_group_for_job):
return sorted(users)

# CC group is only available for jobs, as it is not possible to infer the
# project from the other permission entity kinds alone.
users_in_cc_group = _allowed_users_for_entity(
name, entity_type, data_types.AutoCCType.USE_CC_GROUP)
if users_in_cc_group:
# Assume users are synced with the project group.
group_name = get_cc_group_from_job(name)
if group_name:
users.append(group_name)

return sorted(users)


Expand Down Expand Up @@ -336,15 +356,28 @@ def is_upload_allowed_for_user(user_email):
return bool(permissions.get())


def cc_users_for_job(job_type, security_flag):
def cc_users_for_job(job_type, security_flag, allow_cc_group=True):
"""Return external users that should be CC'ed according to the given rule.

Args:
job_type: The name of the job
security_flag: Whether or not the CC is for a security issue.
allow_cc_group: Whether to allow including the project cc group from the
job, if exists any user with the use cc group auto_cc type.

Returns:
A list of user emails that should be CC'ed.
"""
return _cc_users_for_entity(job_type, data_types.PermissionEntityKind.JOB,
security_flag)
security_flag, allow_cc_group)


def get_cc_group_from_job(job_type: str) -> str:
"""Docstring for get_cc_group_from_entity"""
project_name = data_handler.get_project_name(job_type)
return get_oss_fuzz_project_cc_group(project_name)


def get_oss_fuzz_project_cc_group(project_name: str) -> str | None:
"""Return oss-fuzz issue tracker CC group email for a project."""
return f'{project_name}{OSS_FUZZ_CC_GROUP_SUFFIX}'
25 changes: 25 additions & 0 deletions src/clusterfuzz/_internal/base/memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import functools
import json
import threading
import time

from clusterfuzz._internal.base import persistent_cache
from clusterfuzz._internal.metrics import logs
Expand Down Expand Up @@ -89,6 +90,30 @@ def get_key(self, func, args, kwargs):
return _default_key(func, args, kwargs)


class InMemory(FifoInMemory):
"""In-memory caching engine with TTL."""

def __init__(self, ttl_in_seconds, capacity=1000):
super().__init__(capacity)
self.ttl_in_seconds = ttl_in_seconds

def put(self, key, value):
"""Put (key, value) into cache."""
super().put(key, (value, time.time() + self.ttl_in_seconds))

def get(self, key):
"""Get the value from cache."""
entry = super().get(key)
if entry is None:
return None

value, expiry = entry
if expiry < time.time():
return None

return value


class FifoOnDisk:
"""On-disk caching engine."""

Expand Down
125 changes: 109 additions & 16 deletions src/clusterfuzz/_internal/batch/service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
Expand All @@ -18,14 +18,19 @@
and provides a simple interface for scheduling ClusterFuzz tasks.
"""
import collections
import json
import random
import threading
from typing import Dict
from typing import List
from typing import Tuple
import urllib.request
import uuid

import google.auth.transport.requests
from google.cloud import batch_v1 as batch

from clusterfuzz._internal.base import memoize
from clusterfuzz._internal.base import retry
from clusterfuzz._internal.base import tasks
from clusterfuzz._internal.base import utils
Expand Down Expand Up @@ -65,6 +70,13 @@
# See https://cloud.google.com/batch/quotas#job_limits
MAX_CONCURRENT_VMS_PER_JOB = 1000

MAX_QUEUE_SIZE = 100


class AllRegionsOverloadedError(Exception):
"""Raised when all batch regions are overloaded."""


_local = threading.local()

DEFAULT_RETRY_COUNT = 0
Expand Down Expand Up @@ -184,14 +196,58 @@ def count_queued_or_scheduled_tasks(project: str,
return (queued, scheduled)


@memoize.wrap(memoize.InMemory(60))
def get_region_load(project: str, region: str) -> int:
"""Gets the current load (queued and scheduled jobs) for a region."""
creds, _ = credentials.get_default()
if not creds.valid:
creds.refresh(google.auth.transport.requests.Request())

headers = {
'Authorization': f'Bearer {creds.token}',
'Content-Type': 'application/json'
}

try:
url = (f'https://batch.googleapis.com/v1alpha/projects/{project}/locations/'
f'{region}/jobs:countByState?states=QUEUED')
req = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(req) as response:
if response.status != 200:
logs.error(
f'Batch countByState failed: {response.status} {response.read()}')
return 0

data = json.loads(response.read())
logs.info(f'Batch countByState response for {region}: {data}')
# The API returns a list of state counts.
# Example: { "jobCounts": { "QUEUED": "10" } }
total = 0

# Log data for debugging first few times if needed, or just rely on structure.
# We'll assume the structure is standard for Google APIs.
job_counts = data.get('jobCounts', {})
for state, count in job_counts.items():
count = int(count)
if state == 'QUEUED':
total += count
else:
logs.error(f'Unknown state: {state}')

return total
except Exception as e:
logs.error(f'Failed to get region load for {region}: {e}')
return 0


def _get_batch_config():
"""Returns the batch config. This function was made to make mocking easier."""
return local_config.BatchConfig()


def is_remote_task(command: str, job_name: str) -> bool:
"""Returns whether a task is configured to run remotely on GCP Batch.

This is determined by checking if a valid batch workload specification can
be found for the given command and job type.
"""
Expand Down Expand Up @@ -242,15 +298,46 @@ def _get_config_names(batch_tasks: List[remote_task_types.RemoteTask]):


def _get_subconfig(batch_config, instance_spec):
# TODO(metzman): Make this pick one at random or based on conditions.
all_subconfigs = batch_config.get('subconfigs', {})
instance_subconfigs = instance_spec['subconfigs']
weighted_subconfigs = [
WeightedSubconfig(subconfig['name'], subconfig['weight'])
for subconfig in instance_subconfigs
]
weighted_subconfig = utils.random_weighted_choice(weighted_subconfigs)
return all_subconfigs[weighted_subconfig.name]

queue_check_regions = batch_config.get('queue_check_regions')
if not queue_check_regions:
logs.info(
'Skipping batch load check because queue_check_regions is not configured.'
)
weighted_subconfigs = [
WeightedSubconfig(subconfig['name'], subconfig['weight'])
for subconfig in instance_subconfigs
]
weighted_subconfig = utils.random_weighted_choice(weighted_subconfigs)
return all_subconfigs[weighted_subconfig.name]

# Check load for configured regions.
healthy_subconfigs = []
project = batch_config.get('project')

for subconfig in instance_subconfigs:
name = subconfig['name']
conf = all_subconfigs[name]
region = conf['region']

if region in queue_check_regions:
load = get_region_load(project, region)
logs.info(f'Region {region} has {load} queued jobs.')
if load >= MAX_QUEUE_SIZE:
logs.info(f'Region {region} overloaded (load={load}). Skipping.')
continue

healthy_subconfigs.append(name)

if not healthy_subconfigs:
logs.error('All candidate regions are overloaded.')
raise AllRegionsOverloadedError('All candidate regions are overloaded.')

# Randomly pick one from healthy regions to avoid thundering herd.
chosen_name = random.choice(healthy_subconfigs)
return all_subconfigs[chosen_name]


def _get_specs_from_config(
Expand All @@ -277,7 +364,6 @@ def _get_specs_from_config(
versioned_images_map = instance_spec.get('versioned_docker_images')
if (base_os_version and versioned_images_map and
base_os_version in versioned_images_map):
# New path: Use the versioned image if specified and available.
docker_image_uri = versioned_images_map[base_os_version]
else:
# Fallback/legacy path: Use the original docker_image key.
Expand Down Expand Up @@ -324,7 +410,7 @@ def _get_specs_from_config(

class GcpBatchService(remote_task_types.RemoteTaskInterface):
"""A high-level service for creating and managing remote tasks.

This service provides a simple interface for scheduling ClusterFuzz tasks on
GCP Batch. It handles the details of creating batch jobs and tasks, and
provides a way to check if a task is configured to run remotely.
Expand Down Expand Up @@ -383,20 +469,27 @@ def create_utask_main_job(self, module: str, job_type: str,
def create_utask_main_jobs(self,
remote_tasks: List[remote_task_types.RemoteTask]):
"""Creates a batch job for a list of uworker main tasks.

This method groups the tasks by their workload specification and creates a
separate batch job for each group. This allows tasks with similar
requirements to be processed together, which can improve efficiency.
"""
job_specs = collections.defaultdict(list)
specs = _get_specs_from_config(remote_tasks)
try:
specs = _get_specs_from_config(remote_tasks)

# Return the remote tasks as uncreated task
# if all regions are overloaded
except AllRegionsOverloadedError:
return remote_tasks

for remote_task in remote_tasks:
logs.info(f'Scheduling {remote_task.command}, {remote_task.job_type}.')
spec = specs[(remote_task.command, remote_task.job_type)]
job_specs[spec].append(remote_task.input_download_url)

logs.info('Creating batch jobs.')
logs.info('Batching utask_mains.')

for spec, input_urls in job_specs.items():
for input_urls_portion in utils.batched(input_urls,
MAX_CONCURRENT_VMS_PER_JOB - 1):
Expand Down
7 changes: 4 additions & 3 deletions src/clusterfuzz/_internal/bot/tasks/task_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,10 @@ def execute(self, task_argument, job_type, uworker_env):
utask_main_queue_size = tasks.get_utask_main_queue_size()

utask_main_queue_limit = UTASK_MAIN_QUEUE_LIMIT_DEFAULT
utask_flag = feature_flags.FeatureFlags.UTASK_MAIN_QUEUE_LIMIT.flag
if utask_flag and utask_flag.enabled:
utask_main_queue_limit = utask_flag.content
utask_flag = feature_flags.FeatureFlags.UTASK_MAIN_QUEUE_LIMIT
if utask_flag.enabled and utask_flag.content:
utask_main_queue_limit = int(utask_flag.content)

if utask_main_queue_size > utask_main_queue_limit:
base_os_version = environment.get_value('BASE_OS_VERSION')
queue_name = UTASK_MAIN_QUEUE if not base_os_version else \
Expand Down
Loading
Loading