Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion cot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ blake3.workspace = true
bytes.workspace = true
chrono = { workspace = true, features = ["alloc", "serde", "clock"] }
chrono-tz.workspace = true
clap.workspace = true
clap = {workspace = true, features = ["string"] }
cot_core.workspace = true
cot_macros.workspace = true
deadpool-redis = { workspace = true, features = ["tokio-comp", "rt_tokio_1"], optional = true }
Expand Down
149 changes: 147 additions & 2 deletions cot/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@ use std::collections::HashMap;
use std::path::PathBuf;
use std::str::FromStr;

use crate::{Bootstrapper, Error, Result};
use async_trait::async_trait;
pub use clap;
use clap::{Arg, ArgMatches, Command, value_parser};
use cot::db::migrations::{MigrationEngine, SyncDynMigration};
use cot::project::BootstrappedProject;
use derive_more::Debug;

use crate::{Bootstrapper, Error, Result};

const CONFIG_PARAM: &str = "config";
const COLLECT_STATIC_SUBCOMMAND: &str = "collect-static";
const CHECK_SUBCOMMAND: &str = "check";
const LISTEN_PARAM: &str = "listen";
const COLLECT_STATIC_DIR_PARAM: &str = "dir";
const MIGRATION_GROUP_SUBCOMMAND: &str = "migration";
const MIGRATION_ROLLBACK_SUBCOMMAND: &str = "rollback";

/// A central point for configuring the default Command Line Interface (CLI) for
/// Cot-powered projects.
Expand Down Expand Up @@ -91,6 +94,12 @@ impl Cli {
cli.add_task(Check);
cli.add_task(CollectStatic);

let mut migration_group =
CliTaskGroup::new("migration").about("Database migration commands");
migration_group.add_task(MigrationRollback);

cli.add_task(migration_group);

cli
}

Expand Down Expand Up @@ -389,6 +398,142 @@ impl CliTask for Check {
}
}

/// A group of related sub-tasks under a single parent subcommand.
///
/// Usage:
/// ```
/// let mut migration = CliTaskGroup::new("migration")
/// .about("Database migration commands");
/// migration.add_task(MigrationRollback);
/// migration.add_task(MigrationSquash);
/// cli.add_task(migration);
/// ```
pub struct CliTaskGroup {
name: String,
about: String,
tasks: HashMap<String, Box<dyn CliTask + Send + 'static>>,
}

impl CliTaskGroup {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
about: String::new(),
tasks: HashMap::new(),
}
}

pub fn about(mut self, about: impl Into<String>) -> Self {
self.about = about.into();
self
}

pub fn add_task<C: CliTask + Send + 'static>(&mut self, task: C) {
let subcommand = task.subcommand();
let name = subcommand.get_name().to_owned();

assert!(
!self.tasks.contains_key(&name),
"Task with name {name} already exists in group '{}'",
self.name
);

self.tasks.insert(name, Box::new(task));
}
}

#[async_trait(?Send)]
impl CliTask for CliTaskGroup {
fn subcommand(&self) -> Command {
let name = self.name.clone();
let mut cmd = Command::new(name);

if !self.about.is_empty() {
cmd = cmd.about(self.about.clone());
}

cmd = cmd.subcommand_required(true).arg_required_else_help(true);
for task in self.tasks.values() {
cmd = cmd.subcommand(task.subcommand());
}
cmd
}

async fn execute(
&mut self,
matches: &ArgMatches,
bootstrapper: Bootstrapper<WithConfig>,
) -> Result<()> {
let (sub_name, sub_matches) = matches
.subcommand()
.expect("subcommand_required(true) ensures one is present");

self.tasks
.get_mut(sub_name)
.expect("clap only matches registered subcommands")
.execute(sub_matches, bootstrapper)
.await
}
}

struct MigrationRollback;

#[async_trait(?Send)]
impl CliTask for MigrationRollback {
fn subcommand(&self) -> Command {
Command::new(MIGRATION_ROLLBACK_SUBCOMMAND)
.about("Rollback migrations up to the specified migration file")
.arg(
Arg::new("file")
.help("The migration filename to roll back to (e.g. 0001_create_users)")
.value_name("FILE")
.required(true),
)
}

async fn execute(
&mut self,
matches: &ArgMatches,
bootstrapper: Bootstrapper<WithConfig>,
) -> Result<()> {
let file = matches
.get_one::<String>("file")
.expect("required argument");

let bootstrapper = bootstrapper
.with_apps()
.with_database()
.await?
.with_cache()
.await?
.boot()
.await?;

// migrations are currently tied to crates, so we use the crate name as the app name.
// TODO: cli command should take an explicit crate name as arg when workspaces are supported.
let crate_name = bootstrapper.project().cli_metadata().name;

let BootstrappedProject {
mut context,
mut handler,
mut error_handler,
} = bootstrapper.finish();

#[cfg(feature = "db")]
{
let mut migrations: Vec<Box<SyncDynMigration>> = Vec::new();
for app in context.apps() {
migrations.extend(app.migrations());
}
let migration_engine = MigrationEngine::new(migrations)?;
migration_engine
.rollback(context.database(), file, crate_name)
.await?;
}
Ok(())
}
}

