From 462fef8be3129a2c5e0f7403cecdda443134350c Mon Sep 17 00:00:00 2001 From: Ayush Agrawal Date: Thu, 12 Mar 2026 12:01:12 -0700 Subject: [PATCH] feat: support hyperparameters in distillation tuning FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-genai/pull/2125 from googleapis:release-please--branches--main 0061b63851c5f406312768a8505600b9d22769d2 PiperOrigin-RevId: 882708166 --- google/genai/tests/tunings/test_tune.py | 19 +++++++++++ google/genai/tunings.py | 21 ++++++++++++ google/genai/types.py | 44 +++++++++++++++++-------- 3 files changed, 70 insertions(+), 14 deletions(-) diff --git a/google/genai/tests/tunings/test_tune.py b/google/genai/tests/tunings/test_tune.py index 3e35dc55e..bede40986 100755 --- a/google/genai/tests/tunings/test_tune.py +++ b/google/genai/tests/tunings/test_tune.py @@ -245,6 +245,25 @@ ), exception_if_mldev="not supported in Gemini API", ), + pytest_helper.TestTableItem( + name="test_tune_oss_distillation_hyperparams", + parameters=genai_types.CreateTuningJobParameters( + base_model="qwen/qwen3@qwen3-4b", + training_dataset=genai_types.TuningDataset( + gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl", + ), + config=genai_types.CreateTuningJobConfig( + method="DISTILLATION", + base_teacher_model="deepseek-ai/deepseek-r1-0528-maas", + learning_rate=1e-4, + batch_size=4, + output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test", + tuning_mode="TUNING_MODE_FULL", + http_options=VERTEX_HTTP_OPTIONS, + ), + ), + exception_if_mldev="not supported in Gemini API", + ), pytest_helper.TestTableItem( name="test_tune_encryption_spec", parameters=genai_types.CreateTuningJobParameters( diff --git a/google/genai/tunings.py b/google/genai/tunings.py index cba933af0..da4f2461b 100644 --- a/google/genai/tunings.py +++ b/google/genai/tunings.py @@ -409,6 +409,13 @@ def _CreateTuningJobConfig_to_vertex( ['supervisedTuningSpec', 'tuningMode'], getv(from_object, ['tuning_mode']), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['tuning_mode']) is not None: + setv( + parent_object, + ['distillationSpec', 'tuningMode'], + getv(from_object, ['tuning_mode']), + ) if getv(from_object, ['custom_base_model']) is not None: setv( @@ -427,6 +434,13 @@ def _CreateTuningJobConfig_to_vertex( ['supervisedTuningSpec', 'hyperParameters', 'batchSize'], getv(from_object, ['batch_size']), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['batch_size']) is not None: + setv( + parent_object, + ['distillationSpec', 'hyperParameters', 'batchSize'], + getv(from_object, ['batch_size']), + ) discriminator = getv(root_object, ['config', 'method']) if discriminator is None: @@ -438,6 +452,13 @@ def _CreateTuningJobConfig_to_vertex( ['supervisedTuningSpec', 'hyperParameters', 'learningRate'], getv(from_object, ['learning_rate']), ) + elif discriminator == 'DISTILLATION': + if getv(from_object, ['learning_rate']) is not None: + setv( + parent_object, + ['distillationSpec', 'hyperParameters', 'learningRate'], + getv(from_object, ['learning_rate']), + ) discriminator = getv(root_object, ['config', 'method']) if discriminator is None: diff --git a/google/genai/types.py b/google/genai/types.py index d0cd822ac..9ac79ff61 100644 --- a/google/genai/types.py +++ b/google/genai/types.py @@ -11217,11 +11217,17 @@ class PreferenceOptimizationSpecDict(TypedDict, total=False): class DistillationHyperParameters(_common.BaseModel): - """Hyperparameters for Distillation. - - This data type is not supported in Gemini API. - """ + """Hyperparameters for distillation.""" + batch_size: Optional[int] = Field( + default=None, + description="""The batch size hyperparameter for tuning. + This is only supported for OSS models in Vertex.""", + ) + learning_rate: Optional[float] = Field( + default=None, + description="""The learning rate for tuning. OSS models only.""", + ) adapter_size: Optional[AdapterSize] = Field( default=None, description="""Optional. Adapter size for distillation.""" ) @@ -11236,10 +11242,14 @@ class DistillationHyperParameters(_common.BaseModel): class DistillationHyperParametersDict(TypedDict, total=False): - """Hyperparameters for Distillation. + """Hyperparameters for distillation.""" - This data type is not supported in Gemini API. - """ + batch_size: Optional[int] + """The batch size hyperparameter for tuning. + This is only supported for OSS models in Vertex.""" + + learning_rate: Optional[float] + """The learning rate for tuning. OSS models only.""" adapter_size: Optional[AdapterSize] """Optional. Adapter size for distillation.""" @@ -11263,14 +11273,17 @@ class DistillationSpec(_common.BaseModel): default=None, description="""The GCS URI of the prompt dataset to use during distillation.""", ) - base_teacher_model: Optional[str] = Field( - default=None, - description="""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""", + tuning_mode: Optional[TuningMode] = Field( + default=None, description="""Tuning mode for tuning.""" ) hyper_parameters: Optional[DistillationHyperParameters] = Field( default=None, description="""Optional. Hyperparameters for Distillation.""", ) + base_teacher_model: Optional[str] = Field( + default=None, + description="""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""", + ) pipeline_root_directory: Optional[str] = Field( default=None, description="""Deprecated. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts.""", @@ -11299,12 +11312,15 @@ class DistillationSpecDict(TypedDict, total=False): prompt_dataset_uri: Optional[str] """The GCS URI of the prompt dataset to use during distillation.""" - base_teacher_model: Optional[str] - """The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""" + tuning_mode: Optional[TuningMode] + """Tuning mode for tuning.""" hyper_parameters: Optional[DistillationHyperParametersDict] """Optional. Hyperparameters for Distillation.""" + base_teacher_model: Optional[str] + """The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""" + pipeline_root_directory: Optional[str] """Deprecated. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts.""" @@ -13421,7 +13437,7 @@ class CreateTuningJobConfig(_common.BaseModel): default=None, description="""Adapter size for tuning.""" ) tuning_mode: Optional[TuningMode] = Field( - default=None, description="""Tuning mode for SFT tuning.""" + default=None, description="""Tuning mode for tuning.""" ) custom_base_model: Optional[str] = Field( default=None, @@ -13502,7 +13518,7 @@ class CreateTuningJobConfigDict(TypedDict, total=False): """Adapter size for tuning.""" tuning_mode: Optional[TuningMode] - """Tuning mode for SFT tuning.""" + """Tuning mode for tuning.""" custom_base_model: Optional[str] """Custom base model for tuning. This is only supported for OSS models in Vertex."""