@@ -9,16 +9,21 @@ package dstack
99import (
1010 "bytes"
1111 "context"
12+ "crypto/ecdsa"
13+ "crypto/ed25519"
1214 "crypto/sha512"
15+ "crypto/x509"
1316 "encoding/hex"
1417 "encoding/json"
18+ "encoding/pem"
1519 "fmt"
1620 "io"
1721 "log/slog"
1822 "net"
1923 "net/http"
2024 "os"
2125 "strings"
26+ "time"
2227)
2328
2429// Represents the response from a TLS key derivation request.
@@ -27,20 +32,83 @@ type GetTlsKeyResponse struct {
2732 CertificateChain []string `json:"certificate_chain"`
2833}
2934
35+ // AsUint8Array converts the private key to bytes, optionally limiting the length
36+ func (r * GetTlsKeyResponse ) AsUint8Array (maxLength ... int ) ([]byte , error ) {
37+ block , _ := pem .Decode ([]byte (r .Key ))
38+ if block == nil {
39+ return nil , fmt .Errorf ("failed to decode pem private key" )
40+ }
41+
42+ key , err := x509 .ParsePKCS8PrivateKey (block .Bytes )
43+ if err != nil {
44+ return nil , fmt .Errorf ("failed to parse private key: %w" , err )
45+ }
46+
47+ var keyBytes []byte
48+ switch k := key .(type ) {
49+ case * ecdsa.PrivateKey :
50+ keyBytes = k .D .FillBytes (make ([]byte , (k .Curve .Params ().N .BitLen ()+ 7 )/ 8 ))
51+ case ed25519.PrivateKey :
52+ keyBytes = k .Seed ()
53+ default :
54+ return nil , fmt .Errorf ("unsupported key type: %T" , key )
55+ }
56+
57+ if len (maxLength ) > 0 && maxLength [0 ] > 0 && maxLength [0 ] < len (keyBytes ) {
58+ return keyBytes [:maxLength [0 ]], nil
59+ }
60+ return keyBytes , nil
61+ }
62+
3063// Represents the response from a key derivation request.
3164type GetKeyResponse struct {
3265 Key string `json:"key"`
3366 SignatureChain []string `json:"signature_chain"`
3467}
3568
69+ // DecodeKey returns the key as bytes
70+ func (r * GetKeyResponse ) DecodeKey () ([]byte , error ) {
71+ return hex .DecodeString (r .Key )
72+ }
73+
74+ // DecodeSignatureChain returns the signature chain as bytes
75+ func (r * GetKeyResponse ) DecodeSignatureChain () ([][]byte , error ) {
76+ result := make ([][]byte , len (r .SignatureChain ))
77+ for i , sig := range r .SignatureChain {
78+ bytes , err := hex .DecodeString (sig )
79+ if err != nil {
80+ return nil , fmt .Errorf ("failed to decode signature %d: %w" , i , err )
81+ }
82+ result [i ] = bytes
83+ }
84+ return result , nil
85+ }
86+
3687// Represents the response from a quote request.
3788type GetQuoteResponse struct {
38- Quote [] byte `json:"quote"`
89+ Quote string `json:"quote"`
3990 EventLog string `json:"event_log"`
40- ReportData [] byte `json:"report_data"`
91+ ReportData string `json:"report_data"`
4192 VmConfig string `json:"vm_config"`
4293}
4394
95+ // DecodeQuote returns the quote bytes
96+ func (r * GetQuoteResponse ) DecodeQuote () ([]byte , error ) {
97+ return hex .DecodeString (r .Quote )
98+ }
99+
100+ // DecodeReportData returns the report data bytes
101+ func (r * GetQuoteResponse ) DecodeReportData () ([]byte , error ) {
102+ return hex .DecodeString (r .ReportData )
103+ }
104+
105+ // DecodeEventLog returns the event log as structured data
106+ func (r * GetQuoteResponse ) DecodeEventLog () ([]EventLog , error ) {
107+ var events []EventLog
108+ err := json .Unmarshal ([]byte (r .EventLog ), & events )
109+ return events , err
110+ }
111+
44112// Represents the response from an attestation request.
45113type AttestResponse struct {
46114 Attestation []byte
@@ -57,17 +125,20 @@ type EventLog struct {
57125
58126// Represents the TCB information
59127type TcbInfo struct {
60- Mrtd string `json:"mrtd"`
61- Rtmr0 string `json:"rtmr0"`
62- Rtmr1 string `json:"rtmr1"`
63- Rtmr2 string `json:"rtmr2"`
64- Rtmr3 string `json:"rtmr3"`
65- // The hash of the OS image. This is empty if the OS image is not measured by KMS.
66- OsImageHash string `json:"os_image_hash,omitempty"`
67- ComposeHash string `json:"compose_hash"`
68- DeviceID string `json:"device_id"`
69- AppCompose string `json:"app_compose"`
70- EventLog []EventLog `json:"event_log"`
128+ Mrtd string `json:"mrtd"`
129+ Rtmr0 string `json:"rtmr0"`
130+ Rtmr1 string `json:"rtmr1"`
131+ Rtmr2 string `json:"rtmr2"`
132+ Rtmr3 string `json:"rtmr3"`
133+ AppCompose string `json:"app_compose"`
134+ EventLog []EventLog `json:"event_log"`
135+ // V0.3.x fields
136+ RootfsHash string `json:"rootfs_hash,omitempty"`
137+ // V0.5.x fields
138+ MrAggregated string `json:"mr_aggregated,omitempty"`
139+ OsImageHash string `json:"os_image_hash,omitempty"`
140+ ComposeHash string `json:"compose_hash,omitempty"`
141+ DeviceID string `json:"device_id,omitempty"`
71142}
72143
73144// Represents the response from an info request
@@ -81,9 +152,11 @@ type InfoResponse struct {
81152 MrAggregated string `json:"mr_aggregated,omitempty"`
82153 KeyProviderInfo string `json:"key_provider_info"`
83154 // Optional: empty if OS image is not measured by KMS
84- OsImageHash string `json:"os_image_hash,omitempty"`
85- ComposeHash string `json:"compose_hash"`
86- VmConfig string `json:"vm_config,omitempty"`
155+ OsImageHash string `json:"os_image_hash,omitempty"`
156+ ComposeHash string `json:"compose_hash"`
157+ VmConfig string `json:"vm_config,omitempty"`
158+ CloudVendor string `json:"cloud_vendor,omitempty"`
159+ CloudProduct string `json:"cloud_product,omitempty"`
87160}
88161
89162// DecodeTcbInfo decodes the TcbInfo string into a TcbInfo struct
@@ -269,6 +342,7 @@ func (c *DstackClient) sendRPCRequest(ctx context.Context, path string, payload
269342 }
270343
271344 req .Header .Set ("Content-Type" , "application/json" )
345+ req .Header .Set ("User-Agent" , "dstack-sdk-go/0.1.0" )
272346 resp , err := c .httpClient .Do (req )
273347 if err != nil {
274348 return nil , err
@@ -297,6 +371,9 @@ type tlsKeyOptions struct {
297371 usageRaTls bool
298372 usageServerAuth bool
299373 usageClientAuth bool
374+ notBefore * uint64
375+ notAfter * uint64
376+ withAppInfo * bool
300377}
301378
302379// WithSubject sets the subject for the TLS key
@@ -334,6 +411,27 @@ func WithUsageClientAuth(usage bool) TlsKeyOption {
334411 }
335412}
336413
414+ // WithNotBefore sets the not_before timestamp for the certificate
415+ func WithNotBefore (t uint64 ) TlsKeyOption {
416+ return func (opts * tlsKeyOptions ) {
417+ opts .notBefore = & t
418+ }
419+ }
420+
421+ // WithNotAfter sets the not_after timestamp for the certificate
422+ func WithNotAfter (t uint64 ) TlsKeyOption {
423+ return func (opts * tlsKeyOptions ) {
424+ opts .notAfter = & t
425+ }
426+ }
427+
428+ // WithAppInfo sets the with_app_info flag for the certificate
429+ func WithAppInfo (enabled bool ) TlsKeyOption {
430+ return func (opts * tlsKeyOptions ) {
431+ opts .withAppInfo = & enabled
432+ }
433+ }
434+
337435// Gets a TLS key from the dstack service with optional parameters.
338436func (c * DstackClient ) GetTlsKey (
339437 ctx context.Context ,
@@ -356,6 +454,15 @@ func (c *DstackClient) GetTlsKey(
356454 if len (opts .altNames ) > 0 {
357455 payload ["alt_names" ] = opts .altNames
358456 }
457+ if opts .notBefore != nil {
458+ payload ["not_before" ] = * opts .notBefore
459+ }
460+ if opts .notAfter != nil {
461+ payload ["not_after" ] = * opts .notAfter
462+ }
463+ if opts .withAppInfo != nil {
464+ payload ["with_app_info" ] = * opts .withAppInfo
465+ }
359466
360467 data , err := c .sendRPCRequest (ctx , "/GetTlsKey" , payload )
361468 if err != nil {
@@ -429,30 +536,12 @@ func (c *DstackClient) GetQuote(ctx context.Context, reportData []byte) (*GetQuo
429536 return nil , err
430537 }
431538
432- var response struct {
433- Quote string `json:"quote"`
434- EventLog string `json:"event_log"`
435- ReportData string `json:"report_data"`
436- }
539+ var response GetQuoteResponse
437540 if err := json .Unmarshal (data , & response ); err != nil {
438541 return nil , err
439542 }
440543
441- quote , err := hex .DecodeString (response .Quote )
442- if err != nil {
443- return nil , err
444- }
445-
446- reportDataBytes , err := hex .DecodeString (response .ReportData )
447- if err != nil {
448- return nil , err
449- }
450-
451- return & GetQuoteResponse {
452- Quote : quote ,
453- EventLog : response .EventLog ,
454- ReportData : reportDataBytes ,
455- }, nil
544+ return & response , nil
456545}
457546
458547// Gets a versioned attestation from the dstack service.
@@ -600,14 +689,142 @@ func (c *DstackClient) Verify(ctx context.Context, algorithm string, data []byte
600689 return & response , nil
601690}
602691
692+ // IsReachable checks if the service is reachable
693+ func (c * DstackClient ) IsReachable (ctx context.Context ) bool {
694+ ctx , cancel := context .WithTimeout (ctx , 500 * time .Millisecond )
695+ defer cancel ()
696+ _ , err := c .Info (ctx )
697+ return err == nil
698+ }
699+
603700// EmitEvent sends an event to be extended to RTMR3 on TDX platform.
604701// The event will be extended to RTMR3 with the provided name and payload.
605702//
606703// Requires dstack OS 0.5.0 or later.
607704func (c * DstackClient ) EmitEvent (ctx context.Context , event string , payload []byte ) error {
705+ if event == "" {
706+ return fmt .Errorf ("event name cannot be empty" )
707+ }
608708 _ , err := c .sendRPCRequest (ctx , "/EmitEvent" , map [string ]interface {}{
609709 "event" : event ,
610710 "payload" : hex .EncodeToString (payload ),
611711 })
612712 return err
613713}
714+
715+ // Legacy methods for backward compatibility with warnings
716+
717+ // DeriveKey is deprecated. Use GetKey instead.
718+ // Deprecated: Use GetKey instead.
719+ func (c * DstackClient ) DeriveKey (path string , subject string , altNames []string ) (* GetTlsKeyResponse , error ) {
720+ return nil , fmt .Errorf ("deriveKey is deprecated, please use GetKey instead" )
721+ }
722+
723+ // TdxQuote is deprecated. Use GetQuote instead.
724+ // Deprecated: Use GetQuote instead.
725+ func (c * DstackClient ) TdxQuote (ctx context.Context , reportData []byte , hashAlgorithm string ) (* GetQuoteResponse , error ) {
726+ c .logger .Warn ("tdxQuote is deprecated, please use GetQuote instead" )
727+ if hashAlgorithm != "raw" {
728+ return nil , fmt .Errorf ("tdxQuote only supports raw hash algorithm" )
729+ }
730+ return c .GetQuote (ctx , reportData )
731+ }
732+
733+ // TappdClient is a deprecated wrapper around DstackClient for backward compatibility.
734+ // Deprecated: Use DstackClient instead.
735+ type TappdClient struct {
736+ * DstackClient
737+ }
738+
739+ // NewTappdClient creates a new deprecated TappdClient.
740+ // Deprecated: Use NewDstackClient instead.
741+ func NewTappdClient (opts ... DstackClientOption ) * TappdClient {
742+ // Create a modified option to use TAPPD_SIMULATOR_ENDPOINT
743+ tappdOpts := make ([]DstackClientOption , 0 , len (opts )+ 1 )
744+
745+ // Add default endpoint option that checks TAPPD_SIMULATOR_ENDPOINT
746+ tappdOpts = append (tappdOpts , func (c * DstackClient ) {
747+ if c .endpoint == "" {
748+ if simEndpoint , exists := os .LookupEnv ("TAPPD_SIMULATOR_ENDPOINT" ); exists {
749+ c .logger .Warn ("Using tappd endpoint" , "endpoint" , simEndpoint )
750+ c .endpoint = simEndpoint
751+ } else {
752+ c .endpoint = "/var/run/tappd.sock"
753+ }
754+ }
755+ })
756+
757+ // Add user-provided options
758+ tappdOpts = append (tappdOpts , opts ... )
759+
760+ client := NewDstackClient (tappdOpts ... )
761+ client .logger .Warn ("TappdClient is deprecated, please use DstackClient instead" )
762+
763+ return & TappdClient {
764+ DstackClient : client ,
765+ }
766+ }
767+
768+ // Override deprecated methods to use proper tappd RPC paths
769+
770+ // DeriveKey is deprecated. Use GetKey instead.
771+ // Deprecated: Use GetKey instead.
772+ func (tc * TappdClient ) DeriveKey (ctx context.Context , path string , subject string , altNames []string ) (* GetTlsKeyResponse , error ) {
773+ tc .logger .Warn ("deriveKey is deprecated, please use GetKey instead" )
774+
775+ if subject == "" {
776+ subject = path
777+ }
778+
779+ payload := map [string ]interface {}{
780+ "path" : path ,
781+ "subject" : subject ,
782+ }
783+ if len (altNames ) > 0 {
784+ payload ["alt_names" ] = altNames
785+ }
786+
787+ data , err := tc .sendRPCRequest (ctx , "/prpc/Tappd.DeriveKey" , payload )
788+ if err != nil {
789+ return nil , err
790+ }
791+
792+ var response GetTlsKeyResponse
793+ if err := json .Unmarshal (data , & response ); err != nil {
794+ return nil , err
795+ }
796+ return & response , nil
797+ }
798+
799+ // TdxQuote is deprecated. Use GetQuote instead.
800+ // Deprecated: Use GetQuote instead.
801+ func (tc * TappdClient ) TdxQuote (ctx context.Context , reportData []byte , hashAlgorithm string ) (* GetQuoteResponse , error ) {
802+ tc .logger .Warn ("tdxQuote is deprecated, please use GetQuote instead" )
803+
804+ if hashAlgorithm == "raw" {
805+ if len (reportData ) > 64 {
806+ return nil , fmt .Errorf ("report data is too large, it should be at most 64 bytes when hashAlgorithm is raw" )
807+ }
808+ if len (reportData ) < 64 {
809+ // Left-pad with zeros
810+ padding := make ([]byte , 64 - len (reportData ))
811+ reportData = append (padding , reportData ... )
812+ }
813+ }
814+
815+ payload := map [string ]interface {}{
816+ "report_data" : hex .EncodeToString (reportData ),
817+ "hash_algorithm" : hashAlgorithm ,
818+ }
819+
820+ data , err := tc .sendRPCRequest (ctx , "/prpc/Tappd.TdxQuote" , payload )
821+ if err != nil {
822+ return nil , err
823+ }
824+
825+ var response GetQuoteResponse
826+ if err := json .Unmarshal (data , & response ); err != nil {
827+ return nil , err
828+ }
829+ return & response , nil
830+ }
0 commit comments