Skip to content
This repository was archived by the owner on Jun 21, 2025. It is now read-only.

Commit 73dbfdb

Browse files
committed
test: Improve code coverage
1 parent c26172e commit 73dbfdb

19 files changed

Lines changed: 1733 additions & 1239 deletions

code_agent/adk/models.py

Lines changed: 0 additions & 506 deletions
This file was deleted.

code_agent/adk/models_v2.py

Lines changed: 0 additions & 493 deletions
This file was deleted.

code_agent/agent/software_engineer/software_engineer/tools/shell_command.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def configure_shell_whitelist(args: dict, tool_context: ToolContext) -> Configur
101101
"ss",
102102
"uname",
103103
"uptime",
104+
"date",
104105
"df",
105106
"du",
106107
"free",

code_agent/cli/commands/run.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
operation_warning,
5050
run_cli,
5151
setup_logging,
52+
step_progress,
5253
thinking_indicator,
5354
)
5455
from code_agent.services.session_service import FileSystemSessionService
@@ -206,7 +207,9 @@ def run_command(
206207
# Print provider info using the correct config attribute
207208
# Check if provider attribute exists, otherwise fallback might be needed (e.g., default_provider)
208209
provider_display = getattr(cfg, "provider", cfg.default_provider) # Attempt to get effective provider, fallback to default
209-
console.print(f"[dim]Provider: {provider_display}[/dim]") # Use the determined provider
210+
model_display = getattr(cfg, "model", cfg.default_model) # Attempt to get effective model, fallback to default
211+
step_progress(console, f"[dim]Provider: {provider_display}[/dim]")
212+
step_progress(console, f"[dim]Model: {model_display}[/dim]")
210213

211214
agent_to_run = None
212215
try:
@@ -252,18 +255,18 @@ def run_command(
252255
# --- Get Agent Instance (Revert to previous working logic) ---
253256
if hasattr(agent_module, "root_agent"):
254257
agent_to_run = agent_module.root_agent
255-
operation_warning(console, f"Found 'agent.root_agent' structure in {resolved_agent_path.name}.")
258+
operation_warning(console, f"[dim]Found 'agent.root_agent' structure in {resolved_agent_path.name}.[/dim]")
256259
elif hasattr(agent_module, "agent"):
257260
potential_agent = agent_module.agent
258261
# Previous check: Look for expected attributes like name/tools
259262
if hasattr(potential_agent, "name") and hasattr(potential_agent, "tools"):
260263
agent_to_run = potential_agent
261-
operation_warning(console, f"Found top-level 'agent' variable in {resolved_agent_path.name} and using it.")
264+
operation_warning(console, f"[dim]Found top-level 'agent' variable in {resolved_agent_path.name} and using it.[/dim]")
262265
operation_warning(console, "Consider renaming to 'root_agent' for clarity.")
263266
# Check if agent_module.agent contains root_agent (less common)
264267
elif hasattr(potential_agent, "root_agent"):
265268
agent_to_run = potential_agent.root_agent
266-
operation_warning(console, f"Found 'agent.root_agent' structure in {resolved_agent_path.name}.")
269+
# operation_warning(console, f"[dim]Found 'agent.root_agent' structure in {resolved_agent_path.name}.[/dim]")
267270
else:
268271
# If root_agent doesn't exist, try 'agent'
269272
raise AttributeError(f"Module {resolved_agent_path} has 'agent' but not 'root_agent'. Please expose 'root_agent'.")
@@ -276,7 +279,9 @@ def run_command(
276279
# Should have been caught above, but double check
277280
raise ImportError("Failed to load a valid agent instance.")
278281

279-
operation_complete(console, f"Agent '{getattr(agent_to_run, 'name', 'Unnamed Agent')}' loaded successfully.")
282+
operation_complete(
283+
console, f"[dim]Agent '{getattr(agent_to_run, 'name', 'Unnamed Agent')}' loaded successfully from {resolved_agent_path.name}.[/dim]"
284+
)
280285

281286
except (ImportError, AttributeError) as e:
282287
operation_error(console, f"Failed to load agent: {e}")
@@ -352,7 +357,9 @@ def run_command(
352357
final_session_id = run_cli_args["session_id"]
353358

354359
if final_session_id:
355-
console.print(f"Session ID: {final_session_id}")
360+
# This is the last of the two print statements that look like:
361+
# Session ID: 0f6e2c63-76fc-494b-95fc-b9d0319004e0
362+
console.print(f"[dim]Session ID: {final_session_id}[/dim]")
356363

357364
# Check if saving is requested via CLI flag
358365
should_save = save_session_cli

code_agent/cli/utils.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,25 @@ async def process_message_async(current_session_id, user_input, show_events=True
270270
# Process events
271271
try:
272272
async for event in event_async_generator:
273+
# TODO: Improve event handling to include more details in chat loop
274+
# print(f"Event from: {event.author}")
275+
#
276+
# if event.content and event.content.parts:
277+
# if event.get_function_calls():
278+
# print(" Type: Tool Call Request")
279+
# elif event.get_function_responses():
280+
# print(" Type: Tool Result")
281+
# elif event.content.parts[0].text:
282+
# if event.partial:
283+
# print(" Type: Streaming Text Chunk")
284+
# else:
285+
# print(" Type: Complete Text Message")
286+
# else:
287+
# print(" Type: Other Content (e.g., code result)")
288+
# elif event.actions and (event.actions.state_delta or event.actions.artifact_delta):
289+
# print(" Type: State/Artifact Update")
290+
# else:
291+
# print(" Type: Control Signal or Other")
273292
if interrupted:
274293
console.print("[bold yellow]Processing interrupted by user.[/bold yellow]")
275294
break
@@ -301,11 +320,15 @@ async def process_message_async(current_session_id, user_input, show_events=True
301320
console.print(Markdown(content_text))
302321
# Update last content to avoid duplicates
303322
last_content = content_text
323+
# TODO: Currently now working...
304324
# Optionally handle other authors like 'tool' or 'system' if needed
325+
elif event.get_function_responses(): # author == "tool":
326+
console.print(f"{timestamp_str}[bold green]🔧Tool:[/bold green] {content_text}")
305327

306328
if is_final:
329+
# https://google.github.io/adk-docs/events/
307330
final_response_event = event
308-
operation_complete(console, "[dim]Agent finished processing.[/dim]") # Pass console
331+
operation_complete(console, f"[dim]{event.author} finished processing.[/dim]") # Pass console
309332

310333
except Exception as e:
311334
# Allow KeyboardInterrupt and SystemExit to propagate
@@ -463,10 +486,12 @@ async def run_interactively_async(initial_session_id):
463486
if interrupted:
464487
console.print("[bold yellow]Session terminated by user.[/bold yellow]")
465488

489+
# This is the first of the two print statements that look like:
490+
# Session ID: 0f6e2c63-76fc-494b-95fc-b9d0319004e0
466491
# Always print the session ID at the end for reference
467-
if current_session_id:
468-
console.print(f"[dim]Session ID: [bold cyan]{current_session_id}[/bold cyan][/dim]")
469-
else:
470-
console.print("[dim]No active session ID to display.[/dim]")
492+
# if current_session_id:
493+
# console.print(f"[dim]Session ID: [bold cyan]{current_session_id}[/bold cyan][/dim]")
494+
# else:
495+
# console.print("[dim]No active session ID to display.[/dim]")
471496

472497
return current_session_id
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
Tests for the create_model function in code_agent.adk.models_v2 module.
3+
"""
4+
5+
from unittest.mock import MagicMock, patch
6+
7+
from code_agent.adk.models_v2 import LiteLlm, OllamaLlm, create_model
8+
9+
10+
class TestCreateModel:
11+
"""Test the create_model function."""
12+
13+
@patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key")
14+
@patch("google.adk.models.Gemini")
15+
def test_create_model_gemini(self, mock_gemini_class, mock_get_api_key):
16+
"""Test creating a Gemini model."""
17+
# Set up the mock
18+
mock_gemini_instance = MagicMock()
19+
mock_gemini_class.return_value = mock_gemini_instance
20+
21+
# Call create_model with ai_studio provider
22+
result = create_model(provider="ai_studio", model_name="gemini-1.5-flash")
23+
24+
# Check that the Gemini class was called with expected parameters
25+
mock_gemini_class.assert_called_once()
26+
assert result == mock_gemini_instance
27+
28+
@patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key")
29+
@patch("code_agent.adk.models_v2.LiteLlm")
30+
def test_create_model_litellm(self, mock_litellm_class, mock_get_api_key):
31+
"""Test creating various LiteLlm models."""
32+
# Set up the mock
33+
mock_litellm_instance = MagicMock(spec=LiteLlm)
34+
mock_litellm_class.return_value = mock_litellm_instance
35+
36+
# Test OpenAI
37+
result = create_model(provider="openai", model_name="gpt-4-turbo")
38+
39+
# Check that LiteLlm was instantiated
40+
mock_litellm_class.assert_called_once()
41+
# Check that provider and model_name were passed (not testing all params to avoid brittle tests)
42+
kwargs = mock_litellm_class.call_args.kwargs
43+
assert kwargs["provider"] == "openai"
44+
assert kwargs["model_name"] == "gpt-4-turbo"
45+
assert kwargs["api_key"] == "fake-api-key"
46+
assert result == mock_litellm_instance
47+
48+
@patch("code_agent.adk.models_v2.OllamaLlm")
49+
def test_create_model_ollama(self, mock_ollama_class):
50+
"""Test creating an Ollama model."""
51+
# Set up the mock
52+
mock_ollama_instance = MagicMock(spec=OllamaLlm)
53+
mock_ollama_class.return_value = mock_ollama_instance
54+
55+
# Call create_model with ollama provider
56+
result = create_model(provider="ollama", model_name="llama3")
57+
58+
# Verify OllamaLlm was instantiated
59+
mock_ollama_class.assert_called_once()
60+
# Only check critical parameters
61+
kwargs = mock_ollama_class.call_args.kwargs
62+
assert kwargs["model_name"] == "llama3"
63+
assert result == mock_ollama_instance
64+
65+
@patch("code_agent.adk.models_v2.get_api_key", return_value=None)
66+
@patch("code_agent.adk.models_v2.LiteLlm")
67+
def test_missing_api_key_with_fallback(self, mock_litellm_class, mock_get_api_key):
68+
"""Test fallback to alternative model when API key is missing."""
69+
# Create a mock for the fallback model
70+
mock_fallback_instance = MagicMock(spec=LiteLlm)
71+
mock_litellm_class.return_value = mock_fallback_instance
72+
73+
# Call with a fallback configuration
74+
result = create_model(provider="openai", model_name="gpt-4-turbo", fallback_provider="anthropic", fallback_model="claude-3-opus")
75+
76+
# Should create the fallback model
77+
mock_litellm_class.assert_called_once()
78+
# Check at least one key parameter
79+
kwargs = mock_litellm_class.call_args.kwargs
80+
assert kwargs["provider"] == "anthropic"
81+
assert kwargs["model_name"] == "claude-3-opus"
82+
assert result == mock_fallback_instance
83+
84+
@patch("code_agent.adk.models_v2.get_api_key", return_value=None)
85+
@patch("google.adk.models.Gemini")
86+
def test_fallback_to_gemini_without_key(self, mock_gemini_class, mock_get_api_key):
87+
"""Test that without API key and no fallback, we get a Gemini model."""
88+
# Set up the mock
89+
mock_gemini_instance = MagicMock()
90+
mock_gemini_class.return_value = mock_gemini_instance
91+
92+
# Call create_model without fallback
93+
result = create_model(provider="openai", model_name="gpt-4-turbo")
94+
95+
# Should create a Gemini model (default fallback behavior)
96+
mock_gemini_class.assert_called_once()
97+
assert result == mock_gemini_instance
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
Tests for the model factory functions in code_agent.adk.models_v2 module.
3+
"""
4+
5+
from unittest.mock import patch
6+
7+
import pytest
8+
from google.adk.models import Gemini
9+
10+
from code_agent.adk.models_v2 import (
11+
LiteLlm,
12+
ModelConfig,
13+
OllamaLlm,
14+
create_model,
15+
get_default_models_by_provider,
16+
get_model_providers,
17+
)
18+
19+
20+
class TestModelFactory:
21+
"""Test the model factory functions in the models_v2 module."""
22+
23+
@patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key")
24+
def test_create_model_gemini(self, mock_get_api_key):
25+
"""Test creating a Gemini model."""
26+
model = create_model(provider="ai_studio", model_name="gemini-1.5-flash")
27+
assert isinstance(model, Gemini)
28+
assert model.model == "gemini-1.5-flash"
29+
assert model.api_key == "fake-api-key"
30+
31+
@patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key")
32+
def test_create_model_openai(self, mock_get_api_key):
33+
"""Test creating an OpenAI model."""
34+
model = create_model(provider="openai", model_name="gpt-4-turbo")
35+
assert isinstance(model, LiteLlm)
36+
assert model.provider == "openai"
37+
assert model.model_name == "gpt-4-turbo"
38+
assert model.api_key == "fake-api-key"
39+
assert model.litellm_model == "openai/gpt-4-turbo"
40+
41+
@patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key")
42+
def test_create_model_anthropic(self, mock_get_api_key):
43+
"""Test creating an Anthropic model."""
44+
model = create_model(provider="anthropic", model_name="claude-3-opus")
45+
assert isinstance(model, LiteLlm)
46+
assert model.provider == "anthropic"
47+
assert model.model_name == "claude-3-opus"
48+
assert model.api_key == "fake-api-key"
49+
assert model.litellm_model == "anthropic/claude-3-opus"
50+
51+
def test_create_model_ollama(self):
52+
"""Test creating an Ollama model."""
53+
model = create_model(provider="ollama", model_name="llama3.2")
54+
assert isinstance(model, OllamaLlm)
55+
assert model.provider == "ollama"
56+
assert model.model_name == "llama3.2"
57+
assert model.base_url == "http://localhost:11434"
58+
assert model.litellm_model == "ollama/llama3.2"
59+
60+
@patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key")
61+
def test_create_model_with_temperature(self, mock_get_api_key):
62+
"""Test creating a model with a custom temperature."""
63+
model = create_model(provider="openai", model_name="gpt-4", temperature=0.2)
64+
assert isinstance(model, LiteLlm)
65+
assert model.temperature == 0.2
66+
67+
@patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key")
68+
def test_create_model_with_max_tokens(self, mock_get_api_key):
69+
"""Test creating a model with custom max_tokens."""
70+
model = create_model(provider="openai", model_name="gpt-4", max_tokens=1000)
71+
assert isinstance(model, LiteLlm)
72+
assert model.max_tokens == 1000
73+
74+
@patch("code_agent.adk.models_v2.get_api_key", return_value=None)
75+
def test_create_model_missing_api_key(self, mock_get_api_key):
76+
"""Test error when API key is missing for providers that need it."""
77+
with pytest.raises(ValueError, match="No API key found for provider"):
78+
create_model(provider="openai", model_name="gpt-4")
79+
80+
@patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key")
81+
def test_create_model_unknown_provider(self, mock_get_api_key):
82+
"""Test creating a model with an unknown provider."""
83+
with pytest.raises(ValueError, match="Unknown provider"):
84+
create_model(provider="unknown", model_name="model")
85+
86+
@patch("code_agent.adk.models_v2.get_api_key", return_value="fake-api-key")
87+
def test_create_model_with_fallback(self, mock_get_api_key):
88+
"""Test model creation with fallback configuration."""
89+
model = create_model(provider="openai", model_name="gpt-4", fallback_provider="anthropic", fallback_model="claude-3-sonnet")
90+
# Verify the fallback configuration is stored somewhere
91+
# The exact implementation depends on how fallback is handled
92+
assert hasattr(model, "_fallback_config")
93+
assert model._fallback_config.provider == "anthropic"
94+
assert model._fallback_config.model_name == "claude-3-sonnet"
95+
96+
def test_get_model_providers(self):
97+
"""Test the get_model_providers function returns a non-empty list."""
98+
providers = get_model_providers()
99+
assert isinstance(providers, list)
100+
assert len(providers) > 0
101+
assert "openai" in providers
102+
assert "ai_studio" in providers
103+
assert "anthropic" in providers
104+
assert "ollama" in providers
105+
106+
def test_get_default_models_by_provider(self):
107+
"""Test the get_default_models_by_provider function returns a non-empty dict."""
108+
default_models = get_default_models_by_provider()
109+
assert isinstance(default_models, dict)
110+
assert len(default_models) > 0
111+
assert "openai" in default_models
112+
assert "ai_studio" in default_models
113+
assert "anthropic" in default_models
114+
assert "ollama" in default_models
115+
116+
117+
class TestModelConfig:
118+
"""Test the ModelConfig class."""
119+
120+
def test_model_config_creation(self):
121+
"""Test creating a ModelConfig instance."""
122+
config = ModelConfig(
123+
provider="openai",
124+
model_name="gpt-4",
125+
temperature=0.5,
126+
max_tokens=1000,
127+
timeout=60,
128+
retry_count=3,
129+
fallback_provider="anthropic",
130+
fallback_model="claude-3-opus",
131+
)
132+
assert config.provider == "openai"
133+
assert config.model_name == "gpt-4"
134+
assert config.temperature == 0.5
135+
assert config.max_tokens == 1000
136+
assert config.timeout == 60
137+
assert config.retry_count == 3
138+
assert config.fallback_provider == "anthropic"
139+
assert config.fallback_model == "claude-3-opus"
140+
141+
def test_model_config_defaults(self):
142+
"""Test ModelConfig default values."""
143+
config = ModelConfig(provider="openai", model_name="gpt-4")
144+
assert config.temperature == 0.7 # Default value
145+
assert config.max_tokens is None # Default value
146+
assert config.timeout is None # Default value
147+
assert config.retry_count == 2 # Default value
148+
assert config.fallback_provider is None # Default value
149+
assert config.fallback_model is None # Default value

0 commit comments

Comments
 (0)