diff --git a/go.mod b/go.mod index 92a51b2..5c04c3a 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/nebius/gosdk v0.0.0-20250826102719-940ad1dfb5de github.com/pkg/errors v0.9.1 github.com/sfcompute/nodes-go v0.1.0-alpha.4 + github.com/sfcompute/sfc-go v0.1.0-preview github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.47.0 golang.org/x/text v0.33.0 @@ -85,6 +86,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + github.com/spyzhov/ajson v0.8.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect diff --git a/go.sum b/go.sum index 55a8123..f95ba1b 100644 --- a/go.sum +++ b/go.sum @@ -162,12 +162,16 @@ github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0t github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/sfcompute/nodes-go v0.1.0-alpha.4 h1:oFBWcMPSpqLYm/NDs5I1jTvzgx9rsXDL9Ghsm30Hc0Q= github.com/sfcompute/nodes-go v0.1.0-alpha.4/go.mod h1:nUviHgK+Fgt2hDFcRL3M8VoyiypC8fc0dsY8C30QU8M= +github.com/sfcompute/sfc-go v0.1.0-preview h1:yJ6ICglA/JZal2kauzb2aZlV9XdLPejsvFpsKwwThkQ= +github.com/sfcompute/sfc-go v0.1.0-preview/go.mod h1:vhUpRpAHKitZzzWPg87RjreC+pzK57PGe4ZuSIQSk94= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spyzhov/ajson v0.8.0 h1:sFXyMbi4Y/BKjrsfkUZHSjA2JM1184enheSjjoT/zCc= +github.com/spyzhov/ajson v0.8.0/go.mod h1:63V+CGM6f1Bu/p4nLIN8885ojBdt88TbLoSFzyqMuVA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/v1/providers/sfcomputev2/brev_constants.go b/v1/providers/sfcomputev2/brev_constants.go new file mode 100644 index 0000000..c9ec88f --- /dev/null +++ b/v1/providers/sfcomputev2/brev_constants.go @@ -0,0 +1,23 @@ +package v2 + +// Package-internal constants — SSH defaults and internal tag keys. +const ( + defaultPort = 22 + defaultSSHUsername = "ubuntu" + + // Internal tag keys written to every SFCompute V2 instance. These are stripped from + // v1.Instance.Tags on read so they don't surface as user-facing tags. + tagKeyCloudCredRefID = "brev-cloud-cred-ref-id" + tagKeyRefID = "brev-ref-id" +) + +// Brev environment config for SFCompute V2. +// TODO: source these from environment variables rather than hardcoding them here. +const ( + // BrevProductionCapacityID is the SFCompute V2 capacity ID for Brev production instances. + BrevProductionCapacityID = "brev-production-capacity" + + // BrevProductionImageID is the SFCompute image for Brev production instances + // (ubuntu-24.04.4-cuda-12.8, vm_images.vm_image_id). + BrevProductionImageID = "vmi_4GwEvmclFURy7ztFQjOdr" +) diff --git a/v1/providers/sfcomputev2/capabilities.go b/v1/providers/sfcomputev2/capabilities.go new file mode 100644 index 0000000..e9b62d6 --- /dev/null +++ b/v1/providers/sfcomputev2/capabilities.go @@ -0,0 +1,24 @@ +package v2 + +import ( + "context" + + v1 "github.com/brevdev/cloud/v1" +) + +func getSFCCapabilitiesV2() v1.Capabilities { + return v1.Capabilities{ + v1.CapabilityCreateInstance, + v1.CapabilityTerminateInstance, + v1.CapabilityCreateTerminateInstance, + v1.CapabilityTags, + } +} + +func (c *SFCClientV2) GetCapabilities(_ context.Context) (v1.Capabilities, error) { + return getSFCCapabilitiesV2(), nil +} + +func (c *SFCCredentialV2) GetCapabilities(_ context.Context) (v1.Capabilities, error) { + return getSFCCapabilitiesV2(), nil +} diff --git a/v1/providers/sfcomputev2/client.go b/v1/providers/sfcomputev2/client.go new file mode 100644 index 0000000..b5abb9a --- /dev/null +++ b/v1/providers/sfcomputev2/client.go @@ -0,0 +1,99 @@ +package v2 + +import ( + "context" + + v1 "github.com/brevdev/cloud/v1" + sfc "github.com/sfcompute/sfc-go" +) + +const CloudProviderID = "sfcompute" + +// SFCCredentialV2 holds authentication details for a Brev-managed SFCompute V2 account. +type SFCCredentialV2 struct { + RefID string + APIKey string `json:"api_key"` +} + +var _ v1.CloudCredential = &SFCCredentialV2{} + +func NewSFCCredentialV2(refID, apiKey string) *SFCCredentialV2 { + return &SFCCredentialV2{ + RefID: refID, + APIKey: apiKey, + } +} + +func (c *SFCCredentialV2) GetReferenceID() string { + return c.RefID +} + +func (c *SFCCredentialV2) GetAPIType() v1.APIType { + return v1.APITypeGlobal +} + +func (c *SFCCredentialV2) GetCloudProviderID() v1.CloudProviderID { + return CloudProviderID +} + +func (c *SFCCredentialV2) GetTenantID() (string, error) { + return "", nil +} + +type SFCClientV2 struct { + v1.NotImplCloudClient + refID string + location string + client *sfc.SDK + logger v1.Logger +} + +var _ v1.CloudClient = &SFCClientV2{} + +type SFCClientV2Option func(c *SFCClientV2) + +func WithLogger(logger v1.Logger) SFCClientV2Option { + return func(c *SFCClientV2) { + c.logger = logger + } +} + +func (c *SFCCredentialV2) MakeClientWithOptions(_ context.Context, location string, opts ...SFCClientV2Option) (v1.CloudClient, error) { + sfcClient := &SFCClientV2{ + refID: c.RefID, + location: location, + client: sfc.New(sfc.WithSecurity(c.APIKey)), + logger: &v1.NoopLogger{}, + } + + for _, opt := range opts { + opt(sfcClient) + } + + return sfcClient, nil +} + +func (c *SFCCredentialV2) MakeClient(ctx context.Context, location string) (v1.CloudClient, error) { + return c.MakeClientWithOptions(ctx, location) +} + +func (c *SFCClientV2) GetAPIType() v1.APIType { + return v1.APITypeGlobal +} + +func (c *SFCClientV2) GetCloudProviderID() v1.CloudProviderID { + return CloudProviderID +} + +func (c *SFCClientV2) GetReferenceID() string { + return c.refID +} + +func (c *SFCClientV2) GetTenantID() (string, error) { + return "", nil +} + +func (c *SFCClientV2) MakeClient(_ context.Context, location string) (v1.CloudClient, error) { + c.location = location + return c, nil +} diff --git a/v1/providers/sfcomputev2/instance.go b/v1/providers/sfcomputev2/instance.go new file mode 100644 index 0000000..fea8b53 --- /dev/null +++ b/v1/providers/sfcomputev2/instance.go @@ -0,0 +1,259 @@ +package v2 + +import ( + "context" + "encoding/base64" + "fmt" + "maps" + "slices" + "time" + + "github.com/alecthomas/units" + "github.com/brevdev/cloud/internal/errors" + v1 "github.com/brevdev/cloud/v1" + "github.com/sfcompute/sfc-go/models/components" + "github.com/sfcompute/sfc-go/models/operations" + "github.com/sfcompute/sfc-go/optionalnullable" +) + +func (c *SFCClientV2) CreateInstance(ctx context.Context, attrs v1.CreateInstanceAttrs) (*v1.Instance, error) { + c.logger.Debug(ctx, "sfcv2: CreateInstance start", + v1.LogField("name", attrs.Name), + v1.LogField("location", attrs.Location), + ) + + tags := make(map[string]string, len(attrs.Tags)+2) + maps.Copy(tags, attrs.Tags) + tags[tagKeyCloudCredRefID] = c.refID + tags[tagKeyRefID] = attrs.RefID + + cloudInit := sshKeyCloudInit(attrs.PublicKey) + resp, err := c.client.Instances.Create(ctx, components.CreateInstanceRequest{ + Capacity: BrevProductionCapacityID, + Image: BrevProductionImageID, + CloudInitUserData: &cloudInit, + Tags: optionalnullable.From(&tags), + Name: optionalnullable.From(&attrs.Name), + }) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + if resp.InstanceResponse == nil { + return nil, errors.WrapAndTrace(fmt.Errorf("no instance returned from create")) + } + + instance, err := c.sfcInstanceToBrevInstance(resp.InstanceResponse, "") + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + c.logger.Debug(ctx, "sfcv2: CreateInstance end", + v1.LogField("instanceID", resp.InstanceResponse.ID), + ) + + return instance, nil +} + +func sshKeyCloudInit(sshKey string) string { + script := fmt.Sprintf("#cloud-config\nssh_authorized_keys:\n - %s", sshKey) + return base64.StdEncoding.EncodeToString([]byte(script)) +} + +func (c *SFCClientV2) GetInstance(ctx context.Context, id v1.CloudProviderInstanceID) (*v1.Instance, error) { + c.logger.Debug(ctx, "sfcv2: GetInstance start", + v1.LogField("instanceID", id), + ) + + resp, err := c.client.Instances.Fetch(ctx, string(id), nil) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + if resp.InstanceResponse == nil { + return nil, errors.WrapAndTrace(fmt.Errorf("instance %s not found", id)) + } + + sshHostname, err := c.getSSHHostname(ctx, string(id), resp.InstanceResponse.Status) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + instance, err := c.sfcInstanceToBrevInstance(resp.InstanceResponse, sshHostname) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + c.logger.Debug(ctx, "sfcv2: GetInstance end", + v1.LogField("instanceID", id), + v1.LogField("status", resp.InstanceResponse.Status), + ) + + return instance, nil +} + +func (c *SFCClientV2) ListInstances(ctx context.Context, args v1.ListInstancesArgs) ([]v1.Instance, error) { + c.logger.Debug(ctx, "sfcv2: ListInstances start", + v1.LogField("location", c.location), + ) + + capacityID := BrevProductionCapacityID + resp, err := c.client.Instances.List(ctx, operations.ListInstancesRequest{ + Capacity: &capacityID, + }) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + if resp.ListInstancesResponse == nil { + return []v1.Instance{}, nil + } + + var instances []v1.Instance + for _, inst := range resp.ListInstancesResponse.Data { + // Filter by instance IDs if specified. + if len(args.InstanceIDs) > 0 && !slices.Contains(args.InstanceIDs, v1.CloudProviderInstanceID(inst.ID)) { + continue + } + + sshHostname, err := c.getSSHHostname(ctx, inst.ID, inst.Status) + if err != nil { + c.logger.Error(ctx, err, + v1.LogField("msg", "sfcv2: ListInstances skipping instance due to SSH error"), + v1.LogField("instanceID", inst.ID), + ) + continue + } + + brevInst, err := c.sfcInstanceToBrevInstance(&inst, sshHostname) + if err != nil { + c.logger.Error(ctx, err, + v1.LogField("msg", "sfcv2: ListInstances skipping instance due to conversion error"), + v1.LogField("instanceID", inst.ID), + ) + continue + } + instances = append(instances, *brevInst) + } + + c.logger.Debug(ctx, "sfcv2: ListInstances end", + v1.LogField("instance count", len(instances)), + ) + + return instances, nil +} + +func (c *SFCClientV2) TerminateInstance(ctx context.Context, id v1.CloudProviderInstanceID) error { + c.logger.Debug(ctx, "sfcv2: TerminateInstance start", + v1.LogField("instanceID", id), + ) + + _, err := c.client.Instances.TerminateInstance(ctx, string(id)) + if err != nil { + return errors.WrapAndTrace(err) + } + + c.logger.Debug(ctx, "sfcv2: TerminateInstance end", + v1.LogField("instanceID", id), + ) + + return nil +} + +func (c *SFCClientV2) getSSHHostname(ctx context.Context, id string, status components.InstanceStatus) (string, error) { + if status != components.InstanceStatusRunning { + return "", nil + } + + resp, err := c.client.Instances.GetSSHInfoForInstance(ctx, id) + if err != nil { + return "", errors.WrapAndTrace(err) + } + if resp.InstanceSSHInfo == nil { + return "", nil + } + + return resp.InstanceSSHInfo.Hostname, nil +} + +func (c *SFCClientV2) sfcInstanceToBrevInstance(inst *components.InstanceResponse, sshHostname string) (*v1.Instance, error) { + tags, _ := inst.GetTags().GetOrZero() + + cloudCredRefID := tags[tagKeyCloudCredRefID] + if cloudCredRefID == "" { + cloudCredRefID = c.refID + } + + userTags := make(v1.Tags) + for k, v := range tags { + switch k { + case tagKeyCloudCredRefID, tagKeyRefID: + default: + userTags[k] = v + } + } + + status := sfcStatusToLifecycleStatus(inst.Status) + + diskInt64, err := h100InstanceTypeMetadata.diskBytes.ByteCountInUnitInt64(v1.Gibibyte) + if err != nil { + return nil, err + } + diskSize := units.Base2Bytes(diskInt64 * int64(units.Gibibyte)) + + return &v1.Instance{ + Name: inst.Name, + CloudID: v1.CloudProviderInstanceID(inst.ID), + RefID: tags[tagKeyRefID], + PublicDNS: sshHostname, + PublicIP: sshHostname, + SSHUser: defaultSSHUsername, + SSHPort: defaultPort, + CreatedAt: time.Unix(inst.CreatedAt, 0), + DiskSize: diskSize, + DiskSizeBytes: h100InstanceTypeMetadata.diskBytes, + Status: v1.Status{ + LifecycleStatus: status, + }, + InstanceTypeID: h100InstanceTypeMetadata.instanceTypeID, + InstanceType: h100InstanceType, + Location: sfcLocation, + Spot: false, + Stoppable: false, + Rebootable: false, + CloudCredRefID: cloudCredRefID, + Tags: userTags, + }, nil +} + +func sfcStatusToLifecycleStatus(status components.InstanceStatus) v1.LifecycleStatus { + switch status { + case components.InstanceStatusAwaitingAllocation: + return v1.LifecycleStatusPending + case components.InstanceStatusRunning: + return v1.LifecycleStatusRunning + case components.InstanceStatusTerminated: + return v1.LifecycleStatusTerminated + case components.InstanceStatusFailed: + return v1.LifecycleStatusFailed + default: + return v1.LifecycleStatusPending + } +} + +func (c *SFCClientV2) RebootInstance(_ context.Context, _ v1.CloudProviderInstanceID) error { + return v1.ErrNotImplemented +} + +func (c *SFCClientV2) StopInstance(_ context.Context, _ v1.CloudProviderInstanceID) error { + return v1.ErrNotImplemented +} + +func (c *SFCClientV2) StartInstance(_ context.Context, _ v1.CloudProviderInstanceID) error { + return v1.ErrNotImplemented +} + +func (c *SFCClientV2) MergeInstanceForUpdate(_ v1.Instance, newInst v1.Instance) v1.Instance { + return newInst +} + +func (c *SFCClientV2) MergeInstanceTypeForUpdate(_ v1.InstanceType, newIt v1.InstanceType) v1.InstanceType { + return newIt +} diff --git a/v1/providers/sfcomputev2/instancetype.go b/v1/providers/sfcomputev2/instancetype.go new file mode 100644 index 0000000..af06792 --- /dev/null +++ b/v1/providers/sfcomputev2/instancetype.go @@ -0,0 +1,202 @@ +package v2 + +import ( + "context" + "fmt" + "time" + + "github.com/alecthomas/units" + "github.com/bojanz/currency" + "github.com/brevdev/cloud/internal/errors" + v1 "github.com/brevdev/cloud/v1" + "github.com/sfcompute/sfc-go/models/components" + "github.com/sfcompute/sfc-go/models/operations" +) + +const ( + h100InstanceType = "h100.ib" + sfcVCPU = 112 + sfcGPUCount = 8 + sfcLocation = "sfc" + diskTypeSSD = "ssd" + formFactorSXM5 = "sxm5" +) + +type sfcInstanceTypeMetadata struct { + diskBytes v1.Bytes + memoryBytes v1.Bytes + gpuVRAM v1.Bytes + vcpu int32 + gpuCount int32 + gpuManufacturer v1.Manufacturer + architecture v1.Architecture + deployTime time.Duration + price currency.Amount + instanceTypeID v1.InstanceTypeID +} + +var h100InstanceTypeMetadata = func() sfcInstanceTypeMetadata { + price, err := currency.NewAmount("16.00", "USD") + if err != nil { + panic(err) + } + m := sfcInstanceTypeMetadata{ + diskBytes: v1.NewBytes(1500, v1.Gigabyte), + memoryBytes: v1.NewBytes(960, v1.Gigabyte), + gpuVRAM: v1.NewBytes(80, v1.Gigabyte), + vcpu: sfcVCPU, + gpuCount: sfcGPUCount, + gpuManufacturer: v1.ManufacturerNVIDIA, + architecture: v1.ArchitectureX86_64, + deployTime: 14 * time.Minute, + price: price, + } + + // Compute the instance type ID from a representative InstanceType so it matches + // what Brev expects when validating or storing the type. + it := buildInstanceType(m, true) + m.instanceTypeID = it.ID + return m +}() + +func buildInstanceType(m sfcInstanceTypeMetadata, isAvailable bool) v1.InstanceType { + ramInt64, _ := m.memoryBytes.ByteCountInUnitInt64(v1.Gibibyte) + ram := units.Base2Bytes(ramInt64 * int64(units.Gibibyte)) + + vramInt64, _ := m.gpuVRAM.ByteCountInUnitInt64(v1.Gibibyte) + vram := units.Base2Bytes(vramInt64 * int64(units.Gibibyte)) + + diskInt64, _ := m.diskBytes.ByteCountInUnitInt64(v1.Gibibyte) + diskSize := units.Base2Bytes(diskInt64 * int64(units.Gibibyte)) + + it := v1.InstanceType{ + IsAvailable: isAvailable, + Type: h100InstanceType, + Memory: ram, + MemoryBytes: m.memoryBytes, + VCPU: m.vcpu, + Location: sfcLocation, + Stoppable: false, + Rebootable: false, + IsContainer: false, + Provider: CloudProviderID, + BasePrice: &m.price, + EstimatedDeployTime: &m.deployTime, + SupportedGPUs: []v1.GPU{{ + Count: m.gpuCount, + Type: "H100", + Manufacturer: m.gpuManufacturer, + Name: "H100", + Memory: vram, + MemoryBytes: m.gpuVRAM, + NetworkDetails: formFactorSXM5, + }}, + SupportedStorage: []v1.Storage{{ + Type: diskTypeSSD, + Count: 1, + Size: diskSize, + SizeBytes: m.diskBytes, + }}, + SupportedArchitectures: []v1.Architecture{m.architecture}, + } + it.ID = v1.MakeGenericInstanceTypeID(it) + return it +} + +func (c *SFCClientV2) GetInstanceTypes(ctx context.Context, args v1.GetInstanceTypeArgs) ([]v1.InstanceType, error) { + c.logger.Debug(ctx, "sfcv2: GetInstanceTypes start", + v1.LogField("location", c.location), + ) + + available, err := c.availableSlots(ctx) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + if available <= 0 { + c.logger.Debug(ctx, "sfcv2: GetInstanceTypes no available slots") + return []v1.InstanceType{}, nil + } + + instanceType := buildInstanceType(h100InstanceTypeMetadata, true) + + if !v1.IsSelectedByArgs(instanceType, args) { + return []v1.InstanceType{}, nil + } + + c.logger.Debug(ctx, "sfcv2: GetInstanceTypes end", + v1.LogField("available slots", available), + ) + + return []v1.InstanceType{instanceType}, nil +} + +// availableSlots returns how many more instances can be created in the configured capacity. +// It subtracts the count of non-terminated instances from the current capacity allocation. +func (c *SFCClientV2) availableSlots(ctx context.Context) (int, error) { + allocated, err := c.currentCapacityAllocation(ctx) + if err != nil { + return 0, errors.WrapAndTrace(err) + } + + active, err := c.activeInstanceCount(ctx) + if err != nil { + return 0, errors.WrapAndTrace(err) + } + + return max(allocated-active, 0), nil +} + +// currentCapacityAllocation returns the NodeAllocation from the most recent schedule entry +// in BrevProductionCapacityID that is currently in effect (EffectiveAt <= now). +func (c *SFCClientV2) currentCapacityAllocation(ctx context.Context) (int, error) { + resp, err := c.client.Capacities.Fetch(ctx, BrevProductionCapacityID, nil, nil) + if err != nil { + return 0, errors.WrapAndTrace(err) + } + if resp.CapacityResponse == nil { + return 0, nil + } + + now := time.Now().Unix() + allocation := 0 + latestAt := int64(-1) + for _, entry := range resp.CapacityResponse.AllocationSchedule.Total { + if entry.EffectiveAt <= now && entry.EffectiveAt > latestAt { + latestAt = entry.EffectiveAt + allocation = entry.NodeAllocation + } + } + return allocation, nil +} + +// activeInstanceCount returns the number of non-terminated instances in BrevProductionCapacityID. +// All non-terminated instances occupy a slot in the capacity, including failed ones. +func (c *SFCClientV2) activeInstanceCount(ctx context.Context) (int, error) { + capacityID := BrevProductionCapacityID + resp, err := c.client.Instances.List(ctx, operations.ListInstancesRequest{ + Capacity: &capacityID, + }) + if err != nil { + return 0, errors.WrapAndTrace(err) + } + if resp.ListInstancesResponse == nil { + return 0, nil + } + + count := 0 + for _, inst := range resp.ListInstancesResponse.Data { + if inst.Status != components.InstanceStatusTerminated { + count++ + } + } + return count, nil +} + +func (c *SFCClientV2) GetLocations(_ context.Context, _ v1.GetLocationsArgs) ([]v1.Location, error) { + return []v1.Location{{ + Name: sfcLocation, + Description: fmt.Sprintf("sfc_%s_h100", sfcLocation), + Available: true, + }}, nil +} diff --git a/v1/providers/sfcomputev2/validation_test.go b/v1/providers/sfcomputev2/validation_test.go new file mode 100644 index 0000000..db4c265 --- /dev/null +++ b/v1/providers/sfcomputev2/validation_test.go @@ -0,0 +1,50 @@ +package v2 + +import ( + "os" + "testing" + + "github.com/brevdev/cloud/internal/validation" + v1 "github.com/brevdev/cloud/v1" +) + +func TestValidationFunctions(t *testing.T) { + t.Parallel() + checkSkip(t) + + config := validation.ProviderConfig{ + Credential: NewSFCCredentialV2("validation-test", getAPIKey()), + StableIDs: []v1.InstanceTypeID{ + h100InstanceTypeMetadata.instanceTypeID, + }, + } + + validation.RunValidationSuite(t, config) +} + +func TestInstanceLifecycleValidation(t *testing.T) { + t.Parallel() + checkSkip(t) + + config := validation.ProviderConfig{ + Credential: NewSFCCredentialV2("validation-test", getAPIKey()), + Location: sfcLocation, + } + + validation.RunInstanceLifecycleValidation(t, config) +} + +func checkSkip(t *testing.T) { + t.Helper() + apiKey := getAPIKey() + isValidationTest := os.Getenv("VALIDATION_TEST") + if apiKey == "" && isValidationTest != "" { + t.Fatal("SFCOMPUTE_API_KEY not set, but VALIDATION_TEST is set") + } else if apiKey == "" { + t.Skip("SFCOMPUTE_API_KEY not set, skipping sfcomputev2 validation tests") + } +} + +func getAPIKey() string { + return os.Getenv("SFCOMPUTE_API_KEY") +}