diff --git a/cli/azd/cmd/middleware/debug_test.go b/cli/azd/cmd/middleware/debug_test.go new file mode 100644 index 00000000000..1c93b021cea --- /dev/null +++ b/cli/azd/cmd/middleware/debug_test.go @@ -0,0 +1,214 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package middleware + +import ( + "context" + "errors" + "testing" + + "github.com/azure/azure-dev/cli/azd/cmd/actions" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/test/mocks" + "github.com/azure/azure-dev/cli/azd/test/mocks/mockinput" + "github.com/stretchr/testify/require" +) + +func TestNewDebugMiddleware(t *testing.T) { + mc := mockinput.NewMockConsole() + m := NewDebugMiddleware(&Options{}, mc) + require.NotNil(t, m) + + dm, ok := m.(*DebugMiddleware) + require.True(t, ok) + require.NotNil(t, dm.options) + require.NotNil(t, dm.console) +} + +func TestDebugMiddleware_Run_ChildAction(t *testing.T) { + mc := mockinput.NewMockConsole() + m := &DebugMiddleware{ + options: &Options{CommandPath: "test"}, + console: mc, + } + + nextCalled := false + nextFn := func( + ctx context.Context, + ) (*actions.ActionResult, error) { + nextCalled = true + return &actions.ActionResult{}, nil + } + + ctx := WithChildAction(context.Background()) + result, err := m.Run(ctx, nextFn) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, nextCalled, + "child action should skip debug and call next") +} + +func TestDebugMiddleware_Run_NoEnvVar(t *testing.T) { + mc := mockinput.NewMockConsole() + m := &DebugMiddleware{ + options: &Options{CommandPath: "test"}, + console: mc, + } + + // Ensure AZD_DEBUG is not set + t.Setenv("AZD_DEBUG", "") + + nextCalled := false + nextFn := func( + ctx context.Context, + ) (*actions.ActionResult, error) { + nextCalled = true + return &actions.ActionResult{}, nil + } + + result, err := m.Run(context.Background(), nextFn) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, nextCalled, + "no AZD_DEBUG should skip debug and call next") +} + +func TestDebugMiddleware_Run_EnvVarFalse(t *testing.T) { + mc := mockinput.NewMockConsole() + m := &DebugMiddleware{ + options: &Options{CommandPath: "test"}, + console: mc, + } + + t.Setenv("AZD_DEBUG", "false") + + nextCalled := false + nextFn := func( + ctx context.Context, + ) (*actions.ActionResult, error) { + nextCalled = true + return &actions.ActionResult{}, nil + } + + result, err := m.Run(context.Background(), nextFn) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, nextCalled) +} + +func TestDebugMiddleware_Run_EnvVarInvalid(t *testing.T) { + mc := mockinput.NewMockConsole() + m := &DebugMiddleware{ + options: &Options{CommandPath: "test"}, + console: mc, + } + + t.Setenv("AZD_DEBUG", "not-a-bool") + + nextCalled := false + nextFn := func( + ctx context.Context, + ) (*actions.ActionResult, error) { + nextCalled = true + return &actions.ActionResult{}, nil + } + + result, err := m.Run(context.Background(), nextFn) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, nextCalled, + "invalid bool should parse as false and call next") +} + +func TestDebugMiddleware_Run_TelemetryCommand(t *testing.T) { + mc := mockinput.NewMockConsole() + m := &DebugMiddleware{ + options: &Options{ + CommandPath: "azd telemetry upload", + }, + console: mc, + } + + // AZD_DEBUG is set but we're running a telemetry command. + // It checks AZD_DEBUG_TELEMETRY instead. + t.Setenv("AZD_DEBUG", "true") + t.Setenv("AZD_DEBUG_TELEMETRY", "") + + nextCalled := false + nextFn := func( + ctx context.Context, + ) (*actions.ActionResult, error) { + nextCalled = true + return &actions.ActionResult{}, nil + } + + result, err := m.Run(context.Background(), nextFn) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, nextCalled, + "telemetry checks AZD_DEBUG_TELEMETRY, not AZD_DEBUG") +} + +func TestDebugMiddleware_Run_ConfirmDeclined(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.Console. + WhenConfirm(func(options input.ConsoleOptions) bool { + return true + }). + Respond(false) + + m := &DebugMiddleware{ + options: &Options{CommandPath: "test"}, + console: mockContext.Console, + } + + t.Setenv("AZD_DEBUG", "true") + + nextFn := func( + ctx context.Context, + ) (*actions.ActionResult, error) { + t.Fatal("next should not be called when declined") + return nil, nil + } + + _, err := m.Run(*mockContext.Context, nextFn) + + require.Error(t, err) + require.True(t, errors.Is(err, ErrDebuggerAborted)) +} + +func TestDebugMiddleware_Run_ConfirmAccepted(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.Console. + WhenConfirm(func(options input.ConsoleOptions) bool { + return true + }). + Respond(true) + + m := &DebugMiddleware{ + options: &Options{CommandPath: "test"}, + console: mockContext.Console, + } + + t.Setenv("AZD_DEBUG", "true") + + nextCalled := false + nextFn := func( + ctx context.Context, + ) (*actions.ActionResult, error) { + nextCalled = true + return &actions.ActionResult{}, nil + } + + result, err := m.Run(*mockContext.Context, nextFn) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, nextCalled) +} diff --git a/cli/azd/cmd/middleware/middleware_helpers_test.go b/cli/azd/cmd/middleware/middleware_helpers_test.go new file mode 100644 index 00000000000..94d5ef64ae6 --- /dev/null +++ b/cli/azd/cmd/middleware/middleware_helpers_test.go @@ -0,0 +1,200 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package middleware + +import ( + "context" + "testing" + + "github.com/azure/azure-dev/cli/azd/cmd/actions" + "github.com/azure/azure-dev/cli/azd/test/mocks" + "github.com/stretchr/testify/require" +) + +func TestWithChildAction_IsChildAction(t *testing.T) { + tests := []struct { + name string + ctx context.Context + want bool + }{ + { + name: "PlainContext", + ctx: context.Background(), + want: false, + }, + { + name: "ChildActionContext", + ctx: WithChildAction(context.Background()), + want: true, + }, + { + name: "NestedChildActionContext", + ctx: WithChildAction( + WithChildAction(context.Background()), + ), + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, IsChildAction(tt.ctx)) + }) + } +} + +func TestIsChildAction_WrongType(t *testing.T) { + // Manually set a non-bool value under the same key + ctx := context.WithValue( + context.Background(), + childActionKey, + "not-a-bool", + ) + require.False(t, IsChildAction(ctx)) +} + +func TestIsChildAction_FalseValue(t *testing.T) { + ctx := context.WithValue( + context.Background(), + childActionKey, + false, + ) + require.False(t, IsChildAction(ctx)) +} + +func TestOptions_WithContainer(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + + opts := &Options{ + Name: "test", + } + require.Nil(t, opts.container) + + opts.WithContainer(mockContext.Container) + require.NotNil(t, opts.container) + require.Equal(t, mockContext.Container, opts.container) +} + +func TestMiddlewareRunner_Use_AddsSingleEntry(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + runner := NewMiddlewareRunner(mockContext.Container) + + err := runner.Use("single", func() Middleware { + return middlewareFunc( + func( + ctx context.Context, + next NextFn, + ) (*actions.ActionResult, error) { + return next(ctx) + }, + ) + }) + require.NoError(t, err) + require.Len(t, runner.chain, 1) + require.Equal(t, "single", runner.chain[0]) +} + +func TestMiddlewareRunner_Use_MultipleMiddleware(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + runner := NewMiddlewareRunner(mockContext.Container) + + for _, name := range []string{"first", "second", "third"} { + err := runner.Use(name, func() Middleware { + return middlewareFunc( + func( + ctx context.Context, + next NextFn, + ) (*actions.ActionResult, error) { + return next(ctx) + }, + ) + }) + require.NoError(t, err) + } + + require.Len(t, runner.chain, 3) + require.Equal(t, + []string{"first", "second", "third"}, + runner.chain, + ) +} + +func TestMiddlewareRunner_RunAction_WithOptionsContainer(t *testing.T) { + // Verify that Options.container is used when set + mockContext := mocks.NewMockContext(context.Background()) + runner := NewMiddlewareRunner(mockContext.Container) + + actionRan := false + err := mockContext.Container.RegisterNamedTransient( + "test-action", func() actions.Action { + return &testAction{ + runFunc: func( + ctx context.Context, + ) (*actions.ActionResult, error) { + actionRan = true + return &actions.ActionResult{ + Message: &actions.ResultMessage{ + Header: "OK", + }, + }, nil + }, + } + }) + require.NoError(t, err) + + opts := &Options{ + Name: "test", + container: mockContext.Container, + } + + result, err := runner.RunAction( + *mockContext.Context, opts, "test-action") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, actionRan) +} + +func TestMiddlewareRunner_RunAction_NoMiddleware(t *testing.T) { + // When no middleware is registered, the action runs directly + mockContext := mocks.NewMockContext(context.Background()) + runner := NewMiddlewareRunner(mockContext.Container) + + err := mockContext.Container.RegisterNamedTransient( + "direct-action", func() actions.Action { + return &testAction{ + runFunc: func( + ctx context.Context, + ) (*actions.ActionResult, error) { + return &actions.ActionResult{ + Message: &actions.ResultMessage{ + Header: "Direct", + }, + }, nil + }, + } + }) + require.NoError(t, err) + + result, err := runner.RunAction( + *mockContext.Context, + &Options{Name: "test"}, + "direct-action", + ) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "Direct", result.Message.Header) +} + +func TestMiddlewareRunner_RunAction_InvalidAction(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + runner := NewMiddlewareRunner(mockContext.Container) + + // Don't register any action — resolution should fail + _, err := runner.RunAction( + *mockContext.Context, + &Options{Name: "test"}, + "nonexistent-action", + ) + require.Error(t, err) +} diff --git a/cli/azd/extensions/microsoft.azd.extensions/internal/helpers_coverage_test.go b/cli/azd/extensions/microsoft.azd.extensions/internal/helpers_coverage_test.go new file mode 100644 index 00000000000..2b3bb7247e7 --- /dev/null +++ b/cli/azd/extensions/microsoft.azd.extensions/internal/helpers_coverage_test.go @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "archive/zip" + "io" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToPascalCase(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "SingleWord", + input: "hello", + expected: "Hello", + }, + { + name: "DotSeparated", + input: "azure.ai.models", + expected: "Azure.Ai.Models", + }, + { + name: "AlreadyPascal", + input: "Hello.World", + expected: "Hello.World", + }, + { + name: "EmptyString", + input: "", + expected: "", + }, + { + name: "SingleChar", + input: "a", + expected: "A", + }, + { + name: "SingleDot", + input: "a.b", + expected: "A.B", + }, + { + name: "EmptyParts", + input: "a..b", + expected: "A..B", + }, + { + name: "TrailingDot", + input: "hello.", + expected: "Hello.", + }, + { + name: "LeadingDot", + input: ".hello", + expected: ".Hello", + }, + { + name: "NoDots", + input: "helloworld", + expected: "Helloworld", + }, + { + name: "Unicode", + input: "über.straße", + expected: "Über.Straße", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ToPascalCase(tt.input) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestComputeChecksum(t *testing.T) { + t.Run("ValidFile", func(t *testing.T) { + tempDir := t.TempDir() + filePath := filepath.Join(tempDir, "test.txt") + require.NoError(t, os.WriteFile( + filePath, []byte("hello world"), 0600, + )) + + checksum, err := ComputeChecksum(filePath) + require.NoError(t, err) + // SHA-256 of "hello world" + require.Equal(t, + "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", + checksum, + ) + }) + + t.Run("EmptyFile", func(t *testing.T) { + tempDir := t.TempDir() + filePath := filepath.Join(tempDir, "empty.txt") + require.NoError(t, os.WriteFile( + filePath, []byte{}, 0600, + )) + + checksum, err := ComputeChecksum(filePath) + require.NoError(t, err) + // SHA-256 of empty input + require.Equal(t, + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + checksum, + ) + }) + + t.Run("NonExistentFile", func(t *testing.T) { + _, err := ComputeChecksum("/nonexistent/file.txt") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to open file") + }) +} + +func TestCopyFile(t *testing.T) { + t.Run("SuccessfulCopy", func(t *testing.T) { + tempDir := t.TempDir() + srcPath := filepath.Join(tempDir, "source.txt") + dstPath := filepath.Join(tempDir, "dest.txt") + + content := []byte("test file content") + require.NoError(t, os.WriteFile(srcPath, content, 0600)) + + err := CopyFile(srcPath, dstPath) + require.NoError(t, err) + + copied, err := os.ReadFile(dstPath) + require.NoError(t, err) + require.Equal(t, content, copied) + }) + + t.Run("SourceNotFound", func(t *testing.T) { + tempDir := t.TempDir() + err := CopyFile( + filepath.Join(tempDir, "nonexistent.txt"), + filepath.Join(tempDir, "dest.txt"), + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to open source file") + }) + + t.Run("OverwriteExisting", func(t *testing.T) { + tempDir := t.TempDir() + srcPath := filepath.Join(tempDir, "source.txt") + dstPath := filepath.Join(tempDir, "dest.txt") + + require.NoError(t, os.WriteFile( + srcPath, []byte("new content"), 0600, + )) + require.NoError(t, os.WriteFile( + dstPath, []byte("old content"), 0600, + )) + + err := CopyFile(srcPath, dstPath) + require.NoError(t, err) + + result, err := os.ReadFile(dstPath) + require.NoError(t, err) + require.Equal(t, "new content", string(result)) + }) + + t.Run("LargeFile", func(t *testing.T) { + tempDir := t.TempDir() + srcPath := filepath.Join(tempDir, "large.bin") + dstPath := filepath.Join(tempDir, "large_copy.bin") + + data := make([]byte, 1024*1024) // 1 MB + for i := range data { + data[i] = byte(i % 256) + } + require.NoError(t, os.WriteFile(srcPath, data, 0600)) + + err := CopyFile(srcPath, dstPath) + require.NoError(t, err) + + copied, err := os.ReadFile(dstPath) + require.NoError(t, err) + require.Equal(t, data, copied) + }) +} + +func TestZipSource(t *testing.T) { + // ZipSource has a known file-handle leak on Windows (os.Open without + // Close), so t.TempDir() cleanup fails. Use os.MkdirTemp + manual + // cleanup that tolerates locked files. + + t.Run("CreateZipWithMultipleFiles", func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "ziptest-*") + require.NoError(t, err) + t.Cleanup(func() { _ = os.RemoveAll(tempDir) }) + + file1 := filepath.Join(tempDir, "file1.txt") + file2 := filepath.Join(tempDir, "file2.txt") + targetZip := filepath.Join(tempDir, "output.zip") + + require.NoError(t, os.WriteFile( + file1, []byte("content one"), 0600, + )) + require.NoError(t, os.WriteFile( + file2, []byte("content two"), 0600, + )) + + err = ZipSource([]string{file1, file2}, targetZip) + require.NoError(t, err) + + info, err := os.Stat(targetZip) + require.NoError(t, err) + require.Greater(t, info.Size(), int64(0)) + + reader, err := zip.OpenReader(targetZip) + require.NoError(t, err) + + require.Len(t, reader.File, 2) + require.Equal(t, "file1.txt", reader.File[0].Name) + require.Equal(t, "file2.txt", reader.File[1].Name) + + rc, err := reader.File[0].Open() + require.NoError(t, err) + data, err := io.ReadAll(rc) + require.NoError(t, err) + rc.Close() + require.Equal(t, "content one", string(data)) + + rc, err = reader.File[1].Open() + require.NoError(t, err) + data, err = io.ReadAll(rc) + require.NoError(t, err) + rc.Close() + require.Equal(t, "content two", string(data)) + + reader.Close() + }) + + t.Run("NonExistentSourceFile", func(t *testing.T) { + tempDir := t.TempDir() + err := ZipSource( + []string{filepath.Join(tempDir, "missing.txt")}, + filepath.Join(tempDir, "out.zip"), + ) + require.Error(t, err) + }) + + t.Run("SingleFile", func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "ziptest-*") + require.NoError(t, err) + t.Cleanup(func() { _ = os.RemoveAll(tempDir) }) + + srcFile := filepath.Join(tempDir, "single.txt") + targetZip := filepath.Join(tempDir, "single.zip") + + require.NoError(t, os.WriteFile( + srcFile, []byte("only file"), 0600, + )) + + err = ZipSource([]string{srcFile}, targetZip) + require.NoError(t, err) + + reader, err := zip.OpenReader(targetZip) + require.NoError(t, err) + + require.Len(t, reader.File, 1) + require.Equal(t, "single.txt", reader.File[0].Name) + + reader.Close() + }) +} + +func TestDownloadAssetToTemp_LocalFile(t *testing.T) { + tempDir := t.TempDir() + localFile := filepath.Join(tempDir, "asset.bin") + content := []byte("local asset data") + require.NoError(t, os.WriteFile(localFile, content, 0600)) + + result, err := DownloadAssetToTemp(localFile, "asset.bin") + require.NoError(t, err) + defer os.Remove(result) + + data, err := os.ReadFile(result) + require.NoError(t, err) + require.Equal(t, content, data) +} + +func TestDownloadAssetToTemp_NonExistentLocal(t *testing.T) { + _, err := DownloadAssetToTemp( + "/nonexistent/path/asset.bin", "asset.bin", + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to open local file") +} + +func TestToPtr(t *testing.T) { + t.Run("Int", func(t *testing.T) { + p := ToPtr(42) + require.NotNil(t, p) + require.Equal(t, 42, *p) + }) + + t.Run("String", func(t *testing.T) { + p := ToPtr("hello") + require.NotNil(t, p) + require.Equal(t, "hello", *p) + }) + + t.Run("Bool", func(t *testing.T) { + p := ToPtr(true) + require.NotNil(t, p) + require.True(t, *p) + }) + + t.Run("ZeroValue", func(t *testing.T) { + p := ToPtr(0) + require.NotNil(t, p) + require.Equal(t, 0, *p) + }) +} diff --git a/cli/azd/extensions/microsoft.azd.extensions/internal/models/extension_schema_coverage_test.go b/cli/azd/extensions/microsoft.azd.extensions/internal/models/extension_schema_coverage_test.go new file mode 100644 index 00000000000..78850829596 --- /dev/null +++ b/cli/azd/extensions/microsoft.azd.extensions/internal/models/extension_schema_coverage_test.go @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package models + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/extensions" + "github.com/stretchr/testify/require" +) + +func TestExtensionSchema_SafeDashId(t *testing.T) { + tests := []struct { + name string + id string + expected string + }{ + { + name: "DottedId", + id: "azure.ai.models", + expected: "azure-ai-models", + }, + { + name: "NoDots", + id: "simple", + expected: "simple", + }, + { + name: "SingleDot", + id: "azure.ai", + expected: "azure-ai", + }, + { + name: "AlreadyDashed", + id: "azure-ai-models", + expected: "azure-ai-models", + }, + { + name: "EmptyId", + id: "", + expected: "", + }, + { + name: "LeadingDot", + id: ".hidden", + expected: "-hidden", + }, + { + name: "TrailingDot", + id: "ext.", + expected: "ext-", + }, + { + name: "MultipleDots", + id: "a.b.c.d.e", + expected: "a-b-c-d-e", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + schema := &ExtensionSchema{Id: tt.id} + require.Equal(t, tt.expected, schema.SafeDashId()) + }) + } +} + +func TestLoadExtension_Success(t *testing.T) { + tempDir := t.TempDir() + yamlContent := `id: test.extension +version: "1.0.0" +displayName: Test Extension +description: A test extension +usage: test usage +capabilities: + - custom-commands +` + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "extension.yaml"), + []byte(yamlContent), + 0600, + )) + + ext, err := LoadExtension(tempDir) + require.NoError(t, err) + require.NotNil(t, ext) + require.Equal(t, "test.extension", ext.Id) + require.Equal(t, "1.0.0", ext.Version) + require.Equal(t, "Test Extension", ext.DisplayName) + require.Equal(t, "A test extension", ext.Description) + require.Equal(t, "test usage", ext.Usage) + require.Len(t, ext.Capabilities, 1) + require.Equal(t, + extensions.CustomCommandCapability, ext.Capabilities[0], + ) + + absPath, err := filepath.Abs(tempDir) + require.NoError(t, err) + require.Equal(t, absPath, ext.Path) +} + +func TestLoadExtension_FileNotFound(t *testing.T) { + tempDir := t.TempDir() + + ext, err := LoadExtension(tempDir) + require.Error(t, err) + require.Nil(t, ext) + require.Contains(t, err.Error(), "Extension manifest file not found") +} + +func TestLoadExtension_MissingId(t *testing.T) { + tempDir := t.TempDir() + yamlContent := `version: "1.0.0" +displayName: Test +description: desc +usage: usage +` + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "extension.yaml"), + []byte(yamlContent), + 0600, + )) + + ext, err := LoadExtension(tempDir) + require.Error(t, err) + require.Nil(t, ext) + require.Contains(t, err.Error(), "id is required") +} + +func TestLoadExtension_MissingVersion(t *testing.T) { + tempDir := t.TempDir() + yamlContent := `id: test.extension +displayName: Test +description: desc +usage: usage +` + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "extension.yaml"), + []byte(yamlContent), + 0600, + )) + + ext, err := LoadExtension(tempDir) + require.Error(t, err) + require.Nil(t, ext) + require.Contains(t, err.Error(), "version is required") +} + +func TestLoadExtension_InvalidYAML(t *testing.T) { + tempDir := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "extension.yaml"), + []byte(":::invalid:::yaml:::"), + 0600, + )) + + ext, err := LoadExtension(tempDir) + require.Error(t, err) + require.Nil(t, ext) + require.Contains(t, err.Error(), "id is required") +} + +func TestLoadRegistry_Success(t *testing.T) { + tempDir := t.TempDir() + registry := extensions.Registry{ + Extensions: []*extensions.ExtensionMetadata{ + { + Id: "ext.one", + DisplayName: "Extension One", + Versions: []extensions.ExtensionVersion{ + {Version: "1.0.0"}, + }, + }, + }, + } + + data, err := json.MarshalIndent(registry, "", " ") + require.NoError(t, err) + + regPath := filepath.Join(tempDir, "registry.json") + require.NoError(t, os.WriteFile(regPath, data, 0600)) + + loaded, err := LoadRegistry(regPath) + require.NoError(t, err) + require.NotNil(t, loaded) + require.Len(t, loaded.Extensions, 1) + require.Equal(t, "ext.one", loaded.Extensions[0].Id) + require.Equal(t, "Extension One", loaded.Extensions[0].DisplayName) +} + +func TestLoadRegistry_FileNotFound(t *testing.T) { + _, err := LoadRegistry("/nonexistent/registry.json") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to read registry file") +} + +func TestLoadRegistry_InvalidJSON(t *testing.T) { + tempDir := t.TempDir() + regPath := filepath.Join(tempDir, "registry.json") + require.NoError(t, os.WriteFile( + regPath, []byte("not json"), 0600, + )) + + _, err := LoadRegistry(regPath) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to parse registry file") +} + +func TestLoadRegistry_EmptyRegistry(t *testing.T) { + tempDir := t.TempDir() + regPath := filepath.Join(tempDir, "registry.json") + require.NoError(t, os.WriteFile( + regPath, []byte(`{"extensions":[]}`), 0600, + )) + + loaded, err := LoadRegistry(regPath) + require.NoError(t, err) + require.NotNil(t, loaded) + require.Empty(t, loaded.Extensions) +} diff --git a/cli/azd/extensions/microsoft.azd.extensions/internal/user_friendly_error_test.go b/cli/azd/extensions/microsoft.azd.extensions/internal/user_friendly_error_test.go new file mode 100644 index 00000000000..3bd54720ea3 --- /dev/null +++ b/cli/azd/extensions/microsoft.azd.extensions/internal/user_friendly_error_test.go @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUserFriendlyError_Error(t *testing.T) { + tests := []struct { + name string + err *UserFriendlyError + expected string + }{ + { + name: "BasicMessage", + err: &UserFriendlyError{ + ErrorMessage: "something went wrong", + UserDetails: "Try running the command again", + }, + expected: "something went wrong", + }, + { + name: "EmptyMessage", + err: &UserFriendlyError{ + ErrorMessage: "", + UserDetails: "details only", + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.err.Error()) + }) + } +} + +func TestUserFriendlyError_GetUserDetails(t *testing.T) { + tests := []struct { + name string + err *UserFriendlyError + expected string + }{ + { + name: "WithDetails", + err: &UserFriendlyError{ + ErrorMessage: "error", + UserDetails: "Run azd init first", + }, + expected: "Run azd init first", + }, + { + name: "EmptyDetails", + err: &UserFriendlyError{ + ErrorMessage: "error", + UserDetails: "", + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.err.GetUserDetails()) + }) + } +} + +func TestUserFriendlyError_ImplementsErrorInterface(t *testing.T) { + ufe := NewUserFriendlyError("test error", "test details") + var err error = ufe + require.Error(t, err) + require.Equal(t, "test error", err.Error()) +} + +func TestNewUserFriendlyError(t *testing.T) { + err := NewUserFriendlyError("the error", "the details") + require.NotNil(t, err) + require.Equal(t, "the error", err.ErrorMessage) + require.Equal(t, "the details", err.UserDetails) + require.Equal(t, "the error", err.Error()) + require.Equal(t, "the details", err.GetUserDetails()) +} + +func TestNewUserFriendlyErrorf(t *testing.T) { + tests := []struct { + name string + errorMessage string + userDetails string + args []any + expectedMessage string + expectedDetails string + }{ + { + name: "WithFormatArgs", + errorMessage: "build failed", + userDetails: "Run %s in %s directory", + args: []any{"go build", "/src"}, + expectedMessage: "build failed", + expectedDetails: "Run go build in /src directory", + }, + { + name: "NoFormatArgs", + errorMessage: "error", + userDetails: "plain details", + args: nil, + expectedMessage: "error", + expectedDetails: "plain details", + }, + { + name: "SingleArg", + errorMessage: "not found", + userDetails: "file %q does not exist", + args: []any{"config.yaml"}, + expectedMessage: "not found", + expectedDetails: `file "config.yaml" does not exist`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := NewUserFriendlyErrorf( + tt.errorMessage, tt.userDetails, tt.args..., + ) + require.NotNil(t, err) + require.Equal(t, tt.expectedMessage, err.Error()) + require.Equal(t, tt.expectedDetails, err.GetUserDetails()) + }) + } +} + +func TestUserFriendlyError_ErrorsAs(t *testing.T) { + ufe := NewUserFriendlyError("wrapped", "details") + wrapped := errors.Join(errors.New("context"), ufe) + + var target *UserFriendlyError + require.True(t, errors.As(wrapped, &target)) + require.Equal(t, "wrapped", target.ErrorMessage) + require.Equal(t, "details", target.UserDetails) +} diff --git a/cli/azd/internal/agent/consent/checker_test.go b/cli/azd/internal/agent/consent/checker_test.go new file mode 100644 index 00000000000..00a07c84f58 --- /dev/null +++ b/cli/azd/internal/agent/consent/checker_test.go @@ -0,0 +1,797 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package consent + +import ( + "context" + "fmt" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" +) + +// mockConsentManager is a simple in-memory ConsentManager for testing. +type mockConsentManager struct { + decisions map[string]*ConsentDecision + granted []ConsentRule + err error +} + +func (m *mockConsentManager) CheckConsent( + _ context.Context, req ConsentRequest, +) (*ConsentDecision, error) { + if m.err != nil { + return nil, m.err + } + if d, ok := m.decisions[req.ToolID]; ok { + return d, nil + } + return &ConsentDecision{ + Allowed: false, + RequiresPrompt: true, + Reason: "no consent", + }, nil +} + +func (m *mockConsentManager) GrantConsent( + _ context.Context, rule ConsentRule, +) error { + if m.err != nil { + return m.err + } + m.granted = append(m.granted, rule) + return nil +} + +func (m *mockConsentManager) ListConsentRules( + _ context.Context, _ ...FilterOption, +) ([]ConsentRule, error) { + return nil, nil +} + +func (m *mockConsentManager) ClearConsentRules( + _ context.Context, _ ...FilterOption, +) error { + return nil +} + +func (m *mockConsentManager) PromptWorkflowConsent( + _ context.Context, _ []string, +) error { + return nil +} + +func (m *mockConsentManager) IsProjectScopeAvailable( + _ context.Context, +) bool { + return false +} + +func TestNewConsentChecker(t *testing.T) { + mgr := &mockConsentManager{} + cc := NewConsentChecker(mgr, "test-server") + require.NotNil(t, cc) + require.Equal(t, "test-server", cc.serverName) +} + +func TestCheckToolConsent(t *testing.T) { + tests := []struct { + name string + decisions map[string]*ConsentDecision + wantAllow bool + wantPrompt bool + }{ + { + name: "Allowed", + decisions: map[string]*ConsentDecision{ + "srv/myTool": {Allowed: true, Reason: "allowed"}, + }, + wantAllow: true, + wantPrompt: false, + }, + { + name: "NoConsent", + decisions: map[string]*ConsentDecision{}, + wantAllow: false, + wantPrompt: true, + }, + { + name: "Denied", + decisions: map[string]*ConsentDecision{ + "srv/myTool": { + Allowed: false, + Reason: "denied", + }, + }, + wantAllow: false, + wantPrompt: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mgr := &mockConsentManager{decisions: tt.decisions} + cc := NewConsentChecker(mgr, "srv") + + decision, err := cc.CheckToolConsent( + context.Background(), + "myTool", + "does stuff", + mcp.ToolAnnotation{}, + ) + require.NoError(t, err) + require.Equal(t, tt.wantAllow, decision.Allowed) + require.Equal(t, tt.wantPrompt, decision.RequiresPrompt) + }) + } +} + +func TestCheckToolConsent_Error(t *testing.T) { + mgr := &mockConsentManager{ + err: fmt.Errorf("storage failure"), + } + cc := NewConsentChecker(mgr, "srv") + + _, err := cc.CheckToolConsent( + context.Background(), + "tool", + "desc", + mcp.ToolAnnotation{}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "storage failure") +} + +func TestCheckSamplingConsent(t *testing.T) { + mgr := &mockConsentManager{ + decisions: map[string]*ConsentDecision{ + "srv/sample": {Allowed: true, Reason: "ok"}, + }, + } + cc := NewConsentChecker(mgr, "srv") + + decision, err := cc.CheckSamplingConsent( + context.Background(), "sample", + ) + require.NoError(t, err) + require.True(t, decision.Allowed) +} + +func TestCheckElicitationConsent(t *testing.T) { + mgr := &mockConsentManager{ + decisions: map[string]*ConsentDecision{ + "srv/elicit": {Allowed: true, Reason: "ok"}, + }, + } + cc := NewConsentChecker(mgr, "srv") + + decision, err := cc.CheckElicitationConsent( + context.Background(), "elicit", + ) + require.NoError(t, err) + require.True(t, decision.Allowed) +} + +func TestFormatToolDescriptionWithAnnotations(t *testing.T) { + cc := &ConsentChecker{serverName: "test"} + boolPtr := func(b bool) *bool { return &b } + + tests := []struct { + name string + desc string + annotations mcp.ToolAnnotation + contains []string + notContains []string + }{ + { + name: "EmptyDescGetsDefault", + desc: "", + annotations: mcp.ToolAnnotation{}, + contains: []string{"No description available"}, + }, + { + name: "WithTitle", + desc: "base desc", + annotations: mcp.ToolAnnotation{ + Title: "My Tool", + }, + contains: []string{ + "base desc", + "Title: My Tool", + }, + }, + { + name: "ReadOnlyTrue", + desc: "desc", + annotations: mcp.ToolAnnotation{ + ReadOnlyHint: boolPtr(true), + }, + contains: []string{"Read-only operation"}, + }, + { + name: "ReadOnlyFalse", + desc: "desc", + annotations: mcp.ToolAnnotation{ + ReadOnlyHint: boolPtr(false), + }, + contains: []string{"May modify data"}, + }, + { + name: "DestructiveTrue", + desc: "desc", + annotations: mcp.ToolAnnotation{ + DestructiveHint: boolPtr(true), + }, + contains: []string{ + "Potentially destructive", + }, + }, + { + name: "DestructiveFalse", + desc: "desc", + annotations: mcp.ToolAnnotation{ + DestructiveHint: boolPtr(false), + }, + contains: []string{"Non-destructive"}, + }, + { + name: "IdempotentTrue", + desc: "desc", + annotations: mcp.ToolAnnotation{ + IdempotentHint: boolPtr(true), + }, + contains: []string{"safe to retry"}, + }, + { + name: "IdempotentFalse", + desc: "desc", + annotations: mcp.ToolAnnotation{ + IdempotentHint: boolPtr(false), + }, + contains: []string{"side effects on retry"}, + }, + { + name: "OpenWorldTrue", + desc: "desc", + annotations: mcp.ToolAnnotation{ + OpenWorldHint: boolPtr(true), + }, + contains: []string{"external resources"}, + }, + { + name: "OpenWorldFalse", + desc: "desc", + annotations: mcp.ToolAnnotation{ + OpenWorldHint: boolPtr(false), + }, + contains: []string{"local resources only"}, + }, + { + name: "AllAnnotations", + desc: "full desc", + annotations: mcp.ToolAnnotation{ + Title: "Full Tool", + ReadOnlyHint: boolPtr(true), + DestructiveHint: boolPtr(false), + IdempotentHint: boolPtr(true), + OpenWorldHint: boolPtr(false), + }, + contains: []string{ + "full desc", + "Tool characteristics:", + "Title: Full Tool", + "Read-only operation", + "Non-destructive", + "safe to retry", + "local resources only", + }, + }, + { + name: "NoAnnotations", + desc: "plain", + annotations: mcp.ToolAnnotation{}, + contains: []string{"plain"}, + notContains: []string{"Tool characteristics:"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := cc.formatToolDescriptionWithAnnotations( + tt.desc, tt.annotations, + ) + for _, s := range tt.contains { + require.Contains(t, result, s) + } + for _, s := range tt.notContains { + require.NotContains(t, result, s) + } + }) + } +} + +func TestGrantConsentFromChoice(t *testing.T) { + tests := []struct { + name string + toolID string + choice string + operation OperationType + wantScope Scope + wantTgt Target + wantAct ActionType + wantErr bool + errMsg string + }{ + { + name: "Once", + toolID: "srv/tool", + choice: "once", + operation: OperationTypeTool, + wantScope: ScopeOneTime, + wantTgt: NewToolTarget("srv", "tool"), + wantAct: ActionAny, + }, + { + name: "Session", + toolID: "srv/tool", + choice: "session", + operation: OperationTypeTool, + wantScope: ScopeSession, + wantTgt: NewToolTarget("srv", "tool"), + wantAct: ActionAny, + }, + { + name: "Project", + toolID: "srv/tool", + choice: "project", + operation: OperationTypeSampling, + wantScope: ScopeProject, + wantTgt: NewToolTarget("srv", "tool"), + wantAct: ActionAny, + }, + { + name: "Always", + toolID: "srv/tool", + choice: "always", + operation: OperationTypeTool, + wantScope: ScopeGlobal, + wantTgt: NewToolTarget("srv", "tool"), + wantAct: ActionAny, + }, + { + name: "Server", + toolID: "srv/tool", + choice: "server", + operation: OperationTypeTool, + wantScope: ScopeGlobal, + wantTgt: NewServerTarget("srv"), + wantAct: ActionAny, + }, + { + name: "Global", + toolID: "srv/tool", + choice: "global", + operation: OperationTypeElicitation, + wantScope: ScopeGlobal, + wantTgt: NewGlobalTarget(), + wantAct: ActionAny, + }, + { + name: "ReadOnlySession", + toolID: "srv/tool", + choice: "readonly_session", + operation: OperationTypeTool, + wantScope: ScopeSession, + wantTgt: NewGlobalTarget(), + wantAct: ActionReadOnly, + }, + { + name: "ReadOnlyGlobal", + toolID: "srv/tool", + choice: "readonly_global", + operation: OperationTypeTool, + wantScope: ScopeGlobal, + wantTgt: NewGlobalTarget(), + wantAct: ActionReadOnly, + }, + { + name: "ReadOnlySessionNonToolFails", + toolID: "srv/tool", + choice: "readonly_session", + operation: OperationTypeSampling, + wantErr: true, + errMsg: "readonly session option only available", + }, + { + name: "ReadOnlyGlobalNonToolFails", + toolID: "srv/tool", + choice: "readonly_global", + operation: OperationTypeSampling, + wantErr: true, + errMsg: "readonly global option only available", + }, + { + name: "UnknownChoice", + toolID: "srv/tool", + choice: "magic", + operation: OperationTypeTool, + wantErr: true, + errMsg: "unknown consent choice", + }, + { + name: "InvalidToolID", + toolID: "notool", + choice: "once", + operation: OperationTypeTool, + wantErr: true, + errMsg: "invalid toolId format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mgr := &mockConsentManager{} + cc := NewConsentChecker(mgr, "srv") + + err := cc.grantConsentFromChoice( + context.Background(), + tt.toolID, + tt.choice, + tt.operation, + ) + + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + return + } + + require.NoError(t, err) + require.Len(t, mgr.granted, 1) + rule := mgr.granted[0] + require.Equal(t, tt.wantScope, rule.Scope) + require.Equal(t, tt.wantTgt, rule.Target) + require.Equal(t, tt.wantAct, rule.Action) + require.Equal(t, tt.operation, rule.Operation) + require.Equal(t, PermissionAllow, rule.Permission) + }) + } +} + +func TestIsServerAlreadyTrusted(t *testing.T) { + t.Run("Trusted", func(t *testing.T) { + mgr := &mockConsentManager{ + decisions: map[string]*ConsentDecision{ + "srv/test-tool": {Allowed: true}, + }, + } + cc := NewConsentChecker(mgr, "srv") + require.True( + t, cc.isServerAlreadyTrusted( + context.Background(), OperationTypeTool, + ), + ) + }) + + t.Run("NotTrusted", func(t *testing.T) { + mgr := &mockConsentManager{ + decisions: map[string]*ConsentDecision{}, + } + cc := NewConsentChecker(mgr, "srv") + require.False( + t, cc.isServerAlreadyTrusted( + context.Background(), OperationTypeTool, + ), + ) + }) + + t.Run("ErrorReturnsFalse", func(t *testing.T) { + mgr := &mockConsentManager{ + err: fmt.Errorf("boom"), + } + cc := NewConsentChecker(mgr, "srv") + require.False( + t, cc.isServerAlreadyTrusted( + context.Background(), OperationTypeSampling, + ), + ) + }) +} + +func TestConsentManagerRuleMatchesFilters(t *testing.T) { + mgr := newTestConsentManager(t) + cm := mgr.(*consentManager) + ctx := context.Background() + + // Grant several rules of different types + rules := []ConsentRule{ + { + Scope: ScopeSession, + Target: NewToolTarget("srv", "read"), + Action: ActionReadOnly, + Operation: OperationTypeTool, + Permission: PermissionAllow, + }, + { + Scope: ScopeGlobal, + Target: NewToolTarget("srv", "write"), + Action: ActionAny, + Operation: OperationTypeTool, + Permission: PermissionAllow, + }, + { + Scope: ScopeSession, + Target: NewToolTarget("srv", "sample"), + Action: ActionAny, + Operation: OperationTypeSampling, + Permission: PermissionAllow, + }, + } + for _, r := range rules { + require.NoError(t, mgr.GrantConsent(ctx, r)) + } + + t.Run("FilterByScope", func(t *testing.T) { + listed, err := mgr.ListConsentRules( + ctx, WithScope(ScopeSession), + ) + require.NoError(t, err) + require.Len(t, listed, 2) + for _, r := range listed { + require.Equal(t, ScopeSession, r.Scope) + } + }) + + t.Run("FilterByOperation", func(t *testing.T) { + listed, err := mgr.ListConsentRules( + ctx, WithOperation(OperationTypeSampling), + ) + require.NoError(t, err) + require.Len(t, listed, 1) + require.Equal( + t, OperationTypeSampling, listed[0].Operation, + ) + }) + + t.Run("FilterByAction", func(t *testing.T) { + listed, err := mgr.ListConsentRules( + ctx, WithAction(ActionReadOnly), + ) + require.NoError(t, err) + require.Len(t, listed, 1) + require.Equal(t, ActionReadOnly, listed[0].Action) + }) + + t.Run("FilterByTarget", func(t *testing.T) { + tgt := NewToolTarget("srv", "write") + listed, err := mgr.ListConsentRules( + ctx, WithTarget(tgt), + ) + require.NoError(t, err) + require.Len(t, listed, 1) + require.Equal(t, tgt, listed[0].Target) + }) + + t.Run("FilterByPermission", func(t *testing.T) { + listed, err := mgr.ListConsentRules( + ctx, WithPermission(PermissionAllow), + ) + require.NoError(t, err) + require.Len(t, listed, 3) + }) + + t.Run("CombinedFilters", func(t *testing.T) { + listed, err := mgr.ListConsentRules( + ctx, + WithScope(ScopeSession), + WithOperation(OperationTypeTool), + ) + require.NoError(t, err) + require.Len(t, listed, 1) + require.Equal( + t, + NewToolTarget("srv", "read"), + listed[0].Target, + ) + }) + + t.Run("NoMatchReturnsEmpty", func(t *testing.T) { + listed, err := mgr.ListConsentRules( + ctx, WithPermission(PermissionDeny), + ) + require.NoError(t, err) + require.Empty(t, listed) + }) + + t.Run("EvaluateRule", func(t *testing.T) { + d := cm.evaluateRule(ConsentRule{ + Permission: PermissionAllow, + }) + require.True(t, d.Allowed) + + d = cm.evaluateRule(ConsentRule{ + Permission: PermissionDeny, + }) + require.False(t, d.Allowed) + require.False(t, d.RequiresPrompt) + + d = cm.evaluateRule(ConsentRule{ + Permission: PermissionPrompt, + }) + require.False(t, d.Allowed) + require.True(t, d.RequiresPrompt) + + d = cm.evaluateRule(ConsentRule{ + Permission: Permission("unknown"), + }) + require.False(t, d.Allowed) + require.True(t, d.RequiresPrompt) + }) +} + +func TestActionMatches(t *testing.T) { + mgr := newTestConsentManager(t) + cm := mgr.(*consentManager) + + tests := []struct { + name string + ruleAction ActionType + readOnly bool + want bool + }{ + {"AnyMatchesReadOnly", ActionAny, true, true}, + {"AnyMatchesNonReadOnly", ActionAny, false, true}, + {"ReadOnlyMatchesReadOnly", ActionReadOnly, true, true}, + {"ReadOnlyRejectsNonReadOnly", ActionReadOnly, false, false}, + {"UnknownRejects", ActionType("x"), true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := cm.actionMatches(tt.ruleAction, tt.readOnly) + require.Equal(t, tt.want, got) + }) + } +} + +func TestTargetMatches(t *testing.T) { + mgr := newTestConsentManager(t) + cm := mgr.(*consentManager) + + tests := []struct { + name string + rule Target + request Target + want bool + }{ + { + "GlobalStar", + Target("*"), Target("srv/tool"), true, + }, + { + "GlobalStarSlashStar", + Target("*/*"), Target("srv/tool"), true, + }, + { + "ServerWildcard", + Target("srv/*"), Target("srv/tool"), true, + }, + { + "ServerWildcardNoMatch", + Target("other/*"), Target("srv/tool"), false, + }, + { + "ExactMatch", + Target("srv/tool"), Target("srv/tool"), true, + }, + { + "ExactNoMatch", + Target("srv/tool"), Target("srv/other"), false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := cm.targetMatches(tt.rule, tt.request) + require.Equal(t, tt.want, got) + }) + } +} + +func TestClearConsentRules(t *testing.T) { + mgr := newTestConsentManager(t) + ctx := context.Background() + + // Grant rules in session scope + require.NoError(t, mgr.GrantConsent(ctx, ConsentRule{ + Scope: ScopeSession, + Target: NewToolTarget("srv", "a"), + Action: ActionAny, + Operation: OperationTypeTool, + Permission: PermissionAllow, + })) + require.NoError(t, mgr.GrantConsent(ctx, ConsentRule{ + Scope: ScopeSession, + Target: NewToolTarget("srv", "b"), + Action: ActionAny, + Operation: OperationTypeTool, + Permission: PermissionAllow, + })) + + // Verify we have 2 rules + all, err := mgr.ListConsentRules(ctx) + require.NoError(t, err) + require.Len(t, all, 2) + + // Clear all session rules + require.NoError(t, mgr.ClearConsentRules( + ctx, WithScope(ScopeSession), + )) + + // Verify empty + all, err = mgr.ListConsentRules(ctx) + require.NoError(t, err) + require.Empty(t, all) +} + +func TestGrantConsent_InvalidRule(t *testing.T) { + mgr := newTestConsentManager(t) + err := mgr.GrantConsent(context.Background(), ConsentRule{ + Scope: ScopeGlobal, + Target: Target(""), // invalid + Action: ActionAny, + Operation: OperationTypeTool, + Permission: PermissionAllow, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid consent rule") +} + +func TestGrantConsent_UnknownScope(t *testing.T) { + mgr := newTestConsentManager(t) + err := mgr.GrantConsent(context.Background(), ConsentRule{ + Scope: Scope("unknown"), + Target: NewToolTarget("s", "t"), + Action: ActionAny, + Operation: OperationTypeTool, + Permission: PermissionAllow, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid consent rule") +} + +func TestAddOrUpdateRule(t *testing.T) { + mgr := newTestConsentManager(t) + cm := mgr.(*consentManager) + + rules := []ConsentRule{} + r1 := ConsentRule{ + Target: NewToolTarget("s", "t"), + Operation: OperationTypeTool, + Action: ActionAny, + Permission: PermissionAllow, + } + + // Add new rule + rules = cm.addOrUpdateRule(rules, r1) + require.Len(t, rules, 1) + + // Update existing rule + r2 := r1 + r2.Permission = PermissionDeny + rules = cm.addOrUpdateRule(rules, r2) + require.Len(t, rules, 1) + require.Equal(t, PermissionDeny, rules[0].Permission) + + // Add different target + r3 := ConsentRule{ + Target: NewToolTarget("s", "other"), + Operation: OperationTypeTool, + Action: ActionAny, + Permission: PermissionAllow, + } + rules = cm.addOrUpdateRule(rules, r3) + require.Len(t, rules, 2) +} diff --git a/cli/azd/internal/agent/consent/types_test.go b/cli/azd/internal/agent/consent/types_test.go new file mode 100644 index 00000000000..8f48cc5571e --- /dev/null +++ b/cli/azd/internal/agent/consent/types_test.go @@ -0,0 +1,390 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package consent + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewToolTarget(t *testing.T) { + target := NewToolTarget("myServer", "myTool") + require.Equal(t, Target("myServer/myTool"), target) +} + +func TestNewServerTarget(t *testing.T) { + target := NewServerTarget("myServer") + require.Equal(t, Target("myServer/*"), target) +} + +func TestNewGlobalTarget(t *testing.T) { + target := NewGlobalTarget() + require.Equal(t, Target("*/*"), target) +} + +func TestTargetValidate(t *testing.T) { + tests := []struct { + name string + target Target + wantErr bool + errMsg string + }{ + { + name: "ValidServerTool", + target: Target("server/tool"), + }, + { + name: "ValidServerWildcard", + target: Target("server/*"), + }, + { + name: "ValidGlobalStar", + target: Target("*"), + }, + { + name: "ValidGlobalStarSlashStar", + target: Target("*/*"), + }, + { + name: "Empty", + target: Target(""), + wantErr: true, + errMsg: "target cannot be empty", + }, + { + name: "NoSlash", + target: Target("noslash"), + wantErr: true, + errMsg: "target must be in format", + }, + { + name: "EmptyServer", + target: Target("/tool"), + wantErr: true, + errMsg: "server part of target cannot be empty", + }, + { + name: "EmptyTool", + target: Target("server/"), + wantErr: true, + errMsg: "tool part of target cannot be empty", + }, + { + name: "TooManyParts", + target: Target("a/b/c"), + wantErr: true, + errMsg: "target must be in format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.target.Validate() + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestParseOperationType(t *testing.T) { + tests := []struct { + name string + input string + want OperationType + wantErr bool + }{ + {"Tool", "tool", OperationTypeTool, false}, + {"Sampling", "sampling", OperationTypeSampling, false}, + {"Elicitation", "elicitation", OperationTypeElicitation, false}, + {"Invalid", "unknown", "", true}, + {"Empty", "", "", true}, + {"CaseSensitive", "Tool", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseOperationType(tt.input) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid operation context") + } else { + require.NoError(t, err) + require.Equal(t, tt.want, got) + } + }) + } +} + +func TestParseScope(t *testing.T) { + tests := []struct { + name string + input string + want Scope + wantErr bool + }{ + {"Global", "global", ScopeGlobal, false}, + {"Project", "project", ScopeProject, false}, + {"Session", "session", ScopeSession, false}, + {"OneTime", "one_time", ScopeOneTime, false}, + {"Invalid", "forever", "", true}, + {"Empty", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseScope(tt.input) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid scope") + } else { + require.NoError(t, err) + require.Equal(t, tt.want, got) + } + }) + } +} + +func TestParseActionType(t *testing.T) { + tests := []struct { + name string + input string + want ActionType + wantErr bool + }{ + {"ReadOnly", "readonly", ActionReadOnly, false}, + {"All", "all", ActionAny, false}, + {"Invalid", "write", "", true}, + {"Empty", "", "", true}, + {"AnyLiteral", "any", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseActionType(tt.input) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid action type") + } else { + require.NoError(t, err) + require.Equal(t, tt.want, got) + } + }) + } +} + +func TestParsePermission(t *testing.T) { + tests := []struct { + name string + input string + want Permission + wantErr bool + }{ + {"Allow", "allow", PermissionAllow, false}, + {"Deny", "deny", PermissionDeny, false}, + {"Prompt", "prompt", PermissionPrompt, false}, + {"Invalid", "block", "", true}, + {"Empty", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParsePermission(tt.input) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid permission") + } else { + require.NoError(t, err) + require.Equal(t, tt.want, got) + } + }) + } +} + +func TestConsentRuleValidate(t *testing.T) { + validRule := ConsentRule{ + Scope: ScopeGlobal, + Target: NewToolTarget("server", "tool"), + Action: ActionAny, + Operation: OperationTypeTool, + Permission: PermissionAllow, + GrantedAt: time.Now(), + } + + tests := []struct { + name string + modify func(r *ConsentRule) + wantErr bool + errMsg string + }{ + { + name: "Valid", + modify: func(_ *ConsentRule) {}, + }, + { + name: "InvalidTarget", + modify: func(r *ConsentRule) { + r.Target = Target("") + }, + wantErr: true, + errMsg: "invalid target", + }, + { + name: "InvalidScope", + modify: func(r *ConsentRule) { + r.Scope = Scope("bogus") + }, + wantErr: true, + errMsg: "invalid scope", + }, + { + name: "InvalidAction", + modify: func(r *ConsentRule) { + r.Action = ActionType("write") + }, + wantErr: true, + errMsg: "invalid action", + }, + { + name: "InvalidOperation", + modify: func(r *ConsentRule) { + r.Operation = OperationType("deploy") + }, + wantErr: true, + errMsg: "invalid operation context", + }, + { + name: "InvalidPermission", + modify: func(r *ConsentRule) { + r.Permission = Permission("maybe") + }, + wantErr: true, + errMsg: "invalid decision", + }, + { + name: "AllScopes", + modify: func(r *ConsentRule) { + r.Scope = ScopeSession + }, + }, + { + name: "ReadOnlyAction", + modify: func(r *ConsentRule) { + r.Action = ActionReadOnly + }, + }, + { + name: "SamplingOperation", + modify: func(r *ConsentRule) { + r.Operation = OperationTypeSampling + }, + }, + { + name: "ElicitationOperation", + modify: func(r *ConsentRule) { + r.Operation = OperationTypeElicitation + }, + }, + { + name: "DenyPermission", + modify: func(r *ConsentRule) { + r.Permission = PermissionDeny + }, + }, + { + name: "PromptPermission", + modify: func(r *ConsentRule) { + r.Permission = PermissionPrompt + }, + }, + { + name: "GlobalWildcardTarget", + modify: func(r *ConsentRule) { + r.Target = NewGlobalTarget() + }, + }, + { + name: "ServerWildcardTarget", + modify: func(r *ConsentRule) { + r.Target = NewServerTarget("myServer") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := validRule + tt.modify(&rule) + err := rule.Validate() + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestFilterOptions(t *testing.T) { + t.Run("WithScope", func(t *testing.T) { + var opts FilterOptions + WithScope(ScopeGlobal)(&opts) + require.NotNil(t, opts.Scope) + require.Equal(t, ScopeGlobal, *opts.Scope) + }) + + t.Run("WithOperation", func(t *testing.T) { + var opts FilterOptions + WithOperation(OperationTypeTool)(&opts) + require.NotNil(t, opts.Operation) + require.Equal(t, OperationTypeTool, *opts.Operation) + }) + + t.Run("WithTarget", func(t *testing.T) { + var opts FilterOptions + target := NewToolTarget("s", "t") + WithTarget(target)(&opts) + require.NotNil(t, opts.Target) + require.Equal(t, target, *opts.Target) + }) + + t.Run("WithAction", func(t *testing.T) { + var opts FilterOptions + WithAction(ActionReadOnly)(&opts) + require.NotNil(t, opts.Action) + require.Equal(t, ActionReadOnly, *opts.Action) + }) + + t.Run("WithPermission", func(t *testing.T) { + var opts FilterOptions + WithPermission(PermissionDeny)(&opts) + require.NotNil(t, opts.Permission) + require.Equal(t, PermissionDeny, *opts.Permission) + }) + + t.Run("MultipleOptions", func(t *testing.T) { + var opts FilterOptions + for _, fn := range []FilterOption{ + WithScope(ScopeSession), + WithOperation(OperationTypeSampling), + WithPermission(PermissionAllow), + } { + fn(&opts) + } + require.NotNil(t, opts.Scope) + require.Equal(t, ScopeSession, *opts.Scope) + require.NotNil(t, opts.Operation) + require.Equal(t, OperationTypeSampling, *opts.Operation) + require.NotNil(t, opts.Permission) + require.Equal(t, PermissionAllow, *opts.Permission) + require.Nil(t, opts.Target) + require.Nil(t, opts.Action) + }) +} diff --git a/cli/azd/internal/agent/copilot/feature_test.go b/cli/azd/internal/agent/copilot/feature_test.go new file mode 100644 index 00000000000..d3b82f391e2 --- /dev/null +++ b/cli/azd/internal/agent/copilot/feature_test.go @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package copilot + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/alpha" + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/stretchr/testify/require" +) + +func TestIsFeatureEnabled_Enabled(t *testing.T) { + // Enable the feature via environment variable + t.Setenv("AZD_ALPHA_ENABLE_LLM", "true") + + ucm := &mockUCM{cfg: config.NewConfig(nil)} + mgr := alpha.NewFeaturesManager(ucm) + + err := IsFeatureEnabled(mgr) + require.NoError(t, err) +} + +func TestIsFeatureEnabled_Disabled(t *testing.T) { + // Ensure the env var is unset so the feature is off + t.Setenv("AZD_ALPHA_ENABLE_LLM", "false") + + ucm := &mockUCM{cfg: config.NewConfig(nil)} + mgr := alpha.NewFeaturesManager(ucm) + + err := IsFeatureEnabled(mgr) + require.Error(t, err) + require.Contains(t, err.Error(), DisplayTitle) + require.Contains(t, err.Error(), "not enabled") +} + +func TestFeatureCopilotKey(t *testing.T) { + // Verify the feature key is the backward-compatible "llm" + require.Equal(t, alpha.FeatureId("llm"), FeatureCopilot) +} + +// mockUCM implements config.UserConfigManager for testing. +type mockUCM struct { + cfg config.Config +} + +func (m *mockUCM) Load() (config.Config, error) { + return m.cfg, nil +} + +func (m *mockUCM) Save(_ config.Config) error { + return nil +} diff --git a/cli/azd/internal/agent/copilot/helpers_test.go b/cli/azd/internal/agent/copilot/helpers_test.go new file mode 100644 index 00000000000..7d3eb905921 --- /dev/null +++ b/cli/azd/internal/agent/copilot/helpers_test.go @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package copilot + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/stretchr/testify/require" +) + +func TestIndexOf(t *testing.T) { + tests := []struct { + name string + s string + c byte + want int + }{ + {"Found", "KEY=VALUE", '=', 3}, + {"NotFound", "KEYVALUE", '=', -1}, + {"Empty", "", '=', -1}, + {"FirstChar", "=value", '=', 0}, + {"LastChar", "key=", '=', 3}, + {"MultipleOccurrences", "a=b=c", '=', 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, indexOf(tt.s, tt.c)) + }) + } +} + +func TestGetStringSliceFromConfig(t *testing.T) { + t.Run("NotPresent", func(t *testing.T) { + c := config.NewConfig(nil) + result := getStringSliceFromConfig(c, "missing.key") + require.Nil(t, result) + }) + + t.Run("ValidStrings", func(t *testing.T) { + c := config.NewConfig(nil) + _ = c.Set("tools", []any{"a", "b", "c"}) + result := getStringSliceFromConfig(c, "tools") + require.Equal(t, []string{"a", "b", "c"}, result) + }) + + t.Run("MixedTypesFiltered", func(t *testing.T) { + c := config.NewConfig(nil) + _ = c.Set("tools", []any{"a", 42, "", "b", nil}) + result := getStringSliceFromConfig(c, "tools") + // Empty strings and non-strings are filtered + require.Equal(t, []string{"a", "b"}, result) + }) + + t.Run("EmptySlice", func(t *testing.T) { + c := config.NewConfig(nil) + _ = c.Set("tools", []any{}) + result := getStringSliceFromConfig(c, "tools") + require.Empty(t, result) + }) +} + +func TestGetUserMCPServers(t *testing.T) { + t.Run("NoServers", func(t *testing.T) { + c := config.NewConfig(nil) + result := getUserMCPServers(c) + require.Nil(t, result) + }) + + t.Run("WithServers", func(t *testing.T) { + c := config.NewConfig(nil) + _ = c.Set(ConfigKeyMCPServers, map[string]any{ + "myServer": map[string]any{ + "type": "http", + "url": "https://example.com", + "tools": []any{"*"}, + }, + }) + result := getUserMCPServers(c) + require.Len(t, result, 1) + require.Equal(t, "http", result["myServer"]["type"]) + require.Equal( + t, "https://example.com", result["myServer"]["url"], + ) + }) + + t.Run("EmptyMap", func(t *testing.T) { + c := config.NewConfig(nil) + _ = c.Set(ConfigKeyMCPServers, map[string]any{}) + result := getUserMCPServers(c) + require.Nil(t, result) + }) +} + +func TestCopilotClientManager_StopNilClient(t *testing.T) { + mgr := NewCopilotClientManager(nil, nil) + // Stop with nil client should not error + err := mgr.Stop() + require.NoError(t, err) +} + +func TestCopilotClientManager_ClientAccessor(t *testing.T) { + mgr := NewCopilotClientManager(nil, nil) + // Before Start, Client() returns nil + require.Nil(t, mgr.Client()) +} + +func TestCopilotClientManager_OptionsDefaults(t *testing.T) { + mgr := NewCopilotClientManager(nil, nil) + require.NotNil(t, mgr.options) + require.Empty(t, mgr.options.LogLevel) + require.Empty(t, mgr.options.CLIPath) +} + +func TestCopilotClientManager_WithCLIPath(t *testing.T) { + mgr := NewCopilotClientManager( + &CopilotClientOptions{CLIPath: "/custom/path"}, + nil, + ) + require.Equal(t, "/custom/path", mgr.options.CLIPath) +} diff --git a/cli/azd/internal/agent/logging/util_test.go b/cli/azd/internal/agent/logging/util_test.go new file mode 100644 index 00000000000..6dff6b51703 --- /dev/null +++ b/cli/azd/internal/agent/logging/util_test.go @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package logging + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTruncateString(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + expected string + }{ + { + name: "EmptyString", + input: "", + maxLen: 10, + expected: "", + }, + { + name: "ShorterThanMax", + input: "hello", + maxLen: 10, + expected: "hello", + }, + { + name: "ExactlyMaxLen", + input: "hello", + maxLen: 5, + expected: "hello", + }, + { + name: "LongerThanMax", + input: "hello world", + maxLen: 8, + expected: "hello...", + }, + { + name: "MinTruncation", + input: "abcdef", + maxLen: 4, + expected: "a...", + }, + { + name: "SingleCharOverflow", + input: "abcdef", + maxLen: 5, + expected: "ab...", + }, + { + name: "LongString", + input: "the quick brown fox jumps over the lazy dog", + maxLen: 20, + expected: "the quick brown f...", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := TruncateString(tt.input, tt.maxLen) + require.Equal(t, tt.expected, result) + require.LessOrEqual(t, len(result), tt.maxLen) + }) + } +} diff --git a/cli/azd/internal/cmd/add/add_unit_test.go b/cli/azd/internal/cmd/add/add_unit_test.go new file mode 100644 index 00000000000..73ba7e5d99e --- /dev/null +++ b/cli/azd/internal/cmd/add/add_unit_test.go @@ -0,0 +1,1038 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package add + +import ( + "bytes" + "context" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/internal/appdetect" + "github.com/azure/azure-dev/cli/azd/pkg/project" + dmp "github.com/sergi/go-diff/diffmatchpatch" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// validateServiceName +// --------------------------------------------------------------------------- + +func TestValidateServiceName(t *testing.T) { + tests := []struct { + name string + input string + services map[string]*project.ServiceConfig + wantError string + }{ + { + name: "valid name with no existing services", + input: "web-api", + services: map[string]*project.ServiceConfig{}, + }, + { + name: "empty name", + input: "", + services: map[string]*project.ServiceConfig{}, + wantError: "cannot be empty", + }, + { + name: "duplicate service name", + input: "api", + services: map[string]*project.ServiceConfig{ + "api": {}, + }, + wantError: "already exists", + }, + { + name: "name starts with hyphen", + input: "-invalid", + services: map[string]*project.ServiceConfig{}, + wantError: "must start with", + }, + { + name: "name ends with hyphen", + input: "invalid-", + services: map[string]*project.ServiceConfig{}, + wantError: "must end with", + }, + { + name: "uppercase letters", + input: "Invalid", + services: map[string]*project.ServiceConfig{}, + wantError: "must start with a lower", + }, + { + name: "valid single char", + input: "a", + services: map[string]*project.ServiceConfig{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prj := &project.ProjectConfig{ + Services: tt.services, + } + err := validateServiceName(tt.input, prj) + if tt.wantError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) + } else { + require.NoError(t, err) + } + }) + } +} + +// --------------------------------------------------------------------------- +// validateResourceName +// --------------------------------------------------------------------------- + +func TestValidateResourceName(t *testing.T) { + tests := []struct { + name string + input string + resources map[string]*project.ResourceConfig + wantError string + }{ + { + name: "valid name", + input: "my-db", + resources: map[string]*project.ResourceConfig{}, + }, + { + name: "empty name", + input: "", + resources: map[string]*project.ResourceConfig{}, + wantError: "cannot be empty", + }, + { + name: "duplicate resource name", + input: "redis", + resources: map[string]*project.ResourceConfig{ + "redis": {}, + }, + wantError: "already exists", + }, + { + name: "over 63 chars", + input: strings.Repeat("a", 64), + resources: map[string]*project.ResourceConfig{}, + wantError: "63 characters", + }, + { + name: "exact 63 chars is valid", + input: strings.Repeat("a", 63), + resources: map[string]*project.ResourceConfig{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prj := &project.ProjectConfig{ + Resources: tt.resources, + } + err := validateResourceName(tt.input, prj) + if tt.wantError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) + } else { + require.NoError(t, err) + } + }) + } +} + +// --------------------------------------------------------------------------- +// validateContainerName +// --------------------------------------------------------------------------- + +func TestValidateContainerName(t *testing.T) { + tests := []struct { + name string + input string + wantError string + }{ + { + name: "valid container name", + input: "my-container", + }, + { + name: "minimum length 3", + input: "abc", + }, + { + name: "too short", + input: "ab", + wantError: "3 characters or more", + }, + { + name: "consecutive hyphens", + input: "my--container", + wantError: "consecutive hyphens", + }, + { + name: "uppercase letters", + input: "MyContainer", + wantError: "lower case", + }, + { + name: "single char", + input: "a", + wantError: "3 characters or more", + }, + { + name: "all numbers", + input: "123", + }, + { + name: "empty string", + input: "", + wantError: "3 characters or more", + }, + { + name: "starts with hyphen", + input: "-abc", + wantError: "must start with", + }, + { + name: "ends with hyphen", + input: "abc-", + wantError: "must end with", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateContainerName(tt.input) + if tt.wantError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) + } else { + require.NoError(t, err) + } + }) + } +} + +// --------------------------------------------------------------------------- +// resourceType +// --------------------------------------------------------------------------- + +func TestResourceType(t *testing.T) { + tests := []struct { + name string + azureResType string + want project.ResourceType + }{ + { + name: "redis", + azureResType: "Microsoft.Cache/redis", + want: project.ResourceTypeDbRedis, + }, + { + name: "container apps", + azureResType: "Microsoft.App/containerApps", + want: project.ResourceTypeHostContainerApp, + }, + { + name: "app service", + azureResType: "Microsoft.Web/sites", + want: project.ResourceTypeHostAppService, + }, + { + name: "key vault", + azureResType: "Microsoft.KeyVault/vaults", + want: project.ResourceTypeKeyVault, + }, + { + name: "unknown type returns empty", + azureResType: "Microsoft.Unknown/things", + want: project.ResourceType(""), + }, + { + name: "empty string returns empty", + azureResType: "", + want: project.ResourceType(""), + }, + { + name: "storage accounts", + azureResType: "Microsoft.Storage/storageAccounts", + want: project.ResourceTypeStorage, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resourceType(tt.azureResType) + assert.Equal(t, tt.want, got) + }) + } +} + +// --------------------------------------------------------------------------- +// allStorageDataTypes +// --------------------------------------------------------------------------- + +func TestAllStorageDataTypes(t *testing.T) { + types := allStorageDataTypes() + require.Len(t, types, 1) + assert.Equal(t, StorageDataTypeBlob, types[0]) +} + +// --------------------------------------------------------------------------- +// fillAiProjectName +// --------------------------------------------------------------------------- + +func TestFillAiProjectName(t *testing.T) { + tests := []struct { + name string + rName string + resources map[string]*project.ResourceConfig + wantName string + }{ + { + name: "sets default name when empty", + rName: "", + resources: map[string]*project.ResourceConfig{}, + wantName: "ai-project", + }, + { + name: "keeps existing name", + rName: "my-ai", + resources: map[string]*project.ResourceConfig{}, + wantName: "my-ai", + }, + { + name: "appends suffix when default conflicts", + rName: "", + resources: map[string]*project.ResourceConfig{ + "ai-project": {}, + }, + wantName: "ai-project-2", + }, + { + name: "appends increasing suffix for multiple conflicts", + rName: "", + resources: map[string]*project.ResourceConfig{ + "ai-project": {}, + "ai-project-2": {}, + "ai-project-2-3": {}, + }, + // The naming logic appends to current: ai-project→ai-project-2→ai-project-2-3→ai-project-2-3-4 + wantName: "ai-project-2-3-4", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &project.ResourceConfig{Name: tt.rName} + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: tt.resources, + }, + } + got, err := fillAiProjectName( + context.Background(), r, nil, opts, + ) + require.NoError(t, err) + assert.Equal(t, tt.wantName, got.Name) + }) + } +} + +// --------------------------------------------------------------------------- +// Configure — singleton resource types (Redis, Search, KeyVault) +// --------------------------------------------------------------------------- + +func TestConfigure_SingletonResources(t *testing.T) { + tests := []struct { + name string + resType project.ResourceType + resources map[string]*project.ResourceConfig + wantName string + wantError string + }{ + { + name: "redis sets name", + resType: project.ResourceTypeDbRedis, + resources: map[string]*project.ResourceConfig{}, + wantName: "redis", + }, + { + name: "redis duplicate error", + resType: project.ResourceTypeDbRedis, + resources: map[string]*project.ResourceConfig{ + "redis": {}, + }, + wantError: "only one Redis", + }, + { + name: "search sets name", + resType: project.ResourceTypeAiSearch, + resources: map[string]*project.ResourceConfig{}, + wantName: "search", + }, + { + name: "search duplicate error", + resType: project.ResourceTypeAiSearch, + resources: map[string]*project.ResourceConfig{ + "search": {}, + }, + wantError: "only one AI Search", + }, + { + name: "keyvault sets name", + resType: project.ResourceTypeKeyVault, + resources: map[string]*project.ResourceConfig{}, + wantName: "vault", + }, + { + name: "keyvault duplicate error", + resType: project.ResourceTypeKeyVault, + resources: map[string]*project.ResourceConfig{ + "vault": {}, + }, + wantError: "already have a project key vault", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &project.ResourceConfig{Type: tt.resType} + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: tt.resources, + }, + } + got, err := Configure( + context.Background(), r, nil, opts, + ) + if tt.wantError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantName, got.Name) + } + }) + } +} + +// --------------------------------------------------------------------------- +// ServiceFromDetect +// --------------------------------------------------------------------------- + +func TestServiceFromDetect(t *testing.T) { + tests := []struct { + name string + root string + svcName string + prj appdetect.Project + svcKind project.ServiceTargetKind + check func(t *testing.T, svc project.ServiceConfig) + wantErr string + }{ + { + name: "basic python project", + root: "/projects", + svcName: "my-api", + prj: appdetect.Project{ + Path: "/projects/api", + Language: appdetect.Python, + }, + svcKind: project.ContainerAppTarget, + check: func(t *testing.T, svc project.ServiceConfig) { + assert.Equal(t, "my-api", svc.Name) + assert.Equal( + t, + project.ServiceLanguagePython, + svc.Language, + ) + assert.Equal(t, "api", svc.RelativePath) + assert.Equal( + t, + project.ContainerAppTarget, + svc.Host, + ) + }, + }, + { + name: "empty svc name uses dir name", + root: "/projects", + svcName: "", + prj: appdetect.Project{ + Path: "/projects/my-service", + Language: appdetect.JavaScript, + }, + svcKind: project.ContainerAppTarget, + check: func(t *testing.T, svc project.ServiceConfig) { + assert.Equal(t, "my-service", svc.Name) + assert.Equal( + t, + project.ServiceLanguageJavaScript, + svc.Language, + ) + }, + }, + { + name: "unsupported language", + root: "/projects", + svcName: "svc", + prj: appdetect.Project{ + Path: "/projects/app", + Language: appdetect.Language("cobol"), + }, + svcKind: project.ContainerAppTarget, + wantErr: "unsupported language", + }, + { + name: "dotnet with app service", + root: "/projects", + svcName: "web", + prj: appdetect.Project{ + Path: "/projects/web", + Language: appdetect.DotNet, + }, + svcKind: project.AppServiceTarget, + check: func(t *testing.T, svc project.ServiceConfig) { + assert.Equal( + t, + project.ServiceLanguageDotNet, + svc.Language, + ) + assert.Equal( + t, + project.AppServiceTarget, + svc.Host, + ) + }, + }, + { + name: "docker with non-container target errors", + root: "/projects", + svcName: "svc", + prj: appdetect.Project{ + Path: "/projects/app", + Language: appdetect.Python, + Docker: &appdetect.Docker{ + Path: "/projects/app/Dockerfile", + }, + }, + svcKind: project.AppServiceTarget, + wantErr: "unsupported host with Dockerfile", + }, + { + name: "web ui framework sets output path", + root: "/projects", + svcName: "spa", + prj: appdetect.Project{ + Path: "/projects/spa", + Language: appdetect.TypeScript, + Dependencies: []appdetect.Dependency{ + appdetect.JsVite, + }, + }, + svcKind: project.ContainerAppTarget, + check: func(t *testing.T, svc project.ServiceConfig) { + assert.Equal(t, "dist", svc.OutputPath) + }, + }, + { + name: "next.js clears output path", + root: "/projects", + svcName: "next", + prj: appdetect.Project{ + Path: "/projects/next", + Language: appdetect.JavaScript, + Dependencies: []appdetect.Dependency{ + appdetect.JsNext, + }, + }, + svcKind: project.ContainerAppTarget, + check: func(t *testing.T, svc project.ServiceConfig) { + assert.Equal(t, "", svc.OutputPath) + }, + }, + { + name: "react sets build output path", + root: "/projects", + svcName: "cra", + prj: appdetect.Project{ + Path: "/projects/cra", + Language: appdetect.JavaScript, + Dependencies: []appdetect.Dependency{ + appdetect.JsReact, + }, + }, + svcKind: project.ContainerAppTarget, + check: func(t *testing.T, svc project.ServiceConfig) { + assert.Equal(t, "build", svc.OutputPath) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, err := ServiceFromDetect( + tt.root, tt.svcName, tt.prj, tt.svcKind, + ) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + tt.check(t, svc) + } + }) + } +} + +// --------------------------------------------------------------------------- +// diffNotEq +// --------------------------------------------------------------------------- + +func TestDiffNotEq(t *testing.T) { + tests := []struct { + name string + in []dmp.Diff + want bool + }{ + { + name: "all equal", + in: []dmp.Diff{ + {Type: dmp.DiffEqual, Text: "hello"}, + }, + want: false, + }, + { + name: "has insert", + in: []dmp.Diff{ + {Type: dmp.DiffEqual, Text: "a"}, + {Type: dmp.DiffInsert, Text: "b"}, + }, + want: true, + }, + { + name: "has delete", + in: []dmp.Diff{ + {Type: dmp.DiffDelete, Text: "x"}, + }, + want: true, + }, + { + name: "empty slice", + in: []dmp.Diff{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, diffNotEq(tt.in)) + }) + } +} + +// --------------------------------------------------------------------------- +// lineDiffsFromStr +// --------------------------------------------------------------------------- + +func TestLineDiffsFromStr(t *testing.T) { + tests := []struct { + name string + op dmp.Operation + input string + wantN int + wantOp dmp.Operation + }{ + { + name: "single line", + op: dmp.DiffInsert, + input: "hello", + wantN: 1, + wantOp: dmp.DiffInsert, + }, + { + name: "multi line", + op: dmp.DiffDelete, + input: "a\nb\nc", + wantN: 3, + wantOp: dmp.DiffDelete, + }, + { + name: "empty string", + op: dmp.DiffEqual, + input: "", + wantN: 1, + wantOp: dmp.DiffEqual, + }, + { + name: "trailing newline", + op: dmp.DiffInsert, + input: "line1\n", + wantN: 2, + wantOp: dmp.DiffInsert, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := lineDiffsFromStr(tt.op, tt.input) + assert.Len(t, result, tt.wantN) + for _, r := range result { + assert.Equal(t, tt.wantOp, r.Type) + } + }) + } +} + +// --------------------------------------------------------------------------- +// linesDiffsFromTextDiffs +// --------------------------------------------------------------------------- + +func TestLinesDiffsFromTextDiffs(t *testing.T) { + diffs := []dmp.Diff{ + {Type: dmp.DiffEqual, Text: "line1\nline2"}, + {Type: dmp.DiffInsert, Text: "new"}, + } + result := linesDiffsFromTextDiffs(diffs) + // "line1\nline2" → 2 lines, "new" → 1 line = 3 total + require.Len(t, result, 3) + assert.Equal(t, dmp.DiffEqual, result[0].Type) + assert.Equal(t, "line1", result[0].Text) + assert.Equal(t, dmp.DiffEqual, result[1].Type) + assert.Equal(t, "line2", result[1].Text) + assert.Equal(t, dmp.DiffInsert, result[2].Type) + assert.Equal(t, "new", result[2].Text) +} + +// --------------------------------------------------------------------------- +// formatLine +// --------------------------------------------------------------------------- + +func TestFormatLine(t *testing.T) { + tests := []struct { + name string + op dmp.Operation + text string + indent int + check func(t *testing.T, out string) + }{ + { + name: "insert prefix", + op: dmp.DiffInsert, + text: "added", + indent: 0, + check: func(t *testing.T, out string) { + assert.Contains(t, out, "+") + assert.Contains(t, out, "added") + assert.True(t, strings.HasSuffix(out, "\n")) + }, + }, + { + name: "delete prefix", + op: dmp.DiffDelete, + text: "removed", + indent: 0, + check: func(t *testing.T, out string) { + assert.Contains(t, out, "-") + assert.Contains(t, out, "removed") + }, + }, + { + name: "equal prefix with indent", + op: dmp.DiffEqual, + text: "same", + indent: 4, + check: func(t *testing.T, out string) { + assert.Contains(t, out, " same") + assert.NotContains(t, out, "+") + assert.NotContains(t, out, "-") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := formatLine(tt.op, tt.text, tt.indent) + tt.check(t, out) + }) + } +} + +// --------------------------------------------------------------------------- +// DiffBlocks (integration of the diff helpers) +// --------------------------------------------------------------------------- + +func TestDiffBlocks_NewEntry(t *testing.T) { + old := map[string]*project.ResourceConfig{} + r := &project.ResourceConfig{ + Type: project.ResourceTypeDbRedis, + Name: "redis", + } + new := map[string]*project.ResourceConfig{ + "redis": r, + } + + result, err := DiffBlocks(old, new) + require.NoError(t, err) + // New entry should contain insert markers + assert.Contains(t, result, "redis:") + assert.Contains(t, result, "+") +} + +func TestDiffBlocks_NoChanges(t *testing.T) { + r := &project.ResourceConfig{ + Type: project.ResourceTypeDbRedis, + Name: "redis", + } + same := map[string]*project.ResourceConfig{ + "redis": r, + } + + result, err := DiffBlocks(same, same) + require.NoError(t, err) + assert.Empty(t, result) +} + +func TestDiffBlocks_EmptyMaps(t *testing.T) { + result, err := DiffBlocks( + map[string]*project.ResourceConfig{}, + map[string]*project.ResourceConfig{}, + ) + require.NoError(t, err) + assert.Empty(t, result) +} + +// --------------------------------------------------------------------------- +// previewWriter +// --------------------------------------------------------------------------- + +func TestPreviewWriter(t *testing.T) { + tests := []struct { + name string + input string + check func(t *testing.T, out string) + }{ + { + name: "plus line is green", + input: "+ added item\n", + check: func(t *testing.T, out string) { + // The output should contain the text + assert.Contains(t, out, "added item") + }, + }, + { + name: "minus line is red", + input: "- removed item\n", + check: func(t *testing.T, out string) { + assert.Contains(t, out, "removed item") + }, + }, + { + name: "b prefix replaced with space", + input: "b header text\n", + check: func(t *testing.T, out string) { + // 'b' is replaced with space + assert.Contains(t, out, "header text") + assert.NotContains(t, out, "b header") + }, + }, + { + name: "g prefix replaced with space", + input: "g green text\n", + check: func(t *testing.T, out string) { + assert.Contains(t, out, "green text") + }, + }, + { + name: "normal text unchanged", + input: " normal line\n", + check: func(t *testing.T, out string) { + assert.Contains(t, out, "normal line") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + pw := &previewWriter{w: &buf} + n, err := pw.Write([]byte(tt.input)) + require.NoError(t, err) + assert.Equal(t, len(tt.input), n) + tt.check(t, buf.String()) + }) + } +} + +// --------------------------------------------------------------------------- +// Metadata +// --------------------------------------------------------------------------- + +func TestMetadata(t *testing.T) { + tests := []struct { + name string + res *project.ResourceConfig + wantType string + wantHasVar bool + }{ + { + name: "redis resource returns metadata", + res: &project.ResourceConfig{ + Type: project.ResourceTypeDbRedis, + Name: "redis", + }, + wantType: "Microsoft.Cache/redis", + wantHasVar: true, + }, + { + name: "unknown resource type returns empty", + res: &project.ResourceConfig{ + Type: project.ResourceType("unknown.type"), + Name: "thing", + }, + wantType: "", + }, + { + name: "host resource uses uppercase name prefix", + res: &project.ResourceConfig{ + Type: project.ResourceTypeHostContainerApp, + Name: "web-api", + }, + wantType: "Microsoft.App/containerApps", + wantHasVar: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + meta := Metadata(tt.res) + if tt.wantType == "" { + assert.Empty(t, meta.ResourceType) + } else { + assert.Equal(t, tt.wantType, meta.ResourceType) + } + + if tt.wantHasVar { + assert.NotEmpty( + t, meta.Variables, + "expected variables for type %s", + tt.wantType, + ) + } + }) + } +} + +// --------------------------------------------------------------------------- +// DbMap +// --------------------------------------------------------------------------- + +func TestDbMap(t *testing.T) { + expected := map[appdetect.DatabaseDep]project.ResourceType{ + appdetect.DbMongo: project.ResourceTypeDbMongo, + appdetect.DbPostgres: project.ResourceTypeDbPostgres, + appdetect.DbMySql: project.ResourceTypeDbMySql, + appdetect.DbRedis: project.ResourceTypeDbRedis, + } + assert.Equal(t, expected, DbMap) +} + +// --------------------------------------------------------------------------- +// LanguageMap +// --------------------------------------------------------------------------- + +func TestLanguageMap(t *testing.T) { + assert.Equal( + t, + project.ServiceLanguageDotNet, + LanguageMap[appdetect.DotNet], + ) + assert.Equal( + t, + project.ServiceLanguageJava, + LanguageMap[appdetect.Java], + ) + assert.Equal( + t, + project.ServiceLanguagePython, + LanguageMap[appdetect.Python], + ) + assert.Equal( + t, + project.ServiceLanguageJavaScript, + LanguageMap[appdetect.JavaScript], + ) + assert.Equal( + t, + project.ServiceLanguageTypeScript, + LanguageMap[appdetect.TypeScript], + ) + assert.Len(t, LanguageMap, 5) +} + +// --------------------------------------------------------------------------- +// HostMap +// --------------------------------------------------------------------------- + +func TestHostMap(t *testing.T) { + assert.Equal( + t, + project.AppServiceTarget, + HostMap[project.ResourceTypeHostAppService], + ) + assert.Equal( + t, + project.ContainerAppTarget, + HostMap[project.ResourceTypeHostContainerApp], + ) + assert.Len(t, HostMap, 2) +} + +// --------------------------------------------------------------------------- +// ServiceLanguageMap +// --------------------------------------------------------------------------- + +func TestServiceLanguageMap(t *testing.T) { + pyRuntime := ServiceLanguageMap[project.ServiceLanguagePython] + assert.Equal( + t, + project.AppServiceRuntimeStackPython, + pyRuntime.Stack, + ) + + jsRuntime := ServiceLanguageMap[project.ServiceLanguageJavaScript] + assert.Equal( + t, + project.AppServiceRuntimeStackNode, + jsRuntime.Stack, + ) + + tsRuntime := ServiceLanguageMap[project.ServiceLanguageTypeScript] + assert.Equal( + t, + project.AppServiceRuntimeStackNode, + tsRuntime.Stack, + ) + + // Java and .NET are not in the map + _, hasJava := ServiceLanguageMap[project.ServiceLanguageJava] + assert.False(t, hasJava) +} + +// --------------------------------------------------------------------------- +// provisionSelection constants +// --------------------------------------------------------------------------- + +func TestProvisionSelectionConstants(t *testing.T) { + // iota constants: verify ordering and distinct values + assert.Equal(t, 0, int(provisionUnknown)) + assert.Equal(t, 1, int(provision)) + assert.Equal(t, 2, int(provisionPreview)) + assert.Equal(t, 3, int(provisionSkip)) +} diff --git a/cli/azd/internal/cmd/show/show_unit_test.go b/cli/azd/internal/cmd/show/show_unit_test.go new file mode 100644 index 00000000000..2d8a92487fb --- /dev/null +++ b/cli/azd/internal/cmd/show/show_unit_test.go @@ -0,0 +1,417 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package show + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers/v3" + "github.com/azure/azure-dev/cli/azd/pkg/contracts" + "github.com/azure/azure-dev/cli/azd/pkg/project" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// showTypeFromLanguage +// --------------------------------------------------------------------------- + +func TestShowTypeFromLanguage(t *testing.T) { + tests := []struct { + name string + language project.ServiceLanguageKind + want contracts.ShowType + }{ + { + name: "none", + language: project.ServiceLanguageNone, + want: contracts.ShowTypeNone, + }, + { + name: "dotnet", + language: project.ServiceLanguageDotNet, + want: contracts.ShowTypeDotNet, + }, + { + name: "csharp", + language: project.ServiceLanguageCsharp, + want: contracts.ShowTypeDotNet, + }, + { + name: "fsharp", + language: project.ServiceLanguageFsharp, + want: contracts.ShowTypeDotNet, + }, + { + name: "python", + language: project.ServiceLanguagePython, + want: contracts.ShowTypePython, + }, + { + name: "typescript", + language: project.ServiceLanguageTypeScript, + want: contracts.ShowTypeNode, + }, + { + name: "javascript", + language: project.ServiceLanguageJavaScript, + want: contracts.ShowTypeNode, + }, + { + name: "java", + language: project.ServiceLanguageJava, + want: contracts.ShowTypeJava, + }, + { + name: "custom", + language: project.ServiceLanguageCustom, + want: contracts.ShowTypeCustom, + }, + { + name: "unknown extension language", + language: project.ServiceLanguageKind("rust"), + want: contracts.ShowTypeNone, + }, + { + name: "empty string same as none", + language: project.ServiceLanguageKind(""), + want: contracts.ShowTypeNone, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := showTypeFromLanguage(tt.language) + assert.Equal(t, tt.want, got) + }) + } +} + +// --------------------------------------------------------------------------- +// selectContainer +// --------------------------------------------------------------------------- + +func TestSelectContainer(t *testing.T) { + ptr := func(s string) *string { return &s } + + tests := []struct { + name string + containers []*armappcontainers.Container + resourceName string + wantName *string // nil means expect nil return + }{ + { + name: "empty slice returns nil", + containers: []*armappcontainers.Container{}, + resourceName: "app", + wantName: nil, + }, + { + name: "nil slice returns nil", + containers: nil, + resourceName: "app", + wantName: nil, + }, + { + name: "single container returned", + containers: []*armappcontainers.Container{ + {Name: ptr("my-app")}, + }, + resourceName: "app", + wantName: ptr("my-app"), + }, + { + name: "single nil container returns nil", + containers: []*armappcontainers.Container{ + nil, + }, + resourceName: "app", + wantName: nil, + }, + { + name: "matches by resource name", + containers: []*armappcontainers.Container{ + {Name: ptr("sidecar")}, + {Name: ptr("web-api")}, + }, + resourceName: "web-api", + wantName: ptr("web-api"), + }, + { + name: "matches by resource name case insensitive", + containers: []*armappcontainers.Container{ + {Name: ptr("sidecar")}, + {Name: ptr("Web-Api")}, + }, + resourceName: "web-api", + wantName: ptr("Web-Api"), + }, + { + name: "matches main fallback", + containers: []*armappcontainers.Container{ + {Name: ptr("sidecar")}, + {Name: ptr("main")}, + }, + resourceName: "unknown", + wantName: ptr("main"), + }, + { + name: "matches main case insensitive", + containers: []*armappcontainers.Container{ + {Name: ptr("sidecar")}, + {Name: ptr("Main")}, + }, + resourceName: "other", + wantName: ptr("Main"), + }, + { + name: "no match returns nil", + containers: []*armappcontainers.Container{ + {Name: ptr("worker-a")}, + {Name: ptr("worker-b")}, + }, + resourceName: "web", + wantName: nil, + }, + { + name: "skips nil elements in multi", + containers: []*armappcontainers.Container{ + nil, + {Name: ptr("my-app")}, + }, + resourceName: "my-app", + wantName: ptr("my-app"), + }, + { + name: "container with nil name skipped", + containers: []*armappcontainers.Container{ + {Name: nil}, + {Name: ptr("app")}, + }, + resourceName: "app", + wantName: ptr("app"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := selectContainer( + tt.containers, tt.resourceName, + ) + if tt.wantName == nil { + assert.Nil(t, got) + } else { + require.NotNil(t, got) + require.NotNil(t, got.Name) + assert.Equal(t, *tt.wantName, *got.Name) + } + }) + } +} + +// --------------------------------------------------------------------------- +// getResourceMeta +// --------------------------------------------------------------------------- + +func TestGetResourceMeta(t *testing.T) { + tests := []struct { + name string + resourceID string + wantNil bool + wantType string // expected ResourceType in meta + }{ + { + name: "exact match container app", + resourceID: fmt.Sprintf( + "/subscriptions/%s/resourceGroups/rg/providers"+ + "/Microsoft.App/containerApps/myapp", + testSubscriptionID, + ), + wantType: "Microsoft.App/containerApps", + }, + { + name: "exact match redis", + resourceID: fmt.Sprintf( + "/subscriptions/%s/resourceGroups/rg/providers"+ + "/Microsoft.Cache/redis/myredis", + testSubscriptionID, + ), + wantType: "Microsoft.Cache/redis", + }, + { + name: "child resource matches parent prefix", + resourceID: fmt.Sprintf( + "/subscriptions/%s/resourceGroups/rg/providers"+ + "/Microsoft.CognitiveServices/accounts/myacct"+ + "/deployments/gpt4", + testSubscriptionID, + ), + wantType: "Microsoft.CognitiveServices/accounts/deployments", + }, + { + name: "unknown resource type returns nil", + resourceID: fmt.Sprintf( + "/subscriptions/%s/resourceGroups/rg/providers"+ + "/Microsoft.Unknown/things/foo", + testSubscriptionID, + ), + wantNil: true, + }, + { + name: "key vault exact match", + resourceID: fmt.Sprintf( + "/subscriptions/%s/resourceGroups/rg/providers"+ + "/Microsoft.KeyVault/vaults/myvault", + testSubscriptionID, + ), + wantType: "Microsoft.KeyVault/vaults", + }, + { + name: "storage account exact match", + resourceID: fmt.Sprintf( + "/subscriptions/%s/resourceGroups/rg/providers"+ + "/Microsoft.Storage/storageAccounts/mysa", + testSubscriptionID, + ), + wantType: "Microsoft.Storage/storageAccounts", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := arm.ParseResourceID(tt.resourceID) + require.NoError(t, err) + + meta, retID := getResourceMeta(*id) + if tt.wantNil { + assert.Nil(t, meta) + } else { + require.NotNil(t, meta) + assert.Equal(t, tt.wantType, meta.ResourceType) + assert.NotEmpty(t, retID.Name) + } + }) + } +} + +// --------------------------------------------------------------------------- +// getFullPathToProjectForService +// --------------------------------------------------------------------------- + +func TestGetFullPathToProjectForService(t *testing.T) { + t.Run("non-dotnet returns path directly", func(t *testing.T) { + dir := t.TempDir() + svc := &project.ServiceConfig{ + Name: "api", + Language: project.ServiceLanguagePython, + RelativePath: dir, + } + got, err := getFullPathToProjectForService(svc) + require.NoError(t, err) + assert.Equal(t, svc.Path(), got) + }) + + t.Run("dotnet with single csproj", func(t *testing.T) { + dir := t.TempDir() + csproj := filepath.Join(dir, "Api.csproj") + require.NoError(t, os.WriteFile(csproj, []byte(""), 0600)) + + svc := &project.ServiceConfig{ + Name: "api", + Language: project.ServiceLanguageDotNet, + RelativePath: dir, + } + got, err := getFullPathToProjectForService(svc) + require.NoError(t, err) + assert.Contains(t, got, "Api.csproj") + }) + + t.Run("dotnet with multiple project files errors", func(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(dir, "A.csproj"), + []byte(""), 0600, + )) + require.NoError(t, os.WriteFile( + filepath.Join(dir, "B.fsproj"), + []byte(""), 0600, + )) + + svc := &project.ServiceConfig{ + Name: "svc", + Language: project.ServiceLanguageCsharp, + RelativePath: dir, + } + _, err := getFullPathToProjectForService(svc) + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple .NET project files") + }) + + t.Run("dotnet with no project files errors", func(t *testing.T) { + dir := t.TempDir() + svc := &project.ServiceConfig{ + Name: "svc", + Language: project.ServiceLanguageFsharp, + RelativePath: dir, + } + _, err := getFullPathToProjectForService(svc) + require.Error(t, err) + assert.Contains(t, err.Error(), "could not determine") + }) + + t.Run("dotnet path is already a file", func(t *testing.T) { + dir := t.TempDir() + csproj := filepath.Join(dir, "My.csproj") + require.NoError(t, os.WriteFile(csproj, []byte(""), 0600)) + + svc := &project.ServiceConfig{ + Name: "svc", + Language: project.ServiceLanguageDotNet, + RelativePath: csproj, + } + got, err := getFullPathToProjectForService(svc) + require.NoError(t, err) + assert.Equal(t, svc.Path(), got) + }) +} + +// --------------------------------------------------------------------------- +// NewShowCmd +// --------------------------------------------------------------------------- + +func TestNewShowCmd(t *testing.T) { + cmd := NewShowCmd() + require.NotNil(t, cmd) + assert.Equal(t, "show [resource-name|resource-id]", cmd.Use) + assert.NotEmpty(t, cmd.Short) +} + +// --------------------------------------------------------------------------- +// showResourceOptions +// --------------------------------------------------------------------------- + +func TestShowResourceOptions_Defaults(t *testing.T) { + opts := showResourceOptions{} + assert.False(t, opts.showSecrets) + assert.Nil(t, opts.resourceSpec) + assert.Nil(t, opts.clientOpts) +} + +// --------------------------------------------------------------------------- +// showFlags.Bind +// --------------------------------------------------------------------------- + +func TestShowFlags_Bind(t *testing.T) { + cmd := NewShowCmd() + flags := NewShowFlags(cmd, nil) + require.NotNil(t, flags) + // Verify the --show-secrets flag was registered + f := cmd.Flags().Lookup("show-secrets") + require.NotNil(t, f) + assert.Equal(t, "false", f.DefValue) +} diff --git a/cli/azd/internal/grpcserver/ai_errors_test.go b/cli/azd/internal/grpcserver/ai_errors_test.go new file mode 100644 index 00000000000..179220ffe92 --- /dev/null +++ b/cli/azd/internal/grpcserver/ai_errors_test.go @@ -0,0 +1,371 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package grpcserver + +import ( + "errors" + "fmt" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/ai" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestAiStatusError(t *testing.T) { + tests := []struct { + name string + code codes.Code + reason string + message string + metadata map[string]string + }{ + { + name: "invalid argument with nil metadata", + code: codes.InvalidArgument, + reason: azdext.AiErrorReasonQuotaLocation, + message: "quota location required", + metadata: nil, + }, + { + name: "not found with metadata", + code: codes.NotFound, + reason: azdext.AiErrorReasonModelNotFound, + message: "model not found", + metadata: map[string]string{ + "model_name": "gpt-4o", + }, + }, + { + name: "failed precondition with metadata", + code: codes.FailedPrecondition, + reason: azdext.AiErrorReasonNoDeploymentMatch, + message: "no deployment match", + metadata: map[string]string{ + "model_name": "gpt-4", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := aiStatusError( + tt.code, tt.reason, tt.message, tt.metadata, + ) + require.Error(t, err) + + st, ok := status.FromError(err) + require.True(t, ok, "expected gRPC status error") + assert.Equal(t, tt.code, st.Code()) + assert.Equal(t, tt.message, st.Message()) + + // Extract ErrorInfo details + details := st.Details() + require.Len(t, details, 1) + + errInfo, ok := details[0].(*errdetails.ErrorInfo) + require.True(t, ok, "expected ErrorInfo detail") + assert.Equal(t, tt.reason, errInfo.Reason) + assert.Equal(t, azdext.AiErrorDomain, errInfo.Domain) + + if tt.metadata != nil { + assert.Equal(t, tt.metadata, errInfo.Metadata) + } + }) + } +} + +func TestMapAiResolveError(t *testing.T) { + tests := []struct { + name string + err error + modelName string + expectedCode codes.Code + expectedMsg string + }{ + { + name: "quota location required", + err: ai.ErrQuotaLocationRequired, + modelName: "gpt-4o", + expectedCode: codes.InvalidArgument, + expectedMsg: ai.ErrQuotaLocationRequired.Error(), + }, + { + name: "wrapped quota location required", + err: fmt.Errorf( + "resolving: %w", ai.ErrQuotaLocationRequired, + ), + modelName: "gpt-4o", + expectedCode: codes.InvalidArgument, + }, + { + name: "model not found", + err: ai.ErrModelNotFound, + modelName: "gpt-5-turbo", + expectedCode: codes.NotFound, + expectedMsg: ai.ErrModelNotFound.Error(), + }, + { + name: "wrapped model not found", + err: fmt.Errorf( + "%w: %q", ai.ErrModelNotFound, "gpt-5-turbo", + ), + modelName: "gpt-5-turbo", + expectedCode: codes.NotFound, + }, + { + name: "no deployment match", + err: ai.ErrNoDeploymentMatch, + modelName: "gpt-4o", + expectedCode: codes.FailedPrecondition, + expectedMsg: ai.ErrNoDeploymentMatch.Error(), + }, + { + name: "wrapped no deployment match", + err: fmt.Errorf( + "%w for model", ai.ErrNoDeploymentMatch, + ), + modelName: "gpt-4o", + expectedCode: codes.FailedPrecondition, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := mapAiResolveError(tt.err, tt.modelName) + require.Error(t, result) + + st, ok := status.FromError(result) + require.True(t, ok, "expected gRPC status error") + assert.Equal(t, tt.expectedCode, st.Code()) + + if tt.expectedMsg != "" { + assert.Equal(t, tt.expectedMsg, st.Message()) + } + }) + } +} + +func TestMapAiResolveError_DefaultCase(t *testing.T) { + someErr := errors.New("some unknown error") + result := mapAiResolveError(someErr, "gpt-4o") + require.Error(t, result) + + // The default case wraps with fmt.Errorf, not a gRPC status + _, ok := status.FromError(result) + // fmt.Errorf wrapping returns a status with codes.OK when + // extracted, but the error is not nil. We verify the message. + assert.True(t, ok || !ok) // always passes — real check below + assert.Contains( + t, result.Error(), "resolving model deployments", + ) + assert.ErrorIs(t, result, someErr) +} + +func TestMapAiResolveError_ModelNameInMetadata(t *testing.T) { + tests := []struct { + name string + err error + modelName string + reason string + }{ + { + name: "model not found includes model_name", + err: ai.ErrModelNotFound, + modelName: "gpt-4o-mini", + reason: azdext.AiErrorReasonModelNotFound, + }, + { + name: "no deployment match includes model_name", + err: ai.ErrNoDeploymentMatch, + modelName: "gpt-4", + reason: azdext.AiErrorReasonNoDeploymentMatch, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := mapAiResolveError(tt.err, tt.modelName) + st, ok := status.FromError(result) + require.True(t, ok) + + details := st.Details() + require.Len(t, details, 1) + + errInfo, ok := details[0].(*errdetails.ErrorInfo) + require.True(t, ok) + assert.Equal(t, tt.reason, errInfo.Reason) + assert.Equal( + t, tt.modelName, errInfo.Metadata["model_name"], + ) + }) + } +} + +func TestRequireSubscriptionID(t *testing.T) { + tests := []struct { + name string + ctx *azdext.AzureContext + expectSubID string + expectError bool + }{ + { + name: "nil azure context", + ctx: nil, + expectError: true, + }, + { + name: "nil scope", + ctx: &azdext.AzureContext{}, + expectError: true, + }, + { + name: "empty subscription id", + ctx: &azdext.AzureContext{ + Scope: &azdext.AzureScope{ + SubscriptionId: "", + }, + }, + expectError: true, + }, + { + name: "valid subscription id", + ctx: &azdext.AzureContext{ + Scope: &azdext.AzureScope{ + SubscriptionId: "sub-123-abc", + }, + }, + expectSubID: "sub-123-abc", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + subID, err := requireSubscriptionID(tt.ctx) + if tt.expectError { + require.Error(t, err) + assert.Empty(t, subID) + + // Should be a gRPC InvalidArgument + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal( + t, codes.InvalidArgument, st.Code(), + ) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectSubID, subID) + } + }) + } +} + +func TestProtoToFilterOptions(t *testing.T) { + t.Run("nil input returns nil", func(t *testing.T) { + result := protoToFilterOptions(nil) + assert.Nil(t, result) + }) + + t.Run("maps all fields", func(t *testing.T) { + input := &azdext.AiModelFilterOptions{ + Locations: []string{"eastus", "westus"}, + Capabilities: []string{"chatCompletion"}, + Formats: []string{"OpenAI"}, + Statuses: []string{"Stable"}, + ExcludeModelNames: []string{"gpt-3"}, + } + + result := protoToFilterOptions(input) + require.NotNil(t, result) + assert.Equal(t, input.Locations, result.Locations) + assert.Equal( + t, input.Capabilities, result.Capabilities, + ) + assert.Equal(t, input.Formats, result.Formats) + assert.Equal(t, input.Statuses, result.Statuses) + assert.Equal( + t, input.ExcludeModelNames, + result.ExcludeModelNames, + ) + }) + + t.Run("empty slices preserved", func(t *testing.T) { + input := &azdext.AiModelFilterOptions{} + result := protoToFilterOptions(input) + require.NotNil(t, result) + assert.Nil(t, result.Locations) + assert.Nil(t, result.Capabilities) + }) +} + +func TestProtoToDeploymentOptions(t *testing.T) { + t.Run("nil input returns nil", func(t *testing.T) { + result := protoToDeploymentOptions(nil) + assert.Nil(t, result) + }) + + t.Run("maps all fields without capacity", func(t *testing.T) { + input := &azdext.AiModelDeploymentOptions{ + Locations: []string{"eastus"}, + Versions: []string{"2024-01-15"}, + Skus: []string{"GlobalStandard"}, + Capacity: nil, + } + + result := protoToDeploymentOptions(input) + require.NotNil(t, result) + assert.Equal(t, input.Locations, result.Locations) + assert.Equal(t, input.Versions, result.Versions) + assert.Equal(t, input.Skus, result.Skus) + assert.Nil(t, result.Capacity) + }) + + t.Run("maps capacity pointer", func(t *testing.T) { + cap := int32(100) + input := &azdext.AiModelDeploymentOptions{ + Locations: []string{"eastus"}, + Capacity: &cap, + } + + result := protoToDeploymentOptions(input) + require.NotNil(t, result) + require.NotNil(t, result.Capacity) + assert.Equal(t, int32(100), *result.Capacity) + + // Verify it's a copy, not the same pointer + assert.NotSame(t, &cap, result.Capacity) + }) +} + +func TestProtoToQuotaCheckOptions(t *testing.T) { + t.Run("nil input returns nil", func(t *testing.T) { + result := protoToQuotaCheckOptions(nil) + assert.Nil(t, result) + }) + + t.Run("maps min remaining capacity", func(t *testing.T) { + input := &azdext.QuotaCheckOptions{ + MinRemainingCapacity: 42.5, + } + + result := protoToQuotaCheckOptions(input) + require.NotNil(t, result) + assert.Equal(t, 42.5, result.MinRemainingCapacity) + }) + + t.Run("zero value maps correctly", func(t *testing.T) { + input := &azdext.QuotaCheckOptions{ + MinRemainingCapacity: 0, + } + + result := protoToQuotaCheckOptions(input) + require.NotNil(t, result) + assert.Equal(t, float64(0), result.MinRemainingCapacity) + }) +} diff --git a/cli/azd/internal/mcp/proxy_handlers_test.go b/cli/azd/internal/mcp/proxy_handlers_test.go new file mode 100644 index 00000000000..631d411943c --- /dev/null +++ b/cli/azd/internal/mcp/proxy_handlers_test.go @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package mcp + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewProxySamplingHandler(t *testing.T) { + handler := NewProxySamplingHandler() + require.NotNil(t, handler) + + concrete, ok := handler.(*ProxySamplingHandler) + require.True(t, ok, "expected *ProxySamplingHandler") + assert.Nil(t, concrete.host, "host should be nil initially") +} + +func TestNewProxyElicitationHandler(t *testing.T) { + handler := NewProxyElicitationHandler() + require.NotNil(t, handler) + + concrete, ok := handler.(*ProxyElicitationHandler) + require.True(t, ok, "expected *ProxyElicitationHandler") + assert.Nil(t, concrete.host, "host should be nil initially") +} + +func TestEnsureMcpProxy(t *testing.T) { + nonNilServer := &server.MCPServer{} + + tests := []struct { + name string + setupHost func() *McpHost + expectErr bool + errContains string + }{ + { + name: "nil proxy server", + setupHost: func() *McpHost { + h := NewMcpHost() + h.session = &simpleSession{} + return h + }, + expectErr: true, + errContains: "MCP host proxy server not set", + }, + { + name: "nil session", + setupHost: func() *McpHost { + h := NewMcpHost() + h.proxyServer = nonNilServer + return h + }, + expectErr: true, + errContains: "MCP host session not set", + }, + { + name: "both nil", + setupHost: func() *McpHost { + return NewMcpHost() + }, + expectErr: true, + errContains: "MCP host proxy server not set", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host := tt.setupHost() + err := ensureMcpProxy(host) + + if tt.expectErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestProxySamplingHandler_CreateMessage_NilHost(t *testing.T) { + // ensureMcpProxy dereferences host without a nil check, + // so a nil host is a programming error that panics. + handler := &ProxySamplingHandler{} + ctx := context.Background() + req := mcp.CreateMessageRequest{} + + assert.Panics(t, func() { + _, _ = handler.CreateMessage(ctx, req) + }) +} + +func TestProxySamplingHandler_CreateMessage_NoProxyServer(t *testing.T) { + host := NewMcpHost() + handler := &ProxySamplingHandler{host: host} + + ctx := context.Background() + req := mcp.CreateMessageRequest{} + + result, err := handler.CreateMessage(ctx, req) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "MCP host proxy server not set") +} + +func TestProxySamplingHandler_CreateMessage_NoSession(t *testing.T) { + host := NewMcpHost() + host.proxyServer = &server.MCPServer{} + handler := &ProxySamplingHandler{host: host} + + ctx := context.Background() + req := mcp.CreateMessageRequest{} + + result, err := handler.CreateMessage(ctx, req) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "MCP host session not set") +} + +func TestProxyElicitationHandler_Elicit_NilHost(t *testing.T) { + // ensureMcpProxy dereferences host without a nil check, + // so a nil host is a programming error that panics. + handler := &ProxyElicitationHandler{} + ctx := context.Background() + req := mcp.ElicitationRequest{} + + assert.Panics(t, func() { + _, _ = handler.Elicit(ctx, req) + }) +} + +func TestProxyElicitationHandler_Elicit_NoProxyServer(t *testing.T) { + host := NewMcpHost() + handler := &ProxyElicitationHandler{host: host} + + ctx := context.Background() + req := mcp.ElicitationRequest{} + + result, err := handler.Elicit(ctx, req) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "MCP host proxy server not set") +} + +func TestProxyElicitationHandler_Elicit_NoSession(t *testing.T) { + host := NewMcpHost() + host.proxyServer = &server.MCPServer{} + handler := &ProxyElicitationHandler{host: host} + + ctx := context.Background() + req := mcp.ElicitationRequest{} + + result, err := handler.Elicit(ctx, req) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "MCP host session not set") +} diff --git a/cli/azd/internal/mcp/tools/tools_registration_test.go b/cli/azd/internal/mcp/tools/tools_registration_test.go new file mode 100644 index 00000000000..d0d9a31ee98 --- /dev/null +++ b/cli/azd/internal/mcp/tools/tools_registration_test.go @@ -0,0 +1,237 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package tools + +import ( + "context" + "encoding/json" + "testing" + + "github.com/azure/azure-dev/cli/azd/internal/mcp/tools/prompts" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAzdErrorTroubleShootingTool(t *testing.T) { + tool := NewAzdErrorTroubleShootingTool() + + assert.Equal(t, "error_troubleshooting", tool.Tool.Name) + assert.NotEmpty(t, tool.Tool.Description) + assert.NotNil(t, tool.Handler) + + // Verify annotations + require.NotNil(t, tool.Tool.Annotations.ReadOnlyHint) + assert.True(t, *tool.Tool.Annotations.ReadOnlyHint) + require.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.True(t, *tool.Tool.Annotations.IdempotentHint) + require.NotNil(t, tool.Tool.Annotations.DestructiveHint) + assert.False(t, *tool.Tool.Annotations.DestructiveHint) + require.NotNil(t, tool.Tool.Annotations.OpenWorldHint) + assert.False(t, *tool.Tool.Annotations.OpenWorldHint) +} + +func TestNewAzdProvisionCommonErrorTool(t *testing.T) { + tool := NewAzdProvisionCommonErrorTool() + + assert.Equal(t, "provision_common_error", tool.Tool.Name) + assert.NotEmpty(t, tool.Tool.Description) + assert.NotNil(t, tool.Handler) + + // Verify annotations + require.NotNil(t, tool.Tool.Annotations.ReadOnlyHint) + assert.True(t, *tool.Tool.Annotations.ReadOnlyHint) + require.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.True(t, *tool.Tool.Annotations.IdempotentHint) + require.NotNil(t, tool.Tool.Annotations.DestructiveHint) + assert.False(t, *tool.Tool.Annotations.DestructiveHint) + require.NotNil(t, tool.Tool.Annotations.OpenWorldHint) + assert.False(t, *tool.Tool.Annotations.OpenWorldHint) +} + +func TestNewAzdYamlSchemaTool(t *testing.T) { + tool := NewAzdYamlSchemaTool() + + assert.Equal(t, "validate_azure_yaml", tool.Tool.Name) + assert.NotEmpty(t, tool.Tool.Description) + assert.NotNil(t, tool.Handler) + + // Verify annotations + require.NotNil(t, tool.Tool.Annotations.ReadOnlyHint) + assert.True(t, *tool.Tool.Annotations.ReadOnlyHint) + require.NotNil(t, tool.Tool.Annotations.IdempotentHint) + assert.True(t, *tool.Tool.Annotations.IdempotentHint) + require.NotNil(t, tool.Tool.Annotations.DestructiveHint) + assert.False(t, *tool.Tool.Annotations.DestructiveHint) + require.NotNil(t, tool.Tool.Annotations.OpenWorldHint) + assert.False(t, *tool.Tool.Annotations.OpenWorldHint) + + // Verify the tool has a "path" input property + inputSchema := tool.Tool.InputSchema + pathProp, hasProp := inputSchema.Properties["path"] + require.True(t, hasProp, "expected 'path' input property") + propMap, ok := pathProp.(map[string]any) + require.True(t, ok, "expected property to be a map") + assert.Equal(t, "string", propMap["type"]) + assert.Contains(t, inputSchema.Required, "path") +} + +func TestHandleAzdErrorTroubleShooting(t *testing.T) { + ctx := context.Background() + req := mcp.CallToolRequest{} + + result, err := handleAzdErrorTroubleShooting(ctx, req) + require.NoError(t, err) + require.NotNil(t, result) + + // The handler returns the embedded prompt text + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok, "expected TextContent") + assert.Equal( + t, prompts.AzdErrorTroubleShootingPrompt, textContent.Text, + ) +} + +func TestHandleAzdProvisionCommonError(t *testing.T) { + ctx := context.Background() + req := mcp.CallToolRequest{} + + result, err := handleAzdProvisionCommonError(ctx, req) + require.NoError(t, err) + require.NotNil(t, result) + + // The handler returns the embedded prompt text + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok, "expected TextContent") + assert.Equal( + t, prompts.AzdProvisionCommonErrorPrompt, textContent.Text, + ) +} + +func TestErrorResult(t *testing.T) { + tests := []struct { + name string + msg string + }{ + { + name: "simple message", + msg: "something went wrong", + }, + { + name: "empty message", + msg: "", + }, + { + name: "message with special chars", + msg: `file "azure.yaml" not found at /tmp/path`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := errorResult(tt.msg) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + textContent, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok, "expected TextContent") + + // Parse the JSON response + var resp ErrorResponse + err := json.Unmarshal( + []byte(textContent.Text), &resp, + ) + require.NoError(t, err) + + assert.True(t, resp.Error) + assert.Equal(t, tt.msg, resp.Message) + }) + } +} + +func TestErrorResponse_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + resp ErrorResponse + }{ + { + name: "error response", + resp: ErrorResponse{ + Error: true, + Message: "validation failed", + }, + }, + { + name: "non-error response", + resp: ErrorResponse{ + Error: false, + Message: "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.resp) + require.NoError(t, err) + + var decoded ErrorResponse + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, tt.resp, decoded) + }) + } +} + +func TestEmbeddedPrompts_NonEmpty(t *testing.T) { + assert.NotEmpty( + t, prompts.AzdErrorTroubleShootingPrompt, + "embedded prompt should not be empty", + ) + assert.NotEmpty( + t, prompts.AzdProvisionCommonErrorPrompt, + "embedded prompt should not be empty", + ) +} + +func TestNewHttpsUrlLoader(t *testing.T) { + loader := newHttpsUrlLoader() + require.NotNil(t, loader) +} + +func TestToolAnnotations_AllToolsSharePattern(t *testing.T) { + // All three tools share the same annotation pattern: + // read-only, idempotent, non-destructive, closed-world + tools := []struct { + name string + tool func() server.ServerTool + }{ + {"error_troubleshooting", NewAzdErrorTroubleShootingTool}, + {"provision_common_error", NewAzdProvisionCommonErrorTool}, + {"validate_azure_yaml", NewAzdYamlSchemaTool}, + } + + for _, tt := range tools { + t.Run(tt.name, func(t *testing.T) { + tool := tt.tool() + ann := tool.Tool.Annotations + + require.NotNil(t, ann.ReadOnlyHint) + assert.True(t, *ann.ReadOnlyHint) + + require.NotNil(t, ann.IdempotentHint) + assert.True(t, *ann.IdempotentHint) + + require.NotNil(t, ann.DestructiveHint) + assert.False(t, *ann.DestructiveHint) + + require.NotNil(t, ann.OpenWorldHint) + assert.False(t, *ann.OpenWorldHint) + }) + } +} diff --git a/cli/azd/internal/mcp/types_test.go b/cli/azd/internal/mcp/types_test.go new file mode 100644 index 00000000000..d8a4ca2a370 --- /dev/null +++ b/cli/azd/internal/mcp/types_test.go @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package mcp + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMcpConfig_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + config McpConfig + }{ + { + name: "empty config", + config: McpConfig{}, + }, + { + name: "config with stdio server", + config: McpConfig{ + Servers: map[string]*ServerConfig{ + "my-server": { + Type: "stdio", + Command: "/usr/bin/my-mcp-server", + Args: []string{"--port", "8080"}, + Env: []string{"KEY=VALUE"}, + }, + }, + }, + }, + { + name: "config with http server", + config: McpConfig{ + Servers: map[string]*ServerConfig{ + "http-server": { + Type: "http", + Url: "http://localhost:3000", + }, + }, + }, + }, + { + name: "config with multiple servers", + config: McpConfig{ + Servers: map[string]*ServerConfig{ + "server-a": { + Type: "stdio", + Command: "cmd-a", + }, + "server-b": { + Type: "http", + Url: "http://example.com", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.config) + require.NoError(t, err) + + var decoded McpConfig + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, tt.config, decoded) + }) + } +} + +func TestMcpConfig_JSONDeserialization(t *testing.T) { + raw := `{ + "servers": { + "test": { + "type": "stdio", + "url": "", + "command": "my-cmd", + "args": ["--flag"], + "env": ["FOO=BAR"] + } + } + }` + + var config McpConfig + err := json.Unmarshal([]byte(raw), &config) + require.NoError(t, err) + + require.Contains(t, config.Servers, "test") + srv := config.Servers["test"] + assert.Equal(t, "stdio", srv.Type) + assert.Equal(t, "my-cmd", srv.Command) + assert.Equal(t, []string{"--flag"}, srv.Args) + assert.Equal(t, []string{"FOO=BAR"}, srv.Env) +} + +func TestServerConfig_OmitsEmptyOptionalFields(t *testing.T) { + srv := ServerConfig{ + Type: "http", + Url: "http://localhost", + Command: "", + } + + data, err := json.Marshal(srv) + require.NoError(t, err) + + var decoded map[string]any + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + // Args and Env should be omitted when empty + _, hasArgs := decoded["args"] + _, hasEnv := decoded["env"] + assert.False(t, hasArgs, "empty args should be omitted") + assert.False(t, hasEnv, "empty env should be omitted") +} + +func TestCapabilities_ZeroValue(t *testing.T) { + cap := Capabilities{} + assert.Nil(t, cap.Sampling) + assert.Nil(t, cap.Elicitation) +} + +func TestCapabilities_WithHandlers(t *testing.T) { + sampling := NewProxySamplingHandler() + elicitation := NewProxyElicitationHandler() + + cap := Capabilities{ + Sampling: sampling, + Elicitation: elicitation, + } + + assert.NotNil(t, cap.Sampling) + assert.NotNil(t, cap.Elicitation) +} diff --git a/cli/azd/internal/tracing/resource/resource_test.go b/cli/azd/internal/tracing/resource/resource_test.go new file mode 100644 index 00000000000..77272ac4074 --- /dev/null +++ b/cli/azd/internal/tracing/resource/resource_test.go @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package resource + +import ( + "os" + "testing" + + "github.com/azure/azure-dev/cli/azd/internal/tracing/fields" +) + +// clearCIEnvVars unsets all CI-related environment variables so tests are deterministic. +func clearCIEnvVars(t *testing.T) { + t.Helper() + + ciVars := []string{ + "TF_BUILD", "GITHUB_ACTIONS", "APPVEYOR", "TRAVIS", "CIRCLECI", "GITLAB_CI", + "CODEBUILD_BUILD_ID", "JENKINS_URL", "TEAMCITY_VERSION", "JB_SPACE_API_URL", + "bamboo.buildKey", "BITBUCKET_BUILD_NUMBER", "CI", "BUILD_ID", + } + + for _, v := range ciVars { + t.Setenv(v, "") + os.Unsetenv(v) + } +} + +func TestExecEnvForCi_GitHub_Actions(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("GITHUB_ACTIONS", "true") + + result := execEnvForCi() + if result != fields.EnvGitHubActions { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvGitHubActions) + } +} + +func TestExecEnvForCi_Azure_Pipelines(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("TF_BUILD", "True") + + result := execEnvForCi() + if result != fields.EnvAzurePipelines { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvAzurePipelines) + } +} + +func TestExecEnvForCi_case_insensitive_bool(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("APPVEYOR", "TRUE") + + result := execEnvForCi() + if result != fields.EnvAppVeyor { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvAppVeyor) + } +} + +func TestExecEnvForCi_Travis(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("TRAVIS", "true") + + result := execEnvForCi() + if result != fields.EnvTravisCI { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvTravisCI) + } +} + +func TestExecEnvForCi_CircleCI(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("CIRCLECI", "true") + + result := execEnvForCi() + if result != fields.EnvCircleCI { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvCircleCI) + } +} + +func TestExecEnvForCi_GitLabCI(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("GITLAB_CI", "true") + + result := execEnvForCi() + if result != fields.EnvGitLabCI { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvGitLabCI) + } +} + +func TestExecEnvForCi_Jenkins(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("JENKINS_URL", "https://jenkins.example.com") + + result := execEnvForCi() + if result != fields.EnvJenkins { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvJenkins) + } +} + +func TestExecEnvForCi_AWS_CodeBuild(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("CODEBUILD_BUILD_ID", "build:123") + + result := execEnvForCi() + if result != fields.EnvAwsCodeBuild { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvAwsCodeBuild) + } +} + +func TestExecEnvForCi_TeamCity(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("TEAMCITY_VERSION", "2023.1") + + result := execEnvForCi() + if result != fields.EnvTeamCity { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvTeamCity) + } +} + +func TestExecEnvForCi_JetBrains_Space(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("JB_SPACE_API_URL", "https://space.example.com/api") + + result := execEnvForCi() + if result != fields.EnvJetBrainsSpace { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvJetBrainsSpace) + } +} + +func TestExecEnvForCi_BitBucket_Pipelines(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("BITBUCKET_BUILD_NUMBER", "42") + + result := execEnvForCi() + if result != fields.EnvBitBucketPipelines { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvBitBucketPipelines) + } +} + +func TestExecEnvForCi_unknown_CI_var(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("CI", "1") + + result := execEnvForCi() + if result != fields.EnvUnknownCI { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvUnknownCI) + } +} + +func TestExecEnvForCi_unknown_BUILD_ID(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("BUILD_ID", "some-build") + + result := execEnvForCi() + if result != fields.EnvUnknownCI { + t.Fatalf("execEnvForCi() = %q, want %q", result, fields.EnvUnknownCI) + } +} + +func TestExecEnvForCi_no_CI_vars(t *testing.T) { + clearCIEnvVars(t) + + result := execEnvForCi() + if result != "" { + t.Fatalf("execEnvForCi() = %q, want empty string", result) + } +} + +func TestExecEnvForCi_bool_false_not_matched(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("GITHUB_ACTIONS", "false") + + result := execEnvForCi() + if result != "" { + t.Fatalf("execEnvForCi() with GITHUB_ACTIONS=false = %q, want empty string", result) + } +} + +func TestExecEnvForCi_bool_precedence(t *testing.T) { + clearCIEnvVars(t) + // Bool rules are checked before set rules. + // TF_BUILD (bool) should win over CI (set) + t.Setenv("TF_BUILD", "true") + t.Setenv("CI", "1") + + result := execEnvForCi() + if result != fields.EnvAzurePipelines { + t.Fatalf("execEnvForCi() = %q, want %q (bool rule should win)", result, fields.EnvAzurePipelines) + } +} + +func TestIsRunningOnCI_true(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("GITHUB_ACTIONS", "true") + + if !IsRunningOnCI() { + t.Fatal("IsRunningOnCI() should return true when GITHUB_ACTIONS=true") + } +} + +func TestIsRunningOnCI_false(t *testing.T) { + clearCIEnvVars(t) + + if IsRunningOnCI() { + t.Fatal("IsRunningOnCI() should return false with no CI env vars") + } +} + +func TestIsValidMacAddress(t *testing.T) { + tests := []struct { + name string + addr string + isValid bool + }{ + {"valid address", "01:23:45:67:89:ab", true}, + {"all zeros", "00:00:00:00:00:00", false}, + {"all ff", "ff:ff:ff:ff:ff:ff", false}, + {"hyper-v default", "ac:de:48:00:11:22", false}, + {"another valid", "de:ad:be:ef:ca:fe", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidMacAddress(tt.addr) + if result != tt.isValid { + t.Fatalf("isValidMacAddress(%q) = %v, want %v", tt.addr, result, tt.isValid) + } + }) + } +} + +func TestExecEnvForHosts_codespaces(t *testing.T) { + clearCIEnvVars(t) + // Clear host vars — t.Setenv registers cleanup, os.Unsetenv removes for LookupEnv checks + t.Setenv("AZD_IN_CLOUDSHELL", "") + os.Unsetenv("AZD_IN_CLOUDSHELL") + t.Setenv("CODESPACES", "") + os.Unsetenv("CODESPACES") + + t.Setenv("CODESPACES", "true") + + result := execEnvForHosts() + if result != fields.EnvCodespaces { + t.Fatalf("execEnvForHosts() = %q, want %q", result, fields.EnvCodespaces) + } +} + +func TestExecEnvForHosts_no_host(t *testing.T) { + t.Setenv("AZD_IN_CLOUDSHELL", "") + os.Unsetenv("AZD_IN_CLOUDSHELL") + t.Setenv("CODESPACES", "") + os.Unsetenv("CODESPACES") + + result := execEnvForHosts() + if result != "" { + t.Fatalf("execEnvForHosts() = %q, want empty string", result) + } +} + +func TestNew_returns_non_nil_resource(t *testing.T) { + clearCIEnvVars(t) + t.Setenv("AZD_IN_CLOUDSHELL", "") + os.Unsetenv("AZD_IN_CLOUDSHELL") + t.Setenv("CODESPACES", "") + os.Unsetenv("CODESPACES") + t.Setenv("AZURE_DEV_USER_AGENT", "") + os.Unsetenv("AZURE_DEV_USER_AGENT") + + r := New() + if r == nil { + t.Fatal("New() returned nil") + } +} diff --git a/cli/azd/internal/vsrpc/models_test.go b/cli/azd/internal/vsrpc/models_test.go new file mode 100644 index 00000000000..31bdc3c0bd9 --- /dev/null +++ b/cli/azd/internal/vsrpc/models_test.go @@ -0,0 +1,280 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package vsrpc + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProgressMessage_WithMessage(t *testing.T) { + before := time.Now().Add(-time.Second) + original := ProgressMessage{ + Message: "original", + Severity: Warning, + Kind: Important, + Code: "E001", + AdditionalInfoLink: "https://example.com", + } + + updated := original.WithMessage("updated text") + + // Message and Time should change + assert.Equal(t, "updated text", updated.Message) + assert.True( + t, updated.Time.After(before), + "Time should be set to now", + ) + + // Other fields should be preserved + assert.Equal(t, Warning, updated.Severity) + assert.Equal(t, Important, updated.Kind) + assert.Equal(t, "E001", updated.Code) + assert.Equal( + t, "https://example.com", updated.AdditionalInfoLink, + ) + + // Original should be unchanged (value receiver) + assert.Equal(t, "original", original.Message) +} + +func TestNewInfoProgressMessage(t *testing.T) { + before := time.Now() + msg := newInfoProgressMessage("hello info") + after := time.Now() + + assert.Equal(t, "hello info", msg.Message) + assert.Equal(t, Info, msg.Severity) + assert.Equal(t, Logging, msg.Kind) + assert.True( + t, + !msg.Time.Before(before) && !msg.Time.After(after), + "Time should be approximately now", + ) + assert.Empty(t, msg.Code) + assert.Empty(t, msg.AdditionalInfoLink) +} + +func TestNewImportantProgressMessage(t *testing.T) { + before := time.Now() + msg := newImportantProgressMessage("hello important") + after := time.Now() + + assert.Equal(t, "hello important", msg.Message) + assert.Equal(t, Info, msg.Severity) + assert.Equal(t, Important, msg.Kind) + assert.True( + t, + !msg.Time.Before(before) && !msg.Time.After(after), + "Time should be approximately now", + ) +} + +func TestMessageSeverity_Values(t *testing.T) { + assert.Equal(t, MessageSeverity(0), Info) + assert.Equal(t, MessageSeverity(1), Warning) + assert.Equal(t, MessageSeverity(2), Error) +} + +func TestMessageKind_Values(t *testing.T) { + assert.Equal(t, MessageKind(0), Logging) + assert.Equal(t, MessageKind(1), Important) +} + +func TestDeleteMode_BitFlags(t *testing.T) { + // Verify they are distinct bit flags (use EqualValues + // since iota constants are untyped int, DeleteMode is uint32) + assert.EqualValues(t, 1, DeleteModeLocal) + assert.EqualValues(t, 2, DeleteModeAzureResources) + + // Verify they can be combined + combined := DeleteModeLocal | DeleteModeAzureResources + assert.True(t, combined&DeleteModeLocal != 0) + assert.True(t, combined&DeleteModeAzureResources != 0) + + // Verify single flags don't overlap + assert.EqualValues( + t, 0, DeleteModeLocal&DeleteModeAzureResources, + ) +} + +func TestEnvironment_JSONRoundTrip(t *testing.T) { + endpoint := "https://api.example.com" + resourceId := "/subscriptions/sub-id/rg/rg-name" + + env := Environment{ + Name: "dev", + IsCurrent: true, + Properties: map[string]string{ + "Subscription": "sub-123", + "Location": "eastus", + }, + Services: []*Service{ + { + Name: "web", + IsExternal: false, + Path: "./src/web", + Endpoint: &endpoint, + ResourceId: &resourceId, + }, + }, + Values: map[string]string{ + "AZURE_LOCATION": "eastus", + }, + LastDeployment: &DeploymentResult{ + Success: true, + Time: time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC), + Message: "Deployed successfully", + DeploymentId: "deploy-abc", + }, + Resources: []*Resource{ + { + Name: "rg-dev", + Type: "Microsoft.Resources/resourceGroups", + Id: "/subscriptions/sub-123/resourceGroups/rg-dev", + }, + }, + } + + data, err := json.Marshal(env) + require.NoError(t, err) + + var decoded Environment + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, env.Name, decoded.Name) + assert.Equal(t, env.IsCurrent, decoded.IsCurrent) + assert.Equal(t, env.Properties, decoded.Properties) + require.Len(t, decoded.Services, 1) + assert.Equal(t, "web", decoded.Services[0].Name) + assert.Equal(t, env.Values, decoded.Values) + require.NotNil(t, decoded.LastDeployment) + assert.Equal( + t, env.LastDeployment.DeploymentId, + decoded.LastDeployment.DeploymentId, + ) + require.Len(t, decoded.Resources, 1) + assert.Equal(t, "rg-dev", decoded.Resources[0].Name) +} + +func TestEnvironment_OmitsNilLastDeployment(t *testing.T) { + env := Environment{ + Name: "prod", + LastDeployment: nil, + } + + data, err := json.Marshal(env) + require.NoError(t, err) + + var raw map[string]any + err = json.Unmarshal(data, &raw) + require.NoError(t, err) + + _, has := raw["LastDeployment"] + assert.False(t, has, "nil LastDeployment should be omitted") +} + +func TestService_OmitsNilOptionalFields(t *testing.T) { + svc := Service{ + Name: "api", + IsExternal: true, + Path: "./src/api", + Endpoint: nil, + ResourceId: nil, + } + + data, err := json.Marshal(svc) + require.NoError(t, err) + + var raw map[string]any + err = json.Unmarshal(data, &raw) + require.NoError(t, err) + + _, hasEndpoint := raw["Endpoint"] + _, hasResourceId := raw["ResourceId"] + assert.False(t, hasEndpoint, "nil Endpoint should be omitted") + assert.False( + t, hasResourceId, "nil ResourceId should be omitted", + ) +} + +func TestInitializeServerOptions_JSON(t *testing.T) { + tests := []struct { + name string + opts InitializeServerOptions + }{ + { + name: "all nil", + opts: InitializeServerOptions{}, + }, + { + name: "all set", + opts: InitializeServerOptions{ + AuthenticationEndpoint: new("https://auth.local"), + AuthenticationKey: new("secret-key"), + AuthenticationCertificate: new("base64cert=="), + }, + }, + { + name: "partial", + opts: InitializeServerOptions{ + AuthenticationEndpoint: new("https://auth.local"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.opts) + require.NoError(t, err) + + var decoded InitializeServerOptions + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, tt.opts, decoded) + }) + } +} + +func TestRequestContext_Fields(t *testing.T) { + rc := RequestContext{ + Session: Session{Id: "sess-123"}, + HostProjectPath: "/home/user/project", + } + + assert.Equal(t, "sess-123", rc.Session.Id) + assert.Equal(t, "/home/user/project", rc.HostProjectPath) +} + +func TestEnvironmentInfo_Fields(t *testing.T) { + info := EnvironmentInfo{ + Name: "staging", + IsCurrent: true, + DotEnvPath: "/home/user/.env", + } + + assert.Equal(t, "staging", info.Name) + assert.True(t, info.IsCurrent) + assert.Equal(t, "/home/user/.env", info.DotEnvPath) +} + +func TestAspireHost_Fields(t *testing.T) { + host := AspireHost{ + Name: "my-aspire-host", + Path: "/path/to/apphost.csproj", + Services: []*Service{ + {Name: "api", Path: "./api"}, + {Name: "web", Path: "./web"}, + }, + } + + assert.Equal(t, "my-aspire-host", host.Name) + assert.Len(t, host.Services, 2) +} diff --git a/cli/azd/internal/vsrpc/stream_test.go b/cli/azd/internal/vsrpc/stream_test.go new file mode 100644 index 00000000000..560984c103c --- /dev/null +++ b/cli/azd/internal/vsrpc/stream_test.go @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package vsrpc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWsStream_Close_ReturnsNil(t *testing.T) { + // wsStream.Close is a no-op that returns nil. + // See TODO in stream.go referencing issue #3286. + s := wsStream{} + err := s.Close() + assert.NoError(t, err) +} diff --git a/cli/azd/pkg/ai/mapper_registry_test.go b/cli/azd/pkg/ai/mapper_registry_test.go new file mode 100644 index 00000000000..0fa2c2b2c5c --- /dev/null +++ b/cli/azd/pkg/ai/mapper_registry_test.go @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ai + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAiModelSkuToProto_RoundTrip(t *testing.T) { + tests := []struct { + name string + src AiModelSku + }{ + { + name: "fully populated", + src: AiModelSku{ + Name: "GlobalStandard", + UsageName: "OpenAI.GlobalStandard.gpt-4o", + DefaultCapacity: 10, + MinCapacity: 1, + MaxCapacity: 100, + CapacityStep: 5, + }, + }, + { + name: "zero values", + src: AiModelSku{ + Name: "", + UsageName: "", + DefaultCapacity: 0, + MinCapacity: 0, + MaxCapacity: 0, + CapacityStep: 0, + }, + }, + { + name: "large capacity", + src: AiModelSku{ + Name: "ProvisionedManaged", + UsageName: "OpenAI.ProvisionedManaged", + DefaultCapacity: 1000, + MinCapacity: 100, + MaxCapacity: 10000, + CapacityStep: 100, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + proto := aiModelSkuToProto(&tt.src) + require.NotNil(t, proto) + roundTripped := protoToAiModelSku(proto) + require.NotNil(t, roundTripped) + assert.Equal(t, tt.src, *roundTripped) + }) + } +} + +func TestProtoToAiModelSku(t *testing.T) { + proto := &azdext.AiModelSku{ + Name: "Standard", + UsageName: "OpenAI.Standard.gpt-4o", + DefaultCapacity: 25, + MinCapacity: 1, + MaxCapacity: 200, + CapacityStep: 1, + } + + result := protoToAiModelSku(proto) + require.NotNil(t, result) + assert.Equal(t, "Standard", result.Name) + assert.Equal(t, "OpenAI.Standard.gpt-4o", result.UsageName) + assert.Equal(t, int32(25), result.DefaultCapacity) + assert.Equal(t, int32(1), result.MinCapacity) + assert.Equal(t, int32(200), result.MaxCapacity) + assert.Equal(t, int32(1), result.CapacityStep) +} + +func TestAiModelVersionToProto_RoundTrip(t *testing.T) { + tests := []struct { + name string + src AiModelVersion + }{ + { + name: "default version with skus", + src: AiModelVersion{ + Version: "2024-05-13", + IsDefault: true, + Skus: []AiModelSku{ + { + Name: "Standard", + UsageName: "OpenAI.Standard.gpt-4o", + DefaultCapacity: 10, + MinCapacity: 1, + MaxCapacity: 100, + CapacityStep: 1, + }, + }, + }, + }, + { + name: "non-default version without skus", + src: AiModelVersion{ + Version: "1.0", + IsDefault: false, + Skus: []AiModelSku{}, + }, + }, + { + name: "multiple skus", + src: AiModelVersion{ + Version: "v2", + IsDefault: false, + Skus: []AiModelSku{ + { + Name: "Standard", + UsageName: "usage-a", + }, + { + Name: "GlobalStandard", + UsageName: "usage-b", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + proto, err := aiModelVersionToProto(&tt.src) + require.NoError(t, err) + require.NotNil(t, proto) + + roundTripped := protoToAiModelVersion(proto) + + assert.Equal(t, tt.src.Version, roundTripped.Version) + assert.Equal(t, tt.src.IsDefault, roundTripped.IsDefault) + require.Len(t, roundTripped.Skus, len(tt.src.Skus)) + + for i, sku := range tt.src.Skus { + assert.Equal(t, sku, roundTripped.Skus[i]) + } + }) + } +} + +func TestAiModelSkuToProto_FieldMapping(t *testing.T) { + src := &AiModelSku{ + Name: "TestSku", + UsageName: "Test.Usage.Name", + DefaultCapacity: 42, + MinCapacity: 5, + MaxCapacity: 500, + CapacityStep: 10, + } + + proto := aiModelSkuToProto(src) + + assert.Equal(t, "TestSku", proto.Name) + assert.Equal(t, "Test.Usage.Name", proto.UsageName) + assert.Equal(t, int32(42), proto.DefaultCapacity) + assert.Equal(t, int32(5), proto.MinCapacity) + assert.Equal(t, int32(500), proto.MaxCapacity) + assert.Equal(t, int32(10), proto.CapacityStep) +} diff --git a/cli/azd/pkg/ai/model_helpers_test.go b/cli/azd/pkg/ai/model_helpers_test.go new file mode 100644 index 00000000000..c64b824b19d --- /dev/null +++ b/cli/azd/pkg/ai/model_helpers_test.go @@ -0,0 +1,362 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ai + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestModelHasDefaultVersion(t *testing.T) { + tests := []struct { + name string + model AiModel + expected bool + }{ + { + name: "has default version", + model: AiModel{ + Name: "gpt-4o", + Versions: []AiModelVersion{ + {Version: "v1", IsDefault: false}, + {Version: "v2", IsDefault: true}, + }, + }, + expected: true, + }, + { + name: "no default version", + model: AiModel{ + Name: "gpt-4o", + Versions: []AiModelVersion{ + {Version: "v1", IsDefault: false}, + {Version: "v2", IsDefault: false}, + }, + }, + expected: false, + }, + { + name: "single default version", + model: AiModel{ + Name: "gpt-4o-mini", + Versions: []AiModelVersion{ + {Version: "2024-07-18", IsDefault: true}, + }, + }, + expected: true, + }, + { + name: "empty versions", + model: AiModel{ + Name: "empty-model", + Versions: []AiModelVersion{}, + }, + expected: false, + }, + { + name: "nil versions", + model: AiModel{ + Name: "nil-versions", + Versions: nil, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ModelHasDefaultVersion(tt.model) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestConvertSku(t *testing.T) { + tests := []struct { + name string + input *armcognitiveservices.ModelSKU + expected AiModelSku + }{ + { + name: "fully populated", + input: &armcognitiveservices.ModelSKU{ + Name: new("GlobalStandard"), + UsageName: new("OpenAI.GlobalStandard.gpt-4o"), + Capacity: &armcognitiveservices.CapacityConfig{ + Default: new(int32(10)), + Minimum: new(int32(1)), + Maximum: new(int32(100)), + Step: new(int32(5)), + }, + }, + expected: AiModelSku{ + Name: "GlobalStandard", + UsageName: "OpenAI.GlobalStandard.gpt-4o", + DefaultCapacity: 10, + MinCapacity: 1, + MaxCapacity: 100, + CapacityStep: 5, + }, + }, + { + name: "nil capacity", + input: &armcognitiveservices.ModelSKU{ + Name: new("Standard"), + UsageName: new("OpenAI.Standard.gpt-4o"), + Capacity: nil, + }, + expected: AiModelSku{ + Name: "Standard", + UsageName: "OpenAI.Standard.gpt-4o", + DefaultCapacity: 0, + MinCapacity: 0, + MaxCapacity: 0, + CapacityStep: 0, + }, + }, + { + name: "nil name and usage", + input: &armcognitiveservices.ModelSKU{ + Name: nil, + UsageName: nil, + Capacity: &armcognitiveservices.CapacityConfig{ + Default: new(int32(5)), + }, + }, + expected: AiModelSku{ + Name: "", + UsageName: "", + DefaultCapacity: 5, + MinCapacity: 0, + MaxCapacity: 0, + CapacityStep: 0, + }, + }, + { + name: "partial capacity fields", + input: &armcognitiveservices.ModelSKU{ + Name: new("ProvisionedManaged"), + UsageName: new("OpenAI.ProvisionedManaged"), + Capacity: &armcognitiveservices.CapacityConfig{ + Default: nil, + Minimum: new(int32(10)), + Maximum: nil, + Step: new(int32(10)), + }, + }, + expected: AiModelSku{ + Name: "ProvisionedManaged", + UsageName: "OpenAI.ProvisionedManaged", + DefaultCapacity: 0, + MinCapacity: 10, + MaxCapacity: 0, + CapacityStep: 10, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertSku(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestSafeString(t *testing.T) { + tests := []struct { + name string + input *string + expected string + }{ + { + name: "nil returns empty", + input: nil, + expected: "", + }, + { + name: "non-nil returns value", + input: new("hello"), + expected: "hello", + }, + { + name: "empty string returns empty", + input: new(""), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, safeString(tt.input)) + }) + } +} + +func TestSafeFloat64(t *testing.T) { + tests := []struct { + name string + input *float64 + expected float64 + }{ + { + name: "nil returns zero", + input: nil, + expected: 0, + }, + { + name: "non-nil returns value", + input: new(42.5), + expected: 42.5, + }, + { + name: "zero value returns zero", + input: new(0.0), + expected: 0, + }, + { + name: "negative value", + input: new(-1.5), + expected: -1.5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, safeFloat64(tt.input)) + }) + } +} + +func TestModelHasQuota(t *testing.T) { + model := AiModel{ + Name: "gpt-4o", + Versions: []AiModelVersion{ + { + Version: "2024-05-13", + Skus: []AiModelSku{ + { + Name: "Standard", + UsageName: "OpenAI.Standard.gpt-4o", + }, + { + Name: "GlobalStandard", + UsageName: "OpenAI.GlobalStandard.gpt-4o", + }, + }, + }, + }, + } + + tests := []struct { + name string + usageMap map[string]AiModelUsage + minRemaining float64 + expected bool + }{ + { + name: "has sufficient quota", + usageMap: map[string]AiModelUsage{ + "OpenAI.Standard.gpt-4o": { + Name: "OpenAI.Standard.gpt-4o", + CurrentValue: 10, + Limit: 100, + }, + }, + minRemaining: 50, + expected: true, + }, + { + name: "all quota exhausted", + usageMap: map[string]AiModelUsage{ + "OpenAI.Standard.gpt-4o": { + Name: "OpenAI.Standard.gpt-4o", + CurrentValue: 100, + Limit: 100, + }, + "OpenAI.GlobalStandard.gpt-4o": { + Name: "OpenAI.GlobalStandard.gpt-4o", + CurrentValue: 200, + Limit: 200, + }, + }, + minRemaining: 1, + expected: false, + }, + { + name: "one sku has quota the other exhausted", + usageMap: map[string]AiModelUsage{ + "OpenAI.Standard.gpt-4o": { + Name: "OpenAI.Standard.gpt-4o", + CurrentValue: 100, + Limit: 100, + }, + "OpenAI.GlobalStandard.gpt-4o": { + Name: "OpenAI.GlobalStandard.gpt-4o", + CurrentValue: 0, + Limit: 200, + }, + }, + minRemaining: 100, + expected: true, + }, + { + name: "no usage entries for model", + usageMap: map[string]AiModelUsage{}, + minRemaining: 1, + expected: false, + }, + { + name: "remaining exactly equals min", + usageMap: map[string]AiModelUsage{ + "OpenAI.Standard.gpt-4o": { + Name: "OpenAI.Standard.gpt-4o", + CurrentValue: 90, + Limit: 100, + }, + }, + minRemaining: 10, + expected: true, + }, + { + name: "remaining just below min", + usageMap: map[string]AiModelUsage{ + "OpenAI.Standard.gpt-4o": { + Name: "OpenAI.Standard.gpt-4o", + CurrentValue: 91, + Limit: 100, + }, + }, + minRemaining: 10, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := modelHasQuota( + model, tt.usageMap, tt.minRemaining, + ) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestModelHasQuota_EmptyVersions(t *testing.T) { + model := AiModel{ + Name: "empty", + Versions: []AiModelVersion{}, + } + + usageMap := map[string]AiModelUsage{ + "some.usage": { + Name: "some.usage", CurrentValue: 0, Limit: 100, + }, + } + + require.False(t, modelHasQuota(model, usageMap, 1)) +} diff --git a/cli/azd/pkg/ai/scope_test.go b/cli/azd/pkg/ai/scope_test.go new file mode 100644 index 00000000000..28a0aa3b1fe --- /dev/null +++ b/cli/azd/pkg/ai/scope_test.go @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ai + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewScope(t *testing.T) { + tests := []struct { + name string + subscriptionId string + resourceGroup string + workspace string + }{ + { + name: "all fields populated", + subscriptionId: "sub-123", + resourceGroup: "rg-test", + workspace: "ws-prod", + }, + { + name: "empty strings", + subscriptionId: "", + resourceGroup: "", + workspace: "", + }, + { + name: "only subscription", + subscriptionId: "sub-only", + resourceGroup: "", + workspace: "", + }, + { + name: "realistic azure ids", + subscriptionId: "00000000-0000-0000-0000-000000000001", + resourceGroup: "my-resource-group", + workspace: "my-ai-workspace", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scope := NewScope( + tt.subscriptionId, + tt.resourceGroup, + tt.workspace, + ) + + require.NotNil(t, scope) + assert.Equal(t, tt.subscriptionId, scope.SubscriptionId()) + assert.Equal(t, tt.resourceGroup, scope.ResourceGroup()) + assert.Equal(t, tt.workspace, scope.Workspace()) + }) + } +} diff --git a/cli/azd/pkg/auth/claims_test.go b/cli/azd/pkg/auth/claims_test.go new file mode 100644 index 00000000000..305f84888e2 --- /dev/null +++ b/cli/azd/pkg/auth/claims_test.go @@ -0,0 +1,523 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package auth + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/cloud" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// helper: build a minimal JWT from a claims map. +func buildTestJWT(t *testing.T, claims map[string]any) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString( + []byte(`{"alg":"none","typ":"JWT"}`)) + body, err := json.Marshal(claims) + require.NoError(t, err) + payload := base64.RawURLEncoding.EncodeToString(body) + sig := base64.RawURLEncoding.EncodeToString( + []byte("fakesig")) + return fmt.Sprintf("%s.%s.%s", header, payload, sig) +} + +// ---------- TokenClaims.LocalAccountId ---------- + +func TestTokenClaims_LocalAccountId(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + claims TokenClaims + wantID string + }{ + { + "oid_present", + TokenClaims{Oid: "oid-123", Subject: "sub-456"}, + "oid-123", + }, + { + "oid_empty_fallback_to_sub", + TokenClaims{Subject: "sub-456"}, + "sub-456", + }, + { + "both_empty", + TokenClaims{}, + "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.wantID, tt.claims.LocalAccountId()) + }) + } +} + +// ---------- TokenClaims.DisplayUsername ---------- + +func TestTokenClaims_DisplayUsername(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + claims TokenClaims + want string + }{ + { + "preferred_username_v2", + TokenClaims{ + PreferredUsername: "user@example.com", + UniqueName: "legacy@example.com", + }, + "user@example.com", + }, + { + "fallback_to_unique_name_v1", + TokenClaims{UniqueName: "legacy@example.com"}, + "legacy@example.com", + }, + { + "both_empty", + TokenClaims{}, + "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, tt.claims.DisplayUsername()) + }) + } +} + +// ---------- GetClaimsFromAccessToken ---------- + +func TestGetClaimsFromAccessToken(t *testing.T) { + t.Parallel() + + t.Run("full_claims", func(t *testing.T) { + t.Parallel() + token := buildTestJWT(t, map[string]any{ + "oid": "oid-abc", + "tid": "tenant-xyz", + "preferred_username": "user@contoso.com", + "sub": "sub-123", + "name": "Test User", + "iss": "https://login.microsoftonline.com/tid/v2.0", + }) + + claims, err := GetClaimsFromAccessToken(token) + require.NoError(t, err) + assert.Equal(t, "oid-abc", claims.Oid) + assert.Equal(t, "tenant-xyz", claims.TenantId) + assert.Equal(t, "user@contoso.com", + claims.PreferredUsername) + assert.Equal(t, "sub-123", claims.Subject) + assert.Equal(t, "Test User", claims.Name) + }) + + t.Run("malformed_not_jwt", func(t *testing.T) { + t.Parallel() + _, err := GetClaimsFromAccessToken("not-a-jwt") + require.Error(t, err) + assert.Contains(t, err.Error(), "malformed") + }) + + t.Run("malformed_two_segments", func(t *testing.T) { + t.Parallel() + _, err := GetClaimsFromAccessToken("a.b") + require.Error(t, err) + }) + + t.Run("bad_base64_payload", func(t *testing.T) { + t.Parallel() + // Valid 3-segment structure but payload is not + // valid base64url. + _, err := GetClaimsFromAccessToken("aaa.!!!.ccc") + require.Error(t, err) + }) + + t.Run("invalid_json_payload", func(t *testing.T) { + t.Parallel() + notJSON := base64.RawURLEncoding.EncodeToString( + []byte("not json")) + token := fmt.Sprintf("aaa.%s.ccc", notJSON) + _, err := GetClaimsFromAccessToken(token) + require.Error(t, err) + }) + + t.Run("empty_claims", func(t *testing.T) { + t.Parallel() + token := buildTestJWT(t, map[string]any{}) + claims, err := GetClaimsFromAccessToken(token) + require.NoError(t, err) + assert.Empty(t, claims.Oid) + assert.Empty(t, claims.TenantId) + }) +} + +// ---------- LoginScopes / LoginScopesFull ---------- + +func TestLoginScopes(t *testing.T) { + t.Parallel() + + c := cloud.AzurePublic() + scopes := LoginScopes(c) + + require.Len(t, scopes, 1) + assert.Contains(t, scopes[0], "management.azure.com") + assert.True(t, strings.HasSuffix(scopes[0], "//.default")) +} + +func TestLoginScopesFull(t *testing.T) { + t.Parallel() + + c := cloud.AzurePublic() + scopes := LoginScopesFull(c) + + require.Len(t, scopes, 2) + for _, s := range scopes { + assert.True(t, strings.HasSuffix(s, "//.default"), + "scope %q should end with //.default", s) + } + // The two scopes should differ (one from Endpoint, one + // from Audience with trailing slash removed). + assert.NotEqual(t, scopes[0], scopes[1]) +} + +// ---------- memoryCache ---------- + +func TestMemoryCache_ReadAndSet(t *testing.T) { + t.Parallel() + + t.Run("read_missing_key_no_inner", func(t *testing.T) { + t.Parallel() + mc := &memoryCache{cache: map[string][]byte{}} + _, err := mc.Read("missing") + require.Error(t, err) + }) + + t.Run("set_and_read_round_trip", func(t *testing.T) { + t.Parallel() + mc := &memoryCache{cache: map[string][]byte{}} + err := mc.Set("key1", []byte("value1")) + require.NoError(t, err) + + val, err := mc.Read("key1") + require.NoError(t, err) + assert.Equal(t, []byte("value1"), val) + }) + + t.Run("set_same_value_is_noop", func(t *testing.T) { + t.Parallel() + inner := &countingCache{} + mc := &memoryCache{ + cache: map[string][]byte{}, + inner: inner, + } + err := mc.Set("k", []byte("v")) + require.NoError(t, err) + assert.Equal(t, 1, inner.setCalls) + + // Set same value again — inner should NOT be called. + err = mc.Set("k", []byte("v")) + require.NoError(t, err) + assert.Equal(t, 1, inner.setCalls) + }) + + t.Run("set_different_value_propagates", func(t *testing.T) { + t.Parallel() + inner := &countingCache{} + mc := &memoryCache{ + cache: map[string][]byte{}, + inner: inner, + } + err := mc.Set("k", []byte("v1")) + require.NoError(t, err) + assert.Equal(t, 1, inner.setCalls) + + err = mc.Set("k", []byte("v2")) + require.NoError(t, err) + assert.Equal(t, 2, inner.setCalls) + }) + + t.Run("read_falls_through_to_inner", func(t *testing.T) { + t.Parallel() + inner := &countingCache{ + data: map[string][]byte{ + "from-inner": []byte("inner-val"), + }, + } + mc := &memoryCache{ + cache: map[string][]byte{}, + inner: inner, + } + val, err := mc.Read("from-inner") + require.NoError(t, err) + assert.Equal(t, []byte("inner-val"), val) + }) +} + +// ---------- fixedMarshaller ---------- + +func TestFixedMarshaller(t *testing.T) { + t.Parallel() + + t.Run("marshal_returns_current_value", func(t *testing.T) { + t.Parallel() + fm := &fixedMarshaller{val: []byte("hello")} + data, err := fm.Marshal() + require.NoError(t, err) + assert.Equal(t, []byte("hello"), data) + }) + + t.Run("unmarshal_sets_value", func(t *testing.T) { + t.Parallel() + fm := &fixedMarshaller{} + err := fm.Unmarshal([]byte("new-data")) + require.NoError(t, err) + data, err := fm.Marshal() + require.NoError(t, err) + assert.Equal(t, []byte("new-data"), data) + }) + + t.Run("nil_initial_value", func(t *testing.T) { + t.Parallel() + fm := &fixedMarshaller{} + data, err := fm.Marshal() + require.NoError(t, err) + assert.Nil(t, data) + }) +} + +// ---------- AuthFailedError ---------- + +func TestAuthFailedError_Error_NonHTTP(t *testing.T) { + t.Parallel() + + inner := errors.New("some MSAL error") + e := &AuthFailedError{innerErr: inner} + + msg := e.Error() + assert.Contains(t, msg, "failed to authenticate") + assert.Contains(t, msg, "some MSAL error") +} + +func TestAuthFailedError_Error_WithParsedResponse(t *testing.T) { + t.Parallel() + + e := &AuthFailedError{ + RawResp: &http.Response{}, + Parsed: &AadErrorResponse{ + Error: "invalid_grant", + ErrorDescription: "Token expired", + }, + innerErr: errors.New("wrapped"), + } + + msg := e.Error() + assert.Contains(t, msg, "failed to authenticate") + assert.Contains(t, msg, "invalid_grant") + assert.Contains(t, msg, "Token expired") +} + +func TestAuthFailedError_Error_UnparsedHTTP(t *testing.T) { + t.Parallel() + + body := io.NopCloser(strings.NewReader( + `{"error":"server_error"}`)) + resp := &http.Response{ + StatusCode: 500, + Status: "500 Internal Server Error", + Body: body, + Request: &http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "https", + Host: "login.microsoftonline.com", + Path: "/tenant/oauth2/token", + }, + }, + } + + e := &AuthFailedError{ + RawResp: resp, + Parsed: nil, + innerErr: errors.New("wrapped"), + } + + msg := e.Error() + assert.Contains(t, msg, "failed to authenticate") + assert.Contains(t, msg, "POST") + assert.Contains(t, msg, "login.microsoftonline.com") + assert.Contains(t, msg, "500") +} + +func TestAuthFailedError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("root cause") + e := &AuthFailedError{innerErr: inner} + assert.Equal(t, inner, e.Unwrap()) +} + +func TestAuthFailedError_NonRetriable(t *testing.T) { + t.Parallel() + // NonRetriable is a marker method — just confirm it + // doesn't panic. + e := &AuthFailedError{innerErr: errors.New("err")} + e.NonRetriable() +} + +// ---------- ReLoginRequiredError ---------- + +func TestReLoginRequiredError_Error(t *testing.T) { + t.Parallel() + + e := &ReLoginRequiredError{errText: "token expired"} + assert.Equal(t, "token expired", e.Error()) +} + +func TestReLoginRequiredError_NonRetriable(t *testing.T) { + t.Parallel() + e := &ReLoginRequiredError{} + e.NonRetriable() // marker — should not panic +} + +func TestNewReLoginRequiredError(t *testing.T) { + t.Parallel() + + t.Run("nil_response_returns_false", func(t *testing.T) { + t.Parallel() + err, ok := newReLoginRequiredError( + nil, nil, cloud.AzurePublic()) + assert.Nil(t, err) + assert.False(t, ok) + }) + + t.Run("unrelated_error_returns_false", func(t *testing.T) { + t.Parallel() + resp := &AadErrorResponse{ + Error: "server_error", + ErrorDescription: "something else", + } + err, ok := newReLoginRequiredError( + resp, nil, cloud.AzurePublic()) + assert.Nil(t, err) + assert.False(t, ok) + }) + + t.Run("invalid_grant_returns_error", func(t *testing.T) { + t.Parallel() + resp := &AadErrorResponse{ + Error: "invalid_grant", + ErrorDescription: "AADSTS700082: expired", + } + err, ok := newReLoginRequiredError( + resp, + []string{"https://management.azure.com//.default"}, + cloud.AzurePublic(), + ) + assert.True(t, ok) + require.Error(t, err) + assert.Contains(t, err.Error(), "AADSTS700082") + }) + + t.Run("interaction_required_returns_error", func(t *testing.T) { + t.Parallel() + resp := &AadErrorResponse{ + Error: "interaction_required", + ErrorDescription: "need consent", + } + err, ok := newReLoginRequiredError( + resp, + []string{"https://management.azure.com//.default"}, + cloud.AzurePublic(), + ) + assert.True(t, ok) + require.Error(t, err) + }) + + t.Run("extra_scopes_appended_to_login_cmd", func(t *testing.T) { + t.Parallel() + resp := &AadErrorResponse{ + Error: "invalid_grant", + ErrorDescription: "expired", + } + err, ok := newReLoginRequiredError( + resp, + []string{ + "https://management.azure.com//.default", + "https://graph.microsoft.com//.default", + }, + cloud.AzurePublic(), + ) + assert.True(t, ok) + require.Error(t, err) + assert.Contains(t, err.Error(), "expired") + }) + + t.Run("error_code_70043_sets_login_expired", func(t *testing.T) { + t.Parallel() + resp := &AadErrorResponse{ + Error: "invalid_grant", + ErrorDescription: "AADSTS70043: expired", + ErrorCodes: []int{70043}, + } + err, ok := newReLoginRequiredError( + resp, nil, cloud.AzurePublic()) + assert.True(t, ok) + require.Error(t, err) + }) + + t.Run("error_code_50005_adds_device_code_flag", func(t *testing.T) { + t.Parallel() + resp := &AadErrorResponse{ + Error: "interaction_required", + ErrorDescription: "conditional access", + ErrorCodes: []int{50005}, + } + err, ok := newReLoginRequiredError( + resp, nil, cloud.AzurePublic()) + assert.True(t, ok) + require.Error(t, err) + }) +} + +// ---------- helpers ---------- + +// countingCache is a test spy that records Set calls and +// supports pre-seeded Read data. +type countingCache struct { + setCalls int + data map[string][]byte +} + +func (c *countingCache) Read(key string) ([]byte, error) { + if c.data != nil { + if v, ok := c.data[key]; ok { + return v, nil + } + } + return nil, errCacheKeyNotFound +} + +func (c *countingCache) Set(_ string, _ []byte) error { + c.setCalls++ + return nil +} diff --git a/cli/azd/pkg/azapi/deployments_unit_test.go b/cli/azd/pkg/azapi/deployments_unit_test.go new file mode 100644 index 00000000000..c659309f334 --- /dev/null +++ b/cli/azd/pkg/azapi/deployments_unit_test.go @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azapi + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armdeploymentstacks" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAzCliDeploymentOutput_Secured(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + typeName string + want bool + }{ + {"SecureString", "SecureString", true}, + {"securestring_lower", "securestring", true}, + {"SECURESTRING_upper", "SECURESTRING", true}, + {"SecureObject", "SecureObject", true}, + {"secureobject_lower", "secureobject", true}, + {"SECUREOBJECT_upper", "SECUREOBJECT", true}, + {"plain_string", "String", false}, + {"plain_int", "Int", false}, + {"plain_bool", "Bool", false}, + {"plain_object", "Object", false}, + {"plain_array", "Array", false}, + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + o := AzCliDeploymentOutput{Type: tt.typeName} + assert.Equal(t, tt.want, o.Secured()) + }) + } +} + +func TestCreateDeploymentOutput(t *testing.T) { + t.Parallel() + + t.Run("nil_returns_empty_map", func(t *testing.T) { + t.Parallel() + result := CreateDeploymentOutput(nil) + require.NotNil(t, result) + assert.Empty(t, result) + }) + + t.Run("single_output", func(t *testing.T) { + t.Parallel() + raw := map[string]any{ + "endpoint": map[string]any{ + "type": "String", + "value": "https://example.com", + }, + } + result := CreateDeploymentOutput(raw) + require.Len(t, result, 1) + assert.Equal(t, "String", result["endpoint"].Type) + assert.Equal(t, "https://example.com", result["endpoint"].Value) + }) + + t.Run("multiple_outputs", func(t *testing.T) { + t.Parallel() + raw := map[string]any{ + "endpoint": map[string]any{ + "type": "String", + "value": "https://example.com", + }, + "key": map[string]any{ + "type": "SecureString", + "value": "secret-key-123", + }, + "count": map[string]any{ + "type": "Int", + "value": float64(42), + }, + } + result := CreateDeploymentOutput(raw) + require.Len(t, result, 3) + assert.True(t, result["key"].Secured()) + assert.False(t, result["endpoint"].Secured()) + assert.Equal(t, float64(42), result["count"].Value) + }) +} + +func TestConvertFromStandardProvisioningState(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input armresources.ProvisioningState + want DeploymentProvisioningState + }{ + {"Accepted", armresources.ProvisioningStateAccepted, + DeploymentProvisioningStateAccepted}, + {"Canceled", armresources.ProvisioningStateCanceled, + DeploymentProvisioningStateCanceled}, + {"Creating", armresources.ProvisioningStateCreating, + DeploymentProvisioningStateCreating}, + {"Deleted", armresources.ProvisioningStateDeleted, + DeploymentProvisioningStateDeleted}, + {"Deleting", armresources.ProvisioningStateDeleting, + DeploymentProvisioningStateDeleting}, + {"Failed", armresources.ProvisioningStateFailed, + DeploymentProvisioningStateFailed}, + {"NotSpecified", armresources.ProvisioningStateNotSpecified, + DeploymentProvisioningStateNotSpecified}, + {"Ready", armresources.ProvisioningStateReady, + DeploymentProvisioningStateReady}, + {"Running", armresources.ProvisioningStateRunning, + DeploymentProvisioningStateRunning}, + {"Succeeded", armresources.ProvisioningStateSucceeded, + DeploymentProvisioningStateSucceeded}, + {"Updating", armresources.ProvisioningStateUpdating, + DeploymentProvisioningStateUpdating}, + {"unknown_returns_empty", + armresources.ProvisioningState("SomethingNew"), + DeploymentProvisioningState("")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := convertFromStandardProvisioningState(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestConvertFromStacksProvisioningState(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input armdeploymentstacks.DeploymentStackProvisioningState + want DeploymentProvisioningState + }{ + {"Canceled", + armdeploymentstacks.DeploymentStackProvisioningStateCanceled, + DeploymentProvisioningStateCanceled}, + {"Canceling", + armdeploymentstacks.DeploymentStackProvisioningStateCanceling, + DeploymentProvisioningStateCanceling}, + {"Creating", + armdeploymentstacks.DeploymentStackProvisioningStateCreating, + DeploymentProvisioningStateCreating}, + {"Deleting", + armdeploymentstacks.DeploymentStackProvisioningStateDeleting, + DeploymentProvisioningStateDeleting}, + {"DeletingResources", + armdeploymentstacks.DeploymentStackProvisioningStateDeletingResources, + DeploymentProvisioningStateDeletingResources}, + {"Deploying", + armdeploymentstacks.DeploymentStackProvisioningStateDeploying, + DeploymentProvisioningStateDeploying}, + {"Failed", + armdeploymentstacks.DeploymentStackProvisioningStateFailed, + DeploymentProvisioningStateFailed}, + {"Succeeded", + armdeploymentstacks.DeploymentStackProvisioningStateSucceeded, + DeploymentProvisioningStateSucceeded}, + {"UpdatingDenyAssignments", + armdeploymentstacks.DeploymentStackProvisioningStateUpdatingDenyAssignments, + DeploymentProvisioningStateUpdatingDenyAssignments}, + {"Validating", + armdeploymentstacks.DeploymentStackProvisioningStateValidating, + DeploymentProvisioningStateValidating}, + {"Waiting", + armdeploymentstacks.DeploymentStackProvisioningStateWaiting, + DeploymentProvisioningStateWaiting}, + {"unknown_returns_empty", + armdeploymentstacks.DeploymentStackProvisioningState("Unknown"), + DeploymentProvisioningState("")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := convertFromStacksProvisioningState(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestStackDeployments_GenerateDeploymentName(t *testing.T) { + t.Parallel() + + sd := &StackDeployments{} + + tests := []struct { + name string + baseName string + want string + }{ + {"simple", "my-env", "azd-stack-my-env"}, + {"empty", "", "azd-stack-"}, + {"complex", "a-b-c-123", "azd-stack-a-b-c-123"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := sd.GenerateDeploymentName(tt.baseName) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestGroupByResourceGroup(t *testing.T) { + t.Parallel() + + t.Run("nil_input", func(t *testing.T) { + t.Parallel() + result, err := GroupByResourceGroup(nil) + require.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("empty_input", func(t *testing.T) { + t.Parallel() + result, err := GroupByResourceGroup( + []*armresources.ResourceReference{}, + ) + require.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("groups_resources_correctly", func(t *testing.T) { + t.Parallel() + refs := []*armresources.ResourceReference{ + {ID: new( + "/subscriptions/sub1/resourceGroups/rg1" + + "/providers/Microsoft.Web/sites/app1")}, + {ID: new( + "/subscriptions/sub1/resourceGroups/rg1" + + "/providers/Microsoft.Storage/storageAccounts/sa1")}, + {ID: new( + "/subscriptions/sub1/resourceGroups/rg2" + + "/providers/Microsoft.Web/sites/app2")}, + } + + result, err := GroupByResourceGroup(refs) + require.NoError(t, err) + require.Len(t, result, 2) + assert.Len(t, result["rg1"], 2) + assert.Len(t, result["rg2"], 1) + }) + + t.Run("excludes_resource_group_type", func(t *testing.T) { + t.Parallel() + refs := []*armresources.ResourceReference{ + {ID: new( + "/subscriptions/sub1/resourceGroups/rg1")}, + {ID: new( + "/subscriptions/sub1/resourceGroups/rg1" + + "/providers/Microsoft.Web/sites/app1")}, + } + + result, err := GroupByResourceGroup(refs) + require.NoError(t, err) + require.Len(t, result, 1) + // Only the web app, not the resource group itself + assert.Len(t, result["rg1"], 1) + assert.Equal(t, "app1", result["rg1"][0].Name) + }) + + t.Run("invalid_resource_id", func(t *testing.T) { + t.Parallel() + refs := []*armresources.ResourceReference{ + {ID: new("not-a-valid-resource-id")}, + } + + _, err := GroupByResourceGroup(refs) + require.Error(t, err) + }) + + t.Run("subscription_level_resources_skipped", func(t *testing.T) { + t.Parallel() + refs := []*armresources.ResourceReference{ + {ID: new( + "/subscriptions/sub1/providers" + + "/Microsoft.Resources/deployments/deploy1")}, + } + + result, err := GroupByResourceGroup(refs) + require.NoError(t, err) + assert.Empty(t, result) + }) +} + +func TestIsNotLoggedInMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want bool + }{ + { + "no_subscription_found", + "ERROR: No subscription found", + true, + }, + { + "please_run_az_login_single_quotes", + "Please run 'az login' to setup account.", + true, + }, + { + "please_run_az_login_double_quotes", + `Please run "az login" to access your accounts.`, + true, + }, + { + "unrelated_message", + "deployment succeeded", + false, + }, + { + "empty_string", + "", + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, isNotLoggedInMessage(tt.input)) + }) + } +} + +func TestIsRefreshTokenExpiredMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want bool + }{ + { + "AADSTS70043", + "AADSTS70043: The refresh token has expired", + true, + }, + { + "AADSTS700082", + "AADSTS700082: expired due to inactivity", + true, + }, + { + "unrelated_error", + "AADSTS50001: something else", + false, + }, + { + "empty", + "", + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, + isRefreshTokenExpiredMessage(tt.input)) + }) + } +} diff --git a/cli/azd/pkg/azdext/envelope_coverage_test.go b/cli/azd/pkg/azdext/envelope_coverage_test.go new file mode 100644 index 00000000000..57259771c6b --- /dev/null +++ b/cli/azd/pkg/azdext/envelope_coverage_test.go @@ -0,0 +1,470 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// ----------------------------------------------------------------------- +// FrameworkServiceEnvelope tests +// ----------------------------------------------------------------------- + +func TestFrameworkServiceEnvelope_GetSetRequestId(t *testing.T) { + env := NewFrameworkServiceEnvelope() + msg := &FrameworkServiceMessage{RequestId: "req-123"} + + require.Equal(t, "req-123", env.GetRequestId(context.Background(), msg)) + + env.SetRequestId(context.Background(), msg, "req-456") + require.Equal(t, "req-456", msg.RequestId) +} + +func TestFrameworkServiceEnvelope_GetSetError(t *testing.T) { + env := NewFrameworkServiceEnvelope() + + t.Run("NilError", func(t *testing.T) { + msg := &FrameworkServiceMessage{} + require.Nil(t, env.GetError(msg)) + }) + + t.Run("RoundTripLocalError", func(t *testing.T) { + msg := &FrameworkServiceMessage{} + localErr := &LocalError{ + Message: "validation failed", + Code: "invalid_config", + Category: LocalErrorCategoryValidation, + } + env.SetError(msg, localErr) + require.NotNil(t, msg.Error) + + unwrapped := env.GetError(msg) + require.Error(t, unwrapped) + require.Contains(t, unwrapped.Error(), "validation failed") + }) + + t.Run("RoundTripServiceError", func(t *testing.T) { + msg := &FrameworkServiceMessage{} + svcErr := &ServiceError{ + Message: "not found", + ErrorCode: "NotFound", + StatusCode: 404, + ServiceName: "api.example.com", + } + env.SetError(msg, svcErr) + require.NotNil(t, msg.Error) + + unwrapped := env.GetError(msg) + require.Error(t, unwrapped) + require.Contains(t, unwrapped.Error(), "not found") + }) +} + +func TestFrameworkServiceEnvelope_GetInnerMessage(t *testing.T) { + env := NewFrameworkServiceEnvelope() + + tests := []struct { + name string + msg *FrameworkServiceMessage + wantNil bool + wantType string + }{ + { + name: "RegisterRequest", + msg: &FrameworkServiceMessage{ + MessageType: &FrameworkServiceMessage_RegisterFrameworkServiceRequest{ + RegisterFrameworkServiceRequest: &RegisterFrameworkServiceRequest{}, + }, + }, + wantType: "RegisterFrameworkServiceRequest", + }, + { + name: "InitializeRequest", + msg: &FrameworkServiceMessage{ + MessageType: &FrameworkServiceMessage_InitializeRequest{ + InitializeRequest: &FrameworkServiceInitializeRequest{}, + }, + }, + wantType: "FrameworkServiceInitializeRequest", + }, + { + name: "BuildRequest", + msg: &FrameworkServiceMessage{ + MessageType: &FrameworkServiceMessage_BuildRequest{ + BuildRequest: &FrameworkServiceBuildRequest{}, + }, + }, + wantType: "FrameworkServiceBuildRequest", + }, + { + name: "PackageRequest", + msg: &FrameworkServiceMessage{ + MessageType: &FrameworkServiceMessage_PackageRequest{ + PackageRequest: &FrameworkServicePackageRequest{}, + }, + }, + wantType: "FrameworkServicePackageRequest", + }, + { + name: "RestoreRequest", + msg: &FrameworkServiceMessage{ + MessageType: &FrameworkServiceMessage_RestoreRequest{ + RestoreRequest: &FrameworkServiceRestoreRequest{}, + }, + }, + wantType: "FrameworkServiceRestoreRequest", + }, + { + name: "ProgressMessage", + msg: &FrameworkServiceMessage{ + MessageType: &FrameworkServiceMessage_ProgressMessage{ + ProgressMessage: &FrameworkServiceProgressMessage{ + Message: "building...", + }, + }, + }, + wantType: "FrameworkServiceProgressMessage", + }, + { + name: "NilMessageType", + msg: &FrameworkServiceMessage{}, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inner := env.GetInnerMessage(tt.msg) + if tt.wantNil { + require.Nil(t, inner) + } else { + require.NotNil(t, inner) + } + }) + } +} + +func TestFrameworkServiceEnvelope_ProgressMessage(t *testing.T) { + env := NewFrameworkServiceEnvelope() + + t.Run("IsProgressMessage_True", func(t *testing.T) { + msg := &FrameworkServiceMessage{ + MessageType: &FrameworkServiceMessage_ProgressMessage{ + ProgressMessage: &FrameworkServiceProgressMessage{ + Message: "step 1/3", + }, + }, + } + require.True(t, env.IsProgressMessage(msg)) + require.Equal(t, "step 1/3", env.GetProgressMessage(msg)) + }) + + t.Run("IsProgressMessage_False", func(t *testing.T) { + msg := &FrameworkServiceMessage{ + MessageType: &FrameworkServiceMessage_BuildRequest{ + BuildRequest: &FrameworkServiceBuildRequest{}, + }, + } + require.False(t, env.IsProgressMessage(msg)) + require.Empty(t, env.GetProgressMessage(msg)) + }) + + t.Run("CreateProgressMessage", func(t *testing.T) { + msg := env.CreateProgressMessage("req-1", "deploying...") + require.NotNil(t, msg) + require.Equal(t, "req-1", msg.RequestId) + require.True(t, env.IsProgressMessage(msg)) + require.Equal(t, "deploying...", env.GetProgressMessage(msg)) + }) +} + +// ----------------------------------------------------------------------- +// ServiceTargetEnvelope tests +// ----------------------------------------------------------------------- + +func TestServiceTargetEnvelope_GetSetRequestId(t *testing.T) { + env := NewServiceTargetEnvelope() + msg := &ServiceTargetMessage{RequestId: "st-123"} + + require.Equal(t, "st-123", env.GetRequestId(context.Background(), msg)) + + env.SetRequestId(context.Background(), msg, "st-456") + require.Equal(t, "st-456", msg.RequestId) +} + +func TestServiceTargetEnvelope_GetSetError(t *testing.T) { + env := NewServiceTargetEnvelope() + + t.Run("NilError", func(t *testing.T) { + msg := &ServiceTargetMessage{} + require.Nil(t, env.GetError(msg)) + }) + + t.Run("RoundTripError", func(t *testing.T) { + msg := &ServiceTargetMessage{} + svcErr := &ServiceError{ + Message: "deploy failed", + ErrorCode: "DeploymentFailed", + } + env.SetError(msg, svcErr) + require.NotNil(t, msg.Error) + + unwrapped := env.GetError(msg) + require.Error(t, unwrapped) + require.Contains(t, unwrapped.Error(), "deploy failed") + }) +} + +func TestServiceTargetEnvelope_GetInnerMessage(t *testing.T) { + env := NewServiceTargetEnvelope() + + tests := []struct { + name string + msg *ServiceTargetMessage + wantNil bool + }{ + { + name: "RegisterRequest", + msg: &ServiceTargetMessage{ + MessageType: &ServiceTargetMessage_RegisterServiceTargetRequest{ + RegisterServiceTargetRequest: &RegisterServiceTargetRequest{}, + }, + }, + }, + { + name: "DeployRequest", + msg: &ServiceTargetMessage{ + MessageType: &ServiceTargetMessage_DeployRequest{ + DeployRequest: &ServiceTargetDeployRequest{}, + }, + }, + }, + { + name: "GetTargetResourceRequest", + msg: &ServiceTargetMessage{ + MessageType: &ServiceTargetMessage_GetTargetResourceRequest{ + GetTargetResourceRequest: &GetTargetResourceRequest{}, + }, + }, + }, + { + name: "PackageRequest", + msg: &ServiceTargetMessage{ + MessageType: &ServiceTargetMessage_PackageRequest{ + PackageRequest: &ServiceTargetPackageRequest{}, + }, + }, + }, + { + name: "PublishRequest", + msg: &ServiceTargetMessage{ + MessageType: &ServiceTargetMessage_PublishRequest{ + PublishRequest: &ServiceTargetPublishRequest{}, + }, + }, + }, + { + name: "EndpointsRequest", + msg: &ServiceTargetMessage{ + MessageType: &ServiceTargetMessage_EndpointsRequest{ + EndpointsRequest: &ServiceTargetEndpointsRequest{}, + }, + }, + }, + { + name: "ProgressMessage", + msg: &ServiceTargetMessage{ + MessageType: &ServiceTargetMessage_ProgressMessage{ + ProgressMessage: &ServiceTargetProgressMessage{ + Message: "progress", + }, + }, + }, + }, + { + name: "NilMessageType", + msg: &ServiceTargetMessage{}, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inner := env.GetInnerMessage(tt.msg) + if tt.wantNil { + require.Nil(t, inner) + } else { + require.NotNil(t, inner) + } + }) + } +} + +func TestServiceTargetEnvelope_ProgressMessage(t *testing.T) { + env := NewServiceTargetEnvelope() + + t.Run("IsProgressMessage_True", func(t *testing.T) { + msg := &ServiceTargetMessage{ + MessageType: &ServiceTargetMessage_ProgressMessage{ + ProgressMessage: &ServiceTargetProgressMessage{ + Message: "deploying...", + }, + }, + } + require.True(t, env.IsProgressMessage(msg)) + require.Equal(t, "deploying...", env.GetProgressMessage(msg)) + }) + + t.Run("IsProgressMessage_False", func(t *testing.T) { + msg := &ServiceTargetMessage{ + MessageType: &ServiceTargetMessage_DeployRequest{ + DeployRequest: &ServiceTargetDeployRequest{}, + }, + } + require.False(t, env.IsProgressMessage(msg)) + require.Empty(t, env.GetProgressMessage(msg)) + }) + + t.Run("CreateProgressMessage", func(t *testing.T) { + msg := env.CreateProgressMessage("st-1", "packaging...") + require.NotNil(t, msg) + require.Equal(t, "st-1", msg.RequestId) + require.True(t, env.IsProgressMessage(msg)) + require.Equal(t, "packaging...", env.GetProgressMessage(msg)) + }) +} + +// ----------------------------------------------------------------------- +// EventMessageEnvelope tests +// ----------------------------------------------------------------------- + +func TestEventMessageEnvelope_GetInnerMessage(t *testing.T) { + env := NewEventMessageEnvelope() + + tests := []struct { + name string + msg *EventMessage + wantNil bool + }{ + { + name: "SubscribeProjectEvent", + msg: &EventMessage{ + MessageType: &EventMessage_SubscribeProjectEvent{ + SubscribeProjectEvent: &SubscribeProjectEvent{ + EventNames: []string{"provision"}, + }, + }, + }, + }, + { + name: "InvokeProjectHandler", + msg: &EventMessage{ + MessageType: &EventMessage_InvokeProjectHandler{ + InvokeProjectHandler: &InvokeProjectHandler{ + EventName: "provision", + }, + }, + }, + }, + { + name: "ProjectHandlerStatus", + msg: &EventMessage{ + MessageType: &EventMessage_ProjectHandlerStatus{ + ProjectHandlerStatus: &ProjectHandlerStatus{ + EventName: "provision", + }, + }, + }, + }, + { + name: "SubscribeServiceEvent", + msg: &EventMessage{ + MessageType: &EventMessage_SubscribeServiceEvent{ + SubscribeServiceEvent: &SubscribeServiceEvent{ + EventNames: []string{"deploy"}, + }, + }, + }, + }, + { + name: "InvokeServiceHandler", + msg: &EventMessage{ + MessageType: &EventMessage_InvokeServiceHandler{ + InvokeServiceHandler: &InvokeServiceHandler{ + EventName: "deploy", + }, + }, + }, + }, + { + name: "ServiceHandlerStatus", + msg: &EventMessage{ + MessageType: &EventMessage_ServiceHandlerStatus{ + ServiceHandlerStatus: &ServiceHandlerStatus{ + EventName: "deploy", + ServiceName: "api", + }, + }, + }, + }, + { + name: "NilMessageType", + msg: &EventMessage{}, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inner := env.GetInnerMessage(tt.msg) + if tt.wantNil { + require.Nil(t, inner) + } else { + require.NotNil(t, inner) + } + }) + } +} + +func TestEventMessageEnvelope_NoOps(t *testing.T) { + env := NewEventMessageEnvelope() + msg := &EventMessage{} + + // SetRequestId is a no-op + env.SetRequestId(context.Background(), msg, "ignored") + + // GetError always returns nil + require.Nil(t, env.GetError(msg)) + + // SetError is a no-op + env.SetError(msg, &LocalError{Message: "ignored"}) + + // IsProgressMessage always false + require.False(t, env.IsProgressMessage(msg)) + + // GetProgressMessage always empty + require.Empty(t, env.GetProgressMessage(msg)) + + // CreateProgressMessage always nil + require.Nil(t, env.CreateProgressMessage("id", "msg")) +} + +func TestEventMessageEnvelope_GetRequestId_NoContext(t *testing.T) { + env := NewEventMessageEnvelope() + + // Without extension ID in context, should return "" + msg := &EventMessage{ + MessageType: &EventMessage_SubscribeProjectEvent{ + SubscribeProjectEvent: &SubscribeProjectEvent{ + EventNames: []string{"provision"}, + }, + }, + } + + id := env.GetRequestId(context.Background(), msg) + require.Empty(t, id) +} diff --git a/cli/azd/pkg/azdext/run_coverage_test.go b/cli/azd/pkg/azdext/run_coverage_test.go new file mode 100644 index 00000000000..5227df182a8 --- /dev/null +++ b/cli/azd/pkg/azdext/run_coverage_test.go @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestErrorSuggestion(t *testing.T) { + tests := []struct { + name string + err error + expected string + }{ + { + name: "LocalErrorWithSuggestion", + err: &LocalError{ + Message: "missing config", + Suggestion: "Run azd init first", + }, + expected: "Run azd init first", + }, + { + name: "LocalErrorNoSuggestion", + err: &LocalError{ + Message: "missing config", + }, + expected: "", + }, + { + name: "ServiceErrorWithSuggestion", + err: &ServiceError{ + Message: "rate limited", + Suggestion: "Retry after 60 seconds", + }, + expected: "Retry after 60 seconds", + }, + { + name: "ServiceErrorNoSuggestion", + err: &ServiceError{ + Message: "rate limited", + }, + expected: "", + }, + { + name: "GenericError", + err: &testGenericError{msg: "generic"}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ErrorSuggestion(tt.err) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestErrorMessage(t *testing.T) { + tests := []struct { + name string + err error + expected string + }{ + { + name: "LocalError", + err: &LocalError{ + Message: "config invalid", + }, + expected: "config invalid", + }, + { + name: "ServiceError", + err: &ServiceError{ + Message: "service unavailable", + }, + expected: "service unavailable", + }, + { + name: "LocalErrorEmptyMessage", + err: &LocalError{ + Message: "", + }, + expected: "", + }, + { + name: "GenericError", + err: &testGenericError{msg: "generic"}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ErrorMessage(tt.err) + require.Equal(t, tt.expected, result) + }) + } +} + +// testGenericError is a plain error for testing non-extension error types. +type testGenericError struct { + msg string +} + +func (e *testGenericError) Error() string { + return e.msg +} + +func TestVersion_IsSet(t *testing.T) { + require.NotEmpty(t, Version) + require.Equal(t, "0.1.0", Version) +} + +func TestAiErrorConstants(t *testing.T) { + // Verify domain constant + require.Equal(t, "azd.ai", AiErrorDomain) + + // Verify all reason codes are non-empty and unique + reasons := []string{ + AiErrorReasonMissingSubscription, + AiErrorReasonLocationRequired, + AiErrorReasonQuotaLocation, + AiErrorReasonModelNotFound, + AiErrorReasonNoModelsMatch, + AiErrorReasonNoDeploymentMatch, + AiErrorReasonNoValidSkus, + AiErrorReasonNoLocationsWithQuota, + AiErrorReasonInvalidCapacity, + AiErrorReasonInteractiveRequired, + } + + seen := make(map[string]bool, len(reasons)) + for _, r := range reasons { + require.NotEmpty(t, r, "reason code must not be empty") + require.False(t, seen[r], "duplicate reason code: %s", r) + seen[r] = true + } +} diff --git a/cli/azd/pkg/azdo/azdo_additional_test.go b/cli/azd/pkg/azdo/azdo_additional_test.go new file mode 100644 index 00000000000..4ab086fcfad --- /dev/null +++ b/cli/azd/pkg/azdo/azdo_additional_test.go @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdo + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/entraid" + "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateBuildDefinitionVariable(t *testing.T) { + t.Run("standard variable", func(t *testing.T) { + v := createBuildDefinitionVariable("value1", false, false) + assert.Equal(t, "value1", *v.Value) + assert.False(t, *v.IsSecret) + assert.False(t, *v.AllowOverride) + }) + + t.Run("secret variable", func(t *testing.T) { + v := createBuildDefinitionVariable("secret", true, false) + assert.Equal(t, "secret", *v.Value) + assert.True(t, *v.IsSecret) + assert.False(t, *v.AllowOverride) + }) + + t.Run("overridable variable", func(t *testing.T) { + v := createBuildDefinitionVariable("val", false, true) + assert.True(t, *v.AllowOverride) + }) +} + +func TestGetDefinitionVariables_Bicep(t *testing.T) { + env := environment.NewWithValues( + "test-env", + map[string]string{ + "AZURE_LOCATION": "eastus2", + }, + ) + + creds := &entraid.AzureCredentials{ + SubscriptionId: "sub-123", + TenantId: "tenant-456", + ClientId: "client-789", + ClientSecret: "secret-abc", + } + + opts := provisioning.Options{ + Provider: provisioning.Bicep, + } + + vars, err := getDefinitionVariables( + env, creds, opts, nil, nil, + ) + require.NoError(t, err) + require.NotNil(t, vars) + + m := *vars + + // Standard variables + assert.Equal(t, "eastus2", *m["AZURE_LOCATION"].Value) + assert.Equal(t, "test-env", *m["AZURE_ENV_NAME"].Value) + assert.Equal( + t, ServiceConnectionName, + *m["AZURE_SERVICE_CONNECTION"].Value, + ) + assert.Equal( + t, "sub-123", *m["AZURE_SUBSCRIPTION_ID"].Value, + ) + + // Should NOT have Terraform-specific variables + _, hasTenantID := m["ARM_TENANT_ID"] + assert.False(t, hasTenantID) +} + +func TestGetDefinitionVariables_BicepWithResourceGroup(t *testing.T) { + env := environment.NewWithValues( + "test-env", + map[string]string{ + "AZURE_LOCATION": "westus", + "AZURE_RESOURCE_GROUP": "my-rg", + }, + ) + + opts := provisioning.Options{ + Provider: provisioning.Bicep, + } + + vars, err := getDefinitionVariables( + env, nil, opts, nil, nil, + ) + require.NoError(t, err) + + m := *vars + + // Bicep with resource group should include it + assert.Equal(t, "my-rg", *m["AZURE_RESOURCE_GROUP"].Value) +} + +func TestGetDefinitionVariables_Terraform(t *testing.T) { + env := environment.NewWithValues( + "test-env", + map[string]string{ + "AZURE_LOCATION": "eastus", + "RS_RESOURCE_GROUP": "tf-state-rg", + "RS_STORAGE_ACCOUNT": "tfstatestorage", + "RS_CONTAINER_NAME": "tfstate", + }, + ) + + creds := &entraid.AzureCredentials{ + SubscriptionId: "sub-123", + TenantId: "tenant-456", + ClientId: "client-789", + ClientSecret: "secret-abc", + } + + opts := provisioning.Options{ + Provider: provisioning.Terraform, + } + + vars, err := getDefinitionVariables( + env, creds, opts, nil, nil, + ) + require.NoError(t, err) + + m := *vars + + // Terraform-specific ARM variables + assert.Equal(t, "tenant-456", *m["ARM_TENANT_ID"].Value) + assert.Equal(t, "client-789", *m["ARM_CLIENT_ID"].Value) + assert.True(t, *m["ARM_CLIENT_ID"].IsSecret) + assert.Equal( + t, "secret-abc", *m["ARM_CLIENT_SECRET"].Value, + ) + assert.True(t, *m["ARM_CLIENT_SECRET"].IsSecret) + + // Terraform remote state variables + assert.Equal( + t, "tf-state-rg", *m["RS_RESOURCE_GROUP"].Value, + ) + assert.Equal( + t, "tfstatestorage", *m["RS_STORAGE_ACCOUNT"].Value, + ) + assert.Equal( + t, "tfstate", *m["RS_CONTAINER_NAME"].Value, + ) +} + +func TestGetDefinitionVariables_TerraformMissingRemoteState( + t *testing.T, +) { + env := environment.NewWithValues( + "test-env", + map[string]string{ + "AZURE_LOCATION": "eastus", + }, + ) + + opts := provisioning.Options{ + Provider: provisioning.Terraform, + } + + _, err := getDefinitionVariables(env, nil, opts, nil, nil) + require.Error(t, err) + assert.Contains( + t, err.Error(), + "terraform remote state is not correctly configured", + ) +} + +func TestGetDefinitionVariables_AdditionalSecretsAndVars( + t *testing.T, +) { + env := environment.NewWithValues( + "test-env", + map[string]string{ + "AZURE_LOCATION": "eastus", + }, + ) + + opts := provisioning.Options{ + Provider: provisioning.Bicep, + } + + secrets := map[string]string{ + "MY_SECRET": "secret-value", + } + variables := map[string]string{ + "MY_VAR": "var-value", + } + + vars, err := getDefinitionVariables( + env, nil, opts, secrets, variables, + ) + require.NoError(t, err) + + m := *vars + + // Additional secrets should be marked as secret + assert.Equal(t, "secret-value", *m["MY_SECRET"].Value) + assert.True(t, *m["MY_SECRET"].IsSecret) + + // Additional variables should allow override + assert.Equal(t, "var-value", *m["MY_VAR"].Value) + assert.True(t, *m["MY_VAR"].AllowOverride) +} + +func TestGetDefinitionVariables_NilCredentials(t *testing.T) { + env := environment.NewWithValues( + "test-env", + map[string]string{ + "AZURE_LOCATION": "eastus", + }, + ) + + opts := provisioning.Options{ + Provider: provisioning.Bicep, + } + + vars, err := getDefinitionVariables( + env, nil, opts, nil, nil, + ) + require.NoError(t, err) + + m := *vars + + // Should not have subscription ID without credentials + _, hasSubId := m["AZURE_SUBSCRIPTION_ID"] + assert.False(t, hasSubId) +} + +func TestConstants(t *testing.T) { + assert.Equal(t, "dev.azure.com", AzDoHostName) + assert.Equal(t, "AZURE_DEVOPS_EXT_PAT", AzDoPatName) + assert.Equal( + t, "AZURE_DEVOPS_ORG_NAME", AzDoEnvironmentOrgName, + ) + assert.Equal( + t, ".azdo/pipelines/azure-dev.yml", AzurePipelineYamlPath, + ) + assert.Equal(t, "main", DefaultBranch) + assert.Equal(t, "azconnection", ServiceConnectionName) +} diff --git a/cli/azd/pkg/azsdk/storage/storage_blob_client_coverage_test.go b/cli/azd/pkg/azsdk/storage/storage_blob_client_coverage_test.go new file mode 100644 index 00000000000..94d6b8be972 --- /dev/null +++ b/cli/azd/pkg/azsdk/storage/storage_blob_client_coverage_test.go @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package storage + +import ( + "errors" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/azure/azure-dev/cli/azd/pkg/cloud" + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_NewBlobSdkClient_UsesCustomEndpoint(t *testing.T) { + mockCredProvider := &mockMultiTenantCredentialProvider{} + mockTenantResolver := &mockSubscriptionTenantResolver{} + mockCred := &mockTokenCredential{} + mockConfigMgr := &mockUserConfigManager{} + + accountCfg := &AccountConfig{ + AccountName: "myaccount", + ContainerName: "mycontainer", + Endpoint: "blob.core.usgovcloudapi.net", + } + + coreClientOptions := &azcore.ClientOptions{} + + mockConfigMgr.On("Load").Return(config.NewEmptyConfig(), nil) + mockCredProvider.On( + "GetTokenCredential", mock.Anything, "", + ).Return(mockCred, nil) + + client, err := NewBlobSdkClient( + mockCredProvider, + accountCfg, + mockConfigMgr, + coreClientOptions, + cloud.AzurePublic(), + mockTenantResolver, + ) + + require.NoError(t, err) + require.NotNil(t, client) + // Custom endpoint should NOT be overwritten + require.Equal(t, "blob.core.usgovcloudapi.net", accountCfg.Endpoint) + mockCredProvider.AssertExpectations(t) +} + +func Test_NewBlobSdkClient_DefaultEndpointFromCloud(t *testing.T) { + mockCredProvider := &mockMultiTenantCredentialProvider{} + mockTenantResolver := &mockSubscriptionTenantResolver{} + mockCred := &mockTokenCredential{} + mockConfigMgr := &mockUserConfigManager{} + + accountCfg := &AccountConfig{ + AccountName: "myaccount", + ContainerName: "mycontainer", + // Endpoint empty — should be populated from cloud + } + + coreClientOptions := &azcore.ClientOptions{} + + mockConfigMgr.On("Load").Return(config.NewEmptyConfig(), nil) + mockCredProvider.On( + "GetTokenCredential", mock.Anything, "", + ).Return(mockCred, nil) + + azureCloud := cloud.AzurePublic() + + client, err := NewBlobSdkClient( + mockCredProvider, + accountCfg, + mockConfigMgr, + coreClientOptions, + azureCloud, + mockTenantResolver, + ) + + require.NoError(t, err) + require.NotNil(t, client) + require.Equal(t, azureCloud.StorageEndpointSuffix, accountCfg.Endpoint) +} + +func Test_NewBlobSdkClient_CredentialProviderError(t *testing.T) { + mockCredProvider := &mockMultiTenantCredentialProvider{} + mockTenantResolver := &mockSubscriptionTenantResolver{} + mockConfigMgr := &mockUserConfigManager{} + + accountCfg := &AccountConfig{ + AccountName: "myaccount", + ContainerName: "mycontainer", + } + + coreClientOptions := &azcore.ClientOptions{} + + mockConfigMgr.On("Load").Return(config.NewEmptyConfig(), nil) + mockCredProvider.On( + "GetTokenCredential", mock.Anything, "", + ).Return(nil, errors.New("credential unavailable")) + + client, err := NewBlobSdkClient( + mockCredProvider, + accountCfg, + mockConfigMgr, + coreClientOptions, + cloud.AzurePublic(), + mockTenantResolver, + ) + + require.Error(t, err) + require.Nil(t, client) + require.Contains(t, err.Error(), "credential unavailable") +} + +func Test_NewBlobSdkClient_EmptyDefaultSubscriptionIgnored(t *testing.T) { + mockCredProvider := &mockMultiTenantCredentialProvider{} + mockTenantResolver := &mockSubscriptionTenantResolver{} + mockCred := &mockTokenCredential{} + mockConfigMgr := &mockUserConfigManager{} + + accountCfg := &AccountConfig{ + AccountName: "myaccount", + ContainerName: "mycontainer", + } + + coreClientOptions := &azcore.ClientOptions{} + + // User config has empty default subscription + userCfg := config.NewConfig(map[string]any{ + "defaults": map[string]any{ + "subscription": "", + }, + }) + mockConfigMgr.On("Load").Return(userCfg, nil) + mockCredProvider.On( + "GetTokenCredential", mock.Anything, "", + ).Return(mockCred, nil) + + client, err := NewBlobSdkClient( + mockCredProvider, + accountCfg, + mockConfigMgr, + coreClientOptions, + cloud.AzurePublic(), + mockTenantResolver, + ) + + require.NoError(t, err) + require.NotNil(t, client) + // Tenant resolver should NOT be called for empty subscription + mockTenantResolver.AssertNotCalled( + t, "LookupTenant", mock.Anything, mock.Anything, + ) +} + +func Test_NewBlobClient_ReturnsValidClient(t *testing.T) { + cfg := &AccountConfig{ + AccountName: "testaccount", + ContainerName: "testcontainer", + Endpoint: "blob.core.windows.net", + } + + // NewBlobClient only wraps config+client; it doesn't + // call Azure, so a nil azblob.Client is fine for the + // factory test (we only check the interface is returned). + bc := NewBlobClient(cfg, nil) + require.NotNil(t, bc) +} + +func Test_AccountConfig_Fields(t *testing.T) { + cfg := AccountConfig{ + AccountName: "sa", + ContainerName: "cn", + Endpoint: "ep", + SubscriptionId: "sid", + } + + require.Equal(t, "sa", cfg.AccountName) + require.Equal(t, "cn", cfg.ContainerName) + require.Equal(t, "ep", cfg.Endpoint) + require.Equal(t, "sid", cfg.SubscriptionId) +} + +func Test_ErrContainerNotFound(t *testing.T) { + require.NotNil(t, ErrContainerNotFound) + require.Equal(t, "container not found", ErrContainerNotFound.Error()) +} diff --git a/cli/azd/pkg/azure/arm_template_test.go b/cli/azd/pkg/azure/arm_template_test.go new file mode 100644 index 00000000000..859607867ab --- /dev/null +++ b/cli/azd/pkg/azure/arm_template_test.go @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azure + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_TargetScope(t *testing.T) { + tests := []struct { + name string + schema string + want DeploymentScope + wantErr bool + }{ + { + name: "SubscriptionScope", + schema: "https://schema.management.azure.com/" + + "schemas/2018-05-01/" + + "subscriptionDeploymentTemplate.json#", + want: DeploymentScopeSubscription, + }, + { + name: "ResourceGroupScope", + schema: "https://schema.management.azure.com/" + + "schemas/2019-04-01/" + + "deploymentTemplate.json#", + want: DeploymentScopeResourceGroup, + }, + { + name: "ResourceGroupCaseInsensitive", + schema: "https://schema.management.azure.com/" + + "schemas/2019-04-01/" + + "DeploymentTemplate.json#", + want: DeploymentScopeResourceGroup, + }, + { + name: "EmptySchema", + schema: "", + wantErr: true, + }, + { + name: "UnknownSchema", + schema: "https://example.com/unknown.json", + wantErr: true, + }, + { + name: "InvalidURL", + schema: "://bad-url", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpl := ArmTemplate{Schema: tt.schema} + got, err := tmpl.TargetScope() + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_IsSecuredARMType(t *testing.T) { + tests := []struct { + name string + typ string + want bool + }{ + {"SecureString", "securestring", true}, + {"SecureObject", "secureobject", true}, + {"SecureStringUpper", "SecureString", true}, + {"SecureObjectMixed", "SecureObject", true}, + {"AllCaps", "SECURESTRING", true}, + {"String", "string", false}, + {"Object", "object", false}, + {"Int", "int", false}, + {"Bool", "bool", false}, + {"Empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsSecuredARMType(tt.typ) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_Secure(t *testing.T) { + tests := []struct { + name string + typ string + want bool + }{ + {"SecureString", "securestring", true}, + {"RegularString", "string", false}, + {"SecureObject", "secureobject", true}, + {"RegularObject", "object", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + param := ArmTemplateParameterDefinition{ + Type: tt.typ, + } + require.Equal(t, tt.want, param.Secure()) + }) + } +} + +func Test_Description(t *testing.T) { + tests := []struct { + name string + meta map[string]json.RawMessage + want string + wantOK bool + }{ + { + name: "WithDescription", + meta: map[string]json.RawMessage{ + "description": json.RawMessage( + `"A test parameter"`, + ), + }, + want: "A test parameter", + wantOK: true, + }, + { + name: "NoMetadata", + meta: nil, + want: "", + wantOK: false, + }, + { + name: "NoDescriptionKey", + meta: map[string]json.RawMessage{ + "other": json.RawMessage(`"something"`), + }, + want: "", + wantOK: false, + }, + { + name: "InvalidDescriptionJSON", + meta: map[string]json.RawMessage{ + "description": json.RawMessage(`123`), + }, + want: "", + wantOK: false, + }, + { + name: "EmptyDescription", + meta: map[string]json.RawMessage{ + "description": json.RawMessage(`""`), + }, + want: "", + wantOK: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + param := ArmTemplateParameterDefinition{ + Metadata: tt.meta, + } + got, ok := param.Description() + require.Equal(t, tt.wantOK, ok) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_AzdMetadata(t *testing.T) { + locationType := AzdMetadataTypeLocation + + tests := []struct { + name string + meta map[string]json.RawMessage + wantOK bool + check func(t *testing.T, m AzdMetadata) + }{ + { + name: "WithLocationMetadata", + meta: map[string]json.RawMessage{ + "azd": json.RawMessage( + `{"type":"location"}`, + ), + }, + wantOK: true, + check: func(t *testing.T, m AzdMetadata) { + require.NotNil(t, m.Type) + require.Equal(t, locationType, *m.Type) + }, + }, + { + name: "NoMetadata", + meta: nil, + wantOK: false, + check: func(t *testing.T, m AzdMetadata) {}, + }, + { + name: "NoAzdKey", + meta: map[string]json.RawMessage{ + "other": json.RawMessage(`{}`), + }, + wantOK: false, + check: func(t *testing.T, m AzdMetadata) {}, + }, + { + name: "InvalidAzdJSON", + meta: map[string]json.RawMessage{ + "azd": json.RawMessage(`not-valid`), + }, + wantOK: false, + check: func(t *testing.T, m AzdMetadata) {}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + param := ArmTemplateParameterDefinition{ + Metadata: tt.meta, + } + got, ok := param.AzdMetadata() + require.Equal(t, tt.wantOK, ok) + tt.check(t, got) + }) + } +} + +func Test_AdditionalProperties_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + hasProps bool + wantErr bool + }{ + { + name: "FalseValue", + input: "false", + hasProps: false, + }, + { + name: "ObjectValue", + input: `{"type":"string"}`, + hasProps: true, + }, + { + name: "ObjectWithMinMax", + input: `{"type":"int","minValue":1,"maxValue":10}`, + hasProps: true, + }, + { + name: "InvalidJSON", + input: `{bad json}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var v ArmTemplateParameterAdditionalPropertiesValue + err := v.UnmarshalJSON([]byte(tt.input)) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.hasProps, v.HasAdditionalProperties()) + + if tt.hasProps { + props := v.Properties() + require.NotEmpty(t, props.Type) + } + }) + } +} + +func Test_AdditionalProperties_MarshalJSON(t *testing.T) { + t.Run("NilProps", func(t *testing.T) { + v := ArmTemplateParameterAdditionalPropertiesValue{} + data, err := v.MarshalJSON() + require.NoError(t, err) + require.Equal(t, "false", string(data)) + }) + + t.Run("WithProps", func(t *testing.T) { + v := ArmTemplateParameterAdditionalPropertiesValue{} + err := v.UnmarshalJSON( + []byte(`{"type":"string"}`), + ) + require.NoError(t, err) + + data, err := v.MarshalJSON() + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + require.Equal(t, "string", parsed["type"]) + }) +} + +func Test_AdditionalProperties_RoundTrip(t *testing.T) { + original := `{"type":"object"}` + var v ArmTemplateParameterAdditionalPropertiesValue + err := v.UnmarshalJSON([]byte(original)) + require.NoError(t, err) + + data, err := v.MarshalJSON() + require.NoError(t, err) + + var v2 ArmTemplateParameterAdditionalPropertiesValue + err = v2.UnmarshalJSON(data) + require.NoError(t, err) + require.True(t, v2.HasAdditionalProperties()) + require.Equal(t, "object", v2.Properties().Type) +} + +func Test_UsageName_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + want []string + }{ + { + name: "SingleString", + json: `{"usageName": "foo"}`, + want: []string{"foo"}, + }, + { + name: "ArrayOfStrings", + json: `{"usageName": ["foo", "bar"]}`, + want: []string{"foo", "bar"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var m AzdMetadata + err := json.Unmarshal([]byte(tt.json), &m) + require.NoError(t, err) + require.Equal(t, tt.want, []string(m.UsageName)) + }) + } +} diff --git a/cli/azd/pkg/azure/resource_ids_additional_test.go b/cli/azd/pkg/azure/resource_ids_additional_test.go new file mode 100644 index 00000000000..860328934b0 --- /dev/null +++ b/cli/azd/pkg/azure/resource_ids_additional_test.go @@ -0,0 +1,290 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azure + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_SubscriptionFromRID(t *testing.T) { + tests := []struct { + name string + rid string + want string + panic bool + }{ + { + name: "StandardResourceId", + rid: "/subscriptions/abc-123/resourceGroups/rg/" + + "providers/Microsoft.Web/sites/myapp", + want: "abc-123", + }, + { + name: "SubscriptionOnly", + rid: "/subscriptions/sub-id-456", + want: "sub-id-456", + }, + { + name: "DeploymentResourceId", + rid: "/subscriptions/deadbeef/providers/" + + "Microsoft.Resources/deployments/deploy1", + want: "deadbeef", + }, + { + name: "NoSubscriptionSegment", + rid: "/resourceGroups/rg/providers/Microsoft.Web", + panic: true, + }, + { + name: "EmptyString", + rid: "", + panic: true, + }, + { + name: "SubscriptionsAtEnd", + rid: "/something/subscriptions", + panic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.panic { + require.Panics(t, func() { + SubscriptionFromRID(tt.rid) + }) + return + } + + got := SubscriptionFromRID(tt.rid) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_SubscriptionRID(t *testing.T) { + tests := []struct { + name string + subscriptionId string + want string + }{ + { + name: "Standard", + subscriptionId: "abc-123", + want: "/subscriptions/abc-123", + }, + { + name: "GuidFormat", + subscriptionId: "faa080af-c1d8-40ad-9cce-e1a450ca5b57", + want: "/subscriptions/" + + "faa080af-c1d8-40ad-9cce-e1a450ca5b57", + }, + { + name: "Empty", + subscriptionId: "", + want: "/subscriptions/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SubscriptionRID(tt.subscriptionId) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_SubscriptionDeploymentRID(t *testing.T) { + tests := []struct { + name string + subscriptionId string + deploymentId string + want string + }{ + { + name: "Standard", + subscriptionId: "sub-1", + deploymentId: "deploy-1", + want: "/subscriptions/sub-1/providers/" + + "Microsoft.Resources/deployments/deploy-1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SubscriptionDeploymentRID( + tt.subscriptionId, tt.deploymentId, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_ResourceGroupDeploymentRID(t *testing.T) { + tests := []struct { + name string + subscriptionId string + resourceGroupName string + deploymentId string + want string + }{ + { + name: "Standard", + subscriptionId: "sub-1", + resourceGroupName: "rg-1", + deploymentId: "deploy-1", + want: "/subscriptions/sub-1" + + "/resourceGroups/rg-1/providers/" + + "Microsoft.Resources/deployments/deploy-1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ResourceGroupDeploymentRID( + tt.subscriptionId, + tt.resourceGroupName, + tt.deploymentId, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_ResourceGroupRID(t *testing.T) { + tests := []struct { + name string + subscriptionId string + resourceGroupName string + want string + }{ + { + name: "Standard", + subscriptionId: "sub-1", + resourceGroupName: "my-rg", + want: "/subscriptions/sub-1" + + "/resourceGroups/my-rg", + }, + { + name: "EmptyValues", + subscriptionId: "", + resourceGroupName: "", + want: "/subscriptions//resourceGroups/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ResourceGroupRID( + tt.subscriptionId, tt.resourceGroupName, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_WebsiteRID(t *testing.T) { + got := WebsiteRID("sub-1", "rg-1", "mysite") + want := "/subscriptions/sub-1/resourceGroups/rg-1" + + "/providers/Microsoft.Web/sites/mysite" + require.Equal(t, want, got) +} + +func Test_ContainerAppRID(t *testing.T) { + got := ContainerAppRID("sub-1", "rg-1", "myapp") + want := "/subscriptions/sub-1/resourceGroups/rg-1" + + "/providers/Microsoft.App/containerApps/myapp" + require.Equal(t, want, got) +} + +func Test_KubernetesServiceRID(t *testing.T) { + got := KubernetesServiceRID("sub-1", "rg-1", "mycluster") + want := "/subscriptions/sub-1/resourceGroups/rg-1" + + "/providers/Microsoft.ContainerService" + + "/managedClusters/mycluster" + require.Equal(t, want, got) +} + +func Test_StaticWebAppRID(t *testing.T) { + got := StaticWebAppRID("sub-1", "rg-1", "mystaticsite") + want := "/subscriptions/sub-1/resourceGroups/rg-1" + + "/providers/Microsoft.Web/staticSites/mystaticsite" + require.Equal(t, want, got) +} + +func Test_WorkspaceRID(t *testing.T) { + got := WorkspaceRID("sub-1", "rg-1", "myworkspace") + want := "/subscriptions/sub-1/resourceGroups/rg-1" + + "/providers/" + + "Microsoft.MachineLearningServices" + + "/workspaces/myworkspace" + require.Equal(t, want, got) +} + +func Test_RIDRoundTrip(t *testing.T) { + // Verify that SubscriptionFromRID can extract subscription + // from any RID builder output. + subId := "faa080af-c1d8-40ad-9cce-e1a450ca5b57" + + rids := []string{ + WebsiteRID(subId, "rg", "site"), + ContainerAppRID(subId, "rg", "app"), + KubernetesServiceRID(subId, "rg", "aks"), + StaticWebAppRID(subId, "rg", "swa"), + WorkspaceRID(subId, "rg", "ws"), + ResourceGroupRID(subId, "rg"), + SubscriptionDeploymentRID(subId, "dep"), + ResourceGroupDeploymentRID(subId, "rg", "dep"), + } + + for _, rid := range rids { + t.Run(rid, func(t *testing.T) { + got := SubscriptionFromRID(rid) + require.Equal(t, subId, got) + }) + } +} + +func Test_GetResourceGroupNameFromRIDBuilders(t *testing.T) { + tests := []struct { + name string + rid string + want string + }{ + { + name: "FromWebsiteRID", + rid: WebsiteRID("sub", "my-rg", "site"), + want: "my-rg", + }, + { + name: "FromContainerAppRID", + rid: ContainerAppRID("sub", "my-rg", "app"), + want: "my-rg", + }, + { + name: "FromKubernetesRID", + rid: KubernetesServiceRID("sub", "my-rg", "aks"), + want: "my-rg", + }, + { + name: "SubscriptionLevelNoRG", + rid: "/subscriptions/sub/providers/" + + "Microsoft.Resources/deployments/d", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GetResourceGroupName(tt.rid) + if tt.want == "" { + require.Nil(t, got) + } else { + require.NotNil(t, got) + require.Equal(t, tt.want, *got) + } + }) + } +} diff --git a/cli/azd/pkg/azureutil/resource_group_test.go b/cli/azd/pkg/azureutil/resource_group_test.go new file mode 100644 index 00000000000..59706fa2bba --- /dev/null +++ b/cli/azd/pkg/azureutil/resource_group_test.go @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azureutil + +import ( + "errors" + "testing" +) + +func TestResourceNotFound_with_error(t *testing.T) { + inner := errors.New("some azure error") + err := ResourceNotFound(inner) + + if err == nil { + t.Fatal("ResourceNotFound should return non-nil error") + } + + expected := "resource not found: some azure error" + if err.Error() != expected { + t.Fatalf("Error() = %q, want %q", err.Error(), expected) + } +} + +func TestResourceNotFound_with_nil_error(t *testing.T) { + err := ResourceNotFound(nil) + + if err == nil { + t.Fatal("ResourceNotFound(nil) should return non-nil error") + } + + expected := "resource not found: " + if err.Error() != expected { + t.Fatalf("Error() = %q, want %q", err.Error(), expected) + } +} + +func TestResourceNotFoundError_is_error_type(t *testing.T) { + inner := errors.New("not found") + err := ResourceNotFound(inner) + + var rnfErr *ResourceNotFoundError + if !errors.As(err, &rnfErr) { + t.Fatal("ResourceNotFound should return *ResourceNotFoundError") + } +} + +func TestResourceNotFoundError_errors_As(t *testing.T) { + inner := errors.New("original cause") + err := ResourceNotFound(inner) + + // Verify that errors.As can extract the ResourceNotFoundError + var target *ResourceNotFoundError + if !errors.As(err, &target) { + t.Fatal("errors.As should find *ResourceNotFoundError") + } + + // The inner error message should be embedded + if target.Error() != "resource not found: original cause" { + t.Fatalf("target.Error() = %q, want %q", target.Error(), "resource not found: original cause") + } +} + +func TestResourceNotFoundError_different_inner_errors(t *testing.T) { + tests := []struct { + name string + inner error + expected string + }{ + { + name: "simple error", + inner: errors.New("not found"), + expected: "resource not found: not found", + }, + { + name: "empty error message", + inner: errors.New(""), + expected: "resource not found: ", + }, + { + name: "nil error", + inner: nil, + expected: "resource not found: ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ResourceNotFound(tt.inner) + if err.Error() != tt.expected { + t.Fatalf("Error() = %q, want %q", err.Error(), tt.expected) + } + }) + } +} diff --git a/cli/azd/pkg/cloud/cloud_test.go b/cli/azd/pkg/cloud/cloud_test.go new file mode 100644 index 00000000000..fad8fab68ac --- /dev/null +++ b/cli/azd/pkg/cloud/cloud_test.go @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cloud + +import ( + "testing" + + azcloud "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAzurePublic(t *testing.T) { + c := AzurePublic() + require.NotNil(t, c) + assert.Equal(t, azcloud.AzurePublic, c.Configuration) + assert.Equal(t, "https://portal.azure.com", c.PortalUrlBase) + assert.Equal(t, "core.windows.net", c.StorageEndpointSuffix) + assert.Equal(t, "azurecr.io", c.ContainerRegistryEndpointSuffix) + assert.Equal(t, "vault.azure.net", c.KeyVaultEndpointSuffix) +} + +func TestAzureGovernment(t *testing.T) { + c := AzureGovernment() + require.NotNil(t, c) + assert.Equal(t, azcloud.AzureGovernment, c.Configuration) + assert.Equal(t, "https://portal.azure.us", c.PortalUrlBase) + assert.Equal( + t, + "core.usgovcloudapi.net", + c.StorageEndpointSuffix, + ) + assert.Equal(t, "azurecr.us", c.ContainerRegistryEndpointSuffix) + assert.Equal( + t, + "vault.usgovcloudapi.net", + c.KeyVaultEndpointSuffix, + ) +} + +func TestAzureChina(t *testing.T) { + c := AzureChina() + require.NotNil(t, c) + assert.Equal(t, azcloud.AzureChina, c.Configuration) + assert.Equal(t, "https://portal.azure.cn", c.PortalUrlBase) + assert.Equal( + t, + "core.chinacloudapi.cn", + c.StorageEndpointSuffix, + ) + assert.Equal(t, "azurecr.cn", c.ContainerRegistryEndpointSuffix) + assert.Equal(t, "vault.azure.cn", c.KeyVaultEndpointSuffix) +} + +func TestNewCloud(t *testing.T) { + tests := []struct { + name string + cloudName string + wantPortal string + wantErr bool + errContains string + }{ + { + name: "AzurePublicByName", + cloudName: AzurePublicName, + wantPortal: "https://portal.azure.com", + }, + { + name: "EmptyNameDefaultsToPublic", + cloudName: "", + wantPortal: "https://portal.azure.com", + }, + { + name: "AzureChinaCloud", + cloudName: AzureChinaCloudName, + wantPortal: "https://portal.azure.cn", + }, + { + name: "AzureUSGovernment", + cloudName: AzureUSGovernmentName, + wantPortal: "https://portal.azure.us", + }, + { + name: "InvalidCloudNameReturnsError", + cloudName: "SomeInvalidCloud", + wantErr: true, + errContains: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{Name: tt.cloudName} + c, err := NewCloud(cfg) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + return + } + require.NoError(t, err) + require.NotNil(t, c) + assert.Equal(t, tt.wantPortal, c.PortalUrlBase) + }) + } +} + +func TestNewCloud_EndpointSuffixes(t *testing.T) { + tests := []struct { + name string + cloudName string + wantStorage string + wantContainerRegistry string + wantKeyVault string + }{ + { + name: "PublicEndpoints", + cloudName: AzurePublicName, + wantStorage: "core.windows.net", + wantContainerRegistry: "azurecr.io", + wantKeyVault: "vault.azure.net", + }, + { + name: "GovernmentEndpoints", + cloudName: AzureUSGovernmentName, + wantStorage: "core.usgovcloudapi.net", + wantContainerRegistry: "azurecr.us", + wantKeyVault: "vault.usgovcloudapi.net", + }, + { + name: "ChinaEndpoints", + cloudName: AzureChinaCloudName, + wantStorage: "core.chinacloudapi.cn", + wantContainerRegistry: "azurecr.cn", + wantKeyVault: "vault.azure.cn", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewCloud(&Config{Name: tt.cloudName}) + require.NoError(t, err) + assert.Equal(t, tt.wantStorage, c.StorageEndpointSuffix) + assert.Equal( + t, + tt.wantContainerRegistry, + c.ContainerRegistryEndpointSuffix, + ) + assert.Equal( + t, + tt.wantKeyVault, + c.KeyVaultEndpointSuffix, + ) + }) + } +} + +func TestParseCloudConfig(t *testing.T) { + tests := []struct { + name string + input any + wantName string + wantErr bool + }{ + { + name: "MapWithName", + input: map[string]string{"name": "AzureCloud"}, + wantName: "AzureCloud", + }, + { + name: "MapWithoutName", + input: map[string]string{"other": "value"}, + wantName: "", + }, + { + name: "EmptyMap", + input: map[string]string{}, + wantName: "", + }, + { + name: "StructWithMatchingField", + input: struct{ Name string }{"AzureChina"}, + wantName: "AzureChina", + }, + { + name: "UnmarshalableChannelInput", + input: make(chan int), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := ParseCloudConfig(tt.input) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, cfg) + assert.Equal(t, tt.wantName, cfg.Name) + }) + } +} + +func TestParseCloudConfig_NilInput(t *testing.T) { + // json.Marshal(nil) produces "null", which unmarshals to nil *Config + cfg, err := ParseCloudConfig(nil) + require.NoError(t, err) + assert.Nil(t, cfg) +} + +func TestParseCloudConfig_RoundTrip(t *testing.T) { + input := map[string]any{"name": "AzureUSGovernment"} + cfg, err := ParseCloudConfig(input) + require.NoError(t, err) + require.NotNil(t, cfg) + assert.Equal(t, "AzureUSGovernment", cfg.Name) + + c, err := NewCloud(cfg) + require.NoError(t, err) + assert.Equal(t, "https://portal.azure.us", c.PortalUrlBase) +} + +func TestConstants(t *testing.T) { + assert.Equal(t, "cloud", ConfigPath) + assert.Equal(t, "AzureCloud", AzurePublicName) + assert.Equal(t, "AzureChinaCloud", AzureChinaCloudName) + assert.Equal(t, "AzureUSGovernment", AzureUSGovernmentName) +} diff --git a/cli/azd/pkg/cmdsubst/cmdsubst_additional_test.go b/cli/azd/pkg/cmdsubst/cmdsubst_additional_test.go new file mode 100644 index 00000000000..ecd8d259bee --- /dev/null +++ b/cli/azd/pkg/cmdsubst/cmdsubst_additional_test.go @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmdsubst + +import ( + "context" + "errors" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/keyvault" + "github.com/stretchr/testify/require" +) + +// mockKeyVaultService implements keyvault.KeyVaultService +// for testing SecretOrRandomPasswordCommandExecutor. +type mockKeyVaultService struct { + getSecretFn func( + ctx context.Context, + subscriptionId, vaultName, secretName string, + ) (*keyvault.Secret, error) +} + +func (m *mockKeyVaultService) GetKeyVault( + _ context.Context, _, _, _ string, +) (*keyvault.KeyVault, error) { + return nil, nil +} + +func (m *mockKeyVaultService) GetKeyVaultSecret( + ctx context.Context, + subscriptionId, vaultName, secretName string, +) (*keyvault.Secret, error) { + return m.getSecretFn( + ctx, subscriptionId, vaultName, secretName, + ) +} + +func (m *mockKeyVaultService) PurgeKeyVault( + _ context.Context, _, _, _ string, +) error { + return nil +} + +func (m *mockKeyVaultService) ListSubscriptionVaults( + _ context.Context, _ string, +) ([]keyvault.Vault, error) { + return nil, nil +} + +func (m *mockKeyVaultService) CreateVault( + _ context.Context, _, _, _, _, _ string, +) (keyvault.Vault, error) { + return keyvault.Vault{}, nil +} + +func (m *mockKeyVaultService) ListKeyVaultSecrets( + _ context.Context, _, _ string, +) ([]string, error) { + return nil, nil +} + +func (m *mockKeyVaultService) CreateKeyVaultSecret( + _ context.Context, _, _, _, _ string, +) error { + return nil +} + +func (m *mockKeyVaultService) SecretFromAkvs( + _ context.Context, _ string, +) (string, error) { + return "", nil +} + +func Test_SecretOrRandomPassword_WrongCommand(t *testing.T) { + svc := &mockKeyVaultService{} + executor := NewSecretOrRandomPasswordExecutor(svc, "sub1") + + ran, result, err := executor.Run( + context.Background(), "otherCommand", nil, + ) + require.NoError(t, err) + require.False(t, ran) + require.Empty(t, result) +} + +func Test_SecretOrRandomPassword_NoArgs(t *testing.T) { + svc := &mockKeyVaultService{} + executor := NewSecretOrRandomPasswordExecutor(svc, "sub1") + + ran, result, err := executor.Run( + context.Background(), + SecretOrRandomPasswordCommandName, nil, + ) + require.NoError(t, err) + require.True(t, ran) + // Should generate a random password + require.NotEmpty(t, result) + require.GreaterOrEqual(t, len(result), 15) +} + +func Test_SecretOrRandomPassword_OneArg(t *testing.T) { + svc := &mockKeyVaultService{} + executor := NewSecretOrRandomPasswordExecutor(svc, "sub1") + + ran, result, err := executor.Run( + context.Background(), + SecretOrRandomPasswordCommandName, + []string{"vaultOnly"}, + ) + require.NoError(t, err) + require.True(t, ran) + require.NotEmpty(t, result) +} + +func Test_SecretOrRandomPassword_SecretFound(t *testing.T) { + svc := &mockKeyVaultService{ + getSecretFn: func( + _ context.Context, _, _, _ string, + ) (*keyvault.Secret, error) { + return &keyvault.Secret{ + Value: "my-secret-value", + }, nil + }, + } + executor := NewSecretOrRandomPasswordExecutor(svc, "sub1") + + ran, result, err := executor.Run( + context.Background(), + SecretOrRandomPasswordCommandName, + []string{"myVault", "mySecret"}, + ) + require.NoError(t, err) + require.True(t, ran) + require.Equal(t, "my-secret-value", result) +} + +func Test_SecretOrRandomPassword_SecretNotFound(t *testing.T) { + svc := &mockKeyVaultService{ + getSecretFn: func( + _ context.Context, _, _, _ string, + ) (*keyvault.Secret, error) { + return nil, keyvault.ErrAzCliSecretNotFound + }, + } + executor := NewSecretOrRandomPasswordExecutor(svc, "sub1") + + ran, result, err := executor.Run( + context.Background(), + SecretOrRandomPasswordCommandName, + []string{"myVault", "missingSecret"}, + ) + require.NoError(t, err) + require.True(t, ran) + // Falls back to random password + require.NotEmpty(t, result) +} + +func Test_SecretOrRandomPassword_EmptySecret(t *testing.T) { + svc := &mockKeyVaultService{ + getSecretFn: func( + _ context.Context, _, _, _ string, + ) (*keyvault.Secret, error) { + return &keyvault.Secret{Value: ""}, nil + }, + } + executor := NewSecretOrRandomPasswordExecutor(svc, "sub1") + + ran, result, err := executor.Run( + context.Background(), + SecretOrRandomPasswordCommandName, + []string{"vault", "secret"}, + ) + require.NoError(t, err) + require.True(t, ran) + // Empty secret falls back to random password + require.NotEmpty(t, result) +} + +func Test_SecretOrRandomPassword_WhitespaceSecret(t *testing.T) { + svc := &mockKeyVaultService{ + getSecretFn: func( + _ context.Context, _, _, _ string, + ) (*keyvault.Secret, error) { + return &keyvault.Secret{Value: " "}, nil + }, + } + executor := NewSecretOrRandomPasswordExecutor(svc, "sub1") + + ran, result, err := executor.Run( + context.Background(), + SecretOrRandomPasswordCommandName, + []string{"vault", "secret"}, + ) + require.NoError(t, err) + require.True(t, ran) + // Whitespace-only secret falls back to random password + require.NotEmpty(t, result) +} + +func Test_SecretOrRandomPassword_VaultError(t *testing.T) { + svc := &mockKeyVaultService{ + getSecretFn: func( + _ context.Context, _, _, _ string, + ) (*keyvault.Secret, error) { + return nil, errors.New("network error") + }, + } + executor := NewSecretOrRandomPasswordExecutor(svc, "sub1") + + ran, result, err := executor.Run( + context.Background(), + SecretOrRandomPasswordCommandName, + []string{"vault", "secret"}, + ) + require.Error(t, err) + require.False(t, ran) + require.Empty(t, result) + require.Contains(t, err.Error(), "reading secret") +} + +func Test_ContainsCommandInvocation_EmptyCommandName( + t *testing.T, +) { + require.False(t, + ContainsCommandInvocation("$(cmd)", ""), + ) +} + +func Test_ContainsCommandInvocation_EmptyBoth(t *testing.T) { + require.False(t, ContainsCommandInvocation("", "")) +} + +func Test_Eval_MultipleMixedCommands(t *testing.T) { + // First command recognized, second unrecognized + input := "a $(known x) b $(unknown y) c" + expected := "a result b $(unknown y) c" + + result, err := Eval( + context.Background(), input, + testCommandExecutor{ + runImpl: func( + name string, args []string, + ) (bool, string, error) { + if name == "known" { + return true, "result", nil + } + return false, "", nil + }, + }, + ) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func Test_Eval_CommandWithNoArgs(t *testing.T) { + input := "$(noargs)" + result, err := Eval( + context.Background(), input, + testCommandExecutor{ + runImpl: func( + name string, args []string, + ) (bool, string, error) { + require.Equal(t, "noargs", name) + require.Empty(t, args) + return true, "done", nil + }, + }, + ) + require.NoError(t, err) + require.Equal(t, "done", result) +} + +func Test_Eval_AdjacentSubstitutions(t *testing.T) { + input := "$(a)$(b)" + result, err := Eval( + context.Background(), input, + testCommandExecutor{ + runImpl: func( + name string, _ []string, + ) (bool, string, error) { + return true, name, nil + }, + }, + ) + require.NoError(t, err) + require.Equal(t, "ab", result) +} + +func Test_Eval_ErrorOnFirstOfMultiple(t *testing.T) { + input := "$(fail) $(ok)" + _, err := Eval( + context.Background(), input, + testCommandExecutor{ + runImpl: func( + name string, _ []string, + ) (bool, string, error) { + if name == "fail" { + return false, "", + errors.New("first failed") + } + return true, "ok", nil + }, + }, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "first failed") +} diff --git a/cli/azd/pkg/containerapps/container_app_additional_test.go b/cli/azd/pkg/containerapps/container_app_additional_test.go new file mode 100644 index 00000000000..9345f37e707 --- /dev/null +++ b/cli/azd/pkg/containerapps/container_app_additional_test.go @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package containerapps + +import ( + "errors" + "testing" + + "github.com/azure/azure-dev/cli/azd/internal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateApiVersionPolicy(t *testing.T) { + t.Run("nil options returns nil", func(t *testing.T) { + result := createApiVersionPolicy(nil) + assert.Nil(t, result) + }) + + t.Run("empty api version returns nil", func(t *testing.T) { + result := createApiVersionPolicy( + &ContainerAppOptions{ApiVersion: ""}, + ) + assert.Nil(t, result) + }) + + t.Run("non-empty api version returns policy", + func(t *testing.T) { + opts := &ContainerAppOptions{ + ApiVersion: "2024-02-02-preview", + } + result := createApiVersionPolicy(opts) + require.NotNil(t, result) + assert.Equal( + t, "2024-02-02-preview", result.apiVersion, + ) + }) +} + +func TestWithApiVersionSuggestion(t *testing.T) { + t.Run("wraps error with suggestion", func(t *testing.T) { + original := errors.New("some api error") + wrapped := withApiVersionSuggestion(original) + + require.Error(t, wrapped) + // The error message should be the original + assert.Equal(t, "some api error", wrapped.Error()) + + // Should be an ErrorWithSuggestion + var sugErr *internal.ErrorWithSuggestion + require.True(t, errors.As(wrapped, &sugErr)) + assert.Contains( + t, sugErr.Suggestion, "apiVersion", + ) + assert.Contains( + t, sugErr.Suggestion, "azure.yaml", + ) + }) + + t.Run("underlying error preserved", func(t *testing.T) { + sentinel := errors.New("sentinel") + wrapped := withApiVersionSuggestion(sentinel) + + var sugErr *internal.ErrorWithSuggestion + require.True(t, errors.As(wrapped, &sugErr)) + assert.True(t, errors.Is(sugErr.Err, sentinel)) + }) +} + +func TestContainerAppOptions(t *testing.T) { + t.Run("zero value has empty api version", + func(t *testing.T) { + opts := ContainerAppOptions{} + assert.Equal(t, "", opts.ApiVersion) + }) + + t.Run("api version can be set", func(t *testing.T) { + opts := ContainerAppOptions{ + ApiVersion: "2025-02-02-preview", + } + assert.Equal( + t, "2025-02-02-preview", opts.ApiVersion, + ) + }) +} + +func TestContainerAppIngressConfiguration(t *testing.T) { + t.Run("empty hostnames", func(t *testing.T) { + config := ContainerAppIngressConfiguration{ + HostNames: []string{}, + } + assert.Empty(t, config.HostNames) + }) + + t.Run("single hostname", func(t *testing.T) { + config := ContainerAppIngressConfiguration{ + HostNames: []string{"myapp.azurecontainerapps.io"}, + } + require.Len(t, config.HostNames, 1) + assert.Equal( + t, + "myapp.azurecontainerapps.io", + config.HostNames[0], + ) + }) + + t.Run("multiple hostnames", func(t *testing.T) { + config := ContainerAppIngressConfiguration{ + HostNames: []string{ + "myapp.azurecontainerapps.io", + "custom.domain.com", + }, + } + require.Len(t, config.HostNames, 2) + }) +} diff --git a/cli/azd/pkg/contracts/contracts_test.go b/cli/azd/pkg/contracts/contracts_test.go new file mode 100644 index 00000000000..598db335528 --- /dev/null +++ b/cli/azd/pkg/contracts/contracts_test.go @@ -0,0 +1,506 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package contracts + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRFC3339Time_MarshalJSON(t *testing.T) { + tests := []struct { + name string + input time.Time + expected string + }{ + { + name: "truncates nanoseconds", + input: time.Date(2023, 1, 9, 6, 39, 0, 313323855, time.UTC), + expected: `"2023-01-09T06:39:00Z"`, + }, + { + name: "preserves timezone offset", + input: time.Date(2024, 6, 15, 14, 30, 0, 0, time.FixedZone("EST", -5*3600)), + expected: `"2024-06-15T14:30:00-05:00"`, + }, + { + name: "zero time", + input: time.Time{}, + expected: `"0001-01-01T00:00:00Z"`, + }, + { + name: "epoch", + input: time.Unix(0, 0).UTC(), + expected: `"1970-01-01T00:00:00Z"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := json.Marshal(RFC3339Time(tt.input)) + require.NoError(t, err) + assert.Equal(t, tt.expected, string(result)) + }) + } +} + +func TestRFC3339Time_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + expected time.Time + }{ + { + name: "UTC time", + input: `"2023-01-09T06:39:00Z"`, + expected: time.Date(2023, 1, 9, 6, 39, 0, 0, time.UTC), + }, + { + name: "with timezone offset", + input: `"2024-06-15T14:30:00-05:00"`, + expected: time.Date(2024, 6, 15, 14, 30, 0, 0, time.FixedZone("", -5*3600)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result RFC3339Time + err := json.Unmarshal([]byte(tt.input), &result) + require.NoError(t, err) + assert.True(t, time.Time(result).Equal(tt.expected), + "got %v, want %v", time.Time(result), tt.expected) + }) + } +} + +func TestRFC3339Time_UnmarshalJSON_errors(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + name: "invalid JSON", + input: `not-json`, + }, + { + name: "invalid date format", + input: `"2023-01-09 06:39:00"`, + }, + { + name: "number instead of string", + input: `12345`, + }, + { + name: "empty string", + input: `""`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result RFC3339Time + err := json.Unmarshal([]byte(tt.input), &result) + assert.Error(t, err) + }) + } +} + +func TestRFC3339Time_roundtrip(t *testing.T) { + original := time.Date(2024, 12, 25, 10, 0, 0, 0, time.UTC) + rfc := RFC3339Time(original) + + data, err := json.Marshal(rfc) + require.NoError(t, err) + + var restored RFC3339Time + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + + assert.True(t, time.Time(restored).Equal(original)) +} + +func TestAuthTokenResult_JSON(t *testing.T) { + expiresOn := time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC) + result := AuthTokenResult{ + Token: "test-token-value", //nolint:gosec // test data, not a real credential + ExpiresOn: RFC3339Time(expiresOn), + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "test-token-value", parsed["token"]) + assert.Equal(t, "2024-06-15T12:00:00Z", parsed["expiresOn"]) +} + +func TestAuthTokenResult_JSON_roundtrip(t *testing.T) { + original := AuthTokenResult{ + Token: "test-token", + ExpiresOn: RFC3339Time(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)), + } + + data, err := json.Marshal(original) + require.NoError(t, err) + + var restored AuthTokenResult + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + + assert.Equal(t, original.Token, restored.Token) + assert.True(t, time.Time(restored.ExpiresOn).Equal(time.Time(original.ExpiresOn))) +} + +func TestLoginResult_JSON(t *testing.T) { + t.Run("success with expiry", func(t *testing.T) { + expiresOn := time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC) + result := LoginResult{ + Status: LoginStatusSuccess, + ExpiresOn: &expiresOn, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "success", parsed["status"]) + assert.Contains(t, parsed, "expiresOn") + }) + + t.Run("unauthenticated omits expiresOn", func(t *testing.T) { + result := LoginResult{ + Status: LoginStatusUnauthenticated, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "unauthenticated", parsed["status"]) + assert.NotContains(t, parsed, "expiresOn") + }) +} + +func TestStatusResult_JSON(t *testing.T) { + t.Run("authenticated user", func(t *testing.T) { + result := StatusResult{ + Status: AuthStatusAuthenticated, + Type: AccountTypeUser, + Email: "user@example.com", + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "authenticated", parsed["status"]) + assert.Equal(t, "user", parsed["type"]) + assert.Equal(t, "user@example.com", parsed["email"]) + assert.NotContains(t, parsed, "clientId") + }) + + t.Run("authenticated service principal", func(t *testing.T) { + result := StatusResult{ + Status: AuthStatusAuthenticated, + Type: AccountTypeServicePrincipal, + ClientID: "00000000-0000-0000-0000-000000000001", + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "authenticated", parsed["status"]) + assert.Equal(t, "servicePrincipal", parsed["type"]) + assert.Equal(t, "00000000-0000-0000-0000-000000000001", parsed["clientId"]) + assert.NotContains(t, parsed, "email") + }) + + t.Run("unauthenticated omits optional fields", func(t *testing.T) { + result := StatusResult{ + Status: AuthStatusUnauthenticated, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "unauthenticated", parsed["status"]) + assert.NotContains(t, parsed, "type") + assert.NotContains(t, parsed, "email") + assert.NotContains(t, parsed, "clientId") + }) +} + +func TestShowResult_JSON(t *testing.T) { + result := ShowResult{ + Name: "my-app", + Services: map[string]ShowService{ + "api": { + Project: ShowServiceProject{ + Path: "./src/api", + Type: ShowTypeDotNet, + }, + Target: &ShowTargetArm{ + ResourceIds: []string{"/subscriptions/sub1/resourceGroups/rg1"}, + }, + IngresUrl: "https://api.example.com", + }, + }, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "my-app", parsed["name"]) + + // IngresUrl should be excluded (json:"-") + services := parsed["services"].(map[string]any) + api := services["api"].(map[string]any) + assert.NotContains(t, api, "ingresUrl") + assert.NotContains(t, api, "IngresUrl") + + // Project fields should be present + project := api["project"].(map[string]any) + assert.Equal(t, "./src/api", project["path"]) + assert.Equal(t, "dotnet", project["language"]) + + // Target should be present + target := api["target"].(map[string]any) + resourceIds := target["resourceIds"].([]any) + assert.Len(t, resourceIds, 1) +} + +func TestShowService_JSON_nil_target(t *testing.T) { + result := ShowService{ + Project: ShowServiceProject{ + Path: "./src/web", + Type: ShowTypeNode, + }, + Target: nil, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.NotContains(t, parsed, "target") +} + +func TestShowType_values(t *testing.T) { + assert.Equal(t, ShowType(""), ShowTypeNone) + assert.Equal(t, ShowType("dotnet"), ShowTypeDotNet) + assert.Equal(t, ShowType("python"), ShowTypePython) + assert.Equal(t, ShowType("node"), ShowTypeNode) + assert.Equal(t, ShowType("java"), ShowTypeJava) + assert.Equal(t, ShowType("custom"), ShowTypeCustom) +} + +func TestVersionResult_JSON(t *testing.T) { + result := VersionResult{} + result.Azd.Version = "1.5.0" + result.Azd.Commit = "abc123" + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + azd := parsed["azd"].(map[string]any) + assert.Equal(t, "1.5.0", azd["version"]) + assert.Equal(t, "abc123", azd["commit"]) +} + +func TestVsServerResult_JSON(t *testing.T) { + t.Run("with certificate", func(t *testing.T) { + cert := "MIIC+zCCAeOgAwIBAgIJAL..." + result := VsServerResult{ + Port: 5000, + Pid: 12345, + CertificateBytes: &cert, + } + result.Azd.Version = "1.5.0" + result.Azd.Commit = "abc123" + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, float64(5000), parsed["port"]) + assert.Equal(t, float64(12345), parsed["pid"]) + assert.Equal(t, "MIIC+zCCAeOgAwIBAgIJAL...", parsed["certificateBytes"]) + + // Embedded VersionResult + azd := parsed["azd"].(map[string]any) + assert.Equal(t, "1.5.0", azd["version"]) + assert.Equal(t, "abc123", azd["commit"]) + }) + + t.Run("without certificate omits field", func(t *testing.T) { + result := VsServerResult{ + Port: 5000, + Pid: 12345, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.NotContains(t, parsed, "certificateBytes") + }) +} + +func TestConsoleMessage_JSON(t *testing.T) { + msg := ConsoleMessage{Message: "Deploying services..."} + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "Deploying services...", parsed["message"]) +} + +func TestEventEnvelope_JSON(t *testing.T) { + ts := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + envelope := EventEnvelope{ + Type: ConsoleMessageEventDataType, + Timestamp: ts, + Data: ConsoleMessage{Message: "hello"}, + } + + data, err := json.Marshal(envelope) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "consoleMessage", parsed["type"]) + assert.Contains(t, parsed["timestamp"], "2024-01-15") +} + +func TestEnvListEnvironment_JSON(t *testing.T) { + env := EnvListEnvironment{ + Name: "dev", + IsDefault: true, + DotEnvPath: "/home/user/.azure/dev/.env", + ConfigPath: "/home/user/.azure/dev/config.json", + } + + data, err := json.Marshal(env) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + // Note: keys use uppercase per json tags + assert.Equal(t, "dev", parsed["Name"]) + assert.Equal(t, true, parsed["IsDefault"]) + assert.Equal(t, "/home/user/.azure/dev/.env", parsed["DotEnvPath"]) + assert.Equal(t, "/home/user/.azure/dev/config.json", parsed["ConfigPath"]) +} + +func TestEnvRefreshResult_JSON(t *testing.T) { + result := EnvRefreshResult{ + Outputs: map[string]EnvRefreshOutputParameter{ + "endpoint": { + Type: EnvRefreshOutputTypeString, + Value: "https://app.example.com", + }, + "port": { + Type: EnvRefreshOutputTypeNumber, + Value: 8080, + }, + "enabled": { + Type: EnvRefreshOutputTypeBoolean, + Value: true, + }, + }, + Resources: []EnvRefreshResource{ + {Id: "/subscriptions/sub1/resourceGroups/rg1"}, + {Id: "/subscriptions/sub1/resourceGroups/rg2"}, + }, + } + + data, err := json.Marshal(result) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + outputs := parsed["outputs"].(map[string]any) + assert.Len(t, outputs, 3) + + endpoint := outputs["endpoint"].(map[string]any) + assert.Equal(t, "string", endpoint["type"]) + assert.Equal(t, "https://app.example.com", endpoint["value"]) + + resources := parsed["resources"].([]any) + assert.Len(t, resources, 2) +} + +func TestEnvRefreshOutputType_values(t *testing.T) { + assert.Equal(t, EnvRefreshOutputType("boolean"), EnvRefreshOutputTypeBoolean) + assert.Equal(t, EnvRefreshOutputType("string"), EnvRefreshOutputTypeString) + assert.Equal(t, EnvRefreshOutputType("number"), EnvRefreshOutputTypeNumber) + assert.Equal(t, EnvRefreshOutputType("object"), EnvRefreshOutputTypeObject) + assert.Equal(t, EnvRefreshOutputType("array"), EnvRefreshOutputTypeArray) +} + +func TestLoginStatus_values(t *testing.T) { + assert.Equal(t, LoginStatus("success"), LoginStatusSuccess) + assert.Equal(t, LoginStatus("unauthenticated"), LoginStatusUnauthenticated) +} + +func TestAuthStatus_values(t *testing.T) { + assert.Equal(t, AuthStatus("authenticated"), AuthStatusAuthenticated) + assert.Equal(t, AuthStatus("unauthenticated"), AuthStatusUnauthenticated) +} + +func TestAccountType_values(t *testing.T) { + assert.Equal(t, AccountType("user"), AccountTypeUser) + assert.Equal(t, AccountType("servicePrincipal"), AccountTypeServicePrincipal) +} diff --git a/cli/azd/pkg/convert/util_additional_test.go b/cli/azd/pkg/convert/util_additional_test.go new file mode 100644 index 00000000000..623d5918ff5 --- /dev/null +++ b/cli/azd/pkg/convert/util_additional_test.go @@ -0,0 +1,309 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package convert + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_ToJsonArray_ValidSlice(t *testing.T) { + input := []string{"alpha", "bravo", "charlie"} + result, err := ToJsonArray(input) + require.NoError(t, err) + require.Len(t, result, 3) + require.Equal(t, "alpha", result[0]) + require.Equal(t, "bravo", result[1]) + require.Equal(t, "charlie", result[2]) +} + +func Test_ToJsonArray_IntSlice(t *testing.T) { + input := []int{1, 2, 3} + result, err := ToJsonArray(input) + require.NoError(t, err) + require.Len(t, result, 3) + require.Equal(t, float64(1), result[0]) + require.Equal(t, float64(2), result[1]) + require.Equal(t, float64(3), result[2]) +} + +func Test_ToJsonArray_StructSlice(t *testing.T) { + input := []Person{ + {Name: "Alice", Address: "123 Main St"}, + {Name: "Bob", Address: "456 Oak Ave"}, + } + result, err := ToJsonArray(input) + require.NoError(t, err) + require.Len(t, result, 2) + + first, ok := result[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "Alice", first["Name"]) +} + +func Test_ToJsonArray_Nil(t *testing.T) { + result, err := ToJsonArray(nil) + require.NoError(t, err) + require.Nil(t, result) +} + +func Test_ToJsonArray_EmptySlice(t *testing.T) { + input := []string{} + result, err := ToJsonArray(input) + require.NoError(t, err) + require.Empty(t, result) +} + +func Test_ToJsonArray_NonSlice(t *testing.T) { + // A non-slice value that marshals to JSON but can't + // unmarshal into []any + input := map[string]string{"key": "value"} + _, err := ToJsonArray(input) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to convert") +} + +func Test_ToJsonArray_UnmarshalableInput(t *testing.T) { + // A channel can't be marshalled to JSON + input := make(chan int) + _, err := ToJsonArray(input) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to convert") +} + +func Test_FromHttpResponse_ValidJSON(t *testing.T) { + body := `{"Name":"Alice","Address":"Wonderland"}` + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader([]byte(body))), + } + + var person Person + err := FromHttpResponse(resp, &person) + require.NoError(t, err) + require.Equal(t, "Alice", person.Name) + require.Equal(t, "Wonderland", person.Address) +} + +func Test_FromHttpResponse_InvalidJSON(t *testing.T) { + body := `not-json` + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader([]byte(body))), + } + + var person Person + err := FromHttpResponse(resp, &person) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal") +} + +func Test_FromHttpResponse_EmptyBody(t *testing.T) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader([]byte(""))), + } + + var result map[string]any + err := FromHttpResponse(resp, &result) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to unmarshal") +} + +type errReader struct{} + +func (e *errReader) Read(_ []byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + +func (e *errReader) Close() error { return nil } + +func Test_FromHttpResponse_ReadError(t *testing.T) { + resp := &http.Response{ + StatusCode: 200, + Body: &errReader{}, + } + + var result map[string]any + err := FromHttpResponse(resp, &result) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to read") +} + +func Test_FromHttpResponse_Array(t *testing.T) { + body := `[{"Name":"A","Address":"1"},{"Name":"B","Address":"2"}]` + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader([]byte(body))), + } + + var people []Person + err := FromHttpResponse(resp, &people) + require.NoError(t, err) + require.Len(t, people, 2) + require.Equal(t, "A", people[0].Name) + require.Equal(t, "B", people[1].Name) +} + +func Test_ToMap_Nil(t *testing.T) { + result, err := ToMap(nil) + require.NoError(t, err) + require.Nil(t, result) +} + +func Test_ToMap_MapInput(t *testing.T) { + input := map[string]string{"key": "val"} + result, err := ToMap(input) + require.NoError(t, err) + require.Equal(t, "val", result["key"]) +} + +func Test_ToMap_UnmarshalableInput(t *testing.T) { + input := make(chan int) + _, err := ToMap(input) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to convert") +} + +func Test_ParseDuration_ISO8601Formats(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"Seconds", "PT5S", "5s"}, + {"Minutes", "PT10M", "10m0s"}, + {"Hours", "PT2H", "2h0m0s"}, + {"Combined", "PT1H30M", "1h30m0s"}, + {"Fractional", "PT0.5S", "500ms"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + d, err := ParseDuration(tc.input) + require.NoError(t, err) + require.Equal(t, tc.expected, d.String()) + }) + } +} + +func Test_ParseDuration_AlreadyLowercase(t *testing.T) { + d, err := ParseDuration("10s") + require.NoError(t, err) + require.Equal(t, "10s", d.String()) +} + +func Test_ParseDuration_Invalid(t *testing.T) { + _, err := ParseDuration("not-a-duration") + require.Error(t, err) +} + +func Test_ToValueWithDefault_Bool(t *testing.T) { + trueVal := true + result := ToValueWithDefault(&trueVal, false) + require.True(t, result) +} + +func Test_ToValueWithDefault_BoolNil(t *testing.T) { + result := ToValueWithDefault[bool](nil, true) + require.True(t, result) +} + +func Test_ToStringWithDefault_IntPointer(t *testing.T) { + // An *int is not *string, so should return default + val := 42 + result := ToStringWithDefault(&val, "fallback") + require.Equal(t, "fallback", result) +} + +func Test_ToStringWithDefault_EmptyStringPointer(t *testing.T) { + empty := "" + result := ToStringWithDefault(&empty, "default") + require.Equal(t, "default", result) +} + +func Test_ToMap_NestedStruct(t *testing.T) { + type Inner struct { + Value int + } + type Outer struct { + Name string + Inner Inner + } + input := Outer{Name: "test", Inner: Inner{Value: 42}} + result, err := ToMap(input) + require.NoError(t, err) + require.Equal(t, "test", result["Name"]) + inner, ok := result["Inner"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(42), inner["Value"]) +} + +func Test_ToJsonArray_NestedSlice(t *testing.T) { + input := [][]int{{1, 2}, {3, 4}} + result, err := ToJsonArray(input) + require.NoError(t, err) + require.Len(t, result, 2) + first, ok := result[0].([]any) + require.True(t, ok) + require.Len(t, first, 2) +} + +func Test_FromHttpResponse_ClosesBody(t *testing.T) { + closed := false + body := io.NopCloser( + bytes.NewReader([]byte(`{"Name":"X"}`)), + ) + resp := &http.Response{ + StatusCode: 200, + Body: &trackingCloser{ + ReadCloser: body, + onClose: func() { closed = true }, + }, + } + + var p Person + err := FromHttpResponse(resp, &p) + require.NoError(t, err) + require.True(t, closed) +} + +type trackingCloser struct { + io.ReadCloser + onClose func() +} + +func (tc *trackingCloser) Close() error { + tc.onClose() + return tc.ReadCloser.Close() +} + +func Test_FromHttpResponse_LargePayload(t *testing.T) { + items := make([]Person, 100) + for i := range items { + items[i] = Person{ + Name: "Person", + Address: "Address", + } + } + data, err := json.Marshal(items) + require.NoError(t, err) + + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser( + bytes.NewReader(data), + ), + } + + var result []Person + err = FromHttpResponse(resp, &result) + require.NoError(t, err) + require.Len(t, result, 100) +} diff --git a/cli/azd/pkg/custommaps/with_order_map_additional_test.go b/cli/azd/pkg/custommaps/with_order_map_additional_test.go new file mode 100644 index 00000000000..43faa2ede6e --- /dev/null +++ b/cli/azd/pkg/custommaps/with_order_map_additional_test.go @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package custommaps + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWithOrder_OrderedValues(t *testing.T) { + type Item struct { + Name string `json:"name"` + } + m := &WithOrder[Item]{} + data := `{ + "first": {"name": "alpha"}, + "second": {"name": "bravo"}, + "third": {"name": "charlie"} + }` + err := json.Unmarshal([]byte(data), m) + require.NoError(t, err) + + values := m.OrderedValues() + require.Len(t, values, 3) + require.Equal(t, "alpha", values[0].Name) + require.Equal(t, "bravo", values[1].Name) + require.Equal(t, "charlie", values[2].Name) +} + +func TestWithOrder_Get_ExistingKey(t *testing.T) { + type Item struct { + Value int `json:"value"` + } + m := &WithOrder[Item]{} + data := `{"a": {"value": 1}, "b": {"value": 2}}` + err := json.Unmarshal([]byte(data), m) + require.NoError(t, err) + + val, ok := m.Get("a") + require.True(t, ok) + require.NotNil(t, val) + require.Equal(t, 1, val.Value) + + val, ok = m.Get("b") + require.True(t, ok) + require.Equal(t, 2, val.Value) +} + +func TestWithOrder_Get_MissingKey(t *testing.T) { + m := &WithOrder[struct{}]{} + err := json.Unmarshal([]byte(`{"x": {}}`), m) + require.NoError(t, err) + + val, ok := m.Get("missing") + require.False(t, ok) + require.Nil(t, val) +} + +func TestWithOrder_EmptyObject(t *testing.T) { + m := &WithOrder[struct{}]{} + err := json.Unmarshal([]byte(`{}`), m) + require.NoError(t, err) + + require.Empty(t, m.OrderedKeys()) + require.Empty(t, m.OrderedValues()) +} + +func TestWithOrder_SingleEntry(t *testing.T) { + type Item struct { + ID string `json:"id"` + } + m := &WithOrder[Item]{} + err := json.Unmarshal([]byte(`{"only": {"id": "one"}}`), m) + require.NoError(t, err) + + keys := m.OrderedKeys() + require.Len(t, keys, 1) + require.Equal(t, "only", keys[0]) + + values := m.OrderedValues() + require.Len(t, values, 1) + require.Equal(t, "one", values[0].ID) +} + +func TestWithOrder_OrderPreserved(t *testing.T) { + m := &WithOrder[struct{}]{} + data := `{"z": {}, "a": {}, "m": {}, "b": {}}` + err := json.Unmarshal([]byte(data), m) + require.NoError(t, err) + + keys := m.OrderedKeys() + require.Equal(t, []string{"z", "a", "m", "b"}, keys) +} + +func TestWithOrder_InvalidJSON(t *testing.T) { + m := &WithOrder[struct{}]{} + err := json.Unmarshal([]byte(`not-json`), m) + require.Error(t, err) +} + +func TestWithOrder_ArrayInsteadOfObject(t *testing.T) { + m := &WithOrder[struct{}]{} + err := json.Unmarshal([]byte(`[1,2,3]`), m) + require.Error(t, err) +} + +func TestWithOrder_NestedValues(t *testing.T) { + type Nested struct { + Items []string `json:"items"` + } + m := &WithOrder[Nested]{} + data := `{ + "group1": {"items": ["a", "b"]}, + "group2": {"items": ["c"]} + }` + err := json.Unmarshal([]byte(data), m) + require.NoError(t, err) + + g1, ok := m.Get("group1") + require.True(t, ok) + require.Equal(t, []string{"a", "b"}, g1.Items) + + g2, ok := m.Get("group2") + require.True(t, ok) + require.Equal(t, []string{"c"}, g2.Items) +} + +func TestWithOrder_ValuesMatchKeys(t *testing.T) { + type Item struct { + Name string `json:"name"` + } + m := &WithOrder[Item]{} + data := `{ + "x": {"name": "X"}, + "y": {"name": "Y"}, + "z": {"name": "Z"} + }` + err := json.Unmarshal([]byte(data), m) + require.NoError(t, err) + + keys := m.OrderedKeys() + values := m.OrderedValues() + require.Len(t, keys, len(values)) + + for i, key := range keys { + got, ok := m.Get(key) + require.True(t, ok) + require.Equal(t, got, values[i]) + } +} + +func TestWithOrder_NullValues(t *testing.T) { + type Item struct { + Name string `json:"name"` + } + m := &WithOrder[Item]{} + data := `{"a": null, "b": {"name": "B"}}` + err := json.Unmarshal([]byte(data), m) + require.NoError(t, err) + + keys := m.OrderedKeys() + require.Equal(t, []string{"a", "b"}, keys) + + aVal, ok := m.Get("a") + require.True(t, ok) + require.Nil(t, aVal) + + bVal, ok := m.Get("b") + require.True(t, ok) + require.Equal(t, "B", bVal.Name) +} diff --git a/cli/azd/pkg/devcentersdk/devcentersdk_test.go b/cli/azd/pkg/devcentersdk/devcentersdk_test.go new file mode 100644 index 00000000000..46f1d6fc501 --- /dev/null +++ b/cli/azd/pkg/devcentersdk/devcentersdk_test.go @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package devcentersdk + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewResourceId(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + subscriptionId string + resourceGroup string + provider string + resourcePath string + resourceName string + }{ + { + name: "valid resource ID", + input: "/subscriptions/sub-123/resourceGroups/rg-test" + + "/providers/Microsoft.DevCenter" + + "/devcenters/my-devcenter", + subscriptionId: "sub-123", + resourceGroup: "rg-test", + provider: "Microsoft.DevCenter", + resourcePath: "devcenters", + resourceName: "my-devcenter", + }, + { + name: "GUID subscription", + input: "/subscriptions" + + "/00000000-0000-0000-0000-000000000000" + + "/resourceGroups/my-rg" + + "/providers/Microsoft.Web/sites/my-app", + subscriptionId: "00000000-0000-0000-0000-000000000000", + resourceGroup: "my-rg", + provider: "Microsoft.Web", + resourcePath: "sites", + resourceName: "my-app", + }, + { + name: "empty string", + input: "", + expectError: true, + }, + { + name: "invalid format", + input: "not-a-resource-id", + expectError: true, + }, + { + name: "missing resource name", + input: "/subscriptions/sub-123" + + "/resourceGroups/rg-test" + + "/providers/Microsoft.DevCenter", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := NewResourceId(tt.input) + if tt.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tt.input, result.Id) + assert.Equal( + t, tt.subscriptionId, result.SubscriptionId, + ) + assert.Equal(t, tt.resourceGroup, result.ResourceGroup) + assert.Equal(t, tt.provider, result.Provider) + assert.Equal(t, tt.resourcePath, result.ResourcePath) + assert.Equal(t, tt.resourceName, result.ResourceName) + }) + } +} + +func TestNewResourceGroupId(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + subscriptionId string + rgName string + }{ + { + name: "valid resource group ID", + input: "/subscriptions/sub-123" + + "/resourceGroups/my-rg", + subscriptionId: "sub-123", + rgName: "my-rg", + }, + { + name: "GUID subscription", + input: "/subscriptions" + + "/00000000-0000-0000-0000-000000000000" + + "/resourceGroups/production-rg", + subscriptionId: "00000000-0000-0000-0000-000000000000", + rgName: "production-rg", + }, + { + name: "empty string", + input: "", + expectError: true, + }, + { + name: "invalid format", + input: "not-a-resource-group-id", + expectError: true, + }, + { + name: "subscription only", + input: "/subscriptions/sub-123", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := NewResourceGroupId(tt.input) + if tt.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tt.input, result.Id) + assert.Equal( + t, tt.subscriptionId, result.SubscriptionId, + ) + assert.Equal(t, tt.rgName, result.Name) + }) + } +} + +func TestNewApiVersionPolicy(t *testing.T) { + t.Run("nil version uses default", func(t *testing.T) { + policy := NewApiVersionPolicy(nil) + require.NotNil(t, policy) + }) + + t.Run("custom version", func(t *testing.T) { + version := "2023-10-01" + policy := NewApiVersionPolicy(&version) + require.NotNil(t, policy) + }) +} + +func TestParameterTypes(t *testing.T) { + assert.Equal( + t, ParameterType("string"), ParameterTypeString, + ) + assert.Equal(t, ParameterType("int"), ParameterTypeInt) + assert.Equal(t, ParameterType("bool"), ParameterTypeBool) +} + +func TestProvisioningStates(t *testing.T) { + assert.Equal( + t, ProvisioningState("Succeeded"), + ProvisioningStateSucceeded, + ) + assert.Equal( + t, ProvisioningState("Creating"), + ProvisioningStateCreating, + ) + assert.Equal( + t, ProvisioningState("Deleting"), + ProvisioningStateDeleting, + ) +} + +func TestOutputParameterTypes(t *testing.T) { + assert.Equal( + t, OutputParameterType("array"), OutputParameterTypeArray, + ) + assert.Equal( + t, OutputParameterType("boolean"), + OutputParameterTypeBoolean, + ) + assert.Equal( + t, OutputParameterType("number"), + OutputParameterTypeNumber, + ) + assert.Equal( + t, OutputParameterType("object"), + OutputParameterTypeObject, + ) + assert.Equal( + t, OutputParameterType("string"), + OutputParameterTypeString, + ) +} + +func TestServiceConfig(t *testing.T) { + assert.Equal( + t, + "https://management.core.windows.net", + ServiceConfig.Audience, + ) +} diff --git a/cli/azd/pkg/github/remote_coverage_test.go b/cli/azd/pkg/github/remote_coverage_test.go new file mode 100644 index 00000000000..40eafa5b8df --- /dev/null +++ b/cli/azd/pkg/github/remote_coverage_test.go @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package github + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetSlugForRemote_AdditionalCases(t *testing.T) { + tests := []struct { + name string + remote string + want string + wantErr bool + }{ + { + name: "SSHWithHyphensInOrgAndRepo", + remote: "git@github.com:my-org/my-repo.git", + want: "my-org/my-repo", + }, + { + name: "SSHWithUnderscoresInRepo", + remote: "git@github.com:org/my_repo.git", + want: "org/my_repo", + }, + { + name: "HTTPSWithDotsInRepoName", + remote: "https://github.com/org/repo.v2.git", + want: "org/repo.v2", + }, + { + name: "SSHWithDotsNoGitSuffix", + remote: "git@github.com:org/repo.v2", + want: "org/repo.v2", + }, + { + name: "HTTPSNestedPath", + remote: "https://github.com/org/repo/sub/path", + want: "org/repo/sub/path", + }, + { + name: "SSHNestedPath", + remote: "git@github.com:org/repo/sub/path", + want: "org/repo/sub/path", + }, + { + name: "GitLabSSH", + remote: "git@gitlab.com:org/repo.git", + wantErr: true, + }, + { + name: "BitbucketHTTPS", + remote: "https://bitbucket.org/org/repo.git", + wantErr: true, + }, + { + name: "AzureDevOpsSSH", + remote: "git@ssh.dev.azure.com:v3/org/proj/repo", + wantErr: true, + }, + { + name: "HTTPNotHTTPS", + remote: "http://github.com/org/repo.git", + wantErr: true, + }, + { + name: "GitHubEnterpriseHTTPS", + remote: "https://github.example.com/org/repo.git", + wantErr: true, + }, + { + name: "TrailingSlashReturnsEmptySlug", + remote: "https://github.com/", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + slug, err := GetSlugForRemote(tt.remote) + if tt.wantErr { + require.Error(t, err) + assert.ErrorIs( + t, err, ErrRemoteHostIsNotGitHub, + ) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, slug) + }) + } +} + +func TestErrRemoteHostIsNotGitHub_SentinelError(t *testing.T) { + // Verify the sentinel error works with errors.Is + _, err := GetSlugForRemote("https://not-github.com/o/r") + require.Error(t, err) + assert.True(t, errors.Is(err, ErrRemoteHostIsNotGitHub)) + + // Verify the error message is meaningful + assert.Equal( + t, + "not a github host", + ErrRemoteHostIsNotGitHub.Error(), + ) +} + +func TestGetSlugForRemote_WWWVariant(t *testing.T) { + // www.github.com should work the same as github.com + tests := []struct { + name string + remote string + want string + }{ + { + name: "WWWWithGitSuffix", + remote: "https://www.github.com/org/repo.git", + want: "org/repo", + }, + { + name: "WWWWithoutGitSuffix", + remote: "https://www.github.com/org/repo", + want: "org/repo", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + slug, err := GetSlugForRemote(tt.remote) + require.NoError(t, err) + assert.Equal(t, tt.want, slug) + }) + } +} diff --git a/cli/azd/pkg/httputil/util_coverage_test.go b/cli/azd/pkg/httputil/util_coverage_test.go new file mode 100644 index 00000000000..66dfb907327 --- /dev/null +++ b/cli/azd/pkg/httputil/util_coverage_test.go @@ -0,0 +1,251 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package httputil + +import ( + "bytes" + "errors" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// errReader is an io.Reader that always returns an error. +type errReader struct{} + +func (e errReader) Read(_ []byte) (int, error) { + return 0, errors.New("simulated read error") +} + +// coveragePayload is a simple struct for generic deserialization tests. +type coveragePayload struct { + Value string `json:"value"` +} + +func TestReadRawResponse_InvalidJSON(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser( + bytes.NewBufferString("not valid json"), + ), + } + + result, err := ReadRawResponse[coveragePayload](resp) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed unmarshalling JSON") +} + +func TestReadRawResponse_ReadError(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(errReader{}), + } + + result, err := ReadRawResponse[coveragePayload](resp) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "simulated read error") +} + +func TestReadRawResponse_EmptyBody(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewBufferString("")), + } + + result, err := ReadRawResponse[coveragePayload](resp) + require.Error(t, err) + assert.Nil(t, result) +} + +func TestRetryAfter(t *testing.T) { + tests := []struct { + name string + resp *http.Response + want time.Duration + wantZero bool + }{ + { + name: "NilResponse", + resp: nil, + wantZero: true, + }, + { + name: "NoRetryHeaders", + resp: &http.Response{ + Header: http.Header{ + "Content-Type": {"application/json"}, + }, + }, + wantZero: true, + }, + { + name: "EmptyHeaders", + resp: &http.Response{ + Header: http.Header{}, + }, + wantZero: true, + }, + { + name: "RetryAfterMs", + resp: &http.Response{ + Header: http.Header{ + "Retry-After-Ms": {"150"}, + }, + }, + want: 150 * time.Millisecond, + }, + { + name: "XMsRetryAfterMs", + resp: &http.Response{ + Header: http.Header{ + "X-Ms-Retry-After-Ms": {"250"}, + }, + }, + want: 250 * time.Millisecond, + }, + { + name: "RetryAfterSeconds", + resp: &http.Response{ + Header: http.Header{ + "Retry-After": {"3"}, + }, + }, + want: 3 * time.Second, + }, + { + name: "RetryAfterMsHasHighestPrecedence", + resp: &http.Response{ + Header: http.Header{ + "Retry-After-Ms": {"100"}, + "X-Ms-Retry-After-Ms": {"200"}, + "Retry-After": {"5"}, + }, + }, + want: 100 * time.Millisecond, + }, + { + name: "XMsRetryAfterMsPrecedesRetryAfter", + resp: &http.Response{ + Header: http.Header{ + "X-Ms-Retry-After-Ms": {"300"}, + "Retry-After": {"10"}, + }, + }, + want: 300 * time.Millisecond, + }, + { + name: "InvalidNonNumericRetryAfter", + resp: &http.Response{ + Header: http.Header{ + "Retry-After": {"not-a-number"}, + }, + }, + wantZero: true, + }, + { + name: "ZeroValueRetryAfter", + resp: &http.Response{ + Header: http.Header{ + "Retry-After": {"0"}, + }, + }, + wantZero: true, + }, + { + name: "NegativeRetryAfter", + resp: &http.Response{ + Header: http.Header{ + "Retry-After": {"-5"}, + }, + }, + wantZero: true, + }, + { + name: "ZeroValueRetryAfterMs", + resp: &http.Response{ + Header: http.Header{ + "Retry-After-Ms": {"0"}, + }, + }, + wantZero: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := RetryAfter(tt.resp) + if tt.wantZero { + assert.Equal(t, time.Duration(0), d) + return + } + assert.Equal(t, tt.want, d) + }) + } +} + +func TestRetryAfter_RFC1123DateFormat(t *testing.T) { + futureTime := time.Now().Add(30 * time.Second) + resp := &http.Response{ + Header: http.Header{ + "Retry-After": { + futureTime.UTC().Format(time.RFC1123), + }, + }, + } + + d := RetryAfter(resp) + // Should be close to 30s with some tolerance for test execution + assert.Greater(t, d, 25*time.Second) + assert.Less(t, d, 35*time.Second) +} + +func TestRetryAfter_PastDateReturnsZero(t *testing.T) { + pastTime := time.Now().Add(-60 * time.Second) + resp := &http.Response{ + Header: http.Header{ + "Retry-After": { + pastTime.UTC().Format(time.RFC1123), + }, + }, + } + + d := RetryAfter(resp) + assert.Equal(t, time.Duration(0), d) +} + +func TestRetryAfter_InvalidDateFormat(t *testing.T) { + resp := &http.Response{ + Header: http.Header{ + "Retry-After": {"Mon, 99 Abc 9999 99:99:99 ZZZ"}, + }, + } + + d := RetryAfter(resp) + assert.Equal(t, time.Duration(0), d) +} + +func TestTlsEnabledTransport_InvalidBase64(t *testing.T) { + transport, err := TlsEnabledTransport("not-valid-base64!!!") + require.Error(t, err) + assert.Nil(t, transport) + assert.Contains(t, err.Error(), "failed to decode") +} + +func TestTlsEnabledTransport_InvalidCertBytes(t *testing.T) { + // Valid base64 encoding of "hello world" — not a real DER cert + validBase64InvalidCert := "aGVsbG8gd29ybGQ=" + transport, err := TlsEnabledTransport(validBase64InvalidCert) + require.Error(t, err) + assert.Nil(t, transport) + assert.Contains(t, err.Error(), "failed to parse") +} diff --git a/cli/azd/pkg/infra/deployment_manager_test.go b/cli/azd/pkg/infra/deployment_manager_test.go new file mode 100644 index 00000000000..ab70c382398 --- /dev/null +++ b/cli/azd/pkg/infra/deployment_manager_test.go @@ -0,0 +1,376 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package infra + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/azure/azure-dev/cli/azd/pkg/azapi" + "github.com/azure/azure-dev/cli/azd/pkg/azure" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeDeploymentService is a minimal stub that satisfies +// azapi.DeploymentService for the subset used by +// DeploymentManager unit tests. +type fakeDeploymentService struct { + azapi.DeploymentService + generateName func(string) string + calcHash func() (string, error) +} + +func (f *fakeDeploymentService) GenerateDeploymentName( + base string, +) string { + if f.generateName != nil { + return f.generateName(base) + } + return base + "-generated" +} + +func (f *fakeDeploymentService) CalculateTemplateHash( + _ context.Context, + _ string, + _ azure.RawArmTemplate, +) (string, error) { + if f.calcHash != nil { + return f.calcHash() + } + return "abc123", nil +} + +func TestNewDeploymentManager(t *testing.T) { + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + require.NotNil(t, dm) +} + +func TestGenerateDeploymentName(t *testing.T) { + tests := []struct { + name string + base string + expected string + }{ + { + name: "simple base", + base: "myenv", + expected: "myenv-generated", + }, + { + name: "empty base", + base: "", + expected: "-generated", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + result := dm.GenerateDeploymentName(tt.base) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCalculateTemplateHash(t *testing.T) { + t.Run("success", func(t *testing.T) { + dm := NewDeploymentManager( + &fakeDeploymentService{ + calcHash: func() (string, error) { + return "hash-xyz", nil + }, + }, nil, nil, + ) + hash, err := dm.CalculateTemplateHash( + context.Background(), + "sub-1", + azure.RawArmTemplate("{}"), + ) + require.NoError(t, err) + assert.Equal(t, "hash-xyz", hash) + }) + + t.Run("error", func(t *testing.T) { + dm := NewDeploymentManager( + &fakeDeploymentService{ + calcHash: func() (string, error) { + return "", fmt.Errorf("hash failed") + }, + }, nil, nil, + ) + _, err := dm.CalculateTemplateHash( + context.Background(), + "sub-1", + azure.RawArmTemplate("{}"), + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "hash failed") + }) +} + +func TestSubscriptionScope(t *testing.T) { + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + scope := dm.SubscriptionScope("sub-1", "eastus") + require.NotNil(t, scope) + assert.Equal(t, "sub-1", scope.SubscriptionId()) + assert.Equal(t, "eastus", scope.Location()) +} + +func TestResourceGroupScope(t *testing.T) { + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + scope := dm.ResourceGroupScope("sub-1", "rg-1") + require.NotNil(t, scope) + assert.Equal(t, "sub-1", scope.SubscriptionId()) + assert.Equal(t, "rg-1", scope.ResourceGroupName()) +} + +func TestSubscriptionDeploymentFactory(t *testing.T) { + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + scope := dm.SubscriptionScope("sub-1", "eastus") + deployment := dm.SubscriptionDeployment(scope, "dep-1") + require.NotNil(t, deployment) + assert.Equal(t, "dep-1", deployment.Name()) + assert.Equal(t, "sub-1", deployment.SubscriptionId()) +} + +func TestResourceGroupDeploymentFactory(t *testing.T) { + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + scope := dm.ResourceGroupScope("sub-1", "rg-1") + deployment := dm.ResourceGroupDeployment(scope, "dep-1") + require.NotNil(t, deployment) + assert.Equal(t, "dep-1", deployment.Name()) + assert.Equal(t, "rg-1", deployment.ResourceGroupName()) +} + +func TestProgressDisplay(t *testing.T) { + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + scope := dm.SubscriptionScope("sub-1", "eastus") + deployment := dm.SubscriptionDeployment(scope, "dep-1") + display := dm.ProgressDisplay(deployment) + require.NotNil(t, display) +} + +// fakeScope stubs Scope for CompletedDeployments tests. +type fakeScope struct { + subscriptionId string + deployments []*azapi.ResourceDeployment + err error +} + +func (f *fakeScope) SubscriptionId() string { + return f.subscriptionId +} + +func (f *fakeScope) ListDeployments( + _ context.Context, +) ([]*azapi.ResourceDeployment, error) { + return f.deployments, f.err +} + +func (f *fakeScope) Deployment(name string) Deployment { + return nil +} + +func TestCompletedDeployments(t *testing.T) { + now := time.Now().UTC() + + t.Run("matches by env tag", func(t *testing.T) { + envName := "myenv" + tagVal := envName + scope := &fakeScope{ + subscriptionId: "sub-1", + deployments: []*azapi.ResourceDeployment{ + { + Name: "deploy-1", + Tags: map[string]*string{ + azure.TagKeyAzdEnvName: &tagVal, + }, + ProvisioningState: azapi.DeploymentProvisioningStateSucceeded, + Timestamp: now, + }, + }, + } + + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + result, err := dm.CompletedDeployments( + context.Background(), scope, envName, "", "", + ) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, "deploy-1", result[0].Name) + }) + + t.Run("matches by env and layer tag", func(t *testing.T) { + envName := "myenv" + layerName := "layer1" + envTag := envName + layerTag := layerName + scope := &fakeScope{ + subscriptionId: "sub-1", + deployments: []*azapi.ResourceDeployment{ + { + Name: "deploy-layer1", + Tags: map[string]*string{ + azure.TagKeyAzdEnvName: &envTag, + azure.TagKeyAzdLayerName: &layerTag, + }, + ProvisioningState: azapi.DeploymentProvisioningStateSucceeded, + Timestamp: now, + }, + }, + } + + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + result, err := dm.CompletedDeployments( + context.Background(), + scope, envName, layerName, "", + ) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, "deploy-layer1", result[0].Name) + }) + + t.Run("legacy match by exact deployment name", + func(t *testing.T) { + envName := "myenv" + scope := &fakeScope{ + subscriptionId: "sub-1", + deployments: []*azapi.ResourceDeployment{ + { + Name: envName, + Tags: map[string]*string{}, + ProvisioningState: azapi.DeploymentProvisioningStateSucceeded, + Timestamp: now, + }, + }, + } + + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + result, err := dm.CompletedDeployments( + context.Background(), scope, envName, "", "", + ) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, envName, result[0].Name) + }) + + t.Run("hint fallback returns partial matches", + func(t *testing.T) { + scope := &fakeScope{ + subscriptionId: "sub-1", + deployments: []*azapi.ResourceDeployment{ + { + Name: "myenv-deploy-001", + Tags: map[string]*string{}, + ProvisioningState: azapi.DeploymentProvisioningStateSucceeded, + Timestamp: now, + }, + { + Name: "myenv-deploy-002", + Tags: map[string]*string{}, + ProvisioningState: azapi.DeploymentProvisioningStateSucceeded, + Timestamp: now.Add(-time.Hour), + }, + { + Name: "other-deploy", + Tags: map[string]*string{}, + ProvisioningState: azapi.DeploymentProvisioningStateSucceeded, + Timestamp: now, + }, + }, + } + + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + result, err := dm.CompletedDeployments( + context.Background(), + scope, "myenv", "", "", + ) + require.NoError(t, err) + assert.Len(t, result, 2) + }) + + t.Run("skips non-terminal deployments", func(t *testing.T) { + scope := &fakeScope{ + subscriptionId: "sub-1", + deployments: []*azapi.ResourceDeployment{ + { + Name: "myenv-running", + Tags: map[string]*string{}, + ProvisioningState: azapi.DeploymentProvisioningStateRunning, + Timestamp: now, + }, + }, + } + + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + _, err := dm.CompletedDeployments( + context.Background(), + scope, "myenv-running", "", "", + ) + require.Error(t, err) + assert.ErrorIs(t, err, ErrDeploymentsNotFound) + }) + + t.Run("no matching deployments returns error", + func(t *testing.T) { + scope := &fakeScope{ + subscriptionId: "sub-1", + deployments: []*azapi.ResourceDeployment{}, + } + + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + _, err := dm.CompletedDeployments( + context.Background(), + scope, "myenv", "", "", + ) + require.Error(t, err) + assert.ErrorIs(t, err, ErrDeploymentsNotFound) + }) + + t.Run("list error propagated", func(t *testing.T) { + scope := &fakeScope{ + subscriptionId: "sub-1", + err: fmt.Errorf("list failed"), + } + + dm := NewDeploymentManager( + &fakeDeploymentService{}, nil, nil, + ) + _, err := dm.CompletedDeployments( + context.Background(), + scope, "myenv", "", "", + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "list failed") + }) +} diff --git a/cli/azd/pkg/infra/provisioning/deployment_additional_test.go b/cli/azd/pkg/infra/provisioning/deployment_additional_test.go new file mode 100644 index 00000000000..17fcdccd4f5 --- /dev/null +++ b/cli/azd/pkg/infra/provisioning/deployment_additional_test.go @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package provisioning + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/azapi" + "github.com/azure/azure-dev/cli/azd/pkg/azure" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParameterTypeFromArmType(t *testing.T) { + tests := []struct { + name string + armType string + expected ParameterType + }{ + {"String", "String", ParameterTypeString}, + {"string lowercase", "string", ParameterTypeString}, + {"secureString", "secureString", ParameterTypeString}, + {"securestring lowercase", "securestring", ParameterTypeString}, + {"Bool", "Bool", ParameterTypeBoolean}, + {"bool lowercase", "bool", ParameterTypeBoolean}, + {"Int", "Int", ParameterTypeNumber}, + {"int lowercase", "int", ParameterTypeNumber}, + {"Object", "Object", ParameterTypeObject}, + {"object lowercase", "object", ParameterTypeObject}, + {"secureObject", "secureObject", ParameterTypeObject}, + {"secureobject lowercase", "secureobject", ParameterTypeObject}, + {"Array", "Array", ParameterTypeArray}, + {"array lowercase", "array", ParameterTypeArray}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParameterTypeFromArmType(tt.armType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParameterTypeFromArmType_Panic(t *testing.T) { + assert.Panics(t, func() { + ParameterTypeFromArmType("unknown") + }) +} + +func TestOutputParametersFromArmOutputs(t *testing.T) { + t.Run("canonical casing from template", func(t *testing.T) { + tplOutputs := azure.ArmTemplateOutputs{ + "AZURE_STORAGE_NAME": azure.ArmTemplateOutput{ + Type: "string", + }, + } + azureOutputs := map[string]azapi.AzCliDeploymentOutput{ + "azure_storage_name": { + Type: "string", + Value: "mystorage", + }, + } + + result := OutputParametersFromArmOutputs( + tplOutputs, azureOutputs, + ) + require.Len(t, result, 1) + param, ok := result["AZURE_STORAGE_NAME"] + require.True(t, ok) + assert.Equal(t, ParameterTypeString, param.Type) + assert.Equal(t, "mystorage", param.Value) + }) + + t.Run("uppercase fallback when not in template", + func(t *testing.T) { + tplOutputs := azure.ArmTemplateOutputs{} + azureOutputs := map[string]azapi.AzCliDeploymentOutput{ + "azurE_RESOURCE_GROUP": { + Type: "string", + Value: "my-rg", + }, + } + + result := OutputParametersFromArmOutputs( + tplOutputs, azureOutputs, + ) + require.Len(t, result, 1) + param, ok := result["AZURE_RESOURCE_GROUP"] + require.True(t, ok) + assert.Equal(t, "my-rg", param.Value) + }) + + t.Run("skips secured outputs", func(t *testing.T) { + tplOutputs := azure.ArmTemplateOutputs{ + "secret": azure.ArmTemplateOutput{ + Type: "secureString", + }, + } + azureOutputs := map[string]azapi.AzCliDeploymentOutput{ + "secret": { + Type: "secureString", + Value: "hidden", + }, + } + + result := OutputParametersFromArmOutputs( + tplOutputs, azureOutputs, + ) + assert.Empty(t, result) + }) + + t.Run("empty inputs", func(t *testing.T) { + result := OutputParametersFromArmOutputs( + azure.ArmTemplateOutputs{}, + map[string]azapi.AzCliDeploymentOutput{}, + ) + assert.Empty(t, result) + }) + + t.Run("multiple outputs with mixed types", + func(t *testing.T) { + tplOutputs := azure.ArmTemplateOutputs{ + "myStr": azure.ArmTemplateOutput{ + Type: "string", + }, + "myBool": azure.ArmTemplateOutput{ + Type: "bool", + }, + "myArr": azure.ArmTemplateOutput{ + Type: "array", + }, + } + azureOutputs := map[string]azapi.AzCliDeploymentOutput{ + "mystr": {Type: "string", Value: "hello"}, + "mybool": {Type: "bool", Value: true}, + "myarr": { + Type: "array", + Value: []any{"a", "b"}, + }, + } + + result := OutputParametersFromArmOutputs( + tplOutputs, azureOutputs, + ) + require.Len(t, result, 3) + assert.Equal( + t, ParameterTypeString, result["myStr"].Type, + ) + assert.Equal( + t, ParameterTypeBoolean, result["myBool"].Type, + ) + assert.Equal( + t, ParameterTypeArray, result["myArr"].Type, + ) + }) +} + +func TestStateMergeInto(t *testing.T) { + t.Run("merges outputs from other into empty state", + func(t *testing.T) { + s := &State{} + other := State{ + Outputs: map[string]OutputParameter{ + "key1": { + Type: ParameterTypeString, Value: "val1", + }, + }, + Resources: []Resource{{Id: "res-1"}}, + } + + s.MergeInto(other) + require.Len(t, s.Outputs, 1) + assert.Equal(t, "val1", s.Outputs["key1"].Value) + require.Len(t, s.Resources, 1) + assert.Equal(t, "res-1", s.Resources[0].Id) + }) + + t.Run("overwrites existing output key", + func(t *testing.T) { + s := &State{ + Outputs: map[string]OutputParameter{ + "key1": { + Type: ParameterTypeString, Value: "old", + }, + }, + } + other := State{ + Outputs: map[string]OutputParameter{ + "key1": { + Type: ParameterTypeString, Value: "new", + }, + }, + } + + s.MergeInto(other) + assert.Equal(t, "new", s.Outputs["key1"].Value) + }) + + t.Run("replaces resource by matching ID", + func(t *testing.T) { + s := &State{ + Resources: []Resource{{Id: "res-1"}}, + } + other := State{ + Resources: []Resource{{Id: "res-1"}}, + } + + s.MergeInto(other) + require.Len(t, s.Resources, 1) + }) + + t.Run("appends new resources", + func(t *testing.T) { + s := &State{ + Resources: []Resource{{Id: "res-1"}}, + } + other := State{ + Resources: []Resource{{Id: "res-2"}}, + } + + s.MergeInto(other) + require.Len(t, s.Resources, 2) + }) + + t.Run("empty other is no-op", + func(t *testing.T) { + s := &State{ + Outputs: map[string]OutputParameter{ + "k": { + Type: ParameterTypeString, Value: "v", + }, + }, + Resources: []Resource{{Id: "r1"}}, + } + s.MergeInto(State{}) + require.Len(t, s.Outputs, 1) + require.Len(t, s.Resources, 1) + }) +} diff --git a/cli/azd/pkg/infra/provisioning/options_test.go b/cli/azd/pkg/infra/provisioning/options_test.go new file mode 100644 index 00000000000..67229b3058b --- /dev/null +++ b/cli/azd/pkg/infra/provisioning/options_test.go @@ -0,0 +1,267 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package provisioning + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/output" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewDestroyOptions(t *testing.T) { + tests := []struct { + name string + force bool + purge bool + }{ + {"both false", false, false}, + {"force only", true, false}, + {"purge only", false, true}, + {"both true", true, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := NewDestroyOptions(tt.force, tt.purge) + assert.Equal(t, tt.force, opts.Force()) + assert.Equal(t, tt.purge, opts.Purge()) + }) + } +} + +func TestNewStateOptions(t *testing.T) { + t.Run("stores hint", func(t *testing.T) { + opts := NewStateOptions("my-hint") + assert.Equal(t, "my-hint", opts.Hint()) + }) + + t.Run("empty hint", func(t *testing.T) { + opts := NewStateOptions("") + assert.Equal(t, "", opts.Hint()) + }) +} + +func TestNewActionOptions(t *testing.T) { + t.Run("interactive with nil formatter", func(t *testing.T) { + opts := NewActionOptions(nil, true) + // Formatter() should return a NoneFormatter + assert.NotNil(t, opts.Formatter()) + assert.Equal( + t, output.NoneFormat, opts.Formatter().Kind(), + ) + // IsInteractive returns true when no format set + assert.True(t, opts.IsInteractive()) + }) + + t.Run("non-interactive", func(t *testing.T) { + opts := NewActionOptions(nil, false) + assert.False(t, opts.IsInteractive()) + }) + + t.Run("interactive with json formatter", func(t *testing.T) { + formatter := &output.JsonFormatter{} + opts := NewActionOptions(formatter, true) + assert.Equal( + t, output.JsonFormat, opts.Formatter().Kind(), + ) + // IsInteractive returns false when a format is set + assert.False(t, opts.IsInteractive()) + }) +} + +func TestOptionsGetLayers(t *testing.T) { + t.Run("no layers returns self", func(t *testing.T) { + opts := &Options{ + Provider: Bicep, + Path: "infra", + } + layers := opts.GetLayers() + require.Len(t, layers, 1) + assert.Equal(t, Bicep, layers[0].Provider) + assert.Equal(t, "infra", layers[0].Path) + }) + + t.Run("with layers returns layers", func(t *testing.T) { + opts := &Options{ + Layers: []Options{ + {Name: "layer1", Path: "infra/1"}, + {Name: "layer2", Path: "infra/2"}, + }, + } + layers := opts.GetLayers() + require.Len(t, layers, 2) + assert.Equal(t, "layer1", layers[0].Name) + assert.Equal(t, "layer2", layers[1].Name) + }) +} + +func TestOptionsGetLayer(t *testing.T) { + t.Run("empty name with no layers returns self", + func(t *testing.T) { + opts := &Options{ + Provider: Bicep, + Path: "infra", + } + layer, err := opts.GetLayer("") + require.NoError(t, err) + assert.Equal(t, Bicep, layer.Provider) + }) + + t.Run("empty name with layers returns error", + func(t *testing.T) { + opts := &Options{ + Layers: []Options{ + {Name: "layer1", Path: "infra/1"}, + }, + } + _, err := opts.GetLayer("") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("named layer found", func(t *testing.T) { + opts := &Options{ + Layers: []Options{ + {Name: "a", Path: "infra/a"}, + {Name: "b", Path: "infra/b"}, + }, + } + layer, err := opts.GetLayer("b") + require.NoError(t, err) + assert.Equal(t, "b", layer.Name) + assert.Equal(t, "infra/b", layer.Path) + }) + + t.Run("named layer not found", func(t *testing.T) { + opts := &Options{ + Layers: []Options{ + {Name: "a", Path: "infra/a"}, + }, + } + _, err := opts.GetLayer("missing") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing") + assert.Contains(t, err.Error(), "available layers: a") + }) + + t.Run("no layers defined returns error for non-empty name", + func(t *testing.T) { + opts := &Options{} + _, err := opts.GetLayer("something") + require.Error(t, err) + assert.Contains( + t, err.Error(), "no layers defined", + ) + }) +} + +func TestOptionsValidate(t *testing.T) { + t.Run("empty options is valid", func(t *testing.T) { + opts := &Options{} + require.NoError(t, opts.Validate()) + }) + + t.Run("layers without incompatible fields is valid", + func(t *testing.T) { + opts := &Options{ + Layers: []Options{ + {Name: "l1", Path: "infra/l1"}, + }, + } + require.NoError(t, opts.Validate()) + }) + + t.Run("layers with path set at top level is invalid", + func(t *testing.T) { + opts := &Options{ + Path: "infra", + Layers: []Options{ + {Name: "l1", Path: "infra/l1"}, + }, + } + err := opts.Validate() + require.Error(t, err) + assert.Contains( + t, err.Error(), + "properties on 'infra' cannot be declared", + ) + }) + + t.Run("layers with name set at top level is invalid", + func(t *testing.T) { + opts := &Options{ + Name: "top-name", + Layers: []Options{ + {Name: "l1", Path: "infra/l1"}, + }, + } + err := opts.Validate() + require.Error(t, err) + }) + + t.Run("layers with module set at top level is invalid", + func(t *testing.T) { + opts := &Options{ + Module: "main", + Layers: []Options{ + {Name: "l1", Path: "infra/l1"}, + }, + } + err := opts.Validate() + require.Error(t, err) + }) + + t.Run( + "layers with DeploymentStacks set at top level is invalid", + func(t *testing.T) { + opts := &Options{ + DeploymentStacks: map[string]any{"k": "v"}, + Layers: []Options{ + {Name: "l1", Path: "infra/l1"}, + }, + } + err := opts.Validate() + require.Error(t, err) + }) + + t.Run("layer without name is invalid", + func(t *testing.T) { + opts := &Options{ + Layers: []Options{ + {Path: "infra/l1"}, + }, + } + err := opts.Validate() + require.Error(t, err) + assert.Contains( + t, err.Error(), "name must be specified", + ) + }) + + t.Run("layer without path is invalid", + func(t *testing.T) { + opts := &Options{ + Layers: []Options{ + {Name: "l1"}, + }, + } + err := opts.Validate() + require.Error(t, err) + assert.Contains( + t, err.Error(), "path must be specified", + ) + }) + + t.Run("multiple valid layers", func(t *testing.T) { + opts := &Options{ + Layers: []Options{ + {Name: "l1", Path: "infra/l1"}, + {Name: "l2", Path: "infra/l2"}, + }, + } + require.NoError(t, opts.Validate()) + }) +} diff --git a/cli/azd/pkg/infra/provisioning/provisioning_test.go b/cli/azd/pkg/infra/provisioning/provisioning_test.go new file mode 100644 index 00000000000..74f6c7a9e65 --- /dev/null +++ b/cli/azd/pkg/infra/provisioning/provisioning_test.go @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package provisioning + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/contracts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewEnvRefreshResultFromState(t *testing.T) { + t.Run("all parameter types mapped", func(t *testing.T) { + state := &State{ + Outputs: map[string]OutputParameter{ + "str": {Type: ParameterTypeString, Value: "hi"}, + "num": {Type: ParameterTypeNumber, Value: 42}, + "bool": {Type: ParameterTypeBoolean, Value: true}, + "obj": { + Type: ParameterTypeObject, + Value: map[string]any{"k": "v"}, + }, + "arr": { + Type: ParameterTypeArray, + Value: []any{1, 2}, + }, + }, + Resources: []Resource{ + {Id: "res-1"}, + {Id: "res-2"}, + }, + } + + result := NewEnvRefreshResultFromState(state) + + require.Len(t, result.Outputs, 5) + assert.Equal( + t, + contracts.EnvRefreshOutputTypeString, + result.Outputs["str"].Type, + ) + assert.Equal( + t, + contracts.EnvRefreshOutputTypeNumber, + result.Outputs["num"].Type, + ) + assert.Equal( + t, + contracts.EnvRefreshOutputTypeBoolean, + result.Outputs["bool"].Type, + ) + assert.Equal( + t, + contracts.EnvRefreshOutputTypeObject, + result.Outputs["obj"].Type, + ) + assert.Equal( + t, + contracts.EnvRefreshOutputTypeArray, + result.Outputs["arr"].Type, + ) + + require.Len(t, result.Resources, 2) + assert.Equal(t, "res-1", result.Resources[0].Id) + assert.Equal(t, "res-2", result.Resources[1].Id) + }) + + t.Run("empty state", func(t *testing.T) { + state := &State{ + Outputs: map[string]OutputParameter{}, + Resources: []Resource{}, + } + + result := NewEnvRefreshResultFromState(state) + assert.Empty(t, result.Outputs) + assert.Empty(t, result.Resources) + }) + + t.Run("output values preserved", func(t *testing.T) { + state := &State{ + Outputs: map[string]OutputParameter{ + "key": { + Type: ParameterTypeString, + Value: "test-value", + }, + }, + Resources: []Resource{}, + } + + result := NewEnvRefreshResultFromState(state) + assert.Equal( + t, "test-value", result.Outputs["key"].Value, + ) + }) +} + +func TestParseProvider(t *testing.T) { + tests := []struct { + name string + kind ProviderKind + expected ProviderKind + expectErr bool + }{ + { + name: "empty defaults to NotSpecified", + kind: NotSpecified, + expected: NotSpecified, + }, + { + name: "bicep", + kind: Bicep, + expected: Bicep, + }, + { + name: "terraform", + kind: Terraform, + expected: Terraform, + }, + { + name: "test", + kind: Test, + expected: Test, + }, + { + name: "unsupported provider", + kind: ProviderKind("invalid"), + expectErr: true, + }, + { + name: "pulumi is unsupported", + kind: Pulumi, + expectErr: true, + }, + { + name: "arm is unsupported", + kind: Arm, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseProvider(tt.kind) + if tt.expectErr { + require.Error(t, err) + assert.Contains( + t, err.Error(), "unsupported IaC provider", + ) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} diff --git a/cli/azd/pkg/infra/util_test.go b/cli/azd/pkg/infra/util_test.go new file mode 100644 index 00000000000..2e2c45c2b8f --- /dev/null +++ b/cli/azd/pkg/infra/util_test.go @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package infra + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResourceIdName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple lowercase", + input: "storage", + expected: "AZURE_RESOURCE_STORAGE_ID", + }, + { + name: "mixed case", + input: "myStorage", + expected: "AZURE_RESOURCE_MYSTORAGE_ID", + }, + { + name: "with dashes", + input: "my-resource", + expected: "AZURE_RESOURCE_MY_RESOURCE_ID", + }, + { + name: "already uppercase", + input: "COSMOS", + expected: "AZURE_RESOURCE_COSMOS_ID", + }, + { + name: "empty string", + input: "", + expected: "AZURE_RESOURCE__ID", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ResourceIdName(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestResourceId(t *testing.T) { + //nolint:lll + validResourceID := "/subscriptions/sub-1/resourceGroups/rg-1/providers/Microsoft.Storage/storageAccounts/mystorage" + + t.Run("parses a valid resource ID directly", func(t *testing.T) { + env := environment.New("test") + resId, err := ResourceId(validResourceID, env) + require.NoError(t, err) + assert.Equal( + t, + "mystorage", + resId.Name, + ) + assert.Equal( + t, + "Microsoft.Storage/storageAccounts", + resId.ResourceType.String(), + ) + }) + + t.Run("resolves from env when name is not a resource ID", + func(t *testing.T) { + env := environment.NewWithValues("test", map[string]string{ + "AZURE_RESOURCE_STORAGE_ID": validResourceID, + }) + resId, err := ResourceId("storage", env) + require.NoError(t, err) + assert.Equal(t, "mystorage", resId.Name) + }) + + t.Run("error when env var not set", func(t *testing.T) { + env := environment.New("test") + _, err := ResourceId("notexist", env) + require.Error(t, err) + assert.Contains( + t, + err.Error(), + "AZURE_RESOURCE_NOTEXIST_ID is not set", + ) + }) + + t.Run("error when env var is empty", func(t *testing.T) { + env := environment.NewWithValues("test", map[string]string{ + "AZURE_RESOURCE_EMPTY_ID": "", + }) + _, err := ResourceId("empty", env) + require.Error(t, err) + assert.Contains( + t, + err.Error(), + "AZURE_RESOURCE_EMPTY_ID is empty", + ) + }) + + t.Run("error when env var has invalid resource ID", + func(t *testing.T) { + env := environment.NewWithValues("test", map[string]string{ + "AZURE_RESOURCE_BAD_ID": "not-a-resource-id", + }) + _, err := ResourceId("bad", env) + require.Error(t, err) + assert.Contains( + t, + err.Error(), + "parsing AZURE_RESOURCE_BAD_ID", + ) + }) +} + +func TestKeyVaultName(t *testing.T) { + t.Run("returns value when set", func(t *testing.T) { + env := environment.NewWithValues("test", map[string]string{ + "AZURE_KEY_VAULT_NAME": "my-keyvault", + }) + assert.Equal(t, "my-keyvault", KeyVaultName(env)) + }) + + t.Run("returns empty string when not set", func(t *testing.T) { + env := environment.New("test") + assert.Equal(t, "", KeyVaultName(env)) + }) +} diff --git a/cli/azd/pkg/input/asker_test.go b/cli/azd/pkg/input/asker_test.go new file mode 100644 index 00000000000..861f8d2477f --- /dev/null +++ b/cli/azd/pkg/input/asker_test.go @@ -0,0 +1,541 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package input + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/AlecAivazis/survey/v2" + "github.com/stretchr/testify/require" +) + +func Test_askOneNoPrompt_Input(t *testing.T) { + tests := []struct { + name string + message string + defaultVal string + wantResult string + wantErr bool + errContains string + }{ + { + name: "WithDefault", + message: "Enter name:", + defaultVal: "Alice", + wantResult: "Alice", + }, + { + name: "NoDefault", + message: "Enter name:", + defaultVal: "", + wantErr: true, + errContains: "no default response", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prompt := &survey.Input{ + Message: tt.message, + Default: tt.defaultVal, + } + var result string + err := askOneNoPrompt(prompt, &result) + + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantResult, result) + } + }) + } +} + +func Test_askOneNoPrompt_Select_IntResponse(t *testing.T) { + tests := []struct { + name string + options []string + defaultVal any + wantIdx int + wantErr bool + errContains string + }{ + { + name: "DefaultInList", + options: []string{"a", "b", "c"}, + defaultVal: "b", + wantIdx: 1, + }, + { + name: "DefaultNotInList", + options: []string{"a", "b", "c"}, + defaultVal: "missing", + wantErr: true, + errContains: "default response not in list", + }, + { + name: "NilDefault", + options: []string{"a"}, + defaultVal: nil, + wantErr: true, + errContains: "no default response", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prompt := &survey.Select{ + Message: "Pick one:", + Options: tt.options, + Default: tt.defaultVal, + } + var result int + err := askOneNoPrompt(prompt, &result) + + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantIdx, result) + } + }) + } +} + +func Test_askOneNoPrompt_Select_StringResponse(t *testing.T) { + prompt := &survey.Select{ + Message: "Pick:", + Options: []string{"x", "y"}, + Default: "y", + } + var result string + err := askOneNoPrompt(prompt, &result) + + require.NoError(t, err) + require.Equal(t, "y", result) +} + +func Test_askOneNoPrompt_Select_BadResponseType(t *testing.T) { + prompt := &survey.Select{ + Message: "Pick:", + Options: []string{"x"}, + Default: "x", + } + var result float64 + err := askOneNoPrompt(prompt, &result) + + require.Error(t, err) + require.Contains(t, err.Error(), "bad type") +} + +func Test_askOneNoPrompt_Confirm(t *testing.T) { + tests := []struct { + name string + defaultVal bool + wantResult bool + }{ + {"DefaultTrue", true, true}, + {"DefaultFalse", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prompt := &survey.Confirm{ + Message: "Continue?", + Default: tt.defaultVal, + } + var result bool + err := askOneNoPrompt(prompt, &result) + + require.NoError(t, err) + require.Equal(t, tt.wantResult, result) + }) + } +} + +func Test_askOneNoPrompt_MultiSelect(t *testing.T) { + tests := []struct { + name string + options []string + defaultVal any + wantResult []string + wantErr bool + errContains string + }{ + { + name: "WithDefaults", + options: []string{"a", "b", "c"}, + defaultVal: []string{"a", "c"}, + wantResult: []string{"a", "c"}, + }, + { + name: "NilDefault", + options: []string{"a"}, + defaultVal: nil, + wantErr: true, + errContains: "no default response", + }, + { + name: "WrongDefaultType", + options: []string{"a"}, + defaultVal: "not-a-slice", + wantErr: true, + errContains: "not a string list", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prompt := &survey.MultiSelect{ + Message: "Pick many:", + Options: tt.options, + Default: tt.defaultVal, + } + var result []string + err := askOneNoPrompt(prompt, &result) + + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantResult, result) + } + }) + } +} + +func Test_askOneNoPrompt_UnknownType_Panics(t *testing.T) { + prompt := &survey.Password{Message: "Secret:"} + var result string + + require.Panics(t, func() { + _ = askOneNoPrompt(prompt, &result) + }) +} + +func Test_NewAsker_NoPrompt(t *testing.T) { + asker := NewAsker(true, false, nil, nil) + prompt := &survey.Confirm{ + Message: "OK?", + Default: true, + } + var result bool + err := asker(prompt, &result) + + require.NoError(t, err) + require.True(t, result) +} + +func Test_NewAsker_NonTerminal_Input(t *testing.T) { + input := "Alice\n" + r := strings.NewReader(input) + w := &bytes.Buffer{} + + asker := NewAsker(false, false, w, r) + prompt := &survey.Input{ + Message: "Name:", + } + var result string + err := asker(prompt, &result) + + require.NoError(t, err) + require.Equal(t, "Alice", result) +} + +func Test_NewAsker_NonTerminal_InputWithDefault(t *testing.T) { + // Empty input should use default + input := "\n" + r := strings.NewReader(input) + w := &bytes.Buffer{} + + asker := NewAsker(false, false, w, r) + prompt := &survey.Input{ + Message: "Name:", + Default: "Bob", + } + var result string + err := asker(prompt, &result) + + require.NoError(t, err) + require.Equal(t, "Bob", result) +} + +func Test_NewAsker_NonTerminal_Confirm(t *testing.T) { + tests := []struct { + name string + input string + want bool + defVal bool + }{ + {"Yes", "Y\n", true, false}, + {"No", "n\n", false, true}, + {"EmptyDefaultTrue", "\n", true, true}, + {"EmptyDefaultFalse", "\n", false, false}, + {"LowercaseY", "y\n", true, false}, + {"UppercaseN", "N\n", false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := strings.NewReader(tt.input) + w := &bytes.Buffer{} + + asker := NewAsker(false, false, w, r) + prompt := &survey.Confirm{ + Message: "Continue?", + Default: tt.defVal, + } + result := tt.defVal + err := asker(prompt, &result) + + require.NoError(t, err) + require.Equal(t, tt.want, result) + }) + } +} + +func Test_NewAsker_NonTerminal_Select(t *testing.T) { + tests := []struct { + name string + input string + options []string + defaultVal any + wantIdx int + wantErr bool + errContains string + }{ + { + name: "ExactMatch", + input: "beta\n", + options: []string{"alpha", "beta", "gamma"}, + wantIdx: 1, + }, + { + name: "EmptyUsesDefault", + input: "\n", + options: []string{"alpha", "beta"}, + defaultVal: "alpha", + wantIdx: 0, + }, + { + name: "InvalidChoice", + input: "missing\n", + options: []string{"a", "b"}, + wantErr: true, + errContains: "not an allowed choice", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := strings.NewReader(tt.input) + w := &bytes.Buffer{} + + asker := NewAsker(false, false, w, r) + prompt := &survey.Select{ + Message: "Pick:", + Options: tt.options, + Default: tt.defaultVal, + } + var result int + err := asker(prompt, &result) + + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantIdx, result) + } + }) + } +} + +func Test_NewAsker_NonTerminal_Password(t *testing.T) { + r := strings.NewReader("s3cret\n") + w := &bytes.Buffer{} + + asker := NewAsker(false, false, w, r) + prompt := &survey.Password{ + Message: "Password:", + } + var result string + err := asker(prompt, &result) + + require.NoError(t, err) + require.Equal(t, "s3cret", result) +} + +func Test_NewAsker_NonTerminal_UnknownType_Panics(t *testing.T) { + asker := NewAsker(false, false, &bytes.Buffer{}, strings.NewReader("")) + + require.Panics(t, func() { + var result string + _ = asker(&survey.Editor{Message: "Edit:"}, &result) + }) +} + +func Test_askOnePrompt_Select_StringResponse(t *testing.T) { + r := strings.NewReader("beta\n") + w := &bytes.Buffer{} + + prompt := &survey.Select{ + Message: "Pick:", + Options: []string{"alpha", "beta"}, + } + var result string + err := askOnePrompt(prompt, &result, false, w, r) + + require.NoError(t, err) + require.Equal(t, "beta", result) +} + +func Test_askOnePrompt_Select_BadResponseType(t *testing.T) { + r := strings.NewReader("alpha\n") + w := &bytes.Buffer{} + + prompt := &survey.Select{ + Message: "Pick:", + Options: []string{"alpha"}, + } + var result float64 + err := askOnePrompt(prompt, &result, false, w, r) + + require.Error(t, err) + require.Contains(t, err.Error(), "bad type") +} + +func Test_askOnePrompt_MultiSelect_BadDefaultType(t *testing.T) { + r := strings.NewReader("\n") + w := &bytes.Buffer{} + + prompt := &survey.MultiSelect{ + Message: "Pick:", + Options: []string{"a", "b"}, + Default: "not-a-slice", + } + var result []string + err := askOnePrompt(prompt, &result, false, w, r) + + require.Error(t, err) + require.Contains(t, err.Error(), "not a string list") +} + +func Test_promptFromOptions(t *testing.T) { + tests := []struct { + name string + options ConsoleOptions + wantType string + }{ + { + name: "Password", + options: ConsoleOptions{ + Message: "Enter password:", + IsPassword: true, + }, + wantType: "*survey.Password", + }, + { + name: "RegularInput", + options: ConsoleOptions{ + Message: "Enter name:", + }, + wantType: "*survey.Input", + }, + { + name: "InputWithDefault", + options: ConsoleOptions{ + Message: "Enter name:", + DefaultValue: "Alice", + }, + wantType: "*survey.Input", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := promptFromOptions(tt.options) + got := fmt.Sprintf("%T", p) + require.Equal(t, tt.wantType, got) + + if tt.wantType == "*survey.Input" { + inp := p.(*survey.Input) + require.Equal(t, tt.options.Message, inp.Message) + if defStr, ok := tt.options.DefaultValue.(string); ok { + require.Equal(t, defStr, inp.Default) + } + } + }) + } +} + +func Test_choicesFromOptions(t *testing.T) { + tests := []struct { + name string + options ConsoleOptions + wantLen int + wantFirstVal string + wantHasDetail bool + }{ + { + name: "WithDetails", + options: ConsoleOptions{ + Options: []string{"A", "B"}, + OptionDetails: []string{"detail-A", "detail-B"}, + }, + wantLen: 2, + wantFirstVal: "A", + wantHasDetail: true, + }, + { + name: "WithoutDetails", + options: ConsoleOptions{ + Options: []string{"X", "Y"}, + }, + wantLen: 2, + wantFirstVal: "X", + wantHasDetail: false, + }, + { + name: "PartialDetails", + options: ConsoleOptions{ + Options: []string{"A", "B", "C"}, + OptionDetails: []string{"d-A"}, + }, + wantLen: 3, + wantFirstVal: "A", + wantHasDetail: true, + }, + { + name: "EmptyDetailString", + options: ConsoleOptions{ + Options: []string{"A"}, + OptionDetails: []string{""}, + }, + wantLen: 1, + wantFirstVal: "A", + wantHasDetail: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + choices := choicesFromOptions(tt.options) + require.Len(t, choices, tt.wantLen) + require.Equal(t, tt.wantFirstVal, choices[0].Value) + if tt.wantHasDetail { + require.NotNil(t, choices[0].Detail) + } else { + require.Nil(t, choices[0].Detail) + } + }) + } +} diff --git a/cli/azd/pkg/input/console_fs_test.go b/cli/azd/pkg/input/console_fs_test.go new file mode 100644 index 00000000000..8bef1314f73 --- /dev/null +++ b/cli/azd/pkg/input/console_fs_test.go @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package input + +import ( + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_fsSuggestions(t *testing.T) { + dir := t.TempDir() + + // Create test files and directories + require.NoError(t, os.MkdirAll(filepath.Join(dir, "subdir"), 0o755)) + require.NoError(t, os.WriteFile( + filepath.Join(dir, "file.txt"), []byte("test"), 0o600)) + require.NoError(t, os.WriteFile( + filepath.Join(dir, ".hidden"), []byte("hidden"), 0o600)) + require.NoError(t, os.MkdirAll( + filepath.Join(dir, ".hiddendir"), 0o755)) + + tests := []struct { + name string + opts FsSuggestOptions + root string + input string + wantMinLen int + wantContain string + wantExclude string + }{ + { + name: "EmptyInputIncludesCurrentDir", + opts: FsSuggestOptions{}, + root: dir, + input: "", + wantMinLen: 1, + wantContain: currentDirDisplayed, + }, + { + name: "EmptyInputExcludeCurrentDir", + opts: FsSuggestOptions{ + ExcludeCurrentDir: true, + }, + root: dir, + input: "", + wantExclude: currentDirDisplayed, + }, + { + name: "MatchesFilesInDir", + opts: FsSuggestOptions{}, + root: "", + input: filepath.Join(dir, "file"), + wantMinLen: 1, + }, + { + name: "ExcludeHiddenFiles", + opts: FsSuggestOptions{ + ExcludeCurrentDir: true, + }, + root: "", + input: filepath.Join(dir, "."), + wantExclude: ".hidden", + wantMinLen: 0, + }, + { + name: "IncludeHiddenFiles", + opts: FsSuggestOptions{ + ExcludeCurrentDir: true, + IncludeHiddenFiles: true, + }, + root: "", + input: filepath.Join(dir, ".h"), + wantMinLen: 1, + }, + { + name: "ExcludeDirectories", + opts: FsSuggestOptions{ + ExcludeCurrentDir: true, + ExcludeDirectories: true, + IncludeHiddenFiles: true, + }, + root: "", + input: dir + string(os.PathSeparator), + wantMinLen: 1, + }, + { + name: "ExcludeFiles", + opts: FsSuggestOptions{ + ExcludeCurrentDir: true, + ExcludeFiles: true, + }, + root: "", + input: filepath.Join(dir, "sub"), + wantMinLen: 1, + }, + { + name: "ExactFileMatch", + opts: FsSuggestOptions{ + ExcludeCurrentDir: true, + }, + root: "", + input: filepath.Join(dir, "file.txt"), + wantMinLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + suggestions := fsSuggestions(tt.opts, tt.root, tt.input) + if tt.wantMinLen > 0 { + require.GreaterOrEqual(t, len(suggestions), tt.wantMinLen) + } + if tt.wantContain != "" { + require.Contains(t, suggestions, tt.wantContain) + } + if tt.wantExclude != "" { + for _, s := range suggestions { + require.NotContains(t, s, tt.wantExclude) + } + } + }) + } +} + +func Test_fsSuggestions_DirectoryTrailingSlash(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(dir, "mydir"), 0o755)) + + suggestions := fsSuggestions( + FsSuggestOptions{ExcludeCurrentDir: true}, + "", + filepath.Join(dir, "mydir"), + ) + + require.Len(t, suggestions, 1) + require.True(t, + suggestions[0][len(suggestions[0])-1] == filepath.Separator, + "directory suggestion should end with path separator") +} + +func Test_fsSuggestions_WithRoot(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(dir, "readme.md"), []byte("# readme"), 0o600)) + + suggestions := fsSuggestions( + FsSuggestOptions{ExcludeCurrentDir: true}, + dir, + "readme", + ) + + require.Len(t, suggestions, 1) + require.Contains(t, suggestions[0], "readme.md") +} + +func Test_currentDirDisplayed_Constant(t *testing.T) { + require.Equal(t, "./ [current directory]", currentDirDisplayed) +} + +func Test_currentDirSentinelTranslation(t *testing.T) { + // Verify that the sentinel value translates to "./" + expected := "." + string(filepath.Separator) + + // Simulate what PromptFs does when response == currentDirDisplayed + response := currentDirDisplayed + if response == currentDirDisplayed { + response = "." + string(filepath.Separator) + } + + require.Equal(t, expected, response) + + if runtime.GOOS == "windows" { + require.Equal(t, ".\\", response) + } else { + require.Equal(t, "./", response) + } +} diff --git a/cli/azd/pkg/input/console_helpers_test.go b/cli/azd/pkg/input/console_helpers_test.go new file mode 100644 index 00000000000..a84037a6c65 --- /dev/null +++ b/cli/azd/pkg/input/console_helpers_test.go @@ -0,0 +1,318 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package input + +import ( + "bytes" + "context" + "os" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/output" + "github.com/stretchr/testify/require" +) + +// newTestConsole creates a non-interactive console for unit testing +// without requiring a real terminal. +func newTestConsole( + t *testing.T, + noPrompt bool, + formatter output.Formatter, +) (*AskerConsole, *bytes.Buffer) { + t.Helper() + buf := &bytes.Buffer{} + c := NewConsole( + noPrompt, + false, + Writers{Output: buf}, + ConsoleHandles{ + Stderr: os.Stderr, + Stdin: os.Stdin, + Stdout: buf, + }, + formatter, + nil, + ) + return c.(*AskerConsole), buf +} + +func TestGetStepResultFormat(t *testing.T) { + tests := []struct { + name string + err error + want SpinnerUxType + }{ + {"NilError", nil, StepDone}, + {"NonNilError", os.ErrNotExist, StepFailed}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, GetStepResultFormat(tt.err)) + }) + } +} + +func TestIsUnformatted(t *testing.T) { + tests := []struct { + name string + formatter output.Formatter + want bool + }{ + { + name: "NilFormatter", + formatter: nil, + want: true, + }, + { + name: "NoneFormatter", + formatter: mustFormatter(t, string(output.NoneFormat)), + want: true, + }, + { + name: "JsonFormatter", + formatter: &output.JsonFormatter{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, _ := newTestConsole(t, false, tt.formatter) + require.Equal(t, tt.want, c.IsUnformatted()) + }) + } +} + +func mustFormatter(t *testing.T, name string) output.Formatter { + t.Helper() + f, err := output.NewFormatter(name) + require.NoError(t, err) + return f +} + +func TestGetFormatter(t *testing.T) { + formatter := &output.JsonFormatter{} + c, _ := newTestConsole(t, false, formatter) + require.Equal(t, formatter, c.GetFormatter()) +} + +func TestGetFormatter_Nil(t *testing.T) { + c, _ := newTestConsole(t, false, nil) + require.Nil(t, c.GetFormatter()) +} + +func TestSetWriter(t *testing.T) { + c, defaultBuf := newTestConsole(t, false, nil) + + // Confirm initial writer is the default + require.Equal(t, defaultBuf, c.GetWriter()) + + // Set a custom writer + custom := &bytes.Buffer{} + c.SetWriter(custom) + require.Equal(t, custom, c.GetWriter()) + + // Reset to default + c.SetWriter(nil) + require.Equal(t, defaultBuf, c.GetWriter()) +} + +func TestHandles(t *testing.T) { + c, _ := newTestConsole(t, false, nil) + h := c.Handles() + require.NotNil(t, h.Stdin) + require.NotNil(t, h.Stdout) + require.NotNil(t, h.Stderr) +} + +func TestIsNoPromptMode(t *testing.T) { + tests := []struct { + name string + noPrompt bool + want bool + }{ + {"Enabled", true, true}, + {"Disabled", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, _ := newTestConsole(t, tt.noPrompt, nil) + require.Equal(t, tt.want, c.IsNoPromptMode()) + }) + } +} + +func TestIsSpinnerInteractive_NonTerminal(t *testing.T) { + c, _ := newTestConsole(t, false, nil) + // Non-terminal consoles should not have interactive spinners + require.False(t, c.IsSpinnerInteractive()) +} + +func TestIsSpinnerRunning_InitiallyStopped(t *testing.T) { + c, _ := newTestConsole(t, false, nil) + require.False(t, c.IsSpinnerRunning(context.Background())) +} + +func TestSupportsPromptDialog(t *testing.T) { + tests := []struct { + name string + cfg *ExternalPromptConfiguration + want bool + }{ + { + name: "NoExternalPrompt", + cfg: nil, + want: false, + }, + { + name: "WithExternalPrompt", + cfg: &ExternalPromptConfiguration{ + Endpoint: "http://localhost", + Key: "key", + }, + want: true, + }, + { + name: "WithDialogDisabled", + cfg: &ExternalPromptConfiguration{ + Endpoint: "http://localhost", + Key: "key", + NoPromptDialog: true, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewConsole( + false, + false, + Writers{Output: &bytes.Buffer{}}, + ConsoleHandles{ + Stderr: os.Stderr, + Stdin: os.Stdin, + Stdout: &bytes.Buffer{}, + }, + nil, + tt.cfg, + ) + require.Equal(t, tt.want, c.SupportsPromptDialog()) + }) + } +} + +func TestEnsureBlankLine(t *testing.T) { + tests := []struct { + name string + last2 [2]byte + wantCall bool + }{ + { + name: "AlreadyBlank", + last2: [2]byte{'\n', '\n'}, + wantCall: false, + }, + { + name: "OneNewLine", + last2: [2]byte{'a', '\n'}, + wantCall: true, + }, + { + name: "NoNewLine", + last2: [2]byte{'a', 'b'}, + wantCall: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + formatter := mustFormatter(t, string(output.NoneFormat)) + c, buf := newTestConsole(t, false, formatter) + c.last2Byte = tt.last2 + before := buf.Len() + c.EnsureBlankLine(context.Background()) + if tt.wantCall { + require.Greater(t, buf.Len(), before, + "expected output to be written") + } else { + require.Equal(t, before, buf.Len(), + "expected no output when already blank") + } + }) + } +} + +func TestUpdateLastBytes(t *testing.T) { + tests := []struct { + name string + initial [2]byte + msg string + want [2]byte + }{ + { + name: "EmptyMessage", + initial: [2]byte{'a', 'b'}, + msg: "", + want: [2]byte{'a', 'b'}, + }, + { + name: "SingleChar", + initial: [2]byte{'a', 'b'}, + msg: "x", + want: [2]byte{'b', 'x'}, + }, + { + name: "TwoChars", + initial: [2]byte{'a', 'b'}, + msg: "xy", + want: [2]byte{'x', 'y'}, + }, + { + name: "LongMessage", + initial: [2]byte{0, 0}, + msg: "hello\n", + want: [2]byte{'o', '\n'}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, _ := newTestConsole(t, false, nil) + c.last2Byte = tt.initial + c.updateLastBytes(tt.msg) + require.Equal(t, tt.want, c.last2Byte) + }) + } +} + +func TestSpinnerTerminalMode(t *testing.T) { + // Non-terminal should not have TTY mode + mode := spinnerTerminalMode(false) + require.NotZero(t, mode) + + // Terminal mode should have TTY mode + ttyMode := spinnerTerminalMode(true) + require.NotZero(t, ttyMode) + require.NotEqual(t, mode, ttyMode) +} + +func TestSetIndentation(t *testing.T) { + tests := []struct { + name string + spaces int + want string + }{ + {"Zero", 0, ""}, + {"Two", 2, " "}, + {"Four", 4, " "}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, setIndentation(tt.spaces)) + }) + } +} diff --git a/cli/azd/pkg/keyvault/keyvault_test.go b/cli/azd/pkg/keyvault/keyvault_test.go new file mode 100644 index 00000000000..6bbf5facfe2 --- /dev/null +++ b/cli/azd/pkg/keyvault/keyvault_test.go @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package keyvault + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsAzureKeyVaultSecret(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"valid akvs reference", "akvs://sub-id/vault/secret", true}, + {"empty string", "", false}, + {"wrong prefix", "https://vault.azure.net/secrets/foo", false}, + {"partial prefix", "akvs:/", false}, + {"case sensitive", "AKVS://sub/vault/secret", false}, + {"just prefix", "akvs://", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsAzureKeyVaultSecret(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsValidSecretName(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"simple lowercase", "mysecret", true}, + {"simple uppercase", "MySecret", true}, + {"with numbers", "secret123", true}, + {"with hyphens", "my-secret-name", true}, + {"single char", "a", true}, + { + "max length 127 chars", + strings.Repeat("a", 127), + true, + }, + {"empty string", "", false}, + { + "too long 128 chars", + strings.Repeat("a", 128), + false, + }, + {"with underscore", "my_secret", false}, + {"with dot", "my.secret", false}, + {"with space", "my secret", false}, + {"with slash", "my/secret", false}, + {"with special chars", "my@secret!", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsValidSecretName(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestNewAzureKeyVaultSecret(t *testing.T) { + tests := []struct { + name string + subId string + vaultId string + secretName string + expected string + }{ + { + "standard reference", + "sub-123", "my-vault", "my-secret", + "akvs://sub-123/my-vault/my-secret", + }, + { + "empty values", + "", "", "", + "akvs:////", + }, + { + "guid-style subscription", + "00000000-0000-0000-0000-000000000000", + "production-vault", + "db-connection-string", + "akvs://00000000-0000-0000-0000-000000000000" + + "/production-vault/db-connection-string", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := NewAzureKeyVaultSecret( + tt.subId, tt.vaultId, tt.secretName, + ) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParseAzureKeyVaultSecret(t *testing.T) { + t.Run("valid reference", func(t *testing.T) { + result, err := ParseAzureKeyVaultSecret( + "akvs://sub-123/my-vault/my-secret", + ) + require.NoError(t, err) + assert.Equal(t, "sub-123", result.SubscriptionId) + assert.Equal(t, "my-vault", result.VaultName) + assert.Equal(t, "my-secret", result.SecretName) + }) + + t.Run("invalid prefix", func(t *testing.T) { + _, err := ParseAzureKeyVaultSecret("https://foo/bar/baz") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid Azure Key Vault Secret") + }) + + t.Run("too few parts", func(t *testing.T) { + _, err := ParseAzureKeyVaultSecret("akvs://sub-123/vault-only") + require.Error(t, err) + assert.Contains(t, err.Error(), "Expected format") + }) + + t.Run("too many parts", func(t *testing.T) { + _, err := ParseAzureKeyVaultSecret( + "akvs://sub/vault/secret/extra", + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "Expected format") + }) + + t.Run("empty string", func(t *testing.T) { + _, err := ParseAzureKeyVaultSecret("") + require.Error(t, err) + }) + + t.Run("roundtrip with NewAzureKeyVaultSecret", func(t *testing.T) { + original := NewAzureKeyVaultSecret( + "sub-abc", "vault-xyz", "secret-123", + ) + parsed, err := ParseAzureKeyVaultSecret(original) + require.NoError(t, err) + assert.Equal(t, "sub-abc", parsed.SubscriptionId) + assert.Equal(t, "vault-xyz", parsed.VaultName) + assert.Equal(t, "secret-123", parsed.SecretName) + }) +} + +func TestConstants(t *testing.T) { + t.Run("ErrAzCliSecretNotFound is not nil", func(t *testing.T) { + assert.NotNil(t, ErrAzCliSecretNotFound) + assert.Equal(t, "secret not found", ErrAzCliSecretNotFound.Error()) + }) + + t.Run("role IDs have correct prefix", func(t *testing.T) { + prefix := "/providers/Microsoft.Authorization/roleDefinitions/" + assert.Contains(t, RoleIdKeyVaultAdministrator, prefix) + assert.Contains(t, RoleIdKeyVaultSecretsUser, prefix) + }) +} diff --git a/cli/azd/pkg/kubelogin/cli_test.go b/cli/azd/pkg/kubelogin/cli_test.go new file mode 100644 index 00000000000..5bb22c418b5 --- /dev/null +++ b/cli/azd/pkg/kubelogin/cli_test.go @@ -0,0 +1,284 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package kubelogin + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/test/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCli(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + cli := NewCli(mockContext.CommandRunner) + require.NotNil(t, cli) +} + +func TestCliName(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + cli := NewCli(mockContext.CommandRunner) + assert.Equal(t, "kubelogin", cli.Name()) +} + +func TestCliInstallUrl(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + cli := NewCli(mockContext.CommandRunner) + assert.Equal( + t, + "https://aka.ms/azure-dev/kubelogin-install", + cli.InstallUrl(), + ) +} + +func TestCheckInstalled(t *testing.T) { + t.Run("Found", func(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.MockToolInPath( + "kubelogin", nil, + ) + + cli := NewCli(mockContext.CommandRunner) + err := cli.CheckInstalled(*mockContext.Context) + require.NoError(t, err) + }) + + t.Run("NotFound", func(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.MockToolInPath( + "kubelogin", + errors.New("kubelogin not found in PATH"), + ) + + cli := NewCli(mockContext.CommandRunner) + err := cli.CheckInstalled(*mockContext.Context) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} + +func TestConvertKubeConfig(t *testing.T) { + t.Run("NilOptionsDefaultsToAzdLogin", func(t *testing.T) { + var capturedArgs exec.RunArgs + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "kubelogin") + }). + RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + err := cli.ConvertKubeConfig(*mockContext.Context, nil) + require.NoError(t, err) + + assert.Equal(t, "kubelogin", capturedArgs.Cmd) + expected := []string{ + "convert-kubeconfig", "--login", "azd", + } + assert.Equal(t, expected, capturedArgs.Args) + }) + + t.Run("EmptyOptionsDefaultsToAzdLogin", func(t *testing.T) { + var capturedArgs exec.RunArgs + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "kubelogin") + }). + RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + err := cli.ConvertKubeConfig( + *mockContext.Context, + &ConvertOptions{}, + ) + require.NoError(t, err) + + expected := []string{ + "convert-kubeconfig", "--login", "azd", + } + assert.Equal(t, expected, capturedArgs.Args) + }) + + t.Run("AllOptionsSet", func(t *testing.T) { + var capturedArgs exec.RunArgs + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "kubelogin") + }). + RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + opts := &ConvertOptions{ + Login: "spn", + TenantId: "tenant-abc-123", + Context: "my-k8s-context", + KubeConfig: "/home/user/.kube/config", + } + err := cli.ConvertKubeConfig( + *mockContext.Context, opts, + ) + require.NoError(t, err) + + assert.Equal(t, "kubelogin", capturedArgs.Cmd) + expected := []string{ + "convert-kubeconfig", + "--login", "spn", + "--kubeconfig", "/home/user/.kube/config", + "--tenant-id", "tenant-abc-123", + "--context", "my-k8s-context", + } + assert.Equal(t, expected, capturedArgs.Args) + }) + + t.Run("OnlyKubeConfigSet", func(t *testing.T) { + var capturedArgs exec.RunArgs + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "kubelogin") + }). + RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + opts := &ConvertOptions{ + KubeConfig: "/tmp/kubeconfig", + } + err := cli.ConvertKubeConfig( + *mockContext.Context, opts, + ) + require.NoError(t, err) + + expected := []string{ + "convert-kubeconfig", + "--login", "azd", + "--kubeconfig", "/tmp/kubeconfig", + } + assert.Equal(t, expected, capturedArgs.Args) + }) + + t.Run("OnlyTenantIdSet", func(t *testing.T) { + var capturedArgs exec.RunArgs + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "kubelogin") + }). + RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + opts := &ConvertOptions{ + TenantId: "my-tenant", + } + err := cli.ConvertKubeConfig( + *mockContext.Context, opts, + ) + require.NoError(t, err) + + expected := []string{ + "convert-kubeconfig", + "--login", "azd", + "--tenant-id", "my-tenant", + } + assert.Equal(t, expected, capturedArgs.Args) + }) + + t.Run("OnlyContextSet", func(t *testing.T) { + var capturedArgs exec.RunArgs + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "kubelogin") + }). + RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + opts := &ConvertOptions{ + Context: "production", + } + err := cli.ConvertKubeConfig( + *mockContext.Context, opts, + ) + require.NoError(t, err) + + expected := []string{ + "convert-kubeconfig", + "--login", "azd", + "--context", "production", + } + assert.Equal(t, expected, capturedArgs.Args) + }) + + t.Run("CommandFailure", func(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "kubelogin") + }). + RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(1, "", "err"), + errors.New("exit code 1") + }) + + cli := NewCli(mockContext.CommandRunner) + err := cli.ConvertKubeConfig( + *mockContext.Context, + &ConvertOptions{Login: "azd"}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "converting kubeconfig") + }) + + t.Run("CustomLoginMethod", func(t *testing.T) { + var capturedArgs exec.RunArgs + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "kubelogin") + }). + RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + opts := &ConvertOptions{ + Login: "devicecode", + } + err := cli.ConvertKubeConfig( + *mockContext.Context, opts, + ) + require.NoError(t, err) + + expected := []string{ + "convert-kubeconfig", + "--login", "devicecode", + } + assert.Equal(t, expected, capturedArgs.Args) + }) +} diff --git a/cli/azd/pkg/kustomize/cli_coverage_test.go b/cli/azd/pkg/kustomize/cli_coverage_test.go new file mode 100644 index 00000000000..18be80dff6f --- /dev/null +++ b/cli/azd/pkg/kustomize/cli_coverage_test.go @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package kustomize + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/test/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCliName(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + cli := NewCli(mockContext.CommandRunner) + assert.Equal(t, "kustomize", cli.Name()) +} + +func TestCliInstallUrl(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + cli := NewCli(mockContext.CommandRunner) + assert.Equal( + t, + "https://aka.ms/azure-dev/kustomize-install", + cli.InstallUrl(), + ) +} + +func TestCheckInstalled(t *testing.T) { + t.Run("ToolFoundAndVersionSucceeds", func(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.MockToolInPath( + "kustomize", nil, + ) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains( + command, "kustomize version", + ) + }). + RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult( + 0, "v5.3.0", "", + ), nil + }, + ) + + cli := NewCli(mockContext.CommandRunner) + err := cli.CheckInstalled(*mockContext.Context) + require.NoError(t, err) + }) + + t.Run("ToolFoundButVersionFails", func(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.MockToolInPath( + "kustomize", nil, + ) + mockContext.CommandRunner. + When(func(args exec.RunArgs, command string) bool { + return strings.Contains( + command, "kustomize version", + ) + }). + RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(1, "", ""), + errors.New("version failed") + }, + ) + + cli := NewCli(mockContext.CommandRunner) + // CheckInstalled should still succeed even if + // version fetch fails — it only logs the error. + err := cli.CheckInstalled(*mockContext.Context) + require.NoError(t, err) + }) + + t.Run("ToolNotInPath", func(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.MockToolInPath( + "kustomize", + errors.New("kustomize not found in PATH"), + ) + + cli := NewCli(mockContext.CommandRunner) + err := cli.CheckInstalled(*mockContext.Context) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} + +func TestNewCli(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + cli := NewCli(mockContext.CommandRunner) + require.NotNil(t, cli) + // Verify the cli is usable after construction + assert.Equal(t, "kustomize", cli.Name()) +} + +func TestWithCwd_Chaining(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + cli := NewCli(mockContext.CommandRunner) + + // WithCwd should return the same *Cli for chaining + result := cli.WithCwd("/some/path") + assert.Same(t, cli, result) +} diff --git a/cli/azd/pkg/pipeline/pipeline_helpers_test.go b/cli/azd/pkg/pipeline/pipeline_helpers_test.go new file mode 100644 index 00000000000..6aa93c17c1a --- /dev/null +++ b/cli/azd/pkg/pipeline/pipeline_helpers_test.go @@ -0,0 +1,1034 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package pipeline + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/azure/azure-dev/cli/azd/pkg/entraid" + "github.com/azure/azure-dev/cli/azd/pkg/graphsdk" + "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning" + "github.com/microsoft/azure-devops-go-api/azuredevops/v7/build" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------ +// toCiProviderType +// ------------------------------------------------------------------ + +func Test_toCiProviderType(t *testing.T) { + tests := []struct { + name string + input string + want ciProviderType + wantErr bool + errMatch string + }{ + { + name: "github", + input: "github", + want: ciProviderGitHubActions, + }, + { + name: "azdo", + input: "azdo", + want: ciProviderAzureDevOps, + }, + { + name: "invalid", + input: "jenkins", + wantErr: true, + errMatch: "invalid ci provider type jenkins", + }, + { + name: "empty string", + input: "", + wantErr: true, + errMatch: "invalid ci provider type", + }, + { + name: "mixed case is invalid", + input: "GitHub", + wantErr: true, + errMatch: "invalid ci provider type GitHub", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := toCiProviderType(tt.input) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +// ------------------------------------------------------------------ +// toInfraProviderType +// ------------------------------------------------------------------ + +func Test_toInfraProviderType(t *testing.T) { + tests := []struct { + name string + input string + want infraProviderType + wantErr bool + errMatch string + }{ + { + name: "bicep", + input: "bicep", + want: infraProviderBicep, + }, + { + name: "terraform", + input: "terraform", + want: infraProviderTerraform, + }, + { + name: "empty is valid (undefined)", + input: "", + want: infraProviderUndefined, + }, + { + name: "invalid provider", + input: "pulumi", + wantErr: true, + errMatch: "invalid infra provider type pulumi", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := toInfraProviderType(tt.input) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +// ------------------------------------------------------------------ +// generateFilePaths +// ------------------------------------------------------------------ + +func Test_generateFilePaths(t *testing.T) { + tests := []struct { + name string + dirs []string + files []string + want []string + }{ + { + name: "single dir single file", + dirs: []string{".github/workflows"}, + files: []string{"azure-dev.yml"}, + want: []string{ + filepath.Join(".github/workflows", "azure-dev.yml"), + }, + }, + { + name: "multiple dirs multiple files", + dirs: []string{".azdo/pipelines", ".azuredevops/pipelines"}, + files: []string{"azure-dev.yml", "azure-dev.yaml"}, + want: []string{ + filepath.Join(".azdo/pipelines", "azure-dev.yml"), + filepath.Join(".azdo/pipelines", "azure-dev.yaml"), + filepath.Join(".azuredevops/pipelines", "azure-dev.yml"), + filepath.Join(".azuredevops/pipelines", "azure-dev.yaml"), + }, + }, + { + name: "empty dirs returns nil", + dirs: []string{}, + files: []string{"file.yml"}, + want: nil, + }, + { + name: "empty files returns nil", + dirs: []string{".github"}, + files: []string{}, + want: nil, + }, + { + name: "both empty returns nil", + dirs: []string{}, + files: []string{}, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := generateFilePaths(tt.dirs, tt.files) + assert.Equal(t, tt.want, got) + }) + } +} + +// ------------------------------------------------------------------ +// hasPipelineFile +// ------------------------------------------------------------------ + +func Test_hasPipelineFile(t *testing.T) { + t.Run("returns true when pipeline file exists", func(t *testing.T) { + tmpDir := t.TempDir() + ghDir := filepath.Join( + tmpDir, ".github", "workflows") + require.NoError(t, os.MkdirAll(ghDir, os.ModePerm)) + require.NoError(t, os.WriteFile( + filepath.Join(ghDir, "azure-dev.yml"), + []byte("trigger: none"), 0600)) + + assert.True(t, hasPipelineFile(ciProviderGitHubActions, tmpDir)) + }) + + t.Run("returns false when no pipeline file exists", func(t *testing.T) { + tmpDir := t.TempDir() + assert.False(t, + hasPipelineFile(ciProviderGitHubActions, tmpDir)) + }) + + t.Run("returns true for azdo provider", func(t *testing.T) { + tmpDir := t.TempDir() + azdoDir := filepath.Join(tmpDir, ".azdo", "pipelines") + require.NoError(t, os.MkdirAll(azdoDir, os.ModePerm)) + require.NoError(t, os.WriteFile( + filepath.Join(azdoDir, "azure-dev.yml"), + []byte("trigger: none"), 0600)) + + assert.True(t, hasPipelineFile(ciProviderAzureDevOps, tmpDir)) + }) + + t.Run("returns true for azdo alt dir", func(t *testing.T) { + tmpDir := t.TempDir() + altDir := filepath.Join( + tmpDir, ".azuredevops", "pipelines") + require.NoError(t, os.MkdirAll(altDir, os.ModePerm)) + require.NoError(t, os.WriteFile( + filepath.Join(altDir, "azure-dev.yaml"), + []byte("trigger: none"), 0600)) + + assert.True(t, hasPipelineFile(ciProviderAzureDevOps, tmpDir)) + }) +} + +// ------------------------------------------------------------------ +// resolveSmr +// ------------------------------------------------------------------ + +func Test_resolveSmr(t *testing.T) { + t.Run("returns arg when provided", func(t *testing.T) { + result := resolveSmr( + "arg-value", + config.NewEmptyConfig(), + config.NewEmptyConfig()) + require.NotNil(t, result) + assert.Equal(t, "arg-value", *result) + }) + + t.Run("returns project config value", func(t *testing.T) { + projCfg := config.NewConfig(nil) + _ = projCfg.Set( + "pipeline.config.applicationServiceManagementReference", + "proj-smr") + + result := resolveSmr("", projCfg, config.NewEmptyConfig()) + require.NotNil(t, result) + assert.Equal(t, "proj-smr", *result) + }) + + t.Run("returns user config when project empty", func(t *testing.T) { + userCfg := config.NewConfig(nil) + _ = userCfg.Set( + "pipeline.config.applicationServiceManagementReference", + "user-smr") + + result := resolveSmr( + "", config.NewEmptyConfig(), userCfg) + require.NotNil(t, result) + assert.Equal(t, "user-smr", *result) + }) + + t.Run("arg overrides project config", func(t *testing.T) { + projCfg := config.NewConfig(nil) + _ = projCfg.Set( + "pipeline.config.applicationServiceManagementReference", + "proj-smr") + + result := resolveSmr("arg-wins", projCfg, config.NewEmptyConfig()) + require.NotNil(t, result) + assert.Equal(t, "arg-wins", *result) + }) + + t.Run("project config overrides user config", func(t *testing.T) { + projCfg := config.NewConfig(nil) + _ = projCfg.Set( + "pipeline.config.applicationServiceManagementReference", + "proj-smr") + userCfg := config.NewConfig(nil) + _ = userCfg.Set( + "pipeline.config.applicationServiceManagementReference", + "user-smr") + + result := resolveSmr("", projCfg, userCfg) + require.NotNil(t, result) + assert.Equal(t, "proj-smr", *result) + }) + + t.Run("returns nil when nothing set", func(t *testing.T) { + result := resolveSmr( + "", + config.NewEmptyConfig(), + config.NewEmptyConfig()) + assert.Nil(t, result) + }) +} + +// ------------------------------------------------------------------ +// gitHubActionsEnablingChoice.String() +// ------------------------------------------------------------------ + +func Test_gitHubActionsEnablingChoice_String(t *testing.T) { + assert.Contains(t, manualChoice.String(), "manually enabled") + assert.Contains(t, cancelChoice.String(), "Exit without pushing") +} + +func Test_gitHubActionsEnablingChoice_String_panic(t *testing.T) { + assert.Panics(t, func() { + _ = gitHubActionsEnablingChoice(99).String() + }) +} + +// ------------------------------------------------------------------ +// workflow (GitHub CiPipeline) name() and url() +// ------------------------------------------------------------------ + +func Test_workflow_CiPipeline(t *testing.T) { + w := &workflow{ + repoDetails: &gitRepositoryDetails{ + url: "https://github.com/Azure/azure-dev", + }, + } + + assert.Equal(t, "actions", w.name()) + assert.Equal(t, + "https://github.com/Azure/azure-dev/actions", + w.url()) +} + +// ------------------------------------------------------------------ +// pipeline (AzDo CiPipeline) name() and url() +// ------------------------------------------------------------------ + +func Test_azdoPipeline_CiPipeline(t *testing.T) { + defName := "my-pipeline" + defId := 42 + + p := &pipeline{ + repoDetails: &AzdoRepositoryDetails{ + repoWebUrl: "https://dev.azure.com/org/project/_git/repo", + buildDefinition: &build.BuildDefinition{ + Name: &defName, + Id: &defId, + }, + }, + } + + assert.Equal(t, "my-pipeline", p.name()) + assert.Equal(t, + "https://dev.azure.com/org/project/_build?definitionId=42", + p.url()) +} + +// ------------------------------------------------------------------ +// mergeProjectVariablesAndSecrets — providerParameters +// ------------------------------------------------------------------ + +func Test_mergeProjectVariablesAndSecrets_providerParams( + t *testing.T, +) { + t.Run("single env var secret via provider param", func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "dbPass", + Value: "s3cret", + Secret: true, + LocalPrompt: true, + EnvVarMapping: []string{"DB_PASSWORD"}, + UsingEnvVarMapping: false, + }, + } + vars, secrets, err := mergeProjectVariablesAndSecrets( + nil, nil, map[string]string{}, map[string]string{}, + params, map[string]string{}) + require.NoError(t, err) + assert.Equal(t, "s3cret", secrets["DB_PASSWORD"]) + assert.Empty(t, vars) + }) + + t.Run("single env var variable via provider param", func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "region", + Value: "eastus", + Secret: false, + LocalPrompt: true, + EnvVarMapping: []string{"AZURE_LOCATION"}, + UsingEnvVarMapping: false, + }, + } + vars, secrets, err := mergeProjectVariablesAndSecrets( + nil, nil, map[string]string{}, map[string]string{}, + params, map[string]string{}) + require.NoError(t, err) + assert.Equal(t, "eastus", vars["AZURE_LOCATION"]) + assert.Empty(t, secrets) + }) + + t.Run("error when local prompt and no env var mapping", + func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "bad", + Value: "val", + LocalPrompt: true, + EnvVarMapping: []string{}, + }, + } + _, _, err := mergeProjectVariablesAndSecrets( + nil, nil, map[string]string{}, map[string]string{}, + params, map[string]string{}) + require.Error(t, err) + assert.Contains(t, err.Error(), + "local prompt and it has not a mapped environment variable") + }) + + t.Run( + "error when local prompt and multiple env var mappings", + func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "multi", + Value: "val", + LocalPrompt: true, + EnvVarMapping: []string{"A", "B"}, + }, + } + _, _, err := mergeProjectVariablesAndSecrets( + nil, nil, map[string]string{}, map[string]string{}, + params, map[string]string{}) + require.Error(t, err) + assert.Contains(t, err.Error(), + "more than one mapped environment variable") + }) + + t.Run("multi env var non-prompt uses env values", func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "multiEnv", + Secret: false, + LocalPrompt: false, + EnvVarMapping: []string{"VAR_A", "VAR_B"}, + }, + } + env := map[string]string{ + "VAR_A": "valA", + } + vars, secrets, err := mergeProjectVariablesAndSecrets( + nil, nil, map[string]string{}, map[string]string{}, + params, env) + require.NoError(t, err) + assert.Equal(t, "valA", vars["VAR_A"]) + // VAR_B not in env, so not set + _, hasB := vars["VAR_B"] + assert.False(t, hasB) + assert.Empty(t, secrets) + }) + + t.Run("multi env var secret uses env values", func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "multiSecret", + Secret: true, + LocalPrompt: false, + EnvVarMapping: []string{"SEC_A", "SEC_B"}, + }, + } + env := map[string]string{ + "SEC_A": "secretA", + "SEC_B": "secretB", + } + vars, secrets, err := mergeProjectVariablesAndSecrets( + nil, nil, map[string]string{}, map[string]string{}, + params, env) + require.NoError(t, err) + assert.Equal(t, "secretA", secrets["SEC_A"]) + assert.Equal(t, "secretB", secrets["SEC_B"]) + assert.Empty(t, vars) + }) + + t.Run("no env var mapping and no local prompt is skipped", + func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "skipped", + Value: "val", + LocalPrompt: false, + EnvVarMapping: []string{}, + }, + } + vars, secrets, err := mergeProjectVariablesAndSecrets( + nil, nil, map[string]string{}, map[string]string{}, + params, map[string]string{}) + require.NoError(t, err) + assert.Empty(t, vars) + assert.Empty(t, secrets) + }) + + t.Run( + "single env var non-prompt non-UsingEnvVarMapping skipped", + func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "notUsed", + Value: "val", + LocalPrompt: false, + UsingEnvVarMapping: false, + EnvVarMapping: []string{"SOME_VAR"}, + }, + } + vars, secrets, err := mergeProjectVariablesAndSecrets( + nil, nil, map[string]string{}, map[string]string{}, + params, map[string]string{}) + require.NoError(t, err) + assert.Empty(t, vars) + assert.Empty(t, secrets) + }) + + t.Run( + "single env var non-prompt UsingEnvVarMapping is set", + func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "used", + Value: "myVal", + Secret: false, + LocalPrompt: false, + UsingEnvVarMapping: true, + EnvVarMapping: []string{"MY_VAR"}, + }, + } + vars, secrets, err := mergeProjectVariablesAndSecrets( + nil, nil, map[string]string{}, map[string]string{}, + params, map[string]string{}) + require.NoError(t, err) + assert.Equal(t, "myVal", vars["MY_VAR"]) + assert.Empty(t, secrets) + }) + + t.Run("project vars override provider params", func(t *testing.T) { + params := []provisioning.Parameter{ + { + Name: "region", + Value: "westus", + LocalPrompt: true, + EnvVarMapping: []string{"AZURE_LOCATION"}, + UsingEnvVarMapping: false, + }, + } + env := map[string]string{ + "AZURE_LOCATION": "eastus2", + } + vars, _, err := mergeProjectVariablesAndSecrets( + []string{"AZURE_LOCATION"}, nil, + map[string]string{}, map[string]string{}, + params, env) + require.NoError(t, err) + // project var from env overrides provider param + assert.Equal(t, "eastus2", vars["AZURE_LOCATION"]) + }) +} + +// ------------------------------------------------------------------ +// mergeProjectVariablesAndSecrets — empty values skipped +// ------------------------------------------------------------------ + +func Test_mergeProjectVariablesAndSecrets_emptyEnvSkipped( + t *testing.T, +) { + env := map[string]string{ + "VAR1": "", + "VAR2": "value2", + } + vars, _, err := mergeProjectVariablesAndSecrets( + []string{"VAR1", "VAR2"}, nil, + map[string]string{}, map[string]string{}, + nil, env) + require.NoError(t, err) + // VAR1 has empty value, should NOT be in variables + _, hasVar1 := vars["VAR1"] + assert.False(t, hasVar1) + assert.Equal(t, "value2", vars["VAR2"]) +} + +// ------------------------------------------------------------------ +// mergeProjectVariablesAndSecrets — initial values cloned +// ------------------------------------------------------------------ + +func Test_mergeProjectVariablesAndSecrets_initialNotMutated( + t *testing.T, +) { + initialVars := map[string]string{"INIT": "orig"} + initialSecrets := map[string]string{"SEC": "orig"} + env := map[string]string{"EXTRA": "val"} + + vars, secrets, err := mergeProjectVariablesAndSecrets( + []string{"EXTRA"}, nil, + initialVars, initialSecrets, + nil, env) + require.NoError(t, err) + + // returned maps have both initial and extra + assert.Equal(t, "orig", vars["INIT"]) + assert.Equal(t, "val", vars["EXTRA"]) + assert.Equal(t, "orig", secrets["SEC"]) + + // original maps not mutated + _, hasExtra := initialVars["EXTRA"] + assert.False(t, hasExtra, + "initialVars should not be mutated") +} + +// ------------------------------------------------------------------ +// escapeValuesForPipeline — additional cases +// ------------------------------------------------------------------ + +func Test_escapeValuesForPipeline_nilMap(t *testing.T) { + // should not panic on nil map + assert.NotPanics(t, func() { + escapeValuesForPipeline(nil) + }) +} + +func Test_escapeValuesForPipeline_noModification(t *testing.T) { + values := map[string]string{ + "plain": "hello world", + } + escapeValuesForPipeline(values) + assert.Equal(t, "hello world", values["plain"]) +} + +// ------------------------------------------------------------------ +// pipelineProviderFiles var validation +// ------------------------------------------------------------------ + +func Test_pipelineProviderFiles_knownProviders(t *testing.T) { + ghInfo, ok := pipelineProviderFiles[ciProviderGitHubActions] + require.True(t, ok, "GitHub Actions entry missing") + assert.NotEmpty(t, ghInfo.RootDirectories) + assert.NotEmpty(t, ghInfo.PipelineDirectories) + assert.NotEmpty(t, ghInfo.Files) + assert.NotEmpty(t, ghInfo.DefaultFile) + + azdoInfo, ok := pipelineProviderFiles[ciProviderAzureDevOps] + require.True(t, ok, "Azure DevOps entry missing") + assert.NotEmpty(t, azdoInfo.RootDirectories) + assert.NotEmpty(t, azdoInfo.PipelineDirectories) + assert.NotEmpty(t, azdoInfo.Files) + assert.NotEmpty(t, azdoInfo.DefaultFile) +} + +// ------------------------------------------------------------------ +// constants sanity +// ------------------------------------------------------------------ + +func Test_constants(t *testing.T) { + assert.Equal(t, ciProviderType("github"), ciProviderGitHubActions) + assert.Equal(t, ciProviderType("azdo"), ciProviderAzureDevOps) + assert.Equal(t, infraProviderType("bicep"), infraProviderBicep) + assert.Equal(t, infraProviderType("terraform"), infraProviderTerraform) + assert.Equal(t, infraProviderType(""), infraProviderUndefined) + assert.Equal(t, PipelineAuthType("federated"), AuthTypeFederated) + assert.Equal(t, + PipelineAuthType("client-credentials"), + AuthTypeClientCredentials) + assert.Equal(t, "AZD_PIPELINE_PROVIDER", envPersistedKey) +} + +// ------------------------------------------------------------------ +// servicePrincipal lookup strategy +// ------------------------------------------------------------------ + +type mockEntraIdService struct { + entraid.EntraIdService + getSpResult *graphsdk.ServicePrincipal + getSpErr error +} + +func (m *mockEntraIdService) GetServicePrincipal( + _ context.Context, _, _ string, +) (*graphsdk.ServicePrincipal, error) { + return m.getSpResult, m.getSpErr +} + +func Test_servicePrincipal(t *testing.T) { + ctx := context.Background() + + t.Run("uses principal-id arg when set", func(t *testing.T) { + orgId := "org-id-123" + sp := &graphsdk.ServicePrincipal{ + AppId: "app-id", + DisplayName: "my-sp", + AppOwnerOrganizationId: &orgId, + } + svc := &mockEntraIdService{getSpResult: sp} + + result, err := servicePrincipal(ctx, "", "sub-1", + &PipelineManagerArgs{ + PipelineServicePrincipalId: "app-id", + }, svc) + require.NoError(t, err) + assert.Equal(t, "app-id", result.appIdOrName) + assert.Equal(t, "my-sp", result.applicationName) + assert.Equal(t, lookupKindPrincipalId, result.lookupKind) + assert.NotNil(t, result.servicePrincipal) + }) + + t.Run("uses principal-name arg when set", func(t *testing.T) { + sp := &graphsdk.ServicePrincipal{ + AppId: "app-from-name", + DisplayName: "sp-name", + } + svc := &mockEntraIdService{getSpResult: sp} + + result, err := servicePrincipal(ctx, "", "sub-1", + &PipelineManagerArgs{ + PipelineServicePrincipalName: "sp-name", + }, svc) + require.NoError(t, err) + assert.Equal(t, "app-from-name", result.appIdOrName) + assert.Equal(t, lookupKindPrincipleName, result.lookupKind) + }) + + t.Run("uses env var when no args", func(t *testing.T) { + orgId := "org-id" + sp := &graphsdk.ServicePrincipal{ + AppId: "env-client-id", + DisplayName: "env-sp", + AppOwnerOrganizationId: &orgId, + } + svc := &mockEntraIdService{getSpResult: sp} + + result, err := servicePrincipal( + ctx, "env-client-id", "sub-1", + &PipelineManagerArgs{}, svc) + require.NoError(t, err) + assert.Equal(t, "env-client-id", result.appIdOrName) + assert.Equal(t, + lookupKindEnvironmentVariable, result.lookupKind) + }) + + t.Run("creates new when no args and no env", func(t *testing.T) { + svc := &mockEntraIdService{} + + result, err := servicePrincipal( + ctx, "", "sub-1", + &PipelineManagerArgs{}, svc) + require.NoError(t, err) + assert.Contains(t, result.applicationName, "az-dev-") + assert.Nil(t, result.servicePrincipal) + }) + + t.Run( + "error when principal-id not found", + func(t *testing.T) { + svc := &mockEntraIdService{ + getSpErr: assert.AnError, + } + _, err := servicePrincipal(ctx, "", "sub-1", + &PipelineManagerArgs{ + PipelineServicePrincipalId: "missing-id", + }, svc) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing-id") + assert.Contains(t, err.Error(), "--principal-id") + }) + + t.Run("error when env var not found", func(t *testing.T) { + svc := &mockEntraIdService{ + getSpErr: assert.AnError, + } + _, err := servicePrincipal( + ctx, "env-client", "sub-1", + &PipelineManagerArgs{}, svc) + require.Error(t, err) + assert.Contains(t, err.Error(), "env-client") + assert.Contains(t, err.Error(), + AzurePipelineClientIdEnvVarName) + }) + + t.Run( + "name not found returns name for creation", + func(t *testing.T) { + svc := &mockEntraIdService{ + getSpErr: assert.AnError, + } + result, err := servicePrincipal(ctx, "", "sub-1", + &PipelineManagerArgs{ + PipelineServicePrincipalName: "new-sp", + }, svc) + require.NoError(t, err) + assert.Equal(t, "new-sp", result.appIdOrName) + assert.Equal(t, "new-sp", result.applicationName) + }) +} + +// ------------------------------------------------------------------ +// GitHub credentialOptions +// ------------------------------------------------------------------ + +func Test_GitHubCiProvider_credentialOptions(t *testing.T) { + ctx := context.Background() + provider := &GitHubCiProvider{} + + t.Run("client-credentials auth", func(t *testing.T) { + opts, err := provider.credentialOptions(ctx, + &gitRepositoryDetails{ + owner: "Azure", + repoName: "azure-dev", + branch: "main", + }, + provisioning.Options{}, + AuthTypeClientCredentials, + &entraid.AzureCredentials{}) + require.NoError(t, err) + assert.True(t, opts.EnableClientCredentials) + assert.False(t, opts.EnableFederatedCredentials) + }) + + t.Run("federated auth creates credentials", func(t *testing.T) { + opts, err := provider.credentialOptions(ctx, + &gitRepositoryDetails{ + owner: "Azure", + repoName: "azure-dev", + branch: "feature", + }, + provisioning.Options{}, + AuthTypeFederated, + &entraid.AzureCredentials{}) + require.NoError(t, err) + assert.False(t, opts.EnableClientCredentials) + assert.True(t, opts.EnableFederatedCredentials) + // Should have pull_request + feature + main creds + require.Len(t, opts.FederatedCredentialOptions, 3) + // First is always pull_request + assert.Contains(t, + opts.FederatedCredentialOptions[0].Subject, + "pull_request") + }) + + t.Run( + "federated on main branch - no duplicate", + func(t *testing.T) { + opts, err := provider.credentialOptions(ctx, + &gitRepositoryDetails{ + owner: "Azure", + repoName: "azure-dev", + branch: "main", + }, + provisioning.Options{}, + AuthTypeFederated, + &entraid.AzureCredentials{}) + require.NoError(t, err) + // pull_request + main (no duplicate) + require.Len(t, opts.FederatedCredentialOptions, 2) + }) + + t.Run("empty auth type defaults to federated", + func(t *testing.T) { + opts, err := provider.credentialOptions(ctx, + &gitRepositoryDetails{ + owner: "Azure", + repoName: "azure-dev", + branch: "dev", + }, + provisioning.Options{}, + "", + &entraid.AzureCredentials{}) + require.NoError(t, err) + assert.True(t, opts.EnableFederatedCredentials) + assert.False(t, opts.EnableClientCredentials) + }) + + t.Run( + "unknown auth type returns empty options", + func(t *testing.T) { + opts, err := provider.credentialOptions(ctx, + &gitRepositoryDetails{ + owner: "Azure", + repoName: "azure-dev", + branch: "main", + }, + provisioning.Options{}, + PipelineAuthType("unknown-type"), + &entraid.AzureCredentials{}) + require.NoError(t, err) + assert.False(t, opts.EnableClientCredentials) + assert.False(t, opts.EnableFederatedCredentials) + }) + + t.Run( + "federated credential names sanitized", + func(t *testing.T) { + opts, err := provider.credentialOptions(ctx, + &gitRepositoryDetails{ + owner: "my.org", + repoName: "my.repo", + branch: "feat/branch", + }, + provisioning.Options{}, + AuthTypeFederated, + &entraid.AzureCredentials{}) + require.NoError(t, err) + for _, cred := range opts.FederatedCredentialOptions { + // No dots or slashes in name + assert.NotContains(t, cred.Name, ".") + assert.NotContains(t, cred.Name, "/") + } + }) +} + +// ------------------------------------------------------------------ +// PipelineManager.SetParameters +// ------------------------------------------------------------------ + +func Test_PipelineManager_SetParameters(t *testing.T) { + t.Run("initializes configOptions if nil", func(t *testing.T) { + pm := &PipelineManager{} + pm.SetParameters([]provisioning.Parameter{ + {Name: "param1"}, + }) + require.NotNil(t, pm.configOptions) + require.Len(t, pm.configOptions.providerParameters, 1) + assert.Equal(t, "param1", + pm.configOptions.providerParameters[0].Name) + }) + + t.Run("replaces existing parameters", func(t *testing.T) { + pm := &PipelineManager{ + configOptions: &configurePipelineOptions{ + providerParameters: []provisioning.Parameter{ + {Name: "old"}, + }, + }, + } + pm.SetParameters([]provisioning.Parameter{ + {Name: "new1"}, {Name: "new2"}, + }) + require.Len(t, pm.configOptions.providerParameters, 2) + }) + + t.Run("nil parameters clears list", func(t *testing.T) { + pm := &PipelineManager{ + configOptions: &configurePipelineOptions{ + providerParameters: []provisioning.Parameter{ + {Name: "old"}, + }, + }, + } + pm.SetParameters(nil) + assert.Nil(t, pm.configOptions.providerParameters) + }) +} + +// ------------------------------------------------------------------ +// projectProperties / configurePipelineOptions struct sanity +// ------------------------------------------------------------------ + +func Test_projectProperties_fields(t *testing.T) { + props := projectProperties{ + CiProvider: ciProviderGitHubActions, + InfraProvider: infraProviderBicep, + RepoRoot: "/tmp/repo", + HasAppHost: true, + BranchName: "main", + AuthType: AuthTypeFederated, + Variables: []string{"VAR1"}, + Secrets: []string{"SEC1"}, + } + assert.Equal(t, ciProviderGitHubActions, props.CiProvider) + assert.Equal(t, infraProviderBicep, props.InfraProvider) + assert.True(t, props.HasAppHost) + assert.Equal(t, "main", props.BranchName) + assert.Equal(t, AuthTypeFederated, props.AuthType) +} + +// ------------------------------------------------------------------ +// GitHub regex patterns +// ------------------------------------------------------------------ + +func Test_gitHubRemoteRegexPatterns(t *testing.T) { + t.Run("https url with .git", func(t *testing.T) { + m := gitHubRemoteHttpsUrlRegex.FindStringSubmatch( + "https://github.com/Azure/azure-dev.git") + require.NotNil(t, m) + assert.Equal(t, "Azure/azure-dev", m[1]) + }) + + t.Run("https url without .git", func(t *testing.T) { + m := gitHubRemoteHttpsUrlRegex.FindStringSubmatch( + "https://github.com/Azure/azure-dev") + require.NotNil(t, m) + assert.Equal(t, "Azure/azure-dev", m[1]) + }) + + t.Run("ssh url with .git", func(t *testing.T) { + m := gitHubRemoteGitUrlRegex.FindStringSubmatch( + "git@github.com:Azure/azure-dev.git") + require.NotNil(t, m) + assert.Equal(t, "Azure/azure-dev", m[1]) + }) + + t.Run("ssh url without .git", func(t *testing.T) { + m := gitHubRemoteGitUrlRegex.FindStringSubmatch( + "git@github.com:Azure/azure-dev") + require.NotNil(t, m) + assert.Equal(t, "Azure/azure-dev", m[1]) + }) + + t.Run("https www prefix", func(t *testing.T) { + m := gitHubRemoteHttpsUrlRegex.FindStringSubmatch( + "https://www.github.com/owner/repo.git") + require.NotNil(t, m) + assert.Equal(t, "owner/repo", m[1]) + }) + + t.Run("ssh regex does not match https", func(t *testing.T) { + m := gitHubRemoteGitUrlRegex.FindStringSubmatch( + "https://github.com/Azure/azure-dev.git") + assert.Nil(t, m) + }) +} + +// ------------------------------------------------------------------ +// DefaultRoleNames +// ------------------------------------------------------------------ + +func Test_DefaultRoleNames(t *testing.T) { + require.Len(t, DefaultRoleNames, 2) + assert.Contains(t, DefaultRoleNames, "Contributor") + assert.Contains(t, DefaultRoleNames, "User Access Administrator") +} diff --git a/cli/azd/pkg/project/artifact_filter_test.go b/cli/azd/pkg/project/artifact_filter_test.go new file mode 100644 index 00000000000..840e5b2c279 --- /dev/null +++ b/cli/azd/pkg/project/artifact_filter_test.go @@ -0,0 +1,476 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package project + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_validateArtifact(t *testing.T) { + tests := []struct { + name string + artifact *Artifact + expectErr string + }{ + { + name: "valid artifact", + artifact: &Artifact{ + Kind: ArtifactKindDirectory, + Location: "/build", + LocationKind: LocationKindLocal, + }, + }, + { + name: "empty kind", + artifact: &Artifact{ + Kind: "", + Location: "/build", + LocationKind: LocationKindLocal, + }, + expectErr: "kind is required", + }, + { + name: "whitespace-only kind", + artifact: &Artifact{ + Kind: ArtifactKind(" "), + Location: "/build", + LocationKind: LocationKindLocal, + }, + expectErr: "kind is required", + }, + { + name: "unknown kind", + artifact: &Artifact{ + Kind: ArtifactKind("foobar"), + Location: "/build", + LocationKind: LocationKindLocal, + }, + expectErr: "not a recognized artifact kind", + }, + { + name: "empty location", + artifact: &Artifact{ + Kind: ArtifactKindArchive, + Location: "", + LocationKind: LocationKindLocal, + }, + expectErr: "location is required", + }, + { + name: "whitespace-only location", + artifact: &Artifact{ + Kind: ArtifactKindArchive, + Location: " ", + LocationKind: LocationKindLocal, + }, + expectErr: "location is required", + }, + { + name: "empty locationKind", + artifact: &Artifact{ + Kind: ArtifactKindArchive, + Location: "/build/out.zip", + LocationKind: "", + }, + expectErr: "locationKind is required", + }, + { + name: "invalid locationKind", + artifact: &Artifact{ + Kind: ArtifactKindArchive, + Location: "/build/out.zip", + LocationKind: LocationKind("cloud"), + }, + expectErr: "locationKind must be either", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateArtifact(tt.artifact) + if tt.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectErr) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_ArtifactCollection_Add_Validation(t *testing.T) { + ac := ArtifactCollection{} + + // Add a valid artifact + err := ac.Add(&Artifact{ + Kind: ArtifactKindContainer, + Location: "registry.io/img:v1", + LocationKind: LocationKindRemote, + }) + require.NoError(t, err) + require.Len(t, ac, 1) + + // Add an invalid artifact — collection should not grow + err = ac.Add(&Artifact{ + Kind: ArtifactKind("bad"), + Location: "somewhere", + LocationKind: LocationKindLocal, + }) + require.Error(t, err) + require.Len(t, ac, 1) +} + +func Test_ArtifactCollection_Add_Multiple(t *testing.T) { + ac := ArtifactCollection{} + + err := ac.Add( + &Artifact{ + Kind: ArtifactKindDirectory, + Location: "/a", + LocationKind: LocationKindLocal, + }, + &Artifact{ + Kind: ArtifactKindArchive, + Location: "/b.zip", + LocationKind: LocationKindLocal, + }, + ) + require.NoError(t, err) + require.Len(t, ac, 2) +} + +func Test_ArtifactCollection_Find_WithKindFilter(t *testing.T) { + ac := ArtifactCollection{ + { + Kind: ArtifactKindDirectory, + Location: "/a", + LocationKind: LocationKindLocal, + }, + { + Kind: ArtifactKindContainer, + Location: "img:v1", + LocationKind: LocationKindRemote, + }, + { + Kind: ArtifactKindDirectory, + Location: "/b", + LocationKind: LocationKindLocal, + }, + } + + dirs := ac.Find(WithKind(ArtifactKindDirectory)) + require.Len(t, dirs, 2) + + containers := ac.Find(WithKind(ArtifactKindContainer)) + require.Len(t, containers, 1) + require.Equal(t, "img:v1", containers[0].Location) + + archives := ac.Find(WithKind(ArtifactKindArchive)) + require.Empty(t, archives) +} + +func Test_ArtifactCollection_Find_WithLocationKindFilter( + t *testing.T, +) { + ac := ArtifactCollection{ + { + Kind: ArtifactKindContainer, + Location: "local-img", + LocationKind: LocationKindLocal, + }, + { + Kind: ArtifactKindContainer, + Location: "registry.io/img", + LocationKind: LocationKindRemote, + }, + } + + local := ac.Find(WithLocationKind(LocationKindLocal)) + require.Len(t, local, 1) + require.Equal(t, "local-img", local[0].Location) + + remote := ac.Find(WithLocationKind(LocationKindRemote)) + require.Len(t, remote, 1) + require.Equal(t, "registry.io/img", remote[0].Location) +} + +func Test_ArtifactCollection_Find_WithTake(t *testing.T) { + ac := ArtifactCollection{ + { + Kind: ArtifactKindDirectory, + Location: "/a", + LocationKind: LocationKindLocal, + }, + { + Kind: ArtifactKindDirectory, + Location: "/b", + LocationKind: LocationKindLocal, + }, + { + Kind: ArtifactKindDirectory, + Location: "/c", + LocationKind: LocationKindLocal, + }, + } + + result := ac.Find(WithTake(2)) + require.Len(t, result, 2) + require.Equal(t, "/a", result[0].Location) + require.Equal(t, "/b", result[1].Location) + + // Take more than available + result = ac.Find(WithTake(100)) + require.Len(t, result, 3) +} + +func Test_ArtifactCollection_Find_CombinedFilters( + t *testing.T, +) { + ac := ArtifactCollection{ + { + Kind: ArtifactKindContainer, + Location: "local-img", + LocationKind: LocationKindLocal, + }, + { + Kind: ArtifactKindContainer, + Location: "reg/img1", + LocationKind: LocationKindRemote, + }, + { + Kind: ArtifactKindContainer, + Location: "reg/img2", + LocationKind: LocationKindRemote, + }, + { + Kind: ArtifactKindDirectory, + Location: "/build", + LocationKind: LocationKindLocal, + }, + } + + // Combine kind + locationKind + take + result := ac.Find( + WithKind(ArtifactKindContainer), + WithLocationKind(LocationKindRemote), + WithTake(1), + ) + require.Len(t, result, 1) + require.Equal(t, "reg/img1", result[0].Location) +} + +func Test_ArtifactCollection_FindFirst(t *testing.T) { + ac := ArtifactCollection{ + { + Kind: ArtifactKindDirectory, + Location: "/first", + LocationKind: LocationKindLocal, + }, + { + Kind: ArtifactKindDirectory, + Location: "/second", + LocationKind: LocationKindLocal, + }, + } + + first, ok := ac.FindFirst( + WithKind(ArtifactKindDirectory), + ) + require.True(t, ok) + require.Equal(t, "/first", first.Location) + + // No match + _, ok = ac.FindFirst(WithKind(ArtifactKindArchive)) + require.False(t, ok) +} + +func Test_ArtifactCollection_FindLast(t *testing.T) { + ac := ArtifactCollection{ + { + Kind: ArtifactKindDirectory, + Location: "/first", + LocationKind: LocationKindLocal, + }, + { + Kind: ArtifactKindDirectory, + Location: "/second", + LocationKind: LocationKindLocal, + }, + { + Kind: ArtifactKindContainer, + Location: "img:latest", + LocationKind: LocationKindRemote, + }, + } + + last, ok := ac.FindLast( + WithKind(ArtifactKindDirectory), + ) + require.True(t, ok) + require.Equal(t, "/second", last.Location) + + // No match + _, ok = ac.FindLast(WithKind(ArtifactKindArchive)) + require.False(t, ok) +} + +func Test_ArtifactCollection_FindFirst_Empty(t *testing.T) { + ac := ArtifactCollection{} + _, ok := ac.FindFirst() + require.False(t, ok) +} + +func Test_ArtifactCollection_FindLast_Empty(t *testing.T) { + ac := ArtifactCollection{} + _, ok := ac.FindLast() + require.False(t, ok) +} + +func Test_ArtifactCollection_ToString_Empty(t *testing.T) { + ac := ArtifactCollection{} + result := ac.ToString(" ") + require.Contains(t, result, "No artifacts were found") +} + +func Test_ArtifactCollection_ToString_FiltersNonDisplayable( + t *testing.T, +) { + ac := ArtifactCollection{ + { + Kind: ArtifactKindDeployment, + Location: "https://deploy.url", + LocationKind: LocationKindRemote, + }, + } + + // ArtifactKindDeployment falls through to default which + // returns "" — so collection output is empty string. + result := ac.ToString("") + require.Empty(t, result) +} + +func Test_ArtifactCollection_MarshalJSON(t *testing.T) { + ac := ArtifactCollection{ + { + Kind: ArtifactKindEndpoint, + Location: "https://api.example.com", + LocationKind: LocationKindRemote, + Metadata: map[string]string{"label": "API"}, + }, + } + + data, err := ac.MarshalJSON() + require.NoError(t, err) + + var unmarshaled []*Artifact + err = json.Unmarshal(data, &unmarshaled) + require.NoError(t, err) + require.Len(t, unmarshaled, 1) + require.Equal(t, ArtifactKindEndpoint, unmarshaled[0].Kind) + require.Equal( + t, + "https://api.example.com", + unmarshaled[0].Location, + ) +} + +func Test_findFilter_matches(t *testing.T) { + tests := []struct { + name string + filter findFilter + artifact *Artifact + expected bool + }{ + { + name: "no filter matches everything", + filter: findFilter{}, + artifact: &Artifact{ + Kind: ArtifactKindDirectory, + LocationKind: LocationKindLocal, + }, + expected: true, + }, + { + name: "kind filter match", + filter: findFilter{ + kind: new(ArtifactKindContainer), + }, + artifact: &Artifact{ + Kind: ArtifactKindContainer, + LocationKind: LocationKindRemote, + }, + expected: true, + }, + { + name: "kind filter mismatch", + filter: findFilter{ + kind: new(ArtifactKindContainer), + }, + artifact: &Artifact{ + Kind: ArtifactKindDirectory, + LocationKind: LocationKindLocal, + }, + expected: false, + }, + { + name: "locationKind filter match", + filter: findFilter{ + locationKind: new(LocationKindRemote), + }, + artifact: &Artifact{ + Kind: ArtifactKindContainer, + LocationKind: LocationKindRemote, + }, + expected: true, + }, + { + name: "locationKind filter mismatch", + filter: findFilter{ + locationKind: new(LocationKindRemote), + }, + artifact: &Artifact{ + Kind: ArtifactKindContainer, + LocationKind: LocationKindLocal, + }, + expected: false, + }, + { + name: "both filters match", + filter: findFilter{ + kind: new(ArtifactKindArchive), + locationKind: new(LocationKindLocal), + }, + artifact: &Artifact{ + Kind: ArtifactKindArchive, + LocationKind: LocationKindLocal, + }, + expected: true, + }, + { + name: "kind matches but locationKind does not", + filter: findFilter{ + kind: new(ArtifactKindArchive), + locationKind: new(LocationKindRemote), + }, + artifact: &Artifact{ + Kind: ArtifactKindArchive, + LocationKind: LocationKindLocal, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal( + t, + tt.expected, + tt.filter.matches(tt.artifact), + ) + }) + } +} diff --git a/cli/azd/pkg/project/framework_service_language_test.go b/cli/azd/pkg/project/framework_service_language_test.go new file mode 100644 index 00000000000..90ff3b6f480 --- /dev/null +++ b/cli/azd/pkg/project/framework_service_language_test.go @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package project + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_parseServiceLanguage(t *testing.T) { + tests := []struct { + name string + input ServiceLanguageKind + expected ServiceLanguageKind + }{ + { + name: "dotnet", + input: ServiceLanguageDotNet, + expected: ServiceLanguageDotNet, + }, + { + name: "csharp", + input: ServiceLanguageCsharp, + expected: ServiceLanguageCsharp, + }, + { + name: "fsharp", + input: ServiceLanguageFsharp, + expected: ServiceLanguageFsharp, + }, + { + name: "javascript", + input: ServiceLanguageJavaScript, + expected: ServiceLanguageJavaScript, + }, + { + name: "typescript", + input: ServiceLanguageTypeScript, + expected: ServiceLanguageTypeScript, + }, + { + name: "python", + input: ServiceLanguagePython, + expected: ServiceLanguagePython, + }, + { + name: "java", + input: ServiceLanguageJava, + expected: ServiceLanguageJava, + }, + { + name: "docker", + input: ServiceLanguageDocker, + expected: ServiceLanguageDocker, + }, + { + name: "custom", + input: ServiceLanguageCustom, + expected: ServiceLanguageCustom, + }, + { + name: "empty (none)", + input: ServiceLanguageNone, + expected: ServiceLanguageNone, + }, + { + name: "py alias resolves to python", + input: ServiceLanguageKind("py"), + expected: ServiceLanguagePython, + }, + { + name: "unknown language passes through", + input: ServiceLanguageKind("rust"), + expected: ServiceLanguageKind("rust"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseServiceLanguage(tt.input) + require.NoError(t, err) + require.Equal(t, tt.expected, result) + }) + } +} + +func Test_ServiceLanguageKind_IsDotNet(t *testing.T) { + tests := []struct { + name string + kind ServiceLanguageKind + expected bool + }{ + { + name: "dotnet", + kind: ServiceLanguageDotNet, + expected: true, + }, + { + name: "csharp", + kind: ServiceLanguageCsharp, + expected: true, + }, + { + name: "fsharp", + kind: ServiceLanguageFsharp, + expected: true, + }, + { + name: "python is not dotnet", + kind: ServiceLanguagePython, + expected: false, + }, + { + name: "javascript is not dotnet", + kind: ServiceLanguageJavaScript, + expected: false, + }, + { + name: "java is not dotnet", + kind: ServiceLanguageJava, + expected: false, + }, + { + name: "docker is not dotnet", + kind: ServiceLanguageDocker, + expected: false, + }, + { + name: "empty is not dotnet", + kind: ServiceLanguageNone, + expected: false, + }, + { + name: "unknown is not dotnet", + kind: ServiceLanguageKind("go"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.kind.IsDotNet()) + }) + } +} diff --git a/cli/azd/pkg/project/framework_service_noop_test.go b/cli/azd/pkg/project/framework_service_noop_test.go new file mode 100644 index 00000000000..90816041b74 --- /dev/null +++ b/cli/azd/pkg/project/framework_service_noop_test.go @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package project + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_noOpProject_Requirements(t *testing.T) { + svc := NewNoOpProject(nil) + reqs := svc.Requirements() + + require.False(t, reqs.Package.RequireRestore) + require.False(t, reqs.Package.RequireBuild) +} + +func Test_noOpProject_RequiredExternalTools(t *testing.T) { + svc := NewNoOpProject(nil) + tools := svc.RequiredExternalTools( + context.Background(), nil, + ) + + require.NotNil(t, tools) + require.Empty(t, tools) +} + +func Test_noOpProject_Initialize(t *testing.T) { + svc := NewNoOpProject(nil) + err := svc.Initialize(context.Background(), nil) + require.NoError(t, err) +} + +func Test_noOpProject_Restore(t *testing.T) { + svc := NewNoOpProject(nil) + result, err := svc.Restore( + context.Background(), nil, nil, nil, + ) + + require.NoError(t, err) + require.NotNil(t, result) +} + +func Test_noOpProject_Build(t *testing.T) { + svc := NewNoOpProject(nil) + result, err := svc.Build( + context.Background(), nil, nil, nil, + ) + + require.NoError(t, err) + require.NotNil(t, result) +} + +func Test_noOpProject_Package(t *testing.T) { + svc := NewNoOpProject(nil) + result, err := svc.Package( + context.Background(), nil, nil, nil, + ) + + require.NoError(t, err) + require.NotNil(t, result) +} diff --git a/cli/azd/pkg/project/resources_test.go b/cli/azd/pkg/project/resources_test.go new file mode 100644 index 00000000000..48413d10f62 --- /dev/null +++ b/cli/azd/pkg/project/resources_test.go @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package project + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_AllResourceTypes(t *testing.T) { + all := AllResourceTypes() + + require.NotEmpty(t, all) + + // Verify every known constant is present + expected := []ResourceType{ + ResourceTypeDbRedis, + ResourceTypeDbPostgres, + ResourceTypeDbMySql, + ResourceTypeDbMongo, + ResourceTypeDbCosmos, + ResourceTypeHostAppService, + ResourceTypeHostContainerApp, + ResourceTypeOpenAiModel, + ResourceTypeMessagingEventHubs, + ResourceTypeMessagingServiceBus, + ResourceTypeStorage, + ResourceTypeAiProject, + ResourceTypeAiSearch, + ResourceTypeKeyVault, + } + + require.Equal(t, expected, all) +} + +func Test_ResourceType_String(t *testing.T) { + tests := []struct { + name string + rt ResourceType + expected string + }{ + {"Redis", ResourceTypeDbRedis, "Redis"}, + {"PostgreSQL", ResourceTypeDbPostgres, "PostgreSQL"}, + {"MySQL", ResourceTypeDbMySql, "MySQL"}, + {"MongoDB", ResourceTypeDbMongo, "MongoDB"}, + {"CosmosDB", ResourceTypeDbCosmos, "CosmosDB"}, + { + "App Service", + ResourceTypeHostAppService, + "App Service", + }, + { + "Container App", + ResourceTypeHostContainerApp, + "Container App", + }, + { + "Open AI Model", + ResourceTypeOpenAiModel, + "Open AI Model", + }, + { + "Event Hubs", + ResourceTypeMessagingEventHubs, + "Event Hubs", + }, + { + "Service Bus", + ResourceTypeMessagingServiceBus, + "Service Bus", + }, + { + "Storage Account", + ResourceTypeStorage, + "Storage Account", + }, + {"Foundry", ResourceTypeAiProject, "Foundry"}, + {"AI Search", ResourceTypeAiSearch, "AI Search"}, + {"Key Vault", ResourceTypeKeyVault, "Key Vault"}, + { + "unknown returns empty", + ResourceType("unknown.type"), + "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.rt.String()) + }) + } +} + +func Test_ResourceType_AzureResourceType(t *testing.T) { + tests := []struct { + name string + rt ResourceType + expected string + }{ + { + "AppService", + ResourceTypeHostAppService, + "Microsoft.Web/sites", + }, + { + "ContainerApp", + ResourceTypeHostContainerApp, + "Microsoft.App/containerApps", + }, + { + "Redis", + ResourceTypeDbRedis, + "Microsoft.Cache/redis", + }, + { + "Postgres", + ResourceTypeDbPostgres, + "Microsoft.DBforPostgreSQL/flexibleServers/databases", + }, + { + "MySQL", + ResourceTypeDbMySql, + "Microsoft.DBforMySQL/flexibleServers/databases", + }, + { + "MongoDB", + ResourceTypeDbMongo, + "Microsoft.DocumentDB/databaseAccounts/mongodbDatabases", + }, + { + "OpenAI Model", + ResourceTypeOpenAiModel, + "Microsoft.CognitiveServices/accounts/deployments", + }, + { + "CosmosDB", + ResourceTypeDbCosmos, + "Microsoft.DocumentDB/databaseAccounts/sqlDatabases", + }, + { + "EventHubs", + ResourceTypeMessagingEventHubs, + "Microsoft.EventHub/namespaces", + }, + { + "ServiceBus", + ResourceTypeMessagingServiceBus, + "Microsoft.ServiceBus/namespaces", + }, + { + "Storage", + ResourceTypeStorage, + "Microsoft.Storage/storageAccounts", + }, + { + "KeyVault", + ResourceTypeKeyVault, + "Microsoft.KeyVault/vaults", + }, + { + "AiProject", + ResourceTypeAiProject, + "Microsoft.CognitiveServices/accounts/projects", + }, + { + "AiSearch", + ResourceTypeAiSearch, + "Microsoft.Search/searchServices", + }, + { + "unknown returns empty", + ResourceType("custom.thing"), + "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.rt.AzureResourceType()) + }) + } +} + +func Test_AllResourceTypes_StringAndAzureType_Complete( + t *testing.T, +) { + // Every type in AllResourceTypes should have a non-empty + // String() and AzureResourceType() value. + for _, rt := range AllResourceTypes() { + t.Run(string(rt), func(t *testing.T) { + require.NotEmpty( + t, + rt.String(), + "String() should not be empty for %s", + rt, + ) + require.NotEmpty( + t, + rt.AzureResourceType(), + "AzureResourceType() should not be empty for %s", + rt, + ) + }) + } +} diff --git a/cli/azd/pkg/project/service_logic_test.go b/cli/azd/pkg/project/service_logic_test.go new file mode 100644 index 00000000000..afee3cb2554 --- /dev/null +++ b/cli/azd/pkg/project/service_logic_test.go @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package project + +import ( + "path/filepath" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/async" + "github.com/stretchr/testify/require" +) + +func Test_ServiceConfig_Path_Relative(t *testing.T) { + sc := &ServiceConfig{ + RelativePath: "src/api", + Project: &ProjectConfig{ + Path: "/home/user/myproject", + }, + } + + expected := filepath.Join( + "/home/user/myproject", "src/api", + ) + require.Equal(t, expected, sc.Path()) +} + +func Test_ServiceConfig_Path_Absolute(t *testing.T) { + // t.TempDir() returns a guaranteed absolute path + // on any OS. + absPath := t.TempDir() + + sc := &ServiceConfig{ + RelativePath: absPath, + Project: &ProjectConfig{ + Path: "/home/user/myproject", + }, + } + + // When RelativePath is absolute, it should be returned + // as-is without joining with Project.Path. + require.True(t, filepath.IsAbs(absPath)) + require.Equal(t, absPath, sc.Path()) +} + +func Test_isConditionTrue(t *testing.T) { + tests := []struct { + name string + value string + expected bool + }{ + {"1", "1", true}, + {"true", "true", true}, + {"TRUE", "TRUE", true}, + {"True", "True", true}, + {"yes", "yes", true}, + {"YES", "YES", true}, + {"Yes", "Yes", true}, + {"0", "0", false}, + {"false", "false", false}, + {"no", "no", false}, + {"empty", "", false}, + {"random", "random", false}, + {"tRuE mixed case", "tRuE", false}, + {"yEs mixed case", "yEs", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal( + t, tt.expected, isConditionTrue(tt.value), + ) + }) + } +} + +func Test_NewServiceContext(t *testing.T) { + ctx := NewServiceContext() + + require.NotNil(t, ctx) + require.NotNil(t, ctx.Restore) + require.NotNil(t, ctx.Build) + require.NotNil(t, ctx.Package) + require.NotNil(t, ctx.Publish) + require.NotNil(t, ctx.Deploy) + + // All collections should be empty (not nil) + require.Len(t, ctx.Restore, 0) + require.Len(t, ctx.Build, 0) + require.Len(t, ctx.Package, 0) + require.Len(t, ctx.Publish, 0) + require.Len(t, ctx.Deploy, 0) +} + +func Test_NewServiceProgress(t *testing.T) { + msg := "deploying to Azure" + progress := NewServiceProgress(msg) + + require.Equal(t, msg, progress.Message) + require.False(t, progress.Timestamp.IsZero()) +} + +func Test_envResolver_NilEnv(t *testing.T) { + resolver := envResolver(nil) + result := resolver("ANY_KEY") + require.Equal(t, "", result) +} + +func Test_createProgressFunc_NilProgress(t *testing.T) { + fn := createProgressFunc(nil) + // Should not panic + fn("some message") +} + +func Test_createProgressFunc_WithProgress(t *testing.T) { + // NewNoopProgress drains the channel in a background + // goroutine, preventing SetProgress from blocking. + progress := async.NewNoopProgress[ServiceProgress]() + fn := createProgressFunc(progress) + + // Should not panic; sets progress internally + fn("deploying step 1") + progress.Done() +} + +func Test_externalTool_Name(t *testing.T) { + tool := &externalTool{ + name: "my-tool", + installUrl: "https://example.com/install", + } + + require.Equal(t, "my-tool", tool.Name()) +} + +func Test_externalTool_InstallUrl(t *testing.T) { + tool := &externalTool{ + name: "my-tool", + installUrl: "https://example.com/install", + } + + require.Equal( + t, + "https://example.com/install", + tool.InstallUrl(), + ) +} + +func Test_stripUTF8BOM(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "no BOM", + input: []byte("hello world"), + expected: []byte("hello world"), + }, + { + name: "with BOM", + input: append(utf8BOM, []byte("hello")...), + expected: []byte("hello"), + }, + { + name: "only BOM", + input: utf8BOM, + expected: []byte{}, + }, + { + name: "empty input", + input: []byte{}, + expected: []byte{}, + }, + { + name: "partial BOM prefix", + input: []byte{0xEF, 0xBB}, + expected: []byte{0xEF, 0xBB}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal( + t, tt.expected, stripUTF8BOM(tt.input), + ) + }) + } +} diff --git a/cli/azd/pkg/project/service_target_kind_test.go b/cli/azd/pkg/project/service_target_kind_test.go new file mode 100644 index 00000000000..68e57dda08f --- /dev/null +++ b/cli/azd/pkg/project/service_target_kind_test.go @@ -0,0 +1,306 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package project + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/azapi" + "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/stretchr/testify/require" +) + +func Test_BuiltInServiceTargetKinds(t *testing.T) { + kinds := BuiltInServiceTargetKinds() + + require.NotEmpty(t, kinds) + require.Contains(t, kinds, AppServiceTarget) + require.Contains(t, kinds, ContainerAppTarget) + require.Contains(t, kinds, AzureFunctionTarget) + require.Contains(t, kinds, StaticWebAppTarget) + require.Contains(t, kinds, AksTarget) + require.Contains(t, kinds, AiEndpointTarget) + + // DotNetContainerAppTarget and SpringAppTarget are + // intentionally excluded from the built-in list. + require.NotContains(t, kinds, DotNetContainerAppTarget) + require.NotContains(t, kinds, SpringAppTarget) +} + +func Test_builtInServiceTargetNames(t *testing.T) { + names := builtInServiceTargetNames() + kinds := BuiltInServiceTargetKinds() + + require.Len(t, names, len(kinds)) + for i, kind := range kinds { + require.Equal(t, string(kind), names[i]) + } +} + +func Test_ServiceTargetKind_RequiresContainer(t *testing.T) { + tests := []struct { + name string + kind ServiceTargetKind + expected bool + }{ + { + name: "ContainerAppTarget requires container", + kind: ContainerAppTarget, + expected: true, + }, + { + name: "AksTarget requires container", + kind: AksTarget, + expected: true, + }, + { + name: "AppServiceTarget does not", + kind: AppServiceTarget, + expected: false, + }, + { + name: "AzureFunctionTarget does not", + kind: AzureFunctionTarget, + expected: false, + }, + { + name: "StaticWebAppTarget does not", + kind: StaticWebAppTarget, + expected: false, + }, + { + name: "AiEndpointTarget does not", + kind: AiEndpointTarget, + expected: false, + }, + { + name: "DotNetContainerAppTarget does not", + kind: DotNetContainerAppTarget, + expected: false, + }, + { + name: "NonSpecifiedTarget does not", + kind: NonSpecifiedTarget, + expected: false, + }, + { + name: "Unknown kind does not", + kind: ServiceTargetKind("unknown"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.kind.RequiresContainer()) + }) + } +} + +func Test_ServiceTargetKind_IgnoreFile(t *testing.T) { + tests := []struct { + name string + kind ServiceTargetKind + expected string + }{ + { + name: "AppService returns .webappignore", + kind: AppServiceTarget, + expected: ".webappignore", + }, + { + name: "Function returns .funcignore", + kind: AzureFunctionTarget, + expected: ".funcignore", + }, + { + name: "ContainerApp returns empty", + kind: ContainerAppTarget, + expected: "", + }, + { + name: "AKS returns empty", + kind: AksTarget, + expected: "", + }, + { + name: "StaticWebApp returns empty", + kind: StaticWebAppTarget, + expected: "", + }, + { + name: "NonSpecified returns empty", + kind: NonSpecifiedTarget, + expected: "", + }, + { + name: "Unknown kind returns empty", + kind: ServiceTargetKind("custom-ext"), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.kind.IgnoreFile()) + }) + } +} + +func Test_ServiceTargetKind_SupportsDelayedProvisioning( + t *testing.T, +) { + tests := []struct { + name string + kind ServiceTargetKind + expected bool + }{ + { + name: "AKS supports delayed provisioning", + kind: AksTarget, + expected: true, + }, + { + name: "ContainerApp does not", + kind: ContainerAppTarget, + expected: false, + }, + { + name: "AppService does not", + kind: AppServiceTarget, + expected: false, + }, + { + name: "Function does not", + kind: AzureFunctionTarget, + expected: false, + }, + { + name: "StaticWebApp does not", + kind: StaticWebAppTarget, + expected: false, + }, + { + name: "NonSpecified does not", + kind: NonSpecifiedTarget, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.kind.SupportsDelayedProvisioning() + require.Equal(t, tt.expected, got) + }) + } +} + +func Test_parseServiceHost(t *testing.T) { + tests := []struct { + name string + kind ServiceTargetKind + expected ServiceTargetKind + expectErr bool + }{ + { + name: "known kind appservice", + kind: AppServiceTarget, + expected: AppServiceTarget, + }, + { + name: "known kind containerapp", + kind: ContainerAppTarget, + expected: ContainerAppTarget, + }, + { + name: "known kind function", + kind: AzureFunctionTarget, + expected: AzureFunctionTarget, + }, + { + name: "extension kind passes through", + kind: ServiceTargetKind("my-custom-ext"), + expected: ServiceTargetKind("my-custom-ext"), + }, + { + name: "empty kind returns error", + kind: ServiceTargetKind(""), + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseServiceHost(tt.kind) + if tt.expectErr { + require.Error(t, err) + require.Contains(t, err.Error(), "cannot be empty") + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +} + +func Test_resourceTypeMismatchError(t *testing.T) { + err := resourceTypeMismatchError( + "my-resource", + "Microsoft.Web/sites", + azapi.AzureResourceTypeContainerApp, + ) + + require.Error(t, err) + require.Contains(t, err.Error(), "my-resource") + require.Contains(t, err.Error(), "Microsoft.Web/sites") + require.Contains( + t, + err.Error(), + string(azapi.AzureResourceTypeContainerApp), + ) +} + +func Test_checkResourceType(t *testing.T) { + tests := []struct { + name string + resourceType string + expectedType azapi.AzureResourceType + expectErr bool + }{ + { + name: "matching type succeeds", + resourceType: "Microsoft.Web/sites", + expectedType: azapi.AzureResourceTypeWebSite, + }, + { + name: "case insensitive match succeeds", + resourceType: "microsoft.web/sites", + expectedType: azapi.AzureResourceTypeWebSite, + }, + { + name: "mismatched type fails", + resourceType: "Microsoft.App/containerApps", + expectedType: azapi.AzureResourceTypeWebSite, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resource := environment.NewTargetResource( + "sub-id", + "rg-name", + "res-name", + tt.resourceType, + ) + + err := checkResourceType(resource, tt.expectedType) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/cli/azd/pkg/project/service_target_springapp_test.go b/cli/azd/pkg/project/service_target_springapp_test.go new file mode 100644 index 00000000000..8f22ae80362 --- /dev/null +++ b/cli/azd/pkg/project/service_target_springapp_test.go @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package project + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_springAppTarget_Initialize_ReturnsDeprecated( + t *testing.T, +) { + target := NewSpringAppTarget(nil, nil) + err := target.Initialize(context.Background(), nil) + + require.Error(t, err) + require.ErrorIs(t, err, errSpringAppDeprecated) + require.Contains(t, err.Error(), "no longer supported") +} + +func Test_springAppTarget_Package_ReturnsDeprecated( + t *testing.T, +) { + target := NewSpringAppTarget(nil, nil) + result, err := target.Package( + context.Background(), nil, nil, nil, + ) + + require.Nil(t, result) + require.ErrorIs(t, err, errSpringAppDeprecated) +} + +func Test_springAppTarget_Deploy_ReturnsDeprecated( + t *testing.T, +) { + target := NewSpringAppTarget(nil, nil) + result, err := target.Deploy( + context.Background(), nil, nil, nil, nil, + ) + + require.Nil(t, result) + require.ErrorIs(t, err, errSpringAppDeprecated) +} + +func Test_springAppTarget_Publish_ReturnsDeprecated( + t *testing.T, +) { + target := NewSpringAppTarget(nil, nil) + result, err := target.Publish( + context.Background(), nil, nil, nil, nil, nil, + ) + + require.Nil(t, result) + require.ErrorIs(t, err, errSpringAppDeprecated) +} + +func Test_springAppTarget_Endpoints_ReturnsDeprecated( + t *testing.T, +) { + target := NewSpringAppTarget(nil, nil) + endpoints, err := target.Endpoints( + context.Background(), nil, nil, + ) + + require.Nil(t, endpoints) + require.ErrorIs(t, err, errSpringAppDeprecated) +} + +func Test_springAppTarget_RequiredExternalTools_Empty( + t *testing.T, +) { + target := NewSpringAppTarget(nil, nil) + tools := target.RequiredExternalTools( + context.Background(), nil, + ) + + require.NotNil(t, tools) + require.Empty(t, tools) +} diff --git a/cli/azd/pkg/prompt/format_helpers_test.go b/cli/azd/pkg/prompt/format_helpers_test.go new file mode 100644 index 00000000000..497c8994d57 --- /dev/null +++ b/cli/azd/pkg/prompt/format_helpers_test.go @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package prompt + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/account" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFormatSubscriptionDisplayName(t *testing.T) { + sub := &account.Subscription{ + Id: "sub-123-456", + Name: "My Subscription", + } + + tests := []struct { + name string + hideId bool + wantName bool + wantId bool + exactMatch string + }{ + { + name: "hideId true returns name only", + hideId: true, + wantName: true, + wantId: false, + exactMatch: "My Subscription", + }, + { + name: "hideId false includes both name and id", + hideId: false, + wantName: true, + wantId: true, + // Cannot check exact match due to ANSI color codes + // from output.WithGrayFormat + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatSubscriptionDisplayName(sub, tt.hideId) + require.NotEmpty(t, result) + + if tt.exactMatch != "" { + assert.Equal(t, tt.exactMatch, result) + } + if tt.wantName { + assert.Contains(t, result, "My Subscription") + } + if tt.wantId { + assert.Contains(t, result, "sub-123-456") + } + }) + } +} + +func TestFormatAutoSelectedSubscriptionMessage(t *testing.T) { + sub := &account.Subscription{ + Id: "sub-abc-def", + Name: "Production", + } + + tests := []struct { + name string + hideId bool + expected string + }{ + { + name: "hideId true omits id", + hideId: true, + expected: "Auto-selected subscription: Production", + }, + { + name: "hideId false includes id", + hideId: false, + expected: "Auto-selected subscription: " + + "Production (sub-abc-def)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatAutoSelectedSubscriptionMessage( + sub, tt.hideId, + ) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestFormatAutoSelectedSubscriptionMessage_EmptyName(t *testing.T) { + sub := &account.Subscription{ + Id: "sub-id", + Name: "", + } + + result := formatAutoSelectedSubscriptionMessage(sub, false) + assert.Equal( + t, "Auto-selected subscription: (sub-id)", result, + ) +} + +func TestIsDemoModeEnabled(t *testing.T) { + tests := []struct { + name string + envValue string + expected bool + }{ + { + name: "true enables demo mode", + envValue: "true", + expected: true, + }, + { + name: "TRUE enables demo mode", + envValue: "TRUE", + expected: true, + }, + { + name: "1 enables demo mode", + envValue: "1", + expected: true, + }, + { + name: "false disables demo mode", + envValue: "false", + expected: false, + }, + { + name: "0 disables demo mode", + envValue: "0", + expected: false, + }, + { + name: "empty string disables demo mode", + envValue: "", + expected: false, + }, + { + name: "invalid value disables demo mode", + envValue: "not-a-bool", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue != "" { + t.Setenv("AZD_DEMO_MODE", tt.envValue) + } else { + t.Setenv("AZD_DEMO_MODE", "") + } + + result := isDemoModeEnabled() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsDemoModeEnabled_Unset(t *testing.T) { + // Do NOT set AZD_DEMO_MODE — rely on the default + // process env (which should not have it in CI). + // This tests the "env var not present" path. + t.Setenv("AZD_DEMO_MODE", "") + assert.False(t, isDemoModeEnabled()) +} + +func TestNewEmptyAzureContext(t *testing.T) { + ctx := NewEmptyAzureContext() + + require.NotNil(t, ctx) + assert.Equal(t, "", ctx.Scope.TenantId) + assert.Equal(t, "", ctx.Scope.SubscriptionId) + assert.Equal(t, "", ctx.Scope.Location) + assert.Equal(t, "", ctx.Scope.ResourceGroup) + require.NotNil(t, ctx.Resources) +} + +func TestAzureScope_Fields(t *testing.T) { + scope := AzureScope{ + TenantId: "tenant-1", + SubscriptionId: "sub-1", + Location: "eastus", + ResourceGroup: "rg-1", + } + + assert.Equal(t, "tenant-1", scope.TenantId) + assert.Equal(t, "sub-1", scope.SubscriptionId) + assert.Equal(t, "eastus", scope.Location) + assert.Equal(t, "rg-1", scope.ResourceGroup) +} + +func TestSelectOptions_Defaults(t *testing.T) { + // Verify that a zero-value SelectOptions has expected defaults + opts := SelectOptions{} + + assert.Nil(t, opts.ForceNewResource) + assert.Nil(t, opts.AllowNewResource) + assert.Equal(t, "", opts.Message) + assert.Equal(t, "", opts.HelpMessage) + assert.Equal(t, "", opts.LoadingMessage) + assert.Nil(t, opts.DisplayNumbers) + assert.Equal(t, 0, opts.DisplayCount) + assert.Equal(t, "", opts.Hint) + assert.Nil(t, opts.EnableFiltering) + assert.Nil(t, opts.Writer) +} + +func TestResourceGroupOptions_Defaults(t *testing.T) { + opts := ResourceGroupOptions{} + assert.Nil(t, opts.SelectorOptions) +} + +func TestErrSentinels(t *testing.T) { + // Verify sentinel errors are not nil and have messages + require.NotNil(t, ErrNoResourcesFound) + require.NotNil(t, ErrNoResourceSelected) + assert.Equal( + t, "no resources found", ErrNoResourcesFound.Error(), + ) + assert.Equal( + t, "no resource selected", ErrNoResourceSelected.Error(), + ) +} diff --git a/cli/azd/pkg/syncmap/syncmap_test.go b/cli/azd/pkg/syncmap/syncmap_test.go new file mode 100644 index 00000000000..186ead5d282 --- /dev/null +++ b/cli/azd/pkg/syncmap/syncmap_test.go @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package syncmap + +import ( + "sort" + "sync" + "testing" +) + +func TestStore_and_Load(t *testing.T) { + var m Map[string, int] + + m.Store("a", 1) + m.Store("b", 2) + + v, ok := m.Load("a") + if !ok || v != 1 { + t.Fatalf("Load(a) = (%v, %v), want (1, true)", v, ok) + } + + v, ok = m.Load("b") + if !ok || v != 2 { + t.Fatalf("Load(b) = (%v, %v), want (2, true)", v, ok) + } +} + +func TestLoad_missing_key(t *testing.T) { + var m Map[string, int] + + v, ok := m.Load("missing") + if ok { + t.Fatal("Load(missing) should return false") + } + if v != 0 { + t.Fatalf("Load(missing) zero-value = %v, want 0", v) + } +} + +func TestLoad_missing_key_pointer_type(t *testing.T) { + var m Map[string, *int] + + v, ok := m.Load("missing") + if ok { + t.Fatal("Load(missing) should return false") + } + if v != nil { + t.Fatalf("Load(missing) zero-value = %v, want nil", v) + } +} + +func TestStore_overwrites_existing(t *testing.T) { + var m Map[string, string] + + m.Store("key", "old") + m.Store("key", "new") + + v, ok := m.Load("key") + if !ok || v != "new" { + t.Fatalf("Load(key) = (%v, %v), want (new, true)", v, ok) + } +} + +func TestLoadOrStore_stores_new_value(t *testing.T) { + var m Map[string, int] + + actual, loaded := m.LoadOrStore("k", 42) + if loaded { + t.Fatal("LoadOrStore should return loaded=false for new key") + } + if actual != 42 { + t.Fatalf("LoadOrStore actual = %v, want 42", actual) + } + + // Verify it persisted + v, ok := m.Load("k") + if !ok || v != 42 { + t.Fatalf("Load(k) after LoadOrStore = (%v, %v), want (42, true)", v, ok) + } +} + +func TestLoadOrStore_returns_existing(t *testing.T) { + var m Map[string, int] + + m.Store("k", 10) + + actual, loaded := m.LoadOrStore("k", 99) + if !loaded { + t.Fatal("LoadOrStore should return loaded=true for existing key") + } + if actual != 10 { + t.Fatalf("LoadOrStore actual = %v, want 10 (existing)", actual) + } +} + +func TestLoadAndDelete_existing_key(t *testing.T) { + var m Map[string, string] + + m.Store("x", "hello") + + v, loaded := m.LoadAndDelete("x") + if !loaded { + t.Fatal("LoadAndDelete should return loaded=true for existing key") + } + if v != "hello" { + t.Fatalf("LoadAndDelete value = %v, want hello", v) + } + + // Verify deletion + _, ok := m.Load("x") + if ok { + t.Fatal("Load(x) should return false after LoadAndDelete") + } +} + +func TestLoadAndDelete_missing_key(t *testing.T) { + var m Map[string, int] + + v, loaded := m.LoadAndDelete("nope") + if loaded { + t.Fatal("LoadAndDelete should return loaded=false for missing key") + } + if v != 0 { + t.Fatalf("LoadAndDelete zero-value = %v, want 0", v) + } +} + +func TestDelete(t *testing.T) { + var m Map[int, string] + + m.Store(1, "one") + m.Delete(1) + + _, ok := m.Load(1) + if ok { + t.Fatal("Load(1) should return false after Delete") + } +} + +func TestDelete_missing_key_no_panic(t *testing.T) { + var m Map[string, string] + + // Should not panic + m.Delete("does-not-exist") +} + +func TestRange(t *testing.T) { + var m Map[string, int] + + m.Store("a", 1) + m.Store("b", 2) + m.Store("c", 3) + + var keys []string + sum := 0 + m.Range(func(key string, value int) bool { + keys = append(keys, key) + sum += value + return true + }) + + sort.Strings(keys) + if len(keys) != 3 { + t.Fatalf("Range visited %d keys, want 3", len(keys)) + } + if keys[0] != "a" || keys[1] != "b" || keys[2] != "c" { + t.Fatalf("Range keys = %v, want [a b c]", keys) + } + if sum != 6 { + t.Fatalf("Range sum = %d, want 6", sum) + } +} + +func TestRange_early_stop(t *testing.T) { + var m Map[int, int] + + m.Store(1, 10) + m.Store(2, 20) + m.Store(3, 30) + + count := 0 + m.Range(func(key int, value int) bool { + count++ + return false // stop after first + }) + + if count != 1 { + t.Fatalf("Range with early stop visited %d keys, want 1", count) + } +} + +func TestRange_empty_map(t *testing.T) { + var m Map[string, string] + + count := 0 + m.Range(func(key string, value string) bool { + count++ + return true + }) + + if count != 0 { + t.Fatalf("Range on empty map visited %d keys, want 0", count) + } +} + +func TestConcurrent_access(t *testing.T) { + var m Map[int, int] + var wg sync.WaitGroup + + // Concurrent writes + for i := range 100 { + wg.Add(1) + go func(n int) { + defer wg.Done() + m.Store(n, n*10) + }(i) + } + wg.Wait() + + // Concurrent reads + for i := range 100 { + wg.Add(1) + go func(n int) { + defer wg.Done() + v, ok := m.Load(n) + if !ok { + t.Errorf("Load(%d) returned false", n) + return + } + if v != n*10 { + t.Errorf("Load(%d) = %d, want %d", n, v, n*10) + } + }(i) + } + wg.Wait() +} + +func TestInteger_key_type(t *testing.T) { + var m Map[int, string] + + m.Store(0, "zero") + m.Store(-1, "neg") + m.Store(42, "answer") + + v, ok := m.Load(0) + if !ok || v != "zero" { + t.Fatalf("Load(0) = (%v, %v), want (zero, true)", v, ok) + } + v, ok = m.Load(-1) + if !ok || v != "neg" { + t.Fatalf("Load(-1) = (%v, %v), want (neg, true)", v, ok) + } +} + +func TestStruct_value_type(t *testing.T) { + type item struct { + Name string + Count int + } + + var m Map[string, item] + + m.Store("x", item{Name: "test", Count: 5}) + + v, ok := m.Load("x") + if !ok { + t.Fatal("Load(x) returned false") + } + if v.Name != "test" || v.Count != 5 { + t.Fatalf("Load(x) = %+v, want {Name:test Count:5}", v) + } +} diff --git a/cli/azd/pkg/tools/git/git_additional_test.go b/cli/azd/pkg/tools/git/git_additional_test.go new file mode 100644 index 00000000000..be4b8fd9411 --- /dev/null +++ b/cli/azd/pkg/tools/git/git_additional_test.go @@ -0,0 +1,635 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package git + +import ( + "errors" + "slices" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/test/mocks/mockexec" + "github.com/stretchr/testify/require" +) + +func TestName(t *testing.T) { + cli := NewCli(nil) + require.Equal(t, "git CLI", cli.Name()) +} + +func TestInstallUrl(t *testing.T) { + cli := NewCli(nil) + require.Equal( + t, "https://git-scm.com/downloads", cli.InstallUrl(), + ) +} + +func TestGetRemoteUrl(t *testing.T) { + tests := []struct { + name string + stdout string + stderr string + err error + wantURL string + wantErr error + wantErrMsg string + }{ + { + name: "Success", + stdout: "https://github.com/user/repo.git\n", + wantURL: "https://github.com/user/repo.git", + }, + { + name: "SuccessSSH", + stdout: " git@github.com:user/repo.git \n", + wantURL: "git@github.com:user/repo.git", + }, + { + name: "NoSuchRemote", + stderr: "fatal: No such remote 'upstream'", + err: errors.New("exit code: 2"), + wantErr: ErrNoSuchRemote, + }, + { + name: "ErrorNoSuchRemote", + stderr: "error: No such remote 'upstream'", + err: errors.New("exit code: 2"), + wantErr: ErrNoSuchRemote, + }, + { + name: "NotAGitRepo", + stderr: "fatal: not a git repository", + err: errors.New("exit code: 128"), + wantErr: ErrNotRepository, + }, + { + name: "OtherError", + stderr: "some other error", + err: errors.New("exit code: 1"), + wantErrMsg: "failed to get remote url", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, command string, + ) bool { + return slices.Contains(args.Args, "remote") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{ + Stdout: tt.stdout, + Stderr: tt.stderr, + }, tt.err + }) + + cli := NewCli(runner) + url, err := cli.GetRemoteUrl( + t.Context(), "/repo", "origin", + ) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + return + } + if tt.wantErrMsg != "" { + require.ErrorContains(t, err, tt.wantErrMsg) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantURL, url) + }) + } +} + +func TestGetCurrentBranch(t *testing.T) { + tests := []struct { + name string + stdout string + stderr string + err error + wantBranch string + wantErr error + }{ + { + name: "Success", + stdout: "main\n", + wantBranch: "main", + }, + { + name: "FeatureBranch", + stdout: " feature/my-branch \n", + wantBranch: "feature/my-branch", + }, + { + name: "NotARepo", + stderr: "fatal: not a git repository", + err: errors.New("exit code: 128"), + wantErr: ErrNotRepository, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, command string, + ) bool { + return slices.Contains(args.Args, "branch") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{ + Stdout: tt.stdout, + Stderr: tt.stderr, + }, tt.err + }) + + cli := NewCli(runner) + branch, err := cli.GetCurrentBranch( + t.Context(), "/repo", + ) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantBranch, branch) + }) + } +} + +func TestGetRepoRoot(t *testing.T) { + tests := []struct { + name string + stdout string + stderr string + err error + wantRoot string + wantErr error + }{ + { + name: "Success", + stdout: "/home/user/project\n", + wantRoot: "/home/user/project", + }, + { + name: "NotARepo", + stderr: "fatal: not a git repository (or any parent)", + err: errors.New("exit code: 128"), + wantErr: ErrNotRepository, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, command string, + ) bool { + return slices.Contains(args.Args, "rev-parse") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{ + Stdout: tt.stdout, + Stderr: tt.stderr, + }, tt.err + }) + + cli := NewCli(runner) + root, err := cli.GetRepoRoot( + t.Context(), "/repo", + ) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantRoot, root) + }) + } +} + +func TestShallowClone(t *testing.T) { + tests := []struct { + name string + branch string + runErr error + wantErr bool + }{ + { + name: "WithBranch", + branch: "main", + }, + { + name: "WithoutBranch", + branch: "", + }, + { + name: "Error", + branch: "main", + runErr: errors.New("clone failed"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var capturedArgs exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, command string, + ) bool { + return slices.Contains(args.Args, "clone") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + capturedArgs = args + return exec.RunResult{}, tt.runErr + }) + + cli := NewCli(runner) + err := cli.ShallowClone( + t.Context(), + "https://github.com/user/repo", + tt.branch, + "/target", + ) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Contains(t, capturedArgs.Args, "clone") + require.Contains(t, capturedArgs.Args, "--depth") + require.Contains(t, capturedArgs.Args, "1") + require.Contains(t, capturedArgs.Args, "/target") + + if tt.branch != "" { + require.Contains( + t, capturedArgs.Args, "--branch", + ) + require.Contains( + t, capturedArgs.Args, tt.branch, + ) + } else { + require.NotContains( + t, capturedArgs.Args, "--branch", + ) + } + }) + } +} + +func TestInitRepo(t *testing.T) { + tests := []struct { + name string + initErr error + checkErr error + wantErr bool + wantErrMs string + }{ + { + name: "Success", + }, + { + name: "InitFails", + initErr: errors.New("init failed"), + wantErr: true, + }, + { + name: "CheckoutFails", + checkErr: errors.New("checkout failed"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, command string, + ) bool { + return slices.Contains(args.Args, "init") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{}, tt.initErr + }) + runner.When(func( + args exec.RunArgs, command string, + ) bool { + return slices.Contains(args.Args, "checkout") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{}, tt.checkErr + }) + + cli := NewCli(runner) + err := cli.InitRepo(t.Context(), "/repo") + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestIsUntrackedFile(t *testing.T) { + tests := []struct { + name string + stdout string + err error + wantResult bool + wantErr bool + }{ + { + name: "Untracked", + stdout: "?? newfile.txt\nuntracked files present", + wantResult: true, + }, + { + name: "NewFile", + stdout: "A new file added", + wantResult: true, + }, + { + name: "Tracked", + stdout: " M modified.go", + wantResult: false, + }, + { + name: "Empty", + stdout: "", + wantResult: false, + }, + { + name: "Error", + err: errors.New("git error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, command string, + ) bool { + return slices.Contains(args.Args, "status") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{ + Stdout: tt.stdout, + }, tt.err + }) + + cli := NewCli(runner) + result, err := cli.IsUntrackedFile( + t.Context(), "/repo", "file.txt", + ) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantResult, result) + }) + } +} + +func TestAddRemote(t *testing.T) { + t.Run("Success", func(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "remote") && + slices.Contains(args.Args, "add") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, nil + }) + + cli := NewCli(runner) + err := cli.AddRemote( + t.Context(), "/repo", "origin", + "https://github.com/user/repo", + ) + require.NoError(t, err) + require.Contains(t, captured.Args, "origin") + require.Contains( + t, captured.Args, + "https://github.com/user/repo", + ) + }) + + t.Run("Error", func(t *testing.T) { + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return true + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{}, + errors.New("remote add failed") + }) + + cli := NewCli(runner) + err := cli.AddRemote( + t.Context(), "/repo", "origin", "url", + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to add remote") + }) +} + +func TestUpdateRemote(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "set-url") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, nil + }) + + cli := NewCli(runner) + err := cli.UpdateRemote( + t.Context(), "/repo", "origin", "https://new.url", + ) + require.NoError(t, err) + require.Contains(t, captured.Args, "set-url") + require.Contains(t, captured.Args, "https://new.url") +} + +func TestCommit(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "commit") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, nil + }) + + cli := NewCli(runner) + err := cli.Commit(t.Context(), "/repo", "test commit") + require.NoError(t, err) + require.Contains(t, captured.Args, "--allow-empty") + require.Contains(t, captured.Args, "-m") + require.Contains(t, captured.Args, "test commit") +} + +func TestPushUpstream(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "push") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, nil + }) + + cli := NewCli(runner) + err := cli.PushUpstream( + t.Context(), "/repo", "origin", "main", + ) + require.NoError(t, err) + require.Contains(t, captured.Args, "--set-upstream") + require.Contains(t, captured.Args, "--quiet") + require.Contains(t, captured.Args, "origin") + require.Contains(t, captured.Args, "main") +} + +func TestListStagedFiles(t *testing.T) { + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "ls-files") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{ + Stdout: "100644 abc file.go\n", + }, nil + }) + + cli := NewCli(runner) + out, err := cli.ListStagedFiles(t.Context(), "/repo") + require.NoError(t, err) + require.Contains(t, out, "file.go") +} + +func TestAddFile(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "add") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, nil + }) + + cli := NewCli(runner) + err := cli.AddFile(t.Context(), "/repo", ".") + require.NoError(t, err) + require.Contains(t, captured.Args, "add") + require.Contains(t, captured.Args, ".") +} + +func TestSetCredentialStore(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "config") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, nil + }) + + cli := NewCli(runner) + err := cli.SetCredentialStore(t.Context(), "/repo") + require.NoError(t, err) + require.Contains(t, captured.Args, "credential.helper") + require.Contains(t, captured.Args, "store") +} + +func TestAddFileExecPermission(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "update-index") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, nil + }) + + cli := NewCli(runner) + err := cli.AddFileExecPermission( + t.Context(), "/repo", "script.sh", + ) + require.NoError(t, err) + require.Contains(t, captured.Args, "--chmod=+x") + require.Contains(t, captured.Args, "script.sh") +} diff --git a/cli/azd/pkg/tools/github/github_additional_test.go b/cli/azd/pkg/tools/github/github_additional_test.go new file mode 100644 index 00000000000..89d2f04e402 --- /dev/null +++ b/cli/azd/pkg/tools/github/github_additional_test.go @@ -0,0 +1,342 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package github + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGhOutputToList(t *testing.T) { + tests := []struct { + name string + input string + want []string + }{ + { + name: "MultipleSecrets", + input: "SECRET_A\tUpdated 2024-01-01\n" + + "SECRET_B\tUpdated 2024-01-02\n", + want: []string{"SECRET_A", "SECRET_B"}, + }, + { + name: "EmptyOutput", + input: "", + want: []string{}, + }, + { + name: "SingleLine", + input: "MY_SECRET\tUpdated 2024-01-01\n", + want: []string{"MY_SECRET"}, + }, + { + name: "NoTabs", + input: "SECRET_ONLY\n", + want: []string{"SECRET_ONLY"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ghOutputToList(tt.input) + require.Equal(t, tt.want, got) + }) + } +} + +func TestGhOutputToMap(t *testing.T) { + tests := []struct { + name string + input string + want map[string]string + wantErr bool + }{ + { + name: "MultipleVariables", + input: "VAR_A\tvalue_a\tUpdated\n" + + "VAR_B\tvalue_b\tUpdated\n", + want: map[string]string{ + "VAR_A": "value_a", + "VAR_B": "value_b", + }, + }, + { + name: "EmptyOutput", + input: "", + want: map[string]string{}, + }, + { + name: "SingleVariable", + input: "KEY\tVALUE\n", + want: map[string]string{"KEY": "VALUE"}, + }, + { + name: "BadFormat", + input: "no-tab-here\n", + wantErr: true, + }, + { + name: "MixedValidInvalid", + input: "VALID\tvalue\n" + + "invalid_no_tab\n", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ghOutputToMap(tt.input) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestGhCliVersionRegexp(t *testing.T) { + tests := []struct { + name string + input string + want string + matches bool + }{ + { + name: "StandardVersion", + input: "gh version 2.86.0 (2024-01-15)\n" + + "https://github.com/cli/cli/" + + "releases/tag/v2.86.0", + want: "2.86.0", + matches: true, + }, + { + name: "OlderVersion", + input: "gh version 2.6.0 (2022-03-15)", + want: "2.6.0", + matches: true, + }, + { + name: "NoMatch", + input: "some random text", + matches: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := ghCliVersionRegexp.FindStringSubmatch( + tt.input, + ) + if !tt.matches { + require.Len(t, matches, 0) + return + } + require.Len(t, matches, 2) + require.Equal(t, tt.want, matches[1]) + }) + } +} + +func TestIsGhCliNotLoggedInMessageRegex(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "AuthenticatePlease", + input: "To authenticate, please run " + + "`gh auth login`.", + want: true, + }, + { + name: "TryAuthenticating", + input: "Try authenticating with: " + + " gh auth login", + want: true, + }, + { + name: "ReAuthenticate", + input: "To re-authenticate, run: " + + "gh auth login", + want: true, + }, + { + name: "GetStarted", + input: "To get started with GitHub CLI, " + + "please run: gh auth login", + want: true, + }, + { + name: "NotMatching", + input: "everything is fine", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isGhCliNotLoggedInMessageRegex.MatchString( + tt.input, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func TestIsUserNotAuthorizedMessageRegex(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "Matching", + input: "HTTP 403: Resource not " + + "accessible by integration", + want: true, + }, + { + name: "NotMatching", + input: "HTTP 200: OK", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isUserNotAuthorizedMessageRegex.MatchString( + tt.input, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func TestNotLoggedIntoAnyGitHubHostsRegex(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "Matching", + input: "You are not logged into any " + + "GitHub hosts.", + want: true, + }, + { + name: "NotMatching", + input: "Logged in to github.com", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := notLoggedIntoAnyGitHubHostsMessageRegex. + MatchString(tt.input) + require.Equal(t, tt.want, got) + }) + } +} + +func TestRepositoryNameInUseRegex(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "Matching", + input: "GraphQL: Name already exists on " + + "this account (createRepository)", + want: true, + }, + { + name: "NotMatching", + input: "repository created successfully", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := repositoryNameInUseRegex.MatchString( + tt.input, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func TestRunningOnCodespaces(t *testing.T) { + t.Run("InCodespaces", func(t *testing.T) { + t.Setenv("CODESPACES", "true") + require.True(t, RunningOnCodespaces()) + }) + + t.Run("NotInCodespaces", func(t *testing.T) { + t.Setenv("CODESPACES", "false") + require.False(t, RunningOnCodespaces()) + }) + + t.Run("EnvNotSet", func(t *testing.T) { + t.Setenv("CODESPACES", "") + require.False(t, RunningOnCodespaces()) + }) +} + +func TestCliName(t *testing.T) { + cli := &Cli{} + require.Equal(t, "GitHub CLI", cli.Name()) +} + +func TestCliInstallUrl(t *testing.T) { + cli := &Cli{} + require.Equal( + t, + "https://aka.ms/azure-dev/github-cli-install", + cli.InstallUrl(), + ) +} + +func TestCliBinaryPath(t *testing.T) { + cli := &Cli{path: "/usr/local/bin/gh"} + require.Equal( + t, "/usr/local/bin/gh", cli.BinaryPath(), + ) +} + +func TestCliBinaryPathEmpty(t *testing.T) { + cli := &Cli{} + require.Equal(t, "", cli.BinaryPath()) +} + +func TestProtocolTypeConstants(t *testing.T) { + require.Equal(t, "ssh", GitSshProtocolType) + require.Equal(t, "https", GitHttpsProtocolType) +} + +func TestGhCliName(t *testing.T) { + name := ghCliName() + require.NotEmpty(t, name) + // On all platforms, it should either be "gh" or "gh.exe" + require.Contains(t, name, "gh") +} + +func TestGitHubHostName(t *testing.T) { + require.Equal(t, "github.com", GitHubHostName) +} + +func TestTokenEnvVars(t *testing.T) { + require.Contains(t, TokenEnvVars, "GITHUB_TOKEN") + require.Contains(t, TokenEnvVars, "GH_TOKEN") + require.Len(t, TokenEnvVars, 2) +} diff --git a/cli/azd/pkg/tools/kubectl/kubectl_additional_test.go b/cli/azd/pkg/tools/kubectl/kubectl_additional_test.go new file mode 100644 index 00000000000..a822881b4a1 --- /dev/null +++ b/cli/azd/pkg/tools/kubectl/kubectl_additional_test.go @@ -0,0 +1,465 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package kubectl + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/test/mocks" + "github.com/stretchr/testify/require" +) + +func Test_Cli_Name(t *testing.T) { + cli := NewCli(nil) + require.Equal(t, "kubectl", cli.Name()) +} + +func Test_Cli_InstallUrl(t *testing.T) { + cli := NewCli(nil) + require.Contains(t, cli.InstallUrl(), "kubectl-install") +} + +func Test_Cli_SetEnv(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + var capturedArgs exec.RunArgs + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + cli.SetEnv(map[string]string{ + "MY_VAR": "my_value", + }) + + _, err := cli.Exec(*mockCtx.Context, nil, "get", "pods") + require.NoError(t, err) + require.Contains(t, capturedArgs.Env, "MY_VAR=my_value") +} + +func Test_Cli_SetEnv_MergesValues(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + var capturedArgs exec.RunArgs + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + cli.SetEnv(map[string]string{"A": "1"}) + cli.SetEnv(map[string]string{"B": "2"}) + + _, err := cli.Exec(*mockCtx.Context, nil, "version") + require.NoError(t, err) + require.Contains(t, capturedArgs.Env, "A=1") + require.Contains(t, capturedArgs.Env, "B=2") +} + +func Test_Cli_SetKubeConfig(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + var capturedArgs exec.RunArgs + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + cli.SetKubeConfig("/path/to/config") + + _, err := cli.Exec(*mockCtx.Context, nil, "get", "pods") + require.NoError(t, err) + require.Contains(t, + capturedArgs.Env, "KUBECONFIG=/path/to/config", + ) +} + +func Test_Cli_Cwd(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + var capturedArgs exec.RunArgs + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + cli.Cwd("/my/workdir") + + _, err := cli.Exec(*mockCtx.Context, nil, "get", "pods") + require.NoError(t, err) + require.Equal(t, "/my/workdir", capturedArgs.Cwd) +} + +func Test_Cli_Exec_NilFlags(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + var capturedArgs exec.RunArgs + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "ok", ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + res, err := cli.Exec(*mockCtx.Context, nil, "version") + require.NoError(t, err) + require.Equal(t, "ok", res.Stdout) + require.Equal(t, []string{"version"}, capturedArgs.Args) +} + +func Test_Cli_Exec_AllFlags(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + var capturedArgs exec.RunArgs + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + flags := &KubeCliFlags{ + Namespace: "prod", + DryRun: DryRunTypeServer, + Output: OutputTypeYaml, + } + _, err := cli.Exec( + *mockCtx.Context, flags, "apply", "-f", "file.yaml", + ) + require.NoError(t, err) + require.Equal(t, "kubectl", capturedArgs.Cmd) + require.Equal(t, []string{ + "apply", "-f", "file.yaml", + "--dry-run=server", "-n", "prod", "-o", "yaml", + }, capturedArgs.Args) +} + +func Test_Cli_ApplyWithKustomize(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + var capturedArgs exec.RunArgs + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl apply -k") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.NewRunResult(0, "", ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + err := cli.ApplyWithKustomize( + *mockCtx.Context, "./overlays/prod", + &KubeCliFlags{Namespace: "prod"}, + ) + require.NoError(t, err) + require.Equal(t, "kubectl", capturedArgs.Cmd) + require.Equal(t, []string{ + "apply", "-k", "./overlays/prod", "-n", "prod", + }, capturedArgs.Args) +} + +func Test_Cli_ApplyWithKustomize_Error(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl apply -k") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{}, errors.New("not found") + }) + + cli := NewCli(mockCtx.CommandRunner) + err := cli.ApplyWithKustomize( + *mockCtx.Context, "./bad-path", nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "kubectl apply -k") +} + +func Test_Cli_CheckInstalled_Success(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.MockToolInPath("kubectl", nil) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl version") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + ver := `{"clientVersion":{"gitVersion":"v1.28.0"}}` + return exec.NewRunResult(0, ver, ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + err := cli.CheckInstalled(*mockCtx.Context) + require.NoError(t, err) +} + +func Test_Cli_CheckInstalled_NotInPath(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.MockToolInPath( + "kubectl", errors.New("not found"), + ) + + cli := NewCli(mockCtx.CommandRunner) + err := cli.CheckInstalled(*mockCtx.Context) + require.Error(t, err) +} + +func Test_Cli_CheckInstalled_VersionError(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.MockToolInPath("kubectl", nil) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl version") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{}, + errors.New("version fetch failed") + }) + + cli := NewCli(mockCtx.CommandRunner) + // CheckInstalled logs the error but does not fail + err := cli.CheckInstalled(*mockCtx.Context) + require.NoError(t, err) +} + +func Test_Cli_GetClientVersion(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl version") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + v := `{"clientVersion":{"gitVersion":"v1.30.1"}}` + return exec.NewRunResult(0, v, ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + ver, err := cli.getClientVersion(*mockCtx.Context) + require.NoError(t, err) + require.Equal(t, "v1.30.1", ver) +} + +func Test_Cli_GetClientVersion_BadJSON(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl version") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, "not-json", ""), nil + }) + + cli := NewCli(mockCtx.CommandRunner) + _, err := cli.getClientVersion(*mockCtx.Context) + require.Error(t, err) + require.Contains(t, err.Error(), "parsing kubectl version") +} + +func Test_Cli_ConfigUseContext_Error(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl config") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{}, + errors.New("context not found") + }) + + cli := NewCli(mockCtx.CommandRunner) + _, err := cli.ConfigUseContext( + *mockCtx.Context, "missing-ctx", nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed setting kubectl") +} + +func Test_Cli_CreateNamespace_Error(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl create namespace") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{}, + errors.New("already exists") + }) + + cli := NewCli(mockCtx.CommandRunner) + _, err := cli.CreateNamespace( + *mockCtx.Context, "existing-ns", nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "kubectl create namespace") +} + +func Test_Cli_RolloutStatus_Error(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl rollout") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{}, + errors.New("deadline exceeded") + }) + + cli := NewCli(mockCtx.CommandRunner) + _, err := cli.RolloutStatus( + *mockCtx.Context, "my-deploy", nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "rollout failed") +} + +func Test_Cli_ApplyWithStdIn_Error(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl apply -f -") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{}, + errors.New("invalid yaml") + }) + + cli := NewCli(mockCtx.CommandRunner) + _, err := cli.ApplyWithStdIn( + *mockCtx.Context, "bad-yaml", nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "kubectl apply") +} + +func Test_Cli_ApplyWithFile_Error(t *testing.T) { + mockCtx := mocks.NewMockContext(context.Background()) + mockCtx.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl apply -f") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{}, + errors.New("file not found") + }) + + cli := NewCli(mockCtx.CommandRunner) + _, err := cli.ApplyWithFile( + *mockCtx.Context, "/bad/path.yaml", nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "kubectl apply") +} + +func Test_ParseKubeConfig_Valid(t *testing.T) { + raw := []byte(`apiVersion: v1 +kind: Config +current-context: my-cluster +clusters: + - name: my-cluster + cluster: + server: https://my-cluster.example.com:6443 + certificate-authority-data: Y2VydA== +contexts: + - name: my-cluster + context: + cluster: my-cluster + namespace: default + user: my-user +users: + - name: my-user + user: + token: my-token +preferences: {}`) + + cfg, err := ParseKubeConfig(context.Background(), raw) + require.NoError(t, err) + require.Equal(t, "v1", cfg.ApiVersion) + require.Equal(t, "Config", cfg.Kind) + require.Equal(t, "my-cluster", cfg.CurrentContext) + require.Len(t, cfg.Clusters, 1) + require.Equal(t, "my-cluster", cfg.Clusters[0].Name) + require.Equal(t, + "https://my-cluster.example.com:6443", + cfg.Clusters[0].Cluster.Server, + ) + require.Len(t, cfg.Contexts, 1) + require.Equal(t, "default", + cfg.Contexts[0].Context.Namespace, + ) + require.Len(t, cfg.Users, 1) + require.Equal(t, "my-user", cfg.Users[0].Name) +} + +func Test_ParseKubeConfig_InvalidYaml(t *testing.T) { + raw := []byte(":\tbad yaml\n\t:") + _, err := ParseKubeConfig(context.Background(), raw) + require.Error(t, err) + require.Contains(t, err.Error(), "failed unmarshalling") +} + +func Test_ParseKubeConfig_Empty(t *testing.T) { + cfg, err := ParseKubeConfig(context.Background(), []byte("")) + require.NoError(t, err) + require.NotNil(t, cfg) + require.Empty(t, cfg.Clusters) +} + +func Test_KubeConfig_RoundTrip(t *testing.T) { + original := &KubeConfig{ + ApiVersion: "v1", + Kind: "Config", + CurrentContext: "ctx", + Preferences: KubePreferences{}, + Clusters: []*KubeCluster{{ + Name: "c1", + Cluster: KubeClusterData{ + Server: "https://c1:443", + }, + }}, + Contexts: []*KubeContext{{ + Name: "ctx", + Context: KubeContextData{ + Cluster: "c1", User: "u1", + }, + }}, + Users: []*KubeUser{{ + Name: "u1", + KubeUserData: KubeUserData{"token": "t"}, + }}, + } + + // Marshal to JSON and back + data, err := json.Marshal(original) + require.NoError(t, err) + require.NotEmpty(t, data) + + var restored KubeConfig + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + require.Equal(t, original.CurrentContext, restored.CurrentContext) +} diff --git a/cli/azd/pkg/tools/kubectl/models_additional_test.go b/cli/azd/pkg/tools/kubectl/models_additional_test.go new file mode 100644 index 00000000000..4b395aa8acf --- /dev/null +++ b/cli/azd/pkg/tools/kubectl/models_additional_test.go @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package kubectl + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_Port_UnmarshalJSON_Float64TargetPort(t *testing.T) { + input := `{"port":443,"targetPort":8443.0,"protocol":"TCP"}` + var port Port + err := json.Unmarshal([]byte(input), &port) + require.NoError(t, err) + require.Equal(t, 443, port.Port) + require.Equal(t, "TCP", port.Protocol) + // JSON numbers without explicit int type deserialize + // as float64 in Go. + require.Equal(t, float64(8443), port.TargetPort) +} + +func Test_Port_UnmarshalJSON_NilTargetPort(t *testing.T) { + input := `{"port":80,"targetPort":null,"protocol":"HTTP"}` + var port Port + err := json.Unmarshal([]byte(input), &port) + require.Error(t, err) + require.Contains(t, err.Error(), "unsupported type") +} + +func Test_Port_UnmarshalJSON_InvalidJSON(t *testing.T) { + input := `{"port":}` + var port Port + err := json.Unmarshal([]byte(input), &port) + require.Error(t, err) +} + +func Test_Deployment_JsonRoundTrip(t *testing.T) { + dep := Deployment{ + Resource: Resource{ + ApiVersion: "apps/v1", + Kind: "Deployment", + Metadata: ResourceMetadata{ + Name: "web", + Namespace: "prod", + }, + }, + Spec: DeploymentSpec{Replicas: 3}, + Status: DeploymentStatus{ + AvailableReplicas: 3, + ReadyReplicas: 3, + Replicas: 3, + UpdatedReplicas: 3, + }, + } + data, err := json.Marshal(dep) + require.NoError(t, err) + + var restored Deployment + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + require.Equal(t, "web", restored.Metadata.Name) + require.Equal(t, 3, restored.Spec.Replicas) + require.Equal(t, 3, restored.Status.AvailableReplicas) +} + +func Test_Service_JsonRoundTrip(t *testing.T) { + svc := Service{ + Resource: Resource{ + ApiVersion: "v1", + Kind: "Service", + Metadata: ResourceMetadata{ + Name: "api", + Namespace: "staging", + }, + }, + Spec: ServiceSpec{ + Type: ServiceTypeLoadBalancer, + ClusterIp: "10.0.0.5", + ClusterIps: []string{"10.0.0.5"}, + Ports: []Port{ + {Port: 80, TargetPort: "http", Protocol: "TCP"}, + }, + }, + Status: ServiceStatus{ + LoadBalancer: LoadBalancer{ + Ingress: []LoadBalancerIngress{ + {Ip: "52.1.2.3"}, + }, + }, + }, + } + data, err := json.Marshal(svc) + require.NoError(t, err) + + var restored Service + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + require.Equal(t, ServiceTypeLoadBalancer, restored.Spec.Type) + require.Equal(t, "52.1.2.3", + restored.Status.LoadBalancer.Ingress[0].Ip, + ) +} + +func Test_List_JsonRoundTrip(t *testing.T) { + list := List[Deployment]{ + Resource: Resource{ + ApiVersion: "v1", + Kind: "DeploymentList", + }, + Items: []Deployment{ + { + Resource: Resource{ + Metadata: ResourceMetadata{Name: "a"}, + }, + Spec: DeploymentSpec{Replicas: 1}, + }, + { + Resource: Resource{ + Metadata: ResourceMetadata{Name: "b"}, + }, + Spec: DeploymentSpec{Replicas: 2}, + }, + }, + } + data, err := json.Marshal(list) + require.NoError(t, err) + + var restored List[Deployment] + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + require.Len(t, restored.Items, 2) + require.Equal(t, "a", restored.Items[0].Metadata.Name) + require.Equal(t, "b", restored.Items[1].Metadata.Name) +} + +func Test_ResourceType_Constants(t *testing.T) { + require.Equal(t, ResourceType("deployment"), + ResourceTypeDeployment) + require.Equal(t, ResourceType("ing"), + ResourceTypeIngress) + require.Equal(t, ResourceType("svc"), + ResourceTypeService) +} + +func Test_ServiceType_Constants(t *testing.T) { + require.Equal(t, ServiceType("ClusterIP"), + ServiceTypeClusterIp) + require.Equal(t, ServiceType("LoadBalancer"), + ServiceTypeLoadBalancer) + require.Equal(t, ServiceType("NodePort"), + ServiceTypeNodePort) + require.Equal(t, ServiceType("ExternalName"), + ServiceTypeExternalName) +} + +func Test_OutputType_Constants(t *testing.T) { + require.Equal(t, OutputType("json"), OutputTypeJson) + require.Equal(t, OutputType("yaml"), OutputTypeYaml) +} + +func Test_DryRunType_Constants(t *testing.T) { + require.Equal(t, DryRunType("none"), DryRunTypeNone) + require.Equal(t, DryRunType("client"), DryRunTypeClient) + require.Equal(t, DryRunType("server"), DryRunTypeServer) +} + +func Test_KubeConfigEnvVarName(t *testing.T) { + require.Equal(t, "KUBECONFIG", KubeConfigEnvVarName) +} + +func Test_ResourceMetadata_Annotations(t *testing.T) { + jsonStr := `{ + "name":"test","namespace":"ns", + "Annotations":{"key":"value"} + }` + var meta ResourceMetadata + err := json.Unmarshal([]byte(jsonStr), &meta) + require.NoError(t, err) + require.Equal(t, "test", meta.Name) + require.Equal(t, "value", meta.Annotations["key"]) +} + +func Test_Ingress_JsonRoundTrip(t *testing.T) { + host := "example.com" + ing := Ingress{ + Resource: Resource{ + ApiVersion: "networking.k8s.io/v1", + Kind: "Ingress", + Metadata: ResourceMetadata{ + Name: "my-ingress", + }, + }, + Spec: IngressSpec{ + IngressClassName: "nginx", + Tls: []IngressTls{{ + Hosts: []string{"example.com"}, + SecretName: "tls-secret", + }}, + Rules: []IngressRule{{ + Host: &host, + Http: IngressRuleHttp{ + Paths: []IngressPath{{ + Path: "/", + PathType: "Prefix", + }}, + }, + }}, + }, + Status: IngressStatus{ + LoadBalancer: LoadBalancer{ + Ingress: []LoadBalancerIngress{ + {Ip: "10.0.0.1"}, + }, + }, + }, + } + data, err := json.Marshal(ing) + require.NoError(t, err) + + var restored Ingress + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + require.Equal(t, "nginx", restored.Spec.IngressClassName) + require.Equal(t, "example.com", + *restored.Spec.Rules[0].Host, + ) + require.Equal(t, "tls-secret", + restored.Spec.Tls[0].SecretName, + ) + require.Equal(t, "/", + restored.Spec.Rules[0].Http.Paths[0].Path, + ) + require.Equal(t, "10.0.0.1", + restored.Status.LoadBalancer.Ingress[0].Ip, + ) +} + +func Test_IngressRule_NilHost(t *testing.T) { + rule := IngressRule{ + Host: nil, + Http: IngressRuleHttp{ + Paths: []IngressPath{{Path: "/api"}}, + }, + } + data, err := json.Marshal(rule) + require.NoError(t, err) + + var restored IngressRule + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + require.Nil(t, restored.Host) +} diff --git a/cli/azd/pkg/tools/kubectl/util_additional_test.go b/cli/azd/pkg/tools/kubectl/util_additional_test.go new file mode 100644 index 00000000000..38b9cd98952 --- /dev/null +++ b/cli/azd/pkg/tools/kubectl/util_additional_test.go @@ -0,0 +1,357 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package kubectl + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/test/mocks" + "github.com/stretchr/testify/require" +) + +func Test_GetResource_JsonOutput(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + deployment := Deployment{ + Resource: Resource{ + ApiVersion: "apps/v1", + Kind: "Deployment", + Metadata: ResourceMetadata{ + Name: "my-deploy", + Namespace: "default", + }, + }, + Spec: DeploymentSpec{Replicas: 3}, + Status: DeploymentStatus{ReadyReplicas: 3}, + } + depJSON, err := json.Marshal(deployment) + require.NoError(t, err) + + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get deployment") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, string(depJSON), ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + result, err := GetResource[Deployment]( + *mockContext.Context, cli, + ResourceTypeDeployment, "my-deploy", nil, + ) + require.NoError(t, err) + require.Equal(t, "my-deploy", result.Metadata.Name) + require.Equal(t, 3, result.Spec.Replicas) + require.Equal(t, 3, result.Status.ReadyReplicas) +} + +func Test_GetResource_ExplicitJsonFlag(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + svcJSON := `{ + "apiVersion":"v1","kind":"Service", + "metadata":{"name":"my-svc","namespace":"ns"}, + "spec":{"type":"ClusterIP","clusterIP":"10.0.0.1", + "ports":[{"port":80,"targetPort":8080, + "protocol":"TCP"}]}, + "status":{"loadBalancer":{}} + }` + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get svc") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, svcJSON, ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + flags := &KubeCliFlags{Output: OutputTypeJson} + result, err := GetResource[Service]( + *mockContext.Context, cli, + ResourceTypeService, "my-svc", flags, + ) + require.NoError(t, err) + require.Equal(t, ServiceTypeClusterIp, result.Spec.Type) + require.Equal(t, "10.0.0.1", result.Spec.ClusterIp) +} + +func Test_GetResource_UnsupportedOutputFormat(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, "some output", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + flags := &KubeCliFlags{Output: OutputType("xml")} + _, err := GetResource[Deployment]( + *mockContext.Context, cli, + ResourceTypeDeployment, "my-deploy", flags, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "not supported") +} + +func Test_GetResource_ExecError(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{}, + errors.New("connection refused") + }) + + cli := NewCli(mockContext.CommandRunner) + _, err := GetResource[Deployment]( + *mockContext.Context, cli, + ResourceTypeDeployment, "my-deploy", nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed getting resources") +} + +func Test_GetResource_InvalidJson(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, "not-json{", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + _, err := GetResource[Deployment]( + *mockContext.Context, cli, + ResourceTypeDeployment, "my-deploy", nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed unmarshalling") +} + +func Test_GetResources_JsonOutput(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + listJSON := `{ + "apiVersion":"v1","kind":"DeploymentList", + "items":[ + {"apiVersion":"apps/v1","kind":"Deployment", + "metadata":{"name":"deploy-a","namespace":"ns"}, + "spec":{"replicas":2}, + "status":{"readyReplicas":2}}, + {"apiVersion":"apps/v1","kind":"Deployment", + "metadata":{"name":"deploy-b","namespace":"ns"}, + "spec":{"replicas":1}, + "status":{"readyReplicas":0}} + ] + }` + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get deployment") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, listJSON, ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + list, err := GetResources[Deployment]( + *mockContext.Context, cli, ResourceTypeDeployment, nil, + ) + require.NoError(t, err) + require.Len(t, list.Items, 2) + require.Equal(t, "deploy-a", list.Items[0].Metadata.Name) + require.Equal(t, "deploy-b", list.Items[1].Metadata.Name) +} + +func Test_GetResources_UnsupportedFormat(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, "out", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + flags := &KubeCliFlags{Output: OutputType("table")} + _, err := GetResources[Deployment]( + *mockContext.Context, cli, ResourceTypeDeployment, flags, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "not supported") +} + +func Test_GetResources_ExecError(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{}, errors.New("timeout") + }) + + cli := NewCli(mockContext.CommandRunner) + _, err := GetResources[Deployment]( + *mockContext.Context, cli, ResourceTypeDeployment, nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed getting resources") +} + +func Test_GetResources_InvalidJson(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, "{bad", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + _, err := GetResources[Deployment]( + *mockContext.Context, cli, ResourceTypeDeployment, nil, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed unmarshalling") +} + +func Test_Environ(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + result := environ(map[string]string{}) + require.Empty(t, result) + }) + + t.Run("SingleEntry", func(t *testing.T) { + result := environ(map[string]string{"FOO": "bar"}) + require.Len(t, result, 1) + require.Equal(t, "FOO=bar", result[0]) + }) + + t.Run("MultipleEntries", func(t *testing.T) { + input := map[string]string{ + "A": "1", + "B": "2", + } + result := environ(input) + require.Len(t, result, 2) + // Map iteration order is non-deterministic + require.ElementsMatch(t, + []string{"A=1", "B=2"}, result, + ) + }) +} + +func Test_GetResource_YamlOutput(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + yamlOutput := `apiVersion: apps/v1 +kind: Deployment +metadata: + name: yaml-deploy + namespace: default +spec: + replicas: 5 +status: + readyReplicas: 5 + availableReplicas: 5 + replicas: 5 + updatedReplicas: 5` + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, yamlOutput, ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + flags := &KubeCliFlags{Output: OutputTypeYaml} + result, err := GetResource[Deployment]( + *mockContext.Context, cli, + ResourceTypeDeployment, "yaml-deploy", flags, + ) + require.NoError(t, err) + require.Equal(t, 5, result.Spec.Replicas) +} + +func Test_GetResources_YamlOutput(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + yamlOutput := `apiVersion: v1 +kind: DeploymentList +items: + - apiVersion: apps/v1 + kind: Deployment + metadata: + name: d1 + namespace: ns + spec: + replicas: 1 + status: + readyReplicas: 1` + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, yamlOutput, ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + flags := &KubeCliFlags{Output: OutputTypeYaml} + list, err := GetResources[Deployment]( + *mockContext.Context, cli, ResourceTypeDeployment, flags, + ) + require.NoError(t, err) + require.Len(t, list.Items, 1) + require.Equal(t, 1, list.Items[0].Spec.Replicas) +} + +func Test_GetResource_YamlInvalidOutput(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, ":\tbad yaml\n\t:", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + flags := &KubeCliFlags{Output: OutputTypeYaml} + _, err := GetResource[Deployment]( + *mockContext.Context, cli, + ResourceTypeDeployment, "x", flags, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed unmarshalling") +} + +func Test_GetResources_YamlInvalidOutput(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + mockContext.CommandRunner.When( + func(args exec.RunArgs, cmd string) bool { + return strings.Contains(cmd, "kubectl get") + }).RespondFn( + func(args exec.RunArgs) (exec.RunResult, error) { + return exec.NewRunResult(0, ":\tbad\n:", ""), nil + }) + + cli := NewCli(mockContext.CommandRunner) + flags := &KubeCliFlags{Output: OutputTypeYaml} + _, err := GetResources[Deployment]( + *mockContext.Context, cli, ResourceTypeDeployment, flags, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "failed unmarshalling") +} diff --git a/cli/azd/pkg/tools/maven/maven_additional_test.go b/cli/azd/pkg/tools/maven/maven_additional_test.go new file mode 100644 index 00000000000..2b4f6e82ed7 --- /dev/null +++ b/cli/azd/pkg/tools/maven/maven_additional_test.go @@ -0,0 +1,422 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package maven + +import ( + "errors" + "slices" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/test/mocks/mockexec" + "github.com/stretchr/testify/require" +) + +// newTestCli creates a Cli with a mock command runner +// and marks the lazy maven-path init as done so that +// mvnCmd() returns "mvn" without searching the disk. +func newTestCli( + runner exec.CommandRunner, +) *Cli { + cli := &Cli{ + commandRunner: runner, + mvnCmdStr: "mvn", + } + // Mark lazy init as done; the no-op succeeds so + // subsequent mvnCmd() calls skip getMavenPath. + cli.mvnCmdInit.Do(func() error { return nil }) + return cli +} + +func TestName(t *testing.T) { + cli := NewCli(nil) + require.Equal(t, "Maven", cli.Name()) +} + +func TestInstallUrl(t *testing.T) { + cli := NewCli(nil) + require.Equal( + t, "https://maven.apache.org", cli.InstallUrl(), + ) +} + +func TestSetPath(t *testing.T) { + cli := NewCli(nil) + cli.SetPath("/project", "/root") + require.Equal(t, "/project", cli.projectPath) + require.Equal(t, "/root", cli.rootProjectPath) +} + +func TestMavenVersionRegexp(t *testing.T) { + tests := []struct { + name string + input string + want string + matches bool + }{ + { + name: "Standard", + input: "Apache Maven 3.9.1 " + + "(2e178502fcdbffc201671fb2537d0cb4b4cc58f8)", + want: "3.9.1", + matches: true, + }, + { + name: "OlderVersion", + input: "Apache Maven 3.6.3 " + + "(cecedd343002696d0abb50b32b541b8a6ba2883f)", + want: "3.6.3", + matches: true, + }, + { + name: "NoMatch", + input: "Gradle 8.0.1", + matches: false, + }, + { + name: "PartialMatch", + input: "Apache Maven", + matches: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := mavenVersionRegexp.FindStringSubmatch( + tt.input, + ) + + if !tt.matches { + require.Empty(t, matches) + return + } + + require.Len(t, matches, 2) + require.Equal(t, tt.want, matches[1]) + }) + } +} + +func TestGetEffectivePomStringFromConsoleOutput( + t *testing.T, +) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + { + name: "Standard", + input: "[INFO] Scanning for projects...\n" + + "[INFO]\n" + + "\n" + + " 4.0.0\n" + + "\n" + + "[INFO] Done\n", + want: "" + + " 4.0.0" + + "", + }, + { + name: "EmptyOutput", + input: "", + wantErr: true, + }, + { + name: "NoProjectTag", + input: "[INFO] Scanning for projects...\n" + + "[INFO] Done\n", + wantErr: true, + }, + { + name: "ProjectStartOnly", + input: "\n" + + " 4.0.0\n", + want: "" + + " 4.0.0", + }, + { + name: "IndentedProjectTags", + input: " \n" + + " test\n" + + " \n", + want: " " + + " test" + + " ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getEffectivePomStringFromConsoleOutput( + tt.input, + ) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestCompile(t *testing.T) { + tests := []struct { + name string + runErr error + wantErr bool + }{ + { + name: "Success", + }, + { + name: "Failure", + runErr: errors.New("compilation failed"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "compile") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, tt.runErr + }) + + cli := newTestCli(runner) + err := cli.Compile( + t.Context(), "/project", + []string{"ENV=val"}, + ) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Contains(t, captured.Args, "compile") + require.Equal(t, "/project", captured.Cwd) + }) + } +} + +func TestPackage(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "package") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, nil + }) + + cli := newTestCli(runner) + err := cli.Package(t.Context(), "/project", nil) + require.NoError(t, err) + require.Contains(t, captured.Args, "package") + require.Contains(t, captured.Args, "-DskipTests") +} + +func TestResolveDependencies(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains( + args.Args, "dependency:resolve", + ) + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{}, nil + }) + + cli := newTestCli(runner) + err := cli.ResolveDependencies( + t.Context(), "/project", nil, + ) + require.NoError(t, err) + require.Contains( + t, captured.Args, "dependency:resolve", + ) +} + +func TestGetProperty(t *testing.T) { + tests := []struct { + name string + stdout string + runErr error + want string + wantErr error + }{ + { + name: "Success", + stdout: " com.example.myapp ", + want: "com.example.myapp", + }, + { + name: "PropertyNotFound", + stdout: "null object or invalid expression", + wantErr: ErrPropertyNotFound, + }, + { + name: "RunError", + runErr: errors.New("mvn failed"), + wantErr: errors.New("mvn help:evaluate"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains( + args.Args, "help:evaluate", + ) + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{ + Stdout: tt.stdout, + }, tt.runErr + }) + + cli := newTestCli(runner) + got, err := cli.GetProperty( + t.Context(), + "project.groupId", + "/project", + ) + + if tt.wantErr != nil { + require.Error(t, err) + if errors.Is(tt.wantErr, ErrPropertyNotFound) { + require.ErrorIs(t, err, ErrPropertyNotFound) + } + return + } + + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestGetPropertyArgs(t *testing.T) { + var captured exec.RunArgs + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains(args.Args, "help:evaluate") + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + captured = args + return exec.RunResult{ + Stdout: "result-value", + }, nil + }) + + cli := newTestCli(runner) + _, err := cli.GetProperty( + t.Context(), "project.version", "/proj", + ) + require.NoError(t, err) + require.Contains(t, captured.Args, "help:evaluate") + require.Contains( + t, captured.Args, "-Dexpression=project.version", + ) + require.Contains(t, captured.Args, "-q") + require.Contains(t, captured.Args, "-DforceStdout") +} + +func TestEffectivePom(t *testing.T) { + pomOutput := "[INFO] Scanning\n" + + "\n" + + " 4.0.0\n" + + "\n" + + "[INFO] Done\n" + + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains( + args.Args, "help:effective-pom", + ) + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{ + Stdout: pomOutput, + }, nil + }) + + cli := newTestCli(runner) + got, err := cli.EffectivePom( + t.Context(), "/project/pom.xml", + ) + require.NoError(t, err) + require.Contains(t, got, "") +} + +func TestEffectivePomError(t *testing.T) { + runner := mockexec.NewMockCommandRunner() + runner.When(func( + args exec.RunArgs, _ string, + ) bool { + return slices.Contains( + args.Args, "help:effective-pom", + ) + }).RespondFn(func( + args exec.RunArgs, + ) (exec.RunResult, error) { + return exec.RunResult{}, + errors.New("pom failed") + }) + + cli := newTestCli(runner) + _, err := cli.EffectivePom( + t.Context(), "/project/pom.xml", + ) + require.Error(t, err) + require.Contains(t, err.Error(), "help:effective-pom") +} + +func TestErrPropertyNotFound(t *testing.T) { + require.EqualError( + t, ErrPropertyNotFound, "property not found", + ) +} diff --git a/cli/azd/pkg/ux/confirm_prompt_test.go b/cli/azd/pkg/ux/confirm_prompt_test.go new file mode 100644 index 00000000000..12ba4c0c6aa --- /dev/null +++ b/cli/azd/pkg/ux/confirm_prompt_test.go @@ -0,0 +1,364 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ux + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Confirm tests --- + +func TestNewConfirm_defaults(t *testing.T) { + c := NewConfirm(&ConfirmOptions{ + Message: "Continue?", + }) + require.NotNil(t, c) + assert.Equal(t, "Continue?", c.options.Message) + assert.Equal(t, "[y/n]", c.options.Hint) + assert.Nil(t, c.value) +} + +func TestNewConfirm_with_default_true(t *testing.T) { + c := NewConfirm(&ConfirmOptions{ + Message: "Continue?", + DefaultValue: new(true), + }) + require.NotNil(t, c) + assert.Equal(t, "[Y/n]", c.options.Hint) + assert.Equal(t, "Yes", c.displayValue) + require.NotNil(t, c.value) + assert.True(t, *c.value) +} + +func TestNewConfirm_with_default_false(t *testing.T) { + c := NewConfirm(&ConfirmOptions{ + Message: "Continue?", + DefaultValue: new(false), + }) + assert.Equal(t, "[y/N]", c.options.Hint) + assert.Equal(t, "No", c.displayValue) + require.NotNil(t, c.value) + assert.False(t, *c.value) +} + +func TestNewConfirm_custom_hint(t *testing.T) { + c := NewConfirm(&ConfirmOptions{ + Message: "Continue?", + Hint: "[custom]", + }) + assert.Equal(t, "[custom]", c.options.Hint) +} + +func TestConfirm_Render_initial(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + c := NewConfirm(&ConfirmOptions{ + Message: "Continue?", + HelpMessage: "Some help", + }) + + err := c.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Continue?") + assert.Contains(t, output, "[y/n]") +} + +func TestConfirm_Render_complete(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + c := NewConfirm(&ConfirmOptions{Message: "OK?"}) + c.complete = true + c.displayValue = "Yes" + + err := c.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "OK?") + assert.Contains(t, output, "Yes") +} + +func TestConfirm_Render_cancelled(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + c := NewConfirm(&ConfirmOptions{Message: "OK?"}) + c.cancelled = true + + err := c.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Cancelled") +} + +func TestConfirm_Render_validation_error(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + c := NewConfirm(&ConfirmOptions{Message: "OK?"}) + c.hasValidationError = true + + err := c.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Enter a valid value") +} + +func TestConfirm_Render_with_help(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + c := NewConfirm(&ConfirmOptions{ + Message: "OK?", + HelpMessage: "Pick yes or no", + }) + c.showHelp = true + + err := c.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Hint:") + assert.Contains(t, output, "Pick yes or no") +} + +func TestConfirm_WithCanvas(t *testing.T) { + c := NewConfirm(&ConfirmOptions{Message: "OK?"}) + var buf bytes.Buffer + canvas := NewCanvas().WithWriter(&buf) + defer canvas.Close() + + result := c.WithCanvas(canvas) + assert.Equal(t, c, result) +} + +// --- Prompt tests --- + +func TestNewPrompt_defaults(t *testing.T) { + p := NewPrompt(&PromptOptions{ + Message: "Enter name", + }) + require.NotNil(t, p) + assert.Equal(t, "Enter name", p.options.Message) + assert.False(t, p.options.Required) + assert.Equal(t, "", p.options.Hint) +} + +func TestNewPrompt_auto_hint_with_help(t *testing.T) { + p := NewPrompt(&PromptOptions{ + Message: "Enter name", + HelpMessage: "Your full name", + }) + assert.Equal(t, "[Type ? for hint]", p.options.Hint) +} + +func TestNewPrompt_custom_hint_preserved(t *testing.T) { + p := NewPrompt(&PromptOptions{ + Message: "Enter name", + HelpMessage: "Your full name", + Hint: "[custom hint]", + }) + assert.Equal(t, "[custom hint]", p.options.Hint) +} + +func TestNewPrompt_with_default_value(t *testing.T) { + p := NewPrompt(&PromptOptions{ + Message: "Port", + DefaultValue: "8080", + }) + assert.Equal(t, "8080", p.value) +} + +func TestPrompt_validate_required_empty(t *testing.T) { + p := NewPrompt(&PromptOptions{ + Message: "Name", + Required: true, + }) + p.value = "" + p.validate() + assert.True(t, p.hasValidationError) + assert.Equal(t, + "This field is required", p.validationMessage, + ) +} + +func TestPrompt_validate_required_filled(t *testing.T) { + p := NewPrompt(&PromptOptions{ + Message: "Name", + Required: true, + }) + p.value = "Jon" + p.validate() + assert.False(t, p.hasValidationError) +} + +func TestPrompt_validate_custom_fn_fail(t *testing.T) { + p := NewPrompt(&PromptOptions{ + Message: "Port", + ValidationFn: func(s string) (bool, string) { + return false, "must be numeric" + }, + }) + p.value = "abc" + p.validate() + assert.True(t, p.hasValidationError) + assert.Equal(t, "must be numeric", p.validationMessage) +} + +func TestPrompt_validate_custom_fn_pass(t *testing.T) { + p := NewPrompt(&PromptOptions{ + Message: "Port", + ValidationFn: func(s string) (bool, string) { + return true, "" + }, + }) + p.value = "8080" + p.validate() + assert.False(t, p.hasValidationError) +} + +func TestPrompt_Render_initial(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + p := NewPrompt(&PromptOptions{Message: "Name"}) + + err := p.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Name") +} + +func TestPrompt_Render_with_placeholder(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + p := NewPrompt(&PromptOptions{ + Message: "Name", + PlaceHolder: "Type here...", + }) + p.value = "" + + err := p.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Type here...") +} + +func TestPrompt_Render_complete(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + p := NewPrompt(&PromptOptions{Message: "Name"}) + p.complete = true + p.value = "Jon" + + err := p.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Jon") +} + +func TestPrompt_Render_clear_on_completion(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + p := NewPrompt(&PromptOptions{ + Message: "Name", + ClearOnCompletion: true, + }) + p.complete = true + + err := p.Render(printer) + require.NoError(t, err) + + assert.Empty(t, buf.String()) +} + +func TestPrompt_Render_cancelled(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + p := NewPrompt(&PromptOptions{Message: "Name"}) + p.cancelled = true + + err := p.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Cancelled") +} + +func TestPrompt_Render_validation_shown(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + p := NewPrompt(&PromptOptions{Message: "Port"}) + p.submitted = true + p.hasValidationError = true + p.validationMessage = "Invalid port" + + err := p.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Invalid port") +} + +func TestPrompt_Render_help_message(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + p := NewPrompt(&PromptOptions{ + Message: "Port", + HelpMessage: "Enter a port number", + }) + p.showHelp = true + + err := p.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Hint:") + assert.Contains(t, output, "Enter a port number") +} + +func TestPrompt_Render_help_message_next_line(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + p := NewPrompt(&PromptOptions{ + Message: "Name", + HelpMessageOnNextLine: "Below the input", + }) + + err := p.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Below the input") +} + +func TestPrompt_WithCanvas(t *testing.T) { + p := NewPrompt(&PromptOptions{Message: "X"}) + var buf bytes.Buffer + c := NewCanvas().WithWriter(&buf) + defer c.Close() + + result := p.WithCanvas(c) + assert.Equal(t, p, result) +} diff --git a/cli/azd/pkg/ux/formatting_test.go b/cli/azd/pkg/ux/formatting_test.go new file mode 100644 index 00000000000..e13787ca2e4 --- /dev/null +++ b/cli/azd/pkg/ux/formatting_test.go @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ux + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVisibleLength(t *testing.T) { + tests := []struct { + name string + input string + want int + }{ + {"empty string", "", 0}, + {"plain ASCII", "hello", 5}, + {"single char", "x", 1}, + {"with spaces", "a b c", 5}, + { + "single ANSI color code", + "\x1b[31mred\x1b[0m", + 3, + }, + { + "multiple ANSI codes", + "\x1b[1m\x1b[31mbold red\x1b[0m", + 8, + }, + { + "nested ANSI codes around text", + "\x1b[32mgreen\x1b[0m and \x1b[34mblue\x1b[0m", + 14, + }, + { + "ANSI with no visible text", + "\x1b[0m\x1b[1m\x1b[0m", + 0, + }, + {"unicode characters", "héllo", 5}, + {"CJK characters", "日本語", 3}, + {"emoji single codepoint", "★", 1}, + { + "mixed ANSI and unicode", + "\x1b[36m日本\x1b[0m", + 2, + }, + {"tab and printable", "a\tb", 3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := VisibleLength(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCountLineBreaks(t *testing.T) { + tests := []struct { + name string + content string + width int + want int + }{ + {"empty string", "", 80, 0}, + {"no newlines short", "hello", 80, 0}, + { + "single newline", + "hello\n", + 80, 1, + }, + { + "two newlines", + "line1\nline2\n", + 80, 2, + }, + { + "wrapping single line", + "abcdefghij", + 5, 1, + }, + { + "exact width no wrap", + "abcde", + 5, 0, + }, + { + "wrapping twice", + "abcdefghijklmno", + 5, 2, + }, + { + "newline plus wrapping", + "abcdefghij\nab", + 5, 2, + }, + { + "multiple lines with wrapping", + "abcdefghij\nabcdefghij\n", + 5, 4, + }, + { + "width of 1", + "abc", + 1, 2, + }, + { + "ANSI codes not counted for wrap", + "\x1b[31m" + "ab" + "\x1b[0m", + 10, 0, + }, + { + "only newlines", + "\n\n\n", + 80, 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CountLineBreaks(tt.content, tt.width) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSpecialTextRegex(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + "removes single color code", + "\x1b[31mtext\x1b[0m", + "text", + }, + { + "removes bold code", + "\x1b[1mbold\x1b[0m", + "bold", + }, + { + "removes compound codes", + "\x1b[1;31mbold red\x1b[0m", + "bold red", + }, + { + "no codes unchanged", + "plain text", + "plain text", + }, + { + "empty string unchanged", + "", + "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := specialTextRegex.ReplaceAllString(tt.input, "") + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/cli/azd/pkg/ux/helpers_test.go b/cli/azd/pkg/ux/helpers_test.go new file mode 100644 index 00000000000..d1bd184d83a --- /dev/null +++ b/cli/azd/pkg/ux/helpers_test.go @@ -0,0 +1,183 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ux + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDurationAsText(t *testing.T) { + tests := []struct { + name string + duration time.Duration + want string + }{ + { + "sub second", + 500 * time.Millisecond, + "less than a second", + }, + { + "zero duration", + 0, + "less than a second", + }, + { + "one second singular", + 1 * time.Second, + "1 second", + }, + { + "multiple seconds", + 5 * time.Second, + "5 seconds", + }, + { + "one minute singular", + 1 * time.Minute, + "1 minute", + }, + { + "minutes and seconds", + 2*time.Minute + 30*time.Second, + "2 minutes 30 seconds", + }, + { + "one hour singular", + 1 * time.Hour, + "1 hour", + }, + { + "hours minutes seconds", + 2*time.Hour + 3*time.Minute + 4*time.Second, + "2 hours 3 minutes 4 seconds", + }, + { + "exact minutes no seconds", + 5 * time.Minute, + "5 minutes", + }, + { + "hour and seconds no minutes", + 1*time.Hour + 10*time.Second, + "1 hour 10 seconds", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := durationAsText(tt.duration) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestWritePart(t *testing.T) { + tests := []struct { + name string + existing string + part string + unit string + want string + }{ + { + "singular unit", + "", "1", "hour", + "1 hour", + }, + { + "plural unit", + "", "5", "minute", + "5 minutes", + }, + { + "empty part skipped", + "", "", "second", + "", + }, + { + "zero part skipped", + "", "0", "hour", + "", + }, + { + "appends with space", + "1 hour", "30", "minute", + "1 hour 30 minutes", + }, + { + "singular after existing", + "2 hours", "1", "minute", + "2 hours 1 minute", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var builder strings.Builder + builder.WriteString(tt.existing) + writePart(&builder, tt.part, tt.unit) + assert.Equal(t, tt.want, builder.String()) + }) + } +} + +func TestGetBooleanString(t *testing.T) { + assert.Equal(t, "Yes", getBooleanString(true)) + assert.Equal(t, "No", getBooleanString(false)) +} + +func TestParseBooleanString(t *testing.T) { + yesInputs := []string{ + "y", "yes", "true", "1", + "Y", "YES", "True", "TRUE", + } + for _, input := range yesInputs { + t.Run("yes_"+input, func(t *testing.T) { + got, err := parseBooleanString(input) + require.NoError(t, err) + require.NotNil(t, got) + assert.True(t, *got) + }) + } + + noInputs := []string{ + "n", "no", "false", "0", + "N", "NO", "False", "FALSE", + } + for _, input := range noInputs { + t.Run("no_"+input, func(t *testing.T) { + got, err := parseBooleanString(input) + require.NoError(t, err) + require.NotNil(t, got) + assert.False(t, *got) + }) + } + + invalidInputs := []string{ + "maybe", "yep", "nope", "2", "", + "absolutely", "x", + } + for _, input := range invalidInputs { + t.Run("invalid_"+input, func(t *testing.T) { + got, err := parseBooleanString(input) + assert.Error(t, err) + assert.Nil(t, got) + }) + } +} + +func TestTaskStateConstants(t *testing.T) { + assert.Equal(t, TaskState(0), Pending) + assert.Equal(t, TaskState(1), Running) + assert.Equal(t, TaskState(2), Skipped) + assert.Equal(t, TaskState(3), Warning) + assert.Equal(t, TaskState(4), Error) + assert.Equal(t, TaskState(5), Success) +} diff --git a/cli/azd/pkg/ux/render_test.go b/cli/azd/pkg/ux/render_test.go new file mode 100644 index 00000000000..fa24d71bead --- /dev/null +++ b/cli/azd/pkg/ux/render_test.go @@ -0,0 +1,317 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ux + +import ( + "bytes" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- TaskList tests --- + +func TestNewTaskList_nil_options(t *testing.T) { + tl := NewTaskList(nil) + require.NotNil(t, tl) + assert.NotNil(t, tl.options) + assert.Equal(t, 5, tl.options.MaxConcurrentAsync) + assert.False(t, tl.options.ContinueOnError) +} + +func TestNewTaskList_custom_options(t *testing.T) { + tl := NewTaskList(&TaskListOptions{ + ContinueOnError: true, + MaxConcurrentAsync: 10, + }) + require.NotNil(t, tl) + assert.True(t, tl.options.ContinueOnError) + assert.Equal(t, 10, tl.options.MaxConcurrentAsync) +} + +func TestTaskList_AddTask(t *testing.T) { + tl := NewTaskList(nil) + + result := tl.AddTask(TaskOptions{ + Title: "Task 1", + Action: func(sp SetProgressFunc) (TaskState, error) { + return Success, nil + }, + }) + + // AddTask should return the TaskList for chaining + assert.Equal(t, tl, result) + assert.Len(t, tl.allTasks, 1) + assert.Equal(t, "Task 1", tl.allTasks[0].Title) + assert.Equal(t, Pending, tl.allTasks[0].State) +} + +func TestTaskList_AddTask_chaining(t *testing.T) { + tl := NewTaskList(nil) + action := func(sp SetProgressFunc) (TaskState, error) { + return Success, nil + } + + tl.AddTask(TaskOptions{Title: "A", Action: action}). + AddTask(TaskOptions{Title: "B", Action: action}). + AddTask(TaskOptions{Title: "C", Action: action}) + + assert.Len(t, tl.allTasks, 3) +} + +func TestTaskList_Render_pending(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + tl := NewTaskList(nil) + tl.allTasks = append(tl.allTasks, &Task{ + Title: "My pending task", + State: Pending, + }) + + err := tl.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "My pending task") +} + +func TestTaskList_Render_success_with_elapsed(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + start := time.Now().Add(-5 * time.Second) + end := time.Now() + + tl := NewTaskList(nil) + tl.allTasks = append(tl.allTasks, &Task{ + Title: "Completed task", + State: Success, + startTime: &start, + endTime: &end, + }) + + err := tl.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Completed task") +} + +func TestTaskList_Render_error_with_description(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + start := time.Now().Add(-2 * time.Second) + end := time.Now() + + tl := NewTaskList(nil) + tl.allTasks = append(tl.allTasks, &Task{ + Title: "Failed task", + State: Error, + Error: errors.New("something broke"), + startTime: &start, + endTime: &end, + }) + + err := tl.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Failed task") + assert.Contains(t, output, "something broke") +} + +func TestTaskList_Render_running_with_progress(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + start := time.Now().Add(-1 * time.Second) + tl := NewTaskList(nil) + tl.allTasks = append(tl.allTasks, &Task{ + Title: "Running task", + State: Running, + progress: "50%", + startTime: &start, + }) + + err := tl.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Running task") + assert.Contains(t, output, "50%") +} + +func TestTaskList_Render_skipped_with_and_without_error(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + tl := NewTaskList(nil) + tl.allTasks = append(tl.allTasks, + &Task{Title: "Skip no error", State: Skipped}, + &Task{ + Title: "Skip with error", + State: Skipped, + Error: errors.New("skipped reason"), + }, + ) + + err := tl.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Skip no error") + assert.Contains(t, output, "Skip with error") + assert.Contains(t, output, "skipped reason") +} + +func TestTaskList_Render_warning(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + start := time.Now().Add(-3 * time.Second) + end := time.Now() + + tl := NewTaskList(nil) + tl.allTasks = append(tl.allTasks, &Task{ + Title: "Warn task", + State: Warning, + Error: errors.New("partial failure"), + startTime: &start, + endTime: &end, + }) + + err := tl.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Warn task") + assert.Contains(t, output, "partial failure") +} + +func TestTaskList_Render_ordering(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + tl := NewTaskList(nil) + tl.allTasks = []*Task{ + {Title: "running-task", State: Running}, + {Title: "done-task", State: Success}, + {Title: "pending-task", State: Pending}, + } + + err := tl.Render(printer) + require.NoError(t, err) + + output := buf.String() + // Completed tasks render before running, running before pending + doneIdx := bytes.Index([]byte(output), []byte("done-task")) + runIdx := bytes.Index([]byte(output), []byte("running-task")) + pendIdx := bytes.Index([]byte(output), []byte("pending-task")) + + assert.Less(t, doneIdx, runIdx, + "done should appear before running") + assert.Less(t, runIdx, pendIdx, + "running should appear before pending") +} + +func TestTaskList_WithCanvas(t *testing.T) { + tl := NewTaskList(nil) + var buf bytes.Buffer + c := NewCanvas().WithWriter(&buf) + defer c.Close() + + result := tl.WithCanvas(c) + assert.Equal(t, tl, result) +} + +// --- Spinner tests --- + +func TestNewSpinner_defaults(t *testing.T) { + s := NewSpinner(&SpinnerOptions{}) + require.NotNil(t, s) + assert.Equal(t, "Loading...", s.text) + assert.Len(t, s.options.Animation, 4) + assert.Equal( + t, 250*time.Millisecond, s.options.Interval, + ) +} + +func TestNewSpinner_custom_text(t *testing.T) { + s := NewSpinner(&SpinnerOptions{Text: "Please wait"}) + require.NotNil(t, s) + assert.Equal(t, "Please wait", s.text) +} + +func TestSpinner_UpdateText(t *testing.T) { + s := NewSpinner(&SpinnerOptions{}) + s.UpdateText("new text") + assert.Equal(t, "new text", s.text) +} + +func TestSpinner_Render(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + s := NewSpinner(&SpinnerOptions{ + Text: "Working...", + Animation: []string{"|", "/", "-", "\\"}, + }) + + err := s.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Working...") +} + +func TestSpinner_Render_cycles_animation(t *testing.T) { + s := NewSpinner(&SpinnerOptions{ + Animation: []string{"a", "b", "c"}, + }) + + // Render multiple times to cycle through animation + for range 3 { + var buf bytes.Buffer + printer := NewPrinter(&buf) + err := s.Render(printer) + require.NoError(t, err) + } + + // After 3 renders, index should wrap back to 0 + assert.Equal(t, 0, s.animationIndex) +} + +func TestSpinner_Render_clear_returns_nil(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + s := NewSpinner(&SpinnerOptions{}) + s.clear = true + + err := s.Render(printer) + require.NoError(t, err) + assert.Empty(t, buf.String()) +} + +func TestSpinner_WithCanvas(t *testing.T) { + s := NewSpinner(&SpinnerOptions{}) + var buf bytes.Buffer + c := NewCanvas().WithWriter(&buf) + defer c.Close() + + result := s.WithCanvas(c) + assert.Equal(t, s, result) +} + +func TestSpinner_WithCanvas_nil(t *testing.T) { + s := NewSpinner(&SpinnerOptions{}) + result := s.WithCanvas(nil) + assert.Equal(t, s, result) + assert.Nil(t, s.canvas) +} diff --git a/cli/azd/pkg/ux/select_test.go b/cli/azd/pkg/ux/select_test.go new file mode 100644 index 00000000000..431e0b56e82 --- /dev/null +++ b/cli/azd/pkg/ux/select_test.go @@ -0,0 +1,404 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ux + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Select tests --- + +func TestNewSelect_with_choices(t *testing.T) { + s := NewSelect(&SelectOptions{ + Message: "Pick one", + Choices: []*SelectChoice{ + {Value: "a", Label: "Option A"}, + {Value: "b", Label: "Option B"}, + }, + }) + require.NotNil(t, s) + assert.Len(t, s.choices, 2) + assert.Len(t, s.filteredChoices, 2) + assert.Equal(t, "Pick one", s.options.Message) +} + +func TestNewSelect_default_hint(t *testing.T) { + s := NewSelect(&SelectOptions{ + Message: "Pick one", + Choices: []*SelectChoice{ + {Value: "a", Label: "A"}, + }, + }) + assert.Contains(t, s.options.Hint, "Use arrows to move") + assert.Contains(t, s.options.Hint, "type to filter") +} + +func TestNewSelect_hint_no_filter(t *testing.T) { + s := NewSelect(&SelectOptions{ + Message: "Pick one", + Choices: []*SelectChoice{{Value: "a", Label: "A"}}, + EnableFiltering: new(false), + }) + assert.Contains(t, s.options.Hint, "Use arrows to move") + assert.NotContains(t, s.options.Hint, "type to filter") +} + +func TestNewSelect_custom_hint(t *testing.T) { + s := NewSelect(&SelectOptions{ + Message: "Pick one", + Choices: []*SelectChoice{{Value: "a", Label: "A"}}, + Hint: "[my hint]", + }) + assert.Equal(t, "[my hint]", s.options.Hint) +} + +func TestSelect_Render_initial(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + s := NewSelect(&SelectOptions{ + Message: "Choose", + Choices: []*SelectChoice{ + {Value: "a", Label: "Alpha"}, + {Value: "b", Label: "Bravo"}, + {Value: "c", Label: "Charlie"}, + }, + }) + + err := s.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Choose") + assert.Contains(t, output, "Alpha") + assert.Contains(t, output, "Bravo") + assert.Contains(t, output, "Charlie") +} + +func TestSelect_Render_complete(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + s := NewSelect(&SelectOptions{ + Message: "Choose", + Choices: []*SelectChoice{ + {Value: "a", Label: "Alpha"}, + }, + }) + s.complete = true + s.selectedChoice = s.choices[0] + + err := s.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Choose") + assert.Contains(t, output, "Alpha") + // Options list should NOT appear when complete + assert.NotContains(t, output, "Use arrows") +} + +func TestSelect_Render_cancelled(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + s := NewSelect(&SelectOptions{ + Message: "Choose", + Choices: []*SelectChoice{ + {Value: "a", Label: "Alpha"}, + }, + }) + s.cancelled = true + + err := s.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Cancelled") +} + +func TestSelect_applyFilter_matches(t *testing.T) { + s := NewSelect(&SelectOptions{ + Message: "Choose", + Choices: []*SelectChoice{ + {Value: "apple", Label: "Apple"}, + {Value: "banana", Label: "Banana"}, + {Value: "apricot", Label: "Apricot"}, + }, + }) + // Initialize currentIndex (Render normally does this) + s.currentIndex = new(0) + s.filter = "ap" + + s.applyFilter() + assert.Len(t, s.filteredChoices, 2) +} + +func TestSelect_applyFilter_no_match(t *testing.T) { + s := NewSelect(&SelectOptions{ + Message: "Choose", + Choices: []*SelectChoice{ + {Value: "apple", Label: "Apple"}, + {Value: "banana", Label: "Banana"}, + }, + }) + s.currentIndex = new(0) + s.filter = "xyz" + + s.applyFilter() + assert.Empty(t, s.filteredChoices) +} + +func TestSelect_applyFilter_empty_resets(t *testing.T) { + s := NewSelect(&SelectOptions{ + Message: "Choose", + Choices: []*SelectChoice{ + {Value: "a", Label: "A"}, + {Value: "b", Label: "B"}, + }, + }) + s.currentIndex = new(0) + s.filter = "" + + s.applyFilter() + assert.Len(t, s.filteredChoices, 2) +} + +func TestSelect_applyFilter_by_number(t *testing.T) { + s := NewSelect(&SelectOptions{ + Message: "Choose", + DisplayNumbers: new(true), + Choices: []*SelectChoice{ + {Value: "a", Label: "Alpha"}, + {Value: "b", Label: "Bravo"}, + {Value: "c", Label: "Charlie"}, + }, + }) + s.currentIndex = new(0) + s.filter = "2" + + s.applyFilter() + assert.Len(t, s.filteredChoices, 1) + assert.Equal(t, "Bravo", s.filteredChoices[0].Label) +} + +func TestSelect_WithCanvas(t *testing.T) { + s := NewSelect(&SelectOptions{ + Message: "Choose", + Choices: []*SelectChoice{{Value: "a", Label: "A"}}, + }) + var buf bytes.Buffer + c := NewCanvas().WithWriter(&buf) + defer c.Close() + + result := s.WithCanvas(c) + assert.Equal(t, s, result) +} + +func TestSelect_renderValidation_no_matches(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + s := NewSelect(&SelectOptions{ + Message: "Choose", + Choices: []*SelectChoice{ + {Value: "a", Label: "A"}, + }, + }) + s.filteredChoices = []*indexedSelectChoice{} + + s.renderValidation(printer) + assert.True(t, s.hasValidationError) + assert.Contains(t, s.validationMessage, "No options found") +} + +// --- MultiSelect tests --- + +func TestNewMultiSelect_with_choices(t *testing.T) { + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick many", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "Alpha"}, + {Value: "b", Label: "Bravo"}, + }, + }) + require.NotNil(t, ms) + assert.Len(t, ms.choices, 2) + assert.Empty(t, ms.selectedChoices) +} + +func TestNewMultiSelect_preselected(t *testing.T) { + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick many", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "Alpha", Selected: true}, + {Value: "b", Label: "Bravo"}, + {Value: "c", Label: "Charlie", Selected: true}, + }, + }) + assert.Len(t, ms.selectedChoices, 2) + assert.Contains(t, ms.selectedChoices, "a") + assert.Contains(t, ms.selectedChoices, "c") +} + +func TestMultiSelect_Render_initial(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick many", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "Alpha"}, + {Value: "b", Label: "Bravo"}, + }, + }) + + err := ms.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Pick many") + assert.Contains(t, output, "Alpha") + assert.Contains(t, output, "Bravo") +} + +func TestMultiSelect_Render_complete(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick many", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "Alpha", Selected: true}, + {Value: "b", Label: "Bravo"}, + }, + }) + ms.complete = true + + err := ms.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Alpha") + // Footer should NOT appear when complete + assert.NotContains(t, output, "Use arrows") +} + +func TestMultiSelect_Render_cancelled(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick many", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "Alpha"}, + }, + }) + ms.cancelled = true + + err := ms.Render(printer) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, "Cancelled") +} + +func TestMultiSelect_validate_no_selection(t *testing.T) { + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick many", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "Alpha"}, + }, + }) + ms.submitted = true + ms.validate() + + assert.True(t, ms.hasValidationError) + assert.Contains(t, + ms.validationMessage, "At least one option", + ) +} + +func TestMultiSelect_validate_with_selection(t *testing.T) { + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick many", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "Alpha", Selected: true}, + }, + }) + ms.submitted = true + ms.validate() + + assert.False(t, ms.hasValidationError) +} + +func TestMultiSelect_validate_empty_filter(t *testing.T) { + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick many", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "Alpha"}, + }, + }) + ms.filteredChoices = []*indexedMultiSelectChoice{} + ms.validate() + + assert.True(t, ms.hasValidationError) + assert.Contains(t, + ms.validationMessage, "No options found", + ) +} + +func TestMultiSelect_sortSelectedChoices(t *testing.T) { + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick many", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "Alpha"}, + {Value: "b", Label: "Bravo"}, + {Value: "c", Label: "Charlie"}, + }, + }) + // Select in reverse order + ms.selectedChoices["c"] = ms.choices[2] + ms.selectedChoices["a"] = ms.choices[0] + + sorted := ms.sortSelectedChoices() + require.Len(t, sorted, 2) + assert.Equal(t, "Alpha", sorted[0].Label) + assert.Equal(t, "Charlie", sorted[1].Label) +} + +func TestMultiSelect_applyFilter(t *testing.T) { + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick", + Choices: []*MultiSelectChoice{ + {Value: "apple", Label: "Apple"}, + {Value: "banana", Label: "Banana"}, + {Value: "apricot", Label: "Apricot"}, + }, + }) + ms.currentIndex = new(0) + ms.filter = "ap" + + ms.applyFilter() + assert.Len(t, ms.filteredChoices, 2) +} + +func TestMultiSelect_WithCanvas(t *testing.T) { + ms := NewMultiSelect(&MultiSelectOptions{ + Message: "Pick", + Choices: []*MultiSelectChoice{ + {Value: "a", Label: "A"}, + }, + }) + var buf bytes.Buffer + c := NewCanvas().WithWriter(&buf) + defer c.Close() + + result := ms.WithCanvas(c) + assert.Equal(t, ms, result) +} diff --git a/cli/azd/pkg/ux/ux_additional_test.go b/cli/azd/pkg/ux/ux_additional_test.go new file mode 100644 index 00000000000..b77e390729c --- /dev/null +++ b/cli/azd/pkg/ux/ux_additional_test.go @@ -0,0 +1,343 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ux + +import ( + "bytes" + "errors" + "testing" +) + +func TestConsoleWidth_from_env(t *testing.T) { + t.Setenv("COLUMNS", "200") + + // ConsoleWidth uses consolesize-go first; if that returns <=0 it falls back to COLUMNS + width := ConsoleWidth() + if width <= 0 { + t.Fatalf("ConsoleWidth() = %d, want > 0", width) + } +} + +func TestConsoleWidth_invalid_COLUMNS_fallback(t *testing.T) { + t.Setenv("COLUMNS", "not-a-number") + + width := ConsoleWidth() + if width <= 0 { + t.Fatalf("ConsoleWidth() = %d, want > 0", width) + } +} + +func TestConsoleWidth_empty_COLUMNS_uses_default(t *testing.T) { + t.Setenv("COLUMNS", "") + + width := ConsoleWidth() + if width <= 0 { + t.Fatalf("ConsoleWidth() = %d, want > 0", width) + } +} + +func TestPtr(t *testing.T) { + intVal := 42 + p := Ptr(intVal) + if p == nil { + t.Fatal("Ptr should return non-nil pointer") + } + if *p != 42 { + t.Fatalf("*Ptr(42) = %d, want 42", *p) + } + + strVal := "hello" + sp := Ptr(strVal) + if *sp != "hello" { + t.Fatalf("*Ptr(hello) = %q, want hello", *sp) + } +} + +func TestRender_creates_visual(t *testing.T) { + v := Render(func(p Printer) error { + return nil + }) + + if v == nil { + t.Fatal("Render should return non-nil Visual") + } +} + +func TestNewVisualElement_Render(t *testing.T) { + renderCalled := false + elem := NewVisualElement(func(p Printer) error { + renderCalled = true + p.Fprintf("test output") + return nil + }) + + printer := NewPrinter(&bytes.Buffer{}) + err := elem.Render(printer) + if err != nil { + t.Fatalf("Render error: %v", err) + } + if !renderCalled { + t.Fatal("Render function should have been called") + } +} + +func TestNewVisualElement_Render_error(t *testing.T) { + expectedErr := errors.New("render failed") + elem := NewVisualElement(func(p Printer) error { + return expectedErr + }) + + printer := NewPrinter(&bytes.Buffer{}) + err := elem.Render(printer) + if !errors.Is(err, expectedErr) { + t.Fatalf("Render error = %v, want %v", err, expectedErr) + } +} + +func TestNewVisualElement_WithCanvas(t *testing.T) { + elem := NewVisualElement(func(p Printer) error { return nil }) + + var buf bytes.Buffer + canvas := NewCanvas().WithWriter(&buf) + result := elem.WithCanvas(canvas) + + if result != elem { + t.Fatal("WithCanvas should return the same element for chaining") + } +} + +func TestNewPrinter_with_buffer(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + if printer == nil { + t.Fatal("NewPrinter should return non-nil") + } + + printer.Fprintf("hello %s", "world") + if !bytes.Contains(buf.Bytes(), []byte("hello world")) { + t.Fatalf("Fprintf output = %q, want to contain 'hello world'", buf.String()) + } +} + +func TestNewPrinter_nil_writer_defaults_to_stdout(t *testing.T) { + // Should not panic + printer := NewPrinter(nil) + if printer == nil { + t.Fatal("NewPrinter(nil) should return non-nil") + } +} + +func TestPrinter_Fprintln(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + printer.Fprintln("line 1") + printer.Fprintln("line 2") + + output := buf.String() + if !bytes.Contains([]byte(output), []byte("line 1\n")) { + t.Fatalf("Fprintln output should contain 'line 1\\n', got %q", output) + } + if !bytes.Contains([]byte(output), []byte("line 2\n")) { + t.Fatalf("Fprintln output should contain 'line 2\\n', got %q", output) + } +} + +func TestPrinter_Size_tracks_rows(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + // Initial size + size := printer.Size() + if size.Rows != 1 { + t.Fatalf("Initial Rows = %d, want 1", size.Rows) + } + + printer.Fprintf("line 1\n") + size = printer.Size() + if size.Rows < 2 { + t.Fatalf("After one newline, Rows = %d, want >= 2", size.Rows) + } +} + +func TestPrinter_CursorPosition(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + pos := printer.CursorPosition() + if pos.Row < 1 { + t.Fatalf("CursorPosition().Row = %d, want >= 1", pos.Row) + } +} + +func TestPrinter_ClearCanvas(t *testing.T) { + var buf bytes.Buffer + printer := NewPrinter(&buf) + + printer.Fprintf("some content\n") + printer.Fprintf("more content\n") + + // Clear should not panic and should reset state + printer.ClearCanvas() + + size := printer.Size() + if size.Rows != 1 { + t.Fatalf("After ClearCanvas, Rows = %d, want 1", size.Rows) + } +} + +func TestPrinter_SetCursorPosition_same_position_noop(t *testing.T) { + var buf bytes.Buffer + p := NewPrinter(&buf) + pos := CursorPosition{Row: 5, Col: 3} + + beforeFirst := buf.Len() + p.SetCursorPosition(pos) + afterFirst := buf.Len() + if afterFirst <= beforeFirst { + t.Fatalf( + "expected first SetCursorPosition to write escape codes, before = %d, after = %d", + beforeFirst, afterFirst) + } + + // Setting same position should not write additional escape codes + p.SetCursorPosition(pos) + afterSecond := buf.Len() + + if afterSecond != afterFirst { + t.Fatalf( + "expected same-position SetCursorPosition to be a no-op, first = %d, second = %d", + afterFirst, afterSecond) + } +} + +func TestNewCanvasSize(t *testing.T) { + size := newCanvasSize() + if size.Rows != 1 { + t.Fatalf("newCanvasSize().Rows = %d, want 1", size.Rows) + } + if size.Cols != 0 { + t.Fatalf("newCanvasSize().Cols = %d, want 0", size.Cols) + } +} + +func TestCanvas_Run_with_visual(t *testing.T) { + var buf bytes.Buffer + renderCalled := false + + visual := NewVisualElement(func(p Printer) error { + renderCalled = true + p.Fprintf("canvas output") + return nil + }) + + canvas := NewCanvas(visual).WithWriter(&buf) + defer canvas.Close() + + err := canvas.Run() + if err != nil { + t.Fatalf("Canvas.Run() error: %v", err) + } + + if !renderCalled { + t.Fatal("Visual.Render should have been called") + } + + if !bytes.Contains(buf.Bytes(), []byte("canvas output")) { + t.Fatalf("Canvas output = %q, want to contain 'canvas output'", buf.String()) + } +} + +func TestCanvas_Run_multiple_visuals(t *testing.T) { + var buf bytes.Buffer + callOrder := []string{} + + v1 := NewVisualElement(func(p Printer) error { + callOrder = append(callOrder, "v1") + p.Fprintf("first") + return nil + }) + + v2 := NewVisualElement(func(p Printer) error { + callOrder = append(callOrder, "v2") + p.Fprintf("second") + return nil + }) + + canvas := NewCanvas(v1, v2).WithWriter(&buf) + defer canvas.Close() + + err := canvas.Run() + if err != nil { + t.Fatalf("Canvas.Run() error: %v", err) + } + + if len(callOrder) != 2 || callOrder[0] != "v1" || callOrder[1] != "v2" { + t.Fatalf("Call order = %v, want [v1 v2]", callOrder) + } +} + +func TestCanvas_render_error_propagates(t *testing.T) { + var buf bytes.Buffer + expectedErr := errors.New("visual render failed") + + visual := NewVisualElement(func(p Printer) error { + return expectedErr + }) + + canvas := NewCanvas(visual).WithWriter(&buf) + defer canvas.Close() + + err := canvas.Run() + if !errors.Is(err, expectedErr) { + t.Fatalf("Canvas.Run() error = %v, want %v", err, expectedErr) + } +} + +func TestCanvas_Clear(t *testing.T) { + var buf bytes.Buffer + + visual := NewVisualElement(func(p Printer) error { + p.Fprintf("some content") + return nil + }) + + canvas := NewCanvas(visual).WithWriter(&buf) + defer canvas.Close() + + err := canvas.Run() + if err != nil { + t.Fatalf("Canvas.Run() error: %v", err) + } + + err = canvas.Clear() + if err != nil { + t.Fatalf("Canvas.Clear() error: %v", err) + } +} + +func TestCanvasManager_CanUpdate(t *testing.T) { + mgr := newCanvasManager() + var buf bytes.Buffer + + c1 := NewCanvas().WithWriter(&buf) + c2 := NewCanvas().WithWriter(&buf) + defer c1.Close() + defer c2.Close() + + // No focused canvas — any canvas can update + if !mgr.CanUpdate(c1) { + t.Fatal("CanUpdate(c1) should be true when no canvas is focused") + } +} + +func TestErrCancelled(t *testing.T) { + if ErrCancelled == nil { + t.Fatal("ErrCancelled should not be nil") + } + if ErrCancelled.Error() == "" { + t.Fatal("ErrCancelled.Error() should not be empty") + } +} diff --git a/cli/azd/pkg/watch/watch_helpers_test.go b/cli/azd/pkg/watch/watch_helpers_test.go new file mode 100644 index 00000000000..625b4a201da --- /dev/null +++ b/cli/azd/pkg/watch/watch_helpers_test.go @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package watch + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFileChange_String_Created(t *testing.T) { + fc := FileChange{ + Path: "/absolute/path/file.txt", + ChangeType: FileCreated, + } + s := fc.String() + require.Contains(t, s, "+ Created") + require.Contains(t, s, "file.txt") +} + +func TestFileChange_String_Modified(t *testing.T) { + fc := FileChange{ + Path: "/absolute/path/file.txt", + ChangeType: FileModified, + } + s := fc.String() + require.Contains(t, s, "± Modified") +} + +func TestFileChange_String_Deleted(t *testing.T) { + fc := FileChange{ + Path: "/absolute/path/file.txt", + ChangeType: FileDeleted, + } + s := fc.String() + require.Contains(t, s, "- Deleted") +} + +func TestFileChange_String_DefaultCase(t *testing.T) { + fc := FileChange{ + Path: "/absolute/path/unknown.txt", + ChangeType: FileChangeType(99), + } + s := fc.String() + // Default case should still contain the path + require.Contains(t, s, "unknown.txt") + // Should NOT contain the known prefixes + require.NotContains(t, s, "Created") + require.NotContains(t, s, "Modified") + require.NotContains(t, s, "Deleted") +} + +func TestFileChange_String_RelativePath(t *testing.T) { + // When the file is inside the cwd, String() should + // convert to a relative path. Use a known path that + // exists relative to cwd. + fc := FileChange{ + Path: "relative.txt", + ChangeType: FileCreated, + } + s := fc.String() + // Should produce output regardless of relative conversion + require.NotEmpty(t, s) +} + +func TestFileChanges_String_Empty(t *testing.T) { + fc := FileChanges{} + require.Equal(t, "", fc.String()) +} + +func TestFileChanges_String_SingleEntry(t *testing.T) { + fc := FileChanges{ + {Path: "file.txt", ChangeType: FileCreated}, + } + s := fc.String() + require.Contains(t, s, "Files changed:") + require.Contains(t, s, "file.txt") +} + +func TestFileChanges_String_MultipleEntries(t *testing.T) { + fc := FileChanges{ + {Path: "a.txt", ChangeType: FileCreated}, + {Path: "b.txt", ChangeType: FileModified}, + {Path: "c.txt", ChangeType: FileDeleted}, + } + s := fc.String() + require.Contains(t, s, "Files changed:") + require.Contains(t, s, "a.txt") + require.Contains(t, s, "b.txt") + require.Contains(t, s, "c.txt") + require.Contains(t, s, "Created") + require.Contains(t, s, "Modified") + require.Contains(t, s, "Deleted") +} + +func TestGetFileChanges_Sorting(t *testing.T) { + fc := &fileChanges{ + Created: map[string]bool{"z.txt": true, "a.txt": true}, + Modified: map[string]bool{"m.txt": true}, + Deleted: map[string]bool{"d.txt": true}, + } + fw := &fileWatcher{fileChanges: fc} + + changes := fw.GetFileChanges() + require.Len(t, changes, 4) + // Verify sorted by path + for i := 1; i < len(changes); i++ { + require.LessOrEqual(t, changes[i-1].Path, changes[i].Path, + "changes should be sorted by path") + } +} + +func TestGetFileChanges_ChangeTypes(t *testing.T) { + fc := &fileChanges{ + Created: map[string]bool{"new.txt": true}, + Modified: map[string]bool{"mod.txt": true}, + Deleted: map[string]bool{"del.txt": true}, + } + fw := &fileWatcher{fileChanges: fc} + + changes := fw.GetFileChanges() + changeMap := make(map[string]FileChangeType) + for _, c := range changes { + changeMap[c.Path] = c.ChangeType + } + + require.Equal(t, FileDeleted, changeMap["del.txt"]) + require.Equal(t, FileModified, changeMap["mod.txt"]) + require.Equal(t, FileCreated, changeMap["new.txt"]) +} + +func TestGetFileChanges_EmptyMaps(t *testing.T) { + fc := &fileChanges{ + Created: map[string]bool{}, + Modified: map[string]bool{}, + Deleted: map[string]bool{}, + } + fw := &fileWatcher{fileChanges: fc} + + changes := fw.GetFileChanges() + require.Empty(t, changes) +} + +func TestFileChangeType_Exhaustive(t *testing.T) { + tests := []struct { + name string + val FileChangeType + want int + }{ + {"Created", FileCreated, 0}, + {"Modified", FileModified, 1}, + {"Deleted", FileDeleted, 2}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, FileChangeType(tt.want), tt.val) + }) + } +} + +func TestFileChanges_String_PreservesOrder(t *testing.T) { + // The String() method should iterate in the order of the + // slice, so we can verify ordering is preserved. + fc := FileChanges{ + {Path: "first.txt", ChangeType: FileCreated}, + {Path: "second.txt", ChangeType: FileModified}, + {Path: "third.txt", ChangeType: FileDeleted}, + } + s := fc.String() + + firstIdx := indexOf(s, "first.txt") + secondIdx := indexOf(s, "second.txt") + thirdIdx := indexOf(s, "third.txt") + + require.Greater(t, secondIdx, firstIdx, + "second should appear after first") + require.Greater(t, thirdIdx, secondIdx, + "third should appear after second") +} + +// indexOf returns the position of substr in s, or -1. +func indexOf(s, substr string) int { + for i := 0; i+len(substr) <= len(s); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} diff --git a/eng/scripts/Get-CICoverageReport.ps1 b/eng/scripts/Get-CICoverageReport.ps1 new file mode 100644 index 00000000000..ba67e0304ae --- /dev/null +++ b/eng/scripts/Get-CICoverageReport.ps1 @@ -0,0 +1,176 @@ +#!/usr/bin/env pwsh + +<# +.SYNOPSIS + Downloads and analyzes combined (unit + integration) test coverage from CI. + +.DESCRIPTION + Fetches coverage artifacts from the Azure DevOps CI pipeline, merges unit + and integration test coverage using 'go tool covdata', and produces a + per-package coverage report sorted by coverage percentage. + + This gives the TRUE coverage picture — including coverage from functional + tests that exercise the azd binary built with '-cover'. Running + 'go test -short -cover' locally only measures unit test coverage and can + significantly underestimate actual coverage. + +.PARAMETER BuildId + Azure DevOps build ID to download coverage from. If not specified, uses + the latest successful build from the main branch. + +.PARAMETER Organization + Azure DevOps organization URL. Defaults to 'https://dev.azure.com/azure-sdk'. + +.PARAMETER Project + Azure DevOps project name or ID. Defaults to 'internal'. + +.PARAMETER OutputFile + Path to write the combined cover.out file. Defaults to 'cover-ci-combined.out'. + +.PARAMETER ShowReport + If set, prints a per-package coverage report sorted by coverage. + +.PARAMETER MinCoverage + If set, filters the report to only show packages below this coverage threshold. + +.EXAMPLE + # Get latest main build coverage + ./Get-CICoverageReport.ps1 -ShowReport + +.EXAMPLE + # Show only packages below 10% coverage + ./Get-CICoverageReport.ps1 -ShowReport -MinCoverage 10 + +.EXAMPLE + # Use a specific build + ./Get-CICoverageReport.ps1 -BuildId 6065857 -ShowReport +#> + +param( + [int]$BuildId = 0, + [string]$Organization = 'https://dev.azure.com/azure-sdk', + [string]$Project = 'internal', + [string]$OutputFile = 'cover-ci-combined.out', + [switch]$ShowReport, + [double]$MinCoverage = -1, + [string]$PipelineDefinitionId = '4643' +) + +$ErrorActionPreference = 'Stop' + +# Resolve organization to just the name for API calls +$orgName = $Organization -replace 'https://dev.azure.com/', '' + +# Get Azure DevOps access token +Write-Host "Authenticating with Azure DevOps..." +$token = az account get-access-token --resource "499b84ac-1321-427f-aa17-267ca6975798" --query accessToken -o tsv +if ($LASTEXITCODE) { + throw "Failed to get Azure DevOps access token. Run 'az login' first." +} +$headers = @{ Authorization = "Bearer $token" } + +# Find the build to use +if ($BuildId -eq 0) { + Write-Host "Finding latest successful build from main..." + $buildsUrl = "$Organization/$Project/_apis/build/builds?definitions=$PipelineDefinitionId&branchName=refs/heads/main&resultFilter=succeeded&`$top=1&api-version=7.1" + $buildsResp = Invoke-RestMethod -Uri $buildsUrl -Headers $headers -Method Get + if ($buildsResp.count -eq 0) { + throw "No successful builds found for pipeline $PipelineDefinitionId on main" + } + $BuildId = $buildsResp.value[0].id + $buildNumber = $buildsResp.value[0].buildNumber + Write-Host "Using build $BuildId ($buildNumber)" +} + +# Create temp directory +$tempDir = Join-Path ([System.IO.Path]::GetTempPath()) "azd-ci-coverage-$BuildId" +if (Test-Path $tempDir) { + Remove-Item -Recurse -Force $tempDir +} +New-Item -ItemType Directory -Force -Path $tempDir | Out-Null + +function Download-Artifact { + param([string]$ArtifactName, [string]$DestDir) + + Write-Host " Downloading $ArtifactName..." + $url = "$Organization/$Project/_apis/build/builds/$BuildId/artifacts?artifactName=$ArtifactName&api-version=7.1" + $resp = Invoke-RestMethod -Uri $url -Headers $headers -Method Get + $downloadUrl = $resp.resource.downloadUrl + + $zipPath = Join-Path $tempDir "$ArtifactName.zip" + Invoke-WebRequest -Uri $downloadUrl -Headers $headers -OutFile $zipPath + + $extractPath = Join-Path $tempDir $ArtifactName + Expand-Archive -Path $zipPath -DestinationPath $extractPath -Force + Remove-Item $zipPath + + # Pipeline artifacts nest under the artifact name + $nested = Join-Path $extractPath $ArtifactName + if (Test-Path $nested) { + return $nested + } + return $extractPath +} + +# Download artifacts +Write-Host "Downloading coverage artifacts from build $BuildId..." +$unitDir = Download-Artifact -ArtifactName "cover-unit" -DestDir $tempDir +$intDir = Download-Artifact -ArtifactName "cover-int" -DestDir $tempDir + +# Merge coverage +$mergedDir = Join-Path $tempDir "cover-merged" +New-Item -ItemType Directory -Force -Path $mergedDir | Out-Null + +Write-Host "Merging unit + integration coverage..." +go tool covdata merge -i="$unitDir,$intDir" -o "$mergedDir" +if ($LASTEXITCODE) { + throw "go tool covdata merge failed" +} + +# Convert to text format +Write-Host "Converting to text format..." +go tool covdata textfmt -i="$mergedDir" -o $OutputFile +if ($LASTEXITCODE) { + throw "go tool covdata textfmt failed" +} + +$lineCount = (Get-Content $OutputFile).Count +Write-Host "Combined coverage written to $OutputFile ($lineCount lines)" + +# Show report if requested +if ($ShowReport) { + Write-Host "" + Write-Host "==========================================" + Write-Host " Combined Coverage Report (Build $BuildId)" + Write-Host "==========================================" + Write-Host "" + + $percentOutput = go tool covdata percent -i="$mergedDir" 2>&1 + $parsed = $percentOutput | ForEach-Object { + if ($_ -match '([\w/./-]+)\s+coverage:\s+([\d.]+)%') { + $pkg = $Matches[1] -replace 'github.com/azure/azure-dev/cli/azd/', '' + $pct = [double]$Matches[2] + [PSCustomObject]@{ Package = $pkg; Coverage = $pct } + } + } + + if ($MinCoverage -ge 0) { + $parsed = $parsed | Where-Object { $_.Coverage -lt $MinCoverage } + Write-Host "Packages below ${MinCoverage}% coverage:" + } else { + Write-Host "All packages (sorted by coverage):" + } + + Write-Host "" + $parsed | Sort-Object Coverage | Format-Table -AutoSize + + $avg = ($parsed | Measure-Object -Property Coverage -Average).Average + $count = $parsed.Count + Write-Host "Packages shown: $count | Average coverage: $([math]::Round($avg, 1))%" +} + +# Cleanup temp files (keep the output file) +Remove-Item -Recurse -Force $tempDir + +Write-Host "" +Write-Host "Done. Combined coverage profile: $OutputFile"