Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cmd/pilotctl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3960,7 +3960,8 @@ func cmdApprove(args []string) {

nodeID := resolveToNodeID(d, args[0])

result, err := d.ApproveHandshake(nodeID)
adminToken := getAdminToken()
result, err := d.ApproveHandshake(nodeID, adminToken)
if err != nil {
fatalCode("connection_failed", "approve: %v", err)
}
Expand All @@ -3986,7 +3987,8 @@ func cmdReject(args []string) {
reason = args[1]
}

result, err := d.RejectHandshake(nodeID, reason)
adminToken := getAdminToken()
result, err := d.RejectHandshake(nodeID, reason, adminToken)
if err != nil {
fatalCode("connection_failed", "reject: %v", err)
}
Expand All @@ -4006,7 +4008,8 @@ func cmdUntrust(args []string) {
defer d.Close()

nodeID := resolveToNodeID(d, args[0])
_, err := d.RevokeTrust(nodeID)
adminToken := getAdminToken()
_, err := d.RevokeTrust(nodeID, adminToken)
if err != nil {
fatalCode("connection_failed", "untrust: %v", err)
}
Expand Down
41 changes: 41 additions & 0 deletions pkg/daemon/ipc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package daemon

import (
"context"
"crypto/subtle"
"encoding/binary"
"encoding/json"
"errors"
Expand Down Expand Up @@ -192,6 +193,8 @@ var ErrIPCClosed = errors.New("ipc: connection closed")
// code paths.
var ErrIPCBackpressure = errors.New("ipc: backpressure (client too slow)")

var errHandshakeAuth = errors.New("ipc: handshake requires admin token")

// IPCEnvelopeHeaderSize is the size of the per-message header that sits
// inside the ipcutil length-framed envelope: 1 byte cmd.
const IPCEnvelopeHeaderSize = 1
Expand Down Expand Up @@ -1232,6 +1235,10 @@ func (s *IPCServer) handleHandshake(conn *ipcConn, reqID uint64, payload []byte)
s.ipcWriteHandshakeOK(conn, reqID, data)

case SubHandshakeApprove:
rest, err := s.checkHandshakeAdminToken(conn, reqID, rest, "approve")
if err != nil {
return // error already sent
}
if len(rest) < 4 {
s.sendError(conn, reqID, "handshake approve: missing node_id")
return
Expand All @@ -1248,6 +1255,10 @@ func (s *IPCServer) handleHandshake(conn *ipcConn, reqID uint64, payload []byte)
s.ipcWriteHandshakeOK(conn, reqID, data)

case SubHandshakeReject:
rest, err := s.checkHandshakeAdminToken(conn, reqID, rest, "reject")
if err != nil {
return // error already sent
}
if len(rest) < 4 {
s.sendError(conn, reqID, "handshake reject: missing node_id")
return
Expand Down Expand Up @@ -1301,6 +1312,10 @@ func (s *IPCServer) handleHandshake(conn *ipcConn, reqID uint64, payload []byte)
s.ipcWriteHandshakeOK(conn, reqID, data)

case SubHandshakeRevoke:
rest, err := s.checkHandshakeAdminToken(conn, reqID, rest, "revoke")
if err != nil {
return // error already sent
}
if len(rest) < 4 {
s.sendError(conn, reqID, "handshake revoke: missing node_id")
return
Expand Down Expand Up @@ -1340,6 +1355,32 @@ func (s *IPCServer) handleHandshake(conn *ipcConn, reqID uint64, payload []byte)
}
}

// checkHandshakeAdminToken verifies the admin token prefix when the daemon
// has an admin token configured. Handshake approve/reject/revoke are
// privileged state-mutation verbs — they require the same token gate as
// BroadcastDatagram. When no admin token is configured, the check is a
// no-op (backward-compatible with pre-token daemon configs).
func (s *IPCServer) checkHandshakeAdminToken(conn *ipcConn, reqID uint64, rest []byte, verb string) ([]byte, error) {
if len(rest) < 2 {
s.sendError(conn, reqID, fmt.Sprintf("handshake %s: missing admin token header", verb))
return nil, errHandshakeAuth
}
tokenLen := binary.BigEndian.Uint16(rest[0:2])
if len(rest) < 2+int(tokenLen) {
s.sendError(conn, reqID, fmt.Sprintf("handshake %s: truncated admin token", verb))
return nil, errHandshakeAuth
}
payload := rest[2+tokenLen:]
if s.daemon.config.AdminToken != "" {
token := string(rest[2 : 2+tokenLen])
if subtle.ConstantTimeCompare([]byte(token), []byte(s.daemon.config.AdminToken)) != 1 {
s.sendError(conn, reqID, fmt.Sprintf("handshake %s: invalid admin token", verb))
return nil, errHandshakeAuth
}
}
return payload, nil
}

func (s *IPCServer) ipcWriteHandshakeOK(conn *ipcConn, reqID uint64, data []byte) {
if err := conn.writeReply(CmdHandshakeOK, reqID, data); err != nil {
slog.Debug("IPC handshake reply failed", "err", err)
Expand Down
29 changes: 19 additions & 10 deletions pkg/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,21 +228,27 @@ func (d *Driver) Handshake(nodeID uint32, justification string) (map[string]inte
}

// ApproveHandshake approves a pending trust handshake request.
func (d *Driver) ApproveHandshake(nodeID uint32) (map[string]interface{}, error) {
msg := make([]byte, 6)
func (d *Driver) ApproveHandshake(nodeID uint32, adminToken string) (map[string]interface{}, error) {
tokenBytes := []byte(adminToken)
msg := make([]byte, 1+1+2+len(tokenBytes)+4)
msg[0] = cmdHandshake
msg[1] = subHandshakeApprove
binary.BigEndian.PutUint32(msg[2:6], nodeID)
binary.BigEndian.PutUint16(msg[2:4], uint16(len(tokenBytes)))
copy(msg[4:4+len(tokenBytes)], tokenBytes)
binary.BigEndian.PutUint32(msg[4+len(tokenBytes):], nodeID)
return d.jsonRPC(msg, cmdHandshakeOK, "approve")
}

// RejectHandshake rejects a pending trust handshake request.
func (d *Driver) RejectHandshake(nodeID uint32, reason string) (map[string]interface{}, error) {
msg := make([]byte, 1+1+4+len(reason))
func (d *Driver) RejectHandshake(nodeID uint32, reason string, adminToken string) (map[string]interface{}, error) {
tokenBytes := []byte(adminToken)
msg := make([]byte, 1+1+2+len(tokenBytes)+4+len(reason))
msg[0] = cmdHandshake
msg[1] = subHandshakeReject
binary.BigEndian.PutUint32(msg[2:6], nodeID)
copy(msg[6:], reason)
binary.BigEndian.PutUint16(msg[2:4], uint16(len(tokenBytes)))
copy(msg[4:4+len(tokenBytes)], tokenBytes)
binary.BigEndian.PutUint32(msg[4+len(tokenBytes):], nodeID)
copy(msg[4+len(tokenBytes)+4:], reason)
return d.jsonRPC(msg, cmdHandshakeOK, "reject")
}

Expand Down Expand Up @@ -274,11 +280,14 @@ func (d *Driver) TrustedPeers() (map[string]interface{}, error) {
}

// RevokeTrust removes a peer from the trusted set and notifies the registry.
func (d *Driver) RevokeTrust(nodeID uint32) (map[string]interface{}, error) {
msg := make([]byte, 6)
func (d *Driver) RevokeTrust(nodeID uint32, adminToken string) (map[string]interface{}, error) {
tokenBytes := []byte(adminToken)
msg := make([]byte, 1+1+2+len(tokenBytes)+4)
msg[0] = cmdHandshake
msg[1] = subHandshakeRevoke
binary.BigEndian.PutUint32(msg[2:6], nodeID)
binary.BigEndian.PutUint16(msg[2:4], uint16(len(tokenBytes)))
copy(msg[4:4+len(tokenBytes)], tokenBytes)
binary.BigEndian.PutUint32(msg[4+len(tokenBytes):], nodeID)
return d.jsonRPC(msg, cmdHandshakeOK, "revoke")
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/driver/zz_driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,10 @@ func TestHandshakeFamilyRoundTrips(t *testing.T) {
if _, err := drv.Handshake(99, "please"); err != nil {
t.Fatalf("Handshake: %v", err)
}
if _, err := drv.ApproveHandshake(100); err != nil {
if _, err := drv.ApproveHandshake(100, ""); err != nil {
t.Fatalf("Approve: %v", err)
}
if _, err := drv.RejectHandshake(101, "no"); err != nil {
if _, err := drv.RejectHandshake(101, "no", ""); err != nil {
t.Fatalf("Reject: %v", err)
}
if _, err := drv.PendingHandshakes(); err != nil {
Expand All @@ -493,7 +493,7 @@ func TestHandshakeFamilyRoundTrips(t *testing.T) {
if _, err := drv.TrustedPeers(); err != nil {
t.Fatalf("Trusted: %v", err)
}
if _, err := drv.RevokeTrust(102); err != nil {
if _, err := drv.RevokeTrust(102, ""); err != nil {
t.Fatalf("Revoke: %v", err)
}

Expand Down
8 changes: 4 additions & 4 deletions tests/zz_handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func TestHandshakePendingApproveReject(t *testing.T) {
}

// B approves A
_, err = drvB.ApproveHandshake(daemonA.NodeID())
_, err = drvB.ApproveHandshake(daemonA.NodeID(), "")
if err != nil {
t.Fatalf("approve: %v", err)
}
Expand Down Expand Up @@ -235,7 +235,7 @@ func TestHandshakePendingApproveReject(t *testing.T) {
}

// B rejects C
_, err = drvB.RejectHandshake(daemonC.NodeID(), "not authorized")
_, err = drvB.RejectHandshake(daemonC.NodeID(), "not authorized", "")
if err != nil {
t.Fatalf("reject: %v", err)
}
Expand Down Expand Up @@ -409,7 +409,7 @@ func TestHandshakeRevokeTrust(t *testing.T) {
}

// A revokes trust in B
_, err := drvA.RevokeTrust(daemonB.NodeID())
_, err := drvA.RevokeTrust(daemonB.NodeID(), "")
if err != nil {
t.Fatalf("revoke: %v", err)
}
Expand Down Expand Up @@ -531,7 +531,7 @@ func TestHandshakeRejectReason(t *testing.T) {
}

// B rejects with reason
_, err = drvB.RejectHandshake(daemonA.NodeID(), "not authorized for this network")
_, err = drvB.RejectHandshake(daemonA.NodeID(), "not authorized for this network", "")
if err != nil {
t.Fatalf("reject: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions tests/zz_trust_gate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestWaitForTrustFastPathAfterTrust(t *testing.T) {
}

// B approves → both sides become trusted.
if _, err := b.Driver.ApproveHandshake(a.Daemon.NodeID()); err != nil {
if _, err := b.Driver.ApproveHandshake(a.Daemon.NodeID(), ""); err != nil {
t.Fatalf("B approve: %v", err)
}

Expand Down Expand Up @@ -171,7 +171,7 @@ func TestWaitForTrustBlocksUntilApproved(t *testing.T) {
case <-time.After(10 * time.Millisecond):
}
}
if _, err := b.Driver.ApproveHandshake(a.Daemon.NodeID()); err != nil {
if _, err := b.Driver.ApproveHandshake(a.Daemon.NodeID(), ""); err != nil {
t.Fatalf("B approve: %v", err)
}

Expand Down
Loading