-
Notifications
You must be signed in to change notification settings - Fork 48
Expand file tree
/
Copy pathhelpers.py
More file actions
289 lines (243 loc) · 9.88 KB
/
helpers.py
File metadata and controls
289 lines (243 loc) · 9.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
This module provides helper functions for the mssql_python package.
"""
import re
import threading
import locale
from typing import Any, Union, Tuple, Optional
from mssql_python import ddbc_bindings
from mssql_python.exceptions import raise_exception
from mssql_python.logging import logger
from mssql_python.constants import ConstantsDDBC
# normalize_architecture import removed as it's unused
def check_error(handle_type: int, handle: Any, ret: int) -> None:
"""
Check for errors and raise an exception if an error is found.
Args:
handle_type: The type of the handle (e.g., SQL_HANDLE_ENV, SQL_HANDLE_DBC).
handle: The SqlHandle object associated with the operation.
ret: The return code from the DDBC function call.
Raises:
RuntimeError: If an error is found.
"""
if ret < 0:
logger.debug(
"check_error: Error detected - handle_type=%d, return_code=%d", handle_type, ret
)
error_info = ddbc_bindings.DDBCSQLCheckError(handle_type, handle, ret)
logger.error("Error: %s", error_info.ddbcErrorMsg)
logger.debug("check_error: SQL state=%s", error_info.sqlState)
raise_exception(error_info.sqlState, error_info.ddbcErrorMsg)
def sanitize_connection_string(conn_str: str) -> str:
"""
Sanitize the connection string by removing sensitive information.
Args:
conn_str (str): The connection string to sanitize.
Returns:
str: The sanitized connection string.
"""
logger.debug(
"sanitize_connection_string: Sanitizing connection string (length=%d)", len(conn_str)
)
# Remove sensitive information from the connection string, Pwd section
# Replace Pwd=...; or Pwd=... (end of string) with Pwd=***;
sanitized = re.sub(r"(Pwd\s*=\s*)[^;]*", r"\1***", conn_str, flags=re.IGNORECASE)
logger.debug("sanitize_connection_string: Password fields masked")
return sanitized
def sanitize_user_input(user_input: str, max_length: int = 50) -> str:
"""
Sanitize user input for safe logging by removing control characters,
limiting length, and ensuring safe characters only.
Args:
user_input (str): The user input to sanitize.
max_length (int): Maximum length of the sanitized output.
Returns:
str: The sanitized string safe for logging.
"""
logger.debug(
"sanitize_user_input: Sanitizing input (type=%s, length=%d)",
type(user_input).__name__,
len(user_input) if isinstance(user_input, str) else 0,
)
if not isinstance(user_input, str):
logger.debug("sanitize_user_input: Non-string input detected")
return "<non-string>"
# Remove control characters and non-printable characters
# Allow alphanumeric, dash, underscore, and dot (common in encoding names)
sanitized = re.sub(r"[^\w\-\.]", "", user_input)
# Limit length to prevent log flooding
was_truncated = False
if len(sanitized) > max_length:
sanitized = sanitized[:max_length] + "..."
was_truncated = True
# Return placeholder if nothing remains after sanitization
result = sanitized if sanitized else "<invalid>"
logger.debug(
"sanitize_user_input: Result length=%d, truncated=%s", len(result), str(was_truncated)
)
return result
def validate_attribute_value(
attribute: Union[int, str],
value: Union[int, str, bytes, bytearray],
is_connected: bool = True,
sanitize_logs: bool = True,
max_log_length: int = 50,
) -> Tuple[bool, Optional[str], str, str]:
"""
Validates attribute and value pairs for connection attributes.
Performs basic type checking and validation of ODBC connection attributes.
Args:
attribute (int): The connection attribute to validate (SQL_ATTR_*)
value: The value to set for the attribute (int, str, bytes, or bytearray)
is_connected (bool): Whether the connection is already established
sanitize_logs (bool): Whether to include sanitized versions for logging
max_log_length (int): Maximum length of sanitized output for logging
Returns:
tuple: (is_valid, error_message, sanitized_attribute, sanitized_value)
"""
logger.debug(
"validate_attribute_value: Validating attribute=%s, value_type=%s, is_connected=%s",
str(attribute),
type(value).__name__,
str(is_connected),
)
# Sanitize a value for logging
def _sanitize_for_logging(input_val: Any, max_length: int = max_log_length) -> str:
if not isinstance(input_val, str):
try:
input_val = str(input_val)
except (TypeError, ValueError):
return "<non-string>"
# Allow alphanumeric, dash, underscore, and dot
sanitized = re.sub(r"[^\w\-\.]", "", input_val)
# Limit length
if len(sanitized) > max_length:
sanitized = sanitized[:max_length] + "..."
return sanitized if sanitized else "<invalid>"
# Create sanitized versions for logging
sanitized_attr = _sanitize_for_logging(attribute) if sanitize_logs else str(attribute)
sanitized_val = _sanitize_for_logging(value) if sanitize_logs else str(value)
# Basic attribute validation - must be an integer
if not isinstance(attribute, int):
logger.debug(
"validate_attribute_value: Attribute not an integer - type=%s", type(attribute).__name__
)
return (
False,
f"Attribute must be an integer, got {type(attribute).__name__}",
sanitized_attr,
sanitized_val,
)
# Define driver-level attributes that are supported
supported_attributes = [
ConstantsDDBC.SQL_ATTR_ACCESS_MODE.value,
ConstantsDDBC.SQL_ATTR_CONNECTION_TIMEOUT.value,
ConstantsDDBC.SQL_ATTR_CURRENT_CATALOG.value,
ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value,
ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value,
ConstantsDDBC.SQL_ATTR_TXN_ISOLATION.value,
]
# Check if attribute is supported
if attribute not in supported_attributes:
logger.debug("validate_attribute_value: Unsupported attribute - attr=%d", attribute)
return (
False,
f"Unsupported attribute: {attribute}",
sanitized_attr,
sanitized_val,
)
# Check timing constraints for these specific attributes
before_only_attributes = [
ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value,
ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value,
]
# Check if attribute can be set at the current connection state
if is_connected and attribute in before_only_attributes:
logger.debug(
"validate_attribute_value: Timing violation - attr=%d cannot be set after connection",
attribute,
)
return (
False,
(
f"Attribute {attribute} must be set before connection establishment. "
"Use the attrs_before parameter when creating the connection."
),
sanitized_attr,
sanitized_val,
)
# Basic value type validation
if isinstance(value, int):
# For integer values, check if negative (login timeout can be -1 for default)
if value < 0 and attribute != ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value:
return (
False,
f"Integer value cannot be negative: {value}",
sanitized_attr,
sanitized_val,
)
elif isinstance(value, str):
# Basic string length check
max_string_size = 8192 # 8KB maximum
if len(value) > max_string_size:
return (
False,
f"String value too large: {len(value)} bytes (max {max_string_size})",
sanitized_attr,
sanitized_val,
)
elif isinstance(value, (bytes, bytearray)):
# Basic binary length check
max_binary_size = 32768 # 32KB maximum
if len(value) > max_binary_size:
return (
False,
f"Binary value too large: {len(value)} bytes (max {max_binary_size})",
sanitized_attr,
sanitized_val,
)
else:
# Reject unsupported value types
return (
False,
f"Unsupported attribute value type: {type(value).__name__}",
sanitized_attr,
sanitized_val,
)
# All basic validations passed
logger.debug(
"validate_attribute_value: Validation passed - attr=%d, value_type=%s",
attribute,
type(value).__name__,
)
return True, None, sanitized_attr, sanitized_val
# Settings functionality moved here to avoid circular imports
# Initialize the locale setting only once at module import time
# This avoids thread-safety issues with locale
_default_decimal_separator: str = "."
try:
# Get the locale setting once during module initialization
locale_separator = locale.localeconv()["decimal_point"]
if locale_separator and len(locale_separator) == 1:
_default_decimal_separator = locale_separator
except (AttributeError, KeyError, TypeError, ValueError):
pass # Keep the default "." if locale access fails
class Settings:
"""
Settings class for mssql_python package configuration.
This class holds global settings that affect the behavior of the package,
including lowercase column names, decimal separator.
"""
def __init__(self) -> None:
self.lowercase: bool = False
# Use the pre-determined separator - no locale access here
self.decimal_separator: str = _default_decimal_separator
# Global settings instance
_settings: Settings = Settings()
_settings_lock: threading.Lock = threading.Lock()
def get_settings() -> Settings:
"""Return the global settings object"""
with _settings_lock:
return _settings