Skip to content
Merged
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion crates/autopilot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,17 @@ tokio-stream = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true, features = ["trace"] }
tracing = { workspace = true }
url = { workspace = true }
url = { workspace = true, features = ["serde"] }
winner-selection = { workspace = true }
toml = { workspace = true }
tempfile = { workspace = true, optional = true }

[dev-dependencies]
maplit = { workspace = true }
mockall = { workspace = true }
tokio = { workspace = true, features = ["test-util"] }
shared = { workspace = true, features = ["test-util"] }
tempfile = { workspace = true }

[build-dependencies]
anyhow = { workspace = true }
Expand All @@ -82,3 +85,4 @@ workspace = true
[features]
mimalloc-allocator = ["dep:mimalloc"]
tokio-console = ["observe/tokio-console"]
test-util = ["dep:tempfile"]
147 changes: 12 additions & 135 deletions crates/autopilot/src/arguments.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
use {
crate::{database::INSERT_BATCH_SIZE_DEFAULT, infra},
alloy::primitives::{Address, U256},
anyhow::{Context, anyhow, ensure},
alloy::primitives::Address,
anyhow::Context,
chrono::{DateTime, Utc},
clap::ValueEnum,
shared::{
arguments::{FeeFactor, display_list, display_option, display_secret_option},
arguments::{FeeFactor, display_option, display_secret_option},
http_client,
price_estimation::{self, NativePriceEstimators},
},
std::{
fmt::{self, Display, Formatter},
net::SocketAddr,
num::NonZeroUsize,
str::FromStr,
time::Duration,
},
std::{net::SocketAddr, num::NonZeroUsize, path::PathBuf, str::FromStr, time::Duration},
url::Url,
};

