From 0ead8888695f379ecf35cfc68d69e4b7e8e20403 Mon Sep 17 00:00:00 2001 From: Ayush Agrawal Date: Fri, 22 May 2026 11:56:02 -0700 Subject: [PATCH] feat: Support Reinforcement Tuning in GenAI SDK PiperOrigin-RevId: 919788531 --- google/genai/tests/tunings/test_tune.py | 60 +++++++ google/genai/tunings.py | 217 ++++++++++++++++++++++++ google/genai/types.py | 148 +++++++++++++++- 3 files changed, 423 insertions(+), 2 deletions(-) diff --git a/google/genai/tests/tunings/test_tune.py b/google/genai/tests/tunings/test_tune.py index b99d1a068..e69d51036 100755 --- a/google/genai/tests/tunings/test_tune.py +++ b/google/genai/tests/tunings/test_tune.py @@ -279,6 +279,66 @@ ), exception_if_mldev="parameter is only supported in Gemini Enterprise Agent Platform mode", ), + pytest_helper.TestTableItem( + name="test_tune_reinforcement", + parameters=genai_types.CreateTuningJobParameters( + base_model="gemini-2.5-flash", + training_dataset=genai_types.TuningDataset( + gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl", + ), + config=genai_types.CreateTuningJobConfig( + tuned_model_display_name="Model display name", + epoch_count=1, + method="REINFORCEMENT_TUNING", + adapter_size="ADAPTER_SIZE_ONE", + learning_rate_multiplier=1.0, + batch_size=4, + samples_per_prompt=4, + evaluate_interval=100, + checkpoint_interval=100, + max_output_tokens=2048, + reward_config=genai_types.SingleReinforcementTuningRewardConfig( + autorater_scorer=genai_types.ReinforcementTuningAutoraterScorer( + autorater_config=genai_types.AutoraterConfig( + autorater_model="test-model" + ) + ) + ), + ), + ), + exception_if_mldev="parameter is only supported in Gemini Enterprise Agent Platform mode", + ), + pytest_helper.TestTableItem( + name="test_tune_reinforcement_composite", + parameters=genai_types.CreateTuningJobParameters( + base_model="gemini-2.5-flash", + training_dataset=genai_types.TuningDataset( + gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl", + ), + config=genai_types.CreateTuningJobConfig( + tuned_model_display_name="Model display name", + epoch_count=1, + method="REINFORCEMENT_TUNING", + adapter_size="ADAPTER_SIZE_ONE", + learning_rate_multiplier=1.0, + composite_reward_config=genai_types.CompositeReinforcementTuningRewardConfig( + weighted_reward_configs=[ + genai_types.CompositeReinforcementTuningRewardConfigWeightedRewardConfig( + weight=1.0, + reward_config=genai_types.SingleReinforcementTuningRewardConfig( + autorater_scorer=genai_types.ReinforcementTuningAutoraterScorer( + autorater_config=genai_types.AutoraterConfig( + autorater_model="test-model" + ) + ) + ) + ) + ] + ), + ), + ), + exception_if_mldev="parameter is only supported in Gemini Enterprise Agent Platform mode", + ), ] pytestmark = pytest_helper.setup( diff --git a/google/genai/tunings.py b/google/genai/tunings.py index 1a927fb1c..fd4efe4da 100644 --- a/google/genai/tunings.py +++ b/google/genai/tunings.py @@ -137,6 +137,48 @@ def _CancelTuningJobResponse_from_vertex( return to_object +def _CompositeReinforcementTuningRewardConfigWeightedRewardConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + root_object: Optional[Union[dict[str, Any], object]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['reward_config']) is not None: + setv( + to_object, + ['rewardConfig'], + _SingleReinforcementTuningRewardConfig_to_vertex( + getv(from_object, ['reward_config']), to_object, root_object + ), + ) + + if getv(from_object, ['weight']) is not None: + setv(to_object, ['weight'], getv(from_object, ['weight'])) + + return to_object + + +def _CompositeReinforcementTuningRewardConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + root_object: Optional[Union[dict[str, Any], object]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['weighted_reward_configs']) is not None: + setv( + to_object, + ['weightedRewardConfigs'], + [ + _CompositeReinforcementTuningRewardConfigWeightedRewardConfig_to_vertex( + item, to_object, root_object + ) + for item in getv(from_object, ['weighted_reward_configs']) + ], + ) + + return to_object + + def _CreateTuningJobConfig_to_mldev( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -269,6 +311,42 @@ def _CreateTuningJobConfig_to_mldev( ' Platform mode, not in Gemini Developer API mode.' ) + if getv(from_object, ['reward_config']) is not None: + raise ValueError( + 'reward_config parameter is only supported in Gemini Enterprise Agent' + ' Platform mode, not in Gemini Developer API mode.' + ) + + if getv(from_object, ['composite_reward_config']) is not None: + raise ValueError( + 'composite_reward_config parameter is only supported in Gemini' + ' Enterprise Agent Platform mode, not in Gemini Developer API mode.' + ) + + if getv(from_object, ['samples_per_prompt']) is not None: + raise ValueError( + 'samples_per_prompt parameter is only supported in Gemini Enterprise' + ' Agent Platform mode, not in Gemini Developer API mode.' + ) + + if getv(from_object, ['evaluate_interval']) is not None: + raise ValueError( + 'evaluate_interval parameter is only supported in Gemini Enterprise' + ' Agent Platform mode, not in Gemini Developer API mode.' + ) + + if getv(from_object, ['checkpoint_interval']) is not None: + raise ValueError( + 'checkpoint_interval parameter is only supported in Gemini Enterprise' + ' Agent Platform mode, not in Gemini Developer API mode.' + ) + + if getv(from_object, ['max_output_tokens']) is not None: + raise ValueError( + 'max_output_tokens parameter is only supported in Gemini Enterprise' + ' Agent Platform mode, not in Gemini Developer API mode.' + ) + return to_object @@ -309,6 +387,15 @@ def _CreateTuningJobConfig_to_vertex( getv(from_object, ['validation_dataset']), to_object, root_object ), ) + elif discriminator == 'REINFORCEMENT_TUNING': + if getv(from_object, ['validation_dataset']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec'], + _TuningValidationDataset_to_vertex( + getv(from_object, ['validation_dataset']), to_object, root_object + ), + ) if getv(from_object, ['tuned_model_display_name']) is not None: setv( @@ -344,6 +431,13 @@ def _CreateTuningJobConfig_to_vertex( ['distillationSpec', 'hyperParameters', 'epochCount'], getv(from_object, ['epoch_count']), ) + elif discriminator == 'REINFORCEMENT_TUNING': + if getv(from_object, ['epoch_count']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'hyperParameters', 'epochCount'], + getv(from_object, ['epoch_count']), + ) discriminator = getv(root_object, ['config', 'method']) if discriminator is None: @@ -373,6 +467,17 @@ def _CreateTuningJobConfig_to_vertex( ['distillationSpec', 'hyperParameters', 'learningRateMultiplier'], getv(from_object, ['learning_rate_multiplier']), ) + elif discriminator == 'REINFORCEMENT_TUNING': + if getv(from_object, ['learning_rate_multiplier']) is not None: + setv( + parent_object, + [ + 'reinforcementTuningSpec', + 'hyperParameters', + 'learningRateMultiplier', + ], + getv(from_object, ['learning_rate_multiplier']), + ) discriminator = getv(root_object, ['config', 'method']) if discriminator is None: @@ -423,6 +528,13 @@ def _CreateTuningJobConfig_to_vertex( ['distillationSpec', 'hyperParameters', 'adapterSize'], getv(from_object, ['adapter_size']), ) + elif discriminator == 'REINFORCEMENT_TUNING': + if getv(from_object, ['adapter_size']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'hyperParameters', 'adapterSize'], + getv(from_object, ['adapter_size']), + ) discriminator = getv(root_object, ['config', 'method']) if discriminator is None: @@ -466,6 +578,13 @@ def _CreateTuningJobConfig_to_vertex( ['distillationSpec', 'hyperParameters', 'batchSize'], getv(from_object, ['batch_size']), ) + elif discriminator == 'REINFORCEMENT_TUNING': + if getv(from_object, ['batch_size']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'hyperParameters', 'batchSize'], + getv(from_object, ['batch_size']), + ) discriminator = getv(root_object, ['config', 'method']) if discriminator is None: @@ -557,6 +676,54 @@ def _CreateTuningJobConfig_to_vertex( getv(from_object, ['encryption_spec']), ) + if getv(from_object, ['reward_config']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'singleRewardConfig'], + _SingleReinforcementTuningRewardConfig_to_vertex( + getv(from_object, ['reward_config']), to_object, root_object + ), + ) + + if getv(from_object, ['composite_reward_config']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'compositeRewardConfig'], + _CompositeReinforcementTuningRewardConfig_to_vertex( + getv(from_object, ['composite_reward_config']), + to_object, + root_object, + ), + ) + + if getv(from_object, ['samples_per_prompt']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'hyperParameters', 'samplesPerPrompt'], + getv(from_object, ['samples_per_prompt']), + ) + + if getv(from_object, ['evaluate_interval']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'hyperParameters', 'evaluateInterval'], + getv(from_object, ['evaluate_interval']), + ) + + if getv(from_object, ['checkpoint_interval']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'hyperParameters', 'checkpointInterval'], + getv(from_object, ['checkpoint_interval']), + ) + + if getv(from_object, ['max_output_tokens']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'hyperParameters', 'maxOutputTokens'], + getv(from_object, ['max_output_tokens']), + ) + return to_object @@ -1168,6 +1335,24 @@ def _MultiSpeakerVoiceConfig_to_vertex( return to_object +def _ReinforcementTuningAutoraterScorer_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + root_object: Optional[Union[dict[str, Any], object]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['autorater_config']) is not None: + setv( + to_object, + ['autoraterConfig'], + _AutoraterConfig_to_vertex( + getv(from_object, ['autorater_config']), to_object, root_object + ), + ) + + return to_object + + def _ReplicatedVoiceConfig_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -1199,6 +1384,24 @@ def _ReplicatedVoiceConfig_to_vertex( return to_object +def _SingleReinforcementTuningRewardConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + root_object: Optional[Union[dict[str, Any], object]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['autorater_scorer']) is not None: + setv( + to_object, + ['autoraterScorer'], + _ReinforcementTuningAutoraterScorer_to_vertex( + getv(from_object, ['autorater_scorer']), to_object, root_object + ), + ) + + return to_object + + def _SpeakerVoiceConfig_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -1326,6 +1529,13 @@ def _TuningDataset_to_vertex( ['distillationSpec', 'promptDatasetUri'], getv(from_object, ['gcs_uri']), ) + elif discriminator == 'REINFORCEMENT_TUNING': + if getv(from_object, ['gcs_uri']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'trainingDatasetUri'], + getv(from_object, ['gcs_uri']), + ) discriminator = getv(root_object, ['config', 'method']) if discriminator is None: @@ -1351,6 +1561,13 @@ def _TuningDataset_to_vertex( ['distillationSpec', 'promptDatasetUri'], getv(from_object, ['vertex_dataset_resource']), ) + elif discriminator == 'REINFORCEMENT_TUNING': + if getv(from_object, ['vertex_dataset_resource']) is not None: + setv( + parent_object, + ['reinforcementTuningSpec', 'trainingDatasetUri'], + getv(from_object, ['vertex_dataset_resource']), + ) if getv(from_object, ['examples']) is not None: raise ValueError( diff --git a/google/genai/types.py b/google/genai/types.py index 3da905f02..3d5a7fc45 100644 --- a/google/genai/types.py +++ b/google/genai/types.py @@ -1055,6 +1055,8 @@ class TuningMethod(_common.CaseInSensitiveEnum): """Preference optimization tuning.""" DISTILLATION = 'DISTILLATION' """Distillation tuning.""" + REINFORCEMENT_TUNING = 'REINFORCEMENT_TUNING' + """Reinforcement tuning.""" class FileState(_common.CaseInSensitiveEnum): @@ -14257,6 +14259,102 @@ class TuningValidationDatasetDict(TypedDict, total=False): ] +class ReinforcementTuningAutoraterScorer(_common.BaseModel): + """Reinforcement tuning autorater scorer.""" + + autorater_config: Optional[AutoraterConfig] = Field( + default=None, description="""Autorater config for evaluation.""" + ) + + +class ReinforcementTuningAutoraterScorerDict(TypedDict, total=False): + """Reinforcement tuning autorater scorer.""" + + autorater_config: Optional[AutoraterConfigDict] + """Autorater config for evaluation.""" + + +ReinforcementTuningAutoraterScorerOrDict = Union[ + ReinforcementTuningAutoraterScorer, ReinforcementTuningAutoraterScorerDict +] + + +class SingleReinforcementTuningRewardConfig(_common.BaseModel): + """Single reinforcement tuning reward config.""" + + autorater_scorer: Optional[ReinforcementTuningAutoraterScorer] = Field( + default=None, description="""""" + ) + + +class SingleReinforcementTuningRewardConfigDict(TypedDict, total=False): + """Single reinforcement tuning reward config.""" + + autorater_scorer: Optional[ReinforcementTuningAutoraterScorerDict] + """""" + + +SingleReinforcementTuningRewardConfigOrDict = Union[ + SingleReinforcementTuningRewardConfig, + SingleReinforcementTuningRewardConfigDict, +] + + +class CompositeReinforcementTuningRewardConfigWeightedRewardConfig( + _common.BaseModel +): + """Composite reinforcement tuning reward config weighted reward config.""" + + reward_config: Optional[SingleReinforcementTuningRewardConfig] = Field( + default=None, description="""""" + ) + weight: Optional[float] = Field( + default=None, + description="""How much this single reward contributes to the total overall reward.""", + ) + + +class CompositeReinforcementTuningRewardConfigWeightedRewardConfigDict( + TypedDict, total=False +): + """Composite reinforcement tuning reward config weighted reward config.""" + + reward_config: Optional[SingleReinforcementTuningRewardConfigDict] + """""" + + weight: Optional[float] + """How much this single reward contributes to the total overall reward.""" + + +CompositeReinforcementTuningRewardConfigWeightedRewardConfigOrDict = Union[ + CompositeReinforcementTuningRewardConfigWeightedRewardConfig, + CompositeReinforcementTuningRewardConfigWeightedRewardConfigDict, +] + + +class CompositeReinforcementTuningRewardConfig(_common.BaseModel): + """Composite reinforcement tuning reward config.""" + + weighted_reward_configs: Optional[ + list[CompositeReinforcementTuningRewardConfigWeightedRewardConfig] + ] = Field(default=None, description="""""") + + +class CompositeReinforcementTuningRewardConfigDict(TypedDict, total=False): + """Composite reinforcement tuning reward config.""" + + weighted_reward_configs: Optional[ + list[CompositeReinforcementTuningRewardConfigWeightedRewardConfigDict] + ] + """""" + + +CompositeReinforcementTuningRewardConfigOrDict = Union[ + CompositeReinforcementTuningRewardConfig, + CompositeReinforcementTuningRewardConfigDict, +] + + class CreateTuningJobConfig(_common.BaseModel): """Fine-tuning job creation request - optional fields.""" @@ -14265,7 +14363,7 @@ class CreateTuningJobConfig(_common.BaseModel): ) method: Optional[TuningMethod] = Field( default=None, - description="""The method to use for tuning (SUPERVISED_FINE_TUNING or PREFERENCE_TUNING or DISTILLATION). If not set, the default method (SFT) will be used.""", + description="""The method to use for tuning (SUPERVISED_FINE_TUNING or PREFERENCE_TUNING or DISTILLATION or REINFORCEMENT_TUNING). If not set, the default method (SFT) will be used.""", ) validation_dataset: Optional[TuningValidationDataset] = Field( default=None, @@ -14343,6 +14441,32 @@ class CreateTuningJobConfig(_common.BaseModel): default=None, description="""The encryption spec of the tuning job. Customer-managed encryption key options for a TuningJob. If this is set, then all resources created by the TuningJob will be encrypted with provided encryption key.""", ) + reward_config: Optional[SingleReinforcementTuningRewardConfig] = Field( + default=None, + description="""Reward function configuration for reinforcement tuning. Reinforcement tuning only.""", + ) + composite_reward_config: Optional[ + CompositeReinforcementTuningRewardConfig + ] = Field( + default=None, + description="""Composite reward function configuration for reinforcement tuning. Reinforcement tuning only.""", + ) + samples_per_prompt: Optional[int] = Field( + default=None, + description="""Number of different responses to generate per prompt during tuning. Reinforcement tuning only.""", + ) + evaluate_interval: Optional[int] = Field( + default=None, + description="""How often at steps to evaluate the tuning job during training. Reinforcement tuning only.""", + ) + checkpoint_interval: Optional[int] = Field( + default=None, + description="""How often at steps to save checkpoints during training. Reinforcement tuning only.""", + ) + max_output_tokens: Optional[int] = Field( + default=None, + description="""The maximum number of tokens to generate per prompt. Reinforcement tuning only.""", + ) class CreateTuningJobConfigDict(TypedDict, total=False): @@ -14352,7 +14476,7 @@ class CreateTuningJobConfigDict(TypedDict, total=False): """Used to override HTTP request options.""" method: Optional[TuningMethod] - """The method to use for tuning (SUPERVISED_FINE_TUNING or PREFERENCE_TUNING or DISTILLATION). If not set, the default method (SFT) will be used.""" + """The method to use for tuning (SUPERVISED_FINE_TUNING or PREFERENCE_TUNING or DISTILLATION or REINFORCEMENT_TUNING). If not set, the default method (SFT) will be used.""" validation_dataset: Optional[TuningValidationDatasetDict] """Validation dataset for tuning. The dataset must be formatted as a JSONL file.""" @@ -14414,6 +14538,26 @@ class CreateTuningJobConfigDict(TypedDict, total=False): encryption_spec: Optional[EncryptionSpecDict] """The encryption spec of the tuning job. Customer-managed encryption key options for a TuningJob. If this is set, then all resources created by the TuningJob will be encrypted with provided encryption key.""" + reward_config: Optional[SingleReinforcementTuningRewardConfigDict] + """Reward function configuration for reinforcement tuning. Reinforcement tuning only.""" + + composite_reward_config: Optional[ + CompositeReinforcementTuningRewardConfigDict + ] + """Composite reward function configuration for reinforcement tuning. Reinforcement tuning only.""" + + samples_per_prompt: Optional[int] + """Number of different responses to generate per prompt during tuning. Reinforcement tuning only.""" + + evaluate_interval: Optional[int] + """How often at steps to evaluate the tuning job during training. Reinforcement tuning only.""" + + checkpoint_interval: Optional[int] + """How often at steps to save checkpoints during training. Reinforcement tuning only.""" + + max_output_tokens: Optional[int] + """The maximum number of tokens to generate per prompt. Reinforcement tuning only.""" + CreateTuningJobConfigOrDict = Union[ CreateTuningJobConfig, CreateTuningJobConfigDict