-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathasyncio.py
More file actions
84 lines (66 loc) · 3.39 KB
/
asyncio.py
File metadata and controls
84 lines (66 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterator, Dict, Optional
import pyarrow as pa
from typing_extensions import Self, Unpack
from dbtsl.api.adbc.client.base import BaseADBCClient
from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters
class AsyncADBCClient(BaseADBCClient):
"""An asyncio client to access the Semantic Layer via ADBC."""
def __init__(
self,
server_host: str,
environment_id: int,
auth_token: str,
url_format: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize the ADBC client.
Args:
server_host: the ADBC API host
environment_id: Your dbt environment ID
auth_token: The bearer token that will be used for authentication
url_format: The full connection URL format that transforms the `server_host`
into a full URL. If `None`, the default
`grpc+tls://{server_host}:443`
will be assumed.
extra_headers: extra headers to be sent with the request.
"""
super().__init__(server_host, environment_id, auth_token, url_format, extra_headers=extra_headers)
self._loop = asyncio.get_running_loop()
@asynccontextmanager
async def session(self) -> AsyncIterator[Self]:
"""Open a connection in the underlying ADBC driver.
All requests made during the same session will reuse the same connection.
"""
if self._conn_unsafe is not None:
raise ValueError("A client session is already open.")
ctx = self._get_connection_context_manager()
self._conn_unsafe = await self._loop.run_in_executor(None, ctx.__enter__)
yield self
await self._loop.run_in_executor(None, ctx.__exit__, None, None, None)
self._conn_unsafe = None
async def query(self, **query_params: Unpack[QueryParameters]) -> pa.Table:
"""Query for a dataframe in the Semantic Layer."""
query_sql = self.PROTOCOL.get_query_sql(query_params)
# NOTE: We don't need to wrap this in a `loop.run_in_executor` since
# just creating the cursor object doesn't perform any blocking IO.
with self._conn.cursor() as cur:
try:
await self._loop.run_in_executor(None, cur.execute, query_sql) # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType]
except Exception as err:
self._handle_error(err)
table = await self._loop.run_in_executor(None, cur.fetch_arrow_table)
return table
async def dimension_values(self, **query_params: Unpack[DimensionValuesQueryParameters]) -> pa.Table:
"""Query for the possible values of a dimension."""
query_sql = self.PROTOCOL.get_dimension_values_sql(query_params)
# NOTE: We don't need to wrap this in a `loop.run_in_executor` since
# just creating the cursor object doesn't perform any blocking IO.
with self._conn.cursor() as cur:
try:
await self._loop.run_in_executor(None, cur.execute, query_sql) # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType]
except Exception as err:
self._handle_error(err)
table = await self._loop.run_in_executor(None, cur.fetch_arrow_table)
return table