From d0a79ff3c50d65e69a0d2f89dddac1495883fffd Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Thu, 5 Mar 2026 16:30:20 -0800 Subject: [PATCH 01/28] Add shared infrastructure for functions push/pull Introduce foundational modules that the push and pull commands will build upon: - source_language: SourceLanguage enum and extension classification for distinguishing JS/TS from Python files - utils/fs_atomic: atomic file writes via temp-then-rename - utils/git: GitRepo discovery and dirty-state detection - js_runner: runner script materialization and JS runtime discovery (tsx, vite-node, ts-node, deno) - python_runner: Python interpreter resolution with venv support - scripts/runner-common.ts: shared TS types for runner manifests - scripts/python_runner_common.py: shared Python utilities for module loading, file normalization, and source collection --- Cargo.lock | 1 + Cargo.toml | 1 + scripts/python_runner_common.py | 99 +++++++++++ scripts/runner-common.ts | 83 ++++++++++ src/js_runner.rs | 283 ++++++++++++++++++++++++++++++++ src/main.rs | 7 + src/python_runner.rs | 84 ++++++++++ src/source_language.rs | 68 ++++++++ src/utils/fs_atomic.rs | 41 +++++ src/utils/git.rs | 106 ++++++++++++ src/utils/mod.rs | 4 + 11 files changed, 777 insertions(+) create mode 100644 scripts/python_runner_common.py create mode 100644 scripts/runner-common.ts create mode 100644 src/js_runner.rs create mode 100644 src/python_runner.rs create mode 100644 src/source_language.rs create mode 100644 src/utils/fs_atomic.rs create mode 100644 src/utils/git.rs diff --git a/Cargo.lock b/Cargo.lock index 9666f04..597cd9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,7 @@ dependencies = [ "dialoguer", "dirs", "dotenvy", + "flate2", "futures-util", "getrandom 0.3.4", "glob", diff --git a/Cargo.toml b/Cargo.toml index 32d50ab..7e300fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ chrono = { version = "0.4.40", features = ["clock"] } dirs = "5" pathdiff = "0.2.3" glob = "0.3" +flate2 = "1.1.2" [profile.dist] inherits = "release" diff --git a/scripts/python_runner_common.py b/scripts/python_runner_common.py new file mode 100644 index 0000000..4a738f9 --- /dev/null +++ b/scripts/python_runner_common.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +import importlib.util +import os +import sys +from types import ModuleType + + +def normalize_file_list(files: list[str]) -> list[str]: + unique: set[str] = set() + for file_path in files: + unique.add(os.path.abspath(file_path)) + return sorted(unique) + + +def resolve_module_info(in_file: str) -> tuple[str, list[str]]: + in_file = os.path.abspath(in_file) + module_dir = os.path.dirname(in_file) + module_name = os.path.splitext(os.path.basename(in_file))[0] + + package_parts: list[str] = [] + current = module_dir + while os.path.isfile(os.path.join(current, "__init__.py")): + package_parts.insert(0, os.path.basename(current)) + current = os.path.dirname(current) + + extra_paths = [module_dir] + if package_parts: + module_name = ".".join(package_parts + [module_name]) + if current not in extra_paths: + extra_paths.append(current) + + return module_name, extra_paths + + +def import_file(module_name: str, file_path: str, extra_paths: list[str]) -> ModuleType: + for extra_path in reversed(extra_paths): + if extra_path not in sys.path: + sys.path.insert(0, extra_path) + + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load module spec for {file_path}") + + sys.modules.pop(module_name, None) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def purge_local_modules(cwd: str, preserve_modules: set[str] | None = None) -> None: + preserved = preserve_modules or set() + cwd_abs = os.path.abspath(cwd) + for module_name, module in list(sys.modules.items()): + if module_name in preserved: + continue + module_file = getattr(module, "__file__", None) + if not module_file: + continue + candidate = module_file[:-1] if module_file.endswith(".pyc") else module_file + candidate_abs = os.path.abspath(candidate) + if not os.path.isfile(candidate_abs): + continue + try: + common = os.path.commonpath([candidate_abs, cwd_abs]) + except ValueError: + continue + if common == cwd_abs: + sys.modules.pop(module_name, None) + + +def collect_python_sources(cwd: str, input_file: str) -> list[str]: + sources: set[str] = set() + input_abs = os.path.abspath(input_file) + sources.add(input_abs) + + for module in list(sys.modules.values()): + module_file = getattr(module, "__file__", None) + if not module_file: + continue + candidate = module_file[:-1] if module_file.endswith(".pyc") else module_file + candidate_abs = os.path.abspath(candidate) + if not os.path.isfile(candidate_abs): + continue + if not candidate_abs.endswith(".py"): + continue + try: + common = os.path.commonpath([candidate_abs, cwd]) + except ValueError: + continue + if common != cwd: + continue + sources.add(candidate_abs) + + return sorted(sources) + + +def python_version() -> str: + return f"{sys.version_info.major}.{sys.version_info.minor}" diff --git a/scripts/runner-common.ts b/scripts/runner-common.ts new file mode 100644 index 0000000..8a0dd63 --- /dev/null +++ b/scripts/runner-common.ts @@ -0,0 +1,83 @@ +export type JsonPrimitive = string | number | boolean | null; +export type JsonArray = JsonValue[]; +export type JsonObject = { [key: string]: JsonValue }; +export type JsonValue = JsonPrimitive | JsonArray | JsonObject; + +export type ProjectSelector = { + project_id?: string; + project_name?: string; +}; + +export type ProjectRef = { + id?: string; + name?: string; +}; + +export function asProjectSelector( + project: ProjectRef | undefined, +): ProjectSelector { + if (!project) { + return {}; + } + + if (typeof project.id === "string" && project.id.trim().length > 0) { + return { project_id: project.id }; + } + + if (typeof project.name === "string" && project.name.trim().length > 0) { + return { project_name: project.name }; + } + + return {}; +} + +export function selectorToProjectId(selector: ProjectSelector): string { + if ( + typeof selector.project_id === "string" && + selector.project_id.trim().length > 0 + ) { + return selector.project_id; + } + + if ( + typeof selector.project_name === "string" && + selector.project_name.trim().length > 0 + ) { + return `name:${selector.project_name}`; + } + + return ""; +} + +export function isJsonObject( + value: JsonValue | undefined, +): value is JsonObject { + return typeof value === "object" && value !== null && !Array.isArray(value); +} + +export function toJsonValue(input: JsonValue): JsonValue { + if (Array.isArray(input)) { + return input.map((item) => toJsonValue(item)); + } + + if (input !== null && typeof input === "object") { + const out: JsonObject = {}; + for (const [key, value] of Object.entries(input)) { + if ( + value === null || + typeof value === "string" || + typeof value === "number" || + typeof value === "boolean" + ) { + out[key] = value; + } else if (Array.isArray(value)) { + out[key] = value.map((entry) => toJsonValue(entry)); + } else if (typeof value === "object") { + out[key] = toJsonValue(value as JsonObject); + } + } + return out; + } + + return input; +} diff --git a/src/js_runner.rs b/src/js_runner.rs new file mode 100644 index 0000000..cd4c94d --- /dev/null +++ b/src/js_runner.rs @@ -0,0 +1,283 @@ +use std::ffi::OsStr; +use std::path::{Path, PathBuf}; +use std::process::Command; + +use anyhow::{bail, Context, Result}; + +pub fn materialize_runner_script( + cache_dir: &Path, + file_name: &str, + source: &str, +) -> Result { + std::fs::create_dir_all(cache_dir).with_context(|| { + format!( + "failed to create runner cache directory {}", + cache_dir.display() + ) + })?; + ensure_not_symlink(cache_dir)?; + + let path = cache_dir.join(file_name); + ensure_not_symlink(&path)?; + let current = std::fs::read_to_string(&path).ok(); + if current.as_deref() != Some(source) { + crate::utils::write_text_atomic(&path, source) + .with_context(|| format!("failed to write runner script {}", path.display()))?; + } + Ok(path) +} + +pub fn materialize_runner_script_in_cwd( + cache_subdir: &str, + file_name: &str, + source: &str, +) -> Result { + let cwd = std::env::current_dir().context("failed to resolve current working directory")?; + let cache_dir = cwd + .join(".bt") + .join(cache_subdir) + .join(env!("CARGO_PKG_VERSION")); + ensure_descendant_components_not_symlinks(&cwd, &cache_dir)?; + materialize_runner_script(&cache_dir, file_name, source) +} + +pub fn build_js_runner_command( + runner_override: Option<&str>, + runner_script: &Path, + files: &[PathBuf], +) -> Command { + if let Some(explicit) = runner_override { + let resolved = resolve_js_runner_command(explicit, files); + if is_deno_runner_path(&resolved) { + return build_deno_command(resolved.as_os_str(), runner_script, files); + } + + let mut command = Command::new(&resolved); + command.arg(runner_script); + for file in files { + command.arg(file); + } + return command; + } + + if let Some(auto_runner) = find_js_runner_binary(files) { + if is_deno_runner_path(&auto_runner) { + return build_deno_command(auto_runner.as_os_str(), runner_script, files); + } + + let mut command = Command::new(&auto_runner); + command.arg(runner_script); + for file in files { + command.arg(file); + } + return command; + } + + let mut command = Command::new("npx"); + command.arg("--yes").arg("tsx").arg(runner_script); + for file in files { + command.arg(file); + } + command +} + +pub fn find_js_runner_binary(files: &[PathBuf]) -> Option { + const CANDIDATES: &[&str] = &["tsx", "vite-node", "ts-node", "ts-node-esm", "deno"]; + + for candidate in CANDIDATES { + if let Some(path) = find_node_module_bin_for_files(candidate, files) { + return Some(path); + } + } + + find_binary_in_path(CANDIDATES) +} + +pub fn resolve_js_runner_command(runner: &str, files: &[PathBuf]) -> PathBuf { + if is_path_like_runner(runner) { + return PathBuf::from(runner); + } + + find_node_module_bin_for_files(runner, files) + .or_else(|| find_binary_in_path(&[runner])) + .unwrap_or_else(|| PathBuf::from(runner)) +} + +fn build_deno_command(deno_runner: &OsStr, runner_script: &Path, files: &[PathBuf]) -> Command { + let mut command = Command::new(deno_runner); + command + .arg("run") + .arg("-A") + .arg("--node-modules-dir=auto") + .arg("--unstable-detect-cjs") + .arg(runner_script); + for file in files { + command.arg(file); + } + command +} + +fn is_path_like_runner(runner: &str) -> bool { + let path = Path::new(runner); + path.is_absolute() || runner.contains('/') || runner.contains('\\') || runner.starts_with('.') +} + +fn is_deno_runner_path(runner: &Path) -> bool { + runner + .file_name() + .and_then(|value| value.to_str()) + .map(|name| name.eq_ignore_ascii_case("deno") || name.eq_ignore_ascii_case("deno.exe")) + .unwrap_or(false) +} + +fn find_node_module_bin_for_files(binary: &str, files: &[PathBuf]) -> Option { + for root in js_runner_search_roots(files) { + if let Some(path) = find_node_module_bin(binary, &root) { + return Some(path); + } + } + None +} + +fn js_runner_search_roots(files: &[PathBuf]) -> Vec { + let mut roots = Vec::new(); + if let Ok(cwd) = std::env::current_dir() { + roots.push(cwd.clone()); + for file in files { + let absolute = if file.is_absolute() { + file.clone() + } else { + cwd.join(file) + }; + if let Some(parent) = absolute.parent() { + roots.push(parent.to_path_buf()); + } + } + } + roots +} + +fn find_node_module_bin(binary: &str, start: &Path) -> Option { + let mut current = Some(start); + while let Some(dir) = current { + let base = dir.join("node_modules").join(".bin").join(binary); + if base.is_file() { + return Some(base); + } + if cfg!(windows) { + for candidate in with_windows_extensions(&base) { + if candidate.is_file() { + return Some(candidate); + } + } + } + current = dir.parent(); + } + None +} + +fn find_binary_in_path(candidates: &[&str]) -> Option { + let paths = std::env::var_os("PATH")?; + for dir in std::env::split_paths(&paths) { + for candidate in candidates { + let path = dir.join(candidate); + if path.is_file() { + return Some(path); + } + if cfg!(windows) { + for candidate_path in with_windows_extensions(&path) { + if candidate_path.is_file() { + return Some(candidate_path); + } + } + } + } + } + None +} + +#[cfg(windows)] +fn with_windows_extensions(path: &Path) -> [PathBuf; 2] { + [path.with_extension("exe"), path.with_extension("cmd")] +} + +#[cfg(not(windows))] +fn with_windows_extensions(_path: &Path) -> [PathBuf; 0] { + [] +} + +fn ensure_descendant_components_not_symlinks(base: &Path, descendant: &Path) -> Result<()> { + let Ok(relative) = descendant.strip_prefix(base) else { + return Ok(()); + }; + + let mut current = base.to_path_buf(); + for component in relative.components() { + current.push(component.as_os_str()); + let metadata = match std::fs::symlink_metadata(¤t) { + Ok(metadata) => metadata, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => break, + Err(err) => { + return Err(err).with_context(|| { + format!("failed to inspect path component {}", current.display()) + }) + } + }; + if metadata.file_type().is_symlink() { + bail!( + "refusing to write runner script through symlink path component {}", + current.display() + ); + } + } + Ok(()) +} + +fn ensure_not_symlink(path: &Path) -> Result<()> { + match std::fs::symlink_metadata(path) { + Ok(metadata) => { + if metadata.file_type().is_symlink() { + bail!( + "refusing to write runner script via symlink {}", + path.display() + ); + } + } + Err(err) if err.kind() == std::io::ErrorKind::NotFound => {} + Err(err) => { + return Err(err) + .with_context(|| format!("failed to inspect runner path {}", path.display())) + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn path_like_runner_detection() { + assert!(is_path_like_runner("./tsx")); + assert!(is_path_like_runner("bin/tsx")); + assert!(!is_path_like_runner("tsx")); + } + + #[cfg(unix)] + #[test] + fn descendant_symlink_check_rejects_symlinked_component() { + use std::os::unix::fs::symlink; + + let dir = tempfile::tempdir().expect("tempdir"); + let base = dir.path().join("base"); + let real = dir.path().join("real"); + std::fs::create_dir_all(&base).expect("create base directory"); + std::fs::create_dir_all(&real).expect("create real directory"); + let link = base.join("link"); + symlink(&real, &link).expect("create symlink"); + + let err = ensure_descendant_components_not_symlinks(&base, &link.join("cache")) + .expect_err("must reject symlink path"); + assert!(err.to_string().contains("symlink")); + } +} diff --git a/src/main.rs b/src/main.rs index 0c2361a..a4e2496 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,11 +13,14 @@ mod experiments; mod functions; mod http; mod init; +mod js_runner; mod projects; mod prompts; +mod python_runner; mod scorers; mod self_update; mod setup; +mod source_language; mod sql; mod status; mod switch; @@ -28,6 +31,10 @@ mod ui; mod util_cmd; mod utils; +mod js_runner; +mod python_runner; +mod source_language; + use crate::args::{BaseArgs, CLIArgs}; const DEFAULT_CANARY_VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "-canary.dev"); diff --git a/src/python_runner.rs b/src/python_runner.rs new file mode 100644 index 0000000..bf9760c --- /dev/null +++ b/src/python_runner.rs @@ -0,0 +1,84 @@ +use std::path::{Path, PathBuf}; + +pub fn resolve_python_interpreter( + explicit: Option<&str>, + env_overrides: &[&str], +) -> Option { + if let Some(explicit) = explicit { + return Some(PathBuf::from(explicit)); + } + + for env_name in env_overrides { + if let Some(value) = std::env::var_os(env_name) { + if !value.is_empty() { + return Some(PathBuf::from(value)); + } + } + } + + // Process-internal interpreter discovery for active virtual environments. + if let Some(venv) = find_virtual_env_python() { + return Some(venv); + } + + find_binary_in_path(&["python3", "python"]) +} + +fn find_virtual_env_python() -> Option { + let venv_root = std::env::var_os("VIRTUAL_ENV")?; + let root = PathBuf::from(venv_root); + + let unix = root.join("bin").join("python"); + if unix.is_file() { + return Some(unix); + } + + let windows = root.join("Scripts").join("python.exe"); + if windows.is_file() { + return Some(windows); + } + + None +} + +pub fn find_binary_in_path(candidates: &[&str]) -> Option { + let paths = std::env::var_os("PATH")?; + for dir in std::env::split_paths(&paths) { + for candidate in candidates { + let path = dir.join(candidate); + if path.is_file() { + return Some(path); + } + if cfg!(windows) { + let exe = with_windows_extensions(&path); + for candidate_path in exe { + if candidate_path.is_file() { + return Some(candidate_path); + } + } + } + } + } + None +} + +#[cfg(windows)] +fn with_windows_extensions(path: &Path) -> [PathBuf; 2] { + [path.with_extension("exe"), path.with_extension("cmd")] +} + +#[cfg(not(windows))] +fn with_windows_extensions(_path: &Path) -> [PathBuf; 0] { + [] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn explicit_python_runner_wins() { + let resolved = resolve_python_interpreter(Some("/tmp/python"), &["BT_UNUSED"]); + assert_eq!(resolved, Some(PathBuf::from("/tmp/python"))); + } +} diff --git a/src/source_language.rs b/src/source_language.rs new file mode 100644 index 0000000..8a1b71f --- /dev/null +++ b/src/source_language.rs @@ -0,0 +1,68 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SourceLanguage { + JsLike, + Python, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum JsExtensionProfile { + FunctionsPush, + Eval, +} + +pub fn classify_runtime_extension( + ext: &str, + js_profile: JsExtensionProfile, +) -> Option { + let normalized = ext.to_ascii_lowercase(); + if normalized == "py" { + return Some(SourceLanguage::Python); + } + + let is_js_like = match js_profile { + JsExtensionProfile::FunctionsPush => { + matches!(normalized.as_str(), "ts" | "tsx" | "js" | "jsx") + } + JsExtensionProfile::Eval => { + matches!(normalized.as_str(), "ts" | "tsx" | "js" | "mjs" | "cjs") + } + }; + if is_js_like { + Some(SourceLanguage::JsLike) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn classifies_push_extensions_case_insensitively() { + assert_eq!( + classify_runtime_extension("TS", JsExtensionProfile::FunctionsPush), + Some(SourceLanguage::JsLike) + ); + assert_eq!( + classify_runtime_extension("Py", JsExtensionProfile::FunctionsPush), + Some(SourceLanguage::Python) + ); + assert_eq!( + classify_runtime_extension("mjs", JsExtensionProfile::FunctionsPush), + None + ); + } + + #[test] + fn classifies_eval_extensions() { + assert_eq!( + classify_runtime_extension("mjs", JsExtensionProfile::Eval), + Some(SourceLanguage::JsLike) + ); + assert_eq!( + classify_runtime_extension("cjs", JsExtensionProfile::Eval), + Some(SourceLanguage::JsLike) + ); + } +} diff --git a/src/utils/fs_atomic.rs b/src/utils/fs_atomic.rs new file mode 100644 index 0000000..4906a74 --- /dev/null +++ b/src/utils/fs_atomic.rs @@ -0,0 +1,41 @@ +use std::path::Path; + +use anyhow::{Context, Result}; + +pub fn write_text_atomic(path: &Path, contents: &str) -> Result<()> { + let parent = path.parent().ok_or_else(|| { + anyhow::anyhow!( + "cannot atomically write {} because it has no parent directory", + path.display() + ) + })?; + + std::fs::create_dir_all(parent) + .with_context(|| format!("failed to create parent directory {}", parent.display()))?; + + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .context("failed to read system time for atomic write")? + .as_nanos(); + let pid = std::process::id(); + + let file_name = path + .file_name() + .and_then(|name| name.to_str()) + .ok_or_else(|| anyhow::anyhow!("invalid target file name: {}", path.display()))?; + + let tmp = parent.join(format!(".{file_name}.tmp.{pid}.{nonce}")); + + std::fs::write(&tmp, contents) + .with_context(|| format!("failed to write temporary file {}", tmp.display()))?; + + std::fs::rename(&tmp, path).with_context(|| { + format!( + "failed to replace {} with temporary file {}", + path.display(), + tmp.display() + ) + })?; + + Ok(()) +} diff --git a/src/utils/git.rs b/src/utils/git.rs new file mode 100644 index 0000000..cbbf53b --- /dev/null +++ b/src/utils/git.rs @@ -0,0 +1,106 @@ +use std::path::{Path, PathBuf}; +use std::process::Command; + +use anyhow::{Context, Result}; + +#[derive(Debug, Clone)] +pub struct GitRepo { + root: PathBuf, +} + +impl GitRepo { + pub fn root(&self) -> &Path { + &self.root + } + + pub fn is_dirty_or_untracked(&self, path: &Path) -> Result { + let output = Command::new("git") + .arg("-C") + .arg(&self.root) + .arg("status") + .arg("--porcelain") + .arg("--") + .arg(path) + .output() + .with_context(|| { + format!( + "failed to check git status for {} in {}", + path.display(), + self.root.display() + ) + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!( + "git status failed for {} in {}: {}", + path.display(), + self.root.display(), + stderr.trim() + ); + } + + Ok(!String::from_utf8_lossy(&output.stdout).trim().is_empty()) + } + + pub fn discover_from(path: &Path) -> Option { + find_repo_root_from(path).map(|root| Self { root }) + } +} + +pub fn find_repo_root_from(start: &Path) -> Option { + let mut current = start.to_path_buf(); + if current.is_file() { + current = current.parent()?.to_path_buf(); + } + + loop { + if current.join(".git").exists() { + return Some(current); + } + if !current.pop() { + return None; + } + } +} + +#[cfg(test)] +mod tests { + use std::fs; + + use super::*; + + #[test] + fn find_repo_root_detects_git_dir() { + let unique = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("clock") + .as_nanos(); + let root = std::env::temp_dir().join(format!("bt-git-root-{unique}")); + let nested = root.join("a").join("b"); + fs::create_dir_all(&nested).expect("create nested dirs"); + fs::create_dir_all(root.join(".git")).expect("create .git dir"); + + let found = find_repo_root_from(&nested).expect("should find root"); + assert_eq!(found, root); + + let _ = fs::remove_dir_all(found); + } + + #[test] + fn find_repo_root_detects_git_file() { + let unique = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("clock") + .as_nanos(); + let root = std::env::temp_dir().join(format!("bt-git-file-{unique}")); + let nested = root.join("x").join("y"); + fs::create_dir_all(&nested).expect("create nested dirs"); + fs::write(root.join(".git"), "gitdir: /tmp/mock").expect("write .git file"); + + let found = find_repo_root_from(&nested).expect("should find root"); + assert_eq!(found, root); + + let _ = fs::remove_dir_all(found); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ef1b708..7a5b4da 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,3 +1,7 @@ +mod fs_atomic; +mod git; mod plurals; +pub use fs_atomic::write_text_atomic; +pub use git::GitRepo; pub use plurals::pluralize; From 0e3d5d505a7ee836d9bf70239b4b5be4656df1f8 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Thu, 5 Mar 2026 16:36:59 -0800 Subject: [PATCH 02/28] Add functions push/pull command structure, API layer, and report types Introduce the command scaffolding for `bt functions push` and `bt functions pull`: - functions/mod.rs: PushArgs, PullArgs, AuthContext, IfExistsMode, FunctionsLanguage enums, and refactored context resolution - functions/api.rs: paginated function listing, code upload slots, bundle upload, and batch insert endpoints - functions/report.rs: structured report types for JSON output with HardFailureReason, SoftSkipReason, and summary types - auth.rs: AvailableOrg struct and list_available_orgs() - http.rs: put_signed_url() for uploading to signed URLs --- src/auth.rs | 35 +++ src/functions/api.rs | 270 ++++++++++++++++++++- src/functions/mod.rs | 512 +++++++++++++++++++++++++++++++++++++--- src/functions/report.rs | 172 ++++++++++++++ src/http.rs | 31 +++ 5 files changed, 987 insertions(+), 33 deletions(-) create mode 100644 src/functions/report.rs diff --git a/src/auth.rs b/src/auth.rs index 29a70be..384fd98 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -60,6 +60,13 @@ pub struct ProfileInfo { pub api_key_hint: Option, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AvailableOrg { + pub id: String, + pub name: String, + pub api_url: Option, +} + pub fn list_profiles() -> Result> { let store = load_auth_store()?; Ok(store @@ -155,6 +162,34 @@ pub fn select_profile_interactive(current: Option<&str>) -> Result Result> { + let resolved = resolve_auth(base).await?; + let app_url = resolved + .app_url + .unwrap_or_else(|| DEFAULT_APP_URL.to_string()); + let api_key = match resolved.api_key { + Some(api_key) => api_key, + None => login(base).await?.login.api_key, + }; + + let mut orgs = fetch_login_orgs(&api_key, &app_url).await?; + orgs.sort_by(|a, b| { + a.name + .to_ascii_lowercase() + .cmp(&b.name.to_ascii_lowercase()) + .then_with(|| a.name.cmp(&b.name)) + }); + + Ok(orgs + .into_iter() + .map(|org| AvailableOrg { + id: org.id, + name: org.name, + api_url: org.api_url, + }) + .collect()) +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] struct AuthStore { #[serde(default)] diff --git a/src/functions/api.rs b/src/functions/api.rs index a230046..34b9972 100644 --- a/src/functions/api.rs +++ b/src/functions/api.rs @@ -1,5 +1,6 @@ -use anyhow::Result; +use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; +use serde_json::Value; use urlencoding::encode; use crate::http::ApiClient; @@ -32,6 +33,37 @@ pub struct Function { pub _xact_id: Option, } +#[derive(Debug, Clone, Default)] +pub struct FunctionListQuery { + pub project_id: Option, + pub project_name: Option, + pub slug: Option, + pub id: Option, + pub cursor: Option, + pub snapshot: Option, +} + +#[derive(Debug, Clone)] +pub struct FunctionListPage { + pub objects: Vec, + pub next_cursor: Option, + pub snapshot: Option, + pub pagination_field_present: bool, + pub snapshot_field_present: bool, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct CodeUploadSlot { + pub url: String, + #[serde(rename = "bundleId")] + pub bundle_id: String, +} + +#[derive(Debug, Clone)] +pub struct InsertFunctionsResult { + pub ignored_entries: Option, +} + pub async fn list_functions( client: &ApiClient, project_id: &str, @@ -83,3 +115,239 @@ pub async fn delete_function(client: &ApiClient, function_id: &str) -> Result<() let path = format!("/v1/function/{}", encode(function_id)); client.delete(&path).await } + +pub async fn list_functions_page( + client: &ApiClient, + query: &FunctionListQuery, +) -> Result { + let mut params = Vec::new(); + if let Some(project_id) = &query.project_id { + params.push(("project_id", project_id.clone())); + } + if let Some(project_name) = &query.project_name { + params.push(("project_name", project_name.clone())); + } + if let Some(slug) = &query.slug { + params.push(("slug", slug.clone())); + } + if let Some(id) = &query.id { + params.push(("ids", id.clone())); + } + if let Some(cursor) = &query.cursor { + params.push(("cursor", cursor.clone())); + } + if let Some(snapshot) = &query.snapshot { + params.push(("snapshot", snapshot.clone())); + } + + let path = if params.is_empty() { + "/v1/function".to_string() + } else { + let query = params + .into_iter() + .map(|(key, value)| format!("{}={}", encode(key), encode(&value))) + .collect::>() + .join("&"); + format!("/v1/function?{query}") + }; + + let raw: Value = client + .get(&path) + .await + .with_context(|| format!("failed to list functions via {path}"))?; + + parse_function_list_page(raw) +} + +fn parse_function_list_page(raw: Value) -> Result { + let objects = raw + .get("objects") + .and_then(Value::as_array) + .cloned() + .ok_or_else(|| anyhow::anyhow!("missing 'objects' array in /v1/function response"))?; + + let explicit_next_cursor = raw + .get("next_cursor") + .and_then(Value::as_str) + .or_else(|| raw.get("nextCursor").and_then(Value::as_str)) + .or_else(|| raw.get("next").and_then(Value::as_str)) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned); + + let cursor_field = raw + .get("cursor") + .and_then(Value::as_str) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned); + + let has_more = raw + .get("has_more") + .and_then(Value::as_bool) + .or_else(|| raw.get("hasMore").and_then(Value::as_bool)); + + let next_cursor = explicit_next_cursor.or(match has_more { + Some(false) => None, + _ => cursor_field, + }); + + let snapshot = raw + .get("snapshot") + .and_then(Value::as_str) + .or_else(|| raw.get("snapshot_id").and_then(Value::as_str)) + .or_else(|| raw.get("as_of").and_then(Value::as_str)) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned); + + Ok(FunctionListPage { + objects, + next_cursor, + snapshot, + pagination_field_present: raw.get("next_cursor").is_some() + || raw.get("nextCursor").is_some() + || raw.get("next").is_some() + || raw.get("cursor").is_some() + || raw.get("has_more").is_some() + || raw.get("hasMore").is_some(), + snapshot_field_present: raw.get("snapshot").is_some() + || raw.get("snapshot_id").is_some() + || raw.get("as_of").is_some(), + }) +} + +pub async fn request_code_upload_slot( + client: &ApiClient, + org_id: &str, + runtime: &str, + version: &str, +) -> Result { + let body = serde_json::json!({ + "org_id": org_id, + "runtime_context": { + "runtime": runtime, + "version": version, + } + }); + + client + .post("/function/code", &body) + .await + .context("failed to request code upload slot") +} + +pub async fn upload_bundle( + url: &str, + bundle_bytes: Vec, + content_encoding: Option<&str>, +) -> Result<()> { + crate::http::put_signed_url(url, bundle_bytes, content_encoding) + .await + .context("failed to upload code bundle to signed URL") +} + +pub async fn insert_functions( + client: &ApiClient, + functions: &[Value], +) -> Result { + let body = serde_json::json!({ "functions": functions }); + let raw: Value = client + .post("/insert-functions", &body) + .await + .context("failed to insert functions")?; + + Ok(InsertFunctionsResult { + ignored_entries: ignored_count(&raw), + }) +} + +fn ignored_count(raw: &Value) -> Option { + if let Some(count) = raw.get("ignored_count").and_then(Value::as_u64) { + return usize::try_from(count).ok(); + } + + if let Some(items) = raw.get("ignored").and_then(Value::as_array) { + return Some(items.len()); + } + + if let Some(count) = raw + .get("stats") + .and_then(Value::as_object) + .and_then(|stats| stats.get("ignored")) + .and_then(Value::as_u64) + { + return usize::try_from(count).ok(); + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ignored_count_extracts_known_shapes() { + let first = serde_json::json!({ "ignored_count": 3 }); + assert_eq!(ignored_count(&first), Some(3)); + + let second = serde_json::json!({ "ignored": [1, 2] }); + assert_eq!(ignored_count(&second), Some(2)); + + let third = serde_json::json!({ "stats": { "ignored": 5 } }); + assert_eq!(ignored_count(&third), Some(5)); + + assert_eq!(ignored_count(&serde_json::json!({})), None); + } + + #[test] + fn parse_function_list_page_allows_non_paginated_shape() { + let raw = serde_json::json!({ + "objects": [], + }); + + let page = parse_function_list_page(raw).expect("parse function page"); + assert!(page.objects.is_empty()); + assert!(!page.pagination_field_present); + assert!(page.next_cursor.is_none()); + } + + #[test] + fn parse_function_list_page_detects_next_pagination_field() { + let raw = serde_json::json!({ + "objects": [], + "next": "cursor-1", + }); + + let page = parse_function_list_page(raw).expect("parse function page"); + assert!(page.pagination_field_present); + assert_eq!(page.next_cursor.as_deref(), Some("cursor-1")); + } + + #[test] + fn parse_function_list_page_supports_cursor_has_more_shape() { + let raw = serde_json::json!({ + "objects": [], + "cursor": "cursor-2", + "has_more": true, + }); + + let page = parse_function_list_page(raw).expect("parse function page"); + assert!(page.pagination_field_present); + assert_eq!(page.next_cursor.as_deref(), Some("cursor-2")); + } + + #[test] + fn parse_function_list_page_ignores_cursor_when_has_more_false() { + let raw = serde_json::json!({ + "objects": [], + "cursor": "cursor-2", + "has_more": false, + }); + + let page = parse_function_list_page(raw).expect("parse function page"); + assert!(page.pagination_field_present); + assert!(page.next_cursor.is_none()); + } +} diff --git a/src/functions/mod.rs b/src/functions/mod.rs index 3596449..615bbaa 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -1,5 +1,7 @@ +use std::path::PathBuf; + use anyhow::{anyhow, bail, Result}; -use clap::{Args, Subcommand, ValueEnum}; +use clap::{builder::BoolishValueParser, Args, Subcommand, ValueEnum}; use crate::{ args::BaseArgs, @@ -14,6 +16,9 @@ pub(crate) mod api; mod delete; mod invoke; mod list; +mod pull; +mod push; +pub(crate) mod report; mod view; use api::Function; @@ -72,6 +77,38 @@ impl FunctionTypeFilter { } } +#[derive(Debug, Clone, Copy, ValueEnum, serde::Serialize, serde::Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum IfExistsMode { + Error, + Replace, + Ignore, +} + +impl IfExistsMode { + pub fn as_str(self) -> &'static str { + match self { + Self::Error => "error", + Self::Replace => "replace", + Self::Ignore => "ignore", + } + } +} + +#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)] +pub enum FunctionsLanguage { + Typescript, + Python, +} + +#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)] +pub enum PushLanguage { + Auto, + #[value(name = "javascript")] + JavaScript, + Python, +} + fn build_web_path(function: &Function) -> String { let id = &function.id; match function.function_type.as_deref() { @@ -98,8 +135,6 @@ fn label_plural(ft: Option) -> &'static str { ft.map_or("functions", |f| f.plural()) } -// --- Slug args (shared) --- - #[derive(Debug, Clone, Args)] struct SlugArgs { /// Function slug @@ -118,8 +153,6 @@ impl SlugArgs { } } -// --- Wrapper args (bt tools / bt scorers) --- - #[derive(Debug, Clone, Args)] #[command(after_help = "\ Examples: @@ -145,19 +178,20 @@ enum FunctionCommands { Invoke(invoke::InvokeArgs), } -// --- bt functions args --- - #[derive(Debug, Clone, Args)] #[command(after_help = "\ Examples: bt functions list bt functions view my-function bt functions invoke my-function --input '{\"key\":\"value\"}' + bt functions push --file ./functions + bt functions pull --output-dir ./braintrust ")] pub struct FunctionsArgs { /// Filter by function type #[arg(long = "type", short = 't', value_enum)] function_type: Option, + #[command(subcommand)] command: Option, } @@ -167,11 +201,15 @@ enum FunctionsCommands { /// List functions in the current project List(FunctionsListArgs), /// View function details - View(ViewArgs), + View(FunctionsViewArgs), /// Delete a function Delete(FunctionsDeleteArgs), /// Invoke a function Invoke(FunctionsInvokeArgs), + /// Push local function definitions + Push(PushArgs), + /// Pull remote function definitions + Pull(PullArgs), } #[derive(Debug, Clone, Args)] @@ -181,6 +219,15 @@ struct FunctionsListArgs { function_type: Option, } +#[derive(Debug, Clone, Args)] +struct FunctionsViewArgs { + #[command(flatten)] + inner: ViewArgs, + /// Filter by function type (for interactive selection) + #[arg(long = "type", short = 't', value_enum)] + function_type: Option, +} + #[derive(Debug, Clone, Args)] struct FunctionsDeleteArgs { #[command(flatten)] @@ -208,7 +255,112 @@ struct FunctionsInvokeArgs { function_type: Option, } -// --- Shared view/delete args --- +#[derive(Debug, Clone, Args)] +pub(crate) struct PushArgs { + /// File or directory path(s) to scan for function definitions. + #[arg( + long = "file", + env = "BT_FUNCTIONS_PUSH_FILES", + default_value = ".", + value_name = "PATH", + value_delimiter = ',' + )] + pub files: Vec, + + /// Behavior when a function with the same slug already exists. + #[arg( + long = "if-exists", + env = "BT_FUNCTIONS_PUSH_IF_EXISTS", + value_enum, + default_value = "error" + )] + pub if_exists: IfExistsMode, + + /// Stop after the first hard failure. + #[arg( + long, + env = "BT_FUNCTIONS_PUSH_TERMINATE_ON_FAILURE", + default_value_t = false, + value_parser = BoolishValueParser::new() + )] + pub terminate_on_failure: bool, + + /// Override runner binary (e.g. tsx, vite-node, deno, python). + #[arg(long, env = "BT_FUNCTIONS_PUSH_RUNNER", value_name = "RUNNER")] + pub runner: Option, + + /// Force runtime language selection. + #[arg( + long = "language", + env = "BT_FUNCTIONS_PUSH_LANGUAGE", + value_enum, + default_value = "auto" + )] + pub language: PushLanguage, + + /// Optional Python requirements file. + #[arg(long, env = "BT_FUNCTIONS_PUSH_REQUIREMENTS", value_name = "PATH")] + pub requirements: Option, + + /// Create missing projects referenced by function definitions. + #[arg( + long = "create-missing-projects", + env = "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS", + default_value_t = false, + value_parser = BoolishValueParser::new() + )] + pub create_missing_projects: bool, +} + +#[derive(Debug, Clone, Args)] +pub(crate) struct PullArgs { + /// Destination directory for generated files. + #[arg( + long, + env = "BT_FUNCTIONS_PULL_OUTPUT_DIR", + default_value = ".", + value_name = "PATH" + )] + pub output_dir: PathBuf, + + /// Output language. + #[arg( + long = "language", + env = "BT_FUNCTIONS_PULL_LANGUAGE", + value_enum, + default_value = "typescript" + )] + pub language: FunctionsLanguage, + + /// Project name filter. + #[arg(long, env = "BT_FUNCTIONS_PULL_PROJECT_NAME")] + pub project_name: Option, + + /// Project id filter. + #[arg( + long, + env = "BT_FUNCTIONS_PULL_PROJECT_ID", + conflicts_with = "project_name" + )] + pub project_id: Option, + + /// Function id selector. + #[arg(long, env = "BT_FUNCTIONS_PULL_ID", conflicts_with = "slug")] + pub id: Option, + + /// Function slug selector. + #[arg(long, env = "BT_FUNCTIONS_PULL_SLUG")] + pub slug: Option, + + /// Overwrite targets even when dirty or already existing. + #[arg( + long, + env = "BT_FUNCTIONS_PULL_FORCE", + default_value_t = false, + value_parser = BoolishValueParser::new() + )] + pub force: bool, +} #[derive(Debug, Clone, Args)] pub struct ViewArgs { @@ -243,7 +395,11 @@ impl DeleteArgs { } } -// --- Resolved context --- +pub(crate) struct AuthContext { + pub client: ApiClient, + pub app_url: String, + pub org_id: String, +} pub(crate) struct ResolvedContext { pub client: ApiClient, @@ -251,27 +407,58 @@ pub(crate) struct ResolvedContext { pub project: Project, } -async fn resolve_context(base: &BaseArgs) -> Result { +pub(crate) async fn resolve_auth_context(base: &BaseArgs) -> Result { let ctx = login(base).await?; let client = ApiClient::new(&ctx)?; + Ok(AuthContext { + client, + app_url: ctx.app_url, + org_id: ctx.login.org_id, + }) +} + +pub(crate) async fn resolve_project_context( + base: &BaseArgs, + auth_ctx: &AuthContext, +) -> Result { + resolve_project_context_optional(base, auth_ctx, true) + .await? + .ok_or_else(|| anyhow!("--project required (or set BRAINTRUST_DEFAULT_PROJECT)")) +} + +pub(crate) async fn resolve_project_context_optional( + base: &BaseArgs, + auth_ctx: &AuthContext, + allow_interactive_selection: bool, +) -> Result> { let config_project = config::load().ok().and_then(|c| c.project); let project_name = match base.project.as_deref().or(config_project.as_deref()) { - Some(p) => p.to_string(), - None if is_interactive() => select_project_interactive(&client, None, None).await?, - None => anyhow::bail!("--project required (or set BRAINTRUST_DEFAULT_PROJECT)"), + Some(p) => Some(p.to_string()), + None if allow_interactive_selection && is_interactive() => { + Some(select_project_interactive(&auth_ctx.client, None, None).await?) + } + None => None, }; - let project = get_project_by_name(&client, &project_name) - .await? - .ok_or_else(|| anyhow!("project '{project_name}' not found"))?; + + match project_name { + Some(project_name) => get_project_by_name(&auth_ctx.client, &project_name) + .await? + .map(Some) + .ok_or_else(|| anyhow!("project '{project_name}' not found")), + None => Ok(None), + } +} + +async fn resolve_context(base: &BaseArgs) -> Result { + let auth_ctx = resolve_auth_context(base).await?; + let project = resolve_project_context(base, &auth_ctx).await?; Ok(ResolvedContext { - client, - app_url: ctx.app_url, + client: auth_ctx.client, + app_url: auth_ctx.app_url, project, }) } -// --- Interactive selection --- - pub(crate) async fn select_function_interactive( client: &ApiClient, project_id: &str, @@ -306,8 +493,6 @@ pub(crate) async fn select_function_interactive( Ok(functions[selection].clone()) } -// --- Entry points --- - pub async fn run_typed(base: BaseArgs, args: FunctionArgs, kind: FunctionTypeFilter) -> Result<()> { let ctx = resolve_context(&base).await?; let ft = Some(kind); @@ -322,26 +507,289 @@ pub async fn run_typed(base: BaseArgs, args: FunctionArgs, kind: FunctionTypeFil } pub async fn run(base: BaseArgs, args: FunctionsArgs) -> Result<()> { - let ctx = resolve_context(&base).await?; match args.command { - None => list::run(&ctx, base.json, args.function_type).await, - Some(FunctionsCommands::List(ref la)) => list::run(&ctx, base.json, la.function_type).await, + None => { + let ctx = resolve_context(&base).await?; + list::run(&ctx, base.json, args.function_type).await + } + Some(FunctionsCommands::List(ref la)) => { + let ctx = resolve_context(&base).await?; + list::run(&ctx, base.json, la.function_type.or(args.function_type)).await + } Some(FunctionsCommands::View(v)) => { + let ctx = resolve_context(&base).await?; view::run( &ctx, - v.slug(), + v.inner.slug(), base.json, - v.web, - v.verbose, - args.function_type, + v.inner.web, + v.inner.verbose, + v.function_type.or(args.function_type), ) .await } Some(FunctionsCommands::Delete(d)) => { - delete::run(&ctx, d.slug(), d.force, d.function_type).await + let ctx = resolve_context(&base).await?; + delete::run( + &ctx, + d.slug(), + d.force, + d.function_type.or(args.function_type), + ) + .await } Some(FunctionsCommands::Invoke(i)) => { - invoke::run(&ctx, &i.inner, base.json, i.function_type).await + let ctx = resolve_context(&base).await?; + invoke::run( + &ctx, + &i.inner, + base.json, + i.function_type.or(args.function_type), + ) + .await } + Some(FunctionsCommands::Push(args)) => push::run(base, args).await, + Some(FunctionsCommands::Pull(args)) => pull::run(base, args).await, + } +} + +#[cfg(test)] +mod tests { + use std::sync::{Mutex, MutexGuard, OnceLock}; + + use clap::{Parser, Subcommand}; + + use super::*; + + #[derive(Debug, Parser)] + struct CliHarness { + #[command(subcommand)] + command: Commands, + } + + #[derive(Debug, Subcommand)] + enum Commands { + Functions(FunctionsArgs), + } + + fn parse(args: &[&str]) -> anyhow::Result { + let mut argv = vec!["bt"]; + argv.extend_from_slice(args); + let parsed = CliHarness::try_parse_from(argv)?; + match parsed.command { + Commands::Functions(args) => Ok(args), + } + } + + fn test_lock() -> MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|err| err.into_inner()) + } + + #[test] + fn push_rejects_legacy_type_flag() { + let _guard = test_lock(); + let err = parse(&["functions", "push", "--type", "tool"]).expect_err("should fail"); + let msg = err.to_string(); + assert!(msg.contains("--type")); + } + + #[test] + fn top_level_type_flag_still_parses_for_functions_namespace() { + let _guard = test_lock(); + let parsed = parse(&["functions", "--type", "tool"]).expect("parse functions"); + assert!(matches!( + parsed.function_type, + Some(FunctionTypeFilter::Tool) + )); + } + + #[test] + fn push_file_env_uses_delimiter() { + let _guard = test_lock(); + unsafe { + std::env::set_var("BT_FUNCTIONS_PUSH_FILES", "a.ts,b.ts"); + } + let parsed = parse(&["functions", "push"]).expect("parse push"); + unsafe { + std::env::remove_var("BT_FUNCTIONS_PUSH_FILES"); + } + + let FunctionsCommands::Push(push) = parsed.command.expect("subcommand") else { + panic!("expected push command"); + }; + + assert_eq!( + push.files, + vec![PathBuf::from("a.ts"), PathBuf::from("b.ts")] + ); + } + + #[test] + fn push_boolish_flag_from_env() { + let _guard = test_lock(); + unsafe { + std::env::set_var("BT_FUNCTIONS_PUSH_TERMINATE_ON_FAILURE", "true"); + } + let parsed = parse(&["functions", "push"]).expect("parse push"); + unsafe { + std::env::remove_var("BT_FUNCTIONS_PUSH_TERMINATE_ON_FAILURE"); + } + + let FunctionsCommands::Push(push) = parsed.command.expect("subcommand") else { + panic!("expected push command"); + }; + assert!(push.terminate_on_failure); + } + + #[test] + fn push_create_missing_projects_flag_from_env() { + let _guard = test_lock(); + unsafe { + std::env::set_var("BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS", "true"); + } + let parsed = parse(&["functions", "push"]).expect("parse push"); + unsafe { + std::env::remove_var("BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS"); + } + + let FunctionsCommands::Push(push) = parsed.command.expect("subcommand") else { + panic!("expected push command"); + }; + assert!(push.create_missing_projects); + } + + #[test] + fn push_repeated_file_flags_append_in_order() { + let _guard = test_lock(); + let parsed = parse(&[ + "functions", + "push", + "--file", + "a.ts", + "--file", + "b.ts", + "--file", + "c.ts", + ]) + .expect("parse push"); + + let FunctionsCommands::Push(push) = parsed.command.expect("subcommand") else { + panic!("expected push command"); + }; + assert_eq!( + push.files, + vec![ + PathBuf::from("a.ts"), + PathBuf::from("b.ts"), + PathBuf::from("c.ts") + ] + ); + } + + #[test] + fn push_language_from_env() { + let _guard = test_lock(); + unsafe { + std::env::set_var("BT_FUNCTIONS_PUSH_LANGUAGE", "python"); + } + let parsed = parse(&["functions", "push"]).expect("parse push"); + unsafe { + std::env::remove_var("BT_FUNCTIONS_PUSH_LANGUAGE"); + } + + let FunctionsCommands::Push(push) = parsed.command.expect("subcommand") else { + panic!("expected push command"); + }; + assert_eq!(push.language, PushLanguage::Python); + } + + #[test] + fn push_requirements_from_env() { + let _guard = test_lock(); + unsafe { + std::env::set_var("BT_FUNCTIONS_PUSH_REQUIREMENTS", "requirements.txt"); + } + let parsed = parse(&["functions", "push"]).expect("parse push"); + unsafe { + std::env::remove_var("BT_FUNCTIONS_PUSH_REQUIREMENTS"); + } + + let FunctionsCommands::Push(push) = parsed.command.expect("subcommand") else { + panic!("expected push command"); + }; + assert_eq!(push.requirements, Some(PathBuf::from("requirements.txt"))); + } + + #[test] + fn pull_language_from_env() { + let _guard = test_lock(); + unsafe { + std::env::set_var("BT_FUNCTIONS_PULL_LANGUAGE", "python"); + } + let parsed = parse(&["functions", "pull"]).expect("parse pull"); + unsafe { + std::env::remove_var("BT_FUNCTIONS_PULL_LANGUAGE"); + } + + let FunctionsCommands::Pull(pull) = parsed.command.expect("subcommand") else { + panic!("expected pull command"); + }; + assert_eq!(pull.language, FunctionsLanguage::Python); + } + + #[test] + fn pull_language_defaults_to_typescript() { + let _guard = test_lock(); + unsafe { + std::env::remove_var("BT_FUNCTIONS_PULL_LANGUAGE"); + } + let parsed = parse(&["functions", "pull"]).expect("parse pull"); + let FunctionsCommands::Pull(pull) = parsed.command.expect("subcommand") else { + panic!("expected pull command"); + }; + assert_eq!(pull.language, FunctionsLanguage::Typescript); + } + + #[test] + fn pull_rejects_invalid_language() { + let _guard = test_lock(); + let err = parse(&["functions", "pull", "--language", "ruby"]).expect_err("should fail"); + assert!(err.to_string().contains("ruby")); + } + + #[test] + fn push_rejects_invalid_language() { + let _guard = test_lock(); + let err = + parse(&["functions", "push", "--language", "typescript"]).expect_err("should fail"); + assert!(err.to_string().contains("typescript")); + } + + #[test] + fn pull_conflicts_project_selectors() { + let _guard = test_lock(); + let err = parse(&[ + "functions", + "pull", + "--project-id", + "p1", + "--project-name", + "proj", + ]) + .expect_err("should conflict"); + + assert!(err.to_string().contains("--project-name")); + } + + #[test] + fn pull_conflicts_id_and_slug() { + let _guard = test_lock(); + let err = parse(&["functions", "pull", "--id", "f1", "--slug", "slug"]) + .expect_err("should conflict"); + + assert!(err.to_string().contains("--slug")); } } diff --git a/src/functions/report.rs b/src/functions/report.rs new file mode 100644 index 0000000..b6953ca --- /dev/null +++ b/src/functions/report.rs @@ -0,0 +1,172 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum CommandStatus { + Success, + Partial, + Failed, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum FileStatus { + Success, + Skipped, + Failed, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum HardFailureReason { + AuthFailed, + RequestFailed, + ResponseInvalid, + UserCancelled, + OutputDirInvalid, + AtomicWriteFailed, + UnsafeOutputPath, + RunnerSpawnFailed, + RunnerExitNonzero, + ManifestInvalidJson, + ManifestSchemaInvalid, + ManifestPathMissing, + UploadSlotFailed, + BundleUploadFailed, + InsertFunctionsFailed, + SelectorNotFound, + PaginationUnsupported, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum SoftSkipReason { + DirtyTarget, + ExistingNonGitNoForce, + MalformedRecord, + UnsupportedFunctionType, + SupersededVersion, + TerminatedAfterFailure, + IfExistsIgnored, + NoDefinitionsFound, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum WarningReason { + PaginationNotSnapshotConsistent, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ReportWarning { + pub reason: WarningReason, + pub message: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ReportError { + pub reason: HardFailureReason, + pub message: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct PushFileReport { + pub source_file: String, + pub status: FileStatus, + pub uploaded_entries: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub skipped_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub bundle_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct PushSummary { + pub status: CommandStatus, + pub total_files: usize, + pub uploaded_files: usize, + pub failed_files: usize, + pub skipped_files: usize, + pub ignored_entries: usize, + pub files: Vec, + pub warnings: Vec, + pub errors: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct PullFileReport { + pub output_file: String, + pub status: FileStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub skipped_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct PullSummary { + pub status: CommandStatus, + pub projects_total: usize, + pub files_written: usize, + pub files_skipped: usize, + pub files_failed: usize, + pub functions_seen: usize, + pub functions_materialized: usize, + pub malformed_records_skipped: usize, + pub unsupported_records_skipped: usize, + pub files: Vec, + pub warnings: Vec, + pub errors: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn enums_serialize_as_snake_case() { + let reason = serde_json::to_string(&HardFailureReason::ManifestInvalidJson) + .expect("serialize reason"); + assert_eq!(reason, "\"manifest_invalid_json\""); + + let status = serde_json::to_string(&CommandStatus::Partial).expect("serialize status"); + assert_eq!(status, "\"partial\""); + + let warning = serde_json::to_string(&WarningReason::PaginationNotSnapshotConsistent) + .expect("serialize warning"); + assert_eq!(warning, "\"pagination_not_snapshot_consistent\""); + } + + #[test] + fn push_summary_roundtrip() { + let summary = PushSummary { + status: CommandStatus::Partial, + total_files: 2, + uploaded_files: 1, + failed_files: 0, + skipped_files: 1, + ignored_entries: 1, + files: vec![PushFileReport { + source_file: "a.ts".to_string(), + status: FileStatus::Skipped, + uploaded_entries: 0, + skipped_reason: Some(SoftSkipReason::IfExistsIgnored), + error_reason: None, + bundle_id: None, + message: None, + }], + warnings: vec![], + errors: vec![], + }; + + let encoded = serde_json::to_string(&summary).expect("encode"); + let decoded: PushSummary = serde_json::from_str(&encoded).expect("decode"); + assert_eq!(decoded, summary); + } +} diff --git a/src/http.rs b/src/http.rs index 564ac4a..fea804e 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,4 +1,5 @@ use anyhow::{Context, Result}; +use reqwest::header::HeaderValue; use reqwest::Client; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -198,3 +199,33 @@ impl ApiClient { self.post_with_headers("/btql", &body, &headers).await } } + +pub async fn put_signed_url( + url: &str, + body: Vec, + content_encoding: Option<&str>, +) -> Result<()> { + let client = Client::builder() + .timeout(DEFAULT_HTTP_TIMEOUT) + .build() + .context("failed to build signed-url HTTP client")?; + + let mut request = client.put(url).body(body); + if let Some(encoding) = content_encoding { + request = request.header("Content-Encoding", encoding); + } + if url.contains(".blob.core.windows.net") { + request = request.header("x-ms-blob-type", HeaderValue::from_static("BlockBlob")); + } + + let response = request + .send() + .await + .context("signed-url upload request failed")?; + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(HttpError { status, body }.into()); + } + Ok(()) +} From 17bc63935550a84f983897b1eda9b225c7d32574 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Thu, 5 Mar 2026 16:39:30 -0800 Subject: [PATCH 03/28] Implement bt functions push Add the full push pipeline for deploying local function definitions to Braintrust: - push.rs: file classification, runner invocation, manifest parsing, project preflight, bundle compression/upload, and batch insert with structured reporting - functions-runner.ts: TS/JS runner that imports user files, inspects the Braintrust global registry, and emits a JSON manifest - functions-runner.py: Python runner with bundle collection (entry module + source files) for server-side execution Supports both TypeScript and Python source files with automatic language detection, interactive org/project selection, and configurable conflict resolution (error/replace/ignore). --- scripts/functions-runner.py | 312 ++++ scripts/functions-runner.ts | 371 ++++ src/functions/push.rs | 3179 +++++++++++++++++++++++++++++++++++ 3 files changed, 3862 insertions(+) create mode 100644 scripts/functions-runner.py create mode 100644 scripts/functions-runner.ts create mode 100644 src/functions/push.rs diff --git a/scripts/functions-runner.py b/scripts/functions-runner.py new file mode 100644 index 0000000..1e5ce1b --- /dev/null +++ b/scripts/functions-runner.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +import asyncio +import json +import os +import sys +from contextlib import nullcontext +from typing import Any + +from python_runner_common import ( + collect_python_sources, + import_file, + normalize_file_list, + purge_local_modules, + python_version, + resolve_module_info, +) + + +def to_json_value(value: Any) -> Any: + if value is None: + return None + if isinstance(value, (bool, int, float, str)): + return value + if isinstance(value, list): + return [to_json_value(item) for item in value] + if isinstance(value, tuple): + return [to_json_value(item) for item in value] + if isinstance(value, dict): + return {str(key): to_json_value(val) for key, val in value.items()} + if hasattr(value, "model_dump"): + return to_json_value(value.model_dump()) + if hasattr(value, "dict"): + return to_json_value(value.dict()) + if hasattr(value, "__dict__"): + result: dict[str, Any] = {} + for key, val in vars(value).items(): + if key.startswith("_"): + continue + result[key] = to_json_value(val) + return result + return str(value) + + +def load_framework_globals() -> tuple[Any, Any, Any]: + try: + from braintrust.framework2.global_ import functions, prompts + except Exception: + from braintrust.framework2 import global_ as global_state + + functions = getattr(global_state, "functions", []) + prompts = getattr(global_state, "prompts", []) + + lazy = None + try: + from braintrust.framework2.lazy_load import _set_lazy_load as lazy + except Exception: + try: + from braintrust.framework import _set_lazy_load as lazy + except Exception: + lazy = None + + return functions, prompts, lazy + + +def normalize_project_selector(project: Any) -> tuple[str | None, str | None]: + if project is None: + return None, None + + if isinstance(project, str): + trimmed = project.strip() + if trimmed: + return None, trimmed + return None, None + + if isinstance(project, dict): + project_id = project.get("project_id") + project_name = project.get("project_name") + if isinstance(project_id, str) and project_id.strip(): + return project_id.strip(), None + if isinstance(project_name, str) and project_name.strip(): + return None, project_name.strip() + return None, None + + project_id = getattr(project, "project_id", None) + project_name = getattr(project, "project_name", None) + if isinstance(project_id, str) and project_id.strip(): + return project_id.strip(), None + if isinstance(project_name, str) and project_name.strip(): + return None, project_name.strip() + return None, None + + +def normalize_function_type(raw: Any) -> str | None: + if isinstance(raw, str): + value = raw.strip() + return value if value else None + + value_attr = getattr(raw, "value", None) + if isinstance(value_attr, str): + value = value_attr.strip() + return value if value else None + + name_attr = getattr(raw, "name", None) + if isinstance(name_attr, str): + value = name_attr.strip().lower() + return value if value else None + + return None + + +def selector_to_project_placeholder(project: Any) -> str: + project_id, project_name = normalize_project_selector(project) + if project_id: + return project_id + if project_name: + return f"name:{project_name}" + return "" + + +class Resolver: + async def resolve(self, project: Any) -> str: + return selector_to_project_placeholder(project) + + +def clear_registry(registry: Any) -> None: + if hasattr(registry, "clear"): + registry.clear() + + +def collect_code_entries(functions_registry: Any) -> list[dict[str, Any]]: + entries: list[dict[str, Any]] = [] + items = functions_registry if isinstance(functions_registry, list) else list(functions_registry) + for index, item in enumerate(items): + name = getattr(item, "name", None) + slug = getattr(item, "slug", None) + if not isinstance(name, str) or not isinstance(slug, str) or not name or not slug: + continue + + project_id, project_name = normalize_project_selector(getattr(item, "project", None)) + + entry: dict[str, Any] = { + "kind": "code", + "name": name, + "slug": slug, + "location": {"type": "function", "index": index}, + } + description = getattr(item, "description", None) + if isinstance(description, str): + entry["description"] = description + function_type = ( + getattr(item, "type", None) + or getattr(item, "function_type", None) + or getattr(item, "type_", None) + ) + normalized_function_type = normalize_function_type(function_type) + if normalized_function_type: + entry["function_type"] = normalized_function_type + if_exists = getattr(item, "if_exists", None) or getattr(item, "ifExists", None) + if isinstance(if_exists, str): + entry["if_exists"] = if_exists + metadata = getattr(item, "metadata", None) + if metadata is not None: + entry["metadata"] = to_json_value(metadata) + if project_id: + entry["project_id"] = project_id + if project_name: + entry["project_name"] = project_name + + preview = getattr(item, "preview", None) + if isinstance(preview, str): + entry["preview"] = preview + + entries.append(entry) + return entries + + +def collect_legacy_prompt_event(item: Any, resolver: Resolver) -> dict[str, Any] | None: + name = getattr(item, "name", None) + slug = getattr(item, "slug", None) + if not isinstance(name, str) or not isinstance(slug, str) or not name or not slug: + return None + + prompt = to_json_value(getattr(item, "prompt", {}) or {}) + if not isinstance(prompt, dict): + prompt = {} + + tool_functions = getattr(item, "tool_functions", None) + if isinstance(tool_functions, list) and tool_functions: + resolved_tools: list[Any] = [] + for tool in tool_functions: + if isinstance(tool, dict): + slug_value = tool.get("slug") + project = tool.get("project") + if isinstance(slug_value, str) and project is not None: + placeholder = selector_to_project_placeholder(project) + if placeholder: + resolved_tools.append( + {"type": "slug", "project_id": placeholder, "slug": slug_value} + ) + continue + resolved_tools.append(to_json_value(tool)) + else: + resolved_tools.append(to_json_value(tool)) + if resolved_tools: + prompt["tool_functions"] = resolved_tools + + event: dict[str, Any] = { + "name": name, + "slug": slug, + "description": getattr(item, "description", "") or "", + "function_data": {"type": "prompt"}, + "prompt_data": prompt, + } + + if_exists = getattr(item, "if_exists", None) or getattr(item, "ifExists", None) + if isinstance(if_exists, str): + event["if_exists"] = if_exists + metadata = getattr(item, "metadata", None) + if metadata is not None: + event["metadata"] = to_json_value(metadata) + + project_id, project_name = normalize_project_selector(getattr(item, "project", None)) + out: dict[str, Any] = {"kind": "function_event", "event": event} + if project_id: + out["project_id"] = project_id + if project_name: + out["project_name"] = project_name + return out + + +async def collect_function_event_entries(prompts_registry: Any) -> list[dict[str, Any]]: + entries: list[dict[str, Any]] = [] + resolver = Resolver() + items = prompts_registry if isinstance(prompts_registry, list) else list(prompts_registry) + for item in items: + to_definition = getattr(item, "to_function_definition", None) + if callable(to_definition): + definition = to_definition(resolver) + if asyncio.iscoroutine(definition): + definition = await definition + normalized = to_json_value(definition) + if isinstance(normalized, dict): + project_id, project_name = normalize_project_selector(getattr(item, "project", None)) + event_entry: dict[str, Any] = {"kind": "function_event", "event": normalized} + if project_id: + event_entry["project_id"] = project_id + if project_name: + event_entry["project_name"] = project_name + entries.append(event_entry) + continue + + legacy = collect_legacy_prompt_event(item, resolver) + if legacy is not None: + entries.append(legacy) + + return entries + + +async def process_file(file_path: str) -> dict[str, Any]: + abs_path = os.path.abspath(file_path) + cwd = os.getcwd() + if cwd not in sys.path: + sys.path.insert(0, cwd) + + functions_registry, prompts_registry, lazy_loader = load_framework_globals() + clear_registry(functions_registry) + clear_registry(prompts_registry) + purge_local_modules(cwd, preserve_modules={__name__, "python_runner_common"}) + + module_name, extra_paths = resolve_module_info(abs_path) + lazy_ctx = lazy_loader(True) if callable(lazy_loader) else nullcontext() + with lazy_ctx: + import_file(module_name, abs_path, extra_paths) + code_entries = collect_code_entries(functions_registry) + event_entries = await collect_function_event_entries(prompts_registry) + entries = [*code_entries, *event_entries] + file_manifest: dict[str, Any] = { + "source_file": abs_path, + "entries": entries, + } + if code_entries: + file_manifest["python_bundle"] = { + "entry_module": module_name, + "sources": collect_python_sources(cwd, abs_path), + } + + clear_registry(functions_registry) + clear_registry(prompts_registry) + return file_manifest + + +async def main() -> None: + files = normalize_file_list(sys.argv[1:]) + if not files: + raise RuntimeError("functions-runner.py requires at least one input file") + + manifest: dict[str, Any] = { + "runtime_context": {"runtime": "python", "version": python_version()}, + "files": [], + } + for file_path in files: + manifest["files"].append(await process_file(file_path)) + + sys.stdout.write(json.dumps(manifest)) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except Exception as exc: + sys.stderr.write(f"{exc}\n") + sys.exit(1) diff --git a/scripts/functions-runner.ts b/scripts/functions-runner.ts new file mode 100644 index 0000000..de15262 --- /dev/null +++ b/scripts/functions-runner.ts @@ -0,0 +1,371 @@ +import path from "node:path"; +import { pathToFileURL } from "node:url"; + +import { + asProjectSelector, + isJsonObject, + ProjectRef, + selectorToProjectId, + toJsonValue, + type JsonObject, + type JsonValue, +} from "./runner-common"; + +type Resolver = { + resolve: (project: ProjectRef) => Promise; +}; + +type CodeRegistryItem = { + project?: ProjectRef; + name?: string; + slug?: string; + description?: string; + type?: string; + functionType?: string; + ifExists?: string; + metadata?: JsonValue; + preview?: string; +}; + +type EventRegistryItem = { + project?: ProjectRef; + toFunctionDefinition?: (resolver: Resolver) => Promise; + name?: string; + slug?: string; + description?: string; + ifExists?: string; + metadata?: JsonValue; + prompt?: JsonValue; + toolFunctions?: LegacyToolFunction[]; +}; + +type LegacyToolFunction = { + type?: string; + id?: string; + name?: string; + slug?: string; + project?: ProjectRef; + project_id?: string; +}; + +type CodeEntry = { + kind: "code"; + project_id?: string; + project_name?: string; + name: string; + slug: string; + description?: string; + function_type?: string; + if_exists?: string; + metadata?: JsonValue; + preview?: string; + location: JsonValue; +}; + +type FunctionEventEntry = { + kind: "function_event"; + project_id?: string; + project_name?: string; + event: JsonValue; +}; + +type ManifestFile = { + source_file: string; + entries: Array; +}; + +type Manifest = { + runtime_context: { + runtime: "node"; + version: string; + }; + files: ManifestFile[]; +}; + +type EvalRegistry = NonNullable; + +function freshRegistry(): EvalRegistry { + return { + functions: [], + prompts: [], + parameters: [], + evaluators: {}, + reporters: {}, + }; +} + +function currentRegistry(fallback: EvalRegistry): EvalRegistry { + const registry = globalThis._evals; + if (!registry) { + return fallback; + } + + return { + functions: Array.isArray(registry.functions) ? registry.functions : [], + prompts: Array.isArray(registry.prompts) ? registry.prompts : [], + parameters: Array.isArray(registry.parameters) ? registry.parameters : [], + evaluators: + registry.evaluators !== null && typeof registry.evaluators === "object" + ? registry.evaluators + : {}, + reporters: + registry.reporters !== null && typeof registry.reporters === "object" + ? registry.reporters + : {}, + }; +} + +async function collectFunctionEvents( + items: EventRegistryItem[], + includeLegacyPrompts: boolean, +): Promise { + const entries: FunctionEventEntry[] = []; + + const resolver: Resolver = { + resolve: async (project: ProjectRef): Promise => { + const selector = asProjectSelector(project); + return selectorToProjectId(selector); + }, + }; + + for (const item of items) { + if (!item.toFunctionDefinition) { + if (includeLegacyPrompts) { + const entry = await collectLegacyPromptEvent(item, resolver); + if (entry) { + entries.push(entry); + } + } + continue; + } + + const event = await item.toFunctionDefinition(resolver); + const normalizedEvent = toJsonValue(event); + if (!isJsonObject(normalizedEvent)) { + continue; + } + + const selector = asProjectSelector(item.project); + const projectId = + typeof selector.project_id === "string" ? selector.project_id : undefined; + const projectName = + typeof selector.project_name === "string" + ? selector.project_name + : undefined; + + entries.push({ + kind: "function_event", + project_id: projectId, + project_name: projectName, + event: normalizedEvent, + }); + } + + return entries; +} + +async function collectLegacyPromptEvent( + item: EventRegistryItem, + resolver: Resolver, +): Promise { + if (typeof item.name !== "string" || typeof item.slug !== "string") { + return null; + } + + const normalizedPrompt = toJsonValue(item.prompt ?? {}); + if (!isJsonObject(normalizedPrompt)) { + return null; + } + + const promptData: JsonObject = { ...normalizedPrompt }; + const toolFunctions = Array.isArray(item.toolFunctions) + ? item.toolFunctions + : []; + if (toolFunctions.length > 0) { + const resolvedTools: JsonValue[] = []; + for (const tool of toolFunctions) { + const resolved = await resolveLegacyToolFunction(tool, resolver); + if (resolved) { + resolvedTools.push(resolved); + } + } + if (resolvedTools.length > 0) { + promptData.tool_functions = resolvedTools; + } + } + + const selector = asProjectSelector(item.project); + const projectId = + typeof selector.project_id === "string" ? selector.project_id : undefined; + const projectName = + typeof selector.project_name === "string" + ? selector.project_name + : undefined; + + const event: JsonObject = { + name: item.name, + slug: item.slug, + description: typeof item.description === "string" ? item.description : "", + function_data: { + type: "prompt", + }, + prompt_data: promptData, + }; + if (typeof item.ifExists === "string") { + event.if_exists = item.ifExists; + } + if (item.metadata !== undefined) { + event.metadata = item.metadata; + } + + return { + kind: "function_event", + project_id: projectId, + project_name: projectName, + event, + }; +} + +async function resolveLegacyToolFunction( + tool: LegacyToolFunction, + resolver: Resolver, +): Promise { + if ( + typeof tool.slug === "string" && + tool.slug.length > 0 && + tool.project !== undefined + ) { + const projectId = await resolver.resolve(tool.project); + if (projectId.length > 0) { + return { + type: "slug", + project_id: projectId, + slug: tool.slug, + }; + } + } + + const direct: JsonObject = {}; + if (typeof tool.type === "string") { + direct.type = tool.type; + } + if (typeof tool.id === "string") { + direct.id = tool.id; + } + if (typeof tool.name === "string") { + direct.name = tool.name; + } + if (typeof tool.project_id === "string") { + direct.project_id = tool.project_id; + } + if (typeof tool.slug === "string") { + direct.slug = tool.slug; + } + + return Object.keys(direct).length > 0 ? direct : null; +} + +function collectCodeEntries(items: CodeRegistryItem[]): CodeEntry[] { + const entries: CodeEntry[] = []; + + for (let index = 0; index < items.length; index += 1) { + const item = items[index]; + + if (typeof item.name !== "string" || typeof item.slug !== "string") { + continue; + } + + const selector = asProjectSelector(item.project); + + entries.push({ + kind: "code", + project_id: + typeof selector.project_id === "string" + ? selector.project_id + : undefined, + project_name: + typeof selector.project_name === "string" + ? selector.project_name + : undefined, + name: item.name, + slug: item.slug, + description: + typeof item.description === "string" ? item.description : undefined, + function_type: + typeof item.type === "string" + ? item.type + : typeof item.functionType === "string" + ? item.functionType + : undefined, + if_exists: typeof item.ifExists === "string" ? item.ifExists : undefined, + metadata: item.metadata, + preview: typeof item.preview === "string" ? item.preview : undefined, + location: { + type: "function", + index, + }, + }); + } + + return entries; +} + +async function processFile(filePath: string): Promise { + const absolutePath = path.resolve(process.cwd(), filePath); + const fallbackRegistry = freshRegistry(); + globalThis._evals = fallbackRegistry; + globalThis._lazy_load = true; + + await import(pathToFileURL(absolutePath).href); + const registry = currentRegistry(fallbackRegistry); + + const entries: Array = [ + ...collectCodeEntries(registry.functions as CodeRegistryItem[]), + ...(await collectFunctionEvents( + registry.prompts as EventRegistryItem[], + true, + )), + ...(await collectFunctionEvents( + registry.parameters as EventRegistryItem[], + false, + )), + ]; + + return { + source_file: absolutePath, + entries, + }; +} + +async function main(): Promise { + const files = process.argv.slice(2); + if (files.length === 0) { + throw new Error("functions-runner requires at least one input file"); + } + + const manifest: Manifest = { + runtime_context: { + runtime: "node", + version: + typeof process.version === "string" && process.version.startsWith("v") + ? process.version.slice(1) + : typeof process.version === "string" && process.version.length > 0 + ? process.version + : "unknown", + }, + files: [], + }; + + for (const file of files) { + const result = await processFile(file); + manifest.files.push(result); + } + + process.stdout.write(JSON.stringify(manifest)); +} + +main().catch((error: Error) => { + const message = error instanceof Error ? error.message : String(error); + process.stderr.write(`${message}\n`); + process.exitCode = 1; +}); diff --git a/src/functions/push.rs b/src/functions/push.rs new file mode 100644 index 0000000..fd864a4 --- /dev/null +++ b/src/functions/push.rs @@ -0,0 +1,3179 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::ffi::OsString; +use std::path::{Path, PathBuf}; +use std::process::{Command, Output}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use anyhow::{anyhow, bail, Context, Result}; +use dialoguer::Confirm; +use reqwest::StatusCode; +use serde::Deserialize; +use serde_json::{json, Map, Value}; + +use crate::args::BaseArgs; +use crate::auth::{list_available_orgs, AvailableOrg}; +use crate::config; +use crate::functions::report::{ + CommandStatus, FileStatus, HardFailureReason, PushFileReport, PushSummary, ReportError, + ReportWarning, SoftSkipReason, +}; +use crate::js_runner; +use crate::projects::api::{create_project, get_project_by_name, list_projects}; +use crate::python_runner; +use crate::source_language::{classify_runtime_extension, JsExtensionProfile, SourceLanguage}; +use crate::ui::{fuzzy_select, is_interactive}; + +use super::api; +use super::{resolve_auth_context, PushArgs, PushLanguage}; + +const FUNCTIONS_JS_RUNNER_FILE: &str = "functions-runner.ts"; +const FUNCTIONS_PY_RUNNER_FILE: &str = "functions-runner.py"; +const RUNNER_COMMON_FILE: &str = "runner-common.ts"; +const PYTHON_RUNNER_COMMON_FILE: &str = "python_runner_common.py"; +const FUNCTIONS_JS_RUNNER_SOURCE: &str = include_str!("../../scripts/functions-runner.ts"); +const FUNCTIONS_PY_RUNNER_SOURCE: &str = include_str!("../../scripts/functions-runner.py"); +const RUNNER_COMMON_SOURCE: &str = include_str!("../../scripts/runner-common.ts"); +const PYTHON_RUNNER_COMMON_SOURCE: &str = include_str!("../../scripts/python_runner_common.py"); +const PYTHON_BASELINE_DEPS: &[&str] = + &["pydantic", "braintrust", "autoevals", "requests", "openai"]; + +#[derive(Debug, Deserialize)] +struct RunnerManifest { + runtime_context: RuntimeContext, + files: Vec, +} + +#[derive(Debug, Deserialize)] +struct RuntimeContext { + runtime: String, + version: String, +} + +#[derive(Debug, Deserialize)] +struct ManifestFile { + source_file: String, + #[serde(default)] + entries: Vec, + #[serde(default)] + python_bundle: Option, +} + +#[derive(Debug, Deserialize, Clone)] +struct PythonBundle { + entry_module: String, + #[serde(default)] + sources: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "kind")] +#[allow(clippy::large_enum_variant)] +enum ManifestEntry { + #[serde(rename = "code")] + Code(CodeEntry), + #[serde(rename = "function_event")] + FunctionEvent(FunctionEventEntry), +} + +#[derive(Debug, Deserialize)] +struct CodeEntry { + #[serde(default)] + project_id: Option, + #[serde(default)] + project_name: Option, + name: String, + slug: String, + #[serde(default)] + description: Option, + #[serde(default)] + function_type: Option, + #[serde(default)] + if_exists: Option, + #[serde(default)] + metadata: Option, + #[serde(default)] + location: Option, + #[serde(default)] + preview: Option, +} + +#[derive(Debug, Deserialize)] +struct FunctionEventEntry { + #[serde(default)] + project_id: Option, + #[serde(default)] + project_name: Option, + event: Value, +} + +#[derive(Debug, Clone)] +struct FileFailure { + reason: HardFailureReason, + message: String, +} + +fn error_chain(err: &anyhow::Error) -> String { + format!("{err:#}") +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum OrgDecision { + Continue, + Switch(String), + Cancel, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ProjectSelector { + Id(String), + Name(String), + Fallback, +} + +#[derive(Debug, Clone)] +struct ProjectPreflight { + default_project_name: Option, + requires_default_project: bool, + named_projects: BTreeSet, + direct_project_ids: BTreeSet, + selector_preview: Vec, +} + +#[derive(Debug, Clone)] +struct ResolvedEntryTarget { + source_file: String, + slug: String, + project_id: String, +} + +#[derive(Debug, Clone)] +struct ResolvedFileTargets { + source_file: String, + entry_project_ids: Vec, +} + +#[derive(Debug, Clone)] +struct ResolvedManifestTargets { + entries: Vec, + per_file: Vec, + unique_project_ids: Vec, +} + +#[derive(Debug, Default)] +struct ClassifiedFiles { + js_like: Vec, + python: Vec, + had_directory_inputs: bool, + explicit_file_inputs: usize, + explicit_supported_files: usize, + explicit_js_like: usize, + explicit_python: usize, + allowed_roots: Vec, +} + +impl ClassifiedFiles { + fn files_for_language(&self, language: SourceLanguage) -> Vec { + match language { + SourceLanguage::JsLike => self.js_like.clone(), + SourceLanguage::Python => self.python.clone(), + } + } +} + +pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { + let available_orgs = match list_available_orgs(&base) + .await + .context("failed to list available orgs") + { + Ok(orgs) => orgs, + Err(err) => { + return fail_push( + &base, + 0, + HardFailureReason::AuthFailed, + error_chain(&err), + "failed to list available orgs", + ); + } + }; + + if let Err(err) = validate_explicit_org_selection(&base, &available_orgs) { + return fail_push( + &base, + 0, + HardFailureReason::ResponseInvalid, + error_chain(&err), + "invalid org selection", + ); + } + + let mut auth_ctx = match resolve_auth_context(&base) + .await + .context("failed to resolve auth context") + { + Ok(ctx) => ctx, + Err(err) => { + return fail_push( + &base, + 0, + HardFailureReason::AuthFailed, + error_chain(&err), + "failed to resolve auth context", + ); + } + }; + + let classified = match collect_classified_files(&args.files) { + Ok(files) => files, + Err(err) => { + return fail_push( + &base, + 0, + HardFailureReason::ManifestPathMissing, + err.to_string(), + "failed to collect input files", + ); + } + }; + if classified.explicit_file_inputs > 0 && classified.explicit_supported_files == 0 { + return fail_push( + &base, + 0, + HardFailureReason::ManifestPathMissing, + "no eligible source files found in explicit file inputs; supported extensions: .ts, .tsx, .js, .jsx, .py".to_string(), + "no eligible source files found", + ); + } + + let selected_language = match select_push_language(&args, &classified) { + Ok(language) => language, + Err(err) => { + return fail_push( + &base, + 0, + HardFailureReason::ManifestSchemaInvalid, + err.to_string(), + "failed to select push language", + ); + } + }; + emit_language_selection_notice(&args, &classified, selected_language); + + if args.requirements.is_some() && selected_language != SourceLanguage::Python { + return fail_push( + &base, + 0, + HardFailureReason::ManifestSchemaInvalid, + "--requirements can only be used when pushing Python sources".to_string(), + "invalid --requirements usage", + ); + } + + let files = classified.files_for_language(selected_language); + if files.is_empty() { + if args.language != PushLanguage::Auto { + let selected = match args.language { + PushLanguage::JavaScript => "javascript", + PushLanguage::Python => "python", + PushLanguage::Auto => "auto", + }; + return fail_push( + &base, + 0, + HardFailureReason::ManifestPathMissing, + format!("no eligible files matched selected language '{selected}'"), + "no matching files for selected language", + ); + } + let summary = PushSummary { + status: CommandStatus::Success, + total_files: 0, + uploaded_files: 0, + failed_files: 0, + skipped_files: 0, + ignored_entries: 0, + files: vec![], + warnings: vec![], + errors: vec![], + }; + emit_summary(&base, &summary)?; + return Ok(()); + } + + let manifest = match run_functions_runner(&args, &files, selected_language) { + Ok(manifest) => manifest, + Err(failure) => { + return fail_push_with_all_skipped( + &base, + &files, + failure.reason, + &failure.message, + "skipped because manifest generation failed", + ); + } + }; + + if let Err(failure) = validate_manifest_paths( + &manifest, + &files, + selected_language, + &classified.allowed_roots, + ) { + return fail_push_with_all_skipped( + &base, + &files, + failure.reason, + &failure.message, + "skipped because manifest validation failed", + ); + } + + let has_code_entries = manifest.files.iter().any(|file| { + file.entries + .iter() + .any(|entry| matches!(entry, ManifestEntry::Code(_))) + }); + let requirements_path = if selected_language == SourceLanguage::Python { + if let Some(requirements) = args.requirements.as_deref() { + if has_code_entries { + match validate_requirements_path(requirements, &classified.allowed_roots) { + Ok(validated) => Some(validated), + Err(err) => { + return fail_push_with_all_skipped( + &base, + &files, + HardFailureReason::ManifestPathMissing, + &err.to_string(), + "skipped because requirements validation failed", + ); + } + } + } else { + eprintln!( + "Notice: ignoring --requirements because no Python code functions were discovered." + ); + None + } + } else { + None + } + } else { + None + }; + + let preflight = match collect_project_preflight(&base, &manifest) { + Ok(preflight) => preflight, + Err(err) => { + let message = format!("failed to resolve project selectors in manifest: {err}"); + return fail_push_manifest_preflight( + &base, + &files, + &message, + "skipped because project selector preflight failed", + ); + } + }; + + let (org_decision, org_prompt_confirmed) = match resolve_org_decision( + &base, + &auth_ctx, + &available_orgs, + &preflight.selector_preview, + manifest.files.len(), + ) { + Ok(outcome) => outcome, + Err(err) => { + return fail_push( + &base, + files.len(), + HardFailureReason::ResponseInvalid, + error_chain(&err), + "failed to resolve org context", + ); + } + }; + + match org_decision { + OrgDecision::Continue => {} + OrgDecision::Switch(org_name) => { + let mut switched_base = base.clone(); + switched_base.org_name = Some(org_name); + auth_ctx = match resolve_auth_context(&switched_base) + .await + .context("failed to resolve switched org context") + { + Ok(ctx) => ctx, + Err(err) => { + return fail_push( + &base, + files.len(), + HardFailureReason::AuthFailed, + error_chain(&err), + "failed to resolve switched org context", + ); + } + }; + } + OrgDecision::Cancel => { + return cancel_push(&base, &files); + } + } + + let mut project_name_cache = match resolve_and_ensure_named_projects( + &auth_ctx, + &preflight.named_projects, + args.create_missing_projects, + ) + .await + { + Ok(cache) => cache, + Err(err) => { + let message = format!("failed to resolve target projects for push: {err}"); + return fail_push_manifest_preflight( + &base, + &files, + &message, + "skipped because project target resolution failed", + ); + } + }; + + if let Err(err) = validate_direct_project_ids(&auth_ctx, &preflight.direct_project_ids).await { + let message = format!("failed to validate project ids for push: {err}"); + return fail_push_manifest_preflight( + &base, + &files, + &message, + "skipped because project id validation failed", + ); + } + + let default_project_id = match resolve_default_project_id(&preflight, &project_name_cache) { + Ok(id) => id, + Err(err) => { + let message = format!("failed to resolve default project for push: {err}"); + return fail_push_manifest_preflight( + &base, + &files, + &message, + "skipped because default project resolution failed", + ); + } + }; + + let resolved_targets = match resolve_manifest_targets( + &auth_ctx, + default_project_id.as_deref(), + &manifest, + &mut project_name_cache, + ) + .await + { + Ok(targets) => targets, + Err(err) => { + let message = format!("failed to resolve target projects for push: {err}"); + return fail_push_manifest_preflight( + &base, + &files, + &message, + "skipped because project target resolution failed", + ); + } + }; + + if let Err(err) = validate_duplicate_slugs(&resolved_targets.entries) { + return fail_push_manifest_preflight( + &base, + &files, + &err.to_string(), + "skipped because duplicate slug validation failed", + ); + } + + let target_project_ids = resolved_targets.unique_project_ids.clone(); + + if !org_prompt_confirmed + && !confirm_push_targets(&auth_ctx, &target_project_ids, manifest.files.len())? + { + return cancel_push(&base, &files); + } + + let mut summary = PushSummary { + status: CommandStatus::Success, + total_files: manifest.files.len(), + uploaded_files: 0, + failed_files: 0, + skipped_files: 0, + ignored_entries: 0, + files: Vec::with_capacity(manifest.files.len()), + warnings: vec![], + errors: vec![], + }; + + if resolved_targets.per_file.len() != manifest.files.len() { + return fail_push_manifest_preflight( + &base, + &files, + "internal error: resolved target count did not match manifest file count", + "skipped because internal target resolution failed", + ); + } + + for (index, (file, resolved_file)) in manifest + .files + .iter() + .zip(resolved_targets.per_file.iter()) + .enumerate() + { + if resolved_file.source_file != file.source_file { + return fail_push_manifest_preflight( + &base, + &files, + "internal error: resolved target source mismatch", + "skipped because internal target resolution failed", + ); + } + let source_path = PathBuf::from(&file.source_file); + let file_result = push_file( + &auth_ctx, + default_project_id.as_deref(), + &manifest.runtime_context, + &source_path, + file, + &resolved_file.entry_project_ids, + &args, + selected_language, + requirements_path.as_deref(), + &classified.allowed_roots, + &mut project_name_cache, + ) + .await; + + match file_result { + Ok(file_success) => { + summary.ignored_entries += file_success.ignored_entries; + let skipped_reason = if file_success.uploaded_entries == 0 { + if file_success.ignored_entries > 0 { + Some(SoftSkipReason::IfExistsIgnored) + } else { + Some(SoftSkipReason::NoDefinitionsFound) + } + } else { + None + }; + let status = if skipped_reason.is_some() { + summary.skipped_files += 1; + FileStatus::Skipped + } else { + summary.uploaded_files += 1; + FileStatus::Success + }; + summary.files.push(PushFileReport { + source_file: file.source_file.clone(), + status, + uploaded_entries: file_success.uploaded_entries, + skipped_reason, + error_reason: None, + bundle_id: file_success.bundle_id, + message: if file_success.uploaded_entries == 0 + && file_success.ignored_entries == 0 + { + Some("no publishable definitions found in this file".to_string()) + } else { + None + }, + }); + } + Err(file_failure) => { + summary.failed_files += 1; + summary.status = CommandStatus::Failed; + summary.errors.push(ReportError { + reason: file_failure.reason, + message: file_failure.message.clone(), + }); + summary.files.push(PushFileReport { + source_file: file.source_file.clone(), + status: FileStatus::Failed, + uploaded_entries: 0, + skipped_reason: None, + error_reason: Some(file_failure.reason), + bundle_id: None, + message: Some(file_failure.message), + }); + + if args.terminate_on_failure { + for remaining in manifest.files.iter().skip(index + 1) { + summary.skipped_files += 1; + summary.files.push(PushFileReport { + source_file: remaining.source_file.clone(), + status: FileStatus::Skipped, + uploaded_entries: 0, + skipped_reason: Some(SoftSkipReason::TerminatedAfterFailure), + error_reason: None, + bundle_id: None, + message: Some( + "skipped because --terminate-on-failure was set".to_string(), + ), + }); + } + break; + } + } + } + } + + if summary.status != CommandStatus::Failed && summary.skipped_files > 0 { + summary.status = CommandStatus::Partial; + } + + let failure = summary.status == CommandStatus::Failed; + emit_summary(&base, &summary)?; + + if failure { + bail!("functions push failed; see summary for details"); + } + + Ok(()) +} + +struct FileSuccess { + uploaded_entries: usize, + ignored_entries: usize, + bundle_id: Option, +} + +fn default_code_location(index: usize) -> Value { + json!({ + "type": "function", + "index": index + }) +} + +fn build_code_function_data( + runtime_context: &RuntimeContext, + location: Value, + bundle_id: &str, + preview: Option<&str>, +) -> Value { + let mut data = Map::new(); + data.insert("type".to_string(), Value::String("bundle".to_string())); + data.insert( + "runtime_context".to_string(), + json!({ + "runtime": runtime_context.runtime, + "version": runtime_context.version, + }), + ); + data.insert("location".to_string(), location); + data.insert( + "bundle_id".to_string(), + Value::String(bundle_id.to_string()), + ); + if let Some(preview) = preview.map(str::trim).filter(|preview| !preview.is_empty()) { + data.insert("preview".to_string(), Value::String(preview.to_string())); + } + + json!({ + "type": "code", + "data": Value::Object(data), + }) +} + +async fn push_file( + auth_ctx: &super::AuthContext, + default_project_id: Option<&str>, + runtime_context: &RuntimeContext, + source_path: &Path, + manifest_file: &ManifestFile, + entry_project_ids: &[String], + args: &PushArgs, + selected_language: SourceLanguage, + requirements_path: Option<&Path>, + allowed_roots: &[PathBuf], + project_name_cache: &mut BTreeMap, +) -> std::result::Result { + let mut code_entries = Vec::new(); + let mut events = Vec::new(); + + for (entry_index, entry) in manifest_file.entries.iter().enumerate() { + let project_id = + entry_project_ids + .get(entry_index) + .cloned() + .ok_or_else(|| FileFailure { + reason: HardFailureReason::ManifestSchemaInvalid, + message: format!( + "internal error: missing resolved project id for '{}' entry {}", + manifest_file.source_file, entry_index + ), + })?; + match entry { + ManifestEntry::Code(code) => code_entries.push((code, project_id)), + ManifestEntry::FunctionEvent(event) => events.push((event, project_id)), + } + } + + let mut bundle_id: Option = None; + + let mut function_events: Vec = Vec::new(); + + if !code_entries.is_empty() { + let (upload_bytes, content_encoding) = match selected_language { + SourceLanguage::JsLike => { + let bundle_bytes = std::fs::read(source_path).map_err(|err| FileFailure { + reason: HardFailureReason::ManifestPathMissing, + message: format!("failed to read {}: {err}", source_path.display()), + })?; + let gzipped = gzip_bytes(&bundle_bytes).map_err(|err| FileFailure { + reason: HardFailureReason::BundleUploadFailed, + message: format!("failed to gzip {}: {err}", source_path.display()), + })?; + (gzipped, Some("gzip")) + } + SourceLanguage::Python => { + let bundle = validate_python_bundle(manifest_file, source_path, allowed_roots) + .map_err(|err| FileFailure { + reason: HardFailureReason::ManifestSchemaInvalid, + message: err.to_string(), + })?; + let archive = build_python_bundle_archive( + &bundle.entry_module, + &bundle.sources, + &bundle.archive_root, + requirements_path, + args.runner.as_deref(), + ) + .map_err(|err| FileFailure { + reason: HardFailureReason::BundleUploadFailed, + message: err.to_string(), + })?; + (archive, None) + } + }; + + let slot = api::request_code_upload_slot( + &auth_ctx.client, + &auth_ctx.org_id, + &runtime_context.runtime, + &runtime_context.version, + ) + .await + .map_err(|err| FileFailure { + reason: HardFailureReason::UploadSlotFailed, + message: err.to_string(), + })?; + + api::upload_bundle(&slot.url, upload_bytes, content_encoding) + .await + .map_err(|err| FileFailure { + reason: HardFailureReason::BundleUploadFailed, + message: err.to_string(), + })?; + + bundle_id = Some(slot.bundle_id.clone()); + + for (index, (code, project_id)) in code_entries.iter().enumerate() { + let mut obj = Map::new(); + obj.insert("project_id".to_string(), Value::String(project_id.clone())); + obj.insert("name".to_string(), Value::String(code.name.clone())); + obj.insert("slug".to_string(), Value::String(code.slug.clone())); + obj.insert( + "description".to_string(), + Value::String(code.description.clone().unwrap_or_default()), + ); + obj.insert( + "function_data".to_string(), + build_code_function_data( + runtime_context, + code.location + .clone() + .unwrap_or_else(|| default_code_location(index)), + &slot.bundle_id, + code.preview.as_deref(), + ), + ); + + if let Some(function_type) = &code.function_type { + obj.insert( + "function_type".to_string(), + Value::String(function_type.clone()), + ); + } + if let Some(metadata) = &code.metadata { + obj.insert("metadata".to_string(), metadata.clone()); + } + let if_exists = code + .if_exists + .as_deref() + .map(ToOwned::to_owned) + .unwrap_or_else(|| args.if_exists.as_str().to_string()); + obj.insert("if_exists".to_string(), Value::String(if_exists)); + + function_events.push(Value::Object(obj)); + } + } + + for (event_entry, resolved_project_id) in &events { + let mut event = event_entry.event.clone(); + if !event.is_object() { + return Err(FileFailure { + reason: HardFailureReason::ManifestSchemaInvalid, + message: "function_event entry must be a JSON object".to_string(), + }); + } + + let mut placeholders = BTreeSet::new(); + collect_project_name_placeholders_checked(&event, &mut placeholders).map_err(|err| { + FileFailure { + reason: HardFailureReason::ManifestSchemaInvalid, + message: err.to_string(), + } + })?; + + let mut resolved_placeholders = BTreeMap::new(); + for project_name in placeholders { + let resolved = resolve_project_id( + &auth_ctx.client, + default_project_id, + None, + Some(&project_name), + project_name_cache, + ) + .await + .map_err(|err| FileFailure { + reason: HardFailureReason::ManifestSchemaInvalid, + message: err.to_string(), + })?; + resolved_placeholders.insert(project_name, resolved); + } + + replace_project_name_placeholders(&mut event, &resolved_placeholders); + + let fallback_project_id = resolved_project_id.clone(); + + if let Some(object) = event.as_object_mut() { + let needs_project_id = object + .get("project_id") + .and_then(Value::as_str) + .map(|value| value.trim().is_empty()) + .unwrap_or(true); + if needs_project_id { + object.insert("project_id".to_string(), Value::String(fallback_project_id)); + } + if object.get("if_exists").is_none() { + object.insert( + "if_exists".to_string(), + Value::String(args.if_exists.as_str().to_string()), + ); + } + } + + function_events.push(event); + } + + if function_events.is_empty() { + return Ok(FileSuccess { + uploaded_entries: 0, + ignored_entries: 0, + bundle_id, + }); + } + + let insert_result = api::insert_functions(&auth_ctx.client, &function_events) + .await + .map_err(|err| FileFailure { + reason: HardFailureReason::InsertFunctionsFailed, + message: { + let details = format!("{err:#}"); + if let Some(id) = &bundle_id { + format!( + "failed to save function definitions for {} (bundle_id={}): {}. Retry by re-running `bt functions push --file {}`", + source_path.display(), + id, + details, + source_path.display() + ) + } else { + format!( + "failed to save function definitions for {}: {}", + source_path.display(), + details + ) + } + }, + })?; + + let (uploaded_entries, ignored_entries) = + calculate_upload_counts(function_events.len(), insert_result.ignored_entries); + + Ok(FileSuccess { + uploaded_entries, + ignored_entries, + bundle_id, + }) +} + +fn calculate_upload_counts(total_entries: usize, ignored_entries: Option) -> (usize, usize) { + let ignored_entries = ignored_entries.unwrap_or(0); + let uploaded_entries = total_entries.saturating_sub(ignored_entries); + (uploaded_entries, ignored_entries) +} + +fn run_functions_runner( + args: &PushArgs, + files: &[PathBuf], + language: SourceLanguage, +) -> std::result::Result { + let mut command = match language { + SourceLanguage::JsLike => { + let _common = js_runner::materialize_runner_script_in_cwd( + "functions-runners", + RUNNER_COMMON_FILE, + RUNNER_COMMON_SOURCE, + ) + .map_err(|err| FileFailure { + reason: HardFailureReason::RunnerSpawnFailed, + message: format!("failed to materialize shared runner helper: {err}"), + })?; + let runner_script = js_runner::materialize_runner_script_in_cwd( + "functions-runners", + FUNCTIONS_JS_RUNNER_FILE, + FUNCTIONS_JS_RUNNER_SOURCE, + ) + .map_err(|err| FileFailure { + reason: HardFailureReason::RunnerSpawnFailed, + message: format!("failed to materialize functions runner: {err}"), + })?; + js_runner::build_js_runner_command(args.runner.as_deref(), &runner_script, files) + } + SourceLanguage::Python => { + let _common = js_runner::materialize_runner_script_in_cwd( + "functions-runners", + PYTHON_RUNNER_COMMON_FILE, + PYTHON_RUNNER_COMMON_SOURCE, + ) + .map_err(|err| FileFailure { + reason: HardFailureReason::RunnerSpawnFailed, + message: format!("failed to materialize shared Python runner helper: {err}"), + })?; + let runner_script = js_runner::materialize_runner_script_in_cwd( + "functions-runners", + FUNCTIONS_PY_RUNNER_FILE, + FUNCTIONS_PY_RUNNER_SOURCE, + ) + .map_err(|err| FileFailure { + reason: HardFailureReason::RunnerSpawnFailed, + message: format!("failed to materialize Python functions runner: {err}"), + })?; + let Some(python) = + python_runner::resolve_python_interpreter(args.runner.as_deref(), &[]) + else { + return Err(FileFailure { + reason: HardFailureReason::RunnerSpawnFailed, + message: "No Python interpreter found. Install python or pass --runner." + .to_string(), + }); + }; + let mut command = Command::new(python); + command.arg(runner_script); + for file in files { + command.arg(file); + } + command + } + }; + + let output = command.output().map_err(|err| FileFailure { + reason: HardFailureReason::RunnerSpawnFailed, + message: format!("failed to spawn functions runner: {err}"), + })?; + + parse_runner_manifest_output(output) +} + +fn parse_runner_manifest_output( + output: Output, +) -> std::result::Result { + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(FileFailure { + reason: HardFailureReason::RunnerExitNonzero, + message: format!( + "runner exited with status {}: {}", + output.status, + stderr.trim() + ), + }); + } + + let stdout = String::from_utf8(output.stdout).map_err(|err| FileFailure { + reason: HardFailureReason::ManifestInvalidJson, + message: format!("runner output was not valid UTF-8: {err}"), + })?; + serde_json::from_str(&stdout).map_err(|err| FileFailure { + reason: HardFailureReason::ManifestInvalidJson, + message: format!("failed to parse functions runner manifest JSON: {err}"), + }) +} + +fn classify_source_file(path: &Path) -> Option { + path.extension() + .and_then(|ext| ext.to_str()) + .and_then(|ext| classify_runtime_extension(ext, JsExtensionProfile::FunctionsPush)) +} + +fn collect_classified_files(inputs: &[PathBuf]) -> Result { + let mut js_like = BTreeSet::new(); + let mut python = BTreeSet::new(); + let mut allowed_roots = BTreeSet::new(); + let mut had_directory_inputs = false; + let mut explicit_file_inputs = 0usize; + let mut explicit_supported_files = 0usize; + let mut explicit_js_like = 0usize; + let mut explicit_python = 0usize; + + for input in inputs { + let path = if input.is_absolute() { + input.clone() + } else { + std::env::current_dir() + .context("failed to resolve current directory")? + .join(input) + }; + + if !path.exists() { + bail!("path does not exist: {}", input.display()); + } + + if path.is_file() { + explicit_file_inputs += 1; + let canonical = path + .canonicalize() + .with_context(|| format!("failed to canonicalize file {}", path.display()))?; + let parent = canonical + .parent() + .map(Path::to_path_buf) + .ok_or_else(|| anyhow!("failed to find parent dir for {}", canonical.display()))?; + allowed_roots.insert(parent); + match classify_source_file(&canonical) { + Some(SourceLanguage::JsLike) => { + explicit_supported_files += 1; + explicit_js_like += 1; + js_like.insert(canonical); + } + Some(SourceLanguage::Python) => { + explicit_supported_files += 1; + explicit_python += 1; + python.insert(canonical); + } + None => {} + } + continue; + } + + had_directory_inputs = true; + let canonical_dir = path + .canonicalize() + .with_context(|| format!("failed to canonicalize directory {}", path.display()))?; + allowed_roots.insert(canonical_dir.clone()); + collect_from_dir(&canonical_dir, &mut js_like, &mut python)?; + } + + Ok(ClassifiedFiles { + js_like: js_like.into_iter().collect(), + python: python.into_iter().collect(), + had_directory_inputs, + explicit_file_inputs, + explicit_supported_files, + explicit_js_like, + explicit_python, + allowed_roots: allowed_roots.into_iter().collect(), + }) +} + +fn collect_from_dir( + dir: &Path, + js_like: &mut BTreeSet, + python: &mut BTreeSet, +) -> Result<()> { + for entry in std::fs::read_dir(dir) + .with_context(|| format!("failed to read directory {}", dir.display()))? + { + let entry = entry.with_context(|| format!("failed to read entry in {}", dir.display()))?; + let path = entry.path(); + if path.is_dir() { + collect_from_dir(&path, js_like, python)?; + } else if path.is_file() { + let canonical = path + .canonicalize() + .with_context(|| format!("failed to canonicalize file {}", path.display()))?; + match classify_source_file(&canonical) { + Some(SourceLanguage::JsLike) => { + js_like.insert(canonical); + } + Some(SourceLanguage::Python) => { + python.insert(canonical); + } + None => {} + } + } + } + + Ok(()) +} + +fn select_push_language(args: &PushArgs, files: &ClassifiedFiles) -> Result { + if files.explicit_js_like > 0 && files.explicit_python > 0 { + bail!( + "mixed source languages are not supported in one push invocation; run separate commands for Python and JS/TS files" + ); + } + + match args.language { + PushLanguage::Auto => { + if !files.js_like.is_empty() && !files.python.is_empty() { + Ok(SourceLanguage::JsLike) + } else if !files.python.is_empty() { + Ok(SourceLanguage::Python) + } else { + Ok(SourceLanguage::JsLike) + } + } + PushLanguage::JavaScript => Ok(SourceLanguage::JsLike), + PushLanguage::Python => Ok(SourceLanguage::Python), + } +} + +fn emit_language_selection_notice( + args: &PushArgs, + files: &ClassifiedFiles, + selected_language: SourceLanguage, +) { + let has_mixed = !files.js_like.is_empty() && !files.python.is_empty(); + if !has_mixed { + return; + } + + let (selected_count, skipped_count, skipped_label) = match selected_language { + SourceLanguage::JsLike => (files.js_like.len(), files.python.len(), "python"), + SourceLanguage::Python => (files.python.len(), files.js_like.len(), "js/ts"), + }; + + if args.language == PushLanguage::Auto + && selected_language == SourceLanguage::JsLike + && files.had_directory_inputs + { + eprintln!( + "Notice: discovered mixed runtimes during directory scan; defaulting to JS/TS for compatibility and skipping {skipped_count} Python files. Run a separate `bt functions push --language python` invocation." + ); + return; + } + + if skipped_count > 0 { + eprintln!( + "Notice: selected {} runtime; processing {selected_count} files and skipping {skipped_count} {skipped_label} files.", + language_label(selected_language) + ); + } +} + +fn language_label(language: SourceLanguage) -> &'static str { + match language { + SourceLanguage::JsLike => "javascript", + SourceLanguage::Python => "python", + } +} + +fn validate_manifest_paths( + manifest: &RunnerManifest, + files: &[PathBuf], + language: SourceLanguage, + allowed_roots: &[PathBuf], +) -> std::result::Result<(), FileFailure> { + let expected: BTreeSet = files.iter().cloned().collect(); + let mut seen = BTreeSet::new(); + + for file in &manifest.files { + let path = PathBuf::from(&file.source_file) + .canonicalize() + .map_err(|err| FileFailure { + reason: HardFailureReason::ManifestPathMissing, + message: format!("manifest source file missing: {} ({err})", file.source_file), + })?; + if !expected.contains(&path) { + return Err(FileFailure { + reason: HardFailureReason::ManifestPathMissing, + message: format!("manifest referenced unexpected file: {}", path.display()), + }); + } + let has_code_entries = file + .entries + .iter() + .any(|entry| matches!(entry, ManifestEntry::Code(_))); + if language != SourceLanguage::Python && file.python_bundle.is_some() { + return Err(FileFailure { + reason: HardFailureReason::ManifestSchemaInvalid, + message: format!( + "manifest file '{}' contained python_bundle metadata for non-Python runtime", + file.source_file + ), + }); + } + if language == SourceLanguage::Python && !has_code_entries && file.python_bundle.is_some() { + return Err(FileFailure { + reason: HardFailureReason::ManifestSchemaInvalid, + message: format!( + "manifest file '{}' contained python_bundle metadata without code entries", + file.source_file + ), + }); + } + if language == SourceLanguage::Python && has_code_entries { + validate_python_bundle(file, &path, allowed_roots).map_err(|err| FileFailure { + reason: HardFailureReason::ManifestSchemaInvalid, + message: err.to_string(), + })?; + } + seen.insert(path); + } + + if let Some(missing) = expected.difference(&seen).next() { + return Err(FileFailure { + reason: HardFailureReason::ManifestPathMissing, + message: format!("manifest missing expected file: {}", missing.display()), + }); + } + + Ok(()) +} + +#[derive(Debug)] +struct ValidatedPythonBundle { + entry_module: String, + sources: Vec, + archive_root: PathBuf, +} + +fn validate_python_bundle( + manifest_file: &ManifestFile, + source_path: &Path, + allowed_roots: &[PathBuf], +) -> Result { + let python_bundle = manifest_file.python_bundle.as_ref().ok_or_else(|| { + anyhow!( + "manifest file '{}' includes Python code entries but is missing python_bundle metadata", + manifest_file.source_file + ) + })?; + let entry_module = python_bundle.entry_module.trim(); + if entry_module.is_empty() { + bail!( + "manifest file '{}' has empty python_bundle.entry_module", + manifest_file.source_file + ); + } + if python_bundle.sources.is_empty() { + bail!( + "manifest file '{}' has empty python_bundle.sources", + manifest_file.source_file + ); + } + + let mut sources = BTreeSet::new(); + for raw_source in &python_bundle.sources { + let canonical = PathBuf::from(raw_source).canonicalize().with_context(|| { + format!( + "manifest file '{}' referenced missing python source {}", + manifest_file.source_file, raw_source + ) + })?; + if !canonical.is_file() { + bail!( + "manifest file '{}' referenced non-file python source {}", + manifest_file.source_file, + canonical.display() + ); + } + if !is_within_allowed_roots(&canonical, allowed_roots) { + bail!( + "manifest file '{}' referenced python source outside allowed roots: {}", + manifest_file.source_file, + canonical.display() + ); + } + sources.insert(canonical); + } + + let source_list: Vec = sources.into_iter().collect(); + if !entry_module_matches_sources(entry_module, &source_list, allowed_roots) { + bail!( + "python_bundle.entry_module '{}' does not match any bundled source module for '{}'", + entry_module, + source_path.display() + ); + } + + let archive_root = infer_python_archive_root(entry_module, source_path)?; + for source in &source_list { + if !source.starts_with(&archive_root) { + bail!( + "python source '{}' is outside inferred archive root '{}'", + source.display(), + archive_root.display() + ); + } + } + + Ok(ValidatedPythonBundle { + entry_module: entry_module.to_string(), + sources: source_list, + archive_root, + }) +} + +fn infer_python_archive_root(entry_module: &str, source_path: &Path) -> Result { + let module_parts = entry_module + .split('.') + .filter(|part| !part.trim().is_empty()) + .collect::>(); + if module_parts.is_empty() { + bail!("python_bundle.entry_module cannot be empty"); + } + + let parent = source_path + .parent() + .ok_or_else(|| anyhow!("source file has no parent: {}", source_path.display()))?; + let file_name = source_path + .file_name() + .and_then(|name| name.to_str()) + .ok_or_else(|| { + anyhow!( + "source file has invalid utf-8 name: {}", + source_path.display() + ) + })?; + + let module_depth = if file_name == "__init__.py" { + module_parts.len() + } else { + module_parts.len().saturating_sub(1) + }; + + let mut root = parent.to_path_buf(); + for _ in 0..module_depth { + root = root.parent().map(Path::to_path_buf).ok_or_else(|| { + anyhow!( + "failed to infer archive root for module '{}' from source '{}'", + entry_module, + source_path.display() + ) + })?; + } + + Ok(root) +} + +fn is_within_allowed_roots(path: &Path, allowed_roots: &[PathBuf]) -> bool { + allowed_roots.iter().any(|root| path.starts_with(root)) +} + +fn entry_module_matches_sources( + entry_module: &str, + sources: &[PathBuf], + allowed_roots: &[PathBuf], +) -> bool { + let entry_tail = entry_module + .rsplit('.') + .next() + .unwrap_or(entry_module) + .trim(); + if entry_tail.is_empty() { + return false; + } + + for source in sources { + if source + .file_stem() + .and_then(|stem| stem.to_str()) + .is_some_and(|stem| stem == entry_tail) + { + return true; + } + + for root in allowed_roots { + if let Some(candidate) = module_name_for_source(source, root) { + if candidate == entry_module { + return true; + } + } + } + } + + false +} + +fn module_name_for_source(source: &Path, root: &Path) -> Option { + let rel = source.strip_prefix(root).ok()?; + if rel.extension().and_then(|ext| ext.to_str()) != Some("py") { + return None; + } + + let mut parts = Vec::new(); + let components: Vec<_> = rel.iter().collect(); + if components.is_empty() { + return None; + } + for (index, component) in components.iter().enumerate() { + let component = component.to_str()?; + if component.is_empty() { + return None; + } + if index + 1 == components.len() { + let stem = component.strip_suffix(".py").unwrap_or(component); + if stem != "__init__" { + parts.push(stem.to_string()); + } + } else { + parts.push(component.to_string()); + } + } + + if parts.is_empty() { + None + } else { + Some(parts.join(".")) + } +} + +struct TempBuildDir { + path: PathBuf, +} + +impl TempBuildDir { + fn create(prefix: &str) -> Result { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .context("failed to read system clock")? + .as_nanos(); + let path = std::env::temp_dir().join(format!("{prefix}-{}-{now}", std::process::id())); + std::fs::create_dir_all(&path) + .with_context(|| format!("failed to create temp directory {}", path.display()))?; + Ok(Self { path }) + } +} + +impl Drop for TempBuildDir { + fn drop(&mut self) { + let _ = std::fs::remove_dir_all(&self.path); + } +} + +fn build_python_bundle_archive( + entry_module: &str, + sources: &[PathBuf], + archive_root: &Path, + requirements_path: Option<&Path>, + runner: Option<&str>, +) -> Result> { + let build_dir = TempBuildDir::create("bt-functions-python-bundle")?; + let pkg_dir = build_dir.path.join("pkg"); + std::fs::create_dir_all(&pkg_dir) + .with_context(|| format!("failed to create {}", pkg_dir.display()))?; + + install_python_dependencies(&pkg_dir, requirements_path)?; + + let stage_dir = build_dir.path.join("stage"); + std::fs::create_dir_all(&stage_dir) + .with_context(|| format!("failed to create {}", stage_dir.display()))?; + + copy_directory_files_into_stage(&pkg_dir, &stage_dir)?; + for source in sources { + let archive_path = archive_source_path(source, archive_root)?; + copy_file_into_stage(source, &archive_path, &stage_dir)?; + } + std::fs::write( + stage_dir.join("register.py"), + format!("import {entry_module} as _\n"), + ) + .context("failed to write register.py")?; + + let zip_path = build_dir.path.join("pkg.zip"); + create_zip_with_python(runner, &stage_dir, &zip_path)?; + std::fs::read(&zip_path) + .with_context(|| format!("failed to read generated archive {}", zip_path.display())) +} + +fn archive_source_path(source: &Path, archive_root: &Path) -> Result { + let rel = source.strip_prefix(archive_root).with_context(|| { + format!( + "source '{}' is not under archive root '{}'", + source.display(), + archive_root.display() + ) + })?; + if rel.as_os_str().is_empty() { + bail!( + "refusing to archive source with empty path: {}", + source.display() + ); + } + Ok(rel.to_path_buf()) +} + +fn copy_directory_files_into_stage(source_root: &Path, stage_root: &Path) -> Result<()> { + let files = collect_regular_files_recursive(source_root)?; + for file in files { + let rel = file + .strip_prefix(source_root) + .with_context(|| format!("failed to strip prefix for {}", file.display()))?; + copy_file_into_stage(&file, rel, stage_root)?; + } + Ok(()) +} + +fn copy_file_into_stage(source: &Path, rel_path: &Path, stage_root: &Path) -> Result<()> { + let archive_rel = normalized_archive_relative_path(rel_path)?; + let dest = stage_root.join(archive_rel); + if let Some(parent) = dest.parent() { + std::fs::create_dir_all(parent) + .with_context(|| format!("failed to create {}", parent.display()))?; + } + std::fs::copy(source, &dest) + .with_context(|| format!("failed to copy {} -> {}", source.display(), dest.display()))?; + Ok(()) +} + +fn normalized_archive_relative_path(path: &Path) -> Result { + let mut out = PathBuf::new(); + for component in path.components() { + match component { + std::path::Component::Normal(segment) => out.push(segment), + std::path::Component::CurDir => {} + _ => { + bail!("invalid archive path component in '{}'", path.display()); + } + } + } + if out.as_os_str().is_empty() { + bail!( + "archive path resolved to empty name for '{}'", + path.display() + ); + } + Ok(out) +} + +fn create_zip_with_python(runner: Option<&str>, stage_root: &Path, zip_path: &Path) -> Result<()> { + const ZIP_SCRIPT: &str = r#"import os +import sys +import zipfile + +stage_root = sys.argv[1] +zip_path = sys.argv[2] + +with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zf: + for root, dirs, files in os.walk(stage_root): + dirs.sort() + files.sort() + for filename in files: + source = os.path.join(root, filename) + rel = os.path.relpath(source, stage_root) + zf.write(source, rel) +"#; + + let Some(python) = python_runner::resolve_python_interpreter(runner, &[]) else { + bail!("No Python interpreter found. Install python or pass --runner.") + }; + let output = Command::new(python) + .arg("-c") + .arg(ZIP_SCRIPT) + .arg(stage_root) + .arg(zip_path) + .output() + .context("failed to spawn Python archive builder")?; + if output.status.success() { + return Ok(()); + } + + let stderr = String::from_utf8_lossy(&output.stderr); + let excerpt = stderr + .lines() + .take(20) + .collect::>() + .join("\n") + .trim() + .to_string(); + if excerpt.is_empty() { + bail!( + "Python archive builder failed with status {}", + output.status + ); + } + bail!( + "Python archive builder failed with status {}: {}", + output.status, + excerpt + ); +} + +fn install_python_dependencies(pkg_dir: &Path, requirements_path: Option<&Path>) -> Result<()> { + let uv = python_runner::find_binary_in_path(&["uv"]).ok_or_else(|| { + anyhow!("`uv` is required to build Python code bundles; please install uv") + })?; + + let mut baseline_args = vec![ + OsString::from("pip"), + OsString::from("install"), + OsString::from("--target"), + pkg_dir.as_os_str().to_os_string(), + ]; + baseline_args.extend(PYTHON_BASELINE_DEPS.iter().map(OsString::from)); + run_uv_command( + &uv, + &baseline_args, + "installing baseline Python bundle dependencies", + )?; + + if let Some(requirements) = requirements_path { + let args = vec![ + OsString::from("pip"), + OsString::from("install"), + OsString::from("--target"), + pkg_dir.as_os_str().to_os_string(), + OsString::from("-r"), + requirements.as_os_str().to_os_string(), + ]; + run_uv_command(&uv, &args, "installing requirements file dependencies")?; + } + + Ok(()) +} + +fn run_uv_command(uv: &Path, args: &[OsString], stage: &str) -> Result<()> { + let args_debug = args + .iter() + .map(|arg| arg.to_string_lossy().to_string()) + .collect::>() + .join(" "); + let output = Command::new(uv) + .args(args) + .output() + .with_context(|| format!("failed to run `{} {args_debug}`", uv.display()))?; + if output.status.success() { + return Ok(()); + } + + let stderr = String::from_utf8_lossy(&output.stderr); + let excerpt = stderr + .lines() + .take(20) + .collect::>() + .join("\n") + .trim() + .to_string(); + let message = if excerpt.is_empty() { + format!("{stage} failed with status {}", output.status) + } else { + format!("{stage} failed with status {}: {excerpt}", output.status) + }; + bail!(message); +} + +fn collect_regular_files_recursive(root: &Path) -> Result> { + let mut files = Vec::new(); + collect_regular_files_recursive_impl(root, &mut files)?; + files.sort(); + Ok(files) +} + +fn collect_regular_files_recursive_impl(root: &Path, out: &mut Vec) -> Result<()> { + for entry in + std::fs::read_dir(root).with_context(|| format!("failed to read {}", root.display()))? + { + let entry = entry.with_context(|| format!("failed to read entry in {}", root.display()))?; + let path = entry.path(); + if path.is_dir() { + collect_regular_files_recursive_impl(&path, out)?; + } else if path.is_file() { + out.push(path); + } + } + Ok(()) +} + +fn validate_requirements_path(path: &Path, allowed_roots: &[PathBuf]) -> Result { + let canonical = path + .canonicalize() + .with_context(|| format!("requirements file not found: {}", path.display()))?; + if !canonical.is_file() { + bail!("requirements path is not a file: {}", canonical.display()); + } + let mut visited = BTreeSet::new(); + validate_requirements_local_refs(&canonical, allowed_roots, &mut visited)?; + Ok(canonical) +} + +fn validate_requirements_local_refs( + path: &Path, + allowed_roots: &[PathBuf], + visited: &mut BTreeSet, +) -> Result<()> { + if !visited.insert(path.to_path_buf()) { + return Ok(()); + } + + let parent = path + .parent() + .ok_or_else(|| anyhow!("requirements path has no parent: {}", path.display()))?; + let content = std::fs::read_to_string(path) + .with_context(|| format!("failed to read requirements file {}", path.display()))?; + + for (line_index, raw_line) in content.lines().enumerate() { + let line = strip_requirement_comment(raw_line).trim(); + if line.is_empty() { + continue; + } + + if let Some(reference) = parse_requirement_include(line) { + let resolved = resolve_requirement_path(reference, parent)?; + ensure_path_within_allowed_roots(&resolved, allowed_roots, path, line_index + 1)?; + validate_requirements_local_refs(&resolved, allowed_roots, visited)?; + continue; + } + + if let Some(reference) = parse_editable_local_path(line) { + let resolved = resolve_requirement_path(reference, parent)?; + ensure_path_within_allowed_roots(&resolved, allowed_roots, path, line_index + 1)?; + continue; + } + + if let Some(reference) = parse_local_dependency_path(line) { + let resolved = resolve_requirement_path(reference, parent)?; + ensure_path_within_allowed_roots(&resolved, allowed_roots, path, line_index + 1)?; + } + } + + Ok(()) +} + +fn strip_requirement_comment(line: &str) -> &str { + line.split_once('#').map_or(line, |(head, _)| head) +} + +fn parse_requirement_include(line: &str) -> Option<&str> { + let mut parts = line.split_whitespace(); + let first = parts.next()?; + match first { + "-r" | "--requirement" | "-c" | "--constraint" => parts.next(), + _ => first + .strip_prefix("-r") + .or_else(|| first.strip_prefix("-c")) + .or_else(|| first.strip_prefix("--requirement=")) + .or_else(|| first.strip_prefix("--constraint=")) + .filter(|value| !value.is_empty()), + } +} + +fn parse_editable_local_path(line: &str) -> Option<&str> { + let mut parts = line.split_whitespace(); + let first = parts.next()?; + let value = match first { + "-e" | "--editable" => parts.next(), + _ => first + .strip_prefix("-e") + .or_else(|| first.strip_prefix("--editable=")) + .filter(|value| !value.is_empty()), + }?; + if is_local_path_spec(value) { + Some(value) + } else { + None + } +} + +fn parse_local_dependency_path(line: &str) -> Option<&str> { + let spec = line.split(';').next()?.trim(); + if is_local_path_spec(spec) { + Some(spec) + } else { + None + } +} + +fn is_local_path_spec(spec: &str) -> bool { + if spec.is_empty() { + return false; + } + if spec.starts_with("file:") { + return true; + } + if spec.contains("://") { + return false; + } + spec.starts_with("./") + || spec.starts_with("../") + || spec.starts_with('/') + || spec.starts_with("~/") + || spec.contains('/') + || spec.contains('\\') + || spec.ends_with(".whl") + || spec.ends_with(".tar.gz") + || spec.ends_with(".zip") +} + +fn resolve_requirement_path(reference: &str, parent: &Path) -> Result { + let normalized = reference.trim(); + if normalized.is_empty() { + bail!("empty requirements reference"); + } + + let candidate = if let Some(file) = normalized.strip_prefix("file://") { + PathBuf::from(file) + } else if let Some(file) = normalized.strip_prefix("file:") { + PathBuf::from(file) + } else if let Some(home_relative) = normalized.strip_prefix("~/") { + let home = + dirs::home_dir().ok_or_else(|| anyhow!("unable to resolve HOME for {}", normalized))?; + home.join(home_relative) + } else { + PathBuf::from(normalized) + }; + + let absolute = if candidate.is_absolute() { + candidate + } else { + parent.join(candidate) + }; + absolute + .canonicalize() + .with_context(|| format!("failed to resolve requirements reference {}", normalized)) +} + +fn ensure_path_within_allowed_roots( + path: &Path, + allowed_roots: &[PathBuf], + requirements_path: &Path, + line_number: usize, +) -> Result<()> { + if is_within_allowed_roots(path, allowed_roots) { + return Ok(()); + } + bail!( + "requirements reference escapes allowed roots at {}:{} -> {}", + requirements_path.display(), + line_number, + path.display() + ); +} + +fn validate_explicit_org_selection(base: &BaseArgs, available_orgs: &[AvailableOrg]) -> Result<()> { + let Some(explicit_org) = base + .org_name + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + else { + return Ok(()); + }; + + let exists = available_orgs + .iter() + .any(|org| org.name == explicit_org || org.name.eq_ignore_ascii_case(explicit_org)); + if exists { + return Ok(()); + } + + let available = available_orgs + .iter() + .map(|org| org.name.as_str()) + .collect::>() + .join(", "); + bail!("org '{explicit_org}' is not available for this credential. Available: {available}"); +} + +fn resolve_org_decision( + base: &BaseArgs, + auth_ctx: &super::AuthContext, + available_orgs: &[AvailableOrg], + selector_preview: &[String], + file_count: usize, +) -> Result<(OrgDecision, bool)> { + if base + .org_name + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .is_some() + { + return Ok((OrgDecision::Continue, false)); + } + + if available_orgs.len() <= 1 { + return Ok((OrgDecision::Continue, false)); + } + + if !is_interactive() { + bail!( + "multiple organizations are available for this credential; pass --org in non-interactive mode" + ); + } + + let org_label = current_org_label(auth_ctx); + let selector_label = if selector_preview.is_empty() { + "none".to_string() + } else { + selector_preview.join(", ") + }; + let prompt = format!( + "Push {file_count} file(s) with org '{org_label}'. Project selectors: [{selector_label}]" + ); + let options = [ + "Continue with current org".to_string(), + "Switch organization".to_string(), + "Cancel".to_string(), + ]; + let option_refs = options.iter().map(String::as_str).collect::>(); + let choice = fuzzy_select(&prompt, &option_refs, 0)?; + + match choice { + 0 => Ok((OrgDecision::Continue, true)), + 1 => { + let mut labels = Vec::with_capacity(available_orgs.len()); + let mut default_index = 0usize; + for (index, org) in available_orgs.iter().enumerate() { + let label = if org.api_url.is_some() { + format!("{} [{}]", org.name, org.id) + } else { + org.name.clone() + }; + if org.name == org_label || org.name.eq_ignore_ascii_case(&org_label) { + default_index = index; + } + labels.push(label); + } + let label_refs = labels.iter().map(String::as_str).collect::>(); + let selected_index = fuzzy_select("Select organization", &label_refs, default_index)?; + let selected = available_orgs + .get(selected_index) + .ok_or_else(|| anyhow!("invalid org selection"))?; + if selected.name == org_label || selected.name.eq_ignore_ascii_case(&org_label) { + Ok((OrgDecision::Continue, true)) + } else { + Ok((OrgDecision::Switch(selected.name.clone()), true)) + } + } + _ => Ok((OrgDecision::Cancel, true)), + } +} + +fn current_org_label(auth_ctx: &super::AuthContext) -> String { + if auth_ctx.client.org_name().trim().is_empty() { + auth_ctx.org_id.clone() + } else { + auth_ctx.client.org_name().to_string() + } +} + +fn cancel_push(base: &BaseArgs, files: &[PathBuf]) -> Result<()> { + if base.json { + let summary = PushSummary { + status: CommandStatus::Failed, + total_files: files.len(), + uploaded_files: 0, + failed_files: 0, + skipped_files: files.len(), + ignored_entries: 0, + files: files + .iter() + .map(|path| PushFileReport { + source_file: path.display().to_string(), + status: FileStatus::Skipped, + uploaded_entries: 0, + skipped_reason: Some(SoftSkipReason::TerminatedAfterFailure), + error_reason: Some(HardFailureReason::UserCancelled), + bundle_id: None, + message: Some("push cancelled by user".to_string()), + }) + .collect(), + warnings: vec![], + errors: vec![ReportError { + reason: HardFailureReason::UserCancelled, + message: "push cancelled by user".to_string(), + }], + }; + emit_summary(base, &summary)?; + } else { + eprintln!("Push cancelled. No changes were made."); + } + + bail!("push cancelled by user"); +} + +fn resolve_default_project_name(base: &BaseArgs) -> Result> { + let configured = base + .project + .clone() + .or_else(|| config::load().ok().and_then(|value| value.project)); + let Some(configured) = configured else { + return Ok(None); + }; + let trimmed = configured.trim(); + if trimmed.is_empty() { + bail!("default project name cannot be empty"); + } + Ok(Some(trimmed.to_string())) +} + +fn collect_project_preflight( + base: &BaseArgs, + manifest: &RunnerManifest, +) -> Result { + let default_project_name = resolve_default_project_name(base)?; + let mut requires_default_project = false; + let mut named_projects = BTreeSet::new(); + let mut direct_project_ids = BTreeSet::new(); + let mut selector_preview = BTreeSet::new(); + + for file in &manifest.files { + for entry in &file.entries { + let selector = match entry { + ManifestEntry::Code(code) => project_selector_for_code(code)?, + ManifestEntry::FunctionEvent(event) => { + let mut placeholders = BTreeSet::new(); + collect_project_name_placeholders_checked(&event.event, &mut placeholders)?; + named_projects.extend(placeholders); + project_selector_for_event(event)? + } + }; + + add_selector_requirement( + file, + entry_slug(entry)?, + &selector, + default_project_name.as_deref(), + &mut named_projects, + &mut direct_project_ids, + &mut selector_preview, + &mut requires_default_project, + )?; + } + } + + Ok(ProjectPreflight { + default_project_name, + requires_default_project, + named_projects, + direct_project_ids, + selector_preview: selector_preview.into_iter().collect(), + }) +} + +fn entry_slug(entry: &ManifestEntry) -> Result<&str> { + match entry { + ManifestEntry::Code(code) => Ok(code.slug.as_str()), + ManifestEntry::FunctionEvent(event) => event + .event + .get("slug") + .and_then(Value::as_str) + .map(str::trim) + .filter(|value| !value.is_empty()) + .ok_or_else(|| anyhow!("function_event missing non-empty slug")), + } +} + +fn add_selector_requirement( + file: &ManifestFile, + slug: &str, + selector: &ProjectSelector, + default_project_name: Option<&str>, + named_projects: &mut BTreeSet, + direct_project_ids: &mut BTreeSet, + selector_preview: &mut BTreeSet, + requires_default_project: &mut bool, +) -> Result<()> { + match selector { + ProjectSelector::Id(project_id) => { + direct_project_ids.insert(project_id.clone()); + selector_preview.insert(project_id.clone()); + } + ProjectSelector::Name(project_name) => { + named_projects.insert(project_name.clone()); + selector_preview.insert(format!("name:{project_name}")); + } + ProjectSelector::Fallback => { + let Some(default_project_name) = default_project_name else { + bail!( + "missing project for slug '{}' in '{}'; set project in the definition or pass --project", + slug, + file.source_file + ); + }; + *requires_default_project = true; + named_projects.insert(default_project_name.to_string()); + selector_preview.insert(format!("default:{default_project_name}")); + } + } + Ok(()) +} + +fn normalize_project_id_field(project_id: Option<&str>) -> Result> { + let Some(project_id) = project_id else { + return Ok(None); + }; + let trimmed = project_id.trim(); + if trimmed.is_empty() { + return Ok(None); + } + if let Some(name) = trimmed.strip_prefix("name:") { + let name = name.trim(); + if name.is_empty() { + bail!("invalid project selector '{trimmed}': expected non-empty name after 'name:'"); + } + return Ok(Some(format!("name:{name}"))); + } + Ok(Some(trimmed.to_string())) +} + +fn normalize_project_name_field(project_name: Option<&str>) -> Result> { + let Some(project_name) = project_name else { + return Ok(None); + }; + let trimmed = project_name.trim(); + if trimmed.is_empty() { + bail!("project_name cannot be empty when provided"); + } + Ok(Some(trimmed.to_string())) +} + +fn parse_project_selector( + project_id: Option<&str>, + project_name: Option<&str>, +) -> Result { + let normalized_id = normalize_project_id_field(project_id)?; + if let Some(project_id) = normalized_id { + if let Some(name) = project_id.strip_prefix("name:") { + return Ok(ProjectSelector::Name(name.to_string())); + } + return Ok(ProjectSelector::Id(project_id)); + } + + let normalized_name = normalize_project_name_field(project_name)?; + if let Some(project_name) = normalized_name { + return Ok(ProjectSelector::Name(project_name)); + } + + Ok(ProjectSelector::Fallback) +} + +fn project_selector_for_code(code: &CodeEntry) -> Result { + parse_project_selector(code.project_id.as_deref(), code.project_name.as_deref()) +} + +fn project_selector_for_event(event: &FunctionEventEntry) -> Result { + let event_project_id = + normalize_project_id_field(event.event.get("project_id").and_then(Value::as_str))?; + let entry_project_id = normalize_project_id_field(event.project_id.as_deref())?; + let entry_project_name = normalize_project_name_field(event.project_name.as_deref())?; + + if let Some(project_id) = event_project_id.or(entry_project_id) { + if let Some(name) = project_id.strip_prefix("name:") { + return Ok(ProjectSelector::Name(name.to_string())); + } + return Ok(ProjectSelector::Id(project_id)); + } + if let Some(project_name) = entry_project_name { + return Ok(ProjectSelector::Name(project_name)); + } + + Ok(ProjectSelector::Fallback) +} + +fn resolve_default_project_id( + preflight: &ProjectPreflight, + project_name_cache: &BTreeMap, +) -> Result> { + if !preflight.requires_default_project { + return Ok(None); + } + + let default_project_name = preflight + .default_project_name + .as_deref() + .ok_or_else(|| anyhow!("default project is required but not configured"))?; + let project_id = project_name_cache + .get(default_project_name) + .cloned() + .ok_or_else(|| anyhow!("default project '{default_project_name}' was not resolved"))?; + Ok(Some(project_id)) +} + +async fn resolve_and_ensure_named_projects( + auth_ctx: &super::AuthContext, + named_projects: &BTreeSet, + auto_create: bool, +) -> Result> { + let mut project_name_cache = BTreeMap::new(); + let mut missing = Vec::new(); + + for project_name in named_projects { + let project = get_project_by_name(&auth_ctx.client, project_name).await?; + if let Some(project) = project { + project_name_cache.insert(project_name.clone(), project.id); + } else { + missing.push(project_name.clone()); + } + } + + if missing.is_empty() { + return Ok(project_name_cache); + } + + if !auto_create && !is_interactive() { + let joined = missing.join(", "); + let org = current_org_label(auth_ctx); + bail!( + "project(s) not found in org '{org}': {joined}. Re-run with --create-missing-projects or create them first" + ); + } + + for project_name in missing { + let should_create = if auto_create { + true + } else { + Confirm::new() + .with_prompt(format!( + "Project '{}' does not exist in org '{}'. Create it?", + project_name, + current_org_label(auth_ctx) + )) + .default(false) + .interact()? + }; + + if !should_create { + bail!("project '{project_name}' is missing; push cancelled"); + } + + match create_project(&auth_ctx.client, &project_name).await { + Ok(project) => { + project_name_cache.insert(project_name.clone(), project.id); + } + Err(err) if is_http_conflict(&err) => { + let project = get_project_by_name(&auth_ctx.client, &project_name) + .await? + .ok_or_else(|| { + anyhow!( + "project '{}' already exists but could not be resolved after create conflict", + project_name + ) + })?; + project_name_cache.insert(project_name.clone(), project.id); + } + Err(err) => { + return Err(err).context(format!("failed to create project '{project_name}'")); + } + } + } + + Ok(project_name_cache) +} + +fn is_http_conflict(err: &anyhow::Error) -> bool { + err.downcast_ref::() + .is_some_and(|http| http.status == StatusCode::CONFLICT) +} + +async fn validate_direct_project_ids( + auth_ctx: &super::AuthContext, + direct_project_ids: &BTreeSet, +) -> Result<()> { + if direct_project_ids.is_empty() { + return Ok(()); + } + + let projects = list_projects(&auth_ctx.client).await?; + let known_project_ids = projects + .into_iter() + .map(|project| project.id) + .collect::>(); + + if let Some(inaccessible) = direct_project_ids + .iter() + .find(|project_id| !known_project_ids.contains(project_id.as_str())) + { + bail!( + "project_id '{}' is not accessible in org '{}'; verify --org and project selector", + inaccessible, + current_org_label(auth_ctx) + ); + } + + Ok(()) +} + +async fn resolve_manifest_targets( + auth_ctx: &super::AuthContext, + default_project_id: Option<&str>, + manifest: &RunnerManifest, + project_name_cache: &mut BTreeMap, +) -> Result { + let mut seen_project_ids = BTreeSet::new(); + let mut entries = Vec::new(); + let mut per_file = Vec::with_capacity(manifest.files.len()); + + for file in &manifest.files { + let mut entry_project_ids = Vec::with_capacity(file.entries.len()); + for entry in &file.entries { + let slug = entry_slug(entry)?.to_string(); + let selector = match entry { + ManifestEntry::Code(code) => project_selector_for_code(code)?, + ManifestEntry::FunctionEvent(event) => project_selector_for_event(event)?, + }; + let project_id = resolve_project_selector( + &auth_ctx.client, + default_project_id, + &selector, + project_name_cache, + ) + .await?; + seen_project_ids.insert(project_id.clone()); + entry_project_ids.push(project_id.clone()); + entries.push(ResolvedEntryTarget { + source_file: file.source_file.clone(), + slug, + project_id, + }); + } + + per_file.push(ResolvedFileTargets { + source_file: file.source_file.clone(), + entry_project_ids, + }); + } + + Ok(ResolvedManifestTargets { + entries, + per_file, + unique_project_ids: seen_project_ids.into_iter().collect(), + }) +} + +fn validate_duplicate_slugs(entries: &[ResolvedEntryTarget]) -> Result<()> { + let mut seen: BTreeMap<(String, String), String> = BTreeMap::new(); + for entry in entries { + if let Some(existing_file) = seen.get(&(entry.project_id.clone(), entry.slug.clone())) { + bail!( + "duplicate slug '{}' for project '{}' in files '{}' and '{}'", + entry.slug, + entry.project_id, + existing_file, + entry.source_file + ); + } + + seen.insert( + (entry.project_id.clone(), entry.slug.clone()), + entry.source_file.clone(), + ); + } + + Ok(()) +} + +async fn resolve_project_selector( + client: &crate::http::ApiClient, + default_project_id: Option<&str>, + selector: &ProjectSelector, + project_name_cache: &mut BTreeMap, +) -> Result { + match selector { + ProjectSelector::Id(project_id) => { + resolve_project_id( + client, + default_project_id, + Some(project_id.as_str()), + None, + project_name_cache, + ) + .await + } + ProjectSelector::Name(project_name) => { + resolve_project_id( + client, + default_project_id, + None, + Some(project_name.as_str()), + project_name_cache, + ) + .await + } + ProjectSelector::Fallback => { + resolve_project_id(client, default_project_id, None, None, project_name_cache).await + } + } +} + +async fn resolve_project_id( + client: &crate::http::ApiClient, + default_project_id: Option<&str>, + project_id: Option<&str>, + project_name: Option<&str>, + project_name_cache: &mut BTreeMap, +) -> Result { + let normalized_project_id = normalize_project_id_field(project_id)?; + if let Some(project_id) = normalized_project_id { + if let Some(name) = project_id.strip_prefix("name:") { + return resolve_project_name(client, name.trim(), project_name_cache).await; + } + return Ok(project_id); + } + + let normalized_project_name = normalize_project_name_field(project_name)?; + if let Some(project_name) = normalized_project_name { + return resolve_project_name(client, project_name.trim(), project_name_cache).await; + } + + default_project_id.map(ToOwned::to_owned).ok_or_else(|| { + anyhow!("project is required; set project in the definition or pass --project") + }) +} + +async fn resolve_project_name( + client: &crate::http::ApiClient, + project_name: &str, + project_name_cache: &mut BTreeMap, +) -> Result { + let project_name = project_name.trim(); + if project_name.is_empty() { + bail!("project name cannot be empty"); + } + + if let Some(cached) = project_name_cache.get(project_name) { + return Ok(cached.clone()); + } + + let project = get_project_by_name(client, project_name) + .await? + .ok_or_else(|| anyhow!("project '{project_name}' not found"))?; + + project_name_cache.insert(project_name.to_string(), project.id.clone()); + Ok(project.id) +} + +fn confirm_push_targets( + auth_ctx: &super::AuthContext, + target_project_ids: &[String], + file_count: usize, +) -> Result { + if !is_interactive() || target_project_ids.is_empty() { + return Ok(true); + } + + let org_label = if auth_ctx.client.org_name().is_empty() { + auth_ctx.org_id.clone() + } else { + auth_ctx.client.org_name().to_string() + }; + + let prompt = format!( + "Push {} file(s) to org '{}' and project(s) [{}]?", + file_count, + org_label, + target_project_ids.join(", ") + ); + let confirmed = Confirm::new() + .with_prompt(prompt) + .default(false) + .interact()?; + + Ok(confirmed) +} + +fn collect_project_name_placeholders_checked( + value: &Value, + out: &mut BTreeSet, +) -> Result<()> { + match value { + Value::Object(map) => { + for (key, value) in map { + if key == "project_id" { + if let Some(project_id) = value.as_str() { + if let Some(name) = project_id.strip_prefix("name:") { + let name = name.trim(); + if name.is_empty() { + bail!( + "invalid nested project selector 'name:' in function_event payload" + ); + } + out.insert(name.to_string()); + } + } + } + collect_project_name_placeholders_checked(value, out)?; + } + } + Value::Array(items) => { + for item in items { + collect_project_name_placeholders_checked(item, out)?; + } + } + _ => {} + } + Ok(()) +} + +fn replace_project_name_placeholders( + value: &mut Value, + project_name_to_id: &BTreeMap, +) { + match value { + Value::Object(map) => { + for (key, value) in map { + if key == "project_id" { + if let Some(project_id) = value.as_str() { + if let Some(name) = project_id.strip_prefix("name:") { + if let Some(resolved) = project_name_to_id.get(name.trim()) { + *value = Value::String(resolved.clone()); + continue; + } + } + } + } + replace_project_name_placeholders(value, project_name_to_id); + } + } + Value::Array(items) => { + for item in items { + replace_project_name_placeholders(item, project_name_to_id); + } + } + _ => {} + } +} + +fn gzip_bytes(bytes: &[u8]) -> Result> { + use std::io::Write; + + let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default()); + encoder + .write_all(bytes) + .context("failed to write gzip input bytes")?; + encoder.finish().context("failed to finalize gzip bytes") +} + +fn emit_summary(base: &BaseArgs, summary: &PushSummary) -> Result<()> { + if base.json { + println!("{}", serde_json::to_string(summary)?); + } else { + match summary.status { + CommandStatus::Success => { + eprintln!("Pushed {} file(s) successfully.", summary.uploaded_files); + } + CommandStatus::Partial => { + eprintln!( + "Pushed with partial success: uploaded={}, skipped={}, failed={}", + summary.uploaded_files, summary.skipped_files, summary.failed_files + ); + } + CommandStatus::Failed => { + eprintln!( + "Push failed: uploaded={}, skipped={}, failed={}", + summary.uploaded_files, summary.skipped_files, summary.failed_files + ); + } + } + for warning in &summary.warnings { + eprintln!( + "warning ({}): {}", + to_warning_code(warning), + warning.message + ); + } + for error in &summary.errors { + eprintln!("error ({}): {}", to_error_code(error), error.message); + } + } + Ok(()) +} + +fn to_warning_code(warning: &ReportWarning) -> &'static str { + match warning.reason { + super::report::WarningReason::PaginationNotSnapshotConsistent => { + "pagination_not_snapshot_consistent" + } + } +} + +fn to_error_code(error: &ReportError) -> &'static str { + match error.reason { + HardFailureReason::AuthFailed => "auth_failed", + HardFailureReason::RequestFailed => "request_failed", + HardFailureReason::ResponseInvalid => "response_invalid", + HardFailureReason::UserCancelled => "user_cancelled", + HardFailureReason::OutputDirInvalid => "output_dir_invalid", + HardFailureReason::AtomicWriteFailed => "atomic_write_failed", + HardFailureReason::UnsafeOutputPath => "unsafe_output_path", + HardFailureReason::RunnerSpawnFailed => "runner_spawn_failed", + HardFailureReason::RunnerExitNonzero => "runner_exit_nonzero", + HardFailureReason::ManifestInvalidJson => "manifest_invalid_json", + HardFailureReason::ManifestSchemaInvalid => "manifest_schema_invalid", + HardFailureReason::ManifestPathMissing => "manifest_path_missing", + HardFailureReason::UploadSlotFailed => "upload_slot_failed", + HardFailureReason::BundleUploadFailed => "bundle_upload_failed", + HardFailureReason::InsertFunctionsFailed => "insert_functions_failed", + HardFailureReason::SelectorNotFound => "selector_not_found", + HardFailureReason::PaginationUnsupported => "pagination_unsupported", + } +} + +fn fail_push( + base: &BaseArgs, + total_files: usize, + reason: HardFailureReason, + message: String, + file_message: &str, +) -> Result<()> { + if base.json { + let summary = PushSummary { + status: CommandStatus::Failed, + total_files, + uploaded_files: 0, + failed_files: 0, + skipped_files: total_files, + ignored_entries: 0, + files: vec![PushFileReport { + source_file: String::new(), + status: FileStatus::Failed, + uploaded_entries: 0, + skipped_reason: None, + error_reason: Some(reason), + bundle_id: None, + message: Some(file_message.to_string()), + }], + warnings: vec![], + errors: vec![ReportError { + reason, + message: message.clone(), + }], + }; + emit_summary(base, &summary)?; + } + + bail!(message); +} + +fn fail_push_with_all_skipped( + base: &BaseArgs, + files: &[PathBuf], + reason: HardFailureReason, + message: &str, + file_message: &str, +) -> Result<()> { + if base.json { + let summary = PushSummary { + status: CommandStatus::Failed, + total_files: files.len(), + uploaded_files: 0, + failed_files: 0, + skipped_files: files.len(), + ignored_entries: 0, + files: files + .iter() + .map(|path| PushFileReport { + source_file: path.display().to_string(), + status: FileStatus::Skipped, + uploaded_entries: 0, + skipped_reason: Some(SoftSkipReason::TerminatedAfterFailure), + error_reason: None, + bundle_id: None, + message: Some(file_message.to_string()), + }) + .collect(), + warnings: vec![], + errors: vec![ReportError { + reason, + message: message.to_string(), + }], + }; + emit_summary(base, &summary)?; + } + + bail!(message.to_string()); +} + +fn fail_push_manifest_preflight( + base: &BaseArgs, + files: &[PathBuf], + message: &str, + file_message: &str, +) -> Result<()> { + fail_push_with_all_skipped( + base, + files, + HardFailureReason::ManifestSchemaInvalid, + message, + file_message, + ) +} + +#[cfg(test)] +mod tests { + use crate::args::BaseArgs; + use crate::functions::IfExistsMode; + + use super::*; + + #[test] + fn supported_extension_filtering() { + assert_eq!( + classify_source_file(Path::new("a.ts")), + Some(SourceLanguage::JsLike) + ); + assert_eq!( + classify_source_file(Path::new("a.tsx")), + Some(SourceLanguage::JsLike) + ); + assert_eq!( + classify_source_file(Path::new("a.js")), + Some(SourceLanguage::JsLike) + ); + assert_eq!( + classify_source_file(Path::new("a.jsx")), + Some(SourceLanguage::JsLike) + ); + assert_eq!( + classify_source_file(Path::new("a.py")), + Some(SourceLanguage::Python) + ); + assert_eq!(classify_source_file(Path::new("a.txt")), None); + } + + #[test] + fn parse_project_selector_rejects_empty_name_prefix() { + let err = parse_project_selector(Some("name: "), None).expect_err("must fail"); + assert!(err.to_string().contains("non-empty name")); + } + + #[test] + fn fallback_selector_requires_default_project_name() { + let file = ManifestFile { + source_file: "a.ts".to_string(), + entries: vec![], + python_bundle: None, + }; + let mut named_projects = BTreeSet::new(); + let mut direct_project_ids = BTreeSet::new(); + let mut selector_preview = BTreeSet::new(); + let mut requires_default_project = false; + + let err = add_selector_requirement( + &file, + "same", + &ProjectSelector::Fallback, + None, + &mut named_projects, + &mut direct_project_ids, + &mut selector_preview, + &mut requires_default_project, + ) + .expect_err("must fail"); + assert!(err.to_string().contains("missing project")); + } + + #[test] + fn collect_project_preflight_uses_default_project_when_needed() { + let mut base = test_base_args(); + base.project = Some("demo-project".to_string()); + let manifest = RunnerManifest { + runtime_context: RuntimeContext { + runtime: "node".to_string(), + version: "20.0.0".to_string(), + }, + files: vec![ManifestFile { + source_file: "a.ts".to_string(), + entries: vec![ManifestEntry::Code(CodeEntry { + project_id: None, + project_name: None, + name: "A".to_string(), + slug: "same".to_string(), + description: None, + function_type: Some("tool".to_string()), + if_exists: None, + metadata: None, + location: None, + preview: None, + })], + python_bundle: None, + }], + }; + + let preflight = collect_project_preflight(&base, &manifest).expect("preflight"); + assert!(preflight.requires_default_project); + assert!( + preflight.named_projects.contains("demo-project"), + "default project should be included in named set" + ); + } + + #[test] + fn explicit_org_validation_rejects_unknown_org() { + let mut base = test_base_args(); + base.org_name = Some("missing-org".to_string()); + let orgs = vec![AvailableOrg { + id: "o1".to_string(), + name: "existing-org".to_string(), + api_url: None, + }]; + + let err = validate_explicit_org_selection(&base, &orgs).expect_err("must fail"); + assert!(err.to_string().contains("missing-org")); + } + + #[test] + fn select_push_language_auto_prefers_js_like_for_mixed_scan() { + let args = PushArgs { + files: vec![PathBuf::from(".")], + if_exists: IfExistsMode::Error, + terminate_on_failure: false, + runner: None, + language: PushLanguage::Auto, + requirements: None, + create_missing_projects: false, + }; + let classified = ClassifiedFiles { + js_like: vec![PathBuf::from("/tmp/a.ts")], + python: vec![PathBuf::from("/tmp/a.py")], + had_directory_inputs: true, + explicit_file_inputs: 0, + explicit_supported_files: 0, + explicit_js_like: 0, + explicit_python: 0, + allowed_roots: Vec::new(), + }; + + let selected = select_push_language(&args, &classified).expect("select language"); + assert_eq!(selected, SourceLanguage::JsLike); + } + + #[test] + fn select_push_language_rejects_mixed_explicit_files() { + let args = PushArgs { + files: vec![PathBuf::from("a.ts"), PathBuf::from("b.py")], + if_exists: IfExistsMode::Error, + terminate_on_failure: false, + runner: None, + language: PushLanguage::Auto, + requirements: None, + create_missing_projects: false, + }; + let classified = ClassifiedFiles { + js_like: vec![PathBuf::from("/tmp/a.ts")], + python: vec![PathBuf::from("/tmp/b.py")], + had_directory_inputs: false, + explicit_file_inputs: 2, + explicit_supported_files: 2, + explicit_js_like: 1, + explicit_python: 1, + allowed_roots: Vec::new(), + }; + + let err = select_push_language(&args, &classified).expect_err("must fail"); + assert!(err.to_string().contains("mixed source languages")); + } + + #[test] + fn placeholder_rewrite_updates_nested_project_ids() { + let mut value = serde_json::json!({ + "project_id": "name:alpha", + "nested": { + "tool": { + "project_id": "name:beta" + } + } + }); + + let mut mappings = BTreeMap::new(); + mappings.insert("alpha".to_string(), "p1".to_string()); + mappings.insert("beta".to_string(), "p2".to_string()); + + replace_project_name_placeholders(&mut value, &mappings); + + assert_eq!(value["project_id"], "p1"); + assert_eq!(value["nested"]["tool"]["project_id"], "p2"); + } + + #[test] + fn placeholder_rewrite_trims_nested_project_ids() { + let mut value = serde_json::json!({ + "project_id": "name: alpha", + "nested": { + "tool": { + "project_id": "name:\tbeta " + } + } + }); + + let mut mappings = BTreeMap::new(); + mappings.insert("alpha".to_string(), "p1".to_string()); + mappings.insert("beta".to_string(), "p2".to_string()); + + replace_project_name_placeholders(&mut value, &mappings); + + assert_eq!(value["project_id"], "p1"); + assert_eq!(value["nested"]["tool"]["project_id"], "p2"); + } + + #[test] + fn nested_placeholder_validation_rejects_empty_name() { + let value = serde_json::json!({ + "project_id": "name: " + }); + let mut placeholders = BTreeSet::new(); + let err = collect_project_name_placeholders_checked(&value, &mut placeholders) + .expect_err("must fail"); + assert!(err.to_string().contains("invalid nested project selector")); + } + + #[test] + fn upload_count_calculation_respects_ignored_entries() { + assert_eq!(calculate_upload_counts(3, Some(1)), (2, 1)); + assert_eq!(calculate_upload_counts(3, Some(10)), (0, 10)); + assert_eq!(calculate_upload_counts(3, None), (3, 0)); + } + + #[test] + fn requirements_reference_escape_is_rejected() { + let dir = tempfile::tempdir().expect("tempdir"); + let root = dir.path().join("root"); + std::fs::create_dir_all(&root).expect("create root"); + let req = root.join("requirements.txt"); + std::fs::write(&req, "-r ../outside.txt\n").expect("write requirements"); + let outside = dir.path().join("outside.txt"); + std::fs::write(&outside, "requests\n").expect("write outside"); + + let err = + validate_requirements_path(&req, std::slice::from_ref(&root)).expect_err("must fail"); + assert!(err.to_string().contains("escapes allowed roots")); + } + + #[test] + fn validate_manifest_paths_rejects_python_bundle_for_non_python_runtime() { + let dir = tempfile::tempdir().expect("tempdir"); + let source = dir.path().join("tool.js"); + std::fs::write(&source, "export const x = 1;\n").expect("write source file"); + let source = source.canonicalize().expect("canonicalize source"); + let root = dir.path().canonicalize().expect("canonicalize root"); + + let manifest = RunnerManifest { + runtime_context: RuntimeContext { + runtime: "node".to_string(), + version: "20.0.0".to_string(), + }, + files: vec![ManifestFile { + source_file: source.to_string_lossy().to_string(), + entries: vec![], + python_bundle: Some(PythonBundle { + entry_module: "tool".to_string(), + sources: vec![source.to_string_lossy().to_string()], + }), + }], + }; + + let err = validate_manifest_paths( + &manifest, + std::slice::from_ref(&source), + SourceLanguage::JsLike, + std::slice::from_ref(&root), + ) + .expect_err("must fail"); + assert_eq!(err.reason, HardFailureReason::ManifestSchemaInvalid); + assert!(err + .message + .contains("python_bundle metadata for non-Python")); + } + + #[test] + fn validate_manifest_paths_rejects_missing_python_bundle_for_code_entries() { + let dir = tempfile::tempdir().expect("tempdir"); + let source = dir.path().join("tool.py"); + std::fs::write(&source, "VALUE = 1\n").expect("write source file"); + let source = source.canonicalize().expect("canonicalize source"); + let root = dir.path().canonicalize().expect("canonicalize root"); + + let manifest = RunnerManifest { + runtime_context: RuntimeContext { + runtime: "python".to_string(), + version: "3.12.0".to_string(), + }, + files: vec![ManifestFile { + source_file: source.to_string_lossy().to_string(), + entries: vec![ManifestEntry::Code(CodeEntry { + project_id: None, + project_name: None, + name: "Tool".to_string(), + slug: "tool".to_string(), + description: None, + function_type: Some("tool".to_string()), + if_exists: None, + metadata: None, + location: Some(serde_json::json!({"type":"function","index":0})), + preview: None, + })], + python_bundle: None, + }], + }; + + let err = validate_manifest_paths( + &manifest, + std::slice::from_ref(&source), + SourceLanguage::Python, + std::slice::from_ref(&root), + ) + .expect_err("must fail"); + assert_eq!(err.reason, HardFailureReason::ManifestSchemaInvalid); + assert!(err.message.contains("missing python_bundle metadata")); + } + + #[test] + fn validate_manifest_paths_accepts_valid_python_bundle() { + let dir = tempfile::tempdir().expect("tempdir"); + let source = dir.path().join("tool.py"); + std::fs::write(&source, "VALUE = 1\n").expect("write source file"); + let source = source.canonicalize().expect("canonicalize source"); + let root = dir.path().canonicalize().expect("canonicalize root"); + + let manifest = RunnerManifest { + runtime_context: RuntimeContext { + runtime: "python".to_string(), + version: "3.12.0".to_string(), + }, + files: vec![ManifestFile { + source_file: source.to_string_lossy().to_string(), + entries: vec![ManifestEntry::Code(CodeEntry { + project_id: None, + project_name: None, + name: "Tool".to_string(), + slug: "tool".to_string(), + description: None, + function_type: Some("tool".to_string()), + if_exists: None, + metadata: None, + location: Some(serde_json::json!({"type":"function","index":0})), + preview: None, + })], + python_bundle: Some(PythonBundle { + entry_module: "tool".to_string(), + sources: vec![source.to_string_lossy().to_string()], + }), + }], + }; + + validate_manifest_paths( + &manifest, + std::slice::from_ref(&source), + SourceLanguage::Python, + std::slice::from_ref(&root), + ) + .expect("valid python bundle should pass validation"); + } + + #[test] + fn validate_manifest_paths_rejects_entry_module_mismatch() { + let dir = tempfile::tempdir().expect("tempdir"); + let source = dir.path().join("tool.py"); + std::fs::write(&source, "VALUE = 1\n").expect("write source file"); + let source = source.canonicalize().expect("canonicalize source"); + let root = dir.path().canonicalize().expect("canonicalize root"); + + let manifest = RunnerManifest { + runtime_context: RuntimeContext { + runtime: "python".to_string(), + version: "3.12.0".to_string(), + }, + files: vec![ManifestFile { + source_file: source.to_string_lossy().to_string(), + entries: vec![ManifestEntry::Code(CodeEntry { + project_id: None, + project_name: None, + name: "Tool".to_string(), + slug: "tool".to_string(), + description: None, + function_type: Some("tool".to_string()), + if_exists: None, + metadata: None, + location: Some(serde_json::json!({"type":"function","index":0})), + preview: None, + })], + python_bundle: Some(PythonBundle { + entry_module: "pkg.missing".to_string(), + sources: vec![source.to_string_lossy().to_string()], + }), + }], + }; + + let err = validate_manifest_paths( + &manifest, + std::slice::from_ref(&source), + SourceLanguage::Python, + std::slice::from_ref(&root), + ) + .expect_err("must fail"); + assert_eq!(err.reason, HardFailureReason::ManifestSchemaInvalid); + assert!(err + .message + .contains("does not match any bundled source module")); + } + + #[test] + fn code_function_data_includes_non_empty_preview() { + let runtime = RuntimeContext { + runtime: "python".to_string(), + version: "3.12".to_string(), + }; + let value = build_code_function_data( + &runtime, + serde_json::json!({"type": "function", "index": 0}), + "bundle-123", + Some("print('hello')"), + ); + + assert_eq!(value["type"], "code"); + assert_eq!(value["data"]["type"], "bundle"); + assert_eq!(value["data"]["bundle_id"], "bundle-123"); + assert_eq!(value["data"]["preview"], "print('hello')"); + } + + #[test] + fn code_function_data_omits_empty_preview() { + let runtime = RuntimeContext { + runtime: "node".to_string(), + version: "20.0.0".to_string(), + }; + let value = build_code_function_data( + &runtime, + serde_json::json!({"type": "function", "index": 1}), + "bundle-456", + Some(" "), + ); + + assert_eq!(value["type"], "code"); + assert!(value["data"].get("preview").is_none()); + } + + fn test_base_args() -> BaseArgs { + BaseArgs { + json: false, + quiet: false, + no_color: false, + profile: None, + org_name: None, + project: None, + api_key: None, + prefer_profile: false, + no_input: false, + api_url: None, + app_url: None, + env_file: None, + } + } +} From 093712966da2c09002eba9e371e77e2549ec2549 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Thu, 5 Mar 2026 16:39:56 -0800 Subject: [PATCH 04/28] Implement bt functions pull Add the pull pipeline for downloading Braintrust function definitions as local source files: - Paginated fetching with cursor-based pagination and snapshot consistency - Code generation for both TypeScript and Python with proper imports, typed prompt definitions, and recursive JSON value formatting - Safety checks: git dirty detection, existing file protection, and force flag override - Sanitized identifiers and filenames with Windows reserved name handling - Per-project directory organization with atomic file writes --- src/functions/pull.rs | 1792 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1792 insertions(+) create mode 100644 src/functions/pull.rs diff --git a/src/functions/pull.rs b/src/functions/pull.rs new file mode 100644 index 0000000..1ecb488 --- /dev/null +++ b/src/functions/pull.rs @@ -0,0 +1,1792 @@ +use std::cmp::Ordering; +use std::collections::{BTreeMap, BTreeSet}; +use std::fs::OpenOptions; +use std::path::{Path, PathBuf}; + +use anyhow::{anyhow, bail, Context, Result}; +use serde::Deserialize; +use serde_json::Value; + +use crate::args::BaseArgs; +use crate::functions::report::{ + CommandStatus, FileStatus, HardFailureReason, PullFileReport, PullSummary, ReportError, + ReportWarning, SoftSkipReason, WarningReason, +}; +use crate::projects::api::{list_projects, Project}; +use crate::utils::{write_text_atomic, GitRepo}; + +use super::api::{self, FunctionListQuery}; +use super::{resolve_auth_context, resolve_project_context, FunctionsLanguage, PullArgs}; + +const PAGINATION_PAGE_LIMIT: usize = 10_000; +const OUTPUT_LOCK_FILE: &str = ".bt-functions-pull.lock"; + +#[derive(Debug, Clone, Deserialize)] +struct PullFunctionRow { + id: String, + name: String, + slug: String, + project_id: String, + #[serde(default)] + project_name: Option, + #[serde(default)] + description: Option, + #[serde(default)] + prompt_data: Option, + #[serde(default)] + function_data: Option, + #[serde(default)] + created: Option, + #[serde(default)] + _xact_id: Option, +} + +#[derive(Debug, Clone)] +struct NormalizedPrompt { + variable_seed: String, + name: String, + slug: String, + description: Option, + prompt: Option, + messages: Option, + model: Option, + params: Option, + tools: Option, +} + +#[derive(Debug)] +struct OutputLock { + path: PathBuf, +} + +impl OutputLock { + fn acquire(output_dir: &Path) -> Result { + let path = output_dir.join(OUTPUT_LOCK_FILE); + OpenOptions::new() + .create_new(true) + .write(true) + .open(&path) + .with_context(|| { + format!( + "failed to acquire output lock {}; another pull may be running", + path.display() + ) + })?; + Ok(Self { path }) + } +} + +impl Drop for OutputLock { + fn drop(&mut self) { + let _ = std::fs::remove_file(&self.path); + } +} + +pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { + let mut summary = PullSummary { + status: CommandStatus::Success, + projects_total: 0, + files_written: 0, + files_skipped: 0, + files_failed: 0, + functions_seen: 0, + functions_materialized: 0, + malformed_records_skipped: 0, + unsupported_records_skipped: 0, + files: vec![], + warnings: vec![], + errors: vec![], + }; + let mut projects_cache: Option> = None; + + let auth_ctx = match resolve_auth_context(&base) + .await + .context("failed to resolve auth context") + { + Ok(ctx) => ctx, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::AuthFailed, + err.to_string(), + ); + } + }; + + let mut query = FunctionListQuery::default(); + + if let Some(project_id) = &args.project_id { + query.project_id = Some(project_id.clone()); + } else if let Some(project_name) = &args.project_name { + let projects = match get_projects_cached(&auth_ctx.client, &mut projects_cache).await { + Ok(projects) => projects, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + } + }; + if let Err(err) = ensure_unambiguous_project_name(projects, project_name) { + return fail_pull( + &base, + &mut summary, + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + } + query.project_name = Some(project_name.clone()); + } else { + let project = match resolve_project_context(&base, &auth_ctx) + .await + .context("failed to resolve default project context") + { + Ok(project) => project, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + } + }; + query.project_id = Some(project.id); + } + + if let Some(id) = &args.id { + query.id = Some(id.clone()); + } + if let Some(slug) = &args.slug { + query.slug = Some(slug.clone()); + } + let fetched = match fetch_all_function_rows(&auth_ctx.client, &query).await { + Ok(fetched) => fetched, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::PaginationUnsupported, + err.to_string(), + ); + } + }; + summary.functions_seen = fetched.rows.len(); + summary.warnings.extend(fetched.warnings); + + let mut parsed_rows = Vec::new(); + for raw_row in fetched.rows { + match serde_json::from_value::(raw_row) { + Ok(row) => parsed_rows.push(row), + Err(err) => { + summary.malformed_records_skipped += 1; + summary.files.push(PullFileReport { + output_file: String::new(), + status: FileStatus::Skipped, + skipped_reason: Some(SoftSkipReason::MalformedRecord), + error_reason: None, + message: Some(format!("skipped malformed function row: {err}")), + }); + } + } + } + + let narrowed_rows = match apply_selector_narrowing(parsed_rows, &args) { + Ok(rows) => rows, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::SelectorNotFound, + err.to_string(), + ); + } + }; + + let winners = select_winner_rows(narrowed_rows, &mut summary); + + if (args.id.is_some() || args.slug.is_some()) && winners.is_empty() { + return fail_pull( + &base, + &mut summary, + HardFailureReason::SelectorNotFound, + "no matching function rows found for selector".to_string(), + ); + } + + let mut materializable = Vec::new(); + for row in winners { + if is_prompt_row(&row) { + materializable.push(row); + } else { + summary.unsupported_records_skipped += 1; + } + } + + if (args.id.is_some() || args.slug.is_some()) && materializable.is_empty() { + return fail_pull( + &base, + &mut summary, + HardFailureReason::SelectorNotFound, + "selector matched records but none are materializable prompts".to_string(), + ); + } + + if args.slug.is_some() && materializable.len() > 1 { + return fail_pull( + &base, + &mut summary, + HardFailureReason::SelectorNotFound, + "slug selector matched multiple prompts; pass --project-name or --project-id" + .to_string(), + ); + } + + let output_dir = if args.output_dir.is_absolute() { + args.output_dir.clone() + } else { + std::env::current_dir() + .context("failed to resolve current directory")? + .join(&args.output_dir) + }; + + if let Err(err) = std::fs::create_dir_all(&output_dir) + .with_context(|| format!("failed to create output directory {}", output_dir.display())) + { + return fail_pull( + &base, + &mut summary, + HardFailureReason::OutputDirInvalid, + err.to_string(), + ); + } + + let canonical_output_dir = match output_dir + .canonicalize() + .with_context(|| format!("failed to canonicalize output dir {}", output_dir.display())) + { + Ok(path) => path, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::OutputDirInvalid, + err.to_string(), + ); + } + }; + + let _lock = match OutputLock::acquire(&canonical_output_dir) { + Ok(lock) => lock, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::OutputDirInvalid, + err.to_string(), + ); + } + }; + let repo = GitRepo::discover_from(&canonical_output_dir); + + let project_names = if materializable.is_empty() { + BTreeMap::new() + } else { + let projects = match get_projects_cached(&auth_ctx.client, &mut projects_cache).await { + Ok(projects) => projects, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + } + }; + match resolve_project_names(&materializable, projects) { + Ok(names) => names, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + } + } + }; + + let grouped_by_project = match group_rows_by_project(materializable, &project_names) { + Ok(grouped) => grouped, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + } + }; + summary.projects_total = grouped_by_project.len(); + + let ext = match args.language { + FunctionsLanguage::Typescript => "ts", + FunctionsLanguage::Python => "py", + }; + let file_names = match build_output_file_names(&grouped_by_project, args.slug.as_deref(), ext) { + Ok(file_names) => file_names, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::SelectorNotFound, + err.to_string(), + ); + } + }; + + for ((project_id, project_name), rows) in grouped_by_project { + let file_name = file_names + .get(&(project_id.clone(), project_name.clone())) + .ok_or_else(|| anyhow!("missing output file mapping"))? + .clone(); + + let target = canonical_output_dir.join(&file_name); + let display_target = display_output_path(&target); + if !target.starts_with(&canonical_output_dir) { + record_pull_file_failure( + &mut summary, + target.display().to_string(), + HardFailureReason::UnsafeOutputPath, + format!("refusing to write outside output dir: {}", target.display()), + ); + continue; + } + + let skip_reason = match should_skip_target(&repo, &target, args.force) { + Ok(reason) => reason, + Err(err) => { + record_pull_file_failure( + &mut summary, + target.display().to_string(), + HardFailureReason::RequestFailed, + err.to_string(), + ); + continue; + } + }; + if let Some(reason) = skip_reason { + summary.files_skipped += 1; + summary.files.push(PullFileReport { + output_file: target.display().to_string(), + status: FileStatus::Skipped, + skipped_reason: Some(reason), + error_reason: None, + message: None, + }); + continue; + } + + let rendered = + match render_project_file(args.language, &project_name, &display_target, &rows) { + Ok(rendered) => rendered, + Err(err) => { + record_pull_file_failure( + &mut summary, + target.display().to_string(), + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + continue; + } + }; + match write_text_atomic(&target, &rendered) { + Ok(()) => { + summary.files_written += 1; + summary.functions_materialized += rows.len(); + summary.files.push(PullFileReport { + output_file: target.display().to_string(), + status: FileStatus::Success, + skipped_reason: None, + error_reason: None, + message: None, + }); + } + Err(err) => { + record_pull_file_failure( + &mut summary, + target.display().to_string(), + HardFailureReason::AtomicWriteFailed, + err.to_string(), + ); + } + } + } + + if summary.status != CommandStatus::Failed + && (summary.files_skipped > 0 + || summary.unsupported_records_skipped > 0 + || summary.malformed_records_skipped > 0 + || !summary.warnings.is_empty()) + { + summary.status = CommandStatus::Partial; + } + + let failure = summary.status == CommandStatus::Failed; + emit_summary(&base, &summary)?; + if failure { + bail!("functions pull failed; see summary for details"); + } + + Ok(()) +} + +async fn get_projects_cached<'a>( + client: &crate::http::ApiClient, + cache: &'a mut Option>, +) -> Result<&'a [Project]> { + if cache.is_none() { + *cache = Some(list_projects(client).await?); + } + Ok(cache + .as_deref() + .expect("project cache should be initialized")) +} + +fn ensure_unambiguous_project_name(projects: &[Project], project_name: &str) -> Result<()> { + let exact: Vec<_> = projects + .iter() + .filter(|project| project.name == project_name) + .collect(); + + match exact.len() { + 0 => bail!("project '{project_name}' not found"), + 1 => Ok(()), + count => { + bail!("project-name '{project_name}' is ambiguous ({count} matches); use --project-id") + } + } +} + +struct FetchRowsResult { + rows: Vec, + warnings: Vec, +} + +async fn fetch_all_function_rows( + client: &crate::http::ApiClient, + query: &FunctionListQuery, +) -> Result { + let mut page_count = 0usize; + let mut rows = Vec::new(); + let mut cursor: Option = None; + let mut snapshot: Option = None; + let mut seen_cursors = BTreeSet::new(); + seen_cursors.insert("__start__".to_string()); + let mut warnings = Vec::new(); + let mut snapshot_consistent = true; + + loop { + if page_count >= PAGINATION_PAGE_LIMIT { + bail!("pagination page limit exceeded"); + } + + let mut page_query = query.clone(); + page_query.cursor = cursor.clone(); + page_query.snapshot = snapshot.clone(); + + let page = api::list_functions_page(client, &page_query).await?; + + if page_count == 0 && !page.pagination_field_present { + page_count += 1; + rows.extend(page.objects); + break; + } + + page_count += 1; + + if page_count == 1 { + snapshot = page.snapshot.clone(); + } else if snapshot.is_none() || !page.snapshot_field_present { + snapshot_consistent = false; + } + + if page.objects.is_empty() && page.next_cursor.is_some() { + bail!("pagination returned empty page with non-empty next cursor"); + } + + rows.extend(page.objects); + + let Some(next_cursor) = page.next_cursor else { + break; + }; + + if cursor.as_deref() == Some(next_cursor.as_str()) || seen_cursors.contains(&next_cursor) { + bail!("pagination cursor did not advance"); + } + seen_cursors.insert(next_cursor.clone()); + cursor = Some(next_cursor); + } + + if page_count > 1 && !snapshot_consistent { + warnings.push(ReportWarning { + reason: WarningReason::PaginationNotSnapshotConsistent, + message: "pagination endpoint does not appear to support snapshot-consistent traversal" + .to_string(), + }); + } + + Ok(FetchRowsResult { rows, warnings }) +} + +fn apply_selector_narrowing( + rows: Vec, + args: &PullArgs, +) -> Result> { + let narrowed = if let Some(id) = &args.id { + rows.into_iter() + .filter(|row| row.id == *id) + .collect::>() + } else if let Some(slug) = &args.slug { + rows.into_iter() + .filter(|row| row.slug == *slug) + .collect::>() + } else { + rows + }; + + if (args.id.is_some() || args.slug.is_some()) && narrowed.is_empty() { + bail!("selector did not match any function rows"); + } + + Ok(narrowed) +} + +fn select_winner_rows( + rows: Vec, + summary: &mut PullSummary, +) -> Vec { + let mut winners: BTreeMap<(String, String), PullFunctionRow> = BTreeMap::new(); + + for row in rows { + let key = (row.project_id.clone(), row.slug.clone()); + if let Some(existing) = winners.get_mut(&key) { + summary.files_skipped += 1; + if compare_rows_desc(&row, existing) == Ordering::Less { + *existing = row; + } + } else { + winners.insert(key, row); + } + } + + winners.into_values().collect() +} + +fn is_prompt_row(row: &PullFunctionRow) -> bool { + row.function_data + .as_ref() + .and_then(|data| data.get("type")) + .and_then(Value::as_str) + == Some("prompt") +} + +fn compare_rows_desc(left: &PullFunctionRow, right: &PullFunctionRow) -> Ordering { + let left_xact = left + ._xact_id + .as_deref() + .and_then(|value| value.parse::().ok()) + .unwrap_or(0); + let right_xact = right + ._xact_id + .as_deref() + .and_then(|value| value.parse::().ok()) + .unwrap_or(0); + + match right_xact.cmp(&left_xact) { + Ordering::Equal => {} + non_eq => return non_eq, + } + + match right.created.cmp(&left.created) { + Ordering::Equal => {} + non_eq => return non_eq, + } + + right.id.cmp(&left.id) +} + +fn resolve_project_names( + rows: &[PullFunctionRow], + projects: &[Project], +) -> Result> { + let mut names_by_id = BTreeMap::new(); + if rows.is_empty() { + return Ok(names_by_id); + } + + for project in projects { + names_by_id.insert(project.id.clone(), project.name.clone()); + } + + for row in rows { + if let Some(project_name) = row + .project_name + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + { + names_by_id + .entry(row.project_id.clone()) + .or_insert_with(|| project_name.to_string()); + } + } + + for row in rows { + if !names_by_id.contains_key(&row.project_id) { + bail!( + "failed to resolve project name for project id '{}'", + row.project_id + ); + } + } + + Ok(names_by_id) +} + +fn group_rows_by_project( + rows: Vec, + project_names: &BTreeMap, +) -> Result>> { + let mut grouped = BTreeMap::new(); + for row in rows { + let Some(project_name) = project_names.get(&row.project_id).cloned() else { + bail!( + "missing resolved project name for project id '{}'", + row.project_id + ); + }; + grouped + .entry((row.project_id.clone(), project_name)) + .or_insert_with(Vec::new) + .push(row); + } + Ok(grouped) +} + +fn build_project_file_names( + grouped_by_project: &BTreeMap<(String, String), Vec>, + ext: &str, +) -> BTreeMap<(String, String), String> { + let mut used_casefold = BTreeSet::new(); + let mut names = BTreeMap::new(); + + for (project_id, project_name) in grouped_by_project.keys() { + let base = sanitize_filename(project_name); + let mut candidate = if base.is_empty() { + "project".to_string() + } else { + base + }; + if is_reserved_filename(&candidate) { + candidate.push_str("-file"); + } + + let casefold = candidate.to_ascii_lowercase(); + if used_casefold.contains(&casefold) { + candidate = format!("{}-{}", candidate, sanitize_filename(project_id)); + } + + used_casefold.insert(candidate.to_ascii_lowercase()); + names.insert( + (project_id.clone(), project_name.clone()), + format!("{candidate}.{ext}"), + ); + } + + names +} + +fn build_output_file_names( + grouped_by_project: &BTreeMap<(String, String), Vec>, + slug_selector: Option<&str>, + ext: &str, +) -> Result> { + if let Some(slug) = slug_selector { + if grouped_by_project.len() != 1 { + bail!("slug selector matched multiple projects; pass --project-name or --project-id"); + } + let mut names = BTreeMap::new(); + let key = grouped_by_project + .keys() + .next() + .ok_or_else(|| anyhow!("missing grouped project for slug selector"))? + .clone(); + let base = sanitize_filename(slug); + let file_stem = if base.is_empty() { + "function".to_string() + } else { + base + }; + names.insert(key, format!("{file_stem}.{ext}")); + return Ok(names); + } + + Ok(build_project_file_names(grouped_by_project, ext)) +} + +fn sanitize_filename(value: &str) -> String { + let mut out = String::with_capacity(value.len()); + let mut previous_dash = false; + for ch in value.chars() { + let normalized = if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.' { + ch.to_ascii_lowercase() + } else { + '-' + }; + if normalized == '-' { + if !previous_dash { + out.push('-'); + previous_dash = true; + } + } else { + out.push(normalized); + previous_dash = false; + } + } + + out.trim_matches('-').to_string() +} + +fn is_reserved_filename(value: &str) -> bool { + matches!( + value.to_ascii_lowercase().as_str(), + "con" + | "prn" + | "aux" + | "nul" + | "com1" + | "com2" + | "com3" + | "com4" + | "com5" + | "com6" + | "com7" + | "com8" + | "com9" + | "lpt1" + | "lpt2" + | "lpt3" + | "lpt4" + | "lpt5" + | "lpt6" + | "lpt7" + | "lpt8" + | "lpt9" + ) +} + +fn sanitize_renderer_identifier( + seed: &str, + language: FunctionsLanguage, + used: &mut BTreeSet, +) -> String { + let mut candidate = match language { + FunctionsLanguage::Typescript => sanitize_typescript_identifier(seed), + FunctionsLanguage::Python => sanitize_python_identifier(seed), + }; + if used.contains(&candidate) { + let base = candidate.clone(); + let mut suffix = 1usize; + while used.contains(&candidate) { + candidate = format!("{base}_{suffix}"); + suffix += 1; + } + } + used.insert(candidate.clone()); + candidate +} + +fn sanitize_typescript_identifier(seed: &str) -> String { + let mut parts = Vec::new(); + let mut current = String::new(); + for ch in seed.chars() { + if ch.is_ascii_alphanumeric() || ch == '_' || ch == '$' { + current.push(ch); + } else if !current.is_empty() { + parts.push(current.clone()); + current.clear(); + } + } + if !current.is_empty() { + parts.push(current); + } + + if parts.is_empty() { + return "prompt".to_string(); + } + + let mut out = String::new(); + for (index, part) in parts.iter().enumerate() { + if index == 0 { + out.push_str(&part.to_ascii_lowercase()); + } else { + let mut chars = part.chars(); + if let Some(first) = chars.next() { + out.push(first.to_ascii_uppercase()); + } + out.push_str(&chars.as_str().to_ascii_lowercase()); + } + } + + if out.is_empty() { + return "prompt".to_string(); + } + if out + .chars() + .next() + .is_some_and(|first| first.is_ascii_digit()) + { + out.insert_str(0, "prompt"); + } + if out == "project" || out == "braintrust" { + out.push('_'); + } + out +} + +fn sanitize_python_identifier(seed: &str) -> String { + let mut out = String::with_capacity(seed.len()); + let mut previous_was_underscore = false; + for ch in seed.chars() { + let normalized = if ch.is_ascii_alphanumeric() { ch } else { '_' }; + if normalized == '_' { + if !previous_was_underscore { + out.push('_'); + } + previous_was_underscore = true; + } else { + out.push(normalized.to_ascii_lowercase()); + previous_was_underscore = false; + } + } + + let mut out = out.trim_matches('_').to_string(); + if out.is_empty() { + out = "prompt".to_string(); + } + if out + .chars() + .next() + .is_some_and(|first| first.is_ascii_digit()) + { + out.insert_str(0, "prompt_"); + } + if is_python_keyword(&out) || out == "project" || out == "braintrust" { + out.push('_'); + } + out +} + +fn is_python_keyword(value: &str) -> bool { + matches!( + value, + "false" + | "none" + | "true" + | "and" + | "as" + | "assert" + | "async" + | "await" + | "break" + | "class" + | "continue" + | "def" + | "del" + | "elif" + | "else" + | "except" + | "finally" + | "for" + | "from" + | "global" + | "if" + | "import" + | "in" + | "is" + | "lambda" + | "nonlocal" + | "not" + | "or" + | "pass" + | "raise" + | "return" + | "try" + | "while" + | "with" + | "yield" + ) +} + +fn should_skip_target( + repo: &Option, + target: &Path, + force: bool, +) -> Result> { + if force { + return Ok(None); + } + + if !target.exists() { + return Ok(None); + } + + let Some(repo) = repo else { + return Ok(Some(SoftSkipReason::ExistingNonGitNoForce)); + }; + + if !target.starts_with(repo.root()) { + return Ok(Some(SoftSkipReason::ExistingNonGitNoForce)); + } + + if repo.is_dirty_or_untracked(target)? { + return Ok(Some(SoftSkipReason::DirtyTarget)); + } + + Ok(None) +} + +fn display_output_path(target: &Path) -> String { + let cwd = match std::env::current_dir() { + Ok(cwd) => cwd, + Err(_) => return target.display().to_string(), + }; + + pathdiff::diff_paths(target, &cwd) + .filter(|path| !path.as_os_str().is_empty()) + .map(|path| path.display().to_string()) + .unwrap_or_else(|| target.display().to_string()) +} + +fn render_project_file( + language: FunctionsLanguage, + project_name: &str, + file_name: &str, + rows: &[PullFunctionRow], +) -> Result { + let mut sorted_rows = rows.to_vec(); + sorted_rows.sort_by(compare_rows_for_render); + + let mut normalized = Vec::with_capacity(sorted_rows.len()); + for row in &sorted_rows { + normalized.push(normalize_prompt_row(row)?); + } + + match language { + FunctionsLanguage::Typescript => { + render_project_file_ts(project_name, file_name, &normalized) + } + FunctionsLanguage::Python => render_project_file_py(project_name, file_name, &normalized), + } +} + +fn compare_rows_for_render(left: &PullFunctionRow, right: &PullFunctionRow) -> Ordering { + match left.slug.cmp(&right.slug) { + Ordering::Equal => {} + non_eq => return non_eq, + } + match left.name.cmp(&right.name) { + Ordering::Equal => {} + non_eq => return non_eq, + } + left.id.cmp(&right.id) +} + +fn normalize_prompt_row(row: &PullFunctionRow) -> Result { + let description = row + .description + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned); + + let prompt_data = row + .prompt_data + .as_ref() + .ok_or_else(|| anyhow!("prompt row '{}' missing prompt_data", row.slug))?; + let prompt_block = prompt_data + .get("prompt") + .ok_or_else(|| anyhow!("prompt row '{}' missing prompt_data.prompt", row.slug))?; + + let mut prompt = None; + let mut messages = None; + if prompt_block + .get("type") + .and_then(Value::as_str) + .is_some_and(|value| value == "completion") + { + if let Some(content) = prompt_block.get("content") { + if !is_empty_render_value(content) { + prompt = Some(content.clone()); + } + } + } else if prompt_block + .get("type") + .and_then(Value::as_str) + .is_some_and(|value| value == "chat") + { + if let Some(raw_messages) = prompt_block.get("messages") { + if !is_empty_render_value(raw_messages) { + messages = Some(raw_messages.clone()); + } + } + } + + let model = prompt_data + .get("options") + .and_then(|options| options.get("model")) + .filter(|value| !is_empty_render_value(value)) + .cloned(); + let params = prompt_data + .get("options") + .and_then(|options| options.get("params")) + .filter(|value| !is_empty_render_value(value)) + .cloned(); + + let mut tools: Vec = prompt_data + .get("tool_functions") + .and_then(Value::as_array) + .cloned() + .unwrap_or_default(); + if let Some(raw_tools) = prompt_block.get("tools").and_then(Value::as_str) { + if !raw_tools.trim().is_empty() { + if let Ok(parsed) = serde_json::from_str::(raw_tools) { + if let Some(items) = parsed.as_array() { + tools.extend(items.iter().cloned()); + } + } + } + } + let tools = if tools.is_empty() { + None + } else { + Some(Value::Array(tools)) + }; + + Ok(NormalizedPrompt { + variable_seed: row.slug.clone(), + name: row.name.clone(), + slug: row.slug.clone(), + description, + prompt, + messages, + model, + params, + tools, + }) +} + +fn render_project_file_ts( + project_name: &str, + file_name: &str, + prompts: &[NormalizedPrompt], +) -> Result { + let mut out = String::new(); + out.push_str("// This file was automatically generated by bt functions pull. You can\n"); + out.push_str("// generate it again by running:\n"); + out.push_str(&format!( + "// $ bt functions pull --project-name {}\n", + serde_json::to_string(project_name)? + )); + out.push_str( + "// Feel free to edit this file manually, but once you do, you should make sure to\n", + ); + out.push_str("// sync your changes with Braintrust by running:\n"); + out.push_str(&format!( + "// $ bt functions push --file {}\n\n", + serde_json::to_string(file_name)? + )); + + out.push_str("import braintrust from \"braintrust\";\n\n"); + out.push_str("const project = braintrust.projects.create({\n"); + out.push_str(&format!( + " name: {},\n", + serde_json::to_string(project_name)? + )); + out.push_str("});\n\n"); + + let mut seen_names = BTreeSet::new(); + + for row in prompts { + let var_name = sanitize_renderer_identifier( + &row.variable_seed, + FunctionsLanguage::Typescript, + &mut seen_names, + ); + + let mut body_lines = Vec::new(); + body_lines.push(format!(" name: {},", serde_json::to_string(&row.name)?)); + body_lines.push(format!(" slug: {},", serde_json::to_string(&row.slug)?)); + + if let Some(description) = &row.description { + body_lines.push(format!( + " description: {},", + serde_json::to_string(description)? + )); + } + + if let Some(prompt) = &row.prompt { + body_lines.push(format!(" prompt: {},", format_ts_value(prompt, 2))); + } + if let Some(messages) = &row.messages { + body_lines.push(format!(" messages: {},", format_ts_value(messages, 2))); + } + if let Some(model) = &row.model { + body_lines.push(format!(" model: {},", format_ts_value(model, 2))); + } + if let Some(params) = &row.params { + body_lines.push(format!(" params: {},", format_ts_value(params, 2))); + } + if let Some(tools) = &row.tools { + body_lines.push(format!(" tools: {},", format_ts_value(tools, 2))); + } + + out.push_str(&format!( + "export const {var_name} = project.prompts.create({{\n" + )); + out.push_str(&body_lines.join("\n")); + out.push_str("\n});\n\n"); + } + + Ok(out) +} + +fn render_project_file_py( + project_name: &str, + file_name: &str, + prompts: &[NormalizedPrompt], +) -> Result { + let mut out = String::new(); + out.push_str("# This file was automatically generated by bt functions pull. You can\n"); + out.push_str("# generate it again by running:\n"); + out.push_str(&format!( + "# $ bt functions pull --project-name {} --language python\n", + serde_json::to_string(project_name)? + )); + out.push_str( + "# Feel free to edit this file manually, but once you do, you should make sure to\n", + ); + out.push_str("# sync your changes with Braintrust by running:\n"); + out.push_str(&format!( + "# $ bt functions push --file {}\n\n", + serde_json::to_string(file_name)? + )); + out.push_str("import braintrust\n\n"); + out.push_str(&format!( + "project = braintrust.projects.create(name={})\n\n", + serde_json::to_string(project_name)? + )); + + let mut seen_names = BTreeSet::new(); + for row in prompts { + let var_name = sanitize_renderer_identifier( + &row.variable_seed, + FunctionsLanguage::Python, + &mut seen_names, + ); + out.push_str(&format!("{var_name} = project.prompts.create(\n")); + out.push_str(&format!( + " name={},\n", + serde_json::to_string(&row.name)? + )); + out.push_str(&format!( + " slug={},\n", + serde_json::to_string(&row.slug)? + )); + if let Some(description) = &row.description { + out.push_str(&format!( + " description={},\n", + serde_json::to_string(description)? + )); + } + if let Some(prompt) = &row.prompt { + out.push_str(&format!(" prompt={},\n", format_py_value(prompt, 4))); + } + if let Some(messages) = &row.messages { + out.push_str(&format!(" messages={},\n", format_py_value(messages, 4))); + } + if let Some(model) = &row.model { + out.push_str(&format!(" model={},\n", format_py_value(model, 4))); + } + if let Some(params) = &row.params { + out.push_str(&format!(" params={},\n", format_py_value(params, 4))); + } + if let Some(tools) = &row.tools { + out.push_str(&format!(" tools={},\n", format_py_value(tools, 4))); + } + out.push_str(")\n\n"); + } + + Ok(out) +} + +fn format_ts_value(value: &Value, indent: usize) -> String { + let json = format_ts_value_inner(value, 0); + let pad = " ".repeat(indent); + let mut lines = json.lines(); + let Some(first) = lines.next() else { + return "null".to_string(); + }; + + let mut out = first.to_string(); + for line in lines { + out.push('\n'); + out.push_str(&pad); + out.push_str(line); + } + out +} + +fn format_ts_value_inner(value: &Value, depth: usize) -> String { + match value { + Value::Null => "null".to_string(), + Value::Bool(boolean) => boolean.to_string(), + Value::Number(number) => number.to_string(), + Value::String(string) => { + serde_json::to_string(string).unwrap_or_else(|_| "\"\"".to_string()) + } + Value::Array(items) => { + if items.is_empty() { + return "[]".to_string(); + } + + let indent = " ".repeat(depth + 1); + let closing_indent = " ".repeat(depth); + let mut out = String::from("[\n"); + for (index, item) in items.iter().enumerate() { + out.push_str(&indent); + out.push_str(&format_ts_value_inner(item, depth + 1)); + if index + 1 < items.len() { + out.push(','); + } + out.push('\n'); + } + out.push_str(&closing_indent); + out.push(']'); + out + } + Value::Object(object) => { + if object.is_empty() { + return "{}".to_string(); + } + + let indent = " ".repeat(depth + 1); + let closing_indent = " ".repeat(depth); + let mut out = String::from("{\n"); + for (index, (key, val)) in object.iter().enumerate() { + out.push_str(&indent); + out.push_str(&format_ts_object_key(key)); + out.push_str(": "); + out.push_str(&format_ts_value_inner(val, depth + 1)); + if index + 1 < object.len() { + out.push(','); + } + out.push('\n'); + } + out.push_str(&closing_indent); + out.push('}'); + out + } + } +} + +fn format_ts_object_key(key: &str) -> String { + if should_unquote_object_key(key) { + key.to_string() + } else { + serde_json::to_string(key).unwrap_or_else(|_| "\"\"".to_string()) + } +} + +fn format_py_value(value: &Value, indent: usize) -> String { + let rendered = format_py_value_inner(value, 0); + let pad = " ".repeat(indent); + let mut lines = rendered.lines(); + let Some(first) = lines.next() else { + return "None".to_string(); + }; + + let mut out = first.to_string(); + for line in lines { + out.push('\n'); + out.push_str(&pad); + out.push_str(line); + } + out +} + +fn format_py_value_inner(value: &Value, depth: usize) -> String { + match value { + Value::Null => "None".to_string(), + Value::Bool(boolean) => { + if *boolean { + "True".to_string() + } else { + "False".to_string() + } + } + Value::Number(number) => number.to_string(), + Value::String(string) => { + serde_json::to_string(string).unwrap_or_else(|_| "\"\"".to_string()) + } + Value::Array(items) => { + if items.is_empty() { + return "[]".to_string(); + } + let indent = " ".repeat(depth + 1); + let closing_indent = " ".repeat(depth); + let mut out = String::from("[\n"); + for (index, item) in items.iter().enumerate() { + out.push_str(&indent); + out.push_str(&format_py_value_inner(item, depth + 1)); + if index + 1 < items.len() { + out.push(','); + } + out.push('\n'); + } + out.push_str(&closing_indent); + out.push(']'); + out + } + Value::Object(object) => { + if object.is_empty() { + return "{}".to_string(); + } + let indent = " ".repeat(depth + 1); + let closing_indent = " ".repeat(depth); + let mut out = String::from("{\n"); + let mut entries = object.iter().collect::>(); + entries.sort_by(|(left, _), (right, _)| left.cmp(right)); + for (index, (key, val)) in entries.into_iter().enumerate() { + out.push_str(&indent); + out.push_str(&serde_json::to_string(key).unwrap_or_else(|_| "\"\"".to_string())); + out.push_str(": "); + out.push_str(&format_py_value_inner(val, depth + 1)); + if index + 1 < object.len() { + out.push(','); + } + out.push('\n'); + } + out.push_str(&closing_indent); + out.push('}'); + out + } + } +} + +fn should_unquote_object_key(key: &str) -> bool { + if key.is_empty() || key == "__proto__" { + return false; + } + + let mut chars = key.chars(); + let Some(first) = chars.next() else { + return false; + }; + if !(first == '$' || first == '_' || first.is_ascii_alphabetic()) { + return false; + } + + chars.all(|ch| ch == '$' || ch == '_' || ch.is_ascii_alphanumeric()) +} + +fn is_empty_render_value(value: &Value) -> bool { + match value { + Value::Null => true, + Value::String(value) => value.trim().is_empty(), + Value::Array(value) => value.is_empty(), + Value::Object(value) => value.is_empty(), + Value::Bool(_) | Value::Number(_) => false, + } +} + +fn emit_summary(base: &BaseArgs, summary: &PullSummary) -> Result<()> { + if base.json { + println!("{}", serde_json::to_string(summary)?); + } else { + match summary.status { + CommandStatus::Success => { + eprintln!( + "Pulled {} file(s), materialized {} prompt(s).", + summary.files_written, summary.functions_materialized + ); + } + CommandStatus::Partial => { + eprintln!( + "Pull completed with partial results: written={}, skipped={}, failed={}", + summary.files_written, summary.files_skipped, summary.files_failed + ); + } + CommandStatus::Failed => { + eprintln!( + "Pull failed: written={}, skipped={}, failed={}", + summary.files_written, summary.files_skipped, summary.files_failed + ); + } + } + for warning in &summary.warnings { + eprintln!("warning: {}", warning.message); + } + for error in &summary.errors { + eprintln!("error: {}", error.message); + } + } + + Ok(()) +} + +fn fail_pull( + base: &BaseArgs, + summary: &mut PullSummary, + reason: HardFailureReason, + message: String, +) -> Result<()> { + summary.status = CommandStatus::Failed; + summary.errors.push(ReportError { + reason, + message: message.clone(), + }); + if base.json { + emit_summary(base, summary)?; + } + bail!(message); +} + +fn record_pull_file_failure( + summary: &mut PullSummary, + output_file: String, + reason: HardFailureReason, + message: String, +) { + summary.files_failed += 1; + summary.status = CommandStatus::Failed; + summary.errors.push(ReportError { + reason, + message: message.clone(), + }); + summary.files.push(PullFileReport { + output_file, + status: FileStatus::Failed, + skipped_reason: None, + error_reason: Some(reason), + message: Some(message), + }); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sanitize_identifier_helpers() { + assert_eq!(sanitize_typescript_identifier("my-prompt"), "myPrompt"); + assert_eq!(sanitize_typescript_identifier("1prompt"), "prompt1prompt"); + assert_eq!(sanitize_typescript_identifier("doc-search"), "docSearch"); + assert_eq!(sanitize_typescript_identifier("tt-6bb2"), "tt6bb2"); + assert_eq!(sanitize_python_identifier("1prompt"), "prompt_1prompt"); + assert_eq!(sanitize_python_identifier("class"), "class_"); + } + + #[test] + fn file_name_builder_handles_case_collisions() { + let mut grouped = BTreeMap::new(); + grouped.insert( + ("p1".to_string(), "Project".to_string()), + Vec::::new(), + ); + grouped.insert( + ("p2".to_string(), "project".to_string()), + Vec::::new(), + ); + + let names = build_project_file_names(&grouped, "ts"); + let first = names + .get(&("p1".to_string(), "Project".to_string())) + .expect("first"); + let second = names + .get(&("p2".to_string(), "project".to_string())) + .expect("second"); + + assert_ne!(first.to_ascii_lowercase(), second.to_ascii_lowercase()); + } + + #[test] + fn selector_narrowing_enforces_presence() { + let row = PullFunctionRow { + id: "f1".to_string(), + name: "Prompt".to_string(), + slug: "prompt".to_string(), + project_id: "p1".to_string(), + project_name: Some("Proj".to_string()), + description: None, + prompt_data: None, + function_data: None, + created: None, + _xact_id: None, + }; + let args = PullArgs { + output_dir: PathBuf::from("."), + language: FunctionsLanguage::Typescript, + project_name: None, + project_id: None, + id: Some("missing".to_string()), + slug: None, + force: false, + }; + + let err = apply_selector_narrowing(vec![row], &args).expect_err("should fail"); + assert!(err.to_string().contains("selector")); + } + + #[test] + fn group_rows_uses_resolved_project_name() { + let row = PullFunctionRow { + id: "f1".to_string(), + name: "Prompt".to_string(), + slug: "prompt".to_string(), + project_id: "p1".to_string(), + project_name: None, + description: None, + prompt_data: None, + function_data: None, + created: None, + _xact_id: None, + }; + + let mut names = BTreeMap::new(); + names.insert("p1".to_string(), "Woohoo".to_string()); + + let grouped = group_rows_by_project(vec![row], &names).expect("grouped"); + assert_eq!(grouped.len(), 1); + assert!(grouped.contains_key(&("p1".to_string(), "Woohoo".to_string()))); + } + + #[test] + fn group_rows_fails_when_project_name_missing() { + let row = PullFunctionRow { + id: "f1".to_string(), + name: "Prompt".to_string(), + slug: "prompt".to_string(), + project_id: "p1".to_string(), + project_name: None, + description: None, + prompt_data: None, + function_data: None, + created: None, + _xact_id: None, + }; + + let err = group_rows_by_project(vec![row], &BTreeMap::new()).expect_err("should fail"); + assert!(err.to_string().contains("project id")); + } + + #[test] + fn slug_selector_names_output_file_from_slug() { + let mut grouped = BTreeMap::new(); + grouped.insert( + ("p1".to_string(), "Project".to_string()), + Vec::::new(), + ); + + let names = + build_output_file_names(&grouped, Some("doc-search"), "ts").expect("file names"); + assert_eq!( + names + .get(&("p1".to_string(), "Project".to_string())) + .map(String::as_str), + Some("doc-search.ts") + ); + } + + #[test] + fn slug_selector_rejects_multiple_projects() { + let mut grouped = BTreeMap::new(); + grouped.insert( + ("p1".to_string(), "Project One".to_string()), + Vec::::new(), + ); + grouped.insert( + ("p2".to_string(), "Project Two".to_string()), + Vec::::new(), + ); + + let err = + build_output_file_names(&grouped, Some("doc-search"), "ts").expect_err("should fail"); + assert!(err.to_string().contains("multiple projects")); + } + + #[test] + fn render_project_file_matches_legacy_shape() { + let row = PullFunctionRow { + id: "f1".to_string(), + name: "Doc Search".to_string(), + slug: "doc-search".to_string(), + project_id: "p1".to_string(), + project_name: Some("woohoo".to_string()), + description: Some(String::new()), + prompt_data: Some(serde_json::json!({ + "prompt": { + "type": "chat", + "messages": [ + { "content": "Hello", "role": "system" } + ] + }, + "options": { + "model": "gpt-4o-mini" + }, + "tool_functions": [ + { "type": "function", "id": "tool-1" } + ] + })), + function_data: Some(serde_json::json!({ "type": "prompt" })), + created: None, + _xact_id: Some("123".to_string()), + }; + + let rendered = render_project_file( + FunctionsLanguage::Typescript, + "woohoo", + "braintrust/woohoo.ts", + &[row], + ) + .expect("rendered"); + + assert!(rendered.contains("automatically generated by bt functions pull")); + assert!(rendered.contains("bt functions pull --project-name \"woohoo\"")); + assert!(rendered.contains("bt functions push --file \"braintrust/woohoo.ts\"")); + assert!( + rendered.contains("const project = braintrust.projects.create({\n name: \"woohoo\",") + ); + assert!(rendered.contains("export const docSearch = project.prompts.create({")); + assert!(!rendered.contains("description: \"\",")); + assert!(!rendered.contains("version:")); + assert!(!rendered.contains("id: \"f1\"")); + } + + #[test] + fn render_project_file_python_shape() { + let row = PullFunctionRow { + id: "f1".to_string(), + name: "Doc Search".to_string(), + slug: "doc-search".to_string(), + project_id: "p1".to_string(), + project_name: Some("woohoo".to_string()), + description: Some("find docs".to_string()), + prompt_data: Some(serde_json::json!({ + "prompt": { + "type": "chat", + "messages": [ + { "content": "Hello", "role": "system" } + ] + }, + "options": { + "model": "gpt-4o-mini", + "params": { "temperature": 0 } + }, + "tool_functions": [ + { "type": "function", "id": "tool-1" } + ] + })), + function_data: Some(serde_json::json!({ "type": "prompt" })), + created: None, + _xact_id: Some("123".to_string()), + }; + + let rendered = render_project_file( + FunctionsLanguage::Python, + "woohoo", + "braintrust/woohoo.py", + &[row], + ) + .expect("rendered"); + + assert!(rendered.contains("bt functions pull --project-name \"woohoo\" --language python")); + assert!(rendered.contains("bt functions push --file \"braintrust/woohoo.py\"")); + assert!(rendered.contains("import braintrust")); + assert!(rendered.contains("project = braintrust.projects.create(name=\"woohoo\")")); + assert!(rendered.contains("doc_search = project.prompts.create(")); + assert!(rendered.contains("messages=[")); + assert!(rendered.contains("model=\"gpt-4o-mini\"")); + } + + #[test] + fn format_ts_value_unquotes_safe_keys_only() { + let value = serde_json::json!({ + "content": "Hello", + "role": "system", + "$valid_1": true, + "foo-bar": 1, + "__proto__": { "x": 1 } + }); + + let rendered = format_ts_value(&value, 0); + assert!(rendered.contains("content: \"Hello\"")); + assert!(rendered.contains("role: \"system\"")); + assert!(rendered.contains("$valid_1: true")); + assert!(rendered.contains("\"foo-bar\": 1")); + assert!(rendered.contains("\"__proto__\": {")); + assert!(!rendered.contains("\"content\":")); + assert!(!rendered.contains("\"role\":")); + } + + #[test] + fn format_py_value_maps_literals() { + let value = serde_json::json!({ + "null": null, + "bool_true": true, + "bool_false": false, + "items": [1, "x"] + }); + + let rendered = format_py_value(&value, 0); + assert!(rendered.contains("\"null\": None")); + assert!(rendered.contains("\"bool_true\": True")); + assert!(rendered.contains("\"bool_false\": False")); + assert!(rendered.contains("\"items\": [")); + } + + #[test] + fn is_empty_render_value_handles_supported_shapes() { + assert!(is_empty_render_value(&Value::Null)); + assert!(is_empty_render_value(&Value::String("".to_string()))); + assert!(is_empty_render_value(&Value::String(" ".to_string()))); + assert!(is_empty_render_value(&Value::Array(Vec::new()))); + assert!(is_empty_render_value( + &Value::Object(serde_json::Map::new()) + )); + + assert!(!is_empty_render_value(&Value::String("x".to_string()))); + assert!(!is_empty_render_value(&serde_json::json!(false))); + assert!(!is_empty_render_value(&serde_json::json!(0))); + assert!(!is_empty_render_value(&serde_json::json!([1]))); + assert!(!is_empty_render_value(&serde_json::json!({ "a": 1 }))); + } + + #[test] + fn display_output_path_prefers_relative_path_when_available() { + let cwd = std::env::current_dir().expect("cwd"); + let target = cwd.join("braintrust").join("woohoo.ts"); + let display = display_output_path(&target); + assert_eq!( + display, + Path::new("braintrust") + .join("woohoo.ts") + .display() + .to_string() + ); + } +} From 4f5c904c4c8eda8f9944bb185336e53e383e5f49 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Thu, 5 Mar 2026 16:40:31 -0800 Subject: [PATCH 05/28] Add functions push/pull tests and CLI fixtures Comprehensive test coverage for the push and pull commands: - 16 CLI fixture tests covering help text, flag validation, env var parsing, language selection, and argument conflict detection - Mock API server (actix-web) with handlers for login, projects, upload slots, bundle upload, and function insert/list - Integration tests for full push and pull flows against mock server - JS and Python runner manifest validation tests - Python bundle validation and cross-file module purge tests --- src/main.rs | 4 - .../pull-help-env-vars/fixture.json | 13 + .../pull-help-flags/fixture.json | 13 + .../pull-id-slug-conflict/fixture.json | 5 + .../pull-invalid-language/fixture.json | 8 + .../fixture.json | 12 + .../fixture.json | 6 + .../fixture.json | 15 + .../push-help-env-vars/fixture.json | 13 + .../push-help-flags/fixture.json | 13 + .../push-invalid-bool-env/fixture.json | 8 + .../push-invalid-language/fixture.json | 8 + .../push-multiple-files-accepted/fixture.json | 15 + .../fixture.json | 5 + .../push-reject-tsconfig/fixture.json | 5 + .../push-reject-type/fixture.json | 5 + .../push-valid-bool-env/fixture.json | 9 + .../fixture.json | 6 + .../fixture.json | 13 + .../fixture.json | 6 + tests/functions.rs | 1404 +++++++++++++++++ 21 files changed, 1582 insertions(+), 4 deletions(-) create mode 100644 tests/functions-fixtures/pull-help-env-vars/fixture.json create mode 100644 tests/functions-fixtures/pull-help-flags/fixture.json create mode 100644 tests/functions-fixtures/pull-id-slug-conflict/fixture.json create mode 100644 tests/functions-fixtures/pull-invalid-language/fixture.json create mode 100644 tests/functions-fixtures/pull-project-id-name-conflict/fixture.json create mode 100644 tests/functions-fixtures/pull-valid-language-python-parses/fixture.json create mode 100644 tests/functions-fixtures/pull-valid-language-typescript-parses/fixture.json create mode 100644 tests/functions-fixtures/push-help-env-vars/fixture.json create mode 100644 tests/functions-fixtures/push-help-flags/fixture.json create mode 100644 tests/functions-fixtures/push-invalid-bool-env/fixture.json create mode 100644 tests/functions-fixtures/push-invalid-language/fixture.json create mode 100644 tests/functions-fixtures/push-multiple-files-accepted/fixture.json create mode 100644 tests/functions-fixtures/push-reject-external-packages/fixture.json create mode 100644 tests/functions-fixtures/push-reject-tsconfig/fixture.json create mode 100644 tests/functions-fixtures/push-reject-type/fixture.json create mode 100644 tests/functions-fixtures/push-valid-bool-env/fixture.json create mode 100644 tests/functions-fixtures/push-valid-language-auto-parses/fixture.json create mode 100644 tests/functions-fixtures/push-valid-language-javascript-parses/fixture.json create mode 100644 tests/functions-fixtures/push-valid-language-python-parses/fixture.json create mode 100644 tests/functions.rs diff --git a/src/main.rs b/src/main.rs index a4e2496..db7760b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,10 +31,6 @@ mod ui; mod util_cmd; mod utils; -mod js_runner; -mod python_runner; -mod source_language; - use crate::args::{BaseArgs, CLIArgs}; const DEFAULT_CANARY_VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "-canary.dev"); diff --git a/tests/functions-fixtures/pull-help-env-vars/fixture.json b/tests/functions-fixtures/pull-help-env-vars/fixture.json new file mode 100644 index 0000000..34ab809 --- /dev/null +++ b/tests/functions-fixtures/pull-help-env-vars/fixture.json @@ -0,0 +1,13 @@ +{ + "command": ["functions", "pull", "--help"], + "expect_success": true, + "stdout_contains": [ + "BT_FUNCTIONS_PULL_OUTPUT_DIR", + "BT_FUNCTIONS_PULL_PROJECT_ID", + "BT_FUNCTIONS_PULL_PROJECT_NAME", + "BT_FUNCTIONS_PULL_ID", + "BT_FUNCTIONS_PULL_SLUG", + "BT_FUNCTIONS_PULL_FORCE", + "BT_FUNCTIONS_PULL_LANGUAGE" + ] +} diff --git a/tests/functions-fixtures/pull-help-flags/fixture.json b/tests/functions-fixtures/pull-help-flags/fixture.json new file mode 100644 index 0000000..1bd2716 --- /dev/null +++ b/tests/functions-fixtures/pull-help-flags/fixture.json @@ -0,0 +1,13 @@ +{ + "command": ["functions", "pull", "--help"], + "expect_success": true, + "stdout_contains": [ + "--output-dir", + "--project-id", + "--project-name", + "--id", + "--slug", + "--language", + "--force" + ] +} diff --git a/tests/functions-fixtures/pull-id-slug-conflict/fixture.json b/tests/functions-fixtures/pull-id-slug-conflict/fixture.json new file mode 100644 index 0000000..6d2b4d6 --- /dev/null +++ b/tests/functions-fixtures/pull-id-slug-conflict/fixture.json @@ -0,0 +1,5 @@ +{ + "command": ["functions", "pull", "--id", "abc", "--slug", "slug"], + "expect_success": false, + "stderr_contains": ["--id", "--slug"] +} diff --git a/tests/functions-fixtures/pull-invalid-language/fixture.json b/tests/functions-fixtures/pull-invalid-language/fixture.json new file mode 100644 index 0000000..34c1ff3 --- /dev/null +++ b/tests/functions-fixtures/pull-invalid-language/fixture.json @@ -0,0 +1,8 @@ +{ + "command": ["functions", "pull", "--language", "ruby"], + "expect_success": false, + "stderr_contains": [ + "invalid value 'ruby'", + "possible values: typescript, python" + ] +} diff --git a/tests/functions-fixtures/pull-project-id-name-conflict/fixture.json b/tests/functions-fixtures/pull-project-id-name-conflict/fixture.json new file mode 100644 index 0000000..07538fd --- /dev/null +++ b/tests/functions-fixtures/pull-project-id-name-conflict/fixture.json @@ -0,0 +1,12 @@ +{ + "command": [ + "functions", + "pull", + "--project-id", + "proj_123", + "--project-name", + "demo" + ], + "expect_success": false, + "stderr_contains": ["--project-id", "--project-name"] +} diff --git a/tests/functions-fixtures/pull-valid-language-python-parses/fixture.json b/tests/functions-fixtures/pull-valid-language-python-parses/fixture.json new file mode 100644 index 0000000..65380f0 --- /dev/null +++ b/tests/functions-fixtures/pull-valid-language-python-parses/fixture.json @@ -0,0 +1,6 @@ +{ + "command": ["functions", "pull", "--language", "python", "--help"], + "expect_success": true, + "stdout_contains": ["Usage:", "--language"], + "stderr_not_contains": ["invalid value 'python'"] +} diff --git a/tests/functions-fixtures/pull-valid-language-typescript-parses/fixture.json b/tests/functions-fixtures/pull-valid-language-typescript-parses/fixture.json new file mode 100644 index 0000000..4b53f14 --- /dev/null +++ b/tests/functions-fixtures/pull-valid-language-typescript-parses/fixture.json @@ -0,0 +1,15 @@ +{ + "command": [ + "functions", + "pull", + "--language", + "typescript", + "--id", + "abc", + "--slug", + "slug" + ], + "expect_success": false, + "stderr_contains": ["--id", "--slug"], + "stderr_not_contains": ["invalid value 'typescript'"] +} diff --git a/tests/functions-fixtures/push-help-env-vars/fixture.json b/tests/functions-fixtures/push-help-env-vars/fixture.json new file mode 100644 index 0000000..ec52bc0 --- /dev/null +++ b/tests/functions-fixtures/push-help-env-vars/fixture.json @@ -0,0 +1,13 @@ +{ + "command": ["functions", "push", "--help"], + "expect_success": true, + "stdout_contains": [ + "BT_FUNCTIONS_PUSH_FILES", + "BT_FUNCTIONS_PUSH_IF_EXISTS", + "BT_FUNCTIONS_PUSH_TERMINATE_ON_FAILURE", + "BT_FUNCTIONS_PUSH_RUNNER", + "BT_FUNCTIONS_PUSH_LANGUAGE", + "BT_FUNCTIONS_PUSH_REQUIREMENTS", + "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS" + ] +} diff --git a/tests/functions-fixtures/push-help-flags/fixture.json b/tests/functions-fixtures/push-help-flags/fixture.json new file mode 100644 index 0000000..af5772f --- /dev/null +++ b/tests/functions-fixtures/push-help-flags/fixture.json @@ -0,0 +1,13 @@ +{ + "command": ["functions", "push", "--help"], + "expect_success": true, + "stdout_contains": [ + "--file", + "--if-exists", + "--terminate-on-failure", + "--create-missing-projects", + "--language", + "--requirements", + "--runner" + ] +} diff --git a/tests/functions-fixtures/push-invalid-bool-env/fixture.json b/tests/functions-fixtures/push-invalid-bool-env/fixture.json new file mode 100644 index 0000000..fb531cf --- /dev/null +++ b/tests/functions-fixtures/push-invalid-bool-env/fixture.json @@ -0,0 +1,8 @@ +{ + "command": ["functions", "push"], + "env": { + "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS": "notabool" + }, + "expect_success": false, + "stderr_contains": ["--create-missing-projects", "value was not a boolean"] +} diff --git a/tests/functions-fixtures/push-invalid-language/fixture.json b/tests/functions-fixtures/push-invalid-language/fixture.json new file mode 100644 index 0000000..f7e5887 --- /dev/null +++ b/tests/functions-fixtures/push-invalid-language/fixture.json @@ -0,0 +1,8 @@ +{ + "command": ["functions", "push", "--language", "typescript"], + "expect_success": false, + "stderr_contains": [ + "invalid value 'typescript'", + "possible values: auto, javascript, python" + ] +} diff --git a/tests/functions-fixtures/push-multiple-files-accepted/fixture.json b/tests/functions-fixtures/push-multiple-files-accepted/fixture.json new file mode 100644 index 0000000..2e7af6b --- /dev/null +++ b/tests/functions-fixtures/push-multiple-files-accepted/fixture.json @@ -0,0 +1,15 @@ +{ + "command": [ + "functions", + "push", + "--file", + "a.ts", + "--file", + "b.ts", + "--type", + "tool" + ], + "expect_success": false, + "stderr_contains": ["--type"], + "stderr_not_contains": ["unexpected argument '--file'"] +} diff --git a/tests/functions-fixtures/push-reject-external-packages/fixture.json b/tests/functions-fixtures/push-reject-external-packages/fixture.json new file mode 100644 index 0000000..818959f --- /dev/null +++ b/tests/functions-fixtures/push-reject-external-packages/fixture.json @@ -0,0 +1,5 @@ +{ + "command": ["functions", "push", "--external-packages", "react"], + "expect_success": false, + "stderr_contains": ["--external-packages"] +} diff --git a/tests/functions-fixtures/push-reject-tsconfig/fixture.json b/tests/functions-fixtures/push-reject-tsconfig/fixture.json new file mode 100644 index 0000000..c65a197 --- /dev/null +++ b/tests/functions-fixtures/push-reject-tsconfig/fixture.json @@ -0,0 +1,5 @@ +{ + "command": ["functions", "push", "--tsconfig", "./tsconfig.json"], + "expect_success": false, + "stderr_contains": ["--tsconfig"] +} diff --git a/tests/functions-fixtures/push-reject-type/fixture.json b/tests/functions-fixtures/push-reject-type/fixture.json new file mode 100644 index 0000000..e1579b1 --- /dev/null +++ b/tests/functions-fixtures/push-reject-type/fixture.json @@ -0,0 +1,5 @@ +{ + "command": ["functions", "push", "--type", "tool"], + "expect_success": false, + "stderr_contains": ["--type"] +} diff --git a/tests/functions-fixtures/push-valid-bool-env/fixture.json b/tests/functions-fixtures/push-valid-bool-env/fixture.json new file mode 100644 index 0000000..51c3ce0 --- /dev/null +++ b/tests/functions-fixtures/push-valid-bool-env/fixture.json @@ -0,0 +1,9 @@ +{ + "command": ["functions", "push", "--language", "typescript"], + "env": { + "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS": "true" + }, + "expect_success": false, + "stderr_contains": ["invalid value 'typescript'"], + "stderr_not_contains": ["value was not a boolean"] +} diff --git a/tests/functions-fixtures/push-valid-language-auto-parses/fixture.json b/tests/functions-fixtures/push-valid-language-auto-parses/fixture.json new file mode 100644 index 0000000..1ce328f --- /dev/null +++ b/tests/functions-fixtures/push-valid-language-auto-parses/fixture.json @@ -0,0 +1,6 @@ +{ + "command": ["functions", "push", "--language", "auto", "--type", "tool"], + "expect_success": false, + "stderr_contains": ["--type"], + "stderr_not_contains": ["invalid value 'auto'"] +} diff --git a/tests/functions-fixtures/push-valid-language-javascript-parses/fixture.json b/tests/functions-fixtures/push-valid-language-javascript-parses/fixture.json new file mode 100644 index 0000000..b95a673 --- /dev/null +++ b/tests/functions-fixtures/push-valid-language-javascript-parses/fixture.json @@ -0,0 +1,13 @@ +{ + "command": [ + "functions", + "push", + "--language", + "javascript", + "--type", + "tool" + ], + "expect_success": false, + "stderr_contains": ["--type"], + "stderr_not_contains": ["invalid value 'javascript'"] +} diff --git a/tests/functions-fixtures/push-valid-language-python-parses/fixture.json b/tests/functions-fixtures/push-valid-language-python-parses/fixture.json new file mode 100644 index 0000000..f684eed --- /dev/null +++ b/tests/functions-fixtures/push-valid-language-python-parses/fixture.json @@ -0,0 +1,6 @@ +{ + "command": ["functions", "push", "--language", "python", "--type", "tool"], + "expect_success": false, + "stderr_contains": ["--type"], + "stderr_not_contains": ["invalid value 'python'"] +} diff --git a/tests/functions.rs b/tests/functions.rs new file mode 100644 index 0000000..c1f62f4 --- /dev/null +++ b/tests/functions.rs @@ -0,0 +1,1404 @@ +use std::collections::BTreeMap; +use std::fs; +use std::io::Read; +use std::net::TcpListener; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::sync::{Arc, Mutex}; + +use actix_web::{web, App, HttpRequest, HttpResponse, HttpServer}; +use flate2::read::GzDecoder; +use serde::Deserialize; +use serde_json::Value; +use tempfile::tempdir; + +#[derive(Debug, Deserialize)] +struct FixtureConfig { + command: Vec, + #[serde(default)] + env: BTreeMap, + #[serde(default = "default_expect_success")] + expect_success: bool, + #[serde(default)] + stdout_contains: Vec, + #[serde(default)] + stderr_contains: Vec, + #[serde(default)] + stdout_not_contains: Vec, + #[serde(default)] + stderr_not_contains: Vec, + #[serde(default)] + live: bool, + #[serde(default)] + required_env: Vec, +} + +fn default_expect_success() -> bool { + true +} + +fn repo_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) +} + +fn bt_binary_path() -> PathBuf { + if let Ok(path) = std::env::var("CARGO_BIN_EXE_bt") { + return PathBuf::from(path); + } + + let root = repo_root(); + let candidate = root.join("target").join("debug").join("bt"); + if !candidate.is_file() { + build_bt_binary(&root); + } + candidate +} + +fn build_bt_binary(root: &Path) { + let status = Command::new("cargo") + .args(["build", "--bin", "bt"]) + .current_dir(root) + .status() + .expect("cargo build --bin bt"); + if !status.success() { + panic!("cargo build --bin bt failed"); + } +} + +fn find_python() -> Option { + for candidate in ["python3", "python"] { + let Ok(status) = Command::new(candidate).arg("--version").status() else { + continue; + }; + if status.success() { + return Some(candidate.to_string()); + } + } + None +} + +fn command_exists(command: &str) -> bool { + let Some(paths) = std::env::var_os("PATH") else { + return false; + }; + + for dir in std::env::split_paths(&paths) { + let candidate = dir.join(command); + if candidate.is_file() { + return true; + } + if cfg!(windows) { + let exe = candidate.with_extension("exe"); + if exe.is_file() { + return true; + } + let cmd = candidate.with_extension("cmd"); + if cmd.is_file() { + return true; + } + } + } + + false +} + +fn find_tsc() -> Option { + let local = if cfg!(windows) { + repo_root() + .join("node_modules") + .join(".bin") + .join("tsc.cmd") + } else { + repo_root().join("node_modules").join(".bin").join("tsc") + }; + if local.is_file() { + return Some(local); + } + + if command_exists("tsc") { + return Some(PathBuf::from("tsc")); + } + + None +} + +fn read_fixture_config(path: &Path) -> FixtureConfig { + let raw = fs::read_to_string(path).expect("read fixture.json"); + serde_json::from_str(&raw).expect("parse fixture.json") +} + +fn env_flag(name: &str) -> bool { + match std::env::var(name) { + Ok(value) => matches!( + value.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ), + Err(_) => false, + } +} + +fn sanitized_env_keys() -> &'static [&'static str] { + &[ + "BT_FUNCTIONS_PUSH_FILES", + "BT_FUNCTIONS_PUSH_IF_EXISTS", + "BT_FUNCTIONS_PUSH_TERMINATE_ON_FAILURE", + "BT_FUNCTIONS_PUSH_RUNNER", + "BT_FUNCTIONS_PUSH_LANGUAGE", + "BT_FUNCTIONS_PUSH_REQUIREMENTS", + "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS", + "BT_FUNCTIONS_PULL_OUTPUT_DIR", + "BT_FUNCTIONS_PULL_PROJECT_ID", + "BT_FUNCTIONS_PULL_PROJECT_NAME", + "BT_FUNCTIONS_PULL_ID", + "BT_FUNCTIONS_PULL_SLUG", + "BT_FUNCTIONS_PULL_FORCE", + "BT_FUNCTIONS_PULL_LANGUAGE", + ] +} + +#[derive(Debug, Clone)] +struct MockProject { + id: String, + name: String, + org_id: String, +} + +#[derive(Default)] +struct MockServerState { + requests: Mutex>, + projects: Mutex>, + pull_rows: Mutex>, + uploaded_bundles: Mutex>>, + inserted_functions: Mutex>, + bundle_counter: Mutex, +} + +struct MockServer { + base_url: String, + handle: actix_web::dev::ServerHandle, +} + +impl MockServer { + async fn start(state: Arc) -> Self { + let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind mock server"); + let addr = listener.local_addr().expect("mock server addr"); + let base_url = format!("http://{addr}"); + let data = web::Data::new(state.clone()); + + let server = HttpServer::new(move || { + App::new() + .app_data(data.clone()) + .route("/api/apikey/login", web::post().to(mock_login)) + .route("/v1/project", web::get().to(mock_list_projects)) + .route("/v1/project", web::post().to(mock_create_project)) + .route("/function/code", web::post().to(mock_request_code_slot)) + .route("/upload/{bundle_id}", web::put().to(mock_upload_bundle)) + .route("/insert-functions", web::post().to(mock_insert_functions)) + .route("/v1/function", web::get().to(mock_list_functions)) + }) + .listen(listener) + .expect("listen mock server") + .run(); + let handle = server.handle(); + tokio::spawn(server); + + Self { base_url, handle } + } + + async fn stop(&self) { + self.handle.stop(true).await; + } +} + +async fn mock_login(state: web::Data>, req: HttpRequest) -> HttpResponse { + log_request(&state, &req); + let base = request_base_url(&req); + HttpResponse::Ok().json(serde_json::json!({ + "org_info": [ + { + "id": "org_mock", + "name": "test-org", + "api_url": base + } + ] + })) +} + +async fn mock_list_projects( + state: web::Data>, + req: HttpRequest, +) -> HttpResponse { + log_request(&state, &req); + let query = parse_query(req.query_string()); + let requested_name = query.get("project_name").cloned(); + let projects = state.projects.lock().expect("projects lock").clone(); + let objects = projects + .into_iter() + .filter(|project| { + requested_name + .as_deref() + .is_none_or(|name| project.name == name) + }) + .map(|project| { + serde_json::json!({ + "id": project.id, + "name": project.name, + "org_id": project.org_id + }) + }) + .collect::>(); + HttpResponse::Ok().json(serde_json::json!({ "objects": objects })) +} + +#[derive(Deserialize)] +struct CreateProjectRequest { + name: String, + org_name: String, +} + +async fn mock_create_project( + state: web::Data>, + req: HttpRequest, + body: web::Json, +) -> HttpResponse { + log_request(&state, &req); + let mut projects = state.projects.lock().expect("projects lock"); + if let Some(existing) = projects.iter().find(|project| project.name == body.name) { + return HttpResponse::Ok().json(serde_json::json!({ + "id": existing.id, + "name": existing.name, + "org_id": existing.org_id + })); + } + + let created = MockProject { + id: format!("proj_created_{}", projects.len() + 1), + name: body.name.clone(), + org_id: body.org_name.clone(), + }; + projects.push(created.clone()); + HttpResponse::Ok().json(serde_json::json!({ + "id": created.id, + "name": created.name, + "org_id": created.org_id + })) +} + +async fn mock_request_code_slot( + state: web::Data>, + req: HttpRequest, +) -> HttpResponse { + log_request(&state, &req); + let mut counter = state.bundle_counter.lock().expect("bundle counter lock"); + *counter += 1; + let bundle_id = format!("bundle-{counter}"); + let base = request_base_url(&req); + let upload_url = format!("{base}/upload/{bundle_id}"); + HttpResponse::Ok().json(serde_json::json!({ + "url": upload_url, + "bundleId": bundle_id + })) +} + +async fn mock_upload_bundle( + state: web::Data>, + req: HttpRequest, + body: web::Bytes, +) -> HttpResponse { + log_request(&state, &req); + state + .uploaded_bundles + .lock() + .expect("uploaded bundles lock") + .push(body.to_vec()); + HttpResponse::Ok().finish() +} + +#[derive(Deserialize)] +struct InsertFunctionsRequest { + functions: Vec, +} + +async fn mock_insert_functions( + state: web::Data>, + req: HttpRequest, + body: web::Json, +) -> HttpResponse { + log_request(&state, &req); + let mut inserted = state + .inserted_functions + .lock() + .expect("inserted functions lock"); + inserted.extend(body.functions.clone()); + HttpResponse::Ok().json(serde_json::json!({ "ignored_count": 0 })) +} + +async fn mock_list_functions( + state: web::Data>, + req: HttpRequest, +) -> HttpResponse { + log_request(&state, &req); + let query = parse_query(req.query_string()); + let id = query.get("ids").cloned(); + let slug = query.get("slug").cloned(); + let project_id = query.get("project_id").cloned(); + + let rows = state.pull_rows.lock().expect("pull rows lock").clone(); + let filtered = rows + .into_iter() + .filter(|row| { + id.as_deref() + .is_none_or(|needle| row.get("id").and_then(Value::as_str) == Some(needle)) + }) + .filter(|row| { + slug.as_deref() + .is_none_or(|needle| row.get("slug").and_then(Value::as_str) == Some(needle)) + }) + .filter(|row| { + project_id + .as_deref() + .is_none_or(|needle| row.get("project_id").and_then(Value::as_str) == Some(needle)) + }) + .collect::>(); + + HttpResponse::Ok().json(serde_json::json!({ + "objects": filtered + })) +} + +fn log_request(state: &Arc, req: &HttpRequest) { + let entry = if req.query_string().is_empty() { + req.path().to_string() + } else { + format!("{}?{}", req.path(), req.query_string()) + }; + state.requests.lock().expect("requests lock").push(entry); +} + +fn request_base_url(req: &HttpRequest) -> String { + let info = req.connection_info(); + format!("{}://{}", info.scheme(), info.host()) +} + +fn parse_query(query: &str) -> BTreeMap { + let mut values = BTreeMap::new(); + for pair in query.split('&') { + if pair.is_empty() { + continue; + } + let (raw_key, raw_value) = pair.split_once('=').unwrap_or((pair, "")); + let key = urlencoding::decode(raw_key) + .map(|value| value.into_owned()) + .unwrap_or_else(|_| raw_key.to_string()); + let value = urlencoding::decode(raw_value) + .map(|value| value.into_owned()) + .unwrap_or_else(|_| raw_value.to_string()); + values.insert(key, value); + } + values +} + +#[test] +fn functions_fixtures() { + let root = repo_root(); + let fixtures_root = root.join("tests").join("functions-fixtures"); + if !fixtures_root.exists() { + eprintln!("No functions fixtures found."); + return; + } + + let bt_path = bt_binary_path(); + let run_live = env_flag("BT_FUNCTIONS_FIXTURE_LIVE"); + + let mut fixture_dirs: Vec = fs::read_dir(&fixtures_root) + .expect("read functions fixture root") + .filter_map(|entry| entry.ok()) + .map(|entry| entry.path()) + .filter(|path| path.is_dir()) + .collect(); + fixture_dirs.sort(); + + let mut ran_any = false; + for dir in fixture_dirs { + let config_path = dir.join("fixture.json"); + if !config_path.is_file() { + continue; + } + ran_any = true; + + let fixture_name = dir + .file_name() + .map(|name| name.to_string_lossy().to_string()) + .expect("fixture directory name"); + let config = read_fixture_config(&config_path); + if config.command.is_empty() { + panic!("Fixture {fixture_name} has an empty command."); + } + + if config.live && !run_live { + eprintln!("Skipping {fixture_name} (live fixture; set BT_FUNCTIONS_FIXTURE_LIVE=1)."); + continue; + } + + let missing_required: Vec = config + .required_env + .iter() + .filter(|key| std::env::var(key.as_str()).is_err()) + .cloned() + .collect(); + if !missing_required.is_empty() { + eprintln!( + "Skipping {fixture_name} (missing required env: {}).", + missing_required.join(", ") + ); + continue; + } + + let mut cmd = Command::new(&bt_path); + cmd.args(&config.command).current_dir(&dir); + for key in sanitized_env_keys() { + cmd.env_remove(key); + } + for (key, value) in &config.env { + cmd.env(key, value); + } + + let output = cmd + .output() + .unwrap_or_else(|err| panic!("failed to run fixture {fixture_name}: {err}")); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + if output.status.success() != config.expect_success { + panic!( + "Fixture {fixture_name} command {:?} had status {} (expected success={})\nstdout:\n{}\nstderr:\n{}", + config.command, + output.status, + config.expect_success, + stdout, + stderr + ); + } + + for expected in &config.stdout_contains { + assert!( + stdout.contains(expected), + "Fixture {fixture_name}: stdout missing expected text: {expected}\nstdout:\n{stdout}" + ); + } + for expected in &config.stderr_contains { + assert!( + stderr.contains(expected), + "Fixture {fixture_name}: stderr missing expected text: {expected}\nstderr:\n{stderr}" + ); + } + for unexpected in &config.stdout_not_contains { + assert!( + !stdout.contains(unexpected), + "Fixture {fixture_name}: stdout unexpectedly contained text: {unexpected}\nstdout:\n{stdout}" + ); + } + for unexpected in &config.stderr_not_contains { + assert!( + !stderr.contains(unexpected), + "Fixture {fixture_name}: stderr unexpectedly contained text: {unexpected}\nstderr:\n{stderr}" + ); + } + } + + if !ran_any { + eprintln!("No functions fixtures with fixture.json found."); + } +} + +#[test] +fn functions_push_help_includes_expected_flags() { + let output = Command::new(bt_binary_path()) + .arg("functions") + .arg("push") + .arg("--help") + .output() + .expect("run bt functions push --help"); + + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("--file")); + assert!(stdout.contains("--if-exists")); + assert!(stdout.contains("--terminate-on-failure")); + assert!(stdout.contains("--create-missing-projects")); + assert!(stdout.contains("--language")); + assert!(stdout.contains("--requirements")); +} + +#[test] +fn functions_pull_help_includes_expected_flags() { + let output = Command::new(bt_binary_path()) + .arg("functions") + .arg("pull") + .arg("--help") + .output() + .expect("run bt functions pull --help"); + + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("--output-dir")); + assert!(stdout.contains("--project-id")); + assert!(stdout.contains("--project-name")); + assert!(stdout.contains("--language")); +} + +#[test] +fn functions_pull_id_and_slug_conflict() { + let output = Command::new(bt_binary_path()) + .arg("functions") + .arg("pull") + .arg("--id") + .arg("abc") + .arg("--slug") + .arg("slug") + .output() + .expect("run conflicting pull command"); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!(stderr.contains("--id")); + assert!(stderr.contains("--slug")); +} + +#[test] +fn functions_push_rejects_type_flag() { + let output = Command::new(bt_binary_path()) + .arg("functions") + .arg("push") + .arg("--type") + .arg("tool") + .output() + .expect("run push with invalid --type"); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!(stderr.contains("--type")); +} + +#[test] +fn functions_pull_rejects_invalid_language() { + let output = Command::new(bt_binary_path()) + .arg("functions") + .arg("pull") + .arg("--language") + .arg("ruby") + .output() + .expect("run pull with invalid language"); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!(stderr.contains("ruby")); +} + +#[test] +fn functions_push_rejects_invalid_language() { + let output = Command::new(bt_binary_path()) + .arg("functions") + .arg("push") + .arg("--language") + .arg("typescript") + .output() + .expect("run push with invalid language"); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!(stderr.contains("typescript")); +} + +#[test] +fn functions_help_lists_push_and_pull() { + let output = Command::new(bt_binary_path()) + .arg("functions") + .arg("--help") + .output() + .expect("run bt functions --help"); + + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("push")); + assert!(stdout.contains("pull")); +} + +#[test] +fn push_and_pull_help_are_machine_readable() { + let push_help = Command::new(bt_binary_path()) + .arg("functions") + .arg("push") + .arg("--help") + .output() + .expect("run push help"); + assert!(push_help.status.success()); + + let pull_help = Command::new(bt_binary_path()) + .arg("functions") + .arg("pull") + .arg("--help") + .output() + .expect("run pull help"); + assert!(pull_help.status.success()); + + let push_stdout = String::from_utf8_lossy(&push_help.stdout); + let pull_stdout = String::from_utf8_lossy(&pull_help.stdout); + assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_FILES")); + assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS")); + assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_LANGUAGE")); + assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_REQUIREMENTS")); + assert!(pull_stdout.contains("BT_FUNCTIONS_PULL_OUTPUT_DIR")); + assert!(pull_stdout.contains("BT_FUNCTIONS_PULL_LANGUAGE")); +} + +#[test] +fn functions_python_runner_scripts_compile_when_python_available() { + let Some(python) = find_python() else { + eprintln!( + "Skipping functions_python_runner_scripts_compile_when_python_available (python not installed)." + ); + return; + }; + + let root = repo_root(); + let output = Command::new(&python) + .arg("-m") + .arg("py_compile") + .arg(root.join("scripts").join("functions-runner.py")) + .arg(root.join("scripts").join("python_runner_common.py")) + .output() + .expect("run py_compile"); + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("Python runner scripts failed py_compile:\n{stderr}"); + } +} + +#[test] +fn functions_python_runner_collects_function_type_from_type_() { + let Some(python) = find_python() else { + eprintln!( + "Skipping functions_python_runner_collects_function_type_from_type_ (python not installed)." + ); + return; + }; + + let root = repo_root(); + let scripts_dir = root.join("scripts"); + let runner_script = scripts_dir.join("functions-runner.py"); + let snippet = r#" +import importlib.util +import json +import pathlib +import sys + +runner_path = pathlib.Path(sys.argv[1]) +spec = importlib.util.spec_from_file_location("functions_runner", runner_path) +if spec is None or spec.loader is None: + raise RuntimeError(f"failed to load {runner_path}") +module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(module) + +class TypeEnum: + value = "tool" + +class Item: + def __init__(self): + self.name = "my-tool" + self.slug = "my-tool" + self.type_ = TypeEnum() + self.preview = "def handler(x):\\n return x" + +entries = module.collect_code_entries([Item()]) +print(json.dumps(entries)) +"#; + + let output = Command::new(&python) + .env("PYTHONPATH", &scripts_dir) + .arg("-c") + .arg(snippet) + .arg(&runner_script) + .output() + .expect("run functions-runner collect_code_entries regression script"); + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("Python runner regression script failed:\n{stderr}"); + } + + let stdout = String::from_utf8(output.stdout).expect("stdout utf-8"); + let entries: Vec = + serde_json::from_str(stdout.trim()).expect("parse entries JSON from regression script"); + let first = entries.first().expect("first entry"); + assert_eq!( + first.get("function_type").and_then(Value::as_str), + Some("tool") + ); + assert_eq!( + first.get("preview").and_then(Value::as_str), + Some("def handler(x):\\n return x") + ); +} + +#[test] +fn functions_js_runner_emits_valid_manifest() { + if !command_exists("node") { + eprintln!("Skipping functions_js_runner_emits_valid_manifest (node not installed)."); + return; + } + let Some(tsc) = find_tsc() else { + eprintln!("Skipping functions_js_runner_emits_valid_manifest (tsc not installed)."); + return; + }; + + let root = repo_root(); + let tmp = tempdir().expect("tempdir"); + let sample_path = tmp.path().join("sample.js"); + std::fs::write( + &sample_path, + r#"globalThis._evals ??= { functions: [], prompts: [], parameters: [], evaluators: {}, reporters: {} }; +globalThis._evals.functions.push({ + name: "js-tool", + slug: "js-tool", + type: "tool", + preview: "export function handler() { return 1; }" +}); +"#, + ) + .expect("write sample.js"); + + let runner_dir = tmp.path().join("runner"); + let compile_output = Command::new(&tsc) + .current_dir(&root) + .args([ + "scripts/functions-runner.ts", + "scripts/runner-common.ts", + "--module", + "esnext", + "--target", + "es2020", + "--moduleResolution", + "bundler", + "--outDir", + ]) + .arg(&runner_dir) + .output() + .expect("compile functions runner"); + if !compile_output.status.success() { + let stderr = String::from_utf8_lossy(&compile_output.stderr); + panic!("tsc failed for functions runner:\n{stderr}"); + } + + let runner_js = runner_dir.join("functions-runner.js"); + let runner_common_js = runner_dir.join("runner-common.js"); + assert!(runner_js.is_file(), "compiled functions-runner.js missing"); + assert!( + runner_common_js.is_file(), + "compiled runner-common.js missing" + ); + + let runner_code = std::fs::read_to_string(&runner_js).expect("read compiled runner"); + let patched_runner_code = runner_code + .replace("\"./runner-common\"", "\"./runner-common.js\"") + .replace("'./runner-common'", "'./runner-common.js'"); + assert_ne!( + runner_code, patched_runner_code, + "compiled runner import path did not contain ./runner-common" + ); + std::fs::write(&runner_js, patched_runner_code).expect("write patched compiled runner"); + std::fs::write(runner_dir.join("package.json"), r#"{ "type": "module" }"#) + .expect("write runner package.json"); + + let output = Command::new("node") + .arg(&runner_js) + .arg(&sample_path) + .output() + .expect("run compiled functions runner"); + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("compiled functions runner failed:\n{stderr}"); + } + + let manifest: Value = serde_json::from_slice(&output.stdout).expect("parse manifest JSON"); + assert_eq!( + manifest["runtime_context"]["runtime"].as_str(), + Some("node"), + "runtime_context.runtime should be node" + ); + assert!( + manifest["runtime_context"]["version"] + .as_str() + .is_some_and(|value| !value.trim().is_empty()), + "runtime_context.version should be present" + ); + + let files = manifest["files"].as_array().expect("files array"); + assert_eq!(files.len(), 1, "expected one manifest file"); + let file = files[0].as_object().expect("manifest file object"); + let reported_source = PathBuf::from( + file.get("source_file") + .and_then(Value::as_str) + .expect("source_file"), + ); + assert_eq!( + reported_source + .canonicalize() + .expect("canonicalize source_file"), + sample_path + .canonicalize() + .expect("canonicalize sample file"), + "manifest source_file mismatch" + ); + assert!( + file.get("python_bundle").is_none(), + "JS runner should not emit python_bundle" + ); + + let entries = file + .get("entries") + .and_then(Value::as_array) + .expect("entries array"); + assert_eq!(entries.len(), 1, "expected one code entry"); + let entry = entries[0].as_object().expect("entry object"); + assert_eq!(entry.get("kind").and_then(Value::as_str), Some("code")); + assert_eq!(entry.get("name").and_then(Value::as_str), Some("js-tool")); + assert_eq!(entry.get("slug").and_then(Value::as_str), Some("js-tool")); + assert_eq!( + entry.get("function_type").and_then(Value::as_str), + Some("tool") + ); + assert_eq!( + entry.get("preview").and_then(Value::as_str), + Some("export function handler() { return 1; }") + ); + assert_eq!( + entry + .get("location") + .and_then(Value::as_object) + .and_then(|value| value.get("type")) + .and_then(Value::as_str), + Some("function") + ); +} + +#[test] +fn functions_python_runner_emits_valid_manifest_with_bundle() { + let Some(python) = find_python() else { + eprintln!( + "Skipping functions_python_runner_emits_valid_manifest_with_bundle (python not installed)." + ); + return; + }; + + let root = repo_root(); + let scripts_dir = root.join("scripts"); + let runner_script = scripts_dir.join("functions-runner.py"); + let tmp = tempdir().expect("tempdir"); + let stub_root = tmp.path().join("stub"); + let framework_dir = stub_root.join("braintrust").join("framework2"); + std::fs::create_dir_all(&framework_dir).expect("create stub framework dir"); + std::fs::write(stub_root.join("braintrust").join("__init__.py"), "").expect("write __init__"); + std::fs::write(framework_dir.join("__init__.py"), "").expect("write framework __init__"); + std::fs::write( + framework_dir.join("global_.py"), + "functions = []\nprompts = []\n", + ) + .expect("write global_.py"); + std::fs::write( + framework_dir.join("lazy_load.py"), + "from contextlib import nullcontext\n\ndef _set_lazy_load(_enabled):\n return nullcontext()\n", + ) + .expect("write lazy_load.py"); + + let sample_path = tmp.path().join("sample_tool.py"); + std::fs::write( + &sample_path, + r#"from braintrust.framework2.global_ import functions + +class TypeEnum: + value = "tool" + +class Item: + def __init__(self): + self.name = "py-tool" + self.slug = "py-tool" + self.type_ = TypeEnum() + self.preview = "def handler(x):\n return x" + +functions.append(Item()) +"#, + ) + .expect("write sample_tool.py"); + + let mut python_path_entries = vec![stub_root.clone()]; + if let Some(existing) = std::env::var_os("PYTHONPATH") { + python_path_entries.extend(std::env::split_paths(&existing)); + } + let python_path = std::env::join_paths(python_path_entries).expect("join PYTHONPATH"); + + let output = Command::new(&python) + .env("PYTHONPATH", python_path) + .arg(&runner_script) + .arg(&sample_path) + .output() + .expect("run python functions runner"); + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("python functions runner failed:\n{stderr}"); + } + + let manifest: Value = serde_json::from_slice(&output.stdout).expect("parse manifest JSON"); + assert_eq!( + manifest["runtime_context"]["runtime"].as_str(), + Some("python"), + "runtime_context.runtime should be python" + ); + assert!( + manifest["runtime_context"]["version"] + .as_str() + .is_some_and(|value| !value.trim().is_empty()), + "runtime_context.version should be present" + ); + + let files = manifest["files"].as_array().expect("files array"); + assert_eq!(files.len(), 1, "expected one manifest file"); + let file = files[0].as_object().expect("manifest file object"); + let reported_source = PathBuf::from( + file.get("source_file") + .and_then(Value::as_str) + .expect("source_file"), + ); + let expected_source = sample_path + .canonicalize() + .expect("canonicalize sample file"); + assert_eq!( + reported_source + .canonicalize() + .expect("canonicalize source_file"), + expected_source, + "manifest source_file mismatch" + ); + + let entries = file + .get("entries") + .and_then(Value::as_array) + .expect("entries array"); + assert_eq!(entries.len(), 1, "expected one code entry"); + let entry = entries[0].as_object().expect("entry object"); + assert_eq!(entry.get("kind").and_then(Value::as_str), Some("code")); + assert_eq!(entry.get("name").and_then(Value::as_str), Some("py-tool")); + assert_eq!(entry.get("slug").and_then(Value::as_str), Some("py-tool")); + assert_eq!( + entry.get("function_type").and_then(Value::as_str), + Some("tool") + ); + assert_eq!( + entry.get("preview").and_then(Value::as_str), + Some("def handler(x):\n return x") + ); + + let bundle = file + .get("python_bundle") + .and_then(Value::as_object) + .expect("python_bundle object"); + assert!( + bundle + .get("entry_module") + .and_then(Value::as_str) + .is_some_and(|value| !value.trim().is_empty()), + "python_bundle.entry_module should be present" + ); + let sources = bundle + .get("sources") + .and_then(Value::as_array) + .expect("python_bundle.sources array"); + assert!( + !sources.is_empty(), + "python_bundle.sources should include source files" + ); + let source_paths = sources + .iter() + .filter_map(Value::as_str) + .map(PathBuf::from) + .map(|path| path.canonicalize().expect("canonicalize bundled source")) + .collect::>(); + assert!( + source_paths.contains(&expected_source), + "python_bundle.sources should include sample file" + ); +} + +#[test] +fn python_runner_common_purge_prevents_cross_file_source_leakage() { + let Some(python) = find_python() else { + eprintln!( + "Skipping python_runner_common_purge_prevents_cross_file_source_leakage (python not installed)." + ); + return; + }; + + let root = repo_root(); + let scripts_dir = root.join("scripts"); + let tmp = tempdir().expect("tempdir"); + let a_path = tmp.path().join("a.py"); + let b_path = tmp.path().join("b.py"); + std::fs::write(&a_path, "VALUE_A = 1\n").expect("write a.py"); + std::fs::write(&b_path, "VALUE_B = 2\n").expect("write b.py"); + + let snippet = r#" +import importlib.util +import json +import pathlib +import sys + +scripts_dir = pathlib.Path(sys.argv[1]) +tmp_dir = pathlib.Path(sys.argv[2]) + +common_path = scripts_dir / "python_runner_common.py" +spec = importlib.util.spec_from_file_location("python_runner_common", common_path) +if spec is None or spec.loader is None: + raise RuntimeError(f"failed to load {common_path}") +common = importlib.util.module_from_spec(spec) +spec.loader.exec_module(common) +sys.modules["python_runner_common"] = common + +cwd = str(tmp_dir) +a_path = tmp_dir / "a.py" +b_path = tmp_dir / "b.py" + +module_name_a, extra_a = common.resolve_module_info(str(a_path)) +common.import_file(module_name_a, str(a_path), extra_a) +sources_a = common.collect_python_sources(cwd, str(a_path)) + +common.purge_local_modules(cwd, preserve_modules={"__main__", "python_runner_common"}) + +module_name_b, extra_b = common.resolve_module_info(str(b_path)) +common.import_file(module_name_b, str(b_path), extra_b) +sources_b = common.collect_python_sources(cwd, str(b_path)) + +print(json.dumps({"sources_a": sources_a, "sources_b": sources_b})) +"#; + + let output = Command::new(&python) + .arg("-c") + .arg(snippet) + .arg(&scripts_dir) + .arg(tmp.path()) + .output() + .expect("run python_runner_common purge regression script"); + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("Python runner common regression script failed:\n{stderr}"); + } + + let stdout = String::from_utf8(output.stdout).expect("stdout utf-8"); + let parsed: Value = serde_json::from_str(stdout.trim()).expect("parse JSON output"); + let sources_a = parsed + .get("sources_a") + .and_then(Value::as_array) + .expect("sources_a array") + .iter() + .filter_map(Value::as_str) + .collect::>(); + let sources_b = parsed + .get("sources_b") + .and_then(Value::as_array) + .expect("sources_b array") + .iter() + .filter_map(Value::as_str) + .collect::>(); + + let a_str = a_path.to_string_lossy().to_string(); + let b_str = b_path.to_string_lossy().to_string(); + assert!( + sources_a.contains(&a_str.as_str()), + "sources_a should contain a.py" + ); + assert!( + sources_b.contains(&b_str.as_str()), + "sources_b should contain b.py" + ); + assert!( + !sources_b.contains(&a_str.as_str()), + "sources_b should not include a.py from prior file import" + ); +} + +#[cfg(unix)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn functions_push_works_against_mock_api() { + if !command_exists("node") { + eprintln!("Skipping functions_push_works_against_mock_api (node not installed)."); + return; + } + + let state = Arc::new(MockServerState::default()); + state + .projects + .lock() + .expect("projects lock") + .push(MockProject { + id: "proj_mock".to_string(), + name: "mock-project".to_string(), + org_id: "org_mock".to_string(), + }); + let server = MockServer::start(state.clone()).await; + + let tmp = tempdir().expect("tempdir"); + let source = tmp.path().join("tool.js"); + std::fs::write( + &source, + "globalThis._evals ??= { functions: [], prompts: [], parameters: [], evaluators: {}, reporters: {} };\n", + ) + .expect("write source file"); + + let runner = tmp.path().join("mock-runner.sh"); + std::fs::write( + &runner, + r#"#!/bin/sh +set -eu +_runner_script="$1" +shift +node - "$@" <<'NODE' +const path = require("node:path"); +const files = process.argv.slice(2); +const manifest = { + runtime_context: { runtime: "node", version: process.versions.node || "unknown" }, + files: files.map((file, index) => ({ + source_file: path.resolve(file), + entries: [ + { + kind: "code", + project_id: "proj_mock", + name: index === 0 ? "mock-tool" : `mock-tool-${index}`, + slug: index === 0 ? "mock-tool" : `mock-tool-${index}`, + function_type: "tool", + preview: "function handler() { return 1; }", + location: { type: "function", index: 0 } + } + ] + })) +}; +process.stdout.write(JSON.stringify(manifest)); +NODE +"#, + ) + .expect("write mock runner"); + use std::os::unix::fs::PermissionsExt; + let mut perms = std::fs::metadata(&runner) + .expect("runner metadata") + .permissions(); + perms.set_mode(0o755); + std::fs::set_permissions(&runner, perms).expect("runner permissions"); + + let output = Command::new(bt_binary_path()) + .current_dir(tmp.path()) + .args([ + "functions", + "--json", + "push", + "--file", + source + .to_str() + .expect("source path should be valid UTF-8 for test"), + "--language", + "javascript", + "--runner", + runner + .to_str() + .expect("runner path should be valid UTF-8 for test"), + "--if-exists", + "replace", + ]) + .env("BRAINTRUST_API_KEY", "test-key") + .env("BRAINTRUST_ORG_NAME", "test-org") + .env("BRAINTRUST_API_URL", &server.base_url) + .env("BRAINTRUST_APP_URL", &server.base_url) + .env("BRAINTRUST_NO_COLOR", "1") + .env_remove("BRAINTRUST_PROFILE") + .output() + .expect("run bt functions push"); + + server.stop().await; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("mock push failed:\n{stderr}"); + } + + let summary: Value = serde_json::from_slice(&output.stdout).expect("parse push summary"); + assert_eq!(summary["status"].as_str(), Some("success")); + assert_eq!(summary["uploaded_files"].as_u64(), Some(1)); + assert_eq!(summary["failed_files"].as_u64(), Some(0)); + + let inserted = state + .inserted_functions + .lock() + .expect("inserted functions lock") + .clone(); + assert_eq!(inserted.len(), 1, "exactly one function should be inserted"); + let first = inserted[0].as_object().expect("inserted function object"); + assert_eq!( + first.get("project_id").and_then(Value::as_str), + Some("proj_mock") + ); + assert_eq!(first.get("slug").and_then(Value::as_str), Some("mock-tool")); + assert_eq!( + first.get("function_type").and_then(Value::as_str), + Some("tool") + ); + let function_data = first + .get("function_data") + .and_then(Value::as_object) + .expect("function_data object"); + assert_eq!( + function_data.get("type").and_then(Value::as_str), + Some("code"), + "function_data.type must be code" + ); + let data = function_data + .get("data") + .and_then(Value::as_object) + .expect("function_data.data object"); + assert_eq!(data.get("type").and_then(Value::as_str), Some("bundle")); + assert_eq!( + data.get("preview").and_then(Value::as_str), + Some("function handler() { return 1; }") + ); + + let uploaded = state + .uploaded_bundles + .lock() + .expect("uploaded bundles lock") + .clone(); + assert_eq!(uploaded.len(), 1, "expected one uploaded bundle"); + let bundle = &uploaded[0]; + if bundle.starts_with(&[0x1f, 0x8b]) { + let mut decoder = GzDecoder::new(bundle.as_slice()); + let mut decompressed = String::new(); + decoder + .read_to_string(&mut decompressed) + .expect("decompress uploaded bundle"); + assert!( + decompressed.contains("globalThis._evals"), + "uploaded bundle should contain original source" + ); + } else { + let raw = String::from_utf8(bundle.clone()).expect("uploaded bundle utf8"); + assert!( + raw.contains("globalThis._evals"), + "uploaded bundle should contain original source" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn functions_pull_works_against_mock_api() { + let state = Arc::new(MockServerState::default()); + state + .projects + .lock() + .expect("projects lock") + .push(MockProject { + id: "proj_mock".to_string(), + name: "mock-project".to_string(), + org_id: "org_mock".to_string(), + }); + state + .pull_rows + .lock() + .expect("pull rows lock") + .push(serde_json::json!({ + "id": "fn_123", + "name": "Doc Search", + "slug": "doc-search", + "project_id": "proj_mock", + "description": "", + "function_data": { "type": "prompt" }, + "prompt_data": { + "prompt": { + "type": "chat", + "messages": [ + { "role": "system", "content": "You answer from docs." } + ] + }, + "options": { + "model": "gpt-4o-mini" + } + }, + "_xact_id": "0000000000000001" + })); + + let server = MockServer::start(state.clone()).await; + + let tmp = tempdir().expect("tempdir"); + let out_dir = tmp.path().join("pulled"); + std::fs::create_dir_all(&out_dir).expect("create output dir"); + + let output = Command::new(bt_binary_path()) + .current_dir(tmp.path()) + .args([ + "functions", + "--json", + "pull", + "--project-id", + "proj_mock", + "--slug", + "doc-search", + "--force", + "--output-dir", + out_dir + .to_str() + .expect("output dir should be valid UTF-8 for test"), + "--language", + "typescript", + ]) + .env("BRAINTRUST_API_KEY", "test-key") + .env("BRAINTRUST_ORG_NAME", "test-org") + .env("BRAINTRUST_API_URL", &server.base_url) + .env("BRAINTRUST_APP_URL", &server.base_url) + .env("BRAINTRUST_NO_COLOR", "1") + .env_remove("BRAINTRUST_PROFILE") + .output() + .expect("run bt functions pull"); + + server.stop().await; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("mock pull failed:\n{stderr}"); + } + + let summary: Value = serde_json::from_slice(&output.stdout).expect("parse pull summary"); + assert_eq!(summary["status"].as_str(), Some("success")); + assert_eq!(summary["files_written"].as_u64(), Some(1)); + assert_eq!(summary["files_failed"].as_u64(), Some(0)); + + let rendered_file = out_dir.join("doc-search.ts"); + assert!(rendered_file.is_file(), "expected rendered file to exist"); + let rendered = std::fs::read_to_string(&rendered_file).expect("read rendered file"); + assert!( + rendered.contains("project.prompts.create"), + "rendered file should materialize prompt definitions" + ); + assert!( + rendered.contains("slug: \"doc-search\""), + "rendered file should include slug" + ); + assert!( + rendered.contains("gpt-4o-mini"), + "rendered file should include model config" + ); + + let requests = state.requests.lock().expect("requests lock").clone(); + assert!( + requests.iter().any(|entry| { + entry.contains("/v1/function") + && entry.contains("project_id=proj_mock") + && entry.contains("slug=doc-search") + }), + "pull request should include selector query params" + ); +} From f635890157a0e710fc0f2b4cfa8e2aa329acde4e Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Thu, 5 Mar 2026 16:56:33 -0800 Subject: [PATCH 06/28] feat(functions): add positional file arguments to push command --- src/functions/mod.rs | 19 +++++++++++++++++-- src/functions/push.rs | 3 ++- .../push-multiple-files-accepted/fixture.json | 13 ++----------- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/functions/mod.rs b/src/functions/mod.rs index 615bbaa..a556b95 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -257,15 +257,18 @@ struct FunctionsInvokeArgs { #[derive(Debug, Clone, Args)] pub(crate) struct PushArgs { + /// File or directory path(s) to scan for function definitions. + #[arg(value_name = "PATH")] + pub files: Vec, + /// File or directory path(s) to scan for function definitions. #[arg( long = "file", env = "BT_FUNCTIONS_PUSH_FILES", - default_value = ".", value_name = "PATH", value_delimiter = ',' )] - pub files: Vec, + pub file_flag: Vec, /// Behavior when a function with the same slug already exists. #[arg( @@ -312,6 +315,18 @@ pub(crate) struct PushArgs { pub create_missing_projects: bool, } +impl PushArgs { + pub fn resolved_files(&self) -> Vec { + let mut all = self.files.clone(); + all.extend(self.file_flag.iter().cloned()); + if all.is_empty() { + vec![PathBuf::from(".")] + } else { + all + } + } +} + #[derive(Debug, Clone, Args)] pub(crate) struct PullArgs { /// Destination directory for generated files. diff --git a/src/functions/push.rs b/src/functions/push.rs index fd864a4..805fb3b 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -223,7 +223,8 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { } }; - let classified = match collect_classified_files(&args.files) { + let files = args.resolved_files(); + let classified = match collect_classified_files(&files) { Ok(files) => files, Err(err) => { return fail_push( diff --git a/tests/functions-fixtures/push-multiple-files-accepted/fixture.json b/tests/functions-fixtures/push-multiple-files-accepted/fixture.json index 2e7af6b..3838942 100644 --- a/tests/functions-fixtures/push-multiple-files-accepted/fixture.json +++ b/tests/functions-fixtures/push-multiple-files-accepted/fixture.json @@ -1,15 +1,6 @@ { - "command": [ - "functions", - "push", - "--file", - "a.ts", - "--file", - "b.ts", - "--type", - "tool" - ], + "command": ["functions", "push", "a.ts", "b.ts", "--type", "tool"], "expect_success": false, "stderr_contains": ["--type"], - "stderr_not_contains": ["unexpected argument '--file'"] + "stderr_not_contains": ["unexpected argument"] } From 44879072ea39c2dfcb5f02a99453de8c4c29e4dd Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Thu, 5 Mar 2026 17:23:24 -0800 Subject: [PATCH 07/28] refactor(push): improve push confirmation prompt with file and project names --- src/functions/mod.rs | 4 +- src/functions/push.rs | 156 ++++++++++++------ src/http.rs | 4 + .../push-multiple-files-accepted/fixture.json | 5 +- 4 files changed, 116 insertions(+), 53 deletions(-) diff --git a/src/functions/mod.rs b/src/functions/mod.rs index a556b95..bcfc2b5 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -637,7 +637,7 @@ mod tests { }; assert_eq!( - push.files, + push.file_flag, vec![PathBuf::from("a.ts"), PathBuf::from("b.ts")] ); } @@ -695,7 +695,7 @@ mod tests { panic!("expected push command"); }; assert_eq!( - push.files, + push.file_flag, vec![ PathBuf::from("a.ts"), PathBuf::from("b.ts"), diff --git a/src/functions/push.rs b/src/functions/push.rs index 805fb3b..091e30e 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -5,13 +5,15 @@ use std::process::{Command, Output}; use std::time::{SystemTime, UNIX_EPOCH}; use anyhow::{anyhow, bail, Context, Result}; +use dialoguer::console::style; +use dialoguer::theme::ColorfulTheme; use dialoguer::Confirm; use reqwest::StatusCode; use serde::Deserialize; use serde_json::{json, Map, Value}; use crate::args::BaseArgs; -use crate::auth::{list_available_orgs, AvailableOrg}; +use crate::auth::{list_available_orgs, list_profiles, AvailableOrg}; use crate::config; use crate::functions::report::{ CommandStatus, FileStatus, HardFailureReason, PushFileReport, PushSummary, ReportError, @@ -136,7 +138,6 @@ struct ProjectPreflight { requires_default_project: bool, named_projects: BTreeSet, direct_project_ids: BTreeSet, - selector_preview: Vec, } #[derive(Debug, Clone)] @@ -301,18 +302,19 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { return Ok(()); } - let manifest = match run_functions_runner(&args, &files, selected_language) { - Ok(manifest) => manifest, - Err(failure) => { - return fail_push_with_all_skipped( - &base, - &files, - failure.reason, - &failure.message, - "skipped because manifest generation failed", - ); - } - }; + let manifest = + match run_functions_runner(&args, &files, selected_language, auth_ctx.client.api_key()) { + Ok(manifest) => manifest, + Err(failure) => { + return fail_push_with_all_skipped( + &base, + &files, + failure.reason, + &failure.message, + "skipped because manifest generation failed", + ); + } + }; if let Err(failure) = validate_manifest_paths( &manifest, @@ -375,12 +377,19 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { } }; + let preflight_source_files: Vec<&str> = manifest + .files + .iter() + .map(|f| f.source_file.as_str()) + .collect(); + let preflight_project_names: Vec = preflight.named_projects.iter().cloned().collect(); + let (org_decision, org_prompt_confirmed) = match resolve_org_decision( &base, &auth_ctx, &available_orgs, - &preflight.selector_preview, - manifest.files.len(), + &preflight_source_files, + &preflight_project_names, ) { Ok(outcome) => outcome, Err(err) => { @@ -493,8 +502,19 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { let target_project_ids = resolved_targets.unique_project_ids.clone(); + let source_files: Vec<&str> = manifest + .files + .iter() + .map(|f| f.source_file.as_str()) + .collect(); + if !org_prompt_confirmed - && !confirm_push_targets(&auth_ctx, &target_project_ids, manifest.files.len())? + && !confirm_push_targets( + &auth_ctx, + &target_project_ids, + &source_files, + &project_name_cache, + )? { return cancel_push(&base, &files); } @@ -924,6 +944,7 @@ fn run_functions_runner( args: &PushArgs, files: &[PathBuf], language: SourceLanguage, + api_key: &str, ) -> std::result::Result { let mut command = match language { SourceLanguage::JsLike => { @@ -984,6 +1005,8 @@ fn run_functions_runner( } }; + command.env("BRAINTRUST_API_KEY", api_key); + let output = command.output().map_err(|err| FileFailure { reason: HardFailureReason::RunnerSpawnFailed, message: format!("failed to spawn functions runner: {err}"), @@ -1896,8 +1919,8 @@ fn resolve_org_decision( base: &BaseArgs, auth_ctx: &super::AuthContext, available_orgs: &[AvailableOrg], - selector_preview: &[String], - file_count: usize, + source_files: &[&str], + project_names: &[String], ) -> Result<(OrgDecision, bool)> { if base .org_name @@ -1920,17 +1943,31 @@ fn resolve_org_decision( } let org_label = current_org_label(auth_ctx); - let selector_label = if selector_preview.is_empty() { - "none".to_string() + + let file_names: Vec<&str> = source_files + .iter() + .map(|f| { + Path::new(f) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or(f) + }) + .collect(); + let files_part = file_names + .iter() + .map(|f| style(f).green().to_string()) + .collect::>() + .join(", "); + let projects_part = if project_names.is_empty() { + "(no project)".to_string() } else { - selector_preview.join(", ") + project_names.join(", ") }; - let prompt = format!( - "Push {file_count} file(s) with org '{org_label}'. Project selectors: [{selector_label}]" - ); + + let prompt = format!("Push {files_part} to {projects_part} in {org_label}"); let options = [ - "Continue with current org".to_string(), - "Switch organization".to_string(), + format!("Push to {org_label}"), + "Switch org".to_string(), "Cancel".to_string(), ]; let option_refs = options.iter().map(String::as_str).collect::>(); @@ -2033,8 +2070,6 @@ fn collect_project_preflight( let mut requires_default_project = false; let mut named_projects = BTreeSet::new(); let mut direct_project_ids = BTreeSet::new(); - let mut selector_preview = BTreeSet::new(); - for file in &manifest.files { for entry in &file.entries { let selector = match entry { @@ -2054,7 +2089,6 @@ fn collect_project_preflight( default_project_name.as_deref(), &mut named_projects, &mut direct_project_ids, - &mut selector_preview, &mut requires_default_project, )?; } @@ -2065,7 +2099,6 @@ fn collect_project_preflight( requires_default_project, named_projects, direct_project_ids, - selector_preview: selector_preview.into_iter().collect(), }) } @@ -2089,17 +2122,14 @@ fn add_selector_requirement( default_project_name: Option<&str>, named_projects: &mut BTreeSet, direct_project_ids: &mut BTreeSet, - selector_preview: &mut BTreeSet, requires_default_project: &mut bool, ) -> Result<()> { match selector { ProjectSelector::Id(project_id) => { direct_project_ids.insert(project_id.clone()); - selector_preview.insert(project_id.clone()); } ProjectSelector::Name(project_name) => { named_projects.insert(project_name.clone()); - selector_preview.insert(format!("name:{project_name}")); } ProjectSelector::Fallback => { let Some(default_project_name) = default_project_name else { @@ -2111,7 +2141,6 @@ fn add_selector_requirement( }; *requires_default_project = true; named_projects.insert(default_project_name.to_string()); - selector_preview.insert(format!("default:{default_project_name}")); } } Ok(()) @@ -2464,25 +2493,56 @@ async fn resolve_project_name( fn confirm_push_targets( auth_ctx: &super::AuthContext, target_project_ids: &[String], - file_count: usize, + source_files: &[&str], + project_name_cache: &BTreeMap, ) -> Result { if !is_interactive() || target_project_ids.is_empty() { return Ok(true); } - let org_label = if auth_ctx.client.org_name().is_empty() { - auth_ctx.org_id.clone() + let id_to_name: BTreeMap<&str, &str> = project_name_cache + .iter() + .map(|(name, id)| (id.as_str(), name.as_str())) + .collect(); + + let project_labels: Vec<&str> = target_project_ids + .iter() + .map(|id| id_to_name.get(id.as_str()).copied().unwrap_or(id.as_str())) + .collect(); + + let file_names: Vec<&str> = source_files + .iter() + .map(|f| { + Path::new(f) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or(f) + }) + .collect(); + + let files_part = file_names + .iter() + .map(|f| style(f).green().to_string()) + .collect::>() + .join(", "); + let projects_part = project_labels.join(", "); + + let multi_org = list_profiles().is_ok_and(|p| p.len() > 1); + let prompt = if multi_org { + let org_label = if auth_ctx.client.org_name().is_empty() { + &auth_ctx.org_id + } else { + auth_ctx.client.org_name() + }; + format!( + "Push {files_part} to {projects_part} {}?", + style(format!("({org_label})")).dim() + ) } else { - auth_ctx.client.org_name().to_string() + format!("Push {files_part} to {projects_part}?") }; - let prompt = format!( - "Push {} file(s) to org '{}' and project(s) [{}]?", - file_count, - org_label, - target_project_ids.join(", ") - ); - let confirmed = Confirm::new() + let confirmed = Confirm::with_theme(&ColorfulTheme::default()) .with_prompt(prompt) .default(false) .interact()?; @@ -2764,7 +2824,6 @@ mod tests { }; let mut named_projects = BTreeSet::new(); let mut direct_project_ids = BTreeSet::new(); - let mut selector_preview = BTreeSet::new(); let mut requires_default_project = false; let err = add_selector_requirement( @@ -2774,7 +2833,6 @@ mod tests { None, &mut named_projects, &mut direct_project_ids, - &mut selector_preview, &mut requires_default_project, ) .expect_err("must fail"); @@ -2834,6 +2892,7 @@ mod tests { fn select_push_language_auto_prefers_js_like_for_mixed_scan() { let args = PushArgs { files: vec![PathBuf::from(".")], + file_flag: vec![], if_exists: IfExistsMode::Error, terminate_on_failure: false, runner: None, @@ -2860,6 +2919,7 @@ mod tests { fn select_push_language_rejects_mixed_explicit_files() { let args = PushArgs { files: vec![PathBuf::from("a.ts"), PathBuf::from("b.py")], + file_flag: vec![], if_exists: IfExistsMode::Error, terminate_on_failure: false, runner: None, diff --git a/src/http.rs b/src/http.rs index fea804e..40685af 100644 --- a/src/http.rs +++ b/src/http.rs @@ -56,6 +56,10 @@ impl ApiClient { format!("{}/{}", self.base_url, path) } + pub fn api_key(&self) -> &str { + &self.api_key + } + pub fn org_name(&self) -> &str { &self.org_name } diff --git a/tests/functions-fixtures/push-multiple-files-accepted/fixture.json b/tests/functions-fixtures/push-multiple-files-accepted/fixture.json index 3838942..3fd0386 100644 --- a/tests/functions-fixtures/push-multiple-files-accepted/fixture.json +++ b/tests/functions-fixtures/push-multiple-files-accepted/fixture.json @@ -1,6 +1,5 @@ { - "command": ["functions", "push", "a.ts", "b.ts", "--type", "tool"], + "command": ["functions", "push", "a.ts", "b.ts"], "expect_success": false, - "stderr_contains": ["--type"], - "stderr_not_contains": ["unexpected argument"] + "stderr_not_contains": ["unexpected argument", "unrecognized"] } From fd96ed5b43056d903ad6b0588641ae14ab2f44e5 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Fri, 6 Mar 2026 14:15:34 -0800 Subject: [PATCH 08/28] refactor(functions): add multi-slug support to pull command --- src/functions/mod.rs | 231 +++++++++++++++- src/functions/pull.rs | 570 ++++++++++++++++++++++++++++------------ src/functions/push.rs | 158 +++-------- src/functions/report.rs | 1 + 4 files changed, 663 insertions(+), 297 deletions(-) diff --git a/src/functions/mod.rs b/src/functions/mod.rs index bcfc2b5..6550e05 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -5,11 +5,11 @@ use clap::{builder::BoolishValueParser, Args, Subcommand, ValueEnum}; use crate::{ args::BaseArgs, - auth::login, + auth::{login, AvailableOrg}, config, http::ApiClient, projects::api::{get_project_by_name, Project}, - ui::{self, is_interactive, select_project_interactive, with_spinner}, + ui::{self, fuzzy_select, is_interactive, select_project_interactive, with_spinner}, }; pub(crate) mod api; @@ -329,6 +329,19 @@ impl PushArgs { #[derive(Debug, Clone, Args)] pub(crate) struct PullArgs { + /// Function slug(s) to pull. + #[arg(value_name = "SLUG")] + pub slugs: Vec, + + /// Function slug(s) to pull. + #[arg( + long = "slug", + short = 's', + env = "BT_FUNCTIONS_PULL_SLUG", + value_delimiter = ',' + )] + pub slug_flag: Vec, + /// Destination directory for generated files. #[arg( long, @@ -360,13 +373,9 @@ pub(crate) struct PullArgs { pub project_id: Option, /// Function id selector. - #[arg(long, env = "BT_FUNCTIONS_PULL_ID", conflicts_with = "slug")] + #[arg(long, env = "BT_FUNCTIONS_PULL_ID", conflicts_with_all = ["slugs", "slug_flag"])] pub id: Option, - /// Function slug selector. - #[arg(long, env = "BT_FUNCTIONS_PULL_SLUG")] - pub slug: Option, - /// Overwrite targets even when dirty or already existing. #[arg( long, @@ -375,6 +384,27 @@ pub(crate) struct PullArgs { value_parser = BoolishValueParser::new() )] pub force: bool, + + /// Show skipped files in output. + #[arg(long, default_value_t = false)] + pub verbose: bool, +} + +impl PullArgs { + pub fn resolved_slugs(&self) -> Vec { + let mut seen = std::collections::BTreeSet::new(); + let mut result = Vec::new(); + for s in self.slugs.iter().chain(self.slug_flag.iter()) { + if seen.insert(s.clone()) { + result.push(s.clone()); + } + } + result + } + + pub fn has_slug_selector(&self) -> bool { + !self.slugs.is_empty() || !self.slug_flag.is_empty() + } } #[derive(Debug, Clone, Args)] @@ -432,6 +462,119 @@ pub(crate) async fn resolve_auth_context(base: &BaseArgs) -> Result }) } +#[derive(Debug)] +pub(crate) enum OrgDecision { + Continue, + Switch(String), + Cancel, +} + +pub(crate) fn current_org_label(auth_ctx: &AuthContext) -> String { + if auth_ctx.client.org_name().trim().is_empty() { + auth_ctx.org_id.clone() + } else { + auth_ctx.client.org_name().to_string() + } +} + +pub(crate) fn validate_explicit_org_selection( + base: &BaseArgs, + available_orgs: &[AvailableOrg], +) -> Result<()> { + let Some(explicit_org) = base + .org_name + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + else { + return Ok(()); + }; + + let exists = available_orgs + .iter() + .any(|org| org.name == explicit_org || org.name.eq_ignore_ascii_case(explicit_org)); + if exists { + return Ok(()); + } + + let available = available_orgs + .iter() + .map(|org| org.name.as_str()) + .collect::>() + .join(", "); + bail!("org '{explicit_org}' is not available for this credential. Available: {available}"); +} + +/// Prompt the user to confirm/switch org when multiple orgs are available. +/// `prompt` is the question text, `action_label` is used for the confirm option (e.g. "Push to", "Pull from"). +pub(crate) fn resolve_org_decision( + base: &BaseArgs, + auth_ctx: &AuthContext, + available_orgs: &[AvailableOrg], + prompt: &str, + action_label: &str, +) -> Result<(OrgDecision, bool)> { + if base + .org_name + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .is_some() + { + return Ok((OrgDecision::Continue, false)); + } + + if available_orgs.len() <= 1 { + return Ok((OrgDecision::Continue, false)); + } + + if !is_interactive() { + bail!( + "multiple organizations are available for this credential; pass --org in non-interactive mode" + ); + } + + let org_label = current_org_label(auth_ctx); + + let options = [ + format!("{action_label} {org_label}"), + "Switch org".to_string(), + "Cancel".to_string(), + ]; + let option_refs = options.iter().map(String::as_str).collect::>(); + let choice = fuzzy_select(prompt, &option_refs, 0)?; + + match choice { + 0 => Ok((OrgDecision::Continue, true)), + 1 => { + let mut labels = Vec::with_capacity(available_orgs.len()); + let mut default_index = 0usize; + for (index, org) in available_orgs.iter().enumerate() { + let label = if org.api_url.is_some() { + format!("{} [{}]", org.name, org.id) + } else { + org.name.clone() + }; + if org.name == org_label || org.name.eq_ignore_ascii_case(&org_label) { + default_index = index; + } + labels.push(label); + } + let label_refs = labels.iter().map(String::as_str).collect::>(); + let selected_index = fuzzy_select("Select organization", &label_refs, default_index)?; + let selected = available_orgs + .get(selected_index) + .ok_or_else(|| anyhow!("invalid org selection"))?; + if selected.name == org_label || selected.name.eq_ignore_ascii_case(&org_label) { + Ok((OrgDecision::Continue, true)) + } else { + Ok((OrgDecision::Switch(selected.name.clone()), true)) + } + } + _ => Ok((OrgDecision::Cancel, true)), + } +} + pub(crate) async fn resolve_project_context( base: &BaseArgs, auth_ctx: &AuthContext, @@ -800,11 +943,81 @@ mod tests { } #[test] - fn pull_conflicts_id_and_slug() { + fn pull_conflicts_id_and_slug_flag() { let _guard = test_lock(); let err = parse(&["functions", "pull", "--id", "f1", "--slug", "slug"]) .expect_err("should conflict"); + assert!(err.to_string().contains("--slug") || err.to_string().contains("--id")); + } + + #[test] + fn pull_conflicts_id_and_positional_slug() { + let _guard = test_lock(); + let err = + parse(&["functions", "pull", "--id", "f1", "my-slug"]).expect_err("should conflict"); + assert!( + err.to_string().contains("--id") + || err.to_string().contains("SLUG") + || err.to_string().contains("cannot be used") + ); + } + + #[test] + fn pull_positional_slugs_parse() { + let _guard = test_lock(); + let parsed = parse(&["functions", "pull", "slug-a", "slug-b"]).expect("parse pull"); + let FunctionsCommands::Pull(pull) = parsed.command.expect("subcommand") else { + panic!("expected pull"); + }; + assert_eq!(pull.resolved_slugs(), vec!["slug-a", "slug-b"]); + } - assert!(err.to_string().contains("--slug")); + #[test] + fn pull_slug_flag_repeats() { + let _guard = test_lock(); + let parsed = + parse(&["functions", "pull", "--slug", "a", "--slug", "b"]).expect("parse pull"); + let FunctionsCommands::Pull(pull) = parsed.command.expect("subcommand") else { + panic!("expected pull"); + }; + assert_eq!(pull.resolved_slugs(), vec!["a", "b"]); + } + + #[test] + fn pull_merges_positional_and_flag_slugs() { + let _guard = test_lock(); + let parsed = + parse(&["functions", "pull", "pos-slug", "--slug", "flag-slug"]).expect("parse pull"); + let FunctionsCommands::Pull(pull) = parsed.command.expect("subcommand") else { + panic!("expected pull"); + }; + assert_eq!(pull.resolved_slugs(), vec!["pos-slug", "flag-slug"]); + } + + #[test] + fn pull_deduplicates_slugs() { + let _guard = test_lock(); + let parsed = parse(&["functions", "pull", "same", "--slug", "same"]).expect("parse pull"); + let FunctionsCommands::Pull(pull) = parsed.command.expect("subcommand") else { + panic!("expected pull"); + }; + assert_eq!(pull.resolved_slugs(), vec!["same"]); + } + + #[test] + fn pull_slug_env_uses_delimiter() { + let _guard = test_lock(); + unsafe { + std::env::set_var("BT_FUNCTIONS_PULL_SLUG", "a,b,c"); + } + let parsed = parse(&["functions", "pull"]).expect("parse pull"); + unsafe { + std::env::remove_var("BT_FUNCTIONS_PULL_SLUG"); + } + + let FunctionsCommands::Pull(pull) = parsed.command.expect("subcommand") else { + panic!("expected pull command"); + }; + assert_eq!(pull.slug_flag, vec!["a", "b", "c"]); } } diff --git a/src/functions/pull.rs b/src/functions/pull.rs index 1ecb488..0ada470 100644 --- a/src/functions/pull.rs +++ b/src/functions/pull.rs @@ -8,15 +8,20 @@ use serde::Deserialize; use serde_json::Value; use crate::args::BaseArgs; +use crate::auth::list_available_orgs; use crate::functions::report::{ CommandStatus, FileStatus, HardFailureReason, PullFileReport, PullSummary, ReportError, ReportWarning, SoftSkipReason, WarningReason, }; -use crate::projects::api::{list_projects, Project}; +use crate::projects::api::{get_project_by_name, list_projects, Project}; +use crate::ui::{fuzzy_select, select_project_interactive}; use crate::utils::{write_text_atomic, GitRepo}; use super::api::{self, FunctionListQuery}; -use super::{resolve_auth_context, resolve_project_context, FunctionsLanguage, PullArgs}; +use super::{ + current_org_label, resolve_auth_context, resolve_project_context_optional, + validate_explicit_org_selection, FunctionsLanguage, PullArgs, +}; const PAGINATION_PAGE_LIMIT: usize = 10_000; const OUTPUT_LOCK_FILE: &str = ".bt-functions-pull.lock"; @@ -98,19 +103,35 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { errors: vec![], }; let mut projects_cache: Option> = None; + let has_explicit_project = args.project_id.is_some() || args.project_name.is_some(); - let auth_ctx = match resolve_auth_context(&base) - .await - .context("failed to resolve auth context") - { - Ok(ctx) => ctx, - Err(err) => { - return fail_pull( - &base, - &mut summary, - HardFailureReason::AuthFailed, - err.to_string(), - ); + let mut base = base; + let (auth_ctx, selected_project) = if !has_explicit_project { + match resolve_pull_target(&mut base).await { + Ok(result) => result, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::AuthFailed, + err.to_string(), + ); + } + } + } else { + match resolve_auth_context(&base) + .await + .context("failed to resolve auth context") + { + Ok(ctx) => (ctx, None), + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::AuthFailed, + err.to_string(), + ); + } } }; @@ -139,29 +160,23 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { ); } query.project_name = Some(project_name.clone()); - } else { - let project = match resolve_project_context(&base, &auth_ctx) - .await - .context("failed to resolve default project context") - { - Ok(project) => project, - Err(err) => { - return fail_pull( - &base, - &mut summary, - HardFailureReason::ResponseInvalid, - err.to_string(), - ); - } - }; + } else if let Some(project) = selected_project { query.project_id = Some(project.id); + } else { + return fail_pull( + &base, + &mut summary, + HardFailureReason::ResponseInvalid, + "no project selected".to_string(), + ); } if let Some(id) = &args.id { query.id = Some(id.clone()); } - if let Some(slug) = &args.slug { - query.slug = Some(slug.clone()); + let resolved_slugs = args.resolved_slugs(); + if resolved_slugs.len() == 1 { + query.slug = Some(resolved_slugs[0].clone()); } let fetched = match fetch_all_function_rows(&auth_ctx.client, &query).await { Ok(fetched) => fetched, @@ -208,7 +223,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { let winners = select_winner_rows(narrowed_rows, &mut summary); - if (args.id.is_some() || args.slug.is_some()) && winners.is_empty() { + if (args.id.is_some() || args.has_slug_selector()) && winners.is_empty() { return fail_pull( &base, &mut summary, @@ -226,7 +241,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { } } - if (args.id.is_some() || args.slug.is_some()) && materializable.is_empty() { + if (args.id.is_some() || args.has_slug_selector()) && materializable.is_empty() { return fail_pull( &base, &mut summary, @@ -235,16 +250,6 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { ); } - if args.slug.is_some() && materializable.len() > 1 { - return fail_pull( - &base, - &mut summary, - HardFailureReason::SelectorNotFound, - "slug selector matched multiple prompts; pass --project-name or --project-id" - .to_string(), - ); - } - let output_dir = if args.output_dir.is_absolute() { args.output_dir.clone() } else { @@ -319,110 +324,104 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { } }; - let grouped_by_project = match group_rows_by_project(materializable, &project_names) { - Ok(grouped) => grouped, - Err(err) => { - return fail_pull( - &base, - &mut summary, - HardFailureReason::ResponseInvalid, - err.to_string(), - ); - } - }; - summary.projects_total = grouped_by_project.len(); - let ext = match args.language { FunctionsLanguage::Typescript => "ts", FunctionsLanguage::Python => "py", }; - let file_names = match build_output_file_names(&grouped_by_project, args.slug.as_deref(), ext) { - Ok(file_names) => file_names, - Err(err) => { - return fail_pull( - &base, - &mut summary, - HardFailureReason::SelectorNotFound, - err.to_string(), - ); - } - }; - - for ((project_id, project_name), rows) in grouped_by_project { - let file_name = file_names - .get(&(project_id.clone(), project_name.clone())) - .ok_or_else(|| anyhow!("missing output file mapping"))? - .clone(); - - let target = canonical_output_dir.join(&file_name); - let display_target = display_output_path(&target); - if !target.starts_with(&canonical_output_dir) { - record_pull_file_failure( - &mut summary, - target.display().to_string(), - HardFailureReason::UnsafeOutputPath, - format!("refusing to write outside output dir: {}", target.display()), - ); - continue; - } - let skip_reason = match should_skip_target(&repo, &target, args.force) { - Ok(reason) => reason, - Err(err) => { - record_pull_file_failure( - &mut summary, - target.display().to_string(), - HardFailureReason::RequestFailed, - err.to_string(), - ); - continue; + if !resolved_slugs.is_empty() { + // Slug mode: one file per slug. + let found_slugs: BTreeSet<&str> = materializable.iter().map(|r| r.slug.as_str()).collect(); + for slug in &resolved_slugs { + if !found_slugs.contains(slug.as_str()) { + summary.warnings.push(ReportWarning { + reason: WarningReason::SelectorPartialMatch, + message: format!("slug '{}' not found", slug), + }); } - }; - if let Some(reason) = skip_reason { - summary.files_skipped += 1; - summary.files.push(PullFileReport { - output_file: target.display().to_string(), - status: FileStatus::Skipped, - skipped_reason: Some(reason), - error_reason: None, - message: None, - }); - continue; } - let rendered = - match render_project_file(args.language, &project_name, &display_target, &rows) { - Ok(rendered) => rendered, - Err(err) => { - record_pull_file_failure( + // Check for cross-project collisions: if any slug appears in multiple projects, + // require --project-name/--project-id to disambiguate. + if !has_explicit_project { + let mut slug_projects: BTreeMap<&str, BTreeSet<&str>> = BTreeMap::new(); + for row in &materializable { + slug_projects + .entry(row.slug.as_str()) + .or_default() + .insert(row.project_id.as_str()); + } + for (slug, projects) in &slug_projects { + if projects.len() > 1 { + return fail_pull( + &base, &mut summary, - target.display().to_string(), - HardFailureReason::ResponseInvalid, - err.to_string(), + HardFailureReason::SelectorNotFound, + format!( + "slug '{}' exists in {} projects; pass --project-name or --project-id", + slug, + projects.len() + ), ); - continue; } - }; - match write_text_atomic(&target, &rendered) { - Ok(()) => { - summary.files_written += 1; - summary.functions_materialized += rows.len(); - summary.files.push(PullFileReport { - output_file: target.display().to_string(), - status: FileStatus::Success, - skipped_reason: None, - error_reason: None, - message: None, - }); } + } + + summary.projects_total = materializable + .iter() + .map(|r| r.project_id.as_str()) + .collect::>() + .len(); + + for row in &materializable { + let project_name = project_names + .get(&row.project_id) + .map(String::as_str) + .unwrap_or("unknown"); + let file_name = format!("{}.{ext}", sanitize_filename(&row.slug)); + write_pull_file( + &mut summary, + &canonical_output_dir, + &repo, + args.force, + args.language, + project_name, + &file_name, + std::slice::from_ref(row), + ); + } + } else { + // Project mode: one file per project (existing behavior). + let grouped_by_project = match group_rows_by_project(materializable, &project_names) { + Ok(grouped) => grouped, Err(err) => { - record_pull_file_failure( + return fail_pull( + &base, &mut summary, - target.display().to_string(), - HardFailureReason::AtomicWriteFailed, + HardFailureReason::ResponseInvalid, err.to_string(), ); } + }; + summary.projects_total = grouped_by_project.len(); + + let file_names = build_project_file_names(&grouped_by_project, ext); + + for ((project_id, project_name), rows) in grouped_by_project { + let file_name = file_names + .get(&(project_id.clone(), project_name.clone())) + .ok_or_else(|| anyhow!("missing output file mapping"))? + .clone(); + write_pull_file( + &mut summary, + &canonical_output_dir, + &repo, + args.force, + args.language, + &project_name, + &file_name, + &rows, + ); } } @@ -436,7 +435,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { } let failure = summary.status == CommandStatus::Failed; - emit_summary(&base, &summary)?; + emit_summary(&base, &summary, args.verbose)?; if failure { bail!("functions pull failed; see summary for details"); } @@ -546,19 +545,21 @@ fn apply_selector_narrowing( rows: Vec, args: &PullArgs, ) -> Result> { + let resolved_slugs = args.resolved_slugs(); let narrowed = if let Some(id) = &args.id { rows.into_iter() .filter(|row| row.id == *id) .collect::>() - } else if let Some(slug) = &args.slug { + } else if !resolved_slugs.is_empty() { + let slug_set: BTreeSet<&str> = resolved_slugs.iter().map(String::as_str).collect(); rows.into_iter() - .filter(|row| row.slug == *slug) + .filter(|row| slug_set.contains(row.slug.as_str())) .collect::>() } else { rows }; - if (args.id.is_some() || args.slug.is_some()) && narrowed.is_empty() { + if (args.id.is_some() || args.has_slug_selector()) && narrowed.is_empty() { bail!("selector did not match any function rows"); } @@ -677,6 +678,87 @@ fn group_rows_by_project( Ok(grouped) } +fn write_pull_file( + summary: &mut PullSummary, + canonical_output_dir: &Path, + repo: &Option, + force: bool, + language: FunctionsLanguage, + project_name: &str, + file_name: &str, + rows: &[PullFunctionRow], +) { + let target = canonical_output_dir.join(file_name); + let display_target = display_output_path(&target); + if !target.starts_with(canonical_output_dir) { + record_pull_file_failure( + summary, + target.display().to_string(), + HardFailureReason::UnsafeOutputPath, + format!("refusing to write outside output dir: {}", target.display()), + ); + return; + } + + let skip_reason = match should_skip_target(repo, &target, force) { + Ok(reason) => reason, + Err(err) => { + record_pull_file_failure( + summary, + target.display().to_string(), + HardFailureReason::RequestFailed, + err.to_string(), + ); + return; + } + }; + if let Some(reason) = skip_reason { + summary.files_skipped += 1; + summary.files.push(PullFileReport { + output_file: target.display().to_string(), + status: FileStatus::Skipped, + skipped_reason: Some(reason), + error_reason: None, + message: None, + }); + return; + } + + let rendered = match render_project_file(language, project_name, &display_target, rows) { + Ok(rendered) => rendered, + Err(err) => { + record_pull_file_failure( + summary, + target.display().to_string(), + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + return; + } + }; + match write_text_atomic(&target, &rendered) { + Ok(()) => { + summary.files_written += 1; + summary.functions_materialized += rows.len(); + summary.files.push(PullFileReport { + output_file: target.display().to_string(), + status: FileStatus::Success, + skipped_reason: None, + error_reason: None, + message: None, + }); + } + Err(err) => { + record_pull_file_failure( + summary, + target.display().to_string(), + HardFailureReason::AtomicWriteFailed, + err.to_string(), + ); + } + } +} + fn build_project_file_names( grouped_by_project: &BTreeMap<(String, String), Vec>, ext: &str, @@ -1414,36 +1496,77 @@ fn is_empty_render_value(value: &Value) -> bool { } } -fn emit_summary(base: &BaseArgs, summary: &PullSummary) -> Result<()> { +fn skip_reason_label(reason: Option) -> &'static str { + match reason { + Some(SoftSkipReason::DirtyTarget) => "dirty target", + Some(SoftSkipReason::ExistingNonGitNoForce) => "already exists", + Some(SoftSkipReason::MalformedRecord) => "malformed record", + Some(SoftSkipReason::UnsupportedFunctionType) => "unsupported type", + Some(SoftSkipReason::SupersededVersion) => "superseded", + Some(SoftSkipReason::TerminatedAfterFailure) => "terminated after failure", + Some(SoftSkipReason::IfExistsIgnored) => "ignored", + Some(SoftSkipReason::NoDefinitionsFound) => "no definitions found", + None => "skipped", + } +} + +fn short_display_path(path_str: &str, cwd: Option<&Path>) -> String { + let p = Path::new(path_str); + let file_name = p.file_name().unwrap_or(p.as_os_str()); + match (cwd, p.parent()) { + (Some(cwd), Some(parent)) if parent == cwd => file_name.to_string_lossy().into_owned(), + _ => { + // parent_dir/file_name + let parent_name = p.parent().and_then(|p| p.file_name()).unwrap_or_default(); + Path::new(parent_name).join(file_name).display().to_string() + } + } +} + +fn emit_summary(base: &BaseArgs, summary: &PullSummary, verbose: bool) -> Result<()> { if base.json { println!("{}", serde_json::to_string(summary)?); - } else { - match summary.status { - CommandStatus::Success => { - eprintln!( - "Pulled {} file(s), materialized {} prompt(s).", - summary.files_written, summary.functions_materialized - ); - } - CommandStatus::Partial => { - eprintln!( - "Pull completed with partial results: written={}, skipped={}, failed={}", - summary.files_written, summary.files_skipped, summary.files_failed - ); - } - CommandStatus::Failed => { - eprintln!( - "Pull failed: written={}, skipped={}, failed={}", - summary.files_written, summary.files_skipped, summary.files_failed - ); + return Ok(()); + } + + let has_visible_files = summary + .files + .iter() + .any(|f| f.status == FileStatus::Success || f.status == FileStatus::Failed || verbose); + let mut parts = vec![format!("Wrote {} file(s)", summary.files_written)]; + if has_visible_files { + let cwd = std::env::current_dir().ok(); + for f in &summary.files { + let name = short_display_path(&f.output_file, cwd.as_deref()); + match f.status { + FileStatus::Success => eprintln!("Pulled {name}"), + FileStatus::Failed => { + let msg = f.message.as_deref().unwrap_or("unknown error"); + eprintln!("Failed to pull {name} ({msg})"); + } + FileStatus::Skipped if verbose => { + let reason = skip_reason_label(f.skipped_reason); + eprintln!("Skipped {name} ({reason})"); + } + FileStatus::Skipped => {} } } - for warning in &summary.warnings { - eprintln!("warning: {}", warning.message); - } - for error in &summary.errors { - eprintln!("error: {}", error.message); - } + eprintln!(); + } + + if summary.files_skipped > 0 { + parts.push(format!("skipped {}", summary.files_skipped)); + } + if summary.files_failed > 0 { + parts.push(format!("failed {}", summary.files_failed)); + } + eprintln!("{}.", parts.join(", ")); + + for warning in &summary.warnings { + eprintln!("warning: {}", warning.message); + } + for error in &summary.errors { + eprintln!("error: {}", error.message); } Ok(()) @@ -1461,7 +1584,7 @@ fn fail_pull( message: message.clone(), }); if base.json { - emit_summary(base, summary)?; + emit_summary(base, summary, false)?; } bail!(message); } @@ -1487,6 +1610,68 @@ fn record_pull_file_failure( }); } +async fn resolve_pull_target(base: &mut BaseArgs) -> Result<(super::AuthContext, Option)> { + let available_orgs = list_available_orgs(base) + .await + .context("failed to list available orgs")?; + + validate_explicit_org_selection(base, &available_orgs)?; + + let mut auth_ctx = resolve_auth_context(base) + .await + .context("failed to resolve auth context")?; + + // Org selector: show when multiple orgs and no explicit --org. + let has_explicit_org = base + .org_name + .as_deref() + .map(str::trim) + .filter(|v| !v.is_empty()) + .is_some(); + if !has_explicit_org && available_orgs.len() > 1 && crate::ui::is_interactive() { + let org_label = current_org_label(&auth_ctx); + let names: Vec<&str> = available_orgs.iter().map(|o| o.name.as_str()).collect(); + let default_index = names + .iter() + .position(|n| *n == org_label || n.eq_ignore_ascii_case(&org_label)) + .unwrap_or(0); + let selected_index = fuzzy_select("Select org:", &names, default_index)?; + let selected = &available_orgs[selected_index]; + if selected.name != org_label && !selected.name.eq_ignore_ascii_case(&org_label) { + base.org_name = Some(selected.name.clone()); + auth_ctx = resolve_auth_context(base) + .await + .context("failed to resolve switched org context")?; + } + } else if !has_explicit_org && available_orgs.len() > 1 { + bail!( + "multiple organizations are available for this credential; pass --org in non-interactive mode" + ); + } + + // Project selector with current default pre-focused. + let default_project = resolve_project_context_optional(base, &auth_ctx, false) + .await + .ok() + .flatten(); + if !crate::ui::is_interactive() { + return Ok((auth_ctx, default_project)); + } + let current_name = default_project.as_ref().map(|p| p.name.as_str()); + let selected_name = select_project_interactive( + &auth_ctx.client, + Some("Select project to pull from:"), + current_name, + ) + .await?; + + let project = get_project_by_name(&auth_ctx.client, &selected_name) + .await? + .ok_or_else(|| anyhow!("project '{selected_name}' not found"))?; + + Ok((auth_ctx, Some(project))) +} + #[cfg(test)] mod tests { use super::*; @@ -1539,19 +1724,80 @@ mod tests { _xact_id: None, }; let args = PullArgs { + slugs: vec![], + slug_flag: vec![], output_dir: PathBuf::from("."), language: FunctionsLanguage::Typescript, project_name: None, project_id: None, id: Some("missing".to_string()), - slug: None, force: false, + verbose: false, }; let err = apply_selector_narrowing(vec![row], &args).expect_err("should fail"); assert!(err.to_string().contains("selector")); } + #[test] + fn multi_slug_narrowing_filters_to_matching() { + let rows = vec![ + PullFunctionRow { + id: "f1".to_string(), + name: "A".to_string(), + slug: "alpha".to_string(), + project_id: "p1".to_string(), + project_name: Some("Proj".to_string()), + description: None, + prompt_data: None, + function_data: None, + created: None, + _xact_id: None, + }, + PullFunctionRow { + id: "f2".to_string(), + name: "B".to_string(), + slug: "beta".to_string(), + project_id: "p1".to_string(), + project_name: Some("Proj".to_string()), + description: None, + prompt_data: None, + function_data: None, + created: None, + _xact_id: None, + }, + PullFunctionRow { + id: "f3".to_string(), + name: "G".to_string(), + slug: "gamma".to_string(), + project_id: "p1".to_string(), + project_name: Some("Proj".to_string()), + description: None, + prompt_data: None, + function_data: None, + created: None, + _xact_id: None, + }, + ]; + let args = PullArgs { + slugs: vec!["alpha".to_string()], + slug_flag: vec!["gamma".to_string()], + output_dir: PathBuf::from("."), + language: FunctionsLanguage::Typescript, + project_name: None, + project_id: None, + id: None, + force: false, + verbose: false, + }; + + let narrowed = apply_selector_narrowing(rows, &args).expect("should narrow"); + assert_eq!(narrowed.len(), 2); + let slugs: Vec<&str> = narrowed.iter().map(|r| r.slug.as_str()).collect(); + assert!(slugs.contains(&"alpha")); + assert!(slugs.contains(&"gamma")); + } + #[test] fn group_rows_uses_resolved_project_name() { let row = PullFunctionRow { diff --git a/src/functions/push.rs b/src/functions/push.rs index 091e30e..c534f7a 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -13,7 +13,7 @@ use serde::Deserialize; use serde_json::{json, Map, Value}; use crate::args::BaseArgs; -use crate::auth::{list_available_orgs, list_profiles, AvailableOrg}; +use crate::auth::{list_available_orgs, list_profiles}; use crate::config; use crate::functions::report::{ CommandStatus, FileStatus, HardFailureReason, PushFileReport, PushSummary, ReportError, @@ -23,10 +23,13 @@ use crate::js_runner; use crate::projects::api::{create_project, get_project_by_name, list_projects}; use crate::python_runner; use crate::source_language::{classify_runtime_extension, JsExtensionProfile, SourceLanguage}; -use crate::ui::{fuzzy_select, is_interactive}; +use crate::ui::is_interactive; use super::api; -use super::{resolve_auth_context, PushArgs, PushLanguage}; +use super::{ + current_org_label, resolve_auth_context, resolve_org_decision, validate_explicit_org_selection, + OrgDecision, PushArgs, PushLanguage, +}; const FUNCTIONS_JS_RUNNER_FILE: &str = "functions-runner.ts"; const FUNCTIONS_PY_RUNNER_FILE: &str = "functions-runner.py"; @@ -118,13 +121,6 @@ fn error_chain(err: &anyhow::Error) -> String { format!("{err:#}") } -#[derive(Debug, Clone, PartialEq, Eq)] -enum OrgDecision { - Continue, - Switch(String), - Cancel, -} - #[derive(Debug, Clone, PartialEq, Eq)] enum ProjectSelector { Id(String), @@ -384,24 +380,21 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { .collect(); let preflight_project_names: Vec = preflight.named_projects.iter().cloned().collect(); - let (org_decision, org_prompt_confirmed) = match resolve_org_decision( - &base, - &auth_ctx, - &available_orgs, - &preflight_source_files, - &preflight_project_names, - ) { - Ok(outcome) => outcome, - Err(err) => { - return fail_push( - &base, - files.len(), - HardFailureReason::ResponseInvalid, - error_chain(&err), - "failed to resolve org context", - ); - } - }; + let org_prompt = + build_push_org_prompt(&auth_ctx, &preflight_source_files, &preflight_project_names); + let (org_decision, org_prompt_confirmed) = + match resolve_org_decision(&base, &auth_ctx, &available_orgs, &org_prompt, "Push to") { + Ok(outcome) => outcome, + Err(err) => { + return fail_push( + &base, + files.len(), + HardFailureReason::ResponseInvalid, + error_chain(&err), + "failed to resolve org context", + ); + } + }; match org_decision { OrgDecision::Continue => {} @@ -1890,60 +1883,11 @@ fn ensure_path_within_allowed_roots( ); } -fn validate_explicit_org_selection(base: &BaseArgs, available_orgs: &[AvailableOrg]) -> Result<()> { - let Some(explicit_org) = base - .org_name - .as_deref() - .map(str::trim) - .filter(|value| !value.is_empty()) - else { - return Ok(()); - }; - - let exists = available_orgs - .iter() - .any(|org| org.name == explicit_org || org.name.eq_ignore_ascii_case(explicit_org)); - if exists { - return Ok(()); - } - - let available = available_orgs - .iter() - .map(|org| org.name.as_str()) - .collect::>() - .join(", "); - bail!("org '{explicit_org}' is not available for this credential. Available: {available}"); -} - -fn resolve_org_decision( - base: &BaseArgs, +fn build_push_org_prompt( auth_ctx: &super::AuthContext, - available_orgs: &[AvailableOrg], source_files: &[&str], project_names: &[String], -) -> Result<(OrgDecision, bool)> { - if base - .org_name - .as_deref() - .map(str::trim) - .filter(|value| !value.is_empty()) - .is_some() - { - return Ok((OrgDecision::Continue, false)); - } - - if available_orgs.len() <= 1 { - return Ok((OrgDecision::Continue, false)); - } - - if !is_interactive() { - bail!( - "multiple organizations are available for this credential; pass --org in non-interactive mode" - ); - } - - let org_label = current_org_label(auth_ctx); - +) -> String { let file_names: Vec<&str> = source_files .iter() .map(|f| { @@ -1961,55 +1905,15 @@ fn resolve_org_decision( let projects_part = if project_names.is_empty() { "(no project)".to_string() } else { - project_names.join(", ") + project_names + .iter() + .map(|p| style(p).green().to_string()) + .collect::>() + .join(", ") }; + let org_styled = style(current_org_label(auth_ctx)).green(); - let prompt = format!("Push {files_part} to {projects_part} in {org_label}"); - let options = [ - format!("Push to {org_label}"), - "Switch org".to_string(), - "Cancel".to_string(), - ]; - let option_refs = options.iter().map(String::as_str).collect::>(); - let choice = fuzzy_select(&prompt, &option_refs, 0)?; - - match choice { - 0 => Ok((OrgDecision::Continue, true)), - 1 => { - let mut labels = Vec::with_capacity(available_orgs.len()); - let mut default_index = 0usize; - for (index, org) in available_orgs.iter().enumerate() { - let label = if org.api_url.is_some() { - format!("{} [{}]", org.name, org.id) - } else { - org.name.clone() - }; - if org.name == org_label || org.name.eq_ignore_ascii_case(&org_label) { - default_index = index; - } - labels.push(label); - } - let label_refs = labels.iter().map(String::as_str).collect::>(); - let selected_index = fuzzy_select("Select organization", &label_refs, default_index)?; - let selected = available_orgs - .get(selected_index) - .ok_or_else(|| anyhow!("invalid org selection"))?; - if selected.name == org_label || selected.name.eq_ignore_ascii_case(&org_label) { - Ok((OrgDecision::Continue, true)) - } else { - Ok((OrgDecision::Switch(selected.name.clone()), true)) - } - } - _ => Ok((OrgDecision::Cancel, true)), - } -} - -fn current_org_label(auth_ctx: &super::AuthContext) -> String { - if auth_ctx.client.org_name().trim().is_empty() { - auth_ctx.org_id.clone() - } else { - auth_ctx.client.org_name().to_string() - } + format!("Push {files_part} to {projects_part} in {org_styled}") } fn cancel_push(base: &BaseArgs, files: &[PathBuf]) -> Result<()> { @@ -2662,6 +2566,7 @@ fn to_warning_code(warning: &ReportWarning) -> &'static str { super::report::WarningReason::PaginationNotSnapshotConsistent => { "pagination_not_snapshot_consistent" } + super::report::WarningReason::SelectorPartialMatch => "selector_partial_match", } } @@ -2780,6 +2685,7 @@ fn fail_push_manifest_preflight( #[cfg(test)] mod tests { use crate::args::BaseArgs; + use crate::auth::AvailableOrg; use crate::functions::IfExistsMode; use super::*; diff --git a/src/functions/report.rs b/src/functions/report.rs index b6953ca..3f01149 100644 --- a/src/functions/report.rs +++ b/src/functions/report.rs @@ -55,6 +55,7 @@ pub enum SoftSkipReason { #[serde(rename_all = "snake_case")] pub enum WarningReason { PaginationNotSnapshotConsistent, + SelectorPartialMatch, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] From 6547198b08e3ea66db7a9c2eb567f46948abcdcf Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Fri, 6 Mar 2026 15:21:16 -0800 Subject: [PATCH 09/28] feat(functions): add version filter and legacy compatibility flags --- src/args.rs | 8 +- src/functions/api.rs | 4 + src/functions/mod.rs | 54 ++- src/functions/pull.rs | 331 +++++++----------- src/functions/push.rs | 25 ++ src/main.rs | 10 + src/utils/git.rs | 79 ++++- .../pull-help-env-vars/fixture.json | 1 + .../pull-help-flags/fixture.json | 1 + .../pull-id-slug-conflict/fixture.json | 7 +- .../fixture.json | 7 +- .../push-help-env-vars/fixture.json | 2 + .../push-help-flags/fixture.json | 2 + .../fixture.json | 7 +- .../push-reject-tsconfig/fixture.json | 7 +- tests/functions.rs | 52 ++- 16 files changed, 355 insertions(+), 242 deletions(-) diff --git a/src/args.rs b/src/args.rs index 3748e8b..e6d30c1 100644 --- a/src/args.rs +++ b/src/args.rs @@ -23,7 +23,13 @@ pub struct BaseArgs { pub profile: Option, /// Override active org (or via BRAINTRUST_ORG_NAME) - #[arg(short = 'o', long = "org", env = "BRAINTRUST_ORG_NAME", global = true)] + #[arg( + short = 'o', + long = "org", + alias = "org-name", + env = "BRAINTRUST_ORG_NAME", + global = true + )] pub org_name: Option, /// Override active project diff --git a/src/functions/api.rs b/src/functions/api.rs index 34b9972..6317574 100644 --- a/src/functions/api.rs +++ b/src/functions/api.rs @@ -39,6 +39,7 @@ pub struct FunctionListQuery { pub project_name: Option, pub slug: Option, pub id: Option, + pub version: Option, pub cursor: Option, pub snapshot: Option, } @@ -133,6 +134,9 @@ pub async fn list_functions_page( if let Some(id) = &query.id { params.push(("ids", id.clone())); } + if let Some(version) = &query.version { + params.push(("version", version.clone())); + } if let Some(cursor) = &query.cursor { params.push(("cursor", cursor.clone())); } diff --git a/src/functions/mod.rs b/src/functions/mod.rs index 6550e05..1ec27db 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -305,11 +305,24 @@ pub(crate) struct PushArgs { #[arg(long, env = "BT_FUNCTIONS_PUSH_REQUIREMENTS", value_name = "PATH")] pub requirements: Option, + /// Compatibility flag for legacy push workflows. Currently informational. + #[arg(long, env = "BT_FUNCTIONS_PUSH_TSCONFIG", value_name = "PATH")] + pub tsconfig: Option, + + /// Compatibility flag for legacy push workflows. Currently informational. + #[arg( + long = "external-packages", + env = "BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES", + value_delimiter = ',', + value_name = "PACKAGE" + )] + pub external_packages: Vec, + /// Create missing projects referenced by function definitions. #[arg( long = "create-missing-projects", env = "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS", - default_value_t = false, + default_value_t = true, value_parser = BoolishValueParser::new() )] pub create_missing_projects: bool, @@ -346,7 +359,7 @@ pub(crate) struct PullArgs { #[arg( long, env = "BT_FUNCTIONS_PULL_OUTPUT_DIR", - default_value = ".", + default_value = "./braintrust", value_name = "PATH" )] pub output_dir: PathBuf, @@ -373,9 +386,13 @@ pub(crate) struct PullArgs { pub project_id: Option, /// Function id selector. - #[arg(long, env = "BT_FUNCTIONS_PULL_ID", conflicts_with_all = ["slugs", "slug_flag"])] + #[arg(long, env = "BT_FUNCTIONS_PULL_ID")] pub id: Option, + /// Version selector (supports pretty version IDs). + #[arg(long, env = "BT_FUNCTIONS_PULL_VERSION")] + pub version: Option, + /// Overwrite targets even when dirty or already existing. #[arg( long, @@ -711,6 +728,14 @@ pub async fn run(base: BaseArgs, args: FunctionsArgs) -> Result<()> { } } +pub async fn run_push(base: BaseArgs, args: PushArgs) -> Result<()> { + push::run(base, args).await +} + +pub async fn run_pull(base: BaseArgs, args: PullArgs) -> Result<()> { + pull::run(base, args).await +} + #[cfg(test)] mod tests { use std::sync::{Mutex, MutexGuard, OnceLock}; @@ -945,21 +970,24 @@ mod tests { #[test] fn pull_conflicts_id_and_slug_flag() { let _guard = test_lock(); - let err = parse(&["functions", "pull", "--id", "f1", "--slug", "slug"]) - .expect_err("should conflict"); - assert!(err.to_string().contains("--slug") || err.to_string().contains("--id")); + let parsed = + parse(&["functions", "pull", "--id", "f1", "--slug", "slug"]).expect("parse pull"); + let FunctionsCommands::Pull(pull) = parsed.command.expect("subcommand") else { + panic!("expected pull"); + }; + assert_eq!(pull.id.as_deref(), Some("f1")); + assert_eq!(pull.resolved_slugs(), vec!["slug"]); } #[test] fn pull_conflicts_id_and_positional_slug() { let _guard = test_lock(); - let err = - parse(&["functions", "pull", "--id", "f1", "my-slug"]).expect_err("should conflict"); - assert!( - err.to_string().contains("--id") - || err.to_string().contains("SLUG") - || err.to_string().contains("cannot be used") - ); + let parsed = parse(&["functions", "pull", "--id", "f1", "my-slug"]).expect("parse pull"); + let FunctionsCommands::Pull(pull) = parsed.command.expect("subcommand") else { + panic!("expected pull"); + }; + assert_eq!(pull.id.as_deref(), Some("f1")); + assert_eq!(pull.resolved_slugs(), vec!["my-slug"]); } #[test] diff --git a/src/functions/pull.rs b/src/functions/pull.rs index 0ada470..4881fe6 100644 --- a/src/functions/pull.rs +++ b/src/functions/pull.rs @@ -8,23 +8,22 @@ use serde::Deserialize; use serde_json::Value; use crate::args::BaseArgs; -use crate::auth::list_available_orgs; use crate::functions::report::{ CommandStatus, FileStatus, HardFailureReason, PullFileReport, PullSummary, ReportError, ReportWarning, SoftSkipReason, WarningReason, }; -use crate::projects::api::{get_project_by_name, list_projects, Project}; -use crate::ui::{fuzzy_select, select_project_interactive}; +use crate::projects::api::{list_projects, Project}; use crate::utils::{write_text_atomic, GitRepo}; use super::api::{self, FunctionListQuery}; -use super::{ - current_org_label, resolve_auth_context, resolve_project_context_optional, - validate_explicit_org_selection, FunctionsLanguage, PullArgs, -}; +use super::{resolve_auth_context, FunctionsLanguage, PullArgs}; const PAGINATION_PAGE_LIMIT: usize = 10_000; const OUTPUT_LOCK_FILE: &str = ".bt-functions-pull.lock"; +const TOP_BITS: u64 = 0x0DE1u64 << 48; +const MODULUS: u128 = 1u128 << 64; +const COPRIME: u64 = 205_891_132_094_649; +const COPRIME_INVERSE: u64 = 1_522_336_535_492_693_385; #[derive(Debug, Clone, Deserialize)] struct PullFunctionRow { @@ -103,35 +102,18 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { errors: vec![], }; let mut projects_cache: Option> = None; - let has_explicit_project = args.project_id.is_some() || args.project_name.is_some(); - - let mut base = base; - let (auth_ctx, selected_project) = if !has_explicit_project { - match resolve_pull_target(&mut base).await { - Ok(result) => result, - Err(err) => { - return fail_pull( - &base, - &mut summary, - HardFailureReason::AuthFailed, - err.to_string(), - ); - } - } - } else { - match resolve_auth_context(&base) - .await - .context("failed to resolve auth context") - { - Ok(ctx) => (ctx, None), - Err(err) => { - return fail_pull( - &base, - &mut summary, - HardFailureReason::AuthFailed, - err.to_string(), - ); - } + let auth_ctx = match resolve_auth_context(&base) + .await + .context("failed to resolve auth context") + { + Ok(ctx) => ctx, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::AuthFailed, + err.to_string(), + ); } }; @@ -160,20 +142,24 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { ); } query.project_name = Some(project_name.clone()); - } else if let Some(project) = selected_project { - query.project_id = Some(project.id); - } else { - return fail_pull( - &base, - &mut summary, - HardFailureReason::ResponseInvalid, - "no project selected".to_string(), - ); } if let Some(id) = &args.id { query.id = Some(id.clone()); } + if let Some(version) = &args.version { + query.version = match load_pretty_xact_compat(version) { + Ok(value) => Some(value), + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + } + }; + } let resolved_slugs = args.resolved_slugs(); if resolved_slugs.len() == 1 { query.slug = Some(resolved_slugs[0].clone()); @@ -232,8 +218,13 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { ); } + let project_ids_with_matches = winners + .iter() + .map(|row| row.project_id.clone()) + .collect::>(); + let mut materializable = Vec::new(); - for row in winners { + for row in winners.iter().cloned() { if is_prompt_row(&row) { materializable.push(row); } else { @@ -241,15 +232,6 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { } } - if (args.id.is_some() || args.has_slug_selector()) && materializable.is_empty() { - return fail_pull( - &base, - &mut summary, - HardFailureReason::SelectorNotFound, - "selector matched records but none are materializable prompts".to_string(), - ); - } - let output_dir = if args.output_dir.is_absolute() { args.output_dir.clone() } else { @@ -330,7 +312,6 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { }; if !resolved_slugs.is_empty() { - // Slug mode: one file per slug. let found_slugs: BTreeSet<&str> = materializable.iter().map(|r| r.slug.as_str()).collect(); for slug in &resolved_slugs { if !found_slugs.contains(slug.as_str()) { @@ -340,89 +321,64 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { }); } } + } - // Check for cross-project collisions: if any slug appears in multiple projects, - // require --project-name/--project-id to disambiguate. - if !has_explicit_project { - let mut slug_projects: BTreeMap<&str, BTreeSet<&str>> = BTreeMap::new(); - for row in &materializable { - slug_projects - .entry(row.slug.as_str()) - .or_default() - .insert(row.project_id.as_str()); - } - for (slug, projects) in &slug_projects { - if projects.len() > 1 { - return fail_pull( - &base, - &mut summary, - HardFailureReason::SelectorNotFound, - format!( - "slug '{}' exists in {} projects; pass --project-name or --project-id", - slug, - projects.len() - ), - ); - } - } - } - - summary.projects_total = materializable - .iter() - .map(|r| r.project_id.as_str()) - .collect::>() - .len(); - - for row in &materializable { - let project_name = project_names - .get(&row.project_id) - .map(String::as_str) - .unwrap_or("unknown"); - let file_name = format!("{}.{ext}", sanitize_filename(&row.slug)); - write_pull_file( + // Legacy-compatible project mode: one output file per project, even for + // selector pulls that only matched unsupported record types. + let mut grouped_by_project = BTreeMap::<(String, String), Vec>::new(); + for project_id in project_ids_with_matches { + let Some(project_name) = project_names.get(&project_id).cloned() else { + return fail_pull( + &base, &mut summary, - &canonical_output_dir, - &repo, - args.force, - args.language, - project_name, - &file_name, - std::slice::from_ref(row), + HardFailureReason::ResponseInvalid, + format!( + "missing resolved project name for project id '{}'", + project_id + ), ); - } - } else { - // Project mode: one file per project (existing behavior). - let grouped_by_project = match group_rows_by_project(materializable, &project_names) { - Ok(grouped) => grouped, - Err(err) => { - return fail_pull( - &base, - &mut summary, - HardFailureReason::ResponseInvalid, - err.to_string(), - ); - } }; - summary.projects_total = grouped_by_project.len(); - - let file_names = build_project_file_names(&grouped_by_project, ext); - - for ((project_id, project_name), rows) in grouped_by_project { - let file_name = file_names - .get(&(project_id.clone(), project_name.clone())) - .ok_or_else(|| anyhow!("missing output file mapping"))? - .clone(); - write_pull_file( + grouped_by_project + .entry((project_id, project_name)) + .or_default(); + } + for row in materializable { + let Some(project_name) = project_names.get(&row.project_id).cloned() else { + return fail_pull( + &base, &mut summary, - &canonical_output_dir, - &repo, - args.force, - args.language, - &project_name, - &file_name, - &rows, + HardFailureReason::ResponseInvalid, + format!( + "missing resolved project name for project id '{}'", + row.project_id + ), ); - } + }; + grouped_by_project + .entry((row.project_id.clone(), project_name)) + .or_default() + .push(row); + } + + summary.projects_total = grouped_by_project.len(); + + let file_names = build_project_file_names(&grouped_by_project, ext); + + for ((project_id, project_name), rows) in grouped_by_project { + let file_name = file_names + .get(&(project_id.clone(), project_name.clone())) + .ok_or_else(|| anyhow!("missing output file mapping"))? + .clone(); + write_pull_file( + &mut summary, + &canonical_output_dir, + &repo, + args.force, + args.language, + &project_name, + &file_name, + &rows, + ); } if summary.status != CommandStatus::Failed @@ -470,6 +426,26 @@ fn ensure_unambiguous_project_name(projects: &[Project], project_name: &str) -> } } +fn modular_multiply(value: u64, prime: u64) -> u64 { + ((value as u128 * prime as u128) % MODULUS) as u64 +} + +fn load_pretty_xact_compat(encoded_hex: &str) -> Result { + if encoded_hex.len() != 16 { + return Ok(encoded_hex.to_string()); + } + let value = u64::from_str_radix(encoded_hex, 16).with_context(|| { + format!("invalid pretty version '{encoded_hex}' (expected 16 hex characters)") + })?; + let multiplied_inverse = modular_multiply(value, COPRIME_INVERSE); + let with_top_bits = TOP_BITS | multiplied_inverse; + let roundtrip = modular_multiply(with_top_bits, COPRIME); + if roundtrip != value { + bail!("invalid pretty version '{encoded_hex}' (failed compatibility decode)"); + } + Ok(with_top_bits.to_string()) +} + struct FetchRowsResult { rows: Vec, warnings: Vec, @@ -546,20 +522,17 @@ fn apply_selector_narrowing( args: &PullArgs, ) -> Result> { let resolved_slugs = args.resolved_slugs(); - let narrowed = if let Some(id) = &args.id { - rows.into_iter() - .filter(|row| row.id == *id) - .collect::>() - } else if !resolved_slugs.is_empty() { - let slug_set: BTreeSet<&str> = resolved_slugs.iter().map(String::as_str).collect(); - rows.into_iter() - .filter(|row| slug_set.contains(row.slug.as_str())) - .collect::>() - } else { - rows - }; + let slug_set: BTreeSet<&str> = resolved_slugs.iter().map(String::as_str).collect(); + let has_id_selector = args.id.is_some(); + let has_slug_selector = !slug_set.is_empty(); + + let narrowed = rows + .into_iter() + .filter(|row| args.id.as_ref().is_none_or(|id| row.id == *id)) + .filter(|row| !has_slug_selector || slug_set.contains(row.slug.as_str())) + .collect::>(); - if (args.id.is_some() || args.has_slug_selector()) && narrowed.is_empty() { + if (has_id_selector || has_slug_selector) && narrowed.is_empty() { bail!("selector did not match any function rows"); } @@ -658,6 +631,7 @@ fn resolve_project_names( Ok(names_by_id) } +#[allow(dead_code)] fn group_rows_by_project( rows: Vec, project_names: &BTreeMap, @@ -792,6 +766,7 @@ fn build_project_file_names( names } +#[allow(dead_code)] fn build_output_file_names( grouped_by_project: &BTreeMap<(String, String), Vec>, slug_selector: Option<&str>, @@ -1610,68 +1585,6 @@ fn record_pull_file_failure( }); } -async fn resolve_pull_target(base: &mut BaseArgs) -> Result<(super::AuthContext, Option)> { - let available_orgs = list_available_orgs(base) - .await - .context("failed to list available orgs")?; - - validate_explicit_org_selection(base, &available_orgs)?; - - let mut auth_ctx = resolve_auth_context(base) - .await - .context("failed to resolve auth context")?; - - // Org selector: show when multiple orgs and no explicit --org. - let has_explicit_org = base - .org_name - .as_deref() - .map(str::trim) - .filter(|v| !v.is_empty()) - .is_some(); - if !has_explicit_org && available_orgs.len() > 1 && crate::ui::is_interactive() { - let org_label = current_org_label(&auth_ctx); - let names: Vec<&str> = available_orgs.iter().map(|o| o.name.as_str()).collect(); - let default_index = names - .iter() - .position(|n| *n == org_label || n.eq_ignore_ascii_case(&org_label)) - .unwrap_or(0); - let selected_index = fuzzy_select("Select org:", &names, default_index)?; - let selected = &available_orgs[selected_index]; - if selected.name != org_label && !selected.name.eq_ignore_ascii_case(&org_label) { - base.org_name = Some(selected.name.clone()); - auth_ctx = resolve_auth_context(base) - .await - .context("failed to resolve switched org context")?; - } - } else if !has_explicit_org && available_orgs.len() > 1 { - bail!( - "multiple organizations are available for this credential; pass --org in non-interactive mode" - ); - } - - // Project selector with current default pre-focused. - let default_project = resolve_project_context_optional(base, &auth_ctx, false) - .await - .ok() - .flatten(); - if !crate::ui::is_interactive() { - return Ok((auth_ctx, default_project)); - } - let current_name = default_project.as_ref().map(|p| p.name.as_str()); - let selected_name = select_project_interactive( - &auth_ctx.client, - Some("Select project to pull from:"), - current_name, - ) - .await?; - - let project = get_project_by_name(&auth_ctx.client, &selected_name) - .await? - .ok_or_else(|| anyhow!("project '{selected_name}' not found"))?; - - Ok((auth_ctx, Some(project))) -} - #[cfg(test)] mod tests { use super::*; @@ -1731,6 +1644,7 @@ mod tests { project_name: None, project_id: None, id: Some("missing".to_string()), + version: None, force: false, verbose: false, }; @@ -1787,6 +1701,7 @@ mod tests { project_name: None, project_id: None, id: None, + version: None, force: false, verbose: false, }; diff --git a/src/functions/push.rs b/src/functions/push.rs index c534f7a..ba4c42c 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -257,6 +257,17 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { }; emit_language_selection_notice(&args, &classified, selected_language); + if args.tsconfig.is_some() { + eprintln!( + "Notice: --tsconfig is enabled for runner compatibility where supported (TS_NODE_PROJECT/TSX_TSCONFIG_PATH)." + ); + } + if !args.external_packages.is_empty() { + eprintln!( + "Notice: --external-packages is accepted for compatibility; dependency handling is runner-managed in bt functions push." + ); + } + if args.requirements.is_some() && selected_language != SourceLanguage::Python { return fail_push( &base, @@ -999,6 +1010,16 @@ fn run_functions_runner( }; command.env("BRAINTRUST_API_KEY", api_key); + if let Some(tsconfig) = &args.tsconfig { + command.env("TS_NODE_PROJECT", tsconfig); + command.env("TSX_TSCONFIG_PATH", tsconfig); + } + if !args.external_packages.is_empty() { + command.env( + "BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES", + args.external_packages.join(","), + ); + } let output = command.output().map_err(|err| FileFailure { reason: HardFailureReason::RunnerSpawnFailed, @@ -2804,6 +2825,8 @@ mod tests { runner: None, language: PushLanguage::Auto, requirements: None, + tsconfig: None, + external_packages: vec![], create_missing_projects: false, }; let classified = ClassifiedFiles { @@ -2831,6 +2854,8 @@ mod tests { runner: None, language: PushLanguage::Auto, requirements: None, + tsconfig: None, + external_packages: vec![], create_missing_projects: false, }; let classified = ClassifiedFiles { diff --git a/src/main.rs b/src/main.rs index db7760b..d80374d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,6 +60,8 @@ Core Projects & resources projects Manage projects prompts Manage prompts + push Compatibility alias for `functions push` + pull Compatibility alias for `functions pull` functions Manage functions (tools, scorers, and more) tools Manage tools scorers Manage scorers @@ -140,6 +142,10 @@ enum Commands { Scorers(CLIArgs), /// Manage functions (tools, scorers, and more) Functions(CLIArgs), + /// Compatibility alias for `functions push` + Push(CLIArgs), + /// Compatibility alias for `functions pull` + Pull(CLIArgs), /// Manage experiments Experiments(CLIArgs), /// Synchronize project logs between Braintrust and local NDJSON files @@ -171,6 +177,8 @@ impl Commands { Commands::Tools(cmd) => &cmd.base, Commands::Scorers(cmd) => &cmd.base, Commands::Functions(cmd) => &cmd.base, + Commands::Push(cmd) => &cmd.base, + Commands::Pull(cmd) => &cmd.base, Commands::Experiments(cmd) => &cmd.base, Commands::Sync(cmd) => &cmd.base, Commands::Util(cmd) => &cmd.base, @@ -224,6 +232,8 @@ async fn try_main() -> Result<()> { Commands::Tools(cmd) => tools::run(cmd.base, cmd.args).await?, Commands::Scorers(cmd) => scorers::run(cmd.base, cmd.args).await?, Commands::Functions(cmd) => functions::run(cmd.base, cmd.args).await?, + Commands::Push(cmd) => functions::run_push(cmd.base, cmd.args).await?, + Commands::Pull(cmd) => functions::run_pull(cmd.base, cmd.args).await?, Commands::Experiments(cmd) => experiments::run(cmd.base, cmd.args).await?, Commands::Sync(cmd) => sync::run(cmd.base, cmd.args).await?, Commands::Util(cmd) => util_cmd::run(cmd.base, cmd.args).await?, diff --git a/src/utils/git.rs b/src/utils/git.rs index cbbf53b..51fd606 100644 --- a/src/utils/git.rs +++ b/src/utils/git.rs @@ -40,7 +40,9 @@ impl GitRepo { ); } - Ok(!String::from_utf8_lossy(&output.stdout).trim().is_empty()) + Ok(has_tracked_changes(&String::from_utf8_lossy( + &output.stdout, + ))) } pub fn discover_from(path: &Path) -> Option { @@ -48,6 +50,14 @@ impl GitRepo { } } +fn has_tracked_changes(porcelain: &str) -> bool { + porcelain + .lines() + .map(str::trim_end) + .filter(|line| !line.is_empty()) + .any(|line| !line.starts_with("?? ") && !line.starts_with("!! ")) +} + pub fn find_repo_root_from(start: &Path) -> Option { let mut current = start.to_path_buf(); if current.is_file() { @@ -67,9 +77,23 @@ pub fn find_repo_root_from(start: &Path) -> Option { #[cfg(test)] mod tests { use std::fs; + use std::process::Command; use super::*; + fn run_git(cwd: &Path, args: &[&str]) { + let status = Command::new("git") + .args(args) + .current_dir(cwd) + .status() + .expect("run git"); + assert!( + status.success(), + "git command failed: git {}", + args.join(" ") + ); + } + #[test] fn find_repo_root_detects_git_dir() { let unique = std::time::SystemTime::now() @@ -103,4 +127,57 @@ mod tests { let _ = fs::remove_dir_all(found); } + + #[test] + fn tracked_modifications_are_reported_dirty() { + let unique = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("clock") + .as_nanos(); + let root = std::env::temp_dir().join(format!("bt-git-dirty-{unique}")); + fs::create_dir_all(&root).expect("create repo root"); + + run_git(&root, &["init"]); + run_git(&root, &["config", "user.email", "tests@example.com"]); + run_git(&root, &["config", "user.name", "BT Tests"]); + + let file = root.join("tracked.txt"); + fs::write(&file, "v1\n").expect("write tracked file"); + run_git(&root, &["add", "tracked.txt"]); + run_git(&root, &["commit", "-m", "init"]); + + fs::write(&file, "v2\n").expect("modify tracked file"); + + let repo = GitRepo { root: root.clone() }; + assert!(repo.is_dirty_or_untracked(&file).expect("git status")); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn untracked_file_is_not_treated_as_dirty_for_pull_compat() { + let unique = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("clock") + .as_nanos(); + let root = std::env::temp_dir().join(format!("bt-git-untracked-{unique}")); + fs::create_dir_all(&root).expect("create repo root"); + + run_git(&root, &["init"]); + run_git(&root, &["config", "user.email", "tests@example.com"]); + run_git(&root, &["config", "user.name", "BT Tests"]); + + let tracked = root.join("tracked.txt"); + fs::write(&tracked, "v1\n").expect("write tracked file"); + run_git(&root, &["add", "tracked.txt"]); + run_git(&root, &["commit", "-m", "init"]); + + let untracked = root.join("untracked.txt"); + fs::write(&untracked, "local-only\n").expect("write untracked file"); + + let repo = GitRepo { root: root.clone() }; + assert!(!repo.is_dirty_or_untracked(&untracked).expect("git status")); + + let _ = fs::remove_dir_all(root); + } } diff --git a/tests/functions-fixtures/pull-help-env-vars/fixture.json b/tests/functions-fixtures/pull-help-env-vars/fixture.json index 34ab809..5756da8 100644 --- a/tests/functions-fixtures/pull-help-env-vars/fixture.json +++ b/tests/functions-fixtures/pull-help-env-vars/fixture.json @@ -7,6 +7,7 @@ "BT_FUNCTIONS_PULL_PROJECT_NAME", "BT_FUNCTIONS_PULL_ID", "BT_FUNCTIONS_PULL_SLUG", + "BT_FUNCTIONS_PULL_VERSION", "BT_FUNCTIONS_PULL_FORCE", "BT_FUNCTIONS_PULL_LANGUAGE" ] diff --git a/tests/functions-fixtures/pull-help-flags/fixture.json b/tests/functions-fixtures/pull-help-flags/fixture.json index 1bd2716..197aca7 100644 --- a/tests/functions-fixtures/pull-help-flags/fixture.json +++ b/tests/functions-fixtures/pull-help-flags/fixture.json @@ -7,6 +7,7 @@ "--project-name", "--id", "--slug", + "--version", "--language", "--force" ] diff --git a/tests/functions-fixtures/pull-id-slug-conflict/fixture.json b/tests/functions-fixtures/pull-id-slug-conflict/fixture.json index 6d2b4d6..184920f 100644 --- a/tests/functions-fixtures/pull-id-slug-conflict/fixture.json +++ b/tests/functions-fixtures/pull-id-slug-conflict/fixture.json @@ -1,5 +1,6 @@ { - "command": ["functions", "pull", "--id", "abc", "--slug", "slug"], - "expect_success": false, - "stderr_contains": ["--id", "--slug"] + "command": ["functions", "pull", "--id", "abc", "--slug", "slug", "--help"], + "expect_success": true, + "stdout_contains": ["--id", "--slug"], + "stderr_not_contains": ["conflicts with", "cannot be used with"] } diff --git a/tests/functions-fixtures/pull-valid-language-typescript-parses/fixture.json b/tests/functions-fixtures/pull-valid-language-typescript-parses/fixture.json index 4b53f14..843cbd1 100644 --- a/tests/functions-fixtures/pull-valid-language-typescript-parses/fixture.json +++ b/tests/functions-fixtures/pull-valid-language-typescript-parses/fixture.json @@ -7,9 +7,10 @@ "--id", "abc", "--slug", - "slug" + "slug", + "--help" ], - "expect_success": false, - "stderr_contains": ["--id", "--slug"], + "expect_success": true, + "stdout_contains": ["--id", "--slug"], "stderr_not_contains": ["invalid value 'typescript'"] } diff --git a/tests/functions-fixtures/push-help-env-vars/fixture.json b/tests/functions-fixtures/push-help-env-vars/fixture.json index ec52bc0..3bd4748 100644 --- a/tests/functions-fixtures/push-help-env-vars/fixture.json +++ b/tests/functions-fixtures/push-help-env-vars/fixture.json @@ -8,6 +8,8 @@ "BT_FUNCTIONS_PUSH_RUNNER", "BT_FUNCTIONS_PUSH_LANGUAGE", "BT_FUNCTIONS_PUSH_REQUIREMENTS", + "BT_FUNCTIONS_PUSH_TSCONFIG", + "BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES", "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS" ] } diff --git a/tests/functions-fixtures/push-help-flags/fixture.json b/tests/functions-fixtures/push-help-flags/fixture.json index af5772f..4ffc039 100644 --- a/tests/functions-fixtures/push-help-flags/fixture.json +++ b/tests/functions-fixtures/push-help-flags/fixture.json @@ -8,6 +8,8 @@ "--create-missing-projects", "--language", "--requirements", + "--tsconfig", + "--external-packages", "--runner" ] } diff --git a/tests/functions-fixtures/push-reject-external-packages/fixture.json b/tests/functions-fixtures/push-reject-external-packages/fixture.json index 818959f..2745683 100644 --- a/tests/functions-fixtures/push-reject-external-packages/fixture.json +++ b/tests/functions-fixtures/push-reject-external-packages/fixture.json @@ -1,5 +1,6 @@ { - "command": ["functions", "push", "--external-packages", "react"], - "expect_success": false, - "stderr_contains": ["--external-packages"] + "command": ["functions", "push", "--external-packages", "react", "--help"], + "expect_success": true, + "stdout_contains": ["--external-packages"], + "stderr_not_contains": ["unexpected argument", "unrecognized"] } diff --git a/tests/functions-fixtures/push-reject-tsconfig/fixture.json b/tests/functions-fixtures/push-reject-tsconfig/fixture.json index c65a197..194bc08 100644 --- a/tests/functions-fixtures/push-reject-tsconfig/fixture.json +++ b/tests/functions-fixtures/push-reject-tsconfig/fixture.json @@ -1,5 +1,6 @@ { - "command": ["functions", "push", "--tsconfig", "./tsconfig.json"], - "expect_success": false, - "stderr_contains": ["--tsconfig"] + "command": ["functions", "push", "--tsconfig", "./tsconfig.json", "--help"], + "expect_success": true, + "stdout_contains": ["--tsconfig"], + "stderr_not_contains": ["unexpected argument", "unrecognized"] } diff --git a/tests/functions.rs b/tests/functions.rs index c1f62f4..476096a 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -145,12 +145,15 @@ fn sanitized_env_keys() -> &'static [&'static str] { "BT_FUNCTIONS_PUSH_RUNNER", "BT_FUNCTIONS_PUSH_LANGUAGE", "BT_FUNCTIONS_PUSH_REQUIREMENTS", + "BT_FUNCTIONS_PUSH_TSCONFIG", + "BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES", "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS", "BT_FUNCTIONS_PULL_OUTPUT_DIR", "BT_FUNCTIONS_PULL_PROJECT_ID", "BT_FUNCTIONS_PULL_PROJECT_NAME", "BT_FUNCTIONS_PULL_ID", "BT_FUNCTIONS_PULL_SLUG", + "BT_FUNCTIONS_PULL_VERSION", "BT_FUNCTIONS_PULL_FORCE", "BT_FUNCTIONS_PULL_LANGUAGE", ] @@ -527,6 +530,8 @@ fn functions_push_help_includes_expected_flags() { assert!(stdout.contains("--create-missing-projects")); assert!(stdout.contains("--language")); assert!(stdout.contains("--requirements")); + assert!(stdout.contains("--tsconfig")); + assert!(stdout.contains("--external-packages")); } #[test] @@ -543,11 +548,12 @@ fn functions_pull_help_includes_expected_flags() { assert!(stdout.contains("--output-dir")); assert!(stdout.contains("--project-id")); assert!(stdout.contains("--project-name")); + assert!(stdout.contains("--version")); assert!(stdout.contains("--language")); } #[test] -fn functions_pull_id_and_slug_conflict() { +fn functions_pull_accepts_id_and_slug_together() { let output = Command::new(bt_binary_path()) .arg("functions") .arg("pull") @@ -555,13 +561,14 @@ fn functions_pull_id_and_slug_conflict() { .arg("abc") .arg("--slug") .arg("slug") + .arg("--help") .output() - .expect("run conflicting pull command"); + .expect("run pull with id and slug"); - assert!(!output.status.success()); - let stderr = String::from_utf8_lossy(&output.stderr); - assert!(stderr.contains("--id")); - assert!(stderr.contains("--slug")); + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("--id")); + assert!(stdout.contains("--slug")); } #[test] @@ -623,6 +630,34 @@ fn functions_help_lists_push_and_pull() { assert!(stdout.contains("pull")); } +#[test] +fn top_level_push_help_is_available() { + let output = Command::new(bt_binary_path()) + .arg("push") + .arg("--help") + .output() + .expect("run bt push --help"); + + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("--if-exists")); + assert!(stdout.contains("--file")); +} + +#[test] +fn top_level_pull_help_is_available() { + let output = Command::new(bt_binary_path()) + .arg("pull") + .arg("--help") + .output() + .expect("run bt pull --help"); + + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("--output-dir")); + assert!(stdout.contains("--version")); +} + #[test] fn push_and_pull_help_are_machine_readable() { let push_help = Command::new(bt_binary_path()) @@ -647,8 +682,11 @@ fn push_and_pull_help_are_machine_readable() { assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS")); assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_LANGUAGE")); assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_REQUIREMENTS")); + assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_TSCONFIG")); + assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES")); assert!(pull_stdout.contains("BT_FUNCTIONS_PULL_OUTPUT_DIR")); assert!(pull_stdout.contains("BT_FUNCTIONS_PULL_LANGUAGE")); + assert!(pull_stdout.contains("BT_FUNCTIONS_PULL_VERSION")); } #[test] @@ -1376,7 +1414,7 @@ async fn functions_pull_works_against_mock_api() { assert_eq!(summary["files_written"].as_u64(), Some(1)); assert_eq!(summary["files_failed"].as_u64(), Some(0)); - let rendered_file = out_dir.join("doc-search.ts"); + let rendered_file = out_dir.join("mock-project.ts"); assert!(rendered_file.is_file(), "expected rendered file to exist"); let rendered = std::fs::read_to_string(&rendered_file).expect("read rendered file"); assert!( From 8974925284b3bfb3eb023716998762851acd3e6e Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Fri, 6 Mar 2026 15:38:31 -0800 Subject: [PATCH 10/28] feat(functions): add JS bundling support with esbuild --- scripts/functions-bundler.ts | 152 +++++++++++ src/functions/mod.rs | 47 +++- src/functions/push.rs | 81 +++++- tests/functions.rs | 471 +++++++++++++++++++++++++++++++++-- 4 files changed, 727 insertions(+), 24 deletions(-) create mode 100644 scripts/functions-bundler.ts diff --git a/scripts/functions-bundler.ts b/scripts/functions-bundler.ts new file mode 100644 index 0000000..62de932 --- /dev/null +++ b/scripts/functions-bundler.ts @@ -0,0 +1,152 @@ +import fs from "node:fs"; +import path from "node:path"; + +type EsbuildBuild = (options: Record) => Promise; +type EsbuildModule = { + build: EsbuildBuild; +}; + +function isObject(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function isEsbuildModule(value: unknown): value is EsbuildModule { + return isObject(value) && typeof value.build === "function"; +} + +function parseExternalPackages(raw: string | undefined): string[] { + if (!raw) { + return []; + } + return raw + .split(",") + .map((value) => value.trim()) + .filter((value) => value.length > 0); +} + +function loadTsconfigPath(): string | undefined { + const tsNode = process.env.TS_NODE_PROJECT?.trim(); + if (tsNode) { + return tsNode; + } + const tsx = process.env.TSX_TSCONFIG_PATH?.trim(); + if (tsx) { + return tsx; + } + return undefined; +} + +function createMarkKnownPackagesExternalPlugin(additionalPackages: string[]) { + return { + name: "make-known-packages-external", + setup(build: { + onResolve: ( + opts: { filter: RegExp }, + cb: (args: { path: string }) => { path: string; external: boolean }, + ) => void; + }) { + const knownPackages = [ + "braintrust", + "autoevals", + "@braintrust/", + "config", + "lightningcss", + "@mapbox/node-pre-gyp", + "fsevents", + "chokidar", + ...additionalPackages, + ]; + const escapedPackages = knownPackages.map((pkg) => { + const escaped = pkg.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); + if (pkg.endsWith("/")) { + return `${escaped}.*`; + } + return `${escaped}(?:\\/.*)?`; + }); + const knownPackagesFilter = new RegExp( + `^(${escapedPackages.join("|")})$`, + ); + build.onResolve({ filter: knownPackagesFilter }, (args) => ({ + path: args.path, + external: true, + })); + }, + }; +} + +async function loadEsbuild(): Promise { + if (typeof require === "function") { + try { + const loaded = require("esbuild") as unknown; + if (isEsbuildModule(loaded)) { + return loaded; + } + if (isObject(loaded) && isEsbuildModule(loaded.default)) { + return loaded.default; + } + } catch { + // Fall through to dynamic import. + } + } + + try { + // Keep module name dynamic so TypeScript doesn't require local esbuild types at compile time. + const specifier = "esbuild"; + const loaded = (await import(specifier)) as unknown; + if (isEsbuildModule(loaded)) { + return loaded; + } + if (isObject(loaded) && isEsbuildModule(loaded.default)) { + return loaded.default; + } + } catch { + // handled below + } + + throw new Error( + "failed to load esbuild for JS bundling; install esbuild in your project or use a runner that provides it", + ); +} + +async function main(): Promise { + const [sourceFile, outputFile] = process.argv.slice(2); + if (!sourceFile || !outputFile) { + throw new Error("functions-bundler requires "); + } + + const esbuild = await loadEsbuild(); + const externalPackages = parseExternalPackages( + process.env.BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES, + ); + const tsconfig = loadTsconfigPath(); + + const outputDir = path.dirname(outputFile); + fs.mkdirSync(outputDir, { recursive: true }); + + const targetVersion = + typeof process.version === "string" && process.version.startsWith("v") + ? process.version.slice(1) + : process.versions.node || "18"; + + await esbuild.build({ + entryPoints: [sourceFile], + bundle: true, + treeShaking: true, + platform: "node", + target: `node${targetVersion}`, + write: true, + outfile: outputFile, + tsconfig, + external: ["node_modules/*", "fsevents"], + plugins: [createMarkKnownPackagesExternalPlugin(externalPackages)], + }); +} + +main().catch((error: unknown) => { + const message = + error instanceof Error + ? error.message + : `failed to bundle JS source: ${String(error)}`; + process.stderr.write(`${message}\n`); + process.exitCode = 1; +}); diff --git a/src/functions/mod.rs b/src/functions/mod.rs index 1ec27db..0f12a68 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -305,14 +305,15 @@ pub(crate) struct PushArgs { #[arg(long, env = "BT_FUNCTIONS_PUSH_REQUIREMENTS", value_name = "PATH")] pub requirements: Option, - /// Compatibility flag for legacy push workflows. Currently informational. + /// Optional tsconfig path for JS runner and bundler. #[arg(long, env = "BT_FUNCTIONS_PUSH_TSCONFIG", value_name = "PATH")] pub tsconfig: Option, - /// Compatibility flag for legacy push workflows. Currently informational. + /// Additional packages to mark external during JS bundling. #[arg( long = "external-packages", env = "BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES", + num_args = 1.., value_delimiter = ',', value_name = "PACKAGE" )] @@ -906,6 +907,48 @@ mod tests { assert_eq!(push.requirements, Some(PathBuf::from("requirements.txt"))); } + #[test] + fn push_external_packages_flag_accepts_space_separated_values() { + let _guard = test_lock(); + let parsed = parse(&[ + "functions", + "push", + "--external-packages", + "sqlite3", + "fsevents", + "@mapbox/node-pre-gyp", + ]) + .expect("parse push"); + + let FunctionsCommands::Push(push) = parsed.command.expect("subcommand") else { + panic!("expected push command"); + }; + assert_eq!( + push.external_packages, + vec!["sqlite3", "fsevents", "@mapbox/node-pre-gyp"] + ); + } + + #[test] + fn push_external_packages_flag_accepts_comma_delimited_values() { + let _guard = test_lock(); + let parsed = parse(&[ + "functions", + "push", + "--external-packages", + "sqlite3,fsevents,@mapbox/node-pre-gyp", + ]) + .expect("parse push"); + + let FunctionsCommands::Push(push) = parsed.command.expect("subcommand") else { + panic!("expected push command"); + }; + assert_eq!( + push.external_packages, + vec!["sqlite3", "fsevents", "@mapbox/node-pre-gyp"] + ); + } + #[test] fn pull_language_from_env() { let _guard = test_lock(); diff --git a/src/functions/push.rs b/src/functions/push.rs index ba4c42c..953d199 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -32,10 +32,12 @@ use super::{ }; const FUNCTIONS_JS_RUNNER_FILE: &str = "functions-runner.ts"; +const FUNCTIONS_JS_BUNDLER_FILE: &str = "functions-bundler.ts"; const FUNCTIONS_PY_RUNNER_FILE: &str = "functions-runner.py"; const RUNNER_COMMON_FILE: &str = "runner-common.ts"; const PYTHON_RUNNER_COMMON_FILE: &str = "python_runner_common.py"; const FUNCTIONS_JS_RUNNER_SOURCE: &str = include_str!("../../scripts/functions-runner.ts"); +const FUNCTIONS_JS_BUNDLER_SOURCE: &str = include_str!("../../scripts/functions-bundler.ts"); const FUNCTIONS_PY_RUNNER_SOURCE: &str = include_str!("../../scripts/functions-runner.py"); const RUNNER_COMMON_SOURCE: &str = include_str!("../../scripts/runner-common.ts"); const PYTHON_RUNNER_COMMON_SOURCE: &str = include_str!("../../scripts/python_runner_common.py"); @@ -259,12 +261,19 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { if args.tsconfig.is_some() { eprintln!( - "Notice: --tsconfig is enabled for runner compatibility where supported (TS_NODE_PROJECT/TSX_TSCONFIG_PATH)." + "Notice: --tsconfig is enabled for JS runner and JS bundling (TS_NODE_PROJECT/TSX_TSCONFIG_PATH)." ); } if !args.external_packages.is_empty() { - eprintln!( - "Notice: --external-packages is accepted for compatibility; dependency handling is runner-managed in bt functions push." + eprintln!("Notice: --external-packages will be applied to JS bundle builds."); + } + if !args.external_packages.is_empty() && selected_language != SourceLanguage::JsLike { + return fail_push( + &base, + 0, + HardFailureReason::ManifestSchemaInvalid, + "--external-packages can only be used when pushing JS/TS sources".to_string(), + "invalid --external-packages usage", ); } @@ -745,10 +754,7 @@ async fn push_file( if !code_entries.is_empty() { let (upload_bytes, content_encoding) = match selected_language { SourceLanguage::JsLike => { - let bundle_bytes = std::fs::read(source_path).map_err(|err| FileFailure { - reason: HardFailureReason::ManifestPathMissing, - message: format!("failed to read {}: {err}", source_path.display()), - })?; + let bundle_bytes = build_js_bundle(source_path, args)?; let gzipped = gzip_bytes(&bundle_bytes).map_err(|err| FileFailure { reason: HardFailureReason::BundleUploadFailed, message: format!("failed to gzip {}: {err}", source_path.display()), @@ -938,6 +944,67 @@ async fn push_file( }) } +fn build_js_bundle( + source_path: &Path, + args: &PushArgs, +) -> std::result::Result, FileFailure> { + let build_dir = TempBuildDir::create("bt-functions-js-bundle").map_err(|err| FileFailure { + reason: HardFailureReason::BundleUploadFailed, + message: err.to_string(), + })?; + let output_bundle = build_dir.path.join("bundle.js"); + + let bundler_script = js_runner::materialize_runner_script_in_cwd( + "functions-runners", + FUNCTIONS_JS_BUNDLER_FILE, + FUNCTIONS_JS_BUNDLER_SOURCE, + ) + .map_err(|err| FileFailure { + reason: HardFailureReason::RunnerSpawnFailed, + message: format!("failed to materialize JS bundler script: {err}"), + })?; + + let mut command = js_runner::build_js_runner_command( + args.runner.as_deref(), + &bundler_script, + &[source_path.to_path_buf(), output_bundle.clone()], + ); + if let Some(tsconfig) = &args.tsconfig { + command.env("TS_NODE_PROJECT", tsconfig); + command.env("TSX_TSCONFIG_PATH", tsconfig); + } + if !args.external_packages.is_empty() { + command.env( + "BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES", + args.external_packages.join(","), + ); + } + + let output = command.output().map_err(|err| FileFailure { + reason: HardFailureReason::RunnerSpawnFailed, + message: format!("failed to spawn JS bundler: {err}"), + })?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(FileFailure { + reason: HardFailureReason::BundleUploadFailed, + message: format!( + "JS bundler exited with status {}: {}", + output.status, + stderr.trim() + ), + }); + } + + std::fs::read(&output_bundle).map_err(|err| FileFailure { + reason: HardFailureReason::BundleUploadFailed, + message: format!( + "failed to read bundled JS output {}: {err}", + output_bundle.display() + ), + }) +} + fn calculate_upload_counts(total_entries: usize, ignored_entries: Option) -> (usize, usize) { let ignored_entries = ignored_entries.unwrap_or(0); let uploaded_entries = total_entries.saturating_sub(ignored_entries); diff --git a/tests/functions.rs b/tests/functions.rs index 476096a..fc8762f 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -122,6 +122,19 @@ fn find_tsc() -> Option { None } +fn decode_uploaded_bundle(bundle: &[u8]) -> String { + if bundle.starts_with(&[0x1f, 0x8b]) { + let mut decoder = GzDecoder::new(bundle); + let mut out = String::new(); + decoder + .read_to_string(&mut out) + .expect("decompress uploaded bundle"); + out + } else { + String::from_utf8(bundle.to_vec()).expect("uploaded bundle utf8") + } +} + fn read_fixture_config(path: &Path) -> FixtureConfig { let raw = fs::read_to_string(path).expect("read fixture.json"); serde_json::from_str(&raw).expect("parse fixture.json") @@ -199,6 +212,7 @@ impl MockServer { .route("/insert-functions", web::post().to(mock_insert_functions)) .route("/v1/function", web::get().to(mock_list_functions)) }) + .workers(1) .listen(listener) .expect("listen mock server") .run(); @@ -1199,6 +1213,9 @@ async fn functions_push_works_against_mock_api() { set -eu _runner_script="$1" shift +_runner_name="$(basename "$_runner_script")" + +if [ "$_runner_name" = "functions-runner.ts" ]; then node - "$@" <<'NODE' const path = require("node:path"); const files = process.argv.slice(2); @@ -1221,6 +1238,18 @@ const manifest = { }; process.stdout.write(JSON.stringify(manifest)); NODE +exit 0 +fi + +if [ "$_runner_name" = "functions-bundler.ts" ]; then + _source_file="$1" + _output_file="$2" + cp "$_source_file" "$_output_file" + exit 0 +fi + +echo "unexpected runner script: $_runner_name" >&2 +exit 24 "#, ) .expect("write mock runner"); @@ -1313,23 +1342,435 @@ NODE .clone(); assert_eq!(uploaded.len(), 1, "expected one uploaded bundle"); let bundle = &uploaded[0]; - if bundle.starts_with(&[0x1f, 0x8b]) { - let mut decoder = GzDecoder::new(bundle.as_slice()); - let mut decompressed = String::new(); - decoder - .read_to_string(&mut decompressed) - .expect("decompress uploaded bundle"); - assert!( - decompressed.contains("globalThis._evals"), - "uploaded bundle should contain original source" - ); - } else { - let raw = String::from_utf8(bundle.clone()).expect("uploaded bundle utf8"); - assert!( - raw.contains("globalThis._evals"), - "uploaded bundle should contain original source" + let decompressed = decode_uploaded_bundle(bundle); + assert!( + decompressed.contains("globalThis._evals"), + "uploaded bundle should contain original source" + ); +} + +#[cfg(unix)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn functions_push_external_packages_bundles_with_runner() { + if !command_exists("node") { + eprintln!( + "Skipping functions_push_external_packages_bundles_with_runner (node not installed)." ); + return; } + + let state = Arc::new(MockServerState::default()); + state + .projects + .lock() + .expect("projects lock") + .push(MockProject { + id: "proj_mock".to_string(), + name: "mock-project".to_string(), + org_id: "org_mock".to_string(), + }); + let server = MockServer::start(state.clone()).await; + + let tmp = tempdir().expect("tempdir"); + let source = tmp.path().join("tool.js"); + std::fs::write( + &source, + "globalThis._evals ??= { functions: [], prompts: [], parameters: [], evaluators: {}, reporters: {} };\n", + ) + .expect("write source file"); + + let runner = tmp.path().join("mock-runner.sh"); + std::fs::write( + &runner, + r#"#!/bin/sh +set -eu +_runner_script="$1" +shift +_runner_name="$(basename "$_runner_script")" + +if [ "$_runner_name" = "functions-runner.ts" ]; then + node - "$@" <<'NODE' +const path = require("node:path"); +const files = process.argv.slice(2); +const manifest = { + runtime_context: { runtime: "node", version: process.versions.node || "unknown" }, + files: files.map((file, index) => ({ + source_file: path.resolve(file), + entries: [ + { + kind: "code", + project_id: "proj_mock", + name: index === 0 ? "mock-tool" : `mock-tool-${index}`, + slug: index === 0 ? "mock-tool" : `mock-tool-${index}`, + function_type: "tool", + preview: "function handler() { return 1; }", + location: { type: "function", index: 0 } + } + ] + })) +}; +process.stdout.write(JSON.stringify(manifest)); +NODE + exit 0 +fi + +if [ "$_runner_name" = "functions-bundler.ts" ]; then + if [ "${BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES:-}" != "sqlite3,fsevents" ]; then + echo "unexpected BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES=${BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES:-}" >&2 + exit 23 + fi + _source_file="$1" + _output_file="$2" + printf '%s\n' "// bundled output" >"$_output_file" + printf '%s\n' "const externalMarker = \"externals:${BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES}\";" >>"$_output_file" + printf '%s\n' "const sourceMarker = \"source:${_source_file}\";" >>"$_output_file" + exit 0 +fi + +echo "unexpected runner script: $_runner_name" >&2 +exit 24 +"#, + ) + .expect("write mock runner"); + use std::os::unix::fs::PermissionsExt; + let mut perms = std::fs::metadata(&runner) + .expect("runner metadata") + .permissions(); + perms.set_mode(0o755); + std::fs::set_permissions(&runner, perms).expect("runner permissions"); + + let output = Command::new(bt_binary_path()) + .current_dir(tmp.path()) + .args([ + "functions", + "--json", + "push", + "--file", + source + .to_str() + .expect("source path should be valid UTF-8 for test"), + "--language", + "javascript", + "--runner", + runner + .to_str() + .expect("runner path should be valid UTF-8 for test"), + "--external-packages", + "sqlite3,fsevents", + ]) + .env("BRAINTRUST_API_KEY", "test-key") + .env("BRAINTRUST_ORG_NAME", "test-org") + .env("BRAINTRUST_API_URL", &server.base_url) + .env("BRAINTRUST_APP_URL", &server.base_url) + .env("BRAINTRUST_NO_COLOR", "1") + .env_remove("BRAINTRUST_PROFILE") + .output() + .expect("run bt functions push"); + + server.stop().await; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("mock push failed:\n{stderr}"); + } + + let summary: Value = serde_json::from_slice(&output.stdout).expect("parse push summary"); + assert_eq!(summary["status"].as_str(), Some("success")); + assert_eq!(summary["uploaded_files"].as_u64(), Some(1)); + assert_eq!(summary["failed_files"].as_u64(), Some(0)); + + let uploaded = state + .uploaded_bundles + .lock() + .expect("uploaded bundles lock") + .clone(); + assert_eq!(uploaded.len(), 1, "expected one uploaded bundle"); + let bundle = &uploaded[0]; + let decompressed = decode_uploaded_bundle(bundle); + assert!( + decompressed.contains("externals:sqlite3,fsevents"), + "uploaded bundle should include bundler output with external package marker" + ); +} + +#[cfg(unix)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn functions_push_js_bundles_by_default() { + if !command_exists("node") { + eprintln!("Skipping functions_push_js_bundles_by_default (node not installed)."); + return; + } + + let state = Arc::new(MockServerState::default()); + state + .projects + .lock() + .expect("projects lock") + .push(MockProject { + id: "proj_mock".to_string(), + name: "mock-project".to_string(), + org_id: "org_mock".to_string(), + }); + let server = MockServer::start(state.clone()).await; + + let tmp = tempdir().expect("tempdir"); + let source = tmp.path().join("tool.js"); + std::fs::write( + &source, + "globalThis._evals ??= { functions: [], prompts: [], parameters: [], evaluators: {}, reporters: {} };\n", + ) + .expect("write source file"); + + let runner = tmp.path().join("mock-runner.sh"); + std::fs::write( + &runner, + r#"#!/bin/sh +set -eu +_runner_script="$1" +shift +_runner_name="$(basename "$_runner_script")" + +if [ "$_runner_name" = "functions-runner.ts" ]; then + node - "$@" <<'NODE' +const path = require("node:path"); +const files = process.argv.slice(2); +const manifest = { + runtime_context: { runtime: "node", version: process.versions.node || "unknown" }, + files: files.map((file, index) => ({ + source_file: path.resolve(file), + entries: [ + { + kind: "code", + project_id: "proj_mock", + name: index === 0 ? "mock-tool" : `mock-tool-${index}`, + slug: index === 0 ? "mock-tool" : `mock-tool-${index}`, + function_type: "tool", + preview: "function handler() { return 1; }", + location: { type: "function", index: 0 } + } + ] + })) +}; +process.stdout.write(JSON.stringify(manifest)); +NODE + exit 0 +fi + +if [ "$_runner_name" = "functions-bundler.ts" ]; then + _output_file="$2" + printf '%s\n' "// bundled by default path" >"$_output_file" + printf '%s\n' "const marker = \"default-bundler-used\";" >>"$_output_file" + exit 0 +fi + +echo "unexpected runner script: $_runner_name" >&2 +exit 24 +"#, + ) + .expect("write mock runner"); + use std::os::unix::fs::PermissionsExt; + let mut perms = std::fs::metadata(&runner) + .expect("runner metadata") + .permissions(); + perms.set_mode(0o755); + std::fs::set_permissions(&runner, perms).expect("runner permissions"); + + let output = Command::new(bt_binary_path()) + .current_dir(tmp.path()) + .args([ + "functions", + "--json", + "push", + "--file", + source + .to_str() + .expect("source path should be valid UTF-8 for test"), + "--language", + "javascript", + "--runner", + runner + .to_str() + .expect("runner path should be valid UTF-8 for test"), + ]) + .env("BRAINTRUST_API_KEY", "test-key") + .env("BRAINTRUST_ORG_NAME", "test-org") + .env("BRAINTRUST_API_URL", &server.base_url) + .env("BRAINTRUST_APP_URL", &server.base_url) + .env("BRAINTRUST_NO_COLOR", "1") + .env_remove("BRAINTRUST_PROFILE") + .output() + .expect("run bt functions push"); + + server.stop().await; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("mock push failed:\n{stderr}"); + } + + let uploaded = state + .uploaded_bundles + .lock() + .expect("uploaded bundles lock") + .clone(); + assert_eq!(uploaded.len(), 1, "expected one uploaded bundle"); + let decompressed = decode_uploaded_bundle(&uploaded[0]); + assert!( + decompressed.contains("default-bundler-used"), + "uploaded bundle should include marker emitted by bundler path" + ); +} + +#[cfg(unix)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn functions_push_tsconfig_is_forwarded_to_bundler() { + if !command_exists("node") { + eprintln!("Skipping functions_push_tsconfig_is_forwarded_to_bundler (node not installed)."); + return; + } + + let state = Arc::new(MockServerState::default()); + state + .projects + .lock() + .expect("projects lock") + .push(MockProject { + id: "proj_mock".to_string(), + name: "mock-project".to_string(), + org_id: "org_mock".to_string(), + }); + let server = MockServer::start(state.clone()).await; + + let tmp = tempdir().expect("tempdir"); + let source = tmp.path().join("tool.js"); + std::fs::write( + &source, + "globalThis._evals ??= { functions: [], prompts: [], parameters: [], evaluators: {}, reporters: {} };\n", + ) + .expect("write source file"); + let tsconfig = tmp.path().join("tsconfig.json"); + std::fs::write( + &tsconfig, + "{ \"compilerOptions\": { \"target\": \"ES2020\" } }", + ) + .expect("write tsconfig"); + + let runner = tmp.path().join("mock-runner.sh"); + std::fs::write( + &runner, + r#"#!/bin/sh +set -eu +_runner_script="$1" +shift +_runner_name="$(basename "$_runner_script")" + +if [ "$_runner_name" = "functions-runner.ts" ]; then + node - "$@" <<'NODE' +const path = require("node:path"); +const files = process.argv.slice(2); +const manifest = { + runtime_context: { runtime: "node", version: process.versions.node || "unknown" }, + files: files.map((file, index) => ({ + source_file: path.resolve(file), + entries: [ + { + kind: "code", + project_id: "proj_mock", + name: index === 0 ? "mock-tool" : `mock-tool-${index}`, + slug: index === 0 ? "mock-tool" : `mock-tool-${index}`, + function_type: "tool", + preview: "function handler() { return 1; }", + location: { type: "function", index: 0 } + } + ] + })) +}; +process.stdout.write(JSON.stringify(manifest)); +NODE + exit 0 +fi + +if [ "$_runner_name" = "functions-bundler.ts" ]; then + if [ "${TS_NODE_PROJECT:-}" != "${EXPECTED_TSCONFIG:-}" ]; then + echo "unexpected TS_NODE_PROJECT=${TS_NODE_PROJECT:-}" >&2 + exit 31 + fi + if [ "${TSX_TSCONFIG_PATH:-}" != "${EXPECTED_TSCONFIG:-}" ]; then + echo "unexpected TSX_TSCONFIG_PATH=${TSX_TSCONFIG_PATH:-}" >&2 + exit 32 + fi + _output_file="$2" + printf '%s\n' "// bundled with tsconfig" >"$_output_file" + printf '%s\n' "const marker = \"tsconfig-forwarded:${TS_NODE_PROJECT}\";" >>"$_output_file" + exit 0 +fi + +echo "unexpected runner script: $_runner_name" >&2 +exit 24 +"#, + ) + .expect("write mock runner"); + use std::os::unix::fs::PermissionsExt; + let mut perms = std::fs::metadata(&runner) + .expect("runner metadata") + .permissions(); + perms.set_mode(0o755); + std::fs::set_permissions(&runner, perms).expect("runner permissions"); + + let output = Command::new(bt_binary_path()) + .current_dir(tmp.path()) + .args([ + "functions", + "--json", + "push", + "--file", + source + .to_str() + .expect("source path should be valid UTF-8 for test"), + "--language", + "javascript", + "--runner", + runner + .to_str() + .expect("runner path should be valid UTF-8 for test"), + "--tsconfig", + tsconfig + .to_str() + .expect("tsconfig path should be valid UTF-8 for test"), + ]) + .env("EXPECTED_TSCONFIG", &tsconfig) + .env("BRAINTRUST_API_KEY", "test-key") + .env("BRAINTRUST_ORG_NAME", "test-org") + .env("BRAINTRUST_API_URL", &server.base_url) + .env("BRAINTRUST_APP_URL", &server.base_url) + .env("BRAINTRUST_NO_COLOR", "1") + .env_remove("BRAINTRUST_PROFILE") + .output() + .expect("run bt functions push"); + + server.stop().await; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("mock push failed:\n{stderr}"); + } + + let uploaded = state + .uploaded_bundles + .lock() + .expect("uploaded bundles lock") + .clone(); + assert_eq!(uploaded.len(), 1, "expected one uploaded bundle"); + let decompressed = decode_uploaded_bundle(&uploaded[0]); + assert!( + decompressed.contains(&format!( + "tsconfig-forwarded:{}", + tsconfig + .to_str() + .expect("tsconfig path should be valid UTF-8 for test") + )), + "uploaded bundle should include tsconfig marker emitted by bundler path" + ); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] From b751e135177cb630b43037083a0681aac555ca79 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Fri, 6 Mar 2026 16:16:13 -0800 Subject: [PATCH 11/28] fix(pull): use correct variable names for project resolution --- Cargo.lock | 1 + Cargo.toml | 3 ++ src/eval.rs | 1 + src/functions/pull.rs | 105 +++++++++++++---------------------------- src/functions/push.rs | 38 +++++++++++++-- src/http.rs | 4 +- src/utils/fs_atomic.rs | 104 +++++++++++++++++++++++++++++++++++++++- tests/functions.rs | 86 +++++++++++++++++++++++++++++++++ 8 files changed, 265 insertions(+), 77 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 597cd9d..cb90717 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,6 +471,7 @@ dependencies = [ "tokio", "unicode-width 0.1.14", "urlencoding", + "windows-sys 0.59.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 7e300fb..00efe45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,3 +74,6 @@ install-success-msg = "" [dev-dependencies] tempfile = "3" + +[target.'cfg(windows)'.dependencies] +windows-sys = { version = "0.59", features = ["Win32_Storage_FileSystem"] } diff --git a/src/eval.rs b/src/eval.rs index aba1c34..d108c61 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -642,6 +642,7 @@ struct EvalSpawned { runner_kind: RunnerKind, } +#[allow(clippy::too_many_arguments)] async fn spawn_eval_runner( base: &BaseArgs, language: EvalLanguage, diff --git a/src/functions/pull.rs b/src/functions/pull.rs index 4881fe6..fdec265 100644 --- a/src/functions/pull.rs +++ b/src/functions/pull.rs @@ -20,6 +20,11 @@ use super::{resolve_auth_context, FunctionsLanguage, PullArgs}; const PAGINATION_PAGE_LIMIT: usize = 10_000; const OUTPUT_LOCK_FILE: &str = ".bt-functions-pull.lock"; +// Pretty version IDs are a reversible encoding of internal transaction IDs +// (_xact_id). The encoding multiplies the xact ID (with a fixed top-nibble tag) +// by COPRIME mod 2^64, producing a 16-hex-char string that looks random but +// decodes back via the modular inverse. This lets `--version` accept either the +// raw numeric xact ID or the pretty hex form transparently. const TOP_BITS: u64 = 0x0DE1u64 << 48; const MODULUS: u128 = 1u128 << 64; const COPRIME: u64 = 205_891_132_094_649; @@ -279,7 +284,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { }; let repo = GitRepo::discover_from(&canonical_output_dir); - let project_names = if materializable.is_empty() { + let project_names = if project_ids_with_matches.is_empty() { BTreeMap::new() } else { let projects = match get_projects_cached(&auth_ctx.client, &mut projects_cache).await { @@ -293,7 +298,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { ); } }; - match resolve_project_names(&materializable, projects) { + match resolve_project_names(&winners, projects) { Ok(names) => names, Err(err) => { return fail_pull( @@ -631,27 +636,7 @@ fn resolve_project_names( Ok(names_by_id) } -#[allow(dead_code)] -fn group_rows_by_project( - rows: Vec, - project_names: &BTreeMap, -) -> Result>> { - let mut grouped = BTreeMap::new(); - for row in rows { - let Some(project_name) = project_names.get(&row.project_id).cloned() else { - bail!( - "missing resolved project name for project id '{}'", - row.project_id - ); - }; - grouped - .entry((row.project_id.clone(), project_name)) - .or_insert_with(Vec::new) - .push(row); - } - Ok(grouped) -} - +#[allow(clippy::too_many_arguments)] fn write_pull_file( summary: &mut PullSummary, canonical_output_dir: &Path, @@ -766,35 +751,6 @@ fn build_project_file_names( names } -#[allow(dead_code)] -fn build_output_file_names( - grouped_by_project: &BTreeMap<(String, String), Vec>, - slug_selector: Option<&str>, - ext: &str, -) -> Result> { - if let Some(slug) = slug_selector { - if grouped_by_project.len() != 1 { - bail!("slug selector matched multiple projects; pass --project-name or --project-id"); - } - let mut names = BTreeMap::new(); - let key = grouped_by_project - .keys() - .next() - .ok_or_else(|| anyhow!("missing grouped project for slug selector"))? - .clone(); - let base = sanitize_filename(slug); - let file_stem = if base.is_empty() { - "function".to_string() - } else { - base - }; - names.insert(key, format!("{file_stem}.{ext}")); - return Ok(names); - } - - Ok(build_project_file_names(grouped_by_project, ext)) -} - fn sanitize_filename(value: &str) -> String { let mut out = String::with_capacity(value.len()); let mut previous_dash = false; @@ -1714,7 +1670,7 @@ mod tests { } #[test] - fn group_rows_uses_resolved_project_name() { + fn resolve_project_names_uses_project_lookup() { let row = PullFunctionRow { id: "f1".to_string(), name: "Prompt".to_string(), @@ -1728,16 +1684,19 @@ mod tests { _xact_id: None, }; - let mut names = BTreeMap::new(); - names.insert("p1".to_string(), "Woohoo".to_string()); + let projects = vec![crate::projects::api::Project { + id: "p1".to_string(), + name: "Woohoo".to_string(), + org_id: "o1".to_string(), + description: None, + }]; - let grouped = group_rows_by_project(vec![row], &names).expect("grouped"); - assert_eq!(grouped.len(), 1); - assert!(grouped.contains_key(&("p1".to_string(), "Woohoo".to_string()))); + let names = resolve_project_names(&[row], &projects).expect("resolved names"); + assert_eq!(names.get("p1").map(String::as_str), Some("Woohoo")); } #[test] - fn group_rows_fails_when_project_name_missing() { + fn resolve_project_names_fails_when_missing() { let row = PullFunctionRow { id: "f1".to_string(), name: "Prompt".to_string(), @@ -1751,43 +1710,47 @@ mod tests { _xact_id: None, }; - let err = group_rows_by_project(vec![row], &BTreeMap::new()).expect_err("should fail"); - assert!(err.to_string().contains("project id")); + let err = resolve_project_names(&[row], &[]).expect_err("should fail"); + assert!(err.to_string().contains("failed to resolve project name")); } #[test] - fn slug_selector_names_output_file_from_slug() { + fn project_file_names_use_sanitized_project_name() { let mut grouped = BTreeMap::new(); grouped.insert( - ("p1".to_string(), "Project".to_string()), + ("p1".to_string(), "Doc Search".to_string()), Vec::::new(), ); - let names = - build_output_file_names(&grouped, Some("doc-search"), "ts").expect("file names"); + let names = build_project_file_names(&grouped, "ts"); assert_eq!( names - .get(&("p1".to_string(), "Project".to_string())) + .get(&("p1".to_string(), "Doc Search".to_string())) .map(String::as_str), Some("doc-search.ts") ); } #[test] - fn slug_selector_rejects_multiple_projects() { + fn project_file_names_include_project_id_suffix_on_collision() { let mut grouped = BTreeMap::new(); grouped.insert( ("p1".to_string(), "Project One".to_string()), Vec::::new(), ); grouped.insert( - ("p2".to_string(), "Project Two".to_string()), + ("p2".to_string(), "project-one".to_string()), Vec::::new(), ); - let err = - build_output_file_names(&grouped, Some("doc-search"), "ts").expect_err("should fail"); - assert!(err.to_string().contains("multiple projects")); + let names = build_project_file_names(&grouped, "ts"); + let first = names + .get(&("p1".to_string(), "Project One".to_string())) + .expect("first"); + let second = names + .get(&("p2".to_string(), "project-one".to_string())) + .expect("second"); + assert_ne!(first.to_ascii_lowercase(), second.to_ascii_lowercase()); } #[test] diff --git a/src/functions/push.rs b/src/functions/push.rs index 953d199..c7b9d4f 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -713,6 +713,7 @@ fn build_code_function_data( }) } +#[allow(clippy::too_many_arguments)] async fn push_file( auth_ctx: &super::AuthContext, default_project_id: Option<&str>, @@ -1196,18 +1197,36 @@ fn collect_classified_files(inputs: &[PathBuf]) -> Result { }) } +const MAX_DIR_DEPTH: usize = 64; + fn collect_from_dir( dir: &Path, js_like: &mut BTreeSet, python: &mut BTreeSet, ) -> Result<()> { + collect_from_dir_inner(dir, js_like, python, 0) +} + +fn collect_from_dir_inner( + dir: &Path, + js_like: &mut BTreeSet, + python: &mut BTreeSet, + depth: usize, +) -> Result<()> { + if depth > MAX_DIR_DEPTH { + bail!( + "directory traversal exceeded maximum depth ({}); possible symlink loop at {}", + MAX_DIR_DEPTH, + dir.display() + ); + } for entry in std::fs::read_dir(dir) .with_context(|| format!("failed to read directory {}", dir.display()))? { let entry = entry.with_context(|| format!("failed to read entry in {}", dir.display()))?; let path = entry.path(); if path.is_dir() { - collect_from_dir(&path, js_like, python)?; + collect_from_dir_inner(&path, js_like, python, depth + 1)?; } else if path.is_file() { let canonical = path .canonicalize() @@ -1786,19 +1805,30 @@ fn run_uv_command(uv: &Path, args: &[OsString], stage: &str) -> Result<()> { fn collect_regular_files_recursive(root: &Path) -> Result> { let mut files = Vec::new(); - collect_regular_files_recursive_impl(root, &mut files)?; + collect_regular_files_recursive_impl(root, &mut files, 0)?; files.sort(); Ok(files) } -fn collect_regular_files_recursive_impl(root: &Path, out: &mut Vec) -> Result<()> { +fn collect_regular_files_recursive_impl( + root: &Path, + out: &mut Vec, + depth: usize, +) -> Result<()> { + if depth > MAX_DIR_DEPTH { + bail!( + "directory traversal exceeded maximum depth ({}); possible symlink loop at {}", + MAX_DIR_DEPTH, + root.display() + ); + } for entry in std::fs::read_dir(root).with_context(|| format!("failed to read {}", root.display()))? { let entry = entry.with_context(|| format!("failed to read entry in {}", root.display()))?; let path = entry.path(); if path.is_dir() { - collect_regular_files_recursive_impl(&path, out)?; + collect_regular_files_recursive_impl(&path, out, depth + 1)?; } else if path.is_file() { out.push(path); } diff --git a/src/http.rs b/src/http.rs index 40685af..e077add 100644 --- a/src/http.rs +++ b/src/http.rs @@ -204,13 +204,15 @@ impl ApiClient { } } +const UPLOAD_HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); + pub async fn put_signed_url( url: &str, body: Vec, content_encoding: Option<&str>, ) -> Result<()> { let client = Client::builder() - .timeout(DEFAULT_HTTP_TIMEOUT) + .timeout(UPLOAD_HTTP_TIMEOUT) .build() .context("failed to build signed-url HTTP client")?; diff --git a/src/utils/fs_atomic.rs b/src/utils/fs_atomic.rs index 4906a74..b54ecc0 100644 --- a/src/utils/fs_atomic.rs +++ b/src/utils/fs_atomic.rs @@ -29,13 +29,115 @@ pub fn write_text_atomic(path: &Path, contents: &str) -> Result<()> { std::fs::write(&tmp, contents) .with_context(|| format!("failed to write temporary file {}", tmp.display()))?; - std::fs::rename(&tmp, path).with_context(|| { + replace_file_atomic(&tmp, path)?; + + Ok(()) +} + +#[cfg(not(windows))] +fn replace_file_atomic(tmp: &Path, path: &Path) -> Result<()> { + std::fs::rename(tmp, path).with_context(|| { + format!( + "failed to replace {} with temporary file {}", + path.display(), + tmp.display() + ) + })?; + Ok(()) +} + +#[cfg(windows)] +fn replace_file_atomic(tmp: &Path, path: &Path) -> Result<()> { + if path.exists() { + replace_existing_file_windows(tmp, path)?; + return Ok(()); + } + + let rename_attempt = std::fs::rename(tmp, path); + if rename_attempt.is_ok() { + return Ok(()); + } + + if path.exists() { + replace_existing_file_windows(tmp, path)?; + return Ok(()); + } + + rename_attempt.with_context(|| { format!( "failed to replace {} with temporary file {}", path.display(), tmp.display() ) })?; + Ok(()) +} + +#[cfg(windows)] +fn replace_existing_file_windows(tmp: &Path, path: &Path) -> Result<()> { + use std::iter; + use std::os::windows::ffi::OsStrExt; + use windows_sys::Win32::Storage::FileSystem::ReplaceFileW; + + let target = path + .as_os_str() + .encode_wide() + .chain(iter::once(0)) + .collect::>(); + let replacement = tmp + .as_os_str() + .encode_wide() + .chain(iter::once(0)) + .collect::>(); + + // SAFETY: Both paths are null-terminated UTF-16 strings with stable backing + // storage for the duration of the call, and optional pointers are null. + let replaced = unsafe { + ReplaceFileW( + target.as_ptr(), + replacement.as_ptr(), + std::ptr::null(), + 0, + std::ptr::null_mut(), + std::ptr::null_mut(), + ) + }; + if replaced == 0 { + return Err(std::io::Error::last_os_error()).with_context(|| { + format!( + "failed to replace {} with temporary file {}", + path.display(), + tmp.display() + ) + }); + } Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + #[test] + fn write_text_atomic_creates_file() { + let tmp = tempdir().expect("tempdir"); + let path = tmp.path().join("file.txt"); + + write_text_atomic(&path, "hello").expect("write"); + let contents = std::fs::read_to_string(&path).expect("read"); + assert_eq!(contents, "hello"); + } + + #[test] + fn write_text_atomic_overwrites_existing_file() { + let tmp = tempdir().expect("tempdir"); + let path = tmp.path().join("file.txt"); + std::fs::write(&path, "old").expect("seed file"); + + write_text_atomic(&path, "new").expect("overwrite"); + let contents = std::fs::read_to_string(&path).expect("read"); + assert_eq!(contents, "new"); + } +} diff --git a/tests/functions.rs b/tests/functions.rs index fc8762f..bf3ca3a 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -1881,3 +1881,89 @@ async fn functions_pull_works_against_mock_api() { "pull request should include selector query params" ); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn functions_pull_selector_with_unsupported_only_rows_still_succeeds() { + let state = Arc::new(MockServerState::default()); + state + .projects + .lock() + .expect("projects lock") + .push(MockProject { + id: "proj_mock".to_string(), + name: "mock-project".to_string(), + org_id: "org_mock".to_string(), + }); + state + .pull_rows + .lock() + .expect("pull rows lock") + .push(serde_json::json!({ + "id": "fn_code_1", + "name": "Legacy Code Function", + "slug": "legacy-code", + "project_id": "proj_mock", + "description": "", + "function_data": { "type": "code" }, + "_xact_id": "0000000000000001" + })); + + let server = MockServer::start(state.clone()).await; + + let tmp = tempdir().expect("tempdir"); + let out_dir = tmp.path().join("pulled"); + std::fs::create_dir_all(&out_dir).expect("create output dir"); + + let output = Command::new(bt_binary_path()) + .current_dir(tmp.path()) + .args([ + "functions", + "--json", + "pull", + "--project-id", + "proj_mock", + "--slug", + "legacy-code", + "--force", + "--output-dir", + out_dir + .to_str() + .expect("output dir should be valid UTF-8 for test"), + "--language", + "typescript", + ]) + .env("BRAINTRUST_API_KEY", "test-key") + .env("BRAINTRUST_ORG_NAME", "test-org") + .env("BRAINTRUST_API_URL", &server.base_url) + .env("BRAINTRUST_APP_URL", &server.base_url) + .env("BRAINTRUST_NO_COLOR", "1") + .env_remove("BRAINTRUST_PROFILE") + .output() + .expect("run bt functions pull"); + + server.stop().await; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("mock pull failed:\n{stderr}"); + } + + let summary: Value = serde_json::from_slice(&output.stdout).expect("parse pull summary"); + assert_eq!(summary["status"].as_str(), Some("partial")); + assert_eq!(summary["files_written"].as_u64(), Some(1)); + assert_eq!(summary["files_failed"].as_u64(), Some(0)); + assert_eq!(summary["unsupported_records_skipped"].as_u64(), Some(1)); + assert_eq!(summary["functions_materialized"].as_u64(), Some(0)); + + let rendered_file = out_dir.join("mock-project.ts"); + assert!(rendered_file.is_file(), "expected rendered file to exist"); + let rendered = std::fs::read_to_string(&rendered_file).expect("read rendered file"); + assert!( + rendered.contains("const project = braintrust.projects.create"), + "rendered file should still include project scaffold" + ); + assert!( + !rendered.contains("project.prompts.create"), + "rendered file should not include prompt materializations" + ); +} From c9622eb2cd08517adf27fda3827177f92bb91bb8 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Mon, 9 Mar 2026 10:56:41 -0700 Subject: [PATCH 12/28] refactor(functions): remove create-missing-projects flag and confirmation prompt --- src/functions/mod.rs | 107 +---- src/functions/push.rs | 379 +++++------------- .../push-help-env-vars/fixture.json | 3 +- .../push-help-flags/fixture.json | 1 - .../push-invalid-bool-env/fixture.json | 8 - .../push-valid-bool-env/fixture.json | 9 - tests/functions.rs | 3 - 7 files changed, 106 insertions(+), 404 deletions(-) delete mode 100644 tests/functions-fixtures/push-invalid-bool-env/fixture.json delete mode 100644 tests/functions-fixtures/push-valid-bool-env/fixture.json diff --git a/src/functions/mod.rs b/src/functions/mod.rs index 0f12a68..70155b7 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -9,7 +9,7 @@ use crate::{ config, http::ApiClient, projects::api::{get_project_by_name, Project}, - ui::{self, fuzzy_select, is_interactive, select_project_interactive, with_spinner}, + ui::{self, is_interactive, select_project_interactive, with_spinner}, }; pub(crate) mod api; @@ -319,14 +319,9 @@ pub(crate) struct PushArgs { )] pub external_packages: Vec, - /// Create missing projects referenced by function definitions. - #[arg( - long = "create-missing-projects", - env = "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS", - default_value_t = true, - value_parser = BoolishValueParser::new() - )] - pub create_missing_projects: bool, + /// Skip confirmation prompt. + #[arg(long, short = 'y')] + pub yes: bool, } impl PushArgs { @@ -480,13 +475,6 @@ pub(crate) async fn resolve_auth_context(base: &BaseArgs) -> Result }) } -#[derive(Debug)] -pub(crate) enum OrgDecision { - Continue, - Switch(String), - Cancel, -} - pub(crate) fn current_org_label(auth_ctx: &AuthContext) -> String { if auth_ctx.client.org_name().trim().is_empty() { auth_ctx.org_id.clone() @@ -523,76 +511,6 @@ pub(crate) fn validate_explicit_org_selection( bail!("org '{explicit_org}' is not available for this credential. Available: {available}"); } -/// Prompt the user to confirm/switch org when multiple orgs are available. -/// `prompt` is the question text, `action_label` is used for the confirm option (e.g. "Push to", "Pull from"). -pub(crate) fn resolve_org_decision( - base: &BaseArgs, - auth_ctx: &AuthContext, - available_orgs: &[AvailableOrg], - prompt: &str, - action_label: &str, -) -> Result<(OrgDecision, bool)> { - if base - .org_name - .as_deref() - .map(str::trim) - .filter(|value| !value.is_empty()) - .is_some() - { - return Ok((OrgDecision::Continue, false)); - } - - if available_orgs.len() <= 1 { - return Ok((OrgDecision::Continue, false)); - } - - if !is_interactive() { - bail!( - "multiple organizations are available for this credential; pass --org in non-interactive mode" - ); - } - - let org_label = current_org_label(auth_ctx); - - let options = [ - format!("{action_label} {org_label}"), - "Switch org".to_string(), - "Cancel".to_string(), - ]; - let option_refs = options.iter().map(String::as_str).collect::>(); - let choice = fuzzy_select(prompt, &option_refs, 0)?; - - match choice { - 0 => Ok((OrgDecision::Continue, true)), - 1 => { - let mut labels = Vec::with_capacity(available_orgs.len()); - let mut default_index = 0usize; - for (index, org) in available_orgs.iter().enumerate() { - let label = if org.api_url.is_some() { - format!("{} [{}]", org.name, org.id) - } else { - org.name.clone() - }; - if org.name == org_label || org.name.eq_ignore_ascii_case(&org_label) { - default_index = index; - } - labels.push(label); - } - let label_refs = labels.iter().map(String::as_str).collect::>(); - let selected_index = fuzzy_select("Select organization", &label_refs, default_index)?; - let selected = available_orgs - .get(selected_index) - .ok_or_else(|| anyhow!("invalid org selection"))?; - if selected.name == org_label || selected.name.eq_ignore_ascii_case(&org_label) { - Ok((OrgDecision::Continue, true)) - } else { - Ok((OrgDecision::Switch(selected.name.clone()), true)) - } - } - _ => Ok((OrgDecision::Cancel, true)), - } -} - pub(crate) async fn resolve_project_context( base: &BaseArgs, auth_ctx: &AuthContext, @@ -828,23 +746,6 @@ mod tests { assert!(push.terminate_on_failure); } - #[test] - fn push_create_missing_projects_flag_from_env() { - let _guard = test_lock(); - unsafe { - std::env::set_var("BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS", "true"); - } - let parsed = parse(&["functions", "push"]).expect("parse push"); - unsafe { - std::env::remove_var("BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS"); - } - - let FunctionsCommands::Push(push) = parsed.command.expect("subcommand") else { - panic!("expected push command"); - }; - assert!(push.create_missing_projects); - } - #[test] fn push_repeated_file_flags_append_in_order() { let _guard = test_lock(); diff --git a/src/functions/push.rs b/src/functions/push.rs index c7b9d4f..d6d0f18 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -4,31 +4,34 @@ use std::path::{Path, PathBuf}; use std::process::{Command, Output}; use std::time::{SystemTime, UNIX_EPOCH}; +use std::io::IsTerminal; +use std::time::Duration; + use anyhow::{anyhow, bail, Context, Result}; use dialoguer::console::style; -use dialoguer::theme::ColorfulTheme; use dialoguer::Confirm; -use reqwest::StatusCode; +use indicatif::{ProgressBar, ProgressStyle}; use serde::Deserialize; use serde_json::{json, Map, Value}; use crate::args::BaseArgs; -use crate::auth::{list_available_orgs, list_profiles}; + +use crate::auth::list_available_orgs; use crate::config; use crate::functions::report::{ CommandStatus, FileStatus, HardFailureReason, PushFileReport, PushSummary, ReportError, - ReportWarning, SoftSkipReason, + SoftSkipReason, }; use crate::js_runner; -use crate::projects::api::{create_project, get_project_by_name, list_projects}; +use crate::projects::api::{get_project_by_name, list_projects}; use crate::python_runner; use crate::source_language::{classify_runtime_extension, JsExtensionProfile, SourceLanguage}; -use crate::ui::is_interactive; +use crate::ui::{animations_enabled, is_interactive, is_quiet}; use super::api; use super::{ - current_org_label, resolve_auth_context, resolve_org_decision, validate_explicit_org_selection, - OrgDecision, PushArgs, PushLanguage, + current_org_label, resolve_auth_context, validate_explicit_org_selection, PushArgs, + PushLanguage, }; const FUNCTIONS_JS_RUNNER_FILE: &str = "functions-runner.ts"; @@ -155,7 +158,6 @@ struct ResolvedFileTargets { struct ResolvedManifestTargets { entries: Vec, per_file: Vec, - unique_project_ids: Vec, } #[derive(Debug, Default)] @@ -206,7 +208,7 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { ); } - let mut auth_ctx = match resolve_auth_context(&base) + let auth_ctx = match resolve_auth_context(&base) .await .context("failed to resolve auth context") { @@ -400,67 +402,32 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { .collect(); let preflight_project_names: Vec = preflight.named_projects.iter().cloned().collect(); - let org_prompt = - build_push_org_prompt(&auth_ctx, &preflight_source_files, &preflight_project_names); - let (org_decision, org_prompt_confirmed) = - match resolve_org_decision(&base, &auth_ctx, &available_orgs, &org_prompt, "Push to") { - Ok(outcome) => outcome, + if !args.yes && is_interactive() { + let prompt = + build_push_confirm_prompt(&auth_ctx, &preflight_source_files, &preflight_project_names); + let confirmed = Confirm::new() + .with_prompt(prompt) + .default(false) + .interact()?; + if !confirmed { + return cancel_push(&base, &files); + } + } + + let mut project_name_cache = + match resolve_named_projects(&auth_ctx, &preflight.named_projects).await { + Ok(cache) => cache, Err(err) => { - return fail_push( + let message = format!("failed to resolve target projects for push: {err}"); + return fail_push_manifest_preflight( &base, - files.len(), - HardFailureReason::ResponseInvalid, - error_chain(&err), - "failed to resolve org context", + &files, + &message, + "skipped because project target resolution failed", ); } }; - match org_decision { - OrgDecision::Continue => {} - OrgDecision::Switch(org_name) => { - let mut switched_base = base.clone(); - switched_base.org_name = Some(org_name); - auth_ctx = match resolve_auth_context(&switched_base) - .await - .context("failed to resolve switched org context") - { - Ok(ctx) => ctx, - Err(err) => { - return fail_push( - &base, - files.len(), - HardFailureReason::AuthFailed, - error_chain(&err), - "failed to resolve switched org context", - ); - } - }; - } - OrgDecision::Cancel => { - return cancel_push(&base, &files); - } - } - - let mut project_name_cache = match resolve_and_ensure_named_projects( - &auth_ctx, - &preflight.named_projects, - args.create_missing_projects, - ) - .await - { - Ok(cache) => cache, - Err(err) => { - let message = format!("failed to resolve target projects for push: {err}"); - return fail_push_manifest_preflight( - &base, - &files, - &message, - "skipped because project target resolution failed", - ); - } - }; - if let Err(err) = validate_direct_project_ids(&auth_ctx, &preflight.direct_project_ids).await { let message = format!("failed to validate project ids for push: {err}"); return fail_push_manifest_preflight( @@ -513,25 +480,6 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { ); } - let target_project_ids = resolved_targets.unique_project_ids.clone(); - - let source_files: Vec<&str> = manifest - .files - .iter() - .map(|f| f.source_file.as_str()) - .collect(); - - if !org_prompt_confirmed - && !confirm_push_targets( - &auth_ctx, - &target_project_ids, - &source_files, - &project_name_cache, - )? - { - return cancel_push(&base, &files); - } - let mut summary = PushSummary { status: CommandStatus::Success, total_files: manifest.files.len(), @@ -553,6 +501,35 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { ); } + let use_progress = + !base.json && std::io::stderr().is_terminal() && animations_enabled() && !is_quiet(); + + let file_parts: Vec<&str> = manifest + .files + .iter() + .map(|f| { + Path::new(&f.source_file) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or(&f.source_file) + }) + .collect(); + let file_label = file_parts.join(", "); + + let spinner = if use_progress { + let spinner_style = ProgressStyle::default_spinner() + .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", " "]) + .template("{spinner:.cyan} {msg}") + .unwrap(); + let pb = ProgressBar::new_spinner(); + pb.set_style(spinner_style); + pb.set_message(format!("Pushing {file_label}...")); + pb.enable_steady_tick(Duration::from_millis(80)); + pb + } else { + ProgressBar::hidden() + }; + for (index, (file, resolved_file)) in manifest .files .iter() @@ -560,6 +537,7 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { .enumerate() { if resolved_file.source_file != file.source_file { + spinner.finish_and_clear(); return fail_push_manifest_preflight( &base, &files, @@ -567,6 +545,7 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { "skipped because internal target resolution failed", ); } + let source_path = PathBuf::from(&file.source_file); let file_result = push_file( &auth_ctx, @@ -586,36 +565,15 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { match file_result { Ok(file_success) => { summary.ignored_entries += file_success.ignored_entries; - let skipped_reason = if file_success.uploaded_entries == 0 { - if file_success.ignored_entries > 0 { - Some(SoftSkipReason::IfExistsIgnored) - } else { - Some(SoftSkipReason::NoDefinitionsFound) - } - } else { - None - }; - let status = if skipped_reason.is_some() { - summary.skipped_files += 1; - FileStatus::Skipped - } else { - summary.uploaded_files += 1; - FileStatus::Success - }; + summary.uploaded_files += 1; summary.files.push(PushFileReport { source_file: file.source_file.clone(), - status, + status: FileStatus::Success, uploaded_entries: file_success.uploaded_entries, - skipped_reason, + skipped_reason: None, error_reason: None, bundle_id: file_success.bundle_id, - message: if file_success.uploaded_entries == 0 - && file_success.ignored_entries == 0 - { - Some("no publishable definitions found in this file".to_string()) - } else { - None - }, + message: None, }); } Err(file_failure) => { @@ -656,14 +614,23 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { } } - if summary.status != CommandStatus::Failed && summary.skipped_files > 0 { - summary.status = CommandStatus::Partial; + if summary.status == CommandStatus::Failed { + spinner.finish_with_message(format!( + "{} Failed to push {}", + style("✗").red(), + file_label, + )); + } else { + spinner.finish_with_message(format!( + "{} Successfully pushed {}", + style("✓").green(), + file_label, + )); } - let failure = summary.status == CommandStatus::Failed; emit_summary(&base, &summary)?; - if failure { + if summary.status == CommandStatus::Failed { bail!("functions push failed; see summary for details"); } @@ -1197,7 +1164,7 @@ fn collect_classified_files(inputs: &[PathBuf]) -> Result { }) } -const MAX_DIR_DEPTH: usize = 64; +const MAX_DIR_DEPTH: usize = 256; fn collect_from_dir( dir: &Path, @@ -1224,8 +1191,11 @@ fn collect_from_dir_inner( .with_context(|| format!("failed to read directory {}", dir.display()))? { let entry = entry.with_context(|| format!("failed to read entry in {}", dir.display()))?; + let file_type = entry + .file_type() + .with_context(|| format!("failed to read file type in {}", dir.display()))?; let path = entry.path(); - if path.is_dir() { + if file_type.is_dir() && !file_type.is_symlink() { collect_from_dir_inner(&path, js_like, python, depth + 1)?; } else if path.is_file() { let canonical = path @@ -1826,8 +1796,11 @@ fn collect_regular_files_recursive_impl( std::fs::read_dir(root).with_context(|| format!("failed to read {}", root.display()))? { let entry = entry.with_context(|| format!("failed to read entry in {}", root.display()))?; + let file_type = entry + .file_type() + .with_context(|| format!("failed to read file type in {}", root.display()))?; let path = entry.path(); - if path.is_dir() { + if file_type.is_dir() && !file_type.is_symlink() { collect_regular_files_recursive_impl(&path, out, depth + 1)?; } else if path.is_file() { out.push(path); @@ -2001,7 +1974,7 @@ fn ensure_path_within_allowed_roots( ); } -fn build_push_org_prompt( +fn build_push_confirm_prompt( auth_ctx: &super::AuthContext, source_files: &[&str], project_names: &[String], @@ -2020,18 +1993,18 @@ fn build_push_org_prompt( .map(|f| style(f).green().to_string()) .collect::>() .join(", "); - let projects_part = if project_names.is_empty() { - "(no project)".to_string() + let org_label = current_org_label(auth_ctx); + let targets_part = if project_names.is_empty() { + style(&org_label).green().to_string() } else { project_names .iter() - .map(|p| style(p).green().to_string()) + .map(|p| style(format!("{org_label}/{p}")).green().to_string()) .collect::>() .join(", ") }; - let org_styled = style(current_org_label(auth_ctx)).green(); - format!("Push {files_part} to {projects_part} in {org_styled}") + format!("Push {files_part} to {targets_part}") } fn cancel_push(base: &BaseArgs, files: &[PathBuf]) -> Result<()> { @@ -2259,10 +2232,9 @@ fn resolve_default_project_id( Ok(Some(project_id)) } -async fn resolve_and_ensure_named_projects( +async fn resolve_named_projects( auth_ctx: &super::AuthContext, named_projects: &BTreeSet, - auto_create: bool, ) -> Result> { let mut project_name_cache = BTreeMap::new(); let mut missing = Vec::new(); @@ -2276,65 +2248,15 @@ async fn resolve_and_ensure_named_projects( } } - if missing.is_empty() { - return Ok(project_name_cache); - } - - if !auto_create && !is_interactive() { + if !missing.is_empty() { let joined = missing.join(", "); let org = current_org_label(auth_ctx); - bail!( - "project(s) not found in org '{org}': {joined}. Re-run with --create-missing-projects or create them first" - ); - } - - for project_name in missing { - let should_create = if auto_create { - true - } else { - Confirm::new() - .with_prompt(format!( - "Project '{}' does not exist in org '{}'. Create it?", - project_name, - current_org_label(auth_ctx) - )) - .default(false) - .interact()? - }; - - if !should_create { - bail!("project '{project_name}' is missing; push cancelled"); - } - - match create_project(&auth_ctx.client, &project_name).await { - Ok(project) => { - project_name_cache.insert(project_name.clone(), project.id); - } - Err(err) if is_http_conflict(&err) => { - let project = get_project_by_name(&auth_ctx.client, &project_name) - .await? - .ok_or_else(|| { - anyhow!( - "project '{}' already exists but could not be resolved after create conflict", - project_name - ) - })?; - project_name_cache.insert(project_name.clone(), project.id); - } - Err(err) => { - return Err(err).context(format!("failed to create project '{project_name}'")); - } - } + bail!("project(s) not found in org '{org}': {joined}"); } Ok(project_name_cache) } -fn is_http_conflict(err: &anyhow::Error) -> bool { - err.downcast_ref::() - .is_some_and(|http| http.status == StatusCode::CONFLICT) -} - async fn validate_direct_project_ids( auth_ctx: &super::AuthContext, direct_project_ids: &BTreeSet, @@ -2369,7 +2291,6 @@ async fn resolve_manifest_targets( manifest: &RunnerManifest, project_name_cache: &mut BTreeMap, ) -> Result { - let mut seen_project_ids = BTreeSet::new(); let mut entries = Vec::new(); let mut per_file = Vec::with_capacity(manifest.files.len()); @@ -2388,7 +2309,6 @@ async fn resolve_manifest_targets( project_name_cache, ) .await?; - seen_project_ids.insert(project_id.clone()); entry_project_ids.push(project_id.clone()); entries.push(ResolvedEntryTarget { source_file: file.source_file.clone(), @@ -2403,11 +2323,7 @@ async fn resolve_manifest_targets( }); } - Ok(ResolvedManifestTargets { - entries, - per_file, - unique_project_ids: seen_project_ids.into_iter().collect(), - }) + Ok(ResolvedManifestTargets { entries, per_file }) } fn validate_duplicate_slugs(entries: &[ResolvedEntryTarget]) -> Result<()> { @@ -2512,66 +2428,6 @@ async fn resolve_project_name( Ok(project.id) } -fn confirm_push_targets( - auth_ctx: &super::AuthContext, - target_project_ids: &[String], - source_files: &[&str], - project_name_cache: &BTreeMap, -) -> Result { - if !is_interactive() || target_project_ids.is_empty() { - return Ok(true); - } - - let id_to_name: BTreeMap<&str, &str> = project_name_cache - .iter() - .map(|(name, id)| (id.as_str(), name.as_str())) - .collect(); - - let project_labels: Vec<&str> = target_project_ids - .iter() - .map(|id| id_to_name.get(id.as_str()).copied().unwrap_or(id.as_str())) - .collect(); - - let file_names: Vec<&str> = source_files - .iter() - .map(|f| { - Path::new(f) - .file_name() - .and_then(|n| n.to_str()) - .unwrap_or(f) - }) - .collect(); - - let files_part = file_names - .iter() - .map(|f| style(f).green().to_string()) - .collect::>() - .join(", "); - let projects_part = project_labels.join(", "); - - let multi_org = list_profiles().is_ok_and(|p| p.len() > 1); - let prompt = if multi_org { - let org_label = if auth_ctx.client.org_name().is_empty() { - &auth_ctx.org_id - } else { - auth_ctx.client.org_name() - }; - format!( - "Push {files_part} to {projects_part} {}?", - style(format!("({org_label})")).dim() - ) - } else { - format!("Push {files_part} to {projects_part}?") - }; - - let confirmed = Confirm::with_theme(&ColorfulTheme::default()) - .with_prompt(prompt) - .default(false) - .interact()?; - - Ok(confirmed) -} - fn collect_project_name_placeholders_checked( value: &Value, out: &mut BTreeSet, @@ -2648,30 +2504,6 @@ fn emit_summary(base: &BaseArgs, summary: &PushSummary) -> Result<()> { if base.json { println!("{}", serde_json::to_string(summary)?); } else { - match summary.status { - CommandStatus::Success => { - eprintln!("Pushed {} file(s) successfully.", summary.uploaded_files); - } - CommandStatus::Partial => { - eprintln!( - "Pushed with partial success: uploaded={}, skipped={}, failed={}", - summary.uploaded_files, summary.skipped_files, summary.failed_files - ); - } - CommandStatus::Failed => { - eprintln!( - "Push failed: uploaded={}, skipped={}, failed={}", - summary.uploaded_files, summary.skipped_files, summary.failed_files - ); - } - } - for warning in &summary.warnings { - eprintln!( - "warning ({}): {}", - to_warning_code(warning), - warning.message - ); - } for error in &summary.errors { eprintln!("error ({}): {}", to_error_code(error), error.message); } @@ -2679,15 +2511,6 @@ fn emit_summary(base: &BaseArgs, summary: &PushSummary) -> Result<()> { Ok(()) } -fn to_warning_code(warning: &ReportWarning) -> &'static str { - match warning.reason { - super::report::WarningReason::PaginationNotSnapshotConsistent => { - "pagination_not_snapshot_consistent" - } - super::report::WarningReason::SelectorPartialMatch => "selector_partial_match", - } -} - fn to_error_code(error: &ReportError) -> &'static str { match error.reason { HardFailureReason::AuthFailed => "auth_failed", @@ -2924,7 +2747,7 @@ mod tests { requirements: None, tsconfig: None, external_packages: vec![], - create_missing_projects: false, + yes: false, }; let classified = ClassifiedFiles { js_like: vec![PathBuf::from("/tmp/a.ts")], @@ -2953,7 +2776,7 @@ mod tests { requirements: None, tsconfig: None, external_packages: vec![], - create_missing_projects: false, + yes: false, }; let classified = ClassifiedFiles { js_like: vec![PathBuf::from("/tmp/a.ts")], diff --git a/tests/functions-fixtures/push-help-env-vars/fixture.json b/tests/functions-fixtures/push-help-env-vars/fixture.json index 3bd4748..abe76c5 100644 --- a/tests/functions-fixtures/push-help-env-vars/fixture.json +++ b/tests/functions-fixtures/push-help-env-vars/fixture.json @@ -9,7 +9,6 @@ "BT_FUNCTIONS_PUSH_LANGUAGE", "BT_FUNCTIONS_PUSH_REQUIREMENTS", "BT_FUNCTIONS_PUSH_TSCONFIG", - "BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES", - "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS" + "BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES" ] } diff --git a/tests/functions-fixtures/push-help-flags/fixture.json b/tests/functions-fixtures/push-help-flags/fixture.json index 4ffc039..16d6106 100644 --- a/tests/functions-fixtures/push-help-flags/fixture.json +++ b/tests/functions-fixtures/push-help-flags/fixture.json @@ -5,7 +5,6 @@ "--file", "--if-exists", "--terminate-on-failure", - "--create-missing-projects", "--language", "--requirements", "--tsconfig", diff --git a/tests/functions-fixtures/push-invalid-bool-env/fixture.json b/tests/functions-fixtures/push-invalid-bool-env/fixture.json deleted file mode 100644 index fb531cf..0000000 --- a/tests/functions-fixtures/push-invalid-bool-env/fixture.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "command": ["functions", "push"], - "env": { - "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS": "notabool" - }, - "expect_success": false, - "stderr_contains": ["--create-missing-projects", "value was not a boolean"] -} diff --git a/tests/functions-fixtures/push-valid-bool-env/fixture.json b/tests/functions-fixtures/push-valid-bool-env/fixture.json deleted file mode 100644 index 51c3ce0..0000000 --- a/tests/functions-fixtures/push-valid-bool-env/fixture.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "command": ["functions", "push", "--language", "typescript"], - "env": { - "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS": "true" - }, - "expect_success": false, - "stderr_contains": ["invalid value 'typescript'"], - "stderr_not_contains": ["value was not a boolean"] -} diff --git a/tests/functions.rs b/tests/functions.rs index bf3ca3a..acaa681 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -160,7 +160,6 @@ fn sanitized_env_keys() -> &'static [&'static str] { "BT_FUNCTIONS_PUSH_REQUIREMENTS", "BT_FUNCTIONS_PUSH_TSCONFIG", "BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES", - "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS", "BT_FUNCTIONS_PULL_OUTPUT_DIR", "BT_FUNCTIONS_PULL_PROJECT_ID", "BT_FUNCTIONS_PULL_PROJECT_NAME", @@ -541,7 +540,6 @@ fn functions_push_help_includes_expected_flags() { assert!(stdout.contains("--file")); assert!(stdout.contains("--if-exists")); assert!(stdout.contains("--terminate-on-failure")); - assert!(stdout.contains("--create-missing-projects")); assert!(stdout.contains("--language")); assert!(stdout.contains("--requirements")); assert!(stdout.contains("--tsconfig")); @@ -693,7 +691,6 @@ fn push_and_pull_help_are_machine_readable() { let push_stdout = String::from_utf8_lossy(&push_help.stdout); let pull_stdout = String::from_utf8_lossy(&pull_help.stdout); assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_FILES")); - assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS")); assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_LANGUAGE")); assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_REQUIREMENTS")); assert!(push_stdout.contains("BT_FUNCTIONS_PUSH_TSCONFIG")); From a6f1062b0f35caee2ad7ff2f80a64684c121c7cd Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Mon, 9 Mar 2026 12:44:56 -0700 Subject: [PATCH 13/28] feat(functions): add progress indicator for pull command and cleanup output --- src/functions/pull.rs | 125 ++++++++++++++++++++++++++++++++---------- 1 file changed, 95 insertions(+), 30 deletions(-) diff --git a/src/functions/pull.rs b/src/functions/pull.rs index fdec265..509a5ef 100644 --- a/src/functions/pull.rs +++ b/src/functions/pull.rs @@ -1,9 +1,13 @@ use std::cmp::Ordering; use std::collections::{BTreeMap, BTreeSet}; use std::fs::OpenOptions; +use std::io::IsTerminal; use std::path::{Path, PathBuf}; +use std::time::Duration; use anyhow::{anyhow, bail, Context, Result}; +use dialoguer::console::style; +use indicatif::{ProgressBar, ProgressStyle}; use serde::Deserialize; use serde_json::Value; @@ -16,7 +20,11 @@ use crate::projects::api::{list_projects, Project}; use crate::utils::{write_text_atomic, GitRepo}; use super::api::{self, FunctionListQuery}; -use super::{resolve_auth_context, FunctionsLanguage, PullArgs}; +use super::{ + current_org_label, resolve_auth_context, resolve_project_context_optional, FunctionsLanguage, + PullArgs, +}; +use crate::ui::{animations_enabled, is_quiet}; const PAGINATION_PAGE_LIMIT: usize = 10_000; const OUTPUT_LOCK_FILE: &str = ".bt-functions-pull.lock"; @@ -123,6 +131,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { }; let mut query = FunctionListQuery::default(); + let mut resolved_project_name: Option = None; if let Some(project_id) = &args.project_id { query.project_id = Some(project_id.clone()); @@ -146,7 +155,24 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { err.to_string(), ); } + resolved_project_name = Some(project_name.clone()); query.project_name = Some(project_name.clone()); + } else { + let project = match resolve_project_context_optional(&base, &auth_ctx, false).await { + Ok(project) => project, + Err(err) => { + return fail_pull( + &base, + &mut summary, + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + } + }; + if let Some(project) = project { + resolved_project_name = Some(project.name.clone()); + query.project_id = Some(project.id); + } } if let Some(id) = &args.id { @@ -169,9 +195,37 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { if resolved_slugs.len() == 1 { query.slug = Some(resolved_slugs[0].clone()); } + + let org_label = current_org_label(&auth_ctx); + let subject = if !resolved_slugs.is_empty() { + resolved_slugs.join(", ") + } else { + "functions".to_string() + }; + let from_label = match &resolved_project_name { + Some(project) => format!("{org_label}/{project}"), + None => org_label.clone(), + }; + let use_progress = + !base.json && std::io::stderr().is_terminal() && animations_enabled() && !is_quiet(); + let spinner = if use_progress { + let spinner_style = ProgressStyle::default_spinner() + .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", " "]) + .template("{spinner:.cyan} {msg}") + .unwrap(); + let pb = ProgressBar::new_spinner(); + pb.set_style(spinner_style); + pb.set_message(format!("Pulling {subject} from {from_label}...")); + pb.enable_steady_tick(Duration::from_millis(80)); + pb + } else { + ProgressBar::hidden() + }; + let fetched = match fetch_all_function_rows(&auth_ctx.client, &query).await { Ok(fetched) => fetched, Err(err) => { + spinner.finish_and_clear(); return fail_pull( &base, &mut summary, @@ -203,6 +257,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { let narrowed_rows = match apply_selector_narrowing(parsed_rows, &args) { Ok(rows) => rows, Err(err) => { + spinner.finish_and_clear(); return fail_pull( &base, &mut summary, @@ -215,6 +270,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { let winners = select_winner_rows(narrowed_rows, &mut summary); if (args.id.is_some() || args.has_slug_selector()) && winners.is_empty() { + spinner.finish_and_clear(); return fail_pull( &base, &mut summary, @@ -248,6 +304,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { if let Err(err) = std::fs::create_dir_all(&output_dir) .with_context(|| format!("failed to create output directory {}", output_dir.display())) { + spinner.finish_and_clear(); return fail_pull( &base, &mut summary, @@ -262,6 +319,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { { Ok(path) => path, Err(err) => { + spinner.finish_and_clear(); return fail_pull( &base, &mut summary, @@ -274,6 +332,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { let _lock = match OutputLock::acquire(&canonical_output_dir) { Ok(lock) => lock, Err(err) => { + spinner.finish_and_clear(); return fail_pull( &base, &mut summary, @@ -290,6 +349,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { let projects = match get_projects_cached(&auth_ctx.client, &mut projects_cache).await { Ok(projects) => projects, Err(err) => { + spinner.finish_and_clear(); return fail_pull( &base, &mut summary, @@ -301,6 +361,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { match resolve_project_names(&winners, projects) { Ok(names) => names, Err(err) => { + spinner.finish_and_clear(); return fail_pull( &base, &mut summary, @@ -396,6 +457,26 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { } let failure = summary.status == CommandStatus::Failed; + spinner.finish_and_clear(); + if use_progress { + if failure { + eprintln!("{} Failed to pull {subject}", style("✗").red()); + } else { + let cwd = std::env::current_dir().ok(); + let pulled_files: Vec<_> = summary + .files + .iter() + .filter(|f| f.status == FileStatus::Success) + .map(|f| short_display_path(&f.output_file, cwd.as_deref())) + .collect(); + let file_label = if pulled_files.is_empty() { + from_label.clone() + } else { + pulled_files.join(", ") + }; + eprintln!("{} Pulled {subject} to {file_label}", style("✓").green(),); + } + } emit_summary(&base, &summary, args.verbose)?; if failure { bail!("functions pull failed; see summary for details"); @@ -1460,38 +1541,22 @@ fn emit_summary(base: &BaseArgs, summary: &PullSummary, verbose: bool) -> Result return Ok(()); } - let has_visible_files = summary - .files - .iter() - .any(|f| f.status == FileStatus::Success || f.status == FileStatus::Failed || verbose); - let mut parts = vec![format!("Wrote {} file(s)", summary.files_written)]; - if has_visible_files { - let cwd = std::env::current_dir().ok(); - for f in &summary.files { - let name = short_display_path(&f.output_file, cwd.as_deref()); - match f.status { - FileStatus::Success => eprintln!("Pulled {name}"), - FileStatus::Failed => { - let msg = f.message.as_deref().unwrap_or("unknown error"); - eprintln!("Failed to pull {name} ({msg})"); - } - FileStatus::Skipped if verbose => { - let reason = skip_reason_label(f.skipped_reason); - eprintln!("Skipped {name} ({reason})"); - } - FileStatus::Skipped => {} + let cwd = std::env::current_dir().ok(); + for f in &summary.files { + let name = short_display_path(&f.output_file, cwd.as_deref()); + match f.status { + FileStatus::Success => {} + FileStatus::Failed => { + let msg = f.message.as_deref().unwrap_or("unknown error"); + eprintln!("Failed to pull {name} ({msg})"); + } + FileStatus::Skipped if verbose => { + let reason = skip_reason_label(f.skipped_reason); + eprintln!("Skipped {name} ({reason})"); } + FileStatus::Skipped => {} } - eprintln!(); - } - - if summary.files_skipped > 0 { - parts.push(format!("skipped {}", summary.files_skipped)); - } - if summary.files_failed > 0 { - parts.push(format!("failed {}", summary.files_failed)); } - eprintln!("{}.", parts.join(", ")); for warning in &summary.warnings { eprintln!("warning: {}", warning.message); From 4c551888f471b16bf16e1feb73a95f377d7adbd2 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Mon, 9 Mar 2026 13:06:45 -0700 Subject: [PATCH 14/28] refactor(push): use file_type instead of path.is_file() for consistency --- src/functions/pull.rs | 34 ++----- src/functions/push.rs | 206 ++++++++++++++++++++++++++++++------------ tests/functions.rs | 14 +-- 3 files changed, 166 insertions(+), 88 deletions(-) diff --git a/src/functions/pull.rs b/src/functions/pull.rs index 509a5ef..1755c15 100644 --- a/src/functions/pull.rs +++ b/src/functions/pull.rs @@ -279,17 +279,19 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { ); } - let project_ids_with_matches = winners - .iter() - .map(|row| row.project_id.clone()) - .collect::>(); - let mut materializable = Vec::new(); for row in winners.iter().cloned() { if is_prompt_row(&row) { materializable.push(row); } else { summary.unsupported_records_skipped += 1; + if !is_quiet() { + eprintln!( + "{} skipping '{}' because it is not a prompt", + style("warning:").yellow(), + row.slug + ); + } } } @@ -343,7 +345,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { }; let repo = GitRepo::discover_from(&canonical_output_dir); - let project_names = if project_ids_with_matches.is_empty() { + let project_names = if materializable.is_empty() { BTreeMap::new() } else { let projects = match get_projects_cached(&auth_ctx.client, &mut projects_cache).await { @@ -378,7 +380,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { }; if !resolved_slugs.is_empty() { - let found_slugs: BTreeSet<&str> = materializable.iter().map(|r| r.slug.as_str()).collect(); + let found_slugs: BTreeSet<&str> = winners.iter().map(|r| r.slug.as_str()).collect(); for slug in &resolved_slugs { if !found_slugs.contains(slug.as_str()) { summary.warnings.push(ReportWarning { @@ -389,25 +391,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { } } - // Legacy-compatible project mode: one output file per project, even for - // selector pulls that only matched unsupported record types. let mut grouped_by_project = BTreeMap::<(String, String), Vec>::new(); - for project_id in project_ids_with_matches { - let Some(project_name) = project_names.get(&project_id).cloned() else { - return fail_pull( - &base, - &mut summary, - HardFailureReason::ResponseInvalid, - format!( - "missing resolved project name for project id '{}'", - project_id - ), - ); - }; - grouped_by_project - .entry((project_id, project_name)) - .or_default(); - } for row in materializable { let Some(project_name) = project_names.get(&row.project_id).cloned() else { return fail_pull( diff --git a/src/functions/push.rs b/src/functions/push.rs index d6d0f18..5adec3f 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -614,18 +614,11 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { } } + spinner.finish_and_clear(); if summary.status == CommandStatus::Failed { - spinner.finish_with_message(format!( - "{} Failed to push {}", - style("✗").red(), - file_label, - )); + eprintln!("{} Failed to push {}", style("✗").red(), file_label); } else { - spinner.finish_with_message(format!( - "{} Successfully pushed {}", - style("✓").green(), - file_label, - )); + eprintln!("{} Successfully pushed {}", style("✓").green(), file_label); } emit_summary(&base, &summary)?; @@ -1197,7 +1190,7 @@ fn collect_from_dir_inner( let path = entry.path(); if file_type.is_dir() && !file_type.is_symlink() { collect_from_dir_inner(&path, js_like, python, depth + 1)?; - } else if path.is_file() { + } else if file_type.is_file() { let canonical = path .canonicalize() .with_context(|| format!("failed to canonicalize file {}", path.display()))?; @@ -1569,12 +1562,16 @@ fn build_python_bundle_archive( requirements_path: Option<&Path>, runner: Option<&str>, ) -> Result> { + let Some(python) = python_runner::resolve_python_interpreter(runner, &[]) else { + bail!("No Python interpreter found. Install python or pass --runner.") + }; + let build_dir = TempBuildDir::create("bt-functions-python-bundle")?; let pkg_dir = build_dir.path.join("pkg"); std::fs::create_dir_all(&pkg_dir) .with_context(|| format!("failed to create {}", pkg_dir.display()))?; - install_python_dependencies(&pkg_dir, requirements_path)?; + install_python_dependencies(&pkg_dir, requirements_path, &python)?; let stage_dir = build_dir.path.join("stage"); std::fs::create_dir_all(&stage_dir) @@ -1592,7 +1589,7 @@ fn build_python_bundle_archive( .context("failed to write register.py")?; let zip_path = build_dir.path.join("pkg.zip"); - create_zip_with_python(runner, &stage_dir, &zip_path)?; + create_zip_with_python(&python, &stage_dir, &zip_path)?; std::fs::read(&zip_path) .with_context(|| format!("failed to read generated archive {}", zip_path.display())) } @@ -1657,7 +1654,7 @@ fn normalized_archive_relative_path(path: &Path) -> Result { Ok(out) } -fn create_zip_with_python(runner: Option<&str>, stage_root: &Path, zip_path: &Path) -> Result<()> { +fn create_zip_with_python(python: &Path, stage_root: &Path, zip_path: &Path) -> Result<()> { const ZIP_SCRIPT: &str = r#"import os import sys import zipfile @@ -1675,9 +1672,6 @@ with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED, compressle zf.write(source, rel) "#; - let Some(python) = python_runner::resolve_python_interpreter(runner, &[]) else { - bail!("No Python interpreter found. Install python or pass --runner.") - }; let output = Command::new(python) .arg("-c") .arg(ZIP_SCRIPT) @@ -1710,18 +1704,46 @@ with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED, compressle ); } -fn install_python_dependencies(pkg_dir: &Path, requirements_path: Option<&Path>) -> Result<()> { - let uv = python_runner::find_binary_in_path(&["uv"]).ok_or_else(|| { - anyhow!("`uv` is required to build Python code bundles; please install uv") - })?; - - let mut baseline_args = vec![ +fn baseline_uv_install_args(pkg_dir: &Path, python: &Path) -> Vec { + let mut args = vec![ OsString::from("pip"), OsString::from("install"), + OsString::from("--python"), + python.as_os_str().to_os_string(), OsString::from("--target"), pkg_dir.as_os_str().to_os_string(), ]; - baseline_args.extend(PYTHON_BASELINE_DEPS.iter().map(OsString::from)); + args.extend(PYTHON_BASELINE_DEPS.iter().map(OsString::from)); + args +} + +fn requirements_uv_install_args( + pkg_dir: &Path, + requirements: &Path, + python: &Path, +) -> Vec { + vec![ + OsString::from("pip"), + OsString::from("install"), + OsString::from("--python"), + python.as_os_str().to_os_string(), + OsString::from("--target"), + pkg_dir.as_os_str().to_os_string(), + OsString::from("-r"), + requirements.as_os_str().to_os_string(), + ] +} + +fn install_python_dependencies( + pkg_dir: &Path, + requirements_path: Option<&Path>, + python: &Path, +) -> Result<()> { + let uv = python_runner::find_binary_in_path(&["uv"]).ok_or_else(|| { + anyhow!("`uv` is required to build Python code bundles; please install uv") + })?; + + let baseline_args = baseline_uv_install_args(pkg_dir, python); run_uv_command( &uv, &baseline_args, @@ -1729,14 +1751,7 @@ fn install_python_dependencies(pkg_dir: &Path, requirements_path: Option<&Path>) )?; if let Some(requirements) = requirements_path { - let args = vec![ - OsString::from("pip"), - OsString::from("install"), - OsString::from("--target"), - pkg_dir.as_os_str().to_os_string(), - OsString::from("-r"), - requirements.as_os_str().to_os_string(), - ]; + let args = requirements_uv_install_args(pkg_dir, requirements, python); run_uv_command(&uv, &args, "installing requirements file dependencies")?; } @@ -1802,7 +1817,7 @@ fn collect_regular_files_recursive_impl( let path = entry.path(); if file_type.is_dir() && !file_type.is_symlink() { collect_regular_files_recursive_impl(&path, out, depth + 1)?; - } else if path.is_file() { + } else if file_type.is_file() { out.push(path); } } @@ -2505,34 +2520,16 @@ fn emit_summary(base: &BaseArgs, summary: &PushSummary) -> Result<()> { println!("{}", serde_json::to_string(summary)?); } else { for error in &summary.errors { - eprintln!("error ({}): {}", to_error_code(error), error.message); + let code = serde_json::to_value(error.reason) + .ok() + .and_then(|v| v.as_str().map(ToOwned::to_owned)) + .unwrap_or_else(|| format!("{:?}", error.reason)); + eprintln!("error ({code}): {}", error.message); } } Ok(()) } -fn to_error_code(error: &ReportError) -> &'static str { - match error.reason { - HardFailureReason::AuthFailed => "auth_failed", - HardFailureReason::RequestFailed => "request_failed", - HardFailureReason::ResponseInvalid => "response_invalid", - HardFailureReason::UserCancelled => "user_cancelled", - HardFailureReason::OutputDirInvalid => "output_dir_invalid", - HardFailureReason::AtomicWriteFailed => "atomic_write_failed", - HardFailureReason::UnsafeOutputPath => "unsafe_output_path", - HardFailureReason::RunnerSpawnFailed => "runner_spawn_failed", - HardFailureReason::RunnerExitNonzero => "runner_exit_nonzero", - HardFailureReason::ManifestInvalidJson => "manifest_invalid_json", - HardFailureReason::ManifestSchemaInvalid => "manifest_schema_invalid", - HardFailureReason::ManifestPathMissing => "manifest_path_missing", - HardFailureReason::UploadSlotFailed => "upload_slot_failed", - HardFailureReason::BundleUploadFailed => "bundle_upload_failed", - HardFailureReason::InsertFunctionsFailed => "insert_functions_failed", - HardFailureReason::SelectorNotFound => "selector_not_found", - HardFailureReason::PaginationUnsupported => "pagination_unsupported", - } -} - fn fail_push( base: &BaseArgs, total_files: usize, @@ -3072,6 +3069,103 @@ mod tests { assert!(value["data"].get("preview").is_none()); } + #[test] + fn uv_install_args_include_selected_python() { + let pkg_dir = PathBuf::from("/tmp/pkg"); + let python = PathBuf::from("/tmp/custom-python"); + let rendered = baseline_uv_install_args(&pkg_dir, &python) + .into_iter() + .map(|value| value.to_string_lossy().to_string()) + .collect::>(); + let python_str = python.to_string_lossy().to_string(); + + assert!( + rendered + .windows(2) + .any(|window| window[0] == "--python" && window[1] == python_str.as_str()), + "baseline uv args should pin the selected python interpreter" + ); + } + + #[test] + fn requirements_uv_install_args_include_selected_python() { + let pkg_dir = PathBuf::from("/tmp/pkg"); + let requirements = PathBuf::from("/tmp/requirements.txt"); + let python = PathBuf::from("/tmp/custom-python"); + let rendered = requirements_uv_install_args(&pkg_dir, &requirements, &python) + .into_iter() + .map(|value| value.to_string_lossy().to_string()) + .collect::>(); + let python_str = python.to_string_lossy().to_string(); + + assert!( + rendered + .windows(2) + .any(|window| window[0] == "--python" && window[1] == python_str.as_str()), + "requirements uv args should pin the selected python interpreter" + ); + } + + #[cfg(unix)] + #[test] + fn collect_from_dir_skips_symlinked_files() { + use std::os::unix::fs::symlink; + + let dir = tempfile::tempdir().expect("tempdir"); + let root = dir.path().join("root"); + std::fs::create_dir_all(&root).expect("create root"); + + let inside = root.join("inside.ts"); + std::fs::write(&inside, "export const inside = 1;\n").expect("write inside"); + + let outside = dir.path().join("outside.ts"); + std::fs::write(&outside, "export const outside = 2;\n").expect("write outside"); + symlink(&outside, root.join("outside-link.ts")).expect("create symlink"); + + let mut js_like = BTreeSet::new(); + let mut python = BTreeSet::new(); + collect_from_dir(&root, &mut js_like, &mut python).expect("collect sources"); + + let inside = inside.canonicalize().expect("canonicalize inside"); + let outside = outside.canonicalize().expect("canonicalize outside"); + assert!(js_like.contains(&inside)); + assert!( + !js_like.contains(&outside), + "directory scan should not follow symlinked files" + ); + } + + #[cfg(unix)] + #[test] + fn collect_regular_files_recursive_skips_symlinked_files() { + use std::os::unix::fs::symlink; + + let dir = tempfile::tempdir().expect("tempdir"); + let root = dir.path().join("root"); + std::fs::create_dir_all(&root).expect("create root"); + + let inside = root.join("inside.txt"); + std::fs::write(&inside, "inside\n").expect("write inside"); + + let outside = dir.path().join("outside.txt"); + std::fs::write(&outside, "outside\n").expect("write outside"); + symlink(&outside, root.join("outside-link.txt")).expect("create symlink"); + + let files = collect_regular_files_recursive(&root).expect("collect regular files"); + assert!(files.contains(&inside)); + assert!( + files.iter().all(|path| path != &outside), + "collector must not include symlink targets outside root" + ); + assert!( + files.iter().all(|path| path + .file_name() + .and_then(|value| value.to_str()) + .is_none_or(|value| value != "outside-link.txt")), + "collector must skip symlink file entries" + ); + } + fn test_base_args() -> BaseArgs { BaseArgs { json: false, diff --git a/tests/functions.rs b/tests/functions.rs index acaa681..2b9a1bb 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -1947,20 +1947,20 @@ async fn functions_pull_selector_with_unsupported_only_rows_still_succeeds() { let summary: Value = serde_json::from_slice(&output.stdout).expect("parse pull summary"); assert_eq!(summary["status"].as_str(), Some("partial")); - assert_eq!(summary["files_written"].as_u64(), Some(1)); + assert_eq!(summary["files_written"].as_u64(), Some(0)); assert_eq!(summary["files_failed"].as_u64(), Some(0)); assert_eq!(summary["unsupported_records_skipped"].as_u64(), Some(1)); assert_eq!(summary["functions_materialized"].as_u64(), Some(0)); let rendered_file = out_dir.join("mock-project.ts"); - assert!(rendered_file.is_file(), "expected rendered file to exist"); - let rendered = std::fs::read_to_string(&rendered_file).expect("read rendered file"); assert!( - rendered.contains("const project = braintrust.projects.create"), - "rendered file should still include project scaffold" + !rendered_file.exists(), + "no file should be written when all rows are unsupported" ); + + let stderr = String::from_utf8_lossy(&output.stderr); assert!( - !rendered.contains("project.prompts.create"), - "rendered file should not include prompt materializations" + stderr.contains("skipping 'legacy-code' because it is not a prompt"), + "expected warning about non-prompt function on stderr, got:\n{stderr}" ); } From 615392ff4aa019e00591ddbe1583ed4f91c532cc Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Mon, 9 Mar 2026 18:46:18 -0700 Subject: [PATCH 15/28] refactor(functions): remove legacy compatibility code and aliases --- scripts/functions-runner.py | 80 +----- scripts/functions-runner.ts | 130 +--------- src/args.rs | 8 +- src/functions/api.rs | 84 +----- src/functions/mod.rs | 110 +++----- src/functions/pull.rs | 165 +----------- src/functions/push.rs | 244 +++++++++--------- src/main.rs | 10 - .../pull-help-env-vars/fixture.json | 1 - .../pull-help-flags/fixture.json | 1 - .../fixture.json | 12 - tests/functions.rs | 30 +-- 12 files changed, 180 insertions(+), 695 deletions(-) delete mode 100644 tests/functions-fixtures/pull-project-id-name-conflict/fixture.json diff --git a/scripts/functions-runner.py b/scripts/functions-runner.py index 1e5ce1b..d870193 100644 --- a/scripts/functions-runner.py +++ b/scripts/functions-runner.py @@ -42,23 +42,8 @@ def to_json_value(value: Any) -> Any: def load_framework_globals() -> tuple[Any, Any, Any]: - try: - from braintrust.framework2.global_ import functions, prompts - except Exception: - from braintrust.framework2 import global_ as global_state - - functions = getattr(global_state, "functions", []) - prompts = getattr(global_state, "prompts", []) - - lazy = None - try: - from braintrust.framework2.lazy_load import _set_lazy_load as lazy - except Exception: - try: - from braintrust.framework import _set_lazy_load as lazy - except Exception: - lazy = None - + from braintrust.framework2.global_ import functions, prompts + from braintrust.framework2.lazy_load import _set_lazy_load as lazy return functions, prompts, lazy @@ -155,7 +140,7 @@ def collect_code_entries(functions_registry: Any) -> list[dict[str, Any]]: normalized_function_type = normalize_function_type(function_type) if normalized_function_type: entry["function_type"] = normalized_function_type - if_exists = getattr(item, "if_exists", None) or getattr(item, "ifExists", None) + if_exists = getattr(item, "if_exists", None) if isinstance(if_exists, str): entry["if_exists"] = if_exists metadata = getattr(item, "metadata", None) @@ -174,60 +159,6 @@ def collect_code_entries(functions_registry: Any) -> list[dict[str, Any]]: return entries -def collect_legacy_prompt_event(item: Any, resolver: Resolver) -> dict[str, Any] | None: - name = getattr(item, "name", None) - slug = getattr(item, "slug", None) - if not isinstance(name, str) or not isinstance(slug, str) or not name or not slug: - return None - - prompt = to_json_value(getattr(item, "prompt", {}) or {}) - if not isinstance(prompt, dict): - prompt = {} - - tool_functions = getattr(item, "tool_functions", None) - if isinstance(tool_functions, list) and tool_functions: - resolved_tools: list[Any] = [] - for tool in tool_functions: - if isinstance(tool, dict): - slug_value = tool.get("slug") - project = tool.get("project") - if isinstance(slug_value, str) and project is not None: - placeholder = selector_to_project_placeholder(project) - if placeholder: - resolved_tools.append( - {"type": "slug", "project_id": placeholder, "slug": slug_value} - ) - continue - resolved_tools.append(to_json_value(tool)) - else: - resolved_tools.append(to_json_value(tool)) - if resolved_tools: - prompt["tool_functions"] = resolved_tools - - event: dict[str, Any] = { - "name": name, - "slug": slug, - "description": getattr(item, "description", "") or "", - "function_data": {"type": "prompt"}, - "prompt_data": prompt, - } - - if_exists = getattr(item, "if_exists", None) or getattr(item, "ifExists", None) - if isinstance(if_exists, str): - event["if_exists"] = if_exists - metadata = getattr(item, "metadata", None) - if metadata is not None: - event["metadata"] = to_json_value(metadata) - - project_id, project_name = normalize_project_selector(getattr(item, "project", None)) - out: dict[str, Any] = {"kind": "function_event", "event": event} - if project_id: - out["project_id"] = project_id - if project_name: - out["project_name"] = project_name - return out - - async def collect_function_event_entries(prompts_registry: Any) -> list[dict[str, Any]]: entries: list[dict[str, Any]] = [] resolver = Resolver() @@ -247,11 +178,6 @@ async def collect_function_event_entries(prompts_registry: Any) -> list[dict[str if project_name: event_entry["project_name"] = project_name entries.append(event_entry) - continue - - legacy = collect_legacy_prompt_event(item, resolver) - if legacy is not None: - entries.append(legacy) return entries diff --git a/scripts/functions-runner.ts b/scripts/functions-runner.ts index de15262..4a35c47 100644 --- a/scripts/functions-runner.ts +++ b/scripts/functions-runner.ts @@ -30,22 +30,6 @@ type CodeRegistryItem = { type EventRegistryItem = { project?: ProjectRef; toFunctionDefinition?: (resolver: Resolver) => Promise; - name?: string; - slug?: string; - description?: string; - ifExists?: string; - metadata?: JsonValue; - prompt?: JsonValue; - toolFunctions?: LegacyToolFunction[]; -}; - -type LegacyToolFunction = { - type?: string; - id?: string; - name?: string; - slug?: string; - project?: ProjectRef; - project_id?: string; }; type CodeEntry = { @@ -117,7 +101,6 @@ function currentRegistry(fallback: EvalRegistry): EvalRegistry { async function collectFunctionEvents( items: EventRegistryItem[], - includeLegacyPrompts: boolean, ): Promise { const entries: FunctionEventEntry[] = []; @@ -130,12 +113,6 @@ async function collectFunctionEvents( for (const item of items) { if (!item.toFunctionDefinition) { - if (includeLegacyPrompts) { - const entry = await collectLegacyPromptEvent(item, resolver); - if (entry) { - entries.push(entry); - } - } continue; } @@ -164,107 +141,6 @@ async function collectFunctionEvents( return entries; } -async function collectLegacyPromptEvent( - item: EventRegistryItem, - resolver: Resolver, -): Promise { - if (typeof item.name !== "string" || typeof item.slug !== "string") { - return null; - } - - const normalizedPrompt = toJsonValue(item.prompt ?? {}); - if (!isJsonObject(normalizedPrompt)) { - return null; - } - - const promptData: JsonObject = { ...normalizedPrompt }; - const toolFunctions = Array.isArray(item.toolFunctions) - ? item.toolFunctions - : []; - if (toolFunctions.length > 0) { - const resolvedTools: JsonValue[] = []; - for (const tool of toolFunctions) { - const resolved = await resolveLegacyToolFunction(tool, resolver); - if (resolved) { - resolvedTools.push(resolved); - } - } - if (resolvedTools.length > 0) { - promptData.tool_functions = resolvedTools; - } - } - - const selector = asProjectSelector(item.project); - const projectId = - typeof selector.project_id === "string" ? selector.project_id : undefined; - const projectName = - typeof selector.project_name === "string" - ? selector.project_name - : undefined; - - const event: JsonObject = { - name: item.name, - slug: item.slug, - description: typeof item.description === "string" ? item.description : "", - function_data: { - type: "prompt", - }, - prompt_data: promptData, - }; - if (typeof item.ifExists === "string") { - event.if_exists = item.ifExists; - } - if (item.metadata !== undefined) { - event.metadata = item.metadata; - } - - return { - kind: "function_event", - project_id: projectId, - project_name: projectName, - event, - }; -} - -async function resolveLegacyToolFunction( - tool: LegacyToolFunction, - resolver: Resolver, -): Promise { - if ( - typeof tool.slug === "string" && - tool.slug.length > 0 && - tool.project !== undefined - ) { - const projectId = await resolver.resolve(tool.project); - if (projectId.length > 0) { - return { - type: "slug", - project_id: projectId, - slug: tool.slug, - }; - } - } - - const direct: JsonObject = {}; - if (typeof tool.type === "string") { - direct.type = tool.type; - } - if (typeof tool.id === "string") { - direct.id = tool.id; - } - if (typeof tool.name === "string") { - direct.name = tool.name; - } - if (typeof tool.project_id === "string") { - direct.project_id = tool.project_id; - } - if (typeof tool.slug === "string") { - direct.slug = tool.slug; - } - - return Object.keys(direct).length > 0 ? direct : null; -} - function collectCodeEntries(items: CodeRegistryItem[]): CodeEntry[] { const entries: CodeEntry[] = []; @@ -321,13 +197,9 @@ async function processFile(filePath: string): Promise { const entries: Array = [ ...collectCodeEntries(registry.functions as CodeRegistryItem[]), - ...(await collectFunctionEvents( - registry.prompts as EventRegistryItem[], - true, - )), + ...(await collectFunctionEvents(registry.prompts as EventRegistryItem[])), ...(await collectFunctionEvents( registry.parameters as EventRegistryItem[], - false, )), ]; diff --git a/src/args.rs b/src/args.rs index e6d30c1..3748e8b 100644 --- a/src/args.rs +++ b/src/args.rs @@ -23,13 +23,7 @@ pub struct BaseArgs { pub profile: Option, /// Override active org (or via BRAINTRUST_ORG_NAME) - #[arg( - short = 'o', - long = "org", - alias = "org-name", - env = "BRAINTRUST_ORG_NAME", - global = true - )] + #[arg(short = 'o', long = "org", env = "BRAINTRUST_ORG_NAME", global = true)] pub org_name: Option, /// Override active project diff --git a/src/functions/api.rs b/src/functions/api.rs index 6317574..e4e1f1d 100644 --- a/src/functions/api.rs +++ b/src/functions/api.rs @@ -49,8 +49,6 @@ pub struct FunctionListPage { pub objects: Vec, pub next_cursor: Option, pub snapshot: Option, - pub pagination_field_present: bool, - pub snapshot_field_present: bool, } #[derive(Debug, Clone, Deserialize)] @@ -170,37 +168,16 @@ fn parse_function_list_page(raw: Value) -> Result { .cloned() .ok_or_else(|| anyhow::anyhow!("missing 'objects' array in /v1/function response"))?; - let explicit_next_cursor = raw + let next_cursor = raw .get("next_cursor") .and_then(Value::as_str) - .or_else(|| raw.get("nextCursor").and_then(Value::as_str)) - .or_else(|| raw.get("next").and_then(Value::as_str)) .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned); - let cursor_field = raw - .get("cursor") - .and_then(Value::as_str) - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(ToOwned::to_owned); - - let has_more = raw - .get("has_more") - .and_then(Value::as_bool) - .or_else(|| raw.get("hasMore").and_then(Value::as_bool)); - - let next_cursor = explicit_next_cursor.or(match has_more { - Some(false) => None, - _ => cursor_field, - }); - let snapshot = raw .get("snapshot") .and_then(Value::as_str) - .or_else(|| raw.get("snapshot_id").and_then(Value::as_str)) - .or_else(|| raw.get("as_of").and_then(Value::as_str)) .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned); @@ -209,15 +186,6 @@ fn parse_function_list_page(raw: Value) -> Result { objects, next_cursor, snapshot, - pagination_field_present: raw.get("next_cursor").is_some() - || raw.get("nextCursor").is_some() - || raw.get("next").is_some() - || raw.get("cursor").is_some() - || raw.get("has_more").is_some() - || raw.get("hasMore").is_some(), - snapshot_field_present: raw.get("snapshot").is_some() - || raw.get("snapshot_id").is_some() - || raw.get("as_of").is_some(), }) } @@ -267,24 +235,9 @@ pub async fn insert_functions( } fn ignored_count(raw: &Value) -> Option { - if let Some(count) = raw.get("ignored_count").and_then(Value::as_u64) { - return usize::try_from(count).ok(); - } - - if let Some(items) = raw.get("ignored").and_then(Value::as_array) { - return Some(items.len()); - } - - if let Some(count) = raw - .get("stats") - .and_then(Value::as_object) - .and_then(|stats| stats.get("ignored")) + raw.get("ignored_count") .and_then(Value::as_u64) - { - return usize::try_from(count).ok(); - } - - None + .and_then(|count| usize::try_from(count).ok()) } #[cfg(test)] @@ -292,15 +245,15 @@ mod tests { use super::*; #[test] - fn ignored_count_extracts_known_shapes() { + fn ignored_count_extracts_canonical_shape() { let first = serde_json::json!({ "ignored_count": 3 }); assert_eq!(ignored_count(&first), Some(3)); let second = serde_json::json!({ "ignored": [1, 2] }); - assert_eq!(ignored_count(&second), Some(2)); + assert_eq!(ignored_count(&second), None); let third = serde_json::json!({ "stats": { "ignored": 5 } }); - assert_eq!(ignored_count(&third), Some(5)); + assert_eq!(ignored_count(&third), None); assert_eq!(ignored_count(&serde_json::json!({})), None); } @@ -313,7 +266,6 @@ mod tests { let page = parse_function_list_page(raw).expect("parse function page"); assert!(page.objects.is_empty()); - assert!(!page.pagination_field_present); assert!(page.next_cursor.is_none()); } @@ -321,37 +273,21 @@ mod tests { fn parse_function_list_page_detects_next_pagination_field() { let raw = serde_json::json!({ "objects": [], - "next": "cursor-1", + "next_cursor": "cursor-1", }); let page = parse_function_list_page(raw).expect("parse function page"); - assert!(page.pagination_field_present); assert_eq!(page.next_cursor.as_deref(), Some("cursor-1")); } #[test] - fn parse_function_list_page_supports_cursor_has_more_shape() { + fn parse_function_list_page_extracts_snapshot() { let raw = serde_json::json!({ "objects": [], - "cursor": "cursor-2", - "has_more": true, + "snapshot": "snapshot-1", }); let page = parse_function_list_page(raw).expect("parse function page"); - assert!(page.pagination_field_present); - assert_eq!(page.next_cursor.as_deref(), Some("cursor-2")); - } - - #[test] - fn parse_function_list_page_ignores_cursor_when_has_more_false() { - let raw = serde_json::json!({ - "objects": [], - "cursor": "cursor-2", - "has_more": false, - }); - - let page = parse_function_list_page(raw).expect("parse function page"); - assert!(page.pagination_field_present); - assert!(page.next_cursor.is_none()); + assert_eq!(page.snapshot.as_deref(), Some("snapshot-1")); } } diff --git a/src/functions/mod.rs b/src/functions/mod.rs index 70155b7..f57e459 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -369,23 +369,15 @@ pub(crate) struct PullArgs { )] pub language: FunctionsLanguage, - /// Project name filter. - #[arg(long, env = "BT_FUNCTIONS_PULL_PROJECT_NAME")] - pub project_name: Option, - /// Project id filter. - #[arg( - long, - env = "BT_FUNCTIONS_PULL_PROJECT_ID", - conflicts_with = "project_name" - )] + #[arg(long, env = "BT_FUNCTIONS_PULL_PROJECT_ID")] pub project_id: Option, /// Function id selector. #[arg(long, env = "BT_FUNCTIONS_PULL_ID")] pub id: Option, - /// Version selector (supports pretty version IDs). + /// Version selector. #[arg(long, env = "BT_FUNCTIONS_PULL_VERSION")] pub version: Option, @@ -414,10 +406,6 @@ impl PullArgs { } result } - - pub fn has_slug_selector(&self) -> bool { - !self.slugs.is_empty() || !self.slug_flag.is_empty() - } } #[derive(Debug, Clone, Args)] @@ -601,60 +589,42 @@ pub async fn run_typed(base: BaseArgs, args: FunctionArgs, kind: FunctionTypeFil } pub async fn run(base: BaseArgs, args: FunctionsArgs) -> Result<()> { + let function_type = args.function_type; match args.command { - None => { - let ctx = resolve_context(&base).await?; - list::run(&ctx, base.json, args.function_type).await - } - Some(FunctionsCommands::List(ref la)) => { + Some(FunctionsCommands::Push(push_args)) => push::run(base, push_args).await, + Some(FunctionsCommands::Pull(pull_args)) => pull::run(base, pull_args).await, + command => { let ctx = resolve_context(&base).await?; - list::run(&ctx, base.json, la.function_type.or(args.function_type)).await - } - Some(FunctionsCommands::View(v)) => { - let ctx = resolve_context(&base).await?; - view::run( - &ctx, - v.inner.slug(), - base.json, - v.inner.web, - v.inner.verbose, - v.function_type.or(args.function_type), - ) - .await - } - Some(FunctionsCommands::Delete(d)) => { - let ctx = resolve_context(&base).await?; - delete::run( - &ctx, - d.slug(), - d.force, - d.function_type.or(args.function_type), - ) - .await - } - Some(FunctionsCommands::Invoke(i)) => { - let ctx = resolve_context(&base).await?; - invoke::run( - &ctx, - &i.inner, - base.json, - i.function_type.or(args.function_type), - ) - .await + match command { + None => list::run(&ctx, base.json, function_type).await, + Some(FunctionsCommands::List(la)) => { + list::run(&ctx, base.json, la.function_type.or(function_type)).await + } + Some(FunctionsCommands::View(v)) => { + view::run( + &ctx, + v.inner.slug(), + base.json, + v.inner.web, + v.inner.verbose, + v.function_type.or(function_type), + ) + .await + } + Some(FunctionsCommands::Delete(d)) => { + delete::run(&ctx, d.slug(), d.force, d.function_type.or(function_type)).await + } + Some(FunctionsCommands::Invoke(i)) => { + invoke::run(&ctx, &i.inner, base.json, i.function_type.or(function_type)).await + } + Some(FunctionsCommands::Push(_)) | Some(FunctionsCommands::Pull(_)) => { + unreachable!("handled before context resolution") + } + } } - Some(FunctionsCommands::Push(args)) => push::run(base, args).await, - Some(FunctionsCommands::Pull(args)) => pull::run(base, args).await, } } -pub async fn run_push(base: BaseArgs, args: PushArgs) -> Result<()> { - push::run(base, args).await -} - -pub async fn run_pull(base: BaseArgs, args: PullArgs) -> Result<()> { - pull::run(base, args).await -} - #[cfg(test)] mod tests { use std::sync::{Mutex, MutexGuard, OnceLock}; @@ -895,22 +865,6 @@ mod tests { assert!(err.to_string().contains("typescript")); } - #[test] - fn pull_conflicts_project_selectors() { - let _guard = test_lock(); - let err = parse(&[ - "functions", - "pull", - "--project-id", - "p1", - "--project-name", - "proj", - ]) - .expect_err("should conflict"); - - assert!(err.to_string().contains("--project-name")); - } - #[test] fn pull_conflicts_id_and_slug_flag() { let _guard = test_lock(); diff --git a/src/functions/pull.rs b/src/functions/pull.rs index 1755c15..543562a 100644 --- a/src/functions/pull.rs +++ b/src/functions/pull.rs @@ -28,15 +28,6 @@ use crate::ui::{animations_enabled, is_quiet}; const PAGINATION_PAGE_LIMIT: usize = 10_000; const OUTPUT_LOCK_FILE: &str = ".bt-functions-pull.lock"; -// Pretty version IDs are a reversible encoding of internal transaction IDs -// (_xact_id). The encoding multiplies the xact ID (with a fixed top-nibble tag) -// by COPRIME mod 2^64, producing a 16-hex-char string that looks random but -// decodes back via the modular inverse. This lets `--version` accept either the -// raw numeric xact ID or the pretty hex form transparently. -const TOP_BITS: u64 = 0x0DE1u64 << 48; -const MODULUS: u128 = 1u128 << 64; -const COPRIME: u64 = 205_891_132_094_649; -const COPRIME_INVERSE: u64 = 1_522_336_535_492_693_385; #[derive(Debug, Clone, Deserialize)] struct PullFunctionRow { @@ -135,28 +126,6 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { if let Some(project_id) = &args.project_id { query.project_id = Some(project_id.clone()); - } else if let Some(project_name) = &args.project_name { - let projects = match get_projects_cached(&auth_ctx.client, &mut projects_cache).await { - Ok(projects) => projects, - Err(err) => { - return fail_pull( - &base, - &mut summary, - HardFailureReason::ResponseInvalid, - err.to_string(), - ); - } - }; - if let Err(err) = ensure_unambiguous_project_name(projects, project_name) { - return fail_pull( - &base, - &mut summary, - HardFailureReason::ResponseInvalid, - err.to_string(), - ); - } - resolved_project_name = Some(project_name.clone()); - query.project_name = Some(project_name.clone()); } else { let project = match resolve_project_context_optional(&base, &auth_ctx, false).await { Ok(project) => project, @@ -179,17 +148,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { query.id = Some(id.clone()); } if let Some(version) = &args.version { - query.version = match load_pretty_xact_compat(version) { - Ok(value) => Some(value), - Err(err) => { - return fail_pull( - &base, - &mut summary, - HardFailureReason::ResponseInvalid, - err.to_string(), - ); - } - }; + query.version = Some(version.clone()); } let resolved_slugs = args.resolved_slugs(); if resolved_slugs.len() == 1 { @@ -269,23 +228,13 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { let winners = select_winner_rows(narrowed_rows, &mut summary); - if (args.id.is_some() || args.has_slug_selector()) && winners.is_empty() { - spinner.finish_and_clear(); - return fail_pull( - &base, - &mut summary, - HardFailureReason::SelectorNotFound, - "no matching function rows found for selector".to_string(), - ); - } - let mut materializable = Vec::new(); for row in winners.iter().cloned() { if is_prompt_row(&row) { materializable.push(row); } else { summary.unsupported_records_skipped += 1; - if !is_quiet() { + if args.verbose { eprintln!( "{} skipping '{}' because it is not a prompt", style("warning:").yellow(), @@ -453,12 +402,12 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { .filter(|f| f.status == FileStatus::Success) .map(|f| short_display_path(&f.output_file, cwd.as_deref())) .collect(); - let file_label = if pulled_files.is_empty() { - from_label.clone() + if pulled_files.is_empty() { + eprintln!("No functions to pull from {from_label}"); } else { - pulled_files.join(", ") - }; - eprintln!("{} Pulled {subject} to {file_label}", style("✓").green(),); + let file_label = pulled_files.join(", "); + eprintln!("{} Pulled {subject} to {file_label}", style("✓").green(),); + } } } emit_summary(&base, &summary, args.verbose)?; @@ -481,41 +430,6 @@ async fn get_projects_cached<'a>( .expect("project cache should be initialized")) } -fn ensure_unambiguous_project_name(projects: &[Project], project_name: &str) -> Result<()> { - let exact: Vec<_> = projects - .iter() - .filter(|project| project.name == project_name) - .collect(); - - match exact.len() { - 0 => bail!("project '{project_name}' not found"), - 1 => Ok(()), - count => { - bail!("project-name '{project_name}' is ambiguous ({count} matches); use --project-id") - } - } -} - -fn modular_multiply(value: u64, prime: u64) -> u64 { - ((value as u128 * prime as u128) % MODULUS) as u64 -} - -fn load_pretty_xact_compat(encoded_hex: &str) -> Result { - if encoded_hex.len() != 16 { - return Ok(encoded_hex.to_string()); - } - let value = u64::from_str_radix(encoded_hex, 16).with_context(|| { - format!("invalid pretty version '{encoded_hex}' (expected 16 hex characters)") - })?; - let multiplied_inverse = modular_multiply(value, COPRIME_INVERSE); - let with_top_bits = TOP_BITS | multiplied_inverse; - let roundtrip = modular_multiply(with_top_bits, COPRIME); - if roundtrip != value { - bail!("invalid pretty version '{encoded_hex}' (failed compatibility decode)"); - } - Ok(with_top_bits.to_string()) -} - struct FetchRowsResult { rows: Vec, warnings: Vec, @@ -544,18 +458,11 @@ async fn fetch_all_function_rows( page_query.snapshot = snapshot.clone(); let page = api::list_functions_page(client, &page_query).await?; - - if page_count == 0 && !page.pagination_field_present { - page_count += 1; - rows.extend(page.objects); - break; - } - page_count += 1; if page_count == 1 { snapshot = page.snapshot.clone(); - } else if snapshot.is_none() || !page.snapshot_field_present { + } else if snapshot != page.snapshot { snapshot_consistent = false; } @@ -1177,7 +1084,7 @@ fn render_project_file_ts( out.push_str("// This file was automatically generated by bt functions pull. You can\n"); out.push_str("// generate it again by running:\n"); out.push_str(&format!( - "// $ bt functions pull --project-name {}\n", + "// $ bt functions pull --project {}\n", serde_json::to_string(project_name)? )); out.push_str( @@ -1252,7 +1159,7 @@ fn render_project_file_py( out.push_str("# This file was automatically generated by bt functions pull. You can\n"); out.push_str("# generate it again by running:\n"); out.push_str(&format!( - "# $ bt functions pull --project-name {} --language python\n", + "# $ bt functions pull --project {} --language python\n", serde_json::to_string(project_name)? )); out.push_str( @@ -1646,7 +1553,6 @@ mod tests { slug_flag: vec![], output_dir: PathBuf::from("."), language: FunctionsLanguage::Typescript, - project_name: None, project_id: None, id: Some("missing".to_string()), version: None, @@ -1703,7 +1609,6 @@ mod tests { slug_flag: vec!["gamma".to_string()], output_dir: PathBuf::from("."), language: FunctionsLanguage::Typescript, - project_name: None, project_id: None, id: None, version: None, @@ -1802,54 +1707,6 @@ mod tests { assert_ne!(first.to_ascii_lowercase(), second.to_ascii_lowercase()); } - #[test] - fn render_project_file_matches_legacy_shape() { - let row = PullFunctionRow { - id: "f1".to_string(), - name: "Doc Search".to_string(), - slug: "doc-search".to_string(), - project_id: "p1".to_string(), - project_name: Some("woohoo".to_string()), - description: Some(String::new()), - prompt_data: Some(serde_json::json!({ - "prompt": { - "type": "chat", - "messages": [ - { "content": "Hello", "role": "system" } - ] - }, - "options": { - "model": "gpt-4o-mini" - }, - "tool_functions": [ - { "type": "function", "id": "tool-1" } - ] - })), - function_data: Some(serde_json::json!({ "type": "prompt" })), - created: None, - _xact_id: Some("123".to_string()), - }; - - let rendered = render_project_file( - FunctionsLanguage::Typescript, - "woohoo", - "braintrust/woohoo.ts", - &[row], - ) - .expect("rendered"); - - assert!(rendered.contains("automatically generated by bt functions pull")); - assert!(rendered.contains("bt functions pull --project-name \"woohoo\"")); - assert!(rendered.contains("bt functions push --file \"braintrust/woohoo.ts\"")); - assert!( - rendered.contains("const project = braintrust.projects.create({\n name: \"woohoo\",") - ); - assert!(rendered.contains("export const docSearch = project.prompts.create({")); - assert!(!rendered.contains("description: \"\",")); - assert!(!rendered.contains("version:")); - assert!(!rendered.contains("id: \"f1\"")); - } - #[test] fn render_project_file_python_shape() { let row = PullFunctionRow { @@ -1887,7 +1744,7 @@ mod tests { ) .expect("rendered"); - assert!(rendered.contains("bt functions pull --project-name \"woohoo\" --language python")); + assert!(rendered.contains("bt functions pull --project \"woohoo\" --language python")); assert!(rendered.contains("bt functions push --file \"braintrust/woohoo.py\"")); assert!(rendered.contains("import braintrust")); assert!(rendered.contains("project = braintrust.projects.create(name=\"woohoo\")")); diff --git a/src/functions/push.rs b/src/functions/push.rs index 5adec3f..e631029 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -164,7 +164,6 @@ struct ResolvedManifestTargets { struct ClassifiedFiles { js_like: Vec, python: Vec, - had_directory_inputs: bool, explicit_file_inputs: usize, explicit_supported_files: usize, explicit_js_like: usize, @@ -259,16 +258,8 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { ); } }; - emit_language_selection_notice(&args, &classified, selected_language); + emit_language_selection_notice(&classified, selected_language); - if args.tsconfig.is_some() { - eprintln!( - "Notice: --tsconfig is enabled for JS runner and JS bundling (TS_NODE_PROJECT/TSX_TSCONFIG_PATH)." - ); - } - if !args.external_packages.is_empty() { - eprintln!("Notice: --external-packages will be applied to JS bundle builds."); - } if !args.external_packages.is_empty() && selected_language != SourceLanguage::JsLike { return fail_push( &base, @@ -288,6 +279,14 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { "invalid --requirements usage", ); } + if args.tsconfig.is_some() { + eprintln!( + "Notice: --tsconfig is enabled for JS runner and JS bundling (TS_NODE_PROJECT/TSX_TSCONFIG_PATH)." + ); + } + if !args.external_packages.is_empty() { + eprintln!("Notice: --external-packages will be applied to JS bundle builds."); + } let files = classified.files_for_language(selected_language); if files.is_empty() { @@ -382,14 +381,22 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { None }; + let fail_manifest_preflight = |message: String, file_message: &str| { + fail_push_with_all_skipped( + &base, + &files, + HardFailureReason::ManifestSchemaInvalid, + &message, + file_message, + ) + }; + let preflight = match collect_project_preflight(&base, &manifest) { Ok(preflight) => preflight, Err(err) => { let message = format!("failed to resolve project selectors in manifest: {err}"); - return fail_push_manifest_preflight( - &base, - &files, - &message, + return fail_manifest_preflight( + message, "skipped because project selector preflight failed", ); } @@ -419,10 +426,8 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { Ok(cache) => cache, Err(err) => { let message = format!("failed to resolve target projects for push: {err}"); - return fail_push_manifest_preflight( - &base, - &files, - &message, + return fail_manifest_preflight( + message, "skipped because project target resolution failed", ); } @@ -430,22 +435,15 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { if let Err(err) = validate_direct_project_ids(&auth_ctx, &preflight.direct_project_ids).await { let message = format!("failed to validate project ids for push: {err}"); - return fail_push_manifest_preflight( - &base, - &files, - &message, - "skipped because project id validation failed", - ); + return fail_manifest_preflight(message, "skipped because project id validation failed"); } let default_project_id = match resolve_default_project_id(&preflight, &project_name_cache) { Ok(id) => id, Err(err) => { let message = format!("failed to resolve default project for push: {err}"); - return fail_push_manifest_preflight( - &base, - &files, - &message, + return fail_manifest_preflight( + message, "skipped because default project resolution failed", ); } @@ -462,20 +460,16 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { Ok(targets) => targets, Err(err) => { let message = format!("failed to resolve target projects for push: {err}"); - return fail_push_manifest_preflight( - &base, - &files, - &message, + return fail_manifest_preflight( + message, "skipped because project target resolution failed", ); } }; if let Err(err) = validate_duplicate_slugs(&resolved_targets.entries) { - return fail_push_manifest_preflight( - &base, - &files, - &err.to_string(), + return fail_manifest_preflight( + err.to_string(), "skipped because duplicate slug validation failed", ); } @@ -493,10 +487,8 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { }; if resolved_targets.per_file.len() != manifest.files.len() { - return fail_push_manifest_preflight( - &base, - &files, - "internal error: resolved target count did not match manifest file count", + return fail_manifest_preflight( + "internal error: resolved target count did not match manifest file count".to_string(), "skipped because internal target resolution failed", ); } @@ -538,10 +530,8 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { { if resolved_file.source_file != file.source_file { spinner.finish_and_clear(); - return fail_push_manifest_preflight( - &base, - &files, - "internal error: resolved target source mismatch", + return fail_manifest_preflight( + "internal error: resolved target source mismatch".to_string(), "skipped because internal target resolution failed", ); } @@ -1092,7 +1082,6 @@ fn collect_classified_files(inputs: &[PathBuf]) -> Result { let mut js_like = BTreeSet::new(); let mut python = BTreeSet::new(); let mut allowed_roots = BTreeSet::new(); - let mut had_directory_inputs = false; let mut explicit_file_inputs = 0usize; let mut explicit_supported_files = 0usize; let mut explicit_js_like = 0usize; @@ -1137,7 +1126,6 @@ fn collect_classified_files(inputs: &[PathBuf]) -> Result { continue; } - had_directory_inputs = true; let canonical_dir = path .canonicalize() .with_context(|| format!("failed to canonicalize directory {}", path.display()))?; @@ -1148,7 +1136,6 @@ fn collect_classified_files(inputs: &[PathBuf]) -> Result { Ok(ClassifiedFiles { js_like: js_like.into_iter().collect(), python: python.into_iter().collect(), - had_directory_inputs, explicit_file_inputs, explicit_supported_files, explicit_js_like, @@ -1219,7 +1206,9 @@ fn select_push_language(args: &PushArgs, files: &ClassifiedFiles) -> Result { if !files.js_like.is_empty() && !files.python.is_empty() { - Ok(SourceLanguage::JsLike) + bail!( + "mixed source languages are not supported in one push invocation; run separate commands for Python and JS/TS files" + ); } else if !files.python.is_empty() { Ok(SourceLanguage::Python) } else { @@ -1231,11 +1220,7 @@ fn select_push_language(args: &PushArgs, files: &ClassifiedFiles) -> Result (files.python.len(), files.js_like.len(), "js/ts"), }; - if args.language == PushLanguage::Auto - && selected_language == SourceLanguage::JsLike - && files.had_directory_inputs - { - eprintln!( - "Notice: discovered mixed runtimes during directory scan; defaulting to JS/TS for compatibility and skipping {skipped_count} Python files. Run a separate `bt functions push --language python` invocation." - ); - return; - } - if skipped_count > 0 { eprintln!( "Notice: selected {} runtime; processing {selected_count} files and skipping {skipped_count} {skipped_label} files.", @@ -2530,22 +2505,36 @@ fn emit_summary(base: &BaseArgs, summary: &PushSummary) -> Result<()> { Ok(()) } -fn fail_push( +enum FailedPushFiles<'a> { + SingleFailed { + total_files: usize, + file_message: &'a str, + reason: HardFailureReason, + }, + AllSkipped { + files: &'a [PathBuf], + file_message: &'a str, + }, +} + +fn emit_failed_push_summary( base: &BaseArgs, - total_files: usize, reason: HardFailureReason, - message: String, - file_message: &str, + message: &str, + file_shape: FailedPushFiles<'_>, ) -> Result<()> { - if base.json { - let summary = PushSummary { - status: CommandStatus::Failed, + if !base.json { + return Ok(()); + } + + let (total_files, files) = match file_shape { + FailedPushFiles::SingleFailed { total_files, - uploaded_files: 0, - failed_files: 0, - skipped_files: total_files, - ignored_entries: 0, - files: vec![PushFileReport { + file_message, + reason, + } => ( + total_files, + vec![PushFileReport { source_file: String::new(), status: FileStatus::Failed, uploaded_entries: 0, @@ -2554,34 +2543,13 @@ fn fail_push( bundle_id: None, message: Some(file_message.to_string()), }], - warnings: vec![], - errors: vec![ReportError { - reason, - message: message.clone(), - }], - }; - emit_summary(base, &summary)?; - } - - bail!(message); -} - -fn fail_push_with_all_skipped( - base: &BaseArgs, - files: &[PathBuf], - reason: HardFailureReason, - message: &str, - file_message: &str, -) -> Result<()> { - if base.json { - let summary = PushSummary { - status: CommandStatus::Failed, - total_files: files.len(), - uploaded_files: 0, - failed_files: 0, - skipped_files: files.len(), - ignored_entries: 0, - files: files + ), + FailedPushFiles::AllSkipped { + files, + file_message, + } => ( + files.len(), + files .iter() .map(|path| PushFileReport { source_file: path.display().to_string(), @@ -2593,31 +2561,63 @@ fn fail_push_with_all_skipped( message: Some(file_message.to_string()), }) .collect(), - warnings: vec![], - errors: vec![ReportError { - reason, - message: message.to_string(), - }], - }; - emit_summary(base, &summary)?; - } + ), + }; - bail!(message.to_string()); + let summary = PushSummary { + status: CommandStatus::Failed, + total_files, + uploaded_files: 0, + failed_files: 0, + skipped_files: total_files, + ignored_entries: 0, + files, + warnings: vec![], + errors: vec![ReportError { + reason, + message: message.to_string(), + }], + }; + emit_summary(base, &summary) +} + +fn fail_push( + base: &BaseArgs, + total_files: usize, + reason: HardFailureReason, + message: String, + file_message: &str, +) -> Result<()> { + emit_failed_push_summary( + base, + reason, + &message, + FailedPushFiles::SingleFailed { + total_files, + file_message, + reason, + }, + )?; + bail!(message); } -fn fail_push_manifest_preflight( +fn fail_push_with_all_skipped( base: &BaseArgs, files: &[PathBuf], + reason: HardFailureReason, message: &str, file_message: &str, ) -> Result<()> { - fail_push_with_all_skipped( + emit_failed_push_summary( base, - files, - HardFailureReason::ManifestSchemaInvalid, + reason, message, - file_message, - ) + FailedPushFiles::AllSkipped { + files, + file_message, + }, + )?; + bail!(message.to_string()); } #[cfg(test)] @@ -2733,7 +2733,7 @@ mod tests { } #[test] - fn select_push_language_auto_prefers_js_like_for_mixed_scan() { + fn select_push_language_auto_rejects_mixed_scan() { let args = PushArgs { files: vec![PathBuf::from(".")], file_flag: vec![], @@ -2749,7 +2749,6 @@ mod tests { let classified = ClassifiedFiles { js_like: vec![PathBuf::from("/tmp/a.ts")], python: vec![PathBuf::from("/tmp/a.py")], - had_directory_inputs: true, explicit_file_inputs: 0, explicit_supported_files: 0, explicit_js_like: 0, @@ -2757,8 +2756,8 @@ mod tests { allowed_roots: Vec::new(), }; - let selected = select_push_language(&args, &classified).expect("select language"); - assert_eq!(selected, SourceLanguage::JsLike); + let err = select_push_language(&args, &classified).expect_err("must fail"); + assert!(err.to_string().contains("mixed source languages")); } #[test] @@ -2778,7 +2777,6 @@ mod tests { let classified = ClassifiedFiles { js_like: vec![PathBuf::from("/tmp/a.ts")], python: vec![PathBuf::from("/tmp/b.py")], - had_directory_inputs: false, explicit_file_inputs: 2, explicit_supported_files: 2, explicit_js_like: 1, diff --git a/src/main.rs b/src/main.rs index d80374d..db7760b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,8 +60,6 @@ Core Projects & resources projects Manage projects prompts Manage prompts - push Compatibility alias for `functions push` - pull Compatibility alias for `functions pull` functions Manage functions (tools, scorers, and more) tools Manage tools scorers Manage scorers @@ -142,10 +140,6 @@ enum Commands { Scorers(CLIArgs), /// Manage functions (tools, scorers, and more) Functions(CLIArgs), - /// Compatibility alias for `functions push` - Push(CLIArgs), - /// Compatibility alias for `functions pull` - Pull(CLIArgs), /// Manage experiments Experiments(CLIArgs), /// Synchronize project logs between Braintrust and local NDJSON files @@ -177,8 +171,6 @@ impl Commands { Commands::Tools(cmd) => &cmd.base, Commands::Scorers(cmd) => &cmd.base, Commands::Functions(cmd) => &cmd.base, - Commands::Push(cmd) => &cmd.base, - Commands::Pull(cmd) => &cmd.base, Commands::Experiments(cmd) => &cmd.base, Commands::Sync(cmd) => &cmd.base, Commands::Util(cmd) => &cmd.base, @@ -232,8 +224,6 @@ async fn try_main() -> Result<()> { Commands::Tools(cmd) => tools::run(cmd.base, cmd.args).await?, Commands::Scorers(cmd) => scorers::run(cmd.base, cmd.args).await?, Commands::Functions(cmd) => functions::run(cmd.base, cmd.args).await?, - Commands::Push(cmd) => functions::run_push(cmd.base, cmd.args).await?, - Commands::Pull(cmd) => functions::run_pull(cmd.base, cmd.args).await?, Commands::Experiments(cmd) => experiments::run(cmd.base, cmd.args).await?, Commands::Sync(cmd) => sync::run(cmd.base, cmd.args).await?, Commands::Util(cmd) => util_cmd::run(cmd.base, cmd.args).await?, diff --git a/tests/functions-fixtures/pull-help-env-vars/fixture.json b/tests/functions-fixtures/pull-help-env-vars/fixture.json index 5756da8..8b012ec 100644 --- a/tests/functions-fixtures/pull-help-env-vars/fixture.json +++ b/tests/functions-fixtures/pull-help-env-vars/fixture.json @@ -4,7 +4,6 @@ "stdout_contains": [ "BT_FUNCTIONS_PULL_OUTPUT_DIR", "BT_FUNCTIONS_PULL_PROJECT_ID", - "BT_FUNCTIONS_PULL_PROJECT_NAME", "BT_FUNCTIONS_PULL_ID", "BT_FUNCTIONS_PULL_SLUG", "BT_FUNCTIONS_PULL_VERSION", diff --git a/tests/functions-fixtures/pull-help-flags/fixture.json b/tests/functions-fixtures/pull-help-flags/fixture.json index 197aca7..45d2534 100644 --- a/tests/functions-fixtures/pull-help-flags/fixture.json +++ b/tests/functions-fixtures/pull-help-flags/fixture.json @@ -4,7 +4,6 @@ "stdout_contains": [ "--output-dir", "--project-id", - "--project-name", "--id", "--slug", "--version", diff --git a/tests/functions-fixtures/pull-project-id-name-conflict/fixture.json b/tests/functions-fixtures/pull-project-id-name-conflict/fixture.json deleted file mode 100644 index 07538fd..0000000 --- a/tests/functions-fixtures/pull-project-id-name-conflict/fixture.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "command": [ - "functions", - "pull", - "--project-id", - "proj_123", - "--project-name", - "demo" - ], - "expect_success": false, - "stderr_contains": ["--project-id", "--project-name"] -} diff --git a/tests/functions.rs b/tests/functions.rs index 2b9a1bb..95dd87e 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -559,7 +559,6 @@ fn functions_pull_help_includes_expected_flags() { let stdout = String::from_utf8_lossy(&output.stdout); assert!(stdout.contains("--output-dir")); assert!(stdout.contains("--project-id")); - assert!(stdout.contains("--project-name")); assert!(stdout.contains("--version")); assert!(stdout.contains("--language")); } @@ -642,34 +641,6 @@ fn functions_help_lists_push_and_pull() { assert!(stdout.contains("pull")); } -#[test] -fn top_level_push_help_is_available() { - let output = Command::new(bt_binary_path()) - .arg("push") - .arg("--help") - .output() - .expect("run bt push --help"); - - assert!(output.status.success()); - let stdout = String::from_utf8_lossy(&output.stdout); - assert!(stdout.contains("--if-exists")); - assert!(stdout.contains("--file")); -} - -#[test] -fn top_level_pull_help_is_available() { - let output = Command::new(bt_binary_path()) - .arg("pull") - .arg("--help") - .output() - .expect("run bt pull --help"); - - assert!(output.status.success()); - let stdout = String::from_utf8_lossy(&output.stdout); - assert!(stdout.contains("--output-dir")); - assert!(stdout.contains("--version")); -} - #[test] fn push_and_pull_help_are_machine_readable() { let push_help = Command::new(bt_binary_path()) @@ -1928,6 +1899,7 @@ async fn functions_pull_selector_with_unsupported_only_rows_still_succeeds() { .expect("output dir should be valid UTF-8 for test"), "--language", "typescript", + "--verbose", ]) .env("BRAINTRUST_API_KEY", "test-key") .env("BRAINTRUST_ORG_NAME", "test-org") From b37e5e136f6f59eadcfd3aecaaf07bda0b32e3f8 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Mon, 9 Mar 2026 20:02:13 -0700 Subject: [PATCH 16/28] fix: add legacy prompt support with tool function resolution --- scripts/functions-runner.ts | 130 +++++++++++++++++++++++++++++++++++- src/functions/push.rs | 86 ++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 1 deletion(-) diff --git a/scripts/functions-runner.ts b/scripts/functions-runner.ts index 4a35c47..de15262 100644 --- a/scripts/functions-runner.ts +++ b/scripts/functions-runner.ts @@ -30,6 +30,22 @@ type CodeRegistryItem = { type EventRegistryItem = { project?: ProjectRef; toFunctionDefinition?: (resolver: Resolver) => Promise; + name?: string; + slug?: string; + description?: string; + ifExists?: string; + metadata?: JsonValue; + prompt?: JsonValue; + toolFunctions?: LegacyToolFunction[]; +}; + +type LegacyToolFunction = { + type?: string; + id?: string; + name?: string; + slug?: string; + project?: ProjectRef; + project_id?: string; }; type CodeEntry = { @@ -101,6 +117,7 @@ function currentRegistry(fallback: EvalRegistry): EvalRegistry { async function collectFunctionEvents( items: EventRegistryItem[], + includeLegacyPrompts: boolean, ): Promise { const entries: FunctionEventEntry[] = []; @@ -113,6 +130,12 @@ async function collectFunctionEvents( for (const item of items) { if (!item.toFunctionDefinition) { + if (includeLegacyPrompts) { + const entry = await collectLegacyPromptEvent(item, resolver); + if (entry) { + entries.push(entry); + } + } continue; } @@ -141,6 +164,107 @@ async function collectFunctionEvents( return entries; } +async function collectLegacyPromptEvent( + item: EventRegistryItem, + resolver: Resolver, +): Promise { + if (typeof item.name !== "string" || typeof item.slug !== "string") { + return null; + } + + const normalizedPrompt = toJsonValue(item.prompt ?? {}); + if (!isJsonObject(normalizedPrompt)) { + return null; + } + + const promptData: JsonObject = { ...normalizedPrompt }; + const toolFunctions = Array.isArray(item.toolFunctions) + ? item.toolFunctions + : []; + if (toolFunctions.length > 0) { + const resolvedTools: JsonValue[] = []; + for (const tool of toolFunctions) { + const resolved = await resolveLegacyToolFunction(tool, resolver); + if (resolved) { + resolvedTools.push(resolved); + } + } + if (resolvedTools.length > 0) { + promptData.tool_functions = resolvedTools; + } + } + + const selector = asProjectSelector(item.project); + const projectId = + typeof selector.project_id === "string" ? selector.project_id : undefined; + const projectName = + typeof selector.project_name === "string" + ? selector.project_name + : undefined; + + const event: JsonObject = { + name: item.name, + slug: item.slug, + description: typeof item.description === "string" ? item.description : "", + function_data: { + type: "prompt", + }, + prompt_data: promptData, + }; + if (typeof item.ifExists === "string") { + event.if_exists = item.ifExists; + } + if (item.metadata !== undefined) { + event.metadata = item.metadata; + } + + return { + kind: "function_event", + project_id: projectId, + project_name: projectName, + event, + }; +} + +async function resolveLegacyToolFunction( + tool: LegacyToolFunction, + resolver: Resolver, +): Promise { + if ( + typeof tool.slug === "string" && + tool.slug.length > 0 && + tool.project !== undefined + ) { + const projectId = await resolver.resolve(tool.project); + if (projectId.length > 0) { + return { + type: "slug", + project_id: projectId, + slug: tool.slug, + }; + } + } + + const direct: JsonObject = {}; + if (typeof tool.type === "string") { + direct.type = tool.type; + } + if (typeof tool.id === "string") { + direct.id = tool.id; + } + if (typeof tool.name === "string") { + direct.name = tool.name; + } + if (typeof tool.project_id === "string") { + direct.project_id = tool.project_id; + } + if (typeof tool.slug === "string") { + direct.slug = tool.slug; + } + + return Object.keys(direct).length > 0 ? direct : null; +} + function collectCodeEntries(items: CodeRegistryItem[]): CodeEntry[] { const entries: CodeEntry[] = []; @@ -197,9 +321,13 @@ async function processFile(filePath: string): Promise { const entries: Array = [ ...collectCodeEntries(registry.functions as CodeRegistryItem[]), - ...(await collectFunctionEvents(registry.prompts as EventRegistryItem[])), + ...(await collectFunctionEvents( + registry.prompts as EventRegistryItem[], + true, + )), ...(await collectFunctionEvents( registry.parameters as EventRegistryItem[], + false, )), ]; diff --git a/src/functions/push.rs b/src/functions/push.rs index e631029..b7b9f63 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -848,6 +848,7 @@ async fn push_file( Value::String(args.if_exists.as_str().to_string()), ); } + default_prompt_function_type(object); } function_events.push(event); @@ -962,6 +963,36 @@ fn calculate_upload_counts(total_entries: usize, ignored_entries: Option) (uploaded_entries, ignored_entries) } +fn default_prompt_function_type(event: &mut Map) { + if !is_prompt_function_event(event) { + return; + } + + if function_type_missing_or_empty(event.get("function_type")) { + event.insert( + "function_type".to_string(), + Value::String("prompt".to_string()), + ); + } +} + +fn is_prompt_function_event(event: &Map) -> bool { + event + .get("function_data") + .and_then(Value::as_object) + .and_then(|function_data| function_data.get("type")) + .and_then(Value::as_str) + == Some("prompt") +} + +fn function_type_missing_or_empty(value: Option<&Value>) -> bool { + match value { + None | Some(Value::Null) => true, + Some(Value::String(s)) => s.trim().is_empty(), + Some(_) => false, + } +} + fn run_functions_runner( args: &PushArgs, files: &[PathBuf], @@ -2848,6 +2879,61 @@ mod tests { assert_eq!(calculate_upload_counts(3, None), (3, 0)); } + #[test] + fn prompt_event_defaults_function_type_when_missing() { + let mut event = Map::new(); + event.insert( + "function_data".to_string(), + serde_json::json!({ + "type": "prompt" + }), + ); + + default_prompt_function_type(&mut event); + + assert_eq!( + event.get("function_type"), + Some(&Value::String("prompt".to_string())) + ); + } + + #[test] + fn prompt_event_preserves_existing_function_type() { + let mut event = Map::new(); + event.insert( + "function_data".to_string(), + serde_json::json!({ + "type": "prompt" + }), + ); + event.insert( + "function_type".to_string(), + Value::String("scorer".to_string()), + ); + + default_prompt_function_type(&mut event); + + assert_eq!( + event.get("function_type"), + Some(&Value::String("scorer".to_string())) + ); + } + + #[test] + fn non_prompt_event_does_not_default_function_type() { + let mut event = Map::new(); + event.insert( + "function_data".to_string(), + serde_json::json!({ + "type": "code" + }), + ); + + default_prompt_function_type(&mut event); + + assert!(!event.contains_key("function_type")); + } + #[test] fn requirements_reference_escape_is_rejected() { let dir = tempfile::tempdir().expect("tempdir"); From 8e91881e3a0003eba7937a65fc2a3ba88d179071 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Mon, 9 Mar 2026 20:17:03 -0700 Subject: [PATCH 17/28] fix(functions-runner): force re-evaluation of imported input files --- scripts/functions-runner.ts | 12 ++- tests/functions.rs | 151 ++++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+), 1 deletion(-) diff --git a/scripts/functions-runner.ts b/scripts/functions-runner.ts index de15262..b85c7f9 100644 --- a/scripts/functions-runner.ts +++ b/scripts/functions-runner.ts @@ -83,6 +83,7 @@ type Manifest = { }; type EvalRegistry = NonNullable; +let moduleImportNonce = 0; function freshRegistry(): EvalRegistry { return { @@ -115,6 +116,15 @@ function currentRegistry(fallback: EvalRegistry): EvalRegistry { }; } +function buildIsolatedImportUrl(absolutePath: string): string { + const moduleUrl = pathToFileURL(absolutePath); + // Force top-level evaluation for each input file, even if imported earlier + // as a dependency while processing a previous input file. + moduleUrl.searchParams.set("bt_runner_input_nonce", `${moduleImportNonce}`); + moduleImportNonce += 1; + return moduleUrl.href; +} + async function collectFunctionEvents( items: EventRegistryItem[], includeLegacyPrompts: boolean, @@ -316,7 +326,7 @@ async function processFile(filePath: string): Promise { globalThis._evals = fallbackRegistry; globalThis._lazy_load = true; - await import(pathToFileURL(absolutePath).href); + await import(buildIsolatedImportUrl(absolutePath)); const registry = currentRegistry(fallbackRegistry); const entries: Array = [ diff --git a/tests/functions.rs b/tests/functions.rs index 95dd87e..83e7534 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -902,6 +902,157 @@ globalThis._evals.functions.push({ ); } +#[test] +fn functions_js_runner_reexecutes_imported_input_files() { + if !command_exists("node") { + eprintln!( + "Skipping functions_js_runner_reexecutes_imported_input_files (node not installed)." + ); + return; + } + let Some(tsc) = find_tsc() else { + eprintln!( + "Skipping functions_js_runner_reexecutes_imported_input_files (tsc not installed)." + ); + return; + }; + + let root = repo_root(); + let tmp = tempdir().expect("tempdir"); + let sample_b_path = tmp.path().join("sample-b.mjs"); + std::fs::write( + &sample_b_path, + r#"globalThis._evals ??= { functions: [], prompts: [], parameters: [], evaluators: {}, reporters: {} }; +globalThis._evals.functions.push({ + name: "js-tool-b", + slug: "js-tool-b", + type: "tool", + preview: "export function b() { return 2; }" +}); +export const b = 2; +"#, + ) + .expect("write sample-b.mjs"); + + let sample_a_path = tmp.path().join("sample-a.mjs"); + std::fs::write( + &sample_a_path, + r#"import "./sample-b.mjs"; +globalThis._evals ??= { functions: [], prompts: [], parameters: [], evaluators: {}, reporters: {} }; +globalThis._evals.functions.push({ + name: "js-tool-a", + slug: "js-tool-a", + type: "tool", + preview: "export function a() { return 1; }" +}); +"#, + ) + .expect("write sample-a.mjs"); + + let runner_dir = tmp.path().join("runner"); + let compile_output = Command::new(&tsc) + .current_dir(&root) + .args([ + "scripts/functions-runner.ts", + "scripts/runner-common.ts", + "--module", + "esnext", + "--target", + "es2020", + "--moduleResolution", + "bundler", + "--outDir", + ]) + .arg(&runner_dir) + .output() + .expect("compile functions runner"); + if !compile_output.status.success() { + let stderr = String::from_utf8_lossy(&compile_output.stderr); + panic!("tsc failed for functions runner:\n{stderr}"); + } + + let runner_js = runner_dir.join("functions-runner.js"); + let runner_common_js = runner_dir.join("runner-common.js"); + assert!(runner_js.is_file(), "compiled functions-runner.js missing"); + assert!( + runner_common_js.is_file(), + "compiled runner-common.js missing" + ); + + let runner_code = std::fs::read_to_string(&runner_js).expect("read compiled runner"); + let patched_runner_code = runner_code + .replace("\"./runner-common\"", "\"./runner-common.js\"") + .replace("'./runner-common'", "'./runner-common.js'"); + assert_ne!( + runner_code, patched_runner_code, + "compiled runner import path did not contain ./runner-common" + ); + std::fs::write(&runner_js, patched_runner_code).expect("write patched compiled runner"); + std::fs::write(runner_dir.join("package.json"), r#"{ "type": "module" }"#) + .expect("write runner package.json"); + + let output = Command::new("node") + .arg(&runner_js) + .arg(&sample_a_path) + .arg(&sample_b_path) + .output() + .expect("run compiled functions runner"); + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("compiled functions runner failed:\n{stderr}"); + } + + let manifest: Value = serde_json::from_slice(&output.stdout).expect("parse manifest JSON"); + let files = manifest["files"].as_array().expect("files array"); + assert_eq!(files.len(), 2, "expected two manifest files"); + + let sample_a_canonical = sample_a_path + .canonicalize() + .expect("canonicalize sample-a.mjs"); + let sample_b_canonical = sample_b_path + .canonicalize() + .expect("canonicalize sample-b.mjs"); + let mut files_by_source = BTreeMap::new(); + for file in files { + let source_file = file + .get("source_file") + .and_then(Value::as_str) + .expect("source_file"); + let canonical_source = PathBuf::from(source_file) + .canonicalize() + .expect("canonicalize manifest source_file"); + files_by_source.insert(canonical_source, file); + } + + let file_a = files_by_source + .get(&sample_a_canonical) + .expect("manifest file for sample-a.mjs"); + let entries_a = file_a + .get("entries") + .and_then(Value::as_array) + .expect("sample-a entries"); + assert!( + entries_a + .iter() + .any(|entry| { entry.get("slug").and_then(Value::as_str) == Some("js-tool-a") }), + "expected sample-a.mjs entries to include js-tool-a" + ); + + let file_b = files_by_source + .get(&sample_b_canonical) + .expect("manifest file for sample-b.mjs"); + let entries_b = file_b + .get("entries") + .and_then(Value::as_array) + .expect("sample-b entries"); + assert!( + entries_b + .iter() + .any(|entry| { entry.get("slug").and_then(Value::as_str) == Some("js-tool-b") }), + "expected sample-b.mjs entries to include js-tool-b" + ); +} + #[test] fn functions_python_runner_emits_valid_manifest_with_bundle() { let Some(python) = find_python() else { From 4c40773e7b1f819f5c1fc41bf9ba4a1a47ebff2e Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Mon, 16 Mar 2026 18:47:01 -0700 Subject: [PATCH 18/28] fix(functions-push): restore runner, bundler, and project parity --- scripts/functions-bundler.ts | 300 ++++++++++++++---- scripts/functions-runner.py | 161 +++++++++- scripts/functions-runner.ts | 147 ++++++++- src/functions/mod.rs | 9 + src/functions/push.rs | 235 +++++++------- src/python_runner.rs | 12 + .../push-help-flags/fixture.json | 1 + tests/functions.rs | 1 + 8 files changed, 679 insertions(+), 187 deletions(-) diff --git a/scripts/functions-bundler.ts b/scripts/functions-bundler.ts index 62de932..02b9d11 100644 --- a/scripts/functions-bundler.ts +++ b/scripts/functions-bundler.ts @@ -1,5 +1,8 @@ +import { spawnSync } from "node:child_process"; import fs from "node:fs"; +import { createRequire } from "node:module"; import path from "node:path"; +import { pathToFileURL } from "node:url"; type EsbuildBuild = (options: Record) => Promise; type EsbuildModule = { @@ -36,53 +39,147 @@ function loadTsconfigPath(): string | undefined { return undefined; } -function createMarkKnownPackagesExternalPlugin(additionalPackages: string[]) { - return { - name: "make-known-packages-external", - setup(build: { - onResolve: ( - opts: { filter: RegExp }, - cb: (args: { path: string }) => { path: string; external: boolean }, - ) => void; - }) { - const knownPackages = [ - "braintrust", - "autoevals", - "@braintrust/", - "config", - "lightningcss", - "@mapbox/node-pre-gyp", - "fsevents", - "chokidar", - ...additionalPackages, - ]; - const escapedPackages = knownPackages.map((pkg) => { - const escaped = pkg.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); - if (pkg.endsWith("/")) { - return `${escaped}.*`; - } - return `${escaped}(?:\\/.*)?`; - }); - const knownPackagesFilter = new RegExp( - `^(${escapedPackages.join("|")})$`, +function buildExternalPackagePatterns(additionalPackages: string[]): string[] { + const knownPackages = [ + "braintrust", + "autoevals", + "@braintrust/", + "config", + "lightningcss", + "@mapbox/node-pre-gyp", + "fsevents", + "chokidar", + ...additionalPackages, + ]; + const patterns = new Set(["node_modules/*"]); + for (const pkg of knownPackages) { + const trimmed = pkg.trim(); + if (!trimmed) { + continue; + } + if (trimmed.endsWith("/")) { + patterns.add(`${trimmed}*`); + continue; + } + patterns.add(trimmed); + patterns.add(`${trimmed}/*`); + } + return [...patterns]; +} + +function findNodeModulesBinary( + binary: string, + startPath: string, +): string | null { + let current = path.resolve(startPath); + if (!fs.existsSync(current)) { + current = path.dirname(current); + } else if (!fs.statSync(current).isDirectory()) { + current = path.dirname(current); + } + + const binaryCandidates = + process.platform === "win32" ? [`${binary}.cmd`, binary] : [binary]; + + while (true) { + for (const candidateName of binaryCandidates) { + const candidate = path.join( + current, + "node_modules", + ".bin", + candidateName, ); - build.onResolve({ filter: knownPackagesFilter }, (args) => ({ - path: args.path, - external: true, - })); - }, - }; + if (fs.existsSync(candidate)) { + return candidate; + } + } + + const parent = path.dirname(current); + if (parent === current) { + return null; + } + current = parent; + } +} + +function resolveEsbuildBinary(sourceFile: string): string | null { + const searchRoots = [path.resolve(sourceFile), process.cwd()]; + const seen = new Set(); + for (const root of searchRoots) { + const normalized = path.resolve(root); + if (seen.has(normalized)) { + continue; + } + seen.add(normalized); + const candidate = findNodeModulesBinary("esbuild", normalized); + if (candidate) { + return candidate; + } + } + return null; +} + +function resolveEsbuildModulePath(sourceFile: string): string | null { + const filePath = path.resolve(sourceFile); + try { + const requireFromFile = createRequire(pathToFileURL(filePath).href); + return requireFromFile.resolve("esbuild"); + } catch { + // Fall through to process cwd. + } + + try { + const requireFromCwd = createRequire(path.join(process.cwd(), "noop.js")); + return requireFromCwd.resolve("esbuild"); + } catch { + return null; + } +} + +function normalizeEsbuildModule(loaded: unknown): EsbuildModule | null { + if (isEsbuildModule(loaded)) { + return loaded; + } + if (isObject(loaded) && isEsbuildModule(loaded.default)) { + return loaded.default; + } + return null; } -async function loadEsbuild(): Promise { +async function loadEsbuild(sourceFile: string): Promise { + const resolvedPath = resolveEsbuildModulePath(sourceFile); + if (resolvedPath) { + if (typeof require === "function") { + try { + const loaded = require(resolvedPath) as unknown; + const normalized = normalizeEsbuildModule(loaded); + if (normalized) { + return normalized; + } + } catch { + // Fall through to dynamic import. + } + } + + try { + const loaded = (await import( + pathToFileURL(resolvedPath).href + )) as unknown; + const normalized = normalizeEsbuildModule(loaded); + if (normalized) { + return normalized; + } + } catch { + // Fall through to direct require/import. + } + } + if (typeof require === "function") { try { const loaded = require("esbuild") as unknown; - if (isEsbuildModule(loaded)) { - return loaded; - } - if (isObject(loaded) && isEsbuildModule(loaded.default)) { - return loaded.default; + const normalized = normalizeEsbuildModule(loaded); + if (normalized) { + return normalized; } } catch { // Fall through to dynamic import. @@ -93,19 +190,80 @@ async function loadEsbuild(): Promise { // Keep module name dynamic so TypeScript doesn't require local esbuild types at compile time. const specifier = "esbuild"; const loaded = (await import(specifier)) as unknown; - if (isEsbuildModule(loaded)) { - return loaded; - } - if (isObject(loaded) && isEsbuildModule(loaded.default)) { - return loaded.default; + const normalized = normalizeEsbuildModule(loaded); + if (normalized) { + return normalized; } } catch { // handled below } - throw new Error( - "failed to load esbuild for JS bundling; install esbuild in your project or use a runner that provides it", - ); + return null; +} + +function computeNodeTargetVersion(): string { + return typeof process.version === "string" && process.version.startsWith("v") + ? process.version.slice(1) + : process.versions.node || "18"; +} + +async function bundleWithEsbuildModule( + esbuild: EsbuildModule, + sourceFile: string, + outputFile: string, + tsconfig: string | undefined, + external: string[], +): Promise { + await esbuild.build({ + entryPoints: [sourceFile], + bundle: true, + treeShaking: true, + platform: "node", + target: `node${computeNodeTargetVersion()}`, + write: true, + outfile: outputFile, + tsconfig, + external, + }); +} + +function bundleWithEsbuildBinary( + esbuildBinary: string, + sourceFile: string, + outputFile: string, + tsconfig: string | undefined, + external: string[], +): void { + const args: string[] = [ + sourceFile, + "--bundle", + "--tree-shaking=true", + "--platform=node", + `--target=node${computeNodeTargetVersion()}`, + `--outfile=${outputFile}`, + ]; + + if (tsconfig) { + args.push(`--tsconfig=${tsconfig}`); + } + for (const pattern of external) { + args.push(`--external:${pattern}`); + } + + const result = spawnSync(esbuildBinary, args, { encoding: "utf8" }); + if (result.error) { + throw new Error( + `failed to invoke esbuild CLI at ${esbuildBinary}: ${result.error.message}`, + ); + } + if (result.status !== 0) { + const stderr = (result.stderr ?? "").trim(); + const stdout = (result.stdout ?? "").trim(); + const details = stderr || stdout || "unknown error"; + throw new Error( + `esbuild CLI exited with status ${String(result.status)}: ${details}`, + ); + } } async function main(): Promise { @@ -114,32 +272,42 @@ async function main(): Promise { throw new Error("functions-bundler requires "); } - const esbuild = await loadEsbuild(); const externalPackages = parseExternalPackages( process.env.BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES, ); + const external = buildExternalPackagePatterns(externalPackages); const tsconfig = loadTsconfigPath(); const outputDir = path.dirname(outputFile); fs.mkdirSync(outputDir, { recursive: true }); - const targetVersion = - typeof process.version === "string" && process.version.startsWith("v") - ? process.version.slice(1) - : process.versions.node || "18"; + const esbuild = await loadEsbuild(sourceFile); + if (esbuild) { + await bundleWithEsbuildModule( + esbuild, + sourceFile, + outputFile, + tsconfig, + external, + ); + return; + } - await esbuild.build({ - entryPoints: [sourceFile], - bundle: true, - treeShaking: true, - platform: "node", - target: `node${targetVersion}`, - write: true, - outfile: outputFile, - tsconfig, - external: ["node_modules/*", "fsevents"], - plugins: [createMarkKnownPackagesExternalPlugin(externalPackages)], - }); + const esbuildBinary = resolveEsbuildBinary(sourceFile); + if (esbuildBinary) { + bundleWithEsbuildBinary( + esbuildBinary, + sourceFile, + outputFile, + tsconfig, + external, + ); + return; + } + + throw new Error( + "failed to load esbuild for JS bundling; install esbuild in your project or use a runner that provides it", + ); } main().catch((error: unknown) => { diff --git a/scripts/functions-runner.py b/scripts/functions-runner.py index d870193..1f16f06 100644 --- a/scripts/functions-runner.py +++ b/scripts/functions-runner.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import asyncio +import inspect import json import os import sys @@ -42,9 +43,20 @@ def to_json_value(value: Any) -> Any: def load_framework_globals() -> tuple[Any, Any, Any]: - from braintrust.framework2.global_ import functions, prompts - from braintrust.framework2.lazy_load import _set_lazy_load as lazy - return functions, prompts, lazy + # Prefer current SDK layout first: + # - braintrust.framework2 exposes module-level `global_` + # - braintrust.framework exposes `_set_lazy_load` + try: + from braintrust.framework import _set_lazy_load as lazy + from braintrust.framework2 import global_ as global_state + + return global_state.functions, global_state.prompts, lazy + except (ImportError, ModuleNotFoundError): + # Backward compatibility with older SDK layout. + from braintrust.framework2.global_ import functions, prompts + from braintrust.framework2.lazy_load import _set_lazy_load as lazy + + return functions, prompts, lazy def normalize_project_selector(project: Any) -> tuple[str | None, str | None]: @@ -72,6 +84,10 @@ def normalize_project_selector(project: Any) -> tuple[str | None, str | None]: return project_id.strip(), None if isinstance(project_name, str) and project_name.strip(): return None, project_name.strip() + # braintrust.framework2.Project exposes `.name`. + project_display_name = getattr(project, "name", None) + if isinstance(project_display_name, str) and project_display_name.strip(): + return None, project_display_name.strip() return None, None @@ -93,6 +109,16 @@ def normalize_function_type(raw: Any) -> str | None: return None +def pydantic_to_json_schema(model: Any) -> Any | None: + if model is None: + return None + if hasattr(model, "model_json_schema"): + return to_json_value(model.model_json_schema()) + if hasattr(model, "schema"): + return to_json_value(model.schema()) + return None + + def selector_to_project_placeholder(project: Any) -> str: project_id, project_name = normalize_project_selector(project) if project_id: @@ -102,10 +128,45 @@ def selector_to_project_placeholder(project: Any) -> str: return "" +def import_module_name_from_cwd(cwd: str, source_file: str) -> str | None: + try: + rel = os.path.relpath(source_file, cwd) + except ValueError: + return None + + parent_prefix = os.pardir + os.sep + if rel == os.pardir or rel.startswith(parent_prefix): + return None + + module = os.path.splitext(rel)[0] + module = module.replace("-", "_") + module = module.replace(os.sep, ".") + if os.altsep: + module = module.replace(os.altsep, ".") + return module + + +def package_init_sources_for_module(cwd: str, module_name: str) -> list[str]: + package_parts = [part for part in module_name.split(".")[:-1] if part] + if not package_parts: + return [] + sources: list[str] = [] + current = cwd + for part in package_parts: + current = os.path.join(current, part) + init_path = os.path.join(current, "__init__.py") + if os.path.isfile(init_path): + sources.append(os.path.abspath(init_path)) + return sources + + class Resolver: - async def resolve(self, project: Any) -> str: + def get(self, project: Any) -> str: return selector_to_project_placeholder(project) + async def resolve(self, project: Any) -> str: + return self.get(project) + def clear_registry(registry: Any) -> None: if hasattr(registry, "clear"): @@ -140,12 +201,30 @@ def collect_code_entries(functions_registry: Any) -> list[dict[str, Any]]: normalized_function_type = normalize_function_type(function_type) if normalized_function_type: entry["function_type"] = normalized_function_type + parameters_model = getattr(item, "parameters", None) + if parameters_model is None: + raise ValueError(f"Function {name} has no supplied parameters") + parameters_schema = pydantic_to_json_schema(parameters_model) + if parameters_schema is None: + raise ValueError(f"Function {name} has invalid parameters schema") + function_schema: dict[str, Any] = {"parameters": parameters_schema} + returns_model = getattr(item, "returns", None) + if returns_model is not None: + returns_schema = pydantic_to_json_schema(returns_model) + if returns_schema is not None: + function_schema["returns"] = returns_schema + entry["function_schema"] = function_schema if_exists = getattr(item, "if_exists", None) if isinstance(if_exists, str): entry["if_exists"] = if_exists metadata = getattr(item, "metadata", None) if metadata is not None: entry["metadata"] = to_json_value(metadata) + tags = getattr(item, "tags", None) + if isinstance(tags, list): + normalized_tags = [tag for tag in tags if isinstance(tag, str)] + if normalized_tags: + entry["tags"] = normalized_tags if project_id: entry["project_id"] = project_id if project_name: @@ -166,7 +245,23 @@ async def collect_function_event_entries(prompts_registry: Any) -> list[dict[str for item in items: to_definition = getattr(item, "to_function_definition", None) if callable(to_definition): - definition = to_definition(resolver) + definition: Any + signature = inspect.signature(to_definition) + positional_params = [ + parameter + for parameter in signature.parameters.values() + if parameter.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ] + if len(positional_params) >= 2: + definition = to_definition(None, resolver) + elif len(positional_params) == 1: + definition = to_definition(resolver) + else: + definition = to_definition() if asyncio.iscoroutine(definition): definition = await definition normalized = to_json_value(definition) @@ -188,12 +283,16 @@ async def process_file(file_path: str) -> dict[str, Any]: if cwd not in sys.path: sys.path.insert(0, cwd) + purge_local_modules(cwd, preserve_modules={__name__, "python_runner_common"}) functions_registry, prompts_registry, lazy_loader = load_framework_globals() clear_registry(functions_registry) clear_registry(prompts_registry) - purge_local_modules(cwd, preserve_modules={__name__, "python_runner_common"}) - module_name, extra_paths = resolve_module_info(abs_path) + module_name = import_module_name_from_cwd(cwd, abs_path) + if module_name is None: + module_name, extra_paths = resolve_module_info(abs_path) + else: + extra_paths = [cwd] lazy_ctx = lazy_loader(True) if callable(lazy_loader) else nullcontext() with lazy_ctx: import_file(module_name, abs_path, extra_paths) @@ -205,9 +304,55 @@ async def process_file(file_path: str) -> dict[str, Any]: "entries": entries, } if code_entries: + runner_root = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.abspath(cwd) + path_rest: list[str] = [] + for path_entry in sys.path: + if not path_entry: + continue + entry_abs = os.path.abspath(path_entry) + if entry_abs == project_root: + continue + path_rest.append(entry_abs) + bundled_sources: list[str] = [] + seen_sources: set[str] = set() + for source in collect_python_sources(cwd, abs_path): + source_abs = os.path.abspath(source) + try: + common = os.path.commonpath([source_abs, runner_root]) + except ValueError: + common = "" + if common == runner_root: + continue + try: + project_common = os.path.commonpath([source_abs, project_root]) + except ValueError: + project_common = "" + if project_common != project_root: + continue + covered_by_other_path_root = False + for root in path_rest: + try: + root_common = os.path.commonpath([source_abs, root]) + except ValueError: + continue + if root_common == root: + covered_by_other_path_root = True + break + if covered_by_other_path_root: + continue + if source_abs in seen_sources: + continue + seen_sources.add(source_abs) + bundled_sources.append(source_abs) + for init_source in package_init_sources_for_module(cwd, module_name): + if init_source in seen_sources: + continue + seen_sources.add(init_source) + bundled_sources.append(init_source) file_manifest["python_bundle"] = { "entry_module": module_name, - "sources": collect_python_sources(cwd, abs_path), + "sources": bundled_sources, } clear_registry(functions_registry) diff --git a/scripts/functions-runner.ts b/scripts/functions-runner.ts index b85c7f9..ab8a7de 100644 --- a/scripts/functions-runner.ts +++ b/scripts/functions-runner.ts @@ -1,4 +1,5 @@ import path from "node:path"; +import { createRequire } from "node:module"; import { pathToFileURL } from "node:url"; import { @@ -24,6 +25,9 @@ type CodeRegistryItem = { functionType?: string; ifExists?: string; metadata?: JsonValue; + tags?: unknown; + parameters?: unknown; + returns?: unknown; preview?: string; }; @@ -58,6 +62,8 @@ type CodeEntry = { function_type?: string; if_exists?: string; metadata?: JsonValue; + tags?: string[]; + function_schema?: JsonValue; preview?: string; location: JsonValue; }; @@ -83,7 +89,24 @@ type Manifest = { }; type EvalRegistry = NonNullable; +type ZodToJsonSchemaFn = (schema: unknown) => unknown; + let moduleImportNonce = 0; +let zodToJsonSchemaFn: ZodToJsonSchemaFn | null | undefined; + +const runtimeRequire: NodeRequire | null = + typeof require === "function" ? require : null; + +function safeCreateRequire(modulePath: string): NodeRequire | null { + try { + return createRequire(modulePath); + } catch { + return null; + } +} + +const localRequire = + runtimeRequire ?? safeCreateRequire(path.join(process.cwd(), "package.json")); function freshRegistry(): EvalRegistry { return { @@ -125,6 +148,105 @@ function buildIsolatedImportUrl(absolutePath: string): string { return moduleUrl.href; } +function loadZodToJsonSchemaFn(): ZodToJsonSchemaFn | null { + if (zodToJsonSchemaFn !== undefined) { + return zodToJsonSchemaFn; + } + + const extractConverter = (module: unknown): ZodToJsonSchemaFn | null => { + if ( + module && + typeof module === "object" && + "zodToJsonSchema" in module && + typeof (module as { zodToJsonSchema?: unknown }).zodToJsonSchema === + "function" + ) { + return (module as { zodToJsonSchema: ZodToJsonSchemaFn }).zodToJsonSchema; + } + if ( + module && + typeof module === "object" && + "default" in module && + typeof (module as { default?: unknown }).default === "function" + ) { + return (module as { default: ZodToJsonSchemaFn }).default; + } + return null; + }; + + const requireCandidates: NodeRequire[] = []; + if (localRequire) { + requireCandidates.push(localRequire); + } + const cwdRequire = safeCreateRequire( + path.join(process.cwd(), "package.json"), + ); + if (cwdRequire) { + let exists = false; + for (const candidate of requireCandidates) { + if (candidate === cwdRequire) { + exists = true; + break; + } + } + if (!exists) { + requireCandidates.push(cwdRequire); + } + } + + for (const candidateRequire of requireCandidates) { + try { + const converter = extractConverter( + candidateRequire("zod-to-json-schema"), + ); + if (converter) { + zodToJsonSchemaFn = converter; + return zodToJsonSchemaFn; + } + } catch { + // Try the next location. + } + } + + for (const candidateRequire of requireCandidates) { + try { + const braintrustPkg = candidateRequire.resolve("braintrust/package.json"); + const braintrustRequire = createRequire(braintrustPkg); + const converter = extractConverter( + braintrustRequire("zod-to-json-schema"), + ); + if (converter) { + zodToJsonSchemaFn = converter; + return zodToJsonSchemaFn; + } + } catch { + // Try the next location. + } + } + + zodToJsonSchemaFn = null; + return zodToJsonSchemaFn; +} + +function schemaToJsonSchema(schema: unknown): JsonObject | undefined { + if (schema === undefined || schema === null) { + return undefined; + } + + const converter = loadZodToJsonSchemaFn(); + if (!converter) { + return undefined; + } + + try { + const converted = converter(schema); + const normalized = toJsonValue(converted as JsonValue); + return isJsonObject(normalized) ? normalized : undefined; + } catch { + return undefined; + } +} + async function collectFunctionEvents( items: EventRegistryItem[], includeLegacyPrompts: boolean, @@ -286,8 +408,20 @@ function collectCodeEntries(items: CodeRegistryItem[]): CodeEntry[] { } const selector = asProjectSelector(item.project); + const tags = Array.isArray(item.tags) + ? item.tags.filter((tag): tag is string => typeof tag === "string") + : []; + const parametersSchema = schemaToJsonSchema(item.parameters); + const returnsSchema = schemaToJsonSchema(item.returns); + const functionSchema: JsonObject = {}; + if (parametersSchema) { + functionSchema.parameters = parametersSchema; + } + if (returnsSchema) { + functionSchema.returns = returnsSchema; + } - entries.push({ + const entry: CodeEntry = { kind: "code", project_id: typeof selector.project_id === "string" @@ -314,7 +448,16 @@ function collectCodeEntries(items: CodeRegistryItem[]): CodeEntry[] { type: "function", index, }, - }); + }; + + if (tags.length > 0) { + entry.tags = tags; + } + if (Object.keys(functionSchema).length > 0) { + entry.function_schema = functionSchema; + } + + entries.push(entry); } return entries; diff --git a/src/functions/mod.rs b/src/functions/mod.rs index f57e459..9974284 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -288,6 +288,15 @@ pub(crate) struct PushArgs { )] pub terminate_on_failure: bool, + /// Create referenced projects automatically when they do not exist. + #[arg( + long = "create-missing-projects", + env = "BT_FUNCTIONS_PUSH_CREATE_MISSING_PROJECTS", + default_value_t = true, + value_parser = BoolishValueParser::new() + )] + pub create_missing_projects: bool, + /// Override runner binary (e.g. tsx, vite-node, deno, python). #[arg(long, env = "BT_FUNCTIONS_PUSH_RUNNER", value_name = "RUNNER")] pub runner: Option, diff --git a/src/functions/push.rs b/src/functions/push.rs index b7b9f63..3803c25 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -23,7 +23,7 @@ use crate::functions::report::{ SoftSkipReason, }; use crate::js_runner; -use crate::projects::api::{get_project_by_name, list_projects}; +use crate::projects::api::{create_project, get_project_by_name, list_projects}; use crate::python_runner; use crate::source_language::{classify_runtime_extension, JsExtensionProfile, SourceLanguage}; use crate::ui::{animations_enabled, is_interactive, is_quiet}; @@ -46,6 +46,10 @@ const RUNNER_COMMON_SOURCE: &str = include_str!("../../scripts/runner-common.ts" const PYTHON_RUNNER_COMMON_SOURCE: &str = include_str!("../../scripts/python_runner_common.py"); const PYTHON_BASELINE_DEPS: &[&str] = &["pydantic", "braintrust", "autoevals", "requests", "openai"]; +// Compatibility shim for existing test harnesses and eval workflows that set +// Python interpreter via BT_EVAL_* variables. Preferred path is still +// --runner / BT_FUNCTIONS_PUSH_RUNNER. +const PYTHON_INTERPRETER_ENV_OVERRIDES: &[&str] = &["BT_EVAL_PYTHON_RUNNER", "BT_EVAL_PYTHON"]; #[derive(Debug, Deserialize)] struct RunnerManifest { @@ -102,6 +106,10 @@ struct CodeEntry { #[serde(default)] metadata: Option, #[serde(default)] + tags: Option>, + #[serde(default)] + function_schema: Option, + #[serde(default)] location: Option, #[serde(default)] preview: Option, @@ -421,17 +429,22 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { } } - let mut project_name_cache = - match resolve_named_projects(&auth_ctx, &preflight.named_projects).await { - Ok(cache) => cache, - Err(err) => { - let message = format!("failed to resolve target projects for push: {err}"); - return fail_manifest_preflight( - message, - "skipped because project target resolution failed", - ); - } - }; + let mut project_name_cache = match resolve_named_projects( + &auth_ctx, + &preflight.named_projects, + args.create_missing_projects, + ) + .await + { + Ok(cache) => cache, + Err(err) => { + let message = format!("failed to resolve target projects for push: {err}"); + return fail_manifest_preflight( + message, + "skipped because project target resolution failed", + ); + } + }; if let Err(err) = validate_direct_project_ids(&auth_ctx, &preflight.direct_project_ids).await { let message = format!("failed to validate project ids for push: {err}"); @@ -454,6 +467,7 @@ pub async fn run(base: BaseArgs, args: PushArgs) -> Result<()> { default_project_id.as_deref(), &manifest, &mut project_name_cache, + args.create_missing_projects, ) .await { @@ -784,6 +798,15 @@ async fn push_file( if let Some(metadata) = &code.metadata { obj.insert("metadata".to_string(), metadata.clone()); } + if let Some(tags) = &code.tags { + obj.insert( + "tags".to_string(), + Value::Array(tags.iter().cloned().map(Value::String).collect()), + ); + } + if let Some(function_schema) = &code.function_schema { + obj.insert("function_schema".to_string(), function_schema.clone()); + } let if_exists = code .if_exists .as_deref() @@ -820,6 +843,7 @@ async fn push_file( None, Some(&project_name), project_name_cache, + args.create_missing_projects, ) .await .map_err(|err| FileFailure { @@ -848,7 +872,6 @@ async fn push_file( Value::String(args.if_exists.as_str().to_string()), ); } - default_prompt_function_type(object); } function_events.push(event); @@ -963,36 +986,6 @@ fn calculate_upload_counts(total_entries: usize, ignored_entries: Option) (uploaded_entries, ignored_entries) } -fn default_prompt_function_type(event: &mut Map) { - if !is_prompt_function_event(event) { - return; - } - - if function_type_missing_or_empty(event.get("function_type")) { - event.insert( - "function_type".to_string(), - Value::String("prompt".to_string()), - ); - } -} - -fn is_prompt_function_event(event: &Map) -> bool { - event - .get("function_data") - .and_then(Value::as_object) - .and_then(|function_data| function_data.get("type")) - .and_then(Value::as_str) - == Some("prompt") -} - -fn function_type_missing_or_empty(value: Option<&Value>) -> bool { - match value { - None | Some(Value::Null) => true, - Some(Value::String(s)) => s.trim().is_empty(), - Some(_) => false, - } -} - fn run_functions_runner( args: &PushArgs, files: &[PathBuf], @@ -1040,9 +1033,10 @@ fn run_functions_runner( reason: HardFailureReason::RunnerSpawnFailed, message: format!("failed to materialize Python functions runner: {err}"), })?; - let Some(python) = - python_runner::resolve_python_interpreter(args.runner.as_deref(), &[]) - else { + let Some(python) = python_runner::resolve_python_interpreter( + args.runner.as_deref(), + PYTHON_INTERPRETER_ENV_OVERRIDES, + ) else { return Err(FileFailure { reason: HardFailureReason::RunnerSpawnFailed, message: "No Python interpreter found. Install python or pass --runner." @@ -1113,6 +1107,11 @@ fn collect_classified_files(inputs: &[PathBuf]) -> Result { let mut js_like = BTreeSet::new(); let mut python = BTreeSet::new(); let mut allowed_roots = BTreeSet::new(); + if let Ok(cwd) = std::env::current_dir() { + if let Ok(canonical_cwd) = cwd.canonicalize() { + allowed_roots.insert(canonical_cwd); + } + } let mut explicit_file_inputs = 0usize; let mut explicit_supported_files = 0usize; let mut explicit_js_like = 0usize; @@ -1568,7 +1567,9 @@ fn build_python_bundle_archive( requirements_path: Option<&Path>, runner: Option<&str>, ) -> Result> { - let Some(python) = python_runner::resolve_python_interpreter(runner, &[]) else { + let Some(python) = + python_runner::resolve_python_interpreter(runner, PYTHON_INTERPRETER_ENV_OVERRIDES) + else { bail!("No Python interpreter found. Install python or pass --runner.") }; @@ -2256,16 +2257,37 @@ fn resolve_default_project_id( async fn resolve_named_projects( auth_ctx: &super::AuthContext, named_projects: &BTreeSet, + create_missing_projects: bool, ) -> Result> { let mut project_name_cache = BTreeMap::new(); let mut missing = Vec::new(); for project_name in named_projects { - let project = get_project_by_name(&auth_ctx.client, project_name).await?; - if let Some(project) = project { + if let Some(project) = get_project_by_name(&auth_ctx.client, project_name).await? { project_name_cache.insert(project_name.clone(), project.id); - } else { + continue; + } + + if !create_missing_projects { missing.push(project_name.clone()); + continue; + } + + match create_project(&auth_ctx.client, project_name).await { + Ok(project) => { + project_name_cache.insert(project_name.clone(), project.id); + } + Err(_) => { + // Another writer may have created the project concurrently. + if let Some(project) = get_project_by_name(&auth_ctx.client, project_name).await? { + project_name_cache.insert(project_name.clone(), project.id); + } else { + bail!( + "failed to create project '{project_name}' in org '{}'", + current_org_label(auth_ctx) + ); + } + } } } @@ -2311,6 +2333,7 @@ async fn resolve_manifest_targets( default_project_id: Option<&str>, manifest: &RunnerManifest, project_name_cache: &mut BTreeMap, + create_missing_projects: bool, ) -> Result { let mut entries = Vec::new(); let mut per_file = Vec::with_capacity(manifest.files.len()); @@ -2328,6 +2351,7 @@ async fn resolve_manifest_targets( default_project_id, &selector, project_name_cache, + create_missing_projects, ) .await?; entry_project_ids.push(project_id.clone()); @@ -2374,6 +2398,7 @@ async fn resolve_project_selector( default_project_id: Option<&str>, selector: &ProjectSelector, project_name_cache: &mut BTreeMap, + create_missing_projects: bool, ) -> Result { match selector { ProjectSelector::Id(project_id) => { @@ -2383,6 +2408,7 @@ async fn resolve_project_selector( Some(project_id.as_str()), None, project_name_cache, + create_missing_projects, ) .await } @@ -2393,11 +2419,20 @@ async fn resolve_project_selector( None, Some(project_name.as_str()), project_name_cache, + create_missing_projects, ) .await } ProjectSelector::Fallback => { - resolve_project_id(client, default_project_id, None, None, project_name_cache).await + resolve_project_id( + client, + default_project_id, + None, + None, + project_name_cache, + create_missing_projects, + ) + .await } } } @@ -2408,18 +2443,31 @@ async fn resolve_project_id( project_id: Option<&str>, project_name: Option<&str>, project_name_cache: &mut BTreeMap, + create_missing_projects: bool, ) -> Result { let normalized_project_id = normalize_project_id_field(project_id)?; if let Some(project_id) = normalized_project_id { if let Some(name) = project_id.strip_prefix("name:") { - return resolve_project_name(client, name.trim(), project_name_cache).await; + return resolve_project_name( + client, + name.trim(), + project_name_cache, + create_missing_projects, + ) + .await; } return Ok(project_id); } let normalized_project_name = normalize_project_name_field(project_name)?; if let Some(project_name) = normalized_project_name { - return resolve_project_name(client, project_name.trim(), project_name_cache).await; + return resolve_project_name( + client, + project_name.trim(), + project_name_cache, + create_missing_projects, + ) + .await; } default_project_id.map(ToOwned::to_owned).ok_or_else(|| { @@ -2431,6 +2479,7 @@ async fn resolve_project_name( client: &crate::http::ApiClient, project_name: &str, project_name_cache: &mut BTreeMap, + create_missing_projects: bool, ) -> Result { let project_name = project_name.trim(); if project_name.is_empty() { @@ -2441,9 +2490,18 @@ async fn resolve_project_name( return Ok(cached.clone()); } - let project = get_project_by_name(client, project_name) - .await? - .ok_or_else(|| anyhow!("project '{project_name}' not found"))?; + let project = if let Some(project) = get_project_by_name(client, project_name).await? { + project + } else if create_missing_projects { + match create_project(client, project_name).await { + Ok(project) => project, + Err(_) => get_project_by_name(client, project_name) + .await? + .ok_or_else(|| anyhow!("failed to create project '{project_name}'"))?, + } + } else { + return Err(anyhow!("project '{project_name}' not found")); + }; project_name_cache.insert(project_name.to_string(), project.id.clone()); Ok(project.id) @@ -2734,6 +2792,8 @@ mod tests { function_type: Some("tool".to_string()), if_exists: None, metadata: None, + tags: None, + function_schema: None, location: None, preview: None, })], @@ -2770,6 +2830,7 @@ mod tests { file_flag: vec![], if_exists: IfExistsMode::Error, terminate_on_failure: false, + create_missing_projects: true, runner: None, language: PushLanguage::Auto, requirements: None, @@ -2798,6 +2859,7 @@ mod tests { file_flag: vec![], if_exists: IfExistsMode::Error, terminate_on_failure: false, + create_missing_projects: true, runner: None, language: PushLanguage::Auto, requirements: None, @@ -2879,61 +2941,6 @@ mod tests { assert_eq!(calculate_upload_counts(3, None), (3, 0)); } - #[test] - fn prompt_event_defaults_function_type_when_missing() { - let mut event = Map::new(); - event.insert( - "function_data".to_string(), - serde_json::json!({ - "type": "prompt" - }), - ); - - default_prompt_function_type(&mut event); - - assert_eq!( - event.get("function_type"), - Some(&Value::String("prompt".to_string())) - ); - } - - #[test] - fn prompt_event_preserves_existing_function_type() { - let mut event = Map::new(); - event.insert( - "function_data".to_string(), - serde_json::json!({ - "type": "prompt" - }), - ); - event.insert( - "function_type".to_string(), - Value::String("scorer".to_string()), - ); - - default_prompt_function_type(&mut event); - - assert_eq!( - event.get("function_type"), - Some(&Value::String("scorer".to_string())) - ); - } - - #[test] - fn non_prompt_event_does_not_default_function_type() { - let mut event = Map::new(); - event.insert( - "function_data".to_string(), - serde_json::json!({ - "type": "code" - }), - ); - - default_prompt_function_type(&mut event); - - assert!(!event.contains_key("function_type")); - } - #[test] fn requirements_reference_escape_is_rejected() { let dir = tempfile::tempdir().expect("tempdir"); @@ -3009,6 +3016,8 @@ mod tests { function_type: Some("tool".to_string()), if_exists: None, metadata: None, + tags: None, + function_schema: None, location: Some(serde_json::json!({"type":"function","index":0})), preview: None, })], @@ -3051,6 +3060,8 @@ mod tests { function_type: Some("tool".to_string()), if_exists: None, metadata: None, + tags: None, + function_schema: None, location: Some(serde_json::json!({"type":"function","index":0})), preview: None, })], @@ -3094,6 +3105,8 @@ mod tests { function_type: Some("tool".to_string()), if_exists: None, metadata: None, + tags: None, + function_schema: None, location: Some(serde_json::json!({"type":"function","index":0})), preview: None, })], diff --git a/src/python_runner.rs b/src/python_runner.rs index bf9760c..3f6c92b 100644 --- a/src/python_runner.rs +++ b/src/python_runner.rs @@ -81,4 +81,16 @@ mod tests { let resolved = resolve_python_interpreter(Some("/tmp/python"), &["BT_UNUSED"]); assert_eq!(resolved, Some(PathBuf::from("/tmp/python"))); } + + #[test] + fn env_override_python_runner_is_used() { + unsafe { + std::env::set_var("BT_TEST_PYTHON_RUNNER", "/tmp/from-env-python"); + } + let resolved = resolve_python_interpreter(None, &["BT_TEST_PYTHON_RUNNER"]); + unsafe { + std::env::remove_var("BT_TEST_PYTHON_RUNNER"); + } + assert_eq!(resolved, Some(PathBuf::from("/tmp/from-env-python"))); + } } diff --git a/tests/functions-fixtures/push-help-flags/fixture.json b/tests/functions-fixtures/push-help-flags/fixture.json index 16d6106..4ffc039 100644 --- a/tests/functions-fixtures/push-help-flags/fixture.json +++ b/tests/functions-fixtures/push-help-flags/fixture.json @@ -5,6 +5,7 @@ "--file", "--if-exists", "--terminate-on-failure", + "--create-missing-projects", "--language", "--requirements", "--tsconfig", diff --git a/tests/functions.rs b/tests/functions.rs index 83e7534..d062abd 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -540,6 +540,7 @@ fn functions_push_help_includes_expected_flags() { assert!(stdout.contains("--file")); assert!(stdout.contains("--if-exists")); assert!(stdout.contains("--terminate-on-failure")); + assert!(stdout.contains("--create-missing-projects")); assert!(stdout.contains("--language")); assert!(stdout.contains("--requirements")); assert!(stdout.contains("--tsconfig")); From 5f7d51fdad55382947ac4f10271d197caadc6e60 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Mon, 16 Mar 2026 18:47:14 -0700 Subject: [PATCH 19/28] fix(functions-pull): preserve prompt serialization and identity metadata --- src/functions/pull.rs | 153 +++++++++++++++++++++++++++++++++--------- 1 file changed, 122 insertions(+), 31 deletions(-) diff --git a/src/functions/pull.rs b/src/functions/pull.rs index 543562a..0efe2cf 100644 --- a/src/functions/pull.rs +++ b/src/functions/pull.rs @@ -51,6 +51,8 @@ struct PullFunctionRow { #[derive(Debug, Clone)] struct NormalizedPrompt { + id: String, + version: Option, variable_seed: String, name: String, slug: String, @@ -60,6 +62,8 @@ struct NormalizedPrompt { model: Option, params: Option, tools: Option, + raw_tools_json: Option, + tool_functions: Option, } #[derive(Debug)] @@ -374,6 +378,7 @@ pub async fn run(base: BaseArgs, args: PullArgs) -> Result<()> { &repo, args.force, args.language, + &project_id, &project_name, &file_name, &rows, @@ -615,6 +620,7 @@ fn write_pull_file( repo: &Option, force: bool, language: FunctionsLanguage, + project_id: &str, project_name: &str, file_name: &str, rows: &[PullFunctionRow], @@ -655,18 +661,19 @@ fn write_pull_file( return; } - let rendered = match render_project_file(language, project_name, &display_target, rows) { - Ok(rendered) => rendered, - Err(err) => { - record_pull_file_failure( - summary, - target.display().to_string(), - HardFailureReason::ResponseInvalid, - err.to_string(), - ); - return; - } - }; + let rendered = + match render_project_file(language, project_id, project_name, &display_target, rows) { + Ok(rendered) => rendered, + Err(err) => { + record_pull_file_failure( + summary, + target.display().to_string(), + HardFailureReason::ResponseInvalid, + err.to_string(), + ); + return; + } + }; match write_text_atomic(&target, &rendered) { Ok(()) => { summary.files_written += 1; @@ -959,6 +966,7 @@ fn display_output_path(target: &Path) -> String { fn render_project_file( language: FunctionsLanguage, + project_id: &str, project_name: &str, file_name: &str, rows: &[PullFunctionRow], @@ -973,7 +981,7 @@ fn render_project_file( match language { FunctionsLanguage::Typescript => { - render_project_file_ts(project_name, file_name, &normalized) + render_project_file_ts(project_id, project_name, file_name, &normalized) } FunctionsLanguage::Python => render_project_file_py(project_name, file_name, &normalized), } @@ -1042,27 +1050,31 @@ fn normalize_prompt_row(row: &PullFunctionRow) -> Result { .filter(|value| !is_empty_render_value(value)) .cloned(); - let mut tools: Vec = prompt_data - .get("tool_functions") - .and_then(Value::as_array) - .cloned() - .unwrap_or_default(); - if let Some(raw_tools) = prompt_block.get("tools").and_then(Value::as_str) { - if !raw_tools.trim().is_empty() { - if let Ok(parsed) = serde_json::from_str::(raw_tools) { - if let Some(items) = parsed.as_array() { - tools.extend(items.iter().cloned()); - } - } - } - } - let tools = if tools.is_empty() { - None - } else { - Some(Value::Array(tools)) + let raw_tools_json = prompt_block + .get("tools") + .and_then(Value::as_str) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned); + let tools = match prompt_block.get("tools") { + Some(Value::String(_)) | None => None, + Some(other) if is_empty_render_value(other) => None, + Some(other) => Some(other.clone()), }; + let tool_functions = prompt_data + .get("tool_functions") + .filter(|value| !is_empty_render_value(value)) + .cloned(); + let version = row + ._xact_id + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned); Ok(NormalizedPrompt { + id: row.id.clone(), + version, variable_seed: row.slug.clone(), name: row.name.clone(), slug: row.slug.clone(), @@ -1072,10 +1084,13 @@ fn normalize_prompt_row(row: &PullFunctionRow) -> Result { model, params, tools, + raw_tools_json, + tool_functions, }) } fn render_project_file_ts( + project_id: &str, project_name: &str, file_name: &str, prompts: &[NormalizedPrompt], @@ -1098,6 +1113,7 @@ fn render_project_file_ts( out.push_str("import braintrust from \"braintrust\";\n\n"); out.push_str("const project = braintrust.projects.create({\n"); + out.push_str(&format!(" id: {},\n", serde_json::to_string(project_id)?)); out.push_str(&format!( " name: {},\n", serde_json::to_string(project_name)? @@ -1114,8 +1130,12 @@ fn render_project_file_ts( ); let mut body_lines = Vec::new(); + body_lines.push(format!(" id: {},", serde_json::to_string(&row.id)?)); body_lines.push(format!(" name: {},", serde_json::to_string(&row.name)?)); body_lines.push(format!(" slug: {},", serde_json::to_string(&row.slug)?)); + if let Some(version) = &row.version { + body_lines.push(format!(" version: {},", serde_json::to_string(version)?)); + } if let Some(description) = &row.description { body_lines.push(format!( @@ -1139,6 +1159,18 @@ fn render_project_file_ts( if let Some(tools) = &row.tools { body_lines.push(format!(" tools: {},", format_ts_value(tools, 2))); } + if let Some(raw_tools_json) = &row.raw_tools_json { + body_lines.push(format!( + " tools: JSON.parse({}),", + serde_json::to_string(raw_tools_json)? + )); + } + if let Some(tool_functions) = &row.tool_functions { + body_lines.push(format!( + " toolFunctions: {},", + format_ts_value(tool_functions, 2) + )); + } out.push_str(&format!( "export const {var_name} = project.prompts.create({{\n" @@ -1155,6 +1187,7 @@ fn render_project_file_py( file_name: &str, prompts: &[NormalizedPrompt], ) -> Result { + let needs_json_import = prompts.iter().any(|row| row.raw_tools_json.is_some()); let mut out = String::new(); out.push_str("# This file was automatically generated by bt functions pull. You can\n"); out.push_str("# generate it again by running:\n"); @@ -1170,6 +1203,9 @@ fn render_project_file_py( "# $ bt functions push --file {}\n\n", serde_json::to_string(file_name)? )); + if needs_json_import { + out.push_str("import json\n"); + } out.push_str("import braintrust\n\n"); out.push_str(&format!( "project = braintrust.projects.create(name={})\n\n", @@ -1213,6 +1249,18 @@ fn render_project_file_py( if let Some(tools) = &row.tools { out.push_str(&format!(" tools={},\n", format_py_value(tools, 4))); } + if let Some(raw_tools_json) = &row.raw_tools_json { + out.push_str(&format!( + " tools=json.loads({}),\n", + format_py_value(&Value::String(raw_tools_json.clone()), 4) + )); + } + if let Some(tool_functions) = &row.tool_functions { + out.push_str(&format!( + " tool_functions={},\n", + format_py_value(tool_functions, 4) + )); + } out.push_str(")\n\n"); } @@ -1738,6 +1786,7 @@ mod tests { let rendered = render_project_file( FunctionsLanguage::Python, + "p1", "woohoo", "braintrust/woohoo.py", &[row], @@ -1753,6 +1802,48 @@ mod tests { assert!(rendered.contains("model=\"gpt-4o-mini\"")); } + #[test] + fn render_project_file_typescript_includes_prompt_identity() { + let row = PullFunctionRow { + id: "f1".to_string(), + name: "Basic math".to_string(), + slug: "basic-math".to_string(), + project_id: "p1".to_string(), + project_name: Some("woohoo".to_string()), + description: None, + prompt_data: Some(serde_json::json!({ + "prompt": { + "type": "chat", + "messages": [ + { "content": "Hello", "role": "system" } + ] + }, + "options": { + "model": "gpt-4o-mini" + } + })), + function_data: Some(serde_json::json!({ "type": "prompt" })), + created: None, + _xact_id: Some("123".to_string()), + }; + + let rendered = render_project_file( + FunctionsLanguage::Typescript, + "p1", + "woohoo", + "braintrust/woohoo.ts", + &[row], + ) + .expect("rendered"); + + assert!(rendered.contains("const project = braintrust.projects.create({")); + assert!(rendered.contains(" id: \"p1\",")); + assert!(rendered.contains(" name: \"woohoo\",")); + assert!(rendered.contains("export const basicMath = project.prompts.create({")); + assert!(rendered.contains(" id: \"f1\",")); + assert!(rendered.contains(" version: \"123\",")); + } + #[test] fn format_ts_value_unquotes_safe_keys_only() { let value = serde_json::json!({ From 0cd82adb931327bade36ff1ba28f7a47838ee175 Mon Sep 17 00:00:00 2001 From: Parker Henderson Date: Tue, 17 Mar 2026 16:28:44 -0700 Subject: [PATCH 20/28] test(ci): add functions test to GitHub Actions workflow and fix more tests --- .github/workflows/tests.yml | 2 ++ scripts/functions-runner.ts | 18 ++++++++++++++---- tests/functions.rs | 24 ++++++++++++++++++++---- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 742cfba..0dd22e2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -71,6 +71,8 @@ jobs: corepack prepare pnpm@10.28.2 --activate - name: Run eval fixtures run: cargo test --test eval_fixtures + - name: Run functions fixtures + run: cargo test --test functions eval-tests-python: name: eval-tests-python (py ${{ matrix.python-version }}) diff --git a/scripts/functions-runner.ts b/scripts/functions-runner.ts index ab8a7de..20fc308 100644 --- a/scripts/functions-runner.ts +++ b/scripts/functions-runner.ts @@ -233,6 +233,11 @@ function schemaToJsonSchema(schema: unknown): JsonObject | undefined { return undefined; } + const normalizedSchema = toJsonValue(schema as JsonValue); + if (isJsonObject(normalizedSchema)) { + return normalizedSchema; + } + const converter = loadZodToJsonSchemaFn(); if (!converter) { return undefined; @@ -411,12 +416,17 @@ function collectCodeEntries(items: CodeRegistryItem[]): CodeEntry[] { const tags = Array.isArray(item.tags) ? item.tags.filter((tag): tag is string => typeof tag === "string") : []; + if (item.parameters === undefined || item.parameters === null) { + throw new Error(`Function ${item.name} has no supplied parameters`); + } const parametersSchema = schemaToJsonSchema(item.parameters); - const returnsSchema = schemaToJsonSchema(item.returns); - const functionSchema: JsonObject = {}; - if (parametersSchema) { - functionSchema.parameters = parametersSchema; + if (!parametersSchema) { + throw new Error(`Function ${item.name} has invalid parameters schema`); } + const returnsSchema = schemaToJsonSchema(item.returns); + const functionSchema: JsonObject = { + parameters: parametersSchema, + }; if (returnsSchema) { functionSchema.returns = returnsSchema; } diff --git a/tests/functions.rs b/tests/functions.rs index d062abd..5f18f8e 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -724,11 +724,17 @@ spec.loader.exec_module(module) class TypeEnum: value = "tool" +class Params: + @staticmethod + def model_json_schema(): + return {"type": "object", "properties": {}} + class Item: def __init__(self): self.name = "my-tool" self.slug = "my-tool" self.type_ = TypeEnum() + self.parameters = Params self.preview = "def handler(x):\\n return x" entries = module.collect_code_entries([Item()]) @@ -783,6 +789,7 @@ globalThis._evals.functions.push({ name: "js-tool", slug: "js-tool", type: "tool", + parameters: { type: "object", properties: {} }, preview: "export function handler() { return 1; }" }); "#, @@ -928,6 +935,7 @@ globalThis._evals.functions.push({ name: "js-tool-b", slug: "js-tool-b", type: "tool", + parameters: { type: "object", properties: {} }, preview: "export function b() { return 2; }" }); export const b = 2; @@ -944,6 +952,7 @@ globalThis._evals.functions.push({ name: "js-tool-a", slug: "js-tool-a", type: "tool", + parameters: { type: "object", properties: {} }, preview: "export function a() { return 1; }" }); "#, @@ -1091,11 +1100,17 @@ fn functions_python_runner_emits_valid_manifest_with_bundle() { class TypeEnum: value = "tool" +class Params: + @staticmethod + def model_json_schema(): + return {"type": "object", "properties": {}} + class Item: def __init__(self): self.name = "py-tool" self.slug = "py-tool" self.type_ = TypeEnum() + self.parameters = Params self.preview = "def handler(x):\n return x" functions.append(Item()) @@ -1108,11 +1123,15 @@ functions.append(Item()) python_path_entries.extend(std::env::split_paths(&existing)); } let python_path = std::env::join_paths(python_path_entries).expect("join PYTHONPATH"); + let expected_source = sample_path + .canonicalize() + .expect("canonicalize sample file"); let output = Command::new(&python) + .current_dir(tmp.path()) .env("PYTHONPATH", python_path) .arg(&runner_script) - .arg(&sample_path) + .arg(&expected_source) .output() .expect("run python functions runner"); if !output.status.success() { @@ -1141,9 +1160,6 @@ functions.append(Item()) .and_then(Value::as_str) .expect("source_file"), ); - let expected_source = sample_path - .canonicalize() - .expect("canonicalize sample file"); assert_eq!( reported_source .canonicalize() From 1e8e17c214fe53d508b374c5f38c4f72d0b9185b Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Fri, 13 Mar 2026 20:12:25 +0000 Subject: [PATCH 21/28] add support for pushing sandboxes --- scripts/functions-runner.py | 121 ++++++++++- scripts/functions-runner.ts | 109 ++++++++++ src/functions/mod.rs | 3 + src/functions/push.rs | 51 +++++ .../list-sandbox-type-parses/fixture.json | 5 + tests/functions.rs | 196 ++++++++++++++++++ 6 files changed, 478 insertions(+), 7 deletions(-) create mode 100644 tests/functions-fixtures/list-sandbox-type-parses/fixture.json diff --git a/scripts/functions-runner.py b/scripts/functions-runner.py index 1f16f06..aa7bd05 100644 --- a/scripts/functions-runner.py +++ b/scripts/functions-runner.py @@ -3,6 +3,7 @@ import inspect import json import os +import re import sys from contextlib import nullcontext from typing import Any @@ -42,7 +43,7 @@ def to_json_value(value: Any) -> Any: return str(value) -def load_framework_globals() -> tuple[Any, Any, Any]: +def load_framework_globals() -> tuple[Any, Any, Any, Any]: # Prefer current SDK layout first: # - braintrust.framework2 exposes module-level `global_` # - braintrust.framework exposes `_set_lazy_load` @@ -50,13 +51,23 @@ def load_framework_globals() -> tuple[Any, Any, Any]: from braintrust.framework import _set_lazy_load as lazy from braintrust.framework2 import global_ as global_state - return global_state.functions, global_state.prompts, lazy + try: + from braintrust.framework import _evals + except (ImportError, ModuleNotFoundError): + _evals = None + + return global_state.functions, global_state.prompts, lazy, _evals except (ImportError, ModuleNotFoundError): # Backward compatibility with older SDK layout. from braintrust.framework2.global_ import functions, prompts from braintrust.framework2.lazy_load import _set_lazy_load as lazy - return functions, prompts, lazy + try: + from braintrust.framework import _evals + except (ImportError, ModuleNotFoundError): + _evals = None + + return functions, prompts, lazy, _evals def normalize_project_selector(project: Any) -> tuple[str | None, str | None]: @@ -277,16 +288,105 @@ async def collect_function_event_entries(prompts_registry: Any) -> list[dict[str return entries +def slugify(text: str) -> str: + return re.sub(r"^-|-$", "", re.sub(r"[^a-z0-9]+", "-", text.lower())) + + +def collect_evaluator_entries(evals_registry: Any, source_file: str) -> list[dict[str, Any]]: + if evals_registry is None: + return [] + + evaluators = getattr(evals_registry, "evaluators", None) + if not evaluators or not isinstance(evaluators, dict): + return [] + + entries: list[dict[str, Any]] = [] + stem_base, _ = os.path.splitext(os.path.basename(source_file)) + stem = re.sub(r"\.eval$", "", stem_base) + + for eval_name, instance in evaluators.items(): + if instance is None: + continue + evaluator = getattr(instance, "evaluator", None) + if evaluator is None: + continue + + project_name = getattr(evaluator, "project_name", None) + project_id, proj_name = normalize_project_selector( + {"project_name": project_name} if isinstance(project_name, str) else None + ) + + scores = getattr(evaluator, "scores", []) or [] + score_descriptors = [ + {"name": getattr(score, "__name__", f"scorer_{i}")} + for i, score in enumerate(scores) + ] + + evaluator_definition: dict[str, Any] = {"scores": score_descriptors} + + raw_params = getattr(evaluator, "parameters", None) + if raw_params is not None: + marker = getattr(raw_params, "__braintrust_parameters_marker", None) + if marker is True: + evaluator_definition["parameters"] = { + "type": "braintrust.parameters", + "schema": getattr(raw_params, "schema", None), + "source": { + "parametersId": getattr(raw_params, "id", None), + "slug": getattr(raw_params, "slug", None), + "name": getattr(raw_params, "name", None), + "projectId": getattr(raw_params, "projectId", None), + "version": getattr(raw_params, "version", None), + }, + } + else: + serialized = to_json_value(raw_params) + if serialized is not None: + evaluator_definition["parameters"] = serialized + + base_entry: dict[str, Any] = {"kind": "code"} + if project_id: + base_entry["project_id"] = project_id + if proj_name: + base_entry["project_name"] = proj_name + + # Sandbox entry only — task and scorer entries are pushed separately + # when the eval is actually run, matching the Python SDK behavior. + sandbox_entry = { + **base_entry, + "name": f"Eval {eval_name} sandbox", + "slug": slugify(f"{stem}-{eval_name}-sandbox"), + "function_type": "sandbox", + "location": { + "type": "sandbox", + "sandbox_spec": {"provider": "lambda"}, + "entrypoints": [source_file], + "eval_name": eval_name, + "evaluator_definition": evaluator_definition, + }, + "metadata": {"_bt_sandbox_group_name": stem}, + } + entries.append(sandbox_entry) + + return entries + + async def process_file(file_path: str) -> dict[str, Any]: abs_path = os.path.abspath(file_path) cwd = os.getcwd() if cwd not in sys.path: sys.path.insert(0, cwd) - purge_local_modules(cwd, preserve_modules={__name__, "python_runner_common"}) - functions_registry, prompts_registry, lazy_loader = load_framework_globals() + functions_registry, prompts_registry, lazy_loader, evals_registry = load_framework_globals() clear_registry(functions_registry) clear_registry(prompts_registry) + if ( + evals_registry is not None + and hasattr(evals_registry, "evaluators") + and isinstance(evals_registry.evaluators, dict) + ): + evals_registry.evaluators.clear() + purge_local_modules(cwd, preserve_modules={__name__, "python_runner_common"}) module_name = import_module_name_from_cwd(cwd, abs_path) if module_name is None: @@ -298,12 +398,13 @@ async def process_file(file_path: str) -> dict[str, Any]: import_file(module_name, abs_path, extra_paths) code_entries = collect_code_entries(functions_registry) event_entries = await collect_function_event_entries(prompts_registry) - entries = [*code_entries, *event_entries] + evaluator_entries = collect_evaluator_entries(evals_registry, abs_path) + entries = [*code_entries, *event_entries, *evaluator_entries] file_manifest: dict[str, Any] = { "source_file": abs_path, "entries": entries, } - if code_entries: + if code_entries or evaluator_entries: runner_root = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(cwd) path_rest: list[str] = [] @@ -357,6 +458,12 @@ async def process_file(file_path: str) -> dict[str, Any]: clear_registry(functions_registry) clear_registry(prompts_registry) + if ( + evals_registry is not None + and hasattr(evals_registry, "evaluators") + and isinstance(evals_registry.evaluators, dict) + ): + evals_registry.evaluators.clear() return file_manifest diff --git a/scripts/functions-runner.ts b/scripts/functions-runner.ts index 20fc308..1a5b9e8 100644 --- a/scripts/functions-runner.ts +++ b/scripts/functions-runner.ts @@ -88,6 +88,20 @@ type Manifest = { files: ManifestFile[]; }; +function slugify(input: string): string { + return input + .toLowerCase() + .replace(/[^a-z0-9]+/g, "-") + .replace(/^-|-$/g, ""); +} + +function extractScoreName(score: unknown, idx: number): string { + if (typeof score === "function" && typeof score.name === "string") { + return score.name || `scorer_${idx}`; + } + return `scorer_${idx}`; +} + type EvalRegistry = NonNullable; type ZodToJsonSchemaFn = (schema: unknown) => unknown; @@ -473,6 +487,97 @@ function collectCodeEntries(items: CodeRegistryItem[]): CodeEntry[] { return entries; } +function collectEvaluatorEntries( + evaluators: Record, + sourceFilePath: string, +): CodeEntry[] { + const entries: CodeEntry[] = []; + const ext = path.extname(sourceFilePath); + const stem = path.basename(sourceFilePath, ext).replace(/\.eval$/, ""); + + for (const [evalName, entry] of Object.entries(evaluators)) { + if (!entry || typeof entry !== "object") { + continue; + } + + const evaluator = (entry as Record).evaluator; + if (!evaluator || typeof evaluator !== "object") { + continue; + } + + const evalObj = evaluator as Record; + const projectName = + typeof evalObj.project_name === "string" ? evalObj.project_name : undefined; + const scores = Array.isArray(evalObj.scores) ? evalObj.scores : []; + + const selector = asProjectSelector( + typeof projectName === "string" ? { name: projectName } : undefined, + ); + const projectId = + typeof selector.project_id === "string" ? selector.project_id : undefined; + const selectorProjectName = + typeof selector.project_name === "string" + ? selector.project_name + : undefined; + + const scoreDescriptors = scores.map((s: unknown, i: number) => ({ + name: extractScoreName(s, i), + })); + + const evaluatorDefinition: JsonObject = { + scores: scoreDescriptors as JsonValue, + }; + + const rawParams = evalObj.parameters; + if (rawParams !== undefined && rawParams !== null) { + const marker = + rawParams !== null && + typeof rawParams === "object" && + (rawParams as Record).__braintrust_parameters_marker === true; + if (marker) { + const paramObj = rawParams as Record; + evaluatorDefinition.parameters = toJsonValue({ + type: "braintrust.parameters", + schema: paramObj.schema, + source: { + parametersId: paramObj.id, + slug: paramObj.slug, + name: paramObj.name, + projectId: paramObj.projectId, + version: paramObj.version, + }, + } as JsonValue); + } else { + const serialized = toJsonValue(rawParams as JsonValue); + if (serialized !== undefined) { + evaluatorDefinition.parameters = serialized; + } + } + } + + // Sandbox entry only — task and scorer entries are pushed separately + // when the eval is actually run, matching the Python SDK behavior. + entries.push({ + kind: "code", + project_id: projectId, + project_name: selectorProjectName, + name: `Eval ${evalName} sandbox`, + slug: slugify(`${stem}-${evalName}-sandbox`), + function_type: "sandbox", + location: { + type: "sandbox", + sandbox_spec: { provider: "lambda" }, + entrypoints: [sourceFilePath], + eval_name: evalName, + evaluator_definition: evaluatorDefinition as JsonValue, + } as JsonValue, + metadata: { _bt_sandbox_group_name: stem }, + }); + } + + return entries; +} + async function processFile(filePath: string): Promise { const absolutePath = path.resolve(process.cwd(), filePath); const fallbackRegistry = freshRegistry(); @@ -492,6 +597,10 @@ async function processFile(filePath: string): Promise { registry.parameters as EventRegistryItem[], false, )), + ...collectEvaluatorEntries( + registry.evaluators as Record, + absolutePath, + ), ]; return { diff --git a/src/functions/mod.rs b/src/functions/mod.rs index 9974284..eadc100 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -35,6 +35,7 @@ pub enum FunctionTypeFilter { Classifier, Tag, Parameters, + Sandbox, } impl FunctionTypeFilter { @@ -50,6 +51,7 @@ impl FunctionTypeFilter { Self::Classifier => "classifier", Self::Tag => "tag", Self::Parameters => "parameters", + Self::Sandbox => "sandbox", } } @@ -73,6 +75,7 @@ impl FunctionTypeFilter { Self::Classifier => "classifiers", Self::Tag => "tags", Self::Parameters => "parameters", + Self::Sandbox => "sandboxes", } } } diff --git a/src/functions/push.rs b/src/functions/push.rs index 3803c25..a83e58f 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -3263,6 +3263,57 @@ mod tests { ); } + #[test] + fn code_function_data_passes_through_sandbox_location() { + let runtime = RuntimeContext { + runtime: "node".to_string(), + version: "20.0.0".to_string(), + }; + let sandbox_location = serde_json::json!({ + "type": "sandbox", + "sandbox_spec": { "provider": "lambda" }, + "entrypoints": ["/tmp/eval.ts"], + "eval_name": "my-eval", + "evaluator_definition": { + "scores": [{ "name": "accuracy" }] + } + }); + let value = build_code_function_data( + &runtime, + sandbox_location.clone(), + "bundle-sandbox-1", + None, + ); + + assert_eq!(value["type"], "code"); + assert_eq!(value["data"]["type"], "bundle"); + assert_eq!(value["data"]["bundle_id"], "bundle-sandbox-1"); + assert_eq!(value["data"]["location"], sandbox_location); + assert!(value["data"].get("preview").is_none()); + } + + #[test] + fn code_function_data_passes_through_experiment_location() { + let runtime = RuntimeContext { + runtime: "node".to_string(), + version: "20.0.0".to_string(), + }; + let experiment_location = serde_json::json!({ + "type": "experiment", + "eval_name": "my-eval", + "position": { "type": "task" } + }); + let value = build_code_function_data( + &runtime, + experiment_location.clone(), + "bundle-task-1", + None, + ); + + assert_eq!(value["type"], "code"); + assert_eq!(value["data"]["location"], experiment_location); + } + fn test_base_args() -> BaseArgs { BaseArgs { json: false, diff --git a/tests/functions-fixtures/list-sandbox-type-parses/fixture.json b/tests/functions-fixtures/list-sandbox-type-parses/fixture.json new file mode 100644 index 0000000..642a921 --- /dev/null +++ b/tests/functions-fixtures/list-sandbox-type-parses/fixture.json @@ -0,0 +1,5 @@ +{ + "command": ["functions", "list", "--type", "sandbox"], + "expect_success": false, + "stderr_not_contains": ["invalid value 'sandbox'"] +} diff --git a/tests/functions.rs b/tests/functions.rs index 5f18f8e..c9e6a17 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -1909,6 +1909,202 @@ exit 24 ); } +#[cfg(unix)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn functions_push_sandbox_entries_reach_api() { + if !command_exists("node") { + eprintln!("Skipping functions_push_sandbox_entries_reach_api (node not installed)."); + return; + } + + let state = Arc::new(MockServerState::default()); + state + .projects + .lock() + .expect("projects lock") + .push(MockProject { + id: "proj_mock".to_string(), + name: "mock-project".to_string(), + org_id: "org_mock".to_string(), + }); + let server = MockServer::start(state.clone()).await; + + let tmp = tempdir().expect("tempdir"); + let source = tmp.path().join("my-eval.js"); + std::fs::write( + &source, + "globalThis._evals ??= { functions: [], prompts: [], parameters: [], evaluators: {}, reporters: {} };\n", + ) + .expect("write source file"); + + let runner = tmp.path().join("mock-runner.sh"); + std::fs::write( + &runner, + r#"#!/bin/sh +set -eu +_runner_script="$1" +shift +_runner_name="$(basename "$_runner_script")" + +if [ "$_runner_name" = "functions-runner.ts" ]; then +node - "$@" <<'NODE' +const path = require("node:path"); +const files = process.argv.slice(2); +const manifest = { + runtime_context: { runtime: "node", version: process.versions.node || "unknown" }, + files: files.map((file) => { + const abs = path.resolve(file); + return { + source_file: abs, + entries: [ + { + kind: "code", + project_name: "mock-project", + name: "Eval my-eval sandbox", + slug: "my-eval-my-eval-sandbox", + function_type: "sandbox", + location: { + type: "sandbox", + sandbox_spec: { provider: "lambda" }, + entrypoints: [abs], + eval_name: "my-eval", + evaluator_definition: { scores: [{ name: "accuracy" }] } + }, + metadata: { _bt_sandbox_group_name: "my-eval" } + } + ] + }; + }) +}; +process.stdout.write(JSON.stringify(manifest)); +NODE +exit 0 +fi + +if [ "$_runner_name" = "functions-bundler.ts" ]; then + _source_file="$1" + _output_file="$2" + cp "$_source_file" "$_output_file" + exit 0 +fi + +echo "unexpected runner script: $_runner_name" >&2 +exit 24 +"#, + ) + .expect("write mock runner"); + use std::os::unix::fs::PermissionsExt; + let mut perms = std::fs::metadata(&runner) + .expect("runner metadata") + .permissions(); + perms.set_mode(0o755); + std::fs::set_permissions(&runner, perms).expect("runner permissions"); + + let output = Command::new(bt_binary_path()) + .current_dir(tmp.path()) + .args([ + "functions", + "--json", + "push", + "--file", + source + .to_str() + .expect("source path should be valid UTF-8 for test"), + "--language", + "javascript", + "--runner", + runner + .to_str() + .expect("runner path should be valid UTF-8 for test"), + "--if-exists", + "replace", + ]) + .env("BRAINTRUST_API_KEY", "test-key") + .env("BRAINTRUST_ORG_NAME", "test-org") + .env("BRAINTRUST_API_URL", &server.base_url) + .env("BRAINTRUST_APP_URL", &server.base_url) + .env("BRAINTRUST_NO_COLOR", "1") + .env_remove("BRAINTRUST_PROFILE") + .output() + .expect("run bt functions push"); + + server.stop().await; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("mock push failed:\n{stderr}"); + } + + let summary: Value = serde_json::from_slice(&output.stdout).expect("parse push summary"); + assert_eq!(summary["status"].as_str(), Some("success")); + assert_eq!(summary["uploaded_files"].as_u64(), Some(1)); + assert_eq!(summary["failed_files"].as_u64(), Some(0)); + + let inserted = state + .inserted_functions + .lock() + .expect("inserted functions lock") + .clone(); + assert_eq!( + inserted.len(), + 1, + "expected 1 inserted function (sandbox only)" + ); + + let sandbox_obj = inserted[0].as_object().expect("sandbox should be an object"); + assert_eq!( + sandbox_obj.get("slug").and_then(Value::as_str), + Some("my-eval-my-eval-sandbox") + ); + assert_eq!( + sandbox_obj.get("function_type").and_then(Value::as_str), + Some("sandbox") + ); + + // Verify function_data.data.location is sandbox type + let function_data = sandbox_obj + .get("function_data") + .and_then(Value::as_object) + .expect("function_data object"); + assert_eq!( + function_data.get("type").and_then(Value::as_str), + Some("code") + ); + let data = function_data + .get("data") + .and_then(Value::as_object) + .expect("function_data.data object"); + let location = data + .get("location") + .and_then(Value::as_object) + .expect("location object"); + assert_eq!( + location.get("type").and_then(Value::as_str), + Some("sandbox") + ); + let sandbox_spec = location + .get("sandbox_spec") + .and_then(Value::as_object) + .expect("sandbox_spec object"); + assert_eq!( + sandbox_spec.get("provider").and_then(Value::as_str), + Some("lambda") + ); + + // Verify metadata + let metadata = sandbox_obj + .get("metadata") + .and_then(Value::as_object) + .expect("metadata object"); + assert_eq!( + metadata + .get("_bt_sandbox_group_name") + .and_then(Value::as_str), + Some("my-eval") + ); + +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn functions_pull_works_against_mock_api() { let state = Arc::new(MockServerState::default()); From 9a70a3d9357f8093a24c9528c48b220d317050af Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Fri, 13 Mar 2026 21:18:35 +0000 Subject: [PATCH 22/28] fixes --- scripts/functions-bundler.ts | 5 ++++- scripts/functions-runner.py | 2 +- scripts/functions-runner.ts | 2 +- src/functions/push.rs | 10 +++++++++- tests/functions.rs | 16 ++++++++++++++++ 5 files changed, 31 insertions(+), 4 deletions(-) diff --git a/scripts/functions-bundler.ts b/scripts/functions-bundler.ts index 02b9d11..de0af0b 100644 --- a/scripts/functions-bundler.ts +++ b/scripts/functions-bundler.ts @@ -275,7 +275,10 @@ async function main(): Promise { const externalPackages = parseExternalPackages( process.env.BT_FUNCTIONS_PUSH_EXTERNAL_PACKAGES, ); - const external = buildExternalPackagePatterns(externalPackages); + const selfContained = process.env.BT_FUNCTIONS_PUSH_SELF_CONTAINED === "1"; + const external = selfContained + ? ["fsevents", "chokidar"] + : buildExternalPackagePatterns(externalPackages); const tsconfig = loadTsconfigPath(); const outputDir = path.dirname(outputFile); diff --git a/scripts/functions-runner.py b/scripts/functions-runner.py index aa7bd05..a140a19 100644 --- a/scripts/functions-runner.py +++ b/scripts/functions-runner.py @@ -360,7 +360,7 @@ def collect_evaluator_entries(evals_registry: Any, source_file: str) -> list[dic "location": { "type": "sandbox", "sandbox_spec": {"provider": "lambda"}, - "entrypoints": [source_file], + "entrypoints": [os.path.relpath(source_file)], "eval_name": eval_name, "evaluator_definition": evaluator_definition, }, diff --git a/scripts/functions-runner.ts b/scripts/functions-runner.ts index 1a5b9e8..cd49251 100644 --- a/scripts/functions-runner.ts +++ b/scripts/functions-runner.ts @@ -567,7 +567,7 @@ function collectEvaluatorEntries( location: { type: "sandbox", sandbox_spec: { provider: "lambda" }, - entrypoints: [sourceFilePath], + entrypoints: [path.relative(process.cwd(), sourceFilePath)], eval_name: evalName, evaluator_definition: evaluatorDefinition as JsonValue, } as JsonValue, diff --git a/src/functions/push.rs b/src/functions/push.rs index a83e58f..5d4de74 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -716,10 +716,14 @@ async fn push_file( let mut function_events: Vec = Vec::new(); + let has_sandbox_entries = code_entries + .iter() + .any(|(code, _)| code.function_type.as_deref() == Some("sandbox")); + if !code_entries.is_empty() { let (upload_bytes, content_encoding) = match selected_language { SourceLanguage::JsLike => { - let bundle_bytes = build_js_bundle(source_path, args)?; + let bundle_bytes = build_js_bundle(source_path, args, has_sandbox_entries)?; let gzipped = gzip_bytes(&bundle_bytes).map_err(|err| FileFailure { reason: HardFailureReason::BundleUploadFailed, message: format!("failed to gzip {}: {err}", source_path.display()), @@ -922,6 +926,7 @@ async fn push_file( fn build_js_bundle( source_path: &Path, args: &PushArgs, + self_contained: bool, ) -> std::result::Result, FileFailure> { let build_dir = TempBuildDir::create("bt-functions-js-bundle").map_err(|err| FileFailure { reason: HardFailureReason::BundleUploadFailed, @@ -954,6 +959,9 @@ fn build_js_bundle( args.external_packages.join(","), ); } + if self_contained { + command.env("BT_FUNCTIONS_PUSH_SELF_CONTAINED", "1"); + } let output = command.output().map_err(|err| FileFailure { reason: HardFailureReason::RunnerSpawnFailed, diff --git a/tests/functions.rs b/tests/functions.rs index c9e6a17..fa3de7c 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -1091,6 +1091,22 @@ fn functions_python_runner_emits_valid_manifest_with_bundle() { "from contextlib import nullcontext\n\ndef _set_lazy_load(_enabled):\n return nullcontext()\n", ) .expect("write lazy_load.py"); + std::fs::write( + stub_root.join("braintrust").join("framework.py"), + concat!( + "from contextlib import nullcontext\n", + "\n", + "def _set_lazy_load(_enabled):\n", + " return nullcontext()\n", + "\n", + "class _EvalFile:\n", + " def __init__(self):\n", + " self.evaluators = {}\n", + "\n", + "_evals = _EvalFile()\n", + ), + ) + .expect("write framework.py"); let sample_path = tmp.path().join("sample_tool.py"); std::fs::write( From 0aafaf19c8e978fdd4578fcce5615f484f0806f8 Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Fri, 13 Mar 2026 22:20:12 +0000 Subject: [PATCH 23/28] fixes --- .gitignore | 1 + scripts/eval-runner.py | 15 +++++++++++++++ scripts/functions-runner.py | 11 ++++++++--- scripts/python_runner_common.py | 12 ++++++++++++ src/eval.rs | 26 ++++++++------------------ src/functions/push.rs | 8 ++++++++ 6 files changed, 52 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index da0bacf..a0fe928 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ tests/evals/js/eval-bun/test-data.txt __pycache__ bt-sync +*.env \ No newline at end of file diff --git a/scripts/eval-runner.py b/scripts/eval-runner.py index 8742375..e9d3ce0 100755 --- a/scripts/eval-runner.py +++ b/scripts/eval-runner.py @@ -435,6 +435,21 @@ def load_evaluators(files: list[str]) -> tuple[list[EvaluatorInstance], dict[str cwd = os.getcwd() if cwd not in sys.path: sys.path.insert(0, cwd) + + # Add the project root inferred from input files to sys.path so that + # sibling-package imports work when files live outside CWD (e.g. + # sandbox bundles extracted to a temp directory). Walk up from each + # file's directory looking for a register.py (bundle marker) or the + # filesystem root, whichever comes first. + for f in files: + d = os.path.dirname(os.path.abspath(f)) + while d and d != os.path.dirname(d): + if os.path.isfile(os.path.join(d, "register.py")): + if d not in sys.path: + sys.path.insert(0, d) + break + d = os.path.dirname(d) + unique_files: set[str] = set() for file_path in files: for candidate in collect_files(file_path): diff --git a/scripts/functions-runner.py b/scripts/functions-runner.py index a140a19..fa93c75 100644 --- a/scripts/functions-runner.py +++ b/scripts/functions-runner.py @@ -29,9 +29,9 @@ def to_json_value(value: Any) -> Any: return [to_json_value(item) for item in value] if isinstance(value, dict): return {str(key): to_json_value(val) for key, val in value.items()} - if hasattr(value, "model_dump"): + if hasattr(value, "model_dump") and not isinstance(value, type): return to_json_value(value.model_dump()) - if hasattr(value, "dict"): + if hasattr(value, "dict") and not isinstance(value, type): return to_json_value(value.dict()) if hasattr(value, "__dict__"): result: dict[str, Any] = {} @@ -451,8 +451,13 @@ async def process_file(file_path: str) -> dict[str, Any]: continue seen_sources.add(init_source) bundled_sources.append(init_source) + # Compute entry_module as a CWD-relative dotted path so that the + # archive root inferred by push.rs walks back to CWD, matching + # the Python SDK behavior and allowing sibling-package imports. + rel_path = os.path.relpath(abs_path, cwd) + archive_module = re.sub(r"\.py$", "", rel_path).replace("-", "_").replace(os.sep, ".") file_manifest["python_bundle"] = { - "entry_module": module_name, + "entry_module": archive_module, "sources": bundled_sources, } diff --git a/scripts/python_runner_common.py b/scripts/python_runner_common.py index 4a738f9..1d83141 100644 --- a/scripts/python_runner_common.py +++ b/scripts/python_runner_common.py @@ -61,6 +61,12 @@ def purge_local_modules(cwd: str, preserve_modules: set[str] | None = None) -> N candidate_abs = os.path.abspath(candidate) if not os.path.isfile(candidate_abs): continue + # Skip installed packages inside virtualenvs under cwd (e.g. .venv/lib/.../site-packages). + if os.sep + "site-packages" + os.sep in candidate_abs: + continue + # Skip bt runner scripts materialised under .bt/. + if os.sep + ".bt" + os.sep in candidate_abs: + continue try: common = os.path.commonpath([candidate_abs, cwd_abs]) except ValueError: @@ -84,6 +90,12 @@ def collect_python_sources(cwd: str, input_file: str) -> list[str]: continue if not candidate_abs.endswith(".py"): continue + # Skip installed packages inside virtualenvs under cwd (e.g. .venv/lib/.../site-packages). + if os.sep + "site-packages" + os.sep in candidate_abs: + continue + # Skip bt runner scripts materialised under .bt/. + if os.sep + ".bt" + os.sep in candidate_abs: + continue try: common = os.path.commonpath([candidate_abs, cwd]) except ValueError: diff --git a/src/eval.rs b/src/eval.rs index d108c61..6bc5d40 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1207,12 +1207,6 @@ fn serialize_sse_event(event: &str, data: &str) -> String { format!("event: {event}\ndata: {data}\n\n") } -fn is_eval_progress_payload(progress: &SseProgressEventData) -> bool { - serde_json::from_str::(&progress.data) - .map(|payload| payload.kind_type == "eval_progress") - .unwrap_or(false) -} - fn encode_eval_event_for_http(event: &EvalEvent) -> Option { match event { EvalEvent::Processing(payload) => serde_json::to_string(payload) @@ -1224,15 +1218,9 @@ fn encode_eval_event_for_http(event: &EvalEvent) -> Option { EvalEvent::Summary(summary) => serde_json::to_string(summary) .ok() .map(|data| serialize_sse_event("summary", &data)), - EvalEvent::Progress(progress) => { - if is_eval_progress_payload(progress) { - None - } else { - serde_json::to_string(progress) - .ok() - .map(|data| serialize_sse_event("progress", &data)) - } - } + EvalEvent::Progress(progress) => serde_json::to_string(progress) + .ok() + .map(|data| serialize_sse_event("progress", &data)), EvalEvent::Dependencies { .. } => None, EvalEvent::Done => Some(serialize_sse_event("done", "")), EvalEvent::Error { @@ -2188,7 +2176,7 @@ fn build_python_command( .or_else(|| std::env::var("BT_EVAL_PYTHON_RUNNER").ok()) .or_else(|| std::env::var("BT_EVAL_PYTHON").ok()); - let command = if let Some(explicit) = runner_override { + let mut command = if let Some(explicit) = runner_override { let mut command = Command::new(explicit); command.arg(runner).args(files); command @@ -4022,7 +4010,7 @@ mod tests { } #[test] - fn encode_eval_event_for_http_filters_internal_eval_progress() { + fn encode_eval_event_for_http_forwards_eval_progress() { let event = EvalEvent::Progress(SseProgressEventData { id: "id-1".to_string(), object_type: "task".to_string(), @@ -4034,7 +4022,9 @@ mod tests { data: r#"{"type":"eval_progress","kind":"start","total":1}"#.to_string(), }); - assert!(encode_eval_event_for_http(&event).is_none()); + let encoded = encode_eval_event_for_http(&event).expect("eval_progress should be forwarded"); + assert!(encoded.contains("event: progress")); + assert!(encoded.contains("eval_progress")); } #[test] diff --git a/src/functions/push.rs b/src/functions/push.rs index 5d4de74..0ca6109 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -1125,6 +1125,14 @@ fn collect_classified_files(inputs: &[PathBuf]) -> Result { let mut explicit_js_like = 0usize; let mut explicit_python = 0usize; + // Always include CWD so that Python files importing from sibling + // packages (e.g. `from src.agents import ...`) are accepted. + if let Ok(cwd) = std::env::current_dir() { + if let Ok(canonical_cwd) = cwd.canonicalize() { + allowed_roots.insert(canonical_cwd); + } + } + for input in inputs { let path = if input.is_absolute() { input.clone() From 9628bbc49f6889acda041fdfd869a51dd22893dc Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Mon, 16 Mar 2026 16:11:49 +0000 Subject: [PATCH 24/28] WIP --- scripts/functions-runner.py | 10 +++++++++- src/eval.rs | 24 +++++++++++++++++------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/scripts/functions-runner.py b/scripts/functions-runner.py index fa93c75..0291a50 100644 --- a/scripts/functions-runner.py +++ b/scripts/functions-runner.py @@ -340,7 +340,15 @@ def collect_evaluator_entries(evals_registry: Any, source_file: str) -> list[dic }, } else: - serialized = to_json_value(raw_params) + # Use the braintrust SDK's parameters_to_json_schema when + # available so that Pydantic model classes are converted to + # proper staticParametersSchema entries (type: "data" with a + # JSON Schema) that the UI can parse. + try: + from braintrust.parameters import parameters_to_json_schema + serialized = parameters_to_json_schema(raw_params) + except Exception: + serialized = to_json_value(raw_params) if serialized is not None: evaluator_definition["parameters"] = serialized diff --git a/src/eval.rs b/src/eval.rs index 6bc5d40..bf3afd5 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1218,9 +1218,21 @@ fn encode_eval_event_for_http(event: &EvalEvent) -> Option { EvalEvent::Summary(summary) => serde_json::to_string(summary) .ok() .map(|data| serialize_sse_event("summary", &data)), - EvalEvent::Progress(progress) => serde_json::to_string(progress) - .ok() - .map(|data| serialize_sse_event("progress", &data)), + EvalEvent::Progress(progress) => { + // Filter out internal eval_progress events (start/increment/stop) + // which are used for CLI progress bars but crash the UI stream + // parser. Only forward external progress events (e.g. json_delta). + if serde_json::from_str::(&progress.data) + .map(|p| p.kind_type == "eval_progress") + .unwrap_or(false) + { + None + } else { + serde_json::to_string(progress) + .ok() + .map(|data| serialize_sse_event("progress", &data)) + } + } EvalEvent::Dependencies { .. } => None, EvalEvent::Done => Some(serialize_sse_event("done", "")), EvalEvent::Error { @@ -4010,7 +4022,7 @@ mod tests { } #[test] - fn encode_eval_event_for_http_forwards_eval_progress() { + fn encode_eval_event_for_http_filters_internal_eval_progress() { let event = EvalEvent::Progress(SseProgressEventData { id: "id-1".to_string(), object_type: "task".to_string(), @@ -4022,9 +4034,7 @@ mod tests { data: r#"{"type":"eval_progress","kind":"start","total":1}"#.to_string(), }); - let encoded = encode_eval_event_for_http(&event).expect("eval_progress should be forwarded"); - assert!(encoded.contains("event: progress")); - assert!(encoded.contains("eval_progress")); + assert!(encode_eval_event_for_http(&event).is_none()); } #[test] From 1ff2094a12a5b7372349929665c2566b34776908 Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Wed, 18 Mar 2026 21:34:08 +0000 Subject: [PATCH 25/28] fix CI --- src/eval.rs | 2 +- src/source_language.rs | 1 + tests/eval_dev_server.rs | 16 ++++++++-------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/eval.rs b/src/eval.rs index bf3afd5..bb67433 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -2188,7 +2188,7 @@ fn build_python_command( .or_else(|| std::env::var("BT_EVAL_PYTHON_RUNNER").ok()) .or_else(|| std::env::var("BT_EVAL_PYTHON").ok()); - let mut command = if let Some(explicit) = runner_override { + let command = if let Some(explicit) = runner_override { let mut command = Command::new(explicit); command.arg(runner).args(files); command diff --git a/src/source_language.rs b/src/source_language.rs index 8a1b71f..1bb82bc 100644 --- a/src/source_language.rs +++ b/src/source_language.rs @@ -7,6 +7,7 @@ pub enum SourceLanguage { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum JsExtensionProfile { FunctionsPush, + #[allow(dead_code)] Eval, } diff --git a/tests/eval_dev_server.rs b/tests/eval_dev_server.rs index 3921b92..88d6936 100644 --- a/tests/eval_dev_server.rs +++ b/tests/eval_dev_server.rs @@ -123,10 +123,10 @@ fn parse_sse_events(body: &str) -> Vec { let mut current_data = Vec::::new(); for line in body.lines() { - if line.starts_with("event: ") { - current_event = line["event: ".len()..].to_string(); - } else if line.starts_with("data: ") { - current_data.push(line["data: ".len()..].to_string()); + if let Some(event) = line.strip_prefix("event: ") { + current_event = event.to_string(); + } else if let Some(data) = line.strip_prefix("data: ") { + current_data.push(data.to_string()); } else if line.is_empty() && !current_event.is_empty() { events.push(SseEvent { event: std::mem::take(&mut current_event), @@ -518,10 +518,10 @@ fn streaming_eval_post( Ok(l) => l, Err(_) => break, }; - if line.starts_with("event: ") { - current_event = line["event: ".len()..].to_string(); - } else if line.starts_with("data: ") { - current_data.push(line["data: ".len()..].to_string()); + if let Some(event) = line.strip_prefix("event: ") { + current_event = event.to_string(); + } else if let Some(data) = line.strip_prefix("data: ") { + current_data.push(data.to_string()); } else if line.is_empty() && !current_event.is_empty() { let event = SseEvent { event: std::mem::take(&mut current_event), From 8eb180f0042e052c0a831467bce81d95eb058de2 Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Thu, 19 Mar 2026 18:56:47 +0000 Subject: [PATCH 26/28] add --sandbox flag --- scripts/eval-runner.py | 201 ++++++++++- scripts/eval-runner.ts | 252 ++++++++++++- src/eval.rs | 795 ++++++++++++++++++++++++++++++++++++++++- src/experiments/api.rs | 2 + src/functions/mod.rs | 3 +- src/functions/push.rs | 155 +++++++- src/sync.rs | 2 +- tests/functions.rs | 5 +- 8 files changed, 1383 insertions(+), 32 deletions(-) diff --git a/scripts/eval-runner.py b/scripts/eval-runner.py index e9d3ce0..d8ac067 100755 --- a/scripts/eval-runner.py +++ b/scripts/eval-runner.py @@ -9,9 +9,10 @@ import re import socket import sys +import time import traceback from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, AsyncIterator, Callable try: from braintrust import init_dataset, invoke, login @@ -79,6 +80,41 @@ def close(self) -> None: self.sock.close() +@dataclass +class PullChannel: + sock: socket.socket + + def send(self, payload: Any) -> None: + self.sock.sendall((json.dumps(payload) + "\n").encode("utf-8")) + + async def lines(self) -> AsyncIterator[str]: + buffer = "" + while True: + chunk = await asyncio.to_thread(self.sock.recv, 4096) + if not chunk: + break + buffer += chunk.decode("utf-8") + while True: + newline = buffer.find("\n") + if newline == -1: + break + line = buffer[:newline].strip() + buffer = buffer[newline + 1 :] + if line: + yield line + + trailing = buffer.strip() + if trailing: + yield trailing + + def close(self) -> None: + try: + self.sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass + self.sock.close() + + def serialize_sse_event(event: str, data: Any) -> str: if isinstance(data, (dict, list)): data_str = json.dumps(data) @@ -105,6 +141,16 @@ def create_sse_writer() -> SseWriter | None: return None +def create_pull_channel() -> PullChannel | None: + sock_path = os.getenv("BT_EVAL_PULL_SOCK") + if not sock_path: + return None + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(sock_path) + return PullChannel(sock) + + def env_flag(name: str) -> bool: value = os.getenv(name) if value is None: @@ -137,7 +183,7 @@ def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]: def parse_dev_mode(value: str | None) -> str | None: if value is None or value == "": return None - if value in {"list", "eval"}: + if value in {"list", "eval", "rows"}: return value raise ValueError(f"Invalid BT_EVAL_DEV_MODE value: {value}") @@ -302,6 +348,26 @@ def parse_eval_request(raw: str | None) -> dict[str, Any]: return parsed +def parse_eval_pull_request(raw: str | None) -> dict[str, Any]: + if not raw: + raise ValueError("Missing BT_EVAL_DEV_REQUEST_JSON") + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid BT_EVAL_DEV_REQUEST_JSON: {exc}") from exc + + if not isinstance(parsed, dict): + raise ValueError("BT_EVAL_DEV_REQUEST_JSON must be a JSON object.") + if not isinstance(parsed.get("name"), str) or not parsed["name"]: + raise ValueError("Pull request must include a non-empty evaluator name.") + + parameters = parsed.get("parameters") + if parameters is not None and not isinstance(parameters, dict): + raise ValueError("Pull request parameters must be an object.") + + return parsed + + def resolve_eval_data(data: dict[str, Any]) -> Any: if "data" in data: return data["data"] @@ -324,6 +390,33 @@ def resolve_eval_data(data: dict[str, Any]) -> Any: raise ValueError("Invalid eval data payload.") +async def call_evaluator_data(data: Any) -> tuple[Any, str | None]: + data_result = data + if inspect.isclass(data_result): + data_result = data_result() + if inspect.isfunction(data_result) or inspect.isroutine(data_result): + data_result = data_result() + if inspect.isawaitable(data_result): + data_result = await data_result + + base_experiment_name = None + if isinstance(data_result, BaseExperiment): + base_experiment_name = data_result.name + + return data_result, base_experiment_name + + +def to_async_iterator(value: Any) -> AsyncIterator[Any]: + if inspect.isasyncgen(value): + return value + + async def to_async(it): + for item in it: + yield item + + return to_async(value) + + def make_eval_scorer( score: dict[str, Any], project_id: str | None, @@ -851,6 +944,108 @@ async def run_requested_eval( return True +async def run_dataset_pull( + evaluator_instances: list[EvaluatorInstance], + config: RunnerConfig, +) -> bool: + channel = create_pull_channel() + if channel is None: + raise ValueError("Missing BT_EVAL_PULL_SOCK") + + try: + request = parse_eval_pull_request(config.dev_request_json) + except Exception as exc: + channel.send({"type": "error", "message": str(exc)}) + channel.close() + return False + + target_name = request["name"] + evaluator_instance = next( + (candidate for candidate in evaluator_instances if candidate.evaluator.eval_name == target_name), + None, + ) + if evaluator_instance is None: + channel.send({"type": "error", "message": f"Evaluator '{target_name}' not found"}) + channel.close() + return False + + evaluator = evaluator_instance.evaluator + try: + raw_data, _base_experiment_name = await call_evaluator_data(evaluator.data) + data_iterator = to_async_iterator(raw_data) + iterator = data_iterator.__aiter__() + trial_count = getattr(evaluator, "trial_count", 1) + try: + trial_count = int(trial_count) + except Exception: + trial_count = 1 + if trial_count < 1: + trial_count = 1 + + max_concurrency = getattr(evaluator, "max_concurrency", None) + try: + max_concurrency = int(max_concurrency) if max_concurrency is not None else 10 + except Exception: + max_concurrency = 10 + if max_concurrency < 1: + max_concurrency = 1 + + experiment_name = getattr(evaluator, "experiment_name", None) + if not isinstance(experiment_name, str) or not experiment_name: + experiment_name = f"{evaluator.eval_name}-{int(time.time() * 1000)}" + + channel.send( + { + "type": "ready", + "evaluator_name": evaluator.eval_name, + "max_concurrency": max_concurrency, + "experiment_name": experiment_name, + } + ) + + current_datum = None + trial_index = 0 + async for line in channel.lines(): + parsed = json.loads(line) + command_type = parsed.get("type") if isinstance(parsed, dict) else None + if command_type == "close": + break + if command_type != "next": + channel.send( + { + "type": "error", + "message": f"Unsupported pull command '{command_type}'", + } + ) + break + + if current_datum is None: + try: + current_datum = await iterator.__anext__() + trial_index = 0 + except StopAsyncIteration: + channel.send({"type": "eof"}) + continue + + channel.send( + { + "type": "row", + "datum": current_datum, + "trial_index": trial_index, + } + ) + trial_index += 1 + if trial_index >= trial_count: + current_datum = None + except Exception as exc: + channel.send({"type": "error", "message": str(exc)}) + channel.close() + return False + + channel.close() + return True + + async def run_once( files: list[str], no_send_logs: bool, @@ -872,6 +1067,8 @@ async def run_once( return True if config.dev_mode == "eval": return await run_requested_eval(evaluators, reporters, no_send_logs, sse, config) + if config.dev_mode == "rows": + return await run_dataset_pull(evaluators, config) if config.list_only: for evaluator_instance in evaluators: diff --git a/scripts/eval-runner.ts b/scripts/eval-runner.ts index 2a19c10..fdd7152 100644 --- a/scripts/eval-runner.ts +++ b/scripts/eval-runner.ts @@ -2,11 +2,17 @@ import { createRequire } from "node:module"; import path from "node:path"; import { fileURLToPath, pathToFileURL } from "node:url"; +type EvaluatorDefinition = { + evalName: string; + projectName: string; + data?: unknown; + trialCount?: unknown; + maxConcurrency?: unknown; + experimentName?: unknown; +} & Record; + type EvaluatorEntry = { - evaluator: { - evalName: string; - projectName: string; - } & Record; + evaluator: EvaluatorDefinition; reporter?: unknown; }; @@ -111,12 +117,17 @@ type EvalRequest = { scores?: EvalScoreSpec[]; }; +type EvalPullRequest = { + name: string; + parameters?: Record; +}; + type RunnerConfig = { jsonl: boolean; list: boolean; terminateOnFailure: boolean; filters: EvalFilter[]; - devMode: "list" | "eval" | null; + devMode: "list" | "eval" | "rows" | null; devRequestJson: string | null; }; @@ -243,11 +254,13 @@ function parseSerializedFilters(serialized: string | undefined): EvalFilter[] { } } -function parseDevMode(value: string | undefined): "list" | "eval" | null { +function parseDevMode( + value: string | undefined, +): "list" | "eval" | "rows" | null { if (!value) { return null; } - if (value === "list" || value === "eval") { + if (value === "list" || value === "eval" || value === "rows") { return value; } throw new Error(`Invalid BT_EVAL_DEV_MODE value: ${value}`); @@ -283,6 +296,7 @@ type NetModule = { setNoDelay: (value?: boolean) => void; on: (event: string, listener: (...args: unknown[]) => void) => void; write: (data: string) => void; + [Symbol.asyncIterator]?: () => AsyncIterator; }; }; @@ -766,6 +780,69 @@ function createSseWriter(): SseWriter | null { return { send, close }; } +type PullChannel = { + send: (payload: unknown) => void; + close: () => void; + lines: () => AsyncGenerator; +}; + +function createPullChannel(): PullChannel | null { + const netModule = (() => { + try { + return runtimeRequire("node:net") as NetModule; + } catch { + return null; + } + })(); + const sock = process.env.BT_EVAL_PULL_SOCK; + if (!sock) { + return null; + } + if (!netModule) { + return null; + } + + const socket = netModule.createConnection({ path: sock }); + socket.setNoDelay(true); + + const send = (payload: unknown) => { + if (!socket.writable) { + return; + } + socket.write(`${JSON.stringify(payload)}\n`); + }; + + const close = () => { + socket.end(); + }; + + const lines = async function* () { + let buffer = ""; + for await (const chunk of socket as unknown as AsyncIterable< + Buffer | string + >) { + buffer += typeof chunk === "string" ? chunk : chunk.toString("utf8"); + while (true) { + const newline = buffer.indexOf("\n"); + if (newline === -1) { + break; + } + const line = buffer.slice(0, newline).trim(); + buffer = buffer.slice(newline + 1); + if (line.length > 0) { + yield line; + } + } + } + const trailing = buffer.trim(); + if (trailing.length > 0) { + yield trailing; + } + }; + + return { send, close, lines }; +} + function initRegistry() { globalThis._evals = { functions: [], @@ -1405,6 +1482,24 @@ async function buildEvaluatorDefinitions(evaluators: EvaluatorEntry[]) { return result; } +function parseEvalPullRequest(raw: string | null): EvalPullRequest { + if (!raw) { + throw new Error("Missing BT_EVAL_DEV_REQUEST_JSON"); + } + const parsed = JSON.parse(raw); + if (!isObject(parsed) || typeof parsed.name !== "string" || parsed.name.length === 0) { + throw new Error("Pull request must include a non-empty evaluator name."); + } + const request = parsed as EvalPullRequest; + if ( + request.parameters !== undefined && + (!isObject(request.parameters) || Array.isArray(request.parameters)) + ) { + throw new Error("Pull request parameters must be an object."); + } + return request; +} + function parseEvalRequest(raw: string | null): EvalRequest { if (!raw) { throw new Error("Missing BT_EVAL_DEV_REQUEST_JSON"); @@ -1487,6 +1582,48 @@ function resolveEvalData( throw new Error("Invalid eval data payload."); } +function callEvaluatorData( + data: unknown, +): { data: unknown; baseExperiment: string | undefined } { + const dataResult = typeof data === "function" ? (data as () => unknown)() : data; + let baseExperiment: string | undefined = undefined; + if ( + isObject(dataResult) && + Reflect.get(dataResult, "_type") === "BaseExperiment" && + typeof Reflect.get(dataResult, "name") === "string" + ) { + baseExperiment = Reflect.get(dataResult, "name") as string; + } + return { data: dataResult, baseExperiment }; +} + +function toAsyncIterable(value: unknown): AsyncIterable { + if ( + typeof value === "object" && + value !== null && + Symbol.asyncIterator in value && + typeof (value as AsyncIterable)[Symbol.asyncIterator] === "function" + ) { + return value as AsyncIterable; + } + if ( + typeof value === "object" && + value !== null && + Symbol.iterator in value && + typeof (value as Iterable)[Symbol.iterator] === "function" + ) { + const iterable = value as Iterable; + return (async function* () { + for (const item of iterable) { + yield item; + } + })(); + } + throw new Error( + "Evaluator data must be an array, iterable, or async iterable", + ); +} + function convertFunctionId( functionId: Record, ): Record { @@ -1661,6 +1798,102 @@ async function runRequestedEval(config: RunnerConfig, runner: EvalRunner) { } } +async function runDatasetPull(config: RunnerConfig, runner: EvalRunner) { + const channel = createPullChannel(); + if (!channel) { + throw new Error("Missing BT_EVAL_PULL_SOCK"); + } + + try { + const request = parseEvalPullRequest(config.devRequestJson); + const entry = getEvaluators().find( + (candidate) => candidate.evaluator.evalName === request.name, + ); + if (!entry) { + channel.send({ + type: "error", + message: `Evaluator '${request.name}' not found`, + }); + return; + } + + const state = runner.getState ? runner.getState() : undefined; + const evaluator = { + ...entry.evaluator, + ...(state !== undefined && state !== null ? { state } : {}), + }; + const { data: rawData } = callEvaluatorData(evaluator.data); + const dataIterable = toAsyncIterable(rawData); + const iterator = dataIterable[Symbol.asyncIterator](); + const trialCountRaw = Number(evaluator.trialCount ?? 1); + const trialCount = + Number.isFinite(trialCountRaw) && trialCountRaw > 0 + ? Math.floor(trialCountRaw) + : 1; + const maxConcurrencyRaw = Number(evaluator.maxConcurrency ?? 10); + const maxConcurrency = + Number.isFinite(maxConcurrencyRaw) && maxConcurrencyRaw > 0 + ? Math.floor(maxConcurrencyRaw) + : 10; + const experimentName = + typeof evaluator.experimentName === "string" && + evaluator.experimentName.length > 0 + ? evaluator.experimentName + : `${entry.evaluator.evalName}-${Date.now()}`; + + channel.send({ + type: "ready", + evaluator_name: entry.evaluator.evalName, + max_concurrency: maxConcurrency, + experiment_name: experimentName, + }); + + let currentDatum: unknown | undefined = undefined; + let trialIndex = 0; + for await (const line of channel.lines()) { + const parsed = JSON.parse(line) as { type?: string }; + if (parsed.type === "close") { + break; + } + if (parsed.type !== "next") { + channel.send({ + type: "error", + message: `Unsupported pull command '${String(parsed.type)}'`, + }); + break; + } + + if (currentDatum === undefined) { + const next = await iterator.next(); + if (next.done) { + channel.send({ type: "eof" }); + continue; + } + currentDatum = next.value; + trialIndex = 0; + } + + channel.send({ + type: "row", + datum: currentDatum, + trial_index: trialIndex, + }); + + trialIndex += 1; + if (trialIndex >= trialCount) { + currentDatum = undefined; + } + } + } catch (err) { + channel.send({ + type: "error", + message: err instanceof Error ? err.message : String(err), + }); + } finally { + channel.close(); + } +} + function extractBtEvalMain(mod: unknown): BtEvalMain | null { if (!mod || typeof mod !== "object") { return null; @@ -2061,6 +2294,11 @@ async function main() { return; } + if (config.devMode === "rows") { + await runDatasetPull(config, runner); + return; + } + if (config.list) { for (const entry of filteredEvaluators) { console.log(entry.evaluator.evalName); diff --git a/src/eval.rs b/src/eval.rs index bb67433..ac792e4 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -15,6 +15,7 @@ use actix_web::http::header::{ }; use actix_web::{guard, web, App, HttpRequest, HttpResponse, HttpServer}; use anyhow::{Context, Result}; +use chrono::{SecondsFormat, Utc}; use clap::{Args, ValueEnum}; use crossterm::queue; use crossterm::style::{ @@ -22,11 +23,13 @@ use crossterm::style::{ Stylize, }; use futures_util::stream; +use futures_util::StreamExt; use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use strip_ansi_escapes::strip; +use tokio::io::AsyncWriteExt; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::net::UnixListener; use tokio::process::Command; @@ -41,7 +44,12 @@ use ratatui::widgets::{Cell, Row, Table}; use ratatui::Terminal; use crate::args::BaseArgs; +use crate::auth::login; use crate::auth::resolved_auth_env; +use crate::experiments::api::create_experiment; +use crate::functions::publish_eval_sandbox_functions; +use crate::http::ApiClient; +use crate::source_language::SourceLanguage; use crate::ui::{animations_enabled, is_quiet}; const MAX_NAME_LENGTH: usize = 40; @@ -161,6 +169,61 @@ struct ResolvedDatasetEvalData { _internal_btql: Option, } +#[derive(Debug, Serialize, Deserialize)] +struct EvalPullRequest { + name: String, + #[serde(default)] + parameters: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum EvalPullClientMessage { + Next, + Close, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum EvalPullResponse { + Ready { + evaluator_name: String, + max_concurrency: usize, + experiment_name: String, + }, + Row { + datum: Value, + trial_index: usize, + }, + Eof, + Error { + message: String, + }, +} + +#[derive(Debug)] +struct EvalDataPuller { + child: tokio::process::Child, + writer: tokio::net::unix::OwnedWriteHalf, + reader: BufReader, + _socket_cleanup_guard: SocketCleanupGuard, +} + +#[derive(Debug, Clone)] +struct EvalSandboxPlan { + evaluator_name: String, + function_id: String, + project_id: String, +} + +#[derive(Debug, Clone, Deserialize)] +struct SandboxSummaryRow { + #[serde(default)] + scores: HashMap>, + #[serde(default)] + metrics: HashMap, +} + #[derive(Clone)] struct DevServerState { base: BaseArgs, @@ -194,6 +257,7 @@ const PY_RUNNER_FILE: &str = "eval-runner.py"; const JS_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.ts"); const PY_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.py"); +#[derive(Debug)] struct SocketCleanupGuard { path: PathBuf, } @@ -218,6 +282,12 @@ pub enum EvalLanguage { Python, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +pub enum EvalSandbox { + Local, + Lambda, +} + #[derive(Debug, Clone, Args)] #[command(after_help = "\ Examples: @@ -245,6 +315,10 @@ pub struct EvalArgs { )] pub language: Option, + /// Execute evals locally or in a remote sandbox. + #[arg(long, env = "BT_EVAL_SANDBOX", value_enum, default_value = "local")] + pub sandbox: EvalSandbox, + /// Run evals locally (do not send logs to Braintrust). #[arg( long, @@ -389,6 +463,31 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { extra_args: args.extra_args, }; + if args.sandbox != EvalSandbox::Local { + if args.dev { + anyhow::bail!("--sandbox is not supported with --dev."); + } + if args.watch { + anyhow::bail!("--sandbox is not supported with --watch."); + } + if args.list { + anyhow::bail!("--sandbox is not supported with --list."); + } + if files.len() != 1 { + anyhow::bail!("`bt eval --sandbox lambda` currently supports exactly one eval file."); + } + return run_eval_files_sandbox( + &base, + args.sandbox, + args.language, + args.runner.as_deref(), + &files, + args.no_send_logs, + &options, + ) + .await; + } + if args.dev { let language = detect_eval_language(&files, args.language)?; let app_url = resolve_app_url(&base); @@ -500,6 +599,667 @@ async fn run_eval_files_watch( } } +async fn run_eval_files_sandbox( + base: &BaseArgs, + sandbox: EvalSandbox, + language_override: Option, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, +) -> Result<()> { + if sandbox != EvalSandbox::Lambda { + anyhow::bail!("unsupported sandbox mode"); + } + if no_send_logs { + anyhow::bail!("--sandbox lambda is not supported with --no-send-logs."); + } + let language = detect_eval_language(files, language_override)?; + let source_language = match language { + EvalLanguage::JavaScript => SourceLanguage::JsLike, + EvalLanguage::Python => SourceLanguage::Python, + }; + + let source_file = PathBuf::from( + files + .first() + .ok_or_else(|| anyhow::anyhow!("missing sandbox source file"))?, + ); + let published = + publish_eval_sandbox_functions(base, &source_file, runner_override, source_language) + .await?; + let evaluator_names = list_sandbox_evaluator_names( + base, + language, + runner_override, + files, + no_send_logs, + options, + ) + .await?; + if evaluator_names.is_empty() { + anyhow::bail!("No evaluators found. Did you call Eval() in the file?"); + } + + let mut plans = Vec::new(); + for evaluator_name in evaluator_names { + let slug = sandbox_slug_from_source(&source_file, &evaluator_name); + let published_entry = published + .iter() + .find(|entry| entry.slug == slug) + .ok_or_else(|| { + anyhow::anyhow!( + "sandbox function '{}' for evaluator '{}' was not published", + slug, + evaluator_name + ) + })?; + plans.push(EvalSandboxPlan { + evaluator_name, + function_id: published_entry.function_id.clone(), + project_id: published_entry.project_id.clone(), + }); + } + + let login_ctx = login(base).await?; + let client = ApiClient::new(&login_ctx)?; + let started_at = Utc::now().to_rfc3339_opts(SecondsFormat::Millis, true); + + for plan in plans { + let mut puller = spawn_eval_data_puller( + base, + language, + runner_override, + files, + no_send_logs, + options, + &EvalPullRequest { + name: plan.evaluator_name.clone(), + parameters: Some(json!({})), + }, + ) + .await?; + let ready = puller.read_message().await?; + let (max_concurrency, experiment_name) = match ready { + EvalPullResponse::Ready { + evaluator_name, + max_concurrency, + experiment_name, + } => { + if evaluator_name != plan.evaluator_name { + anyhow::bail!( + "sandbox runner selected unexpected evaluator '{}', expected '{}'", + evaluator_name, + plan.evaluator_name + ); + } + (max_concurrency.max(1), experiment_name) + } + EvalPullResponse::Error { message } => anyhow::bail!("{message}"), + other => anyhow::bail!("unexpected initial sandbox pull response: {other:?}"), + }; + + let experiment = create_experiment(&client, &plan.project_id, &experiment_name, true) + .await + .with_context(|| { + format!( + "failed to create sandbox parent experiment '{}' for evaluator '{}'", + experiment_name, plan.evaluator_name + ) + })?; + let mut in_flight: tokio::task::JoinSet>> = + tokio::task::JoinSet::new(); + let mut saw_eof = false; + let mut experiment_url: Option = None; + + while !saw_eof || !in_flight.is_empty() { + while !saw_eof && in_flight.len() < max_concurrency { + puller.send_message(&EvalPullClientMessage::Next).await?; + match puller.read_message().await? { + EvalPullResponse::Row { + datum, + trial_index: _trial_index, + } => { + let function_id = plan.function_id.clone(); + let evaluator_name = plan.evaluator_name.clone(); + let project_id = plan.project_id.clone(); + let body = json!({ + "api_version": 1, + "function_id": { "function_id": function_id }, + "name": evaluator_name, + "project_id": project_id, + "scores": [], + "stream": true, + "experiment_name": experiment.name, + "parent": { + "object_type": "experiment", + "object_id": experiment.id, + }, + "data": { "data": [datum] }, + }); + let client_cloned = client.clone(); + let org_name = login_ctx.login.org_name.clone(); + let project_id = plan.project_id.clone(); + in_flight.spawn(async move { + invoke_sandbox_eval(&client_cloned, &org_name, &project_id, body).await + }); + } + EvalPullResponse::Eof => saw_eof = true, + EvalPullResponse::Error { message } => anyhow::bail!("{message}"), + other => anyhow::bail!("unexpected sandbox pull response: {other:?}"), + } + } + + if let Some(joined) = in_flight.join_next().await { + if let Some(start) = joined?? { + if experiment_url.is_none() { + experiment_url = start.experiment_url.clone(); + } + } + } + } + + puller.send_message(&EvalPullClientMessage::Close).await?; + puller.wait().await?; + + let summary = summarize_sandbox_experiment( + &client, + &plan.project_id, + &plan.project_id, + &experiment.name, + &experiment.id, + experiment_url, + &started_at, + ) + .await?; + let rendered = format_experiment_summary(&summary); + println!("{rendered}"); + } + + Ok(()) +} + +impl EvalDataPuller { + async fn send_message(&mut self, message: &EvalPullClientMessage) -> Result<()> { + let mut payload = + serde_json::to_string(message).context("failed to serialize pull request")?; + payload.push('\n'); + self.writer + .write_all(payload.as_bytes()) + .await + .context("failed to write pull request")?; + self.writer + .flush() + .await + .context("failed to flush pull request")?; + Ok(()) + } + + async fn read_message(&mut self) -> Result { + let mut line = String::new(); + let read = self + .reader + .read_line(&mut line) + .await + .context("failed to read sandbox pull response")?; + if read == 0 { + let status = self + .child + .wait() + .await + .context("sandbox pull runner exited unexpectedly")?; + anyhow::bail!("sandbox pull runner exited with status {status}"); + } + serde_json::from_str(line.trim()).context("failed to parse sandbox pull response JSON") + } + + async fn wait(mut self) -> Result<()> { + let status = self + .child + .wait() + .await + .context("sandbox pull runner failed")?; + if !status.success() { + anyhow::bail!("sandbox pull runner exited with status {status}"); + } + Ok(()) + } +} + +fn sandbox_slugify(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + let mut previous_dash = false; + for ch in input.chars() { + let lower = ch.to_ascii_lowercase(); + if lower.is_ascii_alphanumeric() { + out.push(lower); + previous_dash = false; + } else if !previous_dash { + out.push('-'); + previous_dash = true; + } + } + out.trim_matches('-').to_string() +} + +fn sandbox_slug_from_source(source_file: &Path, eval_name: &str) -> String { + let stem = source_file + .file_stem() + .and_then(|value| value.to_str()) + .map(|value| value.strip_suffix(".eval").unwrap_or(value)) + .unwrap_or("eval"); + sandbox_slugify(&format!("{stem}-{eval_name}-sandbox")) +} + +async fn list_sandbox_evaluator_names( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, +) -> Result> { + let output = run_eval_runner_command_to_completion( + base, + language, + runner_override, + files, + no_send_logs, + options, + &[("BT_EVAL_DEV_MODE".to_string(), "list".to_string())], + JsMode::Auto, + ) + .await?; + + let parsed: Value = + serde_json::from_slice(&output.stdout).context("failed to parse sandbox evaluator list")?; + let object = parsed + .as_object() + .ok_or_else(|| anyhow::anyhow!("sandbox evaluator list was not a JSON object"))?; + Ok(object.keys().cloned().collect()) +} + +async fn spawn_eval_data_puller( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, + request: &EvalPullRequest, +) -> Result { + let (listener, socket_path, socket_cleanup_guard) = + bind_unix_listener("bt-eval-pull").context("failed to bind sandbox pull socket")?; + let request_json = + serde_json::to_string(request).context("failed to serialize sandbox pull request")?; + let extra_env = vec![ + ("BT_EVAL_DEV_MODE".to_string(), "rows".to_string()), + ("BT_EVAL_DEV_REQUEST_JSON".to_string(), request_json), + ( + "BT_EVAL_PULL_SOCK".to_string(), + socket_path.to_string_lossy().to_string(), + ), + ]; + let child = spawn_eval_support_process( + base, + language, + runner_override, + files, + no_send_logs, + options, + &extra_env, + JsMode::Auto, + ) + .await?; + + let (stream, _) = tokio::time::timeout(Duration::from_secs(30), listener.accept()) + .await + .context("timed out waiting for sandbox pull runner to connect")? + .context("sandbox pull runner failed to connect")?; + let (read_half, write_half) = stream.into_split(); + Ok(EvalDataPuller { + child, + writer: write_half, + reader: BufReader::new(read_half), + _socket_cleanup_guard: socket_cleanup_guard, + }) +} + +async fn spawn_eval_support_process( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, + extra_env: &[(String, String)], + js_mode: JsMode, +) -> Result { + let (js_runner, py_runner) = prepare_eval_runners()?; + let force_esm = matches!(js_mode, JsMode::ForceEsm); + let (mut cmd, runner_kind) = match language { + EvalLanguage::Python => ( + build_python_command(runner_override, &py_runner, files)?, + RunnerKind::Other, + ), + EvalLanguage::JavaScript => { + if force_esm { + ( + build_vite_node_fallback_command(&js_runner, files)?, + RunnerKind::ViteNode, + ) + } else { + let plan = build_js_plan(runner_override, &js_runner, files)?; + (plan.cmd, plan.kind) + } + } + }; + if language == EvalLanguage::JavaScript && should_set_node_heap_size(runner_kind) { + set_node_heap_size_env(&mut cmd); + } + cmd.envs(build_env(base).await?); + for (key, value) in extra_env { + cmd.env(key, value); + } + if no_send_logs { + cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + cmd.env("BT_EVAL_LOCAL", "1"); + } + if options.jsonl { + cmd.env("BT_EVAL_JSONL", "1"); + } + if options.terminate_on_failure { + cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); + } + if options.list { + cmd.env("BT_EVAL_LIST", "1"); + } + if let Some(num_workers) = options.num_workers { + cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if language == EvalLanguage::JavaScript && force_esm { + cmd.env("BT_EVAL_FORCE_ESM", "1"); + } + if !options.extra_args.is_empty() { + let serialized = + serde_json::to_string(&options.extra_args).context("failed to serialize extra args")?; + cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + cmd.stdout(Stdio::inherit()); + cmd.stderr(Stdio::inherit()); + cmd.spawn().context("failed to start eval support runner") +} + +async fn run_eval_runner_command_to_completion( + base: &BaseArgs, + language: EvalLanguage, + runner_override: Option<&str>, + files: &[String], + no_send_logs: bool, + options: &EvalRunOptions, + extra_env: &[(String, String)], + js_mode: JsMode, +) -> Result { + let (js_runner, py_runner) = prepare_eval_runners()?; + let force_esm = matches!(js_mode, JsMode::ForceEsm); + let (mut cmd, runner_kind) = match language { + EvalLanguage::Python => ( + build_python_command(runner_override, &py_runner, files)?, + RunnerKind::Other, + ), + EvalLanguage::JavaScript => { + if force_esm { + ( + build_vite_node_fallback_command(&js_runner, files)?, + RunnerKind::ViteNode, + ) + } else { + let plan = build_js_plan(runner_override, &js_runner, files)?; + (plan.cmd, plan.kind) + } + } + }; + if language == EvalLanguage::JavaScript && should_set_node_heap_size(runner_kind) { + set_node_heap_size_env(&mut cmd); + } + cmd.envs(build_env(base).await?); + for (key, value) in extra_env { + cmd.env(key, value); + } + if no_send_logs { + cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + cmd.env("BT_EVAL_LOCAL", "1"); + } + if let Some(num_workers) = options.num_workers { + cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if !options.extra_args.is_empty() { + let serialized = + serde_json::to_string(&options.extra_args).context("failed to serialize extra args")?; + cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + let output = cmd + .output() + .await + .context("failed to run eval support runner")?; + if !output.status.success() { + anyhow::bail!( + "eval support runner exited with status {}: {}", + output.status, + String::from_utf8_lossy(&output.stderr).trim() + ); + } + Ok(output) +} + +async fn invoke_sandbox_eval( + client: &ApiClient, + org_name: &str, + project_id: &str, + body: Value, +) -> Result> { + let response = client + .post_with_headers_raw( + "/function/sandbox", + &body, + &[("x-bt-org-name", org_name), ("x-bt-project-id", project_id)], + ) + .await?; + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("sandbox invoke failed ({status}): {body}"); + } + + let mut bytes = response.bytes_stream(); + let mut buffer = String::new(); + let mut current_event: Option = None; + let mut data_lines: Vec = Vec::new(); + let mut start: Option = None; + let mut saw_done = false; + + while let Some(chunk) = bytes.next().await { + let chunk = chunk.context("failed to read sandbox SSE response")?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + while let Some(pos) = buffer.find('\n') { + let mut line: String = buffer.drain(..=pos).collect(); + if line.ends_with('\n') { + line.pop(); + } + if line.ends_with('\r') { + line.pop(); + } + if line.is_empty() { + if current_event.is_some() || !data_lines.is_empty() { + let event_name = current_event.take().unwrap_or_default(); + let data = data_lines.join("\n"); + data_lines.clear(); + match event_name.as_str() { + "start" => { + if let Ok(parsed) = serde_json::from_str::(&data) { + if start.is_none() { + start = Some(parsed); + } + } + } + "error" => { + if let Ok(payload) = serde_json::from_str::(&data) { + let message = payload + .get("message") + .or_else(|| payload.get("error")) + .and_then(Value::as_str) + .unwrap_or("sandbox eval failed"); + anyhow::bail!("{message}"); + } + anyhow::bail!("{data}"); + } + "done" => { + saw_done = true; + } + _ => {} + } + } + continue; + } + if let Some(value) = line.strip_prefix("event:") { + current_event = Some(value.trim().to_string()); + } else if let Some(value) = line.strip_prefix("data:") { + data_lines.push(value.trim_start().to_string()); + } + } + } + + if !saw_done { + anyhow::bail!("sandbox SSE stream ended before a done event"); + } + Ok(start) +} + +async fn summarize_sandbox_experiment( + client: &ApiClient, + project_name: &str, + _project_id: &str, + experiment_name: &str, + experiment_id: &str, + experiment_url: Option, + started_at: &str, +) -> Result { + let query = build_sandbox_summary_query(experiment_id, started_at); + let response = client.btql::(&query).await?; + Ok(aggregate_sandbox_summary( + project_name, + experiment_name, + experiment_id, + experiment_url, + &response.data, + )) +} + +fn build_sandbox_summary_query(experiment_id: &str, started_at: &str) -> String { + format!( + "select: scores, metrics | from: experiment('{}') summary | filter: created >= '{}' | limit: 1000", + experiment_id.replace('\'', "''"), + started_at.replace('\'', "''") + ) +} + +fn aggregate_sandbox_summary( + project_name: &str, + experiment_name: &str, + experiment_id: &str, + experiment_url: Option, + rows: &[SandboxSummaryRow], +) -> ExperimentSummary { + let mut scores: HashMap = HashMap::new(); + let mut metrics: HashMap = HashMap::new(); + for row in rows { + for (name, value) in &row.scores { + if let Some(value) = value { + let entry = scores.entry(name.clone()).or_insert((0.0, 0)); + entry.0 += value; + entry.1 += 1; + } + } + for (name, value) in &row.metrics { + let Some(number) = value.as_f64() else { + continue; + }; + let entry = metrics.entry(name.clone()).or_insert((0.0, 0)); + entry.0 += number; + entry.1 += 1; + } + } + + ExperimentSummary { + project_name: project_name.to_string(), + experiment_name: experiment_name.to_string(), + project_id: None, + experiment_id: Some(experiment_id.to_string()), + project_url: None, + experiment_url, + comparison_experiment_name: None, + scores: scores + .into_iter() + .map(|(name, (total, count))| { + let average = if count == 0 { + 0.0 + } else { + total / count as f64 + }; + ( + name.clone(), + ScoreSummary { + name, + score: average, + diff: None, + improvements: 0, + regressions: 0, + }, + ) + }) + .collect(), + metrics: if metrics.is_empty() { + None + } else { + Some( + metrics + .into_iter() + .map(|(name, (total, count))| { + let average = if count == 0 { + 0.0 + } else { + total / count as f64 + }; + ( + name.clone(), + MetricSummary { + name, + metric: average, + unit: String::new(), + diff: None, + improvements: 0, + regressions: 0, + }, + ) + }) + .collect(), + ) + }, + } +} + struct EvalPlan<'a> { language: EvalLanguage, files: &'a [String], @@ -2386,20 +3146,20 @@ fn find_binary_in_path(candidates: &[&str]) -> Option { None } -fn build_sse_socket_path() -> Result { +fn build_socket_path(prefix: &str) -> Result { let pid = std::process::id(); let serial = SSE_SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed); let now = SystemTime::now() .duration_since(UNIX_EPOCH) .context("failed to read system time")? .as_nanos(); - Ok(std::env::temp_dir().join(format!("bt-eval-{pid}-{now}-{serial}.sock"))) + Ok(std::env::temp_dir().join(format!("{prefix}-{pid}-{now}-{serial}.sock"))) } -fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { +fn bind_unix_listener(prefix: &str) -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { let mut last_bind_err: Option = None; for _ in 0..SSE_SOCKET_BIND_MAX_ATTEMPTS { - let socket_path = build_sse_socket_path()?; + let socket_path = build_socket_path(prefix)?; let socket_cleanup_guard = SocketCleanupGuard::new(socket_path.clone()); let _ = std::fs::remove_file(&socket_path); match UnixListener::bind(&socket_path) { @@ -2425,10 +3185,14 @@ fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { ) }); Err(err).context(format!( - "failed to bind SSE unix socket after {SSE_SOCKET_BIND_MAX_ATTEMPTS} attempts" + "failed to bind unix socket after {SSE_SOCKET_BIND_MAX_ATTEMPTS} attempts" )) } +fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { + bind_unix_listener("bt-eval") +} + fn eval_runner_cache_dir() -> PathBuf { let root = std::env::var_os("XDG_CACHE_HOME") .map(PathBuf::from) @@ -4016,8 +4780,8 @@ mod tests { #[test] fn build_sse_socket_path_is_unique_for_consecutive_calls() { - let first = build_sse_socket_path().expect("first socket path"); - let second = build_sse_socket_path().expect("second socket path"); + let first = build_socket_path("bt-eval").expect("first socket path"); + let second = build_socket_path("bt-eval").expect("second socket path"); assert_ne!(first, second); } @@ -4193,6 +4957,7 @@ mod tests { "BT_EVAL_TERMINATE_ON_FAILURE", "BT_EVAL_NUM_WORKERS", "BT_EVAL_LIST", + "BT_EVAL_SANDBOX", "BT_EVAL_FILTER", "BT_EVAL_VERBOSE", "BT_EVAL_WATCH", @@ -4207,6 +4972,7 @@ mod tests { set_env_var("BT_EVAL_TERMINATE_ON_FAILURE", "1"); set_env_var("BT_EVAL_NUM_WORKERS", "4"); set_env_var("BT_EVAL_LIST", "yes"); + set_env_var("BT_EVAL_SANDBOX", "lambda"); set_env_var("BT_EVAL_FILTER", "metadata.case=smoke.*,metadata.kind=fast"); set_env_var("BT_EVAL_VERBOSE", "1"); set_env_var("BT_EVAL_WATCH", "on"); @@ -4221,6 +4987,7 @@ mod tests { assert!(parsed.eval.terminate_on_failure); assert_eq!(parsed.eval.num_workers, Some(4)); assert!(parsed.eval.list); + assert_eq!(parsed.eval.sandbox, EvalSandbox::Lambda); assert_eq!( parsed.eval.filter, vec![ @@ -4239,4 +5006,18 @@ mod tests { restore_env_var(key, value); } } + + #[test] + fn build_sandbox_summary_query_includes_timestamp_filter() { + let query = build_sandbox_summary_query("exp'123", "2026-03-19T12:00:00.000Z"); + assert!(query.contains("from: experiment('exp''123') summary")); + assert!(query.contains("filter: created >= '2026-03-19T12:00:00.000Z'")); + assert!(query.contains("select: scores, metrics")); + } + + #[test] + fn sandbox_slug_from_source_uses_source_stem_and_eval_name() { + let slug = sandbox_slug_from_source(Path::new("/tmp/My Eval.ts"), "Demo Eval"); + assert_eq!(slug, "my-eval-demo-eval-sandbox"); + } } diff --git a/src/experiments/api.rs b/src/experiments/api.rs index 3507a7e..bce8a00 100644 --- a/src/experiments/api.rs +++ b/src/experiments/api.rs @@ -70,11 +70,13 @@ pub async fn create_experiment( client: &ApiClient, project_id: &str, name: &str, + ensure_new: bool, ) -> Result { let body = serde_json::json!({ "name": name, "project_id": project_id, "org_name": client.org_name(), + "ensure_new": ensure_new, }); client.post("/v1/experiment", &body).await } diff --git a/src/functions/mod.rs b/src/functions/mod.rs index eadc100..934d0f7 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -17,11 +17,12 @@ mod delete; mod invoke; mod list; mod pull; -mod push; +pub(crate) mod push; pub(crate) mod report; mod view; use api::Function; +pub(crate) use push::publish_eval_sandbox_functions; #[derive(Debug, Clone, Copy, ValueEnum)] pub enum FunctionTypeFilter { diff --git a/src/functions/push.rs b/src/functions/push.rs index 0ca6109..89ecd87 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -640,6 +640,145 @@ struct FileSuccess { bundle_id: Option, } +#[derive(Debug, Clone)] +pub(crate) struct PublishedSandboxFunction { + pub slug: String, + pub project_id: String, + pub function_id: String, +} + +pub(crate) async fn publish_eval_sandbox_functions( + base: &BaseArgs, + source_file: &Path, + runner_override: Option<&str>, + language: SourceLanguage, +) -> Result> { + let available_orgs = list_available_orgs(base) + .await + .context("failed to list available orgs")?; + validate_explicit_org_selection(base, &available_orgs)?; + let auth_ctx = resolve_auth_context(base) + .await + .context("failed to resolve auth context")?; + + let args = PushArgs { + files: vec![source_file.to_path_buf()], + file_flag: Vec::new(), + if_exists: super::IfExistsMode::Replace, + terminate_on_failure: true, + create_missing_projects: true, + runner: runner_override.map(ToOwned::to_owned), + language: match language { + SourceLanguage::JsLike => PushLanguage::JavaScript, + SourceLanguage::Python => PushLanguage::Python, + }, + requirements: None, + tsconfig: None, + external_packages: Vec::new(), + yes: true, + }; + + let input_files = args.resolved_files(); + let classified = collect_classified_files(&input_files)?; + let files = classified.files_for_language(language); + if files.is_empty() { + bail!( + "no eligible {} files found for sandbox publish: {}", + language_label(language), + source_file.display() + ); + } + + let mut manifest = + run_functions_runner(&args, &files, language, auth_ctx.client.api_key()) + .map_err(|failure| anyhow!(failure.message))?; + + for file in &mut manifest.files { + file.entries.retain(|entry| match entry { + ManifestEntry::Code(code) => code.function_type.as_deref() == Some("sandbox"), + ManifestEntry::FunctionEvent(_) => false, + }); + } + manifest + .files + .retain(|file| !file.entries.is_empty() || file.python_bundle.is_some()); + + if manifest.files.is_empty() { + bail!("no sandbox evaluators found in {}", source_file.display()); + } + + validate_manifest_paths( + &manifest, + &files, + language, + &classified.allowed_roots, + ) + .map_err(|failure| anyhow!(failure.message))?; + + let preflight = collect_project_preflight(base, &manifest)?; + let mut project_name_cache = + resolve_named_projects(&auth_ctx, &preflight.named_projects, true).await?; + validate_direct_project_ids(&auth_ctx, &preflight.direct_project_ids).await?; + let default_project_id = resolve_default_project_id(&preflight, &project_name_cache)?; + let resolved_targets = resolve_manifest_targets( + &auth_ctx, + default_project_id.as_deref(), + &manifest, + &mut project_name_cache, + true, + ) + .await?; + validate_duplicate_slugs(&resolved_targets.entries)?; + + let mut published = Vec::new(); + for (manifest_file, resolved_file) in + manifest.files.iter().zip(resolved_targets.per_file.iter()) + { + let source_path = PathBuf::from(&manifest_file.source_file); + push_file( + &auth_ctx, + default_project_id.as_deref(), + &manifest.runtime_context, + &source_path, + manifest_file, + &resolved_file.entry_project_ids, + &args, + language, + None, + &classified.allowed_roots, + &mut project_name_cache, + ) + .await + .map_err(|failure| anyhow!(failure.message))?; + + for (entry_index, entry) in manifest_file.entries.iter().enumerate() { + let ManifestEntry::Code(code) = entry else { + continue; + }; + let project_id = resolved_file + .entry_project_ids + .get(entry_index) + .cloned() + .ok_or_else(|| anyhow!("missing resolved project id for sandbox entry"))?; + let function = api::get_function_by_slug(&auth_ctx.client, &project_id, &code.slug) + .await? + .ok_or_else(|| { + anyhow!( + "sandbox function '{}' was not found after publish", + code.slug + ) + })?; + published.push(PublishedSandboxFunction { + slug: code.slug.clone(), + project_id, + function_id: function.id, + }); + } + } + + Ok(published) +} + fn default_code_location(index: usize) -> Value { json!({ "type": "function", @@ -3294,12 +3433,8 @@ mod tests { "scores": [{ "name": "accuracy" }] } }); - let value = build_code_function_data( - &runtime, - sandbox_location.clone(), - "bundle-sandbox-1", - None, - ); + let value = + build_code_function_data(&runtime, sandbox_location.clone(), "bundle-sandbox-1", None); assert_eq!(value["type"], "code"); assert_eq!(value["data"]["type"], "bundle"); @@ -3319,12 +3454,8 @@ mod tests { "eval_name": "my-eval", "position": { "type": "task" } }); - let value = build_code_function_data( - &runtime, - experiment_location.clone(), - "bundle-task-1", - None, - ); + let value = + build_code_function_data(&runtime, experiment_location.clone(), "bundle-task-1", None); assert_eq!(value["type"], "code"); assert_eq!(value["data"]["location"], experiment_location); diff --git a/src/sync.rs b/src/sync.rs index e62a856..76e60ce 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -3116,7 +3116,7 @@ async fn resolve_push_experiment_target( ); } - let created = create_experiment(client, &project.id, experiment_selector) + let created = create_experiment(client, &project.id, experiment_selector, false) .await .with_context(|| { format!("experiment '{experiment_selector}' not found, and creating it failed") diff --git a/tests/functions.rs b/tests/functions.rs index fa3de7c..419a351 100644 --- a/tests/functions.rs +++ b/tests/functions.rs @@ -2067,7 +2067,9 @@ exit 24 "expected 1 inserted function (sandbox only)" ); - let sandbox_obj = inserted[0].as_object().expect("sandbox should be an object"); + let sandbox_obj = inserted[0] + .as_object() + .expect("sandbox should be an object"); assert_eq!( sandbox_obj.get("slug").and_then(Value::as_str), Some("my-eval-my-eval-sandbox") @@ -2118,7 +2120,6 @@ exit 24 .and_then(Value::as_str), Some("my-eval") ); - } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] From 1df96cc0e0e68a4152b9e18d10a17a6310c022c4 Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Thu, 19 Mar 2026 21:00:48 +0000 Subject: [PATCH 27/28] Add tests --- src/eval.rs | 145 ++++++++++++++++++ src/functions/push.rs | 14 +- tests/eval_fixtures.rs | 329 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 478 insertions(+), 10 deletions(-) diff --git a/src/eval.rs b/src/eval.rs index ac792e4..04037f9 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -4160,6 +4160,46 @@ mod tests { eval: EvalArgs, } + fn base_args() -> BaseArgs { + BaseArgs { + json: false, + quiet: false, + no_color: false, + profile: None, + org_name: None, + project: None, + api_key: None, + prefer_profile: false, + no_input: false, + api_url: None, + app_url: None, + env_file: None, + } + } + + fn make_eval_args(files: Vec) -> EvalArgs { + EvalArgs { + files, + runner: None, + language: None, + sandbox: EvalSandbox::Local, + no_send_logs: false, + jsonl: false, + terminate_on_failure: false, + num_workers: None, + list: false, + filter: Vec::new(), + verbose: false, + watch: false, + extra_args: Vec::new(), + dev: false, + dev_host: "localhost".to_string(), + dev_port: 8300, + dev_org_name: None, + dev_allowed_origin: Vec::new(), + } + } + fn env_test_lock() -> &'static Mutex<()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) @@ -4215,6 +4255,12 @@ mod tests { path } + fn write_eval_file(dir: &Path, name: &str) -> String { + let path = dir.join(name); + fs::write(&path, "export {};").expect("eval file should be written"); + path.to_string_lossy().to_string() + } + #[test] fn join_app_url_normalizes_slashes() { let joined = @@ -5007,6 +5053,105 @@ mod tests { } } + #[test] + fn eval_args_parse_sandbox_flag() { + let parsed = + EvalArgsHarness::try_parse_from(["bt", "--sandbox", "lambda", "sample.eval.ts"]) + .expect("sandbox flag should parse"); + assert_eq!(parsed.eval.sandbox, EvalSandbox::Lambda); + assert_eq!(parsed.eval.files, vec!["sample.eval.ts".to_string()]); + } + + #[tokio::test] + async fn sandbox_eval_rejects_dev_mode() { + let dir = make_temp_dir("sandbox-dev"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.dev = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+dev should fail"); + assert!(err + .to_string() + .contains("--sandbox is not supported with --dev.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_watch_mode() { + let dir = make_temp_dir("sandbox-watch"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.watch = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+watch should fail"); + assert!(err + .to_string() + .contains("--sandbox is not supported with --watch.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_list_mode() { + let dir = make_temp_dir("sandbox-list"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.list = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+list should fail"); + assert!(err + .to_string() + .contains("--sandbox is not supported with --list.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_no_send_logs() { + let dir = make_temp_dir("sandbox-local"); + let file = write_eval_file(&dir, "sample.eval.ts"); + let mut args = make_eval_args(vec![file]); + args.sandbox = EvalSandbox::Lambda; + args.no_send_logs = true; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+no-send-logs should fail"); + assert!(err + .to_string() + .contains("--sandbox lambda is not supported with --no-send-logs.")); + + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn sandbox_eval_rejects_multiple_files() { + let dir = make_temp_dir("sandbox-multi"); + let first = write_eval_file(&dir, "first.eval.ts"); + let second = write_eval_file(&dir, "second.eval.ts"); + let mut args = make_eval_args(vec![first, second]); + args.sandbox = EvalSandbox::Lambda; + + let err = run(base_args(), args) + .await + .expect_err("sandbox+multiple files should fail"); + assert!(err + .to_string() + .contains("`bt eval --sandbox lambda` currently supports exactly one eval file.")); + + let _ = fs::remove_dir_all(&dir); + } + #[test] fn build_sandbox_summary_query_includes_timestamp_filter() { let query = build_sandbox_summary_query("exp'123", "2026-03-19T12:00:00.000Z"); diff --git a/src/functions/push.rs b/src/functions/push.rs index 89ecd87..e3979f7 100644 --- a/src/functions/push.rs +++ b/src/functions/push.rs @@ -689,9 +689,8 @@ pub(crate) async fn publish_eval_sandbox_functions( ); } - let mut manifest = - run_functions_runner(&args, &files, language, auth_ctx.client.api_key()) - .map_err(|failure| anyhow!(failure.message))?; + let mut manifest = run_functions_runner(&args, &files, language, auth_ctx.client.api_key()) + .map_err(|failure| anyhow!(failure.message))?; for file in &mut manifest.files { file.entries.retain(|entry| match entry { @@ -707,13 +706,8 @@ pub(crate) async fn publish_eval_sandbox_functions( bail!("no sandbox evaluators found in {}", source_file.display()); } - validate_manifest_paths( - &manifest, - &files, - language, - &classified.allowed_roots, - ) - .map_err(|failure| anyhow!(failure.message))?; + validate_manifest_paths(&manifest, &files, language, &classified.allowed_roots) + .map_err(|failure| anyhow!(failure.message))?; let preflight = collect_project_preflight(base, &manifest)?; let mut project_name_cache = diff --git a/tests/eval_fixtures.rs b/tests/eval_fixtures.rs index 84c63a6..1e1d5ba 100644 --- a/tests/eval_fixtures.rs +++ b/tests/eval_fixtures.rs @@ -1,6 +1,10 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fs; +#[cfg(unix)] +use std::io::Write; use std::io::{BufRead, BufReader, Read}; +#[cfg(unix)] +use std::os::unix::net::{UnixListener, UnixStream}; use std::path::{Path, PathBuf}; use std::process::{Child, Command, Stdio}; use std::sync::{Arc, Mutex, MutexGuard, OnceLock}; @@ -8,6 +12,8 @@ use std::thread; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use serde::Deserialize; +#[cfg(unix)] +use serde_json::json; use serde_json::Value; #[derive(Debug, Deserialize, Clone)] @@ -435,11 +441,334 @@ fn eval_runner_list_mode_serializes_parameter_defaults() { ); } +#[cfg(unix)] +#[test] +fn eval_runner_rows_mode_streams_js_rows_and_trials() { + let _guard = test_lock(); + if !command_exists("node") { + if required_runtimes().contains("node") { + panic!("node runtime is required but unavailable for rows-mode test"); + } + eprintln!( + "Skipping eval_runner_rows_mode_streams_js_rows_and_trials (node not installed)." + ); + return; + } + + let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let fixture_dir = root + .join("tests") + .join("evals") + .join("js") + .join("eval-ts-cjs"); + ensure_dependencies(&fixture_dir); + let runner = local_tsx_path(&fixture_dir).expect("resolve tsx runner"); + let runner_script = root.join("scripts").join("eval-runner.ts"); + let fixture_name = format!("sandbox_rows_{}.eval.ts", unique_test_suffix()); + let fixture_path = fixture_dir.join(&fixture_name); + let fixture_source = r#"import { Eval } from "braintrust"; + +async function* rows() { + yield { input: { case_id: "row-1" }, expected: "alpha" }; + yield { input: { case_id: "row-2" }, expected: "bravo" }; +} + +Eval("sandbox-rows-js", { + data: rows, + task: async (input: { case_id: string }) => + input.case_id === "row-1" ? "alpha" : "bravo", + scores: [ + ({ output, expected }: { output: string; expected?: string }) => ({ + name: "match", + score: output === expected ? 1 : 0, + }), + ], + maxConcurrency: 3, + trialCount: 2, +}); +"#; + fs::write(&fixture_path, fixture_source).expect("write js rows fixture"); + + let socket_path = unique_socket_path("bt-eval-js-rows"); + let listener = UnixListener::bind(&socket_path).expect("bind unix listener"); + listener + .set_nonblocking(true) + .expect("set listener nonblocking"); + + let mut child = Command::new(&runner) + .arg(&runner_script) + .arg(&fixture_name) + .current_dir(&fixture_dir) + .env("BT_EVAL_DEV_MODE", "rows") + .env( + "BT_EVAL_DEV_REQUEST_JSON", + json!({ + "name": "sandbox-rows-js", + "parameters": {}, + }) + .to_string(), + ) + .env("BT_EVAL_PULL_SOCK", &socket_path) + .env("BT_EVAL_LOCAL", "1") + .env("BT_EVAL_NO_SEND_LOGS", "1") + .env("BRAINTRUST_API_KEY", "local") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("spawn js eval runner"); + + let stream = accept_pull_stream(&listener, &mut child, Duration::from_secs(10)); + let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); + let mut writer = stream; + + let ready = read_pull_message(&mut reader); + assert_eq!(ready["type"], "ready"); + assert_eq!(ready["evaluator_name"], "sandbox-rows-js"); + assert_eq!(ready["max_concurrency"], 3); + assert!(ready["experiment_name"] + .as_str() + .is_some_and(|value| value.starts_with("sandbox-rows-js-"))); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let first = read_pull_message(&mut reader); + assert_eq!(first["type"], "row"); + assert_eq!(first["trial_index"], 0); + assert_eq!(first["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let second = read_pull_message(&mut reader); + assert_eq!(second["type"], "row"); + assert_eq!(second["trial_index"], 1); + assert_eq!(second["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let third = read_pull_message(&mut reader); + assert_eq!(third["type"], "row"); + assert_eq!(third["trial_index"], 0); + assert_eq!(third["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let fourth = read_pull_message(&mut reader); + assert_eq!(fourth["type"], "row"); + assert_eq!(fourth["trial_index"], 1); + assert_eq!(fourth["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let eof = read_pull_message(&mut reader); + assert_eq!(eof["type"], "eof"); + + write_pull_message(&mut writer, &json!({ "type": "close" })); + drop(writer); + drop(reader); + + let output = child.wait_with_output().expect("wait for js rows runner"); + if !output.status.success() { + panic!( + "js rows runner failed with status {}\nstdout:\n{}\nstderr:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + let _ = fs::remove_file(&socket_path); + let _ = fs::remove_file(&fixture_path); +} + +#[cfg(unix)] +#[test] +fn eval_runner_rows_mode_streams_python_rows_and_trials() { + let _guard = test_lock(); + let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let fixtures_root = root.join("tests").join("evals"); + let fixture_dir = fixtures_root.join("py").join("local_import"); + let python = match ensure_python_env(&fixtures_root.join("py")) { + Some(python) => python, + None => { + if required_runtimes().contains("python") { + panic!("python runtime unavailable for rows-mode test"); + } + eprintln!( + "Skipping eval_runner_rows_mode_streams_python_rows_and_trials (python runtime unavailable)." + ); + return; + } + }; + + let runner_script = root.join("scripts").join("eval-runner.py"); + let fixture_name = format!("sandbox_rows_{}.py", unique_test_suffix()); + let fixture_path = fixture_dir.join(&fixture_name); + let fixture_source = r#"from braintrust import Eval + +def rows(): + yield {"input": {"case_id": "row-1"}, "expected": "alpha"} + yield {"input": {"case_id": "row-2"}, "expected": "bravo"} + +def task(input): + return "alpha" if input["case_id"] == "row-1" else "bravo" + +def match(output, expected): + return {"name": "match", "score": 1 if output == expected else 0} + +Eval( + "sandbox-rows-py", + data=rows, + task=task, + scores=[match], + max_concurrency=4, + trial_count=2, +) +"#; + fs::write(&fixture_path, fixture_source).expect("write python rows fixture"); + + let socket_path = unique_socket_path("bt-eval-py-rows"); + let listener = UnixListener::bind(&socket_path).expect("bind unix listener"); + listener + .set_nonblocking(true) + .expect("set listener nonblocking"); + + let mut child = Command::new(&python) + .arg(&runner_script) + .arg(&fixture_name) + .current_dir(&fixture_dir) + .env("BT_EVAL_DEV_MODE", "rows") + .env( + "BT_EVAL_DEV_REQUEST_JSON", + json!({ + "name": "sandbox-rows-py", + "parameters": {}, + }) + .to_string(), + ) + .env("BT_EVAL_PULL_SOCK", &socket_path) + .env("BT_EVAL_LOCAL", "1") + .env("BT_EVAL_NO_SEND_LOGS", "1") + .env("BRAINTRUST_API_KEY", "local") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("spawn python eval runner"); + + let stream = accept_pull_stream(&listener, &mut child, Duration::from_secs(10)); + let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); + let mut writer = stream; + + let ready = read_pull_message(&mut reader); + assert_eq!(ready["type"], "ready"); + assert_eq!(ready["evaluator_name"], "sandbox-rows-py"); + assert_eq!(ready["max_concurrency"], 4); + assert!(ready["experiment_name"] + .as_str() + .is_some_and(|value| value.starts_with("sandbox-rows-py-"))); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let first = read_pull_message(&mut reader); + assert_eq!(first["type"], "row"); + assert_eq!(first["trial_index"], 0); + assert_eq!(first["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let second = read_pull_message(&mut reader); + assert_eq!(second["type"], "row"); + assert_eq!(second["trial_index"], 1); + assert_eq!(second["datum"]["input"]["case_id"], "row-1"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let third = read_pull_message(&mut reader); + assert_eq!(third["type"], "row"); + assert_eq!(third["trial_index"], 0); + assert_eq!(third["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let fourth = read_pull_message(&mut reader); + assert_eq!(fourth["type"], "row"); + assert_eq!(fourth["trial_index"], 1); + assert_eq!(fourth["datum"]["input"]["case_id"], "row-2"); + + write_pull_message(&mut writer, &json!({ "type": "next" })); + let eof = read_pull_message(&mut reader); + assert_eq!(eof["type"], "eof"); + + write_pull_message(&mut writer, &json!({ "type": "close" })); + drop(writer); + drop(reader); + + let output = child + .wait_with_output() + .expect("wait for python rows runner"); + if !output.status.success() { + panic!( + "python rows runner failed with status {}\nstdout:\n{}\nstderr:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + let _ = fs::remove_file(&socket_path); + let _ = fs::remove_file(&fixture_path); +} + fn read_fixture_config(path: &Path) -> FixtureConfig { let raw = fs::read_to_string(path).expect("read fixture.json"); serde_json::from_str(&raw).expect("parse fixture.json") } +#[cfg(unix)] +fn unique_test_suffix() -> String { + format!( + "{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock before epoch") + .as_nanos() + ) +} + +#[cfg(unix)] +fn unique_socket_path(prefix: &str) -> PathBuf { + std::env::temp_dir().join(format!("{prefix}-{}.sock", unique_test_suffix())) +} + +#[cfg(unix)] +fn accept_pull_stream(listener: &UnixListener, child: &mut Child, timeout: Duration) -> UnixStream { + let started = Instant::now(); + loop { + match listener.accept() { + Ok((stream, _)) => return stream, + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {} + Err(err) => panic!("accept pull stream: {err}"), + } + + if let Some(status) = child.try_wait().expect("try_wait runner") { + panic!("runner exited early with status {status}"); + } + + if started.elapsed() > timeout { + panic!("timed out waiting for runner pull socket connection"); + } + + thread::sleep(Duration::from_millis(25)); + } +} + +#[cfg(unix)] +fn read_pull_message(reader: &mut BufReader) -> Value { + let mut line = String::new(); + let read = reader.read_line(&mut line).expect("read pull message"); + assert!(read > 0, "pull channel closed unexpectedly"); + serde_json::from_str(line.trim()).expect("parse pull message json") +} + +#[cfg(unix)] +fn write_pull_message(writer: &mut UnixStream, payload: &Value) { + writer + .write_all(format!("{payload}\n").as_bytes()) + .expect("write pull message"); + writer.flush().expect("flush pull message"); +} + fn collect_deno_eval_diagnostics(dir: &Path, files: &[String]) -> Option { if !command_exists("deno") { return None; From a68e4cad758f4c03db590cc53712804c480ec93a Mon Sep 17 00:00:00 2001 From: Nate Selvidge Date: Thu, 19 Mar 2026 21:33:56 +0000 Subject: [PATCH 28/28] use unix socket for communication and split data generation to its own runner --- scripts/data-runner.py | 201 ++++++++++++ scripts/data-runner.ts | 201 ++++++++++++ scripts/eval-runner.py | 386 +---------------------- scripts/eval-runner.ts | 665 +++------------------------------------ scripts/runner-common.ts | 447 ++++++++++++++++++++++++++ scripts/runner_common.py | 225 +++++++++++++ src/eval.rs | 312 ++++++++++++------ tests/eval_fixtures.rs | 76 +++-- 8 files changed, 1374 insertions(+), 1139 deletions(-) create mode 100644 scripts/data-runner.py create mode 100644 scripts/data-runner.ts create mode 100644 scripts/runner_common.py diff --git a/scripts/data-runner.py b/scripts/data-runner.py new file mode 100644 index 0000000..35de3d5 --- /dev/null +++ b/scripts/data-runner.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import socket +import sys +import time +from dataclasses import dataclass +from typing import Any + +try: + from braintrust.util import eprint + from runner_common import call_evaluator_data, load_evaluators, to_async_iterator +except Exception as exc: # pragma: no cover - runtime guard + print( + "Unable to import the braintrust package. Please install it in your Python environment.", + file=sys.stderr, + ) + print(str(exc), file=sys.stderr) + sys.exit(1) + + +@dataclass +class PullChannel: + sock: socket.socket + + def send(self, payload: Any) -> None: + self.sock.sendall((json.dumps(payload) + "\n").encode("utf-8")) + + async def lines(self): + buffer = "" + while True: + chunk = await asyncio.to_thread(self.sock.recv, 4096) + if not chunk: + break + buffer += chunk.decode("utf-8") + while True: + newline = buffer.find("\n") + if newline == -1: + break + line = buffer[:newline].strip() + buffer = buffer[newline + 1 :] + if line: + yield line + + trailing = buffer.strip() + if trailing: + yield trailing + + def close(self) -> None: + try: + self.sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass + self.sock.close() + + +def create_pull_channel() -> PullChannel: + sock_path = os.getenv("BT_EVAL_PULL_SOCK") + if not sock_path: + raise ValueError("Missing BT_EVAL_PULL_SOCK") + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(sock_path) + return PullChannel(sock) + + +def parse_start_request(raw: str) -> str: + parsed = json.loads(raw) + if not isinstance(parsed, dict): + raise ValueError("Start request must be a JSON object.") + if parsed.get("type") != "start": + raise ValueError("Expected initial start command.") + name = parsed.get("name") + if not isinstance(name, str) or not name: + raise ValueError("Start request must include a non-empty evaluator name.") + return name + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Stream eval rows over a unix socket for bt.") + parser.add_argument("files", nargs="*", help="Eval files or directories to load.") + return parser + + +async def run(files: list[str]) -> int: + evaluators, _reporters = load_evaluators(files) + channel = create_pull_channel() + + try: + line_iter = channel.lines() + try: + start_line = await anext(line_iter) + except StopAsyncIteration: + return 0 + + try: + target_name = parse_start_request(start_line) + except Exception as exc: + channel.send({"type": "error", "message": str(exc)}) + return 1 + + evaluator_instance = next( + (candidate for candidate in evaluators if candidate.evaluator.eval_name == target_name), + None, + ) + if evaluator_instance is None: + channel.send({"type": "error", "message": f"Evaluator '{target_name}' not found"}) + return 1 + + evaluator = evaluator_instance.evaluator + raw_data, _base_experiment_name = await call_evaluator_data(evaluator.data) + data_iterator = to_async_iterator(raw_data) + iterator = data_iterator.__aiter__() + + trial_count = getattr(evaluator, "trial_count", 1) + try: + trial_count = int(trial_count) + except Exception: + trial_count = 1 + if trial_count < 1: + trial_count = 1 + + max_concurrency = getattr(evaluator, "max_concurrency", None) + try: + max_concurrency = int(max_concurrency) if max_concurrency is not None else 10 + except Exception: + max_concurrency = 10 + if max_concurrency < 1: + max_concurrency = 1 + + experiment_name = getattr(evaluator, "experiment_name", None) + if not isinstance(experiment_name, str) or not experiment_name: + experiment_name = f"{evaluator.eval_name}-{int(time.time() * 1000)}" + + channel.send( + { + "type": "ready", + "evaluator_name": evaluator.eval_name, + "max_concurrency": max_concurrency, + "experiment_name": experiment_name, + } + ) + + current_datum = None + trial_index = 0 + async for line in line_iter: + parsed = json.loads(line) + command_type = parsed.get("type") if isinstance(parsed, dict) else None + if command_type == "close": + break + if command_type != "next": + channel.send( + { + "type": "error", + "message": f"Unsupported pull command '{command_type}'", + } + ) + return 1 + + if current_datum is None: + try: + current_datum = await iterator.__anext__() + trial_index = 0 + except StopAsyncIteration: + channel.send({"type": "eof"}) + continue + + channel.send( + { + "type": "row", + "datum": current_datum, + "trial_index": trial_index, + } + ) + trial_index += 1 + if trial_index >= trial_count: + current_datum = None + + return 0 + finally: + channel.close() + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + files = args.files or ["."] + + try: + return asyncio.run(run(files)) + except Exception as exc: + eprint(str(exc)) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/data-runner.ts b/scripts/data-runner.ts new file mode 100644 index 0000000..53cb81e --- /dev/null +++ b/scripts/data-runner.ts @@ -0,0 +1,201 @@ +import net from "node:net"; +import readline from "node:readline"; + +import { + callEvaluatorData, + formatError, + getBraintrustStateGetter, + getEvaluators, + initRegistry, + loadBraintrust, + loadFiles, + normalizeFiles, + propagateInheritedBraintrustState, + toAsyncIterable, +} from "./runner-common"; + +type StartMessage = { + type: "start"; + name: string; +}; + +type ClientMessage = + | StartMessage + | { type: "next" } + | { type: "close" }; + +type ServerMessage = + | { + type: "ready"; + evaluator_name: string; + max_concurrency: number; + experiment_name: string; + } + | { type: "row"; datum: unknown; trial_index: number } + | { type: "eof" } + | { type: "error"; message: string }; + +function writeMessage(socket: net.Socket, message: ServerMessage) { + socket.write(`${JSON.stringify(message)}\n`); +} + +function parseMessage(line: string): ClientMessage { + const parsed = JSON.parse(line) as { type?: unknown; name?: unknown }; + if (parsed.type === "start") { + if (typeof parsed.name !== "string" || parsed.name.length === 0) { + throw new Error("Start request must include a non-empty evaluator name."); + } + return { type: "start", name: parsed.name }; + } + if (parsed.type === "next" || parsed.type === "close") { + return { type: parsed.type }; + } + throw new Error(`Unsupported pull command '${String(parsed.type)}'`); +} + +async function readMessage( + lines: AsyncIterator, +): Promise { + const next = await lines.next(); + if (next.done) { + return null; + } + return parseMessage(next.value); +} + +function applyExtraArgsFromEnv() { + const extraArgs: string[] = process.env.BT_EVAL_EXTRA_ARGS_JSON + ? (JSON.parse(process.env.BT_EVAL_EXTRA_ARGS_JSON) as string[]) + : []; + process.argv = [...process.argv.slice(0, 2), ...extraArgs]; +} + +function toPositiveInteger(value: unknown, fallback: number): number { + const parsed = Number(value); + if (Number.isFinite(parsed) && parsed > 0) { + return Math.floor(parsed); + } + return fallback; +} + +async function main() { + const files = process.argv.slice(2); + if (files.length === 0) { + throw new Error("No eval files provided."); + } + const socketPath = process.env.BT_EVAL_PULL_SOCK; + if (!socketPath) { + throw new Error("Missing BT_EVAL_PULL_SOCK"); + } + + const normalized = normalizeFiles(files); + const braintrust = await loadBraintrust(normalized); + propagateInheritedBraintrustState(braintrust); + initRegistry(); + applyExtraArgsFromEnv(); + await loadFiles(normalized); + + const socket = net.createConnection({ path: socketPath }); + const socketReady = new Promise((resolve, reject) => { + socket.once("connect", resolve); + socket.once("error", reject); + }); + await socketReady; + + const reader = readline.createInterface({ + input: socket, + crlfDelay: Infinity, + }); + const lines = reader[Symbol.asyncIterator](); + + try { + const start = await readMessage(lines); + if (!start) { + return; + } + if (start.type !== "start") { + throw new Error("Expected initial start command."); + } + + const entry = getEvaluators().find( + (candidate) => candidate.evaluator.evalName === start.name, + ); + if (!entry) { + writeMessage(socket, { + type: "error", + message: `Evaluator '${start.name}' not found`, + }); + return; + } + + const getState = getBraintrustStateGetter(braintrust); + const state = getState ? getState() : undefined; + const evaluator = { + ...entry.evaluator, + ...(state !== undefined && state !== null ? { state } : {}), + }; + const { data: rawData } = callEvaluatorData(evaluator.data); + const dataIterable = toAsyncIterable(rawData); + const iterator = dataIterable[Symbol.asyncIterator](); + const trialCount = toPositiveInteger(evaluator.trialCount, 1); + const maxConcurrency = toPositiveInteger(evaluator.maxConcurrency, 10); + const experimentName = + typeof evaluator.experimentName === "string" && + evaluator.experimentName.length > 0 + ? evaluator.experimentName + : `${entry.evaluator.evalName}-${Date.now()}`; + + writeMessage(socket, { + type: "ready", + evaluator_name: entry.evaluator.evalName, + max_concurrency: maxConcurrency, + experiment_name: experimentName, + }); + + let currentDatum: unknown | undefined; + let trialIndex = 0; + while (true) { + const message = await readMessage(lines); + if (!message || message.type === "close") { + return; + } + if (message.type !== "next") { + throw new Error(`Unsupported pull command '${message.type}'`); + } + + if (currentDatum === undefined) { + const next = await iterator.next(); + if (next.done) { + writeMessage(socket, { type: "eof" }); + continue; + } + currentDatum = next.value; + trialIndex = 0; + } + + writeMessage(socket, { + type: "row", + datum: currentDatum, + trial_index: trialIndex, + }); + + trialIndex += 1; + if (trialIndex >= trialCount) { + currentDatum = undefined; + } + } + } catch (err) { + writeMessage(socket, { + type: "error", + message: formatError(err), + }); + } finally { + reader.close(); + socket.end(); + } +} + +main().catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/scripts/eval-runner.py b/scripts/eval-runner.py index d8ac067..776aa1c 100755 --- a/scripts/eval-runner.py +++ b/scripts/eval-runner.py @@ -1,26 +1,19 @@ #!/usr/bin/env python3 import argparse import asyncio -import fnmatch -import importlib.util import inspect import json import os -import re import socket import sys -import time import traceback from dataclasses import dataclass -from typing import Any, AsyncIterator, Callable +from typing import Any, Callable try: from braintrust import init_dataset, invoke, login from braintrust.framework import ( BaseExperiment, - EvaluatorInstance, - _evals, - _set_lazy_load, run_evaluator, set_thread_pool_max_workers, ) @@ -28,6 +21,15 @@ from braintrust.parameters import parameters_to_json_schema, validate_parameters from braintrust.util import eprint from braintrust.span_identifier_v4 import parse_parent + from runner_common import ( + EvalFilter, + EvaluatorInstance, + call_evaluator_data, + env_flag, + filter_evaluators, + load_evaluators, + parse_serialized_filters, + ) except Exception as exc: # pragma: no cover - runtime guard print( "Unable to import the braintrust package. Please install it in your Python environment.", @@ -45,14 +47,6 @@ "/venv/", ) _DATASET_TOTAL_CACHE: dict[str, int] = {} - - -@dataclass(frozen=True) -class EvalFilter: - path: list[str] - pattern: re.Pattern[str] - - @dataclass(frozen=True) class RunnerConfig: jsonl: bool @@ -80,41 +74,6 @@ def close(self) -> None: self.sock.close() -@dataclass -class PullChannel: - sock: socket.socket - - def send(self, payload: Any) -> None: - self.sock.sendall((json.dumps(payload) + "\n").encode("utf-8")) - - async def lines(self) -> AsyncIterator[str]: - buffer = "" - while True: - chunk = await asyncio.to_thread(self.sock.recv, 4096) - if not chunk: - break - buffer += chunk.decode("utf-8") - while True: - newline = buffer.find("\n") - if newline == -1: - break - line = buffer[:newline].strip() - buffer = buffer[newline + 1 :] - if line: - yield line - - trailing = buffer.strip() - if trailing: - yield trailing - - def close(self) -> None: - try: - self.sock.shutdown(socket.SHUT_RDWR) - except OSError: - pass - self.sock.close() - - def serialize_sse_event(event: str, data: Any) -> str: if isinstance(data, (dict, list)): data_str = json.dumps(data) @@ -139,51 +98,10 @@ def create_sse_writer() -> SseWriter | None: return SseWriter(sock) return None - - -def create_pull_channel() -> PullChannel | None: - sock_path = os.getenv("BT_EVAL_PULL_SOCK") - if not sock_path: - return None - - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(sock_path) - return PullChannel(sock) - - -def env_flag(name: str) -> bool: - value = os.getenv(name) - if value is None: - return False - return value.lower() not in {"0", "false", "no", "off", ""} - - -def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]: - if not serialized: - return [] - - parsed = json.loads(serialized) - if not isinstance(parsed, list): - raise ValueError("BT_EVAL_FILTER_PARSED must be a JSON array") - - filters: list[EvalFilter] = [] - for i, entry in enumerate(parsed): - if not isinstance(entry, dict): - raise ValueError("BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}") - key_path = entry.get("path") - pattern = entry.get("pattern") - if not isinstance(key_path, list) or not all(isinstance(part, str) for part in key_path): - raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} path must be an array of strings") - if not isinstance(pattern, str): - raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} pattern must be a string") - filters.append(EvalFilter(path=key_path, pattern=re.compile(pattern))) - return filters - - def parse_dev_mode(value: str | None) -> str | None: if value is None or value == "": return None - if value in {"list", "eval", "rows"}: + if value in {"list", "eval"}: return value raise ValueError(f"Invalid BT_EVAL_DEV_MODE value: {value}") @@ -201,46 +119,6 @@ def read_runner_config() -> RunnerConfig: dev_request_json=os.getenv("BT_EVAL_DEV_REQUEST_JSON"), ) - -def _to_mapping(value: Any) -> Any: - if isinstance(value, dict): - return {k: _to_mapping(v) for k, v in value.items()} - if isinstance(value, list): - return [_to_mapping(v) for v in value] - if hasattr(value, "__dict__"): - return { - key: _to_mapping(val) - for key, val in vars(value).items() - if not key.startswith("_") - } - return value - - -def serialize_json_with_plain_string(value: Any) -> str: - if isinstance(value, str): - return value - return json.dumps(value) - - -def evaluate_filter(value: Any, filt: EvalFilter) -> bool: - current = _to_mapping(value) - for part in filt.path: - if not isinstance(current, dict) or part not in current: - return False - current = current[part] - return bool(filt.pattern.search(serialize_json_with_plain_string(current))) - - -def filter_evaluators(evaluators: list[EvaluatorInstance], filters: list[EvalFilter]) -> list[EvaluatorInstance]: - if not filters: - return evaluators - return [ - evaluator - for evaluator in evaluators - if all(evaluate_filter(evaluator.evaluator, filt) for filt in filters) - ] - - def snake_to_camel(value: str) -> str: parts = value.split("_") if not parts: @@ -348,26 +226,6 @@ def parse_eval_request(raw: str | None) -> dict[str, Any]: return parsed -def parse_eval_pull_request(raw: str | None) -> dict[str, Any]: - if not raw: - raise ValueError("Missing BT_EVAL_DEV_REQUEST_JSON") - try: - parsed = json.loads(raw) - except json.JSONDecodeError as exc: - raise ValueError(f"Invalid BT_EVAL_DEV_REQUEST_JSON: {exc}") from exc - - if not isinstance(parsed, dict): - raise ValueError("BT_EVAL_DEV_REQUEST_JSON must be a JSON object.") - if not isinstance(parsed.get("name"), str) or not parsed["name"]: - raise ValueError("Pull request must include a non-empty evaluator name.") - - parameters = parsed.get("parameters") - if parameters is not None and not isinstance(parameters, dict): - raise ValueError("Pull request parameters must be an object.") - - return parsed - - def resolve_eval_data(data: dict[str, Any]) -> Any: if "data" in data: return data["data"] @@ -390,33 +248,6 @@ def resolve_eval_data(data: dict[str, Any]) -> Any: raise ValueError("Invalid eval data payload.") -async def call_evaluator_data(data: Any) -> tuple[Any, str | None]: - data_result = data - if inspect.isclass(data_result): - data_result = data_result() - if inspect.isfunction(data_result) or inspect.isroutine(data_result): - data_result = data_result() - if inspect.isawaitable(data_result): - data_result = await data_result - - base_experiment_name = None - if isinstance(data_result, BaseExperiment): - base_experiment_name = data_result.name - - return data_result, base_experiment_name - - -def to_async_iterator(value: Any) -> AsyncIterator[Any]: - if inspect.isasyncgen(value): - return value - - async def to_async(it): - for item in it: - yield item - - return to_async(value) - - def make_eval_scorer( score: dict[str, Any], project_id: str | None, @@ -457,16 +288,6 @@ def build_eval_definitions(evaluator_instances: list[EvaluatorInstance]) -> dict return definitions -def collect_files(input_path: str) -> list[str]: - if os.path.isdir(input_path): - matches: list[str] = [] - for root, _, files in os.walk(input_path): - for filename in files: - matches.append(os.path.join(root, filename)) - return matches - return [input_path] - - def is_watchable_dependency(path_input: str, cwd: str) -> bool: path = os.path.abspath(path_input) normalized = path.replace("\\", "/") @@ -502,86 +323,6 @@ def collect_dependency_files(cwd: str, input_files: list[str]) -> list[str]: return sorted(dependencies) -def resolve_module_info(in_file: str) -> tuple[str, list[str]]: - in_file = os.path.abspath(in_file) - module_dir = os.path.dirname(in_file) - module_name = os.path.splitext(os.path.basename(in_file))[0] - - package_parts: list[str] = [] - current = module_dir - while os.path.isfile(os.path.join(current, "__init__.py")): - package_parts.insert(0, os.path.basename(current)) - current = os.path.dirname(current) - - extra_paths = [module_dir] - if package_parts: - module_name = ".".join(package_parts + [module_name]) - if current not in extra_paths: - extra_paths.append(current) - - return module_name, extra_paths - - -def load_evaluators(files: list[str]) -> tuple[list[EvaluatorInstance], dict[str, Any]]: - evaluator_instances: list[EvaluatorInstance] = [] - reporters: dict[str, Any] = {} - cwd = os.getcwd() - if cwd not in sys.path: - sys.path.insert(0, cwd) - - # Add the project root inferred from input files to sys.path so that - # sibling-package imports work when files live outside CWD (e.g. - # sandbox bundles extracted to a temp directory). Walk up from each - # file's directory looking for a register.py (bundle marker) or the - # filesystem root, whichever comes first. - for f in files: - d = os.path.dirname(os.path.abspath(f)) - while d and d != os.path.dirname(d): - if os.path.isfile(os.path.join(d, "register.py")): - if d not in sys.path: - sys.path.insert(0, d) - break - d = os.path.dirname(d) - - unique_files: set[str] = set() - for file_path in files: - for candidate in collect_files(file_path): - unique_files.add(os.path.abspath(candidate)) - - for file_path in sorted(unique_files): - module_name, extra_paths = resolve_module_info(file_path) - with _set_lazy_load(True): - _evals.clear() - try: - for extra_path in reversed(extra_paths): - if extra_path not in sys.path: - sys.path.insert(0, extra_path) - - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None or spec.loader is None: - raise ImportError(f"Unable to load module spec for {file_path}") - - sys.modules.pop(module_name, None) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - - evaluator_instances.extend( - [ - instance - for instance in _evals.evaluators.values() - if isinstance(instance, EvaluatorInstance) - ] - ) - for reporter_name, reporter in _evals.reporters.items(): - if reporter_name not in reporters: - reporters[reporter_name] = reporter - finally: - _evals.clear() - - return evaluator_instances, reporters - - def resolve_reporter( reporter: Any, reporters: dict[str, Any], @@ -944,108 +685,6 @@ async def run_requested_eval( return True -async def run_dataset_pull( - evaluator_instances: list[EvaluatorInstance], - config: RunnerConfig, -) -> bool: - channel = create_pull_channel() - if channel is None: - raise ValueError("Missing BT_EVAL_PULL_SOCK") - - try: - request = parse_eval_pull_request(config.dev_request_json) - except Exception as exc: - channel.send({"type": "error", "message": str(exc)}) - channel.close() - return False - - target_name = request["name"] - evaluator_instance = next( - (candidate for candidate in evaluator_instances if candidate.evaluator.eval_name == target_name), - None, - ) - if evaluator_instance is None: - channel.send({"type": "error", "message": f"Evaluator '{target_name}' not found"}) - channel.close() - return False - - evaluator = evaluator_instance.evaluator - try: - raw_data, _base_experiment_name = await call_evaluator_data(evaluator.data) - data_iterator = to_async_iterator(raw_data) - iterator = data_iterator.__aiter__() - trial_count = getattr(evaluator, "trial_count", 1) - try: - trial_count = int(trial_count) - except Exception: - trial_count = 1 - if trial_count < 1: - trial_count = 1 - - max_concurrency = getattr(evaluator, "max_concurrency", None) - try: - max_concurrency = int(max_concurrency) if max_concurrency is not None else 10 - except Exception: - max_concurrency = 10 - if max_concurrency < 1: - max_concurrency = 1 - - experiment_name = getattr(evaluator, "experiment_name", None) - if not isinstance(experiment_name, str) or not experiment_name: - experiment_name = f"{evaluator.eval_name}-{int(time.time() * 1000)}" - - channel.send( - { - "type": "ready", - "evaluator_name": evaluator.eval_name, - "max_concurrency": max_concurrency, - "experiment_name": experiment_name, - } - ) - - current_datum = None - trial_index = 0 - async for line in channel.lines(): - parsed = json.loads(line) - command_type = parsed.get("type") if isinstance(parsed, dict) else None - if command_type == "close": - break - if command_type != "next": - channel.send( - { - "type": "error", - "message": f"Unsupported pull command '{command_type}'", - } - ) - break - - if current_datum is None: - try: - current_datum = await iterator.__anext__() - trial_index = 0 - except StopAsyncIteration: - channel.send({"type": "eof"}) - continue - - channel.send( - { - "type": "row", - "datum": current_datum, - "trial_index": trial_index, - } - ) - trial_index += 1 - if trial_index >= trial_count: - current_datum = None - except Exception as exc: - channel.send({"type": "error", "message": str(exc)}) - channel.close() - return False - - channel.close() - return True - - async def run_once( files: list[str], no_send_logs: bool, @@ -1067,9 +706,6 @@ async def run_once( return True if config.dev_mode == "eval": return await run_requested_eval(evaluators, reporters, no_send_logs, sse, config) - if config.dev_mode == "rows": - return await run_dataset_pull(evaluators, config) - if config.list_only: for evaluator_instance in evaluators: print(evaluator_instance.evaluator.eval_name) diff --git a/scripts/eval-runner.ts b/scripts/eval-runner.ts index fdd7152..40f4b86 100644 --- a/scripts/eval-runner.ts +++ b/scripts/eval-runner.ts @@ -2,19 +2,27 @@ import { createRequire } from "node:module"; import path from "node:path"; import { fileURLToPath, pathToFileURL } from "node:url"; -type EvaluatorDefinition = { - evalName: string; - projectName: string; - data?: unknown; - trialCount?: unknown; - maxConcurrency?: unknown; - experimentName?: unknown; -} & Record; - -type EvaluatorEntry = { - evaluator: EvaluatorDefinition; - reporter?: unknown; -}; +import { + envFlag, + filterEvaluators, + formatError, + getBraintrustStateGetter, + getEvaluators, + getReporters, + initRegistry, + isObject, + loadBraintrust, + loadBraintrustUtilParseParent, + loadFiles, + normalizeFiles, + parseSerializedFilters, + propagateInheritedBraintrustState, + resolveBraintrustPath, + type EvalFilter, + type EvaluatorEntry, + type GlobalEvals, + type ParseParentFunction, +} from "./runner-common"; type EvalResult = { results: Array<{ error?: unknown }>; @@ -49,23 +57,6 @@ type InitDatasetFunction = ( ) => unknown; type InvokeFunction = (options: Record) => Promise; -type BraintrustModule = { - Eval?: EvalFunction; - login?: LoginFunction; - initDataset?: InitDatasetFunction; - invoke?: InvokeFunction; - _internalGetGlobalState?: () => unknown; - default?: BraintrustModule; -}; - -type GlobalEvals = { - functions: unknown[]; - prompts: unknown[]; - parameters: unknown[]; - evaluators: Record; - reporters: Record; -}; - type BtEvalMain = (context: BtEvalContext) => void | Promise; type BtEvalContext = { @@ -89,16 +80,6 @@ type SseWriter = { close: () => void; }; -type EvalFilter = { - path: string[]; - pattern: RegExp; -}; - -type SerializedEvalFilter = { - path: string[]; - pattern: string; -}; - type EvalScoreSpec = { name: string; function_id: Record; @@ -117,17 +98,12 @@ type EvalRequest = { scores?: EvalScoreSpec[]; }; -type EvalPullRequest = { - name: string; - parameters?: Record; -}; - type RunnerConfig = { jsonl: boolean; list: boolean; terminateOnFailure: boolean; filters: EvalFilter[]; - devMode: "list" | "eval" | "rows" | null; + devMode: "list" | "eval" | null; devRequestJson: string | null; }; @@ -156,8 +132,6 @@ type EvalRunner = { type ParameterContainerSerializer = (parameters: unknown) => unknown; type PromptDefinitionSerializer = (prompt: unknown) => unknown; type ZodSchemaSerializer = (schema: unknown) => Record; -type ParseParentFunction = (parent: unknown) => string | undefined; - type ParameterSerializationHelpers = { sdkSerializeParameters: ParameterContainerSerializer | null; promptDefinitionToPromptData: PromptDefinitionSerializer | null; @@ -173,94 +147,13 @@ declare global { var __inherited_braintrust_state: unknown; } -function isObject(value: unknown): value is Record { - return typeof value === "object" && value !== null; -} - -function isBraintrustModule(value: unknown): value is BraintrustModule { - return isObject(value) && ("Eval" in value || "login" in value); -} - -function normalizeBraintrustModule(value: unknown): BraintrustModule { - if (isBraintrustModule(value)) { - return value; - } - if (isObject(value) && isBraintrustModule(value.default)) { - return value.default; - } - throw new Error("Unable to load braintrust module."); -} - -function normalizeFiles(files: string[]): string[] { - return files.map((file) => path.resolve(process.cwd(), file)); -} - -function envFlag(name: string): boolean { - const value = process.env[name]; - if (!value) { - return false; - } - const normalized = value.toLowerCase(); - return !["0", "false", "no", "off", ""].includes(normalized); -} - -function serializeJSONWithPlainString(value: unknown): string { - if (typeof value === "string") { - return value; - } - return JSON.stringify(value); -} - -function parseSerializedFilters(serialized: string | undefined): EvalFilter[] { - if (!serialized) { - return []; - } - - try { - const parsed = JSON.parse(serialized); - if (!Array.isArray(parsed)) { - throw new Error("BT_EVAL_FILTER_PARSED must be a JSON array."); - } - return parsed.map((value) => { - if (!isObject(value)) { - throw new Error( - "BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}.", - ); - } - const { path: rawPath, pattern: rawPattern } = - value as SerializedEvalFilter; - if ( - !Array.isArray(rawPath) || - !rawPath.every((part) => typeof part === "string") - ) { - throw new Error( - "BT_EVAL_FILTER_PARSED entry path must be an array of strings.", - ); - } - if (typeof rawPattern !== "string") { - throw new Error( - "BT_EVAL_FILTER_PARSED entry pattern must be a string.", - ); - } - return { - path: rawPath, - pattern: new RegExp(rawPattern), - }; - }); - } catch (err) { - throw new Error( - `Invalid BT_EVAL_FILTER_PARSED value: ${err instanceof Error ? err.message : String(err)}`, - ); - } -} - function parseDevMode( value: string | undefined, -): "list" | "eval" | "rows" | null { +): "list" | "eval" | null { if (!value) { return null; } - if (value === "list" || value === "eval" || value === "rows") { + if (value === "list" || value === "eval") { return value; } throw new Error(`Invalid BT_EVAL_DEV_MODE value: ${value}`); @@ -780,128 +673,6 @@ function createSseWriter(): SseWriter | null { return { send, close }; } -type PullChannel = { - send: (payload: unknown) => void; - close: () => void; - lines: () => AsyncGenerator; -}; - -function createPullChannel(): PullChannel | null { - const netModule = (() => { - try { - return runtimeRequire("node:net") as NetModule; - } catch { - return null; - } - })(); - const sock = process.env.BT_EVAL_PULL_SOCK; - if (!sock) { - return null; - } - if (!netModule) { - return null; - } - - const socket = netModule.createConnection({ path: sock }); - socket.setNoDelay(true); - - const send = (payload: unknown) => { - if (!socket.writable) { - return; - } - socket.write(`${JSON.stringify(payload)}\n`); - }; - - const close = () => { - socket.end(); - }; - - const lines = async function* () { - let buffer = ""; - for await (const chunk of socket as unknown as AsyncIterable< - Buffer | string - >) { - buffer += typeof chunk === "string" ? chunk : chunk.toString("utf8"); - while (true) { - const newline = buffer.indexOf("\n"); - if (newline === -1) { - break; - } - const line = buffer.slice(0, newline).trim(); - buffer = buffer.slice(newline + 1); - if (line.length > 0) { - yield line; - } - } - } - const trailing = buffer.trim(); - if (trailing.length > 0) { - yield trailing; - } - }; - - return { send, close, lines }; -} - -function initRegistry() { - globalThis._evals = { - functions: [], - prompts: [], - parameters: [], - evaluators: {}, - reporters: {}, - }; - globalThis._lazy_load = true; -} - -function ensureBraintrustAvailable() { - resolveBraintrustPath(); -} - -function resolveBraintrustPath(): string { - const files = normalizeFiles(process.argv.slice(2)); - for (const file of files) { - try { - const require = createRequire(pathToFileURL(file).href); - return require.resolve("braintrust"); - } catch { - continue; - } - } - - try { - const require = createRequire(process.cwd() + "/"); - return require.resolve("braintrust"); - } catch { - const message = - "Unable to resolve the `braintrust` package. " + - "Please install it in your project (e.g. `pnpm add braintrust` or `npm install braintrust`)."; - throw new Error(message); - } -} - -async function loadBraintrust() { - const cjsPath = resolveBraintrustPath(); - const cjsUrl = pathToFileURL(cjsPath).href; - - try { - const mod: unknown = await import(cjsUrl); - return normalizeBraintrustModule(mod); - } catch {} - - const esmPath = cjsPath.replace(/\.js$/, ".mjs"); - if (esmPath !== cjsPath && fsMutable.existsSync(esmPath)) { - try { - const mod: unknown = await import(pathToFileURL(esmPath).href); - return normalizeBraintrustModule(mod); - } catch {} - } - - const require = createRequire(cjsUrl); - const mod: unknown = require(cjsPath); - return normalizeBraintrustModule(mod); -} - function extractParameterSerializer( mod: unknown, ): ParameterContainerSerializer | null { @@ -1041,7 +812,7 @@ function loadZodSchemaSerializer( } async function loadParameterSerializationHelpers(): Promise { - const braintrustPath = resolveBraintrustPath(); + const braintrustPath = resolveBraintrustPath(process.argv.slice(2)); const zodToJsonSchema = loadZodSchemaSerializer(braintrustPath); try { const mod: unknown = await import(pathToFileURL(braintrustPath).href); @@ -1059,169 +830,6 @@ async function loadParameterSerializationHelpers(): Promise unknown) | null { - if (!isObject(mod)) { - return null; - } - const candidate = Reflect.get(mod, "_internalGetGlobalState"); - if (typeof candidate === "function") { - return candidate as () => unknown; - } - const defaultExport = Reflect.get(mod, "default"); - if (isObject(defaultExport)) { - const fromDefault = Reflect.get(defaultExport, "_internalGetGlobalState"); - if (typeof fromDefault === "function") { - return fromDefault as () => unknown; - } - } - return null; -} - -function loadBraintrustUtilParseParent(): ParseParentFunction | null { - const braintrustPath = resolveBraintrustPath(); - const requireFromBraintrust = createRequire( - pathToFileURL(braintrustPath).href, - ); - try { - const utilMod: unknown = requireFromBraintrust("braintrust/util"); - return extractParseParent(utilMod); - } catch { - return null; - } -} - -function propagateInheritedBraintrustState(braintrust: BraintrustModule) { - const getter = (braintrust as Record) - ._internalGetGlobalState; - if (typeof getter !== "function") { - return; - } - const state = getter(); - if (state !== undefined && state !== null) { - globalThis.__inherited_braintrust_state = state; - } -} - -async function loadFiles(files: string[]): Promise { - const modules: unknown[] = []; - // Internal CLI-controlled flag for ESM retry; not user-facing config. - const forceEsm = envFlag("BT_EVAL_FORCE_ESM"); - // vite-node installs transform hooks that handle TypeScript (including - // extension-less imports) and CJS named-export interop natively. A failed - // require() corrupts Node's module cache and causes the subsequent import() - // to hit the "module imported again after being required" bug, so we skip - // require() for .ts/.tsx files entirely when running under vite-node. - const isViteNode = process.env.BT_EVAL_RUNNER_KIND === "vite-node"; - for (const file of files) { - const fileUrl = pathToFileURL(file).href; - const isTypeScript = file.endsWith(".ts") || file.endsWith(".tsx"); - const preferRequire = - !forceEsm && - !(isViteNode && isTypeScript) && - (isTypeScript || file.endsWith(".cjs")); - - if (preferRequire) { - try { - const require = createRequire(fileUrl); - const mod = require(file); - modules.push(mod); - continue; - } catch (requireErr) { - try { - const mod = await import(fileUrl); - modules.push(mod); - continue; - } catch (esmErr) { - throw new Error( - `Failed to load ${file} as CJS (${formatError(requireErr)}) or ESM (${formatError(esmErr)}).`, - ); - } - } - } - - try { - const mod = await import(fileUrl); - modules.push(mod); - continue; - } catch (err) { - if (!shouldTryRequire(file, err)) { - throw err; - } - try { - const require = createRequire(fileUrl); - const mod = require(file); - modules.push(mod); - continue; - } catch (requireErr) { - throw new Error( - `Failed to load ${file} as ESM (${formatError(err)}) or CJS (${formatError(requireErr)}).`, - ); - } - } - } - return modules; -} - -function shouldTryRequire(file: string, err: unknown): boolean { - if (envFlag("BT_EVAL_FORCE_ESM")) { - return false; - } - if (process.env.BT_EVAL_RUNNER_KIND === "vite-node") { - return false; - } - if (process.env.BT_EVAL_CJS === "1" || file.endsWith(".cjs")) { - return true; - } - if ( - (file.endsWith(".ts") || file.endsWith(".tsx")) && - isNodeErrorCode(err, "ERR_UNKNOWN_FILE_EXTENSION") - ) { - return true; - } - if (!(err instanceof Error)) { - return false; - } - const message = err.message || ""; - return ( - message.includes("require is not defined") || - message.includes("exports is not defined") || - message.includes("module is not defined") || - message.includes("Cannot use import statement outside a module") - ); -} - -function isNodeErrorCode(err: unknown, code: string): boolean { - if (!isObject(err) || !("code" in err)) { - return false; - } - return typeof err.code === "string" && err.code === code; -} - -function formatError(err: unknown): string { - if (err instanceof Error) { - return err.message; - } - return String(err); -} - function createEvalProgressReporter( sse: SseWriter | null, evaluatorName: string, @@ -1293,22 +901,6 @@ function sendConsole( sse.send("console", { stream, message }); } -function getEvaluators(): EvaluatorEntry[] { - const evals = globalThis._evals; - if (!evals || !evals.evaluators) { - return []; - } - return Object.values(evals.evaluators) as EvaluatorEntry[]; -} - -function getReporters(): Record { - const evals = globalThis._evals; - if (!evals || !evals.reporters) { - return {}; - } - return evals.reporters as Record; -} - function resolveReporter( reporter: unknown, reporters: Record, @@ -1336,34 +928,6 @@ function resolveReporter( ); } -function evaluateFilter( - object: Record, - filter: EvalFilter, -): boolean { - const key = filter.path.reduce((acc, part) => { - if (!isObject(acc)) { - return undefined; - } - return acc[part]; - }, object); - if (key === undefined) { - return false; - } - return filter.pattern.test(serializeJSONWithPlainString(key)); -} - -function filterEvaluators( - evaluators: EvaluatorEntry[], - filters: EvalFilter[], -): EvaluatorEntry[] { - if (filters.length === 0) { - return evaluators; - } - return evaluators.filter((entry) => - filters.every((filter) => evaluateFilter(entry.evaluator, filter)), - ); -} - function extractScoreName(score: unknown, idx: number): string { if (typeof score === "function" && typeof score.name === "string") { return score.name || `scorer_${idx}`; @@ -1482,24 +1046,6 @@ async function buildEvaluatorDefinitions(evaluators: EvaluatorEntry[]) { return result; } -function parseEvalPullRequest(raw: string | null): EvalPullRequest { - if (!raw) { - throw new Error("Missing BT_EVAL_DEV_REQUEST_JSON"); - } - const parsed = JSON.parse(raw); - if (!isObject(parsed) || typeof parsed.name !== "string" || parsed.name.length === 0) { - throw new Error("Pull request must include a non-empty evaluator name."); - } - const request = parsed as EvalPullRequest; - if ( - request.parameters !== undefined && - (!isObject(request.parameters) || Array.isArray(request.parameters)) - ) { - throw new Error("Pull request parameters must be an object."); - } - return request; -} - function parseEvalRequest(raw: string | null): EvalRequest { if (!raw) { throw new Error("Missing BT_EVAL_DEV_REQUEST_JSON"); @@ -1582,48 +1128,6 @@ function resolveEvalData( throw new Error("Invalid eval data payload."); } -function callEvaluatorData( - data: unknown, -): { data: unknown; baseExperiment: string | undefined } { - const dataResult = typeof data === "function" ? (data as () => unknown)() : data; - let baseExperiment: string | undefined = undefined; - if ( - isObject(dataResult) && - Reflect.get(dataResult, "_type") === "BaseExperiment" && - typeof Reflect.get(dataResult, "name") === "string" - ) { - baseExperiment = Reflect.get(dataResult, "name") as string; - } - return { data: dataResult, baseExperiment }; -} - -function toAsyncIterable(value: unknown): AsyncIterable { - if ( - typeof value === "object" && - value !== null && - Symbol.asyncIterator in value && - typeof (value as AsyncIterable)[Symbol.asyncIterator] === "function" - ) { - return value as AsyncIterable; - } - if ( - typeof value === "object" && - value !== null && - Symbol.iterator in value && - typeof (value as Iterable)[Symbol.iterator] === "function" - ) { - const iterable = value as Iterable; - return (async function* () { - for (const item of iterable) { - yield item; - } - })(); - } - throw new Error( - "Evaluator data must be an array, iterable, or async iterable", - ); -} - function convertFunctionId( functionId: Record, ): Record { @@ -1798,102 +1302,6 @@ async function runRequestedEval(config: RunnerConfig, runner: EvalRunner) { } } -async function runDatasetPull(config: RunnerConfig, runner: EvalRunner) { - const channel = createPullChannel(); - if (!channel) { - throw new Error("Missing BT_EVAL_PULL_SOCK"); - } - - try { - const request = parseEvalPullRequest(config.devRequestJson); - const entry = getEvaluators().find( - (candidate) => candidate.evaluator.evalName === request.name, - ); - if (!entry) { - channel.send({ - type: "error", - message: `Evaluator '${request.name}' not found`, - }); - return; - } - - const state = runner.getState ? runner.getState() : undefined; - const evaluator = { - ...entry.evaluator, - ...(state !== undefined && state !== null ? { state } : {}), - }; - const { data: rawData } = callEvaluatorData(evaluator.data); - const dataIterable = toAsyncIterable(rawData); - const iterator = dataIterable[Symbol.asyncIterator](); - const trialCountRaw = Number(evaluator.trialCount ?? 1); - const trialCount = - Number.isFinite(trialCountRaw) && trialCountRaw > 0 - ? Math.floor(trialCountRaw) - : 1; - const maxConcurrencyRaw = Number(evaluator.maxConcurrency ?? 10); - const maxConcurrency = - Number.isFinite(maxConcurrencyRaw) && maxConcurrencyRaw > 0 - ? Math.floor(maxConcurrencyRaw) - : 10; - const experimentName = - typeof evaluator.experimentName === "string" && - evaluator.experimentName.length > 0 - ? evaluator.experimentName - : `${entry.evaluator.evalName}-${Date.now()}`; - - channel.send({ - type: "ready", - evaluator_name: entry.evaluator.evalName, - max_concurrency: maxConcurrency, - experiment_name: experimentName, - }); - - let currentDatum: unknown | undefined = undefined; - let trialIndex = 0; - for await (const line of channel.lines()) { - const parsed = JSON.parse(line) as { type?: string }; - if (parsed.type === "close") { - break; - } - if (parsed.type !== "next") { - channel.send({ - type: "error", - message: `Unsupported pull command '${String(parsed.type)}'`, - }); - break; - } - - if (currentDatum === undefined) { - const next = await iterator.next(); - if (next.done) { - channel.send({ type: "eof" }); - continue; - } - currentDatum = next.value; - trialIndex = 0; - } - - channel.send({ - type: "row", - datum: currentDatum, - trial_index: trialIndex, - }); - - trialIndex += 1; - if (trialIndex >= trialCount) { - currentDatum = undefined; - } - } - } catch (err) { - channel.send({ - type: "error", - message: err instanceof Error ? err.message : String(err), - }); - } finally { - channel.close(); - } -} - function extractBtEvalMain(mod: unknown): BtEvalMain | null { if (!mod || typeof mod !== "object") { return null; @@ -2064,9 +1472,12 @@ function mergeProgress( }; } -async function createEvalRunner(config: RunnerConfig): Promise { - const braintrust = await loadBraintrust(); - const Eval = braintrust.Eval; +async function createEvalRunner( + config: RunnerConfig, + files: string[], +): Promise { + const braintrust = await loadBraintrust(files); + const Eval = braintrust.Eval as EvalFunction | undefined; if (typeof Eval !== "function") { throw new Error("Unable to load Eval() from braintrust package."); } @@ -2076,8 +1487,8 @@ async function createEvalRunner(config: RunnerConfig): Promise { const sse = createSseWriter(); const noSendLogs = shouldDisableSendLogs(); - const parseParent = loadBraintrustUtilParseParent(); - const getState = extractGlobalStateGetter(braintrust); + const parseParent = loadBraintrustUtilParseParent(files); + const getState = getBraintrustStateGetter(braintrust); const makeEvalOptions = ( evaluatorName: string, @@ -2229,8 +1640,7 @@ async function main() { maybeRecordDependency(file); } collectStaticLocalDependencies(normalized); - ensureBraintrustAvailable(); - const braintrust = await loadBraintrust(); + const braintrust = await loadBraintrust(normalized); propagateInheritedBraintrustState(braintrust); initRegistry(); // Replace process.argv with [runtime, script, ...extraArgs] so that user @@ -2243,7 +1653,7 @@ async function main() { const modules = await loadFiles(normalized); const btEvalMains = collectBtEvalMains(modules); - const runner = await createEvalRunner(config); + const runner = await createEvalRunner(config, normalized); if (!runner.noSendLogs && typeof runner.login === "function") { try { await runner.login({}); @@ -2294,11 +1704,6 @@ async function main() { return; } - if (config.devMode === "rows") { - await runDatasetPull(config, runner); - return; - } - if (config.list) { for (const entry of filteredEvaluators) { console.log(entry.evaluator.evalName); diff --git a/scripts/runner-common.ts b/scripts/runner-common.ts index 8a0dd63..a98c612 100644 --- a/scripts/runner-common.ts +++ b/scripts/runner-common.ts @@ -1,3 +1,8 @@ +import { createRequire } from "node:module"; +import fs from "node:fs"; +import path from "node:path"; +import { pathToFileURL } from "node:url"; + export type JsonPrimitive = string | number | boolean | null; export type JsonArray = JsonValue[]; export type JsonObject = { [key: string]: JsonValue }; @@ -13,6 +18,56 @@ export type ProjectRef = { name?: string; }; +export type EvaluatorDefinition = { + evalName: string; + projectName: string; + data?: unknown; + trialCount?: unknown; + maxConcurrency?: unknown; + experimentName?: unknown; +} & Record; + +export type EvaluatorEntry = { + evaluator: EvaluatorDefinition; + reporter?: unknown; +}; + +export type BraintrustModule = { + Eval?: (...args: unknown[]) => unknown; + login?: (...args: unknown[]) => Promise; + initDataset?: (...args: unknown[]) => unknown; + invoke?: (...args: unknown[]) => Promise; + _internalGetGlobalState?: () => unknown; + default?: BraintrustModule; +}; + +export type GlobalEvals = { + functions: unknown[]; + prompts: unknown[]; + parameters: unknown[]; + evaluators: Record; + reporters: Record; +}; + +export type EvalFilter = { + path: string[]; + pattern: RegExp; +}; + +export type SerializedEvalFilter = { + path: string[]; + pattern: string; +}; + +declare global { + // eslint-disable-next-line no-var + var _evals: GlobalEvals | undefined; + // eslint-disable-next-line no-var + var _lazy_load: boolean | undefined; + // eslint-disable-next-line no-var + var __inherited_braintrust_state: unknown; +} + export function asProjectSelector( project: ProjectRef | undefined, ): ProjectSelector { @@ -81,3 +136,395 @@ export function toJsonValue(input: JsonValue): JsonValue { return input; } + +export function isObject(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +export function normalizeFiles(files: string[]): string[] { + return files.map((file) => path.resolve(process.cwd(), file)); +} + +export function envFlag(name: string): boolean { + const value = process.env[name]; + if (!value) { + return false; + } + const normalized = value.toLowerCase(); + return !["0", "false", "no", "off", ""].includes(normalized); +} + +export function serializeJSONWithPlainString(value: unknown): string { + if (typeof value === "string") { + return value; + } + return JSON.stringify(value); +} + +export function parseSerializedFilters( + serialized: string | undefined, +): EvalFilter[] { + if (!serialized) { + return []; + } + + try { + const parsed = JSON.parse(serialized); + if (!Array.isArray(parsed)) { + throw new Error("BT_EVAL_FILTER_PARSED must be a JSON array."); + } + return parsed.map((value) => { + if (!isObject(value)) { + throw new Error( + "BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}.", + ); + } + const { path: rawPath, pattern: rawPattern } = + value as SerializedEvalFilter; + if ( + !Array.isArray(rawPath) || + !rawPath.every((part) => typeof part === "string") + ) { + throw new Error( + "BT_EVAL_FILTER_PARSED entry path must be an array of strings.", + ); + } + if (typeof rawPattern !== "string") { + throw new Error( + "BT_EVAL_FILTER_PARSED entry pattern must be a string.", + ); + } + return { + path: rawPath, + pattern: new RegExp(rawPattern), + }; + }); + } catch (err) { + throw new Error( + `Invalid BT_EVAL_FILTER_PARSED value: ${err instanceof Error ? err.message : String(err)}`, + ); + } +} + +export function formatError(err: unknown): string { + if (err instanceof Error) { + return err.message; + } + return String(err); +} + +export function initRegistry() { + globalThis._evals = { + functions: [], + prompts: [], + parameters: [], + evaluators: {}, + reporters: {}, + }; + globalThis._lazy_load = true; +} + +function isBraintrustModule(value: unknown): value is BraintrustModule { + return isObject(value) && ("Eval" in value || "login" in value); +} + +function normalizeBraintrustModule(value: unknown): BraintrustModule { + if (isBraintrustModule(value)) { + return value; + } + if (isObject(value) && isBraintrustModule(value.default)) { + return value.default; + } + throw new Error("Unable to load braintrust module."); +} + +export function resolveBraintrustPath(files: string[]): string { + const normalizedFiles = normalizeFiles(files); + for (const file of normalizedFiles) { + try { + const require = createRequire(pathToFileURL(file).href); + return require.resolve("braintrust"); + } catch { + continue; + } + } + + try { + const require = createRequire(process.cwd() + "/"); + return require.resolve("braintrust"); + } catch { + const message = + "Unable to resolve the `braintrust` package. " + + "Please install it in your project (e.g. `pnpm add braintrust` or `npm install braintrust`)."; + throw new Error(message); + } +} + +export async function loadBraintrust( + files: string[], +): Promise { + const cjsPath = resolveBraintrustPath(files); + const cjsUrl = pathToFileURL(cjsPath).href; + + try { + const mod: unknown = await import(cjsUrl); + return normalizeBraintrustModule(mod); + } catch {} + + const esmPath = cjsPath.replace(/\.js$/, ".mjs"); + if (esmPath !== cjsPath && fs.existsSync(esmPath)) { + try { + const mod: unknown = await import(pathToFileURL(esmPath).href); + return normalizeBraintrustModule(mod); + } catch {} + } + + const require = createRequire(cjsUrl); + const mod: unknown = require(cjsPath); + return normalizeBraintrustModule(mod); +} + +export type ParseParentFunction = (parent: unknown) => string | undefined; + +function extractParseParent(mod: unknown): ParseParentFunction | null { + if (!isObject(mod)) { + return null; + } + const candidate = Reflect.get(mod, "parseParent"); + if (typeof candidate === "function") { + return candidate as ParseParentFunction; + } + const defaultExport = Reflect.get(mod, "default"); + if (isObject(defaultExport)) { + const fromDefault = Reflect.get(defaultExport, "parseParent"); + if (typeof fromDefault === "function") { + return fromDefault as ParseParentFunction; + } + } + return null; +} + +export function loadBraintrustUtilParseParent( + files: string[], +): ParseParentFunction | null { + const braintrustPath = resolveBraintrustPath(files); + const requireFromBraintrust = createRequire( + pathToFileURL(braintrustPath).href, + ); + try { + const utilMod: unknown = requireFromBraintrust("braintrust/util"); + return extractParseParent(utilMod); + } catch { + return null; + } +} + +function extractGlobalStateGetter(mod: unknown): (() => unknown) | null { + if (!isObject(mod)) { + return null; + } + const candidate = Reflect.get(mod, "_internalGetGlobalState"); + if (typeof candidate === "function") { + return candidate as () => unknown; + } + const defaultExport = Reflect.get(mod, "default"); + if (isObject(defaultExport)) { + const fromDefault = Reflect.get(defaultExport, "_internalGetGlobalState"); + if (typeof fromDefault === "function") { + return fromDefault as () => unknown; + } + } + return null; +} + +export function getBraintrustStateGetter( + braintrust: BraintrustModule, +): (() => unknown) | null { + return extractGlobalStateGetter(braintrust); +} + +export function propagateInheritedBraintrustState(braintrust: BraintrustModule) { + const getter = getBraintrustStateGetter(braintrust); + if (!getter) { + return; + } + const state = getter(); + if (state !== undefined && state !== null) { + globalThis.__inherited_braintrust_state = state; + } +} + +export async function loadFiles(files: string[]): Promise { + const modules: unknown[] = []; + const forceEsm = envFlag("BT_EVAL_FORCE_ESM"); + const isViteNode = process.env.BT_EVAL_RUNNER_KIND === "vite-node"; + for (const file of files) { + const fileUrl = pathToFileURL(file).href; + const isTypeScript = file.endsWith(".ts") || file.endsWith(".tsx"); + const preferRequire = + !forceEsm && + !(isViteNode && isTypeScript) && + (isTypeScript || file.endsWith(".cjs")); + + if (preferRequire) { + try { + const require = createRequire(fileUrl); + const mod = require(file); + modules.push(mod); + continue; + } catch (requireErr) { + try { + const mod = await import(fileUrl); + modules.push(mod); + continue; + } catch (esmErr) { + throw new Error( + `Failed to load ${file} as CJS (${formatError(requireErr)}) or ESM (${formatError(esmErr)}).`, + ); + } + } + } + + try { + const mod = await import(fileUrl); + modules.push(mod); + continue; + } catch (err) { + if (!shouldTryRequire(file, err)) { + throw err; + } + try { + const require = createRequire(fileUrl); + const mod = require(file); + modules.push(mod); + continue; + } catch (requireErr) { + throw new Error( + `Failed to load ${file} as ESM (${formatError(err)}) or CJS (${formatError(requireErr)}).`, + ); + } + } + } + return modules; +} + +function shouldTryRequire(file: string, err: unknown): boolean { + if (envFlag("BT_EVAL_FORCE_ESM")) { + return false; + } + if (process.env.BT_EVAL_RUNNER_KIND === "vite-node") { + return false; + } + if (process.env.BT_EVAL_CJS === "1" || file.endsWith(".cjs")) { + return true; + } + if ( + (file.endsWith(".ts") || file.endsWith(".tsx")) && + isNodeErrorCode(err, "ERR_UNKNOWN_FILE_EXTENSION") + ) { + return true; + } + if (!(err instanceof Error)) { + return false; + } + const message = err.message || ""; + return ( + message.includes("require is not defined") || + message.includes("exports is not defined") || + message.includes("module is not defined") || + message.includes("Cannot use import statement outside a module") + ); +} + +function isNodeErrorCode(err: unknown, code: string): boolean { + if (!isObject(err) || !("code" in err)) { + return false; + } + return typeof err.code === "string" && err.code === code; +} + +export function getEvaluators(): EvaluatorEntry[] { + const evals = globalThis._evals; + if (!evals || !evals.evaluators) { + return []; + } + return Object.values(evals.evaluators) as EvaluatorEntry[]; +} + +export function getReporters(): Record { + const evals = globalThis._evals; + if (!evals || !evals.reporters) { + return {}; + } + return evals.reporters as Record; +} + +export function evaluateFilter( + object: Record, + filter: EvalFilter, +): boolean { + const key = filter.path.reduce((acc, part) => { + if (!isObject(acc)) { + return undefined; + } + return acc[part]; + }, object); + if (key === undefined) { + return false; + } + return filter.pattern.test(serializeJSONWithPlainString(key)); +} + +export function filterEvaluators( + evaluators: EvaluatorEntry[], + filters: EvalFilter[], +): EvaluatorEntry[] { + if (filters.length === 0) { + return evaluators; + } + return evaluators.filter((entry) => + filters.every((filter) => evaluateFilter(entry.evaluator, filter)), + ); +} + +export function callEvaluatorData( + data: unknown, +): { data: unknown; baseExperiment: string | undefined } { + const dataResult = typeof data === "function" ? (data as () => unknown)() : data; + let baseExperiment: string | undefined = undefined; + if ( + isObject(dataResult) && + Reflect.get(dataResult, "_type") === "BaseExperiment" && + typeof Reflect.get(dataResult, "name") === "string" + ) { + baseExperiment = Reflect.get(dataResult, "name") as string; + } + return { data: dataResult, baseExperiment }; +} + +export function toAsyncIterable(value: unknown): AsyncIterable { + if ( + typeof value === "object" && + value !== null && + Symbol.asyncIterator in value && + typeof (value as AsyncIterable)[Symbol.asyncIterator] === "function" + ) { + return value as AsyncIterable; + } + if ( + typeof value === "object" && + value !== null && + Symbol.iterator in value && + typeof (value as Iterable)[Symbol.iterator] === "function" + ) { + const iterable = value as Iterable; + return (async function* () { + for (const item of iterable) { + yield item; + } + })(); + } + throw new Error( + "Evaluator data must be an array, iterable, or async iterable", + ); +} diff --git a/scripts/runner_common.py b/scripts/runner_common.py new file mode 100644 index 0000000..f079a8a --- /dev/null +++ b/scripts/runner_common.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import asyncio +import importlib.util +import inspect +import json +import os +import re +import sys +from dataclasses import dataclass +from typing import Any, AsyncIterator + +try: + from braintrust.framework import ( + BaseExperiment, + EvaluatorInstance, + _evals, + _set_lazy_load, + ) + from braintrust.logger import Dataset +except Exception: + raise + + +@dataclass(frozen=True) +class EvalFilter: + path: list[str] + pattern: re.Pattern[str] + + +def env_flag(name: str) -> bool: + value = os.getenv(name) + if value is None: + return False + return value.lower() not in {"0", "false", "no", "off", ""} + + +def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]: + if not serialized: + return [] + + parsed = json.loads(serialized) + if not isinstance(parsed, list): + raise ValueError("BT_EVAL_FILTER_PARSED must be a JSON array") + + filters: list[EvalFilter] = [] + for i, entry in enumerate(parsed): + if not isinstance(entry, dict): + raise ValueError("BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}") + key_path = entry.get("path") + pattern = entry.get("pattern") + if not isinstance(key_path, list) or not all(isinstance(part, str) for part in key_path): + raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} path must be an array of strings") + if not isinstance(pattern, str): + raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} pattern must be a string") + filters.append(EvalFilter(path=key_path, pattern=re.compile(pattern))) + return filters + + +def _to_mapping(value: Any) -> Any: + if isinstance(value, dict): + return {k: _to_mapping(v) for k, v in value.items()} + if isinstance(value, list): + return [_to_mapping(v) for v in value] + if hasattr(value, "__dict__"): + return { + key: _to_mapping(val) + for key, val in vars(value).items() + if not key.startswith("_") + } + return value + + +def serialize_json_with_plain_string(value: Any) -> str: + if isinstance(value, str): + return value + return json.dumps(value) + + +def evaluate_filter(value: Any, filt: EvalFilter) -> bool: + current = _to_mapping(value) + for part in filt.path: + if not isinstance(current, dict) or part not in current: + return False + current = current[part] + return bool(filt.pattern.search(serialize_json_with_plain_string(current))) + + +def filter_evaluators( + evaluators: list[EvaluatorInstance], filters: list[EvalFilter] +) -> list[EvaluatorInstance]: + if not filters: + return evaluators + return [ + evaluator + for evaluator in evaluators + if all(evaluate_filter(evaluator.evaluator, filt) for filt in filters) + ] + + +async def call_evaluator_data(data: Any) -> tuple[Any, str | None]: + data_result = data + if inspect.isclass(data_result): + data_result = data_result() + if inspect.isfunction(data_result) or inspect.isroutine(data_result): + data_result = data_result() + if inspect.isawaitable(data_result): + data_result = await data_result + + base_experiment_name = None + if isinstance(data_result, BaseExperiment): + base_experiment_name = data_result.name + + return data_result, base_experiment_name + + +def to_async_iterator(value: Any) -> AsyncIterator[Any]: + if inspect.isasyncgen(value): + return value + + async def to_async(it): + for item in it: + yield item + + return to_async(value) + + +def collect_files(input_path: str) -> list[str]: + if os.path.isdir(input_path): + matches: list[str] = [] + for root, _, files in os.walk(input_path): + for filename in files: + matches.append(os.path.join(root, filename)) + return matches + return [input_path] + + +def resolve_module_info(in_file: str) -> tuple[str, list[str]]: + in_file = os.path.abspath(in_file) + module_dir = os.path.dirname(in_file) + module_name = os.path.splitext(os.path.basename(in_file))[0] + + package_parts: list[str] = [] + current = module_dir + while os.path.isfile(os.path.join(current, "__init__.py")): + package_parts.insert(0, os.path.basename(current)) + current = os.path.dirname(current) + + extra_paths = [module_dir] + if package_parts: + module_name = ".".join(package_parts + [module_name]) + if current not in extra_paths: + extra_paths.append(current) + + return module_name, extra_paths + + +def load_evaluators(files: list[str]) -> tuple[list[EvaluatorInstance], dict[str, Any]]: + evaluator_instances: list[EvaluatorInstance] = [] + reporters: dict[str, Any] = {} + cwd = os.getcwd() + if cwd not in sys.path: + sys.path.insert(0, cwd) + + for f in files: + d = os.path.dirname(os.path.abspath(f)) + while d and d != os.path.dirname(d): + if os.path.isfile(os.path.join(d, "register.py")): + if d not in sys.path: + sys.path.insert(0, d) + break + d = os.path.dirname(d) + + unique_files: set[str] = set() + for file_path in files: + for candidate in collect_files(file_path): + unique_files.add(os.path.abspath(candidate)) + + for file_path in sorted(unique_files): + module_name, extra_paths = resolve_module_info(file_path) + with _set_lazy_load(True): + _evals.clear() + try: + for extra_path in reversed(extra_paths): + if extra_path not in sys.path: + sys.path.insert(0, extra_path) + + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load module spec for {file_path}") + + sys.modules.pop(module_name, None) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + evaluator_instances.extend( + [ + instance + for instance in _evals.evaluators.values() + if isinstance(instance, EvaluatorInstance) + ] + ) + for reporter_name, reporter in _evals.reporters.items(): + if reporter_name not in reporters: + reporters[reporter_name] = reporter + finally: + _evals.clear() + + return evaluator_instances, reporters + + +__all__ = [ + "BaseExperiment", + "Dataset", + "EvalFilter", + "EvaluatorInstance", + "call_evaluator_data", + "env_flag", + "filter_evaluators", + "load_evaluators", + "parse_serialized_filters", + "serialize_json_with_plain_string", + "to_async_iterator", +] diff --git a/src/eval.rs b/src/eval.rs index 04037f9..8c88d03 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -179,6 +179,7 @@ struct EvalPullRequest { #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] enum EvalPullClientMessage { + Start { name: String }, Next, Close, } @@ -253,9 +254,17 @@ struct RunnerFilter { } const JS_RUNNER_FILE: &str = "eval-runner.ts"; +const JS_DATA_RUNNER_FILE: &str = "data-runner.ts"; +const JS_RUNNER_COMMON_FILE: &str = "runner-common.ts"; const PY_RUNNER_FILE: &str = "eval-runner.py"; +const PY_DATA_RUNNER_FILE: &str = "data-runner.py"; +const PY_RUNNER_COMMON_FILE: &str = "runner_common.py"; const JS_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.ts"); +const JS_DATA_RUNNER_SOURCE: &str = include_str!("../scripts/data-runner.ts"); +const JS_RUNNER_COMMON_SOURCE: &str = include_str!("../scripts/runner-common.ts"); const PY_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.py"); +const PY_DATA_RUNNER_SOURCE: &str = include_str!("../scripts/data-runner.py"); +const PY_RUNNER_COMMON_SOURCE: &str = include_str!("../scripts/runner_common.py"); #[derive(Debug)] struct SocketCleanupGuard { @@ -890,110 +899,131 @@ async fn spawn_eval_data_puller( ) -> Result { let (listener, socket_path, socket_cleanup_guard) = bind_unix_listener("bt-eval-pull").context("failed to bind sandbox pull socket")?; - let request_json = - serde_json::to_string(request).context("failed to serialize sandbox pull request")?; - let extra_env = vec![ - ("BT_EVAL_DEV_MODE".to_string(), "rows".to_string()), - ("BT_EVAL_DEV_REQUEST_JSON".to_string(), request_json), - ( - "BT_EVAL_PULL_SOCK".to_string(), - socket_path.to_string_lossy().to_string(), - ), - ]; - let child = spawn_eval_support_process( - base, - language, - runner_override, - files, - no_send_logs, - options, - &extra_env, - JsMode::Auto, - ) - .await?; + let child = match language { + EvalLanguage::JavaScript => { + let js_runner = prepare_js_data_runner()?; + let mut extra_env = vec![( + "BT_EVAL_PULL_SOCK".to_string(), + socket_path.to_string_lossy().to_string(), + )]; + let mut plan = build_js_plan_with_entrypoint( + runner_override, + &js_runner, + files, + JS_DATA_RUNNER_FILE, + JS_DATA_RUNNER_SOURCE, + )?; + if should_set_node_heap_size(plan.kind) { + set_node_heap_size_env(&mut plan.cmd); + } + plan.cmd.envs(build_env(base).await?); + for (key, value) in extra_env.drain(..) { + plan.cmd.env(key, value); + } + if no_send_logs { + plan.cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + plan.cmd.env("BT_EVAL_LOCAL", "1"); + } + if options.jsonl { + plan.cmd.env("BT_EVAL_JSONL", "1"); + } + if options.terminate_on_failure { + plan.cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); + } + if options.list { + plan.cmd.env("BT_EVAL_LIST", "1"); + } + if let Some(num_workers) = options.num_workers { + plan.cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + plan.cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if !options.extra_args.is_empty() { + let serialized = serde_json::to_string(&options.extra_args) + .context("failed to serialize eval extra args")?; + plan.cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + let runner_name = match plan.kind { + RunnerKind::Tsx => "tsx", + RunnerKind::ViteNode => "vite-node", + RunnerKind::Deno => "deno", + RunnerKind::Bun => "bun", + RunnerKind::Other => "other", + }; + plan.cmd.env("BT_EVAL_RUNNER_KIND", runner_name); + plan.cmd + .stdin(Stdio::null()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + plan.cmd + .spawn() + .context("failed to spawn sandbox pull runner")? + } + EvalLanguage::Python => { + let py_runner = prepare_py_data_runner()?; + let mut cmd = build_python_command(runner_override, &py_runner, files)?; + cmd.envs(build_env(base).await?); + cmd.env( + "BT_EVAL_PULL_SOCK", + socket_path.to_string_lossy().to_string(), + ); + if no_send_logs { + cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); + cmd.env("BT_EVAL_LOCAL", "1"); + } + if options.jsonl { + cmd.env("BT_EVAL_JSONL", "1"); + } + if options.terminate_on_failure { + cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); + } + if options.list { + cmd.env("BT_EVAL_LIST", "1"); + } + if let Some(num_workers) = options.num_workers { + cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } + if !options.extra_args.is_empty() { + let serialized = serde_json::to_string(&options.extra_args) + .context("failed to serialize eval extra args")?; + cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + } + cmd.stdin(Stdio::null()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + cmd.spawn().context("failed to spawn sandbox pull runner")? + } + }; let (stream, _) = tokio::time::timeout(Duration::from_secs(30), listener.accept()) .await .context("timed out waiting for sandbox pull runner to connect")? .context("sandbox pull runner failed to connect")?; let (read_half, write_half) = stream.into_split(); - Ok(EvalDataPuller { + let mut puller = EvalDataPuller { child, writer: write_half, reader: BufReader::new(read_half), _socket_cleanup_guard: socket_cleanup_guard, - }) -} - -async fn spawn_eval_support_process( - base: &BaseArgs, - language: EvalLanguage, - runner_override: Option<&str>, - files: &[String], - no_send_logs: bool, - options: &EvalRunOptions, - extra_env: &[(String, String)], - js_mode: JsMode, -) -> Result { - let (js_runner, py_runner) = prepare_eval_runners()?; - let force_esm = matches!(js_mode, JsMode::ForceEsm); - let (mut cmd, runner_kind) = match language { - EvalLanguage::Python => ( - build_python_command(runner_override, &py_runner, files)?, - RunnerKind::Other, - ), - EvalLanguage::JavaScript => { - if force_esm { - ( - build_vite_node_fallback_command(&js_runner, files)?, - RunnerKind::ViteNode, - ) - } else { - let plan = build_js_plan(runner_override, &js_runner, files)?; - (plan.cmd, plan.kind) - } - } }; - if language == EvalLanguage::JavaScript && should_set_node_heap_size(runner_kind) { - set_node_heap_size_env(&mut cmd); - } - cmd.envs(build_env(base).await?); - for (key, value) in extra_env { - cmd.env(key, value); - } - if no_send_logs { - cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); - cmd.env("BT_EVAL_LOCAL", "1"); - } - if options.jsonl { - cmd.env("BT_EVAL_JSONL", "1"); - } - if options.terminate_on_failure { - cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); - } - if options.list { - cmd.env("BT_EVAL_LIST", "1"); - } - if let Some(num_workers) = options.num_workers { - cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); - } - if !options.filter.is_empty() { - let parsed = parse_eval_filter_expressions(&options.filter)?; - let serialized = - serde_json::to_string(&parsed).context("failed to serialize eval filters")?; - cmd.env("BT_EVAL_FILTER_PARSED", serialized); - } - if language == EvalLanguage::JavaScript && force_esm { - cmd.env("BT_EVAL_FORCE_ESM", "1"); - } - if !options.extra_args.is_empty() { - let serialized = - serde_json::to_string(&options.extra_args).context("failed to serialize extra args")?; - cmd.env("BT_EVAL_EXTRA_ARGS_JSON", serialized); + if matches!(language, EvalLanguage::JavaScript | EvalLanguage::Python) { + puller + .send_message(&EvalPullClientMessage::Start { + name: request.name.clone(), + }) + .await?; } - cmd.stdout(Stdio::inherit()); - cmd.stderr(Stdio::inherit()); - cmd.spawn().context("failed to start eval support runner") + Ok(puller) } async fn run_eval_runner_command_to_completion( @@ -2858,18 +2888,40 @@ fn build_js_plan( runner_override: Option<&str>, runner: &Path, files: &[String], +) -> Result { + build_js_plan_with_entrypoint( + runner_override, + runner, + files, + JS_RUNNER_FILE, + JS_RUNNER_SOURCE, + ) +} + +fn build_js_plan_with_entrypoint( + runner_override: Option<&str>, + runner: &Path, + files: &[String], + embedded_file_name: &str, + embedded_source: &str, ) -> Result { if let Some(explicit) = runner_override { let resolved_runner = resolve_js_runner_command(explicit, files); if is_deno_runner(explicit) || is_deno_runner_path(resolved_runner.as_ref()) { - let runner_script = prepare_js_runner_in_cwd()?; + let runner_script = + prepare_js_embedded_runner_in_cwd(embedded_file_name, embedded_source)?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(resolved_runner.as_os_str(), &runner_script, files), kind: RunnerKind::Deno, }); } let kind = runner_kind_for_bin(resolved_runner.as_ref()); - let runner_script = select_js_runner_entrypoint(runner, resolved_runner.as_ref())?; + let runner_script = select_js_runner_entrypoint_with_source( + runner, + resolved_runner.as_ref(), + embedded_file_name, + embedded_source, + )?; let mut command = Command::new(resolved_runner); command.arg(runner_script).args(files); return Ok(JsRunnerPlan { cmd: command, kind }); @@ -2877,14 +2929,20 @@ fn build_js_plan( if let Some(auto_runner) = find_js_runner_binary(files) { if is_deno_runner_path(&auto_runner) { - let runner_script = prepare_js_runner_in_cwd()?; + let runner_script = + prepare_js_embedded_runner_in_cwd(embedded_file_name, embedded_source)?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(auto_runner.as_os_str(), &runner_script, files), kind: RunnerKind::Deno, }); } let kind = runner_kind_for_bin(auto_runner.as_ref()); - let runner_script = select_js_runner_entrypoint(runner, auto_runner.as_ref())?; + let runner_script = select_js_runner_entrypoint_with_source( + runner, + auto_runner.as_ref(), + embedded_file_name, + embedded_source, + )?; let mut command = Command::new(auto_runner); command.arg(runner_script).args(files); return Ok(JsRunnerPlan { cmd: command, kind }); @@ -3039,14 +3097,19 @@ fn is_deno_runner_path(runner: &Path) -> bool { .unwrap_or(false) } -fn select_js_runner_entrypoint(default_runner: &Path, runner_command: &Path) -> Result { +fn select_js_runner_entrypoint_with_source( + default_runner: &Path, + runner_command: &Path, + embedded_file_name: &str, + embedded_source: &str, +) -> Result { if is_ts_node_runner(runner_command) { - return prepare_js_runner_in_cwd(); + return prepare_js_embedded_runner_in_cwd(embedded_file_name, embedded_source); } Ok(default_runner.to_path_buf()) } -fn prepare_js_runner_in_cwd() -> Result { +fn prepare_js_embedded_runner_in_cwd(file_name: &str, source: &str) -> Result { let cwd = std::env::current_dir().context("failed to resolve current working directory")?; let cache_dir = cwd .join(".bt") @@ -3058,7 +3121,8 @@ fn prepare_js_runner_in_cwd() -> Result { cache_dir.display() ) })?; - materialize_runner_script(&cache_dir, JS_RUNNER_FILE, JS_RUNNER_SOURCE) + materialize_runner_script(&cache_dir, JS_RUNNER_COMMON_FILE, JS_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(&cache_dir, file_name, source) } fn runner_bin_name(runner_command: &Path) -> Option { @@ -3208,6 +3272,14 @@ fn prepare_eval_runners() -> Result<(PathBuf, PathBuf)> { prepare_eval_runners_in_dir(&eval_runner_cache_dir()) } +fn prepare_js_data_runner() -> Result { + prepare_js_data_runner_in_dir(&eval_runner_cache_dir()) +} + +fn prepare_py_data_runner() -> Result { + prepare_py_data_runner_in_dir(&eval_runner_cache_dir()) +} + fn prepare_eval_runners_in_dir(cache_dir: &Path) -> Result<(PathBuf, PathBuf)> { std::fs::create_dir_all(cache_dir).with_context(|| { format!( @@ -3216,11 +3288,35 @@ fn prepare_eval_runners_in_dir(cache_dir: &Path) -> Result<(PathBuf, PathBuf)> { ) })?; + materialize_runner_script(cache_dir, JS_RUNNER_COMMON_FILE, JS_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(cache_dir, PY_RUNNER_COMMON_FILE, PY_RUNNER_COMMON_SOURCE)?; let js_runner = materialize_runner_script(cache_dir, JS_RUNNER_FILE, JS_RUNNER_SOURCE)?; let py_runner = materialize_runner_script(cache_dir, PY_RUNNER_FILE, PY_RUNNER_SOURCE)?; Ok((js_runner, py_runner)) } +fn prepare_js_data_runner_in_dir(cache_dir: &Path) -> Result { + std::fs::create_dir_all(cache_dir).with_context(|| { + format!( + "failed to create eval runner cache dir {}", + cache_dir.display() + ) + })?; + materialize_runner_script(cache_dir, JS_RUNNER_COMMON_FILE, JS_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(cache_dir, JS_DATA_RUNNER_FILE, JS_DATA_RUNNER_SOURCE) +} + +fn prepare_py_data_runner_in_dir(cache_dir: &Path) -> Result { + std::fs::create_dir_all(cache_dir).with_context(|| { + format!( + "failed to create eval runner cache dir {}", + cache_dir.display() + ) + })?; + materialize_runner_script(cache_dir, PY_RUNNER_COMMON_FILE, PY_RUNNER_COMMON_SOURCE)?; + materialize_runner_script(cache_dir, PY_DATA_RUNNER_FILE, PY_DATA_RUNNER_SOURCE) +} + fn materialize_runner_script(cache_dir: &Path, file_name: &str, source: &str) -> Result { let path = cache_dir.join(file_name); let current = std::fs::read_to_string(&path).ok(); @@ -4324,11 +4420,27 @@ mod tests { let dir = make_temp_dir("embedded"); let (js_runner, py_runner) = prepare_eval_runners_in_dir(&dir).expect("embedded runners should be materialized"); + let js_data_runner = prepare_js_data_runner_in_dir(&dir) + .expect("embedded data runner should be materialized"); + let py_data_runner = prepare_py_data_runner_in_dir(&dir) + .expect("embedded py data runner should be materialized"); let js = fs::read_to_string(js_runner).expect("js runner should be readable"); + let js_data = + fs::read_to_string(js_data_runner).expect("js data runner should be readable"); + let js_common = fs::read_to_string(dir.join(JS_RUNNER_COMMON_FILE)) + .expect("js common should be readable"); let py = fs::read_to_string(py_runner).expect("python runner should be readable"); + let py_data = + fs::read_to_string(py_data_runner).expect("python data runner should be readable"); + let py_common = fs::read_to_string(dir.join(PY_RUNNER_COMMON_FILE)) + .expect("python common should be readable"); assert_eq!(js, JS_RUNNER_SOURCE); + assert_eq!(js_data, JS_DATA_RUNNER_SOURCE); + assert_eq!(js_common, JS_RUNNER_COMMON_SOURCE); assert_eq!(py, PY_RUNNER_SOURCE); + assert_eq!(py_data, PY_DATA_RUNNER_SOURCE); + assert_eq!(py_common, PY_RUNNER_COMMON_SOURCE); let _ = fs::remove_dir_all(&dir); } diff --git a/tests/eval_fixtures.rs b/tests/eval_fixtures.rs index 1e1d5ba..b64485f 100644 --- a/tests/eval_fixtures.rs +++ b/tests/eval_fixtures.rs @@ -4,6 +4,8 @@ use std::fs; use std::io::Write; use std::io::{BufRead, BufReader, Read}; #[cfg(unix)] +use std::os::unix::fs::symlink; +#[cfg(unix)] use std::os::unix::net::{UnixListener, UnixStream}; use std::path::{Path, PathBuf}; use std::process::{Child, Command, Stdio}; @@ -15,6 +17,8 @@ use serde::Deserialize; #[cfg(unix)] use serde_json::json; use serde_json::Value; +#[cfg(unix)] +use tempfile::tempdir; #[derive(Debug, Deserialize, Clone)] struct FixtureConfig { @@ -463,9 +467,20 @@ fn eval_runner_rows_mode_streams_js_rows_and_trials() { .join("eval-ts-cjs"); ensure_dependencies(&fixture_dir); let runner = local_tsx_path(&fixture_dir).expect("resolve tsx runner"); - let runner_script = root.join("scripts").join("eval-runner.ts"); - let fixture_name = format!("sandbox_rows_{}.eval.ts", unique_test_suffix()); - let fixture_path = fixture_dir.join(&fixture_name); + let runner_script = root.join("scripts").join("data-runner.ts"); + let temp_fixture_dir = tempdir().expect("create js rows tempdir"); + fs::copy( + fixture_dir.join("package.json"), + temp_fixture_dir.path().join("package.json"), + ) + .expect("copy js fixture package.json"); + symlink( + fixture_dir.join("node_modules"), + temp_fixture_dir.path().join("node_modules"), + ) + .expect("symlink js fixture node_modules"); + let fixture_name = "sandbox_rows.eval.ts"; + let fixture_path = temp_fixture_dir.path().join(fixture_name); let fixture_source = r#"import { Eval } from "braintrust"; async function* rows() { @@ -475,10 +490,9 @@ async function* rows() { Eval("sandbox-rows-js", { data: rows, - task: async (input: { case_id: string }) => - input.case_id === "row-1" ? "alpha" : "bravo", + task: async (input) => (input.case_id === "row-1" ? "alpha" : "bravo"), scores: [ - ({ output, expected }: { output: string; expected?: string }) => ({ + ({ output, expected }) => ({ name: "match", score: output === expected ? 1 : 0, }), @@ -497,17 +511,8 @@ Eval("sandbox-rows-js", { let mut child = Command::new(&runner) .arg(&runner_script) - .arg(&fixture_name) - .current_dir(&fixture_dir) - .env("BT_EVAL_DEV_MODE", "rows") - .env( - "BT_EVAL_DEV_REQUEST_JSON", - json!({ - "name": "sandbox-rows-js", - "parameters": {}, - }) - .to_string(), - ) + .arg(fixture_name) + .current_dir(temp_fixture_dir.path()) .env("BT_EVAL_PULL_SOCK", &socket_path) .env("BT_EVAL_LOCAL", "1") .env("BT_EVAL_NO_SEND_LOGS", "1") @@ -521,6 +526,13 @@ Eval("sandbox-rows-js", { let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); let mut writer = stream; + write_pull_message( + &mut writer, + &json!({ + "type": "start", + "name": "sandbox-rows-js", + }), + ); let ready = read_pull_message(&mut reader); assert_eq!(ready["type"], "ready"); assert_eq!(ready["evaluator_name"], "sandbox-rows-js"); @@ -572,7 +584,6 @@ Eval("sandbox-rows-js", { } let _ = fs::remove_file(&socket_path); - let _ = fs::remove_file(&fixture_path); } #[cfg(unix)] @@ -581,7 +592,6 @@ fn eval_runner_rows_mode_streams_python_rows_and_trials() { let _guard = test_lock(); let root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); let fixtures_root = root.join("tests").join("evals"); - let fixture_dir = fixtures_root.join("py").join("local_import"); let python = match ensure_python_env(&fixtures_root.join("py")) { Some(python) => python, None => { @@ -595,9 +605,10 @@ fn eval_runner_rows_mode_streams_python_rows_and_trials() { } }; - let runner_script = root.join("scripts").join("eval-runner.py"); - let fixture_name = format!("sandbox_rows_{}.py", unique_test_suffix()); - let fixture_path = fixture_dir.join(&fixture_name); + let runner_script = root.join("scripts").join("data-runner.py"); + let temp_fixture_dir = tempdir().expect("create python rows tempdir"); + let fixture_name = "sandbox_rows.py"; + let fixture_path = temp_fixture_dir.path().join(fixture_name); let fixture_source = r#"from braintrust import Eval def rows(): @@ -629,17 +640,8 @@ Eval( let mut child = Command::new(&python) .arg(&runner_script) - .arg(&fixture_name) - .current_dir(&fixture_dir) - .env("BT_EVAL_DEV_MODE", "rows") - .env( - "BT_EVAL_DEV_REQUEST_JSON", - json!({ - "name": "sandbox-rows-py", - "parameters": {}, - }) - .to_string(), - ) + .arg(fixture_name) + .current_dir(temp_fixture_dir.path()) .env("BT_EVAL_PULL_SOCK", &socket_path) .env("BT_EVAL_LOCAL", "1") .env("BT_EVAL_NO_SEND_LOGS", "1") @@ -653,6 +655,13 @@ Eval( let mut reader = BufReader::new(stream.try_clone().expect("clone pull stream")); let mut writer = stream; + write_pull_message( + &mut writer, + &json!({ + "type": "start", + "name": "sandbox-rows-py", + }), + ); let ready = read_pull_message(&mut reader); assert_eq!(ready["type"], "ready"); assert_eq!(ready["evaluator_name"], "sandbox-rows-py"); @@ -706,7 +715,6 @@ Eval( } let _ = fs::remove_file(&socket_path); - let _ = fs::remove_file(&fixture_path); } fn read_fixture_config(path: &Path) -> FixtureConfig {