Skip to content

Commit 0905311

Browse files
committed
Reformat files with black to allow tests to pass
Reviewers: haneeshr
1 parent 3b5f2b5 commit 0905311

4 files changed

Lines changed: 47 additions & 46 deletions

File tree

src/rockset_sqlalchemy/connection.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
from .cursor import Cursor
44
from .exceptions import ProgrammingError
55

6+
67
class Connection(object):
78
def __init__(self, api_server, api_key, virtual_instance=None, debug_sql=False):
89
self._closed = False
9-
self._client = RocksetClient(
10-
host=api_server,
11-
api_key=api_key
12-
)
10+
self._client = RocksetClient(host=api_server, api_key=api_key)
1311
self.vi = virtual_instance
1412
self.debug_sql = debug_sql
1513
# Used for testing connectivity to Rockset.

src/rockset_sqlalchemy/cursor.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,23 @@ def execute_query(client, query, vi=None, query_params={}):
4343
query=query,
4444
parameters=[
4545
rockset.models.QueryParameter(
46-
name=param, value=str(val), type=Cursor.__convert_to_rockset_type(val)
46+
name=param,
47+
value=str(val),
48+
type=Cursor.__convert_to_rockset_type(val),
4749
)
4850
for param, val in query_params.items()
49-
]
51+
],
5052
)
5153
try:
52-
return client.VirtualInstances.query_virtual_instance(
53-
virtual_instance_id=vi,
54-
sql=request
55-
) if vi else client.Queries.query(sql=request)
54+
return (
55+
client.VirtualInstances.query_virtual_instance(
56+
virtual_instance_id=vi, sql=request
57+
)
58+
if vi
59+
else client.Queries.query(sql=request)
60+
)
5661
except rockset.exceptions.RocksetException as e:
57-
raise Error.map_rockset_exception(e)
62+
raise Error.map_rockset_exception(e)
5863

5964
def execute(self, sql, parameters=None):
6065
self.__check_cursor_opened()
@@ -68,7 +73,7 @@ def execute(self, sql, parameters=None):
6873
else:
6974
new_params[k] = v
7075
parameters = new_params
71-
76+
7277
if self._connection.debug_sql:
7378
print("+++++++++++++++++++++++++++++")
7479
print(f"Query:\n{sql}")
@@ -83,10 +88,7 @@ def execute(self, sql, parameters=None):
8388
)
8489

8590
self._response = Cursor.execute_query(
86-
self._connection._client,
87-
sql,
88-
self._connection.vi,
89-
query_params=parameters
91+
self._connection._client, sql, self._connection.vi, query_params=parameters
9092
)
9193
self._response_iter = iter(self._response.results)
9294

@@ -108,8 +110,8 @@ def fetchone(self):
108110
return None
109111

110112
result = []
111-
112-
for field in self._response_to_column_fields(self._response.column_fields):
113+
114+
for field in self._response_to_column_fields(self._response.column_fields):
113115
name = field["name"]
114116
if name in next_doc:
115117
result.append(next_doc[name])
@@ -126,7 +128,7 @@ def _response_to_column_fields(self, column_fields):
126128

127129
schema = rockset.Document()
128130
if self._response.results and len(self._response.results) > 0:
129-
# we only look at the first document because
131+
# we only look at the first document because
130132
# is sqlalchemy is typically used for relational
131133
# tables with no sparse fields
132134
schema.update(self._response.results[0])
@@ -152,7 +154,6 @@ def fetchmany(self, size=None):
152154
break
153155
docs.append(doc)
154156
return docs
155-
156157

157158
@property
158159
def description(self):

src/rockset_sqlalchemy/exceptions.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,27 @@
11
import rockset
22
from json import loads
33

4+
45
class Error(rockset.exceptions.RocksetException):
56
@classmethod
67
def map_rockset_exception(cls, exc):
78
err_body = loads(exc.body)
8-
args = [
9-
err_body["message"],
10-
exc.status,
11-
err_body["type"]
12-
]
9+
args = [err_body["message"], exc.status, err_body["type"]]
1310
exc_type = type(exc)
1411
if (
15-
exc_type == rockset.exceptions.ApiTypeError or
16-
exc_type == rockset.exceptions.ApiValueError or
17-
exc_type == rockset.exceptions.ApiAttributeError or
18-
exc_type == rockset.exceptions.ApiKeyError or
19-
exc_type == rockset.exceptions.NotFoundException or
20-
exc_type == rockset.exceptions.InputException or
21-
exc_type == rockset.exceptions.InitializationException or
22-
exc_type == rockset.exceptions.BadRequestException
23-
12+
exc_type == rockset.exceptions.ApiTypeError
13+
or exc_type == rockset.exceptions.ApiValueError
14+
or exc_type == rockset.exceptions.ApiAttributeError
15+
or exc_type == rockset.exceptions.ApiKeyError
16+
or exc_type == rockset.exceptions.NotFoundException
17+
or exc_type == rockset.exceptions.InputException
18+
or exc_type == rockset.exceptions.InitializationException
19+
or exc_type == rockset.exceptions.BadRequestException
2420
):
2521
ret = ProgrammingError(*args)
2622
elif (
27-
exc_type == rockset.exceptions.UnauthorizedException or
28-
exc_type == rockset.exceptions.ForbiddenException
23+
exc_type == rockset.exceptions.UnauthorizedException
24+
or exc_type == rockset.exceptions.ForbiddenException
2925
):
3026
ret = OperationalError(*args)
3127
elif exc_type == rockset.exceptions.ServiceException:

src/rockset_sqlalchemy/sqlalchemy/dialect.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ class RocksetDialect(default.DefaultDialect):
4343

4444
@classmethod
4545
def dbapi(cls):
46-
"""Retained for backward compatibility with SQLAlchemy 1.x.
47-
"""
46+
"""Retained for backward compatibility with SQLAlchemy 1.x."""
4847
import rockset_sqlalchemy
4948

5049
return rockset_sqlalchemy
51-
50+
5251
@classmethod
5352
def import_dbapi(cls):
5453
return RocksetDialect.dbapi()
@@ -57,20 +56,27 @@ def create_connect_args(self, url):
5756
kwargs = {
5857
"api_server": "https://{}".format(url.host),
5958
"api_key": url.password or url.username,
60-
"virtual_instance": url.database
59+
"virtual_instance": url.database,
6160
}
6261
return ([], kwargs)
6362

6463
@reflection.cache
6564
def get_schema_names(self, connection, **kw):
66-
return [w["name"] for w in connection.connect().connection._client.Workspaces.list()["data"]]
65+
return [
66+
w["name"]
67+
for w in connection.connect().connection._client.Workspaces.list()["data"]
68+
]
6769

6870
@reflection.cache
6971
def get_table_names(self, connection, schema=None, **kw):
70-
tables = (connection.connect().connection._client.Collections.list()
71-
if schema is None else
72-
connection.connect().connection._client.Collections.workspace_collections(workspace=schema))['data']
73-
72+
tables = (
73+
connection.connect().connection._client.Collections.list()
74+
if schema is None
75+
else connection.connect().connection._client.Collections.workspace_collections(
76+
workspace=schema
77+
)
78+
)["data"]
79+
7480
return [w["name"] for w in tables]
7581

7682
def _get_table_columns(self, connection, table_name, schema):
@@ -132,7 +138,7 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
132138
@reflection.cache
133139
def get_indexes(self, connection, table_name, schema=None, **kw):
134140
return []
135-
141+
136142
def has_table(self, connection, table_name, schema=None):
137143
try:
138144
self._get_table_columns(connection, table_name, schema)

0 commit comments

Comments
 (0)