Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions google/genai/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,20 @@ class UnknownApiResponseError(ValueError):
"""Raised when the response from the API cannot be parsed as JSON."""
pass


class FileProcessingError(Exception):
"""Error related to file processing in the API.

This exception is raised when a file fails to reach the ACTIVE state
required for using it in content generation requests.
"""

def __init__(
self, message: str, response_json: Optional[dict[str, Any]] = None
) -> None:
self.message = message
self.details = response_json or {}
super().__init__(message)


ExperimentalWarning = _common.ExperimentalWarning
85 changes: 85 additions & 0 deletions google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import json
import logging
import time
from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union
from urllib.parse import urlencode

Expand Down Expand Up @@ -4472,6 +4473,86 @@ def _Video_to_vertex(
return to_object


def _ensure_file_active(
api_client: BaseApiClient,
file_obj: types.File,
max_retries: int = 3,
retry_delay_seconds: int = 5,
) -> types.File:
"""Ensure a file object is in ACTIVE state before using it in content generation.

Args:
api_client: The API client to use for requests.
file_obj: The file object to check.
max_retries: Maximum number of retries for checking file state.
retry_delay_seconds: Delay between retries in seconds.

Returns:
The file object, refreshed if necessary.

Raises:
errors.FileProcessingError: If the file fails to become ACTIVE within the retry limit.
"""
if hasattr(file_obj, 'name') and file_obj.name and hasattr(file_obj, 'state'):
if file_obj.state == types.FileState.PROCESSING:
logger.info(
f'File {file_obj.name} is in PROCESSING state. Waiting for it to become ACTIVE.'
)
for attempt in range(max_retries):
time.sleep(retry_delay_seconds)
try:
file_id = file_obj.name.split('/')[-1]
response = api_client.request('GET', f'files/{file_id}', {}, None)
response_dict = {} if not response.body else json.loads(response.body)
refreshed_file = types.File._from_response(
response=response_dict, kwargs={}
)
logger.info(f'File {file_obj.name} state: {refreshed_file.state}')
if refreshed_file.state == types.FileState.ACTIVE:
return refreshed_file
if refreshed_file.state == types.FileState.FAILED:
error_msg = 'File processing failed'
if hasattr(refreshed_file, 'error') and refreshed_file.error:
error_msg = f'{error_msg}: {refreshed_file.error.message}'
raise errors.FileProcessingError(error_msg)
except errors.FileProcessingError:
raise
except Exception as e:
logger.warning(f'Error refreshing file state: {e}')
logger.warning(
f'File {file_obj.name} did not become ACTIVE after {max_retries} attempts. '
'This may cause the content generation to fail.'
)
return file_obj


def _process_contents_for_generation(
api_client: BaseApiClient,
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
) -> list[types.Content]:
"""Process the contents, ensuring all File objects are in the ACTIVE state.

Args:
api_client: The API client to use for requests.
contents: The contents to process.

Returns:
The processed contents.
"""
processed_contents = t.t_contents(contents)

def process_file_in_item(item: types.Content) -> types.Content:
if isinstance(item, types.Content):
if hasattr(item, 'parts') and item.parts:
for part in item.parts:
if hasattr(part, 'file_data') and part.file_data:
if isinstance(part.file_data, types.File):
part.file_data = _ensure_file_active(api_client, part.file_data)
return item

return [process_file_in_item(item) for item in processed_contents]


class Models(_api_module.BaseModule):

def _generate_content(
Expand All @@ -4481,6 +4562,7 @@ def _generate_content(
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> types.GenerateContentResponse:
contents = _process_contents_for_generation(self._api_client, contents)
parameter_model = types._GenerateContentParameters(
model=model,
contents=contents,
Expand Down Expand Up @@ -4562,6 +4644,7 @@ def _generate_content_stream(
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> Iterator[types.GenerateContentResponse]:
contents = _process_contents_for_generation(self._api_client, contents)
parameter_model = types._GenerateContentParameters(
model=model,
contents=contents,
Expand Down Expand Up @@ -6440,6 +6523,7 @@ async def _generate_content(
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> types.GenerateContentResponse:
contents = _process_contents_for_generation(self._api_client, contents)
parameter_model = types._GenerateContentParameters(
model=model,
contents=contents,
Expand Down Expand Up @@ -6521,6 +6605,7 @@ async def _generate_content_stream(
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> Awaitable[AsyncIterator[types.GenerateContentResponse]]:
contents = _process_contents_for_generation(self._api_client, contents)
parameter_model = types._GenerateContentParameters(
model=model,
contents=contents,
Expand Down
178 changes: 178 additions & 0 deletions google/genai/tests/models/test_file_state_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#!/usr/bin/env python
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Tests for file state handling in content generation."""

import json
import unittest
from unittest import mock

import pytest

from google.genai import errors
from google.genai import types
from google.genai.models import _ensure_file_active, _process_contents_for_generation
from google.genai.types import FileState


def _make_response_body(state: str, error_message: str = None) -> bytes:
"""Create a mock API response body for a file."""
data = {
'name': 'files/test123',
'displayName': 'Test File',
'mimeType': 'video/mp4',
'state': state,
}
if error_message:
data['error'] = {'message': error_message}
return json.dumps(data).encode()


class TestFileStateHandling(unittest.TestCase):
"""Test file state handling functionality."""

def setUp(self):
"""Set up test fixtures."""
self.api_client = mock.MagicMock()
self.file_obj = types.File(
name='files/test123',
display_name='Test File',
mime_type='video/mp4',
uri='https://example.com/files/test123',
state=types.FileState.PROCESSING,
)

def test_ensure_file_active_with_processing_file(self):
"""Test that _ensure_file_active waits for a PROCESSING file to become ACTIVE."""
response_mock = mock.MagicMock()
response_mock.body = _make_response_body('ACTIVE')
self.api_client.request.return_value = response_mock

result = _ensure_file_active(
self.api_client, self.file_obj, max_retries=1, retry_delay_seconds=0
)

self.api_client.request.assert_called_once_with(
'GET', 'files/test123', {}, None
)
self.assertEqual(result.state, types.FileState.ACTIVE)

def test_ensure_file_active_with_failed_file(self):
"""Test that _ensure_file_active raises FileProcessingError for a FAILED file."""
response_mock = mock.MagicMock()
response_mock.body = _make_response_body(
'FAILED', error_message='File processing failed'
)
self.api_client.request.return_value = response_mock

with pytest.raises(errors.FileProcessingError) as excinfo:
_ensure_file_active(
self.api_client, self.file_obj, max_retries=1, retry_delay_seconds=0
)

assert 'File processing failed' in str(excinfo.value)

def test_ensure_file_active_with_retries_exhausted(self):
"""Test that _ensure_file_active returns original file after exhausting retries."""
response_mock = mock.MagicMock()
response_mock.body = _make_response_body('PROCESSING')
self.api_client.request.return_value = response_mock

result = _ensure_file_active(
self.api_client, self.file_obj, max_retries=2, retry_delay_seconds=0
)

self.assertEqual(self.api_client.request.call_count, 2)
self.assertEqual(result, self.file_obj)
self.assertEqual(result.state, types.FileState.PROCESSING)

def test_ensure_file_active_with_already_active_file(self):
"""Test that _ensure_file_active returns immediately for an already ACTIVE file."""
active_file = types.File(
name='files/active123',
display_name='Active File',
mime_type='video/mp4',
state=types.FileState.ACTIVE,
)

result = _ensure_file_active(
self.api_client, active_file, max_retries=1, retry_delay_seconds=0
)

self.api_client.request.assert_not_called()
self.assertEqual(result, active_file)
self.assertEqual(result.state, types.FileState.ACTIVE)


class TestProcessContentsFunction(unittest.TestCase):
"""Test the _process_contents_for_generation function."""

def setUp(self):
"""Set up test fixtures."""
self.api_client = mock.MagicMock()
self.processing_file = types.File(
name='files/processing123',
display_name='Processing File',
mime_type='video/mp4',
uri='https://example.com/files/processing123',
state=types.FileState.PROCESSING,
)
self.active_file = types.File(
name='files/active123',
display_name='Active File',
mime_type='video/mp4',
uri='https://example.com/files/active123',
state=types.FileState.ACTIVE,
)

def test_process_contents_with_files(self):
"""Test that _process_contents_for_generation can handle various file scenarios."""
file_in_list = [self.processing_file, 'Process this file']
file_in_parts = types.Content(
role='user',
parts=[types.Part(text="Here's a video:"), self.processing_file],
)
multiple_files = [
types.Content(
role='user',
parts=[types.Part(text='First video:'), self.processing_file],
),
types.Content(
role='user',
parts=[types.Part(text='Second video:'), self.active_file],
),
]

with mock.patch(
'google.genai.models._ensure_file_active', side_effect=lambda client, f: f
):
for test_content in [file_in_list, file_in_parts, multiple_files]:
with mock.patch(
'google.genai.models.t.t_contents',
return_value=(
test_content
if isinstance(test_content, list)
else [test_content]
),
):
result = _process_contents_for_generation(
self.api_client, test_content
)
self.assertTrue(result)


if __name__ == '__main__':
unittest.main()