Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 2118e6d

Browse files
committed
Expand model tests to improve coverage
1 parent f9b1801 commit 2118e6d

1 file changed

Lines changed: 172 additions & 1 deletion

File tree

test_dir/test_model_basic.py

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,107 @@ def test_gemini_append_history(self, mock_get_model, mock_configure):
8282
self.assertEqual(agent.history[0]["parts"][0]["text"], "Hello")
8383
self.assertEqual(agent.history[1]["role"], "model")
8484
self.assertEqual(agent.history[1]["parts"][0]["text"], "Hi there!")
85+
86+
@patch('google.generativeai.configure')
87+
@patch('google.generativeai.get_model')
88+
def test_gemini_chat_generation_parameters(self, mock_get_model, mock_configure):
89+
"""Test chat generation parameters are properly set."""
90+
agent = GeminiModelAgent("fake-api-key", "gemini-pro")
91+
mock_model = MagicMock()
92+
mock_get_model.return_value = mock_model
93+
94+
# Setup the mock model's generate_content to return a valid response
95+
mock_response = MagicMock()
96+
mock_content = MagicMock()
97+
mock_content.text = "Generated response"
98+
mock_response.candidates = [MagicMock()]
99+
mock_response.candidates[0].content = mock_content
100+
mock_model.generate_content.return_value = mock_response
101+
102+
# Add some history before chat
103+
agent.add_system_prompt("System prompt")
104+
agent.append_to_history(role="user", content="Hello")
105+
106+
# Call chat method with custom parameters
107+
response = agent.chat("What can you help me with?", temperature=0.2, max_tokens=1000)
108+
109+
# Verify the model was called with correct parameters
110+
mock_model.generate_content.assert_called_once()
111+
args, kwargs = mock_model.generate_content.call_args
112+
113+
# Check that history was included
114+
self.assertEqual(len(args[0]), 3) # System prompt + user message + new query
115+
116+
# Check that generation parameters were passed correctly
117+
self.assertEqual(kwargs.get('generation_config').temperature, 0.2)
118+
self.assertEqual(kwargs.get('generation_config').max_output_tokens, 1000)
119+
120+
# Check response handling
121+
self.assertEqual(response, "Generated response")
122+
123+
@patch('google.generativeai.configure')
124+
@patch('google.generativeai.get_model')
125+
def test_gemini_parse_response(self, mock_get_model, mock_configure):
126+
"""Test parsing different response formats from the Gemini API."""
127+
agent = GeminiModelAgent("fake-api-key", "gemini-pro")
128+
129+
# Mock normal response
130+
normal_response = MagicMock()
131+
normal_content = MagicMock()
132+
normal_content.text = "Normal response"
133+
normal_response.candidates = [MagicMock()]
134+
normal_response.candidates[0].content = normal_content
135+
136+
# Mock empty response
137+
empty_response = MagicMock()
138+
empty_response.candidates = []
139+
140+
# Mock response with finish reason not STOP
141+
blocked_response = MagicMock()
142+
blocked_response.candidates = [MagicMock()]
143+
blocked_candidate = blocked_response.candidates[0]
144+
blocked_candidate.content.text = "Blocked content"
145+
blocked_candidate.finish_reason = MagicMock()
146+
blocked_candidate.finish_reason.name = "SAFETY"
147+
148+
# Test normal response parsing
149+
result = agent._parse_response(normal_response)
150+
self.assertEqual(result, "Normal response")
151+
152+
# Test empty response parsing
153+
result = agent._parse_response(empty_response)
154+
self.assertEqual(result, "No response generated. Please try again.")
155+
156+
# Test blocked response parsing
157+
result = agent._parse_response(blocked_response)
158+
self.assertEqual(result, "The response was blocked due to: SAFETY")
159+
160+
@patch('google.generativeai.configure')
161+
@patch('google.generativeai.get_model')
162+
def test_gemini_content_handling(self, mock_get_model, mock_configure):
163+
"""Test content handling for different input types."""
164+
agent = GeminiModelAgent("fake-api-key", "gemini-pro")
165+
166+
# Test string content
167+
parts = agent._prepare_content("Hello world")
168+
self.assertEqual(len(parts), 1)
169+
self.assertEqual(parts[0]["text"], "Hello world")
170+
171+
# Test list content
172+
parts = agent._prepare_content(["Hello", "world"])
173+
self.assertEqual(len(parts), 2)
174+
self.assertEqual(parts[0]["text"], "Hello")
175+
self.assertEqual(parts[1]["text"], "world")
176+
177+
# Test already formatted content
178+
parts = agent._prepare_content([{"text": "Already formatted"}])
179+
self.assertEqual(len(parts), 1)
180+
self.assertEqual(parts[0]["text"], "Already formatted")
181+
182+
# Test empty content
183+
parts = agent._prepare_content("")
184+
self.assertEqual(len(parts), 1)
185+
self.assertEqual(parts[0]["text"], "")
85186

