Skip to content

Commit 2aeed83

Browse files
authored
Merge pull request #308 from Dstack-TEE/feature-sdk-golang
feat(sdk/go): update golang sdk.
2 parents 06a2347 + 4e4418b commit 2aeed83

14 files changed

Lines changed: 1985 additions & 352 deletions

sdk/go/README.md

Lines changed: 747 additions & 114 deletions
Large diffs are not rendered by default.

sdk/go/dstack/client.go

Lines changed: 253 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@ package dstack
99
import (
1010
"bytes"
1111
"context"
12+
"crypto/ecdsa"
13+
"crypto/ed25519"
1214
"crypto/sha512"
15+
"crypto/x509"
1316
"encoding/hex"
1417
"encoding/json"
18+
"encoding/pem"
1519
"fmt"
1620
"io"
1721
"log/slog"
1822
"net"
1923
"net/http"
2024
"os"
2125
"strings"
26+
"time"
2227
)
2328

2429
// Represents the response from a TLS key derivation request.
@@ -27,20 +32,83 @@ type GetTlsKeyResponse struct {
2732
CertificateChain []string `json:"certificate_chain"`
2833
}
2934

35+
// AsUint8Array converts the private key to bytes, optionally limiting the length
36+
func (r *GetTlsKeyResponse) AsUint8Array(maxLength ...int) ([]byte, error) {
37+
block, _ := pem.Decode([]byte(r.Key))
38+
if block == nil {
39+
return nil, fmt.Errorf("failed to decode pem private key")
40+
}
41+
42+
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to parse private key: %w", err)
45+
}
46+
47+
var keyBytes []byte
48+
switch k := key.(type) {
49+
case *ecdsa.PrivateKey:
50+
keyBytes = k.D.FillBytes(make([]byte, (k.Curve.Params().N.BitLen()+7)/8))
51+
case ed25519.PrivateKey:
52+
keyBytes = k.Seed()
53+
default:
54+
return nil, fmt.Errorf("unsupported key type: %T", key)
55+
}
56+
57+
if len(maxLength) > 0 && maxLength[0] > 0 && maxLength[0] < len(keyBytes) {
58+
return keyBytes[:maxLength[0]], nil
59+
}
60+
return keyBytes, nil
61+
}
62+
3063
// Represents the response from a key derivation request.
3164
type GetKeyResponse struct {
3265
Key string `json:"key"`
3366
SignatureChain []string `json:"signature_chain"`
3467
}
3568

69+
// DecodeKey returns the key as bytes
70+
func (r *GetKeyResponse) DecodeKey() ([]byte, error) {
71+
return hex.DecodeString(r.Key)
72+
}
73+
74+
// DecodeSignatureChain returns the signature chain as bytes
75+
func (r *GetKeyResponse) DecodeSignatureChain() ([][]byte, error) {
76+
result := make([][]byte, len(r.SignatureChain))
77+
for i, sig := range r.SignatureChain {
78+
bytes, err := hex.DecodeString(sig)
79+
if err != nil {
80+
return nil, fmt.Errorf("failed to decode signature %d: %w", i, err)
81+
}
82+
result[i] = bytes
83+
}
84+
return result, nil
85+
}
86+
3687
// Represents the response from a quote request.
3788
type GetQuoteResponse struct {
38-
Quote []byte `json:"quote"`
89+
Quote string `json:"quote"`
3990
EventLog string `json:"event_log"`
40-
ReportData []byte `json:"report_data"`
91+
ReportData string `json:"report_data"`
4192
VmConfig string `json:"vm_config"`
4293
}
4394