#[derive(clap::Parser)]
pub struct Arguments {
pub struct CliArguments {
#[clap(long, env)]
pub config: PathBuf,

#[clap(flatten)]
pub shared: shared::arguments::Arguments,

Expand Down Expand Up @@ -139,11 +136,6 @@ pub struct Arguments {
)]
pub trusted_tokens_update_interval: Duration,

/// A list of drivers in the following format:
/// `<NAME>|<URL>|<SUBMISSION_ADDRESS>|<FAIRNESS_THRESHOLD>`
#[clap(long, env, use_value_delimiter = true)]
pub drivers: Vec<Solver>,

/// The maximum number of blocks to wait for a settlement to appear on
/// chain.
#[clap(long, env, default_value = "5")]
Expand Down Expand Up @@ -272,9 +264,10 @@ pub struct Arguments {
pub native_price_prefetch_time: Duration,
}

impl std::fmt::Display for Arguments {
impl std::fmt::Display for CliArguments {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
config,
shared,
order_quoting,
http_client,
Expand All @@ -295,7 +288,6 @@ impl std::fmt::Display for Arguments {
trusted_tokens_url,
trusted_tokens,
trusted_tokens_update_interval,
drivers,
submission_deadline,
shadow,
solve_deadline,
Expand All @@ -319,7 +311,7 @@ impl std::fmt::Display for Arguments {
native_price_cache_refresh,
native_price_prefetch_time,
} = self;

write!(f, "{}", config.display())?;
write!(f, "{shared}")?;
write!(f, "{order_quoting}")?;
write!(f, "{http_client}")?;
Expand Down Expand Up @@ -354,7 +346,6 @@ impl std::fmt::Display for Arguments {
f,
"trusted_tokens_update_interval: {trusted_tokens_update_interval:?}"
)?;
display_list(f, "drivers", drivers.iter())?;
writeln!(f, "submission_deadline: {submission_deadline}")?;
display_option(f, "shadow", shadow)?;
writeln!(f, "solve_deadline: {solve_deadline:?}")?;
Expand Down Expand Up @@ -404,75 +395,6 @@ impl std::fmt::Display for Arguments {
}
}

/// External solver driver configuration
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Solver {
pub name: String,
pub url: Url,
pub submission_account: Account,
pub fairness_threshold: Option<U256>,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Account {
/// AWS KMS is used to retrieve the solver public key
Kms(Arn),
/// Solver public key
Address(Address),
}

// Wrapper type for AWS ARN identifiers
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Arn(pub String);

impl FromStr for Arn {
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
// Could be more strict here, but this should suffice to catch unintended
// configuration mistakes
if s.starts_with("arn:aws:kms:") {
Ok(Self(s.to_string()))
} else {
Err(anyhow!("Invalid ARN identifier: {}", s))
}
}
}

impl Display for Solver {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}({})", self.name, self.url)
}
}

impl FromStr for Solver {
type Err = anyhow::Error;

fn from_str(solver: &str) -> anyhow::Result<Self> {
let parts: Vec<&str> = solver.split('|').collect();
ensure!(parts.len() >= 3, "not enough arguments for external solver");
let (name, url) = (parts[0], parts[1]);
let url: Url = url.parse()?;
let submission_account = match Arn::from_str(parts[2]) {
Ok(value) => Account::Kms(value),
_ => {
Account::Address(Address::from_str(parts[2]).context("failed to parse submission")?)
}
};

let fairness_threshold = parts
.get(3)
.and_then(|value| U256::from_str_radix(value, 10).ok());

Ok(Self {
name: name.to_owned(),
url,
fairness_threshold,
submission_account,
})
}
}

#[derive(clap::Parser, Debug, Clone)]
pub struct FeePoliciesConfig {
/// Describes how the protocol fees should be calculated.
Expand Down Expand Up @@ -665,7 +587,7 @@ impl FromStr for CowAmmConfig {

#[cfg(test)]
mod test {
use {super::*, alloy::primitives::address};
use super::*;

#[test]
fn test_fee_factor_limits() {
Expand All @@ -692,49 +614,4 @@ mod test {
)
}
}

#[test]
fn parse_driver_submission_account_address() {
let argument = "name1|http://localhost:8080|0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2";
let driver = Solver::from_str(argument).unwrap();
let expected = Solver {
name: "name1".into(),
url: Url::parse("http://localhost:8080").unwrap(),
fairness_threshold: None,
submission_account: Account::Address(address!(
"C02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"
)),
};
assert_eq!(driver, expected);
}

#[test]
fn parse_driver_submission_account_arn() {
let argument = "name1|http://localhost:8080|arn:aws:kms:supersecretstuff";
let driver = Solver::from_str(argument).unwrap();
let expected = Solver {
name: "name1".into(),
url: Url::parse("http://localhost:8080").unwrap(),
fairness_threshold: None,
submission_account: Account::Kms(
Arn::from_str("arn:aws:kms:supersecretstuff").unwrap(),
),
};
assert_eq!(driver, expected);
}

#[test]
fn parse_driver_with_threshold() {
let argument = "name1|http://localhost:8080|0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2|1000000000000000000";
let driver = Solver::from_str(argument).unwrap();
let expected = Solver {
name: "name1".into(),
url: Url::parse("http://localhost:8080").unwrap(),
submission_account: Account::Address(address!(
"C02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"
)),
fairness_threshold: Some(U256::from(10).pow(U256::from(18))),
};
assert_eq!(driver, expected);
}
}
60 changes: 60 additions & 0 deletions crates/autopilot/src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use {
crate::config::solver::Solver,
anyhow::{anyhow, ensure},
serde::{Deserialize, Serialize},
std::path::Path,
};

pub mod solver;

#[derive(Debug, Default, Deserialize, Serialize)]
Comment thread
jmg-duarte marked this conversation as resolved.
#[serde(rename_all = "kebab-case", deny_unknown_fields)]
pub struct Configuration {
// #[serde(default)]
pub drivers: Vec<Solver>,
}

impl Configuration {
pub async fn from_path<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
match toml::from_str(&tokio::fs::read_to_string(&path).await?) {
Ok(self_) => Ok(self_),
Err(err) if std::env::var("TOML_TRACE_ERROR").is_ok_and(|v| v == "1") => Err(anyhow!(
"failed to parse TOML config at {}: {err:#?}",
path.as_ref().display()
)),
Err(_) => Err(anyhow!(
"failed to parse TOML config at: {}. Set TOML_TRACE_ERROR=1 to print parsing \
error but this may leak secrets.",
path.as_ref().display()
)),
}
}

pub async fn to_path<P: AsRef<Path>>(&self, path: P) -> anyhow::Result<()> {
Ok(tokio::fs::write(path, toml::to_string_pretty(self)?).await?)
}

#[cfg(any(test, feature = "test-util"))]
pub fn to_temp_path(&self) -> tempfile::NamedTempFile {
use std::io::Write;
let mut file = tempfile::NamedTempFile::new().expect("temp file creation should not fail");
file.write_all(
toml::to_string_pretty(self)
.expect("serialization should not fail")
.as_bytes(),
)
.expect("writing to temp file should not fail");
file
}

// Note for reviewers: if this and other validations are always applied,
// we should instead move them to the deserialization stage
// https://lexi-lambda.github.io/blog/2019/11/05/parse-don-t-validate/
pub fn validate(self) -> anyhow::Result<Self> {
ensure!(
!self.drivers.is_empty(),
"colocation is enabled but no drivers are configured"
);
Ok(self)
}
}
Loading
Loading