forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
174 lines (142 loc) · 5.58 KB
/
utils.py
File metadata and controls
174 lines (142 loc) · 5.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from typing import Any
from typing import Optional
from urllib.parse import urljoin
from adk_answering_agent.settings import GITHUB_GRAPHQL_URL
from adk_answering_agent.settings import GITHUB_TOKEN
from google.adk.agents.run_config import RunConfig
from google.adk.runners import Runner
from google.genai import types
import requests
headers = {
"Authorization": f"token {GITHUB_TOKEN}",
"Accept": "application/vnd.github.v3+json",
}
def error_response(error_message: str) -> dict[str, Any]:
return {"status": "error", "error_message": error_message}
def run_graphql_query(query: str, variables: dict[str, Any]) -> dict[str, Any]:
"""Executes a GraphQL query."""
payload = {"query": query, "variables": variables}
response = requests.post(
GITHUB_GRAPHQL_URL, headers=headers, json=payload, timeout=60
)
response.raise_for_status()
return response.json()
def parse_number_string(number_str: str | None, default_value: int = 0) -> int:
"""Parse a number from the given string."""
if not number_str:
return default_value
try:
return int(number_str)
except ValueError:
print(
f"Warning: Invalid number string: {number_str}. Defaulting to"
f" {default_value}.",
file=sys.stderr,
)
return default_value
def _check_url_exists(url: str) -> bool:
"""Checks if a URL exists and is accessible."""
try:
# Set a timeout to prevent the program from waiting indefinitely.
# allow_redirects=True ensures we correctly handle valid links
# after redirection.
response = requests.head(url, timeout=5, allow_redirects=True)
# Status codes 2xx (Success) or 3xx (Redirection) are considered valid.
return response.ok
except requests.RequestException:
# Catch all possible exceptions from the requests library
# (e.g., connection errors, timeouts).
return False
def _generate_github_url(repo_name: str, relative_path: str) -> str:
"""Generates a standard GitHub URL for a repo file."""
return f"https://github.com/google/{repo_name}/blob/main/{relative_path}"
def convert_gcs_to_https(gcs_uri: str) -> Optional[str]:
"""Converts a GCS file link into a publicly accessible HTTPS link.
Args:
gcs_uri: The Google Cloud Storage link, in the format
'gs://bucket_name/prefix/relative_path'.
Returns:
The converted HTTPS link as a string, or None if the input format is
incorrect.
"""
# Parse the GCS link
if not gcs_uri or not gcs_uri.startswith("gs://"):
print(f"Error: Invalid GCS link format: {gcs_uri}")
return None
try:
# Strip 'gs://' and split by '/', requiring at least 3 parts
# (bucket, prefix, path)
parts = gcs_uri[5:].split("/", 2)
if len(parts) < 3:
raise ValueError(
"GCS link must contain a bucket, prefix, and relative_path."
)
_, prefix, relative_path = parts
except (ValueError, IndexError) as e:
print(f"Error: Failed to parse GCS link '{gcs_uri}': {e}")
return None
# Replace .html with .md
if relative_path.endswith(".html"):
relative_path = relative_path.removesuffix(".html") + ".md"
# Replace .txt with .yaml
if relative_path.endswith(".txt"):
relative_path = relative_path.removesuffix(".txt") + ".yaml"
# Convert the links for adk-docs
if prefix == "adk-docs" and relative_path.startswith("docs/"):
path_after_docs = relative_path[len("docs/") :]
if not path_after_docs.endswith(".md"):
# Use the regular github url
return _generate_github_url(prefix, relative_path)
base_url = "https://google.github.io/adk-docs/"
if os.path.basename(path_after_docs) == "index.md":
# Use the directory path if it is a index file
final_path_segment = os.path.dirname(path_after_docs)
else:
# Otherwise, use the file name without extension
final_path_segment = path_after_docs.removesuffix(".md")
if final_path_segment and not final_path_segment.endswith("/"):
final_path_segment += "/"
potential_url = urljoin(base_url, final_path_segment)
# Check if the generated link exists
if _check_url_exists(potential_url):
return potential_url
else:
# If it doesn't exist, fall back to the regular github url
return _generate_github_url(prefix, relative_path)
# Convert the links for other cases, e.g. adk-python
else:
return _generate_github_url(prefix, relative_path)
async def call_agent_async(
runner: Runner, user_id: str, session_id: str, prompt: str
) -> str:
"""Call the agent asynchronously with the user's prompt."""
content = types.Content(
role="user", parts=[types.Part.from_text(text=prompt)]
)
final_response_text = ""
async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
new_message=content,
run_config=RunConfig(save_input_blobs_as_artifacts=False),
):
if event.content and event.content.parts:
if text := "".join(part.text or "" for part in event.content.parts):
if event.author != "user":
final_response_text += text
return final_response_text