diff --git a/crates/chat-cli/src/cli/chat/cli/model.rs b/crates/chat-cli/src/cli/chat/cli/model.rs index 8d4dacd87d..91d1f507b5 100644 --- a/crates/chat-cli/src/cli/chat/cli/model.rs +++ b/crates/chat-cli/src/cli/chat/cli/model.rs @@ -196,9 +196,27 @@ fn get_fallback_models() -> Vec { context_window_tokens: 200_000, }, ModelInfo { - model_name: Some("claude-4-sonnet".to_string()), - model_id: "claude-4-sonnet".to_string(), + model_name: Some("claude-sonnet-4".to_string()), + model_id: "claude-sonnet-4".to_string(), context_window_tokens: 200_000, }, ] } + +pub fn normalize_model_name(name: &str) -> &str { + match name { + "claude-4-sonnet" => "claude-sonnet-4", + // can add more mapping for backward compatibility + _ => name, + } +} + +pub fn find_model<'a>(models: &'a [ModelInfo], name: &str) -> Option<&'a ModelInfo> { + let normalized = normalize_model_name(name); + models.iter().find(|m| { + m.model_name + .as_deref() + .is_some_and(|n| n.eq_ignore_ascii_case(normalized)) + || m.model_id.eq_ignore_ascii_case(normalized) + }) +} diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 0941df73a3..0d3b7b1d8c 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -133,6 +133,7 @@ use crate::auth::AuthError; use crate::auth::builder_id::is_idc_user; use crate::cli::agent::Agents; use crate::cli::chat::cli::SlashCommand; +use crate::cli::chat::cli::model::find_model; use crate::cli::chat::cli::prompts::{ GetPromptError, PromptsSubcommand, @@ -315,13 +316,7 @@ impl ChatArgs { // Otherwise, CLI will use a default model when starting chat let (models, default_model_opt) = get_available_models(os).await?; let model_id: Option = if let Some(requested) = self.model.as_ref() { - let requested_lower = requested.to_lowercase(); - if let Some(m) = models.iter().find(|m| { - m.model_name - .as_deref() - .is_some_and(|n| n.eq_ignore_ascii_case(&requested_lower)) - || m.model_id.eq_ignore_ascii_case(&requested_lower) - }) { + if let Some(m) = find_model(&models, requested) { Some(m.model_id.clone()) } else { let available = models @@ -332,14 +327,9 @@ impl ChatArgs { bail!("Model '{}' does not exist. Available models: {}", requested, available); } } else if let Some(saved) = os.database.settings.get_string(Setting::ChatDefaultModel) { - if let Some(m) = models.iter().find(|m| { - m.model_name.as_deref().is_some_and(|n| n.eq_ignore_ascii_case(&saved)) - || m.model_id.eq_ignore_ascii_case(&saved) - }) { - Some(m.model_id.clone()) - } else { - Some(default_model_opt.model_id.clone()) - } + find_model(&models, &saved) + .map(|m| m.model_id.clone()) + .or(Some(default_model_opt.model_id.clone())) } else { Some(default_model_opt.model_id.clone()) };