Skip to content

Commit 402d7fe

Browse files
author
Pierre
authored
Merge branch 'main' into pierre-examples-improvements
2 parents 1d2531e + ad52ded commit 402d7fe

File tree

5 files changed

+86
-15
lines changed

5 files changed

+86
-15
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.6.0.dev21"
3+
version = "0.6.0.dev22"
44
description = ""
55
authors = ["Guillaume Aquilina <guillaume@workflowai.com>"]
66
readme = "README.md"

workflowai/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from workflowai.core.domain.model import Model as Model
1313
from workflowai.core.domain.run import Run as Run
1414
from workflowai.core.domain.version import Version as Version
15+
from workflowai.core.domain.version_properties import VersionProperties as VersionProperties
1516
from workflowai.core.domain.version_reference import (
1617
VersionReference as VersionReference,
1718
)

workflowai/core/client/_utils.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
import asyncio
55
import os
66
import re
7+
from collections.abc import Mapping
78
from json import JSONDecodeError
89
from time import time
9-
from typing import Any
10+
from typing import Any, NamedTuple, Optional, Union
11+
12+
from typing_extensions import Self
1013

1114
from workflowai.core._common_types import OutputValidator
1215
from workflowai.core._logger import logger
1316
from workflowai.core.domain.errors import BaseError, WorkflowAIError
1417
from workflowai.core.domain.task import AgentOutput
18+
from workflowai.core.domain.version_properties import VersionProperties
1519
from workflowai.core.domain.version_reference import VersionReference
1620
from workflowai.core.utils._pydantic import partial_model
1721

@@ -113,3 +117,38 @@ def global_default_version_reference() -> VersionReference:
113117
logger.warning("Invalid default version: %s", version)
114118

115119
return "production"
120+
121+
122+
class ModelInstructionTemperature(NamedTuple):
123+
"""A combination of run properties, with useful method
124+
for combination"""
125+
126+
model: Optional[str] = None
127+
instructions: Optional[str] = None
128+
temperature: Optional[float] = None
129+
130+
@classmethod
131+
def from_dict(cls, d: Mapping[str, Any]):
132+
return cls(
133+
model=d.get("model"),
134+
instructions=d.get("instructions"),
135+
temperature=d.get("temperature"),
136+
)
137+
138+
@classmethod
139+
def from_version(cls, version: Union[int, str, VersionProperties, None]):
140+
if isinstance(version, VersionProperties):
141+
return cls(
142+
model=version.model,
143+
instructions=version.instructions,
144+
temperature=version.temperature,
145+
)
146+
return cls()
147+
148+
@classmethod
149+
def combine(cls, *args: Self):
150+
return cls(
151+
model=next((a.model for a in args if a.model is not None), None),
152+
instructions=next((a.instructions for a in args if a.instructions is not None), None),
153+
temperature=next((a.temperature for a in args if a.temperature is not None), None),
154+
)

workflowai/core/client/agent.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from workflowai.core.client._types import RunParams
2323
from workflowai.core.client._utils import (
24+
ModelInstructionTemperature,
2425
build_retryable_wait,
2526
default_validator,
2627
global_default_version_reference,
@@ -123,32 +124,42 @@ class _PreparedRun(NamedTuple):
123124

124125
def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[str, Any]]:
125126
"""Combine a version requested at runtime and the version requested at build time."""
127+
# Version contains either the requested version or the default version
128+
# this is important to combine the check below of whether the version is a remote version (e-g production)
129+
# or a local version (VersionProperties)
126130
version = params.get("version", self.version)
127-
model = params.get("model")
128-
instructions = params.get("instructions")
129-
temperature = params.get("temperature")
130131

131-
has_property_overrides = bool(model or instructions or temperature or self._tools)
132+
# Combine all overrides in a tuple
133+
overrides = ModelInstructionTemperature.from_dict(params)
134+
has_property_overrides = bool(self._tools or any(o is not None for o in overrides))
132135

136+
# Version exists and is a remote version
133137
if version and not isinstance(version, VersionProperties):
138+
# No property override so we return as is
134139
if not has_property_overrides and not self._tools:
135140
return version
136141
# In the case where the version requested a build time was a remote version
137142
# (either an ID or an environment), we use an empty template for the version
138-
logger.warning("Overriding remove version with a local one")
143+
logger.warning("Overriding remote version with a local one")
139144
version = VersionProperties()
140145

146+
# Version does not exist and there are no overrides
147+
# We return the default version
141148
if not version and not has_property_overrides:
142149
g = global_default_version_reference()
143150
return g.model_dump(by_alias=True, exclude_unset=True) if isinstance(g, VersionProperties) else g
144151

145152
dumped = version.model_dump(by_alias=True, exclude_unset=True) if version else {}
146153

147-
if not dumped.get("model"):
154+
requested = ModelInstructionTemperature.from_version(version)
155+
defaults = ModelInstructionTemperature.from_version(self.version)
156+
combined = ModelInstructionTemperature.combine(overrides, requested, defaults)
157+
158+
if not combined.model:
148159
# We always provide a default model since it is required by the API
149160
import workflowai
150161

151-
dumped["model"] = workflowai.DEFAULT_MODEL
162+
combined = combined._replace(model=workflowai.DEFAULT_MODEL)
152163

153164
if self._tools:
154165
dumped["enabled_tools"] = [
@@ -161,12 +172,12 @@ def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[st
161172
for tool in self._tools.values()
162173
]
163174
# Finally we apply the property overrides
164-
if model:
165-
dumped["model"] = model
166-
if instructions:
167-
dumped["instructions"] = instructions
168-
if temperature:
169-
dumped["temperature"] = temperature
175+
if combined.model is not None:
176+
dumped["model"] = combined.model
177+
if combined.instructions is not None:
178+
dumped["instructions"] = combined.instructions
179+
if combined.temperature is not None:
180+
dumped["temperature"] = combined.temperature
170181
return dumped
171182

172183
async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]):

workflowai/core/client/agent_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,26 @@ def test_only_model_privider(self, agent: Agent[HelloTaskInput, HelloTaskOutput]
385385
"instructions": "You are a helpful assistant.",
386386
}
387387

388+
def test_with_explicit_version_without_instructions(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
389+
"""In the case where the agent has instructions but we send a version without instructions,
390+
we use the instructions from the agent"""
391+
392+
agent.version = VersionProperties(instructions="You are a helpful assistant.")
393+
sanitized = agent._sanitize_version({"version": VersionProperties(model="gpt-4o-latest")}) # pyright: ignore [reportPrivateUsage]
394+
assert sanitized == {
395+
"model": "gpt-4o-latest",
396+
"instructions": "You are a helpful assistant.",
397+
}
398+
399+
def test_override_with_0_temperature(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
400+
"""Test that a 0 temperature is not overridden by the default version"""
401+
agent.version = VersionProperties(temperature=0.7)
402+
sanitized = agent._sanitize_version({"version": VersionProperties(temperature=0)}) # pyright: ignore [reportPrivateUsage]
403+
assert sanitized == {
404+
"model": "gemini-1.5-pro-latest",
405+
"temperature": 0.0,
406+
}
407+
388408

389409
class TestListModels:
390410
async def test_list_models(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock):

0 commit comments

Comments
 (0)