95+
// DecodeQuote returns the quote bytes
96+
func (r *GetQuoteResponse) DecodeQuote() ([]byte, error) {
97+
return hex.DecodeString(r.Quote)
98+
}
99+
100+
// DecodeReportData returns the report data bytes
101+
func (r *GetQuoteResponse) DecodeReportData() ([]byte, error) {
102+
return hex.DecodeString(r.ReportData)
103+
}
104+
105+
// DecodeEventLog returns the event log as structured data
106+
func (r *GetQuoteResponse) DecodeEventLog() ([]EventLog, error) {
107+
var events []EventLog
108+
err := json.Unmarshal([]byte(r.EventLog), &events)
109+
return events, err
110+
}
111+
44112
// Represents the response from an attestation request.
45113
type AttestResponse struct {
46114
Attestation []byte
@@ -57,17 +125,20 @@ type EventLog struct {
57125

58126
// Represents the TCB information
59127
type TcbInfo struct {
60-
Mrtd string `json:"mrtd"`
61-
Rtmr0 string `json:"rtmr0"`
62-
Rtmr1 string `json:"rtmr1"`
63-
Rtmr2 string `json:"rtmr2"`
64-
Rtmr3 string `json:"rtmr3"`
65-
// The hash of the OS image. This is empty if the OS image is not measured by KMS.
66-
OsImageHash string `json:"os_image_hash,omitempty"`
67-
ComposeHash string `json:"compose_hash"`
68-
DeviceID string `json:"device_id"`
69-
AppCompose string `json:"app_compose"`
70-
EventLog []EventLog `json:"event_log"`
128+
Mrtd string `json:"mrtd"`
129+
Rtmr0 string `json:"rtmr0"`
130+
Rtmr1 string `json:"rtmr1"`
131+
Rtmr2 string `json:"rtmr2"`
132+
Rtmr3 string `json:"rtmr3"`
133+
AppCompose string `json:"app_compose"`
134+
EventLog []EventLog `json:"event_log"`
135+
// V0.3.x fields
136+
RootfsHash string `json:"rootfs_hash,omitempty"`
137+
// V0.5.x fields
138+
MrAggregated string `json:"mr_aggregated,omitempty"`
139+
OsImageHash string `json:"os_image_hash,omitempty"`
140+
ComposeHash string `json:"compose_hash,omitempty"`
141+
DeviceID string `json:"device_id,omitempty"`
71142
}
72143

73144
// Represents the response from an info request
@@ -81,9 +152,11 @@ type InfoResponse struct {
81152
MrAggregated string `json:"mr_aggregated,omitempty"`
82153
KeyProviderInfo string `json:"key_provider_info"`
83154
// Optional: empty if OS image is not measured by KMS
84-
OsImageHash string `json:"os_image_hash,omitempty"`
85-
ComposeHash string `json:"compose_hash"`
86-
VmConfig string `json:"vm_config,omitempty"`
155+
OsImageHash string `json:"os_image_hash,omitempty"`
156+
ComposeHash string `json:"compose_hash"`
157+
VmConfig string `json:"vm_config,omitempty"`
158+
CloudVendor string `json:"cloud_vendor,omitempty"`
159+
CloudProduct string `json:"cloud_product,omitempty"`
87160
}
88161

89162
// DecodeTcbInfo decodes the TcbInfo string into a TcbInfo struct
@@ -269,6 +342,7 @@ func (c *DstackClient) sendRPCRequest(ctx context.Context, path string, payload
269342
}
270343

271344
req.Header.Set("Content-Type", "application/json")
345+
req.Header.Set("User-Agent", "dstack-sdk-go/0.1.0")
272346
resp, err := c.httpClient.Do(req)
273347
if err != nil {
274348
return nil, err
@@ -297,6 +371,9 @@ type tlsKeyOptions struct {
297371
usageRaTls bool
298372
usageServerAuth bool
299373
usageClientAuth bool
374+
notBefore *uint64
375+
notAfter *uint64
376+
withAppInfo *bool
300377
}
301378

302379
// WithSubject sets the subject for the TLS key
@@ -334,6 +411,27 @@ func WithUsageClientAuth(usage bool) TlsKeyOption {
334411
}
335412
}
336413

414+
// WithNotBefore sets the not_before timestamp for the certificate
415+
func WithNotBefore(t uint64) TlsKeyOption {
416+
return func(opts *tlsKeyOptions) {
417+
opts.notBefore = &t
418+
}
419+
}
420+
421+
// WithNotAfter sets the not_after timestamp for the certificate
422+
func WithNotAfter(t uint64) TlsKeyOption {
423+
return func(opts *tlsKeyOptions) {
424+
opts.notAfter = &t
425+
}
426+
}
427+
428+
// WithAppInfo sets the with_app_info flag for the certificate
429+
func WithAppInfo(enabled bool) TlsKeyOption {
430+
return func(opts *tlsKeyOptions) {
431+
opts.withAppInfo = &enabled
432+
}
433+
}
434+
337435
// Gets a TLS key from the dstack service with optional parameters.
338436
func (c *DstackClient) GetTlsKey(
339437
ctx context.Context,
@@ -356,6 +454,15 @@ func (c *DstackClient) GetTlsKey(
356454
if len(opts.altNames) > 0 {
357455
payload["alt_names"] = opts.altNames
358456
}
457+
if opts.notBefore != nil {
458+
payload["not_before"] = *opts.notBefore
459+
}
460+
if opts.notAfter != nil {
461+
payload["not_after"] = *opts.notAfter
462+
}
463+
if opts.withAppInfo != nil {
464+
payload["with_app_info"] = *opts.withAppInfo
465+
}
359466

360467
data, err := c.sendRPCRequest(ctx, "/GetTlsKey", payload)
361468
if err != nil {
@@ -429,30 +536,12 @@ func (c *DstackClient) GetQuote(ctx context.Context, reportData []byte) (*GetQuo
429536
return nil, err
430537
}
431538

432-
var response struct {
433-
Quote string `json:"quote"`
434-
EventLog string `json:"event_log"`
435-
ReportData string `json:"report_data"`
436-
}
539+
var response GetQuoteResponse
437540
if err := json.Unmarshal(data, &response); err != nil {
438541
return nil, err
439542
}
440543

441-
quote, err := hex.DecodeString(response.Quote)
442-
if err != nil {
443-
return nil, err
444-
}
445-
446-
reportDataBytes, err := hex.DecodeString(response.ReportData)
447-
if err != nil {
448-
return nil, err
449-
}
450-
451-
return &GetQuoteResponse{
452-
Quote: quote,
453-
EventLog: response.EventLog,
454-
ReportData: reportDataBytes,
455-
}, nil
544+
return &response, nil
456545
}
457546

458547
// Gets a versioned attestation from the dstack service.
@@ -600,14 +689,142 @@ func (c *DstackClient) Verify(ctx context.Context, algorithm string, data []byte
600689
return &response, nil
601690
}
602691

692+
// IsReachable checks if the service is reachable
693+
func (c *DstackClient) IsReachable(ctx context.Context) bool {
694+
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
695+
defer cancel()
696+
_, err := c.Info(ctx)
697+
return err == nil
698+
}
699+
603700
// EmitEvent sends an event to be extended to RTMR3 on TDX platform.
604701
// The event will be extended to RTMR3 with the provided name and payload.
605702
//
606703
// Requires dstack OS 0.5.0 or later.
607704
func (c *DstackClient) EmitEvent(ctx context.Context, event string, payload []byte) error {
705+
if event == "" {
706+
return fmt.Errorf("event name cannot be empty")
707+
}
608708
_, err := c.sendRPCRequest(ctx, "/EmitEvent", map[string]interface{}{
609709
"event": event,
610710
"payload": hex.EncodeToString(payload),
611711
})
612712
return err
613713
}
714+
715+
// Legacy methods for backward compatibility with warnings
716+
717+
// DeriveKey is deprecated. Use GetKey instead.
718+
// Deprecated: Use GetKey instead.
719+
func (c *DstackClient) DeriveKey(path string, subject string, altNames []string) (*GetTlsKeyResponse, error) {
720+
return nil, fmt.Errorf("deriveKey is deprecated, please use GetKey instead")
721+
}
722+
723+
// TdxQuote is deprecated. Use GetQuote instead.
724+
// Deprecated: Use GetQuote instead.
725+
func (c *DstackClient) TdxQuote(ctx context.Context, reportData []byte, hashAlgorithm string) (*GetQuoteResponse, error) {
726+
c.logger.Warn("tdxQuote is deprecated, please use GetQuote instead")
727+
if hashAlgorithm != "raw" {
728+
return nil, fmt.Errorf("tdxQuote only supports raw hash algorithm")
729+
}
730+
return c.GetQuote(ctx, reportData)
731+
}
732+
733+
// TappdClient is a deprecated wrapper around DstackClient for backward compatibility.
734+
// Deprecated: Use DstackClient instead.
735+
type TappdClient struct {
736+
*DstackClient
737+
}
738+
739+
// NewTappdClient creates a new deprecated TappdClient.
740+
// Deprecated: Use NewDstackClient instead.
741+
func NewTappdClient(opts ...DstackClientOption) *TappdClient {
742+
// Create a modified option to use TAPPD_SIMULATOR_ENDPOINT
743+
tappdOpts := make([]DstackClientOption, 0, len(opts)+1)
744+
745+
// Add default endpoint option that checks TAPPD_SIMULATOR_ENDPOINT
746+
tappdOpts = append(tappdOpts, func(c *DstackClient) {
747+
if c.endpoint == "" {
748+
if simEndpoint, exists := os.LookupEnv("TAPPD_SIMULATOR_ENDPOINT"); exists {
749+
c.logger.Warn("Using tappd endpoint", "endpoint", simEndpoint)
750+
c.endpoint = simEndpoint
751+
} else {
752+
c.endpoint = "/var/run/tappd.sock"
753+
}
754+
}
755+
})
756+
757+
// Add user-provided options
758+
tappdOpts = append(tappdOpts, opts...)
759+
760+
client := NewDstackClient(tappdOpts...)
761+
client.logger.Warn("TappdClient is deprecated, please use DstackClient instead")
762+
763+
return &TappdClient{
764+
DstackClient: client,
765+
}
766+
}
767+
768+
// Override deprecated methods to use proper tappd RPC paths
769+
770+
// DeriveKey is deprecated. Use GetKey instead.
771+
// Deprecated: Use GetKey instead.
772+
func (tc *TappdClient) DeriveKey(ctx context.Context, path string, subject string, altNames []string) (*GetTlsKeyResponse, error) {
773+
tc.logger.Warn("deriveKey is deprecated, please use GetKey instead")
774+
775+
if subject == "" {
776+
subject = path
777+
}
778+
779+
payload := map[string]interface{}{
780+
"path": path,
781+
"subject": subject,
782+
}
783+
if len(altNames) > 0 {
784+
payload["alt_names"] = altNames
785+
}
786+
787+
data, err := tc.sendRPCRequest(ctx, "/prpc/Tappd.DeriveKey", payload)
788+
if err != nil {
789+
return nil, err
790+
}
791+
792+
var response GetTlsKeyResponse
793+
if err := json.Unmarshal(data, &response); err != nil {
794+
return nil, err
795+
}
796+
return &response, nil
797+
}
798+
799+
// TdxQuote is deprecated. Use GetQuote instead.
800+
// Deprecated: Use GetQuote instead.
801+
func (tc *TappdClient) TdxQuote(ctx context.Context, reportData []byte, hashAlgorithm string) (*GetQuoteResponse, error) {
802+
tc.logger.Warn("tdxQuote is deprecated, please use GetQuote instead")
803+
804+
if hashAlgorithm == "raw" {
805+
if len(reportData) > 64 {
806+
return nil, fmt.Errorf("report data is too large, it should be at most 64 bytes when hashAlgorithm is raw")
807+
}
808+
if len(reportData) < 64 {
809+
// Left-pad with zeros
810+
padding := make([]byte, 64-len(reportData))
811+
reportData = append(padding, reportData...)
812+
}
813+
}
814+
815+
payload := map[string]interface{}{
816+
"report_data": hex.EncodeToString(reportData),
817+
"hash_algorithm": hashAlgorithm,
818+
}
819+
820+
data, err := tc.sendRPCRequest(ctx, "/prpc/Tappd.TdxQuote", payload)
821+
if err != nil {
822+
return nil, err
823+
}
824+
825+
var response GetQuoteResponse
826+
if err := json.Unmarshal(data, &response); err != nil {
827+
return nil, err
828+
}
829+
return &response, nil
830+
}

0 commit comments

Comments
 (0)