|
| 1 | +""" |
| 2 | +Tests for the model factory functions in code_agent.adk.models_v2 module. |
| 3 | +""" |
| 4 | + |
| 5 | +from unittest.mock import patch |
| 6 | + |
| 7 | +import pytest |
| 8 | +from google.adk.models import Gemini |
| 9 | + |
| 10 | +from code_agent.adk.models_v2 import ( |
| 11 | + LiteLlm, |
| 12 | + ModelConfig, |
| 13 | + OllamaLlm, |
| 14 | + create_model, |
| 15 | + get_default_models_by_provider, |
| 16 | + get_model_providers, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +class TestModelFactory: |
| 21 | + """Test the model factory functions in the models_v2 module.""" |
| 22 | + |
| 23 | + @patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key") |
| 24 | + def test_create_model_gemini(self, mock_get_api_key): |
| 25 | + """Test creating a Gemini model.""" |
| 26 | + model = create_model(provider="ai_studio", model_name="gemini-1.5-flash") |
| 27 | + assert isinstance(model, Gemini) |
| 28 | + assert model.model == "gemini-1.5-flash" |
| 29 | + assert model.api_key == "fake-api-key" |
| 30 | + |
| 31 | + @patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key") |
| 32 | + def test_create_model_openai(self, mock_get_api_key): |
| 33 | + """Test creating an OpenAI model.""" |
| 34 | + model = create_model(provider="openai", model_name="gpt-4-turbo") |
| 35 | + assert isinstance(model, LiteLlm) |
| 36 | + assert model.provider == "openai" |
| 37 | + assert model.model_name == "gpt-4-turbo" |
| 38 | + assert model.api_key == "fake-api-key" |
| 39 | + assert model.litellm_model == "openai/gpt-4-turbo" |
| 40 | + |
| 41 | + @patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key") |
| 42 | + def test_create_model_anthropic(self, mock_get_api_key): |
| 43 | + """Test creating an Anthropic model.""" |
| 44 | + model = create_model(provider="anthropic", model_name="claude-3-opus") |
| 45 | + assert isinstance(model, LiteLlm) |
| 46 | + assert model.provider == "anthropic" |
| 47 | + assert model.model_name == "claude-3-opus" |
| 48 | + assert model.api_key == "fake-api-key" |
| 49 | + assert model.litellm_model == "anthropic/claude-3-opus" |
| 50 | + |
| 51 | + def test_create_model_ollama(self): |
| 52 | + """Test creating an Ollama model.""" |
| 53 | + model = create_model(provider="ollama", model_name="llama3.2") |
| 54 | + assert isinstance(model, OllamaLlm) |
| 55 | + assert model.provider == "ollama" |
| 56 | + assert model.model_name == "llama3.2" |
| 57 | + assert model.base_url == "http://localhost:11434" |
| 58 | + assert model.litellm_model == "ollama/llama3.2" |
| 59 | + |
| 60 | + @patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key") |
| 61 | + def test_create_model_with_temperature(self, mock_get_api_key): |
| 62 | + """Test creating a model with a custom temperature.""" |
| 63 | + model = create_model(provider="openai", model_name="gpt-4", temperature=0.2) |
| 64 | + assert isinstance(model, LiteLlm) |
| 65 | + assert model.temperature == 0.2 |
| 66 | + |
| 67 | + @patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key") |
| 68 | + def test_create_model_with_max_tokens(self, mock_get_api_key): |
| 69 | + """Test creating a model with custom max_tokens.""" |
| 70 | + model = create_model(provider="openai", model_name="gpt-4", max_tokens=1000) |
| 71 | + assert isinstance(model, LiteLlm) |
| 72 | + assert model.max_tokens == 1000 |
| 73 | + |
| 74 | + @patch("code_agent.adk.models_v2.get_api_key", return_value=None) |
| 75 | + def test_create_model_missing_api_key(self, mock_get_api_key): |
| 76 | + """Test error when API key is missing for providers that need it.""" |
| 77 | + with pytest.raises(ValueError, match="No API key found for provider"): |
| 78 | + create_model(provider="openai", model_name="gpt-4") |
| 79 | + |
| 80 | + @patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key") |
| 81 | + def test_create_model_unknown_provider(self, mock_get_api_key): |
| 82 | + """Test creating a model with an unknown provider.""" |
| 83 | + with pytest.raises(ValueError, match="Unknown provider"): |
| 84 | + create_model(provider="unknown", model_name="model") |
| 85 | + |
| 86 | + @patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key") |
| 87 | + def test_create_model_with_fallback(self, mock_get_api_key): |
| 88 | + """Test model creation with fallback configuration.""" |
| 89 | + model = create_model(provider="openai", model_name="gpt-4", fallback_provider="anthropic", fallback_model="claude-3-sonnet") |
| 90 | + # Verify the fallback configuration is stored somewhere |
| 91 | + # The exact implementation depends on how fallback is handled |
| 92 | + assert hasattr(model, "_fallback_config") |
| 93 | + assert model._fallback_config.provider == "anthropic" |
| 94 | + assert model._fallback_config.model_name == "claude-3-sonnet" |
| 95 | + |
| 96 | + def test_get_model_providers(self): |
| 97 | + """Test the get_model_providers function returns a non-empty list.""" |
| 98 | + providers = get_model_providers() |
| 99 | + assert isinstance(providers, list) |
| 100 | + assert len(providers) > 0 |
| 101 | + assert "openai" in providers |
| 102 | + assert "ai_studio" in providers |
| 103 | + assert "anthropic" in providers |
| 104 | + assert "ollama" in providers |
| 105 | + |
| 106 | + def test_get_default_models_by_provider(self): |
| 107 | + """Test the get_default_models_by_provider function returns a non-empty dict.""" |
| 108 | + default_models = get_default_models_by_provider() |
| 109 | + assert isinstance(default_models, dict) |
| 110 | + assert len(default_models) > 0 |
| 111 | + assert "openai" in default_models |
| 112 | + assert "ai_studio" in default_models |
| 113 | + assert "anthropic" in default_models |
| 114 | + assert "ollama" in default_models |
| 115 | + |
| 116 | + |
| 117 | +class TestModelConfig: |
| 118 | + """Test the ModelConfig class.""" |
| 119 | + |
| 120 | + def test_model_config_creation(self): |
| 121 | + """Test creating a ModelConfig instance.""" |
| 122 | + config = ModelConfig( |
| 123 | + provider="openai", |
| 124 | + model_name="gpt-4", |
| 125 | + temperature=0.5, |
| 126 | + max_tokens=1000, |
| 127 | + timeout=60, |
| 128 | + retry_count=3, |
| 129 | + fallback_provider="anthropic", |
| 130 | + fallback_model="claude-3-opus", |
| 131 | + ) |
| 132 | + assert config.provider == "openai" |
| 133 | + assert config.model_name == "gpt-4" |
| 134 | + assert config.temperature == 0.5 |
| 135 | + assert config.max_tokens == 1000 |
| 136 | + assert config.timeout == 60 |
| 137 | + assert config.retry_count == 3 |
| 138 | + assert config.fallback_provider == "anthropic" |
| 139 | + assert config.fallback_model == "claude-3-opus" |
| 140 | + |
| 141 | + def test_model_config_defaults(self): |
| 142 | + """Test ModelConfig default values.""" |
| 143 | + config = ModelConfig(provider="openai", model_name="gpt-4") |
| 144 | + assert config.temperature == 0.7 # Default value |
| 145 | + assert config.max_tokens is None # Default value |
| 146 | + assert config.timeout is None # Default value |
| 147 | + assert config.retry_count == 2 # Default value |
| 148 | + assert config.fallback_provider is None # Default value |
| 149 | + assert config.fallback_model is None # Default value |
0 commit comments