diff --git a/Cargo.lock b/Cargo.lock index ee5f69c17..cfd9d78ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1675,6 +1675,8 @@ dependencies = [ "chrono", "chrono-tz", "cot", + "tracing", + "tracing-subscriber", ] [[package]] diff --git a/cot/Cargo.toml b/cot/Cargo.toml index 75e9548d8..bc0067e35 100644 --- a/cot/Cargo.toml +++ b/cot/Cargo.toml @@ -24,7 +24,7 @@ blake3.workspace = true bytes.workspace = true chrono = { workspace = true, features = ["alloc", "serde", "clock"] } chrono-tz.workspace = true -clap.workspace = true +clap = {workspace = true, features = ["string"] } cot_core.workspace = true cot_macros.workspace = true deadpool-redis = { workspace = true, features = ["tokio-comp", "rt_tokio_1"], optional = true } diff --git a/cot/src/cli.rs b/cot/src/cli.rs index 97652a2e1..5601b252d 100644 --- a/cot/src/cli.rs +++ b/cot/src/cli.rs @@ -4,18 +4,21 @@ use std::collections::HashMap; use std::path::PathBuf; use std::str::FromStr; +use crate::{Bootstrapper, Error, Result}; use async_trait::async_trait; pub use clap; use clap::{Arg, ArgMatches, Command, value_parser}; +use cot::db::migrations::{MigrationEngine, SyncDynMigration}; +use cot::project::BootstrappedProject; use derive_more::Debug; -use crate::{Bootstrapper, Error, Result}; - const CONFIG_PARAM: &str = "config"; const COLLECT_STATIC_SUBCOMMAND: &str = "collect-static"; const CHECK_SUBCOMMAND: &str = "check"; const LISTEN_PARAM: &str = "listen"; const COLLECT_STATIC_DIR_PARAM: &str = "dir"; +const MIGRATION_GROUP_SUBCOMMAND: &str = "migration"; +const MIGRATION_ROLLBACK_SUBCOMMAND: &str = "rollback"; /// A central point for configuring the default Command Line Interface (CLI) for /// Cot-powered projects. @@ -91,6 +94,12 @@ impl Cli { cli.add_task(Check); cli.add_task(CollectStatic); + let mut migration_group = + CliTaskGroup::new("migration").about("Database migration commands"); + migration_group.add_task(MigrationRollback); + + cli.add_task(migration_group); + cli } @@ -389,6 +398,142 @@ impl CliTask for Check { } } +/// A group of related sub-tasks under a single parent subcommand. +/// +/// Usage: +/// ``` +/// let mut migration = CliTaskGroup::new("migration") +/// .about("Database migration commands"); +/// migration.add_task(MigrationRollback); +/// migration.add_task(MigrationSquash); +/// cli.add_task(migration); +/// ``` +pub struct CliTaskGroup { + name: String, + about: String, + tasks: HashMap>, +} + +impl CliTaskGroup { + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + about: String::new(), + tasks: HashMap::new(), + } + } + + pub fn about(mut self, about: impl Into) -> Self { + self.about = about.into(); + self + } + + pub fn add_task(&mut self, task: C) { + let subcommand = task.subcommand(); + let name = subcommand.get_name().to_owned(); + + assert!( + !self.tasks.contains_key(&name), + "Task with name {name} already exists in group '{}'", + self.name + ); + + self.tasks.insert(name, Box::new(task)); + } +} + +#[async_trait(?Send)] +impl CliTask for CliTaskGroup { + fn subcommand(&self) -> Command { + let name = self.name.clone(); + let mut cmd = Command::new(name); + + if !self.about.is_empty() { + cmd = cmd.about(self.about.clone()); + } + + cmd = cmd.subcommand_required(true).arg_required_else_help(true); + for task in self.tasks.values() { + cmd = cmd.subcommand(task.subcommand()); + } + cmd + } + + async fn execute( + &mut self, + matches: &ArgMatches, + bootstrapper: Bootstrapper, + ) -> Result<()> { + let (sub_name, sub_matches) = matches + .subcommand() + .expect("subcommand_required(true) ensures one is present"); + + self.tasks + .get_mut(sub_name) + .expect("clap only matches registered subcommands") + .execute(sub_matches, bootstrapper) + .await + } +} + +struct MigrationRollback; + +#[async_trait(?Send)] +impl CliTask for MigrationRollback { + fn subcommand(&self) -> Command { + Command::new(MIGRATION_ROLLBACK_SUBCOMMAND) + .about("Rollback migrations up to the specified migration file") + .arg( + Arg::new("file") + .help("The migration filename to roll back to (e.g. 0001_create_users)") + .value_name("FILE") + .required(true), + ) + } + + async fn execute( + &mut self, + matches: &ArgMatches, + bootstrapper: Bootstrapper, + ) -> Result<()> { + let file = matches + .get_one::("file") + .expect("required argument"); + + let bootstrapper = bootstrapper + .with_apps() + .with_database() + .await? + .with_cache() + .await? + .boot() + .await?; + + // migrations are currently tied to crates, so we use the crate name as the app name. + // TODO: cli command should take an explicit crate name as arg when workspaces are supported. + let crate_name = bootstrapper.project().cli_metadata().name; + + let BootstrappedProject { + mut context, + mut handler, + mut error_handler, + } = bootstrapper.finish(); + + #[cfg(feature = "db")] + { + let mut migrations: Vec> = Vec::new(); + for app in context.apps() { + migrations.extend(app.migrations()); + } + let migration_engine = MigrationEngine::new(migrations)?; + migration_engine + .rollback(context.database(), file, crate_name) + .await?; + } + Ok(()) + } +} + /// A macro to generate a [`CliMetadata`] struct from the Cargo manifest. #[macro_export] macro_rules! metadata { diff --git a/cot/src/db/migrations.rs b/cot/src/db/migrations.rs index 51dc19761..c81c8fc5d 100644 --- a/cot/src/db/migrations.rs +++ b/cot/src/db/migrations.rs @@ -2,12 +2,13 @@ mod sorter; +pub use cot_macros::migration_op; +use sea_query::{ColumnDef, StringLen}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::fmt; use std::fmt::{Debug, Formatter}; use std::future::Future; - -pub use cot_macros::migration_op; -use sea_query::{ColumnDef, StringLen}; +use std::path::{Path, PathBuf}; use thiserror::Error; use tracing::{Level, info}; @@ -215,6 +216,92 @@ impl MigrationEngine { Ok(()) } + pub async fn rollback(&self, database: &Database, file: &str, app_name: &str) -> Result<()> { + let rollback_plan = self.rollback_plan(file, app_name)?; + + for migration in rollback_plan { + if !Self::is_migration_applied(database, migration).await? { + continue; + } + + let span = tracing::span!( + Level::TRACE, + "rollback_migration", + app_name = migration.app_name(), + migration_name = migration.name() + ); + let _enter = span.enter(); + + info!( + "Rolling back migration {} for app {}", + migration.name(), + migration.app_name() + ); + + for operation in migration.operations().iter().rev() { + operation.backwards(database).await?; + } + + Self::unapply_migration(database, migration).await?; + } + + Ok(()) + } + + fn rollback_plan<'a>( + &'a self, + file: &str, + app_name: &str, + ) -> Result> { + let target_index = self + .migrations + .iter() + .position(|migration| { + migration.app_name() == app_name + && resolve_migration_file_name(migration.name()).contains(&file) + }) + .ok_or_else(|| { + MigrationEngineError::Custom(format!( + "Migration with file name {file} not found for app {app_name}" + )) + })?; + + let mut rollback_indices = HashSet::new(); + // seed all possible migration files in the same app that are dependent on the target migration + rollback_indices.extend( + self.migrations + .iter() + .enumerate() + .filter(|(index, migration)| { + *index > target_index && migration.app_name() == app_name + }) + .map(|(index, _)| index), + ); + + let graph = MigrationSorter::generate_graph(&self.migrations).map_err(|e| { + MigrationEngineError::Custom(format!("Failed to generate migration graph: {}", e)) + })?; + let mut queue = rollback_indices.iter().copied().collect::>(); + // go through every migration dependent on the target migration and add any dependencies to the queue. + // This is also useful in cases where a migration is depended on by migrations in other apps. In this case, + // we need to make sure that we also roll back the migrations in other apps. + while let Some(index) = queue.pop_front() { + for &dependent_index in graph.get_edges(index) { + if rollback_indices.insert(dependent_index) { + queue.push_back(dependent_index); + } + } + } + + Ok(self + .migrations + .iter() + .enumerate() + .rev() + .filter_map(|(index, migration)| rollback_indices.contains(&index).then_some(migration)) + .collect()) + } + async fn is_migration_applied( database: &Database, migration: &MigrationWrapper, @@ -241,6 +328,22 @@ impl MigrationEngine { database.insert(&mut applied_migration).await?; Ok(()) } + + async fn unapply_migration(database: &Database, migration: &MigrationWrapper) -> Result<()> { + query!(AppliedMigration, $app == migration.app_name() && $name == migration.name()) + .delete(database) + .await?; + Ok(()) + } +} + +fn resolve_migration_file_name(file_name: &str) -> Vec<&str> { + let mut names = vec![file_name]; + let migration_number = file_name.split('_').nth(1); + if let Some(migration_number) = migration_number { + names.push(migration_number); + } + names } /// A migration operation that can be run forwards or backwards. diff --git a/cot/src/db/migrations/sorter.rs b/cot/src/db/migrations/sorter.rs index e2d5c7e66..46faf1f96 100644 --- a/cot/src/db/migrations/sorter.rs +++ b/cot/src/db/migrations/sorter.rs @@ -61,11 +61,11 @@ impl<'a, T: DynMigration> MigrationSorter<'a, T> { Ok(()) } - fn toposort(&mut self) -> Result<()> { - let lookup = Self::create_lookup_table(self.migrations)?; - let mut graph = Graph::new(self.migrations.len()); + pub(super) fn generate_graph(migrations: &[T]) -> Result { + let lookup = Self::create_lookup_table(migrations)?; + let mut graph = Graph::new(migrations.len()); - for (index, migration) in self.migrations.iter().enumerate() { + for (index, migration) in migrations.iter().enumerate() { for dependency in migration.dependencies() { let dependency_index = lookup .get(&MigrationLookup::from(dependency)) @@ -74,6 +74,12 @@ impl<'a, T: DynMigration> MigrationSorter<'a, T> { } } + Ok(graph) + } + + fn toposort(&mut self) -> Result<()> { + let mut graph = Self::generate_graph(self.migrations)?; + let mut sorted_indices = graph.toposort()?; apply_permutation(self.migrations, &mut sorted_indices); diff --git a/cot/src/utils/graph.rs b/cot/src/utils/graph.rs index 741f98400..bf6053efc 100644 --- a/cot/src/utils/graph.rs +++ b/cot/src/utils/graph.rs @@ -40,6 +40,10 @@ impl Graph { self.vertex_edges[from].push(to); } + pub(crate) fn get_edges(&self, from: usize) -> &[usize] { + &self.vertex_edges[from] + } + #[must_use] pub(crate) fn vertex_num(&self) -> usize { self.vertex_edges.len()