Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions google/genai/tests/tunings/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,25 @@
),
exception_if_mldev="not supported in Gemini API",
),
pytest_helper.TestTableItem(
name="test_tune_oss_distillation_hyperparams",
parameters=genai_types.CreateTuningJobParameters(
base_model="qwen/qwen3@qwen3-4b",
training_dataset=genai_types.TuningDataset(
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
),
config=genai_types.CreateTuningJobConfig(
method="DISTILLATION",
base_teacher_model="deepseek-ai/deepseek-r1-0528-maas",
learning_rate=1e-4,
batch_size=4,
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
tuning_mode="TUNING_MODE_FULL",
http_options=VERTEX_HTTP_OPTIONS,
),
),
exception_if_mldev="not supported in Gemini API",
),
pytest_helper.TestTableItem(
name="test_tune_encryption_spec",
parameters=genai_types.CreateTuningJobParameters(
Expand Down
21 changes: 21 additions & 0 deletions google/genai/tunings.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ def _CreateTuningJobConfig_to_vertex(
['supervisedTuningSpec', 'tuningMode'],
getv(from_object, ['tuning_mode']),
)
elif discriminator == 'DISTILLATION':
if getv(from_object, ['tuning_mode']) is not None:
setv(
parent_object,
['distillationSpec', 'tuningMode'],
getv(from_object, ['tuning_mode']),
)

if getv(from_object, ['custom_base_model']) is not None:
setv(
Expand All @@ -427,6 +434,13 @@ def _CreateTuningJobConfig_to_vertex(
['supervisedTuningSpec', 'hyperParameters', 'batchSize'],
getv(from_object, ['batch_size']),
)
elif discriminator == 'DISTILLATION':
if getv(from_object, ['batch_size']) is not None:
setv(
parent_object,
['distillationSpec', 'hyperParameters', 'batchSize'],
getv(from_object, ['batch_size']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
Expand All @@ -438,6 +452,13 @@ def _CreateTuningJobConfig_to_vertex(
['supervisedTuningSpec', 'hyperParameters', 'learningRate'],
getv(from_object, ['learning_rate']),
)
elif discriminator == 'DISTILLATION':
if getv(from_object, ['learning_rate']) is not None:
setv(
parent_object,
['distillationSpec', 'hyperParameters', 'learningRate'],
getv(from_object, ['learning_rate']),
)

discriminator = getv(root_object, ['config', 'method'])
if discriminator is None:
Expand Down
44 changes: 30 additions & 14 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11217,11 +11217,17 @@ class PreferenceOptimizationSpecDict(TypedDict, total=False):


class DistillationHyperParameters(_common.BaseModel):
"""Hyperparameters for Distillation.

This data type is not supported in Gemini API.
"""
"""Hyperparameters for distillation."""

batch_size: Optional[int] = Field(
default=None,
description="""The batch size hyperparameter for tuning.
This is only supported for OSS models in Vertex.""",
)
learning_rate: Optional[float] = Field(
default=None,
description="""The learning rate for tuning. OSS models only.""",
)
adapter_size: Optional[AdapterSize] = Field(
default=None, description="""Optional. Adapter size for distillation."""
)
Expand All @@ -11236,10 +11242,14 @@ class DistillationHyperParameters(_common.BaseModel):


class DistillationHyperParametersDict(TypedDict, total=False):
"""Hyperparameters for Distillation.
"""Hyperparameters for distillation."""

This data type is not supported in Gemini API.
"""
batch_size: Optional[int]
"""The batch size hyperparameter for tuning.
This is only supported for OSS models in Vertex."""

learning_rate: Optional[float]
"""The learning rate for tuning. OSS models only."""

adapter_size: Optional[AdapterSize]
"""Optional. Adapter size for distillation."""
Expand All @@ -11263,14 +11273,17 @@ class DistillationSpec(_common.BaseModel):
default=None,
description="""The GCS URI of the prompt dataset to use during distillation.""",
)
base_teacher_model: Optional[str] = Field(
default=None,
description="""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""",
tuning_mode: Optional[TuningMode] = Field(
default=None, description="""Tuning mode for tuning."""
)
hyper_parameters: Optional[DistillationHyperParameters] = Field(
default=None,
description="""Optional. Hyperparameters for Distillation.""",
)
base_teacher_model: Optional[str] = Field(
default=None,
description="""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models).""",
)
pipeline_root_directory: Optional[str] = Field(
default=None,
description="""Deprecated. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts.""",
Expand Down Expand Up @@ -11299,12 +11312,15 @@ class DistillationSpecDict(TypedDict, total=False):
prompt_dataset_uri: Optional[str]
"""The GCS URI of the prompt dataset to use during distillation."""

base_teacher_model: Optional[str]
"""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models)."""
tuning_mode: Optional[TuningMode]
"""Tuning mode for tuning."""

hyper_parameters: Optional[DistillationHyperParametersDict]
"""Optional. Hyperparameters for Distillation."""

base_teacher_model: Optional[str]
"""The base teacher model that is being distilled. See [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#supported_models)."""

pipeline_root_directory: Optional[str]
"""Deprecated. A path in a Cloud Storage bucket, which will be treated as the root output directory of the distillation pipeline. It is used by the system to generate the paths of output artifacts."""

Expand Down Expand Up @@ -13421,7 +13437,7 @@ class CreateTuningJobConfig(_common.BaseModel):
default=None, description="""Adapter size for tuning."""
)
tuning_mode: Optional[TuningMode] = Field(
default=None, description="""Tuning mode for SFT tuning."""
default=None, description="""Tuning mode for tuning."""
)
custom_base_model: Optional[str] = Field(
default=None,
Expand Down Expand Up @@ -13502,7 +13518,7 @@ class CreateTuningJobConfigDict(TypedDict, total=False):
"""Adapter size for tuning."""

tuning_mode: Optional[TuningMode]
"""Tuning mode for SFT tuning."""
"""Tuning mode for tuning."""

custom_base_model: Optional[str]
"""Custom base model for tuning. This is only supported for OSS models in Vertex."""
Expand Down
Loading