86187

87188
@skipIf(not IMPORTS_AVAILABLE, "Required model imports not available")
@@ -159,4 +260,74 @@ def test_ollama_prepare_chat_params(self, mock_post):
159260
self.assertEqual(params["messages"][0]["role"], "system")
160261
self.assertEqual(params["messages"][0]["content"], "System instructions")
161262
self.assertEqual(params["messages"][1]["role"], "user")
162-
self.assertEqual(params["messages"][1]["content"], "Hello")
263+
self.assertEqual(params["messages"][1]["content"], "Hello")
264+
265+
@patch('requests.post')
266+
def test_ollama_chat_with_parameters(self, mock_post):
267+
"""Test chat method with various parameters."""
268+
agent = OllamaModelAgent("http://localhost:11434", "llama2")
269+
270+
# Mock the response for the post request
271+
mock_response = MagicMock()
272+
mock_response.json.return_value = {"message": {"content": "Response from model"}}
273+
mock_post.return_value = mock_response
274+
275+
# Add a system prompt
276+
agent.add_system_prompt("Be helpful")
277+
278+
# Call chat with different parameters
279+
result = agent.chat("Hello", temperature=0.3, max_tokens=2000)
280+
281+
# Verify the post request was called with correct parameters
282+
mock_post.assert_called_once()
283+
args, kwargs = mock_post.call_args
284+
285+
# Check URL
286+
self.assertEqual(args[0], "http://localhost:11434/api/chat")
287+
288+
# Check JSON payload
289+
json_data = kwargs.get('json', {})
290+
self.assertEqual(json_data["model"], "llama2")
291+
self.assertEqual(len(json_data["messages"]), 3) # System + history + new message
292+
self.assertEqual(json_data["temperature"], 0.3)
293+
self.assertEqual(json_data["max_tokens"], 2000)
294+
295+
# Verify the response was correctly processed
296+
self.assertEqual(result, "Response from model")
297+
298+
@patch('requests.post')
299+
def test_ollama_error_handling(self, mock_post):
300+
"""Test handling of various error cases."""
301+
agent = OllamaModelAgent("http://localhost:11434", "llama2")
302+
303+
# Test connection error
304+
mock_post.side_effect = Exception("Connection failed")
305+
result = agent.chat("Hello")
306+
self.assertTrue("Error communicating with Ollama API" in result)
307+
308+
# Test bad response
309+
mock_post.side_effect = None
310+
mock_response = MagicMock()
311+
mock_response.json.return_value = {"error": "Model not found"}
312+
mock_post.return_value = mock_response
313+
result = agent.chat("Hello")
314+
self.assertTrue("Error" in result)
315+
316+
# Test missing content in response
317+
mock_response.json.return_value = {"message": {}} # Missing content
318+
result = agent.chat("Hello")
319+
self.assertTrue("Unexpected response format" in result)
320+
321+
def test_ollama_url_handling(self):
322+
"""Test handling of different URL formats."""
323+
# Test with trailing slash
324+
agent = OllamaModelAgent("http://localhost:11434/", "llama2")
325+
self.assertEqual(agent.api_url, "http://localhost:11434")
326+
327+
# Test without protocol
328+
agent = OllamaModelAgent("localhost:11434", "llama2")
329+
self.assertEqual(agent.api_url, "http://localhost:11434")
330+
331+
# Test with https
332+
agent = OllamaModelAgent("https://ollama.example.com", "llama2")
333+
self.assertEqual(agent.api_url, "https://ollama.example.com")

0 commit comments

Comments
 (0)