-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathdatabase_ops.py
More file actions
144 lines (119 loc) · 5 KB
/
database_ops.py
File metadata and controls
144 lines (119 loc) · 5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import sqlalchemy
from sqlalchemy import orm
from . import backend_types_sql as bts
from . import filter_query_sql
def create_db_engine_and_migrate_db(
database_uri: str,
**kwargs,
) -> sqlalchemy.Engine:
db_engine = create_db_engine(database_uri=database_uri, **kwargs)
bts._TableBase.metadata.create_all(db_engine)
migrate_db(db_engine=db_engine)
return db_engine
def initialize_and_migrate_db(db_engine: sqlalchemy.Engine):
bts._TableBase.metadata.create_all(db_engine)
migrate_db(db_engine=db_engine)
def create_db_engine(
database_uri: str,
**kwargs,
) -> sqlalchemy.Engine:
if database_uri.startswith("mysql://"):
try:
import MySQLdb
except ImportError:
# Using PyMySQL instead of missing MySQLdb
database_uri = database_uri.replace("mysql://", "mysql+pymysql://")
create_engine_kwargs = {}
if database_uri == "sqlite://":
create_engine_kwargs["poolclass"] = sqlalchemy.pool.StaticPool
if database_uri.startswith("sqlite://"):
# FastApi claims it's needed and safe: https://fastapi.tiangolo.com/tutorial/sql-databases/#create-an-engine
create_engine_kwargs.setdefault("connect_args", {})["check_same_thread"] = False
# https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#using-a-memory-database-in-multiple-threads
if create_engine_kwargs.get("poolclass") != sqlalchemy.pool.StaticPool:
# Preventing the "MySQL server has gone away" error:
# https://docs.sqlalchemy.org/en/20/faq/connections.html#mysql-server-has-gone-away
create_engine_kwargs["pool_recycle"] = 3600
create_engine_kwargs["pool_pre_ping"] = True
if kwargs:
create_engine_kwargs.update(kwargs)
db_engine = sqlalchemy.create_engine(
url=database_uri,
**create_engine_kwargs,
)
return db_engine
def migrate_db(db_engine: sqlalchemy.Engine):
# # Example:
# sqlalchemy.Index(
# "ix_pipeline_run_created_by_created_at_desc",
# bts.PipelineRun.created_by,
# bts.PipelineRun.created_at.desc(),
# ).create(db_engine, checkfirst=True)
# index1 = sqlalchemy.Index(
# "ix_execution_node_container_execution_cache_key",
# bts.ExecutionNode.container_execution_cache_key,
# )
# index1.create(db_engine, checkfirst=True)
# SqlAlchemy's Index constructor is broken and adds indexes to the table definition (even if they are duplicate)
# See https://github.com/sqlalchemy/sqlalchemy/issues/12965
# See https://github.com/sqlalchemy/sqlalchemy/discussions/12420
# To work around that issue we either need to remove the index from the table
# bts.ExecutionNode.__table__.indexes.remove(index1)
# Or we need to avoid calling the Index constructor.
for index in bts.ExecutionNode.__table__.indexes:
if index.name == bts.ExecutionNode._IX_EXECUTION_NODE_CACHE_KEY:
index.create(db_engine, checkfirst=True)
break
for index in bts.PipelineRunAnnotation.__table__.indexes:
if index.name == bts.PipelineRunAnnotation._IX_ANNOTATION_RUN_ID_KEY_VALUE:
index.create(db_engine, checkfirst=True)
break
_backfill_pipeline_run_created_by_annotations(db_engine=db_engine)
def _is_pipeline_run_annotation_key_already_backfilled(
*,
session: orm.Session,
key: str,
) -> bool:
"""Return True if at least one annotation with the given key exists."""
return session.query(
sqlalchemy.exists(
sqlalchemy.select(sqlalchemy.literal(1))
.select_from(bts.PipelineRunAnnotation)
.where(
bts.PipelineRunAnnotation.key == key,
)
)
).scalar()
def _backfill_pipeline_run_created_by_annotations(
*,
db_engine: sqlalchemy.Engine,
) -> None:
"""Copy pipeline_run.created_by into pipeline_run_annotation so
annotation-based search works for created_by.
The check and insert run in a single session/transaction to avoid
TOCTOU races between concurrent startup processes.
Skips entirely if any created_by annotation key already exists (i.e. the
write-path is populating them, so the backfill has already run or is
no longer needed).
"""
with orm.Session(db_engine) as session:
if _is_pipeline_run_annotation_key_already_backfilled(
session=session,
key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY,
):
return
stmt = sqlalchemy.insert(bts.PipelineRunAnnotation).from_select(
["pipeline_run_id", "key", "value"],
sqlalchemy.select(
bts.PipelineRun.id,
sqlalchemy.literal(
filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY
),
bts.PipelineRun.created_by,
).where(
bts.PipelineRun.created_by.isnot(None),
bts.PipelineRun.created_by != "",
),
)
session.execute(stmt)
session.commit()