Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions pkg/cmd/gpucreate/gpucreate.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,38 @@ type typeCreateResult struct {
fatalError error
}

// validateInstanceTypeAvailability errors when the type is invalid or has no capacity. Returns nil when the listing is missing.
func (c *createContext) validateInstanceTypeAvailability(instanceType string) error {
if c.allInstanceTypes == nil {
return nil
}
if c.allInstanceTypes.GetWorkspaceGroupID(instanceType) != "" {
return nil
}
if !c.allInstanceTypes.HasInstanceType(instanceType) {
return breverrors.NewValidationError(fmt.Sprintf(
"instance type %q is not a recognized type; run 'brev search' to see available types",
instanceType,
))
}
return breverrors.NewValidationError(fmt.Sprintf(
"instance type %q is currently unavailable (no capacity); try again later or run 'brev search' to find another type",
instanceType,
))
}

// createInstancesWithType attempts to create instances using a specific type
func (c *createContext) createInstancesWithType(spec InstanceSpec, startIdx, count int) typeCreateResult {
result := typeCreateResult{}

if c.opts.LaunchableID == "" {
if err := c.validateInstanceTypeAvailability(spec.Type); err != nil {
c.logf("Skipping: %s\n", err.Error())
result.hadFailure = true
return result
}
}

var mu sync.Mutex
var wg sync.WaitGroup

Expand Down Expand Up @@ -1003,6 +1031,19 @@ func (c *createContext) createWorkspace(name string, spec InstanceSpec) (*entity
}
}

if cwOptions.WorkspaceGroupID == "" {
if c.allInstanceTypes == nil {
return nil, breverrors.NewValidationError(fmt.Sprintf(
"could not resolve workspace group for %q (instance-type listing was unavailable); please retry",
spec.Type,
))
}
return nil, breverrors.NewValidationError(fmt.Sprintf(
"instance type %q is invalid or unavailable; run 'brev search' to see available types",
spec.Type,
))
}

