Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions kms/src/main_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ use ra_tls::{
};
use scale::Decode;
use sha2::Digest;
use tokio::sync::OnceCell;
use tracing::info;
use upgrade_authority::BootInfo;
use upgrade_authority::{build_boot_info, local_kms_boot_info, BootInfo};

use crate::{
config::KmsConfig,
crypto::{derive_k256_key, sign_message, sign_message_with_timestamp},
};

mod upgrade_authority;
pub(crate) mod upgrade_authority;

#[derive(Clone)]
pub struct KmsState {
Expand All @@ -52,6 +53,7 @@ pub struct KmsStateInner {
temp_ca_cert: String,
temp_ca_key: String,
verifier: CvmVerifier,
self_boot_info: OnceCell<BootInfo>,
}

impl KmsState {
Expand Down Expand Up @@ -79,6 +81,7 @@ impl KmsState {
temp_ca_cert,
temp_ca_key,
verifier,
self_boot_info: OnceCell::new(),
}),
})
}
Expand All @@ -95,6 +98,29 @@ struct BootConfig {
}

impl RpcHandler {
async fn ensure_self_allowed(&self) -> Result<()> {
if !self.state.config.onboard.quote_enabled {
return Ok(());
}
let boot_info = self
.state
.self_boot_info
.get_or_try_init(|| local_kms_boot_info(self.state.config.pccs_url.as_deref()))
.await
.context("Failed to load cached self boot info")?;
let response = self
.state
.config
.auth_api
.is_app_allowed(boot_info, true)
.await
.context("Failed to call self KMS auth check")?;
if !response.is_allowed {
bail!("KMS is not allowed: {}", response.reason);
}
Ok(())
}

fn ensure_attested(&self) -> Result<&VerifiedAttestation> {
let Some(attestation) = &self.attestation else {
bail!("No attestation provided");
Expand Down Expand Up @@ -169,32 +195,7 @@ impl RpcHandler {
use_boottime_mr: bool,
vm_config_str: &str,
) -> Result<BootConfig> {
let tcb_status;
let advisory_ids;
match att.report.tdx_report() {
Some(report) => {
tcb_status = report.status.clone();
advisory_ids = report.advisory_ids.clone();
}
None => {
tcb_status = "".to_string();
advisory_ids = Vec::new();
}
};
let app_info = att.decode_app_info_ex(use_boottime_mr, vm_config_str)?;
let boot_info = BootInfo {
attestation_mode: att.quote.mode(),
mr_aggregated: app_info.mr_aggregated.to_vec(),
os_image_hash: app_info.os_image_hash,
mr_system: app_info.mr_system.to_vec(),
app_id: app_info.app_id,
compose_hash: app_info.compose_hash,
instance_id: app_info.instance_id,
device_id: app_info.device_id,
key_provider_info: app_info.key_provider_info,
tcb_status,
advisory_ids,
};
let boot_info = build_boot_info(att, use_boottime_mr, vm_config_str)?;
let response = self
.state
.config
Expand Down Expand Up @@ -239,6 +240,9 @@ impl KmsRpc for RpcHandler {
if request.api_version > 1 {
bail!("Unsupported API version: {}", request.api_version);
}
self.ensure_self_allowed()
.await
.context("KMS self authorization failed")?;
let BootConfig {
boot_info,
gateway_app_id,
Expand Down Expand Up @@ -279,6 +283,9 @@ impl KmsRpc for RpcHandler {
}

async fn get_app_env_encrypt_pub_key(self, request: AppId) -> Result<PublicKeyResponse> {
self.ensure_self_allowed()
.await
.context("KMS self authorization failed")?;
let secret = kdf::derive_dh_secret(
&self.state.root_ca.key,
&[&request.app_id[..], "env-encrypt-key".as_bytes()],
Expand Down Expand Up @@ -345,6 +352,9 @@ impl KmsRpc for RpcHandler {
}

async fn get_kms_key(self, request: GetKmsKeyRequest) -> Result<KmsKeyResponse> {
self.ensure_self_allowed()
.await
.context("KMS self authorization failed")?;
if self.state.config.onboard.quote_enabled {
let _info = self.ensure_kms_allowed(&request.vm_config).await?;
}
Expand All @@ -358,6 +368,9 @@ impl KmsRpc for RpcHandler {
}

async fn get_temp_ca_cert(self) -> Result<GetTempCaCertResponse> {
self.ensure_self_allowed()
.await
.context("KMS self authorization failed")?;
Ok(GetTempCaCertResponse {
temp_ca_cert: self.state.inner.temp_ca_cert.clone(),
temp_ca_key: self.state.inner.temp_ca_key.clone(),
Expand All @@ -366,6 +379,9 @@ impl KmsRpc for RpcHandler {
}

async fn sign_cert(self, request: SignCertRequest) -> Result<SignCertResponse> {
self.ensure_self_allowed()
.await
.context("KMS self authorization failed")?;
let csr = match request.api_version {
1 => {
let csr = CertSigningRequestV1::decode(&mut &request.csr[..])
Expand Down
70 changes: 70 additions & 0 deletions kms/src/main_service/upgrade_authority.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

use crate::config::AuthApi;
use anyhow::{bail, Context, Result};
use dstack_guest_agent_rpc::{
dstack_guest_client::DstackGuestClient, AttestResponse, RawQuoteArgs,
};
use http_client::prpc::PrpcClient;
use ra_tls::attestation::AttestationMode;
use ra_tls::attestation::VerifiedAttestation;
use ra_tls::attestation::VersionedAttestation;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_human_bytes as hex_bytes;
Expand Down Expand Up @@ -33,6 +39,53 @@ pub(crate) struct BootInfo {
pub advisory_ids: Vec<String>,
}

pub(crate) fn build_boot_info(
att: &VerifiedAttestation,
use_boottime_mr: bool,
vm_config_str: &str,
) -> Result<BootInfo> {
let tcb_status;
let advisory_ids;
match att.report.tdx_report() {
Some(report) => {
tcb_status = report.status.clone();
advisory_ids = report.advisory_ids.clone();
}
None => {
tcb_status = "".to_string();
advisory_ids = Vec::new();
}
};
let app_info = att.decode_app_info_ex(use_boottime_mr, vm_config_str)?;
Ok(BootInfo {
attestation_mode: att.quote.mode(),
mr_aggregated: app_info.mr_aggregated.to_vec(),
os_image_hash: app_info.os_image_hash,
mr_system: app_info.mr_system.to_vec(),
app_id: app_info.app_id,
compose_hash: app_info.compose_hash,
instance_id: app_info.instance_id,
device_id: app_info.device_id,
key_provider_info: app_info.key_provider_info,
tcb_status,
advisory_ids,
})
}

pub(crate) async fn local_kms_boot_info(pccs_url: Option<&str>) -> Result<BootInfo> {
let response = app_attest(pad64([0u8; 32]))
.await
.context("Failed to get local KMS attestation")?;
let attestation = VersionedAttestation::from_scale(&response.attestation)
.context("Failed to decode local KMS attestation")?
.into_inner();
let verified = attestation
.verify(pccs_url)
.await
.context("Failed to verify local KMS attestation")?;
build_boot_info(&verified, false, "")
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct BootResponse {
Expand Down Expand Up @@ -134,3 +187,20 @@ fn url_join(url: &str, path: &str) -> String {
url.push_str(path);
url
}

fn dstack_client() -> DstackGuestClient<PrpcClient> {
let address = dstack_types::dstack_agent_address();
let http_client = PrpcClient::new(address);
DstackGuestClient::new(http_client)
}

async fn app_attest(report_data: Vec<u8>) -> Result<AttestResponse> {
dstack_client().attest(RawQuoteArgs { report_data }).await
}

fn pad64(hash: [u8; 32]) -> Vec<u8> {
let mut padded = Vec::with_capacity(64);
padded.extend_from_slice(&hash);
padded.resize(64, 0);
padded
}
Loading
Loading