Skip to content

Commit a0dbd86

Browse files
committed
move shared DbManager + fns to db.py
1 parent 1923849 commit a0dbd86

4 files changed

Lines changed: 313 additions & 331 deletions

File tree

llms/db.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
import json
2+
import sqlite3
3+
import threading
4+
from queue import Empty, Queue
5+
from threading import Event, Thread
6+
7+
8+
def create_reader_connection(db_path):
9+
conn = sqlite3.connect(db_path, timeout=1.0) # Lower - reads should be fast
10+
conn.execute("PRAGMA query_only=1") # Read-only optimization
11+
return conn
12+
13+
14+
def create_writer_connection(db_path):
15+
conn = sqlite3.connect(db_path)
16+
conn.execute("PRAGMA busy_timeout=5000") # Reasonable timeout for busy connections
17+
conn.execute("PRAGMA journal_mode=WAL") # Enable WAL mode for better concurrency
18+
conn.execute("PRAGMA cache_size=-128000") # Increase cache size for better performance
19+
conn.execute("PRAGMA synchronous=NORMAL") # Reasonable durability/performance balance
20+
return conn
21+
22+
23+
def writer_thread(ctx, db_path, task_queue, stop_event):
24+
conn = create_writer_connection(db_path)
25+
try:
26+
while not stop_event.is_set():
27+
try:
28+
# Use timeout to check stop_event periodically
29+
task = task_queue.get(timeout=0.1)
30+
31+
if task is None: # Poison pill for clean shutdown
32+
break
33+
34+
sql, args, callback = task # Optional callback for results
35+
36+
try:
37+
ctx.dbg("SQL>" + ("\n" if "\n" in sql else " ") + sql)
38+
cursor = conn.execute(sql, args)
39+
conn.commit()
40+
ctx.dbg(f"lastrowid {cursor.lastrowid}, rowcount {cursor.rowcount}")
41+
if callback:
42+
callback(cursor.lastrowid, cursor.rowcount)
43+
except sqlite3.Error as e:
44+
ctx.err("writer_thread", e)
45+
if callback:
46+
callback(None, None, error=e)
47+
finally:
48+
task_queue.task_done()
49+
50+
except Empty:
51+
continue
52+
53+
finally:
54+
conn.close()
55+
56+
57+
def valid_columns(all_columns, fields):
58+
if fields:
59+
if not isinstance(fields, list):
60+
fields = fields.split(",")
61+
cols = []
62+
for k in fields:
63+
k = k.strip()
64+
if k in all_columns:
65+
cols.append(k)
66+
return cols
67+
return []
68+
69+
70+
def table_columns(all_columns, fields):
71+
cols = valid_columns(all_columns, fields)
72+
return ", ".join(cols) if len(cols) > 0 else ", ".join(all_columns)
73+
74+
75+
def select_columns(all_columns, fields, select=None):
76+
columns = table_columns(all_columns, fields)
77+
if select == "distinct":
78+
return f"SELECT DISTINCT {columns}"
79+
return f"SELECT {columns}"
80+
81+
82+
def order_by(all_columns, sort):
83+
cols = []
84+
for k in sort.split(","):
85+
k = k.strip()
86+
by = ""
87+
if k[0] == "-":
88+
by = " DESC"
89+
k = k[1:]
90+
if k in all_columns:
91+
cols.append(f"{k}{by}")
92+
return f"ORDER BY {', '.join(cols)} " if len(cols) > 0 else ""
93+
94+
95+
class DbManager:
96+
def __init__(self, ctx, db_path):
97+
if db_path is None:
98+
raise ValueError("db_path is required")
99+
self.ctx = ctx
100+
self.db_path = db_path
101+
self.task_queue = Queue()
102+
self.stop_event = Event()
103+
self.writer_thread = Thread(target=writer_thread, args=(ctx, db_path, self.task_queue, self.stop_event))
104+
self.writer_thread.start()
105+
self.read_only_pool = Queue()
106+
107+
def create_reader_connection(self):
108+
return create_reader_connection(self.db_path)
109+
110+
def create_writer_connection(self):
111+
return create_writer_connection(self.db_path)
112+
113+
def resolve_connection(self):
114+
try:
115+
return self.read_only_pool.get_nowait()
116+
except Empty:
117+
return self.create_reader_connection()
118+
119+
def write(self, query, args=None, callback=None):
120+
"""
121+
Execute a write operation asynchronously.
122+
123+
Args:
124+
query (str): The SQL query to execute.
125+
args (tuple, optional): Arguments for the query.
126+
callback (callable, optional): A function called after execution with signature:
127+
callback(lastrowid, rowcount, error=None)
128+
- lastrowid (int): output of cursor.lastrowid
129+
- rowcount (int): output of cursor.rowcount
130+
- error (Exception): exception if operation failed, else None
131+
"""
132+
self.task_queue.put((query, args, callback))
133+
134+
def log_sql(self, sql, parameters=None):
135+
if self.ctx.debug:
136+
self.ctx.dbg("SQL>" + ("\n" if "\n" in sql else " ") + sql + ("\n" if parameters else "") + str(parameters))
137+
138+
def exec(self, connection, sql, parameters=None):
139+
self.log_sql(sql, parameters)
140+
return connection.execute(sql, parameters or ())
141+
142+
def all(self, sql, parameters=None, connection=None):
143+
"""
144+
Execute a query and return all rows as a list of dictionaries.
145+
"""
146+
conn = self.resolve_connection() if connection is None else connection
147+
148+
try:
149+
self.log_sql(sql, parameters)
150+
conn.row_factory = sqlite3.Row
151+
cursor = conn.execute(sql, parameters or ())
152+
rows = [dict(row) for row in cursor.fetchall()]
153+
return rows
154+
finally:
155+
if connection is None:
156+
conn.row_factory = None
157+
self.read_only_pool.put(conn)
158+
159+
def one(self, sql, parameters=None, connection=None):
160+
"""
161+
Execute a query and return the first row as a dictionary.
162+
"""
163+
conn = self.resolve_connection() if connection is None else connection
164+
165+
try:
166+
self.log_sql(sql, parameters)
167+
conn.row_factory = sqlite3.Row
168+
cursor = conn.execute(sql, parameters or ())
169+
row = cursor.fetchone()
170+
return dict(row) if row else None
171+
finally:
172+
if connection is None:
173+
conn.row_factory = None
174+
self.read_only_pool.put(conn)
175+
176+
def scalar(self, sql, parameters=None, connection=None):
177+
"""
178+
Execute a scalar query and return the first column of the first row.
179+
"""
180+
conn = self.resolve_connection() if connection is None else connection
181+
182+
try:
183+
self.log_sql(sql, parameters)
184+
conn.row_factory = sqlite3.Row
185+
cursor = conn.execute(sql, parameters or ())
186+
row = cursor.fetchone()
187+
return row[0] if row else None
188+
finally:
189+
if connection is None:
190+
conn.row_factory = None
191+
self.read_only_pool.put(conn)
192+
193+
def column(self, sql, parameters=None, connection=None):
194+
"""
195+
Execute a 1 column query and return the values as a list.
196+
"""
197+
conn = self.resolve_connection() if connection is None else connection
198+
199+
try:
200+
self.log_sql(sql, parameters)
201+
cursor = conn.execute(sql, parameters or ())
202+
return [row[0] for row in cursor.fetchall()]
203+
finally:
204+
if connection is None:
205+
self.read_only_pool.put(conn)
206+
207+
def dict(self, sql, parameters=None, connection=None):
208+
"""
209+
Execute a 2 column query and return the keys as the first column and the values as the second column.
210+
"""
211+
conn = self.resolve_connection() if connection is None else connection
212+
213+
try:
214+
self.log_sql(sql, parameters)
215+
conn.row_factory = sqlite3.Row
216+
cursor = conn.execute(sql, parameters or ())
217+
rows = cursor.fetchall()
218+
return {row[0]: row[1] for row in rows}
219+
finally:
220+
if connection is None:
221+
conn.row_factory = None
222+
self.read_only_pool.put(conn)
223+
224+
# Helper to safely dump JSON if value exists
225+
def value(self, val):
226+
if val is None or val == "":
227+
return None
228+
if isinstance(val, (dict, list)):
229+
return json.dumps(val)
230+
return val
231+
232+
def insert(self, table, columns, info, callback=None):
233+
if not info:
234+
raise Exception("info is required")
235+
236+
args = {}
237+
known_columns = columns.keys()
238+
for k, val in info.items():
239+
if k in known_columns and k != "id":
240+
args[k] = self.value(val)
241+
242+
insert_keys = list(args.keys())
243+
insert_body = ", ".join(insert_keys)
244+
insert_values = ", ".join(["?" for _ in insert_keys])
245+
246+
sql = f"INSERT INTO {table} ({insert_body}) VALUES ({insert_values})"
247+
248+
self.write(sql, tuple(args[k] for k in insert_keys), callback)
249+
250+
async def insert_async(self, table, columns, info):
251+
event = threading.Event()
252+
253+
ret = [None]
254+
255+
def cb(lastrowid, rowcount, error=None):
256+
nonlocal ret
257+
if error:
258+
raise error
259+
ret[0] = lastrowid
260+
event.set()
261+
262+
self.insert(table, columns, info, cb)
263+
event.wait()
264+
return ret[0]
265+
266+
def update(self, table, info, columns, callback=None):
267+
if not info:
268+
raise Exception("info is required")
269+
270+
args = {}
271+
known_columns = columns.keys()
272+
for k, val in info.items():
273+
if k in known_columns and k != "id":
274+
args[k] = self.value(val)
275+
276+
update_keys = list(args.keys())
277+
update_body = ", ".join([f"{k} = :{k}" for k in update_keys])
278+
279+
args["id"] = info["id"]
280+
sql = f"UPDATE {table} SET {update_body} WHERE id = :id"
281+
282+
self.write(sql, args, callback)
283+
284+
async def update_async(self, table, columns, info):
285+
event = threading.Event()
286+
287+
ret = [None]
288+
289+
def cb(lastrowid, rowcount, error=None):
290+
nonlocal ret
291+
if error:
292+
raise error
293+
ret[0] = rowcount
294+
event.set()
295+
296+
self.update(table, info, columns, cb)
297+
event.wait()
298+
return ret[0]
299+
300+
def close(self):
301+
self.ctx.dbg("Closing database")
302+
self.stop_event.set()
303+
self.task_queue.put(None) # Poison pill to signal shutdown
304+
self.writer_thread.join()
305+
306+
while not self.read_only_pool.empty():
307+
try:
308+
conn = self.read_only_pool.get_nowait()
309+
conn.close()
310+
except Empty:
311+
break