workspace, err := c.store.CreateWorkspace(c.org.ID, cwOptions)
if err != nil {
return nil, breverrors.WrapAndTrace(err)
Expand Down
156 changes: 156 additions & 0 deletions pkg/cmd/gpucreate/gpucreate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/brevdev/brev-cli/pkg/cmd/gpusearch"
"github.com/brevdev/brev-cli/pkg/entity"
breverrors "github.com/brevdev/brev-cli/pkg/errors"
"github.com/brevdev/brev-cli/pkg/store"
"github.com/brevdev/brev-cli/pkg/terminal"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -665,3 +666,158 @@ func TestPollUntilReadyReportsWorkspaceFailureMessage(t *testing.T) {

assert.ErrorContains(t, err, "instance test failed: unexpected end of JSON input")
}

func TestValidateInstanceTypeAvailability(t *testing.T) {
t.Run("returns nil when listing is unavailable", func(t *testing.T) {
ctx := &createContext{}
assert.NoError(t, ctx.validateInstanceTypeAvailability("hyperstack_H100x8_one"))
})

t.Run("returns nil when type has a workspace group", func(t *testing.T) {
ctx := &createContext{
allInstanceTypes: &gpusearch.AllInstanceTypesResponse{
AllInstanceTypes: []gpusearch.InstanceType{
{
Type: "hyperstack_H100_sxm5x8",
WorkspaceGroups: []gpusearch.WorkspaceGroup{
{ID: "wg-1", Name: "Shadeform", PlatformType: "shadeform"},
},
},
},
},
}
assert.NoError(t, ctx.validateInstanceTypeAvailability("hyperstack_H100_sxm5x8"))
})

t.Run("returns invalid-type error for unknown type", func(t *testing.T) {
ctx := &createContext{
allInstanceTypes: &gpusearch.AllInstanceTypesResponse{
AllInstanceTypes: []gpusearch.InstanceType{
{Type: "hyperstack_H100_sxm5x8", WorkspaceGroups: []gpusearch.WorkspaceGroup{{ID: "wg-1"}}},
},
},
}
err := ctx.validateInstanceTypeAvailability("hyperstack_H100x8_one")
assert.Error(t, err)
assert.Contains(t, err.Error(), `"hyperstack_H100x8_one"`)
assert.Contains(t, err.Error(), "not a recognized type")
assert.Contains(t, err.Error(), "brev search")
})

t.Run("returns unavailable error for known type without workspace groups", func(t *testing.T) {
ctx := &createContext{
allInstanceTypes: &gpusearch.AllInstanceTypesResponse{
AllInstanceTypes: []gpusearch.InstanceType{
{Type: "hyperstack_H100x8_NVLINK", WorkspaceGroups: nil},
},
},
}
err := ctx.validateInstanceTypeAvailability("hyperstack_H100x8_NVLINK")
assert.Error(t, err)
assert.Contains(t, err.Error(), `"hyperstack_H100x8_NVLINK"`)
assert.Contains(t, err.Error(), "currently unavailable")
assert.Contains(t, err.Error(), "brev search")
})

t.Run("error type is ValidationError so no stack trace is appended", func(t *testing.T) {
ctx := &createContext{
allInstanceTypes: &gpusearch.AllInstanceTypesResponse{
AllInstanceTypes: []gpusearch.InstanceType{},
},
}
err := ctx.validateInstanceTypeAvailability("missing")
assert.Error(t, err)
var ve breverrors.ValidationError
assert.ErrorAs(t, err, &ve)
assert.NotContains(t, err.Error(), "gpucreate.go", "validation error should not include source-file traces")
})
}

func TestCreateInstancesWithTypeSkipsInvalidType(t *testing.T) {
mock := NewMockGPUCreateStore()
ctx := &createContext{
t: terminal.New(),
store: mock,
opts: GPUCreateOptions{Count: 1, Parallel: 1, Name: "jt-4"},
org: mock.Org,
user: mock.User,
piped: true,
allInstanceTypes: &gpusearch.AllInstanceTypesResponse{
AllInstanceTypes: []gpusearch.InstanceType{
{
Type: "hyperstack_H100_sxm5x8",
WorkspaceGroups: []gpusearch.WorkspaceGroup{
{ID: "wg-shadeform", Name: "Shadeform", PlatformType: "shadeform"},
},
},
},
},
}
ctx.logf = func(_ string, _ ...interface{}) {}

result := ctx.createInstancesWithType(InstanceSpec{Type: "hyperstack_H100x8_one"}, 0, 1)

assert.True(t, result.hadFailure, "expected hadFailure for an invalid instance type")
assert.Empty(t, result.successes, "expected no successes for invalid type")
assert.NoError(t, result.fatalError, "invalid type should not be fatal — caller may try the next type")
assert.Empty(t, mock.CreatedWorkspaces, "CreateWorkspace must not be called when the type is unrecognized")
}

func TestCreateInstancesWithTypeSkipsUnavailableType(t *testing.T) {
mock := NewMockGPUCreateStore()
ctx := &createContext{
t: terminal.New(),
store: mock,
opts: GPUCreateOptions{Count: 1, Parallel: 1, Name: "jt-4"},
org: mock.Org,
user: mock.User,
piped: true,
allInstanceTypes: &gpusearch.AllInstanceTypesResponse{
AllInstanceTypes: []gpusearch.InstanceType{
{Type: "hyperstack_H100x8_NVLINK", WorkspaceGroups: nil},
},
},
}
ctx.logf = func(_ string, _ ...interface{}) {}

result := ctx.createInstancesWithType(InstanceSpec{Type: "hyperstack_H100x8_NVLINK"}, 0, 1)

assert.True(t, result.hadFailure, "expected hadFailure for an unavailable instance type")
assert.Empty(t, result.successes)
assert.Empty(t, mock.CreatedWorkspaces, "CreateWorkspace must not be called when no workspace group is available")
}

func TestCreateInstancesWithTypeBypassesValidationForLaunchable(t *testing.T) {
mock := NewMockGPUCreateStore()
ctx := &createContext{
t: terminal.New(),
store: mock,
opts: GPUCreateOptions{
Count: 1,
Parallel: 1,
Name: "jt-4",
LaunchableID: "env-abc",
LaunchableInfo: &store.LaunchableResponse{
ID: "env-abc",
Name: "test-launchable",
CreateWorkspaceRequest: store.LaunchableWorkspaceRequest{
WorkspaceGroupID: "wg-from-launchable",
InstanceType: "n2-standard-4",
},
},
},
org: mock.Org,
user: mock.User,
piped: true,
allInstanceTypes: &gpusearch.AllInstanceTypesResponse{
AllInstanceTypes: []gpusearch.InstanceType{}, // launchable's type is not in the org listing
},
}
ctx.logf = func(_ string, _ ...interface{}) {}

result := ctx.createInstancesWithType(InstanceSpec{Type: "n2-standard-4"}, 0, 1)

assert.False(t, result.hadFailure, "launchable should not be blocked by pre-flight validation")
assert.Len(t, result.successes, 1, "expected the launchable instance to be created")
assert.Len(t, mock.CreatedWorkspaces, 1)
}
10 changes: 10 additions & 0 deletions pkg/cmd/gpusearch/gpusearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ func (r *AllInstanceTypesResponse) GetWorkspaceGroupID(instanceType string) stri
return ""
}

