-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy path__init__.py
More file actions
248 lines (202 loc) · 7.9 KB
/
__init__.py
File metadata and controls
248 lines (202 loc) · 7.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import functools
import jinja2
import jinja2.ext
from jinja2.lexer import Token
import re
import logging
import six
import itertools, collections
import sqlalchemy
from packaging import version
_SA_VERSION = version.parse(sqlalchemy.__version__)
class UnsafeSqlException(Exception):
pass
NOT_DANGEROUS_RE = re.compile('^[A-Za-z0-9_]*$')
def is_safe(value):
return NOT_DANGEROUS_RE.match(value)
@six.python_2_unicode_compatible
class DangerouslyInjectedSql(object):
def __init__(self, value):
self.value = value
def __str__(self):
return self.value
def sql(engine, template, **params):
return sql_inner(engine, template, params)
def sql_inner(engine, template, params):
query = render(template, params)
query, params = format_query_with_list_params(query, params)
return SqlProxy(execute_sql(engine, query, params))
sql_inner_original = sql_inner
def compile_template_nocache(template):
return jenv.from_string(template)
# disable cache: jsql.compile_template = jsql.compile_template_nocache
compile_template = functools.lru_cache(maxsize=64)(compile_template_nocache)
def render(template, params):
params['bindparam'] = params.get('bindparam', gen_bindparam(params))
return compile_template(template).render(**params)
logger = logging.getLogger('jsql')
def assert_safe_filter(value):
if value is None:
return None
if isinstance(value, DangerouslyInjectedSql):
return value
value = six.text_type(value)
if not is_safe(value):
raise UnsafeSqlException('unsafe sql param "{}"'.format(value))
return value
class AssertSafeExtension(jinja2.ext.Extension):
# based on https://github.com/pallets/jinja/issues/503
def filter_stream(self, stream):
for token in stream:
if token.type == 'variable_end':
yield Token(token.lineno, 'rparen', ')')
yield Token(token.lineno, 'pipe', '|')
yield Token(token.lineno, 'name', 'assert_safe')
yield token
if token.type == 'variable_begin':
yield Token(token.lineno, 'lparen', '(')
jenv = jinja2.Environment(autoescape=False,
extensions=(AssertSafeExtension,))
jenv.filters["assert_safe"] = assert_safe_filter
def dangerously_inject_sql(value):
return DangerouslyInjectedSql(value)
jenv.filters["dangerously_inject_sql"] = dangerously_inject_sql
jenv.globals["comma"] = DangerouslyInjectedSql(",")
def execute_sql(engine, query, params):
from sqlalchemy.sql import text
from sqlalchemy.engine import Engine
q = text(query)
if _SA_VERSION.major >= 2:
if isinstance(engine, Engine):
raise TypeError("SQLAlchemy 2.0 removed Engine.execute(). Pass a Connection instead.")
return engine.execute(q, params)
is_session = 'session' in repr(engine.__class__).lower()
return engine.execute(q, params=params) if is_session else engine.execute(q, **params)
BINDPARAM_PREFIX = 'bp'
def gen_bindparam(params):
keygen = key_generator()
def bindparam(val):
key = keygen(BINDPARAM_PREFIX)
while key in params:
key = keygen(BINDPARAM_PREFIX)
params[key] = val
return key
return bindparam
def key_generator():
keycnt = collections.defaultdict(itertools.count)
def gen_key(key):
return key + str(next(keycnt[key]))
return gen_key
def get_param_keys(query):
import re
return set(re.findall("(?P<key>:[a-zA-Z_]+_list)", query))
def format_query_with_list_params(query, params):
keys = get_param_keys(query)
for key in keys:
if key.endswith('_tuple_list'):
query, params = _format_query_tuple_list_key(key, query, params)
else:
query, params = _format_query_list_key(key, query, params)
return query, params
def _format_query_list_key(key, query, params):
values = params.pop(key[1:])
new_keys = []
for i, value in enumerate(values):
new_key = '{}_{}'.format(key, i)
new_keys.append(new_key)
params[new_key[1:]] = value
new_keys_str = ", ".join(new_keys) or "null" # NOTE: ("SELECT 'xyz' WHERE 'abc' NOT IN :i_list", i_list=[]) -> expected: 'xyz' | output: None
query = query.replace(key, "({})".format(new_keys_str))
return query, params
def _format_query_tuple_list_key(key, query, params):
values = params.pop(key[1:])
new_keys = []
for i, value in enumerate(values):
new_key = '{}_{}'.format(key, i)
assert isinstance(value, tuple)
new_keys2 = []
for i, tuple_val in enumerate(value):
new_key2 = '{}_{}'.format(new_key, i)
new_keys2.append(new_key2)
params[new_key2[1:]] = tuple_val
new_keys.append("({})".format(", ".join(new_keys2)))
new_keys_str = ", ".join(new_keys) or "null" # NOTE: ("SELECT 'xyz' WHERE ('abc', '') NOT IN :i_tuple_list", i_tuple_list=[]) -> expected: 'xyz' | output: None
query = query.replace(key, "({})".format(new_keys_str))
return query, params
class ObjProxy(object):
def __init__(self, proxied):
self._proxied = proxied
def __iter__(self):
return self._proxied.__iter__()
def __getattr__(self, attr):
if attr in self.__dict__:
return getattr(self, attr)
return getattr(self._proxied, attr)
class SqlProxy(ObjProxy):
def dicts_iter(self, dict=dict):
result = self._proxied
keys = result.keys()
for r in result:
yield dict((k, v) for k, v in zip(keys, r))
def pk_map_iter(self, dict=dict):
result = self._proxied
keys = result.keys()
for r in result:
yield (r[0], dict((k, v) for k, v in zip(keys, r)))
def pks_map_iter(self, *keys, n:int=None, dict=dict, tuple=tuple):
result = self._proxied
all_keys = result.keys()
for r in result:
d = dict((k, v) for k, v in zip(all_keys, r))
if len(keys) == 1:
yield d[keys[0]], d
elif len(keys) > 1:
yield tuple(d[k] for k in keys), d
elif n == 1:
yield r[0], d
elif n and n > 1:
yield tuple(r[k] for k in range(n)), d
else:
raise ValueError('Expected either `n` as int >= 1 OR `keys` as a list of str arguments')
def kv_map_iter(self):
result = self._proxied
for r in result:
yield (r[0], r[1])
def scalars_iter(self):
result = self._proxied
for r in result:
yield r[0]
def tuples_iter(self, tuple=tuple):
result = self._proxied
for r in result:
yield tuple(r)
def pk_map(self, dict=dict):
return dict(self.pk_map_iter())
def pks_map(self, *keys, n=None, dict=dict, tuple=tuple):
return dict(self.pks_map_iter(*keys, n=n, dict=dict, tuple=tuple))
def kv_map(self, dict=dict):
return dict(self.kv_map_iter())
def dicts(self, dict=dict):
return list(self.dicts_iter(dict=dict))
def scalars(self):
return list(self.scalars_iter())
def tuples(self, tuple=tuple):
# although supported natively since version 2.0
# https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.CursorResult.tuples
# same as `scalars()` which was supported since version 1.4
# https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.CursorResult.scalars
return list(self.tuples_iter(tuple=tuple))
def scalar_set(self):
return set(self.scalars_iter())
def tuple_set(self):
return set(self.tuples_iter(tuple=tuple))
def dict(self, dict=dict):
try:
return self.dicts(dict=dict)[0]
except IndexError:
return None
def tuple(self, tuple=tuple):
try:
return self.tuples(tuple=tuple)[0]
except IndexError:
return tuple(None for _ in range(len(self._proxied.keys())))