-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathdatabase_ops.py
More file actions
151 lines (127 loc) · 5.45 KB
/
database_ops.py
File metadata and controls
151 lines (127 loc) · 5.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import logging
import sqlalchemy
from sqlalchemy import orm
from . import backend_types_sql as bts
from . import database_migrations
_logger = logging.getLogger(__name__)
def create_db_engine_and_migrate_db(
*,
database_uri: str,
do_skip_backfill: bool = False,
**kwargs,
) -> sqlalchemy.Engine:
db_engine = create_db_engine(database_uri=database_uri, **kwargs)
bts._TableBase.metadata.create_all(db_engine)
migrate_db(db_engine=db_engine, do_skip_backfill=do_skip_backfill)
return db_engine
def initialize_and_migrate_db(
*,
db_engine: sqlalchemy.Engine,
do_skip_backfill: bool = False,
) -> None:
bts._TableBase.metadata.create_all(db_engine)
migrate_db(db_engine=db_engine, do_skip_backfill=do_skip_backfill)
def create_db_engine(
database_uri: str,
**kwargs,
) -> sqlalchemy.Engine:
if database_uri.startswith("mysql://"):
try:
import MySQLdb # noqa: F401
except ImportError:
# Using PyMySQL instead of missing MySQLdb
database_uri = database_uri.replace("mysql://", "mysql+pymysql://")
create_engine_kwargs = {}
if database_uri == "sqlite://":
create_engine_kwargs["poolclass"] = sqlalchemy.pool.StaticPool
if database_uri.startswith("sqlite://"):
# FastApi claims it's needed and safe: https://fastapi.tiangolo.com/tutorial/sql-databases/#create-an-engine
create_engine_kwargs.setdefault("connect_args", {})["check_same_thread"] = False
# https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#using-a-memory-database-in-multiple-threads
if create_engine_kwargs.get("poolclass") != sqlalchemy.pool.StaticPool:
# Preventing the "MySQL server has gone away" error:
# https://docs.sqlalchemy.org/en/20/faq/connections.html#mysql-server-has-gone-away
create_engine_kwargs["pool_recycle"] = 3600
create_engine_kwargs["pool_pre_ping"] = True
if kwargs:
create_engine_kwargs.update(kwargs)
db_engine = sqlalchemy.create_engine(
url=database_uri,
**create_engine_kwargs,
)
return db_engine
def _add_columns_if_missing(*, db_engine: sqlalchemy.Engine) -> None:
"""Add new nullable columns to existing tables when they are not yet present.
SQLAlchemy's create_all() only creates missing tables, not missing columns,
so new columns require an explicit migration step. All additions run in a
single transaction so the schema is updated atomically."""
_COLUMN_MIGRATIONS = [
bts.ExecutionNode.__table__.c.status_updated_at,
]
inspector = sqlalchemy.inspect(db_engine)
with db_engine.connect() as conn:
for col in _COLUMN_MIGRATIONS:
existing = {c["name"] for c in inspector.get_columns(col.table.name)}
if col.name not in existing:
_logger.info(
f"Migrating: ALTER TABLE {col.table.name} ADD COLUMN {col.name} ({col.type})"
)
try:
col_type_str = col.type.compile(dialect=db_engine.dialect)
conn.execute(
sqlalchemy.text(
f"ALTER TABLE {col.table.name}"
f" ADD COLUMN {col.name} {col_type_str}"
)
)
except sqlalchemy.exc.OperationalError:
_logger.info(
f"Column {col.table.name}.{col.name} already exists (concurrent migration) — skipping"
)
else:
_logger.info(
f"Column {col.table.name}.{col.name} already exists — skipping"
)
conn.commit()
def migrate_db(
*,
db_engine: sqlalchemy.Engine,
do_skip_backfill: bool,
) -> None:
_logger.info("Enter migrate DB")
_add_columns_if_missing(db_engine=db_engine)
# # Example:
# sqlalchemy.Index(
# "ix_pipeline_run_created_by_created_at_desc",
# bts.PipelineRun.created_by,
# bts.PipelineRun.created_at.desc(),
# ).create(db_engine, checkfirst=True)
# index1 = sqlalchemy.Index(
# "ix_execution_node_container_execution_cache_key",
# bts.ExecutionNode.container_execution_cache_key,
# )
# index1.create(db_engine, checkfirst=True)
# SqlAlchemy's Index constructor is broken and adds indexes to the table definition (even if they are duplicate)
# See https://github.com/sqlalchemy/sqlalchemy/issues/12965
# See https://github.com/sqlalchemy/sqlalchemy/discussions/12420
# To work around that issue we either need to remove the index from the table
# bts.ExecutionNode.__table__.indexes.remove(index1)
# Or we need to avoid calling the Index constructor.
for index in bts.ExecutionNode.__table__.indexes:
if index.name in (
bts.ExecutionNode._IX_EXECUTION_NODE_CACHE_KEY,
"ix_execution_node_container_execution_id",
):
index.create(db_engine, checkfirst=True)
for index in bts.PipelineRunAnnotation.__table__.indexes:
if index.name == bts.PipelineRunAnnotation._IX_ANNOTATION_RUN_ID_KEY_VALUE:
index.create(db_engine, checkfirst=True)
break
if do_skip_backfill:
_logger.info("Skipping annotation backfills")
else:
with orm.Session(db_engine) as session:
database_migrations.run_all_annotation_backfills(
session=session,
)
_logger.info("Exit migrate DB")