// HasInstanceType reports whether the type exists in the API listing, independent of capacity.
func (r *AllInstanceTypesResponse) HasInstanceType(instanceType string) bool {
for _, it := range r.AllInstanceTypes {
if it.Type == instanceType {
return true
}
}
return false
}

// GPUSearchStore defines the interface for fetching instance types
type GPUSearchStore interface {
GetInstanceTypes(includeCPU bool) (*InstanceTypesResponse, error)
Expand Down
45 changes: 45 additions & 0 deletions pkg/cmd/gpusearch/gpusearch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,48 @@ func TestProcessInstancesCloudExtraction(t *testing.T) {
assert.Equal(t, "nebius", instances[1].Cloud)
assert.Equal(t, "nebius", instances[1].Provider)
}

func TestAllInstanceTypesResponseLookup(t *testing.T) {
resp := &AllInstanceTypesResponse{
AllInstanceTypes: []InstanceType{
{
Type: "hyperstack_H100_sxm5x8",
WorkspaceGroups: []WorkspaceGroup{
{ID: "wg-shadeform", Name: "Shadeform", PlatformType: "shadeform"},
},
},
{
Type: "hyperstack_H100x8_NVLINK",
WorkspaceGroups: nil,
},
{
Type: "verda-b300-8x",
WorkspaceGroups: []WorkspaceGroup{},
},
},
}

t.Run("GetWorkspaceGroupID returns id when type has groups", func(t *testing.T) {
assert.Equal(t, "wg-shadeform", resp.GetWorkspaceGroupID("hyperstack_H100_sxm5x8"))
})

t.Run("GetWorkspaceGroupID returns empty for type without groups", func(t *testing.T) {
assert.Equal(t, "", resp.GetWorkspaceGroupID("hyperstack_H100x8_NVLINK"))
assert.Equal(t, "", resp.GetWorkspaceGroupID("verda-b300-8x"))
})

t.Run("GetWorkspaceGroupID returns empty for unknown type", func(t *testing.T) {
assert.Equal(t, "", resp.GetWorkspaceGroupID("hyperstack_H100x8_one"))
})

t.Run("HasInstanceType is true even when groups are empty", func(t *testing.T) {
assert.True(t, resp.HasInstanceType("hyperstack_H100_sxm5x8"))
assert.True(t, resp.HasInstanceType("hyperstack_H100x8_NVLINK"))
assert.True(t, resp.HasInstanceType("verda-b300-8x"))
})

t.Run("HasInstanceType is false for unknown type", func(t *testing.T) {
assert.False(t, resp.HasInstanceType("hyperstack_H100x8_one"))
assert.False(t, resp.HasInstanceType(""))
})
}
Loading