Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ this plugin deployed in your Kubernetes cluster, you will be able to run jobs
* This plugin targets Kubernetes v1.18+.

## Deployment
The device plugin needs to be run on all the nodes that are equipped with Confidential Computing devices (e.g. TPM). The simplest way of doing so is to create a Kubernetes [DaemonSet][dp], which run a copy of a pod on all (or some) Nodes in the cluster. We have a pre-built Docker image on [Goolge Artifact Registry][release] that you can use for with your DaemonSet. This repository also have a pre-defined yaml file named `cc-device-plugin.yaml`. You can create a DaemonSet in your Kubernetes cluster by running this command:
The device plugin needs to be run on all the nodes that are equipped with Confidential Computing devices (e.g. TPM). The simplest way of doing so is to create a Kubernetes [DaemonSet][dp], which run a copy of a pod on all (or some) Nodes in the cluster. We have a pre-built Docker image on [Google Artifact Registry][release] that you can use for with your DaemonSet. This repository also have a pre-defined yaml file named `cc-device-plugin.yaml`. You can create a DaemonSet in your Kubernetes cluster by running this command:

```
kubectl create -f manifests/cc-device-plugin.yaml
Expand Down
246 changes: 152 additions & 94 deletions deviceplugin/ccdevice.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"

Expand All @@ -32,9 +33,18 @@ import (
)

const (
deviceCheckInterval = 5 * time.Second
// By default, GKE allows up to 110 Pods per node on Standard clusters. Standard clusters can be configured to allow up to 256 Pods per node.
workloadSharedLimit = 256
deviceCheckInterval = 5 * time.Second
copiedEventLogDirectory = "/run/cc-device-plugin"
copiedEventLogLocation = "/run/cc-device-plugin/binary_bios_measurements"
containerEventLogDirectory = "/run/cc-device-plugin"
)

// AttestationType defines if the attestation is based on software emulation or hardware.
type AttestationType string

const (
SoftwareAttestation AttestationType = "software" // e.g., vTPM
HardwareAttestation AttestationType = "hardware" // e.g., Intel TDX, AMD SEV-SNP
)

var (
Expand All @@ -47,25 +57,25 @@ type CcDeviceSpec struct {
Resource string
DevicePaths []string
MeasurementPaths []string
DeviceLimit int // Number of allocatable instances of this resource
Type AttestationType // New flag to explicitly define the device type
}

// CcDevice wraps the v1.beta1.Device type, which has hostPath, containerPath and permission
type CcDevice struct {
v1beta1.Device
DeviceSpecs []*v1beta1.DeviceSpec
Mounts []*v1beta1.Mount
// Limit specifies the cap number of workloads sharing a worker node
Limit int
}

// CcDevicePlugin is a device plugin for cc devices
type CcDevicePlugin struct {
cds *CcDeviceSpec
ccDevices map[string]CcDevice
copiedEventLogDirectory string
copiedEventLogLocation string
containerEventLogDirectory string
logger log.Logger
cds *CcDeviceSpec
ccDevices map[string]CcDevice
logger log.Logger
copiedEventLogDirectory string
copiedEventLogLocation string
containerEventLogDirectory string
// this lock prevents data race when kubelet sends multiple requests at the same time
mu sync.Mutex

Expand All @@ -79,14 +89,17 @@ func NewCcDevicePlugin(cds *CcDeviceSpec, devicePluginPath string, socket string
if logger == nil {
logger = log.NewNopLogger()
}
if cds.DeviceLimit <= 0 {
cds.DeviceLimit = 1 // Default to 1 if not specified
}

cdp := &CcDevicePlugin{
cds: cds,
ccDevices: make(map[string]CcDevice),
logger: logger,
copiedEventLogDirectory: "/run/cc-device-plugin",
copiedEventLogLocation: "/run/cc-device-plugin/binary_bios_measurements",
containerEventLogDirectory: "/run/cc-device-plugin",
cds: cds,
ccDevices: make(map[string]CcDevice),
logger: logger,
copiedEventLogDirectory: copiedEventLogDirectory,
copiedEventLogLocation: copiedEventLogLocation, // Note: This path is static, used only by vTPM plugin instance.
containerEventLogDirectory: containerEventLogDirectory,
deviceGauge: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "cc_device_plugin_devices",
Help: "The number of cc devices managed by this device plugin.",
Expand All @@ -97,16 +110,19 @@ func NewCcDevicePlugin(cds *CcDeviceSpec, devicePluginPath string, socket string
}),
}

// Check if the copiedEventLogDirectory directory exists
if _, err := os.Stat(cdp.copiedEventLogDirectory); os.IsNotExist(err) {
// Create the directory
err = os.Mkdir(cdp.copiedEventLogDirectory, 0755)
if err != nil {
return nil, err
// Only create the directory if the device type is software-based (e.g., vTPM),
// as hardware-based devices (TDX/SNP) do not require copying measurement files to /run.
if cdp.cds.Type == SoftwareAttestation {
if _, err := os.Stat(cdp.copiedEventLogDirectory); os.IsNotExist(err) {
// Create the directory
err = os.MkdirAll(cdp.copiedEventLogDirectory, 0755)
if err != nil {
return nil, err
}
level.Info(cdp.logger).Log("msg", "Directory created:" + cdp.copiedEventLogDirectory)
} else {
level.Info(cdp.logger).Log("msg", "Directory already exists:" + cdp.copiedEventLogDirectory)
}
level.Info(cdp.logger).Log("msg", "Directory created:"+cdp.copiedEventLogDirectory)
} else {
level.Info(cdp.logger).Log("msg", "Directory already exists:"+cdp.copiedEventLogDirectory)
Comment thread
jimmychiuuuu marked this conversation as resolved.
}

if reg != nil {
Expand All @@ -118,75 +134,109 @@ func NewCcDevicePlugin(cds *CcDeviceSpec, devicePluginPath string, socket string

func (cdp *CcDevicePlugin) discoverCcDevices() ([]CcDevice, error) {
var ccDevices []CcDevice
cd := CcDevice{
Device: v1beta1.Device{
Health: v1beta1.Healthy,
},
// set cap
Limit: workloadSharedLimit,
}
h := sha1.New()
var foundDevicePaths []string

// We use foundDevicePaths as an accumulator because a single resource (like TDX)
// might be represented by multiple device path patterns.
for _, path := range cdp.cds.DevicePaths {
matches, err := filepath.Glob(path)
if err != nil {
return nil, err
}
for _, matchPath := range matches {
level.Info(cdp.logger).Log("msg", "device path found:"+matchPath)
cd.DeviceSpecs = append(cd.DeviceSpecs, &v1beta1.DeviceSpec{
HostPath: matchPath,
ContainerPath: matchPath,
Permissions: "mrw",
})
if len(matches) > 0 {
level.Info(cdp.logger).Log("msg", "found matching device path(s)", "pattern", path, "matches", strings.Join(matches, ","))
foundDevicePaths = append(foundDevicePaths, matches...)
Comment thread
jimmychiuuuu marked this conversation as resolved.
}
}

for _, path := range cdp.cds.MeasurementPaths {
matches, err := filepath.Glob(path)
if err != nil {
return nil, err
// If no device paths were found for this resource type, simply return an empty list.
// This is not an error; the node just doesn't have this specific hardware.
if len(foundDevicePaths) == 0 {
return nil, nil
}

baseDevice := CcDevice{
Device: v1beta1.Device{
Health: v1beta1.Healthy,
},
}

for _, matchPath := range foundDevicePaths {
baseDevice.DeviceSpecs = append(baseDevice.DeviceSpecs, &v1beta1.DeviceSpec{
HostPath: matchPath,
ContainerPath: matchPath,
Permissions: "mrw",
})
}

// Measurement files are currently only expected for software-emulated devices (vTPM).
if cdp.cds.Type == SoftwareAttestation && len(cdp.cds.MeasurementPaths) > 0 {
var foundMeasurementPath string
for _, path := range cdp.cds.MeasurementPaths {
matches, err := filepath.Glob(path)
if err != nil {
return nil, err
}
if len(matches) > 0 {
// We only expect one measurement file
Comment thread
jimmychiuuuu marked this conversation as resolved.
foundMeasurementPath = matches[0]
level.Info(cdp.logger).Log("msg", "measurement path found", "path", foundMeasurementPath)
break
}
}
for _, matchPath := range matches {
level.Info(cdp.logger).Log("msg", "measurement path found:"+matchPath)
cd.Mounts = append(cd.Mounts, &v1beta1.Mount{
if foundMeasurementPath != "" {
baseDevice.Mounts = append(baseDevice.Mounts, &v1beta1.Mount{
HostPath: cdp.copiedEventLogDirectory,
ContainerPath: cdp.containerEventLogDirectory,
ReadOnly: true,
})

// copy when no measurement file at copiedEventLogLocation
Comment thread
jimmychiuuuu marked this conversation as resolved.
fileInfo, err := os.Stat(cdp.copiedEventLogLocation)
if errors.Is(err, os.ErrNotExist) {
err := copyMeasurementFile(matchPath, cdp.copiedEventLogLocation)
if err != nil {
if err := copyMeasurementFile(foundMeasurementPath, cdp.copiedEventLogLocation); err != nil {
level.Error(cdp.logger).Log("msg", "failed to copy measurement file", "error", err)
return nil, err
}
} else {
// copy when measurement file at /run was updated, but not by the current instance.
// measurementFileLastUpdate is init to 0.
// when file exists during first run, this instance deletes and creates a new file
if fileInfo.ModTime().After(measurementFileLastUpdate) {
err := copyMeasurementFile(matchPath, cdp.copiedEventLogLocation)
if err != nil {
return nil, err
}
} else if err == nil && fileInfo.ModTime().After(measurementFileLastUpdate) {
Comment thread
jimmychiuuuu marked this conversation as resolved.
// Refresh the copy if the source file has been updated by the kernel since the last copy.
if err := copyMeasurementFile(foundMeasurementPath, cdp.copiedEventLogLocation); err != nil {
level.Error(cdp.logger).Log("msg", "failed to re-copy measurement file", "error", err)
return nil, err
}
} else if err != nil {
level.Error(cdp.logger).Log("msg", "failed to stat copied measurement file", "error", err)
return nil, err
}
} else {
level.Warn(cdp.logger).Log("msg", "MeasurementPaths specified but no measurement file found", "paths", strings.Join(cdp.cds.MeasurementPaths, ","))
}
}
if cd.DeviceSpecs != nil {
for i := 0; i < cd.Limit; i++ {
b := make([]byte, 1)
b[0] = byte(i)
cd.ID = fmt.Sprintf("%x", h.Sum(b))
ccDevices = append(ccDevices, cd)

// Create DeviceLimit instances of the device
h := sha1.New()
h.Write([]byte(cdp.cds.Resource))
baseID := fmt.Sprintf("%x", h.Sum(nil))

for i := 0; i < cdp.cds.DeviceLimit; i++ {
cd := baseDevice // Copy the base structure
// For single-limit devices, ID is baseID. For multi-limit, append index.
if cdp.cds.DeviceLimit > 1 {
cd.ID = fmt.Sprintf("%s-%d", baseID, i)
} else {
cd.ID = baseID
}
ccDevices = append(ccDevices, cd)
}

return ccDevices, nil
}

func copyMeasurementFile(src string, dest string) error {
// get time for src
sourceInfo, err := os.Stat(src)
if err != nil {
return err
}
// copy out measurement
eventlogFile, err := os.ReadFile(src)
if err != nil {
Expand All @@ -201,11 +251,7 @@ func copyMeasurementFile(src string, dest string) error {
if err != nil {
return err
}
fileInfo, err := os.Stat(dest)
if err != nil {
return err
}
measurementFileLastUpdate = fileInfo.ModTime()
measurementFileLastUpdate = sourceInfo.ModTime()
return nil
}

Expand Down Expand Up @@ -235,18 +281,28 @@ func (cdp *CcDevicePlugin) refreshDevices() (bool, error) {
devicesUnchange = false
}
}
if !devicesUnchange {
return false, nil
if len(ccDevices) != len(old) {
devicesUnchange = false
}

// Check if devices were removed.
if devicesUnchange {
return true, nil
}

// Log if devices were removed
for k := range old {
if _, ok := cdp.ccDevices[k]; !ok {
level.Warn(cdp.logger).Log("msg", "devices removed")
return false, nil
level.Info(cdp.logger).Log("msg", "device removed", "id", k)
}
}
return true, nil
// Log if devices were added
for k := range cdp.ccDevices {
if _, ok := old[k]; !ok {
level.Info(cdp.logger).Log("msg", "device added", "id", k)
}
}

return false, nil
}

// Allocate assigns cc devices to a Pod.
Expand All @@ -267,19 +323,18 @@ func (cdp *CcDevicePlugin) Allocate(_ context.Context, req *v1beta1.AllocateRequ
if ccDevice.Health != v1beta1.Healthy {
return nil, fmt.Errorf("requested cc device is not healthy %q", id)
}
level.Info(cdp.logger).Log("msg", "adding device and measurement to Pod, device id is:"+id)
level.Info(cdp.logger).Log("msg", "adding device and measurement to Pod", "device id", id)

for _, ds := range ccDevice.DeviceSpecs {
level.Info(cdp.logger).Log("msg", "added ccDevice.deviceSpecs is:"+ds.String())
level.Debug(cdp.logger).Log("msg", "added ccDevice.deviceSpecs", "spec", ds.String())
}

for _, dm := range ccDevice.Mounts {
level.Info(cdp.logger).Log("msg", "added ccDevice.mounts is:"+dm.String())
level.Debug(cdp.logger).Log("msg", "added ccDevice.mounts", "mount", dm.String())
}

resp.Devices = append(resp.Devices, ccDevice.DeviceSpecs...)
resp.Mounts = append(resp.Mounts, ccDevice.Mounts...)

}
res.ContainerResponses = append(res.ContainerResponses, resp)
}
Expand All @@ -298,23 +353,26 @@ func (cdp *CcDevicePlugin) ListAndWatch(_ *v1beta1.Empty, stream v1beta1.DeviceP
if _, err := cdp.refreshDevices(); err != nil {
return err
}
refreshComplete := false
var err error

for {
if !refreshComplete {
res := new(v1beta1.ListAndWatchResponse)
for _, dev := range cdp.ccDevices {
res.Devices = append(res.Devices, &v1beta1.Device{ID: dev.ID, Health: dev.Health})
}
if err := stream.Send(res); err != nil {
return err
}
res := new(v1beta1.ListAndWatchResponse)
cdp.mu.Lock()
for _, dev := range cdp.ccDevices {
res.Devices = append(res.Devices, &v1beta1.Device{ID: dev.ID, Health: dev.Health})
}
<-time.After(deviceCheckInterval)
refreshComplete, err = cdp.refreshDevices()
if err != nil {
cdp.mu.Unlock()

if err := stream.Send(res); err != nil {
level.Error(cdp.logger).Log("msg", "failed to send ListAndWatchResponse", "error", err)
return err
}

<-time.After(deviceCheckInterval)

if _, err := cdp.refreshDevices(); err != nil {
level.Error(cdp.logger).Log("msg", "error during device refresh", "error", err)
// Don't return error immediately, try to continue
}
}
}

Expand Down
Loading