Skip to content

Commit ed67de7

Browse files
committed
Addressing review comments and bug.
1 parent df566eb commit ed67de7

6 files changed

Lines changed: 48 additions & 42 deletions

File tree

.github/workflows/publish-to-pypi.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
run: |
3434
WHEEL=$(ls dist/*.whl | head -n 1)
3535
python -m pip install "${WHEEL}[openai,google,anthropic]"
36-
python -c "from oci_genai_auth.openai import OciOpenAI; from oci_genai_auth.google import OciGoogleGenAI; from oci_genai_auth.anthropic import OciAnthropic; import oci_genai_auth;"
36+
python -c "from oci_genai_auth import OciSessionAuth; from oci_genai_auth import OciUserPrincipalAuth; import oci_genai_auth;"
3737
# - name: Publish to Test PyPI
3838
# run: |
3939
# python -m pip install twine

examples/common.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,44 +18,33 @@
1818
COMPARTMENT_ID = ""
1919
CONVERSATION_STORE_ID = ""
2020
OPENAI_PROJECT = ""
21-
OVERRIDE_URL = ""
2221
PROFILE_NAME = "DEFAULT"
23-
REGION = "us-chicago-1"
2422
GEMINI_API_KEY = ""
25-
GEMINI_BASE_URL = ""
23+
24+
# OpenAI-compatible base URLs.
25+
OPENAI_BASE_URL_PT = "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com/v1"
26+
OPENAI_BASE_URL_NP = "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com/openai/v1"
27+
# Switch to "NP" for examples that store data on the server.
28+
RESPONSE_API_MODE = "PT" # "PT" (pass-through) or "NP" (non-pass-through)
29+
30+
# Other provider base URLs.
31+
ANTHROPIC_BASE_URL = "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com/anthropic"
32+
GOOGLE_BASE_URL = "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com/google"
2633

2734

2835
def _build_headers(include_conversation_store_id: bool = False) -> dict[str, str]:
29-
headers: dict[str, str] = {}
30-
if COMPARTMENT_ID:
31-
headers["CompartmentId"] = COMPARTMENT_ID
32-
headers["opc-compartment-id"] = COMPARTMENT_ID
33-
if OPENAI_PROJECT:
34-
headers["OpenAI-Project"] = OPENAI_PROJECT
35-
if include_conversation_store_id and CONVERSATION_STORE_ID:
36+
headers: dict[str, str] = {
37+
"CompartmentId": COMPARTMENT_ID,
38+
"opc-compartment-id": COMPARTMENT_ID,
39+
"OpenAI-Project": OPENAI_PROJECT,
40+
}
41+
if include_conversation_store_id:
3642
headers["opc-conversation-store-id"] = CONVERSATION_STORE_ID
37-
return headers
43+
return {key: value for key, value in headers.items() if value}
3844

3945

4046
def _resolve_openai_base_url() -> str:
41-
service_endpoint = OVERRIDE_URL or (
42-
f"https://inference.generativeai.{REGION}.oci.oraclecloud.com" if REGION else ""
43-
)
44-
if not service_endpoint:
45-
raise ValueError("REGION or OVERRIDE_URL must be set.")
46-
return f"{service_endpoint.rstrip(' /')}/openai/v1"
47-
48-
49-
def _resolve_anthropic_base_url() -> str:
50-
if not REGION:
51-
raise ValueError("REGION or ANTHROPIC_BASE_URL must be set.")
52-
return f"https://inference.generativeai.{REGION}.oci.oraclecloud.com/anthropic"
53-
54-
55-
def _resolve_google_base_url() -> str:
56-
if not REGION:
57-
raise ValueError("REGION or GOOGLE_BASE_URL must be set.")
58-
return f"https://inference.generativeai.{REGION}.oci.oraclecloud.com/google"
47+
return OPENAI_BASE_URL_NP if RESPONSE_API_MODE == "NP" else OPENAI_BASE_URL_PT
5948

6049

6150
def build_openai_client() -> "OpenAI":
@@ -95,7 +84,7 @@ def build_anthropic_client() -> "Anthropic":
9584

9685
return Anthropic(
9786
api_key="not-used",
98-
base_url=_resolve_anthropic_base_url(),
87+
base_url=ANTHROPIC_BASE_URL,
9988
http_client=httpx.Client(
10089
auth=OciSessionAuth(profile_name=PROFILE_NAME),
10190
headers=_build_headers(),
@@ -108,7 +97,7 @@ def build_anthropic_async_client() -> "AsyncAnthropic":
10897

10998
return AsyncAnthropic(
11099
api_key="not-used",
111-
base_url=_resolve_anthropic_base_url(),
100+
base_url=ANTHROPIC_BASE_URL,
112101
http_client=httpx.AsyncClient(
113102
auth=OciSessionAuth(profile_name=PROFILE_NAME),
114103
headers=_build_headers(),
@@ -127,7 +116,7 @@ def build_google_client() -> "genai.Client":
127116
return genai.Client(
128117
api_key="not-used",
129118
http_options={
130-
"base_url": _resolve_google_base_url(),
119+
"base_url": GOOGLE_BASE_URL,
131120
"headers": headers,
132121
"httpx_client": http_client,
133122
},
@@ -145,7 +134,7 @@ def build_google_async_client() -> tuple["genai.Client", httpx.AsyncClient]:
145134
client = genai.Client(
146135
api_key="not-used",
147136
http_options={
148-
"base_url": _resolve_google_base_url(),
137+
"base_url": GOOGLE_BASE_URL,
149138
"headers": headers,
150139
"httpx_async_client": http_client,
151140
},

examples/openai/function/create_response_fc_ parallel_tool.py renamed to examples/openai/function/create_response_fc_parallel_tool.py

File renamed without changes.

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ dev = [
4747
"pytest-cov>=5.0.0,<6.0.0",
4848
"respx",
4949
"typing-extensions>=4.11, <5",
50-
"openai>=v1.108.1",
50+
"openai>=1.108.1",
5151
"google-genai>=1.0.0",
5252
"anthropic>=0.79.0",
5353
"openai-agents>=0.5.1",
@@ -94,7 +94,7 @@ exclude_lines = [
9494
]
9595

9696
[tool.ruff]
97-
target-version = "py38"
97+
target-version = "py39"
9898
line-length = 100
9999
extend-exclude = [
100100
]
@@ -114,5 +114,5 @@ select = [
114114

115115
[tool.black]
116116
line-length = 100
117-
target-version = ['py38']
117+
target-version = ['py39']
118118
exclude = '\.venv|build|dist'

src/oci_genai_auth/auth.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ def _should_refresh_token(self) -> bool:
5858
Returns:
5959
bool: True if token should be refreshed, False otherwise
6060
"""
61-
if not self._last_refresh:
62-
return True
6361
current_time = time.time()
6462
return (current_time - self._last_refresh) >= self.refresh_interval
6563

@@ -84,7 +82,7 @@ def _refresh_if_needed(self) -> None:
8482
self._last_refresh = time.time()
8583
logger.info("%s token refresh completed successfully", self.__class__.__name__)
8684
except Exception as e:
87-
logger.exception("Warning: Token refresh failed:", e)
85+
logger.exception("Warning: Token refresh failed")
8886

8987
def _sign_request(self, request: httpx.Request, content: bytes) -> None:
9088
"""
@@ -150,7 +148,7 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re
150148
self._sign_request(request, content)
151149
yield request
152150
except Exception as e:
153-
logger.exception("Token refresh on 401 failed:", e)
151+
logger.exception("Token refresh on 401 failed")
154152

155153

156154
class OciSessionAuth(HttpxOciAuth):
@@ -169,7 +167,7 @@ def __init__(
169167
config_file: str = DEFAULT_LOCATION,
170168
profile_name: str = DEFAULT_PROFILE,
171169
refresh_interval: int = 3600,
172-
**kwargs: Mapping[str, Any],
170+
**kwargs: Any,
173171
):
174172
"""
175173
Initialize a Security Token-based OCI signer.

tests/test_auth.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ def _refresh_signer(self) -> None:
3434
self.signer = _DummySigner(f"signed-{self.refresh_calls}")
3535

3636

37+
class _BrokenRefreshAuth(HttpxOciAuth):
38+
def _refresh_signer(self) -> None:
39+
raise ConnectionError("metadata service unreachable")
40+
41+
3742
def test_auth_flow_signs_request():
3843
auth = _DummyAuth(_DummySigner("signed-0"))
3944
request = httpx.Request(
@@ -71,6 +76,20 @@ def test_refresh_if_needed_calls_refresh_signer():
7176
assert auth.refresh_calls == 1
7277

7378

79+
def test_refresh_failure_does_not_break_auth_flow(caplog):
80+
auth = _BrokenRefreshAuth(_DummySigner("signed-0"), refresh_interval=0)
81+
request = httpx.Request("GET", "https://example.com")
82+
83+
with caplog.at_level("ERROR"):
84+
flow = auth.auth_flow(request)
85+
signed_request = next(flow)
86+
87+
assert signed_request.headers["authorization"] == "signed-0"
88+
assert any(
89+
"Token refresh failed" in record.message for record in caplog.records
90+
)
91+
92+
7493
def test_session_auth_initializes_signer_from_config():
7594
config = {
7695
"key_file": "dummy.key",

0 commit comments

Comments
 (0)