llms/extensions/app/db.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import datetime, timedelta
44
from typing import Any, Dict
55

6-
from llms.main import DbManager
6+
from llms.db import DbManager, order_by, select_columns, valid_columns
77

88

99
def with_user(data, user):
@@ -16,44 +16,6 @@ def with_user(data, user):
1616
return data
1717

1818

19-
def valid_columns(all_columns, fields):
20-
if fields:
21-
if not isinstance(fields, list):
22-
fields = fields.split(",")
23-
cols = []
24-
for k in fields:
25-
k = k.strip()
26-
if k in all_columns:
27-
cols.append(k)
28-
return cols
29-
return []
30-
31-
32-
def table_columns(all_columns, fields):
33-
cols = valid_columns(all_columns, fields)
34-
return ", ".join(cols) if len(cols) > 0 else ", ".join(all_columns)
35-
36-
37-
def select_columns(all_columns, fields, select=None):
38-
columns = table_columns(all_columns, fields)
39-
if select == "distinct":
40-
return f"SELECT DISTINCT {columns}"
41-
return f"SELECT {columns}"
42-
43-
44-
def order_by(all_columns, sort):
45-
cols = []
46-
for k in sort.split(","):
47-
k = k.strip()
48-
by = ""
49-
if k[0] == "-":
50-
by = " DESC"
51-
k = k[1:]
52-
if k in all_columns:
53-
cols.append(f"{k}{by}")
54-
return f"ORDER BY {', '.join(cols)} " if len(cols) > 0 else ""
55-
56-
5719
class AppDB:
5820
def __init__(self, ctx, db_path):
5921
if db_path is None:

llms/extensions/gallery/db.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from typing import Any, Dict
44

5-
from llms.main import DbManager
5+
from llms.db import DbManager, order_by
66

77

88
def with_user(data, user):
@@ -24,19 +24,6 @@ def ratio_format(ratio):
2424
return 0
2525

2626

27-
def order_by(all_columns, sort):
28-
cols = []
29-
for k in sort.split(","):
30-
k = k.strip()
31-
by = ""
32-
if k[0] == "-":
33-
by = " DESC"
34-
k = k[1:]
35-
if k in all_columns:
36-
cols.append(f"{k}{by}")
37-
return f"ORDER BY {', '.join(cols)} " if len(cols) > 0 else ""
38-
39-
4027
class GalleryDB:
4128
def __init__(self, ctx, db_path=None):
4229
if db_path is None:

0 commit comments

Comments
 (0)