1- import signal
1+ import concurrent . futures
22import threading
33import time
44from contextlib import contextmanager
@@ -93,55 +93,85 @@ def test_connection(self) -> bool:
9393 pass
9494 return False
9595
96+ def _invalidate_connection (self , connection : pymysql .Connection | None = None ):
97+ """关闭并清理失效的连接"""
98+ try :
99+ if connection :
100+ connection .close ()
101+ except Exception :
102+ pass
103+ finally :
104+ self .connection = None
105+
96106 @contextmanager
97107 def get_cursor (self ):
98108 """获取数据库游标的上下文管理器"""
99109 max_retries = 2
110+ cursor = None
111+ connection = None
112+ last_error : Exception | None = None
113+
114+ # 优先确保成功获取游标再交给调用方执行查询
100115 for attempt in range (max_retries ):
101116 try :
102117 connection = self ._get_connection ()
103118 cursor = connection .cursor ()
104-
105- try :
106- yield cursor
107- connection .commit ()
108- break # 成功,退出重试循环
109-
110- except Exception as e :
111- connection .rollback ()
112-
113- # 如果是连接错误,尝试重新连接
114- if "MySQL" in str (e ) or "connection" in str (e ).lower ():
115- if attempt < max_retries - 1 :
116- logger .warning (f"Connection error, retrying (attempt { attempt + 1 } ): { e } " )
117- # 强制重新连接
118- if self .connection :
119- try :
120- self .connection .close ()
121- except Exception as _ :
122- pass
123- self .connection = None
124- time .sleep (1 )
125- continue
126-
127- raise e # 其他错误直接抛出
128-
129- finally :
130- if cursor :
131- cursor .close ()
132-
119+ break
133120 except Exception as e :
121+ last_error = e
122+ logger .warning (f"Failed to acquire cursor (attempt { attempt + 1 } ): { e } " )
123+ self ._invalidate_connection (connection )
124+ cursor = None
125+ connection = None
134126 if attempt == max_retries - 1 :
135- raise e # 最后一次尝试失败,抛出异常
127+ raise e
136128 time .sleep (1 )
137129
130+ if cursor is None or connection is None :
131+ raise last_error or ConnectionError ("Unable to acquire MySQL cursor" )
132+
133+ try :
134+ yield cursor
135+ connection .commit ()
136+ except Exception as e :
137+ try :
138+ connection .rollback ()
139+ except Exception :
140+ pass
141+
142+ # 标记连接失效,等待下一次获取时重建
143+ if "MySQL" in str (e ) or "connection" in str (e ).lower ():
144+ logger .warning (f"MySQL connection error encountered, invalidating connection: { e } " )
145+ self ._invalidate_connection (connection )
146+
147+ raise
148+ finally :
149+ if cursor :
150+ try :
151+ cursor .close ()
152+ except Exception :
153+ pass
154+
138155 def close (self ):
139156 """关闭数据库连接"""
140157 if self .connection :
141158 self .connection .close ()
142159 self .connection = None
143160 logger .info ("MySQL connection closed" )
144161
162+ def get_connection (self ) -> pymysql .Connection :
163+ """对外暴露的连接获取方法"""
164+ return self ._get_connection ()
165+
166+ def invalidate_connection (self ):
167+ """手动标记连接失效"""
168+ self ._invalidate_connection (self .connection )
169+
170+ @property
171+ def database_name (self ) -> str :
172+ """返回当前配置的数据库名称"""
173+ return self .config ["database" ]
174+
145175
146176class QueryTimeoutError (Exception ):
147177 """查询超时异常"""
@@ -156,25 +186,30 @@ class QueryResultTooLargeError(Exception):
156186
157187
158188def execute_query_with_timeout (connection : pymysql .Connection , sql : str , params : tuple = None , timeout : int = 10 ):
159- """带超时的查询执行 """
189+ """使用线程池实现超时控制,避免信号导致的生成器问题 """
160190
161- def timeout_handler (_signum , _frame ):
162- raise QueryTimeoutError (f"Query timeout after { timeout } seconds" )
163-
164- # 设置信号处理
165- old_handler = signal .signal (signal .SIGALRM , timeout_handler )
166- signal .alarm (timeout )
167-
168- try :
191+ def query_worker ():
192+ """查询工作函数,在单独线程中执行"""
169193 cursor = connection .cursor (DictCursor )
170- cursor .execute (sql , params or ())
171- result = cursor .fetchall ()
172- cursor .close ()
173- return result
174- finally :
175- # 恢复原始信号处理
176- signal .alarm (0 )
177- signal .signal (signal .SIGALRM , old_handler )
194+ try :
195+ if params is None :
196+ cursor .execute (sql )
197+ else :
198+ cursor .execute (sql , params )
199+ result = cursor .fetchall ()
200+ return result
201+ finally :
202+ cursor .close ()
203+
204+ # 使用线程池执行查询,设置超时
205+ with concurrent .futures .ThreadPoolExecutor (max_workers = 1 ) as executor :
206+ future = executor .submit (query_worker )
207+ try :
208+ return future .result (timeout = timeout )
209+ except concurrent .futures .TimeoutError :
210+ # 尝试取消任务
211+ future .cancel ()
212+ raise QueryTimeoutError (f"Query timeout after { timeout } seconds" )
178213
179214
180215def limit_result_size (result : list , max_chars : int = 10000 ) -> list :
0 commit comments