Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions google/genai/tests/tunings/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,66 @@
),
exception_if_mldev="parameter is only supported in Gemini Enterprise Agent Platform mode",
),
pytest_helper.TestTableItem(
name="test_tune_reinforcement",
parameters=genai_types.CreateTuningJobParameters(
base_model="gemini-2.5-flash",
training_dataset=genai_types.TuningDataset(
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
),
config=genai_types.CreateTuningJobConfig(
tuned_model_display_name="Model display name",
epoch_count=1,
method="REINFORCEMENT_TUNING",
adapter_size="ADAPTER_SIZE_ONE",
learning_rate_multiplier=1.0,
batch_size=4,
samples_per_prompt=4,
evaluate_interval=100,
checkpoint_interval=100,
max_output_tokens=2048,
reward_config=genai_types.SingleReinforcementTuningRewardConfig(
autorater_scorer=genai_types.ReinforcementTuningAutoraterScorer(
autorater_config=genai_types.AutoraterConfig(
autorater_model="test-model"
)
)
),
),
),
exception_if_mldev="parameter is only supported in Gemini Enterprise Agent Platform mode",
),
pytest_helper.TestTableItem(
name="test_tune_reinforcement_composite",
parameters=genai_types.CreateTuningJobParameters(
base_model="gemini-2.5-flash",
training_dataset=genai_types.TuningDataset(
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
),
config=genai_types.CreateTuningJobConfig(
tuned_model_display_name="Model display name",
epoch_count=1,
method="REINFORCEMENT_TUNING",
adapter_size="ADAPTER_SIZE_ONE",
learning_rate_multiplier=1.0,
composite_reward_config=genai_types.CompositeReinforcementTuningRewardConfig(
weighted_reward_configs=[
genai_types.CompositeReinforcementTuningRewardConfigWeightedRewardConfig(
weight=1.0,
reward_config=genai_types.SingleReinforcementTuningRewardConfig(
autorater_scorer=genai_types.ReinforcementTuningAutoraterScorer(
autorater_config=genai_types.AutoraterConfig(
autorater_model="test-model"
)
)
)
)
]
),
),
),
exception_if_mldev="parameter is only supported in Gemini Enterprise Agent Platform mode",
),
]

