From 7604e3e52f44a527af64f8d32e08d63c1b48d7c5 Mon Sep 17 00:00:00 2001 From: Miles Adkins Date: Sat, 1 Mar 2025 13:31:49 -0500 Subject: [PATCH] feat: Add special snowflake path for internal dns usage --- examples/snowflake_native_app_example.py | 39 ++++++++++++------------ src/contextual/_client.py | 38 ++++++++++++++--------- tests/test_client.py | 6 ++-- 3 files changed, 46 insertions(+), 37 deletions(-) diff --git a/examples/snowflake_native_app_example.py b/examples/snowflake_native_app_example.py index 82781ab5..ea79a31a 100644 --- a/examples/snowflake_native_app_example.py +++ b/examples/snowflake_native_app_example.py @@ -2,46 +2,45 @@ from contextual import ContextualAI -SF_BASE_URL = 'xxxxx-xxxxx-xxxxx.snowflakecomputing.app' -BASE_URL = f'https://{SF_BASE_URL}/v1' +SF_BASE_URL = "xxxxx-xxxxx-xxxxx.snowflakecomputing.app" +BASE_URL = f"https://{SF_BASE_URL}/v1" -SAMPLE_MESSAGE = 'Can you tell me about XYZ' +SAMPLE_MESSAGE = "Can you tell me about XYZ" -ctx = snowflake.connector.connect( # type: ignore - user="",# snowflake account user - password='', # snowflake account password - account="organization-account", # snowflake organization and account - - session_parameters={ - 'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT': 'json' - }) +ctx = snowflake.connector.connect( # type: ignore + user="", # snowflake account user + password="", # snowflake account password + account="organization-account", # snowflake organization and account - + session_parameters={"PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "json"}, +) # Obtain a session token. -token_data = ctx._rest._token_request('ISSUE') # type: ignore -token_extract = token_data['data']['sessionToken'] # type: ignore +token_data = ctx._rest._token_request("ISSUE") # type: ignore +token_extract = token_data["data"]["sessionToken"] # type: ignore # Create a request to the ingress endpoint with authz. -api_key = f'\"{token_extract}\"' +api_key = f'"{token_extract}"' client = ContextualAI(api_key=api_key, base_url=BASE_URL) -agents = [a for a in client.agents.list() ] +agents = [a for a in client.agents.list()] agent = agents[0] if agents else None if agent is None: - print('No agents found') + print("No agents found") exit() print(f"Found agent {agent.name} with id {agent.id}") messages = [ { - 'content': SAMPLE_MESSAGE, - 'role': 'user', + "content": SAMPLE_MESSAGE, + "role": "user", } ] -res = client.agents.query.create(agent.id, messages=messages) # type: ignore +res = client.agents.query.create(agent.id, messages=messages) # type: ignore -output = res.message.content # type: ignore +output = res.message.content # type: ignore -print(output) \ No newline at end of file +print(output) diff --git a/src/contextual/_client.py b/src/contextual/_client.py index d255c5d0..46479dca 100644 --- a/src/contextual/_client.py +++ b/src/contextual/_client.py @@ -57,8 +57,9 @@ class ContextualAI(SyncAPIClient): with_streaming_response: ContextualAIWithStreamedResponse # client options - api_key: str - is_snowflake: bool + api_key: str | None = None + is_snowflake: bool = False + is_snowflake_internal: bool = False def __init__( self, @@ -90,9 +91,12 @@ def __init__( if api_key is None: api_key = os.environ.get("CONTEXTUAL_API_KEY") if api_key is None: - raise ContextualAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable" - ) + if os.getenv('SNOWFLAKE_INTERNAL_API_SERVICE', False): + self.is_snowflake_internal = True + else: + raise ContextualAIError( + "The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable" + ) self.api_key = api_key if base_url is None: @@ -102,8 +106,6 @@ def __init__( if 'snowflakecomputing.app' in str(base_url): self.is_snowflake = True - else: - self.is_snowflake = False super().__init__( version=__version__, @@ -135,6 +137,8 @@ def auth_headers(self) -> dict[str, str]: api_key = self.api_key if self.is_snowflake: return {"Authorization": f"Snowflake Token={api_key}"} + elif self.is_snowflake_internal: + return {} else: return {"Authorization": f"Bearer {api_key}"} @@ -242,8 +246,9 @@ class AsyncContextualAI(AsyncAPIClient): with_streaming_response: AsyncContextualAIWithStreamedResponse # client options - api_key: str - is_snowflake: bool + api_key: str | None = None + is_snowflake: bool = False + is_snowflake_internal: bool = False def __init__( self, @@ -275,9 +280,12 @@ def __init__( if api_key is None: api_key = os.environ.get("CONTEXTUAL_API_KEY") if api_key is None: - raise ContextualAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable" - ) + if os.getenv('SNOWFLAKE_INTERNAL_API_SERVICE', False): + self.is_snowflake_internal = True + else: + raise ContextualAIError( + "The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable" + ) self.api_key = api_key if base_url is None: @@ -287,8 +295,6 @@ def __init__( if 'snowflakecomputing.app' in str(base_url): self.is_snowflake = True - else: - self.is_snowflake = False super().__init__( version=__version__, @@ -315,11 +321,13 @@ def qs(self) -> Querystring: return Querystring(array_format="repeat") @property - @override + @override def auth_headers(self) -> dict[str, str]: api_key = self.api_key if self.is_snowflake: return {"Authorization": f"Snowflake Token={api_key}"} + elif self.is_snowflake_internal: + return {} else: return {"Authorization": f"Bearer {api_key}"} diff --git a/tests/test_client.py b/tests/test_client.py index 2c685292..13724e05 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1615,7 +1615,8 @@ def test_get_platform(self) -> None: # # Since nest_asyncio.apply() is global and cannot be un-applied, this # test is run in a separate process to avoid affecting other tests. - test_code = dedent(""" + test_code = dedent( + """ import asyncio import nest_asyncio import threading @@ -1631,7 +1632,8 @@ async def test_main() -> None: nest_asyncio.apply() asyncio.run(test_main()) - """) + """ + ) with subprocess.Popen( [sys.executable, "-c", test_code], text=True,