Skip to content

Commit 37df28a

Browse files
speedstorm1copybara-github
authored andcommitted
feat: support hyperparameters in distillation tuning
FUTURE_COPYBARA_INTEGRATE_REVIEW=#2125 from googleapis:release-please--branches--main 0061b63 PiperOrigin-RevId: 882708166
1 parent 3442105 commit 37df28a

3 files changed

Lines changed: 70 additions & 14 deletions

File tree

google/genai/tests/tunings/test_tune.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,25 @@
245245
),
246246
exception_if_mldev="not supported in Gemini API",
247247
),
248+
pytest_helper.TestTableItem(
249+
name="test_tune_oss_distillation_hyperparams",
250+
parameters=genai_types.CreateTuningJobParameters(
251+
base_model="qwen/qwen3@qwen3-4b",
252+
training_dataset=genai_types.TuningDataset(
253+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
254+
),
255+
config=genai_types.CreateTuningJobConfig(
256+
method="DISTILLATION",
257+
base_teacher_model="deepseek-ai/deepseek-r1-0528-maas",
258+
learning_rate=1e-4,
259+
batch_size=4,
260+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
261+
tuning_mode="TUNING_MODE_FULL",
262+
http_options=VERTEX_HTTP_OPTIONS,
263+
),
264+
),
265+
exception_if_mldev="not supported in Gemini API",
266+
),
248267
pytest_helper.TestTableItem(
249268
name="test_tune_encryption_spec",
250269
parameters=genai_types.CreateTuningJobParameters(

google/genai/tunings.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,13 @@ def _CreateTuningJobConfig_to_vertex(
409409
['supervisedTuningSpec', 'tuningMode'],
410410
getv(from_object, ['tuning_mode']),
411411
)
412+
elif discriminator == 'DISTILLATION':
413+
if getv(from_object, ['tuning_mode']) is not None:
414+
setv(
415+
parent_object,
416+
['distillationSpec', 'tuningMode'],
417+
getv(from_object, ['tuning_mode']),
418+
)
412419

413420
if getv(from_object, ['custom_base_model']) is not None:
414421
setv(
@@ -427,6 +434,13 @@ def _CreateTuningJobConfig_to_vertex(
427434
['supervisedTuningSpec', 'hyperParameters', 'batchSize'],
428435
getv(from_object, ['batch_size']),
429436
)
437+
elif discriminator == 'DISTILLATION':
438+
if getv(from_object, ['batch_size']) is not None:
439+
setv(
440+
parent_object,
441+
['distillationSpec', 'hyperParameters', 'batchSize'],
442+
getv(from_object, ['batch_size']),
443+
)
430444

431445
discriminator = getv(root_object, ['config', 'method'])
432446
if discriminator is None:
@@ -438,6 +452,13 @@ def _CreateTuningJobConfig_to_vertex(
438452
['supervisedTuningSpec', 'hyperParameters', 'learningRate'],
439453
getv(from_object, ['learning_rate']),
440454
)
455+
elif discriminator == 'DISTILLATION':
456+
if getv(from_object, ['learning_rate']) is not None:
457+
setv(
458+
parent_object,
459+
['distillationSpec', 'hyperParameters', 'learningRate'],
460+
getv(from_object, ['learning_rate']),
461+
)
441462

442463
discriminator = getv(root_object, ['config', 'method'])
443464
if discriminator is None:

google/genai/types.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11217,11 +11217,17 @@ class PreferenceOptimizationSpecDict(TypedDict, total=False):
1121711217

1121811218

1121911219
class DistillationHyperParameters(_common.BaseModel):
11220-
"""Hyperparameters for Distillation.
11221-
11222-
This data type is not supported in Gemini API.
11223-
"""
11220+
"""Hyperparameters for distillation."""
1122411221

11222+
batch_size: Optional[int] = Field(
11223+
default=None,
11224+
description="""The batch size hyperparameter for tuning.
11225+
This is only supported for OSS models in Vertex.""",
11226+
)
11227+
learning_rate: Optional[float] = Field(
11228+
default=None,
11229+
description="""The learning rate for tuning. OSS models only.""",
11230+
)
1122511231
adapter_size: Optional[AdapterSize] = Field(
1122611232
default=None, description="""Optional. Adapter size for distillation."""
1122711233
)
@@ -11236,10 +11242,14 @@ class DistillationHyperParameters(_common.BaseModel):
1123611242

1123711243

1123811244
class DistillationHyperParametersDict(TypedDict, total=False):
11239-
"""Hyperparameters for Distillation.
11245+
"""Hyperparameters for distillation."""
1124011246

11241-
This data type is not supported in Gemini API.
11242-
"""
11247+
batch_size: Optional[int]
11248+
"""The batch size hyperparameter for tuning.
11249+
This is only supported for OSS models in Vertex."""
11250+
11251+
learning_rate: Optional[float]
11252+
"""The learning rate for tuning. OSS models only."""
1124311253

1124411254
adapter_size: Optional[AdapterSize]
1124511255
"""Optional. Adapter size for distillation."""
@@ -11263,14 +11273,17 @@ class DistillationSpec(_common.BaseModel):
1126311273
default=None,
1126411274
description="""The GCS URI of the prompt dataset to use during distillation.""",
1126511275
)
11266-
base_teacher_model: Optional[str] = Field(
11267-
default=None,
11268-
description="""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""",
11276+
tuning_mode: Optional[TuningMode] = Field(
11277+
default=None, description="""Tuning mode for tuning."""
1126911278
)
1127011279
hyper_parameters: Optional[DistillationHyperParameters] = Field(
1127111280
default=None,
1127211281
description="""Optional. Hyperparameters for Distillation.""",
1127311282
)
11283+
base_teacher_model: Optional[str] = Field(
11284+
default=None,
11285+
description="""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""",
11286+
)
1127411287
pipeline_root_directory: Optional[str] = Field(
1127511288
default=None,
1127611289
description="""Deprecated. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts.""",
@@ -11299,12 +11312,15 @@ class DistillationSpecDict(TypedDict, total=False):
1129911312
prompt_dataset_uri: Optional[str]
1130011313
"""The GCS URI of the prompt dataset to use during distillation."""
1130111314

11302-
base_teacher_model: Optional[str]
11303-
"""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models)."""
11315+
tuning_mode: Optional[TuningMode]
11316+
"""Tuning mode for tuning."""
1130411317

1130511318
hyper_parameters: Optional[DistillationHyperParametersDict]
1130611319
"""Optional. Hyperparameters for Distillation."""
1130711320

11321+
base_teacher_model: Optional[str]
11322+
"""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models)."""
11323+
1130811324
pipeline_root_directory: Optional[str]
1130911325
"""Deprecated. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts."""
1131011326

@@ -13421,7 +13437,7 @@ class CreateTuningJobConfig(_common.BaseModel):
1342113437
default=None, description="""Adapter size for tuning."""
1342213438
)
1342313439
tuning_mode: Optional[TuningMode] = Field(
13424-
default=None, description="""Tuning mode for SFT tuning."""
13440+
default=None, description="""Tuning mode for tuning."""
1342513441
)
1342613442
custom_base_model: Optional[str] = Field(
1342713443
default=None,
@@ -13502,7 +13518,7 @@ class CreateTuningJobConfigDict(TypedDict, total=False):
1350213518
"""Adapter size for tuning."""
1350313519

1350413520
tuning_mode: Optional[TuningMode]
13505-
"""Tuning mode for SFT tuning."""
13521+
"""Tuning mode for tuning."""
1350613522

1350713523
custom_base_model: Optional[str]
1350813524
"""Custom base model for tuning. This is only supported for OSS models in Vertex."""

0 commit comments

Comments
 (0)