Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions examples/snowflake_native_app_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,45 @@

from contextual import ContextualAI

SF_BASE_URL = 'xxxxx-xxxxx-xxxxx.snowflakecomputing.app'
BASE_URL = f'https://{SF_BASE_URL}/v1'
SF_BASE_URL = "xxxxx-xxxxx-xxxxx.snowflakecomputing.app"
BASE_URL = f"https://{SF_BASE_URL}/v1"

SAMPLE_MESSAGE = 'Can you tell me about XYZ'
SAMPLE_MESSAGE = "Can you tell me about XYZ"

ctx = snowflake.connector.connect( # type: ignore
user="",# snowflake account user
password='', # snowflake account password
account="organization-account", # snowflake organization and account <Organization>-<Account>
session_parameters={
'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT': 'json'
})
ctx = snowflake.connector.connect( # type: ignore
user="", # snowflake account user
password="", # snowflake account password
account="organization-account", # snowflake organization and account <Organization>-<Account>
session_parameters={"PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "json"},
)

# Obtain a session token.
token_data = ctx._rest._token_request('ISSUE') # type: ignore
token_extract = token_data['data']['sessionToken'] # type: ignore
token_data = ctx._rest._token_request("ISSUE") # type: ignore
token_extract = token_data["data"]["sessionToken"] # type: ignore

# Create a request to the ingress endpoint with authz.
api_key = f'\"{token_extract}\"'
api_key = f'"{token_extract}"'

client = ContextualAI(api_key=api_key, base_url=BASE_URL)

agents = [a for a in client.agents.list() ]
agents = [a for a in client.agents.list()]

agent = agents[0] if agents else None

if agent is None:
print('No agents found')
print("No agents found")
exit()
print(f"Found agent {agent.name} with id {agent.id}")

messages = [
{
'content': SAMPLE_MESSAGE,
'role': 'user',
"content": SAMPLE_MESSAGE,
"role": "user",
}
]

res = client.agents.query.create(agent.id, messages=messages) # type: ignore
res = client.agents.query.create(agent.id, messages=messages) # type: ignore

output = res.message.content # type: ignore
output = res.message.content # type: ignore

print(output)
print(output)
38 changes: 23 additions & 15 deletions src/contextual/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ class ContextualAI(SyncAPIClient):
with_streaming_response: ContextualAIWithStreamedResponse

# client options
api_key: str
is_snowflake: bool
api_key: str | None = None
is_snowflake: bool = False
is_snowflake_internal: bool = False

def __init__(
self,
Expand Down Expand Up @@ -91,9 +92,12 @@ def __init__(
if api_key is None:
api_key = os.environ.get("CONTEXTUAL_API_KEY")
if api_key is None:
raise ContextualAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable"
)
if os.getenv('SNOWFLAKE_INTERNAL_API_SERVICE', False):
self.is_snowflake_internal = True
else:
raise ContextualAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable"
)
self.api_key = api_key

if base_url is None:
Expand All @@ -103,8 +107,6 @@ def __init__(

if 'snowflakecomputing.app' in str(base_url):
self.is_snowflake = True
else:
self.is_snowflake = False

super().__init__(
version=__version__,
Expand Down Expand Up @@ -137,6 +139,8 @@ def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if self.is_snowflake:
return {"Authorization": f"Snowflake Token={api_key}"}
elif self.is_snowflake_internal:
return {}
else:
return {"Authorization": f"Bearer {api_key}"}

Expand Down Expand Up @@ -245,8 +249,9 @@ class AsyncContextualAI(AsyncAPIClient):
with_streaming_response: AsyncContextualAIWithStreamedResponse

# client options
api_key: str
is_snowflake: bool
api_key: str | None = None
is_snowflake: bool = False
is_snowflake_internal: bool = False

def __init__(
self,
Expand Down Expand Up @@ -278,9 +283,12 @@ def __init__(
if api_key is None:
api_key = os.environ.get("CONTEXTUAL_API_KEY")
if api_key is None:
raise ContextualAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable"
)
if os.getenv('SNOWFLAKE_INTERNAL_API_SERVICE', False):
self.is_snowflake_internal = True
else:
raise ContextualAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the CONTEXTUAL_API_KEY environment variable"
)
self.api_key = api_key

if base_url is None:
Expand All @@ -290,8 +298,6 @@ def __init__(

if 'snowflakecomputing.app' in str(base_url):
self.is_snowflake = True
else:
self.is_snowflake = False

super().__init__(
version=__version__,
Expand Down Expand Up @@ -319,11 +325,13 @@ def qs(self) -> Querystring:
return Querystring(array_format="repeat")

@property
@override
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if self.is_snowflake:
return {"Authorization": f"Snowflake Token={api_key}"}
elif self.is_snowflake_internal:
return {}
else:
return {"Authorization": f"Bearer {api_key}"}

Expand Down
6 changes: 4 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,7 +1617,8 @@ def test_get_platform(self) -> None:
#
# Since nest_asyncio.apply() is global and cannot be un-applied, this
# test is run in a separate process to avoid affecting other tests.
test_code = dedent("""
test_code = dedent(
"""
import asyncio
import nest_asyncio
import threading
Expand All @@ -1633,7 +1634,8 @@ async def test_main() -> None:

nest_asyncio.apply()
asyncio.run(test_main())
""")
"""
)
with subprocess.Popen(
[sys.executable, "-c", test_code],
text=True,
Expand Down