diff --git a/packages/ui/src/stores/session-actions.ts b/packages/ui/src/stores/session-actions.ts index 3cdf0321..af7515f8 100644 --- a/packages/ui/src/stores/session-actions.ts +++ b/packages/ui/src/stores/session-actions.ts @@ -11,6 +11,15 @@ import { removeMessagePartV2, removeMessageV2 } from "./message-v2/bridge" import { getLogger } from "../lib/logger" import { requestData } from "../lib/opencode-api" import { clearConversationPlaybackForSession } from "./conversation-speech" +import { + getOverride, + setAgentOverride, + setModelOverride, + clearOverrides, + clearAgentOverride, + clearModelOverride, + getModelOverrideValue, +} from "./session-overrides" const log = getLogger("actions") @@ -183,20 +192,25 @@ async function sendMessage( /* trigger reactivity for legacy session data */ }) + // Only send agent/model when the user explicitly selected them (override state). + // session.agent/session.model are display state (what the server has), NOT user intent. + const userOverride = getOverride(instanceId, sessionId) + const overrideModel = userOverride?.model + ? getModelOverrideValue(instanceId, sessionId) + : undefined + const requestBody = { parts: requestParts, - ...(session.agent && { agent: session.agent }), - ...(session.model.providerId && - session.model.modelId && { - model: { - providerID: session.model.providerId, - modelID: session.model.modelId, - }, - }), - ...(session.model.providerId && - session.model.modelId && + ...(userOverride?.agent && { agent: userOverride.agent }), + ...(overrideModel && { + model: { + providerID: overrideModel.providerId, + modelID: overrideModel.modelId, + }, + }), + ...(overrideModel && (() => { - const variant = getThinkingVariantToSend(instanceId, session.model) + const variant = getThinkingVariantToSend(instanceId, overrideModel) return variant ? { variant } : {} })()), } @@ -216,6 +230,7 @@ async function sendMessage( }), "session.promptAsync", ) + clearOverrides(instanceId, sessionId) } catch (error) { log.error("Failed to send prompt", error) throw error @@ -254,14 +269,20 @@ async function executeCustomCommand( messageID: createId("msg"), } - if (session.agent) { - body.agent = session.agent + // Only send agent/model when the user explicitly selected them (override state). + const userOverride = getOverride(instanceId, sessionId) + + if (userOverride?.agent) { + body.agent = userOverride.agent } - if (session.model.providerId && session.model.modelId) { - body.model = `${session.model.providerId}/${session.model.modelId}` - const variant = getThinkingVariantToSend(instanceId, session.model) - if (variant) body.variant = variant + if (userOverride?.model) { + const overrideModel = getModelOverrideValue(instanceId, sessionId) + if (overrideModel) { + body.model = `${overrideModel.providerId}/${overrideModel.modelId}` + const variant = getThinkingVariantToSend(instanceId, overrideModel) + if (variant) body.variant = variant + } } await requestData( @@ -271,6 +292,7 @@ async function executeCustomCommand( }), "session.command", ) + clearOverrides(instanceId, sessionId) } async function runShellCommand(instanceId: string, sessionId: string, command: string): Promise { @@ -287,7 +309,9 @@ async function runShellCommand(instanceId: string, sessionId: string, command: s throw new Error("Session not found") } - const agent = session.agent || "build" + // Only use override agent — fall back to "build" (OpenCode default) if user hasn't explicitly set one. + const userOverride = getOverride(instanceId, sessionId) + const agent = userOverride?.agent || "build" await requestData( client.session.shell({ @@ -297,6 +321,9 @@ async function runShellCommand(instanceId: string, sessionId: string, command: s }), "session.shell", ) + if (userOverride?.agent) { + clearAgentOverride(instanceId, sessionId) + } } async function abortSession(instanceId: string, sessionId: string): Promise { @@ -335,6 +362,7 @@ async function updateSessionAgent(instanceId: string, sessionId: string, agent: const nextModel = await getDefaultModel(instanceId, agent) const shouldApplyModel = isModelValid(instanceId, nextModel) + // Update local display state withSession(instanceId, sessionId, (current) => { current.agent = agent if (shouldApplyModel) { @@ -342,6 +370,15 @@ async function updateSessionAgent(instanceId: string, sessionId: string, agent: } }) + // Record explicit user override so the next prompt sends these values. + setAgentOverride(instanceId, sessionId, agent) + if (shouldApplyModel) { + setModelOverride(instanceId, sessionId, nextModel) + } else { + // Agent changed but no valid model found — clear any stale model override + clearModelOverride(instanceId, sessionId) + } + if (agent && shouldApplyModel) { await setAgentModelPreference(instanceId, agent, nextModel) } @@ -367,10 +404,14 @@ async function updateSessionModel( return } + // Update local display state withSession(instanceId, sessionId, (current) => { current.model = model }) + // Record explicit user override so the next prompt sends this model. + setModelOverride(instanceId, sessionId, model) + if (session.agent) { await setAgentModelPreference(instanceId, session.agent, model) } diff --git a/packages/ui/src/stores/session-api.ts b/packages/ui/src/stores/session-api.ts index 36609380..057c52a7 100644 --- a/packages/ui/src/stores/session-api.ts +++ b/packages/ui/src/stores/session-api.ts @@ -49,6 +49,11 @@ import { removeParentSessionMapping, setWorktreeSlugForParentSession, } from "./worktrees" +import { + clearOverrides, + setAgentOverride, + setModelOverride, +} from "./session-overrides" const log = getLogger("api") @@ -443,6 +448,16 @@ async function createSession(instanceId: string, agent?: string): Promise { const next = new Map(prev) const loadedSet = next.get(instanceId) || new Set() diff --git a/packages/ui/src/stores/session-overrides.ts b/packages/ui/src/stores/session-overrides.ts new file mode 100644 index 00000000..487b4a6c --- /dev/null +++ b/packages/ui/src/stores/session-overrides.ts @@ -0,0 +1,132 @@ +/** + * Session overrides track user-explicit agent/model selections. + * + * Invariant: + * `session.agent` / `session.model` = display state (what the client believes the server has). + * `overrides.agent` / `overrides.model` = user intent (what the user wants to force on the next request). + * + * Only overrides are sent in promptAsync / command request bodies. + * Server-observed values (from SSE, message loading) update display state but do NOT create overrides. + */ + +import { createSignal } from "solid-js" + +export interface SessionOverride { + agent?: string + /** Format: "providerId/modelId" */ + model?: string +} + +type OverrideMap = Map> + +const [overrides, setOverrides] = createSignal(new Map()) + +function getOverride(instanceId: string, sessionId: string): SessionOverride | undefined { + return overrides().get(instanceId)?.get(sessionId) +} + +function setOverride(instanceId: string, sessionId: string, patch: Partial): void { + setOverrides((prev) => { + const next = new Map(prev) + const instanceMap = new Map(next.get(instanceId) ?? new Map()) + const existing = instanceMap.get(sessionId) ?? {} + const merged: SessionOverride = { ...existing, ...patch } + + // Clean up undefined/empty values + if (!merged.agent) delete merged.agent + if (!merged.model) delete merged.model + + if (!merged.agent && !merged.model) { + instanceMap.delete(sessionId) + } else { + instanceMap.set(sessionId, merged) + } + + if (instanceMap.size === 0) { + next.delete(instanceId) + } else { + next.set(instanceId, instanceMap) + } + return next + }) +} + +/** Record that the user explicitly selected an agent for this session. */ +function setAgentOverride(instanceId: string, sessionId: string, agent: string): void { + setOverride(instanceId, sessionId, { agent }) +} + +/** Record that the user explicitly selected a model for this session. */ +function setModelOverride(instanceId: string, sessionId: string, model: { providerId: string; modelId: string }): void { + if (!model.providerId || !model.modelId) return + setOverride(instanceId, sessionId, { model: `${model.providerId}/${model.modelId}` }) +} + +/** Clear the agent override (e.g., after server confirms the agent change). */ +function clearAgentOverride(instanceId: string, sessionId: string): void { + const current = getOverride(instanceId, sessionId) + if (!current?.agent) return + setOverride(instanceId, sessionId, { agent: undefined }) +} + +/** Clear the model override (e.g., after server confirms the model change). */ +function clearModelOverride(instanceId: string, sessionId: string): void { + const current = getOverride(instanceId, sessionId) + if (!current?.model) return + setOverride(instanceId, sessionId, { model: undefined }) +} + +/** Clear both overrides for a session (e.g., when server state diverges). */ +function clearOverrides(instanceId: string, sessionId: string): void { + setOverrides((prev) => { + const next = new Map(prev) + const instanceMap = next.get(instanceId) + if (!instanceMap) return prev + if (!instanceMap.has(sessionId)) return prev + const updated = new Map(instanceMap) + updated.delete(sessionId) + if (updated.size === 0) { + next.delete(instanceId) + } else { + next.set(instanceId, updated) + } + return next + }) +} + +/** Clear all overrides for an instance (e.g., on disconnect). */ +function clearInstanceOverrides(instanceId: string): void { + setOverrides((prev) => { + if (!prev.has(instanceId)) return prev + const next = new Map(prev) + next.delete(instanceId) + return next + }) +} + +/** + * Parse a model override string back into providerId/modelId. + * Returns undefined if no model override exists. + */ +function getModelOverrideValue(instanceId: string, sessionId: string): { providerId: string; modelId: string } | undefined { + const raw = getOverride(instanceId, sessionId)?.model + if (!raw) return undefined + const slashIndex = raw.indexOf("/") + if (slashIndex <= 0) return undefined + return { + providerId: raw.substring(0, slashIndex), + modelId: raw.substring(slashIndex + 1), + } +} + +export { + overrides, + getOverride, + setAgentOverride, + setModelOverride, + clearAgentOverride, + clearModelOverride, + clearOverrides, + clearInstanceOverrides, + getModelOverrideValue, +}