diff --git a/.github/workflows/simulator-release.yml b/.github/workflows/simulator-release.yml new file mode 100644 index 000000000..b362ffe73 --- /dev/null +++ b/.github/workflows/simulator-release.yml @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: © 2026 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +name: Simulator Release + +on: + workflow_dispatch: + inputs: + version: + description: 'Release version (for example: 0.5.8)' + required: true + type: string + push: + tags: + - 'simulator-v*' + +permissions: + contents: write + +jobs: + build-and-release: + runs-on: ubuntu-latest + env: + TARGET_TRIPLE: x86_64-unknown-linux-musl + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Resolve version and tag + run: | + if [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ]]; then + VERSION="${{ github.event.inputs.version }}" + else + VERSION="${GITHUB_REF#refs/tags/simulator-v}" + fi + VERSION="${VERSION#simulator-v}" + TAG="simulator-v${VERSION}" + echo "VERSION=${VERSION}" >> "$GITHUB_ENV" + echo "TAG=${TAG}" >> "$GITHUB_ENV" + echo "Resolved release version: ${VERSION}" + + - name: Install musl toolchain + run: | + sudo apt-get update + sudo apt-get install -y musl-tools + + - name: Set up Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + targets: ${{ env.TARGET_TRIPLE }} + + - name: Cache Rust build artifacts + uses: Swatinem/rust-cache@v2 + + - name: Build musl simulator binary + run: cargo build --locked --release --target "${TARGET_TRIPLE}" -p dstack-guest-agent-simulator + + - name: Package release bundle + run: ./guest-agent-simulator/package-release.sh "${VERSION}" "${TARGET_TRIPLE}" + + - name: GitHub Release + uses: softprops/action-gh-release@v1 + with: + tag_name: ${{ env.TAG }} + name: "Simulator Release v${{ env.VERSION }}" + files: | + guest-agent-simulator/dist/dstack-simulator-${{ env.VERSION }}-${{ env.TARGET_TRIPLE }}.tar.gz + guest-agent-simulator/dist/dstack-simulator-${{ env.VERSION }}-${{ env.TARGET_TRIPLE }}.tar.gz.sha256 + guest-agent-simulator/install-systemd.sh + body: | + ## Release Assets + + - `dstack-simulator-${{ env.VERSION }}-${{ env.TARGET_TRIPLE }}.tar.gz` + - `dstack-simulator-${{ env.VERSION }}-${{ env.TARGET_TRIPLE }}.tar.gz.sha256` + - `install-systemd.sh` + + The tarball contains the musl-linked `dstack-simulator` binary together with the default + simulator config, fixture data, and a systemd unit template. + + ## Quick Start + + Download and run directly: + + ```bash + curl -LO https://github.com/${{ github.repository }}/releases/download/${{ env.TAG }}/dstack-simulator-${{ env.VERSION }}-${{ env.TARGET_TRIPLE }}.tar.gz + tar -xzf dstack-simulator-${{ env.VERSION }}-${{ env.TARGET_TRIPLE }}.tar.gz + cd dstack-simulator-${{ env.VERSION }}-${{ env.TARGET_TRIPLE }} + ./dstack-simulator -c dstack.toml + ``` + + Install to systemd: + + ```bash + curl -fsSL https://raw.githubusercontent.com/${{ github.repository }}/${{ env.TAG }}/guest-agent-simulator/install-systemd.sh | sudo bash -s -- --version ${{ env.VERSION }} + ``` diff --git a/CLAUDE.md b/CLAUDE.md index c85c936a3..0b6984537 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -48,6 +48,7 @@ cargo build --release -p dstack-vmm cargo build --release -p dstack-kms cargo build --release -p dstack-gateway cargo build --release -p dstack-guest-agent +cargo build --release -p dstack-guest-agent-simulator # Check code cargo check --all-features diff --git a/Cargo.lock b/Cargo.lock index b2a11131b..5d35982a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2400,6 +2400,23 @@ dependencies = [ "serde_json", ] +[[package]] +name = "dstack-guest-agent-simulator" +version = "0.5.8" +dependencies = [ + "anyhow", + "clap", + "dstack-guest-agent", + "dstack-guest-agent-rpc", + "ra-rpc", + "ra-tls", + "rocket", + "serde", + "serde_json", + "tracing", + "tracing-subscriber", +] + [[package]] name = "dstack-kms" version = "0.5.8" diff --git a/Cargo.toml b/Cargo.toml index 841001799..982b892ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ members = [ "iohash", "guest-agent", "guest-agent/rpc", + "guest-agent-simulator", "vmm", "vmm/rpc", "gateway", diff --git a/cert-client/src/lib.rs b/cert-client/src/lib.rs index 82141499f..8b5b329a1 100644 --- a/cert-client/src/lib.rs +++ b/cert-client/src/lib.rs @@ -6,11 +6,7 @@ use anyhow::{Context, Result}; use dstack_kms_rpc::{kms_client::KmsClient, SignCertRequest}; use dstack_types::{AppKeys, KeyProvider}; use ra_rpc::client::{RaClient, RaClientConfig}; -use ra_tls::{ - attestation::{QuoteContentType, VersionedAttestation}, - cert::{generate_ra_cert, CaCert, CertConfigV2, CertSigningRequestV2, Csr}, - rcgen::KeyPair, -}; +use ra_tls::cert::{generate_ra_cert, CaCert, CertSigningRequestV2}; pub enum CertRequestClient { Local { @@ -92,34 +88,4 @@ impl CertRequestClient { } } } - - pub async fn request_cert( - &self, - key: &KeyPair, - config: CertConfigV2, - attestation_override: Option, - ) -> Result> { - let pubkey = key.public_key_der(); - let report_data = QuoteContentType::RaTlsCert.to_report_data(&pubkey); - let attestation = match attestation_override { - Some(mut attestation) => { - attestation.set_report_data(report_data); - attestation - } - None => ra_rpc::Attestation::quote(&report_data) - .context("Failed to get quote for cert pubkey")? - .into_versioned(), - }; - - let csr = CertSigningRequestV2 { - confirm: "please sign cert:".to_string(), - pubkey, - config, - attestation, - }; - let signature = csr.signed_by(key).context("Failed to sign the CSR")?; - self.sign_csr(&csr, &signature) - .await - .context("Failed to sign the CSR") - } } diff --git a/dstack-attest/src/attestation.rs b/dstack-attest/src/attestation.rs index 7a93005a2..ed5d9623e 100644 --- a/dstack-attest/src/attestation.rs +++ b/dstack-attest/src/attestation.rs @@ -272,25 +272,6 @@ impl VersionedAttestation { } } - /// Set the report_data field in the attestation and in the raw TDX quote bytes (offset 568..632). - /// This is used by the simulator to patch a canned attestation with the correct report_data - /// that binds to the actual TLS public key. - pub fn set_report_data(&mut self, report_data: [u8; 64]) { - let VersionedAttestation::V0 { attestation } = self; - attestation.report_data = report_data; - if let Some(tdx_quote) = attestation.tdx_quote_mut() { - if tdx_quote.quote.len() >= TDX_QUOTE_REPORT_DATA_RANGE.end { - tdx_quote.quote[TDX_QUOTE_REPORT_DATA_RANGE].copy_from_slice(&report_data); - } else { - tracing::warn!( - "TDX quote too short to patch report_data ({} < {})", - tdx_quote.quote.len(), - TDX_QUOTE_REPORT_DATA_RANGE.end - ); - } - } - } - /// Strip data for certificate embedding (e.g. keep RTMR3 event logs only). pub fn into_stripped(mut self) -> Self { let VersionedAttestation::V0 { attestation } = &mut self; diff --git a/dstack-util/src/system_setup.rs b/dstack-util/src/system_setup.rs index 4bde55572..35cd3b003 100644 --- a/dstack-util/src/system_setup.rs +++ b/dstack-util/src/system_setup.rs @@ -27,8 +27,14 @@ use luks2::{ LuksAf, LuksConfig, LuksDigest, LuksHeader, LuksJson, LuksKdf, LuksKeyslot, LuksSegment, LuksSegmentSize, }; -use ra_rpc::client::{CertInfo, RaClient, RaClientConfig}; -use ra_tls::cert::{generate_ra_cert, CertConfigV2}; +use ra_rpc::{ + client::{CertInfo, RaClient, RaClientConfig}, + Attestation, +}; +use ra_tls::{ + attestation::QuoteContentType, + cert::{generate_ra_cert, CertConfigV2, CertSigningRequestV2, Csr}, +}; use rand::Rng as _; use safe_write::safe_write; use scopeguard::defer; @@ -53,6 +59,29 @@ use ra_tls::rcgen::{KeyPair, PKCS_ECDSA_P256_SHA256}; use serde_human_bytes as hex_bytes; use serde_json::Value; +async fn sign_cert_request( + cert_client: &CertRequestClient, + key: &KeyPair, + config: CertConfigV2, +) -> Result> { + let pubkey = key.public_key_der(); + let report_data = QuoteContentType::RaTlsCert.to_report_data(&pubkey); + let attestation = Attestation::quote(&report_data) + .context("Failed to get quote for cert pubkey")? + .into_versioned(); + let csr = CertSigningRequestV2 { + confirm: "please sign cert:".to_string(), + pubkey, + config, + attestation, + }; + let signature = csr.signed_by(key).context("Failed to sign the CSR")?; + cert_client + .sign_csr(&csr, &signature) + .await + .context("Failed to sign the CSR") +} + mod config_id_verifier; #[derive(clap::Parser)] @@ -500,8 +529,7 @@ impl<'a> GatewayContext<'a> { not_before: None, not_after: Some(cert_not_after), }; - let certs = cert_client - .request_cert(&key, config, None) + let certs = sign_cert_request(&cert_client, &key, config) .await .context("Failed to request cert")?; let client_cert = certs.join("\n"); @@ -520,8 +548,7 @@ impl<'a> GatewayContext<'a> { not_before: None, not_after: Some(cert_not_after), }; - let certs_with_quote = cert_client - .request_cert(&key, config_with_quote, None) + let certs_with_quote = sign_cert_request(&cert_client, &key, config_with_quote) .await .context("Failed to request cert with quote")?; let client_cert_with_quote = certs_with_quote.join("\n"); diff --git a/guest-agent-simulator/.gitignore b/guest-agent-simulator/.gitignore new file mode 100644 index 000000000..849ddff3b --- /dev/null +++ b/guest-agent-simulator/.gitignore @@ -0,0 +1 @@ +dist/ diff --git a/guest-agent-simulator/Cargo.toml b/guest-agent-simulator/Cargo.toml new file mode 100644 index 000000000..f476e01c2 --- /dev/null +++ b/guest-agent-simulator/Cargo.toml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "dstack-guest-agent-simulator" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true + +[[bin]] +name = "dstack-simulator" +path = "src/main.rs" + +[dependencies] +anyhow.workspace = true +clap.workspace = true +serde.workspace = true +serde_json.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true +rocket.workspace = true +ra-rpc = { workspace = true, features = ["rocket"] } +ra-tls = { workspace = true, features = ["quote"] } +dstack-guest-agent = { path = "../guest-agent" } +dstack-guest-agent-rpc.workspace = true diff --git a/guest-agent-simulator/dstack-simulator.service b/guest-agent-simulator/dstack-simulator.service new file mode 100644 index 000000000..64b7be66e --- /dev/null +++ b/guest-agent-simulator/dstack-simulator.service @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: © 2026 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +[Unit] +Description=dstack Simulator Service +After=network.target + +[Service] +Type=simple +WorkingDirectory=@INSTALL_DIR@ +ExecStart=@INSTALL_DIR@/dstack-simulator -c @INSTALL_DIR@/dstack.toml +Restart=on-failure +RestartSec=2s +User=@USER@ +Group=@GROUP@ +Environment=RUST_LOG=@RUST_LOG@ +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=multi-user.target diff --git a/guest-agent-simulator/dstack.toml b/guest-agent-simulator/dstack.toml new file mode 100644 index 000000000..b1c7325a6 --- /dev/null +++ b/guest-agent-simulator/dstack.toml @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +[default] +workers = 8 +max_blocking = 64 +ident = "dstack Simulator" +temp_dir = "/tmp" +keep_alive = 10 +log_level = "debug" + +[default.core] +keys_file = "appkeys.json" +compose_file = "app-compose.json" +sys_config_file = "sys-config.json" +data_disks = ["/"] + +[default.core.simulator] +attestation_file = "attestation.bin" +patch_report_data = true + +[internal-v0] +address = "unix:./tappd.sock" +reuse = true + +[internal] +address = "unix:./dstack.sock" +reuse = true + +[external] +address = "unix:./external.sock" +reuse = true + +[guest-api] +address = "unix:./guest.sock" +reuse = true diff --git a/guest-agent-simulator/install-systemd.sh b/guest-agent-simulator/install-systemd.sh new file mode 100755 index 000000000..01a2a6cfe --- /dev/null +++ b/guest-agent-simulator/install-systemd.sh @@ -0,0 +1,274 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2026 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +REPO="Dstack-TEE/dstack" +TARGET="x86_64-unknown-linux-musl" +INSTALL_ROOT="/opt/dstack-simulator" +SERVICE_NAME="dstack-simulator" +SERVICE_FILE="/etc/systemd/system/${SERVICE_NAME}.service" +BIN_LINK="/usr/local/bin/dstack-simulator" +RUN_USER="root" +RUN_GROUP="root" +RUST_LOG="info" +VERSION="" +TARBALL="" +SKIP_SYSTEMD=0 + +usage() { + cat <<'EOF' +Usage: install-systemd.sh [options] + +Install dstack-simulator from a GitHub release tarball and register it as a systemd service. + +Options: + --version Release version to install (e.g. 0.5.8). Defaults to latest simulator release. + --tarball Install from a local tarball or explicit URL. + --repo GitHub repository to download from. Default: Dstack-TEE/dstack + --target Target triple asset to download. Default: x86_64-unknown-linux-musl + --install-root Installation root. Default: /opt/dstack-simulator + --service-name systemd service name. Default: dstack-simulator + --service-file systemd unit path. Default: /etc/systemd/system/dstack-simulator.service + --bin-link Binary symlink path. Default: /usr/local/bin/dstack-simulator + --user Service user. Default: root + --group Service group. Default: root + --rust-log RUST_LOG value for the systemd unit. Default: info + --skip-systemd Install files but skip systemd daemon-reload/enable/start. + -h, --help Show this help text. +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --version) + VERSION="$2" + shift 2 + ;; + --tarball) + TARBALL="$2" + shift 2 + ;; + --repo) + REPO="$2" + shift 2 + ;; + --target) + TARGET="$2" + shift 2 + ;; + --install-root) + INSTALL_ROOT="$2" + shift 2 + ;; + --service-name) + SERVICE_NAME="$2" + SERVICE_FILE="/etc/systemd/system/${SERVICE_NAME}.service" + shift 2 + ;; + --service-file) + SERVICE_FILE="$2" + shift 2 + ;; + --bin-link) + BIN_LINK="$2" + shift 2 + ;; + --user) + RUN_USER="$2" + shift 2 + ;; + --group) + RUN_GROUP="$2" + shift 2 + ;; + --rust-log) + RUST_LOG="$2" + shift 2 + ;; + --skip-systemd) + SKIP_SYSTEMD=1 + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + usage >&2 + exit 1 + ;; + esac +done + +if [[ $EUID -ne 0 ]]; then + echo "Please run as root." >&2 + exit 1 +fi + +need_cmd() { + command -v "$1" >/dev/null 2>&1 || { + echo "Missing required command: $1" >&2 + exit 1 + } +} + +need_cmd curl +need_cmd tar +need_cmd python3 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +LOCAL_BUNDLE_DIR="" +if [[ -f "$SCRIPT_DIR/dstack-simulator" && -f "$SCRIPT_DIR/dstack.toml" ]]; then + LOCAL_BUNDLE_DIR="$SCRIPT_DIR" +fi + +cleanup() { + if [[ -n "${TMP_DIR:-}" && -d "${TMP_DIR:-}" ]]; then + rm -rf "$TMP_DIR" + fi +} +trap cleanup EXIT + +normalize_version() { + local value="$1" + value="${value#refs/tags/}" + value="${value#simulator-v}" + echo "$value" +} + +latest_version() { + curl -fsSL "https://api.github.com/repos/${REPO}/releases?per_page=100" | python3 -c ' +import json, sys +for release in json.load(sys.stdin): + if release.get("draft") or release.get("prerelease"): + continue + tag = release.get("tag_name", "") + if tag.startswith("simulator-v"): + print(tag[len("simulator-v"):]) + break +else: + raise SystemExit("No simulator release found") +' +} + +guess_local_version() { + local base + base="$(basename "$LOCAL_BUNDLE_DIR")" + base="${base#dstack-simulator-}" + base="${base%-${TARGET}}" + if [[ "$base" == "dstack-simulator" || -z "$base" ]]; then + return 1 + fi + echo "$base" +} + +if [[ -z "$VERSION" ]]; then + if [[ -n "$LOCAL_BUNDLE_DIR" ]]; then + VERSION="$(guess_local_version || true)" + fi + if [[ -z "$VERSION" ]]; then + VERSION="$(latest_version)" + fi +fi +VERSION="$(normalize_version "$VERSION")" + +ASSET_NAME="dstack-simulator-${VERSION}-${TARGET}.tar.gz" +TAG="simulator-v${VERSION}" +TMP_DIR="$(mktemp -d)" + +fetch_to_file() { + local source="$1" + local dest="$2" + if [[ "$source" =~ ^https?:// ]]; then + curl -fsSL "$source" -o "$dest" + else + cp "$source" "$dest" + fi +} + +extract_tarball() { + local tarball_path="$1" + local dest_dir="$2" + tar -xzf "$tarball_path" -C "$dest_dir" +} + +BUNDLE_DIR="" +if [[ -n "$LOCAL_BUNDLE_DIR" && -z "$TARBALL" ]]; then + BUNDLE_DIR="$LOCAL_BUNDLE_DIR" +else + TARBALL_PATH="$TMP_DIR/$ASSET_NAME" + CHECKSUM_PATH="$TMP_DIR/${ASSET_NAME}.sha256" + + if [[ -n "$TARBALL" ]]; then + echo "Fetching simulator tarball from: $TARBALL" + fetch_to_file "$TARBALL" "$TARBALL_PATH" + else + local_url="https://github.com/${REPO}/releases/download/${TAG}/${ASSET_NAME}" + checksum_url="${local_url}.sha256" + echo "Downloading simulator release ${TAG}" + fetch_to_file "$local_url" "$TARBALL_PATH" + if curl -fsSL "$checksum_url" -o "$CHECKSUM_PATH"; then + ( + cd "$TMP_DIR" + sha256sum -c "$(basename "$CHECKSUM_PATH")" + ) + else + echo "Warning: checksum asset not found, skipping checksum verification." >&2 + fi + fi + + extract_tarball "$TARBALL_PATH" "$TMP_DIR" + BUNDLE_DIR="$TMP_DIR/dstack-simulator-${VERSION}-${TARGET}" +fi + +if [[ ! -f "$BUNDLE_DIR/dstack-simulator" || ! -f "$BUNDLE_DIR/dstack.toml" ]]; then + echo "Bundle directory is missing expected simulator files: $BUNDLE_DIR" >&2 + exit 1 +fi + +VERSION_DIR="$INSTALL_ROOT/releases/$VERSION" +CURRENT_DIR="$INSTALL_ROOT/current" + +mkdir -p "$INSTALL_ROOT/releases" +rm -rf "$VERSION_DIR" +mkdir -p "$VERSION_DIR" +cp -a "$BUNDLE_DIR/." "$VERSION_DIR/" +ln -sfn "$VERSION_DIR" "$CURRENT_DIR" + +chown -R "$RUN_USER:$RUN_GROUP" "$VERSION_DIR" +ln -sfn "$CURRENT_DIR/dstack-simulator" "$BIN_LINK" + +python3 - "$BUNDLE_DIR/dstack-simulator.service" "$SERVICE_FILE" "$CURRENT_DIR" "$RUN_USER" "$RUN_GROUP" "$RUST_LOG" <<'PY' +from pathlib import Path +import sys + +template_path, output_path, install_dir, user, group, rust_log = sys.argv[1:] +template = Path(template_path).read_text() +rendered = ( + template + .replace("@INSTALL_DIR@", install_dir) + .replace("@USER@", user) + .replace("@GROUP@", group) + .replace("@RUST_LOG@", rust_log) +) +Path(output_path).write_text(rendered) +PY + +echo "Installed dstack-simulator ${VERSION} to ${CURRENT_DIR}" +echo "Binary symlink: ${BIN_LINK}" +echo "Service file: ${SERVICE_FILE}" + +if [[ "$SKIP_SYSTEMD" -eq 0 ]]; then + need_cmd systemctl + UNIT_NAME="$(basename "$SERVICE_FILE")" + systemctl daemon-reload + systemctl enable --now "$UNIT_NAME" + echo "systemd service enabled and started: ${UNIT_NAME}" +else + echo "Skipping systemd enable/start (--skip-systemd)." +fi diff --git a/guest-agent-simulator/package-release.sh b/guest-agent-simulator/package-release.sh new file mode 100755 index 000000000..5223eda66 --- /dev/null +++ b/guest-agent-simulator/package-release.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2026 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +usage() { + cat <<'EOF' +Usage: package-release.sh [--binary ] [--out-dir ] + +Create a self-contained dstack-simulator release tarball that includes: + - the simulator binary + - default simulator config and fixture data + - a systemd unit template + - the install-systemd.sh helper +EOF +} + +if [[ $# -lt 2 ]]; then + usage >&2 + exit 1 +fi + +VERSION="$1" +TARGET="$2" +shift 2 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +OUT_DIR="$SCRIPT_DIR/dist" +BINARY_PATH="$ROOT_DIR/target/$TARGET/release/dstack-simulator" + +while [[ $# -gt 0 ]]; do + case "$1" in + --binary) + BINARY_PATH="$2" + shift 2 + ;; + --out-dir) + OUT_DIR="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + usage >&2 + exit 1 + ;; + esac +done + +if [[ ! -f "$BINARY_PATH" ]]; then + echo "Simulator binary not found: $BINARY_PATH" >&2 + exit 1 +fi + +PACKAGE_NAME="dstack-simulator-${VERSION}-${TARGET}" +STAGE_DIR="$OUT_DIR/$PACKAGE_NAME" +TARBALL_PATH="$OUT_DIR/${PACKAGE_NAME}.tar.gz" +CHECKSUM_PATH="${TARBALL_PATH}.sha256" + +rm -rf "$STAGE_DIR" "$TARBALL_PATH" "$CHECKSUM_PATH" +mkdir -p "$STAGE_DIR" + +install -m 755 "$BINARY_PATH" "$STAGE_DIR/dstack-simulator" +install -m 644 "$ROOT_DIR/sdk/simulator/dstack.toml" "$STAGE_DIR/dstack.toml" +install -m 644 "$ROOT_DIR/sdk/simulator/app-compose.json" "$STAGE_DIR/app-compose.json" +install -m 644 "$ROOT_DIR/sdk/simulator/appkeys.json" "$STAGE_DIR/appkeys.json" +install -m 644 "$ROOT_DIR/sdk/simulator/sys-config.json" "$STAGE_DIR/sys-config.json" +install -m 644 "$ROOT_DIR/sdk/simulator/attestation.bin" "$STAGE_DIR/attestation.bin" +install -m 644 "$SCRIPT_DIR/dstack-simulator.service" "$STAGE_DIR/dstack-simulator.service" +install -m 755 "$SCRIPT_DIR/install-systemd.sh" "$STAGE_DIR/install-systemd.sh" + +tar -C "$OUT_DIR" -czf "$TARBALL_PATH" "$PACKAGE_NAME" +( + cd "$OUT_DIR" + sha256sum "$(basename "$TARBALL_PATH")" > "$(basename "$CHECKSUM_PATH")" +) + +echo "Created release bundle:" +echo " $TARBALL_PATH" +echo " $CHECKSUM_PATH" diff --git a/guest-agent-simulator/src/main.rs b/guest-agent-simulator/src/main.rs new file mode 100644 index 000000000..65e0b0b9b --- /dev/null +++ b/guest-agent-simulator/src/main.rs @@ -0,0 +1,196 @@ +// SPDX-FileCopyrightText: © 2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +mod simulator; + +use std::sync::Arc; + +use anyhow::{bail, Context, Result}; +use clap::Parser; +use dstack_guest_agent::{ + backend::PlatformBackend, + config::{self, Config}, + run_server, AppState, +}; +use dstack_guest_agent_rpc::{AttestResponse, GetQuoteResponse}; +use ra_rpc::Attestation; +use ra_tls::attestation::VersionedAttestation; +use serde::Deserialize; +use tracing::warn; + +const DEFAULT_CONFIG: &str = include_str!("../dstack.toml"); + +#[derive(Parser)] +#[command(author, version, about = "dstack guest agent simulator", long_version = dstack_guest_agent::app_version())] +struct Args { + /// Path to the configuration file + #[arg(short, long)] + config: Option, + + /// Enable systemd watchdog + #[arg(short, long)] + watchdog: bool, +} + +#[derive(Debug, Clone, Deserialize)] +struct SimulatorSettings { + attestation_file: String, + #[serde(default = "default_patch_report_data")] + patch_report_data: bool, +} + +#[derive(Debug, Clone, Deserialize)] +struct SimulatorCoreConfig { + #[serde(flatten)] + core: Config, + simulator: SimulatorSettings, +} + +struct SimulatorPlatform { + attestation: VersionedAttestation, + patch_report_data: bool, +} + +impl SimulatorPlatform { + fn new(attestation: VersionedAttestation, patch_report_data: bool) -> Self { + Self { + attestation, + patch_report_data, + } + } +} + +fn default_patch_report_data() -> bool { + true +} + +impl PlatformBackend for SimulatorPlatform { + fn attestation_for_info(&self) -> Result { + Ok(simulator::simulated_info_attestation(&self.attestation)) + } + + fn certificate_attestation(&self, pubkey: &[u8]) -> Result { + Ok(simulator::simulated_certificate_attestation( + &self.attestation, + pubkey, + self.patch_report_data, + )) + } + + fn quote_response(&self, report_data: [u8; 64], vm_config: &str) -> Result { + simulator::simulated_quote_response( + &self.attestation, + report_data, + vm_config, + self.patch_report_data, + ) + } + + fn attest_response(&self, report_data: [u8; 64]) -> Result { + Ok(simulator::simulated_attest_response( + &self.attestation, + report_data, + self.patch_report_data, + )) + } + + fn emit_event(&self, event: &str, _payload: &[u8]) -> Result<()> { + bail!("runtime event emission is unavailable in simulator mode: {event}") + } +} + +#[rocket::main] +async fn main() -> Result<()> { + { + use tracing_subscriber::{fmt, EnvFilter}; + let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); + fmt().with_env_filter(filter).with_ansi(false).init(); + } + let args = Args::parse(); + let figment = config::load_config_figment_with_default(DEFAULT_CONFIG, args.config.as_deref()); + let sim_config: SimulatorCoreConfig = figment + .focus("core") + .extract() + .context("Failed to extract simulator core config")?; + warn!( + attestation_file = %sim_config.simulator.attestation_file, + patch_report_data = sim_config.simulator.patch_report_data, + "starting dstack guest-agent simulator" + ); + if sim_config.simulator.patch_report_data { + warn!("simulator will rewrite report_data to match requests; quote verification may fail against the original fixture signature"); + } else { + warn!("simulator will preserve fixture report_data; cert/key binding and requested report_data may not match"); + } + let attestation = + simulator::load_versioned_attestation(&sim_config.simulator.attestation_file)?; + let state = AppState::new_with_platform( + sim_config.core, + Arc::new(SimulatorPlatform::new( + attestation, + sim_config.simulator.patch_report_data, + )), + ) + .await + .context("Failed to create simulator app state")?; + run_server(state, figment, args.watchdog).await +} + +#[cfg(test)] +mod tests { + use super::*; + + fn load_fixture_platform() -> SimulatorPlatform { + let fixture = simulator::load_versioned_attestation( + std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../guest-agent/fixtures/attestation.bin"), + ) + .expect("fixture attestation should load"); + SimulatorPlatform::new(fixture, true) + } + + #[test] + fn simulator_rejects_runtime_event_emission() { + let platform = load_fixture_platform(); + let err = platform.emit_event("test.event", b"payload").unwrap_err(); + assert!(err.to_string().contains("unavailable in simulator mode")); + } + + #[test] + fn simulator_provides_certificate_attestation() { + let platform = load_fixture_platform(); + let cert_attestation = platform + .certificate_attestation(b"test-public-key") + .unwrap(); + assert!(cert_attestation.decode_app_info(false).is_ok()); + let _ = platform.attestation_for_info().unwrap(); + } + + #[test] + fn simulator_attest_response_uses_supplied_report_data() { + let platform = load_fixture_platform(); + let report_data = [0x5a; 64]; + let response = platform.attest_response(report_data).unwrap(); + let patched = VersionedAttestation::from_scale(&response.attestation).unwrap(); + let VersionedAttestation::V0 { attestation } = patched; + assert_eq!(attestation.report_data, report_data); + } + + #[test] + fn simulator_can_preserve_fixture_report_data() { + let fixture = simulator::load_versioned_attestation( + std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../guest-agent/fixtures/attestation.bin"), + ) + .expect("fixture attestation should load"); + let original = fixture.clone().into_inner().report_data; + let platform = SimulatorPlatform::new(fixture, false); + let report_data = [0x5a; 64]; + let response = platform.attest_response(report_data).unwrap(); + let patched = VersionedAttestation::from_scale(&response.attestation).unwrap(); + let VersionedAttestation::V0 { attestation } = patched; + assert_eq!(attestation.report_data, original); + assert_ne!(attestation.report_data, report_data); + } +} diff --git a/guest-agent-simulator/src/simulator.rs b/guest-agent-simulator/src/simulator.rs new file mode 100644 index 000000000..abe1781fc --- /dev/null +++ b/guest-agent-simulator/src/simulator.rs @@ -0,0 +1,107 @@ +// SPDX-FileCopyrightText: © 2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +use std::path::Path; + +use anyhow::{Context, Result}; +use dstack_guest_agent_rpc::{AttestResponse, GetQuoteResponse}; +use ra_rpc::Attestation; +use ra_tls::attestation::{QuoteContentType, VersionedAttestation, TDX_QUOTE_REPORT_DATA_RANGE}; +use std::fs; +use tracing::warn; + +pub fn load_versioned_attestation(path: impl AsRef) -> Result { + let path = path.as_ref(); + let attestation_bytes = fs::read(path).with_context(|| { + format!( + "Failed to read simulator attestation file: {}", + path.display() + ) + })?; + VersionedAttestation::from_scale(&attestation_bytes) + .context("Failed to decode simulator attestation") +} + +pub fn simulated_quote_response( + attestation: &VersionedAttestation, + report_data: [u8; 64], + vm_config: &str, + patch_report_data: bool, +) -> Result { + let VersionedAttestation::V0 { attestation } = + maybe_patch_report_data(attestation, report_data, patch_report_data, "quote"); + let mut attestation = attestation; + let Some(quote) = attestation.tdx_quote_mut() else { + return Err(anyhow::anyhow!("Quote not found")); + }; + + Ok(GetQuoteResponse { + quote: quote.quote.to_vec(), + event_log: serde_json::to_string("e.event_log) + .context("Failed to serialize event log")?, + report_data: report_data.to_vec(), + vm_config: vm_config.to_string(), + }) +} + +pub fn simulated_attest_response( + attestation: &VersionedAttestation, + report_data: [u8; 64], + patch_report_data: bool, +) -> AttestResponse { + AttestResponse { + attestation: maybe_patch_report_data(attestation, report_data, patch_report_data, "attest") + .to_scale(), + } +} + +pub fn simulated_info_attestation(attestation: &VersionedAttestation) -> Attestation { + attestation.clone().into_inner() +} + +pub fn simulated_certificate_attestation( + attestation: &VersionedAttestation, + pubkey: &[u8], + patch_report_data: bool, +) -> VersionedAttestation { + let report_data = QuoteContentType::RaTlsCert.to_report_data(pubkey); + maybe_patch_report_data( + attestation, + report_data, + patch_report_data, + "certificate_attestation", + ) +} + +fn maybe_patch_report_data( + attestation: &VersionedAttestation, + report_data: [u8; 64], + patch_report_data: bool, + context: &str, +) -> VersionedAttestation { + if !patch_report_data { + warn!( + context = context, + requested_report_data = ?report_data, + "simulator is preserving fixture report_data; returned attestation may not match the current request" + ); + return attestation.clone(); + } + + let VersionedAttestation::V0 { attestation } = attestation.clone(); + let mut attestation = attestation; + attestation.report_data = report_data; + if let Some(tdx_quote) = attestation.tdx_quote_mut() { + if tdx_quote.quote.len() >= TDX_QUOTE_REPORT_DATA_RANGE.end { + tdx_quote.quote[TDX_QUOTE_REPORT_DATA_RANGE].copy_from_slice(&report_data); + } else { + warn!( + "TDX quote too short to patch report_data ({} < {})", + tdx_quote.quote.len(), + TDX_QUOTE_REPORT_DATA_RANGE.end + ); + } + } + VersionedAttestation::V0 { attestation } +} diff --git a/guest-agent/dstack.toml b/guest-agent/dstack.toml index b350d886a..8b4bcd5db 100644 --- a/guest-agent/dstack.toml +++ b/guest-agent/dstack.toml @@ -16,9 +16,6 @@ compose_file = "/dstack/.host-shared/app-compose.json" sys_config_file = "/dstack/.host-shared/.sys-config.json" data_disks = ["/"] -[default.core.simulator] -enabled = false -attestation_file = "attestation.bin" [internal-v0] address = "unix:/var/run/dstack/tappd.sock" diff --git a/guest-agent/src/backend.rs b/guest-agent/src/backend.rs new file mode 100644 index 000000000..4a7d4fa91 --- /dev/null +++ b/guest-agent/src/backend.rs @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::{Context, Result}; +use dstack_attest::emit_runtime_event; +use dstack_guest_agent_rpc::{AttestResponse, GetQuoteResponse}; +use ra_rpc::Attestation; +use ra_tls::attestation::{QuoteContentType, VersionedAttestation}; + +pub trait PlatformBackend: Send + Sync { + fn attestation_for_info(&self) -> Result; + fn certificate_attestation(&self, pubkey: &[u8]) -> Result; + fn quote_response(&self, report_data: [u8; 64], vm_config: &str) -> Result; + fn attest_response(&self, report_data: [u8; 64]) -> Result; + fn emit_event(&self, event: &str, payload: &[u8]) -> Result<()>; +} + +#[derive(Debug, Default)] +pub struct RealPlatform; + +impl PlatformBackend for RealPlatform { + fn attestation_for_info(&self) -> Result { + Attestation::local().context("Failed to get local attestation") + } + + fn certificate_attestation(&self, pubkey: &[u8]) -> Result { + let report_data = QuoteContentType::RaTlsCert.to_report_data(pubkey); + Ok(Attestation::quote(&report_data) + .context("Failed to get quote for cert pubkey")? + .into_versioned()) + } + + fn quote_response(&self, report_data: [u8; 64], vm_config: &str) -> Result { + let attestation = Attestation::quote(&report_data).context("Failed to get quote")?; + let tdx_quote = attestation.get_tdx_quote_bytes(); + let tdx_event_log = attestation.get_tdx_event_log_string(); + Ok(GetQuoteResponse { + quote: tdx_quote.unwrap_or_default(), + event_log: tdx_event_log.unwrap_or_default(), + report_data: report_data.to_vec(), + vm_config: vm_config.to_string(), + }) + } + + fn attest_response(&self, report_data: [u8; 64]) -> Result { + let attestation = Attestation::quote(&report_data).context("Failed to get attestation")?; + Ok(AttestResponse { + attestation: attestation.into_versioned().to_scale(), + }) + } + + fn emit_event(&self, event: &str, payload: &[u8]) -> Result<()> { + emit_runtime_event(event, payload) + } +} diff --git a/guest-agent/src/config.rs b/guest-agent/src/config.rs index 6276a4daa..7f94185d7 100644 --- a/guest-agent/src/config.rs +++ b/guest-agent/src/config.rs @@ -13,7 +13,14 @@ use serde::{de::Error, Deserialize}; pub const DEFAULT_CONFIG: &str = include_str!("../dstack.toml"); pub fn load_config_figment(config_file: Option<&str>) -> Figment { - load_config("dstack", DEFAULT_CONFIG, config_file, true) + load_config_figment_with_default(DEFAULT_CONFIG, config_file) +} + +pub fn load_config_figment_with_default( + default_config: &str, + config_file: Option<&str>, +) -> Figment { + load_config("dstack", default_config, config_file, true) } #[derive(Debug, Clone, Copy, Deserialize)] @@ -43,8 +50,6 @@ pub struct Config { pub sys_config_file: PathBuf, #[serde(default)] pub pccs_url: Option, - pub simulator: Simulator, - // List of disks to be shown in the dashboard pub data_disks: HashSet, } @@ -67,9 +72,3 @@ where raw: content, }) } - -#[derive(Debug, Clone, Deserialize)] -pub struct Simulator { - pub enabled: bool, - pub attestation_file: String, -} diff --git a/guest-agent/src/lib.rs b/guest-agent/src/lib.rs new file mode 100644 index 000000000..ef6f844ce --- /dev/null +++ b/guest-agent/src/lib.rs @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +pub const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); +pub const GIT_REV: &str = git_version::git_version!( + args = ["--abbrev=20", "--always", "--dirty=-modified"], + prefix = "git:", + fallback = "unknown" +); + +pub mod backend; +pub mod config; +mod guest_api_service; +mod http_routes; +mod models; +pub mod rpc_service; +mod server; +mod socket_activation; + +pub use rpc_service::AppState; +pub use server::{app_version, run as run_server}; diff --git a/guest-agent/src/main.rs b/guest-agent/src/main.rs index bb68ee800..0263d3f7f 100644 --- a/guest-agent/src/main.rs +++ b/guest-agent/src/main.rs @@ -2,45 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 -use std::{future::pending, os::unix::net::UnixListener as StdUnixListener}; - -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result}; use clap::Parser; -use config::BindAddr; -use guest_api_service::GuestApiHandler; -use rocket::{ - fairing::AdHoc, - figment::Figment, - listener::{Bind, DefaultListener}, -}; -use rocket_vsock_listener::VsockListener; -use rpc_service::{AppState, ExternalRpcHandler, InternalRpcHandler, InternalRpcHandlerV0}; -use sd_notify::{notify as sd_notify, NotifyState}; -use socket_activation::{ActivatedSockets, ActivatedUnixListener}; -use std::time::Duration; -use tokio::sync::oneshot; -use tracing::{error, info}; - -mod config; -mod guest_api_service; -mod http_routes; -mod models; -mod rpc_service; -mod socket_activation; - -const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); -const GIT_REV: &str = git_version::git_version!( - args = ["--abbrev=20", "--always", "--dirty=-modified"], - prefix = "git:", - fallback = "unknown" -); - -fn app_version() -> String { - format!("v{CARGO_PKG_VERSION} ({GIT_REV})") -} +use dstack_guest_agent::{config, run_server, AppState}; #[derive(Parser)] -#[command(author, version, about, long_version = app_version())] +#[command(author, version, about, long_version = dstack_guest_agent::app_version())] struct Args { /// Path to the configuration file #[arg(short, long)] @@ -51,177 +18,6 @@ struct Args { watchdog: bool, } -async fn run_internal_v0( - state: AppState, - figment: Figment, - activated_socket: Option, - sock_ready_tx: oneshot::Sender<()>, -) -> Result<()> { - let rocket = rocket::custom(figment) - .mount( - "/prpc/", - ra_rpc::prpc_routes!(AppState, InternalRpcHandlerV0, trim: "Tappd."), - ) - .manage(state); - let ignite = rocket - .ignite() - .await - .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; - - if let Some(std_listener) = activated_socket { - // Use systemd-activated socket - info!("Using systemd-activated socket for tappd.sock"); - let listener = ActivatedUnixListener::new(std_listener)?; - sock_ready_tx.send(()).ok(); - ignite - .launch_on(listener) - .await - .map_err(|err: rocket::Error| anyhow!(err.to_string()))?; - } else { - // Fall back to binding our own socket - let endpoint = DefaultListener::bind_endpoint(&ignite) - .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; - let listener = DefaultListener::bind(&ignite) - .await - .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; - sock_ready_tx.send(()).ok(); - ignite - .launch_on(listener) - .await - .map_err(|err| anyhow!(err.to_string()))?; - } - Ok(()) -} - -async fn run_internal( - state: AppState, - figment: Figment, - activated_socket: Option, - sock_ready_tx: oneshot::Sender<()>, -) -> Result<()> { - let rocket = rocket::custom(figment) - .mount("/", ra_rpc::prpc_routes!(AppState, InternalRpcHandler)) - .manage(state); - let ignite = rocket - .ignite() - .await - .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; - - if let Some(std_listener) = activated_socket { - // Use systemd-activated socket - info!("Using systemd-activated socket for dstack.sock"); - let listener = ActivatedUnixListener::new(std_listener)?; - sock_ready_tx.send(()).ok(); - ignite - .launch_on(listener) - .await - .map_err(|err: rocket::Error| anyhow!(err.to_string()))?; - } else { - // Fall back to binding our own socket - let endpoint = DefaultListener::bind_endpoint(&ignite) - .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; - let listener = DefaultListener::bind(&ignite) - .await - .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; - sock_ready_tx.send(()).ok(); - ignite - .launch_on(listener) - .await - .map_err(|err| anyhow!(err.to_string()))?; - } - Ok(()) -} - -async fn run_external(state: AppState, figment: Figment) -> Result<()> { - let rocket = rocket::custom(figment) - .mount("/", http_routes::external_routes(state.config())) - .mount( - "/prpc", - ra_rpc::prpc_routes!(AppState, ExternalRpcHandler, trim: "Worker."), - ) - .attach(AdHoc::on_response("Add app version header", |_req, res| { - Box::pin(async move { - res.set_raw_header("X-App-Version", app_version()); - }) - })) - .manage(state); - let _ = rocket - .launch() - .await - .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; - Ok(()) -} - -async fn run_guest_api(state: AppState, figment: Figment) -> Result<()> { - let rocket = rocket::custom(figment) - .mount("/api", ra_rpc::prpc_routes!(AppState, GuestApiHandler)) - .manage(state); - - let ignite = rocket - .ignite() - .await - .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; - if DefaultListener::bind_endpoint(&ignite).is_ok() { - let listener = DefaultListener::bind(&ignite) - .await - .map_err(|err| anyhow!("Failed to bind guest API : {err}"))?; - ignite - .launch_on(listener) - .await - .map_err(|err| anyhow!(err.to_string()))?; - } else { - let listener = VsockListener::bind_rocket(&ignite) - .map_err(|err| anyhow!("Failed to bind guest API : {err}"))?; - ignite - .launch_on(listener) - .await - .map_err(|err| anyhow!(err.to_string()))?; - } - Ok(()) -} - -async fn run_watchdog(port: u16) { - let mut watchdog_usec = 0; - let enabled = sd_notify::watchdog_enabled(false, &mut watchdog_usec); - if !enabled { - info!("Watchdog is not enabled in systemd service"); - return pending::<()>().await; - } - - info!("Starting watchdog"); - // Notify systemd that we're ready - if let Err(err) = sd_notify(false, &[NotifyState::Ready]) { - error!("Failed to notify systemd: {err}"); - } - let heatbeat_interval = Duration::from_micros(watchdog_usec / 2); - let heatbeat_interval = heatbeat_interval.max(Duration::from_secs(1)); - info!("Watchdog enabled, interval={watchdog_usec}us, heartbeat={heatbeat_interval:?}",); - let mut interval = tokio::time::interval(heatbeat_interval); - - let probe_url = format!("http://localhost:{port}/prpc/Worker.Version"); - loop { - interval.tick().await; - - // Create HTTP client for health checks - let client = reqwest::Client::new(); - // Perform health check - match client.get(&probe_url).send().await { - Ok(response) if response.status().is_success() => { - // Only notify systemd if health check passes - if let Err(err) = sd_notify(false, &[NotifyState::Watchdog]) { - error!("Failed to notify systemd: {err}"); - } - } - Ok(response) => { - error!("Health check failed with status: {}", response.status()); - } - Err(err) => { - error!("Health check request failed: {err:?}"); - } - } - } -} - #[rocket::main] async fn main() -> Result<()> { { @@ -234,36 +30,5 @@ async fn main() -> Result<()> { let state = AppState::new(figment.focus("core").extract()?) .await .context("Failed to create app state")?; - let internal_v0_figment = figment.clone().select("internal-v0"); - let internal_figment = figment.clone().select("internal"); - let external_figment = figment.clone().select("external"); - let bind_addr: BindAddr = external_figment - .extract() - .context("Failed to extract bind address")?; - let guest_api_figment = figment.select("guest-api"); - - // Get systemd-activated sockets if available - let activated = ActivatedSockets::from_env(); - if activated.any_activated() { - info!("Systemd socket activation detected"); - } - - let (tappd_ready_tx, tappd_ready_rx) = oneshot::channel(); - let (sock_ready_tx, sock_ready_rx) = oneshot::channel(); - tokio::select!( - res = run_internal_v0(state.clone(), internal_v0_figment, activated.tappd, tappd_ready_tx) => res?, - res = run_internal(state.clone(), internal_figment, activated.dstack, sock_ready_tx) => res?, - res = run_external(state.clone(), external_figment) => res?, - res = run_guest_api(state.clone(), guest_api_figment) => res?, - _ = async { - let _ = tappd_ready_rx.await; - let _ = sock_ready_rx.await; - if args.watchdog { - run_watchdog(bind_addr.port).await; - } else { - pending::<()>().await; - } - } => {} - ); - Ok(()) + run_server(state, figment, args.watchdog).await } diff --git a/guest-agent/src/rpc_service.rs b/guest-agent/src/rpc_service.rs index 80d10ff59..087139c49 100644 --- a/guest-agent/src/rpc_service.rs +++ b/guest-agent/src/rpc_service.rs @@ -6,9 +6,7 @@ use std::sync::{Arc, RwLock}; use anyhow::{Context, Result}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; -use cc_eventlog::tdx::read_event_log; use cert_client::CertRequestClient; -use dstack_attest::emit_runtime_event; use dstack_guest_agent_rpc::{ dstack_guest_server::{DstackGuestRpc, DstackGuestServer}, tappd_server::{TappdRpc, TappdServer}, @@ -26,12 +24,10 @@ use ed25519_dalek::{ use fs_err as fs; use k256::ecdsa::SigningKey; use or_panic::ResultOrPanic; -use ra_rpc::{Attestation, CallContext, RpcCall}; +use ra_rpc::{CallContext, RpcCall}; use ra_tls::{ - attestation::{ - QuoteContentType, VersionedAttestation, DEFAULT_HASH_ALGORITHM, TDX_QUOTE_REPORT_DATA_RANGE, - }, - cert::CertConfigV2, + attestation::{QuoteContentType, DEFAULT_HASH_ALGORITHM}, + cert::{CertConfigV2, CertSigningRequestV2, Csr}, kdf::{derive_key, derive_p256_key_pair_from_bytes}, }; use rcgen::KeyPair; @@ -40,7 +36,10 @@ use serde_json::json; use sha3::{Digest, Keccak256}; use tracing::error; -use crate::config::Config; +use crate::{ + backend::{PlatformBackend, RealPlatform}, + config::Config, +}; fn read_dmi_file(name: &str) -> String { fs::read_to_string(format!("/sys/class/dmi/id/{name}")) @@ -59,26 +58,37 @@ struct AppStateInner { vm_config: String, cert_client: CertRequestClient, demo_cert: RwLock, + platform: Arc, } impl AppStateInner { - fn simulator_attestation(&self) -> Result> { - if !self.config.simulator.enabled { - return Ok(None); - } - let attestation_bytes = fs::read(&self.config.simulator.attestation_file) - .context("Failed to read simulator attestation file")?; - let attestation = VersionedAttestation::from_scale(&attestation_bytes) - .context("Failed to decode simulator attestation")?; - Ok(Some(attestation)) + fn info_attestation(&self) -> Result { + self.platform.attestation_for_info() + } + + async fn issue_cert(&self, key: &KeyPair, config: CertConfigV2) -> Result> { + let pubkey = key.public_key_der(); + let attestation = self + .platform + .certificate_attestation(&pubkey) + .context("Failed to get certificate attestation")?; + let csr = CertSigningRequestV2 { + confirm: "please sign cert:".to_string(), + pubkey, + config, + attestation, + }; + let signature = csr.signed_by(key).context("Failed to sign the CSR")?; + self.cert_client + .sign_csr(&csr, &signature) + .await + .context("Failed to sign the CSR") } async fn request_demo_cert(&self) -> Result { let key = KeyPair::generate().context("Failed to generate demo key")?; - let attestation_override = self.simulator_attestation()?; let demo_cert = self - .cert_client - .request_cert( + .issue_cert( &key, CertConfigV2 { org_name: None, @@ -91,7 +101,6 @@ impl AppStateInner { not_after: None, not_before: None, }, - attestation_override, ) .await .context("Failed to get app cert")? @@ -123,7 +132,10 @@ impl AppState { }); } - pub async fn new(config: Config) -> Result { + pub async fn new_with_platform( + config: Config, + platform: Arc, + ) -> Result { let keys: AppKeys = serde_json::from_str(&fs::read_to_string(&config.keys_file)?) .context("Failed to parse app keys")?; let sys_config: SysConfig = @@ -141,15 +153,34 @@ impl AppState { cert_client, demo_cert: RwLock::new(String::new()), vm_config, + platform, }), }; me.maybe_request_demo_cert(); Ok(me) } + pub async fn new(config: Config) -> Result { + Self::new_with_platform(config, Arc::new(RealPlatform)).await + } + pub fn config(&self) -> &Config { &self.inner.config } + + fn quote_response(&self, report_data: [u8; 64]) -> Result { + self.inner + .platform + .quote_response(report_data, &self.inner.vm_config) + } + + fn attest_response(&self, report_data: [u8; 64]) -> Result { + self.inner.platform.attest_response(report_data) + } + + fn emit_event(&self, event: &str, payload: &[u8]) -> Result<()> { + self.inner.platform.emit_event(event, payload) + } } pub struct InternalRpcHandler { @@ -158,14 +189,7 @@ pub struct InternalRpcHandler { pub async fn get_info(state: &AppState, external: bool) -> Result { let hide_tcb_info = external && !state.config().app_compose.public_tcbinfo; - let attestation = if let Some(attestation) = state.inner.simulator_attestation()? { - attestation.into_inner() - } else { - let Ok(attestation) = Attestation::local() else { - return Ok(AppInfo::default()); - }; - attestation - }; + let attestation = state.inner.info_attestation()?; let app_info = attestation .decode_app_info(false) .context("Failed to decode app info")?; @@ -249,18 +273,7 @@ impl DstackGuestRpc for InternalRpcHandler { not_after: request.not_after, not_before: request.not_before, }; - let attestation_override = self - .state - .inner - .simulator_attestation() - .context("Failed to load simulator attestation")?; - let certificate_chain = self - .state - .inner - .cert_client - .request_cert(&derived_key, config, attestation_override) - .await - .context("Failed to sign the CSR")?; + let certificate_chain = self.state.inner.issue_cert(&derived_key, config).await?; Ok(GetTlsKeyResponse { key: derived_key.serialize_pem(), certificate_chain, @@ -312,29 +325,11 @@ impl DstackGuestRpc for InternalRpcHandler { async fn get_quote(self, request: RawQuoteArgs) -> Result { let report_data = pad64(&request.report_data).context("Report data is too long")?; - if self.state.config().simulator.enabled { - return simulate_quote( - self.state.config(), - report_data, - &self.state.inner.vm_config, - ); - } - let attestation = Attestation::quote(&report_data).context("Failed to get quote")?; - let tdx_quote = attestation.get_tdx_quote_bytes(); - let tdx_event_log = attestation.get_tdx_event_log_string(); - Ok(GetQuoteResponse { - quote: tdx_quote.unwrap_or_default(), - event_log: tdx_event_log.unwrap_or_default(), - report_data: report_data.to_vec(), - vm_config: self.state.inner.vm_config.clone(), - }) + self.state.quote_response(report_data) } async fn emit_event(self, request: EmitEventArgs) -> Result<()> { - if self.state.config().simulator.enabled { - return Ok(()); - } - emit_runtime_event(&request.event, &request.payload) + self.state.emit_event(&request.event, &request.payload) } async fn info(self) -> Result { @@ -436,21 +431,13 @@ impl DstackGuestRpc for InternalRpcHandler { async fn attest(self, request: RawQuoteArgs) -> Result { let report_data = pad64(&request.report_data).context("Report data is too long")?; - if let Some(attestation) = self.state.inner.simulator_attestation()? { - return Ok(AttestResponse { - attestation: attestation.to_scale(), - }); - } - let attestation = Attestation::quote(&report_data).context("Failed to get attestation")?; - Ok(AttestResponse { - attestation: attestation.into_versioned().to_scale(), - }) + self.state.attest_response(report_data) } async fn version(self) -> Result { Ok(WorkerVersion { - version: env!("CARGO_PKG_VERSION").to_string(), - rev: super::GIT_REV.to_string(), + version: crate::CARGO_PKG_VERSION.to_string(), + rev: crate::GIT_REV.to_string(), }) } } @@ -473,31 +460,6 @@ fn pad64(data: &[u8]) -> Option<[u8; 64]> { Some(padded) } -fn simulate_quote( - config: &Config, - report_data: [u8; 64], - vm_config: &str, -) -> Result { - let attestation_bytes = fs::read(&config.simulator.attestation_file) - .context("Failed to read simulator attestation file")?; - let VersionedAttestation::V0 { attestation } = - VersionedAttestation::from_scale(&attestation_bytes) - .context("Failed to decode simulator attestation")?; - let mut attestation = attestation; - let Some(quote) = attestation.tdx_quote_mut() else { - return Err(anyhow::anyhow!("Quote not found")); - }; - - quote.quote[TDX_QUOTE_REPORT_DATA_RANGE].copy_from_slice(&report_data); - Ok(GetQuoteResponse { - quote: quote.quote.to_vec(), - event_log: serde_json::to_string("e.event_log) - .context("Failed to serialize event log")?, - report_data: report_data.to_vec(), - vm_config: vm_config.to_string(), - }) -} - impl RpcCall for InternalRpcHandler { type PrpcService = DstackGuestServer; @@ -536,18 +498,7 @@ impl TappdRpc for InternalRpcHandlerV0 { not_before: None, not_after: None, }; - let attestation_override = self - .state - .inner - .simulator_attestation() - .context("Failed to load simulator attestation")?; - let certificate_chain = self - .state - .inner - .cert_client - .request_cert(&derived_key, config, attestation_override) - .await - .context("Failed to sign the CSR")?; + let certificate_chain = self.state.inner.issue_cert(&derived_key, config).await?; Ok(GetTlsKeyResponse { key: derived_key.serialize_pem(), certificate_chain, @@ -582,28 +533,10 @@ impl TappdRpc for InternalRpcHandlerV0 { }; let report_data = content_type.to_report_data_with_hash(&request.report_data, &request.hash_algorithm)?; - if self.state.config().simulator.enabled { - let response = simulate_quote( - self.state.config(), - report_data, - &self.state.inner.vm_config, - )?; - return Ok(TdxQuoteResponse { - quote: response.quote, - event_log: response.event_log, - hash_algorithm: hash_algorithm.to_string(), - prefix, - }); - } - let event_log = read_event_log().context("Failed to decode event log")?; - // Strip RTMR[0-2] payloads, keep only digests - let stripped: Vec<_> = event_log.iter().map(|e| e.stripped()).collect(); - let event_log = - serde_json::to_string(&stripped).context("Failed to serialize event log")?; - let quote = tdx_attest::get_quote(&report_data).context("Failed to get quote")?; + let response = self.state.quote_response(report_data)?; Ok(TdxQuoteResponse { - quote, - event_log, + quote: response.quote, + event_log: response.event_log, hash_algorithm: hash_algorithm.to_string(), prefix, }) @@ -624,8 +557,8 @@ impl TappdRpc for InternalRpcHandlerV0 { async fn version(self) -> Result { Ok(WorkerVersion { - version: env!("CARGO_PKG_VERSION").to_string(), - rev: super::GIT_REV.to_string(), + version: crate::CARGO_PKG_VERSION.to_string(), + rev: crate::GIT_REV.to_string(), }) } } @@ -657,8 +590,8 @@ impl WorkerRpc for ExternalRpcHandler { async fn version(self) -> Result { Ok(WorkerVersion { - version: env!("CARGO_PKG_VERSION").to_string(), - rev: super::GIT_REV.to_string(), + version: crate::CARGO_PKG_VERSION.to_string(), + rev: crate::GIT_REV.to_string(), }) } @@ -693,26 +626,7 @@ impl WorkerRpc for ExternalRpcHandler { let ed_bytes = ed25519_report_string.as_bytes(); ed25519_report_data[..ed_bytes.len()].copy_from_slice(ed_bytes); - if self.state.config().simulator.enabled { - Ok(simulate_quote( - self.state.config(), - ed25519_report_data, - &self.state.inner.vm_config, - )?) - } else { - let ed25519_quote = tdx_attest::get_quote(&ed25519_report_data) - .context("Failed to get ed25519 quote")?; - let raw_event_log = read_event_log().context("Failed to read event log")?; - // Strip RTMR[0-2] payloads, keep only digests - let stripped: Vec<_> = raw_event_log.iter().map(|e| e.stripped()).collect(); - let event_log = serde_json::to_string(&stripped)?; - Ok(GetQuoteResponse { - quote: ed25519_quote, - event_log, - report_data: ed25519_report_data.to_vec(), - vm_config: self.state.inner.vm_config.clone(), - }) - } + self.state.quote_response(ed25519_report_data) } "secp256k1" | "secp256k1_prehashed" => { let secp256k1_key = SigningKey::from_slice(&key_response.key) @@ -725,27 +639,7 @@ impl WorkerRpc for ExternalRpcHandler { let secp_bytes = secp256k1_report_string.as_bytes(); secp256k1_report_data[..secp_bytes.len()].copy_from_slice(secp_bytes); - if self.state.config().simulator.enabled { - Ok(simulate_quote( - self.state.config(), - secp256k1_report_data, - &self.state.inner.vm_config, - )?) - } else { - let secp256k1_quote = tdx_attest::get_quote(&secp256k1_report_data) - .context("Failed to get secp256k1 quote")?; - let raw_event_log = read_event_log().context("Failed to read event log")?; - // Strip RTMR[0-2] payloads, keep only digests - let stripped: Vec<_> = raw_event_log.iter().map(|e| e.stripped()).collect(); - let event_log = serde_json::to_string(&stripped)?; - - Ok(GetQuoteResponse { - quote: secp256k1_quote, - event_log, - report_data: secp256k1_report_data.to_vec(), - vm_config: self.state.inner.vm_config.clone(), - }) - } + self.state.quote_response(secp256k1_report_data) } _ => Err(anyhow::anyhow!("Unsupported algorithm")), } @@ -765,7 +659,10 @@ impl RpcCall for ExternalRpcHandler { #[cfg(test)] mod tests { use super::*; - use crate::config::{AppComposeWrapper, Config, Simulator}; + use crate::{ + backend::PlatformBackend, + config::{AppComposeWrapper, Config}, + }; use dstack_guest_agent_rpc::{GetAttestationForAppKeyRequest, SignRequest}; use dstack_types::{AppCompose, AppKeys, KeyProvider}; use ed25519_dalek::ed25519::signature::hazmat::PrehashVerifier; @@ -773,6 +670,7 @@ mod tests { Signature as Ed25519Signature, Verifier, VerifyingKey as Ed25519VerifyingKey, }; use k256::ecdsa::{Signature as K256Signature, VerifyingKey}; + use ra_tls::attestation::VersionedAttestation; use sha2::Sha256; use std::collections::HashSet; use std::convert::TryFrom; @@ -801,11 +699,6 @@ mod tests { temp_attestation_file.write_all(attestation).unwrap(); temp_attestation_file.flush().unwrap(); - let dummy_simulator = Simulator { - enabled: true, - attestation_file: temp_attestation_file.path().to_str().unwrap().to_string(), - }; - let dummy_appcompose = AppCompose { manifest_version: 0, name: String::new(), @@ -837,7 +730,6 @@ mod tests { app_compose: dummy_appcompose_wrapper, sys_config_file: String::new().into(), pccs_url: None, - simulator: dummy_simulator, data_disks: HashSet::new(), }; @@ -914,12 +806,86 @@ pNs85uhOZE8z2jr8Pg== .await .expect("Failed to create CertRequestClient"); + struct TestSimulatorPlatform { + attestation: VersionedAttestation, + } + + fn patch_report_data( + attestation: &VersionedAttestation, + report_data: [u8; 64], + ) -> VersionedAttestation { + let ra_tls::attestation::VersionedAttestation::V0 { attestation } = attestation.clone(); + let mut attestation = attestation; + attestation.report_data = report_data; + if let Some(tdx_quote) = attestation.tdx_quote_mut() { + if tdx_quote.quote.len() >= ra_tls::attestation::TDX_QUOTE_REPORT_DATA_RANGE.end { + tdx_quote.quote[ra_tls::attestation::TDX_QUOTE_REPORT_DATA_RANGE] + .copy_from_slice(&report_data); + } else { + tracing::warn!( + "TDX quote too short to patch report_data ({} < {})", + tdx_quote.quote.len(), + ra_tls::attestation::TDX_QUOTE_REPORT_DATA_RANGE.end + ); + } + } + ra_tls::attestation::VersionedAttestation::V0 { attestation } + } + + impl PlatformBackend for TestSimulatorPlatform { + fn attestation_for_info(&self) -> Result { + Ok(self.attestation.clone().into_inner()) + } + + fn certificate_attestation(&self, pubkey: &[u8]) -> Result { + let report_data = + ra_tls::attestation::QuoteContentType::RaTlsCert.to_report_data(pubkey); + Ok(patch_report_data(&self.attestation, report_data)) + } + + fn quote_response( + &self, + report_data: [u8; 64], + vm_config: &str, + ) -> Result { + let ra_tls::attestation::VersionedAttestation::V0 { attestation } = + patch_report_data(&self.attestation, report_data); + let mut attestation = attestation; + let Some(quote) = attestation.tdx_quote_mut() else { + return Err(anyhow::anyhow!("Quote not found")); + }; + Ok(GetQuoteResponse { + quote: quote.quote.to_vec(), + event_log: serde_json::to_string("e.event_log) + .context("Failed to serialize event log")?, + report_data: report_data.to_vec(), + vm_config: vm_config.to_string(), + }) + } + + fn attest_response(&self, report_data: [u8; 64]) -> Result { + Ok(AttestResponse { + attestation: patch_report_data(&self.attestation, report_data).to_scale(), + }) + } + + fn emit_event(&self, _event: &str, _payload: &[u8]) -> Result<()> { + Ok(()) + } + } + let inner = AppStateInner { config: dummy_config, keys: dummy_keys, vm_config: String::new(), cert_client: dummy_cert_client, demo_cert: RwLock::new(String::new()), + platform: Arc::new(TestSimulatorPlatform { + attestation: VersionedAttestation::from_scale( + &std::fs::read(temp_attestation_file.path()).unwrap(), + ) + .unwrap(), + }), }; ( diff --git a/guest-agent/src/server.rs b/guest-agent/src/server.rs new file mode 100644 index 000000000..00d3d90c8 --- /dev/null +++ b/guest-agent/src/server.rs @@ -0,0 +1,228 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +use std::{future::pending, os::unix::net::UnixListener as StdUnixListener, time::Duration}; + +use crate::config::BindAddr; +use crate::guest_api_service::GuestApiHandler; +use crate::http_routes; +use crate::rpc_service::{AppState, ExternalRpcHandler, InternalRpcHandler, InternalRpcHandlerV0}; +use crate::socket_activation::{ActivatedSockets, ActivatedUnixListener}; +use anyhow::{anyhow, Context, Result}; +use rocket::{ + fairing::AdHoc, + figment::Figment, + listener::{Bind, DefaultListener}, +}; +use rocket_vsock_listener::VsockListener; +use sd_notify::{notify as sd_notify, NotifyState}; +use tokio::sync::oneshot; +use tracing::{error, info}; + +pub fn app_version() -> String { + format!("v{} ({})", crate::CARGO_PKG_VERSION, crate::GIT_REV) +} + +async fn run_internal_v0( + state: AppState, + figment: Figment, + activated_socket: Option, + sock_ready_tx: oneshot::Sender<()>, +) -> Result<()> { + let rocket = rocket::custom(figment) + .mount( + "/prpc/", + ra_rpc::prpc_routes!(AppState, InternalRpcHandlerV0, trim: "Tappd."), + ) + .manage(state); + let ignite = rocket + .ignite() + .await + .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; + + if let Some(std_listener) = activated_socket { + info!("Using systemd-activated socket for tappd.sock"); + let listener = ActivatedUnixListener::new(std_listener)?; + sock_ready_tx.send(()).ok(); + ignite + .launch_on(listener) + .await + .map_err(|err: rocket::Error| anyhow!(err.to_string()))?; + } else { + let endpoint = DefaultListener::bind_endpoint(&ignite) + .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; + let listener = DefaultListener::bind(&ignite) + .await + .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; + sock_ready_tx.send(()).ok(); + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; + } + Ok(()) +} + +async fn run_internal( + state: AppState, + figment: Figment, + activated_socket: Option, + sock_ready_tx: oneshot::Sender<()>, +) -> Result<()> { + let rocket = rocket::custom(figment) + .mount("/", ra_rpc::prpc_routes!(AppState, InternalRpcHandler)) + .manage(state); + let ignite = rocket + .ignite() + .await + .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; + + if let Some(std_listener) = activated_socket { + info!("Using systemd-activated socket for dstack.sock"); + let listener = ActivatedUnixListener::new(std_listener)?; + sock_ready_tx.send(()).ok(); + ignite + .launch_on(listener) + .await + .map_err(|err: rocket::Error| anyhow!(err.to_string()))?; + } else { + let endpoint = DefaultListener::bind_endpoint(&ignite) + .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; + let listener = DefaultListener::bind(&ignite) + .await + .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; + sock_ready_tx.send(()).ok(); + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; + } + Ok(()) +} + +async fn run_external(state: AppState, figment: Figment) -> Result<()> { + let rocket = rocket::custom(figment) + .mount("/", http_routes::external_routes(state.config())) + .mount( + "/prpc", + ra_rpc::prpc_routes!(AppState, ExternalRpcHandler, trim: "Worker."), + ) + .attach(AdHoc::on_response("Add app version header", |_req, res| { + Box::pin(async move { + res.set_raw_header("X-App-Version", app_version()); + }) + })) + .manage(state); + let _ = rocket + .launch() + .await + .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; + Ok(()) +} + +async fn run_guest_api(state: AppState, figment: Figment) -> Result<()> { + let rocket = rocket::custom(figment) + .mount("/api", ra_rpc::prpc_routes!(AppState, GuestApiHandler)) + .manage(state); + + let ignite = rocket + .ignite() + .await + .map_err(|err| anyhow!("Failed to ignite rocket: {err}"))?; + if DefaultListener::bind_endpoint(&ignite).is_ok() { + let listener = DefaultListener::bind(&ignite) + .await + .map_err(|err| anyhow!("Failed to bind guest API : {err}"))?; + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; + } else { + let listener = VsockListener::bind_rocket(&ignite) + .map_err(|err| anyhow!("Failed to bind guest API : {err}"))?; + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; + } + Ok(()) +} + +async fn run_watchdog(port: u16) { + let mut watchdog_usec = 0; + let enabled = sd_notify::watchdog_enabled(false, &mut watchdog_usec); + if !enabled { + info!("Watchdog is not enabled in systemd service"); + return pending::<()>().await; + } + + info!("Starting watchdog"); + if let Err(err) = sd_notify(false, &[NotifyState::Ready]) { + error!("Failed to notify systemd: {err}"); + } + let heatbeat_interval = Duration::from_micros(watchdog_usec / 2); + let heatbeat_interval = heatbeat_interval.max(Duration::from_secs(1)); + info!("Watchdog enabled, interval={watchdog_usec}us, heartbeat={heatbeat_interval:?}"); + let mut interval = tokio::time::interval(heatbeat_interval); + + let probe_url = format!("http://localhost:{port}/prpc/Worker.Version"); + loop { + interval.tick().await; + + let client = reqwest::Client::new(); + match client.get(&probe_url).send().await { + Ok(response) if response.status().is_success() => { + if let Err(err) = sd_notify(false, &[NotifyState::Watchdog]) { + error!("Failed to notify systemd: {err}"); + } + } + Ok(response) => { + error!("Health check failed with status: {}", response.status()); + } + Err(err) => { + error!("Health check request failed: {err:?}"); + } + } + } +} + +pub async fn run(state: AppState, figment: Figment, watchdog: bool) -> Result<()> { + let internal_v0_figment = figment.clone().select("internal-v0"); + let internal_figment = figment.clone().select("internal"); + let external_figment = figment.clone().select("external"); + let bind_addr = if watchdog { + Some( + external_figment + .extract::() + .context("Failed to extract bind address")?, + ) + } else { + None + }; + let guest_api_figment = figment.select("guest-api"); + + let activated = ActivatedSockets::from_env(); + if activated.any_activated() { + info!("Systemd socket activation detected"); + } + + let (tappd_ready_tx, tappd_ready_rx) = oneshot::channel(); + let (sock_ready_tx, sock_ready_rx) = oneshot::channel(); + tokio::select!( + res = run_internal_v0(state.clone(), internal_v0_figment, activated.tappd, tappd_ready_tx) => res?, + res = run_internal(state.clone(), internal_figment, activated.dstack, sock_ready_tx) => res?, + res = run_external(state.clone(), external_figment) => res?, + res = run_guest_api(state.clone(), guest_api_figment) => res?, + _ = async { + let _ = tappd_ready_rx.await; + let _ = sock_ready_rx.await; + if let Some(bind_addr) = bind_addr { + run_watchdog(bind_addr.port).await; + } else { + pending::<()>().await; + } + } => {} + ); + Ok(()) +} diff --git a/run-tests.sh b/run-tests.sh index e348d1bed..59aa0a5fd 100755 --- a/run-tests.sh +++ b/run-tests.sh @@ -4,22 +4,73 @@ # # SPDX-License-Identifier: Apache-2.0 -set -e +set -Eeuo pipefail -(cd sdk/simulator && ./build.sh) +ROOT_DIR="$(pwd -P)" +SIMULATOR_DIR="$ROOT_DIR/sdk/simulator" +SIMULATOR_LOG="$SIMULATOR_DIR/dstack-simulator.log" +DSTACK_SOCKET="$SIMULATOR_DIR/dstack.sock" +TAPPD_SOCKET="$SIMULATOR_DIR/tappd.sock" +SIMULATOR_PID="" -pushd sdk/simulator -./dstack-simulator & +cleanup() { + if [[ -n "${SIMULATOR_PID:-}" ]]; then + kill "$SIMULATOR_PID" 2>/dev/null || true + wait "$SIMULATOR_PID" 2>/dev/null || true + fi +} + +print_simulator_logs() { + if [[ -f "$SIMULATOR_LOG" ]]; then + echo "Last simulator logs:" + tail -100 "$SIMULATOR_LOG" || true + fi +} + +wait_for_socket() { + local socket_path="$1" + local name="$2" + + for _ in {1..100}; do + if [[ -S "$socket_path" ]]; then + return 0 + fi + if [[ -n "${SIMULATOR_PID:-}" ]] && ! kill -0 "$SIMULATOR_PID" 2>/dev/null; then + echo "Simulator exited before $name socket became ready." + print_simulator_logs + return 1 + fi + sleep 0.2 + done + + echo "Timed out waiting for $name socket at $socket_path" + print_simulator_logs + return 1 +} + +trap 'print_simulator_logs' ERR +trap cleanup EXIT INT TERM + +rm -f "$DSTACK_SOCKET" "$TAPPD_SOCKET" "$SIMULATOR_LOG" +( + cd "$SIMULATOR_DIR" + ./build.sh +) + +( + cd "$SIMULATOR_DIR" + ./dstack-simulator >"$SIMULATOR_LOG" 2>&1 +) & SIMULATOR_PID=$! -trap "kill $SIMULATOR_PID 2>/dev/null || true" EXIT echo "Simulator process (PID: $SIMULATOR_PID) started." -popd -export DSTACK_SIMULATOR_ENDPOINT=$(realpath sdk/simulator/dstack.sock) -export TAPPD_SIMULATOR_ENDPOINT=$(realpath sdk/simulator/tappd.sock) +wait_for_socket "$DSTACK_SOCKET" "dstack" +wait_for_socket "$TAPPD_SOCKET" "tappd" + +export DSTACK_SIMULATOR_ENDPOINT="$DSTACK_SOCKET" +export TAPPD_SIMULATOR_ENDPOINT="$TAPPD_SOCKET" echo "DSTACK_SIMULATOR_ENDPOINT: $DSTACK_SIMULATOR_ENDPOINT" echo "TAPPD_SIMULATOR_ENDPOINT: $TAPPD_SIMULATOR_ENDPOINT" -# Run the tests cargo test --all-features -- --show-output diff --git a/sdk/README.md b/sdk/README.md index 930eeba06..6d1b72a0b 100644 --- a/sdk/README.md +++ b/sdk/README.md @@ -19,5 +19,6 @@ All SDKs communicate with the guest agent via HTTP over a Unix socket (`/var/run For local development without TDX hardware, use the simulator: -- [Download releases](https://github.com/Leechael/dstack-simulator/releases) +- [Download releases](https://github.com/Dstack-TEE/dstack/releases?q=simulator-v&expanded=true) +- [Install as a systemd service](../guest-agent-simulator/install-systemd.sh) - [Docker image](https://hub.docker.com/r/phalanetwork/dstack-simulator) diff --git a/sdk/python/tests/test_client.py b/sdk/python/tests/test_client.py index f98484a94..2583e414b 100644 --- a/sdk/python/tests/test_client.py +++ b/sdk/python/tests/test_client.py @@ -3,9 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 import hashlib +import os import warnings from evidence_api.tdx.quote import TdxQuote +import httpx import pytest from dstack_sdk import AsyncDstackClient @@ -248,6 +250,14 @@ def test_unix_socket_file_not_exist(): os.environ["DSTACK_SIMULATOR_ENDPOINT"] = saved_env +def assert_emit_event_behavior(error: Exception | None) -> None: + if "DSTACK_SIMULATOR_ENDPOINT" in os.environ: + assert isinstance(error, httpx.HTTPStatusError) + assert error.response.status_code == 400 + else: + assert error is None, f"emit_event unexpectedly failed: {error}" + + def test_non_unix_socket_endpoints(): """Test that client doesn't throw error for non-unix socket paths.""" import os @@ -272,17 +282,25 @@ def test_non_unix_socket_endpoints(): async def test_emit_event(): """Test emit event functionality.""" client = AsyncDstackClient() - # This should not raise an error - await client.emit_event("test-event", "test payload") - await client.emit_event("test-event-bytes", b"test payload bytes") + error = None + try: + await client.emit_event("test-event", "test payload") + await client.emit_event("test-event-bytes", b"test payload bytes") + except Exception as exc: # pragma: no cover - behavior depends on runtime mode + error = exc + assert_emit_event_behavior(error) def test_sync_emit_event(): """Test sync emit event functionality.""" client = DstackClient() - # This should not raise an error - client.emit_event("test-event", "test payload") - client.emit_event("test-event-bytes", b"test payload bytes") + error = None + try: + client.emit_event("test-event", "test payload") + client.emit_event("test-event-bytes", b"test payload bytes") + except Exception as exc: # pragma: no cover - behavior depends on runtime mode + error = exc + assert_emit_event_behavior(error) def test_emit_event_validation(): diff --git a/sdk/run-tests.sh b/sdk/run-tests.sh index 00ebcd7d7..8ab258f2a 100755 --- a/sdk/run-tests.sh +++ b/sdk/run-tests.sh @@ -5,17 +5,70 @@ # # SPDX-License-Identifier: Apache-2.0 -set -e +set -Eeuo pipefail -export DSTACK_SIMULATOR_ENDPOINT=$(realpath simulator/dstack.sock) -export TAPPD_SIMULATOR_ENDPOINT=$(realpath simulator/tappd.sock) +ROOT_DIR="$(pwd -P)" +SIMULATOR_DIR="$ROOT_DIR/simulator" +SIMULATOR_LOG="$SIMULATOR_DIR/dstack-simulator.log" +DSTACK_SOCKET="$SIMULATOR_DIR/dstack.sock" +TAPPD_SOCKET="$SIMULATOR_DIR/tappd.sock" +SIMULATOR_PID="" -pushd simulator -./build.sh -./dstack-simulator >/dev/null 2>&1 & +cleanup() { + if [[ -n "${SIMULATOR_PID:-}" ]]; then + kill "$SIMULATOR_PID" 2>/dev/null || true + wait "$SIMULATOR_PID" 2>/dev/null || true + fi +} + +print_simulator_logs() { + if [[ -f "$SIMULATOR_LOG" ]]; then + echo "Last simulator logs:" + tail -100 "$SIMULATOR_LOG" || true + fi +} + +wait_for_socket() { + local socket_path="$1" + local name="$2" + + for _ in {1..100}; do + if [[ -S "$socket_path" ]]; then + return 0 + fi + if [[ -n "${SIMULATOR_PID:-}" ]] && ! kill -0 "$SIMULATOR_PID" 2>/dev/null; then + echo "Simulator exited before $name socket became ready." + print_simulator_logs + return 1 + fi + sleep 0.2 + done + + echo "Timed out waiting for $name socket at $socket_path" + print_simulator_logs + return 1 +} + +trap 'print_simulator_logs' ERR +trap cleanup EXIT INT TERM + +rm -f "$DSTACK_SOCKET" "$TAPPD_SOCKET" "$SIMULATOR_LOG" +export DSTACK_SIMULATOR_ENDPOINT="$DSTACK_SOCKET" +export TAPPD_SIMULATOR_ENDPOINT="$TAPPD_SOCKET" + +( + cd "$SIMULATOR_DIR" + ./build.sh +) + +( + cd "$SIMULATOR_DIR" + ./dstack-simulator >"$SIMULATOR_LOG" 2>&1 +) & SIMULATOR_PID=$! -trap "kill $SIMULATOR_PID 2>/dev/null || true" EXIT -popd + +wait_for_socket "$DSTACK_SOCKET" "dstack" +wait_for_socket "$TAPPD_SOCKET" "tappd" pushd rust/ cargo test -- --show-output diff --git a/sdk/rust/examples/dstack_client_usage.rs b/sdk/rust/examples/dstack_client_usage.rs index 722dbf2d0..bdb17302c 100644 --- a/sdk/rust/examples/dstack_client_usage.rs +++ b/sdk/rust/examples/dstack_client_usage.rs @@ -67,10 +67,16 @@ async fn main() -> anyhow::Result<()> { // 4. Emit an event let event_payload = b"Application started successfully".to_vec(); - client + match client .emit_event("AppStart".to_string(), event_payload) - .await?; - println!("Event emitted successfully!"); + .await + { + Ok(()) => println!("Event emitted successfully!"), + Err(err) if std::env::var_os("DSTACK_SIMULATOR_ENDPOINT").is_some() => { + println!("Event emission is unavailable in simulator mode: {err}"); + } + Err(err) => return Err(err), + } // 5. Get TLS key for server authentication let tls_config = TlsKeyConfig::builder() diff --git a/sdk/simulator/build.sh b/sdk/simulator/build.sh index 2567eeac1..fca993175 100755 --- a/sdk/simulator/build.sh +++ b/sdk/simulator/build.sh @@ -5,7 +5,6 @@ # SPDX-License-Identifier: Apache-2.0 cd $(dirname $0) -cargo build --release -p dstack-guest-agent -cp ../../target/release/dstack-guest-agent . -ln -sf dstack-guest-agent dstack-simulator +cargo build --release -p dstack-guest-agent-simulator +cp ../../target/release/dstack-simulator . diff --git a/sdk/simulator/dstack.toml b/sdk/simulator/dstack.toml index ecf4a8e40..b1c7325a6 100644 --- a/sdk/simulator/dstack.toml +++ b/sdk/simulator/dstack.toml @@ -14,10 +14,11 @@ log_level = "debug" keys_file = "appkeys.json" compose_file = "app-compose.json" sys_config_file = "sys-config.json" +data_disks = ["/"] [default.core.simulator] -enabled = true attestation_file = "attestation.bin" +patch_report_data = true [internal-v0] address = "unix:./tappd.sock"