diff --git a/examples/snowflake_native_app_example.py b/examples/snowflake_native_app_example.py index 82781ab5..ea79a31a 100644 --- a/examples/snowflake_native_app_example.py +++ b/examples/snowflake_native_app_example.py @@ -2,46 +2,45 @@ from contextual import ContextualAI -SF_BASE_URL = 'xxxxx-xxxxx-xxxxx.snowflakecomputing.app' -BASE_URL = f'https://{SF_BASE_URL}/v1' +SF_BASE_URL = "xxxxx-xxxxx-xxxxx.snowflakecomputing.app" +BASE_URL = f"https://{SF_BASE_URL}/v1" -SAMPLE_MESSAGE = 'Can you tell me about XYZ' +SAMPLE_MESSAGE = "Can you tell me about XYZ" -ctx = snowflake.connector.connect( # type: ignore - user="",# snowflake account user - password='', # snowflake account password - account="organization-account", # snowflake organization and account - - session_parameters={ - 'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT': 'json' - }) +ctx = snowflake.connector.connect( # type: ignore + user="", # snowflake account user + password="", # snowflake account password + account="organization-account", # snowflake organization and account - + session_parameters={"PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "json"}, +) # Obtain a session token. -token_data = ctx._rest._token_request('ISSUE') # type: ignore -token_extract = token_data['data']['sessionToken'] # type: ignore +token_data = ctx._rest._token_request("ISSUE") # type: ignore +token_extract = token_data["data"]["sessionToken"] # type: ignore # Create a request to the ingress endpoint with authz. -api_key = f'\"{token_extract}\"' +api_key = f'"{token_extract}"' client = ContextualAI(api_key=api_key, base_url=BASE_URL) -agents = [a for a in client.agents.list() ] +agents = [a for a in client.agents.list()] agent = agents[0] if agents else None if agent is None: - print('No agents found') + print("No agents found") exit() print(f"Found agent {agent.name} with id {agent.id}") messages = [ { - 'content': SAMPLE_MESSAGE, - 'role': 'user', + "content": SAMPLE_MESSAGE, + "role": "user", } ] -res = client.agents.query.create(agent.id, messages=messages) # type: ignore +res = client.agents.query.create(agent.id, messages=messages) # type: ignore -output = res.message.content # type: ignore +output = res.message.content # type: ignore -print(output) \ No newline at end of file +print(output) diff --git a/src/contextual/_client.py b/src/contextual/_client.py index b11c0638..dd5ae9e3 100644 --- a/src/contextual/_client.py +++ b/src/contextual/_client.py @@ -58,8 +58,9 @@ class ContextualAI(SyncAPIClient): with_streaming_response: ContextualAIWithStreamedResponse # client options - api_key: str - is_snowflake: bool + api_key: str | None = None + is_snowflake: bool = False + is_snowflake_internal: bool = False def __init__( self, @@ -91,9 +92,12 @@ def __init__( if api_key is None: api_key = os.environ.get("CONTEXTUAL_API_KEY") if api_key is None: - raise ContextualAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable" - ) + if os.getenv('SNOWFLAKE_INTERNAL_API_SERVICE', False): + self.is_snowflake_internal = True + else: + raise ContextualAIError( + "The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable" + ) self.api_key = api_key if base_url is None: @@ -103,8 +107,6 @@ def __init__( if 'snowflakecomputing.app' in str(base_url): self.is_snowflake = True - else: - self.is_snowflake = False super().__init__( version=__version__, @@ -137,6 +139,8 @@ def auth_headers(self) -> dict[str, str]: api_key = self.api_key if self.is_snowflake: return {"Authorization": f"Snowflake Token={api_key}"} + elif self.is_snowflake_internal: + return {} else: return {"Authorization": f"Bearer {api_key}"} @@ -245,8 +249,9 @@ class AsyncContextualAI(AsyncAPIClient): with_streaming_response: AsyncContextualAIWithStreamedResponse # client options - api_key: str - is_snowflake: bool + api_key: str | None = None + is_snowflake: bool = False + is_snowflake_internal: bool = False def __init__( self, @@ -278,9 +283,12 @@ def __init__( if api_key is None: api_key = os.environ.get("CONTEXTUAL_API_KEY") if api_key is None: - raise ContextualAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable" - ) + if os.getenv('SNOWFLAKE_INTERNAL_API_SERVICE', False): + self.is_snowflake_internal = True + else: + raise ContextualAIError( + "The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable" + ) self.api_key = api_key if base_url is None: @@ -290,8 +298,6 @@ def __init__( if 'snowflakecomputing.app' in str(base_url): self.is_snowflake = True - else: - self.is_snowflake = False super().__init__( version=__version__, @@ -319,11 +325,13 @@ def qs(self) -> Querystring: return Querystring(array_format="repeat") @property - @override + @override def auth_headers(self) -> dict[str, str]: api_key = self.api_key if self.is_snowflake: return {"Authorization": f"Snowflake Token={api_key}"} + elif self.is_snowflake_internal: + return {} else: return {"Authorization": f"Bearer {api_key}"} diff --git a/tests/test_client.py b/tests/test_client.py index 76484fed..40052e12 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1617,7 +1617,8 @@ def test_get_platform(self) -> None: # # Since nest_asyncio.apply() is global and cannot be un-applied, this # test is run in a separate process to avoid affecting other tests. - test_code = dedent(""" + test_code = dedent( + """ import asyncio import nest_asyncio import threading @@ -1633,7 +1634,8 @@ async def test_main() -> None: nest_asyncio.apply() asyncio.run(test_main()) - """) + """ + ) with subprocess.Popen( [sys.executable, "-c", test_code], text=True,