/// A macro to generate a [`CliMetadata`] struct from the Cargo manifest.
#[macro_export]
macro_rules! metadata {
Expand Down
109 changes: 106 additions & 3 deletions cot/src/db/migrations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

mod sorter;

pub use cot_macros::migration_op;
use sea_query::{ColumnDef, StringLen};
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::future::Future;

pub use cot_macros::migration_op;
use sea_query::{ColumnDef, StringLen};
use std::path::{Path, PathBuf};
use thiserror::Error;
use tracing::{Level, info};

Expand Down Expand Up @@ -215,6 +216,92 @@ impl MigrationEngine {
Ok(())
}

pub async fn rollback(&self, database: &Database, file: &str, app_name: &str) -> Result<()> {
let rollback_plan = self.rollback_plan(file, app_name)?;

for migration in rollback_plan {
if !Self::is_migration_applied(database, migration).await? {
continue;
}

let span = tracing::span!(
Level::TRACE,
"rollback_migration",
app_name = migration.app_name(),
migration_name = migration.name()
);
let _enter = span.enter();

info!(
"Rolling back migration {} for app {}",
migration.name(),
migration.app_name()
);

for operation in migration.operations().iter().rev() {
operation.backwards(database).await?;
}

Self::unapply_migration(database, migration).await?;
}

Ok(())
}

fn rollback_plan<'a>(
&'a self,
file: &str,
app_name: &str,
) -> Result<Vec<&'a MigrationWrapper>> {
let target_index = self
.migrations
.iter()
.position(|migration| {
migration.app_name() == app_name
&& resolve_migration_file_name(migration.name()).contains(&file)
})
.ok_or_else(|| {
MigrationEngineError::Custom(format!(
"Migration with file name {file} not found for app {app_name}"
))
})?;

let mut rollback_indices = HashSet::new();
// seed all possible migration files in the same app that are dependent on the target migration
rollback_indices.extend(
self.migrations
.iter()
.enumerate()
.filter(|(index, migration)| {
*index > target_index && migration.app_name() == app_name
})
.map(|(index, _)| index),
);

let graph = MigrationSorter::generate_graph(&self.migrations).map_err(|e| {
MigrationEngineError::Custom(format!("Failed to generate migration graph: {}", e))
})?;
let mut queue = rollback_indices.iter().copied().collect::<VecDeque<_>>();
// go through every migration dependent on the target migration and add any dependencies to the queue.
// This is also useful in cases where a migration is depended on by migrations in other apps. In this case,
// we need to make sure that we also roll back the migrations in other apps.
while let Some(index) = queue.pop_front() {
for &dependent_index in graph.get_edges(index) {
if rollback_indices.insert(dependent_index) {
queue.push_back(dependent_index);
}
}
}

Ok(self
.migrations
.iter()
.enumerate()
.rev()
.filter_map(|(index, migration)| rollback_indices.contains(&index).then_some(migration))
.collect())
}

async fn is_migration_applied(
database: &Database,
migration: &MigrationWrapper,
Expand All @@ -241,6 +328,22 @@ impl MigrationEngine {
database.insert(&mut applied_migration).await?;
Ok(())
}

async fn unapply_migration(database: &Database, migration: &MigrationWrapper) -> Result<()> {
query!(AppliedMigration, $app == migration.app_name() && $name == migration.name())
.delete(database)
.await?;
Ok(())
}
}

fn resolve_migration_file_name(file_name: &str) -> Vec<&str> {
let mut names = vec![file_name];
let migration_number = file_name.split('_').nth(1);
if let Some(migration_number) = migration_number {
names.push(migration_number);
}
names
}

/// A migration operation that can be run forwards or backwards.
Expand Down
14 changes: 10 additions & 4 deletions cot/src/db/migrations/sorter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ impl<'a, T: DynMigration> MigrationSorter<'a, T> {
Ok(())
}

fn toposort(&mut self) -> Result<()> {
let lookup = Self::create_lookup_table(self.migrations)?;
let mut graph = Graph::new(self.migrations.len());
pub(super) fn generate_graph(migrations: &[T]) -> Result<Graph> {
let lookup = Self::create_lookup_table(migrations)?;
let mut graph = Graph::new(migrations.len());

for (index, migration) in self.migrations.iter().enumerate() {
for (index, migration) in migrations.iter().enumerate() {
for dependency in migration.dependencies() {
let dependency_index = lookup
.get(&MigrationLookup::from(dependency))
Expand All @@ -74,6 +74,12 @@ impl<'a, T: DynMigration> MigrationSorter<'a, T> {
}
}

Ok(graph)
}

fn toposort(&mut self) -> Result<()> {
let mut graph = Self::generate_graph(self.migrations)?;

let mut sorted_indices = graph.toposort()?;
apply_permutation(self.migrations, &mut sorted_indices);

Expand Down
4 changes: 4 additions & 0 deletions cot/src/utils/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ impl Graph {
self.vertex_edges[from].push(to);
}

pub(crate) fn get_edges(&self, from: usize) -> &[usize] {
&self.vertex_edges[from]
}

#[must_use]
pub(crate) fn vertex_num(&self) -> usize {
self.vertex_edges.len()
Expand Down
Loading