@@ -82,6 +82,107 @@ def test_gemini_append_history(self, mock_get_model, mock_configure):
8282 self .assertEqual (agent .history [0 ]["parts" ][0 ]["text" ], "Hello" )
8383 self .assertEqual (agent .history [1 ]["role" ], "model" )
8484 self .assertEqual (agent .history [1 ]["parts" ][0 ]["text" ], "Hi there!" )
85+
86+ @patch ('google.generativeai.configure' )
87+ @patch ('google.generativeai.get_model' )
88+ def test_gemini_chat_generation_parameters (self , mock_get_model , mock_configure ):
89+ """Test chat generation parameters are properly set."""
90+ agent = GeminiModelAgent ("fake-api-key" , "gemini-pro" )
91+ mock_model = MagicMock ()
92+ mock_get_model .return_value = mock_model
93+
94+ # Setup the mock model's generate_content to return a valid response
95+ mock_response = MagicMock ()
96+ mock_content = MagicMock ()
97+ mock_content .text = "Generated response"
98+ mock_response .candidates = [MagicMock ()]
99+ mock_response .candidates [0 ].content = mock_content
100+ mock_model .generate_content .return_value = mock_response
101+
102+ # Add some history before chat
103+ agent .add_system_prompt ("System prompt" )
104+ agent .append_to_history (role = "user" , content = "Hello" )
105+
106+ # Call chat method with custom parameters
107+ response = agent .chat ("What can you help me with?" , temperature = 0.2 , max_tokens = 1000 )
108+
109+ # Verify the model was called with correct parameters
110+ mock_model .generate_content .assert_called_once ()
111+ args , kwargs = mock_model .generate_content .call_args
112+
113+ # Check that history was included
114+ self .assertEqual (len (args [0 ]), 3 ) # System prompt + user message + new query
115+
116+ # Check that generation parameters were passed correctly
117+ self .assertEqual (kwargs .get ('generation_config' ).temperature , 0.2 )
118+ self .assertEqual (kwargs .get ('generation_config' ).max_output_tokens , 1000 )
119+
120+ # Check response handling
121+ self .assertEqual (response , "Generated response" )
122+
123+ @patch ('google.generativeai.configure' )
124+ @patch ('google.generativeai.get_model' )
125+ def test_gemini_parse_response (self , mock_get_model , mock_configure ):
126+ """Test parsing different response formats from the Gemini API."""
127+ agent = GeminiModelAgent ("fake-api-key" , "gemini-pro" )
128+
129+ # Mock normal response
130+ normal_response = MagicMock ()
131+ normal_content = MagicMock ()
132+ normal_content .text = "Normal response"
133+ normal_response .candidates = [MagicMock ()]
134+ normal_response .candidates [0 ].content = normal_content
135+
136+ # Mock empty response
137+ empty_response = MagicMock ()
138+ empty_response .candidates = []
139+
140+ # Mock response with finish reason not STOP
141+ blocked_response = MagicMock ()
142+ blocked_response .candidates = [MagicMock ()]
143+ blocked_candidate = blocked_response .candidates [0 ]
144+ blocked_candidate .content .text = "Blocked content"
145+ blocked_candidate .finish_reason = MagicMock ()
146+ blocked_candidate .finish_reason .name = "SAFETY"
147+
148+ # Test normal response parsing
149+ result = agent ._parse_response (normal_response )
150+ self .assertEqual (result , "Normal response" )
151+
152+ # Test empty response parsing
153+ result = agent ._parse_response (empty_response )
154+ self .assertEqual (result , "No response generated. Please try again." )
155+
156+ # Test blocked response parsing
157+ result = agent ._parse_response (blocked_response )
158+ self .assertEqual (result , "The response was blocked due to: SAFETY" )
159+
160+ @patch ('google.generativeai.configure' )
161+ @patch ('google.generativeai.get_model' )
162+ def test_gemini_content_handling (self , mock_get_model , mock_configure ):
163+ """Test content handling for different input types."""
164+ agent = GeminiModelAgent ("fake-api-key" , "gemini-pro" )
165+
166+ # Test string content
167+ parts = agent ._prepare_content ("Hello world" )
168+ self .assertEqual (len (parts ), 1 )
169+ self .assertEqual (parts [0 ]["text" ], "Hello world" )
170+
171+ # Test list content
172+ parts = agent ._prepare_content (["Hello" , "world" ])
173+ self .assertEqual (len (parts ), 2 )
174+ self .assertEqual (parts [0 ]["text" ], "Hello" )
175+ self .assertEqual (parts [1 ]["text" ], "world" )
176+
177+ # Test already formatted content
178+ parts = agent ._prepare_content ([{"text" : "Already formatted" }])
179+ self .assertEqual (len (parts ), 1 )
180+ self .assertEqual (parts [0 ]["text" ], "Already formatted" )
181+
182+ # Test empty content
183+ parts = agent ._prepare_content ("" )
184+ self .assertEqual (len (parts ), 1 )
185+ self .assertEqual (parts [0 ]["text" ], "" )
85186
86187
87188@skipIf (not IMPORTS_AVAILABLE , "Required model imports not available" )
@@ -159,4 +260,74 @@ def test_ollama_prepare_chat_params(self, mock_post):
159260 self .assertEqual (params ["messages" ][0 ]["role" ], "system" )
160261 self .assertEqual (params ["messages" ][0 ]["content" ], "System instructions" )
161262 self .assertEqual (params ["messages" ][1 ]["role" ], "user" )
162- self .assertEqual (params ["messages" ][1 ]["content" ], "Hello" )
263+ self .assertEqual (params ["messages" ][1 ]["content" ], "Hello" )
264+
265+ @patch ('requests.post' )
266+ def test_ollama_chat_with_parameters (self , mock_post ):
267+ """Test chat method with various parameters."""
268+ agent = OllamaModelAgent ("http://localhost:11434" , "llama2" )
269+
270+ # Mock the response for the post request
271+ mock_response = MagicMock ()
272+ mock_response .json .return_value = {"message" : {"content" : "Response from model" }}
273+ mock_post .return_value = mock_response
274+
275+ # Add a system prompt
276+ agent .add_system_prompt ("Be helpful" )
277+
278+ # Call chat with different parameters
279+ result = agent .chat ("Hello" , temperature = 0.3 , max_tokens = 2000 )
280+
281+ # Verify the post request was called with correct parameters
282+ mock_post .assert_called_once ()
283+ args , kwargs = mock_post .call_args
284+
285+ # Check URL
286+ self .assertEqual (args [0 ], "http://localhost:11434/api/chat" )
287+
288+ # Check JSON payload
289+ json_data = kwargs .get ('json' , {})
290+ self .assertEqual (json_data ["model" ], "llama2" )
291+ self .assertEqual (len (json_data ["messages" ]), 3 ) # System + history + new message
292+ self .assertEqual (json_data ["temperature" ], 0.3 )
293+ self .assertEqual (json_data ["max_tokens" ], 2000 )
294+
295+ # Verify the response was correctly processed
296+ self .assertEqual (result , "Response from model" )
297+
298+ @patch ('requests.post' )
299+ def test_ollama_error_handling (self , mock_post ):
300+ """Test handling of various error cases."""
301+ agent = OllamaModelAgent ("http://localhost:11434" , "llama2" )
302+
303+ # Test connection error
304+ mock_post .side_effect = Exception ("Connection failed" )
305+ result = agent .chat ("Hello" )
306+ self .assertTrue ("Error communicating with Ollama API" in result )
307+
308+ # Test bad response
309+ mock_post .side_effect = None
310+ mock_response = MagicMock ()
311+ mock_response .json .return_value = {"error" : "Model not found" }
312+ mock_post .return_value = mock_response
313+ result = agent .chat ("Hello" )
314+ self .assertTrue ("Error" in result )
315+
316+ # Test missing content in response
317+ mock_response .json .return_value = {"message" : {}} # Missing content
318+ result = agent .chat ("Hello" )
319+ self .assertTrue ("Unexpected response format" in result )
320+
321+ def test_ollama_url_handling (self ):
322+ """Test handling of different URL formats."""
323+ # Test with trailing slash
324+ agent = OllamaModelAgent ("http://localhost:11434/" , "llama2" )
325+ self .assertEqual (agent .api_url , "http://localhost:11434" )
326+
327+ # Test without protocol
328+ agent = OllamaModelAgent ("localhost:11434" , "llama2" )
329+ self .assertEqual (agent .api_url , "http://localhost:11434" )
330+
331+ # Test with https
332+ agent = OllamaModelAgent ("https://ollama.example.com" , "llama2" )
333+ self .assertEqual (agent .api_url , "https://ollama.example.com" )
0 commit comments