2121)
2222from workflowai .core .client ._types import RunParams
2323from workflowai .core .client ._utils import (
24+ ModelInstructionTemperature ,
2425 build_retryable_wait ,
2526 default_validator ,
2627 global_default_version_reference ,
@@ -123,32 +124,42 @@ class _PreparedRun(NamedTuple):
123124
124125 def _sanitize_version (self , params : VersionRunParams ) -> Union [str , int , dict [str , Any ]]:
125126 """Combine a version requested at runtime and the version requested at build time."""
127+ # Version contains either the requested version or the default version
128+ # this is important to combine the check below of whether the version is a remote version (e-g production)
129+ # or a local version (VersionProperties)
126130 version = params .get ("version" , self .version )
127- model = params .get ("model" )
128- instructions = params .get ("instructions" )
129- temperature = params .get ("temperature" )
130131
131- has_property_overrides = bool (model or instructions or temperature or self ._tools )
132+ # Combine all overrides in a tuple
133+ overrides = ModelInstructionTemperature .from_dict (params )
134+ has_property_overrides = bool (self ._tools or any (o is not None for o in overrides ))
132135
136+ # Version exists and is a remote version
133137 if version and not isinstance (version , VersionProperties ):
138+ # No property override so we return as is
134139 if not has_property_overrides and not self ._tools :
135140 return version
136141 # In the case where the version requested a build time was a remote version
137142 # (either an ID or an environment), we use an empty template for the version
138- logger .warning ("Overriding remove version with a local one" )
143+ logger .warning ("Overriding remote version with a local one" )
139144 version = VersionProperties ()
140145
146+ # Version does not exist and there are no overrides
147+ # We return the default version
141148 if not version and not has_property_overrides :
142149 g = global_default_version_reference ()
143150 return g .model_dump (by_alias = True , exclude_unset = True ) if isinstance (g , VersionProperties ) else g
144151
145152 dumped = version .model_dump (by_alias = True , exclude_unset = True ) if version else {}
146153
147- if not dumped .get ("model" ):
154+ requested = ModelInstructionTemperature .from_version (version )
155+ defaults = ModelInstructionTemperature .from_version (self .version )
156+ combined = ModelInstructionTemperature .combine (overrides , requested , defaults )
157+
158+ if not combined .model :
148159 # We always provide a default model since it is required by the API
149160 import workflowai
150161
151- dumped [ "model" ] = workflowai .DEFAULT_MODEL
162+ combined = combined . _replace ( model = workflowai .DEFAULT_MODEL )
152163
153164 if self ._tools :
154165 dumped ["enabled_tools" ] = [
@@ -161,12 +172,12 @@ def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[st
161172 for tool in self ._tools .values ()
162173 ]
163174 # Finally we apply the property overrides
164- if model :
165- dumped ["model" ] = model
166- if instructions :
167- dumped ["instructions" ] = instructions
168- if temperature :
169- dumped ["temperature" ] = temperature
175+ if combined . model is not None :
176+ dumped ["model" ] = combined . model
177+ if combined . instructions is not None :
178+ dumped ["instructions" ] = combined . instructions
179+ if combined . temperature is not None :
180+ dumped ["temperature" ] = combined . temperature
170181 return dumped
171182
172183 async def _prepare_run (self , agent_input : AgentInput , stream : bool , ** kwargs : Unpack [RunParams [AgentOutput ]]):
0 commit comments