-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmigration_runner.py
More file actions
131 lines (105 loc) · 4.63 KB
/
migration_runner.py
File metadata and controls
131 lines (105 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import asyncio
import asyncpg
from pathlib import Path
from src.infrastructure.config import settings
import platform
class MigrationRunner:
def __init__(self, migrations_dir: str = None):
base_dir = Path(__file__).resolve().parent
if migrations_dir is None:
migrations_dir = base_dir / "migrations"
self.migrations_dir = Path(migrations_dir)
self.migrations_dir.mkdir(parents=True, exist_ok=True)
async def _ensure_migrations_table(self, conn):
await conn.execute("""
create table if not exists schema_migrations (
version varchar(255) primary key,
applied_at timestamp default current_timestamp
)
""")
async def _get_applied_migrations(self, conn):
rows = await conn.fetch("select version from schema_migrations order by version")
return {row['version'] for row in rows}
async def _get_pending_migrations(self, conn):
applied = await self._get_applied_migrations(conn)
all_migrations = sorted([f.stem for f in self.migrations_dir.glob("*.sql")])
print(f"🔍 Found migrations: {all_migrations}")
return [m for m in all_migrations if m not in applied]
async def migrate(self):
print(
f"🔌 Connecting to {settings.database_name}@"
f"{settings.database_host}:{settings.database_port}"
f" as {settings.database_user}")
conn = await asyncpg.connect(
host=settings.database_host,
port=settings.database_port,
database=settings.database_name,
user=settings.database_user,
password=settings.database_password,
)
db_name = await conn.fetchval("SELECT current_database()")
print(f"📡 Connected to database: {db_name}")
try:
await self._ensure_migrations_table(conn)
pending = await self._get_pending_migrations(conn)
if not pending:
print("No pending migrations")
return
for migration_name in pending:
print(f"Applying migration: {migration_name}")
migration_file = self.migrations_dir / f"{migration_name}.sql"
sql = migration_file.read_text()
async with conn.transaction():
await conn.execute(sql)
await conn.execute(
"insert into schema_migrations (version) values ($1)",
migration_name
)
print(f"✅ Applied: {migration_name}")
print(f"\n✅ Successfully applied {len(pending)} migration(s)")
except Exception as e:
print(f"❌ Migration failed: {e}")
raise
finally:
await conn.close()
async def status(self):
conn = await asyncpg.connect(
host=settings.database_host,
port=settings.database_port,
database=settings.database_name,
user=settings.database_user,
password=settings.database_password,
)
db_name = await conn.fetchval("SELECT current_database()")
print(f"📡 Connected to database: {db_name}")
try:
await self._ensure_migrations_table(conn)
applied = await self._get_applied_migrations(conn)
all_migrations = sorted([
f.stem for f in self.migrations_dir.glob("*.sql")
])
if not all_migrations:
print("\n⚠️ No migration files found")
print(f"Create SQL files in: {self.migrations_dir}")
return
print("\nMigration Status:")
print("-" * 70)
for migration in all_migrations:
status = "✅ Applied" if migration in applied else "⏳ Pending"
print(f"{migration:<55} {status}")
print("-" * 70)
pending_count = len([m for m in all_migrations if m not in applied])
print(f"\nTotal: {len(all_migrations)} | Applied: {len(applied)} | Pending: {pending_count}")
finally:
await conn.close()
async def main():
import sys
runner = MigrationRunner()
if len(sys.argv) > 1 and sys.argv[1] == "status":
await runner.status()
else:
await runner.migrate()
if platform.system() == "Windows":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
if __name__ == "__main__":
asyncio.run(main())