Skip to content

Commit 977451c

Browse files
committed
fix(batches): preserve order for Gemini inlined batch responses
1 parent c04be0d commit 977451c

3 files changed

Lines changed: 210 additions & 10 deletions

File tree

google/genai/batches.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434

3535
logger = logging.getLogger('google_genai.batches')
36+
_INLINED_REQUEST_ORDER_METADATA_KEY = '_google_genai_inlined_request_order'
3637

3738

3839
def _AuthConfig_to_mldev(
@@ -79,15 +80,37 @@ def _BatchJobDestination_from_mldev(
7980
setv(to_object, ['file_name'], getv(from_object, ['responsesFile']))
8081

8182
if getv(from_object, ['inlinedResponses', 'inlinedResponses']) is not None:
83+
inlined_responses = [
84+
_InlinedResponse_from_mldev(item, to_object)
85+
for item in getv(from_object, ['inlinedResponses', 'inlinedResponses'])
86+
]
87+
# Backend can return inlined responses out of input order. When we have the
88+
# SDK-injected order marker, restore the original order deterministically.
89+
sortable = True
90+
for inlined_response in inlined_responses:
91+
metadata = getv(inlined_response, ['metadata'])
92+
request_order = (
93+
metadata.get(_INLINED_REQUEST_ORDER_METADATA_KEY)
94+
if isinstance(metadata, dict)
95+
else None
96+
)
97+
if request_order is None or not str(request_order).isdigit():
98+
sortable = False
99+
break
100+
if sortable:
101+
inlined_responses.sort(
102+
key=lambda response: int(
103+
getv(response, ['metadata', _INLINED_REQUEST_ORDER_METADATA_KEY])
104+
)
105+
)
106+
for inlined_response in inlined_responses:
107+
metadata = getv(inlined_response, ['metadata'])
108+
if isinstance(metadata, dict):
109+
metadata.pop(_INLINED_REQUEST_ORDER_METADATA_KEY, None)
82110
setv(
83111
to_object,
84112
['inlined_responses'],
85-
[
86-
_InlinedResponse_from_mldev(item, to_object)
87-
for item in getv(
88-
from_object, ['inlinedResponses', 'inlinedResponses']
89-
)
90-
],
113+
inlined_responses,
91114
)
92115

93116
if (
@@ -213,13 +236,23 @@ def _BatchJobSource_to_mldev(
213236
setv(to_object, ['fileName'], getv(from_object, ['file_name']))
214237

215238
if getv(from_object, ['inlined_requests']) is not None:
239+
inlined_requests = []
240+
for index, inlined_request in enumerate(getv(from_object, ['inlined_requests'])):
241+
inlined_request_object = _InlinedRequest_to_mldev(
242+
api_client, inlined_request, to_object
243+
)
244+
metadata = getv(inlined_request_object, ['metadata'], default_value={})
245+
if not isinstance(metadata, dict):
246+
metadata = {}
247+
# Reserved SDK key: always stamp deterministic order marker, even when
248+
# caller metadata contains the same key.
249+
metadata[_INLINED_REQUEST_ORDER_METADATA_KEY] = str(index)
250+
setv(inlined_request_object, ['metadata'], metadata)
251+
inlined_requests.append(inlined_request_object)
216252
setv(
217253
to_object,
218254
['requests', 'requests'],
219-
[
220-
_InlinedRequest_to_mldev(api_client, item, to_object)
221-
for item in getv(from_object, ['inlined_requests'])
222-
],
255+
inlined_requests,
223256
)
224257

225258
return to_object

google/genai/tests/batches/test_create_with_inlined_requests.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import os
2222

2323
import pytest
24+
from unittest import mock
2425

26+
from ... import batches as batches_module
2527
from ... import _transformers as t
2628
from ... import types
2729
from .. import pytest_helper
@@ -258,6 +260,163 @@
258260
]
259261

260262

263+
def test_inlined_requests_include_internal_order_metadata(
264+
use_vertex, replays_prefix, http_options
265+
):
266+
del use_vertex, replays_prefix, http_options
267+
request_payload = {
268+
'inlined_requests': [
269+
{'contents': [{'parts': [{'text': 'first'}], 'role': 'user'}]},
270+
{
271+
'contents': [{'parts': [{'text': 'second'}], 'role': 'user'}],
272+
'metadata': {'caller': 'external'},
273+
},
274+
]
275+
}
276+
277+
converted = batches_module._BatchJobSource_to_mldev(
278+
mock.MagicMock(), request_payload
279+
)
280+
requests = converted['requests']['requests']
281+
key = batches_module._INLINED_REQUEST_ORDER_METADATA_KEY
282+
283+
assert requests[0]['metadata'][key] == '0'
284+
assert requests[1]['metadata'][key] == '1'
285+
assert requests[1]['metadata']['caller'] == 'external'
286+
287+
288+
def test_inlined_requests_internal_order_metadata_overrides_reserved_key(
289+
use_vertex, replays_prefix, http_options
290+
):
291+
del use_vertex, replays_prefix, http_options
292+
key = batches_module._INLINED_REQUEST_ORDER_METADATA_KEY
293+
request_payload = {
294+
'inlined_requests': [
295+
{
296+
'contents': [{'parts': [{'text': 'first'}], 'role': 'user'}],
297+
'metadata': {key: '999', 'caller': 'external'},
298+
},
299+
]
300+
}
301+
302+
converted = batches_module._BatchJobSource_to_mldev(
303+
mock.MagicMock(), request_payload
304+
)
305+
request = converted['requests']['requests'][0]
306+
307+
assert request['metadata'][key] == '0'
308+
assert request['metadata']['caller'] == 'external'
309+
310+
311+
def test_inlined_responses_are_reordered_by_internal_order_metadata(
312+
use_vertex, replays_prefix, http_options
313+
):
314+
del use_vertex, replays_prefix, http_options
315+
key = batches_module._INLINED_REQUEST_ORDER_METADATA_KEY
316+
response_payload = {
317+
'inlinedResponses': {
318+
'inlinedResponses': [
319+
{
320+
'metadata': {'request_key': 'two', key: '2'},
321+
'response': {'candidates': []},
322+
},
323+
{
324+
'metadata': {'request_key': 'zero', key: '0'},
325+
'response': {'candidates': []},
326+
},
327+
{
328+
'metadata': {'request_key': 'one', key: '1'},
329+
'response': {'candidates': []},
330+
},
331+
]
332+
}
333+
}
334+
335+
converted = batches_module._BatchJobDestination_from_mldev(response_payload)
336+
responses = converted['inlined_responses']
337+
338+
assert [item['metadata']['request_key'] for item in responses] == [
339+
'zero',
340+
'one',
341+
'two',
342+
]
343+
assert all(key not in item['metadata'] for item in responses)
344+
345+
346+
def test_inlined_responses_keep_input_order_when_metadata_missing(
347+
use_vertex, replays_prefix, http_options
348+
):
349+
del use_vertex, replays_prefix, http_options
350+
key = batches_module._INLINED_REQUEST_ORDER_METADATA_KEY
351+
response_payload = {
352+
'inlinedResponses': {
353+
'inlinedResponses': [
354+
{
355+
'metadata': {'request_key': 'two', key: '2'},
356+
'response': {'candidates': []},
357+
},
358+
{
359+
'metadata': {'request_key': 'zero'},
360+
'response': {'candidates': []},
361+
},
362+
{
363+
'metadata': {'request_key': 'one', key: '1'},
364+
'response': {'candidates': []},
365+
},
366+
]
367+
}
368+
}
369+
370+
converted = batches_module._BatchJobDestination_from_mldev(response_payload)
371+
responses = converted['inlined_responses']
372+
373+
assert [item['metadata']['request_key'] for item in responses] == [
374+
'two',
375+
'zero',
376+
'one',
377+
]
378+
assert responses[0]['metadata'][key] == '2'
379+
assert key not in responses[1]['metadata']
380+
assert responses[2]['metadata'][key] == '1'
381+
382+
383+
def test_inlined_responses_keep_input_order_when_metadata_non_numeric(
384+
use_vertex, replays_prefix, http_options
385+
):
386+
del use_vertex, replays_prefix, http_options
387+
key = batches_module._INLINED_REQUEST_ORDER_METADATA_KEY
388+
response_payload = {
389+
'inlinedResponses': {
390+
'inlinedResponses': [
391+
{
392+
'metadata': {'request_key': 'two', key: '2'},
393+
'response': {'candidates': []},
394+
},
395+
{
396+
'metadata': {'request_key': 'bad', key: 'not-a-number'},
397+
'response': {'candidates': []},
398+
},
399+
{
400+
'metadata': {'request_key': 'one', key: '1'},
401+
'response': {'candidates': []},
402+
},
403+
]
404+
}
405+
}
406+
407+
converted = batches_module._BatchJobDestination_from_mldev(response_payload)
408+
responses = converted['inlined_responses']
409+
410+
assert [item['metadata']['request_key'] for item in responses] == [
411+
'two',
412+
'bad',
413+
'one',
414+
]
415+
assert responses[0]['metadata'][key] == '2'
416+
assert responses[1]['metadata'][key] == 'not-a-number'
417+
assert responses[2]['metadata'][key] == '1'
418+
419+
261420
@pytest.mark.asyncio
262421
async def test_async_create(client):
263422
with pytest_helper.exception_if_vertex(client, ValueError):

google/genai/tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ def client(use_vertex, replays_prefix, http_options, request):
9292
Assert an exception if the test is not supported in an API.""")
9393
replay_id = _get_replay_id(use_vertex, replays_prefix)
9494

95+
if mode in ['replay', 'tap'] and not use_vertex:
96+
# Replay mode should not require a real API key, but client init still
97+
# validates key presence on the mldev path.
98+
if not os.environ.get('GOOGLE_API_KEY') and not os.environ.get(
99+
'GEMINI_API_KEY'
100+
):
101+
os.environ['GOOGLE_API_KEY'] = 'dummy-api-key'
102+
95103
if mode == 'tap':
96104
mode = 'replay'
97105

0 commit comments

Comments
 (0)