Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ def _extract_step_from_wandb_artifact(artifact: "wandb.Artifact") -> int | None:
return None


def _wandb_checkpoint_collection_path(
*,
from_model: str,
from_project: str,
model_entity: str | None,
default_entity: str | None,
from_entity: str | None = None,
) -> str:
resolved_entity = from_entity or model_entity or default_entity
if resolved_entity is None:
raise ValueError("A W&B entity is required to locate the source checkpoint")
return f"{resolved_entity}/{from_project}/{from_model}"


_UPSTREAM_TRAIN_METRIC_KEYS = {
"reward": "reward",
"reward_std_dev": "reward_std_dev",
Expand Down Expand Up @@ -728,6 +742,7 @@ async def _experimental_fork_checkpoint(
model: "Model",
from_model: str,
from_project: str | None = None,
from_entity: str | None = None,
from_s3_bucket: str | None = None,
not_after_step: int | None = None,
verbose: bool = False,
Expand All @@ -746,6 +761,8 @@ async def _experimental_fork_checkpoint(
model: The destination model to fork to.
from_model: The name of the source model to fork from.
from_project: The project of the source model. Defaults to model.project.
from_entity: Optional W&B entity for the source model. Defaults to the
destination model's entity, then the W&B default entity.
from_s3_bucket: Optional S3 bucket to pull the checkpoint from.
not_after_step: If provided, uses the latest checkpoint <= this step.
verbose: Whether to print verbose output.
Expand Down Expand Up @@ -812,12 +829,17 @@ async def _experimental_fork_checkpoint(
else:
# Pull from W&B artifacts
api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute]
from_entity = model.entity or api.default_entity

# Iterate all artifact versions to find the best step.
# We avoid relying on the W&B `:latest` alias because it
# may not correspond to the highest training step.
collection_path = f"{from_entity}/{from_project}/{from_model}"
collection_path = _wandb_checkpoint_collection_path(
from_model=from_model,
from_project=from_project,
from_entity=from_entity,
model_entity=model.entity,
default_entity=api.default_entity,
)
versions = api.artifacts("lora", collection_path)

best_step: int | None = None
Expand Down
92 changes: 92 additions & 0 deletions tests/unit/test_serverless_fork_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import sys
from types import SimpleNamespace

import pytest

from art.serverless.backend import (
ServerlessBackend,
_wandb_checkpoint_collection_path,
)


def test_checkpoint_collection_path_prefers_explicit_source_entity():
path = _wandb_checkpoint_collection_path(
from_model="source-model",
from_project="source-project",
from_entity="source-entity",
model_entity="destination-entity",
default_entity="default-entity",
)

assert path == "source-entity/source-project/source-model"


def test_checkpoint_collection_path_falls_back_to_destination_entity():
path = _wandb_checkpoint_collection_path(
from_model="source-model",
from_project="source-project",
from_entity=None,
model_entity="destination-entity",
default_entity="default-entity",
)

assert path == "destination-entity/source-project/source-model"


def test_checkpoint_collection_path_falls_back_to_default_entity():
path = _wandb_checkpoint_collection_path(
from_model="source-model",
from_project="source-project",
from_entity=None,
model_entity=None,
default_entity="default-entity",
)

assert path == "default-entity/source-project/source-model"


def test_checkpoint_collection_path_requires_an_entity():
with pytest.raises(ValueError, match="W&B entity"):
_wandb_checkpoint_collection_path(
from_model="source-model",
from_project="source-project",
from_entity=None,
model_entity=None,
default_entity=None,
)


@pytest.mark.asyncio
async def test_fork_checkpoint_uses_explicit_source_entity(monkeypatch):
artifact_calls = []

class FakeApi:
default_entity = "default-entity"

def __init__(self, api_key):
assert api_key == "test-api-key"

def artifacts(self, artifact_type, collection_path):
artifact_calls.append((artifact_type, collection_path))
return []

fake_wandb = SimpleNamespace(Api=FakeApi)
monkeypatch.setitem(sys.modules, "wandb", fake_wandb)

backend = ServerlessBackend.__new__(ServerlessBackend)
backend._client = SimpleNamespace(api_key="test-api-key")
model = SimpleNamespace(
entity="destination-entity",
project="destination-project",
name="destination-model",
)

with pytest.raises(ValueError, match="No checkpoints found"):
await backend._experimental_fork_checkpoint(
model,
from_model="source-model",
from_project="source-project",
from_entity="source-entity",
)

assert artifact_calls == [("lora", "source-entity/source-project/source-model")]