pytestmark = pytest_helper.setup(
Expand Down
217 changes: 217 additions & 0 deletions google/genai/tunings.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,48 @@ def _CancelTuningJobResponse_from_vertex(
return to_object


def _CompositeReinforcementTuningRewardConfigWeightedRewardConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
root_object: Optional[Union[dict[str, Any], object]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['reward_config']) is not None:
setv(
to_object,
['rewardConfig'],
_SingleReinforcementTuningRewardConfig_to_vertex(
getv(from_object, ['reward_config']), to_object, root_object
),
)

if getv(from_object, ['weight']) is not None:
setv(to_object, ['weight'], getv(from_object, ['weight']))

return to_object


def _CompositeReinforcementTuningRewardConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
root_object: Optional[Union[dict[str, Any], object]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['weighted_reward_configs']) is not None:
setv(
to_object,
['weightedRewardConfigs'],
[
_CompositeReinforcementTuningRewardConfigWeightedRewardConfig_to_vertex(
item, to_object, root_object
)
for item in getv(from_object, ['weighted_reward_configs'])
],
)

return to_object


def _CreateTuningJobConfig_to_mldev(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -269,6 +311,42 @@ def _CreateTuningJobConfig_to_mldev(
' Platform mode, not in Gemini Developer API mode.'
)

if getv(from_object, ['reward_config']) is not None:
raise ValueError(
'reward_config parameter is only supported in Gemini Enterprise Agent'
' Platform mode, not in Gemini Developer API mode.'
)

if getv(from_object, ['composite_reward_config']) is not None:
raise ValueError(
'composite_reward_config parameter is only supported in Gemini'
' Enterprise Agent Platform mode, not in Gemini Developer API mode.'
)

if getv(from_object, ['samples_per_prompt']) is not None:
raise ValueError(
'samples_per_prompt parameter is only supported in Gemini Enterprise'
' Agent Platform mode, not in Gemini Developer API mode.'
)

if getv(from_object, ['evaluate_interval']) is not None:
raise ValueError(
'evaluate_interval parameter is only supported in Gemini Enterprise'
' Agent Platform mode, not in Gemini Developer API mode.'
)

if getv(from_object, ['checkpoint_interval']) is not None:
raise ValueError(
'checkpoint_interval parameter is only supported in Gemini Enterprise'
' Agent Platform mode, not in Gemini Developer API mode.'
)

if getv(from_object, ['max_output_tokens']) is not None:
raise ValueError(
'max_output_tokens parameter is only supported in Gemini Enterprise'
' Agent Platform mode, not in Gemini Developer API mode.'
)

return to_object


Expand Down Expand Up @@ -309,6 +387,15 @@ def _CreateTuningJobConfig_to_vertex(
getv(from_object, ['validation_dataset']), to_object, root_object
),
)
elif discriminator == 'REINFORCEMENT_TUNING':
if getv(from_object, ['validation_dataset']) is not None:
setv(
parent_object,
['reinforcementTuningSpec'],
_TuningValidationDataset_to_vertex(
getv(from_object, ['validation_dataset']), to_object, root_object
),
)

if getv(from_object, ['tuned_model_display_name']) is not None:
setv(
Expand Down Expand Up @@ -344,6 +431,13 @@ def _CreateTuningJobConfig_to_vertex(
['distillationSpec', 'hyperParameters', 'epochCount'],
getv(from_object, ['epoch_count']),
)
elif discriminator == 'REINFORCEMENT_TUNING':
if getv(from_object, ['epoch_count']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'hyperParameters', 'epochCount'],
getv(from_object, ['epoch_count']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
Expand Down Expand Up @@ -373,6 +467,17 @@ def _CreateTuningJobConfig_to_vertex(
['distillationSpec', 'hyperParameters', 'learningRateMultiplier'],
getv(from_object, ['learning_rate_multiplier']),
)
elif discriminator == 'REINFORCEMENT_TUNING':
if getv(from_object, ['learning_rate_multiplier']) is not None:
setv(
parent_object,
[
'reinforcementTuningSpec',
'hyperParameters',
'learningRateMultiplier',
],
getv(from_object, ['learning_rate_multiplier']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
Expand Down Expand Up @@ -423,6 +528,13 @@ def _CreateTuningJobConfig_to_vertex(
['distillationSpec', 'hyperParameters', 'adapterSize'],
getv(from_object, ['adapter_size']),
)
elif discriminator == 'REINFORCEMENT_TUNING':
if getv(from_object, ['adapter_size']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'hyperParameters', 'adapterSize'],
getv(from_object, ['adapter_size']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
Expand Down Expand Up @@ -466,6 +578,13 @@ def _CreateTuningJobConfig_to_vertex(
['distillationSpec', 'hyperParameters', 'batchSize'],
getv(from_object, ['batch_size']),
)
elif discriminator == 'REINFORCEMENT_TUNING':
if getv(from_object, ['batch_size']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'hyperParameters', 'batchSize'],
getv(from_object, ['batch_size']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
Expand Down Expand Up @@ -557,6 +676,54 @@ def _CreateTuningJobConfig_to_vertex(
getv(from_object, ['encryption_spec']),
)

if getv(from_object, ['reward_config']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'singleRewardConfig'],
_SingleReinforcementTuningRewardConfig_to_vertex(
getv(from_object, ['reward_config']), to_object, root_object
),
)

if getv(from_object, ['composite_reward_config']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'compositeRewardConfig'],
_CompositeReinforcementTuningRewardConfig_to_vertex(
getv(from_object, ['composite_reward_config']),
to_object,
root_object,
),
)

if getv(from_object, ['samples_per_prompt']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'hyperParameters', 'samplesPerPrompt'],
getv(from_object, ['samples_per_prompt']),
)

if getv(from_object, ['evaluate_interval']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'hyperParameters', 'evaluateInterval'],
getv(from_object, ['evaluate_interval']),
)

if getv(from_object, ['checkpoint_interval']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'hyperParameters', 'checkpointInterval'],
getv(from_object, ['checkpoint_interval']),
)

if getv(from_object, ['max_output_tokens']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'hyperParameters', 'maxOutputTokens'],
getv(from_object, ['max_output_tokens']),
)

return to_object


Expand Down Expand Up @@ -1168,6 +1335,24 @@ def _MultiSpeakerVoiceConfig_to_vertex(
return to_object


def _ReinforcementTuningAutoraterScorer_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
root_object: Optional[Union[dict[str, Any], object]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['autorater_config']) is not None:
setv(
to_object,
['autoraterConfig'],
_AutoraterConfig_to_vertex(
getv(from_object, ['autorater_config']), to_object, root_object
),
)

return to_object


def _ReplicatedVoiceConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -1199,6 +1384,24 @@ def _ReplicatedVoiceConfig_to_vertex(
return to_object


def _SingleReinforcementTuningRewardConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
root_object: Optional[Union[dict[str, Any], object]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['autorater_scorer']) is not None:
setv(
to_object,
['autoraterScorer'],
_ReinforcementTuningAutoraterScorer_to_vertex(
getv(from_object, ['autorater_scorer']), to_object, root_object
),
)

return to_object


def _SpeakerVoiceConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -1326,6 +1529,13 @@ def _TuningDataset_to_vertex(
['distillationSpec', 'promptDatasetUri'],
getv(from_object, ['gcs_uri']),
)
elif discriminator == 'REINFORCEMENT_TUNING':
if getv(from_object, ['gcs_uri']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'trainingDatasetUri'],
getv(from_object, ['gcs_uri']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
Expand All @@ -1351,6 +1561,13 @@ def _TuningDataset_to_vertex(
['distillationSpec', 'promptDatasetUri'],
getv(from_object, ['vertex_dataset_resource']),
)
elif discriminator == 'REINFORCEMENT_TUNING':
if getv(from_object, ['vertex_dataset_resource']) is not None:
setv(
parent_object,
['reinforcementTuningSpec', 'trainingDatasetUri'],
getv(from_object, ['vertex_dataset_resource']),
)

if getv(from_object, ['examples']) is not None:
raise ValueError(
Expand Down
Loading
Loading