diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval.go new file mode 100644 index 00000000000..d53ae625948 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval.go @@ -0,0 +1,484 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval.go implements the top-level "eval" command group and shared context +// resolution logic used by all eval subcommands (init, run, update, list, show). +// +// The evalResolvedContext struct holds the resolved agent, project, and +// endpoint information. It is built from azd project state, environment +// variables, or interactive prompts, and threaded through all subcommands. + +package cmd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/agents/agent_yaml" + "azureaiagent/internal/pkg/agents/dataset_api" + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/azure/azure-dev/cli/azd/pkg/output" + "github.com/fatih/color" + "github.com/spf13/cobra" + "go.yaml.in/yaml/v3" +) + +// Default values for eval configuration. +const ( + defaultEvalConfigName = "eval.yaml" + defaultEvalName = "smoke-core" + defaultEvalSamples = 15 + defaultEvalModel = "gpt-4o" +) + +// Type aliases to avoid repeating full package paths throughout the eval code. +type evalConfig = eval_api.EvalConfig +type evalAgentRef = opt_eval.AgentRef +type evalDatasetRef = opt_eval.DatasetRef + +// evalResolvedContext holds the fully-resolved context for an eval operation, +// including the azd client, API clients, project paths, and agent metadata. +// Built by resolveEvalContext from azd project state, environment variables, +// or interactive prompts. +type evalResolvedContext struct { + azdClient *azdext.AzdClient + evalClient *eval_api.EvalClient + datasetClient *dataset_api.DatasetClient + projectRoot string // azd project root directory + hasProject bool // true if running within an azd project + agentProject string // agent service directory + agentProjectSource string // how agentProject was resolved + agentName string // deployed agent name + agentNameSource string // how agentName was resolved + version string // agent version + versionSource string // how version was resolved + agentKind agent_yaml.AgentKind // hosted or prompt + agentKindSource string // how agentKind was resolved + serviceName string // azure.yaml service name + projectEndpoint string // Foundry project endpoint URL + projectEndpointSource string // how projectEndpoint was resolved + envName string // azd environment name +} + +// evalContextOptions configures the behavior of resolveEvalContext. +type evalContextOptions struct { + agent string // explicit agent name (from --agent flag) + projectEndpoint string // explicit project endpoint (from --project-endpoint flag) + requireAgent bool // fail if agent name cannot be resolved + noPrompt bool // skip interactive prompts +} + +func newEvalCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + cmd := &cobra.Command{ + Use: "eval ", + Short: "Create and run quick evals for an agent.", + Long: `Create and run quick evals for an agent. + +Subcommands: + init Generate an eval config and dataset from a hosted agent + run Execute an evaluation run from eval.yaml + update Update an existing eval configuration + list List evaluations for the current project + show Show details of an evaluation run`, + } + + cmd.AddCommand(newEvalInitCommand(extCtx)) + cmd.AddCommand(newEvalRunCommand(extCtx)) + cmd.AddCommand(newEvalUpdateCommand(extCtx)) + cmd.AddCommand(newEvalListCommand()) + cmd.AddCommand(newEvalShowCommand()) + + return cmd +} + +// resolveEvalContext resolves the context for an eval operation by reading azd project state, +// environment variables, and optionally prompting the user. It returns an evalResolvedContext +// with API clients and metadata needed to run eval commands. +func resolveEvalContext(ctx context.Context, options evalContextOptions) (*evalResolvedContext, error) { + fmt.Println(output.WithGrayFormat("Resolving eval context...")) + + azdClient, err := azdext.NewAzdClient() + if err != nil { + return nil, fmt.Errorf("failed to create azd client: %w", err) + } + + fmt.Println(output.WithGrayFormat(" Reading project configuration...")) + projectResponse, err := azdClient.Project().Get(ctx, &azdext.EmptyRequest{}) + + // If no azd workspace is found, fall back to prompt-based resolution. + if err != nil || projectResponse.Project == nil { + return resolveEvalContextWithoutProject(ctx, azdClient, options) + } + project := projectResponse.Project + + fmt.Println(output.WithGrayFormat(" Detecting agent service...")) + + // Read the current azd environment once — used for agent info, endpoint, and env name. + var envName string + envResp, envErr := azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{}) + if envErr == nil && envResp.Environment != nil { + envName = envResp.Environment.Name + } + + getEnvValue := func(key string) string { + if envName == "" { + return "" + } + v, e := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: envName, Key: key, + }) + if e != nil || v.Value == "" { + return "" + } + return v.Value + } + + var svc *azdext.ServiceConfig + var info *AgentServiceInfo + svc, _, err = resolveAgentService(ctx, azdClient, options.agent, options.noPrompt) + if err == nil { + // Resolve deployed agent name/version from azd environment. + info = &AgentServiceInfo{ServiceName: svc.Name} + serviceKey := toServiceKey(svc.Name) + if v := getEnvValue(fmt.Sprintf("AGENT_%s_NAME", serviceKey)); v != "" { + info.AgentName = v + } + if v := getEnvValue(fmt.Sprintf("AGENT_%s_VERSION", serviceKey)); v != "" { + info.Version = v + } + } else if options.agent == "" && options.requireAgent { + azdClient.Close() + return nil, evalAgentContextError(err) + } + + fmt.Println(output.WithGrayFormat(" Resolving Foundry project endpoint...")) + projectEndpoint := options.projectEndpoint + projectEndpointSource := "--project-endpoint" + if projectEndpoint == "" { + if v := getEnvValue("FOUNDRY_PROJECT_ENDPOINT"); v != "" { + projectEndpoint = v + projectEndpointSource = "FOUNDRY_PROJECT_ENDPOINT" + } + } + if projectEndpoint == "" { + if v := getEnvValue("AZURE_AI_PROJECT_ENDPOINT"); v != "" { + projectEndpoint = v + projectEndpointSource = "AZURE_AI_PROJECT_ENDPOINT" + } + } + if projectEndpoint == "" { + if v := getEnvValue("AZURE_AI_PROJECT_ID"); v != "" { + ep, epErr := endpointFromProjectID(v) + if epErr != nil { + azdClient.Close() + return nil, epErr + } + projectEndpoint = ep + projectEndpointSource = "AZURE_AI_PROJECT_ID" + } + } + if projectEndpoint == "" { + azdClient.Close() + return nil, exterrors.Dependency( + exterrors.CodeMissingAiProjectEndpoint, + "Foundry project context could not be resolved", + "run 'azd ai agent init' to configure your project, or pass --project-endpoint directly", + ) + } + + agentName := options.agent + agentNameSource := "--agent" + agentVersion := "" + agentVersionSource := "unresolved" + agentKind := agent_yaml.AgentKind("") + agentKindSource := "unresolved" + serviceName := "" + agentProject := project.Path + agentProjectSource := "workspace root" + if agentName == "" { + agentNameSource = "unresolved" + } + if svc != nil { + serviceName = svc.Name + agentProject = filepath.Join(project.Path, svc.RelativePath) + agentProjectSource = fmt.Sprintf("azure.yaml service %q project path", svc.Name) + serviceKey := toServiceKey(svc.Name) + if info != nil && info.AgentName != "" { + agentName = info.AgentName + agentNameSource = fmt.Sprintf("AGENT_%s_NAME", serviceKey) + } + if info != nil && info.Version != "" { + agentVersion = info.Version + agentVersionSource = fmt.Sprintf("AGENT_%s_VERSION", serviceKey) + } + if detectedKind, manifestPath := detectEvalAgentKind(agentProject); detectedKind != "" { + agentKind = detectedKind + agentKindSource = relPathForYaml(project.Path, manifestPath) + } + } + if agentKind == "" { + agentKind = agent_yaml.AgentKindHosted + agentKindSource = "default" + } + if !agent_yaml.IsValidAgentKind(agentKind) { + azdClient.Close() + return nil, fmt.Errorf("unsupported agent kind %q", agentKind) + } + + if options.requireAgent && agentName == "" { + azdClient.Close() + return nil, evalAgentContextError(nil) + } + + credential, err := newAgentCredential() + if err != nil { + azdClient.Close() + return nil, err + } + evalClient := eval_api.NewEvalClient(projectEndpoint, credential) + datasetClient := dataset_api.NewDatasetClient(projectEndpoint, credential) + + return &evalResolvedContext{ + azdClient: azdClient, + evalClient: evalClient, + datasetClient: datasetClient, + projectRoot: project.Path, + hasProject: true, + agentProject: agentProject, + agentProjectSource: agentProjectSource, + agentName: agentName, + agentNameSource: agentNameSource, + version: agentVersion, + versionSource: agentVersionSource, + agentKind: agentKind, + agentKindSource: agentKindSource, + serviceName: serviceName, + projectEndpoint: projectEndpoint, + projectEndpointSource: projectEndpointSource, + envName: envName, + }, nil +} + +// resolveEvalContextWithoutProject prompts the user for essential inputs when +// there is no azd workspace (no azure.yaml). In --no-prompt mode it requires +// --project-endpoint and --agent to be passed explicitly. +func resolveEvalContextWithoutProject( + ctx context.Context, + azdClient *azdext.AzdClient, + options evalContextOptions, +) (*evalResolvedContext, error) { + fmt.Println(output.WithGrayFormat(" No azd project found. Prompting for inputs...")) + + projectEndpoint := options.projectEndpoint + agentName := options.agent + + if options.noPrompt { + if projectEndpoint == "" { + azdClient.Close() + return nil, exterrors.Dependency( + exterrors.CodeMissingAiProjectEndpoint, + "--project-endpoint is required when running outside an azd project with --no-prompt", + "pass --project-endpoint (-p) with your Foundry project endpoint URL", + ) + } + if agentName == "" && options.requireAgent { + azdClient.Close() + return nil, evalAgentContextError(nil) + } + } else { + prompt := azdClient.Prompt() + + if projectEndpoint == "" { + resp, err := prompt.Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Foundry project endpoint URL", + IgnoreHintKeys: true, + }, + }) + if err != nil { + azdClient.Close() + return nil, fmt.Errorf("prompting for project endpoint: %w", err) + } + projectEndpoint = strings.TrimSpace(resp.Value) + if projectEndpoint == "" { + azdClient.Close() + return nil, fmt.Errorf("project endpoint is required") + } + } + + if agentName == "" && options.requireAgent { + resp, err := prompt.Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Agent name", + IgnoreHintKeys: true, + }, + }) + if err != nil { + azdClient.Close() + return nil, fmt.Errorf("prompting for agent name: %w", err) + } + agentName = strings.TrimSpace(resp.Value) + if agentName == "" { + azdClient.Close() + return nil, fmt.Errorf("agent name is required") + } + } + } + + credential, err := newAgentCredential() + if err != nil { + azdClient.Close() + return nil, err + } + + cwd, _ := os.Getwd() + evalClient := eval_api.NewEvalClient(projectEndpoint, credential) + datasetClient := dataset_api.NewDatasetClient(projectEndpoint, credential) + + return &evalResolvedContext{ + azdClient: azdClient, + evalClient: evalClient, + datasetClient: datasetClient, + projectRoot: cwd, + agentProject: cwd, + agentProjectSource: "current directory", + agentName: agentName, + agentNameSource: "user input", + version: "", + versionSource: "unresolved", + agentKind: agent_yaml.AgentKindHosted, + agentKindSource: "default", + serviceName: "", + projectEndpoint: projectEndpoint, + projectEndpointSource: "user input", + envName: "", + }, nil +} + +func printEvalDetectedContext(resolved *evalResolvedContext, configPath string) { + fmt.Println() + fmt.Println(color.CyanString("Detected eval target:")) + if resolved.serviceName != "" { + printEvalField("Service", resolved.serviceName, "azure.yaml") + } + printEvalField("Agent", resolved.agentName, resolved.agentNameSource) + printEvalField("Version", resolved.version, resolved.versionSource) + printEvalField("Kind", string(resolved.agentKind), resolved.agentKindSource) + printEvalField("Endpoint", resolved.projectEndpoint, resolved.projectEndpointSource) + printEvalField("Project", resolved.agentProject, resolved.agentProjectSource) + fmt.Printf(" Eval config: %s\n", output.WithHighLightFormat(configPath)) + fmt.Println() +} + +func printEvalField(label, value, source string) { + padded := fmt.Sprintf("%-16s", label+":") + if value == "" || source == "unresolved" { + fmt.Printf(" %s%s\n", padded, output.WithGrayFormat("%s (%s)", value, source)) + } else { + fmt.Printf(" %s %s %s\n", + color.GreenString("(✓)"), + padded+output.WithHighLightFormat(value), + output.WithGrayFormat("(%s)", source), + ) + } +} + +func detectEvalAgentKind(agentProject string) (agent_yaml.AgentKind, string) { + for _, fileName := range []string{"agent.yaml", "agent.yml"} { + path := filepath.Join(agentProject, fileName) + data, err := os.ReadFile(path) //nolint:gosec // local agent manifest path is derived from azure.yaml service project + if err != nil { + continue + } + + var manifest struct { + Kind agent_yaml.AgentKind `yaml:"kind"` + } + if err := yaml.Unmarshal(data, &manifest); err != nil { + continue + } + if agent_yaml.IsValidAgentKind(manifest.Kind) { + return manifest.Kind, path + } + } + + return "", "" +} + +func evalAgentContextError(cause error) error { + message := "agent context could not be resolved" + if cause != nil { + message = fmt.Sprintf("%s: %s", message, cause) + } + return exterrors.Dependency( + exterrors.CodeMissingAgentEnvVars, + message, + "run 'azd ai agent init' to configure your agent, or pass --agent and --project-endpoint directly", + ) +} + +func endpointFromProjectID(projectID string) (string, error) { + project, err := extractProjectDetails(projectID) + if err != nil { + return "", err + } + return buildAgentEndpoint(project.AccountName, project.ProjectName), nil +} + +// pollEvalOperationWithSpinner polls a long-running eval operation with a spinner, updating the provided evalProgress with status. It returns the completed job or an error if the operation failed or timed out. +func pollEvalOperationWithSpinner( + ctx context.Context, + label string, + operationID string, + get eval_api.GetJobFunc, + apiVersion string, + progress *evalProgress, +) (*eval_api.GenerationJob, error) { + if operationID == "" { + return nil, fmt.Errorf("%s did not return an operation ID", strings.ToLower(label)) + } + + progress.setRunning(label, operationID) + poller := eval_api.NewPoller(operationID, apiVersion, get) + job, err := poller.Poll(ctx) + + if err != nil { + if _, ok := errors.AsType[*eval_api.PollerTimeoutError](err); ok { + progress.setTimedOut(label) + return nil, err + } + if jfe, ok := errors.AsType[*eval_api.JobFailedError](err); ok { + if body, marshalErr := json.MarshalIndent(jfe.Job, "", " "); marshalErr == nil { + log.Printf("[debug] %s: failed response:\n%s", label, body) + } + progress.setFailed(label) + errMsg := fmt.Sprintf("%s failed with status %q", strings.ToLower(label), jfe.Status) + if jfe.Job != nil && jfe.Job.Error != nil && jfe.Job.Error.Message != "" { + errMsg += ": " + jfe.Job.Error.Message + } + return nil, fmt.Errorf("%s", errMsg) + } + progress.setFailed(label) + return nil, err + } + + progress.setDone(label) + return job, nil +} + +func relPathForYaml(baseDir string, target string) string { + if rel, err := filepath.Rel(baseDir, target); err == nil { + return filepath.ToSlash(rel) + } + return filepath.ToSlash(target) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_helpers.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_helpers.go new file mode 100644 index 00000000000..f11fcfd3c95 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_helpers.go @@ -0,0 +1,307 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval_helpers.go provides shared utility functions used by both eval and +// optimize commands, including portal URL construction and path display +// helpers. + +package cmd + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "path/filepath" + + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "go.yaml.in/yaml/v3" +) + +// resolvePortalPrefix reads AZURE_AI_PROJECT_ID from the azd environment and +// returns a PortalPrefix for building Foundry portal URLs. +// Returns nil on any failure. +func resolvePortalPrefix(ctx context.Context, azdClient *azdext.AzdClient, envName string) *eval_api.PortalPrefix { + if azdClient == nil || envName == "" { + return nil + } + v, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: envName, + Key: "AZURE_AI_PROJECT_ID", + }) + if err != nil || v.Value == "" { + log.Printf("[debug] could not read AZURE_AI_PROJECT_ID: %v", err) + return nil + } + prefix, err := eval_api.NewPortalPrefix(v.Value) + if err != nil { + log.Printf("[debug] failed to build portal prefix: %v", err) + return nil + } + return prefix +} + +// buildEvalReportURL constructs the Foundry portal URL for an eval run report. +// Returns empty string on any failure. +func buildEvalReportURL(ctx context.Context, azdClient *azdext.AzdClient, envName, evalID, runID string) string { + if evalID == "" || runID == "" { + return "" + } + prefix := resolvePortalPrefix(ctx, azdClient, envName) + if prefix == nil { + return "" + } + return prefix.EvalRunURL(evalID, runID) +} + +// printPortalLink resolves the portal prefix and prints a portal URL. +// The buildURL callback receives the resolved prefix and returns the full URL. +// Best-effort — silently skips on any failure. +func printPortalLink(ctx context.Context, out io.Writer, azdClient *azdext.AzdClient, envName string, buildURL func(*eval_api.PortalPrefix) string) { + prefix := resolvePortalPrefix(ctx, azdClient, envName) + if prefix == nil { + return + } + fmt.Fprintf(out, " Portal: %s\n", color.CyanString(buildURL(prefix))) +} + +// relativeDisplay returns a project-relative path for display purposes. +// Used by both eval and optimize config confirmation prompts. +// Returns empty string for empty input. +func relativeDisplay(absPath, projectDir string) string { + if absPath == "" || projectDir == "" { + return absPath + } + if rel, err := filepath.Rel(projectDir, absPath); err == nil { + return rel + } + return absPath +} + +// reconcileConfigAgentName reconciles the agent name in a config with the +// environment-resolved name. Environment takes precedence. Returns true if +// the config was changed. Used by both eval run and optimize. +func reconcileConfigAgentName(agent *opt_eval.AgentRef, envName, configSource string) bool { + if envName == "" || agent.Name == "" || agent.Name == envName { + if envName != "" && agent.Name == "" { + agent.Name = envName + } + return false + } + fmt.Printf(" %s agent name in %s (%q) differs from environment (%q) — using environment value\n", + color.YellowString("warning:"), configSource, agent.Name, envName) + agent.Name = envName + return true +} + +// resolveAgentConfig resolves agent configuration from config metadata +// using a priority chain: +// +// 1. existingConfig's agent.config path — if the config references a +// metadata.yaml, resolve all fields from it. +// 2. Default baseline path — try .agent_configs/baseline/metadata.yaml. +// 3. Nothing found — returns nil; the caller should prompt the user +// for an instruction and then call writeBaselineIfNeeded. +// +// The returned AgentConfig contains resolved instruction file path, model, +// skill_dir, and tools_file. Eval init uses only instruction fields; +// optimize also uses skill_dir and tools_file. +func resolveAgentConfig( + existingConfig *opt_eval.Config, + projectDir string, +) *opt_eval.AgentConfig { + // Step 1: existing config has a config pointer — resolve from it. + if existingConfig != nil && existingConfig.Agent.ConfigFile != "" { + ref := opt_eval.AgentRef{ConfigFile: existingConfig.Agent.ConfigFile} + return ref.ResolveConfig(projectDir) + } + + // Step 2: try the default baseline path. + if projectDir != "" { + relPath := opt_eval.BaselineConfigRelPath() + if fileExists(filepath.Join(projectDir, relPath)) { + ref := opt_eval.AgentRef{ConfigFile: relPath} + return ref.ResolveConfig(projectDir) + } + } + + // Step 3: nothing found — caller should prompt and write baseline. + return nil +} + +// writeBaselineIfNeeded creates a baseline config when no config was resolved +// but an instruction is available. Returns the config file relative path +// (empty if nothing was written). +func writeBaselineIfNeeded( + projectDir, instruction string, +) string { + if projectDir == "" || instruction == "" { + return "" + } + defaultConfigFile := opt_eval.BaselineConfigRelPath() + absConfigFile := filepath.Join(projectDir, defaultConfigFile) + // Don't overwrite an existing baseline. + if fileExists(absConfigFile) { + return "" + } + if err := writeBaselineConfig(projectDir, baselineParams{ + Instruction: instruction, + }); err != nil { + fmt.Printf(" warning: failed to write baseline config: %s\n", err) + return "" + } + fmt.Printf(" Baseline: %s\n", absConfigFile) + return defaultConfigFile +} + +// baselineParams holds optional inputs for writing a baseline agent config. +type baselineParams struct { + Model string // agent model (optional) + Instruction string // system prompt text (optional) + SkillDir string // absolute skill dir path (empty = auto-detect) + ToolsFile string // absolute tools file path (optional) +} + +// writeBaselineConfig writes a baseline agent config to .agent_configs/baseline/. +// It creates metadata.yaml with file pointers and writes instructions.md. +// When skillDir is empty, it auto-detects a "skills" or "skill" directory. +// Used by both eval init and optimize. +func writeBaselineConfig(agentProject string, p baselineParams) error { + baseDir := filepath.Join(agentProject, opt_eval.AgentConfigsDir, opt_eval.BaselineDir) + if err := os.MkdirAll(baseDir, 0750); err != nil { + return fmt.Errorf("creating baseline directory: %w", err) + } + + meta := struct { + Model string `yaml:"model,omitempty"` + InstructionFile string `yaml:"instruction_file,omitempty"` + SkillDir string `yaml:"skill_dir,omitempty"` + ToolsFile string `yaml:"tools_file,omitempty"` + }{ + Model: p.Model, + } + + if p.Instruction != "" { + instructionPath := filepath.Join(baseDir, opt_eval.InstructionFile) + if err := os.WriteFile(instructionPath, []byte(p.Instruction), 0600); err != nil { + return fmt.Errorf("writing baseline instructions: %w", err) + } + meta.InstructionFile = opt_eval.InstructionFile + } + + // Resolve skill_dir: use explicit path, or auto-detect from project. + skillDir := p.SkillDir + if skillDir == "" { + for _, candidate := range []string{"skills", "skill"} { + dir := filepath.Join(agentProject, candidate) + if info, err := os.Stat(dir); err == nil && info.IsDir() { + skillDir = dir + break + } + } + } + if skillDir != "" { + if rel, err := filepath.Rel(baseDir, skillDir); err == nil { + meta.SkillDir = filepath.ToSlash(rel) + } else { + meta.SkillDir = skillDir + } + } + + if p.ToolsFile != "" { + if rel, err := filepath.Rel(baseDir, p.ToolsFile); err == nil { + meta.ToolsFile = filepath.ToSlash(rel) + } else { + meta.ToolsFile = p.ToolsFile + } + } + + data, err := yaml.Marshal(meta) + if err != nil { + return fmt.Errorf("serializing baseline metadata: %w", err) + } + + metaPath := filepath.Join(baseDir, opt_eval.MetadataFile) + if err := os.WriteFile(metaPath, data, 0600); err != nil { + return fmt.Errorf("writing baseline metadata: %w", err) + } + + return nil +} + +// loadJSONLFile reads a JSONL file and unmarshals each non-empty line into T. +// Returns an error if the file cannot be read, a line fails to parse, or no items are found. +func loadJSONLFile[T any](path string) ([]T, error) { + f, err := os.Open(path) //nolint:gosec // path is provided by user for local dataset + if err != nil { + return nil, fmt.Errorf("failed to open dataset file %s: %w", path, err) + } + defer f.Close() + + var items []T + scanner := bufio.NewScanner(f) + lineNum := 0 + for scanner.Scan() { + lineNum++ + line := scanner.Text() + if line == "" { + continue + } + var item T + if err := json.Unmarshal([]byte(line), &item); err != nil { + return nil, fmt.Errorf("failed to parse dataset line %d: %w", lineNum, err) + } + items = append(items, item) + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading dataset file %s: %w", path, err) + } + if len(items) == 0 { + return nil, fmt.Errorf("dataset file %s contains no items", path) + } + return items, nil +} + +// statusLabelAndColor maps a raw status to a display label and color function. +func statusLabelAndColor(status string) (string, func(string, ...any) string) { + switch status { + case "completed": + return "Completed", color.GreenString + case "succeeded": + return "Succeeded", color.GreenString + case "failed": + return "Failed", color.RedString + case "cancelled", "canceled": + return "Cancelled", color.YellowString + case "running", "in_progress": + return "Running", color.CyanString + case "partial": + return "Partial", color.YellowString + case "": + return "No runs", color.HiBlackString + default: + return status, fmt.Sprintf + } +} + +// colorizeStatus returns a colorized status string for display. +func colorizeStatus(status string) string { + label, colorFn := statusLabelAndColor(status) + return colorFn(label) +} + +// padColorizedStatus returns a fixed-width colored status string so that +// tabwriter aligns columns correctly despite ANSI escape sequences. +func padColorizedStatus(status string) string { + const statusWidth = 10 // wide enough for "Completed", "Cancelled", etc. + label, colorFn := statusLabelAndColor(status) + padded := fmt.Sprintf("%-*s", statusWidth, label) + return colorFn(padded) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_helpers_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_helpers_test.go new file mode 100644 index 00000000000..63d9f743f4c --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_helpers_test.go @@ -0,0 +1,210 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ---- relativeDisplay ---- + +func TestRelativeDisplay(t *testing.T) { + t.Parallel() + tests := []struct { + name string + absPath string + projectDir string + want string + }{ + {"relative path", filepath.Join("/project", "sub", "file.yaml"), "/project", filepath.Join("sub", "file.yaml")}, + {"same dir", filepath.Join("/project", "file.yaml"), "/project", "file.yaml"}, + {"empty absPath", "", "/project", ""}, + {"empty projectDir", "/project/file.yaml", "", "/project/file.yaml"}, + {"both empty", "", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := relativeDisplay(tt.absPath, tt.projectDir) + assert.Equal(t, tt.want, got) + }) + } +} + +// ---- reconcileConfigAgentName ---- + +func TestReconcileConfigAgentName(t *testing.T) { + t.Parallel() + t.Run("no change when names match", func(t *testing.T) { + t.Parallel() + agent := &opt_eval.AgentRef{Name: "my-agent"} + changed := reconcileConfigAgentName(agent, "my-agent", "config.yaml") + assert.False(t, changed) + assert.Equal(t, "my-agent", agent.Name) + }) + + t.Run("sets name when agent name is empty", func(t *testing.T) { + t.Parallel() + agent := &opt_eval.AgentRef{} + changed := reconcileConfigAgentName(agent, "env-agent", "config.yaml") + assert.False(t, changed) + assert.Equal(t, "env-agent", agent.Name) + }) + + t.Run("overrides when names differ", func(t *testing.T) { + t.Parallel() + agent := &opt_eval.AgentRef{Name: "config-agent"} + changed := reconcileConfigAgentName(agent, "env-agent", "config.yaml") + assert.True(t, changed) + assert.Equal(t, "env-agent", agent.Name) + }) + + t.Run("no change when envName is empty", func(t *testing.T) { + t.Parallel() + agent := &opt_eval.AgentRef{Name: "my-agent"} + changed := reconcileConfigAgentName(agent, "", "config.yaml") + assert.False(t, changed) + assert.Equal(t, "my-agent", agent.Name) + }) +} + +// ---- statusLabelAndColor ---- + +func TestStatusLabelAndColor(t *testing.T) { + t.Parallel() + tests := []struct { + status string + wantLabel string + }{ + {"completed", "Completed"}, + {"succeeded", "Succeeded"}, + {"failed", "Failed"}, + {"cancelled", "Cancelled"}, + {"canceled", "Cancelled"}, + {"running", "Running"}, + {"in_progress", "Running"}, + {"partial", "Partial"}, + {"", "No runs"}, + {"unknown_status", "unknown_status"}, + } + for _, tt := range tests { + t.Run(tt.status, func(t *testing.T) { + t.Parallel() + label, colorFn := statusLabelAndColor(tt.status) + assert.Equal(t, tt.wantLabel, label) + assert.NotNil(t, colorFn) + }) + } +} + +func TestColorizeStatus(t *testing.T) { + t.Parallel() + // colorizeStatus should return a non-empty string for any input. + assert.NotEmpty(t, colorizeStatus("completed")) + assert.NotEmpty(t, colorizeStatus("failed")) + assert.NotEmpty(t, colorizeStatus("")) + assert.NotEmpty(t, colorizeStatus("unknown")) +} + +func TestPadColorizedStatus(t *testing.T) { + t.Parallel() + // padColorizedStatus should return a non-empty string for any input. + result := padColorizedStatus("completed") + assert.NotEmpty(t, result) + // The padded string should be longer than the label due to padding + ANSI. + assert.Contains(t, result, "Completed") +} + +// ---- writeBaselineConfig ---- + +func TestWriteBaselineConfig(t *testing.T) { + t.Parallel() + t.Run("writes metadata and instruction file", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + err := writeBaselineConfig(dir, baselineParams{ + Model: "gpt-4o", + Instruction: "You are a helpful assistant.", + }) + require.NoError(t, err) + + metaPath := filepath.Join(dir, opt_eval.AgentConfigsDir, opt_eval.BaselineDir, opt_eval.MetadataFile) + assert.FileExists(t, metaPath) + + instrPath := filepath.Join(dir, opt_eval.AgentConfigsDir, opt_eval.BaselineDir, opt_eval.InstructionFile) + assert.FileExists(t, instrPath) + content, err := os.ReadFile(instrPath) //nolint:gosec // test file path + require.NoError(t, err) + assert.Equal(t, "You are a helpful assistant.", string(content)) + }) + + t.Run("writes metadata without instruction", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + err := writeBaselineConfig(dir, baselineParams{ + Model: "gpt-4o", + }) + require.NoError(t, err) + + metaPath := filepath.Join(dir, opt_eval.AgentConfigsDir, opt_eval.BaselineDir, opt_eval.MetadataFile) + assert.FileExists(t, metaPath) + + instrPath := filepath.Join(dir, opt_eval.AgentConfigsDir, opt_eval.BaselineDir, opt_eval.InstructionFile) + assert.NoFileExists(t, instrPath) + }) + + t.Run("auto-detects skill dir", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(dir, "skills"), 0750)) + + err := writeBaselineConfig(dir, baselineParams{ + Instruction: "test", + }) + require.NoError(t, err) + + metaPath := filepath.Join(dir, opt_eval.AgentConfigsDir, opt_eval.BaselineDir, opt_eval.MetadataFile) + data, err := os.ReadFile(metaPath) //nolint:gosec // test file path + require.NoError(t, err) + assert.Contains(t, string(data), "skill_dir") + }) +} + +// ---- writeBaselineIfNeeded ---- + +func TestWriteBaselineIfNeeded(t *testing.T) { + t.Parallel() + t.Run("creates baseline when none exists", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + result := writeBaselineIfNeeded(dir, "test instruction") + assert.NotEmpty(t, result) + assert.FileExists(t, filepath.Join(dir, result)) + }) + + t.Run("skips when baseline already exists", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + // Create existing baseline. + absPath := filepath.Join(dir, opt_eval.BaselineConfigRelPath()) + require.NoError(t, os.MkdirAll(filepath.Dir(absPath), 0750)) + require.NoError(t, os.WriteFile(absPath, []byte("existing"), 0600)) + + result := writeBaselineIfNeeded(dir, "test instruction") + assert.Empty(t, result) + }) + + t.Run("returns empty for empty inputs", func(t *testing.T) { + t.Parallel() + assert.Empty(t, writeBaselineIfNeeded("", "instruction")) + assert.Empty(t, writeBaselineIfNeeded("/some/dir", "")) + }) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init.go new file mode 100644 index 00000000000..126e6fced5e --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init.go @@ -0,0 +1,433 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval_init.go implements the "eval init" command, which generates a local +// eval suite (eval.yaml) for a deployed agent. It resolves context, submits +// dataset and evaluator generation jobs, polls for completion (unless +// --no-wait), downloads review artifacts, and writes the eval config. + +package cmd + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "path/filepath" + + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +// DataGenerationAPIVersion is the API version used for data generation jobs. +const DataGenerationAPIVersion = "v1" + +// evalInitFlags holds CLI flags and interactive prompt state for eval init. +type evalInitFlags struct { + // CLI flags. + name string // eval suite name + agent string // target agent name + projectEndpoint string // Foundry project endpoint + instruction string // inline agent instruction + instructionFile string // path to agent instruction file + configFile string // agent config metadata path + evalModel string // model for evaluation and generation + dataset string // existing dataset file or name + output string // eval config output path + maxSamples int // number of samples to generate + evaluators []string // built-in or custom evaluator names + noWait bool // submit and return immediately + resetDefaults bool // overwrite existing eval config + evalModelSet bool // true if --eval-model was explicitly set + maxSamplesSet bool // true if --max-samples was explicitly set + traceDays int // include traces from last N days + + // Internal state set during interactive prompts. + regenerateDataset bool + regenerateEvaluator bool +} + +func newEvalInitCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + flags := &evalInitFlags{maxSamples: defaultEvalSamples, output: defaultEvalConfigName} + cmd := &cobra.Command{ + Use: "init", + Short: "Generate a local eval suite for a deployed agent.", + Long: `Generate a local eval suite for a deployed agent. + +By default, this command submits dataset and evaluator generation jobs, waits for +completion, downloads review artifacts, and writes eval.yaml at +the agent project root. Use --no-wait to write pending operation IDs and return.`, + Example: ` azd ai agent eval init + azd ai agent eval init --gen-instruction "This agent handles restaurant reservations." --eval-model gpt-4o --max-samples 50 + azd ai agent eval init --gen-instruction-file ./instructions.md --eval-model gpt-4o + azd ai agent eval init --dataset ./tests/golden.jsonl --evaluator builtin.intent_resolution`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + logCleanup := setupDebugLogging(cmd.Flags()) + defer logCleanup() + flags.evalModelSet = cmd.Flags().Changed("eval-model") + flags.maxSamplesSet = cmd.Flags().Changed("max-samples") + return runEvalInit(ctx, flags, extCtx.NoPrompt) + }, + } + + cmd.Flags().StringVar(&flags.name, "name", "", "Name for the eval suite") + cmd.Flags().BoolVar(&flags.noWait, "no-wait", false, "Submit generation jobs and return immediately") + cmd.Flags().StringVar(&flags.agent, "agent", "", "Target agent name") + cmd.Flags().StringVarP(&flags.projectEndpoint, "project-endpoint", "p", "", "Microsoft Foundry project endpoint URL") + cmd.Flags().StringVarP(&flags.instruction, "gen-instruction", "g", "", "Agent instruction used for dataset and evaluator generation") + cmd.Flags().StringVarP(&flags.instructionFile, "gen-instruction-file", "", "", "Path to a file containing the agent instruction") + cmd.Flags().StringVar(&flags.evalModel, "eval-model", "", "Model used for evaluation and generation") + cmd.Flags().StringVar(&flags.dataset, "dataset", "", "Existing local file or registered dataset name to use for evaluation (instead of generating a new dataset)") + cmd.Flags().IntVar(&flags.maxSamples, "max-samples", defaultEvalSamples, "Number of samples to generate (15-1000)") + cmd.Flags().StringArrayVar(&flags.evaluators, "evaluator", nil, "Built-in or custom evaluator name") + cmd.Flags().StringVar(&flags.output, "out-file", defaultEvalConfigName, "Eval config path") + cmd.Flags().IntVar(&flags.traceDays, "trace-days", 0, "Include agent traces from the last N days for evaluator generation (0 = no traces)") + cmd.Flags().BoolVar(&flags.resetDefaults, "reset-defaults", false, "Overwrite an existing eval config") + + return cmd +} + +// runEvalInit executes the eval init command logic. It resolves context, +// prompts for missing options, submits generation jobs, polls for completion +// (unless --no-wait), writes the eval config, and prints next steps. +func runEvalInit(ctx context.Context, flags *evalInitFlags, noPrompt bool) error { + if flags.instruction != "" && flags.instructionFile != "" { + return fmt.Errorf("cannot use both --gen-instruction and --gen-instruction-file; provide one or the other") + } + + // Validate instruction file early when the path won't be resolved relative to a project. + if flags.instructionFile != "" { + if _, err := os.Stat(flags.instructionFile); err != nil && filepath.IsAbs(flags.instructionFile) { + return fmt.Errorf("instruction file %q is not accessible: %w", flags.instructionFile, err) + } + } + + resolved, err := resolveEvalContext(ctx, evalContextOptions{ + agent: flags.agent, + projectEndpoint: flags.projectEndpoint, + requireAgent: true, + noPrompt: noPrompt, + }) + if err != nil { + return err + } + defer resolved.azdClient.Close() + + // Resolve relative instruction file paths against the agent project directory. + if flags.instructionFile != "" && !filepath.IsAbs(flags.instructionFile) { + if resolved.projectRoot != "" { + flags.instructionFile = filepath.Join(resolved.projectRoot, flags.instructionFile) + } + if _, err := os.Stat(flags.instructionFile); err != nil { + return fmt.Errorf("instruction file %q is not accessible: %w", flags.instructionFile, err) + } + } + + configPath := eval_api.ResolveRelPath(flags.output, resolved.agentProject) + printEvalDetectedContext(resolved, configPath) + + // Load existing eval.yaml and resolve agent config. + existingCfg, hasExisting := tryLoadExistingEvalConfig(configPath) + isRegenerate := false + + // Resolve agent config: eval.yaml config → default baseline → nothing. + if flags.instruction == "" && flags.instructionFile == "" && resolved.hasProject { + var existing *opt_eval.Config + if hasExisting && !flags.resetDefaults { + existing = &existingCfg.Config + } + if agentCfg := resolveAgentConfig(existing, resolved.agentProject); agentCfg != nil { + flags.configFile = agentCfg.ConfigFile + flags.instructionFile = agentCfg.InstructionFile + fmt.Printf(" Agent Config: %s\n", filepath.Join(resolved.agentProject, agentCfg.ConfigFile)) + } + } + + // If --reset-defaults is set, clear existing state so the user can start fresh. + if flags.resetDefaults && resolved.envName != "" { + if err := opt_eval.ClearEvalState(ctx, resolved.azdClient, resolved.envName); err != nil { + log.Printf("warning: clearing eval state: %v", err) + } + } + + // Handle existing eval.yaml: prompt for regeneration, carry forward options. + if hasExisting && !flags.resetDefaults { + var keepExisting bool + keepExisting, err = handleExistingEvalConfig(ctx, resolved, existingCfg, flags, noPrompt) + if err != nil { + return err + } + if keepExisting { + fmt.Println("Keeping existing eval config unchanged.") + return nil + } + isRegenerate = true + } + + // When the user hasn't explicitly set --eval-model, use the deployed model. + if !flags.evalModelSet && resolved.envName != "" { + if v, err := resolved.azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: resolved.envName, + Key: "AZURE_AI_MODEL_DEPLOYMENT_NAME", + }); err == nil && v.Value != "" { + flags.evalModel = v.Value + } + } + + if err := promptEvalInitOptions(ctx, resolved, flags, noPrompt); err != nil { + return err + } + + // Write baseline config if none was resolved but we have an instruction. + if flags.configFile == "" && resolved.hasProject { + instruction := resolvedInstruction(flags) + if cfgFile := writeBaselineIfNeeded(resolved.agentProject, instruction); cfgFile != "" { + flags.configFile = cfgFile + } + } + + if !isRegenerate { + flags.name = resolveEvalName(flags) + } + + if flags.instruction == "" && flags.instructionFile == "" && flags.configFile == "" && + (flags.dataset == "" || len(flags.evaluators) == 0) { + return fmt.Errorf( + "one of --gen-instruction, --gen-instruction-file, --config, or both --dataset and --evaluators is required" + + " when generating eval assets for a hosted agent") + } + if flags.maxSamples < 15 || flags.maxSamples > 1000 { + return fmt.Errorf("--max-samples must be between 15 and 1000") + } + + // Build config and submit generation jobs. + evalCfg := newEvalConfig(flags, resolved) + var extraEvals opt_eval.EvaluatorList + if !isRegenerate && len(flags.evaluators) > 0 { + extraEvals = evaluatorsFromFlags(flags.evaluators) + } + + state, err := submitEvalJobs(ctx, resolved, flags, evalCfg, existingCfg, isRegenerate) + if err != nil { + return err + } + + if flags.noWait { + if state.DatasetGenOpID != "" || state.EvalGenOpID != "" { + state.InitStatus = opt_eval.InitStatusPending + } + return writePendingEvalInit(ctx, resolved, configPath, evalCfg, state) + } + + pollRes, err := pollAndFinalizeJobs(ctx, resolved, evalCfg, state, extraEvals) + if err != nil { + if _, ok := errors.AsType[*initTimeoutError](err); ok { + return writeTimedOutEvalInit(ctx, resolved, configPath, evalCfg, state) + } + return err + } + + state.InitStatus = opt_eval.InitStatusCompleted + if err := opt_eval.ClearEvalState(ctx, resolved.azdClient, resolved.envName); err != nil { + log.Printf("warning: clearing eval state: %v", err) + } + return writeAndPrintEvalResult(ctx, resolved, evalCfg, pollRes, configPath, isRegenerate) +} + +// handleExistingEvalConfig processes an existing eval.yaml by prompting for +// regeneration choices and carrying forward options that weren't overridden. +// Returns keepExisting=true if the user chose not to regenerate anything. +func handleExistingEvalConfig( + ctx context.Context, + resolved *evalResolvedContext, + existingCfg *evalConfig, + flags *evalInitFlags, + noPrompt bool, +) (keepExisting bool, err error) { + if noPrompt { + // --no-prompt: keep existing config unchanged by default. + return true, nil + } + + if err := promptRegenerateChoices(ctx, resolved, existingCfg, flags); err != nil { + return false, err + } + if !flags.regenerateDataset && !flags.regenerateEvaluator { + return true, nil + } + + // Carry forward existing options when not explicitly overridden. + if flags.name == "" && existingCfg.Name != "" { + flags.name = existingCfg.Name + } + if existingCfg.Options != nil && !flags.evalModelSet { + flags.evalModel = existingCfg.Options.EvalModel + } + if !flags.maxSamplesSet && existingCfg.MaxSamples > 0 { + flags.maxSamples = existingCfg.MaxSamples + } + if flags.traceDays == 0 && existingCfg.TraceDays > 0 { + flags.traceDays = existingCfg.TraceDays + } + return false, nil +} + +// submitEvalJobs determines which generation jobs are needed and submits them. +// It preserves existing config fields when regenerating only a subset. +func submitEvalJobs( + ctx context.Context, + resolved *evalResolvedContext, + flags *evalInitFlags, + evalCfg *evalConfig, + existingCfg *evalConfig, + isRegenerate bool, +) (*opt_eval.EvalState, error) { + state := &opt_eval.EvalState{} + + var needDatasetGen, needEvalGen bool + if isRegenerate { + needDatasetGen = flags.regenerateDataset + needEvalGen = flags.regenerateEvaluator + if !needDatasetGen { + evalCfg.DatasetFile = existingCfg.DatasetFile + evalCfg.Config.DatasetReference = existingCfg.Config.DatasetReference + } + if !needEvalGen { + evalCfg.Evaluators = existingCfg.Evaluators + } + } else { + needDatasetGen = flags.dataset == "" + needEvalGen = true + if !needDatasetGen { + datasetPath, err := resolveLocalDatasetFile(flags.dataset, resolved.agentProject) + if err != nil { + return nil, err + } + evalCfg.DatasetFile = datasetPath + } + } + + if needDatasetGen { + job, err := submitDatasetGeneration(ctx, resolved, flags) + if err != nil { + return nil, err + } + state.DatasetGenOpID = job.OperationID() + state.DatasetGenStatus = job.NormalizedStatus() + } + if needEvalGen { + job, err := submitEvaluatorGeneration(ctx, resolved, flags) + if err != nil { + return nil, err + } + state.EvalGenOpID = job.OperationID() + state.EvalGenStatus = job.NormalizedStatus() + } + + return state, nil +} + +// writeAndPrintEvalResult writes the eval config and review artifacts, then +// prints a summary of the generated assets along with portal links and +// next-step instructions. +func writeAndPrintEvalResult( + ctx context.Context, + resolved *evalResolvedContext, + evalCfg *evalConfig, + pollRes *pollResults, + configPath string, + isRegenerate bool, +) error { + if err := eval_api.WriteEvalConfig(configPath, evalCfg); err != nil { + return err + } + + if resolved.hasProject { + if err := eval_api.WriteEvalReviewArtifacts(resolved.agentProject, evalCfg); err != nil { + log.Printf("warning: writing eval review artifacts: %v", err) + } + } + if isRegenerate { + fmt.Println(color.GreenString("\nEval suite regenerated")) + } else { + fmt.Println(color.GreenString("\nEval suite created")) + } + fmt.Printf(" Config: %s\n", configPath) + if evalCfg.DatasetFile != "" { + fmt.Printf(" Dataset: %s\n", evalCfg.DatasetFile) + } else if evalCfg.DatasetReference != nil && evalCfg.DatasetReference.Name != "" { + ds := evalCfg.DatasetReference.Name + if evalCfg.DatasetReference.Version != "" { + ds += " (" + evalCfg.DatasetReference.Version + ")" + } + fmt.Printf(" Dataset: %s\n", ds) + if resolved.hasProject { + fmt.Printf(" %s\n", eval_api.DatasetArtifactPath(resolved.agentProject, evalCfg.DatasetReference)) + } + } + for _, evaluator := range evalCfg.Evaluators { + if evaluator.Name != "" { + ev := evaluator.Name + if evaluator.Version != "" { + ev += " (" + evaluator.Version + ")" + } + fmt.Printf(" Evaluator: %s\n", ev) + if resolved.hasProject && !eval_api.IsBuiltinEvaluator(evaluator.Name) { + fmt.Printf(" %s\n", + filepath.Join(resolved.agentProject, eval_api.EvaluatorLocalURI(evaluator.Name))) + } + } + } + + printEvalDimensions(pollRes) + printEvalPortalLinks(ctx, resolved, evalCfg) + + fmt.Println("\n Next steps:") + fmt.Printf(" %s\n", color.CyanString("azd ai agent eval run")) + fmt.Printf(" Run the eval suite against your agent.\n") + fmt.Printf(" %s\n", color.CyanString("azd ai agent eval update")) + fmt.Printf(" Edit the generated dataset or evaluator locally, then upload changes.\n") + return nil +} + +// printEvalDimensions prints rubric dimensions from the poll results if available. +func printEvalDimensions(results *pollResults) { + if results == nil || results.EvaluatorResult == nil { + return + } + if len(results.EvaluatorResult.Definition.Dimensions) == 0 { + return + } + eval_api.PrintEvaluatorDimensions(results.EvaluatorResult) +} + +// printEvalPortalLinks prints Foundry portal links for the generated dataset and evaluator. +func printEvalPortalLinks(ctx context.Context, resolved *evalResolvedContext, evalCfg *evalConfig) { + prefix := resolvePortalPrefix(ctx, resolved.azdClient, resolved.envName) + if prefix == nil { + return + } + hasLink := false + if evalCfg.DatasetReference != nil && evalCfg.DatasetReference.Name != "" { + fmt.Printf("\n "+color.HiBlackString("Portal:")+"\n Dataset: %s\n", + color.CyanString(prefix.DatasetURL(evalCfg.DatasetReference.Name, evalCfg.DatasetReference.Version))) + hasLink = true + } + for _, evaluator := range evalCfg.Evaluators { + if evaluator.Name != "" && !eval_api.IsBuiltinEvaluator(evaluator.Name) { + if !hasLink { + fmt.Println("\n " + color.HiBlackString("Portal:")) + hasLink = true + } + fmt.Printf(" Evaluator: %s\n", + color.CyanString(prefix.EvaluatorURL(evaluator.Name, evaluator.Version))) + } + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init_jobs.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init_jobs.go new file mode 100644 index 00000000000..abc9d32beb0 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init_jobs.go @@ -0,0 +1,447 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval_init_jobs.go handles generation job submission and polling for the +// eval init command. It submits dataset and evaluator generation requests, +// polls for completion in parallel, downloads artifacts on success, and +// persists state for resume on timeout. + +package cmd + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "path/filepath" + "sync" + + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/fatih/color" +) + +// resolveEvalName returns the eval suite name from flags, falling back to defaultEvalName. +func resolveEvalName(flags *evalInitFlags) string { + if flags.name != "" { + return flags.name + } + return defaultEvalName +} + +// resolvedInstruction returns the instruction content from flags, reading +// from file if instructionFile is set. +func resolvedInstruction(flags *evalInitFlags) string { + if flags.instructionFile != "" { + data, err := os.ReadFile(flags.instructionFile) //nolint:gosec // user-provided path validated earlier + if err != nil { + return flags.instruction + } + return string(data) + } + return flags.instruction +} + +// newEvalConfig builds an evalConfig from flags and resolved context, applying defaults as needed. +func newEvalConfig(flags *evalInitFlags, resolved *evalResolvedContext) *evalConfig { + agent := evalAgentRef{ + Name: resolved.agentName, + Kind: resolved.agentKind, + Version: resolved.version, + } + if flags.configFile != "" { + agent.ConfigFile = flags.configFile + } + if flags.instruction != "" { + agent.Instruction.Value = flags.instruction + } + if flags.instructionFile != "" { + agent.Instruction.File = flags.instructionFile + } + return &evalConfig{ + Config: opt_eval.Config{ + Name: resolveEvalName(flags), + Agent: agent, + }, + Options: &opt_eval.Options{ + EvalModel: flags.evalModel, + }, + MaxSamples: flags.maxSamples, + TraceDays: flags.traceDays, + } +} + +// submitDatasetGeneration submits a dataset generation job and returns the created job or an error. +func submitDatasetGeneration( + ctx context.Context, + resolved *evalResolvedContext, + flags *evalInitFlags, +) (*eval_api.GenerationJob, error) { + // Traces are only supported for evaluator generation, not dataset generation. + prompt := resolvedInstruction(flags) + sources := eval_api.BuildGenerationSources( + string(resolved.agentKind), resolved.agentName, resolved.version, prompt, nil, + ) + request := eval_api.NewDataGenerationJobRequest( + resolveEvalName(flags), flags.evalModel, flags.maxSamples, sources, + ) + return resolved.evalClient.CreateDataGenerationJob(ctx, request, DataGenerationAPIVersion) +} + +// submitEvaluatorGeneration submits an evaluator generation job and returns the created job or an error. +func submitEvaluatorGeneration( + ctx context.Context, + resolved *evalResolvedContext, + flags *evalInitFlags, +) (*eval_api.GenerationJob, error) { + var traces *eval_api.TraceOptions + if flags.traceDays > 0 { + traces = &eval_api.TraceOptions{Days: flags.traceDays} + } + prompt := resolvedInstruction(flags) + sources := eval_api.BuildGenerationSources( + string(resolved.agentKind), resolved.agentName, resolved.version, prompt, traces, + ) + request := eval_api.NewEvaluatorGenerationJobRequest( + resolveEvalName(flags), flags.evalModel, sources, + ) + return resolved.evalClient.CreateEvaluatorGenerationJob(ctx, request, DefaultAgentAPIVersion) +} + +// resolveLocalDatasetFile resolves the dataset flag value to an absolute path +// for the local JSONL file. If the value is relative it is resolved against +// the agent project directory. +func resolveLocalDatasetFile(dataset string, agentProject string) (string, error) { + if filepath.IsAbs(dataset) { + if _, err := os.Stat(dataset); err != nil { + return "", fmt.Errorf("dataset file %q is not accessible: %w", dataset, err) + } + return dataset, nil + } + abs := filepath.Join(agentProject, dataset) + if _, err := os.Stat(abs); err != nil { + return "", fmt.Errorf("dataset file %q is not accessible: %w", dataset, err) + } + return abs, nil +} + +func datasetFromJob(job *eval_api.GenerationJob) *evalDatasetRef { + name, version := job.ResolvedNameVersion() + if name == "" { + return nil + } + return &evalDatasetRef{ + Name: name, + Version: version, + } +} + +func evaluatorFromJob(job *eval_api.GenerationJob) (string, string) { + return job.ResolvedNameVersion() +} + +func evaluatorsFromFlags(values []string) opt_eval.EvaluatorList { + refs := make(opt_eval.EvaluatorList, len(values)) + for i, v := range values { + refs[i] = opt_eval.EvaluatorRef{Name: v} + } + return refs +} + +func buildOpenAIEvalRequest(evalCfg *evalConfig) *eval_api.CreateOpenAIEvalRequest { + return evalCfg.ToAgentTargetAdaptableEvalGroupRequest() +} + +// resumeEvalInit handles resuming an eval init when generation jobs are still pending. It polls for job completion, updates state and config on success, and persists state for later resume if polling times out. +func resumeEvalInit( + ctx context.Context, + resolved *evalResolvedContext, + configPath string, + evalCfg *evalConfig, + state *opt_eval.EvalState, +) error { + if _, err := pollAndFinalizeJobs(ctx, resolved, evalCfg, state, nil); err != nil { + if _, ok := errors.AsType[*initTimeoutError](err); ok { + return writeTimedOutEvalInit(ctx, resolved, configPath, evalCfg, state) + } + return err + } + state.InitStatus = opt_eval.InitStatusCompleted + if err := opt_eval.ClearEvalState(ctx, resolved.azdClient, resolved.envName); err != nil { + log.Printf("warning: clearing eval state: %v", err) + } + if resolved.hasProject { + if err := eval_api.WriteEvalReviewArtifacts(resolved.agentProject, evalCfg); err != nil { + log.Printf("warning: writing eval review artifacts: %v", err) + } + } + return eval_api.WriteEvalConfig(configPath, evalCfg) +} + +// pollResults carries parsed outputs from completed generation jobs so that +// the caller can display them after both jobs finish. +type pollResults struct { + EvaluatorResult *eval_api.EvaluatorResult +} + +// pollAndFinalizeJobs polls pending dataset and evaluator generation jobs in +// parallel, saves artifacts when an azd project exists, and updates state and +// evalCfg. Jobs whose status is already terminal are skipped (safe for resume). +// extraEvals are prepended to the generated evaluator list on completion; +// pass nil for fresh inits without --evaluator flags. +func pollAndFinalizeJobs( + ctx context.Context, + resolved *evalResolvedContext, + evalCfg *evalConfig, + state *opt_eval.EvalState, + extraEvals opt_eval.EvaluatorList, +) (*pollResults, error) { + results := &pollResults{} + // Each goroutine writes to distinct fields of evalCfg and state, so no + // mutex is needed for those. Only the error variables are shared across + // both goroutines and guarded by wg.Wait() (written before Wait, read after). + var ( + datasetPollErr error + evalPollErr error + wg sync.WaitGroup + ) + + hasDataset := state.DatasetGenOpID != "" + hasEval := state.EvalGenOpID != "" + needPollDataset := hasDataset && !eval_api.ParseJobStatus(state.DatasetGenStatus).IsTerminal() + needPollEval := hasEval && !eval_api.ParseJobStatus(state.EvalGenStatus).IsTerminal() + + // Build progress display labels (only for jobs that need polling). + var labels []string + if needPollDataset { + labels = append(labels, "Dataset generation") + } + if needPollEval { + labels = append(labels, "Evaluator generation") + } + progress := newEvalProgress(labels...) + progress.Start() + + if hasDataset { + wg.Go(func() { + var completed *eval_api.GenerationJob + if needPollDataset { + var err error + completed, err = pollEvalOperationWithSpinner( + ctx, "Dataset generation", state.DatasetGenOpID, + resolved.evalClient.GetDataGenerationJob, DataGenerationAPIVersion, + progress, + ) + if err != nil { + datasetPollErr = fmt.Errorf("dataset generation job %s: %w", state.DatasetGenOpID, err) + return + } + } else { + // Job was already terminal at submission — fetch it directly. + var err error + completed, err = resolved.evalClient.GetDataGenerationJob( + ctx, state.DatasetGenOpID, DataGenerationAPIVersion, + ) + if err != nil { + datasetPollErr = err + return + } + if eval_api.ParseJobStatus(completed.NormalizedStatus()).IsFailed() { + errMsg := fmt.Sprintf("dataset generation job %s failed", state.DatasetGenOpID) + if completed.Error != nil && completed.Error.Message != "" { + errMsg += ": " + completed.Error.Message + } + datasetPollErr = fmt.Errorf("%s", errMsg) + return + } + } + + state.DatasetGenStatus = completed.NormalizedStatus() + dsRef := datasetFromJob(completed) + if dsRef == nil { + return + } + evalCfg.DatasetReference = dsRef + + if resolved.hasProject { + localURI, err := eval_api.DownloadDatasetArtifact( + ctx, resolved.datasetClient, resolved.agentProject, dsRef, DefaultAgentAPIVersion, + ) + if err != nil { + log.Printf("warning: downloading dataset artifact for %q: %v", dsRef.Name, err) + } + if localURI != "" { + dsRef.LocalURI = localURI + } + } + }) + } + + if hasEval { + wg.Go(func() { + var completed *eval_api.GenerationJob + if needPollEval { + var err error + completed, err = pollEvalOperationWithSpinner( + ctx, "Evaluator generation", state.EvalGenOpID, + resolved.evalClient.GetEvaluatorGenerationJob, DefaultAgentAPIVersion, + progress, + ) + if err != nil { + evalPollErr = fmt.Errorf("evaluator generation job %s: %w", state.EvalGenOpID, err) + return + } + } else { + // Job was already terminal at submission — fetch it directly. + var err error + completed, err = resolved.evalClient.GetEvaluatorGenerationJob( + ctx, state.EvalGenOpID, DefaultAgentAPIVersion, + ) + if err != nil { + evalPollErr = err + return + } + if eval_api.ParseJobStatus(completed.NormalizedStatus()).IsFailed() { + errMsg := fmt.Sprintf("evaluator generation job %s failed", state.EvalGenOpID) + if completed.Error != nil && completed.Error.Message != "" { + errMsg += ": " + completed.Error.Message + } + evalPollErr = fmt.Errorf("%s", errMsg) + return + } + } + + // Evaluator goroutine owns: state.EvalGenStatus, evalCfg.Evaluators. + evalName, evalVersion := evaluatorFromJob(completed) + state.EvalGenStatus = completed.NormalizedStatus() + evalRef := opt_eval.EvaluatorRef{ + Name: evalName, + Version: evalVersion, + LocalURI: eval_api.EvaluatorLocalURI(evalName), + } + evalCfg.Evaluators = append(extraEvals, evalRef) + + results.EvaluatorResult = eval_api.ParseEvaluatorResult(completed.Result) + + if resolved.hasProject { + if err := eval_api.SaveEvaluatorResult(resolved.agentProject, evalName, completed.Result); err != nil { + log.Printf("warning: saving evaluator result for %q: %v", evalName, err) + } + } + }) + } + + wg.Wait() + progress.Stop() + + // If either job timed out, return a timeout error so the caller can + // persist the YAML and operation IDs for later resume. + dsTimeout := isPollerTimeout(datasetPollErr) + evalTimeout := isPollerTimeout(evalPollErr) + if dsTimeout || evalTimeout { + return results, &initTimeoutError{ + datasetOpID: state.DatasetGenOpID, + evaluatorOpID: state.EvalGenOpID, + datasetTimedOut: dsTimeout, + evaluatorTimedOut: evalTimeout, + } + } + + if datasetPollErr != nil && evalPollErr != nil { + return results, fmt.Errorf("%w\n%w", datasetPollErr, evalPollErr) + } + if datasetPollErr != nil { + return results, datasetPollErr + } + return results, evalPollErr +} + +// isPollerTimeout returns true when the error is a *eval_api.PollerTimeoutError. +func isPollerTimeout(err error) bool { + _, ok := errors.AsType[*eval_api.PollerTimeoutError](err) + return ok +} + +// initTimeoutError is returned by pollAndFinalizeJobs when one or both +// generation jobs exceed the polling timeout. The caller should persist state +// and YAML so the user can resume later. +type initTimeoutError struct { + datasetOpID string + evaluatorOpID string + datasetTimedOut bool + evaluatorTimedOut bool +} + +func (e *initTimeoutError) Error() string { + return "generation jobs did not complete within the polling timeout" +} + +func writePendingEvalInit( + ctx context.Context, + resolved *evalResolvedContext, + configPath string, + evalCfg *evalConfig, + state *opt_eval.EvalState, +) error { + if err := opt_eval.SaveEvalState(ctx, resolved.azdClient, resolved.envName, state); err != nil { + return err + } + if err := eval_api.WriteEvalConfig(configPath, evalCfg); err != nil { + return err + } + fmt.Println(color.YellowString("Eval init submitted (async)")) + if state.DatasetGenOpID != "" { + fmt.Printf(" dataset generation: %s (%s)\n", state.DatasetGenOpID, state.DatasetGenStatus) + } + if state.EvalGenOpID != "" { + fmt.Printf(" evaluator generation: %s (%s)\n", state.EvalGenOpID, state.EvalGenStatus) + } + fmt.Printf("\n Config written to: %s\n", configPath) + fmt.Println("\n When ready, run:") + fmt.Println(" azd ai agent eval run") + return nil +} + +// writeTimedOutEvalInit persists state and YAML when generation jobs exceed +// the polling timeout, allowing the user to resume later. +func writeTimedOutEvalInit( + ctx context.Context, + resolved *evalResolvedContext, + configPath string, + evalCfg *evalConfig, + state *opt_eval.EvalState, +) error { + state.InitStatus = opt_eval.InitStatusPending + if err := opt_eval.SaveEvalState(ctx, resolved.azdClient, resolved.envName, state); err != nil { + return err + } + if err := eval_api.WriteEvalConfig(configPath, evalCfg); err != nil { + return err + } + fmt.Println(color.YellowString("\nGeneration jobs timed out but are still running on the server.")) + if state.DatasetGenOpID != "" { + fmt.Printf(" dataset generation: %s\n", state.DatasetGenOpID) + } + if state.EvalGenOpID != "" { + fmt.Printf(" evaluator generation: %s\n", state.EvalGenOpID) + } + fmt.Printf("\n Config written to: %s\n", configPath) + fmt.Printf(" State saved to: azd environment %q\n", resolved.envName) + fmt.Println("\n To resume polling, run:") + fmt.Println(" azd ai agent eval run") + fmt.Println("\n To start fresh and clear timed-out state, run:") + fmt.Println(" azd ai agent eval init --reset-defaults") + return nil +} + +// tryLoadExistingEvalConfig attempts to load an eval config from the given path. +// Returns (config, true) if the file exists and parses successfully, or (nil, false) otherwise. +func tryLoadExistingEvalConfig(configPath string) (*evalConfig, bool) { + cfg, err := eval_api.LoadEvalConfig(configPath) + if err != nil { + return nil, false + } + return cfg, true +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init_prompts.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init_prompts.go new file mode 100644 index 00000000000..1e472b7f9d8 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init_prompts.go @@ -0,0 +1,360 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval_init_prompts.go implements interactive prompts for the eval init +// command, including eval suite name, instruction source, trace inclusion, +// eval model selection, and regeneration choices for existing configs. + +package cmd + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +// promptEvalInitOptions runs interactive prompts for eval init options that +// were not provided via flags: name, instruction, trace days, eval model, +// and max samples. +func promptEvalInitOptions(ctx context.Context, resolved *evalResolvedContext, flags *evalInitFlags, noPrompt bool) error { + azdClient := resolved.azdClient + if noPrompt { + return nil + } + + if flags.name == "" { + defaultName := defaultEvalName + if resolved.agentName != "" { + defaultName = resolved.agentName + } + resp, err := azdClient.Prompt().Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Eval suite name", + DefaultValue: defaultName, + IgnoreHintKeys: true, + }, + }) + if err != nil { + return fmt.Errorf("prompting for eval suite name: %w", err) + } + if value := strings.TrimSpace(resp.Value); value != "" { + flags.name = value + } + } + + if flags.configFile != "" { + // Config detected — show resolved values and let the user confirm or override. + if err := promptConfigConfirmation(ctx, azdClient, resolved, flags); err != nil { + return err + } + } else if flags.instruction == "" && flags.instructionFile == "" { + // Let the user choose between inline text or loading from a file. + inputChoices := []*azdext.SelectChoice{ + {Label: "Type inline", Value: "inline"}, + {Label: "Load from file", Value: "file"}, + } + defaultIdx := int32(0) + selResp, err := azdClient.Prompt().Select(ctx, &azdext.SelectRequest{ + Options: &azdext.SelectOptions{ + Message: "How would you like to provide the agent instruction?", + Choices: inputChoices, + SelectedIndex: &defaultIdx, + }, + }) + if err != nil { + return fmt.Errorf("prompting for instruction input method: %w", err) + } + + if inputChoices[int(*selResp.Value)].Value == "file" { + // Prompt for the file path. + pathResp, err := azdClient.Prompt().Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Path to agent instruction file", + IgnoreHintKeys: true, + }, + }) + if err != nil { + return fmt.Errorf("prompting for instruction file path: %w", err) + } + filePath := strings.TrimSpace(pathResp.Value) + // Resolve relative paths against the agent project directory. + if !filepath.IsAbs(filePath) && resolved.projectRoot != "" { + filePath = filepath.Join(resolved.projectRoot, filePath) + } + if _, err := os.Stat(filePath); err != nil { + return fmt.Errorf("instruction file %q is not accessible: %w", filePath, err) + } + flags.instructionFile = filePath + } else { + // Inline text input. + resp, err := azdClient.Prompt().Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Describe what this agent does and what scenarios to test", + IgnoreHintKeys: true, + }, + }) + if err != nil { + return fmt.Errorf("prompting for instruction: %w", err) + } + flags.instruction = strings.TrimSpace(resp.Value) + } + } + + // Ask whether to include traces for evaluator generation, unless already set via flags. + if flags.traceDays == 0 { + confirmResp, err := azdClient.Prompt().Confirm(ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: "Include agent traces for evaluator generation?", + DefaultValue: new(bool), // default false + }, + }) + if err != nil { + return fmt.Errorf("prompting for trace inclusion: %w", err) + } + if confirmResp.GetValue() { + rangeChoices := []*azdext.SelectChoice{ + {Label: "Last Day", Value: "1"}, + {Label: "Last 7 Days", Value: "7"}, + {Label: "Last 30 Days", Value: "30"}, + {Label: "Last 90 Days", Value: "90"}, + } + defaultRangeIdx := int32(1) // 7 days + rangeResp, err := azdClient.Prompt().Select(ctx, &azdext.SelectRequest{ + Options: &azdext.SelectOptions{ + Message: "Select trace time range", + Choices: rangeChoices, + SelectedIndex: &defaultRangeIdx, + }, + }) + if err != nil { + return fmt.Errorf("prompting for trace time range: %w", err) + } + days, _ := strconv.Atoi(rangeChoices[int(*rangeResp.Value)].Value) + flags.traceDays = days + } + } + + if !flags.evalModelSet { + // Read the deployed model name from the azd environment to use as default. + var deployedModel string + if resolved.envName != "" { + if v, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: resolved.envName, + Key: "AZURE_AI_MODEL_DEPLOYMENT_NAME", + }); err == nil && v.Value != "" { + deployedModel = v.Value + } + } + + choices := buildModelChoices(deployedModel) + defaultIndex := int32(0) + resp, err := azdClient.Prompt().Select(ctx, &azdext.SelectRequest{ + Options: &azdext.SelectOptions{ + Message: "Select the model for evaluation and generation", + Choices: choices, + SelectedIndex: &defaultIndex, + }, + }) + if err != nil { + return fmt.Errorf("prompting for evaluation model: %w", err) + } + selected := choices[int(*resp.Value)].Value + + // User chose to pick from another deployment in the project. + if selected == selectOtherDeployment { + selected, err = promptProjectDeployment(ctx, resolved) + if err != nil { + return err + } + } + flags.evalModel = selected + } + + if !flags.maxSamplesSet { + resp, err := azdClient.Prompt().Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Max samples (between 15 and 1000)", + DefaultValue: strconv.Itoa(defaultEvalSamples), + IgnoreHintKeys: true, + }, + }) + if err != nil { + return fmt.Errorf("prompting for max samples: %w", err) + } + if value := strings.TrimSpace(resp.Value); value != "" { + parsed, err := strconv.Atoi(value) + if err != nil || parsed < 15 || parsed > 1000 { + return fmt.Errorf("--max-samples must be between 15 and 1000") + } + flags.maxSamples = parsed + } + } + + return nil +} + +// selectOtherDeployment is the sentinel value for the "Select another deployment" +// choice in the model prompt. +const selectOtherDeployment = "__select_other_deployment__" + +// buildModelChoices builds the initial model choices for the generation model +// prompt. When deployedModel is non-empty it appears first as the default. +// A "Select another deployment" option is always appended so the user can +// browse all deployments in the Foundry project. +func buildModelChoices(deployedModel string) []*azdext.SelectChoice { + var choices []*azdext.SelectChoice + if deployedModel != "" { + choices = append(choices, &azdext.SelectChoice{ + Label: deployedModel + " (deployed)", + Value: deployedModel, + }) + } + choices = append(choices, &azdext.SelectChoice{ + Label: "Select another deployment", + Value: selectOtherDeployment, + }) + return choices +} + +// promptProjectDeployment fetches model deployments from the Foundry project +// and prompts the user to select one. +func promptProjectDeployment(ctx context.Context, resolved *evalResolvedContext) (string, error) { + var deployments []FoundryDeploymentInfo + if resolved.envName != "" { + if v, err := resolved.azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: resolved.envName, + Key: "AZURE_AI_PROJECT_ID", + }); err == nil && v.Value != "" { + if project, err := extractProjectDetails(v.Value); err == nil { + if cred, err := newAgentCredential(); err == nil { + deployments, _ = listProjectDeployments( + ctx, cred, + project.SubscriptionId, + project.ResourceGroupName, + project.AccountName, + ) + } + } + } + } + if len(deployments) == 0 { + return "", fmt.Errorf("no model deployments found in the Foundry project") + } + + choices := make([]*azdext.SelectChoice, len(deployments)) + for i, d := range deployments { + label := d.Name + if d.ModelName != "" { + label = fmt.Sprintf("%s (%s)", d.Name, d.ModelName) + } + choices[i] = &azdext.SelectChoice{Label: label, Value: d.Name} + } + + defaultIndex := int32(0) + resp, err := resolved.azdClient.Prompt().Select(ctx, &azdext.SelectRequest{ + Options: &azdext.SelectOptions{ + Message: "Select a model deployment", + Choices: choices, + SelectedIndex: &defaultIndex, + }, + }) + if err != nil { + return "", fmt.Errorf("prompting for model deployment: %w", err) + } + return choices[int(*resp.Value)].Value, nil +} + +// promptRegenerateChoices asks the user whether to regenerate the existing +// dataset and evaluator using individual yes/no confirmations. +func promptRegenerateChoices( + ctx context.Context, + resolved *evalResolvedContext, + existingCfg *evalConfig, + flags *evalInitFlags, +) error { + prompt := resolved.azdClient.Prompt() + + // Ask about dataset. + datasetLabel := existingCfg.DatasetFile + if datasetLabel == "" && existingCfg.DatasetReference != nil { + datasetLabel = existingCfg.DatasetReference.Name + } + if datasetLabel != "" { + resp, err := prompt.Confirm(ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: fmt.Sprintf("Existing dataset: %s. Do you want to regenerate?", datasetLabel), + DefaultValue: new(false), + }, + }) + if err != nil { + return fmt.Errorf("prompting for dataset regeneration: %w", err) + } + if resp.Value != nil && *resp.Value { + flags.regenerateDataset = true + } + } + + // Ask about evaluator. + if len(existingCfg.Evaluators) > 0 { + evalLabel := strings.Join(existingCfg.Evaluators.Names(), ", ") + resp, err := prompt.Confirm(ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: fmt.Sprintf("Existing evaluator: %s. Do you want to regenerate?", evalLabel), + DefaultValue: new(false), + }, + }) + if err != nil { + return fmt.Errorf("prompting for evaluator regeneration: %w", err) + } + if resp.Value != nil && *resp.Value { + flags.regenerateEvaluator = true + } + } else { + // No evaluators exist — generate one by default. + flags.regenerateEvaluator = true + } + + return nil +} + +// promptConfigConfirmation shows the resolved instruction file from +// metadata.yaml and lets the user confirm or override it. +func promptConfigConfirmation( + ctx context.Context, + azdClient *azdext.AzdClient, + resolved *evalResolvedContext, + flags *evalInitFlags, +) error { + prompt := azdClient.Prompt() + projectDir := resolved.agentProject + + // Instruction file. + instrDefault := relativeDisplay(flags.instructionFile, projectDir) + resp, err := prompt.Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Instruction file", + DefaultValue: instrDefault, + IgnoreHintKeys: true, + }, + }) + if err != nil { + return fmt.Errorf("prompting for instruction file: %w", err) + } + if value := strings.TrimSpace(resp.Value); value != "" { + if !filepath.IsAbs(value) && projectDir != "" { + value = filepath.Join(projectDir, value) + } + if _, err := os.Stat(value); err != nil { + return fmt.Errorf("instruction file %q is not accessible: %w", value, err) + } + flags.instructionFile = value + flags.instruction = "" // file takes precedence + } + + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init_test.go new file mode 100644 index 00000000000..f923023617e --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_init_test.go @@ -0,0 +1,693 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/agent_yaml" + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// newEvalInitCommand — command shape +// --------------------------------------------------------------------------- + +func TestNewEvalInitCommand_Flags(t *testing.T) { + t.Parallel() + cmd := newEvalInitCommand(&azdext.ExtensionContext{}) + + expectedFlags := []struct { + name string + defaultValue string + }{ + {"name", ""}, + {"no-wait", "false"}, + {"agent", ""}, + {"project-endpoint", ""}, + {"gen-instruction", ""}, + {"gen-instruction-file", ""}, + {"eval-model", ""}, + {"dataset", ""}, + {"max-samples", "15"}, + {"out-file", defaultEvalConfigName}, + {"reset-defaults", "false"}, + } + + for _, ef := range expectedFlags { + t.Run(ef.name, func(t *testing.T) { + f := cmd.Flags().Lookup(ef.name) + require.NotNil(t, f, "flag %q should exist", ef.name) + assert.Equal(t, ef.defaultValue, f.DefValue) + }) + } +} + +func TestNewEvalInitCommand_NoArgs(t *testing.T) { + t.Parallel() + cmd := newEvalInitCommand(&azdext.ExtensionContext{}) + assert.NoError(t, cmd.Args(cmd, nil)) + assert.Error(t, cmd.Args(cmd, []string{"extra"})) +} + +func TestNewEvalInitCommand_NoShortOutFile(t *testing.T) { + t.Parallel() + cmd := newEvalInitCommand(&azdext.ExtensionContext{}) + f := cmd.Flags().ShorthandLookup("o") + assert.Nil(t, f, "flag -o shorthand must not exist (conflicts with azd global --output)") +} + +// --------------------------------------------------------------------------- +// --agent-instruction / --agent-instruction-file mutual exclusion +// --------------------------------------------------------------------------- + +func TestRunEvalInit_MutualExclusion(t *testing.T) { + t.Parallel() + flags := &evalInitFlags{ + instruction: "inline text", + instructionFile: "some-file.txt", + } + err := runEvalInit(t.Context(), flags, true) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot use both --gen-instruction and --gen-instruction-file") +} + +func TestRunEvalInit_InstructionFile(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + instrFile := filepath.Join(tmpDir, "instruction.md") + require.NoError(t, os.WriteFile(instrFile, []byte(" Test booking agent \n"), 0600)) + + flags := &evalInitFlags{ + instructionFile: instrFile, + evalModel: "test-model", + maxSamples: 10, + } + // runEvalInit will fail later (no azd client), but file validation should pass. + _ = runEvalInit(t.Context(), flags, true) + // File path remains on the flag — content is NOT inlined. + assert.Equal(t, instrFile, flags.instructionFile) + assert.Empty(t, flags.instruction) +} + +func TestRunEvalInit_InstructionFileMissing(t *testing.T) { + t.Parallel() + // Use filepath.Join with TempDir to get a proper absolute path that doesn't exist. + missingFile := filepath.Join(t.TempDir(), "nonexistent", "instruction.txt") + flags := &evalInitFlags{ + instructionFile: missingFile, + projectEndpoint: "https://example.ai.azure.com/", + } + err := runEvalInit(t.Context(), flags, true) + require.Error(t, err) + assert.Contains(t, err.Error(), "not accessible") +} + +// --------------------------------------------------------------------------- +// newEvalConfig +// --------------------------------------------------------------------------- + +func TestNewEvalConfig(t *testing.T) { + t.Parallel() + + t.Run("uses default name", func(t *testing.T) { + t.Parallel() + flags := &evalInitFlags{ + instruction: "Test the booking agent", + evalModel: "gpt-4.1", + maxSamples: 50, + } + resolved := &evalResolvedContext{ + agentName: "booking-agent", + agentKind: agent_yaml.AgentKindHosted, + version: "v2", + } + + cfg := newEvalConfig(flags, resolved) + + assert.Equal(t, defaultEvalName, cfg.Name) + assert.Equal(t, "booking-agent", cfg.Agent.Name) + assert.Equal(t, agent_yaml.AgentKindHosted, cfg.Agent.Kind) + assert.Equal(t, "v2", cfg.Agent.Version) + assert.Equal(t, "gpt-4.1", cfg.Options.EvalModel) + assert.Equal(t, "Test the booking agent", cfg.Agent.Instruction.Value) + assert.Equal(t, 50, cfg.MaxSamples) + }) + + t.Run("uses custom name from flag", func(t *testing.T) { + t.Parallel() + flags := &evalInitFlags{ + name: "my-suite", + maxSamples: 10, + } + resolved := &evalResolvedContext{agentName: "a"} + cfg := newEvalConfig(flags, resolved) + assert.Equal(t, "my-suite", cfg.Name) + }) + + t.Run("stores instruction_file when file provided", func(t *testing.T) { + t.Parallel() + flags := &evalInitFlags{ + instructionFile: "./prompts/system.md", + evalModel: "gpt-4o", + maxSamples: 20, + } + resolved := &evalResolvedContext{ + agentName: "my-agent", + agentKind: agent_yaml.AgentKindHosted, + version: "v1", + } + + cfg := newEvalConfig(flags, resolved) + + assert.Empty(t, cfg.Agent.Instruction.Value) + assert.Equal(t, "./prompts/system.md", cfg.Agent.Instruction.File) + }) +} + +// --------------------------------------------------------------------------- +// datasetFromJob +// --------------------------------------------------------------------------- + +func TestDatasetFromJob(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + job *eval_api.GenerationJob + expectedName string + expectedVersion string + expectedNil bool + }{ + { + "result fields", + &eval_api.GenerationJob{ + Result: json.RawMessage(`{"name":"ds-1","version":"v2"}`), + }, + "ds-1", "v2", false, + }, + { + "result name defaults version to latest", + &eval_api.GenerationJob{ + Result: json.RawMessage(`{"outputs":[{"name":"ds-2"}]}`), + }, + "ds-2", "latest", false, + }, + { + "nested outputs format", + &eval_api.GenerationJob{ + Result: json.RawMessage(`{"outputs":[{"name":"ds-3","version":"v3"}]}`), + }, + "ds-3", "v3", false, + }, + { + "empty result returns nil", + &eval_api.GenerationJob{}, + "", "", true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ref := datasetFromJob(tt.job) + if tt.expectedNil { + assert.Nil(t, ref) + } else { + require.NotNil(t, ref) + assert.Equal(t, tt.expectedName, ref.Name) + assert.Equal(t, tt.expectedVersion, ref.Version) + } + }) + } +} + +// --------------------------------------------------------------------------- +// parseDatasetURI +// --------------------------------------------------------------------------- + +func TestIsDatasetName(t *testing.T) { + t.Parallel() + + t.Run("simple name is a dataset name", func(t *testing.T) { + t.Parallel() + assert.True(t, eval_api.IsDatasetName("eval-data-2026-04-16")) + }) + + t.Run("name with dots but no data extension", func(t *testing.T) { + t.Parallel() + assert.True(t, eval_api.IsDatasetName("my-dataset.v2")) + }) + + t.Run("jsonl file is not a name", func(t *testing.T) { + t.Parallel() + assert.False(t, eval_api.IsDatasetName("golden.jsonl")) + }) + + t.Run("json file is not a name", func(t *testing.T) { + t.Parallel() + assert.False(t, eval_api.IsDatasetName("data.json")) + }) + + t.Run("csv file is not a name", func(t *testing.T) { + t.Parallel() + assert.False(t, eval_api.IsDatasetName("results.csv")) + }) + + t.Run("path with separator is not a name", func(t *testing.T) { + t.Parallel() + assert.False(t, eval_api.IsDatasetName("./tests/golden.jsonl")) + }) + + t.Run("empty string is not a name", func(t *testing.T) { + t.Parallel() + assert.False(t, eval_api.IsDatasetName("")) + }) +} + +// --------------------------------------------------------------------------- +// buildModelChoices +// --------------------------------------------------------------------------- + +func TestBuildModelChoices(t *testing.T) { + t.Parallel() + + t.Run("no deployed model has select-other only", func(t *testing.T) { + t.Parallel() + choices := buildModelChoices("") + require.Len(t, choices, 1) + assert.Equal(t, selectOtherDeployment, choices[0].Value) + assert.Equal(t, "Select another deployment", choices[0].Label) + }) + + t.Run("deployed model first then select-other", func(t *testing.T) { + t.Parallel() + choices := buildModelChoices("my-deployment") + require.Len(t, choices, 2) + assert.Equal(t, "my-deployment", choices[0].Value) + assert.Contains(t, choices[0].Label, "(deployed)") + assert.Equal(t, selectOtherDeployment, choices[1].Value) + }) +} + +// --------------------------------------------------------------------------- +// evaluatorFromJob +// --------------------------------------------------------------------------- + +func TestEvaluatorFromJob(t *testing.T) { + t.Parallel() + + t.Run("extracts name and version from result", func(t *testing.T) { + t.Parallel() + job := &eval_api.GenerationJob{ + Result: json.RawMessage(`{"name":"quality-eval","version":"v2"}`), + } + name, version := evaluatorFromJob(job) + assert.Equal(t, "quality-eval", name) + assert.Equal(t, "v2", version) + }) + + t.Run("defaults version to latest", func(t *testing.T) { + t.Parallel() + job := &eval_api.GenerationJob{ + Result: json.RawMessage(`{"name":"smoke-core","display_name":"smoke-core"}`), + } + name, version := evaluatorFromJob(job) + assert.Equal(t, "smoke-core", name) + assert.Equal(t, "latest", version) + }) + + t.Run("returns empty name when no result", func(t *testing.T) { + t.Parallel() + job := &eval_api.GenerationJob{} + name, version := evaluatorFromJob(job) + assert.Empty(t, name) + assert.Empty(t, version) + }) +} + +// --------------------------------------------------------------------------- +// eval_api.BuildGenerationSources +// --------------------------------------------------------------------------- + +func TestBuildGenerationSources(t *testing.T) { + t.Parallel() + + t.Run("hosted agent includes prompt and agent sources", func(t *testing.T) { + t.Parallel() + sources := eval_api.BuildGenerationSources( + string(agent_yaml.AgentKindHosted), "my-agent", "v2", + "Test customer service interactions", nil, + ) + require.Len(t, sources, 2) + + // First source: prompt + assert.Equal(t, "prompt", sources[0].Type) + assert.Equal(t, "Test customer service interactions", sources[0].Prompt) + + // Second source: agent + assert.Equal(t, "agent", sources[1].Type) + assert.Equal(t, "my-agent", sources[1].AgentName) + assert.Equal(t, "v2", sources[1].AgentVersion) + assert.Empty(t, sources[1].Prompt) + }) + + t.Run("hosted agent without instruction omits prompt source", func(t *testing.T) { + t.Parallel() + sources := eval_api.BuildGenerationSources( + string(agent_yaml.AgentKindHosted), "my-agent", "v1", "", nil, + ) + require.Len(t, sources, 1) + assert.Equal(t, "agent", sources[0].Type) + }) +} + +// --------------------------------------------------------------------------- +// evaluatorsFromFlags +// --------------------------------------------------------------------------- + +func TestEvaluatorsFromFlags(t *testing.T) { + t.Parallel() + + t.Run("passes through strings", func(t *testing.T) { + t.Parallel() + result := evaluatorsFromFlags([]string{"builtin.task_adherence", "my-custom"}) + require.Len(t, result, 2) + assert.Equal(t, "builtin.task_adherence", result[0].Name) + assert.Equal(t, "my-custom", result[1].Name) + }) + + t.Run("nil returns nil", func(t *testing.T) { + t.Parallel() + result := evaluatorsFromFlags(nil) + assert.Empty(t, result) + }) +} + +// --------------------------------------------------------------------------- +// buildOpenAIEvalRequest +// --------------------------------------------------------------------------- + +func TestBuildOpenAIEvalRequest(t *testing.T) { + t.Parallel() + + cfg := &evalConfig{ + Config: opt_eval.Config{ + Name: "smoke-core", + Agent: evalAgentRef{ + Name: "agent-1", + Version: "v1", + }, + DatasetReference: &evalDatasetRef{Name: "ds", Version: "v1"}, + Evaluators: opt_eval.EvaluatorList{{Name: "builtin.quality"}}, + }, + Options: &opt_eval.Options{EvalModel: "gpt-4o"}, + } + + req := buildOpenAIEvalRequest(cfg) + + assert.Equal(t, "smoke-core", req.Name) + assert.Equal(t, "agent-1", req.Metadata["azd_agent"]) + assert.Equal(t, "v1", req.Metadata["azd_agent_version"]) + require.NotNil(t, req.DataSourceConfig) + assert.Equal(t, "custom", req.DataSourceConfig.Type) + require.Len(t, req.TestingCriteria, 1) + assert.Equal(t, "azure_ai_evaluator", req.TestingCriteria[0].Type) + assert.Equal(t, "builtin.quality", req.TestingCriteria[0].EvaluatorName) + assert.Equal(t, "gpt-4o", req.TestingCriteria[0].InitializationParameters["model"]) + assert.Equal(t, "{{item.query}}", req.TestingCriteria[0].DataMapping["query"]) + assert.Equal(t, "{{sample.output_items}}", req.TestingCriteria[0].DataMapping["response"]) +} + +func TestBuildOpenAIEvalRequest_WithDatasetFile(t *testing.T) { + t.Parallel() + + cfg := &evalConfig{ + Config: opt_eval.Config{ + Name: "test-eval", + Agent: evalAgentRef{Name: "agent-1"}, + DatasetFile: "tasks.jsonl", + }, + } + + req := buildOpenAIEvalRequest(cfg) + require.NotNil(t, req.DataSourceConfig) + assert.Equal(t, "custom", req.DataSourceConfig.Type) + assert.Empty(t, req.TestingCriteria) +} + +// --------------------------------------------------------------------------- +// resolveLocalDatasetFile +// --------------------------------------------------------------------------- + +func TestResolveLocalDatasetFile_Absolute(t *testing.T) { + t.Parallel() + dir := t.TempDir() + f := filepath.Join(dir, "tasks.jsonl") + require.NoError(t, os.WriteFile(f, []byte(`{"query":"hi"}`+"\n"), 0600)) + + result, err := resolveLocalDatasetFile(f, "/other") + require.NoError(t, err) + assert.Equal(t, f, result) +} + +func TestResolveLocalDatasetFile_Relative(t *testing.T) { + t.Parallel() + dir := t.TempDir() + f := filepath.Join(dir, "data.jsonl") + require.NoError(t, os.WriteFile(f, []byte(`{"query":"hi"}`+"\n"), 0600)) + + result, err := resolveLocalDatasetFile("data.jsonl", dir) + require.NoError(t, err) + assert.Equal(t, f, result) +} + +func TestResolveLocalDatasetFile_NotFound(t *testing.T) { + t.Parallel() + _, err := resolveLocalDatasetFile("missing.jsonl", t.TempDir()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not accessible") +} + +// --------------------------------------------------------------------------- +// tryLoadExistingEvalConfig +// --------------------------------------------------------------------------- + +func TestTryLoadExistingEvalConfig_Found(t *testing.T) { + t.Parallel() + dir := t.TempDir() + cfgPath := filepath.Join(dir, "eval.yaml") + cfg := &evalConfig{ + Config: opt_eval.Config{ + Name: "smoke-core", + Agent: evalAgentRef{ + Name: "my-agent", + }, + DatasetFile: "data.jsonl", + Evaluators: opt_eval.EvaluatorList{{Name: "quality"}}, + }, + } + require.NoError(t, eval_api.WriteEvalConfig(cfgPath, cfg)) + + loaded, ok := tryLoadExistingEvalConfig(cfgPath) + require.True(t, ok) + assert.Equal(t, "smoke-core", loaded.Name) + assert.Equal(t, "my-agent", loaded.Agent.Name) + assert.Equal(t, opt_eval.EvaluatorList{{Name: "quality"}}, loaded.Evaluators) +} + +func TestTryLoadExistingEvalConfig_NotFound(t *testing.T) { + t.Parallel() + cfg, ok := tryLoadExistingEvalConfig(filepath.Join(t.TempDir(), "missing.yaml")) + assert.False(t, ok) + assert.Nil(t, cfg) +} + +func TestTryLoadExistingEvalConfig_InvalidYAML(t *testing.T) { + t.Parallel() + dir := t.TempDir() + cfgPath := filepath.Join(dir, "eval.yaml") + require.NoError(t, os.WriteFile(cfgPath, []byte(":\ninvalid: [yaml"), 0600)) + + cfg, ok := tryLoadExistingEvalConfig(cfgPath) + assert.False(t, ok) + assert.Nil(t, cfg) +} + +// --------------------------------------------------------------------------- +// eval_api.SplitEvaluators / eval_api.IsBuiltinEvaluator +// --------------------------------------------------------------------------- + +func TestIsBuiltinEvaluator(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input string + expected bool + }{ + {"builtin prefix", "builtin.task_adherence", true}, + {"builtin prefix dot only", "builtin.", true}, + {"custom evaluator", "my-quality", false}, + {"empty string", "", false}, + {"similar prefix", "builtins.quality", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, eval_api.IsBuiltinEvaluator(tt.input)) + }) + } +} + +func TestSplitEvaluators(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input opt_eval.EvaluatorList + expectedGenerated opt_eval.EvaluatorList + expectedBuiltin opt_eval.EvaluatorList + }{ + { + "mixed list", + opt_eval.EvaluatorList{{Name: "builtin.task_adherence"}, {Name: "my-quality"}, {Name: "builtin.safety"}}, + opt_eval.EvaluatorList{{Name: "my-quality"}}, + opt_eval.EvaluatorList{{Name: "builtin.task_adherence"}, {Name: "builtin.safety"}}, + }, + { + "all builtin", + opt_eval.EvaluatorList{{Name: "builtin.quality"}, {Name: "builtin.safety"}}, + nil, + opt_eval.EvaluatorList{{Name: "builtin.quality"}, {Name: "builtin.safety"}}, + }, + { + "all generated", + opt_eval.EvaluatorList{{Name: "smoke-core"}, {Name: "custom-1"}}, + opt_eval.EvaluatorList{{Name: "smoke-core"}, {Name: "custom-1"}}, + nil, + }, + { + "empty list", + nil, + nil, + nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generated, builtin := eval_api.SplitEvaluators(tt.input) + assert.Equal(t, tt.expectedGenerated, generated) + assert.Equal(t, tt.expectedBuiltin, builtin) + }) + } +} + +// --------------------------------------------------------------------------- +// resolveEvalName — name resolution +// --------------------------------------------------------------------------- + +func TestResolveEvalName(t *testing.T) { + t.Parallel() + t.Run("returns flag name when set", func(t *testing.T) { + t.Parallel() + flags := &evalInitFlags{name: "my-eval"} + assert.Equal(t, "my-eval", resolveEvalName(flags)) + }) + + t.Run("returns default when flag is empty", func(t *testing.T) { + t.Parallel() + flags := &evalInitFlags{} + assert.Equal(t, defaultEvalName, resolveEvalName(flags)) + }) +} + +// --------------------------------------------------------------------------- +// resolvedInstruction — instruction from flags +// --------------------------------------------------------------------------- + +func TestResolvedInstruction(t *testing.T) { + t.Parallel() + t.Run("returns inline instruction", func(t *testing.T) { + t.Parallel() + flags := &evalInitFlags{instruction: "Be helpful."} + assert.Equal(t, "Be helpful.", resolvedInstruction(flags)) + }) + + t.Run("reads from instruction file", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + filePath := filepath.Join(dir, "prompt.md") + require.NoError(t, os.WriteFile(filePath, []byte("File instruction."), 0600)) + + flags := &evalInitFlags{instructionFile: filePath} + assert.Equal(t, "File instruction.", resolvedInstruction(flags)) + }) + + t.Run("falls back to inline when file missing", func(t *testing.T) { + t.Parallel() + flags := &evalInitFlags{ + instructionFile: "/nonexistent/path.md", + instruction: "fallback", + } + assert.Equal(t, "fallback", resolvedInstruction(flags)) + }) + + t.Run("returns empty when nothing set", func(t *testing.T) { + t.Parallel() + flags := &evalInitFlags{} + assert.Empty(t, resolvedInstruction(flags)) + }) +} + +// --------------------------------------------------------------------------- +// isPollerTimeout — timeout error detection +// --------------------------------------------------------------------------- + +func TestIsPollerTimeout(t *testing.T) { + t.Parallel() + t.Run("true for PollerTimeoutError", func(t *testing.T) { + t.Parallel() + err := &eval_api.PollerTimeoutError{} + assert.True(t, isPollerTimeout(err)) + }) + + t.Run("true for wrapped PollerTimeoutError", func(t *testing.T) { + t.Parallel() + inner := &eval_api.PollerTimeoutError{} + wrapped := fmt.Errorf("context: %w", inner) + assert.True(t, isPollerTimeout(wrapped)) + }) + + t.Run("false for other errors", func(t *testing.T) { + t.Parallel() + assert.False(t, isPollerTimeout(errors.New("some error"))) + }) + + t.Run("false for nil", func(t *testing.T) { + t.Parallel() + assert.False(t, isPollerTimeout(nil)) + }) +} + +// --------------------------------------------------------------------------- +// initTimeoutError — error message +// --------------------------------------------------------------------------- + +func TestInitTimeoutError(t *testing.T) { + t.Parallel() + err := &initTimeoutError{ + datasetOpID: "ds-123", + evaluatorOpID: "ev-456", + datasetTimedOut: true, + evaluatorTimedOut: false, + } + assert.Contains(t, err.Error(), "polling timeout") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_list.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_list.go new file mode 100644 index 00000000000..51dc7fe9a54 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_list.go @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval_list.go implements the "eval list" command, which lists recent +// evaluations for the current Foundry project with run counts and status. + +package cmd + +import ( + "context" + "fmt" + "os" + "sync" + "text/tabwriter" + + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// evalListFlags holds CLI flags for the eval list command. +type evalListFlags struct { + limit int // maximum number of evals to return +} + +func newEvalListCommand() *cobra.Command { + flags := &evalListFlags{limit: 10} + cmd := &cobra.Command{ + Use: "list", + Short: "List evaluations for the current project.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + logCleanup := setupDebugLogging(cmd.Flags()) + defer logCleanup() + return runEvalList(ctx, flags) + }, + } + cmd.Flags().IntVar(&flags.limit, "limit", 10, "Maximum number of evals to return") + return cmd +} + +// evalRunSummary holds the fetched run info for a single eval. +type evalRunSummary struct { + runCount int + lastRunStatus string +} + +func runEvalList(ctx context.Context, flags *evalListFlags) error { + resolved, err := resolveEvalContext(ctx, evalContextOptions{}) + if err != nil { + return err + } + defer resolved.azdClient.Close() + + // Load the active eval ID from the azd environment. + var activeEvalID string + if resolved.envName != "" { + state := opt_eval.LoadEvalState(ctx, resolved.azdClient, resolved.envName) + activeEvalID = state.EvalID + } + + resp, err := resolved.evalClient.ListOpenAIEvals(ctx, flags.limit, DefaultAgentAPIVersion) + if err != nil { + return fmt.Errorf("failed to list evals: %w", err) + } + + items := resp.Data + + // Fetch run summaries in parallel for each eval, bounded by a semaphore + // to avoid overwhelming the service with concurrent requests. + const maxConcurrent = 5 + sem := make(chan struct{}, maxConcurrent) + summaries := make([]evalRunSummary, len(items)) + var wg sync.WaitGroup + for i, item := range items { + wg.Add(1) + go func(idx int, evalID string) { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + runs, err := resolved.evalClient.ListOpenAIEvalRuns(ctx, evalID, 10, DefaultAgentAPIVersion) + if err != nil || runs == nil { + return + } + summaries[idx].runCount = len(runs.Data) + if len(runs.Data) > 0 { + summaries[idx].lastRunStatus = runs.Data[0].Status + } + }(i, item.ResolvedID()) + } + wg.Wait() + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, " \tEval ID\tName\tStatus of last run\tRuns\tCreated by\tCreated on") + fmt.Fprintln(w, " \t-------\t----\t------------------\t----\t----------\t----------") + for i, item := range items { + marker := " " + if item.ResolvedID() == activeEvalID { + marker = "*" + } + name := item.Name + if name == "" { + name = item.ResolvedID() + } + status := padColorizedStatus(summaries[i].lastRunStatus) + createdBy := item.CreatedBy + createdOn := eval_api.FormatTimestamp(item.CreatedAt) + + fmt.Fprintf(w, "%s \t%s\t%s\t%s\t%d\t%s\t%s\n", + marker, + item.ResolvedID(), + name, + status, + summaries[i].runCount, + createdBy, + createdOn, + ) + } + if err := w.Flush(); err != nil { + return err + } + if activeEvalID != "" { + fmt.Printf("\n* = active eval in current environment\n") + } + fmt.Printf("(showing %d — use --limit to change)\n", len(items)) + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_list_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_list_test.go new file mode 100644 index 00000000000..53ce6dfcb63 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_list_test.go @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// newEvalListCommand — command shape +// --------------------------------------------------------------------------- + +func TestNewEvalListCommand_Flags(t *testing.T) { + t.Parallel() + cmd := newEvalListCommand() + + f := cmd.Flags().Lookup("limit") + require.NotNil(t, f) + assert.Equal(t, "10", f.DefValue) +} + +func TestNewEvalListCommand_NoArgs(t *testing.T) { + t.Parallel() + cmd := newEvalListCommand() + assert.NoError(t, cmd.Args(cmd, nil)) + assert.Error(t, cmd.Args(cmd, []string{"extra"})) +} + +func TestNewEvalListCommand_UseString(t *testing.T) { + t.Parallel() + cmd := newEvalListCommand() + assert.Equal(t, "list", cmd.Use) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_progress.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_progress.go new file mode 100644 index 00000000000..3a74a61981c --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_progress.go @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval_progress.go provides a concurrent-safe progress display with an +// animated spinner for long-running eval operations (generation polling, +// eval runs). Status transitions (running → done/failed/timed-out) are +// printed above the spinner line. + +package cmd + +import ( + "fmt" + "os" + "sync" + "time" + + "github.com/fatih/color" +) + +// evalProgress prints status lines for each job and keeps a single animated +// spinner line at the bottom to show that polling is still in progress. +type evalProgress struct { + mu sync.Mutex + starts map[string]time.Time + start time.Time + stop chan struct{} + done chan struct{} + spinning bool +} + +func newEvalProgress(_ ...string) *evalProgress { + return &evalProgress{ + starts: make(map[string]time.Time), + stop: make(chan struct{}), + done: make(chan struct{}), + } +} + +var spinFrames = []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"} + +// Start launches the background spinner ticker. +func (p *evalProgress) Start() { + p.mu.Lock() + p.start = time.Now() + p.spinning = true + p.mu.Unlock() + go func() { + defer close(p.done) + frameIdx := 0 + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-p.stop: + return + case <-ticker.C: + p.mu.Lock() + if p.spinning { + elapsed := time.Since(p.start).Truncate(time.Second) + spin := spinFrames[frameIdx%len(spinFrames)] + frameIdx++ + fmt.Fprintf(os.Stdout, "\r %s waiting · %s", spin, elapsed) + } + p.mu.Unlock() + } + } + }() +} + +// Stop halts the spinner and clears its line. +func (p *evalProgress) Stop() { + select { + case <-p.stop: + return + default: + close(p.stop) + } + <-p.done + p.mu.Lock() + if p.spinning { + fmt.Fprintf(os.Stdout, "\r%-60s\r", "") + p.spinning = false + } + p.mu.Unlock() +} + +// clearSpinnerLine clears the current spinner line so a status line can be +// printed cleanly. Must be called with p.mu held. +func (p *evalProgress) clearSpinnerLine() { + if p.spinning { + fmt.Fprintf(os.Stdout, "\r%-60s\r", "") + } +} + +func (p *evalProgress) setRunning(label string, detail string) { + p.mu.Lock() + defer p.mu.Unlock() + p.starts[label] = time.Now() + p.clearSpinnerLine() + if detail != "" { + fmt.Printf(" %s %s %s\n", color.BlueString("(\u2013) Running"), label, color.HiBlackString("(%s)", detail)) + } else { + fmt.Printf(" %s %s\n", color.BlueString("(\u2013) Running"), label) + } +} + +func (p *evalProgress) setDone(label string) { + p.mu.Lock() + defer p.mu.Unlock() + elapsed := durationText(time.Since(p.starts[label])) + p.clearSpinnerLine() + fmt.Printf(" %s %s (%s)\n", color.GreenString("(✓) Done"), label, elapsed) +} + +func (p *evalProgress) setFailed(label string) { + p.mu.Lock() + defer p.mu.Unlock() + elapsed := durationText(time.Since(p.starts[label])) + p.clearSpinnerLine() + fmt.Printf(" %s %s (%s)\n", color.RedString("(x) Failed"), label, elapsed) +} + +func (p *evalProgress) setTimedOut(label string) { + p.mu.Lock() + defer p.mu.Unlock() + elapsed := durationText(time.Since(p.starts[label])) + p.clearSpinnerLine() + fmt.Printf(" %s %s (%s)\n", color.YellowString("(!) Timed out"), label, elapsed) +} + +// durationText returns a human-friendly elapsed time string. +func durationText(d time.Duration) string { + s := int(d.Seconds()) + if s < 1 { + return "less than a second" + } + if s == 1 { + return "1 second" + } + if s < 60 { + return fmt.Sprintf("%d seconds", s) + } + m := s / 60 + rem := s % 60 + if rem == 0 { + if m == 1 { + return "1 minute" + } + return fmt.Sprintf("%d minutes", m) + } + return fmt.Sprintf("%dm %ds", m, rem) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_progress_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_progress_test.go new file mode 100644 index 00000000000..cd5934768ef --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_progress_test.go @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// ---- durationText ---- + +func TestDurationText(t *testing.T) { + t.Parallel() + tests := []struct { + name string + duration time.Duration + want string + }{ + {"less than a second", 500 * time.Millisecond, "less than a second"}, + {"exactly 1 second", 1 * time.Second, "1 second"}, + {"multiple seconds", 30 * time.Second, "30 seconds"}, + {"exactly 59 seconds", 59 * time.Second, "59 seconds"}, + {"exactly 1 minute", 60 * time.Second, "1 minute"}, + {"exactly 2 minutes", 120 * time.Second, "2 minutes"}, + {"1 minute 30 seconds", 90 * time.Second, "1m 30s"}, + {"2 minutes 15 seconds", 135 * time.Second, "2m 15s"}, + {"zero", 0, "less than a second"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, durationText(tt.duration)) + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_run.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_run.go new file mode 100644 index 00000000000..28361af71de --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_run.go @@ -0,0 +1,338 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval_run.go implements the "eval run" command, which executes an evaluation +// run using an eval.yaml config. It creates or reuses an OpenAI eval, submits +// a run with the configured dataset and agent target, and polls for results. + +package cmd + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/agents" + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +// evalRunFlags holds CLI flags for the eval run command. +type evalRunFlags struct { + config string // eval config path + name string // eval run name + noWait bool // start and return immediately +} + +func newEvalRunCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + flags := &evalRunFlags{config: defaultEvalConfigName} + cmd := &cobra.Command{ + Use: "run", + Short: "Execute an evaluation run from eval.yaml.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + logCleanup := setupDebugLogging(cmd.Flags()) + defer logCleanup() + return runEvalRun(ctx, flags, extCtx.NoPrompt) + }, + } + cmd.Flags().StringVar(&flags.config, "config", defaultEvalConfigName, "Local eval config YAML") + cmd.Flags().StringVar(&flags.name, "name", "", "Name for the eval run (defaults to eval config name)") + cmd.Flags().BoolVar(&flags.noWait, "no-wait", false, "Start the run and return immediately without waiting for results") + return cmd +} + +func runEvalRun(ctx context.Context, flags *evalRunFlags, noPrompt bool) error { + resolved, err := resolveEvalContext(ctx, evalContextOptions{}) + if err != nil { + return err + } + defer resolved.azdClient.Close() + + configPath := eval_api.ResolveRelPath(flags.config, resolved.agentProject) + evalCfg, err := eval_api.LoadEvalConfig(configPath) + if err != nil { + return err + } + + // Reconcile agent name/version between environment and eval.yaml. + // Environment values take precedence; warn and update the config if they differ. + configChanged := reconcileConfigAgentName(&evalCfg.Agent, resolved.agentName, flags.config) + if resolved.agentName == "" { + resolved.agentName = evalCfg.Agent.Name + } + if resolved.version == "" { + resolved.version = evalCfg.Agent.Version + } else if evalCfg.Agent.Version != "" && evalCfg.Agent.Version != resolved.version { + fmt.Printf(" %s agent version in %s (%q) differs from environment (%q) — using environment value\n", + color.YellowString("warning:"), flags.config, evalCfg.Agent.Version, resolved.version) + evalCfg.Agent.Version = resolved.version + configChanged = true + } + if configChanged { + if err := eval_api.WriteEvalConfig(configPath, evalCfg); err != nil { + fmt.Printf(" %s failed to update %s: %s\n", color.YellowString("warning:"), flags.config, err) + } else { + fmt.Printf(" Updated %s with current environment values\n", flags.config) + } + } + + state := opt_eval.LoadEvalState(ctx, resolved.azdClient, resolved.envName) + + if state.InitStatus == opt_eval.InitStatusPending { + if err := resumeEvalInit(ctx, resolved, configPath, evalCfg, state); err != nil { + return err + } + } + + evalID := state.EvalID + if evalID != "" && !noPrompt { + // Ask whether to reuse the existing eval or create a new one. + resp, promptErr := resolved.azdClient.Prompt().Confirm(ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: fmt.Sprintf("Found existing eval %s. Reuse it?", evalID), + DefaultValue: new(false), + }, + }) + if promptErr == nil && resp.Value != nil && !*resp.Value { + evalID = "" // user chose to create a new eval + } + } + + if evalID == "" { + created, err := resolved.evalClient.CreateOpenAIEval( + ctx, buildOpenAIEvalRequest(evalCfg), DefaultAgentAPIVersion, + ) + if err != nil { + return fmt.Errorf("failed to create eval: %w", err) + } + evalID = created.ResolvedID() + if evalID == "" { + evalID = evalCfg.Name + } + state.EvalID = evalID + if err := opt_eval.SaveEvalState(ctx, resolved.azdClient, resolved.envName, state); err != nil { + return err + } + } + + runReq := &eval_api.CreateOpenAIEvalRunRequest{ + Name: resolveRunName(ctx, resolved.azdClient, flags.name, evalCfg.Name, noPrompt), + Metadata: map[string]string{"azd_agent": evalCfg.Agent.Name}, + } + + // Build agent target data source. + dataSource := eval_api.NewAgentTargetDataSource( + resolved.agentName, agentVersionPtr(resolved.version), + ) + + // Set source from local dataset file or remote dataset reference. + if evalCfg.DatasetFile != "" { + items, err := loadJSONLFile[map[string]any](evalCfg.DatasetFile) + if err != nil { + return err + } + dataSource.SetFileContent(items) + } else if evalCfg.DatasetReference != nil { + fileID := buildDatasetFileID(resolved.projectEndpoint, evalCfg.DatasetReference) + dataSource.SetFileID(fileID) + } else { + return fmt.Errorf("no dataset configured; run 'azd ai agent eval init' or specify dataset_file / dataset_reference in the eval config") + } + + runReq.DataSource = dataSource + + run, err := resolved.evalClient.CreateOpenAIEvalRun( + ctx, + evalID, + runReq, + DefaultAgentAPIVersion, + ) + if err != nil { + return fmt.Errorf("failed to start eval run: %w", err) + } + + fmt.Println(color.GreenString("Eval run started")) + fmt.Printf(" Eval: %s\n", evalID) + if run.ID != "" { + fmt.Printf(" Run: %s\n", run.ID) + } + + reportURL := buildEvalReportURL(ctx, resolved.azdClient, resolved.envName, evalID, run.ID) + if reportURL != "" { + fmt.Printf(" Report: %s\n", color.CyanString(reportURL)) + } + + if flags.noWait { + fmt.Printf("\n To view result summary, run:\n %s\n %s\n", + color.CyanString("azd ai agent eval list"), + color.CyanString("azd ai agent eval show"), + ) + return nil + } + + // Poll until the eval run reaches a terminal state. + completed, err := pollEvalRun(ctx, resolved.evalClient, evalID, run.ID) + if err != nil { + return err + } + + // Report URL was already printed above; clear it to avoid duplication. + completed.ReportURL = "" + + fmt.Println() + return printEvalRunSummary(evalID, completed) +} + +// resolveRunName determines the eval run name from the flag, interactive +// prompt, or config default (in that priority order). +func resolveRunName( + ctx context.Context, + azdClient *azdext.AzdClient, + flagName, configName string, + noPrompt bool, +) string { + if flagName != "" { + return flagName + } + + defaultName := configName + if defaultName == "" { + defaultName = defaultEvalName + } + + if !noPrompt { + resp, err := azdClient.Prompt().Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Eval run name", + DefaultValue: defaultName, + IgnoreHintKeys: true, + }, + }) + if err == nil { + if value := strings.TrimSpace(resp.Value); value != "" { + return value + } + } + } + + return defaultName +} + +// Default polling constants for eval run monitoring. +const ( + defaultEvalPollInterval = 5 * time.Second + defaultEvalMaxAttempts = 360 // ~30 minutes at 5s intervals + maxConsecutiveTransientErr = 5 +) + +// pollEvalRun polls an eval run until it reaches a terminal status. +// Terminal statuses: "completed", "failed", "canceled". +func pollEvalRun( + ctx context.Context, + client *eval_api.EvalClient, + evalID, runID string, +) (*eval_api.OpenAIEvalRun, error) { + progress := newEvalProgress() + progress.Start() + defer progress.Stop() + + progress.setRunning("Eval run", runID) + + consecutiveTransient := 0 + for range defaultEvalMaxAttempts { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(defaultEvalPollInterval): + } + + run, err := client.GetOpenAIEvalRun(ctx, evalID, runID, DefaultAgentAPIVersion) + if err != nil { + if agents.IsTransientError(err) { + consecutiveTransient++ + if consecutiveTransient <= maxConsecutiveTransientErr { + continue + } + } + progress.setFailed("Eval run") + return nil, fmt.Errorf("failed to poll eval run: %w", err) + } + consecutiveTransient = 0 + + switch run.Status { + case "completed": + progress.setDone("Eval run") + return run, nil + case "failed": + progress.setFailed("Eval run") + errMsg := "eval run failed" + if run.Error != nil { + errMsg = fmt.Sprintf("eval run failed: %v", run.Error) + } + return nil, exterrors.Dependency( + exterrors.CodeEvalRunFailed, errMsg, + "check eval run details with 'azd ai agent eval show'") + case "canceled", "cancelled": + progress.setFailed("Eval run") + return nil, exterrors.Cancelled("eval run was canceled") + } + } + + progress.setTimedOut("Eval run") + return nil, fmt.Errorf( + "eval run %s did not complete within %d attempts", + runID, defaultEvalMaxAttempts) +} + +// buildDatasetFileID constructs an azureai:// URI for a remote dataset reference. +// Format: azureai://accounts//projects//data//versions/ +// The account and project are extracted from the project endpoint URL +// (https://.services.ai.azure.com/api/projects/). +func buildDatasetFileID(projectEndpoint string, ref *opt_eval.DatasetRef) string { + account, project := parseProjectEndpoint(projectEndpoint) + version := ref.Version + if version == "" { + version = "1" + } + return fmt.Sprintf("azureai://accounts/%s/projects/%s/data/%s/versions/%s", + account, project, ref.Name, version) +} + +// parseProjectEndpoint extracts account and project names from a Foundry project endpoint URL. +func parseProjectEndpoint(endpoint string) (account, project string) { + u, err := url.Parse(endpoint) + if err != nil { + return "", "" + } + // Host format: .services.ai.azure.com + host := u.Hostname() + if idx := strings.Index(host, "."); idx > 0 { + account = host[:idx] + } + // Path format: /api/projects/ + parts := strings.Split(strings.Trim(u.Path, "/"), "/") + for i, p := range parts { + if p == "projects" && i+1 < len(parts) { + project = parts[i+1] + break + } + } + return account, project +} + +// agentVersionPtr returns a pointer to the version string, or nil if empty. +func agentVersionPtr(version string) *string { + if version == "" { + return nil + } + return &version +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_run_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_run_test.go new file mode 100644 index 00000000000..c3960bef7e0 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_run_test.go @@ -0,0 +1,168 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// newEvalRunCommand — command shape +// --------------------------------------------------------------------------- + +func TestNewEvalRunCommand_Flags(t *testing.T) { + t.Parallel() + cmd := newEvalRunCommand(nil) + + f := cmd.Flags().Lookup("config") + require.NotNil(t, f) + assert.Equal(t, defaultEvalConfigName, f.DefValue) +} + +func TestNewEvalRunCommand_NoArgs(t *testing.T) { + t.Parallel() + cmd := newEvalRunCommand(nil) + assert.NoError(t, cmd.Args(cmd, nil)) + assert.Error(t, cmd.Args(cmd, []string{"extra"})) +} + +func TestNewEvalRunCommand_UseString(t *testing.T) { + t.Parallel() + cmd := newEvalRunCommand(nil) + assert.Equal(t, "run", cmd.Use) +} + +// --------------------------------------------------------------------------- +// loadJSONLFile +// --------------------------------------------------------------------------- + +func TestLoadJSONLFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + f := filepath.Join(dir, "data.jsonl") + content := "{\"query\":\"hello\",\"id\":\"1\"}\n{\"query\":\"world\",\"id\":\"2\"}\n" + require.NoError(t, os.WriteFile(f, []byte(content), 0600)) + + items, err := loadJSONLFile[map[string]any](f) + require.NoError(t, err) + require.Len(t, items, 2) + assert.Equal(t, "hello", items[0]["query"]) + assert.Equal(t, "2", items[1]["id"]) +} + +func TestLoadJSONLFile_Empty(t *testing.T) { + t.Parallel() + dir := t.TempDir() + f := filepath.Join(dir, "empty.jsonl") + require.NoError(t, os.WriteFile(f, []byte(""), 0600)) + + _, err := loadJSONLFile[map[string]any](f) + assert.Error(t, err) + assert.Contains(t, err.Error(), "contains no items") +} + +func TestLoadJSONLFile_NotFound(t *testing.T) { + t.Parallel() + _, err := loadJSONLFile[map[string]any]("/nonexistent/data.jsonl") + assert.Error(t, err) +} + +// --------------------------------------------------------------------------- +// parseProjectEndpoint +// --------------------------------------------------------------------------- + +func TestParseProjectEndpoint(t *testing.T) { + t.Parallel() + tests := []struct { + name string + endpoint string + expectedAccount string + expectedProject string + }{ + { + "standard endpoint", + "https://foundryljm7.services.ai.azure.com/api/projects/projectljm7", + "foundryljm7", + "projectljm7", + }, + { + "endpoint with trailing slash", + "https://myaccount.services.ai.azure.com/api/projects/myproject/", + "myaccount", + "myproject", + }, + { + "empty string", + "", + "", + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account, project := parseProjectEndpoint(tt.endpoint) + assert.Equal(t, tt.expectedAccount, account) + assert.Equal(t, tt.expectedProject, project) + }) + } +} + +// --------------------------------------------------------------------------- +// buildDatasetFileID +// --------------------------------------------------------------------------- + +func TestBuildDatasetFileID(t *testing.T) { + t.Parallel() + tests := []struct { + name string + endpoint string + ref *opt_eval.DatasetRef + expected string + }{ + { + "with version", + "https://foundryljm7.services.ai.azure.com/api/projects/projectljm7", + &opt_eval.DatasetRef{Name: "bugbash-mt-sim-scenarios", Version: "1"}, + "azureai://accounts/foundryljm7/projects/projectljm7/data/bugbash-mt-sim-scenarios/versions/1", + }, + { + "default version", + "https://myaccount.services.ai.azure.com/api/projects/myproject", + &opt_eval.DatasetRef{Name: "my-dataset"}, + "azureai://accounts/myaccount/projects/myproject/data/my-dataset/versions/1", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildDatasetFileID(tt.endpoint, tt.ref) + assert.Equal(t, tt.expected, result) + }) + } +} + +// --------------------------------------------------------------------------- +// agentVersionPtr — version string to pointer +// --------------------------------------------------------------------------- + +func TestAgentVersionPtr(t *testing.T) { + t.Parallel() + t.Run("returns nil for empty string", func(t *testing.T) { + t.Parallel() + assert.Nil(t, agentVersionPtr("")) + }) + + t.Run("returns pointer to version", func(t *testing.T) { + t.Parallel() + v := agentVersionPtr("v2") + require.NotNil(t, v) + assert.Equal(t, "v2", *v) + }) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_show.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_show.go new file mode 100644 index 00000000000..22e77f4ad88 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_show.go @@ -0,0 +1,187 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval_show.go implements the "eval show" command, which displays eval +// definitions, run history, and per-criteria result breakdowns. + +package cmd + +import ( + "context" + "fmt" + "os" + "text/tabwriter" + + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +// evalShowFlags holds CLI flags for the eval show command. +type evalShowFlags struct { + evalRunID string // specific eval run to show + limit int // maximum number of runs to display + output string // export results to JSON file +} + +func newEvalShowCommand() *cobra.Command { + flags := &evalShowFlags{limit: 20} + cmd := &cobra.Command{ + Use: "show [eval-id]", + Short: "Show an eval definition, run history, or run details.", + Long: `Show an eval definition, run history, or run details. + +If eval-id is omitted, the most recent eval from the current environment is used.`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + logCleanup := setupDebugLogging(cmd.Flags()) + defer logCleanup() + + var evalID string + if len(args) > 0 { + evalID = args[0] + } + return runEvalShow(ctx, evalID, flags) + }, + } + cmd.Flags().StringVar(&flags.evalRunID, "eval-run-id", "", "Show details for a specific eval run") + cmd.Flags().IntVar(&flags.limit, "limit", 20, "Maximum number of runs to show") + cmd.Flags().StringVarP(&flags.output, "out-file", "O", "", "Export full run results to a JSON file") + return cmd +} + +func runEvalShow(ctx context.Context, evalID string, flags *evalShowFlags) error { + resolved, err := resolveEvalContext(ctx, evalContextOptions{}) + if err != nil { + return err + } + defer resolved.azdClient.Close() + + // Fall back to the eval ID stored in the azd environment. + if evalID == "" && resolved.envName != "" { + state := opt_eval.LoadEvalState(ctx, resolved.azdClient, resolved.envName) + evalID = state.EvalID + } + if evalID == "" { + return fmt.Errorf("no eval-id provided and none found in the current environment; run 'azd ai agent eval init' first or pass an eval-id") + } + + if flags.evalRunID != "" { + run, err := resolved.evalClient.GetOpenAIEvalRun(ctx, evalID, flags.evalRunID, DefaultAgentAPIVersion) + if err != nil { + return fmt.Errorf("failed to get eval run: %w", err) + } + if flags.output != "" { + return eval_api.WriteJSONFile(flags.output, run) + } + return printEvalRunSummary(evalID, run) + } + + evalObj, err := resolved.evalClient.GetOpenAIEval(ctx, evalID, DefaultAgentAPIVersion) + if err != nil { + return fmt.Errorf("failed to get eval: %w", err) + } + runs, err := resolved.evalClient.ListOpenAIEvalRuns(ctx, evalID, flags.limit, DefaultAgentAPIVersion) + if err != nil { + return fmt.Errorf("failed to list eval runs: %w", err) + } + if flags.output != "" { + return eval_api.WriteJSONFile(flags.output, map[string]any{ + "eval": evalObj, + "runs": runs.Data, + }) + } + return printEvalSummary(evalObj, runs.Data, flags.limit) +} + +func printEvalSummary(evalObj *eval_api.OpenAIEval, runs []eval_api.OpenAIEvalRun, limit int) error { + fmt.Printf("Eval: %s\n", evalObj.ResolvedID()) + if evalObj.Name != "" { + fmt.Printf("Name: %s\n", evalObj.Name) + } + if agent := evalObj.Metadata["azd_agent"]; agent != "" { + fmt.Printf("Agent: %s\n", agent) + } + fmt.Printf("Created: %s\n", eval_api.FormatTimestamp(evalObj.CreatedAt)) + if evalObj.CreatedBy != "" { + fmt.Printf("Created by: %s\n", evalObj.CreatedBy) + } + fmt.Printf("Runs: %d\n\n", len(runs)) + fmt.Println("Recent runs:") + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, " Run ID\tStatus\tPassed\tFailed\tCreated") + fmt.Fprintln(w, " ------\t------\t------\t------\t-------") + for _, run := range runs { + passed, failed := "", "" + if run.ResultCounts != nil { + passed = fmt.Sprintf("%d/%d", run.ResultCounts.Passed, run.ResultCounts.Total) + failed = fmt.Sprintf("%d", run.ResultCounts.Failed) + } + fmt.Fprintf(w, " %s\t%s\t%s\t%s\t%s\n", + run.ID, + colorizeStatus(run.Status), + passed, + failed, + eval_api.FormatTimestamp(run.CreatedAt), + ) + } + if err := w.Flush(); err != nil { + return err + } + fmt.Printf("\n(showing %d runs — use --limit to change)\n", min(limit, len(runs))) + return nil +} + +func printEvalRunSummary(evalID string, run *eval_api.OpenAIEvalRun) error { + fmt.Printf("Eval: %s\n", evalID) + fmt.Printf("Run: %s\n", run.ID) + if run.Name != "" { + fmt.Printf("Name: %s\n", run.Name) + } + fmt.Printf("Status: %s\n", colorizeStatus(run.Status)) + fmt.Printf("Created: %s\n", eval_api.FormatTimestamp(run.CreatedAt)) + if run.CreatedBy != "" { + fmt.Printf("Created by: %s\n", run.CreatedBy) + } + + // Agent target info from data source. + if run.DataSource != nil && run.DataSource.Target != nil { + agent := run.DataSource.Target.Name + if run.DataSource.Target.Version != nil { + agent += " v" + *run.DataSource.Target.Version + } + fmt.Printf("Agent: %s\n", agent) + } + + // Result counts. + if rc := run.ResultCounts; rc != nil { + fmt.Printf("\nResults: %d total, %s passed, %s failed, %s errored\n", + rc.Total, + color.GreenString("%d", rc.Passed), + color.RedString("%d", rc.Failed), + color.YellowString("%d", rc.Errored), + ) + } + + // Per-criteria breakdown. + if len(run.PerTestingCriteria) > 0 { + fmt.Println("\nPer-criteria results:") + for _, c := range run.PerTestingCriteria { + fmt.Printf(" %s: %s passed, %s failed, %s errored\n", + c.TestingCriteria, + color.GreenString("%d", c.Passed), + color.RedString("%d", c.Failed), + color.YellowString("%d", c.Errored), + ) + } + } + + if run.ReportURL != "" { + fmt.Printf("\nReport: %s\n", run.ReportURL) + } + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_show_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_show_test.go new file mode 100644 index 00000000000..27fe3187b00 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_show_test.go @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ---- newEvalShowCommand ---- + +func TestNewEvalShowCommand_UseString(t *testing.T) { + t.Parallel() + cmd := newEvalShowCommand() + assert.Equal(t, "show [eval-id]", cmd.Use) +} + +func TestNewEvalShowCommand_Flags(t *testing.T) { + t.Parallel() + cmd := newEvalShowCommand() + + tests := []struct { + name string + flag string + wantNil bool + defValue string + }{ + {"eval-run-id flag", "eval-run-id", false, ""}, + {"limit flag", "limit", false, "20"}, + {"out-file flag", "out-file", false, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + f := cmd.Flags().Lookup(tt.flag) + if tt.wantNil { + assert.Nil(t, f) + } else { + require.NotNil(t, f, "--%s flag should be registered", tt.flag) + assert.Equal(t, tt.defValue, f.DefValue) + } + }) + } +} + +func TestNewEvalShowCommand_AcceptsOptionalPositionalArg(t *testing.T) { + t.Parallel() + cmd := newEvalShowCommand() + // MaximumNArgs(1) — should accept 0 args without error from arg validation. + assert.NotNil(t, cmd.Args) +} + +func TestNewEvalShowCommand_HasOutFileShorthand(t *testing.T) { + t.Parallel() + cmd := newEvalShowCommand() + f := cmd.Flags().Lookup("out-file") + require.NotNil(t, f) + assert.Equal(t, "O", f.Shorthand) +} + +// ---- newEvalUpdateCommand ---- + +func TestNewEvalUpdateCommand_UseString(t *testing.T) { + t.Parallel() + cmd := newEvalUpdateCommand(&azdext.ExtensionContext{}) + assert.Equal(t, "update", cmd.Use) +} + +func TestNewEvalUpdateCommand_Flags(t *testing.T) { + t.Parallel() + cmd := newEvalUpdateCommand(&azdext.ExtensionContext{}) + + tests := []struct { + name string + flag string + defValue string + }{ + {"config flag", "config", defaultEvalConfigName}, + {"dataset-only flag", "dataset-only", "false"}, + {"evaluator-only flag", "evaluator-only", "false"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + f := cmd.Flags().Lookup(tt.flag) + require.NotNil(t, f, "--%s flag should be registered", tt.flag) + assert.Equal(t, tt.defValue, f.DefValue) + }) + } +} + +func TestNewEvalUpdateCommand_NoArgs(t *testing.T) { + t.Parallel() + cmd := newEvalUpdateCommand(&azdext.ExtensionContext{}) + assert.NotNil(t, cmd.Args) +} + +// ---- eval "update" in parent command ---- + +func TestNewEvalCommand_HasUpdateSubcommand(t *testing.T) { + t.Parallel() + cmd := newEvalCommand(&azdext.ExtensionContext{}) + var names []string + for _, sub := range cmd.Commands() { + names = append(names, sub.Name()) + } + assert.Contains(t, names, "update") +} + +// ---- eval "show" in parent command ---- + +func TestNewEvalCommand_HasShowSubcommand(t *testing.T) { + t.Parallel() + cmd := newEvalCommand(&azdext.ExtensionContext{}) + var names []string + for _, sub := range cmd.Commands() { + names = append(names, sub.Name()) + } + assert.Contains(t, names, "show") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_test.go new file mode 100644 index 00000000000..0d9b59d447b --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_test.go @@ -0,0 +1,480 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/agent_yaml" + "azureaiagent/internal/pkg/agents/dataset_api" + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeTokenCredential satisfies azcore.TokenCredential for tests. +type fakeTokenCredential struct{} + +func (f *fakeTokenCredential) GetToken( + _ context.Context, + _ policy.TokenRequestOptions, +) (azcore.AccessToken, error) { + return azcore.AccessToken{Token: "fake-token"}, nil +} + +// --------------------------------------------------------------------------- +// newEvalCommand — command tree shape +// --------------------------------------------------------------------------- + +func TestNewEvalCommand_HasExpectedSubcommands(t *testing.T) { + t.Parallel() + + cmd := newEvalCommand(&azdext.ExtensionContext{}) + names := make([]string, 0, len(cmd.Commands())) + for _, sub := range cmd.Commands() { + names = append(names, sub.Name()) + } + + assert.Contains(t, names, "init") + assert.Contains(t, names, "run") + assert.Contains(t, names, "list") + assert.Contains(t, names, "show") +} + +func TestNewEvalCommand_UseString(t *testing.T) { + t.Parallel() + cmd := newEvalCommand(&azdext.ExtensionContext{}) + assert.Equal(t, "eval ", cmd.Use) +} + +// --------------------------------------------------------------------------- +// GenerationJob methods +// --------------------------------------------------------------------------- + +func TestGenerationJob_OperationID(t *testing.T) { + t.Parallel() + assert.Equal(t, "op-123", (&eval_api.GenerationJob{ID: "op-123"}).OperationID()) + assert.Equal(t, "", (&eval_api.GenerationJob{}).OperationID()) +} + +func TestGenerationJob_NormalizedStatus(t *testing.T) { + t.Parallel() + assert.Equal(t, "completed", (&eval_api.GenerationJob{Status: "completed"}).NormalizedStatus()) + assert.Equal(t, "running", (&eval_api.GenerationJob{}).NormalizedStatus()) +} + +func TestGenerationJob_ResolvedNameVersion(t *testing.T) { + t.Parallel() + + // Empty job returns empty name and empty version. + name, version := (&eval_api.GenerationJob{}).ResolvedNameVersion() + assert.Equal(t, "", name) + assert.Equal(t, "", version) + + // Extracts name and version from the result JSON. + job := &eval_api.GenerationJob{ + Result: json.RawMessage(`{"name":"generated-ds","version":"v2"}`), + } + name, version = job.ResolvedNameVersion() + assert.Equal(t, "generated-ds", name) + assert.Equal(t, "v2", version) + + // Extracts from result.outputs[0] (nested API response format). + jobNested := &eval_api.GenerationJob{ + Result: json.RawMessage(`{"outputs":[{"type":"dataset","name":"nested-ds","version":"36735"}]}`), + } + name, version = jobNested.ResolvedNameVersion() + assert.Equal(t, "nested-ds", name) + assert.Equal(t, "36735", version) + + // Defaults version to "latest" when missing. + jobNoVer := &eval_api.GenerationJob{ + Result: json.RawMessage(`{"name":"smoke-core"}`), + } + name, version = jobNoVer.ResolvedNameVersion() + assert.Equal(t, "smoke-core", name) + assert.Equal(t, "latest", version) +} + +func TestOpenAIEval_ResolvedID(t *testing.T) { + t.Parallel() + assert.Equal(t, "eval-1", (&eval_api.OpenAIEval{ID: "eval-1", Name: "n"}).ResolvedID()) + assert.Equal(t, "n", (&eval_api.OpenAIEval{Name: "n"}).ResolvedID()) + assert.Equal(t, "", (&eval_api.OpenAIEval{}).ResolvedID()) +} + +// --------------------------------------------------------------------------- +// eval_api.FormatTimestamp +// --------------------------------------------------------------------------- + +func TestFormatTimestamp(t *testing.T) { + t.Parallel() + + assert.Equal(t, "2024-01-15 10:30 UTC", eval_api.FormatTimestamp("2024-01-15 10:30 UTC")) + assert.Contains(t, eval_api.FormatTimestamp(float64(1705312200)), "2024-01-15") + assert.Contains(t, eval_api.FormatTimestamp(int64(1705312200)), "2024-01-15") + assert.Equal(t, "", eval_api.FormatTimestamp(nil)) + assert.Equal(t, "", eval_api.FormatTimestamp(true)) +} + +// --------------------------------------------------------------------------- +// eval_api.ResolveRelPath +// --------------------------------------------------------------------------- + +func TestResolveRelPath(t *testing.T) { + t.Parallel() + + t.Run("absolute path returned as-is", func(t *testing.T) { + t.Parallel() + abs := filepath.Join(os.TempDir(), "eval.yaml") + assert.Equal(t, abs, eval_api.ResolveRelPath(abs, "/project")) + }) + + t.Run("relative path joined with agent project", func(t *testing.T) { + t.Parallel() + result := eval_api.ResolveRelPath("eval.yaml", "/project/agent") + assert.Equal(t, filepath.Join("/project/agent", "eval.yaml"), result) + }) +} + +// --------------------------------------------------------------------------- +// detectEvalAgentKind +// --------------------------------------------------------------------------- + +func TestDetectEvalAgentKind(t *testing.T) { + t.Parallel() + + t.Run("detects hosted kind from agent.yaml", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + writeTestFile(t, dir, "agent.yaml", "kind: hosted\nname: test-agent\n") + kind, path := detectEvalAgentKind(dir) + assert.Equal(t, agent_yaml.AgentKindHosted, kind) + assert.Equal(t, filepath.Join(dir, "agent.yaml"), path) + }) + + t.Run("returns empty for missing manifest", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + kind, path := detectEvalAgentKind(dir) + assert.Empty(t, kind) + assert.Empty(t, path) + }) + + t.Run("returns empty for invalid kind", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + writeTestFile(t, dir, "agent.yaml", "kind: invalid_kind_xyz\nname: test-agent\n") + kind, path := detectEvalAgentKind(dir) + assert.Empty(t, kind) + assert.Empty(t, path) + }) + + t.Run("returns empty for malformed YAML", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + writeTestFile(t, dir, "agent.yaml", "{{invalid yaml}}") + kind, path := detectEvalAgentKind(dir) + assert.Empty(t, kind) + assert.Empty(t, path) + }) +} + +// --------------------------------------------------------------------------- +// EvalState — stored in azd environment (integration-tested via eval init/run) +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// writeEvalReviewArtifacts +// --------------------------------------------------------------------------- + +func TestWriteEvalReviewArtifacts(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + cfg := &evalConfig{} + cfg.DatasetReference = &evalDatasetRef{Name: "test-data", Version: "v1"} + cfg.Evaluators = opt_eval.EvaluatorList{{Name: "quality"}} + + err := eval_api.WriteEvalReviewArtifacts(dir, cfg) + require.NoError(t, err) + + evPath := filepath.Join(dir, "evaluators", "quality", "quality.yaml") + assert.FileExists(t, evPath) +} + +func TestWriteEvalReviewArtifacts_NilDataset(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + cfg := &evalConfig{} + // No dataset reference — should not panic. + err := eval_api.WriteEvalReviewArtifacts(dir, cfg) + require.NoError(t, err) +} + +// --------------------------------------------------------------------------- +// saveEvaluatorResult +// --------------------------------------------------------------------------- + +func TestSaveEvaluatorResult(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + result := json.RawMessage(`{"name":"smoke-core","definition":{"type":"rubric","dimensions":[{"id":"quality","weight":10}]}}`) + require.NoError(t, eval_api.SaveEvaluatorResult(dir, "smoke-core", result)) + + path := filepath.Join(dir, "evaluators", "smoke-core", "rubric_dimensions.json") + assert.FileExists(t, path) + data, err := os.ReadFile(path) //nolint:gosec // test file path + require.NoError(t, err) + // Only the dimensions array is saved, not the outer fields. + assert.Contains(t, string(data), `"id": "quality"`) + assert.Contains(t, string(data), `"weight": 10`) + assert.NotContains(t, string(data), `"name": "smoke-core"`) +} + +func TestSaveEvaluatorResult_WithVersion(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + result := json.RawMessage(`{"name":"custom","definition":{"type":"rubric","dimensions":[{"id":"d1","weight":5}]}}`) + require.NoError(t, eval_api.SaveEvaluatorResult(dir, "custom", result)) + + path := filepath.Join(dir, "evaluators", "custom", "rubric_dimensions.json") + assert.FileExists(t, path) +} + +func TestSaveEvaluatorResult_NilResult(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + require.NoError(t, eval_api.SaveEvaluatorResult(dir, "test", nil)) + path := filepath.Join(dir, "evaluators", "test", "rubric_dimensions.json") + assert.NoFileExists(t, path) +} + +func TestSaveEvaluatorResult_EmptyName(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + require.NoError(t, eval_api.SaveEvaluatorResult(dir, "", json.RawMessage(`{"name":"x"}`))) + // Should not create any file. + matches, _ := filepath.Glob(filepath.Join(dir, "evaluators", "*.json")) + assert.Empty(t, matches) +} + +func TestWriteEvalReviewArtifacts_SkipsWhenResultExists(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Pre-save a result file. + require.NoError(t, eval_api.SaveEvaluatorResult(dir, "quality", json.RawMessage(`{"name":"quality","definition":{"type":"rubric","dimensions":[{"id":"q","weight":1}]}}`))) + + cfg := &evalConfig{} + cfg.Evaluators = opt_eval.EvaluatorList{{Name: "quality"}} + err := eval_api.WriteEvalReviewArtifacts(dir, cfg) + require.NoError(t, err) + + // Should NOT create a .yaml stub since .json result already exists. + yamlPath := filepath.Join(dir, "evaluators", "quality", "quality.yaml") + assert.NoFileExists(t, yamlPath) +} + +// --------------------------------------------------------------------------- +// downloadDatasetArtifact +// --------------------------------------------------------------------------- + +func TestDownloadDatasetArtifact_NilDataset(t *testing.T) { + t.Parallel() + _, err := eval_api.DownloadDatasetArtifact(t.Context(), nil, t.TempDir(), nil, "2025-11-15-preview") + require.NoError(t, err) +} + +func TestDownloadDatasetArtifact_WritesBlob(t *testing.T) { + t.Parallel() + + // The Azure SDK bearer token policy rejects non-TLS test servers, so the + // credential call will fail. downloadDatasetArtifact now returns the error. + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"sas_uri":"http://blob.example/data?sig=abc"}`)) + })) + t.Cleanup(apiServer.Close) + + client := dataset_api.NewDatasetClient(apiServer.URL, &fakeTokenCredential{}) + dir := t.TempDir() + + ref := &evalDatasetRef{Name: "test-ds", Version: "v1"} + _, err := eval_api.DownloadDatasetArtifact(t.Context(), client, dir, ref, "2025-11-15-preview") + require.Error(t, err) + assert.Contains(t, err.Error(), "getting dataset credential") + + // No file written when credential fetch fails. + dest := eval_api.DatasetArtifactPath(dir, ref) + assert.NoDirExists(t, dest) +} + +// --------------------------------------------------------------------------- +// datasetArtifactPath +// --------------------------------------------------------------------------- + +func TestDatasetArtifactPath(t *testing.T) { + t.Parallel() + ref := &evalDatasetRef{Name: "golden", Version: "v2"} + result := eval_api.DatasetArtifactPath("/project", ref) + assert.Equal(t, filepath.Join("/project", "datasets", "golden"), result) + + // No version — same path, version not included. + refNoVer := &evalDatasetRef{Name: "golden", Version: ""} + resultNoVer := eval_api.DatasetArtifactPath("/project", refNoVer) + assert.Equal(t, filepath.Join("/project", "datasets", "golden"), resultNoVer) +} + +// --------------------------------------------------------------------------- +// writeJSONFile +// --------------------------------------------------------------------------- + +func TestWriteJSONFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "result.json") + + err := eval_api.WriteJSONFile(path, map[string]string{"hello": "world"}) + require.NoError(t, err) + + data, err := os.ReadFile(path) //nolint:gosec // test file path + require.NoError(t, err) + assert.Contains(t, string(data), `"hello": "world"`) +} + +// --------------------------------------------------------------------------- +// evalAgentContextError +// --------------------------------------------------------------------------- + +func TestEvalAgentContextError(t *testing.T) { + t.Parallel() + + t.Run("without cause", func(t *testing.T) { + t.Parallel() + err := evalAgentContextError(nil) + assert.Contains(t, err.Error(), "agent context could not be resolved") + var localErr *azdext.LocalError + require.True(t, errors.As(err, &localErr)) + assert.Contains(t, localErr.Suggestion, "azd ai agent init") + }) + + t.Run("with cause", func(t *testing.T) { + t.Parallel() + cause := assert.AnError + err := evalAgentContextError(cause) + assert.Contains(t, err.Error(), cause.Error()) + var localErr *azdext.LocalError + require.True(t, errors.As(err, &localErr)) + assert.Contains(t, localErr.Suggestion, "--agent") + assert.Contains(t, localErr.Suggestion, "--project-endpoint") + }) +} + +// --------------------------------------------------------------------------- +// relPathForYaml +// --------------------------------------------------------------------------- + +func TestRelPathForYaml(t *testing.T) { + t.Parallel() + + result := relPathForYaml("/project", filepath.Join("/project", "src", "agent.yaml")) + assert.Equal(t, "src/agent.yaml", result) +} + +// --------------------------------------------------------------------------- +// eval_api.WriteEvalConfig / eval_api.LoadEvalConfig round-trip +// --------------------------------------------------------------------------- + +func TestEvalConfigRoundTrip(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "eval.yaml") + + original := &evalConfig{ + Config: opt_eval.Config{ + Name: "smoke-core", + Agent: evalAgentRef{ + Name: "my-agent", + Kind: agent_yaml.AgentKindHosted, + Version: "v1", + }, + DatasetReference: &evalDatasetRef{Name: "ds", Version: "v1"}, + Evaluators: opt_eval.EvaluatorList{{Name: "builtin.task_adherence"}}, + }, + Options: &opt_eval.Options{ + EvalModel: "gpt-4o", + }, + MaxSamples: 50, + } + + err := eval_api.WriteEvalConfig(path, original) + require.NoError(t, err) + + loaded, err := eval_api.LoadEvalConfig(path) + require.NoError(t, err) + + assert.Equal(t, original.Name, loaded.Name) + assert.Equal(t, original.Agent.Name, loaded.Agent.Name) + assert.Equal(t, original.Agent.Kind, loaded.Agent.Kind) + assert.Equal(t, original.Agent.Version, loaded.Agent.Version) + assert.Equal(t, "gpt-4o", loaded.Options.EvalModel) + assert.Equal(t, original.MaxSamples, loaded.MaxSamples) + require.NotNil(t, loaded.DatasetReference) + assert.Equal(t, "ds", loaded.DatasetReference.Name) + require.Len(t, loaded.Evaluators, 1) + assert.Equal(t, "builtin.task_adherence", loaded.Evaluators[0].Name) +} + +func TestReadEvalConfig_MissingFile(t *testing.T) { + t.Parallel() + _, err := eval_api.LoadEvalConfig("/nonexistent/path/eval.yaml") + assert.Error(t, err) +} + +// --------------------------------------------------------------------------- +// endpointFromProjectID — project ID to endpoint conversion +// --------------------------------------------------------------------------- + +func TestEndpointFromProjectID(t *testing.T) { + t.Parallel() + t.Run("valid project ID", func(t *testing.T) { + t.Parallel() + projectID := "/subscriptions/sub123/resourceGroups/rg/providers/Microsoft.CognitiveServices/accounts/myaccount/projects/myproject" + endpoint, err := endpointFromProjectID(projectID) + require.NoError(t, err) + assert.Contains(t, endpoint, "myaccount") + assert.Contains(t, endpoint, "myproject") + }) + + t.Run("invalid project ID", func(t *testing.T) { + t.Parallel() + _, err := endpointFromProjectID("not-a-valid-id") + assert.Error(t, err) + }) + + t.Run("empty project ID", func(t *testing.T) { + t.Parallel() + _, err := endpointFromProjectID("") + assert.Error(t, err) + }) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_update.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_update.go new file mode 100644 index 00000000000..54c6d5792b2 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/eval_update.go @@ -0,0 +1,245 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// eval_update.go implements the "eval update" command, which uploads new +// versions of locally-edited evaluators and datasets. It reads eval.yaml, +// detects assets with local_uri pointers, and uploads them as new versions. + +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + + "azureaiagent/internal/pkg/agents/dataset_api" + "azureaiagent/internal/pkg/agents/eval_api" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +// evalUpdateFlags holds CLI flags for the eval update command. +type evalUpdateFlags struct { + config string // eval config path + datasetOnly bool // only update the dataset + evaluatorOnly bool // only update evaluators +} + +func newEvalUpdateCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + flags := &evalUpdateFlags{config: defaultEvalConfigName} + cmd := &cobra.Command{ + Use: "update", + Short: "Update evaluators and datasets from local files.", + Long: `Reads the eval config and uploads new versions for: + - Evaluators with a local_uri (rubric dimensions file) + - Datasets with a local_uri (JSONL data directory) +The version fields in the config are updated after successful uploads. + +In interactive mode, you will be prompted for each asset type that has +local changes. Use --dataset-only or --evaluator-only to skip prompts.`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + logCleanup := setupDebugLogging(cmd.Flags()) + defer logCleanup() + return runEvalUpdate(ctx, flags, extCtx.NoPrompt) + }, + } + cmd.Flags().StringVar(&flags.config, "config", defaultEvalConfigName, "Local eval config YAML") + cmd.Flags().BoolVar(&flags.datasetOnly, "dataset-only", false, "Only update the dataset") + cmd.Flags().BoolVar(&flags.evaluatorOnly, "evaluator-only", false, "Only update evaluators") + return cmd +} + +func runEvalUpdate(ctx context.Context, flags *evalUpdateFlags, noPrompt bool) error { + resolved, err := resolveEvalContext(ctx, evalContextOptions{}) + if err != nil { + return err + } + defer resolved.azdClient.Close() + + configPath := eval_api.ResolveRelPath(flags.config, resolved.agentProject) + evalCfg, err := eval_api.LoadEvalConfig(configPath) + if err != nil { + return fmt.Errorf("failed to load eval config: %w", err) + } + + // Detect what has local changes. + hasDataset := evalCfg.DatasetReference != nil && + evalCfg.DatasetReference.Name != "" && + evalCfg.DatasetReference.LocalURI != "" + hasEvaluators := len(evalCfg.Evaluators.FindByLocalURI()) > 0 + + // Determine what to update based on flags and interactive prompts. + updateDS := hasDataset && !flags.evaluatorOnly + updateEval := hasEvaluators && !flags.datasetOnly + + // In interactive mode (no exclusive flags), prompt for each detected type. + if !noPrompt && !flags.datasetOnly && !flags.evaluatorOnly { + if hasDataset { + updateDS = confirmUpdate(ctx, resolved, fmt.Sprintf( + "Dataset %s has local changes. Upload new version?", + evalCfg.DatasetReference.Name, + )) + } + if hasEvaluators { + updateEval = confirmUpdate(ctx, resolved, "Evaluator(s) have local changes. Upload new version(s)?") + } + } + + var totalUpdated int + + if updateDS { + dsUpdated, err := updateDataset(ctx, resolved.datasetClient, evalCfg, configPath) + if err != nil { + return err + } + totalUpdated += dsUpdated + } + + if updateEval { + evalUpdated, err := updateEvaluators(ctx, resolved.evalClient, evalCfg, configPath) + if err != nil { + return err + } + totalUpdated += evalUpdated + } + + if totalUpdated > 0 { + if err := eval_api.WriteEvalConfig(configPath, evalCfg); err != nil { + return fmt.Errorf("failed to save updated config: %w", err) + } + fmt.Printf("\n%s Updated config saved to %s\n", color.GreenString("Done."), flags.config) + } else { + fmt.Println("\nNo updates were made.") + } + + return nil +} + +// confirmUpdate prompts the user with a yes/no question, defaulting to yes. +func confirmUpdate(ctx context.Context, resolved *evalResolvedContext, message string) bool { + resp, err := resolved.azdClient.Prompt().Confirm(ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: message, + DefaultValue: new(true), + }, + }) + if err != nil { + return true // on error, default to updating + } + return resp.Value != nil && *resp.Value +} + +// updateDataset uploads local dataset files as a new dataset version. +// Returns the number of datasets updated (0 or 1). +func updateDataset( + ctx context.Context, + client *dataset_api.DatasetClient, + evalCfg *evalConfig, + configPath string, +) (int, error) { + ref := evalCfg.DatasetReference + if ref == nil || ref.Name == "" || ref.LocalURI == "" { + return 0, nil + } + + localDir := ref.LocalURI + if !filepath.IsAbs(localDir) { + localDir = filepath.Join(filepath.Dir(configPath), localDir) + } + + resp, err := client.UploadNewVersion(ctx, ref.Name, ref.Version, localDir, DefaultAgentAPIVersion) + if err != nil { + fmt.Printf(" %s Failed to update dataset %s: %v\n", color.RedString("x"), ref.Name, err) + return 0, nil + } + + ref.Version = resp.Version + fmt.Printf(" %s Dataset %s → version %s\n", color.GreenString("✓"), ref.Name, resp.Version) + return 1, nil +} + +// updateEvaluators uploads local evaluator dimensions as new evaluator versions. +// Returns the number of evaluators updated. +func updateEvaluators( + ctx context.Context, + client *eval_api.EvalClient, + evalCfg *evalConfig, + configPath string, +) (int, error) { + localEvals := evalCfg.Evaluators.FindByLocalURI() + if len(localEvals) == 0 { + return 0, nil + } + + var updated int + for _, ref := range localEvals { + localPath := ref.LocalURI + if !filepath.IsAbs(localPath) { + localPath = filepath.Join(filepath.Dir(configPath), localPath) + } + + data, err := os.ReadFile(localPath) //nolint:gosec // user-provided local config path + if err != nil { + fmt.Printf(" %s Skipping %s: %v\n", color.YellowString("!"), ref.Name, err) + continue + } + + if !json.Valid(data) { + fmt.Printf(" %s Skipping %s: file is not valid JSON\n", color.YellowString("!"), ref.Name) + continue + } + + current, err := client.GetEvaluatorRaw(ctx, ref.Name, ref.Version, DefaultAgentAPIVersion) + if err != nil { + fmt.Printf(" %s Failed to get evaluator %s: %v\n", color.RedString("x"), ref.Name, err) + continue + } + + var obj map[string]json.RawMessage + if err := json.Unmarshal(current, &obj); err != nil { + fmt.Printf(" %s Failed to parse evaluator %s: %v\n", color.RedString("x"), ref.Name, err) + continue + } + + // Patch dimensions into the existing definition. + var defObj map[string]json.RawMessage + if raw, ok := obj["definition"]; ok { + if err := json.Unmarshal(raw, &defObj); err != nil { + defObj = make(map[string]json.RawMessage) + } + } else { + defObj = make(map[string]json.RawMessage) + } + defObj["dimensions"] = json.RawMessage(data) + updatedDef, err := json.Marshal(defObj) + if err != nil { + fmt.Printf(" %s Failed to build definition for %s: %v\n", color.RedString("x"), ref.Name, err) + continue + } + obj["definition"] = json.RawMessage(updatedDef) + + body, err := json.Marshal(obj) + if err != nil { + fmt.Printf(" %s Failed to build request for %s: %v\n", color.RedString("x"), ref.Name, err) + continue + } + + resp, err := client.CreateEvaluatorVersion(ctx, ref.Name, body, DefaultAgentAPIVersion) + if err != nil { + fmt.Printf(" %s Failed to update %s: %v\n", color.RedString("x"), ref.Name, err) + continue + } + + evalCfg.Evaluators.SetVersion(ref.Name, resp.Version) + updated++ + fmt.Printf(" %s Evaluator %s → version %s\n", color.GreenString("✓"), ref.Name, resp.Version) + } + + return updated, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go index 74e9b7f98e0..3979f875284 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go @@ -15,6 +15,7 @@ import ( "azureaiagent/internal/exterrors" "azureaiagent/internal/pkg/agents/agent_api" "azureaiagent/internal/pkg/agents/agent_yaml" + "azureaiagent/internal/pkg/agents/optimize_api" "azureaiagent/internal/pkg/azure" "azureaiagent/internal/pkg/envkey" "azureaiagent/internal/pkg/paths" @@ -303,6 +304,16 @@ func postdeployHandler(ctx context.Context, azdClient *azdext.AzdClient, args *a return fmt.Errorf("agent identity RBAC setup failed: %w", err) } + // Report optimization candidate deployments to the optimization service. + // If a service has AGENT_{KEY}_OPTIMIZATION_CANDIDATE_ID in the azd environment, + // the agent was deployed from an optimization candidate. We notify the + // optimization service so it can track which candidates have been deployed. + reportOptimizationDeployments(ctx, azdClient, hostedAgents, envName, endpointResp.Value, + func(endpoint string) *optimize_api.OptimizeClient { + return optimize_api.NewOptimizeClient(endpoint, cred) + }, + ) + return nil } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go index 919b52b955c..2350896e6cc 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go @@ -12,6 +12,8 @@ import ( "azureaiagent/internal/project" "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestPostdeployHandler_NoAgentService_NoOp verifies postdeployHandler returns nil @@ -332,3 +334,308 @@ func TestBuildConnectionCredentials(t *testing.T) { }) } } + +// --------------------------------------------------------------------------- +// isHostedAgentService +// --------------------------------------------------------------------------- + +func TestIsHostedAgentService_HostedKind(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(dir, "agent.yaml"), + []byte("kind: hosted\nname: my-agent\n"), 0600, + )) + + svc := &azdext.ServiceConfig{Name: "svc", RelativePath: "."} + proj := &azdext.ProjectConfig{Path: dir} + + assert.True(t, isHostedAgentService(svc, proj)) +} + +func TestIsHostedAgentService_NonHostedKind(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(dir, "agent.yaml"), + []byte("kind: local\nname: my-agent\n"), 0600, + )) + + svc := &azdext.ServiceConfig{Name: "svc", RelativePath: "."} + proj := &azdext.ProjectConfig{Path: dir} + + assert.False(t, isHostedAgentService(svc, proj)) +} + +func TestIsHostedAgentService_NoAgentYaml(t *testing.T) { + t.Parallel() + + svc := &azdext.ServiceConfig{Name: "svc", RelativePath: "."} + proj := &azdext.ProjectConfig{Path: t.TempDir()} + + assert.False(t, isHostedAgentService(svc, proj)) +} + +func TestIsHostedAgentService_InvalidYaml(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(dir, "agent.yaml"), + []byte(":::invalid yaml:::"), 0600, + )) + + svc := &azdext.ServiceConfig{Name: "svc", RelativePath: "."} + proj := &azdext.ProjectConfig{Path: dir} + + assert.False(t, isHostedAgentService(svc, proj)) +} + +func TestIsHostedAgentService_MissingKindField(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(dir, "agent.yaml"), + []byte("name: my-agent\n"), 0600, + )) + + svc := &azdext.ServiceConfig{Name: "svc", RelativePath: "."} + proj := &azdext.ProjectConfig{Path: dir} + + assert.False(t, isHostedAgentService(svc, proj)) +} + +func TestIsHostedAgentService_SubDirectory(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + subDir := filepath.Join(dir, "agents", "bot") + require.NoError(t, os.MkdirAll(subDir, 0700)) + require.NoError(t, os.WriteFile( + filepath.Join(subDir, "agent.yaml"), + []byte("kind: hosted\nname: bot\n"), 0600, + )) + + svc := &azdext.ServiceConfig{Name: "bot", RelativePath: "agents/bot"} + proj := &azdext.ProjectConfig{Path: dir} + + assert.True(t, isHostedAgentService(svc, proj)) +} + +// --------------------------------------------------------------------------- +// resolveEnvValue / resolveMapValues / resolveAnyValue +// --------------------------------------------------------------------------- + +func TestResolveEnvValue(t *testing.T) { + t.Parallel() + + env := map[string]string{ + "DB_HOST": "mydb.postgres.azure.com", + "DB_PORT": "5432", + } + + tests := []struct { + input string + want string + }{ + {"${DB_HOST}", "mydb.postgres.azure.com"}, + {"host=${DB_HOST}:${DB_PORT}", "host=mydb.postgres.azure.com:5432"}, + {"no-var", "no-var"}, + {"${UNDEFINED}", ""}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, resolveEnvValue(tt.input, env)) + }) + } +} + +func TestResolveMapValues(t *testing.T) { + t.Parallel() + + env := map[string]string{"KEY": "val"} + m := map[string]any{ + "a": "${KEY}", + "b": "literal", + "c": 42, + } + + got := resolveMapValues(m, env) + assert.Equal(t, "val", got["a"]) + assert.Equal(t, "literal", got["b"]) + assert.Equal(t, 42, got["c"]) +} + +func TestResolveAnyValue_NestedStructures(t *testing.T) { + t.Parallel() + + env := map[string]string{"X": "resolved"} + + // Nested map + nested := map[string]any{ + "inner": map[string]any{"key": "${X}"}, + } + got := resolveAnyValue(nested, env) + gotMap := got.(map[string]any) + inner := gotMap["inner"].(map[string]any) + assert.Equal(t, "resolved", inner["key"]) + + // Slice + slice := []any{"${X}", "plain", 99} + gotSlice := resolveAnyValue(slice, env).([]any) + assert.Equal(t, "resolved", gotSlice[0]) + assert.Equal(t, "plain", gotSlice[1]) + assert.Equal(t, 99, gotSlice[2]) + + // Non-string type passthrough + assert.Equal(t, true, resolveAnyValue(true, env)) +} + +// --------------------------------------------------------------------------- +// resolveToolboxEnvVars +// --------------------------------------------------------------------------- + +func TestResolveToolboxEnvVars(t *testing.T) { + t.Parallel() + + env := map[string]string{ + "TB_NAME": "my-toolbox", + "TB_DESC": "A test toolbox", + "URL": "https://example.com", + } + + tb := project.Toolbox{ + Name: "${TB_NAME}", + Description: "${TB_DESC}", + Tools: []map[string]any{ + {"server_url": "${URL}", "type": "web_search"}, + }, + } + + resolveToolboxEnvVars(&tb, env) + + assert.Equal(t, "my-toolbox", tb.Name) + assert.Equal(t, "A test toolbox", tb.Description) + assert.Equal(t, "https://example.com", tb.Tools[0]["server_url"]) + assert.Equal(t, "web_search", tb.Tools[0]["type"]) +} + +// --------------------------------------------------------------------------- +// toolboxConnectionsByName +// --------------------------------------------------------------------------- + +func TestToolboxConnectionsByName_NilConfig(t *testing.T) { + t.Parallel() + assert.Empty(t, toolboxConnectionsByName(nil)) +} + +func TestToolboxConnectionsByName_MergesBothTypes(t *testing.T) { + t.Parallel() + + config := &project.ServiceTargetAgentConfig{ + Connections: []project.Connection{ + {Name: "conn-a", Target: "https://a.com"}, + }, + ToolConnections: []project.ToolConnection{ + {Name: "tool-b", Target: "https://b.com"}, + }, + } + + result := toolboxConnectionsByName(config) + assert.Len(t, result, 2) + assert.Equal(t, "https://a.com", result["conn-a"].Target) + assert.Equal(t, "https://b.com", result["tool-b"].Target) +} + +// --------------------------------------------------------------------------- +// postdeployHandler — skips non-agent-host services +// --------------------------------------------------------------------------- + +func TestPostdeployHandler_SkipsNonAgentHostServices(t *testing.T) { + t.Parallel() + + // Project has one service with a different host type — handler should + // return nil without making any RPC calls (azdClient is nil). + args := &azdext.ProjectEventArgs{ + Project: &azdext.ProjectConfig{ + Path: t.TempDir(), + Services: map[string]*azdext.ServiceConfig{ + "api": {Name: "api", Host: "containerapp", RelativePath: "."}, + }, + }, + } + + assert.NoError(t, postdeployHandler(t.Context(), nil, args)) +} + +func TestPostdeployHandler_SkipsWhenNoServices(t *testing.T) { + t.Parallel() + + args := &azdext.ProjectEventArgs{ + Project: &azdext.ProjectConfig{ + Path: t.TempDir(), + Services: map[string]*azdext.ServiceConfig{}, + }, + } + + assert.NoError(t, postdeployHandler(t.Context(), nil, args)) +} + +// --------------------------------------------------------------------------- +// enrichToolboxFromConnections — server_url already set +// --------------------------------------------------------------------------- + +func TestEnrichToolboxFromConnections_DoesNotOverrideExistingServerURL(t *testing.T) { + t.Parallel() + + connByName := map[string]toolboxConnection{ + "my-conn": {Name: "my-conn", Target: "https://conn-target.com"}, + } + + tb := project.Toolbox{ + Name: "test", + Tools: []map[string]any{ + { + "type": "mcp", + "project_connection_id": "my-conn", + "server_url": "https://custom-url.com", + }, + }, + } + + enrichToolboxFromConnections(&tb, connByName) + + // server_url was already set — should not be overridden. + assert.Equal(t, "https://custom-url.com", tb.Tools[0]["server_url"]) + // server_label should still be filled in. + assert.Equal(t, "my-conn", tb.Tools[0]["server_label"]) +} + +func TestEnrichToolboxFromConnections_EmptyTarget(t *testing.T) { + t.Parallel() + + connByName := map[string]toolboxConnection{ + "no-target": {Name: "no-target", Target: ""}, + } + + tb := project.Toolbox{ + Name: "test", + Tools: []map[string]any{ + {"type": "mcp", "project_connection_id": "no-target"}, + }, + } + + enrichToolboxFromConnections(&tb, connByName) + + // Empty target → server_url should NOT be set. + _, hasURL := tb.Tools[0]["server_url"] + assert.False(t, hasURL) + // server_label should still be set. + assert.Equal(t, "no-target", tb.Tools[0]["server_label"]) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize.go new file mode 100644 index 00000000000..2bdfdb4cd8c --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize.go @@ -0,0 +1,579 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// optimize.go implements the top-level "optimize" command, which submits +// agent optimization jobs. It resolves the agent, loads or builds a config, +// prompts for instruction/skills/model, and polls for results. +// +// Subcommands (status, list, cancel, apply, deploy) are registered here +// and implemented in their own files. + +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "path/filepath" + "strings" + "time" + + "azureaiagent/internal/pkg/agents/opt_eval" + "azureaiagent/internal/pkg/agents/optimize_api" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +// optimizeAgentContext holds the resolved agent name and project directory +// for an optimization operation. +type optimizeAgentContext struct { + agentName string // deployed agent name + agentProject string // agent project directory (empty if not in an azd project) +} + +// resolveOptimizeAgent resolves the agent name and project directory. +// Resolution order: +// 1. Explicit --agent flag +// 2. azd project context (resolveAgentService + environment variables) +// 3. Error with guidance +func resolveOptimizeAgent(ctx context.Context, flagValue string, noPrompt bool) (*optimizeAgentContext, error) { + if flagValue != "" { + return &optimizeAgentContext{agentName: flagValue}, nil + } + + // Try resolving from azd project — single resolveAgentService call + // to get both project path and agent info from environment. + azdClient, err := azdext.NewAzdClient() + if err == nil { + defer azdClient.Close() + + svc, project, svcErr := resolveAgentService(ctx, azdClient, "", noPrompt) + if svcErr == nil && svc != nil && project != nil { + agentProject := filepath.Join(project.Path, svc.RelativePath) + serviceKey := toServiceKey(svc.Name) + + // Read agent name from azd environment + envResp, envErr := azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{}) + if envErr == nil && envResp.Environment != nil { + nameKey := fmt.Sprintf("AGENT_%s_NAME", serviceKey) + if v, e := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: envResp.Environment.Name, + Key: nameKey, + }); e == nil && v.Value != "" { + return &optimizeAgentContext{ + agentName: v.Value, + agentProject: agentProject, + }, nil + } + } + } + } + + return nil, fmt.Errorf("agent name is required: use --agent , or run from an azd project after 'azd deploy'") +} + +// optimizeFlags holds CLI flags for the optimize (submit) command. +type optimizeFlags struct { + configFile string // path to YAML config file + agent string // agent name override + evalModel string // model for evaluation + targetAttributes []string // optimization targets (instruction, skill) + noWait bool // return immediately after submission + pollInterval int // polling interval in seconds + optimizeConnectionFlags +} + +// newOptimizeCommand creates the top-level "optimize" command and registers its subcommands. +func newOptimizeCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + flags := &optimizeFlags{} + action := &OptimizeAction{flags: flags, noPrompt: extCtx.NoPrompt} + + cmd := &cobra.Command{ + Use: "optimize [agent-name]", + Short: "Evaluate and optimize AI agents.", + Long: `Evaluate and optimize AI agents — baseline scoring and iterative improvement. + +When run without a subcommand, submits an optimization job. +Use --config for a custom YAML spec, or just provide the agent name to use sensible defaults.`, + Example: ` # Optimize (auto-detect agent from azd project) + azd ai agent optimize + + # Optimize a specific agent + azd ai agent optimize my-agent + + # Optimize with skill target + azd ai agent optimize --target skill + + # Optimize with multiple target attributes + azd ai agent optimize --target instruction --target skill + + # Full control via config file + azd ai agent optimize --config spec.yaml + + # Subcommands + azd ai agent optimize status --watch + azd ai agent optimize list + azd ai agent optimize cancel + azd ai agent optimize deploy --candidate `, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + setupDebugLogging(cmd.Flags()) + + // Positional arg fills in agent name + if len(args) > 0 && flags.agent == "" { + flags.agent = args[0] + } + + return action.Run(ctx, cmd) + }, + } + + cmd.Flags().StringVarP(&flags.configFile, "config", "c", "", "Path to YAML config file (optional — uses defaults if omitted)") + cmd.Flags().StringVarP(&flags.agent, "agent", "a", "", "Agent name (auto-detected from azd project if omitted)") + cmd.Flags().StringVarP(&flags.evalModel, "eval-model", "m", defaultEvalModel, "Model for evaluation") + cmd.Flags().StringArrayVarP(&flags.targetAttributes, "target", "t", nil, + "Target attribute for optimization: instruction, skill (repeatable)") + cmd.Flags().BoolVar(&flags.noWait, "no-wait", false, "Submit job and return immediately without waiting for completion") + cmd.Flags().IntVar(&flags.pollInterval, "poll-interval", 5, "Polling interval in seconds") + flags.optimizeConnectionFlags.register(cmd) + + cmd.AddCommand(newOptimizeStatusCommand()) + cmd.AddCommand(newOptimizeListCommand()) + cmd.AddCommand(newOptimizeCancelCommand()) + cmd.AddCommand(newOptimizeApplyCommand(extCtx)) + cmd.AddCommand(newOptimizeDeployCommand()) + + return cmd +} + +// OptimizeAction implements the optimize (submit job) command. +type OptimizeAction struct { + flags *optimizeFlags + noPrompt bool +} + +// Run executes the optimize command: resolves the agent, loads/builds the config, applies overrides, submits the job, and optionally polls for results. +func (a *OptimizeAction) Run(ctx context.Context, cmd *cobra.Command) error { + endpoint, err := a.flags.resolve(ctx) + if err != nil { + return err + } + + cfg, configSource, agentProject, err := a.resolveConfig(ctx) + if err != nil { + return err + } + hasProject := agentProject != "" + + if err := a.applyOverrides(ctx, cfg, agentProject); err != nil { + return err + } + + out := cmd.OutOrStdout() + bold := color.New(color.Bold) + + _, _ = bold.Fprintf(out, "Optimizing agent %q...\n", cfg.Agent.Name) + if configSource == "" { + fmt.Fprintf(out, " Dataset: built-in (3 tasks, 12 criteria)\n") + } else { + fmt.Fprintf(out, " Config: %s\n", configSource) + } + + resp, client, err := a.submitJob(ctx, out, endpoint, cfg, agentProject) + if err != nil { + return err + } + + if !a.flags.noWait && !optimize_api.IsTerminal(resp.Status) { + finalStatus, err := pollOptimizeJob(cmd, client, a.flags.pollInterval, resp.OperationID) + if err != nil { + return err + } + printOptimizeResults(out, finalStatus, hasProject) + } + + return nil +} + +// resolveConfig loads or builds an OptimizeConfig from flags, eval.yaml +// detection, and agent resolution. Returns the config, its source path +// (empty if using defaults), and the agent project directory. +func (a *OptimizeAction) resolveConfig( + ctx context.Context, +) (cfg *OptimizeConfig, configSource, agentProject string, err error) { + if a.flags.configFile != "" { + cfg, err = LoadOptimizeConfig(a.flags.configFile) + if err != nil { + return nil, "", "", fmt.Errorf("%w\n\nCheck that the file path is correct and contains valid YAML", err) + } + + // Even with explicit --config, try to reconcile agent name with the environment. + resolved, resolveErr := resolveOptimizeAgent(ctx, a.flags.agent, a.noPrompt) + if resolveErr == nil { + agentProject = resolved.agentProject + reconcileConfigAgentName(&cfg.Agent, resolved.agentName, a.flags.configFile) + } + + return cfg, a.flags.configFile, agentProject, nil + } + + resolved, err := resolveOptimizeAgent(ctx, a.flags.agent, a.noPrompt) + if err != nil { + return nil, "", "", err + } + agentProject = resolved.agentProject + + // Check if eval.yaml exists in the agent project and offer to use it. + if resolved.agentProject != "" { + evalPath := filepath.Join(resolved.agentProject, defaultEvalConfigName) + if _, statErr := os.Stat(evalPath); statErr == nil && !a.noPrompt { + azdClient, clientErr := azdext.NewAzdClient() + if clientErr == nil { + defer azdClient.Close() + resp, promptErr := azdClient.Prompt().Confirm(ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: fmt.Sprintf("Found %s in project. Use it for optimization?", defaultEvalConfigName), + DefaultValue: new(true), + }, + }) + if promptErr == nil && resp.Value != nil && *resp.Value { + cfg, err = LoadOptimizeConfig(evalPath) + if err != nil { + return nil, "", "", fmt.Errorf("failed to load %s: %w", evalPath, err) + } + configSource = evalPath + } + } + } + } + + if cfg == nil { + cfg = defaultOptimizeConfig(resolved.agentName) + } else { + reconcileConfigAgentName(&cfg.Agent, resolved.agentName, configSource) + } + + return cfg, configSource, agentProject, nil +} + +// applyOverrides applies CLI flag overrides, resolves baseline agent config, +// and interactively fills missing instruction/skills/model values. +func (a *OptimizeAction) applyOverrides( + ctx context.Context, + cfg *OptimizeConfig, + agentProject string, +) error { + if err := cfg.Validate(); err != nil { + return fmt.Errorf("invalid config: %w", err) + } + + hasProject := agentProject != "" + + // CLI flags override config values. + if a.flags.evalModel != "" { + cfg.Options.EvalModel = a.flags.evalModel + } + if len(a.flags.targetAttributes) > 0 { + cfg.Options.TargetAttributes = a.flags.targetAttributes + } + + // Resolve agent config: try existing config pointer, then default baseline. + if hasProject { + mergeAgentBaseline(cfg, agentProject) + } + + // When baseline config is detected, show resolved values and let the user confirm. + if cfg.Agent.ConfigFile != "" && hasProject && !a.noPrompt { + if err := promptOptimizeConfigConfirmation(ctx, cfg, agentProject); err != nil { + return err + } + } + + // Resolve relative skill_dir against agent project directory. + if cfg.SkillDir != "" && hasProject && !filepath.IsAbs(cfg.SkillDir) { + cfg.SkillDir = filepath.Join(agentProject, cfg.SkillDir) + } + + // Resolve relative tools_file against agent project directory. + if cfg.ToolsFile != "" && hasProject && !filepath.IsAbs(cfg.ToolsFile) { + cfg.ToolsFile = filepath.Join(agentProject, cfg.ToolsFile) + } + + // Resolve agent instruction using a well-defined lifecycle: + // 1. Config dir pointer (agent.config in eval.yaml) — resolves from metadata.yaml + // 2. Config file (eval.yaml / --config) — instruction in the agent section (inline or file reference) + // 3. Interactive prompt — ask the user to provide inline text or a file path + if err := resolveOptimizeSystemPrompt(ctx, cfg, agentProject, hasProject, a.noPrompt); err != nil { + return err + } + + // Resolve skill_dir: auto-detect, check baseline, or prompt user. + if cfg.SkillDir == "" && hasProject { + if err := resolveOptimizeSkillDir(ctx, cfg, agentProject, a.noPrompt); err != nil { + return err + } + } + + // Resolve target_config.model: prompt user if not set. + if (cfg.Options.TargetConfig == nil || len(cfg.Options.TargetConfig.Model) == 0) && !a.noPrompt { + if err := resolveOptimizeTargetModels(ctx, cfg); err != nil { + return err + } + } + + // Resolve reflection_model: prompt user if not set. + if cfg.Options.ReflectionModel == "" && !a.noPrompt { + if err := resolveOptimizeReflectionModel(ctx, cfg); err != nil { + return err + } + } + + return nil +} + +// mergeAgentBaseline resolves the baseline agent config and merges missing +// fields (instruction, model, skills, tools) into the OptimizeConfig. +func mergeAgentBaseline(cfg *OptimizeConfig, agentProject string) { + var existing *opt_eval.Config + if cfg.Agent.ConfigFile != "" { + existing = &opt_eval.Config{Agent: cfg.Agent} + } + agentCfg := resolveAgentConfig(existing, agentProject) + if agentCfg == nil { + return + } + cfg.Agent.ConfigFile = agentCfg.ConfigFile + if cfg.Agent.Instruction.IsEmpty() && agentCfg.InstructionFile != "" { + cfg.Agent.Instruction.File = agentCfg.InstructionFile + } + if cfg.Agent.Model == "" { + cfg.Agent.Model = agentCfg.Model + } + if cfg.SkillDir == "" { + cfg.SkillDir = agentCfg.SkillDir + } + if cfg.ToolsFile == "" { + cfg.ToolsFile = agentCfg.ToolsFile + } + if existing == nil { + fmt.Printf(" Baseline: %s\n", filepath.Join(agentProject, agentCfg.ConfigFile)) + } +} + +// submitJob builds the optimization request, saves the baseline config, +// submits the job, and prints initial status. +func (a *OptimizeAction) submitJob( + ctx context.Context, + out io.Writer, + endpoint string, + cfg *OptimizeConfig, + agentProject string, +) (*optimize_api.OptimizeResponse, *optimize_api.OptimizeClient, error) { + credential, err := newAgentCredential() + if err != nil { + return nil, nil, err + } + + client := optimize_api.NewOptimizeClient(endpoint, credential) + + optimizeReq, err := cfg.ToRequest(endpoint) + if err != nil { + return nil, nil, fmt.Errorf("failed to build optimization request: %w", err) + } + + if body, jsonErr := json.MarshalIndent(optimizeReq, "", " "); jsonErr == nil { + log.Printf("[debug] optimization request:\n%s", body) + } + + // Save baseline config before starting optimization. + hasProject := agentProject != "" + if hasProject { + if err := writeBaselineConfig(agentProject, baselineParams{ + Model: optimizeReq.Agent.Model, + Instruction: optimizeReq.Agent.SystemPrompt, + SkillDir: cfg.SkillDir, + ToolsFile: cfg.ToolsFile, + }); err != nil { + fmt.Fprintf(out, " warning: failed to save baseline config: %s\n", err) + } else { + baselineMetaPath := opt_eval.BaselineConfigRelPath() + fmt.Fprintf(out, " Baseline saved to %s\n", baselineMetaPath) + if cfg.Agent.ConfigFile == "" { + cfg.Agent.ConfigFile = baselineMetaPath + } + } + } + + resp, err := client.StartOptimize(ctx, optimizeReq) + if err != nil { + return nil, nil, fmt.Errorf( + "failed to submit optimization job: %w\n\nCheck that the endpoint %q is reachable", err, endpoint) + } + + fmt.Fprintf(out, " Job ID: %s\n", color.CyanString(resp.OperationID)) + fmt.Fprintf(out, " Status: %s\n", resp.Status) + + printOptimizePortalLink(ctx, out, cfg.Agent.Name, resp.OperationID) + fmt.Fprintln(out) + + saveLastOptimizeJobID(ctx, resp.OperationID) + + return resp, client, nil +} + +// pollOptimizeJob polls the optimization job until it reaches a terminal state. +func pollOptimizeJob( + cmd *cobra.Command, + client *optimize_api.OptimizeClient, + pollInterval int, + operationID string, +) (*optimize_api.OptimizeJobStatus, error) { + out := cmd.OutOrStdout() + spinFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"} + frameIdx := 0 + startTime := time.Now() + + poller := &optimize_api.Poller{ + Client: client, + OperationID: operationID, + Interval: time.Duration(pollInterval) * time.Second, + OnProgress: func(status *optimize_api.OptimizeJobStatus) { + elapsed := time.Since(startTime).Truncate(time.Second) + spin := spinFrames[frameIdx%len(spinFrames)] + frameIdx++ + + progress := fmt.Sprintf("\r %s %s", spin, status.Status) + if status.Progress != nil { + p := status.Progress + if p.CurrentTargetAttribute != "" { + progress += fmt.Sprintf(" · strategy: %s", p.CurrentTargetAttribute) + } + if p.CurrentIteration > 0 { + progress += fmt.Sprintf(" · iteration %d", p.CurrentIteration) + } + if p.BestScore > 0 { + progress += fmt.Sprintf(" · score: %.2f", p.BestScore) + } + } + progress += fmt.Sprintf(" · %s", elapsed) + fmt.Fprintf(out, "%-80s", progress) + }, + } + + finalStatus, err := poller.PollUntilDone(cmd.Context()) + fmt.Fprintln(out) + if err != nil { + return nil, fmt.Errorf("failed while polling optimization job: %w", err) + } + + return finalStatus, nil +} + +// printOptimizeResults prints the optimization results table and next-step commands. +func printOptimizeResults(out io.Writer, status *optimize_api.OptimizeJobStatus, hasProject bool) { + if status.Error != nil { + fmt.Fprintf(out, "\n %s %s\n", color.RedString("Error:"), status.Error.Message) + } + + if len(status.Candidates) == 0 { + return + } + + bold := color.New(color.Bold) + green := color.New(color.FgGreen) + + _, _ = bold.Fprintln(out, "\nResults:") + fmt.Fprintf(out, " %-20s %7s %7s %8s\n", "Candidate", "Score", "Pass", "Tokens") + fmt.Fprintf(out, " %-20s %7s %7s %8s\n", + strings.Repeat("─", 20), strings.Repeat("─", 7), + strings.Repeat("─", 7), strings.Repeat("─", 8)) + + bestName := "" + if status.Best != nil { + bestName = status.Best.Name + } + + for _, c := range status.Candidates { + isBest := c.Name == bestName + name := c.Name + if isBest { + name += " ★" + } + + line := fmt.Sprintf(" %-20s %7.2f %6.0f%% %8.0f", name, c.AvgScore, c.PassRate*100, c.AvgTokens) + if isBest { + _, _ = green.Fprintln(out, line) + } else { + fmt.Fprintln(out, line) + } + } + + // Print candidate IDs for deploy + hasIDs := false + for _, c := range status.Candidates { + if c.CandidateID != "" { + if !hasIDs { + fmt.Fprintf(out, "\n Candidate IDs:\n") + hasIDs = true + } + marker := " " + if c.Name == bestName { + marker = "★ " + } + fmt.Fprintf(out, " %s%-20s %s\n", marker, c.Name, c.CandidateID) + } + } + + // Print next-step commands for best candidate + if status.Best != nil && status.Best.CandidateID != "" { + agentName := "" + if status.Agent != nil { + agentName = status.Agent.AgentName + } + if hasProject { + fmt.Fprintf(out, "\n Apply the best candidate locally, then deploy:\n") + fmt.Fprintf(out, " azd ai agent optimize apply --candidate %s\n", status.Best.CandidateID) + fmt.Fprintf(out, " azd deploy\n") + } else { + fmt.Fprintf(out, "\n Deploy the best candidate:\n") + fmt.Fprintf(out, " azd ai agent optimize deploy --candidate %s --agent %s\n", + status.Best.CandidateID, agentName) + } + } + fmt.Fprintln(out) +} + +// formatOptimizeStatus returns a colorized string for the given job status. +func formatOptimizeStatus(status string) string { + switch status { + case optimize_api.StatusCompleted: + return color.GreenString(status) + case optimize_api.StatusFailed: + return color.RedString(status) + case optimize_api.StatusCancelled: + return color.YellowString(status) + case optimize_api.StatusRunning: + return color.CyanString(status) + case optimize_api.StatusPending: + return color.BlueString(status) + default: + return status + } +} + +// truncateString truncates s to maxLen characters, appending "..." if trimmed. +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + if maxLen <= 3 { + return s[:maxLen] + } + return s[:maxLen-3] + "..." +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_apply.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_apply.go new file mode 100644 index 00000000000..9716d3190d5 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_apply.go @@ -0,0 +1,587 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// optimize_apply.go implements the "optimize apply" command, which downloads +// an optimization candidate and applies it locally to the azd project. +// +// It writes the candidate's instruction, skills, and tool definitions +// into .agent_configs//, updates agent.yaml environment +// variables, and shows a diff summary (prompt and skills) against the +// baseline. + +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "azureaiagent/internal/pkg/agents/opt_eval" + "azureaiagent/internal/pkg/agents/optimize_api" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "github.com/spf13/cobra" + "go.yaml.in/yaml/v3" +) + +// agentConfigsDir aliases the shared constant for local use. +const agentConfigsDir = opt_eval.AgentConfigsDir + +// optimizeApplyFlags holds CLI flags for the optimize apply command. +type optimizeApplyFlags struct { + candidate string // candidate ID from optimization results + agent string // agent service name + optimizeConnectionFlags +} + +func newOptimizeApplyCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + flags := &optimizeApplyFlags{} + action := &OptimizeApplyAction{flags: flags, noPrompt: extCtx.NoPrompt} + + cmd := &cobra.Command{ + Use: "apply", + Short: "Apply optimized candidate configuration locally to your azd project.", + Long: `Download the optimized configuration and skill files from an optimization +candidate and write them into your local azd project under .agent_configs/. + +After applying, run 'azd deploy' to deploy the optimized agent version.`, + Example: ` # Apply candidate config locally, then deploy + azd ai agent optimize apply --candidate candidate_abc123 + azd deploy`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + setupDebugLogging(cmd.Flags()) + return action.Run(ctx, cmd) + }, + } + + cmd.Flags().StringVar(&flags.candidate, "candidate", "", "Candidate ID from optimization results (required)") + cmd.Flags().StringVar(&flags.agent, "agent", "", "Agent service name (auto-detected from azure.yaml)") + _ = cmd.MarkFlagRequired("candidate") + flags.optimizeConnectionFlags.register(cmd) + + return cmd +} + +// OptimizeApplyAction implements the optimize apply command. +type OptimizeApplyAction struct { + flags *optimizeApplyFlags + noPrompt bool +} + +func (a *OptimizeApplyAction) Run(ctx context.Context, cmd *cobra.Command) error { + out := cmd.OutOrStdout() + bold := color.New(color.Bold) + + azdClient, err := azdext.NewAzdClient() + if err != nil { + return fmt.Errorf("failed to create azd client: %w\n\n"+ + "'optimize apply' requires an azd project. Use 'optimize deploy' for standalone API deployment", err) + } + defer azdClient.Close() + + svc, project, err := resolveAgentService(ctx, azdClient, a.flags.agent, a.noPrompt) + if err != nil || project == nil || svc == nil { + return fmt.Errorf("could not resolve agent service in azd project: %w\n\n"+ + "Run 'azd ai agent init' first, or use 'optimize deploy' for standalone API deployment", err) + } + + return a.apply(ctx, azdClient, svc, project, out, bold) +} + +// apply downloads and writes the candidate config, updates agent.yaml, +// stores state, and prints a diff summary. +func (a *OptimizeApplyAction) apply( + ctx context.Context, + azdClient *azdext.AzdClient, + svc *azdext.ServiceConfig, + project *azdext.ProjectConfig, + out io.Writer, + bold *color.Color, +) error { + projectEndpoint, err := resolveProjectEndpointForDeploy(ctx, &a.flags.optimizeConnectionFlags) + if err != nil { + return err + } + + serviceDir := filepath.Join(project.Path, svc.RelativePath) + candidateDir := filepath.Join(serviceDir, agentConfigsDir, a.flags.candidate) + + _, _ = bold.Fprintf(out, "Applying optimization candidate %s...\n\n", a.flags.candidate) + + credential, err := newAgentCredential() + if err != nil { + return err + } + optClient := optimize_api.NewOptimizeClient(projectEndpoint, credential) + + // Step 1: Fetch candidate config from the optimization service. + fmt.Fprintf(out, " Fetching candidate config...\n") + candidateConfig, err := optClient.GetCandidateConfig(ctx, a.flags.candidate) + if err != nil { + return fmt.Errorf("failed to fetch candidate config: %w", err) + } + + if err := os.MkdirAll(candidateDir, 0750); err != nil { + return fmt.Errorf("failed to create optimization directory: %w", err) + } + + // Clean up other candidate directories, keeping only baseline and the current candidate. + cleanOtherCandidates(filepath.Join(serviceDir, agentConfigsDir), a.flags.candidate, out) + + // Step 2: Download skill files into the candidate directory (before metadata.yaml + // so the skills/ dir exists when writeAgentConfigFromCandidate checks for it). + if n, dlErr := downloadSkillFilesToDir(ctx, optClient, a.flags.candidate, candidateDir, out); dlErr != nil { + fmt.Fprintf(out, " warning: failed to download skill files: %s\n", dlErr) + } else if n > 0 { + fmt.Fprintf(out, " Downloaded %d skill file(s)\n", n) + } + + // Write metadata.yaml, instructions.md, skills, and tool definitions for the candidate. + if err := writeAgentConfigFromCandidate(candidateDir, candidateConfig); err != nil { + return fmt.Errorf("failed to write candidate config: %w", err) + } + fmt.Fprintf(out, " → %s\n", filepath.Join(candidateDir, opt_eval.MetadataFile)) + + // Step 3: Write OPTIMIZATION_LOCAL_DIR and OPTIMIZATION_CANDIDATE_ID into agent.yaml + // so the deploy pipeline knows which local optimization config to use. + agentYamlPath := filepath.Join(serviceDir, "agent.yaml") + fmt.Fprintf(out, " Updating %s...\n", agentYamlPath) + if err := upsertAgentYamlEnvVar(agentYamlPath, "OPTIMIZATION_LOCAL_DIR", agentConfigsDir); err != nil { + return fmt.Errorf("failed to update agent.yaml: %w", err) + } + if err := upsertAgentYamlEnvVar(agentYamlPath, "OPTIMIZATION_CANDIDATE_ID", a.flags.candidate); err != nil { + return fmt.Errorf("failed to update agent.yaml: %w", err) + } + + // Step 4: Store candidate ID in the azd environment for tracking. + serviceKey := toServiceKey(svc.Name) + envResp, err := azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{}) + if err != nil { + return fmt.Errorf("failed to get current environment: %w", err) + } + + candidateKey := fmt.Sprintf("AGENT_%s_OPTIMIZATION_CANDIDATE_ID", serviceKey) + if _, err := azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ + EnvName: envResp.Environment.Name, + Key: candidateKey, + Value: a.flags.candidate, + }); err != nil { + return fmt.Errorf("failed to store candidate ID in azd environment: %w", err) + } + + // Done — prompt the user to deploy. + fmt.Fprintln(out) + _, _ = color.New(color.FgGreen, color.Bold).Fprintf(out, + " ✓ Candidate %s applied to %s\n\n", + a.flags.candidate, filepath.Join(agentConfigsDir, a.flags.candidate)) + fmt.Fprintf(out, " Run %s to deploy the optimized agent.\n", + color.CyanString("azd deploy --service %s", svc.Name)) + + // Show instruction diff (baseline → optimized). + printPromptDiff(out, serviceDir, a.flags.candidate, candidateConfig) + + // Point the user to the config folders for other differences (skills, tools, etc.). + baselinePath := filepath.Join(serviceDir, agentConfigsDir, opt_eval.BaselineDir) + candidatePath := filepath.Join(serviceDir, agentConfigsDir, a.flags.candidate) + fmt.Fprintf(out, "\n For other changes (skills, tools, etc.), compare the files in:\n") + fmt.Fprintf(out, " Baseline: %s\n", color.CyanString(baselinePath)) + fmt.Fprintf(out, " Optimized: %s\n", color.CyanString(candidatePath)) + + return nil +} + +// agentConfigMetadata is the YAML structure written as metadata.yaml in each +// agent config version directory (baseline or candidate). +// +// It uses file pointers instead of embedding large content inline: +// - instruction_file → points to instructions.md in the same directory +// - skill_dir → points to the skills/ subdirectory +// - tools_file → points to a tools definition file (optional) +type agentConfigMetadata struct { + Name string `yaml:"name"` + Model string `yaml:"model,omitempty"` + InstructionFile string `yaml:"instruction_file,omitempty"` + SkillDir string `yaml:"skill_dir,omitempty"` + ToolsFile string `yaml:"tools_file,omitempty"` +} + +// loadBaselineConfig reads the baseline metadata.yaml from +// /.agent_configs/baseline/metadata.yaml and resolves +// file pointers to absolute paths. +func loadBaselineConfig(agentProject string) (*agentConfigMetadata, error) { + baseDir := filepath.Join(agentProject, agentConfigsDir, opt_eval.BaselineDir) + metaPath := filepath.Join(baseDir, opt_eval.MetadataFile) + data, err := os.ReadFile(metaPath) //nolint:gosec // path derived from project directory + if err != nil { + return nil, err + } + + var meta agentConfigMetadata + if err := yaml.Unmarshal(data, &meta); err != nil { + return nil, fmt.Errorf("parsing baseline metadata: %w", err) + } + return &meta, nil +} + +// resolveInstructions reads the instruction content from the metadata's +// instruction_file, resolved relative to configDir. +func (m *agentConfigMetadata) resolveInstructions(configDir string) string { + if m.InstructionFile == "" { + return "" + } + path := m.InstructionFile + if !filepath.IsAbs(path) { + path = filepath.Join(configDir, path) + } + data, err := os.ReadFile(path) //nolint:gosec // path derived from project directory + if err != nil { + return "" + } + return string(data) +} + +// resolveSkillDir returns the absolute path to the skill directory, +// resolved relative to configDir. Returns empty if not set. +func (m *agentConfigMetadata) resolveSkillDir(configDir string) string { + if m.SkillDir == "" { + return "" + } + if filepath.IsAbs(m.SkillDir) { + return m.SkillDir + } + return filepath.Join(configDir, m.SkillDir) +} + +// resolveToolsFile returns the absolute path to the tools file, +// resolved relative to configDir. Returns empty if not set. +func (m *agentConfigMetadata) resolveToolsFile(configDir string) string { + if m.ToolsFile == "" { + return "" + } + if filepath.IsAbs(m.ToolsFile) { + return m.ToolsFile + } + return filepath.Join(configDir, m.ToolsFile) +} + +// writeAgentConfigFromCandidate writes metadata.yaml, instructions.md, skill +// files, and tool definitions for an optimization candidate into the given +// directory. No config.json is written — all content is decomposed into +// individual files with pointers in metadata.yaml. +func writeAgentConfigFromCandidate(candidateDir string, rawConfig json.RawMessage) error { + meta := agentConfigMetadata{} + + // Unmarshal the raw JSON into a generic map for field extraction. + var m map[string]any + if err := json.Unmarshal(rawConfig, &m); err != nil { + return fmt.Errorf("parsing candidate config JSON: %w", err) + } + if m != nil { + if v, exists := m["name"]; exists { + if s, ok := v.(string); ok { + meta.Name = s + } + } + if v, exists := m["agentName"]; exists { + if s, ok := v.(string); ok { + meta.Name = s + } + } + if v, exists := m["model"]; exists { + if s, ok := v.(string); ok { + meta.Model = s + } + } + } + + // Write instructions.md from the candidate's system prompt. + instructions := extractInstructions(m) + if instructions != "" { + instructionPath := filepath.Join(candidateDir, opt_eval.InstructionFile) + if err := os.WriteFile(instructionPath, []byte(instructions), 0600); err != nil { + return fmt.Errorf("writing candidate instructions: %w", err) + } + meta.InstructionFile = opt_eval.InstructionFile + } + + // Write inline skills from the candidate config as individual files. + if m != nil { + if err := writeInlineSkills(candidateDir, m); err != nil { + return fmt.Errorf("writing candidate skills: %w", err) + } + } + + // Set skill_dir pointer if the skills/ dir exists (from inline or downloaded skills). + skillDir := filepath.Join(candidateDir, opt_eval.SkillsDir) + if info, err := os.Stat(skillDir); err == nil && info.IsDir() { + meta.SkillDir = opt_eval.SkillsDir + } + + // Write the candidate config as tools.json (preserves original structure). + if m != nil { + if err := writeToolsFile(candidateDir, m); err != nil { + return fmt.Errorf("writing candidate tools file: %w", err) + } + if _, err := os.Stat(filepath.Join(candidateDir, opt_eval.ToolsFile)); err == nil { + meta.ToolsFile = opt_eval.ToolsFile + } + } + + // Write metadata.yaml. + data, err := yaml.Marshal(meta) + if err != nil { + return fmt.Errorf("serializing candidate metadata: %w", err) + } + metaPath := filepath.Join(candidateDir, opt_eval.MetadataFile) + if err := os.WriteFile(metaPath, data, 0600); err != nil { + return fmt.Errorf("writing candidate metadata: %w", err) + } + + return nil +} + +// writeInlineSkills extracts the "skills" array from a candidate config and +// writes each skill as skills//SKILL.md. Each file contains a YAML +// front-matter header with the skill name and description, followed by the +// skill body. +func writeInlineSkills(candidateDir string, config map[string]any) error { + skillsRaw, exists := config["skills"] + if !exists { + return nil + } + skills, ok := skillsRaw.([]any) + if !ok || len(skills) == 0 { + return nil + } + + for _, s := range skills { + sm, ok := s.(map[string]any) + if !ok { + continue + } + name, _ := sm["name"].(string) + if name == "" { + continue + } + body, _ := sm["body"].(string) + description, _ := sm["description"].(string) + + skillSubDir := filepath.Join(candidateDir, opt_eval.SkillsDir, name) + if err := os.MkdirAll(skillSubDir, 0750); err != nil { + return fmt.Errorf("creating skill directory %s: %w", name, err) + } + + // Build the skill file content with YAML front-matter. + var content strings.Builder + content.WriteString("---\n") + content.WriteString(fmt.Sprintf("name: %s\n", name)) + if description != "" { + content.WriteString(fmt.Sprintf("description: %s\n", description)) + } + content.WriteString("---\n") + if body != "" { + content.WriteString(body) + if !strings.HasSuffix(body, "\n") { + content.WriteString("\n") + } + } + + filePath := filepath.Join(skillSubDir, "SKILL.md") + if err := os.WriteFile(filePath, []byte(content.String()), 0600); err != nil { + return fmt.Errorf("writing skill %s: %w", name, err) + } + } + return nil +} + +// writeToolsFile writes the candidate config as tools.json, preserving its +// original structure (may be a list or an object). +func writeToolsFile(candidateDir string, config map[string]any) error { + toolDefs, hasDefs := config["toolDefinitions"] + toolDescriptions, hasToolDescriptions := config["toolDescriptions"] + if !hasDefs && !hasToolDescriptions { + return nil + } + + // Write whichever is present. If only one key exists, write its value + // directly (preserves array or object). If both exist, wrap in an object. + var payload any + switch { + case hasDefs && hasToolDescriptions: + payload = map[string]any{ + "toolDefinitions": toolDefs, + "toolDescriptions": toolDescriptions, + } + case hasDefs: + payload = toolDefs + default: + payload = toolDescriptions + } + + data, err := json.MarshalIndent(payload, "", " ") + if err != nil { + return fmt.Errorf("serializing tools file: %w", err) + } + + return os.WriteFile(filepath.Join(candidateDir, opt_eval.ToolsFile), data, 0600) +} + +// downloadSkillFilesToDir fetches the candidate manifest, downloads all skill +// files, and writes them into the given directory. Returns the number of files written. +func downloadSkillFilesToDir( + ctx context.Context, + client *optimize_api.OptimizeClient, + candidateID string, + destDir string, + out io.Writer, +) (int, error) { + manifest, err := client.GetCandidate(ctx, candidateID) + if err != nil { + return 0, fmt.Errorf("fetching candidate manifest: %w", err) + } + + var skillFiles []optimize_api.CandidateFile + for _, f := range manifest.Files { + if isSkillFile(f) { + skillFiles = append(skillFiles, f) + } + } + if len(skillFiles) == 0 { + return 0, nil + } + + count := 0 + for _, f := range skillFiles { + if f.Path == "" { + continue + } + + content, err := client.GetCandidateFile(ctx, candidateID, f.Path) + if err != nil { + fmt.Fprintf(out, " warning: failed to download skill file %s: %s\n", f.Path, err) + continue + } + + outPath, pathErr := opt_eval.SafePath(destDir, f.Path) + if pathErr != nil { + fmt.Fprintf(out, " warning: skipping file %s: path escapes destination directory\n", f.Path) + continue + } + if err := os.MkdirAll(filepath.Dir(outPath), 0750); err != nil { + return count, fmt.Errorf("creating directory for %s: %w", f.Path, err) + } + + if err := os.WriteFile(outPath, []byte(content), 0600); err != nil { + return count, fmt.Errorf("writing skill file %s: %w", f.Path, err) + } + + fmt.Fprintf(out, " → %s (%d bytes)\n", outPath, len(content)) + count++ + } + + return count, nil +} + +// cleanOtherCandidates removes all subdirectories in the optimization folder +// except the baseline and the candidate being applied. +func cleanOtherCandidates(optimizeDir, currentCandidate string, out io.Writer) { + entries, err := os.ReadDir(optimizeDir) + if err != nil { + return + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + name := entry.Name() + if name == opt_eval.BaselineDir || name == currentCandidate { + continue + } + dir := filepath.Join(optimizeDir, name) + if err := os.RemoveAll(dir); err != nil { + fmt.Fprintf(out, " warning: failed to remove old candidate %s: %s\n", name, err) + } else { + fmt.Fprintf(out, " Removed old candidate: %s\n", name) + } + } +} + +// extractInstructions retrieves the system prompt string from a candidate config +// returned by the optimization service. +func extractInstructions(m map[string]any) string { + if m == nil { + return "" + } + if v, exists := m["systemPrompt"]; exists { + if s, ok := v.(string); ok { + return s + } + } + if v, exists := m["instructions"]; exists { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +// maxDiffPreviewLines is the max lines shown per section in the prompt diff preview. +const maxDiffPreviewLines = 4 + +// printPromptDiff displays an abbreviated instruction diff (baseline → optimized) +// with a short preview of each. +func printPromptDiff(out io.Writer, serviceDir, candidateID string, candidateConfig json.RawMessage) { + var m map[string]any + if err := json.Unmarshal(candidateConfig, &m); err != nil { + return + } + optimized := extractInstructions(m) + if optimized == "" { + return + } + + baseDir := filepath.Join(serviceDir, agentConfigsDir, opt_eval.BaselineDir) + baseline, err := loadBaselineConfig(serviceDir) + if err != nil { + return + } + baselineText := baseline.resolveInstructions(baseDir) + if baselineText == "" { + return + } + baselineLines := strings.Split(baselineText, "\n") + optimizedLines := strings.Split(optimized, "\n") + + fmt.Fprintf(out, "\n Instruction diff (baseline → optimized):\n\n") + + removed := color.New(color.FgRed) + _, _ = removed.Fprintf(out, " — Baseline (%d lines, %d chars):\n", + len(baselineLines), len(baselineText)) + printPreviewLines(out, baselineLines, "- ", removed) + + fmt.Fprintln(out) + + added := color.New(color.FgGreen) + _, _ = added.Fprintf(out, " — Optimized (%d lines, %d chars):\n", + len(optimizedLines), len(optimized)) + printPreviewLines(out, optimizedLines, "+ ", added) +} + +// printPreviewLines prints up to maxDiffPreviewLines with a prefix, then "..." if truncated. +func printPreviewLines(out io.Writer, lines []string, prefix string, c *color.Color) { + limit := min(len(lines), maxDiffPreviewLines) + for _, line := range lines[:limit] { + _, _ = c.Fprintf(out, " %s%s\n", prefix, line) + } + if len(lines) > maxDiffPreviewLines { + _, _ = c.Fprintf(out, " %s... (%d more lines)\n", prefix, len(lines)-maxDiffPreviewLines) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_apply_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_apply_test.go new file mode 100644 index 00000000000..ed9e1488a97 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_apply_test.go @@ -0,0 +1,520 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/opt_eval" + "azureaiagent/internal/pkg/agents/optimize_api" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ---- newOptimizeApplyCommand — command shape ---- + +func TestNewOptimizeApplyCommand_UseString(t *testing.T) { + t.Parallel() + cmd := newOptimizeApplyCommand(&azdext.ExtensionContext{}) + assert.Equal(t, "apply", cmd.Use) +} + +func TestNewOptimizeApplyCommand_Flags(t *testing.T) { + t.Parallel() + cmd := newOptimizeApplyCommand(&azdext.ExtensionContext{}) + + require.NotNil(t, cmd.Flags().Lookup("candidate")) + require.NotNil(t, cmd.Flags().Lookup("agent")) + require.NotNil(t, cmd.Flags().Lookup("endpoint")) + require.NotNil(t, cmd.Flags().Lookup("project-endpoint")) +} + +func TestNewOptimizeApplyCommand_CandidateIsRequired(t *testing.T) { + t.Parallel() + cmd := newOptimizeApplyCommand(&azdext.ExtensionContext{}) + cmd.SetArgs([]string{}) + err := cmd.Execute() + assert.Error(t, err) + assert.Contains(t, err.Error(), "candidate") +} + +// ---- printPreviewLines ---- + +func TestPrintPreviewLines(t *testing.T) { + t.Parallel() + + // Disable color output so assertions don't need ANSI codes. + color.NoColor = true + + tests := []struct { + name string + lines []string + prefix string + want []string // substrings expected in output + }{ + { + "fewer lines than limit", + []string{"line1", "line2"}, + "+ ", + []string{"+ line1", "+ line2"}, + }, + { + "exactly at limit", + []string{"a", "b", "c", "d"}, + "- ", + []string{"- a", "- b", "- c", "- d"}, + }, + { + "exceeds limit shows truncation", + []string{"a", "b", "c", "d", "e", "f"}, + "+ ", + []string{"+ a", "+ b", "+ c", "+ d", "... (2 more lines)"}, + }, + { + "empty lines", + []string{}, + "- ", + nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + c := color.New(color.FgWhite) + printPreviewLines(&buf, tt.lines, tt.prefix, c) + out := buf.String() + for _, s := range tt.want { + assert.Contains(t, out, s) + } + if tt.want == nil { + assert.Empty(t, out) + } + }) + } +} + +// ---- printPromptDiff ---- + +func TestPrintPromptDiff(t *testing.T) { + t.Parallel() + + color.NoColor = true + + t.Run("shows diff when baseline and candidate have instructions", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Set up baseline with metadata that points to an instruction file. + baselineDir := filepath.Join(dir, agentConfigsDir, opt_eval.BaselineDir) + require.NoError(t, os.MkdirAll(baselineDir, 0750)) + require.NoError(t, os.WriteFile( + filepath.Join(baselineDir, opt_eval.InstructionFile), + []byte("You are a baseline assistant.\nLine two."), + 0600, + )) + require.NoError(t, os.WriteFile( + filepath.Join(baselineDir, opt_eval.MetadataFile), + []byte("instruction_file: instructions.md\nmodel: gpt-4o\n"), + 0600, + )) + + candidateConfig := mustMarshal(t, map[string]any{ + "systemPrompt": "You are an optimized assistant.\nNew line two.\nNew line three.", + }) + + var buf bytes.Buffer + printPromptDiff(&buf, dir, "cand1", candidateConfig) + out := buf.String() + + assert.Contains(t, out, "Instruction diff") + assert.Contains(t, out, "Baseline") + assert.Contains(t, out, "Optimized") + assert.Contains(t, out, "You are a baseline assistant.") + assert.Contains(t, out, "You are an optimized assistant.") + }) + + t.Run("no output when candidate has no instructions", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + candidateConfig := mustMarshal(t, map[string]any{"model": "gpt-4o"}) + + var buf bytes.Buffer + printPromptDiff(&buf, dir, "cand1", candidateConfig) + assert.Empty(t, buf.String()) + }) + + t.Run("no output when baseline config missing", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + candidateConfig := mustMarshal(t, map[string]any{"systemPrompt": "optimized"}) + + var buf bytes.Buffer + printPromptDiff(&buf, dir, "cand1", candidateConfig) + assert.Empty(t, buf.String()) + }) + + t.Run("no output when baseline has no instruction file", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Write metadata without instruction_file. + baselineDir := filepath.Join(dir, agentConfigsDir, opt_eval.BaselineDir) + require.NoError(t, os.MkdirAll(baselineDir, 0750)) + require.NoError(t, os.WriteFile( + filepath.Join(baselineDir, opt_eval.MetadataFile), + []byte("model: gpt-4o\n"), + 0600, + )) + + candidateConfig := mustMarshal(t, map[string]any{"systemPrompt": "optimized"}) + + var buf bytes.Buffer + printPromptDiff(&buf, dir, "cand1", candidateConfig) + assert.Empty(t, buf.String()) + }) +} + +func mustMarshal(t *testing.T, v any) json.RawMessage { + t.Helper() + data, err := json.Marshal(v) + require.NoError(t, err) + return data +} + +// ---- extractInstructions ---- + +func TestExtractInstructions(t *testing.T) { + t.Parallel() + tests := []struct { + name string + config map[string]any + want string + }{ + { + "systemPrompt field", + map[string]any{"systemPrompt": "You are a helpful assistant."}, + "You are a helpful assistant.", + }, + { + "instructions field", + map[string]any{"instructions": "Follow the rules."}, + "Follow the rules.", + }, + { + "systemPrompt takes precedence", + map[string]any{ + "systemPrompt": "From systemPrompt", + "instructions": "From instructions", + }, + "From systemPrompt", + }, + {"nil config", nil, ""}, + {"empty map", map[string]any{}, ""}, + {"non-string value", map[string]any{"systemPrompt": 42}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, extractInstructions(tt.config)) + }) + } +} + +// ---- agentConfigMetadata.resolveInstructions ---- + +func TestAgentConfigMetadata_ResolveInstructions(t *testing.T) { + t.Parallel() + t.Run("reads instruction file", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "instructions.md"), []byte("Be helpful."), 0600)) + + meta := &agentConfigMetadata{InstructionFile: "instructions.md"} + assert.Equal(t, "Be helpful.", meta.resolveInstructions(dir)) + }) + + t.Run("returns empty when no file set", func(t *testing.T) { + t.Parallel() + meta := &agentConfigMetadata{} + assert.Empty(t, meta.resolveInstructions(t.TempDir())) + }) + + t.Run("returns empty when file missing", func(t *testing.T) { + t.Parallel() + meta := &agentConfigMetadata{InstructionFile: "nonexistent.md"} + assert.Empty(t, meta.resolveInstructions(t.TempDir())) + }) +} + +// ---- agentConfigMetadata.resolveSkillDir ---- + +func TestAgentConfigMetadata_ResolveSkillDir(t *testing.T) { + t.Parallel() + t.Run("returns empty when not set", func(t *testing.T) { + t.Parallel() + meta := &agentConfigMetadata{} + assert.Empty(t, meta.resolveSkillDir("/some/dir")) + }) + + t.Run("resolves relative path", func(t *testing.T) { + t.Parallel() + meta := &agentConfigMetadata{SkillDir: "skills"} + result := meta.resolveSkillDir("/project/config") + assert.Equal(t, filepath.Join("/project/config", "skills"), result) + }) + + t.Run("preserves absolute path", func(t *testing.T) { + t.Parallel() + abs := filepath.Join(os.TempDir(), "absolute-skills") + meta := &agentConfigMetadata{SkillDir: abs} + assert.Equal(t, abs, meta.resolveSkillDir("/any/dir")) + }) +} + +func TestAgentConfigMetadata_ResolveToolsFile(t *testing.T) { + t.Parallel() + t.Run("returns empty when not set", func(t *testing.T) { + t.Parallel() + meta := &agentConfigMetadata{} + assert.Empty(t, meta.resolveToolsFile("/some/dir")) + }) + + t.Run("resolves relative path", func(t *testing.T) { + t.Parallel() + meta := &agentConfigMetadata{ToolsFile: "tools.json"} + result := meta.resolveToolsFile("/project/config") + assert.Equal(t, filepath.Join("/project/config", "tools.json"), result) + }) + + t.Run("preserves absolute path", func(t *testing.T) { + t.Parallel() + abs := filepath.Join(os.TempDir(), "absolute-tools.json") + meta := &agentConfigMetadata{ToolsFile: abs} + assert.Equal(t, abs, meta.resolveToolsFile("/any/dir")) + }) +} + +// ---- writeAgentConfigFromCandidate ---- + +func TestWriteAgentConfigFromCandidate(t *testing.T) { + t.Parallel() + t.Run("writes metadata and instructions", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + config := mustMarshal(t, map[string]any{ + "name": "test-agent", + "model": "gpt-4o", + "systemPrompt": "Test prompt.", + }) + + err := writeAgentConfigFromCandidate(dir, config) + require.NoError(t, err) + + assert.FileExists(t, filepath.Join(dir, opt_eval.MetadataFile)) + assert.FileExists(t, filepath.Join(dir, opt_eval.InstructionFile)) + + content, err := os.ReadFile(filepath.Join(dir, opt_eval.InstructionFile)) //nolint:gosec // test file path + require.NoError(t, err) + assert.Equal(t, "Test prompt.", string(content)) + }) + + t.Run("writes inline skills", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + config := mustMarshal(t, map[string]any{ + "systemPrompt": "prompt", + "skills": []any{ + map[string]any{ + "name": "search", + "description": "Search the web", + "body": "Search content here.", + }, + }, + }) + + err := writeAgentConfigFromCandidate(dir, config) + require.NoError(t, err) + + skillFile := filepath.Join(dir, opt_eval.SkillsDir, "search", "SKILL.md") + assert.FileExists(t, skillFile) + }) + + t.Run("handles nil config gracefully", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + err := writeAgentConfigFromCandidate(dir, json.RawMessage(`{}`)) + require.NoError(t, err) + assert.FileExists(t, filepath.Join(dir, opt_eval.MetadataFile)) + }) +} + +// ---- cleanOtherCandidates ---- + +func TestCleanOtherCandidates(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Create baseline, current candidate, and old candidate directories. + require.NoError(t, os.MkdirAll(filepath.Join(dir, opt_eval.BaselineDir), 0750)) + require.NoError(t, os.MkdirAll(filepath.Join(dir, "cand_current"), 0750)) + require.NoError(t, os.MkdirAll(filepath.Join(dir, "cand_old"), 0750)) + + var buf bytes.Buffer + cleanOtherCandidates(dir, "cand_current", &buf) + + // baseline and cand_current should remain; cand_old should be removed. + assert.DirExists(t, filepath.Join(dir, opt_eval.BaselineDir)) + assert.DirExists(t, filepath.Join(dir, "cand_current")) + assert.NoDirExists(t, filepath.Join(dir, "cand_old")) +} + +// ---- isSkillFile ---- + +func TestIsSkillFile(t *testing.T) { + t.Parallel() + tests := []struct { + name string + file optimize_api.CandidateFile + want bool + }{ + {"skill type", optimize_api.CandidateFile{Type: "skill", Path: "foo.md"}, true}, + {"skills path prefix", optimize_api.CandidateFile{Type: "file", Path: "skills/search/SKILL.md"}, true}, + {"other type and path", optimize_api.CandidateFile{Type: "file", Path: "config.yaml"}, false}, + {"empty", optimize_api.CandidateFile{}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, isSkillFile(tt.file)) + }) + } +} + +// ---- isReservedEnvVarError ---- + +func TestIsReservedEnvVarError(t *testing.T) { + t.Parallel() + tests := []struct { + name string + err error + want bool + }{ + {"nil error", nil, false}, + {"reserved for platform use", fmt.Errorf("variable is reserved for platform use"), true}, + {"AGENT_* variables", fmt.Errorf("AGENT_* variables are reserved"), true}, + {"unrelated error", fmt.Errorf("connection refused"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, isReservedEnvVarError(tt.err)) + }) + } +} + +// ---- writeToolsFile ---- + +func TestWriteToolsFile_NoKeys(t *testing.T) { + t.Parallel() + dir := t.TempDir() + err := writeToolsFile(dir, map[string]any{"name": "agent"}) + require.NoError(t, err) + assert.NoFileExists(t, filepath.Join(dir, opt_eval.ToolsFile)) +} + +func TestWriteToolsFile_OnlyToolDefinitions(t *testing.T) { + t.Parallel() + dir := t.TempDir() + defs := []any{ + map[string]any{"type": "function", "function": map[string]any{"name": "search"}}, + } + err := writeToolsFile(dir, map[string]any{"toolDefinitions": defs}) + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(dir, opt_eval.ToolsFile)) //nolint:gosec // test file path + require.NoError(t, err) + + // Should be written as a raw list (no wrapper object). + var parsed []any + require.NoError(t, json.Unmarshal(data, &parsed)) + assert.Len(t, parsed, 1) +} + +func TestWriteToolsFile_OnlyToolDescriptions(t *testing.T) { + t.Parallel() + dir := t.TempDir() + descs := map[string]any{ + "lookup_policy": map[string]any{ + "description": "Look up policy", + "parameters": map[string]any{}, + }, + } + err := writeToolsFile(dir, map[string]any{"toolDescriptions": descs}) + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(dir, opt_eval.ToolsFile)) //nolint:gosec // test file path + require.NoError(t, err) + + // Should be written as a raw object (no wrapper). + var parsed map[string]any + require.NoError(t, json.Unmarshal(data, &parsed)) + assert.Contains(t, parsed, "lookup_policy") +} + +func TestWriteToolsFile_BothKeys(t *testing.T) { + t.Parallel() + dir := t.TempDir() + defs := []any{ + map[string]any{"type": "function", "function": map[string]any{"name": "search"}}, + } + descs := map[string]any{ + "search": map[string]any{"description": "Search stuff", "parameters": map[string]any{}}, + } + err := writeToolsFile(dir, map[string]any{ + "toolDefinitions": defs, + "toolDescriptions": descs, + }) + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(dir, opt_eval.ToolsFile)) //nolint:gosec // test file path + require.NoError(t, err) + + var parsed map[string]any + require.NoError(t, json.Unmarshal(data, &parsed)) + assert.Contains(t, parsed, "toolDefinitions") + assert.Contains(t, parsed, "toolDescriptions") +} + +func TestWriteAgentConfigFromCandidate_WithToolDescriptions(t *testing.T) { + t.Parallel() + dir := t.TempDir() + config := mustMarshal(t, map[string]any{ + "systemPrompt": "prompt", + "toolDescriptions": map[string]any{ + "check_budget": map[string]any{ + "description": "Check the budget", + "parameters": map[string]any{}, + }, + }, + }) + + err := writeAgentConfigFromCandidate(dir, config) + require.NoError(t, err) + + assert.FileExists(t, filepath.Join(dir, opt_eval.ToolsFile)) + + // Verify metadata references tools_file. + metaData, err := os.ReadFile(filepath.Join(dir, opt_eval.MetadataFile)) //nolint:gosec // test file path + require.NoError(t, err) + assert.Contains(t, string(metaData), "tools_file") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_cancel.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_cancel.go new file mode 100644 index 00000000000..597f3c8ce72 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_cancel.go @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// optimize_cancel.go implements the "optimize cancel" command, which cancels +// a running optimization job by its operation ID. + +package cmd + +import ( + "fmt" + + "azureaiagent/internal/pkg/agents/optimize_api" + + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +// optimizeCancelFlags holds connection settings for the cancel command. +type optimizeCancelFlags struct { + optimizeConnectionFlags +} + +func newOptimizeCancelCommand() *cobra.Command { + flags := &optimizeCancelFlags{} + + cmd := &cobra.Command{ + Use: "cancel ", + Short: "Cancel a running optimization job.", + Long: `Cancel a running optimization or evaluation job by its operation ID. + +Only jobs in a non-terminal state (pending, running) can be cancelled.`, + Example: ` # Cancel a running job + azd ai agent optimize cancel opt_abc123`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runOptimizeCancel(cmd, flags, args[0]) + }, + } + + flags.optimizeConnectionFlags.register(cmd) + + return cmd +} + +func runOptimizeCancel(cmd *cobra.Command, flags *optimizeCancelFlags, operationID string) error { + endpoint, err := flags.resolve(cmd.Context()) + if err != nil { + return err + } + + credential, err := newAgentCredential() + if err != nil { + return err + } + + client := optimize_api.NewOptimizeClient(endpoint, credential) + + cancelResp, err := client.CancelOptimize(cmd.Context(), operationID) + if err != nil { + return fmt.Errorf("failed to cancel job: %w\n\nCheck that the operation ID %q is correct and the job is still running", err, operationID) + } + + out := cmd.OutOrStdout() + fmt.Fprintf(out, " %s Job %s has been cancelled (status: %s).\n", + color.YellowString("⚠"), operationID, cancelResp.Status) + fmt.Fprintf(out, "\n Check status with:\n azd ai agent optimize status %s\n", operationID) + + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_cancel_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_cancel_test.go new file mode 100644 index 00000000000..af815c2128f --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_cancel_test.go @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOptimizeCancelCommand_RequiresPositionalArg(t *testing.T) { + cmd := newOptimizeCancelCommand() + + err := cmd.Args(cmd, []string{}) + assert.Error(t, err) + + err = cmd.Args(cmd, []string{"opt_abc123"}) + assert.NoError(t, err) + + err = cmd.Args(cmd, []string{"opt_abc123", "extra"}) + assert.Error(t, err) +} + +func TestOptimizeCancelCommand_HasConnectionFlags(t *testing.T) { + cmd := newOptimizeCancelCommand() + + assert.NotNil(t, cmd.Flags().Lookup("endpoint")) + assert.NotNil(t, cmd.Flags().Lookup("project-endpoint")) + + assert.Nil(t, cmd.Flags().Lookup("subscription")) + assert.Nil(t, cmd.Flags().Lookup("resource-group")) + assert.Nil(t, cmd.Flags().Lookup("workspace")) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_config.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_config.go new file mode 100644 index 00000000000..60c00f5626d --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_config.go @@ -0,0 +1,353 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// optimize_config.go defines OptimizeConfig (the YAML config structure for +// optimization jobs), provides loading/validation, and converts configs into +// API requests. It also handles reading skills from disk and parsing YAML +// preamble in skill files. + +package cmd + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "azureaiagent/internal/pkg/agents/opt_eval" + "azureaiagent/internal/pkg/agents/optimize_api" + + "go.yaml.in/yaml/v3" +) + +// OptimizeConfig extends the shared Config with optimize-specific fields. +type OptimizeConfig struct { + opt_eval.Config `yaml:",inline"` + + // Optimize-specific YAML fields. + ValidationReference *opt_eval.DatasetRef `yaml:"validation_reference,omitempty"` + Criteria []OptimizeConfigCriterion `yaml:"criteria,omitempty"` + Options *opt_eval.Options `yaml:"options"` + InlineDataset []optimize_api.DatasetTask `yaml:"-"` // populated by defaultOptimizeConfig, not from YAML + + // Runtime-only: resolved skill directory and tools file (not serialized to YAML). + SkillDir string `yaml:"-"` + ToolsFile string `yaml:"-"` +} + +// OptimizeConfigCriterion is a named evaluation criterion with a natural-language instruction. +type OptimizeConfigCriterion struct { + Name string `yaml:"name"` + Instruction string `yaml:"instruction"` +} + +// LoadOptimizeConfig reads and parses a YAML optimization config file. +func LoadOptimizeConfig(path string) (*OptimizeConfig, error) { + data, err := os.ReadFile(path) //nolint:gosec // path is provided by user for local config + if err != nil { + return nil, fmt.Errorf("failed to read config file %s: %w", path, err) + } + + var cfg OptimizeConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse config file %s: %w", path, err) + } + + return &cfg, nil +} + +// Validate checks required fields and mutual exclusivity constraints. +func (c *OptimizeConfig) Validate() error { + if c.Agent.Name == "" { + return fmt.Errorf("agent.name is required") + } + + if c.Options == nil || c.Options.EvalModel == "" { + return fmt.Errorf("options.eval_model is required") + } + + hasFile := c.DatasetFile != "" + hasRef := c.DatasetReference != nil + hasInline := len(c.InlineDataset) > 0 + + if hasFile && hasRef { + return fmt.Errorf("dataset_file and dataset_reference are mutually exclusive; specify one, not both") + } + + if !hasFile && !hasRef && !hasInline { + return fmt.Errorf("one of dataset_file or dataset_reference is required") + } + + return nil +} + +// defaultOptimizeConfig returns a config with sensible defaults and a built-in +// evaluation dataset. +func defaultOptimizeConfig(agentName string) *OptimizeConfig { + return &OptimizeConfig{ + Config: opt_eval.Config{ + Agent: opt_eval.AgentRef{Name: agentName}, + Evaluators: opt_eval.EvaluatorList{{Name: "builtin.task_adherence"}}, + }, + InlineDataset: defaultDataset, + Options: &opt_eval.Options{ + EvalModel: defaultEvalModel, + TargetAttributes: []string{"instruction", "skill"}, + }, + } +} + +var defaultDataset = []optimize_api.DatasetTask{ + { + Name: "calculator_module", + Prompt: "Create a Python module calc.py with four functions: add, subtract, multiply, divide. " + + "Each takes two numbers and returns the result. Include a brief test at the bottom " + + "(if __name__ == '__main__') that exercises each function and prints the results. Then run it.", + Criteria: []optimize_api.Criterion{ + {Name: "decimal_types", Instruction: "ALL functions MUST use and return Python's decimal.Decimal type, NOT float."}, + {Name: "error_code_prefix", Instruction: "ALL error messages raised by any function MUST include a bracketed error code prefix [CALC-NNN]."}, + {Name: "version_constant", Instruction: "The module MUST define VERSION = '0.1.0' and __version__ = VERSION near the top."}, + {Name: "module_exports", Instruction: "The module MUST define __all__ = ['add', 'subtract', 'multiply', 'divide'] at the top."}, + }, + }, + { + Name: "csv_report", + Prompt: "Create a Python script report.py that generates a CSV file 'sales_report.csv' " + + "with 10 rows of sample sales data. Columns: date, product, quantity, unit_price, total. " + + "Then read the CSV back and print a summary: total revenue and the top-selling product " + + "by quantity. Run the script.", + Criteria: []optimize_api.Criterion{ + {Name: "pipe_delimiter", Instruction: "The CSV file MUST use pipe '|' as the delimiter, NOT comma."}, + {Name: "zero_padded_quantity", Instruction: "ALL quantity values MUST be zero-padded to exactly 4 digits (e.g. '0042' not '42')."}, + {Name: "logging_not_print", Instruction: "The script MUST use Python's logging module for progress messages, NOT print()."}, + {Name: "summary_footer", Instruction: "The LAST line of the CSV file MUST be a comment starting with '# SUMMARY:' including total revenue."}, + }, + }, + { + Name: "api_response_builder", + Prompt: "Create a Python module api_utils.py with a function build_response(data, " + + "status_code=200) that builds a JSON-ready dictionary representing an API response. " + + "Also create a function validate_email(email: str) -> bool that checks if an email " + + "is roughly valid. Write a test block that demonstrates both functions with a few " + + "examples and prints the JSON output. Run it.", + Criteria: []optimize_api.Criterion{ + {Name: "named_tuple_validation", Instruction: "validate_email() MUST return a typing.NamedTuple with fields (is_valid: bool, reason: str), NOT a bare bool."}, + {Name: "request_id", Instruction: "build_response() MUST include a 'requestId' field containing a UUID4 string."}, + {Name: "rfc7807_errors", Instruction: "When status_code >= 400, the response MUST follow RFC 7807 with 'type', 'title', 'detail', 'status' keys."}, + {Name: "camel_case_keys", Instruction: "ALL dictionary keys in the response MUST be camelCase (e.g. 'statusCode', NOT 'status_code')."}, + }, + }, +} + +// ToRequest converts the YAML config into an API OptimizeRequest. +// If DatasetFile is set, each line of the file is read as a JSON-encoded DatasetTask. +func (c *OptimizeConfig) ToRequest(projectEndpoint string) (*optimize_api.OptimizeRequest, error) { + req := &optimize_api.OptimizeRequest{ + Agent: optimize_api.AgentDefinition{ + FoundryProjectURL: projectEndpoint, + AgentName: c.Agent.Name, + AgentVersion: c.Agent.Version, + Model: c.Agent.Model, + SystemPrompt: c.Agent.ResolvedSystemPrompt(), + }, + Evaluators: c.Evaluators.Names(), + Options: optimize_api.OptimizeOptions{ + EvalModel: c.Options.EvalModel, + MaxIterations: c.Options.MaxIterations, + Strategies: c.Options.TargetAttributes, + TargetAttributes: c.Options.TargetAttributes, + KeepVersions: c.Options.KeepVersions, + TasksPerIteration: c.Options.TasksPerIteration, + ReflectionModel: c.Options.ReflectionModel, + EvaluationLevel: c.Options.EvaluationLevel, + }, + } + + // Map target_config from YAML to API format. + if c.Options.TargetConfig != nil { + req.Options.TargetConfig = &optimize_api.TargetConfig{ + Model: c.Options.TargetConfig.Model, + } + } + + // Map criteria from config schema to API schema. + for _, crit := range c.Criteria { + req.Criteria = append(req.Criteria, optimize_api.Criterion{ + Name: crit.Name, + Instruction: crit.Instruction, + }) + } + + if c.DatasetReference != nil { + req.TrainDatasetReference = &optimize_api.DatasetReference{ + Name: c.DatasetReference.Name, + Version: c.DatasetReference.Version, + } + } + + if c.ValidationReference != nil { + req.ValidationDatasetReference = &optimize_api.DatasetReference{ + Name: c.ValidationReference.Name, + Version: c.ValidationReference.Version, + } + } + + if c.DatasetFile != "" { + tasks, err := loadJSONLFile[optimize_api.DatasetTask](c.DatasetFile) + if err != nil { + return nil, err + } + req.Dataset = tasks + } else if len(c.InlineDataset) > 0 { + req.Dataset = c.InlineDataset + } + + // Load skills from skill_dir if specified. + if c.SkillDir != "" { + skills, err := loadSkillsFromDir(c.SkillDir) + if err != nil { + return nil, fmt.Errorf("loading skills from %s: %w", c.SkillDir, err) + } + req.Agent.Skills = skills + } + + // Load tool definitions if a tools file is specified. + if c.ToolsFile != "" { + tools, err := loadToolDefinitions(c.ToolsFile) + if err != nil { + return nil, fmt.Errorf("loading tool definitions from %s: %w", c.ToolsFile, err) + } + req.Agent.ToolDefinitions = tools + } + + return req, nil +} + +// loadToolDefinitions reads an OpenAI-format tools JSON file and returns +// ToolDefinition entries for the optimize API request. +func loadToolDefinitions(path string) ([]optimize_api.ToolDefinition, error) { + data, err := os.ReadFile(path) //nolint:gosec // path derived from project tools file + if err != nil { + return nil, fmt.Errorf("reading tools file: %w", err) + } + + var tools []optimize_api.ToolDefinition + if err := json.Unmarshal(data, &tools); err != nil { + return nil, fmt.Errorf("parsing tools file: %w", err) + } + + return tools, nil +} + +// loadSkillsFromDir reads skill files from a directory and returns SkillDefinitions. +// For markdown files (.md), YAML preamble is parsed to extract name and description; +// the content after the preamble becomes the skill body. +// For other files, the filename (without extension) is used as the name and the full +// content as the body. +// Subdirectories are recursed into — each file within is also loaded as a skill. +func loadSkillsFromDir(dir string) ([]optimize_api.SkillDefinition, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("reading skill directory: %w", err) + } + + var skills []optimize_api.SkillDefinition + for _, entry := range entries { + entryPath := filepath.Join(dir, entry.Name()) + + if entry.IsDir() { + subSkills, err := loadSkillsFromDir(entryPath) + if err != nil { + return nil, err + } + skills = append(skills, subSkills...) + continue + } + + data, err := os.ReadFile(entryPath) //nolint:gosec // path derived from project skill directory + if err != nil { + return nil, fmt.Errorf("reading skill file %s: %w", entry.Name(), err) + } + + skill := parseSkillFile(entry.Name(), string(data)) + skills = append(skills, skill) + } + + return skills, nil +} + +// skillPreamble represents the YAML preamble in a skill markdown file. +type skillPreamble struct { + Name string `yaml:"name"` + Description string `yaml:"description"` +} + +// parseSkillFile parses a skill file. For .md files it attempts to extract +// YAML preamble (delimited by "---") for name and description; the body +// is the content after the preamble. For other files, the filename (sans +// extension) is the name and the full content is the body. +func parseSkillFile(filename, content string) optimize_api.SkillDefinition { + ext := filepath.Ext(filename) + baseName := strings.TrimSuffix(filename, ext) + + if !strings.EqualFold(ext, ".md") { + return optimize_api.SkillDefinition{ + Name: baseName, + Body: content, + } + } + + // Try to parse YAML preamble from markdown. + fm, body := splitPreamble(content) + skill := optimize_api.SkillDefinition{ + Name: baseName, + Body: body, + } + + if fm != "" { + var meta skillPreamble + if err := yaml.Unmarshal([]byte(fm), &meta); err == nil { + if meta.Name != "" { + skill.Name = meta.Name + } + skill.Description = meta.Description + } + } + + return skill +} + +// splitPreamble splits YAML preamble (between "---" delimiters) from +// the rest of the content. Returns (preamble, body). If no preamble is +// found, returns ("", original content). +func splitPreamble(content string) (string, string) { + const delimiter = "---" + + scanner := bufio.NewScanner(strings.NewReader(content)) + if !scanner.Scan() { + return "", content + } + if strings.TrimSpace(scanner.Text()) != delimiter { + return "", content + } + + var fmLines []string + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) == delimiter { + // Found closing delimiter — rest is the body. + var bodyLines []string + for scanner.Scan() { + bodyLines = append(bodyLines, scanner.Text()) + } + body := strings.Join(bodyLines, "\n") + return strings.Join(fmLines, "\n"), strings.TrimSpace(body) + } + fmLines = append(fmLines, line) + } + + // No closing delimiter found — treat entire content as body. + return "", content +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_config_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_config_test.go new file mode 100644 index 00000000000..33e1b6dfa79 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_config_test.go @@ -0,0 +1,473 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/opt_eval" + "azureaiagent/internal/pkg/agents/optimize_api" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func writeTestFile(t *testing.T, dir, name, content string) string { + t.Helper() + path := filepath.Join(dir, name) + require.NoError(t, os.WriteFile(path, []byte(content), 0600)) + return path +} + +func TestLoadOptimizeConfig_WithDatasetFile(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + datasetPath := writeTestFile(t, dir, "tasks.jsonl", + `{"prompt":"What is 2+2?","groundTruth":"4"} +{"prompt":"Capital of France?","groundTruth":"Paris"} +`) + + yamlContent := ` +agent: + name: my-agent + version: "1" + model: gpt-4o +dataset_file: ` + datasetPath + ` +evaluators: + - coherence + - relevance +criteria: + - name: accuracy + instruction: answer must be correct +options: + eval_model: gpt-4o-mini + budget: 100 + max_iterations: 5 + strategies: + - prompt_mutation +` + cfgPath := writeTestFile(t, dir, "optimize.yaml", yamlContent) + + cfg, err := LoadOptimizeConfig(cfgPath) + require.NoError(t, err) + require.NoError(t, cfg.Validate()) + + req, err := cfg.ToRequest("https://example.ai.azure.com/project/p") + require.NoError(t, err) + + assert.Equal(t, "my-agent", req.Agent.AgentName) + assert.Equal(t, "1", req.Agent.AgentVersion) + assert.Equal(t, "https://example.ai.azure.com/project/p", req.Agent.FoundryProjectURL) + assert.Len(t, req.Dataset, 2) + assert.Equal(t, "What is 2+2?", req.Dataset[0].Prompt) + assert.Equal(t, "4", req.Dataset[0].GroundTruth) + assert.Nil(t, req.TrainDatasetReference) + assert.Equal(t, "gpt-4o-mini", req.Options.EvalModel) + assert.Equal(t, []string{"coherence", "relevance"}, req.Evaluators) + assert.Len(t, req.Criteria, 1) + assert.Equal(t, "accuracy", req.Criteria[0].Name) +} + +func TestLoadOptimizeConfig_WithDatasetReference(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + yamlContent := ` +agent: + name: ref-agent +dataset_reference: + name: my-dataset + version: "2" +validation_reference: + name: val-dataset + version: "1" +options: + eval_model: gpt-4o-mini +` + cfgPath := writeTestFile(t, dir, "optimize.yaml", yamlContent) + + cfg, err := LoadOptimizeConfig(cfgPath) + require.NoError(t, err) + require.NoError(t, cfg.Validate()) + + req, err := cfg.ToRequest("https://example.com/proj") + require.NoError(t, err) + + assert.Equal(t, "ref-agent", req.Agent.AgentName) + assert.Empty(t, req.Dataset) + require.NotNil(t, req.TrainDatasetReference) + assert.Equal(t, "my-dataset", req.TrainDatasetReference.Name) + assert.Equal(t, "2", req.TrainDatasetReference.Version) + require.NotNil(t, req.ValidationDatasetReference) + assert.Equal(t, "val-dataset", req.ValidationDatasetReference.Name) +} + +func TestValidate_MissingAgentName(t *testing.T) { + t.Parallel() + + cfg := &OptimizeConfig{ + Config: opt_eval.Config{ + DatasetReference: &opt_eval.DatasetRef{Name: "ds", Version: "1"}, + }, + Options: &opt_eval.Options{EvalModel: "gpt-4o-mini"}, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "agent.name is required") +} + +func TestValidate_MissingEvalModel(t *testing.T) { + t.Parallel() + + cfg := &OptimizeConfig{ + Config: opt_eval.Config{ + Agent: opt_eval.AgentRef{Name: "agent"}, + DatasetReference: &opt_eval.DatasetRef{Name: "ds", Version: "1"}, + }, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "eval_model is required") +} + +func TestValidate_BothDatasetFileAndReference(t *testing.T) { + t.Parallel() + + cfg := &OptimizeConfig{ + Config: opt_eval.Config{ + Agent: opt_eval.AgentRef{Name: "agent"}, + DatasetFile: "tasks.jsonl", + DatasetReference: &opt_eval.DatasetRef{Name: "ds", Version: "1"}, + }, + Options: &opt_eval.Options{EvalModel: "gpt-4o-mini"}, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "mutually exclusive") +} + +func TestValidate_NeitherDatasetFileNorReference(t *testing.T) { + t.Parallel() + + cfg := &OptimizeConfig{ + Config: opt_eval.Config{Agent: opt_eval.AgentRef{Name: "agent"}}, + Options: &opt_eval.Options{EvalModel: "gpt-4o-mini"}, + } + + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "one of dataset_file or dataset_reference is required") +} + +func TestLoadOptimizeConfig_FileNotFound(t *testing.T) { + t.Parallel() + + _, err := LoadOptimizeConfig("/nonexistent/path/optimize.yaml") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read config file") +} + +func TestLoadOptimizeConfig_InvalidYAML(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfgPath := writeTestFile(t, dir, "bad.yaml", "{{invalid yaml}}") + + _, err := LoadOptimizeConfig(cfgPath) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse config") +} + +func TestLoadOptimizeConfig_EvalYAMLFormat(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + // An eval.yaml file should be loadable by the optimize config loader. + // eval_model at the top level won't map to Options, so we verify the + // agent and evaluators parse correctly. + yamlContent := ` +name: smoke-core +agent: + name: my-eval-agent + version: "3" + kind: hosted +dataset_reference: + name: eval-dataset + version: "1" +evaluators: + - builtin.task_adherence +options: + eval_model: gpt-4o +` + cfgPath := writeTestFile(t, dir, "eval.yaml", yamlContent) + + cfg, err := LoadOptimizeConfig(cfgPath) + require.NoError(t, err) + + assert.Equal(t, "my-eval-agent", cfg.Agent.Name) + assert.Equal(t, "3", cfg.Agent.Version) + require.NotNil(t, cfg.Options) + assert.Equal(t, "gpt-4o", cfg.Options.EvalModel) + assert.Len(t, cfg.Evaluators, 1) + assert.Equal(t, "builtin.task_adherence", cfg.Evaluators[0].Name) + require.NotNil(t, cfg.DatasetReference) + assert.Equal(t, "eval-dataset", cfg.DatasetReference.Name) +} + +func TestLoadOptimizeConfig_ScalarEvaluatorsWithOptions(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + yamlContent := ` +agent: + name: my-test-agent + +dataset_file: eval.jsonl + +evaluators: + - builtin.task_adherence + +options: + eval_model: gpt-4o + strategies: + - instruction + budget: 3 +` + datasetPath := writeTestFile(t, dir, "eval.jsonl", + `{"prompt":"hello","groundTruth":"hi"} +`) + // Rewrite dataset_file to the real temp path so Validate+ToRequest work. + yamlContent = ` +agent: + name: my-test-agent +dataset_file: ` + datasetPath + ` +evaluators: + - builtin.task_adherence +options: + eval_model: gpt-4o + strategies: + - instruction + budget: 3 +` + cfgPath := writeTestFile(t, dir, "spec.yaml", yamlContent) + + cfg, err := LoadOptimizeConfig(cfgPath) + require.NoError(t, err) + + // Agent + assert.Equal(t, "my-test-agent", cfg.Agent.Name) + + // Dataset + assert.Equal(t, datasetPath, cfg.DatasetFile) + assert.Nil(t, cfg.DatasetReference) + + // Evaluator — scalar string without builtin. prefix resolves as custom. + require.Len(t, cfg.Evaluators, 1) + assert.Equal(t, "builtin.task_adherence", cfg.Evaluators[0].Name) + + // Options + require.NotNil(t, cfg.Options) + assert.Equal(t, "gpt-4o", cfg.Options.EvalModel) + assert.Equal(t, []string{"instruction"}, cfg.Options.TargetAttributes) + + // Validate + ToRequest + require.NoError(t, cfg.Validate()) + req, err := cfg.ToRequest("https://example.ai.azure.com/project/p") + require.NoError(t, err) + assert.Equal(t, "my-test-agent", req.Agent.AgentName) + assert.Len(t, req.Dataset, 1) + assert.Equal(t, []string{"builtin.task_adherence"}, req.Evaluators) +} + +// --------------------------------------------------------------------------- +// parseSkillFile / loadSkillsFromDir +// --------------------------------------------------------------------------- + +func TestParseSkillFile_MarkdownWithPreamble(t *testing.T) { + t.Parallel() + content := `--- +name: policy-reviewer +description: Reviews a travel request against company travel policy. +--- + +# Policy Reviewer Skill + +Review travel requests and provide a friendly assessment. +` + skill := parseSkillFile("SKILL.md", content) + assert.Equal(t, "policy-reviewer", skill.Name) + assert.Equal(t, "Reviews a travel request against company travel policy.", skill.Description) + assert.Contains(t, skill.Body, "# Policy Reviewer Skill") + assert.Contains(t, skill.Body, "friendly assessment") + assert.NotContains(t, skill.Body, "---") +} + +func TestParseSkillFile_MarkdownWithoutPreamble(t *testing.T) { + t.Parallel() + content := "# Simple Skill\n\nDo something useful.\n" + skill := parseSkillFile("simple.md", content) + assert.Equal(t, "simple", skill.Name) + assert.Empty(t, skill.Description) + assert.Equal(t, content, skill.Body) +} + +func TestParseSkillFile_NonMarkdown(t *testing.T) { + t.Parallel() + content := "You are a helpful assistant." + skill := parseSkillFile("assistant.txt", content) + assert.Equal(t, "assistant", skill.Name) + assert.Empty(t, skill.Description) + assert.Equal(t, content, skill.Body) +} + +func TestParseSkillFile_PreambleNameOnly(t *testing.T) { + t.Parallel() + content := "---\nname: custom-name\n---\nBody content here.\n" + skill := parseSkillFile("ignored-filename.md", content) + assert.Equal(t, "custom-name", skill.Name) + assert.Empty(t, skill.Description) + assert.Equal(t, "Body content here.", skill.Body) +} + +func TestLoadSkillsFromDir_WithMarkdownSkills(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + md := "---\nname: reviewer\ndescription: Reviews things\n---\n\nReview body.\n" + require.NoError(t, os.WriteFile(filepath.Join(dir, "SKILL.md"), []byte(md), 0600)) + + txt := "Plain text skill body." + require.NoError(t, os.WriteFile(filepath.Join(dir, "helper.txt"), []byte(txt), 0600)) + + skills, err := loadSkillsFromDir(dir) + require.NoError(t, err) + require.Len(t, skills, 2) + + // Find each skill by name. + var mdSkill, txtSkill *optimize_api.SkillDefinition + for i := range skills { + switch skills[i].Name { + case "reviewer": + mdSkill = &skills[i] + case "helper": + txtSkill = &skills[i] + } + } + + require.NotNil(t, mdSkill) + assert.Equal(t, "Reviews things", mdSkill.Description) + assert.Contains(t, mdSkill.Body, "Review body.") + + require.NotNil(t, txtSkill) + assert.Empty(t, txtSkill.Description) + assert.Equal(t, txt, txtSkill.Body) +} + +// --------------------------------------------------------------------------- +// loadToolDefinitions +// --------------------------------------------------------------------------- + +func TestLoadToolDefinitions_Valid(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + content := `[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"} + }, + "required": ["location"] + } + } + }, + { + "type": "function", + "function": { + "name": "search", + "description": "Search the web" + } + } +]` + path := writeTestFile(t, dir, "tools.json", content) + + tools, err := loadToolDefinitions(path) + require.NoError(t, err) + require.Len(t, tools, 2) + + assert.Equal(t, "function", tools[0].Type) + assert.Equal(t, "get_weather", tools[0].Function.Name) + assert.Equal(t, "Get current weather for a location", tools[0].Function.Description) + assert.NotNil(t, tools[0].Function.Parameters) + + assert.Equal(t, "function", tools[1].Type) + assert.Equal(t, "search", tools[1].Function.Name) +} + +func TestLoadToolDefinitions_FileNotFound(t *testing.T) { + t.Parallel() + _, err := loadToolDefinitions(filepath.Join(t.TempDir(), "nonexistent.json")) + require.Error(t, err) + assert.Contains(t, err.Error(), "reading tools file") +} + +func TestLoadToolDefinitions_InvalidJSON(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := writeTestFile(t, dir, "tools.json", "not json") + + _, err := loadToolDefinitions(path) + require.Error(t, err) + assert.Contains(t, err.Error(), "parsing tools file") +} + +func TestLoadToolDefinitions_EmptyArray(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := writeTestFile(t, dir, "tools.json", "[]") + + tools, err := loadToolDefinitions(path) + require.NoError(t, err) + assert.Empty(t, tools) +} + +func TestToRequest_WithToolsFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + toolsContent := `[{"type":"function","function":{"name":"calculator","description":"Do math"}}]` + toolsPath := writeTestFile(t, dir, "tools.json", toolsContent) + + cfg := &OptimizeConfig{ + Config: opt_eval.Config{ + Agent: opt_eval.AgentRef{Name: "test-agent"}, + DatasetFile: writeTestFile(t, dir, "dataset.jsonl", `{"prompt":"test"}`), + }, + Options: &opt_eval.Options{ + EvalModel: "gpt-4o", + }, + ToolsFile: toolsPath, + } + + req, err := cfg.ToRequest("https://example.com") + require.NoError(t, err) + require.Len(t, req.Agent.ToolDefinitions, 1) + assert.Equal(t, "calculator", req.Agent.ToolDefinitions[0].Function.Name) + assert.Equal(t, "Do math", req.Agent.ToolDefinitions[0].Function.Description) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_deploy.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_deploy.go new file mode 100644 index 00000000000..d757e001d91 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_deploy.go @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// optimize_deploy.go implements the "optimize deploy" command, which deploys +// an optimization candidate directly to a Foundry agent (without requiring +// an azd project). It fetches the candidate config, patches the agent, and +// creates a new agent version. + +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "strings" + "time" + + "azureaiagent/internal/pkg/agents/agent_api" + "azureaiagent/internal/pkg/agents/agent_yaml" + "azureaiagent/internal/pkg/agents/optimize_api" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "github.com/spf13/cobra" + "go.yaml.in/yaml/v3" +) + +type optimizeDeployFlags struct { + candidate string + agent string + optimizeConnectionFlags +} + +func newOptimizeDeployCommand() *cobra.Command { + flags := &optimizeDeployFlags{} + action := &OptimizeDeployAction{flags: flags} + + cmd := &cobra.Command{ + Use: "deploy [agent-name]", + Short: "Deploy a winning optimization candidate as a new agent version via the API.", + Long: `Deploy an optimization candidate directly via the Foundry agent API. + +This creates a new agent version with the optimized configuration applied. +Use 'optimize apply' instead if you want to localize the config into your azd project first.`, + Example: ` # Deploy candidate directly + azd ai agent optimize deploy --candidate candidate_abc123 --agent my-agent + + # Deploy with explicit endpoint + azd ai agent optimize deploy --candidate candidate_abc123 --agent my-agent --project-endpoint https://...`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + setupDebugLogging(cmd.Flags()) + + if len(args) > 0 && flags.agent == "" { + flags.agent = args[0] + } + + return action.Run(ctx, cmd) + }, + } + + cmd.Flags().StringVar(&flags.candidate, "candidate", "", "Candidate ID from optimization results (required)") + cmd.Flags().StringVar(&flags.agent, "agent", "", "Agent name to deploy to (auto-detected from agent.yaml)") + _ = cmd.MarkFlagRequired("candidate") + flags.optimizeConnectionFlags.register(cmd) + + return cmd +} + +// OptimizeDeployAction implements the optimize deploy command. +type OptimizeDeployAction struct { + flags *optimizeDeployFlags +} + +func (a *OptimizeDeployAction) Run(ctx context.Context, cmd *cobra.Command) error { + out := cmd.OutOrStdout() + bold := color.New(color.Bold) + + return a.runDirect(ctx, out, bold) +} + +// runDirect deploys a candidate directly via the Foundry agent API. +// TODO: Change this to full remote deployment here if not in an azd project +func (a *OptimizeDeployAction) runDirect( + ctx context.Context, + out io.Writer, + bold *color.Color, +) error { + // Resolve agent name from flag or agent.yaml in current directory. + resolved, err := resolveOptimizeAgent(ctx, a.flags.agent, false) + if err != nil { + return err + } + agentName := resolved.agentName + + // Resolve project endpoint (for Foundry agent API). + projectEndpoint, err := resolveProjectEndpointForDeploy(ctx, &a.flags.optimizeConnectionFlags) + if err != nil { + return err + } + + _, _ = bold.Fprintf(out, "Deploying candidate %s to agent %s...\n\n", a.flags.candidate, agentName) + + // Step 1: Fetch candidate config from optimization service. + fmt.Fprintf(out, " Fetching candidate config...\n") + credential, err := newAgentCredential() + if err != nil { + return err + } + optClient := optimize_api.NewOptimizeClient(projectEndpoint, credential) + candidateConfig, err := optClient.GetCandidateConfig(ctx, a.flags.candidate) + if err != nil { + return fmt.Errorf("failed to fetch candidate config: %w", err) + } + + // JSON-stringify the candidate config for the env var. + configJSON, err := json.Marshal(candidateConfig) + if err != nil { + return fmt.Errorf("failed to serialize candidate config: %w", err) + } + + // Step 2: Fetch current agent from Foundry. + fmt.Fprintf(out, " Fetching current agent definition...\n") + agentClient := agent_api.NewAgentClient(projectEndpoint, credential) + + agentObj, err := agentClient.GetAgent(ctx, agentName, DefaultAgentAPIVersion) + if err != nil { + return fmt.Errorf("failed to get agent %q: %w", agentName, err) + } + + // Extract definition from latest version using map[string]any for flexibility. + latestDef, err := extractLatestDefinition(agentObj) + if err != nil { + return err + } + + // Step 3: Merge env vars and create new version. + // Use OPTIMIZATION_CONFIG (non-reserved) — the agent SDK reads both + // AGENT_OPTIMIZATION_CONFIG (first-party service) and OPTIMIZATION_CONFIG (CLI). + envVars := extractEnvVars(latestDef) + envVars["OPTIMIZATION_CONFIG"] = string(configJSON) + + newDef := buildDeployDefinition(latestDef, envVars) + + description := fmt.Sprintf("Optimized: candidate %s", a.flags.candidate) + createReq := &agent_api.CreateAgentVersionRequest{ + Description: &description, + Metadata: map[string]string{"optimized_from": a.flags.candidate}, + Definition: newDef, + } + + fmt.Fprintf(out, " Creating new agent version...\n") + versionObj, err := agentClient.CreateAgentVersion(ctx, agentName, createReq, DefaultAgentAPIVersion) + if err != nil { + // Check for reserved env var error (AGENT_* and FOUNDRY_* are platform-reserved). + if isReservedEnvVarError(err) { + return fmt.Errorf("the platform reserves AGENT_* environment variables for internal use.\n\n" + + "Deploying optimization candidates for hosted (container) agents requires the\n" + + "optimization service to create versions with elevated privileges.\n\n" + + "Contact the platform team to promote via the optimization service API") + } + return fmt.Errorf("failed to create agent version: %w", err) + } + + // Step 4: Poll until version is active. + fmt.Fprintf(out, " Waiting for version %s to become active...\n", versionObj.Version) + if err := pollVersionActive(ctx, agentClient, agentName, versionObj.Version); err != nil { + return err + } + + // Step 5: Report the deployment to the optimization service (best-effort). + if err := optClient.ReportDeployment(ctx, &optimize_api.DeploymentReport{ + CandidateID: a.flags.candidate, + AgentName: agentName, + AgentVersion: versionObj.Version, + }); err != nil { + // Non-fatal — deployment succeeded, just log the reporting failure. + fmt.Fprintf(out, " %s failed to report deployment to optimization service: %s\n", + color.YellowString("warning:"), err) + } + + // Step 6: Print success. + fmt.Fprintln(out) + _, _ = color.New(color.FgGreen, color.Bold).Fprintf(out, + " \u2713 Successfully deployed candidate %s as version %s\n", a.flags.candidate, versionObj.Version) + fmt.Fprintf(out, "\n Agent: %s\n", agentName) + fmt.Fprintf(out, " Version: %s\n", versionObj.Version) + + return nil +} + +// upsertAgentYamlEnvVar reads the agent.yaml file, adds or updates the specified +// environment variable in the environment_variables list, and writes back. +func upsertAgentYamlEnvVar(agentYamlPath, key, value string) error { + data, err := os.ReadFile(agentYamlPath) //nolint:gosec // G304: path from azd project + if err != nil { + return fmt.Errorf("reading agent.yaml: %w", err) + } + + var agent agent_yaml.ContainerAgent + if err := yaml.Unmarshal(data, &agent); err != nil { + return fmt.Errorf("parsing agent.yaml: %w", err) + } + + // Upsert the environment variable. + if agent.EnvironmentVariables == nil { + agent.EnvironmentVariables = &[]agent_yaml.EnvironmentVariable{} + } + + found := false + envVars := *agent.EnvironmentVariables + for i := range envVars { + if envVars[i].Name == key { + envVars[i].Value = value + found = true + break + } + } + if !found { + envVars = append(envVars, agent_yaml.EnvironmentVariable{Name: key, Value: value}) + } + agent.EnvironmentVariables = &envVars + + // Marshal back to YAML and write. + out, err := yaml.Marshal(&agent) + if err != nil { + return fmt.Errorf("marshaling agent.yaml: %w", err) + } + + //nolint:gosec // G306: agent.yaml should be readable by tooling + if err := os.WriteFile(agentYamlPath, out, 0644); err != nil { + return fmt.Errorf("writing agent.yaml: %w", err) + } + + return nil +} + +// resolveProjectEndpointForDeploy resolves the Foundry project endpoint using +// the same resolution chain as other agent commands. +func resolveProjectEndpointForDeploy(ctx context.Context, connFlags *optimizeConnectionFlags) (string, error) { + if connFlags.projectEndpoint != "" { + return strings.TrimRight(connFlags.projectEndpoint, "/"), nil + } + + projectEndpoint, err := resolveAgentEndpoint(ctx, "", "") + if err != nil { + if ep := projectEndpointFromEnv(); ep != "" { + return ep, nil + } + return "", fmt.Errorf("could not resolve project endpoint: %w\n\n"+ + "Provide --project-endpoint (-p), or run 'azd ai agent init'", err) + } + return projectEndpoint, nil +} + +// isReservedEnvVarError checks if a version creation error is due to +// the platform rejecting reserved AGENT_* or FOUNDRY_* environment variables. +// TODO: Use azcore.ResponseError.StatusCode + stable API error code when available, +// instead of brittle substring matching on server error wording. +func isReservedEnvVarError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "reserved for platform use") || + strings.Contains(msg, "AGENT_* variables are reserved") +} + +// --- Skill file download --- + +// isSkillFile returns true if the manifest entry represents a skill file. +func isSkillFile(f optimize_api.CandidateFile) bool { + return f.Type == "skill" || strings.HasPrefix(f.Path, "skills/") +} + +// extractLatestDefinition gets the latest version's definition as a map for flexible field access. +func extractLatestDefinition(agent *agent_api.AgentObject) (map[string]any, error) { + defBytes, err := json.Marshal(agent.Versions.Latest.Definition) + if err != nil { + return nil, fmt.Errorf("failed to read agent definition: %w", err) + } + + var defMap map[string]any + if err := json.Unmarshal(defBytes, &defMap); err != nil { + return nil, fmt.Errorf("failed to parse agent definition: %w", err) + } + return defMap, nil +} + +// extractEnvVars extracts existing environment variables from a definition map. +func extractEnvVars(def map[string]any) map[string]string { + result := make(map[string]string) + if envRaw, ok := def["environment_variables"]; ok { + if envMap, ok := envRaw.(map[string]any); ok { + for k, v := range envMap { + if s, ok := v.(string); ok { + result[k] = s + } + } + } + } + return result +} + +// buildDeployDefinition creates the definition map for the new version, +// preserving all fields from the current version but overriding env vars. +func buildDeployDefinition(currentDef map[string]any, envVars map[string]string) map[string]any { + newDef := make(map[string]any) + for k, v := range currentDef { + if k != "environment_variables" { + newDef[k] = v + } + } + newDef["environment_variables"] = envVars + normalizeProtocolVersions(newDef) + return newDef +} + +// normalizeProtocolVersions ensures container_protocol_versions use the +// canonical "1.0.0" format instead of the legacy "v1" format that the +// platform no longer accepts for new versions. +func normalizeProtocolVersions(def map[string]any) { + raw, ok := def["container_protocol_versions"] + if !ok { + return + } + protocols, ok := raw.([]any) + if !ok { + return + } + for _, p := range protocols { + pMap, ok := p.(map[string]any) + if !ok { + continue + } + if ver, ok := pMap["version"].(string); ok && ver == "v1" { + pMap["version"] = "1.0.0" + } + } +} + +// pollVersionActive polls the agent version until its status is "active" or a timeout occurs. +func pollVersionActive( + ctx context.Context, + client *agent_api.AgentClient, + agentName, versionNum string, +) error { + timeout := 5 * time.Minute + interval := 5 * time.Second + deadline := time.Now().Add(timeout) + + for { + if time.Now().After(deadline) { + return fmt.Errorf("timed out waiting for version %s to become active after %s", versionNum, timeout) + } + + version, err := client.GetAgentVersion(ctx, agentName, versionNum, DefaultAgentAPIVersion) + if err != nil { + return fmt.Errorf("failed to poll version status: %w", err) + } + + if version.Status == "active" { + return nil + } + + if version.Status == "failed" { + return fmt.Errorf("version %s failed to activate", versionNum) + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(interval): + } + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_deploy_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_deploy_test.go new file mode 100644 index 00000000000..5b58f43eb66 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_deploy_test.go @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOptimizeDeployCommand_HasRequiredFlags(t *testing.T) { + cmd := newOptimizeDeployCommand() + + candidateFlag := cmd.Flags().Lookup("candidate") + require.NotNil(t, candidateFlag, "--candidate flag should be registered") + + agentFlag := cmd.Flags().Lookup("agent") + require.NotNil(t, agentFlag, "--agent flag should be registered") +} + +func TestOptimizeDeployCommand_CandidateIsRequired(t *testing.T) { + cmd := newOptimizeDeployCommand() + + // Set only --agent, omit --candidate + cmd.SetArgs([]string{"--agent", "my-agent"}) + err := cmd.Execute() + assert.Error(t, err) + assert.Contains(t, err.Error(), "candidate") +} + +func TestOptimizeDeployCommand_AgentResolvedFromFlagOrYaml(t *testing.T) { + cmd := newOptimizeDeployCommand() + + // --agent is no longer MarkFlagRequired; it falls back to agent.yaml + agentFlag := cmd.Flags().Lookup("agent") + require.NotNil(t, agentFlag) + // Without --agent and without agent.yaml, should error about agent name + cmd.SetArgs([]string{"--candidate", "cand_123"}) + err := cmd.Execute() + assert.Error(t, err) + assert.Contains(t, err.Error(), "agent") +} + +func TestOptimizeDeployCommand_HasConnectionFlags(t *testing.T) { + cmd := newOptimizeDeployCommand() + + assert.NotNil(t, cmd.Flags().Lookup("endpoint")) + assert.NotNil(t, cmd.Flags().Lookup("project-endpoint")) + + // Should NOT have subscription/resource-group/workspace + assert.Nil(t, cmd.Flags().Lookup("subscription")) + assert.Nil(t, cmd.Flags().Lookup("resource-group")) + assert.Nil(t, cmd.Flags().Lookup("workspace")) +} + +func TestOptimizeCommand_HasDeploySubCommand(t *testing.T) { + cmd := newOptimizeCommand(&azdext.ExtensionContext{}) + + var actual []string + for _, sub := range cmd.Commands() { + actual = append(actual, sub.Name()) + } + + assert.Contains(t, actual, "deploy", "optimize should have 'deploy' sub-command") +} + +func TestExtractEnvVars_EmptyDef(t *testing.T) { + def := map[string]any{"kind": "hosted"} + result := extractEnvVars(def) + assert.Empty(t, result) +} + +func TestExtractEnvVars_WithVars(t *testing.T) { + def := map[string]any{ + "kind": "hosted", + "environment_variables": map[string]any{ + "FOO": "bar", + "BAZ": "qux", + }, + } + result := extractEnvVars(def) + assert.Equal(t, "bar", result["FOO"]) + assert.Equal(t, "qux", result["BAZ"]) + assert.Len(t, result, 2) +} + +func TestBuildDeployDefinition_PreservesFieldsAndOverridesEnvVars(t *testing.T) { + currentDef := map[string]any{ + "kind": "hosted", + "image": "myimage:latest", + "cpu": "1.0", + "memory": "2Gi", + "environment_variables": map[string]any{ + "EXISTING_VAR": "keep_me", + }, + } + + envVars := map[string]string{ + "EXISTING_VAR": "keep_me", + "OPTIMIZATION_CONFIG": `{"key":"value"}`, + } + + newDef := buildDeployDefinition(currentDef, envVars) + + assert.Equal(t, "hosted", newDef["kind"]) + assert.Equal(t, "myimage:latest", newDef["image"]) + assert.Equal(t, "1.0", newDef["cpu"]) + assert.Equal(t, "2Gi", newDef["memory"]) + + newEnvVars, ok := newDef["environment_variables"].(map[string]string) + require.True(t, ok) + assert.Equal(t, "keep_me", newEnvVars["EXISTING_VAR"]) + assert.Equal(t, `{"key":"value"}`, newEnvVars["OPTIMIZATION_CONFIG"]) +} + +func TestBuildDeployDefinition_NormalizesProtocolVersion(t *testing.T) { + currentDef := map[string]any{ + "kind": "hosted", + "image": "myimage:latest", + "cpu": "1.0", + "memory": "2Gi", + "container_protocol_versions": []any{ + map[string]any{"protocol": "responses", "version": "v1"}, + }, + "environment_variables": map[string]any{}, + } + + newDef := buildDeployDefinition(currentDef, map[string]string{"FOO": "bar"}) + + protocols := newDef["container_protocol_versions"].([]any) + p := protocols[0].(map[string]any) + assert.Equal(t, "1.0.0", p["version"], "v1 should be normalized to 1.0.0") + assert.Equal(t, "responses", p["protocol"]) +} + +func TestNormalizeProtocolVersions_NoOp(t *testing.T) { + // Already 1.0.0 — should not change + def := map[string]any{ + "container_protocol_versions": []any{ + map[string]any{"protocol": "responses", "version": "1.0.0"}, + }, + } + normalizeProtocolVersions(def) + + protocols := def["container_protocol_versions"].([]any) + p := protocols[0].(map[string]any) + assert.Equal(t, "1.0.0", p["version"]) +} + +func TestNormalizeProtocolVersions_MissingField(t *testing.T) { + def := map[string]any{"kind": "hosted"} + normalizeProtocolVersions(def) // should not panic +} + +// ---- upsertAgentYamlEnvVar ---- + +func TestUpsertAgentYamlEnvVar_InsertsNew(t *testing.T) { + t.Parallel() + dir := t.TempDir() + yamlPath := filepath.Join(dir, "agent.yaml") + require.NoError(t, os.WriteFile(yamlPath, []byte("name: test-agent\n"), 0600)) + + err := upsertAgentYamlEnvVar(yamlPath, "MY_VAR", "my_value") + require.NoError(t, err) + + data, err := os.ReadFile(yamlPath) //nolint:gosec // test file path + require.NoError(t, err) + assert.Contains(t, string(data), "MY_VAR") + assert.Contains(t, string(data), "my_value") +} + +func TestUpsertAgentYamlEnvVar_UpdatesExisting(t *testing.T) { + t.Parallel() + dir := t.TempDir() + yamlPath := filepath.Join(dir, "agent.yaml") + content := `name: test-agent +environment_variables: + - name: MY_VAR + value: old_value +` + require.NoError(t, os.WriteFile(yamlPath, []byte(content), 0600)) + + err := upsertAgentYamlEnvVar(yamlPath, "MY_VAR", "new_value") + require.NoError(t, err) + + data, err := os.ReadFile(yamlPath) //nolint:gosec // test file path + require.NoError(t, err) + assert.Contains(t, string(data), "new_value") + assert.NotContains(t, string(data), "old_value") +} + +func TestUpsertAgentYamlEnvVar_FileMissing(t *testing.T) { + t.Parallel() + err := upsertAgentYamlEnvVar("/nonexistent/agent.yaml", "KEY", "VALUE") + assert.Error(t, err) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_helpers.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_helpers.go new file mode 100644 index 00000000000..d29e3251fa9 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_helpers.go @@ -0,0 +1,222 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// optimize_helpers.go provides shared utilities for optimize commands: +// connection flag resolution, job ID persistence in the azd environment, +// and portal link construction. + +package cmd + +import ( + "context" + "fmt" + "io" + "log" + "os" + "strings" + + "azureaiagent/internal/pkg/agents/eval_api" + "azureaiagent/internal/pkg/agents/optimize_api" + + azdext "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// optimizeConnectionFlags holds connection settings shared across all optimize sub-commands. +type optimizeConnectionFlags struct { + projectEndpoint string // Foundry project endpoint URL + endpoint string // direct optimization service URL (for local dev only) +} + +// register adds the connection flags to the given cobra command. +func (f *optimizeConnectionFlags) register(cmd *cobra.Command) { + cmd.Flags().StringVarP(&f.projectEndpoint, "project-endpoint", "p", "", "Foundry project endpoint URL") + cmd.Flags().StringVar(&f.endpoint, "endpoint", "", "Optimization service endpoint (for local dev)") +} + +// resolve returns the project endpoint for optimize API calls. +// projectEndpointFromEnv returns the project endpoint from FOUNDRY_PROJECT_ENDPOINT +// or AZURE_AI_PROJECT_ENDPOINT environment variables (in that priority order). +// Returns empty string if neither is set. +func projectEndpointFromEnv() string { + if ep := os.Getenv("FOUNDRY_PROJECT_ENDPOINT"); ep != "" { + return strings.TrimRight(ep, "/") + } + if ep := os.Getenv("AZURE_AI_PROJECT_ENDPOINT"); ep != "" { + return strings.TrimRight(ep, "/") + } + return "" +} + +// Priority: --endpoint flag → --project-endpoint → azd environment → FOUNDRY_PROJECT_ENDPOINT / AZURE_AI_PROJECT_ENDPOINT env var. +func (f *optimizeConnectionFlags) resolve(ctx context.Context) (string, error) { + if f.endpoint != "" { + return strings.TrimRight(f.endpoint, "/"), nil + } + + // Explicit --project-endpoint flag + if f.projectEndpoint != "" { + return strings.TrimRight(f.projectEndpoint, "/"), nil + } + + // Try azd environment (works when running under azd) + projectEndpoint, err := resolveAgentEndpoint(ctx, "", "") + if err != nil { + // Fall back to FOUNDRY_PROJECT_ENDPOINT or AZURE_AI_PROJECT_ENDPOINT env var (works standalone) + if ep := projectEndpointFromEnv(); ep != "" { + return ep, nil + } + return "", fmt.Errorf("could not resolve project endpoint\n\n" + + "Set FOUNDRY_PROJECT_ENDPOINT or AZURE_AI_PROJECT_ENDPOINT, provide --project-endpoint (-p),\n" + + "or run 'azd ai agent init'") + } + + return projectEndpoint, nil +} + +// optimizeLastJobIDKey is the azd environment key for the last optimization job ID. +const optimizeLastJobIDKey = "OPTIMIZE_LAST_OPERATION_ID" + +// saveLastOptimizeJobID stores the operation ID in the azd environment. +// Best-effort — silently ignores errors (e.g., when running outside azd). +func saveLastOptimizeJobID(ctx context.Context, operationID string) { + azdClient, err := azdext.NewAzdClient() + if err != nil { + return + } + defer azdClient.Close() + + envResp, err := azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{}) + if err != nil || envResp == nil { + return + } + + _, _ = azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ + EnvName: envResp.Environment.Name, + Key: optimizeLastJobIDKey, + Value: operationID, + }) +} + +// loadLastOptimizeJobID retrieves the last operation ID from the azd environment. +// Returns empty string if not available. +func loadLastOptimizeJobID(ctx context.Context) string { + azdClient, err := azdext.NewAzdClient() + if err != nil { + return "" + } + defer azdClient.Close() + + envResp, err := azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{}) + if err != nil || envResp == nil { + return "" + } + + resp, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: envResp.Environment.Name, + Key: optimizeLastJobIDKey, + }) + if err != nil || resp == nil { + return "" + } + return resp.Value +} + +// printOptimizePortalLink prints the Foundry portal URL for an optimization job. +// Best-effort — silently skips if the portal prefix cannot be resolved. +func printOptimizePortalLink(ctx context.Context, out io.Writer, agentName, operationID string) { + azdClient, err := azdext.NewAzdClient() + if err != nil { + return + } + defer azdClient.Close() + + envResp, err := azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{}) + if err != nil || envResp == nil { + return + } + + printPortalLink(ctx, out, azdClient, envResp.Environment.Name, func(prefix *eval_api.PortalPrefix) string { + return prefix.OptimizationURL(agentName, operationID) + }) +} + +// reportOptimizationDeployments reports optimization candidate deployments to the optimization service. +// For each hosted agent service, if AGENT_{KEY}_OPTIMIZATION_CANDIDATE_ID is set in +// the azd environment, it calls the promote API and then clears the env var. +// This is best-effort — failures are logged but do not block the deploy. +func reportOptimizationDeployments( + ctx context.Context, + azdClient *azdext.AzdClient, + hostedAgents []*azdext.ServiceConfig, + envName, projectEndpoint string, + newClient func(endpoint string) *optimize_api.OptimizeClient, +) { + log.Printf("postdeploy: reporting optimization deployments for %d hosted agents", len(hostedAgents)) + + for _, svc := range hostedAgents { + func() { + defer func() { + if r := recover(); r != nil { + log.Printf("postdeploy: optimization reporting panicked for %s: %v", svc.Name, r) + } + }() + reportSvcOptimizationDeployment(ctx, azdClient, svc, envName, projectEndpoint, newClient) + }() + } +} + +// reportSvcOptimizationDeployment reports a single service's optimization candidate. +func reportSvcOptimizationDeployment( + ctx context.Context, + azdClient *azdext.AzdClient, + svc *azdext.ServiceConfig, + envName, projectEndpoint string, + newClient func(endpoint string) *optimize_api.OptimizeClient, +) { + serviceKey := toServiceKey(svc.Name) + candidateKey := fmt.Sprintf("AGENT_%s_OPTIMIZATION_CANDIDATE_ID", serviceKey) + + candidateResp, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: envName, + Key: candidateKey, + }) + if err != nil || candidateResp.Value == "" { + log.Printf("postdeploy: no optimization candidate for %s, skipping", svc.Name) + return + } + + versionKey := fmt.Sprintf("AGENT_%s_VERSION", serviceKey) + versionResp, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: envName, + Key: versionKey, + }) + if err != nil || versionResp.Value == "" { + log.Printf("postdeploy: no version for %s, skipping", svc.Name) + return + } + + log.Printf("postdeploy: promoting candidate %s for %s (version %s)", + candidateResp.Value, svc.Name, versionResp.Value) + + optClient := newClient(projectEndpoint) + if err := optClient.ReportDeployment(ctx, &optimize_api.DeploymentReport{ + CandidateID: candidateResp.Value, + AgentName: svc.Name, + AgentVersion: versionResp.Value, + }); err != nil { + log.Printf("postdeploy: failed to report optimization deployment for %s: %v", svc.Name, err) + return + } + + log.Printf("postdeploy: successfully promoted candidate %s for %s", candidateResp.Value, svc.Name) + + // Clear the candidate ID after successful reporting. + if _, err := azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ + EnvName: envName, + Key: candidateKey, + Value: "", + }); err != nil { + log.Printf("postdeploy: failed to clear %s: %v", candidateKey, err) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_helpers_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_helpers_test.go new file mode 100644 index 00000000000..e325e84f653 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_helpers_test.go @@ -0,0 +1,393 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" + + "azureaiagent/internal/pkg/agents/optimize_api" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +func TestOptimizeConnectionFlags_Resolve_AllEmpty(t *testing.T) { + f := &optimizeConnectionFlags{} + _, err := f.resolve(context.Background()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "endpoint") +} + +func TestOptimizeConnectionFlags_Resolve_FlagEndpoint(t *testing.T) { + f := &optimizeConnectionFlags{ + endpoint: "https://from-flag.com", + } + endpoint, err := f.resolve(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "https://from-flag.com", endpoint) +} + +func TestOptimizeConnectionFlags_Resolve_TrimsTrailingSlash(t *testing.T) { + f := &optimizeConnectionFlags{ + endpoint: "https://example.com/", + } + endpoint, err := f.resolve(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "https://example.com", endpoint) +} + +func TestOptimizeConnectionFlags_Resolve_ProjectEndpointFlag(t *testing.T) { + f := &optimizeConnectionFlags{ + projectEndpoint: "https://my-project.services.ai.azure.com/", + } + endpoint, err := f.resolve(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "https://my-project.services.ai.azure.com", endpoint) +} + +// newOptimizeTestAzdClient creates a test AzdClient backed by a gRPC server +// with the given environment service implementation. +func newOptimizeTestAzdClient( + t *testing.T, + envServer azdext.EnvironmentServiceServer, +) *azdext.AzdClient { + t.Helper() + + grpcServer := grpc.NewServer() + azdext.RegisterEnvironmentServiceServer(grpcServer, envServer) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + go func() { _ = grpcServer.Serve(listener) }() + + t.Cleanup(func() { + grpcServer.Stop() + _ = listener.Close() + }) + + azdClient, err := azdext.NewAzdClient(azdext.WithAddress(listener.Addr().String())) + require.NoError(t, err) + t.Cleanup(func() { azdClient.Close() }) + + return azdClient +} + +// newTestOptimizeClient creates an OptimizeClient that talks to the given +// httptest server, using a bare pipeline (no auth). +func newTestOptimizeClient(endpoint string) *optimize_api.OptimizeClient { + pl := runtime.NewPipeline("test", "v0.0.0", runtime.PipelineOptions{}, &policy.ClientOptions{}) + return optimize_api.NewOptimizeClientFromPipeline(endpoint, pl) +} + +func TestReportOptimizationDeployments_NoAgents(t *testing.T) { + t.Parallel() + + envServer := &testEnvironmentServiceServer{ + values: map[string]map[string]string{ + "dev": {}, + }, + } + azdClient := newOptimizeTestAzdClient(t, envServer) + + // Should complete without calling any API. + reportOptimizationDeployments( + t.Context(), azdClient, nil, "dev", "https://unused.example.com", + newTestOptimizeClient, + ) +} + +func TestReportOptimizationDeployments_Success_ClearsCandidate(t *testing.T) { + t.Parallel() + + envServer := &testEnvironmentServiceServer{ + values: map[string]map[string]string{ + "dev": { + "AGENT_MY_AGENT_OPTIMIZATION_CANDIDATE_ID": "cand-123", + "AGENT_MY_AGENT_VERSION": "v2", + }, + }, + } + azdClient := newOptimizeTestAzdClient(t, envServer) + + var gotURL string + var gotBody optimize_api.DeploymentReport + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotURL = r.URL.String() + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + agents := []*azdext.ServiceConfig{{Name: "my-agent"}} + + reportOptimizationDeployments( + t.Context(), azdClient, agents, "dev", srv.URL, + newTestOptimizeClient, + ) + + assert.Contains(t, gotURL, "/optimize/candidates/cand-123:promote") + assert.Equal(t, "my-agent", gotBody.AgentName) + assert.Equal(t, "v2", gotBody.AgentVersion) + // CandidateID is json:"-", so it should not appear in the body. + assert.Empty(t, gotBody.CandidateID) + + // The candidate key should be cleared after successful reporting. + assert.Equal(t, "", envServer.values["dev"]["AGENT_MY_AGENT_OPTIMIZATION_CANDIDATE_ID"]) +} + +func TestReportOptimizationDeployments_MissingCandidateID_Skips(t *testing.T) { + t.Parallel() + + envServer := &testEnvironmentServiceServer{ + values: map[string]map[string]string{ + "dev": { + // No AGENT_SVC_OPTIMIZATION_CANDIDATE_ID at all. + "AGENT_SVC_VERSION": "v1", + }, + }, + } + azdClient := newOptimizeTestAzdClient(t, envServer) + + apiCalled := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + apiCalled = true + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + agents := []*azdext.ServiceConfig{{Name: "svc"}} + reportOptimizationDeployments( + t.Context(), azdClient, agents, "dev", srv.URL, + newTestOptimizeClient, + ) + + assert.False(t, apiCalled, "API should not be called when candidate ID is missing") +} + +func TestReportOptimizationDeployments_MissingVersion_Skips(t *testing.T) { + t.Parallel() + + envServer := &testEnvironmentServiceServer{ + values: map[string]map[string]string{ + "dev": { + "AGENT_SVC_OPTIMIZATION_CANDIDATE_ID": "cand-456", + // No AGENT_SVC_VERSION. + }, + }, + } + azdClient := newOptimizeTestAzdClient(t, envServer) + + apiCalled := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + apiCalled = true + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + agents := []*azdext.ServiceConfig{{Name: "svc"}} + reportOptimizationDeployments( + t.Context(), azdClient, agents, "dev", srv.URL, + newTestOptimizeClient, + ) + + assert.False(t, apiCalled, "API should not be called when version is missing") +} + +func TestReportOptimizationDeployments_APIFailure_DoesNotClearCandidate(t *testing.T) { + t.Parallel() + + envServer := &testEnvironmentServiceServer{ + values: map[string]map[string]string{ + "dev": { + "AGENT_SVC_OPTIMIZATION_CANDIDATE_ID": "cand-789", + "AGENT_SVC_VERSION": "v3", + }, + }, + } + azdClient := newOptimizeTestAzdClient(t, envServer) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + agents := []*azdext.ServiceConfig{{Name: "svc"}} + reportOptimizationDeployments( + t.Context(), azdClient, agents, "dev", srv.URL, + newTestOptimizeClient, + ) + + // Candidate key should NOT be cleared when the API returns an error. + assert.Equal(t, "cand-789", envServer.values["dev"]["AGENT_SVC_OPTIMIZATION_CANDIDATE_ID"]) +} + +func TestReportOptimizationDeployments_MultipleAgents(t *testing.T) { + t.Parallel() + + envServer := &testEnvironmentServiceServer{ + values: map[string]map[string]string{ + "dev": { + "AGENT_ALPHA_OPTIMIZATION_CANDIDATE_ID": "c-a", + "AGENT_ALPHA_VERSION": "v1", + // beta has no candidate — should be skipped. + "AGENT_BETA_VERSION": "v2", + // gamma has candidate but API will fail for it. + "AGENT_GAMMA_OPTIMIZATION_CANDIDATE_ID": "c-g", + "AGENT_GAMMA_VERSION": "v3", + }, + }, + } + azdClient := newOptimizeTestAzdClient(t, envServer) + + promoted := map[string]bool{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/optimize/candidates/c-g:promote" { + w.WriteHeader(http.StatusInternalServerError) + return + } + promoted[r.URL.Path] = true + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + agents := []*azdext.ServiceConfig{ + {Name: "alpha"}, + {Name: "beta"}, + {Name: "gamma"}, + } + + reportOptimizationDeployments( + t.Context(), azdClient, agents, "dev", srv.URL, + newTestOptimizeClient, + ) + + // Alpha: promoted and cleared. + assert.True(t, promoted["/optimize/candidates/c-a:promote"]) + assert.Equal(t, "", envServer.values["dev"]["AGENT_ALPHA_OPTIMIZATION_CANDIDATE_ID"]) + + // Beta: skipped (no candidate ID), no API call. + assert.False(t, promoted["/optimize/candidates/:promote"]) // shouldn't appear + + // Gamma: API failed, so candidate key should remain. + assert.Equal(t, "c-g", envServer.values["dev"]["AGENT_GAMMA_OPTIMIZATION_CANDIDATE_ID"]) +} + +func TestReportOptimizationDeployments_ServiceNameWithDashes(t *testing.T) { + t.Parallel() + + envServer := &testEnvironmentServiceServer{ + values: map[string]map[string]string{ + "dev": { + "AGENT_MY_COOL_AGENT_OPTIMIZATION_CANDIDATE_ID": "cand-dash", + "AGENT_MY_COOL_AGENT_VERSION": "v5", + }, + }, + } + azdClient := newOptimizeTestAzdClient(t, envServer) + + var gotURL string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotURL = r.URL.String() + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + agents := []*azdext.ServiceConfig{{Name: "my-cool-agent"}} + reportOptimizationDeployments( + t.Context(), azdClient, agents, "dev", srv.URL, + newTestOptimizeClient, + ) + + assert.Contains(t, gotURL, "/optimize/candidates/cand-dash:promote") + assert.Equal(t, "", envServer.values["dev"]["AGENT_MY_COOL_AGENT_OPTIMIZATION_CANDIDATE_ID"]) +} + +func TestReportOptimizationDeployments_PanicRecovery(t *testing.T) { + t.Parallel() + + envServer := &testEnvironmentServiceServer{ + values: map[string]map[string]string{ + "dev": { + "AGENT_SVC_OPTIMIZATION_CANDIDATE_ID": "cand-panic", + "AGENT_SVC_VERSION": "v1", + }, + }, + } + azdClient := newOptimizeTestAzdClient(t, envServer) + + agents := []*azdext.ServiceConfig{{Name: "svc"}} + + // Pass a newClient factory that panics. The recover guard should + // prevent this from crashing the caller. + assert.NotPanics(t, func() { + reportOptimizationDeployments( + t.Context(), azdClient, agents, "dev", "https://unused", + func(_ string) *optimize_api.OptimizeClient { + panic("boom") + }, + ) + }) + + // Candidate key should remain since the promote never succeeded. + assert.Equal(t, "cand-panic", envServer.values["dev"]["AGENT_SVC_OPTIMIZATION_CANDIDATE_ID"]) +} + +func TestOptimizeConnectionFlags_Resolve_FoundryEnvVar(t *testing.T) { + t.Setenv("FOUNDRY_PROJECT_ENDPOINT", "https://foundry.example.com/") + f := &optimizeConnectionFlags{} + endpoint, err := f.resolve(t.Context()) + assert.NoError(t, err) + assert.Equal(t, "https://foundry.example.com", endpoint) +} + +func TestOptimizeConnectionFlags_Resolve_AzureAIEnvVar(t *testing.T) { + t.Setenv("AZURE_AI_PROJECT_ENDPOINT", "https://azure-ai.example.com/") + f := &optimizeConnectionFlags{} + endpoint, err := f.resolve(t.Context()) + assert.NoError(t, err) + assert.Equal(t, "https://azure-ai.example.com", endpoint) +} + +func TestOptimizeConnectionFlags_Resolve_FoundryTakesPriorityOverAzureAI(t *testing.T) { + t.Setenv("FOUNDRY_PROJECT_ENDPOINT", "https://foundry.example.com") + t.Setenv("AZURE_AI_PROJECT_ENDPOINT", "https://azure-ai.example.com") + f := &optimizeConnectionFlags{} + endpoint, err := f.resolve(t.Context()) + assert.NoError(t, err) + assert.Equal(t, "https://foundry.example.com", endpoint) +} + +func TestResolveProjectEndpointForDeploy_FoundryEnvVar(t *testing.T) { + t.Setenv("FOUNDRY_PROJECT_ENDPOINT", "https://foundry-deploy.example.com/") + ep, err := resolveProjectEndpointForDeploy(t.Context(), &optimizeConnectionFlags{}) + assert.NoError(t, err) + assert.Equal(t, "https://foundry-deploy.example.com", ep) +} + +func TestResolveProjectEndpointForDeploy_AzureAIEnvVar(t *testing.T) { + t.Setenv("AZURE_AI_PROJECT_ENDPOINT", "https://azure-ai-deploy.example.com/") + ep, err := resolveProjectEndpointForDeploy(t.Context(), &optimizeConnectionFlags{}) + assert.NoError(t, err) + assert.Equal(t, "https://azure-ai-deploy.example.com", ep) +} + +func TestResolveProjectEndpointForDeploy_FoundryTakesPriorityOverAzureAI(t *testing.T) { + t.Setenv("FOUNDRY_PROJECT_ENDPOINT", "https://foundry-deploy.example.com") + t.Setenv("AZURE_AI_PROJECT_ENDPOINT", "https://azure-ai-deploy.example.com") + ep, err := resolveProjectEndpointForDeploy(t.Context(), &optimizeConnectionFlags{}) + assert.NoError(t, err) + assert.Equal(t, "https://foundry-deploy.example.com", ep) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_list.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_list.go new file mode 100644 index 00000000000..2b79ffe1274 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_list.go @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// optimize_list.go implements the "optimize list" command, which lists +// recent optimization jobs with status, agent, and score. + +package cmd + +import ( + "fmt" + "io" + "strings" + + "azureaiagent/internal/pkg/agents/optimize_api" + + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +// optimizeListFlags holds CLI flags for the optimize list command. +type optimizeListFlags struct { + limit int // maximum number of results + status string // filter by job status + optimizeConnectionFlags +} + +func newOptimizeListCommand() *cobra.Command { + flags := &optimizeListFlags{} + + cmd := &cobra.Command{ + Use: "list", + Short: "List recent optimization runs.", + Long: `List recent optimization and evaluation runs. + +Use --status to filter by job status and --limit to control page size.`, + Example: ` # List all recent runs + azd ai agent optimize list + + # List only completed runs + azd ai agent optimize list --status completed + + # Show last 5 runs + azd ai agent optimize list --limit 5`, + RunE: func(cmd *cobra.Command, args []string) error { + return runOptimizeList(cmd, flags) + }, + } + + cmd.Flags().IntVar(&flags.limit, "limit", 20, "Maximum number of results") + cmd.Flags().StringVar(&flags.status, "status", "", "Filter by status (pending/running/completed/failed/cancelled)") + flags.optimizeConnectionFlags.register(cmd) + + return cmd +} + +func runOptimizeList(cmd *cobra.Command, flags *optimizeListFlags) error { + // Validate --status flag before making API call + if flags.status != "" { + valid := map[string]bool{"pending": true, "running": true, "completed": true, "failed": true, "cancelled": true} + if !valid[flags.status] { + return fmt.Errorf("invalid --status %q: must be one of pending, running, completed, failed, cancelled", flags.status) + } + } + + endpoint, err := flags.resolve(cmd.Context()) + if err != nil { + return err + } + + credential, err := newAgentCredential() + if err != nil { + return err + } + + client := optimize_api.NewOptimizeClient(endpoint, credential) + + listResp, err := client.ListOptimizeJobs(cmd.Context(), flags.limit, flags.status) + if err != nil { + return fmt.Errorf("failed to list optimization jobs: %w\n\nCheck that the endpoint %q is reachable", err, endpoint) + } + + out := cmd.OutOrStdout() + + if len(listResp.Data) == 0 { + fmt.Fprintln(out, " No optimization jobs found.") + if flags.status != "" { + fmt.Fprintf(out, "\n Try removing the --status filter or run a new job with:\n") + fmt.Fprintf(out, " azd ai agent optimize --config spec.yaml\n") + } + return nil + } + + printOptimizeListTable(out, listResp.Data) + return nil +} + +func printOptimizeListTable(out io.Writer, jobs []optimize_api.OptimizeJobStatus) { + bold := color.New(color.Bold) + + _, _ = bold.Fprintf(out, " %-38s %-12s %-14s %7s %s\n", "ID", "Status", "Agent", "Score", "Created") + fmt.Fprintf(out, " %-38s %-12s %-14s %7s %s\n", + strings.Repeat("─", 38), strings.Repeat("─", 12), + strings.Repeat("─", 14), strings.Repeat("─", 7), strings.Repeat("─", 19)) + + for _, job := range jobs { + scoreStr := "—" + if job.Best != nil { + scoreStr = fmt.Sprintf("%.2f", job.Best.AvgScore) + } + + agentName := "—" + if job.Agent != nil && job.Agent.AgentName != "" { + agentName = job.Agent.AgentName + } + + created := job.CreatedAt + if created == "" { + created = "—" + } + + fmt.Fprintf(out, " %-38s %-12s %-14s %7s %s\n", + job.OperationID, + formatOptimizeStatus(job.Status), + truncateString(agentName, 14), + scoreStr, + truncateString(created, 19), + ) + } + fmt.Fprintln(out) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_list_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_list_test.go new file mode 100644 index 00000000000..4aa5390a9f9 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_list_test.go @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOptimizeListCommand_AcceptsLimitAndStatusFlags(t *testing.T) { + cmd := newOptimizeListCommand() + + limitFlag := cmd.Flags().Lookup("limit") + require.NotNil(t, limitFlag, "--limit flag should be registered") + + limitVal, err := cmd.Flags().GetInt("limit") + require.NoError(t, err) + assert.Equal(t, 20, limitVal, "--limit should default to 20") + + statusFlag := cmd.Flags().Lookup("status") + require.NotNil(t, statusFlag, "--status flag should be registered") + + statusVal, err := cmd.Flags().GetString("status") + require.NoError(t, err) + assert.Equal(t, "", statusVal, "--status should default to empty") +} + +func TestOptimizeListCommand_HasConnectionFlags(t *testing.T) { + cmd := newOptimizeListCommand() + + assert.NotNil(t, cmd.Flags().Lookup("endpoint")) + assert.NotNil(t, cmd.Flags().Lookup("project-endpoint")) + + assert.Nil(t, cmd.Flags().Lookup("subscription")) + assert.Nil(t, cmd.Flags().Lookup("resource-group")) + assert.Nil(t, cmd.Flags().Lookup("workspace")) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_prompts.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_prompts.go new file mode 100644 index 00000000000..99d9d19ba4b --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_prompts.go @@ -0,0 +1,534 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// optimize_prompts.go contains interactive resolution functions for the +// optimize command: system prompt, skill directory, config confirmation, +// and target model selection. + +package cmd + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +// resolveOptimizeSystemPrompt resolves the agent's system prompt: +// +// 1. Config dir pointer (agent.config): instruction from metadata.yaml (already resolved). +// 2. Config (eval.yaml / --config): inline instruction or file reference. +// 3. Interactive prompt: ask the user to provide inline text or a file path. +// +// Relative file paths are resolved against agentProject. +func resolveOptimizeSystemPrompt( + ctx context.Context, + cfg *OptimizeConfig, + agentProject string, + hasProject bool, + noPrompt bool, +) error { + // Resolve relative instruction file paths against the agent project directory. + if cfg.Agent.Instruction.File != "" && hasProject && !filepath.IsAbs(cfg.Agent.Instruction.File) { + cfg.Agent.Instruction.File = filepath.Join(agentProject, cfg.Agent.Instruction.File) + } + + // Step 1: Config explicitly declares a file reference — validate it's readable. + if cfg.Agent.Instruction.File != "" { + if _, err := os.Stat(cfg.Agent.Instruction.File); err != nil { + return fmt.Errorf("instruction file %q from config is not accessible: %w", + cfg.Agent.Instruction.File, err) + } + return nil + } + + // Step 1b: Config already has inline instruction — nothing to do. + if cfg.Agent.Instruction.Value != "" { + return nil + } + + // Step 2: Interactive prompt — ask user to provide inline text or a file path. + if noPrompt { + return fmt.Errorf("instruction is required for optimization.\n\n" + + "Provide it via one of:\n" + + " 1. Set agent.config in eval.yaml to point to a config dir with metadata.yaml\n" + + " 2. Set instruction in eval.yaml (agent section): inline string or file reference\n" + + " 3. Run without --no-prompt to enter it interactively") + } + + azdClient, clientErr := azdext.NewAzdClient() + if clientErr != nil { + return fmt.Errorf("instruction is required but could not open interactive prompt: %w", clientErr) + } + defer azdClient.Close() + + inputChoices := []*azdext.SelectChoice{ + {Label: "Type inline", Value: "inline"}, + {Label: "Load from file", Value: "file"}, + } + defaultIdx := int32(0) + selResp, selErr := azdClient.Prompt().Select(ctx, &azdext.SelectRequest{ + Options: &azdext.SelectOptions{ + Message: "No instruction found in config or baseline. " + + "How would you like to provide it?", + Choices: inputChoices, + SelectedIndex: &defaultIdx, + }, + }) + if selErr != nil { + return fmt.Errorf("prompting for instruction input method: %w", selErr) + } + + if inputChoices[int(*selResp.Value)].Value == "file" { + pathResp, pathErr := azdClient.Prompt().Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Path to instruction file", + IgnoreHintKeys: true, + }, + }) + if pathErr != nil { + return fmt.Errorf("prompting for instruction file path: %w", pathErr) + } + filePath := strings.TrimSpace(pathResp.Value) + // Resolve relative paths against the agent project directory. + if !filepath.IsAbs(filePath) && hasProject { + filePath = filepath.Join(agentProject, filePath) + } + if _, err := os.Stat(filePath); err != nil { + return fmt.Errorf("instruction file %q is not accessible: %w", filePath, err) + } + cfg.Agent.Instruction.File = filePath + } else { + resp, promptErr := azdClient.Prompt().Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Enter the agent's instruction", + IgnoreHintKeys: true, + }, + }) + if promptErr != nil { + return fmt.Errorf("prompting for instruction: %w", promptErr) + } + cfg.Agent.Instruction.Value = strings.TrimSpace(resp.Value) + } + + return nil +} + +// resolveOptimizeSkillDir resolves the agent's skill directory: +// 1. Config dir pointer (agent.config): skill_dir from metadata.yaml (already resolved). +// 2. Auto-detect: look for a "skills/" folder in the agent project — confirm with user. +// 3. Interactive prompt: ask the user to provide a path or skip. +func resolveOptimizeSkillDir( + ctx context.Context, + cfg *OptimizeConfig, + agentProject string, + noPrompt bool, +) error { + // Step 1: Auto-detect common skill directory names. + var detectedDir string + for _, candidate := range []string{"skills", "skill"} { + dir := filepath.Join(agentProject, candidate) + if info, err := os.Stat(dir); err == nil && info.IsDir() { + detectedDir = dir + break + } + } + + if noPrompt { + // In no-prompt mode, use whatever was detected (may be empty). + cfg.SkillDir = detectedDir + return nil + } + + azdClient, clientErr := azdext.NewAzdClient() + if clientErr != nil { + cfg.SkillDir = detectedDir + return nil + } + defer azdClient.Close() + + if detectedDir != "" { + // Found a skill directory — ask user to confirm or provide a different one. + choices := []*azdext.SelectChoice{ + {Label: fmt.Sprintf("Use detected: %s", detectedDir), Value: "use"}, + {Label: "Provide a different path", Value: "other"}, + {Label: "Skip (no skills)", Value: "skip"}, + } + defaultIdx := int32(0) + selResp, selErr := azdClient.Prompt().Select(ctx, &azdext.SelectRequest{ + Options: &azdext.SelectOptions{ + Message: fmt.Sprintf("Found skills directory: %s", detectedDir), + Choices: choices, + SelectedIndex: &defaultIdx, + }, + }) + if selErr != nil { + cfg.SkillDir = detectedDir + return nil + } + + switch choices[int(*selResp.Value)].Value { + case "use": + cfg.SkillDir = detectedDir + return nil + case "skip": + return nil + case "other": + // Fall through to path prompt below. + } + } else { + // No skill directory found — ask if they want to provide one. + resp, promptErr := azdClient.Prompt().Confirm(ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: "No skills directory found. Would you like to provide one?", + DefaultValue: new(bool), // default false + }, + }) + if promptErr != nil || !resp.GetValue() { + return nil // skip skills + } + } + + // Prompt for a custom path. + pathResp, pathErr := azdClient.Prompt().Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Path to skills directory", + IgnoreHintKeys: true, + }, + }) + if pathErr != nil { + return fmt.Errorf("prompting for skills directory: %w", pathErr) + } + + dir := strings.TrimSpace(pathResp.Value) + if dir == "" { + return nil + } + if !filepath.IsAbs(dir) { + dir = filepath.Join(agentProject, dir) + } + if info, err := os.Stat(dir); err != nil || !info.IsDir() { + return fmt.Errorf("skills directory %q is not accessible or not a directory", dir) + } + + cfg.SkillDir = dir + return nil +} + +// promptOptimizeConfigConfirmation shows the resolved values from the baseline +// config and lets the user confirm or override instruction file, skills +// directory, and tools file. +func promptOptimizeConfigConfirmation(ctx context.Context, cfg *OptimizeConfig, agentProject string) error { + azdClient, clientErr := azdext.NewAzdClient() + if clientErr != nil { + return nil // non-fatal — skip confirmation prompts + } + defer azdClient.Close() + prompt := azdClient.Prompt() + + // Instruction file. + instrDefault := relativeDisplay(cfg.Agent.Instruction.File, agentProject) + resp, err := prompt.Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Instruction file", + DefaultValue: instrDefault, + IgnoreHintKeys: true, + }, + }) + if err != nil { + return fmt.Errorf("prompting for instruction file: %w", err) + } + if value := strings.TrimSpace(resp.Value); value != "" { + if !filepath.IsAbs(value) && agentProject != "" { + value = filepath.Join(agentProject, value) + } + if _, err := os.Stat(value); err != nil { + return fmt.Errorf("instruction file %q is not accessible: %w", value, err) + } + cfg.Agent.Instruction.File = value + cfg.Agent.Instruction.Value = "" + } + + // Skills directory. + skillDefault := relativeDisplay(cfg.SkillDir, agentProject) + resp, err = prompt.Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Skills directory (enter to skip)", + DefaultValue: skillDefault, + IgnoreHintKeys: true, + }, + }) + if err != nil { + return fmt.Errorf("prompting for skills directory: %w", err) + } + if value := strings.TrimSpace(resp.Value); value != "" { + if !filepath.IsAbs(value) && agentProject != "" { + value = filepath.Join(agentProject, value) + } + cfg.SkillDir = value + } else { + cfg.SkillDir = "" + } + + // Tools file. + toolsDefault := relativeDisplay(cfg.ToolsFile, agentProject) + resp, err = prompt.Prompt(ctx, &azdext.PromptRequest{ + Options: &azdext.PromptOptions{ + Message: "Tools file (enter to skip)", + DefaultValue: toolsDefault, + IgnoreHintKeys: true, + }, + }) + if err != nil { + return fmt.Errorf("prompting for tools file: %w", err) + } + if value := strings.TrimSpace(resp.Value); value != "" { + if !filepath.IsAbs(value) && agentProject != "" { + value = filepath.Join(agentProject, value) + } + cfg.ToolsFile = value + } else { + cfg.ToolsFile = "" + } + + return nil +} + +// resolveOptimizeTargetModels prompts the user to select model candidates +// for optimization (target_config.model). Fetches actual deployments from the +// Foundry project and allows multi-select. +func resolveOptimizeTargetModels( + ctx context.Context, + cfg *OptimizeConfig, +) error { + azdClient, clientErr := azdext.NewAzdClient() + if clientErr != nil { + return nil + } + defer azdClient.Close() + + currentModel := cfg.Agent.Model + + resp, promptErr := azdClient.Prompt().Confirm(ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: "Would you like to specify target models for optimization?", + DefaultValue: new(bool), // default false + }, + }) + if promptErr != nil || !resp.GetValue() { + return nil + } + + // Fetch deployed models from the Foundry project. + choices := buildOptimizeModelChoices(ctx, azdClient, currentModel) + + message := "Select target models for optimization" + if currentModel != "" { + message = fmt.Sprintf("Select target models for optimization (current: %s)", currentModel) + } + + multiResp, multiErr := azdClient.Prompt().MultiSelect(ctx, &azdext.MultiSelectRequest{ + Options: &azdext.MultiSelectOptions{ + Message: message, + Choices: choices, + }, + }) + if multiErr != nil { + return fmt.Errorf("prompting for target models: %w", multiErr) + } + + var models []string + for _, v := range multiResp.Values { + models = append(models, v.Value) + } + + if len(models) > 0 { + if cfg.Options.TargetConfig == nil { + cfg.Options.TargetConfig = &opt_eval.TargetConfig{} + } + cfg.Options.TargetConfig.Model = models + } + + return nil +} + +// allowedReflectionModels is the set of model families permitted as reflection +// models by the server. Deployments whose ModelName does not match one of these +// prefixes are excluded from the selection list. +var allowedReflectionModels = []string{"gpt-5", "gpt-5.1", "gpt-5.3"} + +// isAllowedReflectionModel checks whether a model name matches an allowed +// reflection model (exact match or prefix followed by a separator). +func isAllowedReflectionModel(modelName string) bool { + for _, allowed := range allowedReflectionModels { + if strings.EqualFold(modelName, allowed) { + return true + } + } + return false +} + +// resolveOptimizeReflectionModel prompts the user to select a reflection model +// for optimization. All deployments are shown; if the user picks one whose +// model is not in the recommended set, a warning is printed. This avoids +// requiring client-side updates when the server's allowed set changes. +func resolveOptimizeReflectionModel(ctx context.Context, cfg *OptimizeConfig) error { + azdClient, clientErr := azdext.NewAzdClient() + if clientErr != nil { + return nil + } + defer azdClient.Close() + + deployments := listDeploymentsFromEnv(ctx, azdClient) + if len(deployments) == 0 { + return nil + } + + allowedList := strings.Join(allowedReflectionModels, ", ") + + var choices []*azdext.SelectChoice + seen := make(map[string]bool) + + // Always offer Skip — defaults to using the eval model. + choices = append(choices, &azdext.SelectChoice{ + Label: fmt.Sprintf("Skip (use eval model: %s)", cfg.Options.EvalModel), + Value: "", + }) + + // Show all deployments — don't filter by allowed set. + for _, d := range deployments { + if seen[d.Name] { + continue + } + label := d.Name + if d.ModelName != "" && d.ModelName != d.Name { + label = fmt.Sprintf("%s (%s)", d.Name, d.ModelName) + } + choices = append(choices, &azdext.SelectChoice{ + Label: label, + Value: d.Name, + }) + seen[d.Name] = true + } + + message := fmt.Sprintf("Select a reflection model (recommended: %s)", allowedList) + + selectResp, selectErr := azdClient.Prompt().Select(ctx, &azdext.SelectRequest{ + Options: &azdext.SelectOptions{ + Message: message, + Choices: choices, + }, + }) + if selectErr != nil || selectResp.Value == nil { + return nil + } + + idx := int(*selectResp.Value) + if idx >= 0 && idx < len(choices) && choices[idx].Value != "" { + selected := choices[idx].Value + // Warn if the selected deployment's model is not in the recommended set. + for _, d := range deployments { + if d.Name == selected && !isAllowedReflectionModel(d.ModelName) { + fmt.Printf("Warning: deployment %q uses model %q which is not in the recommended "+ + "reflection model set (%s). The server may reject it.\n", selected, d.ModelName, allowedList) + break + } + } + cfg.Options.ReflectionModel = selected + } + // Empty Value means Skip — leave ReflectionModel empty (server uses eval model). + + return nil +} + +// buildOptimizeModelChoices fetches Foundry project deployments and returns +// MultiSelectChoice items. The current deployed model is pre-selected. +// Falls back to an empty list if deployments cannot be fetched. +func buildOptimizeModelChoices(ctx context.Context, azdClient *azdext.AzdClient, currentModel string) []*azdext.MultiSelectChoice { + deployments := listDeploymentsFromEnv(ctx, azdClient) + + var choices []*azdext.MultiSelectChoice + seen := make(map[string]bool) + + // If current model is present in deployments, it will be marked below. + // If not (and it's non-empty), prepend it as a pre-selected entry. + if currentModel != "" { + found := false + for _, d := range deployments { + if d.Name == currentModel { + found = true + break + } + } + if !found { + choices = append(choices, &azdext.MultiSelectChoice{ + Label: currentModel + " (current)", + Value: currentModel, + Selected: true, + }) + seen[currentModel] = true + } + } + + for _, d := range deployments { + if seen[d.Name] { + continue + } + label := d.Name + if d.ModelName != "" && d.ModelName != d.Name { + label = fmt.Sprintf("%s (%s)", d.Name, d.ModelName) + } + selected := d.Name == currentModel + if selected { + label += " (current)" + } + choices = append(choices, &azdext.MultiSelectChoice{ + Label: label, + Value: d.Name, + Selected: selected, + }) + seen[d.Name] = true + } + + return choices +} + +// listDeploymentsFromEnv reads AZURE_AI_PROJECT_ID from the azd environment +// and returns the Foundry project's model deployments. Returns nil on failure. +func listDeploymentsFromEnv(ctx context.Context, azdClient *azdext.AzdClient) []FoundryDeploymentInfo { + envResp, err := azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{}) + if err != nil || envResp == nil || envResp.Environment == nil { + return nil + } + + v, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: envResp.Environment.Name, + Key: "AZURE_AI_PROJECT_ID", + }) + if err != nil || v.Value == "" { + return nil + } + + project, err := extractProjectDetails(v.Value) + if err != nil { + return nil + } + + cred, err := newAgentCredential() + if err != nil { + return nil + } + + deployments, _ := listProjectDeployments( + ctx, cred, + project.SubscriptionId, + project.ResourceGroupName, + project.AccountName, + ) + return deployments +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_status.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_status.go new file mode 100644 index 00000000000..865ef11dab3 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_status.go @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// optimize_status.go implements the "optimize status" command, which checks +// or watches the status of an optimization job. + +package cmd + +import ( + "fmt" + "io" + + "azureaiagent/internal/pkg/agents/optimize_api" + + azdext "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +// optimizeStatusFlags holds CLI flags for the optimize status command. +type optimizeStatusFlags struct { + watch bool // poll until job completes + pollInterval int // polling interval in seconds + optimizeConnectionFlags +} + +func newOptimizeStatusCommand() *cobra.Command { + flags := &optimizeStatusFlags{} + + cmd := &cobra.Command{ + Use: "status [operation-id]", + Short: "Check the status of an optimization job.", + Long: `Check the status of an optimization job by its operation ID. + +If no operation ID is provided, uses the last optimization job from this project. +Use --watch to poll until the job completes.`, + Example: ` # Check last job status (auto-resolved) + azd ai agent optimize status + + # Check specific job status + azd ai agent optimize status opt_abc123 + + # Watch until complete + azd ai agent optimize status opt_abc123 --watch`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + operationID := "" + if len(args) > 0 { + operationID = args[0] + } else { + operationID = loadLastOptimizeJobID(ctx) + if operationID == "" { + return fmt.Errorf("operation ID is required: provide it as an argument, or run 'azd ai agent optimize' first") + } + fmt.Fprintf(cmd.OutOrStdout(), " Using last job: %s\n\n", operationID) + } + return runOptimizeStatus(cmd, flags, operationID) + }, + } + + cmd.Flags().BoolVar(&flags.watch, "watch", false, "Poll until job completes") + cmd.Flags().IntVar(&flags.pollInterval, "poll-interval", 5, "Polling interval in seconds") + flags.optimizeConnectionFlags.register(cmd) + + return cmd +} + +func runOptimizeStatus(cmd *cobra.Command, flags *optimizeStatusFlags, operationID string) error { + endpoint, err := flags.resolve(cmd.Context()) + if err != nil { + return err + } + + credential, err := newAgentCredential() + if err != nil { + return err + } + + client := optimize_api.NewOptimizeClient(endpoint, credential) + out := cmd.OutOrStdout() + + status, err := client.GetOptimizeStatus(cmd.Context(), operationID) + if err != nil { + return fmt.Errorf("failed to get job status: %w\n\nCheck that the operation ID %q is correct", err, operationID) + } + + printOptimizeJobSummary(out, status) + + if flags.watch && !optimize_api.IsTerminal(status.Status) { + finalStatus, err := pollOptimizeJob(cmd, client, flags.pollInterval, operationID) + if err != nil { + return err + } + printOptimizeResults(out, finalStatus, false) + } else if len(status.Candidates) > 0 { + printOptimizeResults(out, status, false) + } + + if status.Error != nil { + return fmt.Errorf("optimization job failed: %s", status.Error.Message) + } + + return nil +} + +// printOptimizeJobSummary prints a brief summary of an optimization job's state. +func printOptimizeJobSummary(out io.Writer, status *optimize_api.OptimizeJobStatus) { + fmt.Fprintf(out, " Job ID: %s\n", color.CyanString(status.OperationID)) + fmt.Fprintf(out, " Status: %s\n", formatOptimizeStatus(status.Status)) + if status.Agent != nil && status.Agent.AgentName != "" { + fmt.Fprintf(out, " Agent: %s\n", status.Agent.AgentName) + } + if status.AllTargetAttributesFailed { + fmt.Fprintf(out, " Strategy: %s\n", color.YellowString("failed (baseline only — no candidates generated)")) + } else if status.Progress != nil && status.Progress.CurrentTargetAttribute != "" { + fmt.Fprintf(out, " Strategy: %s\n", status.Progress.CurrentTargetAttribute) + } + if status.Best != nil { + fmt.Fprintf(out, " Best: %.2f\n", status.Best.AvgScore) + } + if status.CreatedAt != "" { + fmt.Fprintf(out, " Created: %s\n", status.CreatedAt) + } + if status.Error != nil { + fmt.Fprintf(out, " Error: %s\n", color.RedString(status.Error.Message)) + } + if len(status.Warnings) > 0 { + for _, w := range status.Warnings { + fmt.Fprintf(out, " Warning: %s\n", color.YellowString(w)) + } + } + fmt.Fprintln(out) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_status_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_status_test.go new file mode 100644 index 00000000000..7996b6dc7ee --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_status_test.go @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOptimizeStatusCommand_AcceptsOptionalPositionalArg(t *testing.T) { + cmd := newOptimizeStatusCommand() + + // Zero args is now OK (uses last job ID) + err := cmd.Args(cmd, []string{}) + assert.NoError(t, err) + + // One arg is OK + err = cmd.Args(cmd, []string{"opt_abc123"}) + assert.NoError(t, err) + + // Two args is rejected + err = cmd.Args(cmd, []string{"opt_abc123", "extra"}) + assert.Error(t, err) +} + +func TestOptimizeStatusCommand_HasWatchFlag(t *testing.T) { + cmd := newOptimizeStatusCommand() + + f := cmd.Flags().Lookup("watch") + require.NotNil(t, f, "--watch flag should be registered") + + watchVal, err := cmd.Flags().GetBool("watch") + require.NoError(t, err) + assert.False(t, watchVal, "--watch should default to false for status") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_test.go new file mode 100644 index 00000000000..d10cf036233 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/optimize_test.go @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/optimize_api" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOptimizeCommand_HasExpectedSubCommands(t *testing.T) { + cmd := newOptimizeCommand(&azdext.ExtensionContext{}) + + expected := []string{"status", "list", "cancel", "deploy", "apply"} + var actual []string + for _, sub := range cmd.Commands() { + actual = append(actual, sub.Name()) + } + + for _, name := range expected { + assert.Contains(t, actual, name, "optimize should have sub-command %q", name) + } + assert.NotContains(t, actual, "run", "optimize should not have 'run' sub-command (merged into root)") +} + +func TestOptimizeCommand_AcceptsPositionalArg(t *testing.T) { + cmd := newOptimizeCommand(&azdext.ExtensionContext{}) + + err := cmd.Args(cmd, []string{"my-agent"}) + assert.NoError(t, err) + + err = cmd.Args(cmd, []string{}) + assert.NoError(t, err) + + err = cmd.Args(cmd, []string{"my-agent", "extra"}) + assert.Error(t, err) +} + +func TestOptimizeCommand_AcceptsConfigFlag(t *testing.T) { + cmd := newOptimizeCommand(&azdext.ExtensionContext{}) + + f := cmd.Flags().Lookup("config") + require.NotNil(t, f, "--config flag should be registered") + assert.Equal(t, "c", f.Shorthand, "--config should have -c shorthand") + + assert.NotNil(t, cmd.Flags().Lookup("poll-interval")) + assert.NotNil(t, cmd.Flags().Lookup("target")) +} + +func TestOptimizeCommand_DefaultFlags(t *testing.T) { + cmd := newOptimizeCommand(&azdext.ExtensionContext{}) + + pollVal, err := cmd.Flags().GetInt("poll-interval") + require.NoError(t, err) + assert.Equal(t, 5, pollVal, "--poll-interval should default to 5") +} + +func TestIsTerminal_ViaOptimizeAPI(t *testing.T) { + assert.True(t, optimize_api.IsTerminal(optimize_api.StatusCompleted)) + assert.True(t, optimize_api.IsTerminal(optimize_api.StatusFailed)) + assert.True(t, optimize_api.IsTerminal(optimize_api.StatusCancelled)) + assert.False(t, optimize_api.IsTerminal(optimize_api.StatusRunning)) + assert.False(t, optimize_api.IsTerminal(optimize_api.StatusPending)) + assert.False(t, optimize_api.IsTerminal("")) +} + +func TestTruncateString(t *testing.T) { + assert.Equal(t, "abc", truncateString("abc", 10)) + assert.Equal(t, "abcdefg...", truncateString("abcdefghijk", 10)) + assert.Equal(t, "ab", truncateString("abcdef", 2)) +} + +func TestFormatOptimizeStatus(t *testing.T) { + assert.NotEmpty(t, formatOptimizeStatus(optimize_api.StatusCompleted)) + assert.NotEmpty(t, formatOptimizeStatus(optimize_api.StatusFailed)) + assert.NotEmpty(t, formatOptimizeStatus(optimize_api.StatusCancelled)) + assert.NotEmpty(t, formatOptimizeStatus(optimize_api.StatusRunning)) + assert.NotEmpty(t, formatOptimizeStatus("unknown")) +} + +// ---- defaultOptimizeConfig ---- + +func TestDefaultOptimizeConfig(t *testing.T) { + t.Parallel() + cfg := defaultOptimizeConfig("my-agent") + + assert.Equal(t, "my-agent", cfg.Agent.Name) + assert.NotEmpty(t, cfg.InlineDataset) + require.NotNil(t, cfg.Options) + assert.Equal(t, "gpt-4o", cfg.Options.EvalModel) + assert.Contains(t, cfg.Options.TargetAttributes, "instruction") + assert.Contains(t, cfg.Options.TargetAttributes, "skill") + require.Len(t, cfg.Evaluators, 1) + assert.Equal(t, "builtin.task_adherence", cfg.Evaluators[0].Name) +} + +// ---- LoadOptimizeConfig + reconcileConfigAgentName (--config path) ---- + +func TestLoadOptimizeConfig_ReconcileAgentName(t *testing.T) { + t.Parallel() + + writeConfigYAML := func(t *testing.T, dir, agentName string) string { + t.Helper() + content := "agent:\n name: " + agentName + "\noptions:\n eval_model: gpt-4o\n mode: optimize\n" + path := filepath.Join(dir, "spec.yaml") + require.NoError(t, os.WriteFile(path, []byte(content), 0600)) + return path + } + + t.Run("env overrides config when names differ", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + cfgPath := writeConfigYAML(t, dir, "config-agent") + + cfg, err := LoadOptimizeConfig(cfgPath) + require.NoError(t, err) + assert.Equal(t, "config-agent", cfg.Agent.Name) + + changed := reconcileConfigAgentName(&cfg.Agent, "env-agent", cfgPath) + assert.True(t, changed, "should report change when names differ") + assert.Equal(t, "env-agent", cfg.Agent.Name, "environment name should take precedence") + }) + + t.Run("no change when names match", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + cfgPath := writeConfigYAML(t, dir, "same-agent") + + cfg, err := LoadOptimizeConfig(cfgPath) + require.NoError(t, err) + + changed := reconcileConfigAgentName(&cfg.Agent, "same-agent", cfgPath) + assert.False(t, changed) + assert.Equal(t, "same-agent", cfg.Agent.Name) + }) + + t.Run("sets name when config has empty agent name", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + content := "agent:\n kind: hosted\noptions:\n eval_model: gpt-4o\n" + cfgPath := filepath.Join(dir, "spec.yaml") + require.NoError(t, os.WriteFile(cfgPath, []byte(content), 0600)) + + cfg, err := LoadOptimizeConfig(cfgPath) + require.NoError(t, err) + assert.Empty(t, cfg.Agent.Name) + + changed := reconcileConfigAgentName(&cfg.Agent, "env-agent", cfgPath) + assert.False(t, changed, "filling empty name is not a 'change' (no conflict)") + assert.Equal(t, "env-agent", cfg.Agent.Name) + }) + + t.Run("no-op when env name is empty", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + cfgPath := writeConfigYAML(t, dir, "config-agent") + + cfg, err := LoadOptimizeConfig(cfgPath) + require.NoError(t, err) + + changed := reconcileConfigAgentName(&cfg.Agent, "", cfgPath) + assert.False(t, changed) + assert.Equal(t, "config-agent", cfg.Agent.Name, "original name preserved when env is empty") + }) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go index cef2bd54634..923c5a6a8cc 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go @@ -69,6 +69,8 @@ func NewRootCommand() *cobra.Command { // When the azd core namespace change lands, move this AddCommand call // to the new root and update the import path. rootCmd.AddCommand(conncmd.NewConnectionRootCommand(extCtx)) + rootCmd.AddCommand(newEvalCommand(extCtx)) + rootCmd.AddCommand(newOptimizeCommand(extCtx)) return rootCmd } diff --git a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go index 4722c9e22e8..2bcc78683d6 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go +++ b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go @@ -163,3 +163,15 @@ const ( OpCreateToolboxVersion = "create_toolbox_version" OpGetToolbox = "get_toolbox" ) + +// Error codes for eval and optimize operations. +const ( + CodeEvalRunFailed = "eval_run_failed" + CodeEvalRunCancelled = "eval_run_cancelled" + CodeEvalRunTimeout = "eval_run_timeout" + CodeEvalConfigInvalid = "eval_config_invalid" + CodeOptimizeJobFailed = "optimize_job_failed" + CodeOptimizeJobTimeout = "optimize_job_timeout" + CodeInvalidTargetAttr = "invalid_target_attribute" + CodeReservedEnvVar = "reserved_env_var" +) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/models.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/models.go new file mode 100644 index 00000000000..cee6bbcd886 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/models.go @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package dataset_api + +import ( + "fmt" + "math" + "os" + "path/filepath" + "strconv" + "strings" +) + +// CreateDatasetRequest is the request body for creating (uploading) a dataset. +type CreateDatasetRequest struct { + Name string `json:"name"` + Version string `json:"version"` + Format string `json:"format"` + Content string `json:"content"` +} + +// Dataset is the response for dataset operations. +// Note: The GET /datasets API returns snake_case field names (data_uri, blob_uri, +// content_uri), while the POST /finalize API accepts camelCase (dataUri). +// Both conventions are correct for their respective endpoints. +type Dataset struct { + Name string `json:"name"` + Version string `json:"version"` + BlobURI string `json:"blob_uri,omitempty"` + Format string `json:"format,omitempty"` + DataURI string `json:"data_uri,omitempty"` + ContentURI string `json:"content_uri,omitempty"` +} + +// ResolvedBlobURI returns the best available blob URI. Prefers blob_uri, +// falls back to data_uri, then content_uri. +func (d *Dataset) ResolvedBlobURI() string { + if d.BlobURI != "" { + return d.BlobURI + } + if d.DataURI != "" { + return d.DataURI + } + return d.ContentURI +} + +// DatasetCredential is the response for dataset credential (SAS token) requests. +// The API returns a nested structure with blobReference and blobReferenceForConsumption. +type DatasetCredential struct { + // Flat fields (legacy format). + BlobURI string `json:"blob_uri,omitempty"` + SAS string `json:"sas,omitempty"` + SASUri string `json:"sas_uri,omitempty"` + + // Nested fields (current API format). + BlobReference *BlobReference `json:"blobReference,omitempty"` + BlobReferenceConsumption *BlobReference `json:"blobReferenceForConsumption,omitempty"` +} + +// BlobReference represents a blob storage reference with credentials. +type BlobReference struct { + BlobURI string `json:"blobUri,omitempty"` + StorageAccountARM string `json:"storageAccountArmId,omitempty"` + Credential *BlobCredential `json:"credential,omitempty"` +} + +// BlobCredential holds SAS credential details for blob access. +type BlobCredential struct { + Type string `json:"type,omitempty"` + SASUri string `json:"sasUri,omitempty"` + SASPath string `json:"sas,omitempty"` +} + +// ResolvedDownloadURI returns the URL to download the dataset. +// Prefers blobReferenceForConsumption.credential.sasUri (current API), +// then blobReference.credential.sasUri, then flat sas_uri, then blob_uri + sas. +func (c *DatasetCredential) ResolvedDownloadURI() string { + // Current API format: nested blob references. + if c.BlobReferenceConsumption != nil && c.BlobReferenceConsumption.Credential != nil { + if uri := c.BlobReferenceConsumption.Credential.SASUri; uri != "" { + return uri + } + } + if c.BlobReference != nil && c.BlobReference.Credential != nil { + if uri := c.BlobReference.Credential.SASUri; uri != "" { + return uri + } + } + // Legacy flat format. + if c.SASUri != "" { + return c.SASUri + } + if c.BlobURI != "" && c.SAS != "" { + return c.BlobURI + "?" + c.SAS + } + return c.BlobURI +} + +// PendingUploadResponse is returned by the startPendingUpload endpoint. +// It contains a SAS URI for uploading blob data and the blob container URI. +type PendingUploadResponse struct { + BlobReference *BlobReference `json:"blobReference,omitempty"` + BlobReferenceConsumption *BlobReference `json:"blobReferenceForConsumption,omitempty"` + PendingUploadID *string `json:"pendingUploadId,omitempty"` + PendingUploadType string `json:"pendingUploadType,omitempty"` + Version string `json:"version,omitempty"` +} + +// ResolvedUploadURI returns the SAS URI for uploading blobs. +func (p *PendingUploadResponse) ResolvedUploadURI() string { + if p.BlobReference != nil && p.BlobReference.Credential != nil { + if uri := p.BlobReference.Credential.SASUri; uri != "" { + return uri + } + } + return "" +} + +// ResolvedBlobURI returns the blob container URI (without SAS) for the finalize request. +func (p *PendingUploadResponse) ResolvedBlobURI() string { + if p.BlobReference != nil { + return p.BlobReference.BlobURI + } + return "" +} + +// FinalizeDatasetRequest is the request body for finalizing a dataset version +// after blob upload. +type FinalizeDatasetRequest struct { + Name string `json:"name"` + Version string `json:"version"` + Description string `json:"description"` + Type string `json:"type"` + IsReference bool `json:"isReference"` + DataURI string `json:"dataUri"` +} + +// NextVersion computes the next dataset version string. +// +// Rules: +// 1. Empty → "1.0" +// 2. Parsable as a decimal number → increment by 1, format as "N.0" +// 3. Ends with trailing digits → increment the trailing numeric part +// 4. Otherwise → append ".1" +func NextVersion(current string) string { + current = strings.TrimSpace(current) + if current == "" { + return "1.0" + } + + // Try parsing as a decimal number (e.g. "1", "1.0", "2.0"). + if f, err := strconv.ParseFloat(current, 64); err == nil { + return strconv.FormatFloat(math.Floor(f)+1, 'f', 1, 64) + } + + // Find trailing digits and increment them. + i := len(current) - 1 + for i >= 0 && current[i] >= '0' && current[i] <= '9' { + i-- + } + if i < len(current)-1 { + prefix := current[:i+1] + n, err := strconv.Atoi(current[i+1:]) + if err == nil { + return prefix + strconv.Itoa(n+1) + } + } + + return current + ".1" +} + +// ReadFirstJSONLFile finds and reads the first .jsonl file in a directory. +func ReadFirstJSONLFile(dir string) (string, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return "", fmt.Errorf("reading directory: %w", err) + } + for _, e := range entries { + if e.IsDir() { + continue + } + if filepath.Ext(e.Name()) == ".jsonl" { + data, err := os.ReadFile(filepath.Join(dir, e.Name())) //nolint:gosec // local artifact path + if err != nil { + return "", fmt.Errorf("reading %s: %w", e.Name(), err) + } + return string(data), nil + } + } + return "", fmt.Errorf("no .jsonl file found in %s", dir) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/models_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/models_test.go new file mode 100644 index 00000000000..ad764c9e44f --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/models_test.go @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package dataset_api + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// --------------------------------------------------------------------------- +// Dataset +// --------------------------------------------------------------------------- + +func TestDataset_ResolvedBlobURI(t *testing.T) { + t.Parallel() + tests := []struct { + name string + dataset Dataset + expected string + }{ + { + name: "prefers blob_uri", + dataset: Dataset{BlobURI: "https://blob.example", DataURI: "https://data.example"}, + expected: "https://blob.example", + }, + { + name: "falls back to data_uri", + dataset: Dataset{DataURI: "https://data.example", ContentURI: "https://content.example"}, + expected: "https://data.example", + }, + { + name: "falls back to content_uri", + dataset: Dataset{ContentURI: "https://content.example"}, + expected: "https://content.example", + }, + { + name: "empty when no URI", + dataset: Dataset{Name: "test"}, + expected: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, tc.dataset.ResolvedBlobURI()) + }) + } +} + +// --------------------------------------------------------------------------- +// DatasetCredential +// --------------------------------------------------------------------------- + +func TestDatasetCredential_ResolvedDownloadURI(t *testing.T) { + t.Parallel() + tests := []struct { + name string + cred DatasetCredential + expected string + }{ + { + name: "prefers sas_uri", + cred: DatasetCredential{SASUri: "https://blob.example/data?sig=abc", BlobURI: "https://blob.example/data"}, + expected: "https://blob.example/data?sig=abc", + }, + { + name: "combines blob_uri and sas", + cred: DatasetCredential{BlobURI: "https://blob.example/data", SAS: "sig=abc&se=2025"}, + expected: "https://blob.example/data?sig=abc&se=2025", + }, + { + name: "blob_uri only", + cred: DatasetCredential{BlobURI: "https://blob.example/data"}, + expected: "https://blob.example/data", + }, + { + name: "empty when no fields", + cred: DatasetCredential{}, + expected: "", + }, + { + name: "prefers blobReferenceForConsumption", + cred: DatasetCredential{ + BlobReference: &BlobReference{Credential: &BlobCredential{SASUri: "https://blob.example/ref?sig=1"}}, + BlobReferenceConsumption: &BlobReference{Credential: &BlobCredential{SASUri: "https://blob.example/consumption?sig=2"}}, + }, + expected: "https://blob.example/consumption?sig=2", + }, + { + name: "falls back to blobReference", + cred: DatasetCredential{ + BlobReference: &BlobReference{Credential: &BlobCredential{SASUri: "https://blob.example/ref?sig=1"}}, + }, + expected: "https://blob.example/ref?sig=1", + }, + { + name: "nested takes priority over flat sas_uri", + cred: DatasetCredential{ + SASUri: "https://blob.example/flat?sig=flat", + BlobReference: &BlobReference{Credential: &BlobCredential{SASUri: "https://blob.example/nested?sig=nested"}}, + }, + expected: "https://blob.example/nested?sig=nested", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, tc.cred.ResolvedDownloadURI()) + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/operations.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/operations.go new file mode 100644 index 00000000000..a1d54bf47a5 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/operations.go @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package dataset_api + +import ( + "bytes" + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + + "azureaiagent/internal/version" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/azure/azure-dev/cli/azd/pkg/azsdk" +) + +// API path prefix for dataset endpoints. +const pathDatasets = "/datasets" + +// DatasetClient provides methods for dataset upload, download, and metadata retrieval. +type DatasetClient struct { + endpoint string + pipeline runtime.Pipeline +} + +// NewDatasetClient creates a new DatasetClient. +func NewDatasetClient(endpoint string, cred azcore.TokenCredential) *DatasetClient { + userAgent := fmt.Sprintf("azd-ext-azure-ai-agents/%s", version.Version) + + clientOptions := &policy.ClientOptions{ + Logging: policy.LogOptions{ + AllowedHeaders: []string{"X-Ms-Correlation-Request-Id", "X-Request-Id"}, + IncludeBody: false, + }, + PerCallPolicies: []policy.Policy{ + runtime.NewBearerTokenPolicy(cred, []string{"https://ai.azure.com/.default"}, nil), + azsdk.NewMsCorrelationPolicy(), + azsdk.NewUserAgentPolicy(userAgent), + }, + } + + pipeline := runtime.NewPipeline( + "azure-ai-datasets", + "v1.0.0", + runtime.PipelineOptions{}, + clientOptions, + ) + + return &DatasetClient{ + endpoint: endpoint, + pipeline: pipeline, + } +} + +// NewDatasetClientFromPipeline creates a DatasetClient with a pre-built pipeline. +// This is intended for tests that need to bypass auth policies. +func NewDatasetClientFromPipeline(endpoint string, pipeline runtime.Pipeline) *DatasetClient { + return &DatasetClient{ + endpoint: endpoint, + pipeline: pipeline, + } +} + +// CreateDataset registers a dataset with inline content (upload). +func (c *DatasetClient) CreateDataset( + ctx context.Context, + request *CreateDatasetRequest, + apiVersion string, +) (*Dataset, error) { + return doRequestTyped[Dataset](c, ctx, http.MethodPost, pathDatasets, nil, request, apiVersion) +} + +// UploadNewVersion reads the first JSONL file from localDir, computes the next +// version from currentVersion, and uploads it as a new dataset version using +// the 3-step pending upload flow: +// 1. startPendingUpload → get SAS URI +// 2. Upload blob to SAS URI +// 3. Finalize dataset version with dataUri +func (c *DatasetClient) UploadNewVersion( + ctx context.Context, + name string, + currentVersion string, + localDir string, + apiVersion string, +) (*Dataset, error) { + content, err := ReadFirstJSONLFile(localDir) + if err != nil { + return nil, fmt.Errorf("reading dataset from %s: %w", localDir, err) + } + + newVersion := NextVersion(currentVersion) + + // Step 1: Start pending upload to get a SAS URI. + pending, err := c.StartPendingUpload(ctx, name, newVersion, apiVersion) + if err != nil { + return nil, fmt.Errorf("starting pending upload: %w", err) + } + + uploadURI := pending.ResolvedUploadURI() + if uploadURI == "" { + return nil, fmt.Errorf("no upload SAS URI returned from startPendingUpload") + } + + // Step 2: Upload the JSONL file to blob storage. + blobName := name + ".jsonl" + if err := c.UploadBlob(ctx, uploadURI, blobName, []byte(content)); err != nil { + return nil, fmt.Errorf("uploading blob: %w", err) + } + + // Step 3: Finalize the dataset version with the full blob URI. + dataURI := strings.TrimSuffix(pending.ResolvedBlobURI(), "/") + "/" + blobName + return c.FinalizeDatasetVersion(ctx, name, newVersion, dataURI, apiVersion) +} + +// StartPendingUpload initiates a pending upload for a dataset version. +// Returns the SAS URI and blob reference for uploading data. +func (c *DatasetClient) StartPendingUpload( + ctx context.Context, + name string, + version string, + apiVersion string, +) (*PendingUploadResponse, error) { + path := fmt.Sprintf( + "%s/%s/versions/%s/startPendingUpload", + pathDatasets, url.PathEscape(name), url.PathEscape(version), + ) + return doRequestTyped[PendingUploadResponse](c, ctx, http.MethodPost, path, nil, json.RawMessage(`{}`), apiVersion) +} + +// UploadBlob uploads data to a container SAS URI as a block blob. +func (c *DatasetClient) UploadBlob(ctx context.Context, containerSASUri, blobName string, data []byte) error { + u, err := url.Parse(containerSASUri) + if err != nil { + return fmt.Errorf("invalid container SAS URI: %w", err) + } + + // Append blob name to the container path. + u.Path = strings.TrimSuffix(u.Path, "/") + "/" + blobName + + req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.String(), bytes.NewReader(data)) + if err != nil { + return fmt.Errorf("failed to create upload request: %w", err) + } + req.Header.Set("x-ms-blob-type", "BlockBlob") + req.Header.Set("Content-Type", "application/octet-stream") + + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to upload blob: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("blob upload failed with status %d: %s", resp.StatusCode, string(body)) + } + + return nil +} + +// FinalizeDatasetVersion completes the dataset version after blob upload +// by sending the metadata (name, version, dataUri) to the API. +func (c *DatasetClient) FinalizeDatasetVersion( + ctx context.Context, + name string, + version string, + dataURI string, + apiVersion string, +) (*Dataset, error) { + path := fmt.Sprintf("%s/%s/versions/%s", pathDatasets, url.PathEscape(name), url.PathEscape(version)) + request := &FinalizeDatasetRequest{ + Name: name, + Version: version, + Type: "uri_file", + DataURI: dataURI, + } + return doRequestTyped[Dataset](c, ctx, http.MethodPut, path, nil, request, apiVersion) +} + +// GetDataset retrieves metadata for a dataset by name and version. +func (c *DatasetClient) GetDataset( + ctx context.Context, + name string, + version string, + apiVersion string, +) (*Dataset, error) { + path := fmt.Sprintf("%s/%s/versions/%s", pathDatasets, url.PathEscape(name), url.PathEscape(version)) + return doRequestTyped[Dataset](c, ctx, http.MethodGet, path, nil, nil, apiVersion) +} + +// GetDatasetCredential retrieves a SAS credential for downloading a dataset from blob storage. +func (c *DatasetClient) GetDatasetCredential( + ctx context.Context, + name string, + version string, + apiVersion string, +) (*DatasetCredential, error) { + path := fmt.Sprintf( + "%s/%s/versions/%s/credentials", + pathDatasets, url.PathEscape(name), url.PathEscape(version), + ) + return doRequestTyped[DatasetCredential](c, ctx, http.MethodPost, path, nil, nil, apiVersion) +} + +// DownloadDataset downloads dataset content from blob storage using a SAS-authenticated URL. +// Returns the raw content as bytes. The downloadURL should be the full URL with SAS token +// (e.g., from DatasetCredential.ResolvedDownloadURI()). +func (c *DatasetClient) DownloadDataset(ctx context.Context, downloadURL string) ([]byte, error) { + req, err := runtime.NewRequest(ctx, http.MethodGet, downloadURL) + if err != nil { + return nil, fmt.Errorf("failed to create download request: %w", err) + } + + // Use a plain HTTP client for blob downloads — the SAS token in the URL provides + // authentication, and Azure SDK pipeline policies (bearer token, correlation ID) + // should not be sent to Azure Blob Storage endpoints. + httpClient := &http.Client{} + resp, err := httpClient.Do(req.Raw()) + if err != nil { + return nil, fmt.Errorf("failed to download dataset from blob: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("blob download failed with status %d", resp.StatusCode) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read dataset content: %w", err) + } + + log.Printf("[dataset_api] downloaded %d bytes", len(data)) + return data, nil +} + +// ListContainerBlobs lists blobs in a container using a container-level SAS URI. +// The containerSASUri should include the SAS token (e.g., from credential.sasUri with sr=c). +// Returns a list of blob names found in the container. +func (c *DatasetClient) ListContainerBlobs(ctx context.Context, containerSASUri string) ([]string, error) { + // Parse the container URI and append list query parameters. + u, err := url.Parse(containerSASUri) + if err != nil { + return nil, fmt.Errorf("invalid container SAS URI: %w", err) + } + + q := u.Query() + q.Set("restype", "container") // cspell:ignore restype — Azure Storage API query parameter + q.Set("comp", "list") + u.RawQuery = q.Encode() + + log.Printf("[dataset_api] listing blobs: %s", u.Redacted()) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create list request: %w", err) + } + + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to list container blobs: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("container list failed with status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read list response: %w", err) + } + + // Parse XML blob listing to extract blob names. + names := parseBlobNames(string(body)) + log.Printf("[dataset_api] found %d blobs in container", len(names)) + return names, nil +} + +// DownloadBlob downloads a single blob from a container using the container SAS URI +// and the blob name. Returns the blob content as bytes. +func (c *DatasetClient) DownloadBlob(ctx context.Context, containerSASUri, blobName string) ([]byte, error) { + u, err := url.Parse(containerSASUri) + if err != nil { + return nil, fmt.Errorf("invalid container SAS URI: %w", err) + } + + // Append blob name to the container path. + u.Path = strings.TrimSuffix(u.Path, "/") + "/" + blobName + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create blob download request: %w", err) + } + + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to download blob: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("blob download failed with status %d for %s", resp.StatusCode, blobName) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read blob content: %w", err) + } + + log.Printf("[dataset_api] downloaded blob %s (%d bytes)", blobName, len(data)) + return data, nil +} + +// parseBlobNames extracts blob names from the Azure Blob Storage XML list response +// using proper XML parsing against the EnumerationResults schema. +func parseBlobNames(xmlBody string) []string { + type blob struct { + Name string `xml:"Name"` + } + type blobs struct { + Blob []blob `xml:"Blob"` + } + type enumerationResults struct { + Blobs blobs `xml:"Blobs"` + } + + var result enumerationResults + if err := xml.Unmarshal([]byte(xmlBody), &result); err != nil { + return nil + } + + names := make([]string, 0, len(result.Blobs.Blob)) + for _, b := range result.Blobs.Blob { + if b.Name != "" { + names = append(names, b.Name) + } + } + return names +} + +// doRequest performs an HTTP request against the dataset API and returns the raw response body. +func (c *DatasetClient) doRequest( + ctx context.Context, + method string, + path string, + query map[string]string, + body any, + apiVersion string, +) ([]byte, error) { + u, err := url.Parse(c.endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint URL: %w", err) + } + + u.Path += path + q := u.Query() + if apiVersion != "" { + q.Set("api-version", apiVersion) + } + for k, v := range query { + q.Set(k, v) + } + u.RawQuery = q.Encode() + + req, err := runtime.NewRequest(ctx, method, u.String()) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + log.Printf("[dataset_api] %s %s", method, u.Redacted()) + + if body != nil { + payload, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + if err := req.SetBody(streaming.NopCloser(bytes.NewReader(payload)), "application/json"); err != nil { + return nil, fmt.Errorf("failed to set request body: %w", err) + } + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + log.Printf("[dataset_api] response status: %d", resp.StatusCode) + + if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated, http.StatusAccepted) { + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + return nil, runtime.NewResponseError(resp) + } + + return respBody, nil +} + +// doRequestTyped performs an HTTP request and unmarshals the response into T. +func doRequestTyped[T any]( + c *DatasetClient, + ctx context.Context, + method string, + path string, + query map[string]string, + body any, + apiVersion string, +) (*T, error) { + respBody, err := c.doRequest(ctx, method, path, query, body, apiVersion) + if err != nil { + return nil, err + } + + if len(respBody) == 0 { + return new(T), nil + } + + var result T + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/operations_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/operations_test.go new file mode 100644 index 00000000000..64ec678fb29 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/dataset_api/operations_test.go @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package dataset_api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// test helpers +// --------------------------------------------------------------------------- + +type fakeCredential struct{} + +func (f *fakeCredential) GetToken( + _ context.Context, + _ policy.TokenRequestOptions, +) (azcore.AccessToken, error) { + return azcore.AccessToken{Token: "fake-token"}, nil +} + +func newTestClient(t *testing.T, handler http.Handler) (*DatasetClient, *httptest.Server) { + t.Helper() + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + pipeline := runtime.NewPipeline( + "test", + "v0.0.0", + runtime.PipelineOptions{}, + &policy.ClientOptions{}, + ) + client := NewDatasetClientFromPipeline(server.URL, pipeline) + return client, server +} + +// --------------------------------------------------------------------------- +// NewDatasetClient +// --------------------------------------------------------------------------- + +func TestNewDatasetClient(t *testing.T) { + t.Parallel() + + client := NewDatasetClient("https://example.ai.azure.com", &fakeCredential{}) + require.NotNil(t, client) + assert.Equal(t, "https://example.ai.azure.com", client.endpoint) +} + +// --------------------------------------------------------------------------- +// CreateDataset +// --------------------------------------------------------------------------- + +func TestCreateDataset_Success(t *testing.T) { + t.Parallel() + + var capturedPath string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + data, _ := json.Marshal(map[string]any{"name": "my-ds", "version": "v1"}) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.CreateDataset(t.Context(), &CreateDatasetRequest{ + Name: "my-ds", + Version: "v1", + Format: "jsonl", + Content: `{"input":"hello"}`, + }, "2025-11-15-preview") + + require.NoError(t, err) + assert.Equal(t, "/datasets", capturedPath) + assert.Equal(t, "my-ds", result.Name) + assert.Equal(t, "v1", result.Version) +} + +// --------------------------------------------------------------------------- +// GetDataset +// --------------------------------------------------------------------------- + +func TestGetDataset_Success(t *testing.T) { + t.Parallel() + + var capturedPath string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(map[string]any{ + "name": "golden", + "version": "v2", + "blob_uri": "https://storage.blob.core.windows.net/datasets/golden.jsonl", + }) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.GetDataset(t.Context(), "golden", "v2", "2025-11-15-preview") + + require.NoError(t, err) + assert.Equal(t, "/datasets/golden/versions/v2", capturedPath) + assert.Equal(t, "golden", result.Name) + assert.Equal(t, "v2", result.Version) + assert.Equal(t, "https://storage.blob.core.windows.net/datasets/golden.jsonl", result.BlobURI) +} + +func TestDataset_UnmarshalServicePayload(t *testing.T) { + t.Parallel() + + // Recorded service GET /datasets//versions/ response (snake_case). + payload := `{ + "name": "eval-golden", + "version": "3", + "format": "jsonl", + "blob_uri": "https://store.blob.core.windows.net/ds/eval-golden.jsonl", + "data_uri": "https://store.blob.core.windows.net/ds/eval-golden-data.jsonl", + "content_uri": "https://store.blob.core.windows.net/ds/eval-golden-content.jsonl" + }` + + var ds Dataset + require.NoError(t, json.Unmarshal([]byte(payload), &ds)) + + assert.Equal(t, "eval-golden", ds.Name) + assert.Equal(t, "3", ds.Version) + assert.Equal(t, "jsonl", ds.Format) + assert.Equal(t, "https://store.blob.core.windows.net/ds/eval-golden.jsonl", ds.BlobURI) + assert.Equal(t, "https://store.blob.core.windows.net/ds/eval-golden-data.jsonl", ds.DataURI) + assert.Equal(t, "https://store.blob.core.windows.net/ds/eval-golden-content.jsonl", ds.ContentURI) + + // ResolvedBlobURI prefers blob_uri. + assert.Equal(t, ds.BlobURI, ds.ResolvedBlobURI()) + + // When blob_uri is empty, falls back to data_uri. + ds.BlobURI = "" + assert.Equal(t, ds.DataURI, ds.ResolvedBlobURI()) + + // When both blob_uri and data_uri are empty, falls back to content_uri. + ds.DataURI = "" + assert.Equal(t, ds.ContentURI, ds.ResolvedBlobURI()) +} + +// --------------------------------------------------------------------------- +// GetDatasetCredential +// --------------------------------------------------------------------------- + +func TestGetDatasetCredential_Success(t *testing.T) { + t.Parallel() + + var capturedPath, capturedMethod string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + capturedMethod = r.Method + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(map[string]any{ + "blob_uri": "https://storage.blob.core.windows.net/datasets/golden.jsonl", + "sas": "sig=abc&se=2025-12-31", + }) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.GetDatasetCredential(t.Context(), "golden", "v2", "2025-11-15-preview") + + require.NoError(t, err) + assert.Equal(t, "/datasets/golden/versions/v2/credentials", capturedPath) + assert.Equal(t, http.MethodPost, capturedMethod) + assert.Equal(t, "https://storage.blob.core.windows.net/datasets/golden.jsonl", result.BlobURI) + assert.Equal(t, "sig=abc&se=2025-12-31", result.SAS) +} + +// --------------------------------------------------------------------------- +// DownloadDataset +// --------------------------------------------------------------------------- + +func TestDownloadDataset_Success(t *testing.T) { + t.Parallel() + + blobContent := `{"input":"hello","expected":"world"}` + "\n" + + `{"input":"foo","expected":"bar"}` + "\n" + + blobServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(blobContent)) + })) + t.Cleanup(blobServer.Close) + + client := NewDatasetClient("https://example.ai.azure.com", &fakeCredential{}) + data, err := client.DownloadDataset(t.Context(), blobServer.URL+"/datasets/golden.jsonl?sig=abc") + + require.NoError(t, err) + assert.Equal(t, blobContent, string(data)) +} + +func TestDownloadDataset_Error(t *testing.T) { + t.Parallel() + + blobServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + t.Cleanup(blobServer.Close) + + client := NewDatasetClient("https://example.ai.azure.com", &fakeCredential{}) + _, err := client.DownloadDataset(t.Context(), blobServer.URL+"/datasets/golden.jsonl?sig=expired") + + require.Error(t, err) + assert.Contains(t, err.Error(), "403") +} + +// --------------------------------------------------------------------------- +// Error handling +// --------------------------------------------------------------------------- + +func TestGetDataset_NotFound(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"not found"}`)) + }) + + client, _ := newTestClient(t, handler) + _, err := client.GetDataset(t.Context(), "missing", "v1", "2025-11-15-preview") + + require.Error(t, err) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/artifacts.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/artifacts.go new file mode 100644 index 00000000000..2dd49447323 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/artifacts.go @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "slices" + "strings" + "time" + + "azureaiagent/internal/pkg/agents/dataset_api" + "azureaiagent/internal/pkg/agents/opt_eval" +) + +// Artifact directory names relative to the agent project root. +const ( + EvaluatorsDir = "evaluators" + DatasetsDir = "datasets" + EvaluatorContractFile = "rubric_dimensions.json" +) + +// ResolveRelPath resolves a relative path against the agent project directory. +// If the path is already absolute it is returned as-is. +func ResolveRelPath(path, agentProject string) string { + if filepath.IsAbs(path) { + return path + } + return filepath.Join(agentProject, path) +} + +// DownloadDatasetArtifact downloads the dataset and writes it locally. +// If the download fails (e.g., non-TLS test server), it returns nil gracefully. +// On success it returns the relative local URI (datasets///) for the +// downloaded directory. The SAS URI may point to a container (downloads all blobs) +// or a single blob. +func DownloadDatasetArtifact( + ctx context.Context, + client *dataset_api.DatasetClient, + agentProject string, + ref *opt_eval.DatasetRef, + apiVersion string, +) (string, error) { + if ref == nil || ref.Name == "" { + return "", nil + } + + // Attempt full download via the dataset API. + cred, credErr := client.GetDatasetCredential(ctx, ref.Name, ref.Version, apiVersion) + if credErr != nil { + return "", fmt.Errorf("getting dataset credential for %q: %w", ref.Name, credErr) + } + + downloadURL := cred.ResolvedDownloadURI() + if downloadURL == "" { + return "", fmt.Errorf("dataset %q returned empty download URI", ref.Name) + } + + destDir := DatasetArtifactPath(agentProject, ref) + + // Clear existing dataset directory to ensure a clean download. + if err := os.RemoveAll(destDir); err != nil { + return "", fmt.Errorf("removing existing dataset dir: %w", err) + } + if err := os.MkdirAll(destDir, 0750); err != nil { + return "", fmt.Errorf("creating dataset artifact dir: %w", err) + } + + // Determine if this is a container-level SAS (sr=c) or blob-level. + if isContainerSAS(downloadURL) { + blobs, err := client.ListContainerBlobs(ctx, downloadURL) + if err != nil { + return "", fmt.Errorf("listing container blobs for dataset %q: %w", ref.Name, err) + } + if len(blobs) == 0 { + return "", fmt.Errorf("dataset %q container has no blobs", ref.Name) + } + var errs []error + for _, blobName := range blobs { + ext := strings.ToLower(filepath.Ext(blobName)) + if ext != ".jsonl" && ext != ".csv" { + continue + } + data, dlErr := client.DownloadBlob(ctx, downloadURL, blobName) + if dlErr != nil { + errs = append(errs, fmt.Errorf("downloading blob %q: %w", blobName, dlErr)) + continue + } + dest, pathErr := opt_eval.SafePath(destDir, blobName) + if pathErr != nil { + errs = append(errs, pathErr) + continue + } + if err := os.MkdirAll(filepath.Dir(dest), 0750); err != nil { + errs = append(errs, fmt.Errorf("creating dir for %q: %w", blobName, err)) + continue + } + if err := os.WriteFile(dest, data, 0600); err != nil { + errs = append(errs, fmt.Errorf("writing %q: %w", blobName, err)) + continue + } + } + if len(errs) > 0 { + return "", errors.Join(errs...) + } + } else { + // Single blob download. + data, dlErr := client.DownloadDataset(ctx, downloadURL) + if dlErr != nil { + return "", fmt.Errorf("downloading dataset %q: %w", ref.Name, dlErr) + } + // Infer filename from URL. + filename := filenameFromURL(downloadURL) + dest := filepath.Join(destDir, filename) + if err := os.WriteFile(dest, data, 0600); err != nil { + return "", fmt.Errorf("writing dataset artifact: %w", err) + } + } + + return DatasetLocalURI(ref.Name), nil +} + +// isContainerSAS checks if a SAS URI is container-scoped (sr=c in query). +func isContainerSAS(rawURL string) bool { + _, query, ok := strings.Cut(rawURL, "?") + if !ok { + return false + } + // Look for sr=c parameter. + return slices.Contains(strings.Split(query, "&"), "sr=c") +} + +// filenameFromURL extracts the filename from a blob URL path. +// Falls back to "data.jsonl" if unable to determine. +func filenameFromURL(rawURL string) string { + path := rawURL + if idx := strings.IndexByte(path, '?'); idx != -1 { + path = path[:idx] + } + parts := strings.Split(path, "/") + if len(parts) > 0 { + name := parts[len(parts)-1] + if name != "" && strings.Contains(name, ".") { + return name + } + } + return "data.jsonl" +} + +// DatasetArtifactPath returns the local filesystem path for a downloaded dataset directory. +func DatasetArtifactPath(agentProject string, ref *opt_eval.DatasetRef) string { + if ref == nil || ref.Name == "" { + return "" + } + return filepath.Join(agentProject, DatasetsDir, ref.Name) +} + +// DatasetLocalURI returns the relative path (from the agent project root) +// to a dataset artifact directory. This is the value stored in DatasetRef.LocalURI. +func DatasetLocalURI(name string) string { + return filepath.Join(DatasetsDir, name) +} + +// evaluatorDir returns the full path to an evaluator's local directory. +func evaluatorDir(agentProject, name string) string { + return filepath.Join(agentProject, EvaluatorsDir, name) +} + +// EvaluatorLocalURI returns the relative path (from the agent project root) +// to an evaluator artifact file. This is the value stored in EvaluatorRef.LocalURI. +func EvaluatorLocalURI(name string) string { + return filepath.Join(EvaluatorsDir, name, EvaluatorContractFile) +} + +// SaveEvaluatorResult extracts the rubric dimensions from the evaluator result +// and saves them as the local artifact. Only dimensions are persisted so that +// users can edit weights/descriptions and upload a new evaluator version. +func SaveEvaluatorResult(agentProject, evaluatorName string, result json.RawMessage) error { + if evaluatorName == "" || len(result) == 0 { + return nil + } + dir := evaluatorDir(agentProject, evaluatorName) + if err := os.MkdirAll(dir, 0750); err != nil { + return fmt.Errorf("creating evaluator dir %q: %w", dir, err) + } + + // Parse the evaluator result to extract the rubric dimensions. + parsed := ParseEvaluatorResult(result) + if parsed == nil || len(parsed.Definition.Dimensions) == 0 { + return nil + } + + formatted, err := json.MarshalIndent(parsed.Definition.Dimensions, "", " ") + if err != nil { + return fmt.Errorf("marshalling evaluator dimensions: %w", err) + } + + path := filepath.Join(dir, EvaluatorContractFile) + if err := os.WriteFile(path, formatted, 0600); err != nil { + return fmt.Errorf("writing evaluator artifact %q: %w", path, err) + } + return nil +} + +// PrintEvaluatorDimensions prints a compact table of rubric dimensions. +func PrintEvaluatorDimensions(parsed *EvaluatorResult) { + dims := parsed.Definition.Dimensions + fmt.Printf("\n Evaluator dimensions (%d):\n", len(dims)) + fmt.Println(" Weight Dimension") + fmt.Println(" ────── ─────────") + for _, d := range dims { + fmt.Printf(" %6d %s\n", d.Weight, d.ID) + } +} + +// WriteEvalReviewArtifacts writes human-readable review artifacts for evaluators. +// It writes a stub YAML file for each evaluator unless a result JSON already exists. +func WriteEvalReviewArtifacts(agentProject string, cfg *EvalConfig) error { + if cfg == nil { + return nil + } + var errs []error + for _, evaluator := range cfg.Evaluators { + if evaluator.Name == "" || IsBuiltinEvaluator(evaluator.Name) { + continue + } + dir := evaluatorDir(agentProject, evaluator.Name) + if err := os.MkdirAll(dir, 0750); err != nil { + errs = append(errs, fmt.Errorf("creating dir for evaluator %q: %w", evaluator.Name, err)) + continue + } + // Skip if a result JSON already exists. + jsonPath := filepath.Join(dir, EvaluatorContractFile) + if _, err := os.Stat(jsonPath); err == nil { + continue + } + yamlPath := filepath.Join(dir, evaluator.Name+".yaml") + stub := fmt.Sprintf("# Evaluator stub: %s\nname: %s\n", evaluator.Name, evaluator.Name) + if err := os.WriteFile(yamlPath, []byte(stub), 0600); err != nil { + errs = append(errs, fmt.Errorf("writing evaluator stub %q: %w", yamlPath, err)) + } + } + return errors.Join(errs...) +} + +// WriteJSONFile writes a value as indented JSON to the specified path. +func WriteJSONFile(path string, v any) error { + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { + return fmt.Errorf("creating output directory: %w", err) + } + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + return fmt.Errorf("marshalling JSON: %w", err) + } + return os.WriteFile(path, data, 0600) +} + +// FormatTimestamp formats a timestamp value (int64, float64, or string) as a +// human-readable UTC string. +func FormatTimestamp(ts any) string { + switch v := ts.(type) { + case int64: + if v == 0 { + return "" + } + return time.Unix(v, 0).UTC().Format("2006-01-02 15:04:05 UTC") + case float64: + if v == 0 { + return "" + } + return time.Unix(int64(v), 0).UTC().Format("2006-01-02 15:04:05 UTC") + case int: + if v == 0 { + return "" + } + return time.Unix(int64(v), 0).UTC().Format("2006-01-02 15:04:05 UTC") + case string: + if v == "" { + return "" + } + t, err := time.Parse(time.RFC3339, v) + if err != nil { + return v + } + return t.UTC().Format("2006-01-02 15:04:05 UTC") + default: + return "" + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/artifacts_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/artifacts_test.go new file mode 100644 index 00000000000..073497ef367 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/artifacts_test.go @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/stretchr/testify/assert" +) + +func TestDatasetArtifactPath_Basic(t *testing.T) { + t.Parallel() + ref := &opt_eval.DatasetRef{Name: "test-ds", Version: "v2"} + got := DatasetArtifactPath("/project", ref) + assert.Equal(t, filepath.Join("/project", "datasets", "test-ds"), got) +} + +func TestDatasetArtifactPath_NilRef(t *testing.T) { + t.Parallel() + got := DatasetArtifactPath("/project", nil) + assert.Empty(t, got) +} + +func TestDatasetArtifactPath_EmptyName(t *testing.T) { + t.Parallel() + ref := &opt_eval.DatasetRef{Name: ""} + got := DatasetArtifactPath("/project", ref) + assert.Empty(t, got) +} + +func TestDatasetLocalURI(t *testing.T) { + t.Parallel() + got := DatasetLocalURI("my-dataset") + assert.Equal(t, filepath.Join("datasets", "my-dataset"), got) +} + +func TestEvaluatorLocalURI(t *testing.T) { + t.Parallel() + got := EvaluatorLocalURI("coherence") + assert.Equal(t, filepath.Join("evaluators", "coherence", "rubric_dimensions.json"), got) +} + +func TestIsContainerSAS(t *testing.T) { + t.Parallel() + assert.True(t, isContainerSAS("https://blob.core.windows.net/container?sr=c&sig=abc")) + assert.False(t, isContainerSAS("https://blob.core.windows.net/container?sr=b&sig=abc")) + assert.False(t, isContainerSAS("https://blob.core.windows.net/container")) +} + +func TestFilenameFromURL(t *testing.T) { + t.Parallel() + assert.Equal(t, "data.jsonl", filenameFromURL("https://blob.core.windows.net/c/data.jsonl?sig=abc")) + assert.Equal(t, "data.jsonl", filenameFromURL("https://blob.core.windows.net/c/prefix/data.jsonl?sig=abc")) + assert.Equal(t, "data.jsonl", filenameFromURL("https://blob.core.windows.net/c/noext")) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/eval_config.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/eval_config.go new file mode 100644 index 00000000000..0fb499f1a9a --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/eval_config.go @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/agents/opt_eval" + + "go.yaml.in/yaml/v3" +) + +// EvalConfig extends the shared Config with eval-specific fields and helpers. +type EvalConfig struct { + opt_eval.Config `yaml:",inline"` + + // Options holds run-time options (eval_model, etc.). + Options *opt_eval.Options `yaml:"options,omitempty"` + + // MaxSamples is the maximum number of data samples to generate. + MaxSamples int `yaml:"max_samples,omitempty"` + + // TraceDays is the number of days of agent traces to include (0 = none). + TraceDays int `yaml:"trace_days,omitempty"` +} + +// LoadEvalConfig reads and parses a YAML eval config file. +func LoadEvalConfig(path string) (*EvalConfig, error) { + data, err := os.ReadFile(path) //nolint:gosec // path is provided by user for local config + if err != nil { + return nil, fmt.Errorf("failed to read eval config %q: %w", path, err) + } + + var cfg EvalConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse eval config %q: %w", path, err) + } + + return &cfg, nil +} + +// WriteEvalConfig writes the eval config to a YAML file. +func WriteEvalConfig(path string, cfg *EvalConfig) error { + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { + return fmt.Errorf("creating config directory: %w", err) + } + + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("failed to marshal eval config: %w", err) + } + + if err := os.WriteFile(path, data, 0600); err != nil { + return fmt.Errorf("failed to write eval config %q: %w", path, err) + } + + return nil +} + +// Validate checks required fields for the eval command. +func (c *EvalConfig) Validate() error { + if c.Name == "" { + return exterrors.Validation( + exterrors.CodeEvalConfigInvalid, + "name is required in the eval config", + "add a 'name:' field to your eval.yaml") + } + + if c.Agent.Name == "" { + return exterrors.Validation( + exterrors.CodeEvalConfigInvalid, + "agent.name is required", + "add 'agent.name' to your eval.yaml or use --agent") + } + + if len(c.Evaluators) == 0 { + return exterrors.Validation( + exterrors.CodeEvalConfigInvalid, + "at least one evaluator is required", + "add an 'evaluators:' section to your eval.yaml or use --evaluator") + } + + hasFile := c.DatasetFile != "" + hasRef := c.DatasetReference != nil + + if hasFile && hasRef { + return fmt.Errorf("dataset_file and dataset_reference are mutually exclusive; specify one, not both") + } + + if !hasFile && !hasRef { + return fmt.Errorf("one of dataset_file or dataset_reference is required") + } + + return nil +} + +// ToAgentTargetAdaptableEvalGroupRequest builds the request body for creating an OpenAI eval +// with agent target completions and adaptable evaluator schema. +func (c *EvalConfig) ToAgentTargetAdaptableEvalGroupRequest() *CreateOpenAIEvalRequest { + request := &CreateOpenAIEvalRequest{ + Name: c.Name, + Metadata: map[string]string{ + "azd_agent": c.Agent.Name, + "azd_agent_version": c.Agent.Version, + }, + DataSourceConfig: &DataSourceConfig{ + Type: "custom", + IncludeSampleSchema: true, + ItemSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + //"ground_truth": map[string]any{"type": "string"}, + }, + }, + }, + } + + // Build testing_criteria from evaluators. + evalModel := "" + if c.Options != nil { + evalModel = c.Options.EvalModel + } + for _, evaluator := range c.Evaluators { + apiName := strings.TrimPrefix(evaluator.Name, "builtin.") + criterion := TestingCriterion{ + Type: "azure_ai_evaluator", + Name: apiName, + EvaluatorName: evaluator.Name, + DataMapping: map[string]string{ + //"messages": "{{item.messages}}", + "query": "{{item.query}}", + //"ground_truth": "{{item.ground_truth}}", + "response": "{{sample.output_items}}", + "tool_calls": "{{sample.tool_calls}}", + "tool_definitions": "{{sample.tool_definitions}}", + }, + } + if evalModel != "" { + criterion.InitializationParameters = map[string]any{ + "model": evalModel, + "deployment_name": evalModel, + } + } + request.TestingCriteria = append(request.TestingCriteria, criterion) + } + + return request +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/eval_config_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/eval_config_test.go new file mode 100644 index 00000000000..c0acd91bf35 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/eval_config_test.go @@ -0,0 +1,307 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "os" + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/agent_yaml" + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Validate +// --------------------------------------------------------------------------- + +func TestValidate_RequiresName(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Agent: opt_eval.AgentRef{Name: "agent-1"}, + DatasetReference: &opt_eval.DatasetRef{Name: "ds", Version: "v1"}, + Evaluators: opt_eval.EvaluatorList{{Name: "coherence"}}, + }, + } + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "name is required") +} + +func TestValidate_RequiresAgentName(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "my-eval", + Agent: opt_eval.AgentRef{}, + DatasetReference: &opt_eval.DatasetRef{Name: "ds", Version: "v1"}, + Evaluators: opt_eval.EvaluatorList{{Name: "coherence"}}, + }, + } + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "agent.name is required") +} + +func TestValidate_RequiresEvaluators(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "my-eval", + Agent: opt_eval.AgentRef{Name: "agent-1"}, + DatasetReference: &opt_eval.DatasetRef{Name: "ds", Version: "v1"}, + }, + } + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "at least one evaluator") +} + +func TestValidate_RequiresDataset(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "my-eval", + Agent: opt_eval.AgentRef{Name: "agent-1"}, + Evaluators: opt_eval.EvaluatorList{{Name: "coherence"}}, + }, + } + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "dataset_file or dataset_reference is required") +} + +func TestValidate_MutuallyExclusiveDataset(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "my-eval", + Agent: opt_eval.AgentRef{Name: "agent-1"}, + DatasetFile: "tasks.jsonl", + DatasetReference: &opt_eval.DatasetRef{Name: "ds", Version: "v1"}, + Evaluators: opt_eval.EvaluatorList{{Name: "coherence"}}, + }, + } + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "mutually exclusive") +} + +func TestValidate_ValidWithDatasetFile(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "my-eval", + Agent: opt_eval.AgentRef{Name: "agent-1"}, + DatasetFile: "tasks.jsonl", + Evaluators: opt_eval.EvaluatorList{{Name: "coherence"}}, + }, + } + assert.NoError(t, cfg.Validate()) +} + +func TestValidate_ValidWithDatasetReference(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "my-eval", + Agent: opt_eval.AgentRef{Name: "agent-1"}, + DatasetReference: &opt_eval.DatasetRef{Name: "ds", Version: "v1"}, + Evaluators: opt_eval.EvaluatorList{{Name: "coherence"}}, + }, + } + assert.NoError(t, cfg.Validate()) +} + +// --------------------------------------------------------------------------- +// LoadEvalConfig / WriteEvalConfig round-trip +// --------------------------------------------------------------------------- + +func TestEvalConfig_RoundTrip_FullFields(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "eval.yaml") + + original := &EvalConfig{ + Config: opt_eval.Config{ + Name: "full-test", + Agent: opt_eval.AgentRef{ + Name: "booking-agent", + Kind: "hosted", + Version: "v3", + Model: "gpt-4.1", + }, + DatasetReference: &opt_eval.DatasetRef{Name: "golden-data", Version: "v2"}, + Evaluators: opt_eval.EvaluatorList{{Name: "builtin.task_adherence"}, {Name: "custom-quality"}}, + }, + Options: &opt_eval.Options{ + EvalModel: "gpt-4o", + }, + MaxSamples: 75, + } + + require.NoError(t, WriteEvalConfig(path, original)) + loaded, err := LoadEvalConfig(path) + require.NoError(t, err) + + assert.Equal(t, "full-test", loaded.Name) + assert.Equal(t, "booking-agent", loaded.Agent.Name) + assert.Equal(t, agent_yaml.AgentKind("hosted"), loaded.Agent.Kind) + assert.Equal(t, "v3", loaded.Agent.Version) + assert.Equal(t, "gpt-4.1", loaded.Agent.Model) + require.NotNil(t, loaded.DatasetReference) + assert.Equal(t, "golden-data", loaded.DatasetReference.Name) + assert.Equal(t, "v2", loaded.DatasetReference.Version) + require.Len(t, loaded.Evaluators, 2) + assert.Equal(t, "builtin.task_adherence", loaded.Evaluators[0].Name) + assert.Equal(t, "custom-quality", loaded.Evaluators[1].Name) + assert.Equal(t, "gpt-4o", loaded.Options.EvalModel) + assert.Equal(t, 75, loaded.MaxSamples) +} + +func TestEvalConfig_RoundTrip_MinimalFields(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "eval.yaml") + + original := &EvalConfig{ + Config: opt_eval.Config{ + Agent: opt_eval.AgentRef{Name: "simple-agent"}, + DatasetFile: "data.jsonl", + }, + } + + require.NoError(t, WriteEvalConfig(path, original)) + loaded, err := LoadEvalConfig(path) + require.NoError(t, err) + + assert.Equal(t, "simple-agent", loaded.Agent.Name) + assert.Equal(t, "data.jsonl", loaded.DatasetFile) + assert.Nil(t, loaded.DatasetReference) + assert.Empty(t, loaded.Evaluators) + assert.True(t, loaded.Agent.Instruction.IsEmpty()) + assert.Zero(t, loaded.MaxSamples) +} + +func TestLoadEvalConfig_MissingFile(t *testing.T) { + t.Parallel() + _, err := LoadEvalConfig("/nonexistent/path/eval.yaml") + assert.Error(t, err) +} + +func TestLoadEvalConfig_InvalidYAML(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "bad.yaml") + require.NoError(t, os.WriteFile(path, []byte("{{invalid yaml}}"), 0600)) + _, err := LoadEvalConfig(path) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse") +} + +func TestWriteEvalConfig_CreatesDirectory(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "subdir", "nested", "eval.yaml") + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Agent: opt_eval.AgentRef{Name: "agent-1"}, + }, + } + + require.NoError(t, WriteEvalConfig(path, cfg)) + assert.FileExists(t, path) +} + +// --------------------------------------------------------------------------- +// ToAgentTargetAdaptableEvalGroupRequest +// --------------------------------------------------------------------------- + +func TestToAgentTargetAdaptableEvalGroupRequest_WithEvaluators(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "test-eval", + Agent: opt_eval.AgentRef{Name: "agent-1", Version: "v1"}, + Evaluators: opt_eval.EvaluatorList{{Name: "builtin.quality"}, {Name: "custom-1"}}, + DatasetFile: "tasks.jsonl", + }, + Options: &opt_eval.Options{EvalModel: "gpt-4o"}, + } + + req := cfg.ToAgentTargetAdaptableEvalGroupRequest() + + assert.Equal(t, "test-eval", req.Name) + assert.Equal(t, "agent-1", req.Metadata["azd_agent"]) + assert.Equal(t, "v1", req.Metadata["azd_agent_version"]) + require.NotNil(t, req.DataSourceConfig) + assert.Equal(t, "custom", req.DataSourceConfig.Type) + require.Len(t, req.TestingCriteria, 2) + assert.Equal(t, "azure_ai_evaluator", req.TestingCriteria[0].Type) + assert.Equal(t, "builtin.quality", req.TestingCriteria[0].EvaluatorName) + assert.Equal(t, "gpt-4o", req.TestingCriteria[0].InitializationParameters["model"]) + assert.Equal(t, "{{item.query}}", req.TestingCriteria[0].DataMapping["query"]) + assert.Equal(t, "custom-1", req.TestingCriteria[1].EvaluatorName) +} + +func TestToAgentTargetAdaptableEvalGroupRequest_WithDatasetReference(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "ref-eval", + Agent: opt_eval.AgentRef{Name: "agent-1"}, + DatasetReference: &opt_eval.DatasetRef{Name: "ds", Version: "v1"}, + }, + } + + req := cfg.ToAgentTargetAdaptableEvalGroupRequest() + // DataSourceConfig is always set with the custom schema. + require.NotNil(t, req.DataSourceConfig) + assert.Equal(t, "custom", req.DataSourceConfig.Type) +} + +func TestToAgentTargetAdaptableEvalGroupRequest_NoEvaluators(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "test-eval", + Agent: opt_eval.AgentRef{Name: "agent-1"}, + DatasetFile: "tasks.jsonl", + }, + } + + req := cfg.ToAgentTargetAdaptableEvalGroupRequest() + assert.Empty(t, req.TestingCriteria) +} + +func TestToAgentTargetAdaptableEvalGroupRequest_MetadataFields(t *testing.T) { + t.Parallel() + + cfg := &EvalConfig{ + Config: opt_eval.Config{ + Name: "meta-test", + Agent: opt_eval.AgentRef{Name: "my-agent", Version: "v5"}, + DatasetFile: "tasks.jsonl", + }, + } + + req := cfg.ToAgentTargetAdaptableEvalGroupRequest() + assert.Equal(t, "my-agent", req.Metadata["azd_agent"]) + assert.Equal(t, "v5", req.Metadata["azd_agent_version"]) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/generation.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/generation.go new file mode 100644 index 00000000000..47d9410083d --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/generation.go @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "path/filepath" + "strings" + "time" + + "azureaiagent/internal/pkg/agents/opt_eval" +) + +// --------------------------------------------------------------------------- +// Generation source building +// --------------------------------------------------------------------------- + +// TraceOptions holds optional trace inclusion parameters for generation sources. +type TraceOptions struct { + Days int +} + +// BuildGenerationSources constructs the sources array for generation jobs. +// A prompt source is included when instruction is non-empty, along with the +// agent source. When traces is non-nil and Days > 0, a traces source is +// appended with start_time computed from the current time. +func BuildGenerationSources(agentKind, agentName, version, instruction string, traces *TraceOptions) []GenerationSource { + var sources []GenerationSource + + if instruction != "" { + sources = append(sources, GenerationSource{ + Type: "prompt", + Prompt: instruction, + }) + } + + agentSource := GenerationSource{ + Type: "agent", + AgentName: agentName, + } + if version != "" { + agentSource.AgentVersion = version + } + sources = append(sources, agentSource) + + if traces != nil && traces.Days > 0 { + startTime := time.Now().AddDate(0, 0, -traces.Days).Unix() + sources = append(sources, GenerationSource{ + Type: "traces", + AgentName: agentName, + StartTime: startTime, + }) + } + + return sources +} + +// --------------------------------------------------------------------------- +// Request builders +// --------------------------------------------------------------------------- + +// NewDataGenerationJobRequest builds a DataGenerationJobRequest from the +// provided parameters. Currently, it's always "simple_qna" type with multiple sources +func NewDataGenerationJobRequest( + name, evalModel string, + maxSamples int, + sources []GenerationSource, +) *DataGenerationJobRequest { + return &DataGenerationJobRequest{ + Inputs: DataGenerationInputs{ + Name: name, + Scenario: "evaluation", + Options: DataGenerationOptions{ + Type: "simple_qna", + MaxSamples: maxSamples, + ModelOptions: ModelOptions{ + Model: evalModel, + }, + }, + Sources: sources, + }, + } +} + +// NewEvaluatorGenerationJobRequest builds an EvaluatorGenerationJobRequest +// from the provided parameters. +func NewEvaluatorGenerationJobRequest( + name, evalModel string, + sources []GenerationSource, +) *EvaluatorGenerationJobRequest { + return &EvaluatorGenerationJobRequest{ + Name: name, + EvaluatorName: name, + Category: "quality", + Model: evalModel, + Sources: sources, + } +} + +// --------------------------------------------------------------------------- +// Evaluator classification +// --------------------------------------------------------------------------- + +// IsBuiltinEvaluator returns true when the evaluator name has the "builtin." +// prefix. +func IsBuiltinEvaluator(name string) bool { + return strings.HasPrefix(name, "builtin.") +} + +// SplitEvaluators partitions evaluators into generated (non-builtin) and +// built-in lists. +func SplitEvaluators(evaluators opt_eval.EvaluatorList) (generated, builtin opt_eval.EvaluatorList) { + for _, e := range evaluators { + if IsBuiltinEvaluator(e.Name) { + builtin = append(builtin, e) + } else { + generated = append(generated, e) + } + } + return generated, builtin +} + +// --------------------------------------------------------------------------- +// Dataset name detection +// --------------------------------------------------------------------------- + +// IsDatasetName returns true when the value looks like a registered dataset +// name rather than a local file path. A name has no path separators and no +// common data-file extension (.jsonl, .json, .csv). +func IsDatasetName(value string) bool { + if value == "" { + return false + } + if strings.ContainsAny(value, "/\\") { + return false + } + ext := strings.ToLower(filepath.Ext(value)) + return ext != ".jsonl" && ext != ".json" && ext != ".csv" +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/generation_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/generation_test.go new file mode 100644 index 00000000000..f4d95245408 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/generation_test.go @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "testing" + + "azureaiagent/internal/pkg/agents/opt_eval" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// BuildGenerationSources +// --------------------------------------------------------------------------- + +func TestBuildGenerationSources_HostedWithInstruction(t *testing.T) { + t.Parallel() + sources := BuildGenerationSources("hosted", "my-agent", "v2", "Test interactions", nil) + require.Len(t, sources, 2) + assert.Equal(t, "prompt", sources[0].Type) + assert.Equal(t, "Test interactions", sources[0].Prompt) + assert.Equal(t, "agent", sources[1].Type) + assert.Equal(t, "my-agent", sources[1].AgentName) + assert.Equal(t, "v2", sources[1].AgentVersion) +} + +func TestBuildGenerationSources_NoVersion(t *testing.T) { + t.Parallel() + sources := BuildGenerationSources("hosted", "agent", "", "", nil) + require.Len(t, sources, 1) + assert.Empty(t, sources[0].AgentVersion) +} + +func TestBuildGenerationSources_HostedNoInstruction(t *testing.T) { + t.Parallel() + sources := BuildGenerationSources("hosted", "agent", "v1", "", nil) + require.Len(t, sources, 1) + assert.Equal(t, "agent", sources[0].Type) +} + +// --------------------------------------------------------------------------- +// NewDataGenerationJobRequest +// --------------------------------------------------------------------------- + +func TestNewDataGenerationJobRequest(t *testing.T) { + t.Parallel() + sources := []GenerationSource{{Type: "agent", AgentName: "a1"}} + req := NewDataGenerationJobRequest("eval-suite", "gpt-4o", 50, sources) + assert.Equal(t, "eval-suite", req.Inputs.Name) + assert.Equal(t, "evaluation", req.Inputs.Scenario) + assert.Equal(t, "simple_qna", req.Inputs.Options.Type) + assert.Equal(t, 50, req.Inputs.Options.MaxSamples) + assert.Equal(t, "gpt-4o", req.Inputs.Options.ModelOptions.Model) + require.Len(t, req.Inputs.Sources, 1) +} + +// --------------------------------------------------------------------------- +// NewEvaluatorGenerationJobRequest +// --------------------------------------------------------------------------- + +func TestNewEvaluatorGenerationJobRequest(t *testing.T) { + t.Parallel() + sources := []GenerationSource{{Type: "agent", AgentName: "a1"}} + req := NewEvaluatorGenerationJobRequest("eval-suite", "gpt-4o", sources) + assert.Equal(t, "eval-suite", req.Name) + assert.Equal(t, "eval-suite", req.EvaluatorName) + assert.Equal(t, "quality", req.Category) + assert.Equal(t, "gpt-4o", req.Model) + require.Len(t, req.Sources, 1) +} + +// --------------------------------------------------------------------------- +// IsBuiltinEvaluator +// --------------------------------------------------------------------------- + +func TestIsBuiltinEvaluator(t *testing.T) { + t.Parallel() + assert.True(t, IsBuiltinEvaluator("builtin.task_adherence")) + assert.True(t, IsBuiltinEvaluator("builtin.")) + assert.False(t, IsBuiltinEvaluator("my-quality")) + assert.False(t, IsBuiltinEvaluator("")) + assert.False(t, IsBuiltinEvaluator("builtins.quality")) +} + +// --------------------------------------------------------------------------- +// SplitEvaluators +// --------------------------------------------------------------------------- + +func TestSplitEvaluators(t *testing.T) { + t.Parallel() + + t.Run("mixed", func(t *testing.T) { + t.Parallel() + gen, bi := SplitEvaluators(opt_eval.EvaluatorList{ + {Name: "builtin.task_adherence"}, {Name: "my-quality"}, {Name: "builtin.safety"}, + }) + assert.Equal(t, opt_eval.EvaluatorList{{Name: "my-quality"}}, gen) + assert.Equal(t, opt_eval.EvaluatorList{{Name: "builtin.task_adherence"}, {Name: "builtin.safety"}}, bi) + }) + + t.Run("all builtin", func(t *testing.T) { + t.Parallel() + gen, bi := SplitEvaluators(opt_eval.EvaluatorList{ + {Name: "builtin.quality"}, {Name: "builtin.safety"}, + }) + assert.Nil(t, gen) + assert.Equal(t, opt_eval.EvaluatorList{{Name: "builtin.quality"}, {Name: "builtin.safety"}}, bi) + }) + + t.Run("nil", func(t *testing.T) { + t.Parallel() + gen, bi := SplitEvaluators(nil) + assert.Nil(t, gen) + assert.Nil(t, bi) + }) +} + +// --------------------------------------------------------------------------- +// IsDatasetName +// --------------------------------------------------------------------------- + +func TestIsDatasetName(t *testing.T) { + t.Parallel() + assert.True(t, IsDatasetName("eval-data-2026")) + assert.True(t, IsDatasetName("my-dataset.v2")) + assert.False(t, IsDatasetName("golden.jsonl")) + assert.False(t, IsDatasetName("data.json")) + assert.False(t, IsDatasetName("results.csv")) + assert.False(t, IsDatasetName("./tests/golden.jsonl")) + assert.False(t, IsDatasetName("")) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/models.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/models.go new file mode 100644 index 00000000000..a917e700140 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/models.go @@ -0,0 +1,407 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import "encoding/json" + +// --------------------------------------------------------------------------- +// Data Generation Jobs +// --------------------------------------------------------------------------- + +// DataGenerationJobRequest is the request body for CreateDataGenerationJob. +type DataGenerationJobRequest struct { + Inputs DataGenerationInputs `json:"inputs"` +} + +// DataGenerationInputs holds the inputs for a data generation job. +type DataGenerationInputs struct { + Name string `json:"name"` + Scenario string `json:"scenario"` + Options DataGenerationOptions `json:"options"` + Sources []GenerationSource `json:"sources"` +} + +// DataGenerationOptions holds configuration for data generation. +type DataGenerationOptions struct { + Type string `json:"type"` + MaxSamples int `json:"max_samples"` + ModelOptions ModelOptions `json:"model_options"` +} + +// ModelOptions holds the model selection for generation. +type ModelOptions struct { + Model string `json:"model"` +} + +// GenerationSource describes a source used for dataset or evaluator generation. +type GenerationSource struct { + Type string `json:"type"` + Prompt string `json:"prompt,omitempty"` + AgentName string `json:"agent_name,omitempty"` + AgentVersion string `json:"agent_version,omitempty"` + StartTime int64 `json:"start_time,omitempty"` +} + +// GenerationJob is the response for data and evaluator generation job operations. +type GenerationJob struct { + ID string `json:"id"` + Status string `json:"status"` + Result json.RawMessage `json:"result,omitempty"` + Error *JobError `json:"error,omitempty"` +} + +// JobError captures error details from a failed generation job. +type JobError struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +// OperationID returns the job's operation identifier. +func (j *GenerationJob) OperationID() string { + return j.ID +} + +// NormalizedStatus returns the lowercase status, defaulting to "running". +func (j *GenerationJob) NormalizedStatus() string { + if j.Status == "" { + return "running" + } + return j.Status +} + +// ResolvedNameVersion extracts the name and version from the generation job result. +// If name is empty, both return values are empty (caller should treat as no result). +// If version is empty, it defaults to "latest". +func (j *GenerationJob) ResolvedNameVersion() (string, string) { + name := j.resultStringField("name") + if name == "" { + return "", "" + } + version := j.resultStringField("version") + if version == "" { + version = "latest" + } + return name, version +} + +// resultStringField extracts a string field from the raw Result JSON. +// It first checks for a top-level key, then falls back to outputs[0].key +// to handle the nested response format. +func (j *GenerationJob) resultStringField(key string) string { + if len(j.Result) == 0 { + return "" + } + var m map[string]json.RawMessage + if err := json.Unmarshal(j.Result, &m); err != nil { + return "" + } + + // Try top-level field first. + if raw, ok := m[key]; ok { + var s string + if err := json.Unmarshal(raw, &s); err == nil && s != "" { + return s + } + } + + // Fall back to outputs[0].key for nested response format. + if rawOutputs, ok := m["outputs"]; ok { + var outputs []map[string]json.RawMessage + if err := json.Unmarshal(rawOutputs, &outputs); err == nil && len(outputs) > 0 { + if raw, ok := outputs[0][key]; ok { + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + } + } + } + + return "" +} + +// --------------------------------------------------------------------------- +// Evaluator Generation Jobs +// --------------------------------------------------------------------------- + +// EvaluatorGenerationJobRequest is the request body for CreateEvaluatorGenerationJob. +type EvaluatorGenerationJobRequest struct { + Name string `json:"name"` + EvaluatorName string `json:"evaluator_name"` + Category string `json:"category"` + Model string `json:"model"` + Sources []GenerationSource `json:"sources"` +} + +// --------------------------------------------------------------------------- +// Evaluator Versions +// --------------------------------------------------------------------------- + +// EvaluatorVersion is the response for evaluator version operations. +type EvaluatorVersion struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// --------------------------------------------------------------------------- +// Evaluator Definition (Rubric) +// --------------------------------------------------------------------------- + +// EvaluatorResult is the top-level response from evaluator generation, +// containing the evaluator's definition. +type EvaluatorResult struct { + Name string `json:"name"` + Version string `json:"version,omitempty"` + Definition EvaluatorDefinition `json:"definition"` +} + +// EvaluatorDefinition describes an evaluator's scoring rubric. +type EvaluatorDefinition struct { + Type string `json:"type"` + Dimensions []EvaluatorDimension `json:"dimensions"` +} + +// EvaluatorDimension is a single scoring dimension within a rubric evaluator. +type EvaluatorDimension struct { + ID string `json:"id"` + Description string `json:"description,omitempty"` + Weight int `json:"weight"` + AlwaysApplicable bool `json:"always_applicable,omitempty"` +} + +// ParseEvaluatorResult parses a GenerationJob result into a structured EvaluatorResult. +// Returns nil if the result cannot be parsed. +func ParseEvaluatorResult(result json.RawMessage) *EvaluatorResult { + if len(result) == 0 { + return nil + } + var r EvaluatorResult + if err := json.Unmarshal(result, &r); err != nil { + return nil + } + if len(r.Definition.Dimensions) == 0 { + return nil + } + return &r +} + +// --------------------------------------------------------------------------- +// Datasets +// --------------------------------------------------------------------------- + +// CreateDatasetRequest is the request body for CreateDataset. +type CreateDatasetRequest struct { + Name string `json:"name"` + Version string `json:"version"` + Format string `json:"format"` + Content string `json:"content"` +} + +// Dataset is the response for dataset operations. +type Dataset struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// --------------------------------------------------------------------------- +// OpenAI Evals +// --------------------------------------------------------------------------- + +// DataSourceConfig describes the data source for an OpenAI eval. +type DataSourceConfig struct { + Type string `json:"type"` + ItemSchema map[string]any `json:"item_schema"` + IncludeSampleSchema bool `json:"include_sample_schema"` +} + +// DataSourceSchema defines the item and sample schemas for an eval data source. +type DataSourceSchema struct { + Item map[string]any `json:"item,omitempty"` + Sample map[string]any `json:"sample,omitempty"` +} + +// TestingCriterion describes a single evaluator in testing_criteria. +type TestingCriterion struct { + Type string `json:"type"` + Name string `json:"name"` + EvaluatorName string `json:"evaluator_name"` + InitializationParameters map[string]any `json:"initialization_parameters,omitempty"` + DataMapping map[string]string `json:"data_mapping,omitempty"` +} + +// CreateOpenAIEvalRequest is the request body for CreateOpenAIEval. +type CreateOpenAIEvalRequest struct { + Name string `json:"name"` + Metadata map[string]string `json:"metadata,omitempty"` + DataSourceConfig *DataSourceConfig `json:"data_source_config,omitempty"` + TestingCriteria []TestingCriterion `json:"testing_criteria,omitempty"` +} + +// OpenAIEval is the response for an OpenAI eval definition. +type OpenAIEval struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + CreatedAt any `json:"created_at,omitempty"` + ModifiedAt any `json:"modified_at,omitempty"` + CreatedBy string `json:"created_by,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// ResolvedID returns the eval's ID, falling back to name. +func (e *OpenAIEval) ResolvedID() string { + if e.ID != "" { + return e.ID + } + return e.Name +} + +// OpenAIEvalList is the response for listing OpenAI eval definitions. +type OpenAIEvalList struct { + Data []OpenAIEval `json:"data"` +} + +// --------------------------------------------------------------------------- +// OpenAI Eval Runs +// --------------------------------------------------------------------------- + +// CreateOpenAIEvalRunRequest is the request body for CreateOpenAIEvalRun. +type CreateOpenAIEvalRunRequest struct { + Name string `json:"name"` + DataSource *EvalRunDataSource `json:"data_source,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// EvalRunDataSourceType defines the type for an eval run data source. +type EvalRunDataSourceType string + +const ( + // EvalRunDataSourceTypeAgentTarget is the data source type for agent target completions. + EvalRunDataSourceTypeAgentTarget EvalRunDataSourceType = "azure_ai_target_completions" +) + +// EvalRunDataContentType defines the source type for eval run data content. +type EvalRunDataContentType string + +const ( + EvalRunDataContentTypeFileContent EvalRunDataContentType = "file_content" + EvalRunDataContentTypeFileID EvalRunDataContentType = "file_id" +) + +// EvalRunDataSource describes the data source for an eval run with agent target completions. +type EvalRunDataSource struct { + Type EvalRunDataSourceType `json:"type"` + InputMessages *EvalRunInputMessages `json:"input_messages,omitempty"` + Source *EvalRunDataContent `json:"source,omitempty"` + Target *EvalRunTarget `json:"target,omitempty"` +} + +// EvalRunInputMessages describes how input messages are constructed from dataset items. +type EvalRunInputMessages struct { + Type string `json:"type"` + Template []EvalRunMessageTemplate `json:"template"` +} + +// EvalRunMessageTemplate describes a single message in the input template. +type EvalRunMessageTemplate struct { + Role string `json:"role"` + Content string `json:"content"` + Type string `json:"type"` +} + +// EvalRunTarget describes the agent target for completions. +type EvalRunTarget struct { + Type string `json:"type"` + Name string `json:"name"` + Version *string `json:"version"` + ToolDescriptions []string `json:"tool_descriptions"` +} + +// EvalRunDataContent holds the source reference within an EvalRunDataSource. +type EvalRunDataContent struct { + Type EvalRunDataContentType `json:"type"` + ID string `json:"id,omitempty"` + Content []map[string]any `json:"content,omitempty"` +} + +// NewAgentTargetDataSource builds an EvalRunDataSource configured for agent target completions. +// The source field must be set separately via SetFileContent or SetFileID. +func NewAgentTargetDataSource(agentName string, agentVersion *string) *EvalRunDataSource { + return &EvalRunDataSource{ + Type: EvalRunDataSourceTypeAgentTarget, + InputMessages: &EvalRunInputMessages{ + Type: "template", + Template: []EvalRunMessageTemplate{ + { + Role: "user", + Content: "{{item.query}}", + Type: "message", + }, + }, + }, + Target: &EvalRunTarget{ + Type: "azure_ai_agent", + Name: agentName, + Version: agentVersion, + ToolDescriptions: []string{}, + }, + } +} + +// SetFileContent sets the data source to use inline file content. +func (ds *EvalRunDataSource) SetFileContent(items []map[string]any) { + ds.Source = &EvalRunDataContent{ + Type: EvalRunDataContentTypeFileContent, + Content: items, + } +} + +// SetFileID sets the data source to reference a remote dataset by ID. +func (ds *EvalRunDataSource) SetFileID(fileID string) { + ds.Source = &EvalRunDataContent{ + Type: EvalRunDataContentTypeFileID, + ID: fileID, + } +} + +// OpenAIEvalRun is the response for an OpenAI eval run. +type OpenAIEvalRun struct { + ID string `json:"id"` + EvalID string `json:"eval_id,omitempty"` + Name string `json:"name,omitempty"` + Status string `json:"status,omitempty"` + CreatedAt any `json:"created_at,omitempty"` + ModifiedAt any `json:"modified_at,omitempty"` + CreatedBy string `json:"created_by,omitempty"` + DataSource *EvalRunDataSource `json:"data_source,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + ReportURL string `json:"report_url,omitempty"` + + // Result summary + ResultCounts *EvalRunResultCounts `json:"result_counts,omitempty"` + PerTestingCriteria []EvalRunCriteriaResult `json:"per_testing_criteria_results,omitempty"` + Error any `json:"error,omitempty"` +} + +// EvalRunResultCounts holds pass/fail/error/skip counts for a run. +type EvalRunResultCounts struct { + Total int `json:"total"` + Passed int `json:"passed"` + Failed int `json:"failed"` + Errored int `json:"errored"` + Skipped int `json:"skipped"` +} + +// EvalRunCriteriaResult holds per-testing-criteria pass/fail counts. +type EvalRunCriteriaResult struct { + TestingCriteria string `json:"testing_criteria"` + Passed int `json:"passed"` + Failed int `json:"failed"` + Errored int `json:"errored"` + Skipped int `json:"skipped"` +} + +// OpenAIEvalRunList is the response for listing OpenAI eval runs. +type OpenAIEvalRunList struct { + Data []OpenAIEvalRun `json:"data"` +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/operations.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/operations.go new file mode 100644 index 00000000000..8c661cadd13 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/operations.go @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strconv" + + "azureaiagent/internal/version" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/azure/azure-dev/cli/azd/pkg/azsdk" +) + +// API path prefixes for eval service endpoints. +const ( + pathDataGenerationJobs = "/data_generation_jobs" + pathEvaluatorGenerationJobs = "/evaluator_generation_jobs" + pathEvaluators = "/evaluators" + pathDatasets = "/datasets" + pathOpenAIEvals = "/openai/evals" +) + +// EvalClient provides methods for interacting with the Azure AI eval APIs. +type EvalClient struct { + endpoint string + pipeline runtime.Pipeline +} + +// NewEvalClient creates a new EvalClient. +func NewEvalClient(endpoint string, cred azcore.TokenCredential) *EvalClient { + userAgent := fmt.Sprintf("azd-ext-azure-ai-agents/%s", version.Version) + + clientOptions := &policy.ClientOptions{ + Logging: policy.LogOptions{ + AllowedHeaders: []string{"X-Ms-Correlation-Request-Id", "X-Request-Id"}, + IncludeBody: false, + }, + PerCallPolicies: []policy.Policy{ + runtime.NewBearerTokenPolicy(cred, []string{"https://ai.azure.com/.default"}, nil), + azsdk.NewMsCorrelationPolicy(), + azsdk.NewUserAgentPolicy(userAgent), + }, + } + + pipeline := runtime.NewPipeline( + "azure-ai-evals", + "v1.0.0", + runtime.PipelineOptions{}, + clientOptions, + ) + + return &EvalClient{ + endpoint: endpoint, + pipeline: pipeline, + } +} + +// NewEvalClientFromPipeline creates an EvalClient with a pre-built pipeline. +// This is intended for tests that need to bypass auth policies. +func NewEvalClientFromPipeline(endpoint string, pipeline runtime.Pipeline) *EvalClient { + return &EvalClient{ + endpoint: endpoint, + pipeline: pipeline, + } +} + +// CreateDataGenerationJob starts a dataset generation job for eval onboarding. +func (c *EvalClient) CreateDataGenerationJob( + ctx context.Context, + request *DataGenerationJobRequest, + apiVersion string, +) (*GenerationJob, error) { + return doRequestTyped[GenerationJob](c, ctx, http.MethodPost, pathDataGenerationJobs, nil, request, apiVersion) +} + +// GetDataGenerationJob gets the current state of a dataset generation job. +func (c *EvalClient) GetDataGenerationJob( + ctx context.Context, + operationID string, + apiVersion string, +) (*GenerationJob, error) { + path := pathDataGenerationJobs + "/" + url.PathEscape(operationID) + return doRequestTyped[GenerationJob](c, ctx, http.MethodGet, path, nil, nil, apiVersion) +} + +// CreateEvaluatorGenerationJob starts an evaluator generation job for eval onboarding. +func (c *EvalClient) CreateEvaluatorGenerationJob( + ctx context.Context, + request *EvaluatorGenerationJobRequest, + apiVersion string, +) (*GenerationJob, error) { + return doRequestTyped[GenerationJob](c, ctx, http.MethodPost, pathEvaluatorGenerationJobs, nil, request, apiVersion) +} + +// GetEvaluatorGenerationJob gets the current state of an evaluator generation job. +func (c *EvalClient) GetEvaluatorGenerationJob( + ctx context.Context, + operationID string, + apiVersion string, +) (*GenerationJob, error) { + path := pathEvaluatorGenerationJobs + "/" + url.PathEscape(operationID) + return doRequestTyped[GenerationJob](c, ctx, http.MethodGet, path, nil, nil, apiVersion) +} + +// CreateEvaluatorVersion creates a new version of a named evaluator. +// The body should be the full evaluator JSON with the definition field updated. +func (c *EvalClient) CreateEvaluatorVersion( + ctx context.Context, + name string, + body json.RawMessage, + apiVersion string, +) (*EvaluatorVersion, error) { + path := pathEvaluators + "/" + url.PathEscape(name) + "/versions" + return doRequestTyped[EvaluatorVersion](c, ctx, http.MethodPost, path, nil, body, apiVersion) +} + +// GetEvaluatorRaw gets an evaluator by name and version as raw JSON. +// If version is empty, the latest version is fetched. +func (c *EvalClient) GetEvaluatorRaw( + ctx context.Context, + name string, + version string, + apiVersion string, +) (json.RawMessage, error) { + path := pathEvaluators + "/" + url.PathEscape(name) + if version != "" { + path += "/versions/" + url.PathEscape(version) + } + return c.doRequest(ctx, http.MethodGet, path, nil, nil, apiVersion) +} + +// CreateOpenAIEval creates an OpenAI eval definition. +func (c *EvalClient) CreateOpenAIEval( + ctx context.Context, + request *CreateOpenAIEvalRequest, + apiVersion string, +) (*OpenAIEval, error) { + return doRequestTyped[OpenAIEval](c, ctx, http.MethodPost, pathOpenAIEvals, nil, request, apiVersion) +} + +// ListOpenAIEvals lists OpenAI eval definitions. +func (c *EvalClient) ListOpenAIEvals(ctx context.Context, limit int, apiVersion string) (*OpenAIEvalList, error) { + query := map[string]string{} + if limit > 0 { + query["limit"] = strconv.Itoa(limit) + } + + return doRequestTyped[OpenAIEvalList](c, ctx, http.MethodGet, pathOpenAIEvals, query, nil, apiVersion) +} + +// GetOpenAIEval gets an OpenAI eval definition. +func (c *EvalClient) GetOpenAIEval(ctx context.Context, evalID string, apiVersion string) (*OpenAIEval, error) { + path := pathOpenAIEvals + "/" + url.PathEscape(evalID) + return doRequestTyped[OpenAIEval](c, ctx, http.MethodGet, path, nil, nil, apiVersion) +} + +// CreateOpenAIEvalRun starts a run for an OpenAI eval definition. +func (c *EvalClient) CreateOpenAIEvalRun( + ctx context.Context, + evalID string, + request *CreateOpenAIEvalRunRequest, + apiVersion string, +) (*OpenAIEvalRun, error) { + path := fmt.Sprintf("%s/%s/runs", pathOpenAIEvals, url.PathEscape(evalID)) + return doRequestTyped[OpenAIEvalRun](c, ctx, http.MethodPost, path, nil, request, apiVersion) +} + +// ListOpenAIEvalRuns lists runs for an OpenAI eval definition. +func (c *EvalClient) ListOpenAIEvalRuns( + ctx context.Context, + evalID string, + limit int, + apiVersion string, +) (*OpenAIEvalRunList, error) { + query := map[string]string{} + if limit > 0 { + query["limit"] = strconv.Itoa(limit) + } + + path := fmt.Sprintf("%s/%s/runs", pathOpenAIEvals, url.PathEscape(evalID)) + return doRequestTyped[OpenAIEvalRunList](c, ctx, http.MethodGet, path, query, nil, apiVersion) +} + +// GetOpenAIEvalRun gets a run for an OpenAI eval definition. +func (c *EvalClient) GetOpenAIEvalRun( + ctx context.Context, + evalID string, + runID string, + apiVersion string, +) (*OpenAIEvalRun, error) { + path := fmt.Sprintf("%s/%s/runs/%s", pathOpenAIEvals, url.PathEscape(evalID), url.PathEscape(runID)) + return doRequestTyped[OpenAIEvalRun](c, ctx, http.MethodGet, path, nil, nil, apiVersion) +} + +func (c *EvalClient) doRequest( + ctx context.Context, + method string, + path string, + query map[string]string, + body any, + apiVersion string, +) ([]byte, error) { + u, err := url.Parse(c.endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint URL: %w", err) + } + + u.Path += path + q := u.Query() + if apiVersion != "" { + q.Set("api-version", apiVersion) + } + for k, v := range query { + q.Set(k, v) + } + u.RawQuery = q.Encode() + + req, err := runtime.NewRequest(ctx, method, u.String()) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + log.Printf("[eval_api] %s %s", method, u.Redacted()) + + if body != nil { + payload, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + if err := req.SetBody(streaming.NopCloser(bytes.NewReader(payload)), "application/json"); err != nil { + return nil, fmt.Errorf("failed to set request body: %w", err) + } + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + log.Printf("[eval_api] response status: %d", resp.StatusCode) + + if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated, http.StatusAccepted) { + // Restore the body so runtime.NewResponseError can read it. + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + return nil, runtime.NewResponseError(resp) + } + + return respBody, nil +} + +// doRequestTyped performs an HTTP request and unmarshals the response into T. +func doRequestTyped[T any]( + c *EvalClient, + ctx context.Context, + method string, + path string, + query map[string]string, + body any, + apiVersion string, +) (*T, error) { + respBody, err := c.doRequest(ctx, method, path, query, body, apiVersion) + if err != nil { + return nil, err + } + + if len(respBody) == 0 { + return new(T), nil + } + + var result T + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/operations_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/operations_test.go new file mode 100644 index 00000000000..08d845303cb --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/operations_test.go @@ -0,0 +1,450 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// test helpers +// --------------------------------------------------------------------------- + +// fakeCredential satisfies azcore.TokenCredential for tests without real auth. +type fakeCredential struct{} + +func (f *fakeCredential) GetToken( + _ context.Context, + _ policy.TokenRequestOptions, +) (azcore.AccessToken, error) { + return azcore.AccessToken{Token: "fake-token"}, nil +} + +// newTestClient creates an EvalClient pointed at a test HTTP server. +func newTestClient(t *testing.T, handler http.Handler) (*EvalClient, *httptest.Server) { + t.Helper() + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + pipeline := runtime.NewPipeline( + "test", + "v0.0.0", + runtime.PipelineOptions{}, + &policy.ClientOptions{}, + ) + client := NewEvalClientFromPipeline(server.URL, pipeline) + return client, server +} + +// --------------------------------------------------------------------------- +// NewEvalClient +// --------------------------------------------------------------------------- + +func TestNewEvalClient(t *testing.T) { + t.Parallel() + + client := NewEvalClient("https://example.ai.azure.com", &fakeCredential{}) + require.NotNil(t, client) + assert.Equal(t, "https://example.ai.azure.com", client.endpoint) +} + +// --------------------------------------------------------------------------- +// CreateDataGenerationJob +// --------------------------------------------------------------------------- + +func TestCreateDataGenerationJob_Success(t *testing.T) { + t.Parallel() + + var capturedPath, capturedAPIVersion string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + capturedAPIVersion = r.URL.Query().Get("api-version") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + resp := map[string]any{"id": "op-123", "status": "running"} + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.CreateDataGenerationJob(t.Context(), &DataGenerationJobRequest{ + Inputs: DataGenerationInputs{ + Name: "test", + Scenario: "evaluation", + }, + }, "v1") + + require.NoError(t, err) + assert.Equal(t, "/data_generation_jobs", capturedPath) + assert.Equal(t, "v1", capturedAPIVersion) + assert.Equal(t, "op-123", result.ID) + assert.Equal(t, "running", result.Status) +} + +// --------------------------------------------------------------------------- +// GetDataGenerationJob +// --------------------------------------------------------------------------- + +func TestGetDataGenerationJob_Success(t *testing.T) { + t.Parallel() + + var capturedPath string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + resp := map[string]any{ + "id": "op-123", + "status": "completed", + "result": map[string]any{"name": "test-ds", "version": "v1"}, + } + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.GetDataGenerationJob(t.Context(), "op-123", "v1") + + require.NoError(t, err) + assert.Equal(t, "/data_generation_jobs/op-123", capturedPath) + assert.Equal(t, "completed", result.Status) + name, _ := result.ResolvedNameVersion() + assert.Equal(t, "test-ds", name) +} + +// --------------------------------------------------------------------------- +// CreateEvaluatorGenerationJob +// --------------------------------------------------------------------------- + +func TestCreateEvaluatorGenerationJob_Success(t *testing.T) { + t.Parallel() + + var capturedPath string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + resp := map[string]any{"id": "eval-op-456", "status": "running"} + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.CreateEvaluatorGenerationJob( + t.Context(), &EvaluatorGenerationJobRequest{Name: "my-eval"}, "2025-11-15-preview", + ) + + require.NoError(t, err) + assert.Equal(t, "/evaluator_generation_jobs", capturedPath) + assert.Equal(t, "eval-op-456", result.ID) +} + +// --------------------------------------------------------------------------- +// GetEvaluatorGenerationJob +// --------------------------------------------------------------------------- + +func TestGetEvaluatorGenerationJob_Success(t *testing.T) { + t.Parallel() + + var capturedPath string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + resp := map[string]any{ + "id": "eval-op-456", + "status": "completed", + "result": map[string]any{"name": "quality"}, + } + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.GetEvaluatorGenerationJob(t.Context(), "eval-op-456", "2025-11-15-preview") + + require.NoError(t, err) + assert.Equal(t, "/evaluator_generation_jobs/eval-op-456", capturedPath) + assert.Equal(t, "completed", result.Status) + name, _ := result.ResolvedNameVersion() + assert.Equal(t, "quality", name) +} + +// --------------------------------------------------------------------------- +// CreateOpenAIEval +// --------------------------------------------------------------------------- + +func TestCreateOpenAIEval_Success(t *testing.T) { + t.Parallel() + + var capturedPath string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + resp := map[string]any{"id": "eval-001", "name": "smoke-core"} + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.CreateOpenAIEval( + t.Context(), &CreateOpenAIEvalRequest{Name: "smoke-core"}, "2025-11-15-preview", + ) + + require.NoError(t, err) + assert.Equal(t, "/openai/evals", capturedPath) + assert.Equal(t, "eval-001", result.ID) +} + +// --------------------------------------------------------------------------- +// ListOpenAIEvals +// --------------------------------------------------------------------------- + +func TestListOpenAIEvals_Success(t *testing.T) { + t.Parallel() + + var capturedLimit string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedLimit = r.URL.Query().Get("limit") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + resp := map[string]any{ + "data": []any{ + map[string]any{"id": "eval-1"}, + map[string]any{"id": "eval-2"}, + }, + } + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.ListOpenAIEvals(t.Context(), 10, "2025-11-15-preview") + + require.NoError(t, err) + assert.Equal(t, "10", capturedLimit) + assert.Len(t, result.Data, 2) +} + +func TestListOpenAIEvals_ZeroLimit(t *testing.T) { + t.Parallel() + + var hasLimitParam bool + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hasLimitParam = r.URL.Query().Has("limit") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[]}`)) + }) + + client, _ := newTestClient(t, handler) + _, err := client.ListOpenAIEvals(t.Context(), 0, "2025-11-15-preview") + + require.NoError(t, err) + assert.False(t, hasLimitParam, "limit should not be set when 0") +} + +// --------------------------------------------------------------------------- +// GetOpenAIEval +// --------------------------------------------------------------------------- + +func TestGetOpenAIEval_Success(t *testing.T) { + t.Parallel() + + var capturedPath string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + resp := map[string]any{"id": "eval-001", "name": "smoke-core", "metadata": map[string]string{"azd_agent": "agent-1"}} + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.GetOpenAIEval(t.Context(), "eval-001", "2025-11-15-preview") + + require.NoError(t, err) + assert.Equal(t, "/openai/evals/eval-001", capturedPath) + assert.Equal(t, "smoke-core", result.Name) +} + +// --------------------------------------------------------------------------- +// CreateOpenAIEvalRun +// --------------------------------------------------------------------------- + +func TestCreateOpenAIEvalRun_Success(t *testing.T) { + t.Parallel() + + var capturedPath string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + resp := map[string]any{"id": "run-001", "status": "running"} + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.CreateOpenAIEvalRun( + t.Context(), "eval-001", &CreateOpenAIEvalRunRequest{ + Metadata: map[string]string{"agent": "a"}, + }, "2025-11-15-preview", + ) + + require.NoError(t, err) + assert.Equal(t, "/openai/evals/eval-001/runs", capturedPath) + assert.Equal(t, "run-001", result.ID) +} + +// --------------------------------------------------------------------------- +// ListOpenAIEvalRuns +// --------------------------------------------------------------------------- + +func TestListOpenAIEvalRuns_Success(t *testing.T) { + t.Parallel() + + var capturedPath, capturedLimit string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + capturedLimit = r.URL.Query().Get("limit") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + resp := map[string]any{"data": []any{map[string]any{"id": "run-1"}}} + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.ListOpenAIEvalRuns(t.Context(), "eval-001", 5, "2025-11-15-preview") + + require.NoError(t, err) + assert.Equal(t, "/openai/evals/eval-001/runs", capturedPath) + assert.Equal(t, "5", capturedLimit) + assert.Len(t, result.Data, 1) +} + +// --------------------------------------------------------------------------- +// GetOpenAIEvalRun +// --------------------------------------------------------------------------- + +func TestGetOpenAIEvalRun_Success(t *testing.T) { + t.Parallel() + + var capturedPath string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + resp := map[string]any{"id": "run-001", "status": "completed", "score": 0.92} + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + }) + + client, _ := newTestClient(t, handler) + result, err := client.GetOpenAIEvalRun(t.Context(), "eval-001", "run-001", "2025-11-15-preview") + + require.NoError(t, err) + assert.Equal(t, "/openai/evals/eval-001/runs/run-001", capturedPath) + assert.Equal(t, "completed", result.Status) +} + +// --------------------------------------------------------------------------- +// Error handling +// --------------------------------------------------------------------------- + +func TestDoRequest_ServerError(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + + client, _ := newTestClient(t, handler) + _, err := client.CreateOpenAIEval(t.Context(), &CreateOpenAIEvalRequest{}, "2025-11-15-preview") + assert.Error(t, err) +} + +func TestDoRequest_EmptyBody(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + client, _ := newTestClient(t, handler) + result, err := client.ListOpenAIEvals(t.Context(), 0, "2025-11-15-preview") + require.NoError(t, err) + assert.Empty(t, result.Data) +} + +func TestDoRequest_APIVersionInQuery(t *testing.T) { + t.Parallel() + + var capturedAPIVersion string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAPIVersion = r.URL.Query().Get("api-version") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{}`)) + }) + + client, _ := newTestClient(t, handler) + _, err := client.GetOpenAIEval(t.Context(), "eval-1", "2025-11-15-preview") + require.NoError(t, err) + assert.Equal(t, "2025-11-15-preview", capturedAPIVersion) +} + +func TestDoRequest_RequestBodySent(t *testing.T) { + t.Parallel() + + var capturedBody map[string]any + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + _ = json.NewDecoder(r.Body).Decode(&capturedBody) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":"ok"}`)) + }) + + client, _ := newTestClient(t, handler) + req := &DataGenerationJobRequest{ + Inputs: DataGenerationInputs{ + Name: "test-eval", + Scenario: "evaluation", + }, + } + _, err := client.CreateDataGenerationJob(t.Context(), req, "v1") + + require.NoError(t, err) + require.NotNil(t, capturedBody) + inputs, ok := capturedBody["inputs"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "test-eval", inputs["name"]) + assert.Equal(t, "evaluation", inputs["scenario"]) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/poller.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/poller.go new file mode 100644 index 00000000000..71e9df09788 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/poller.go @@ -0,0 +1,195 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "context" + "fmt" + "log" + "strings" + "time" + + "azureaiagent/internal/pkg/agents" +) + +// --------------------------------------------------------------------------- +// JobStatus — typed status with terminal/failed semantics +// --------------------------------------------------------------------------- + +// JobStatus represents the normalized status of a generation job. +type JobStatus string + +const ( + JobStatusRunning JobStatus = "running" + JobStatusCompleted JobStatus = "completed" + JobStatusSucceeded JobStatus = "succeeded" + JobStatusFailed JobStatus = "failed" + JobStatusCancelled JobStatus = "cancelled" + JobStatusCanceled JobStatus = "canceled" +) + +// ParseJobStatus normalizes a raw status string into a JobStatus. +// An empty string is treated as "running". +func ParseJobStatus(s string) JobStatus { + if s == "" { + return JobStatusRunning + } + return JobStatus(strings.ToLower(s)) +} + +// IsTerminal returns true when the status represents a final state. +func (s JobStatus) IsTerminal() bool { + switch s { + case JobStatusCompleted, JobStatusSucceeded, JobStatusFailed, JobStatusCancelled, JobStatusCanceled: + return true + } + return false +} + +// IsFailed returns true when the status represents a failure or cancellation. +func (s JobStatus) IsFailed() bool { + switch s { + case JobStatusFailed, JobStatusCancelled, JobStatusCanceled: + return true + } + return false +} + +// String returns the status as a plain string. +func (s JobStatus) String() string { + return string(s) +} + +// --------------------------------------------------------------------------- +// JobFailedError — returned when a polled job reaches a failed state +// --------------------------------------------------------------------------- + +// JobFailedError is returned when a generation job reaches a failed terminal state. +type JobFailedError struct { + Job *GenerationJob + Status JobStatus +} + +func (e *JobFailedError) Error() string { + if e.Job != nil && e.Job.Error != nil && e.Job.Error.Message != "" { + return fmt.Sprintf("job failed with status %q: %s", e.Status, e.Job.Error.Message) + } + return fmt.Sprintf("job failed with status %q", e.Status) +} + +// --------------------------------------------------------------------------- +// PollerTimeoutError — returned when polling exhausts all attempts +// --------------------------------------------------------------------------- + +// PollerTimeoutError is returned when a generation job has not reached a +// terminal state within the configured number of polling attempts. +type PollerTimeoutError struct { + OperationID string + Attempts int +} + +func (e *PollerTimeoutError) Error() string { + return fmt.Sprintf( + "operation %s did not complete within %d attempts", + e.OperationID, e.Attempts, + ) +} + +// --------------------------------------------------------------------------- +// GetJobFunc — callback type for fetching job state +// --------------------------------------------------------------------------- + +// GetJobFunc fetches the current state of a generation job by operation ID. +type GetJobFunc func(ctx context.Context, operationID, apiVersion string) (*GenerationJob, error) + +// --------------------------------------------------------------------------- +// PollerOptions — configurable polling behavior +// --------------------------------------------------------------------------- + +// PollerOptions configures the polling interval and attempt limit. +type PollerOptions struct { + Interval time.Duration + MaxAttempts int +} + +// DefaultPollerOptions returns sensible defaults: 2 s interval, 300 attempts (~10 min). +func DefaultPollerOptions() PollerOptions { + return PollerOptions{ + Interval: 2 * time.Second, + MaxAttempts: 300, + } +} + +// --------------------------------------------------------------------------- +// Poller — polls a generation job until it reaches a terminal state +// --------------------------------------------------------------------------- + +// Poller polls a GenerationJob until it reaches a terminal status. +type Poller struct { + OperationID string + APIVersion string + GetJob GetJobFunc + Options PollerOptions + // OnPoll is called after each successful poll with the latest status. + // Callers can use this for progress reporting (e.g. debug logging). + OnPoll func(status JobStatus) +} + +// NewPoller creates a Poller with default options. +func NewPoller(operationID, apiVersion string, getJob GetJobFunc) *Poller { + return &Poller{ + OperationID: operationID, + APIVersion: apiVersion, + GetJob: getJob, + Options: DefaultPollerOptions(), + } +} + +// Poll blocks until the job reaches a terminal state, the context is +// cancelled, or the maximum number of attempts is exhausted. +// +// On success it returns the completed GenerationJob. +// On failure it returns a *JobFailedError (which wraps the job for inspection). +// On timeout it returns a plain error. +func (p *Poller) Poll(ctx context.Context) (*GenerationJob, error) { + if p.OperationID == "" { + return nil, fmt.Errorf("operation ID is empty") + } + + for range p.Options.MaxAttempts { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(p.Options.Interval): + } + + job, err := p.GetJob(ctx, p.OperationID, p.APIVersion) + if err != nil { + if agents.IsTransientError(err) { + log.Printf("[poller] transient error polling %s, will retry: %v", p.OperationID, err) + continue + } + return nil, err + } + + status := ParseJobStatus(job.Status) + log.Printf("[poller] operationID=%s status=%s", p.OperationID, status) + + if p.OnPoll != nil { + p.OnPoll(status) + } + + if status.IsTerminal() { + if status.IsFailed() { + return nil, &JobFailedError{Job: job, Status: status} + } + return job, nil + } + } + + return nil, &PollerTimeoutError{ + OperationID: p.OperationID, + Attempts: p.Options.MaxAttempts, + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/poller_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/poller_test.go new file mode 100644 index 00000000000..88afdb0f935 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/poller_test.go @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// JobStatus +// --------------------------------------------------------------------------- + +func TestParseJobStatus(t *testing.T) { + t.Parallel() + + assert.Equal(t, JobStatusRunning, ParseJobStatus("")) + assert.Equal(t, JobStatusCompleted, ParseJobStatus("completed")) + assert.Equal(t, JobStatusCompleted, ParseJobStatus("Completed")) + assert.Equal(t, JobStatusFailed, ParseJobStatus("Failed")) + assert.Equal(t, JobStatus("pending"), ParseJobStatus("pending")) +} + +func TestJobStatus_IsTerminal(t *testing.T) { + t.Parallel() + + tests := []struct { + status string + terminal bool + }{ + {"completed", true}, + {"Completed", true}, + {"succeeded", true}, + {"failed", true}, + {"cancelled", true}, + {"canceled", true}, + {"running", false}, + {"pending", false}, + {"", false}, + } + for _, tt := range tests { + t.Run(tt.status, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.terminal, ParseJobStatus(tt.status).IsTerminal()) + }) + } +} + +func TestJobStatus_IsFailed(t *testing.T) { + t.Parallel() + + tests := []struct { + status string + failed bool + }{ + {"failed", true}, + {"Failed", true}, + {"cancelled", true}, + {"canceled", true}, + {"completed", false}, + {"succeeded", false}, + {"running", false}, + } + for _, tt := range tests { + t.Run(tt.status, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.failed, ParseJobStatus(tt.status).IsFailed()) + }) + } +} + +// --------------------------------------------------------------------------- +// Poller +// --------------------------------------------------------------------------- + +func TestPoller_EmptyOperationID(t *testing.T) { + t.Parallel() + + p := NewPoller("", "v1", func(ctx context.Context, id, ver string) (*GenerationJob, error) { + return nil, nil + }) + _, err := p.Poll(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "operation ID is empty") +} + +func TestPoller_CompletedImmediately(t *testing.T) { + t.Parallel() + + calls := 0 + p := NewPoller("op-1", "v1", func(ctx context.Context, id, ver string) (*GenerationJob, error) { + calls++ + return &GenerationJob{ID: id, Status: "completed"}, nil + }) + p.Options.Interval = time.Millisecond + + job, err := p.Poll(t.Context()) + require.NoError(t, err) + assert.Equal(t, "op-1", job.ID) + assert.Equal(t, 1, calls) +} + +func TestPoller_SucceededAfterPending(t *testing.T) { + t.Parallel() + + calls := 0 + p := NewPoller("op-2", "v1", func(ctx context.Context, id, ver string) (*GenerationJob, error) { + calls++ + if calls < 3 { + return &GenerationJob{ID: id, Status: "running"}, nil + } + return &GenerationJob{ID: id, Status: "succeeded"}, nil + }) + p.Options.Interval = time.Millisecond + + job, err := p.Poll(t.Context()) + require.NoError(t, err) + assert.Equal(t, "succeeded", job.Status) + assert.Equal(t, 3, calls) +} + +func TestPoller_FailedReturnsJobFailedError(t *testing.T) { + t.Parallel() + + p := NewPoller("op-3", "v1", func(ctx context.Context, id, ver string) (*GenerationJob, error) { + return &GenerationJob{ID: id, Status: "failed"}, nil + }) + p.Options.Interval = time.Millisecond + + _, err := p.Poll(t.Context()) + require.Error(t, err) + + var jfe *JobFailedError + require.True(t, errors.As(err, &jfe)) + assert.Equal(t, JobStatusFailed, jfe.Status) + assert.Equal(t, "op-3", jfe.Job.ID) +} + +func TestPoller_APIError(t *testing.T) { + t.Parallel() + + p := NewPoller("op-4", "v1", func(ctx context.Context, id, ver string) (*GenerationJob, error) { + return nil, fmt.Errorf("network error") + }) + p.Options.Interval = time.Millisecond + + _, err := p.Poll(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "network error") +} + +func TestPoller_MaxAttemptsExhausted(t *testing.T) { + t.Parallel() + + p := NewPoller("op-5", "v1", func(ctx context.Context, id, ver string) (*GenerationJob, error) { + return &GenerationJob{ID: id, Status: "running"}, nil + }) + p.Options.Interval = time.Millisecond + p.Options.MaxAttempts = 3 + + _, err := p.Poll(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "did not complete") + timeoutErr, ok := errors.AsType[*PollerTimeoutError](err) + require.True(t, ok) + assert.Equal(t, "op-5", timeoutErr.OperationID) + assert.Equal(t, 3, timeoutErr.Attempts) +} + +func TestPoller_ContextCancelled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(t.Context()) + cancel() // cancel immediately + + p := NewPoller("op-6", "v1", func(ctx context.Context, id, ver string) (*GenerationJob, error) { + return &GenerationJob{ID: id, Status: "running"}, nil + }) + p.Options.Interval = time.Millisecond + + _, err := p.Poll(ctx) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestPoller_OnPollCallback(t *testing.T) { + t.Parallel() + + calls := 0 + var observed []JobStatus + + p := NewPoller("op-7", "v1", func(ctx context.Context, id, ver string) (*GenerationJob, error) { + calls++ + if calls < 2 { + return &GenerationJob{ID: id, Status: "running"}, nil + } + return &GenerationJob{ID: id, Status: "completed"}, nil + }) + p.Options.Interval = time.Millisecond + p.OnPoll = func(status JobStatus) { + observed = append(observed, status) + } + + _, err := p.Poll(t.Context()) + require.NoError(t, err) + assert.Equal(t, []JobStatus{JobStatusRunning, JobStatusCompleted}, observed) +} + +// --------------------------------------------------------------------------- +// JobFailedError +// --------------------------------------------------------------------------- + +func TestJobFailedError_Error(t *testing.T) { + t.Parallel() + + e := &JobFailedError{ + Job: &GenerationJob{ID: "op-1"}, + Status: JobStatusFailed, + } + assert.Contains(t, e.Error(), "failed") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/portal_urls.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/portal_urls.go new file mode 100644 index 00000000000..8b1ccd0fd5d --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/portal_urls.go @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "encoding/base64" + "fmt" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/google/uuid" +) + +// PortalPrefix holds the parsed project context needed to construct Foundry portal URLs. +type PortalPrefix struct { + prefix string // e.g. "https://ai.azure.com/nextgen/r/,,,," +} + +// NewPortalPrefix parses an ARM project resource ID and returns a PortalPrefix +// that can be reused to build multiple portal URLs. +// Returns an error if the resource ID is invalid or not a Foundry project. +func NewPortalPrefix(projectResourceID string) (*PortalPrefix, error) { + resourceID, err := arm.ParseResourceID(projectResourceID) + if err != nil { + return nil, fmt.Errorf("failed to parse project resource ID: %w", err) + } + + encodedSub, err := encodeSubscriptionForURL(resourceID.SubscriptionID) + if err != nil { + return nil, fmt.Errorf("failed to encode subscription ID: %w", err) + } + + if resourceID.Parent == nil || + !strings.Contains(string(resourceID.ResourceType.Type), "/") { + return nil, fmt.Errorf( + "resource ID does not represent a Foundry project (missing parent account): %s", + projectResourceID, + ) + } + + prefix := fmt.Sprintf( + "https://ai.azure.com/nextgen/r/%s,%s,,%s,%s", + encodedSub, resourceID.ResourceGroupName, + resourceID.Parent.Name, resourceID.Name, + ) + return &PortalPrefix{prefix: prefix}, nil +} + +// EvalRunURL returns the portal URL for an eval run report. +func (p *PortalPrefix) EvalRunURL(evalID, runID string) string { + return fmt.Sprintf("%s/build/evaluations/%s/run/%s", p.prefix, evalID, runID) +} + +// EvaluatorURL returns the portal URL for a generated evaluator. +func (p *PortalPrefix) EvaluatorURL(evaluatorName, version string) string { + return fmt.Sprintf("%s/build/evaluations/catalog/%s/%s", p.prefix, evaluatorName, version) +} + +// DatasetURL returns the portal URL for a dataset. +func (p *PortalPrefix) DatasetURL(datasetName, version string) string { + return fmt.Sprintf("%s/build/data/datasets/%s/%s", p.prefix, datasetName, version) +} + +// OptimizationURL returns the portal URL for an optimization job. +func (p *PortalPrefix) OptimizationURL(agentName, operationID string) string { + return fmt.Sprintf("%s/build/agents/%s/optimization/%s", + p.prefix, agentName, operationID) +} + +// encodeSubscriptionForURL encodes a subscription ID GUID as base64 without padding. +func encodeSubscriptionForURL(subscriptionID string) (string, error) { + guid, err := uuid.Parse(subscriptionID) + if err != nil { + return "", fmt.Errorf("invalid subscription ID format: %w", err) + } + guidBytes, _ := guid.MarshalBinary() + return strings.TrimRight(base64.URLEncoding.EncodeToString(guidBytes), "="), nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/portal_urls_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/portal_urls_test.go new file mode 100644 index 00000000000..264418af439 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/eval_api/portal_urls_test.go @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package eval_api + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewPortalPrefix_Valid(t *testing.T) { + t.Parallel() + resID := "/subscriptions/00000000-0000-0000-0000-000000000001/resourceGroups/rg1/providers/Microsoft.CognitiveServices/accounts/acct1/projects/proj1" + prefix, err := NewPortalPrefix(resID) + require.NoError(t, err) + assert.Contains(t, prefix.prefix, "ai.azure.com/nextgen/r/") + assert.Contains(t, prefix.prefix, "rg1") + assert.Contains(t, prefix.prefix, "acct1") + assert.Contains(t, prefix.prefix, "proj1") +} + +func TestNewPortalPrefix_InvalidResourceID(t *testing.T) { + t.Parallel() + _, err := NewPortalPrefix("not-a-resource-id") + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse") +} + +func TestNewPortalPrefix_MissingParent(t *testing.T) { + t.Parallel() + // Resource ID without a parent (not a nested resource). + resID := "/subscriptions/00000000-0000-0000-0000-000000000001/resourceGroups/rg1/providers/Microsoft.CognitiveServices/accounts/acct1" + _, err := NewPortalPrefix(resID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Foundry project") +} + +func TestPortalPrefix_EvalRunURL(t *testing.T) { + t.Parallel() + p := &PortalPrefix{prefix: "https://ai.azure.com/nextgen/r/sub,rg,,acct,proj"} + url := p.EvalRunURL("eval-123", "run-456") + assert.Equal(t, "https://ai.azure.com/nextgen/r/sub,rg,,acct,proj/build/evaluations/eval-123/run/run-456", url) +} + +func TestPortalPrefix_EvaluatorURL(t *testing.T) { + t.Parallel() + p := &PortalPrefix{prefix: "https://ai.azure.com/nextgen/r/sub,rg,,acct,proj"} + url := p.EvaluatorURL("coherence", "v1") + assert.Equal(t, "https://ai.azure.com/nextgen/r/sub,rg,,acct,proj/build/evaluations/catalog/coherence/v1", url) +} + +func TestPortalPrefix_DatasetURL(t *testing.T) { + t.Parallel() + p := &PortalPrefix{prefix: "https://ai.azure.com/nextgen/r/sub,rg,,acct,proj"} + url := p.DatasetURL("my-dataset", "v2") + assert.Equal(t, "https://ai.azure.com/nextgen/r/sub,rg,,acct,proj/build/data/datasets/my-dataset/v2", url) +} + +func TestPortalPrefix_OptimizationURL(t *testing.T) { + t.Parallel() + p := &PortalPrefix{prefix: "https://ai.azure.com/nextgen/r/sub,rg,,acct,proj"} + url := p.OptimizationURL("my-agent", "op-789") + assert.Equal(t, "https://ai.azure.com/nextgen/r/sub,rg,,acct,proj/build/agents/my-agent/optimization/op-789", url) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/opt_eval/state.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/opt_eval/state.go new file mode 100644 index 00000000000..078d807fd6f --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/opt_eval/state.go @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// state.go centralizes transient runtime state that is persisted in the azd +// environment across CLI invocations. This covers eval job tracking and any +// other cross-invocation state needed by eval, optimize, or related commands. + +package opt_eval + +import ( + "context" + "errors" + "fmt" + "log" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +// EvalState holds transient runtime state stored in the azd environment +// for tracking generation job progress across CLI invocations. +type EvalState struct { + InitStatus string // overall init status + DatasetGenOpID string // dataset generation operation ID + DatasetGenStatus string // dataset generation job status + EvalGenOpID string // evaluator generation operation ID + EvalGenStatus string // evaluator generation job status + EvalID string // created eval ID for running evals +} + +// InitStatus values. +const ( + InitStatusPending = "pending" + InitStatusCompleted = "completed" +) + +// Azd environment keys for persisting eval state across CLI invocations. +const ( + evalKeyInitStatus = "LAST_EVAL_INIT_STATUS" + evalKeyDatasetGenOpID = "LAST_EVAL_DATASET_GEN_OP_ID" + evalKeyDatasetGenStatus = "LAST_EVAL_DATASET_GEN_STATUS" + evalKeyEvalGenOpID = "LAST_EVAL_GEN_OP_ID" + evalKeyEvalGenStatus = "LAST_EVAL_GEN_STATUS" + evalKeyEvalID = "LAST_EVAL_ID" +) + +// LoadEvalState reads eval runtime state from the azd environment. +// Individual key-read errors are logged but do not prevent loading +// the remaining keys; a partial state is still useful for resume logic. +func LoadEvalState(ctx context.Context, azdClient *azdext.AzdClient, envName string) *EvalState { + get := func(key string) string { + v, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: envName, Key: key, + }) + if err != nil { + log.Printf("LoadEvalState: failed to read %s: %v", key, err) + return "" + } + return v.Value + } + return &EvalState{ + InitStatus: get(evalKeyInitStatus), + DatasetGenOpID: get(evalKeyDatasetGenOpID), + DatasetGenStatus: get(evalKeyDatasetGenStatus), + EvalGenOpID: get(evalKeyEvalGenOpID), + EvalGenStatus: get(evalKeyEvalGenStatus), + EvalID: get(evalKeyEvalID), + } +} + +// SaveEvalState persists eval runtime state to the azd environment. +func SaveEvalState(ctx context.Context, azdClient *azdext.AzdClient, envName string, state *EvalState) error { + pairs := []struct { + key, val string + }{ + {evalKeyInitStatus, state.InitStatus}, + {evalKeyDatasetGenOpID, state.DatasetGenOpID}, + {evalKeyDatasetGenStatus, state.DatasetGenStatus}, + {evalKeyEvalGenOpID, state.EvalGenOpID}, + {evalKeyEvalGenStatus, state.EvalGenStatus}, + {evalKeyEvalID, state.EvalID}, + } + for _, p := range pairs { + if _, err := azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ + EnvName: envName, Key: p.key, Value: p.val, + }); err != nil { + return fmt.Errorf("setting %s in azd env: %w", p.key, err) + } + } + return nil +} + +// ClearEvalState removes eval state keys from the azd environment. +func ClearEvalState(ctx context.Context, azdClient *azdext.AzdClient, envName string) error { + var errs []error + for _, key := range []string{ + evalKeyInitStatus, evalKeyDatasetGenOpID, evalKeyDatasetGenStatus, + evalKeyEvalGenOpID, evalKeyEvalGenStatus, evalKeyEvalID, + } { + _, err := azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ + EnvName: envName, Key: key, Value: "", + }) + if err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/opt_eval/yaml.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/opt_eval/yaml.go new file mode 100644 index 00000000000..58a98e1a8bd --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/opt_eval/yaml.go @@ -0,0 +1,441 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package opt_eval + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "azureaiagent/internal/pkg/agents/agent_yaml" + + "go.yaml.in/yaml/v3" +) + +// SafePath validates that joining baseDir with an untrusted relative path +// does not escape baseDir (zip-slip prevention). Returns the cleaned +// absolute path or an error. +func SafePath(baseDir, untrusted string) (string, error) { + p := filepath.Join(baseDir, filepath.FromSlash(untrusted)) + p = filepath.Clean(p) + + rel, err := filepath.Rel(baseDir, p) + if err != nil { + return "", fmt.Errorf("path %q escapes base directory", untrusted) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("path %q escapes base directory", untrusted) + } + return p, nil +} + +// Config is the shared YAML configuration for eval and optimize commands. +// +// Contains fields common to both commands. Optimize-specific fields +// (Criteria, ValidationReference, etc) live in +// the OptimizeConfig wrapper in the cmd package. +// +// Runtime state (operation IDs, eval IDs, status) is stored in +// the azd environment rather than in this config file. +type Config struct { + Name string `yaml:"name,omitempty"` + Agent AgentRef `yaml:"agent"` + DatasetFile string `yaml:"dataset_file,omitempty"` + DatasetReference *DatasetRef `yaml:"dataset_reference,omitempty"` + Evaluators EvaluatorList `yaml:"evaluators,omitempty"` +} + +// EvaluatorRef describes an evaluator. It can be a simple string name or a +// structured entry with name, version, and local_uri. +type EvaluatorRef struct { + Name string `yaml:"name" json:"name"` + Version string `yaml:"version,omitempty" json:"version,omitempty"` + LocalURI string `yaml:"local_uri,omitempty" json:"local_uri,omitempty"` +} + +// EvaluatorList is a list of evaluators that supports mixed YAML: +// +// evaluators: +// - builtin.task_adherence +// - name: custom-quality +// version: "2" +// local_uri: evaluators/custom-quality_2.json +type EvaluatorList []EvaluatorRef + +// UnmarshalYAML handles both plain string and mapping entries. +func (el *EvaluatorList) UnmarshalYAML(value *yaml.Node) error { + if value.Kind != yaml.SequenceNode { + return fmt.Errorf("evaluators must be a sequence, got %v", value.Kind) + } + + result := make([]EvaluatorRef, 0, len(value.Content)) + for _, item := range value.Content { + switch item.Kind { + case yaml.ScalarNode: + // Plain string entry: "builtin.task_adherence" + result = append(result, EvaluatorRef{Name: item.Value}) + case yaml.MappingNode: + // Structured entry: {name: ..., version: ..., local_uri: ...} + var ref EvaluatorRef + if err := item.Decode(&ref); err != nil { + return fmt.Errorf("parsing evaluator entry: %w", err) + } + result = append(result, ref) + default: + return fmt.Errorf("unexpected evaluator entry type: %v", item.Kind) + } + } + *el = result + return nil +} + +// MarshalYAML emits plain strings for simple evaluators and mappings for +// structured ones (those with version or local_uri). +func (el EvaluatorList) MarshalYAML() (any, error) { + nodes := make([]*yaml.Node, 0, len(el)) + for _, ref := range el { + if ref.Version == "" && ref.LocalURI == "" { + // Emit as a plain string. + nodes = append(nodes, &yaml.Node{ + Kind: yaml.ScalarNode, + Tag: "!!str", + Value: ref.Name, + }) + } else { + // Emit as a mapping. + var n yaml.Node + if err := n.Encode(ref); err != nil { + return nil, err + } + nodes = append(nodes, &n) + } + } + return &yaml.Node{Kind: yaml.SequenceNode, Content: nodes}, nil +} + +// Names returns the evaluator names as a plain string slice. +func (el EvaluatorList) Names() []string { + names := make([]string, len(el)) + for i, ref := range el { + names[i] = ref.Name + } + return names +} + +// FindByLocalURI returns all evaluators that have a local_uri set. +func (el EvaluatorList) FindByLocalURI() []EvaluatorRef { + var refs []EvaluatorRef + for _, ref := range el { + if ref.LocalURI != "" { + refs = append(refs, ref) + } + } + return refs +} + +// SetVersion updates the version of a named evaluator in the list. +func (el EvaluatorList) SetVersion(name, version string) { + for i := range el { + if el[i].Name == name { + el[i].Version = version + return + } + } +} + +// SetLocalURI updates the local_uri of a named evaluator in the list. +func (el EvaluatorList) SetLocalURI(name, uri string) { + for i := range el { + if el[i].Name == name { + el[i].LocalURI = uri + return + } + } +} + +// Agent config directory structure +// +// Each agent configuration version (baseline or optimized candidate) is stored +// under AgentConfigsDir as a self-contained directory with a fixed layout: +// +// .agent_configs/ +// ├── baseline/ # original agent config captured by eval init or optimize +// │ ├── metadata.yaml # MetadataFile — model, file pointers +// │ ├── instructions.md # InstructionFile — system prompt +// │ ├── skills/ # SkillsDir — skill definitions (optional) +// │ └── tools.json # ToolsFile — tool definitions (optional) +// └── / # optimized candidate written by optimize apply +// ├── metadata.yaml +// ├── instructions.md +// ├── skills/ +// └── tools.json +// +// Both eval and optimize commands share these constants and layout conventions. +// Eval init writes the baseline directory; optimize apply writes candidate +// directories and reads the baseline for diff display. +const ( + // AgentConfigsDir is the top-level folder that holds agent configuration + // versions (baseline and optimized candidates). + AgentConfigsDir = ".agent_configs" + + // BaselineDir is the subdirectory name for the original agent configuration. + BaselineDir = "baseline" + + // MetadataFile is the YAML file in each config directory that describes + // the agent model, instruction file path, skill directory, and tools file. + MetadataFile = "metadata.yaml" + + // InstructionFile is the Markdown file containing the agent's system prompt. + InstructionFile = "instructions.md" + + // SkillsDir is the subdirectory containing skill definition files. + SkillsDir = "skills" + + // ToolsFile is the JSON file containing tool definitions. + ToolsFile = "tools.json" +) + +// BaselineConfigRelPath returns the project-relative path to the baseline +// metadata file: ".agent_configs/baseline/metadata.yaml". +func BaselineConfigRelPath() string { + return filepath.Join(AgentConfigsDir, BaselineDir, MetadataFile) +} + +// AgentConfig holds resolved agent configuration from metadata.yaml. +// Unlike AgentRef (the YAML-serializable reference), AgentConfig contains +// fully resolved absolute paths and values for use during command execution. +type AgentConfig struct { + ConfigFile string // project-relative path to metadata.yaml + Model string // resolved model name + InstructionFile string // absolute path to instruction file + SkillDir string // absolute path to skills directory + ToolsFile string // absolute path to tools definition file +} + +// ResolvedInstruction reads and returns the instruction file content. +// Returns empty string if no instruction file is set or the file cannot be read. +func (c *AgentConfig) ResolvedInstruction() string { + if c.InstructionFile == "" { + return "" + } + data, err := os.ReadFile(c.InstructionFile) //nolint:gosec // path from project config + if err != nil { + return "" + } + return string(data) +} + +// AgentRef references the agent under evaluation/optimization. +// Optimize-specific fields (skill_dir, tools_file) are stored in +// OptimizeConfig, not here, so eval.yaml stays target-agnostic. +type AgentRef struct { + Name string `yaml:"name"` + Kind agent_yaml.AgentKind `yaml:"kind,omitempty"` + Version string `yaml:"version,omitempty"` + ConfigFile string `yaml:"config,omitempty"` + Model string `yaml:"model,omitempty"` + // Not serialized to YAML — populated at runtime from config or flags. + Instruction InstructionRef `yaml:"-"` +} + +// ResolveConfig loads the metadata.yaml pointed to by ConfigFile and returns +// a resolved AgentConfig without mutating the AgentRef. Relative paths inside +// metadata.yaml are resolved against the directory containing the config file. +// Returns nil if ConfigFile is not set. +func (a *AgentRef) ResolveConfig(projectDir string) *AgentConfig { + if a.ConfigFile == "" { + return nil + } + + configPath := a.ConfigFile + if !filepath.IsAbs(configPath) { + configPath = filepath.Join(projectDir, configPath) + } + configDir := filepath.Dir(configPath) + + cfg := &AgentConfig{ConfigFile: a.ConfigFile} + + data, err := os.ReadFile(configPath) //nolint:gosec // path from project config + if err != nil { + return cfg + } + + var meta struct { + Model string `yaml:"model"` + InstructionFile string `yaml:"instruction_file"` + SkillDir string `yaml:"skill_dir"` + ToolsFile string `yaml:"tools_file"` + } + if err := yaml.Unmarshal(data, &meta); err != nil { + return cfg + } + + cfg.Model = meta.Model + if meta.InstructionFile != "" { + instrPath := meta.InstructionFile + if !filepath.IsAbs(instrPath) { + instrPath = filepath.Join(configDir, instrPath) + } + cfg.InstructionFile = instrPath + } + if meta.SkillDir != "" { + skillDir := meta.SkillDir + if !filepath.IsAbs(skillDir) { + skillDir = filepath.Join(configDir, skillDir) + } + cfg.SkillDir = skillDir + } + if meta.ToolsFile != "" { + toolsFile := meta.ToolsFile + if !filepath.IsAbs(toolsFile) { + toolsFile = filepath.Join(configDir, toolsFile) + } + cfg.ToolsFile = toolsFile + } + + return cfg +} + +// ResolvedSystemPrompt returns the resolved instruction text. +// If the instruction references a file, its contents are read; otherwise the +// inline value is returned. +func (a *AgentRef) ResolvedSystemPrompt() string { + return a.Instruction.Resolve() +} + +// InstructionRef holds an instruction that can be either an inline string or a +// file reference. In YAML it supports two forms: +// +// instruction: "inline text" +// instruction: +// file: ./path/to/file.md +type InstructionRef struct { + Value string `yaml:"-"` // inline text + File string `yaml:"-"` // file reference +} + +// Resolve returns the instruction text. If File is set, the file is read; +// otherwise Value is returned directly. +func (r *InstructionRef) Resolve() string { + if r.File != "" { + data, err := os.ReadFile(r.File) + if err != nil { + return r.Value + } + return string(data) + } + return r.Value +} + +// IsEmpty returns true if neither inline value nor file is set. +func (r *InstructionRef) IsEmpty() bool { + return r.Value == "" && r.File == "" +} + +// UnmarshalYAML allows InstructionRef to be either a plain string or a mapping +// with a "file" key. +func (r *InstructionRef) UnmarshalYAML(value *yaml.Node) error { + if value.Kind == yaml.ScalarNode { + r.Value = value.Value + return nil + } + if value.Kind == yaml.MappingNode { + var m struct { + File string `yaml:"file"` + } + if err := value.Decode(&m); err != nil { + return err + } + r.File = m.File + return nil + } + return fmt.Errorf("instruction must be a string or a mapping with 'file' key") +} + +// MarshalYAML writes InstructionRef as a plain string when inline, or as a +// mapping with "file" when referencing a file. +func (r InstructionRef) MarshalYAML() (any, error) { + if r.File != "" { + return map[string]string{"file": r.File}, nil + } + return r.Value, nil +} + +// DatasetRef references a named/versioned dataset. +type DatasetRef struct { + Name string `yaml:"name"` + Version string `yaml:"version,omitempty"` + LocalURI string `yaml:"local_uri,omitempty"` +} + +// TargetConfig specifies model candidates and other target-specific configuration. +type TargetConfig struct { + Model []string `yaml:"model,omitempty"` +} + +// Options holds run-time options for eval and optimize. +// Eval only uses EvalModel; optimize uses all fields. +type Options struct { + EvalModel string `yaml:"eval_model,omitempty"` + TargetAttributes []string `yaml:"target_attributes,omitempty"` + TargetConfig *TargetConfig `yaml:"target_config,omitempty"` + MaxIterations *int `yaml:"max_iterations,omitempty"` + KeepVersions bool `yaml:"keep_versions,omitempty"` + TasksPerIteration int `yaml:"tasks_per_iteration,omitempty"` + ReflectionModel string `yaml:"reflection_model,omitempty"` + EvaluationLevel string `yaml:"evaluation_level,omitempty"` +} + +// UnmarshalYAML populates default target attributes when the field is absent in YAML. +// For backward compatibility, the legacy "strategies" key is also accepted. +func (o *Options) UnmarshalYAML(value *yaml.Node) error { + // Alias avoids infinite recursion. + type raw Options + if err := value.Decode((*raw)(o)); err != nil { + return err + } + + // Backward compatibility: if "strategies" is present and target_attributes is not, + // migrate the value. + if len(o.TargetAttributes) == 0 { + var legacy struct { + Strategies []string `yaml:"strategies"` + } + _ = value.Decode(&legacy) + if len(legacy.Strategies) > 0 { + o.TargetAttributes = legacy.Strategies + } + } + return nil +} + +// Read reads a YAML config file (eval or optimize format). +func Read(path string) (*Config, error) { + data, err := os.ReadFile(path) //nolint:gosec // path is provided by user for local config + if err != nil { + return nil, fmt.Errorf("failed to read config %q: %w", path, err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse config %q: %w", path, err) + } + + return &cfg, nil +} + +// Write writes a YAML config file. +func Write(path string, cfg *Config) error { + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { + return fmt.Errorf("creating config directory: %w", err) + } + + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + return os.WriteFile(path, data, 0600) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/opt_eval/yaml_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/opt_eval/yaml_test.go new file mode 100644 index 00000000000..cc252d212a3 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/opt_eval/yaml_test.go @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package opt_eval + +import ( + "path/filepath" + "testing" + + "azureaiagent/internal/pkg/agents/agent_yaml" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.yaml.in/yaml/v3" +) + +// --------------------------------------------------------------------------- +// Config Read / Write round-trip +// --------------------------------------------------------------------------- + +func TestConfig_RoundTrip(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + + original := &Config{ + Name: "test-config", + Agent: AgentRef{ + Name: "my-agent", + Kind: agent_yaml.AgentKindHosted, + Version: "v1", + Model: "gpt-4o", + }, + DatasetFile: "tasks.jsonl", + Evaluators: EvaluatorList{{Name: "builtin.quality"}, {Name: "custom-1"}}, + } + + require.NoError(t, Write(path, original)) + loaded, err := Read(path) + require.NoError(t, err) + + assert.Equal(t, "test-config", loaded.Name) + assert.Equal(t, "my-agent", loaded.Agent.Name) + assert.Equal(t, agent_yaml.AgentKindHosted, loaded.Agent.Kind) + assert.Equal(t, "v1", loaded.Agent.Version) + assert.Equal(t, "gpt-4o", loaded.Agent.Model) + assert.Equal(t, "tasks.jsonl", loaded.DatasetFile) + require.Len(t, loaded.Evaluators, 2) + assert.Equal(t, "builtin.quality", loaded.Evaluators[0].Name) + assert.Equal(t, "custom-1", loaded.Evaluators[1].Name) +} + +func TestConfig_RoundTrip_MixedEvaluators(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + + original := &Config{ + Agent: AgentRef{Name: "agent-x"}, + Evaluators: EvaluatorList{ + {Name: "builtin.task_adherence"}, + {Name: "custom-quality", Version: "2", LocalURI: "evaluators/custom-quality_2.json"}, + }, + } + + require.NoError(t, Write(path, original)) + loaded, err := Read(path) + require.NoError(t, err) + + require.Len(t, loaded.Evaluators, 2) + assert.Equal(t, "builtin.task_adherence", loaded.Evaluators[0].Name) + assert.Empty(t, loaded.Evaluators[0].Version) + assert.Empty(t, loaded.Evaluators[0].LocalURI) + assert.Equal(t, "custom-quality", loaded.Evaluators[1].Name) + assert.Equal(t, "2", loaded.Evaluators[1].Version) + assert.Equal(t, "evaluators/custom-quality_2.json", loaded.Evaluators[1].LocalURI) +} + +func TestEvaluatorList_Names(t *testing.T) { + t.Parallel() + list := EvaluatorList{{Name: "a"}, {Name: "b"}, {Name: "c"}} + assert.Equal(t, []string{"a", "b", "c"}, list.Names()) +} + +func TestEvaluatorList_FindByLocalURI(t *testing.T) { + t.Parallel() + list := EvaluatorList{ + {Name: "builtin.x"}, + {Name: "custom", LocalURI: "/path/to/file.json"}, + {Name: "other"}, + } + found := list.FindByLocalURI() + require.Len(t, found, 1) + assert.Equal(t, "custom", found[0].Name) +} + +func TestEvaluatorList_SetVersion(t *testing.T) { + t.Parallel() + list := EvaluatorList{{Name: "a", Version: "1"}, {Name: "b"}} + list.SetVersion("b", "3") + assert.Equal(t, "3", list[1].Version) + // Non-matching name is a no-op. + list.SetVersion("nonexistent", "99") + assert.Equal(t, "1", list[0].Version) +} + +func TestConfig_RoundTrip_DatasetReference(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + + original := &Config{ + Agent: AgentRef{Name: "a1"}, + DatasetReference: &DatasetRef{Name: "golden", Version: "v2"}, + } + + require.NoError(t, Write(path, original)) + loaded, err := Read(path) + require.NoError(t, err) + + require.NotNil(t, loaded.DatasetReference) + assert.Equal(t, "golden", loaded.DatasetReference.Name) + assert.Equal(t, "v2", loaded.DatasetReference.Version) + assert.Empty(t, loaded.DatasetFile) +} + +func TestRead_MissingFile(t *testing.T) { + t.Parallel() + _, err := Read("/nonexistent/config.yaml") + assert.Error(t, err) +} + +func TestWrite_CreatesDirectory(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "sub", "nested", "config.yaml") + + cfg := &Config{Agent: AgentRef{Name: "a1"}} + require.NoError(t, Write(path, cfg)) + assert.FileExists(t, path) +} + +// --------------------------------------------------------------------------- +// AgentRef fields +// --------------------------------------------------------------------------- + +func TestAgentRef_YAMLFields(t *testing.T) { + t.Parallel() + + input := ` +name: test-agent +kind: hosted +version: v5 +model: gpt-4.1 +` + var ref AgentRef + require.NoError(t, yaml.Unmarshal([]byte(input), &ref)) + + assert.Equal(t, "test-agent", ref.Name) + assert.Equal(t, agent_yaml.AgentKindHosted, ref.Kind) + assert.Equal(t, "v5", ref.Version) + assert.Equal(t, "gpt-4.1", ref.Model) +} + +// --------------------------------------------------------------------------- +// DatasetRef fields +// --------------------------------------------------------------------------- + +func TestDatasetRef_YAMLFields(t *testing.T) { + t.Parallel() + + input := ` +name: golden-data +version: v3 +` + var ref DatasetRef + require.NoError(t, yaml.Unmarshal([]byte(input), &ref)) + + assert.Equal(t, "golden-data", ref.Name) + assert.Equal(t, "v3", ref.Version) +} + +// --------------------------------------------------------------------------- +// Options fields +// --------------------------------------------------------------------------- + +func TestOptions_YAMLFields(t *testing.T) { + t.Parallel() + + input := ` +eval_model: gpt-4.1 +target_attributes: + - prompt + - tool +max_iterations: 10 +keep_versions: true +tasks_per_iteration: 20 +reflection_model: gpt-4o +` + var opts Options + require.NoError(t, yaml.Unmarshal([]byte(input), &opts)) + + assert.Equal(t, "gpt-4.1", opts.EvalModel) + assert.Equal(t, []string{"prompt", "tool"}, opts.TargetAttributes) + require.NotNil(t, opts.MaxIterations) + assert.Equal(t, 10, *opts.MaxIterations) + assert.True(t, opts.KeepVersions) + assert.Equal(t, 20, opts.TasksPerIteration) + assert.Equal(t, "gpt-4o", opts.ReflectionModel) +} + +func TestOptions_LegacyStrategiesBackwardCompat(t *testing.T) { + t.Parallel() + + input := ` +eval_model: gpt-4.1 +strategies: + - prompt + - tool +` + var opts Options + require.NoError(t, yaml.Unmarshal([]byte(input), &opts)) + + assert.Equal(t, []string{"prompt", "tool"}, opts.TargetAttributes) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/client.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/client.go new file mode 100644 index 00000000000..0350168e538 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/client.go @@ -0,0 +1,365 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Package optimize_api provides an HTTP client for the agent optimization +// service API. It supports job submission, status polling, cancellation, +// and candidate config/file retrieval. +package optimize_api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + netURL "net/url" + + "azureaiagent/internal/version" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/azure/azure-dev/cli/azd/pkg/azsdk" +) + +// OptimizeClient provides methods for interacting with the Agents Optimization API. +type OptimizeClient struct { + endpoint string + pipeline runtime.Pipeline +} + +// NewOptimizeClient creates a new OptimizeClient with the given endpoint and credential. +func NewOptimizeClient(endpoint string, cred azcore.TokenCredential) *OptimizeClient { + userAgent := fmt.Sprintf("azd-ext-azure-ai-agents/%s", version.Version) + + clientOptions := &policy.ClientOptions{ + Logging: policy.LogOptions{ + AllowedHeaders: []string{"X-Ms-Correlation-Request-Id", "X-Request-Id"}, + IncludeBody: false, + }, + PerCallPolicies: []policy.Policy{ + runtime.NewBearerTokenPolicy(cred, []string{"https://ai.azure.com/.default"}, nil), + azsdk.NewMsCorrelationPolicy(), + azsdk.NewUserAgentPolicy(userAgent), + }, + } + + pipeline := runtime.NewPipeline( + "agents-optimization", + "v1.0.0", + runtime.PipelineOptions{}, + clientOptions, + ) + + return &OptimizeClient{ + endpoint: endpoint, + pipeline: pipeline, + } +} + +// NewOptimizeClientFromPipeline creates an OptimizeClient with a pre-built pipeline. +// This is intended for tests that need to bypass auth policies. +func NewOptimizeClientFromPipeline(endpoint string, pipeline runtime.Pipeline) *OptimizeClient { + return &OptimizeClient{ + endpoint: endpoint, + pipeline: pipeline, + } +} + +// StartOptimize submits a new optimization job. +func (c *OptimizeClient) StartOptimize( + ctx context.Context, + optimizeReq *OptimizeRequest, +) (*OptimizeResponse, error) { + url := fmt.Sprintf("%s/optimize?api-version=%s", c.endpoint, APIVersion) + + payload, err := json.Marshal(optimizeReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := runtime.NewRequest(ctx, http.MethodPost, url) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + if err := req.SetBody(streaming.NopCloser(bytes.NewReader(payload)), "application/json"); err != nil { + return nil, fmt.Errorf("failed to set request body: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusAccepted) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var result OptimizeResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} + +// GetOptimizeStatus retrieves the status of an optimization job. +func (c *OptimizeClient) GetOptimizeStatus( + ctx context.Context, + operationID string, +) (*OptimizeJobStatus, error) { + url := fmt.Sprintf("%s/optimize/%s?api-version=%s", c.endpoint, netURL.PathEscape(operationID), APIVersion) + + req, err := runtime.NewRequest(ctx, http.MethodGet, url) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var result OptimizeJobStatus + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} + +// ListOptimizeJobs lists optimization jobs with optional filtering. +func (c *OptimizeClient) ListOptimizeJobs( + ctx context.Context, + limit int, + status string, +) (*OptimizeListResponse, error) { + url := fmt.Sprintf("%s/optimize?api-version=%s&limit=%d", c.endpoint, APIVersion, limit) + if status != "" { + url += "&status=" + netURL.QueryEscape(status) + } + + req, err := runtime.NewRequest(ctx, http.MethodGet, url) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var result OptimizeListResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} + +// CancelOptimize cancels a running optimization job. +func (c *OptimizeClient) CancelOptimize( + ctx context.Context, + operationID string, +) (*OptimizeCancelResponse, error) { + url := fmt.Sprintf("%s/optimize/%s/cancel?api-version=%s", c.endpoint, netURL.PathEscape(operationID), APIVersion) + + req, err := runtime.NewRequest(ctx, http.MethodPost, url) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var result OptimizeCancelResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} + +// ReportDeployment notifies the optimization service that a candidate has been +// deployed. This allows the optimization service to track which candidates have been deployed. +func (c *OptimizeClient) ReportDeployment( + ctx context.Context, + report *DeploymentReport, +) error { + url := fmt.Sprintf( + "%s/optimize/candidates/%s:promote?api-version=%s", + c.endpoint, netURL.PathEscape(report.CandidateID), APIVersion, + ) + + payload, err := json.Marshal(report) + if err != nil { + return fmt.Errorf("failed to marshal deployment report: %w", err) + } + + req, err := runtime.NewRequest(ctx, http.MethodPost, url) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + if err := req.SetBody( + streaming.NopCloser(bytes.NewReader(payload)), "application/json", + ); err != nil { + return fmt.Errorf("failed to set request body: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated, http.StatusAccepted, http.StatusNoContent) { + return runtime.NewResponseError(resp) + } + + return nil +} + +// GetCandidateConfig fetches the candidate configuration from the optimization service. +// GET /optimize/candidates/{id}/config +func (c *OptimizeClient) GetCandidateConfig( + ctx context.Context, + candidateID string, +) (json.RawMessage, error) { + url := fmt.Sprintf("%s/optimize/candidates/%s/config?api-version=%s", c.endpoint, netURL.PathEscape(candidateID), APIVersion) + + req, err := runtime.NewRequest(ctx, http.MethodGet, url) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Validate that the body is valid JSON. + if !json.Valid(body) { + return nil, fmt.Errorf("candidate config is not valid JSON") + } + return json.RawMessage(body), nil +} + +// GetCandidate fetches the candidate manifest (metadata + file list) from the optimization service. +// GET /optimize/candidates/{id} +func (c *OptimizeClient) GetCandidate( + ctx context.Context, + candidateID string, +) (*CandidateManifest, error) { + url := fmt.Sprintf("%s/optimize/candidates/%s?api-version=%s", c.endpoint, netURL.PathEscape(candidateID), APIVersion) + + req, err := runtime.NewRequest(ctx, http.MethodGet, url) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var manifest CandidateManifest + if err := json.Unmarshal(body, &manifest); err != nil { + return nil, fmt.Errorf("failed to parse candidate manifest: %w", err) + } + return &manifest, nil +} + +// GetCandidateFile downloads a single file from a candidate. +// GET /optimize/candidates/{id}/files?path={path} +func (c *OptimizeClient) GetCandidateFile( + ctx context.Context, + candidateID string, + filePath string, +) (string, error) { + url := fmt.Sprintf("%s/optimize/candidates/%s/files?api-version=%s&path=%s", + c.endpoint, netURL.PathEscape(candidateID), APIVersion, netURL.QueryEscape(filePath)) + + req, err := runtime.NewRequest(ctx, http.MethodGet, url) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return "", fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return "", runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + + return string(body), nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/client_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/client_test.go new file mode 100644 index 00000000000..52e72e59c93 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/client_test.go @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package optimize_api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestClient builds an OptimizeClient that talks to the given httptest server +// with no auth (bare pipeline). +func newTestClient(serverURL string) *OptimizeClient { + pipeline := runtime.NewPipeline( + "test", + "v0.0.0", + runtime.PipelineOptions{}, + &policy.ClientOptions{}, + ) + return &OptimizeClient{ + endpoint: serverURL, + pipeline: pipeline, + } +} + +// stubCredential satisfies azcore.TokenCredential for constructor tests. +type stubCredential struct{} + +func (stubCredential) GetToken(_ context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) { + return azcore.AccessToken{Token: "stub"}, nil +} + +func TestNewOptimizeClient(t *testing.T) { + t.Parallel() + client := NewOptimizeClient("https://example.com", stubCredential{}) + require.NotNil(t, client) + assert.Equal(t, "https://example.com", client.endpoint) +} + +func TestStartOptimize(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.True(t, strings.HasSuffix(r.URL.Path, "/optimize")) + assert.Contains(t, r.URL.RawQuery, "api-version=v1") + + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(OptimizeResponse{ + OperationID: "op-abc", + Status: StatusQueued, + }) + })) + defer server.Close() + + client := newTestClient(server.URL) + resp, err := client.StartOptimize(context.Background(), &OptimizeRequest{ + Agent: AgentDefinition{ + FoundryProjectURL: "https://example.com/proj", + AgentName: "agent-1", + }, + Options: OptimizeOptions{EvalModel: "gpt-4o-mini"}, + }) + + require.NoError(t, err) + assert.Equal(t, "op-abc", resp.OperationID) + assert.Equal(t, StatusQueued, resp.Status) +} + +func TestGetOptimizeStatus(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + assert.Contains(t, r.URL.Path, "/optimize/op-123") + assert.Contains(t, r.URL.RawQuery, "api-version=v1") + + _ = json.NewEncoder(w).Encode(OptimizeJobStatus{ + OperationID: "op-123", + Status: StatusCompleted, + CreatedAt: "2024-01-01T00:00:00Z", + UpdatedAt: "2024-01-01T01:00:00Z", + Best: &CandidateResult{ + Name: "candidate-1", + AvgScore: 0.92, + PassRate: 0.95, + }, + Baseline: &CandidateResult{ + Name: "baseline", + AvgScore: 0.6, + }, + }) + })) + defer server.Close() + + client := newTestClient(server.URL) + status, err := client.GetOptimizeStatus(context.Background(), "op-123") + + require.NoError(t, err) + assert.Equal(t, "op-123", status.OperationID) + assert.Equal(t, StatusCompleted, status.Status) + require.NotNil(t, status.Best) + assert.InDelta(t, 0.92, status.Best.AvgScore, 0.001) + require.NotNil(t, status.Baseline) + assert.InDelta(t, 0.6, status.Baseline.AvgScore, 0.001) +} + +func TestListOptimizeJobs(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + assert.Contains(t, r.URL.RawQuery, "limit=10") + assert.Contains(t, r.URL.RawQuery, "status=running") + assert.Contains(t, r.URL.RawQuery, "api-version=v1") + + _ = json.NewEncoder(w).Encode(OptimizeListResponse{ + Data: []OptimizeJobStatus{ + {OperationID: "op-1", Status: StatusRunning}, + {OperationID: "op-2", Status: StatusRunning}, + }, + FirstID: "op-1", + LastID: "op-2", + HasMore: false, + }) + })) + defer server.Close() + + client := newTestClient(server.URL) + resp, err := client.ListOptimizeJobs(context.Background(), 10, "running") + + require.NoError(t, err) + assert.Len(t, resp.Data, 2) + assert.Equal(t, "op-1", resp.FirstID) + assert.Equal(t, "op-2", resp.LastID) + assert.False(t, resp.HasMore) +} + +func TestCancelOptimize(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Contains(t, r.URL.Path, "/optimize/op-xyz/cancel") + assert.Contains(t, r.URL.RawQuery, "api-version=v1") + + _ = json.NewEncoder(w).Encode(OptimizeCancelResponse{ + OperationID: "op-xyz", + Status: StatusCancelled, + }) + })) + defer server.Close() + + client := newTestClient(server.URL) + resp, err := client.CancelOptimize(context.Background(), "op-xyz") + + require.NoError(t, err) + assert.Equal(t, "op-xyz", resp.OperationID) + assert.Equal(t, StatusCancelled, resp.Status) +} + +func TestStartOptimize_HTTPError(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error": {"code": "BadRequest", "message": "invalid payload"}}`)) + })) + defer server.Close() + + client := newTestClient(server.URL) + resp, err := client.StartOptimize(context.Background(), &OptimizeRequest{}) + + assert.Nil(t, resp) + require.Error(t, err) + assert.Contains(t, err.Error(), "400") +} + +func TestGetOptimizeStatus_HTTPError(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error": {"code": "NotFound", "message": "job not found"}}`)) + })) + defer server.Close() + + client := newTestClient(server.URL) + resp, err := client.GetOptimizeStatus(context.Background(), "nonexistent") + + assert.Nil(t, resp) + require.Error(t, err) + assert.Contains(t, err.Error(), "404") +} + +func TestListOptimizeJobs_NoStatusFilter(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.NotContains(t, r.URL.RawQuery, "status=") + _ = json.NewEncoder(w).Encode(OptimizeListResponse{ + Data: []OptimizeJobStatus{}, + }) + })) + defer server.Close() + + client := newTestClient(server.URL) + resp, err := client.ListOptimizeJobs(context.Background(), 20, "") + + require.NoError(t, err) + assert.Empty(t, resp.Data) +} + +func TestReportDeployment(t *testing.T) { + t.Parallel() + + var capturedBody map[string]string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Contains(t, r.URL.Path, "/optimize/candidates/cand-42:promote") + assert.Contains(t, r.URL.RawQuery, "api-version=v1") + + err := json.NewDecoder(r.Body).Decode(&capturedBody) + assert.NoError(t, err) + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClient(server.URL) + err := client.ReportDeployment(t.Context(), &DeploymentReport{ + CandidateID: "cand-42", + AgentName: "my-agent", + AgentVersion: "3", + }) + + require.NoError(t, err) + assert.Equal(t, "my-agent", capturedBody["agentName"]) + assert.Equal(t, "3", capturedBody["agentVersion"]) + // CandidateID should not appear in the body (json:"-") + assert.Empty(t, capturedBody["candidateId"]) +} + +func TestReportDeployment_HTTPError(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"code":"BadRequest","message":"invalid candidate"}}`)) + })) + defer server.Close() + + client := newTestClient(server.URL) + err := client.ReportDeployment(t.Context(), &DeploymentReport{ + CandidateID: "bad-id", + AgentName: "agent", + AgentVersion: "1", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "400") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/models.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/models.go new file mode 100644 index 00000000000..c2ac594e283 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/models.go @@ -0,0 +1,245 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// models.go defines the request and response types for the optimization +// service API, including job status, candidate results, agent definitions, +// dataset tasks, and skill/tool definitions. +package optimize_api + +import "encoding/json" + +// APIVersion is the API version used for all optimization service calls. +const APIVersion = "v1" + +// Optimization job status constants. +const ( + StatusPending = "pending" + StatusRunning = "running" + StatusCompleted = "completed" + StatusFailed = "failed" + StatusCancelled = "cancelled" + + // StatusQueued is a deprecated alias for StatusPending. + // The API returns "pending", not "queued". + StatusQueued = StatusPending +) + +// IsTerminal returns true if the status represents a terminal state. +func IsTerminal(status string) bool { + switch status { + case StatusCompleted, StatusFailed, StatusCancelled: + return true + default: + return false + } +} + +// --- Request models --- + +// OptimizeRequest is the top-level payload sent to POST /optimize. +type OptimizeRequest struct { + Agent AgentDefinition `json:"agent"` + Dataset []DatasetTask `json:"dataset,omitempty"` + TrainDatasetReference *DatasetReference `json:"trainDatasetReference,omitempty"` + ValidationDatasetReference *DatasetReference `json:"validationDatasetReference,omitempty"` + Evaluators []string `json:"evaluators,omitempty"` + Criteria []Criterion `json:"criteria,omitempty"` + Options OptimizeOptions `json:"options"` +} + +// AgentDefinition identifies the agent to optimize. +type AgentDefinition struct { + FoundryProjectURL string `json:"foundryProjectUrl"` + AgentName string `json:"agentName"` + AgentVersion string `json:"agentVersion,omitempty"` + Model string `json:"model,omitempty"` + SystemPrompt string `json:"systemPrompt,omitempty"` + Skills []SkillDefinition `json:"skills,omitempty"` + ToolDefinitions []ToolDefinition `json:"tools,omitempty"` +} + +// SkillDefinition describes a skill attached to an agent. +type SkillDefinition struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Body string `json:"body,omitempty"` +} + +// ToolDefinition is an OpenAI-format function tool definition. +// The optimizer may mutate the function's description and per-parameter +// descriptions; schema fields (name, types, required) are immutable. +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunction `json:"function"` +} + +// ToolFunction is the inner function definition of a ToolDefinition. +type ToolFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// DatasetTask is a single task in an inline dataset. +type DatasetTask struct { + Name string `json:"name,omitempty"` + Query string `json:"query,omitempty"` + Prompt string `json:"prompt"` + GroundTruth string `json:"groundTruth,omitempty"` + Criteria []Criterion `json:"criteria,omitempty"` +} + +// DatasetReference points to a registered dataset by name and version. +type DatasetReference struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// Criterion is a named evaluation criterion. +type Criterion struct { + Name string `json:"name"` + Instruction string `json:"instruction"` +} + +// TargetConfig specifies model candidates and other target-specific configuration. +type TargetConfig struct { + Model []string `json:"model,omitempty"` +} + +// OptimizeOptions controls the optimization run. +type OptimizeOptions struct { + MaxIterations *int `json:"maxIterations,omitempty"` + EvalModel string `json:"evalModel,omitempty"` + // Send as both "strategies" (current server) and "targetAttributes" (future). + Strategies []string `json:"strategies,omitempty"` + TargetAttributes []string `json:"targetAttributes,omitempty"` + TargetConfig *TargetConfig `json:"targetConfig,omitempty"` + KeepVersions bool `json:"keepVersions,omitempty"` + TasksPerIteration int `json:"tasksPerIteration,omitempty"` + MaxReflectionTasks int `json:"maxReflectionTasks,omitempty"` + ReflectionModel string `json:"reflectionModel,omitempty"` + EvaluationLevel string `json:"evaluationLevel,omitempty"` +} + +// --- Response models --- + +// OptimizeResponse is the immediate response from POST /optimize. +type OptimizeResponse struct { + OperationID string `json:"operationId"` + Status string `json:"status"` +} + +// OptimizeJobStatus is the full status of an optimization job. +type OptimizeJobStatus struct { + OperationID string `json:"operationId"` + Status string `json:"status"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + Agent *AgentDefinition `json:"agent,omitempty"` + Progress *JobProgress `json:"progress,omitempty"` + Error *JobError `json:"error,omitempty"` + Baseline *CandidateResult `json:"baseline,omitempty"` + Best *CandidateResult `json:"best,omitempty"` + Candidates []CandidateResult `json:"candidates,omitempty"` + AllTargetAttributesFailed bool `json:"allTargetAttributesFailed,omitempty"` + Warnings []string `json:"warnings,omitempty"` +} + +// JobProgress reports iteration-level progress. +type JobProgress struct { + CurrentTargetAttribute string `json:"currentTargetAttribute"` + CurrentIteration int `json:"currentIteration"` + TasksCompleted int `json:"tasksCompleted"` + TasksTotal int `json:"tasksTotal"` + BestScore float64 `json:"bestScore"` + ElapsedSeconds float64 `json:"elapsedSeconds"` +} + +// JobError captures an error from a failed job. +// The API sometimes returns a string and sometimes an object — this handles both. +type JobError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func (e *JobError) UnmarshalJSON(data []byte) error { + // Try as string first + var s string + if err := json.Unmarshal(data, &s); err == nil { + e.Message = s + return nil + } + // Try as object + type alias JobError + var a alias + if err := json.Unmarshal(data, &a); err != nil { + return err + } + *e = JobError(a) + return nil +} + +// CandidateResult holds the evaluation result for a single candidate. +type CandidateResult struct { + Name string `json:"name"` + AvgScore float64 `json:"avgScore"` + AvgTokens float64 `json:"avgTokens"` + PassRate float64 `json:"passRate"` + Mutations map[string]any `json:"mutations,omitempty"` + Rationale string `json:"rationale,omitempty"` + CandidateID string `json:"candidateId,omitempty"` + TaskScores []TaskScore `json:"taskScores,omitempty"` +} + +// TaskScore captures per-task evaluation metrics. +type TaskScore struct { + TaskName string `json:"taskName"` + Scores map[string]float64 `json:"scores"` + CompositeScore float64 `json:"compositeScore"` + Tokens int `json:"tokens"` + Duration float64 `json:"durationSeconds"` + Passed bool `json:"passed"` +} + +// --- List response --- + +// OptimizeListResponse is the paginated list of optimization jobs. +type OptimizeListResponse struct { + Data []OptimizeJobStatus `json:"data"` + FirstID string `json:"firstId"` + LastID string `json:"lastId"` + HasMore bool `json:"hasMore"` +} + +// --- Cancel response --- + +// OptimizeCancelResponse is returned when cancelling an optimization job. +type OptimizeCancelResponse struct { + OperationID string `json:"operationId"` + Status string `json:"status"` +} + +// --- Deployment report --- + +// DeploymentReport is sent to the optimization service after a candidate is promoted, +// creating the candidate→deployment mapping. +type DeploymentReport struct { + CandidateID string `json:"-"` // used in URL path, not serialized + AgentName string `json:"agentName"` // deployed agent name + AgentVersion string `json:"agentVersion"` // deployed agent version +} + +// --- Candidate models --- + +// CandidateManifest represents the candidate metadata returned by +// GET /optimize/candidates/{id}. +type CandidateManifest struct { + Files []CandidateFile `json:"files"` +} + +// CandidateFile is a single entry in the candidate manifest's files list. +type CandidateFile struct { + Path string `json:"path"` + Type string `json:"type"` +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/models_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/models_test.go new file mode 100644 index 00000000000..9392943cc66 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/models_test.go @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package optimize_api + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOptimizeRequest_RoundTrip(t *testing.T) { + t.Parallel() + + original := OptimizeRequest{ + Agent: AgentDefinition{ + FoundryProjectURL: "https://example.ai.azure.com/project/my-proj", + AgentName: "my-agent", + AgentVersion: "1", + Model: "gpt-4o", + SystemPrompt: "You are helpful", + Skills: []SkillDefinition{ + {Name: "search", Description: "web search"}, + }, + }, + Dataset: []DatasetTask{ + { + Name: "task1", + Prompt: "What is 2+2?", + GroundTruth: "4", + Criteria: []Criterion{ + {Name: "accuracy", Instruction: "answer must be correct"}, + }, + }, + }, + TrainDatasetReference: &DatasetReference{ + Name: "train-ds", + Version: "1", + }, + Evaluators: []string{"coherence", "relevance"}, + Criteria: []Criterion{ + {Name: "global-crit", Instruction: "be concise"}, + }, + Options: OptimizeOptions{ + MaxIterations: new(5), + EvalModel: "gpt-4o-mini", + Strategies: []string{"prompt_mutation"}, + TargetAttributes: []string{"prompt_mutation"}, + KeepVersions: true, + TasksPerIteration: 10, + MaxReflectionTasks: 3, + ReflectionModel: "gpt-4o", + }, + } + + data, err := json.Marshal(original) + require.NoError(t, err, "marshal should succeed") + + s := string(data) + // Verify camelCase JSON tags + for _, field := range []string{ + `"agent"`, `"foundryProjectUrl"`, `"agentName"`, `"agentVersion"`, + `"dataset"`, `"trainDatasetReference"`, `"evaluators"`, `"criteria"`, + `"options"`, `"evalModel"`, `"maxIterations"`, + `"keepVersions"`, + `"tasksPerIteration"`, `"maxReflectionTasks"`, `"reflectionModel"`, + `"strategies"`, `"targetAttributes"`, `"groundTruth"`, `"systemPrompt"`, `"skills"`, + } { + assert.True(t, strings.Contains(s, field), "JSON should contain %s", field) + } + + var got OptimizeRequest + require.NoError(t, json.Unmarshal(data, &got), "unmarshal should succeed") + + assert.Equal(t, original.Agent.AgentName, got.Agent.AgentName) + assert.Equal(t, original.Agent.FoundryProjectURL, got.Agent.FoundryProjectURL) + assert.Equal(t, original.Agent.Model, got.Agent.Model) + assert.Len(t, got.Dataset, 1) + assert.Equal(t, "task1", got.Dataset[0].Name) + assert.Equal(t, "4", got.Dataset[0].GroundTruth) + assert.NotNil(t, got.TrainDatasetReference) + assert.Equal(t, "train-ds", got.TrainDatasetReference.Name) + assert.Equal(t, "gpt-4o-mini", got.Options.EvalModel) + assert.True(t, got.Options.KeepVersions) +} + +func TestOptimizeJobStatus_RoundTrip(t *testing.T) { + t.Parallel() + + original := OptimizeJobStatus{ + OperationID: "op-123", + Status: StatusRunning, + CreatedAt: "2024-01-01T00:00:00Z", + UpdatedAt: "2024-01-01T01:00:00Z", + Agent: &AgentDefinition{ + FoundryProjectURL: "https://example.ai.azure.com/project/p", + AgentName: "agent-1", + }, + Progress: &JobProgress{ + CurrentTargetAttribute: "prompt_mutation", + CurrentIteration: 3, + TasksCompleted: 15, + TasksTotal: 20, + BestScore: 0.85, + ElapsedSeconds: 120.5, + }, + Baseline: &CandidateResult{ + Name: "baseline", + AvgScore: 0.6, + PassRate: 0.5, + }, + Best: &CandidateResult{ + Name: "candidate-2", + AvgScore: 0.9, + AvgTokens: 150.0, + PassRate: 0.95, + CandidateID: "cand-2", + Mutations: map[string]any{"systemPrompt": "Be very helpful"}, + Rationale: "Improved prompt clarity", + TaskScores: []TaskScore{ + { + TaskName: "task1", + Scores: map[string]float64{"coherence": 0.9, "relevance": 0.95}, + CompositeScore: 0.925, + Tokens: 200, + Duration: 1.5, + Passed: true, + }, + }, + }, + Candidates: []CandidateResult{ + {Name: "candidate-1", AvgScore: 0.7}, + }, + } + + data, err := json.Marshal(original) + require.NoError(t, err, "marshal should succeed") + + s := string(data) + for _, field := range []string{ + `"operationId"`, `"status"`, `"createdAt"`, `"updatedAt"`, + `"progress"`, `"currentTargetAttribute"`, `"currentIteration"`, + `"tasksCompleted"`, `"tasksTotal"`, `"bestScore"`, `"elapsedSeconds"`, + `"baseline"`, `"best"`, `"candidates"`, `"candidateId"`, + `"avgScore"`, `"avgTokens"`, `"passRate"`, `"mutations"`, + `"rationale"`, `"taskScores"`, `"compositeScore"`, `"durationSeconds"`, + } { + assert.True(t, strings.Contains(s, field), "JSON should contain %s", field) + } + + var got OptimizeJobStatus + require.NoError(t, json.Unmarshal(data, &got), "unmarshal should succeed") + + assert.Equal(t, "op-123", got.OperationID) + assert.Equal(t, StatusRunning, got.Status) + assert.NotNil(t, got.Agent) + assert.Equal(t, "agent-1", got.Agent.AgentName) + assert.NotNil(t, got.Progress) + assert.Equal(t, 3, got.Progress.CurrentIteration) + assert.InDelta(t, 0.85, got.Progress.BestScore, 0.001) + assert.NotNil(t, got.Baseline) + assert.InDelta(t, 0.6, got.Baseline.AvgScore, 0.001) + assert.NotNil(t, got.Best) + assert.Equal(t, "cand-2", got.Best.CandidateID) + assert.Len(t, got.Best.TaskScores, 1) + assert.True(t, got.Best.TaskScores[0].Passed) + assert.Len(t, got.Candidates, 1) +} + +func TestOptimizeJobStatus_ErrorField(t *testing.T) { + t.Parallel() + + original := OptimizeJobStatus{ + OperationID: "op-err", + Status: StatusFailed, + Error: &JobError{ + Code: "InternalError", + Message: "something went wrong", + }, + } + + data, err := json.Marshal(original) + require.NoError(t, err) + + var got OptimizeJobStatus + require.NoError(t, json.Unmarshal(data, &got)) + + assert.Equal(t, StatusFailed, got.Status) + require.NotNil(t, got.Error) + assert.Equal(t, "InternalError", got.Error.Code) + assert.Equal(t, "something went wrong", got.Error.Message) +} + +func TestIsTerminal(t *testing.T) { + t.Parallel() + + assert.True(t, IsTerminal(StatusCompleted)) + assert.True(t, IsTerminal(StatusFailed)) + assert.True(t, IsTerminal(StatusCancelled)) + assert.False(t, IsTerminal(StatusRunning)) + assert.False(t, IsTerminal(StatusQueued)) + assert.False(t, IsTerminal("unknown")) +} + +func TestOptimizeListResponse_RoundTrip(t *testing.T) { + t.Parallel() + + original := OptimizeListResponse{ + Data: []OptimizeJobStatus{ + {OperationID: "op-1", Status: StatusCompleted}, + {OperationID: "op-2", Status: StatusRunning}, + }, + FirstID: "op-1", + LastID: "op-2", + HasMore: true, + } + + data, err := json.Marshal(original) + require.NoError(t, err) + + var got OptimizeListResponse + require.NoError(t, json.Unmarshal(data, &got)) + + assert.Len(t, got.Data, 2) + assert.Equal(t, "op-1", got.FirstID) + assert.Equal(t, "op-2", got.LastID) + assert.True(t, got.HasMore) +} + +// ---- DeploymentReport serialization ---- + +func TestDeploymentReport_JSON_ExcludesCandidateID(t *testing.T) { + t.Parallel() + + report := DeploymentReport{ + CandidateID: "cand_abc123", + AgentName: "my-agent", + AgentVersion: "3", + } + + data, err := json.Marshal(report) + require.NoError(t, err) + + // CandidateID has json:"-", so it must not appear in the body. + assert.NotContains(t, string(data), "candidateId") + assert.NotContains(t, string(data), "cand_abc123") + + // agentName and agentVersion must be present. + assert.Contains(t, string(data), `"agentName":"my-agent"`) + assert.Contains(t, string(data), `"agentVersion":"3"`) +} + +func TestDeploymentReport_JSON_RoundTrip(t *testing.T) { + t.Parallel() + + body := `{"agentName":"test-agent","agentVersion":"5"}` + var report DeploymentReport + require.NoError(t, json.Unmarshal([]byte(body), &report)) + + assert.Equal(t, "test-agent", report.AgentName) + assert.Equal(t, "5", report.AgentVersion) + assert.Empty(t, report.CandidateID, "CandidateID should not be populated from JSON") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/poller.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/poller.go new file mode 100644 index 00000000000..d380f42e1d5 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/poller.go @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// poller.go provides a polling loop for optimization jobs that calls +// a progress callback on each tick until the job reaches a terminal state. +package optimize_api + +import ( + "context" + "fmt" + "log" + "time" + + "azureaiagent/internal/pkg/agents" +) + +// Poller polls an optimization job until it reaches a terminal state. +type Poller struct { + Client *OptimizeClient + OperationID string + Interval time.Duration + MaxAttempts int // 0 means no limit + OnProgress func(*OptimizeJobStatus) +} + +// PollUntilDone polls GetOptimizeStatus at the configured interval until the +// job reaches a terminal state (completed, failed, cancelled), the context +// is cancelled, or MaxAttempts is reached. Transient errors (5xx, 429, +// connection reset) are retried up to maxConsecutiveTransient times before +// the poller gives up. +func (p *Poller) PollUntilDone(ctx context.Context) (*OptimizeJobStatus, error) { + if p.Client == nil { + return nil, fmt.Errorf("poller Client is nil") + } + if p.OperationID == "" { + return nil, fmt.Errorf("poller OperationID is empty") + } + + const maxConsecutiveTransient = 5 + + interval := p.Interval + if interval <= 0 { + interval = 5 * time.Second + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + attempts := 0 + consecutiveTransient := 0 + for { + status, err := p.Client.GetOptimizeStatus(ctx, p.OperationID) + if err != nil { + if agents.IsTransientError(err) { + consecutiveTransient++ + if consecutiveTransient > maxConsecutiveTransient { + return nil, fmt.Errorf( + "polling aborted after %d consecutive transient errors, last: %w", + consecutiveTransient, err) + } + log.Printf("[poller] transient error polling %s (%d/%d), will retry: %v", + p.OperationID, consecutiveTransient, maxConsecutiveTransient, err) + goto wait + } + return nil, fmt.Errorf("failed to get optimization status: %w", err) + } + + consecutiveTransient = 0 // reset on success + + if p.OnProgress != nil { + p.OnProgress(status) + } + + if IsTerminal(status.Status) { + return status, nil + } + + wait: + attempts++ + if p.MaxAttempts > 0 && attempts >= p.MaxAttempts { + return nil, fmt.Errorf("polling timed out after %d attempts", attempts) + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + // continue polling + } + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/poller_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/poller_test.go new file mode 100644 index 00000000000..fe623db11d7 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/optimize_api/poller_test.go @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package optimize_api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newPollerTestClient(serverURL string) *OptimizeClient { + pipeline := runtime.NewPipeline( + "test", + "v0.0.0", + runtime.PipelineOptions{}, + &policy.ClientOptions{}, + ) + return &OptimizeClient{ + endpoint: serverURL, + pipeline: pipeline, + } +} + +func TestPoller_PollsUntilCompleted(t *testing.T) { + t.Parallel() + + var callCount int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := atomic.AddInt32(&callCount, 1) + status := StatusRunning + if n >= 3 { + status = StatusCompleted + } + _ = json.NewEncoder(w).Encode(OptimizeJobStatus{ + OperationID: "op-1", + Status: status, + Progress: &JobProgress{ + CurrentIteration: int(n), + }, + }) + })) + defer server.Close() + + var progressCalls int32 + poller := &Poller{ + Client: newPollerTestClient(server.URL), + OperationID: "op-1", + Interval: 10 * time.Millisecond, + OnProgress: func(_ *OptimizeJobStatus) { + atomic.AddInt32(&progressCalls, 1) + }, + } + + result, err := poller.PollUntilDone(context.Background()) + require.NoError(t, err) + assert.Equal(t, StatusCompleted, result.Status) + assert.GreaterOrEqual(t, atomic.LoadInt32(&callCount), int32(3)) + assert.GreaterOrEqual(t, atomic.LoadInt32(&progressCalls), int32(3)) +} + +func TestPoller_PollsUntilFailed(t *testing.T) { + t.Parallel() + + var callCount int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := atomic.AddInt32(&callCount, 1) + status := StatusRunning + if n >= 2 { + status = StatusFailed + } + _ = json.NewEncoder(w).Encode(OptimizeJobStatus{ + OperationID: "op-fail", + Status: status, + Error: &JobError{ + Code: "InternalError", + Message: "something broke", + }, + }) + })) + defer server.Close() + + poller := &Poller{ + Client: newPollerTestClient(server.URL), + OperationID: "op-fail", + Interval: 10 * time.Millisecond, + } + + result, err := poller.PollUntilDone(context.Background()) + require.NoError(t, err) + assert.Equal(t, StatusFailed, result.Status) +} + +func TestPoller_ContextCancellation(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(OptimizeJobStatus{ + OperationID: "op-cancel", + Status: StatusRunning, + }) + })) + defer server.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + poller := &Poller{ + Client: newPollerTestClient(server.URL), + OperationID: "op-cancel", + Interval: 10 * time.Millisecond, + } + + result, err := poller.PollUntilDone(ctx) + assert.Nil(t, result) + require.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestPoller_OnProgressCalled(t *testing.T) { + t.Parallel() + + var callCount int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := atomic.AddInt32(&callCount, 1) + status := StatusRunning + if n >= 2 { + status = StatusCompleted + } + _ = json.NewEncoder(w).Encode(OptimizeJobStatus{ + OperationID: "op-prog", + Status: status, + }) + })) + defer server.Close() + + var statuses []string + poller := &Poller{ + Client: newPollerTestClient(server.URL), + OperationID: "op-prog", + Interval: 10 * time.Millisecond, + OnProgress: func(s *OptimizeJobStatus) { + statuses = append(statuses, s.Status) + }, + } + + result, err := poller.PollUntilDone(context.Background()) + require.NoError(t, err) + assert.Equal(t, StatusCompleted, result.Status) + assert.GreaterOrEqual(t, len(statuses), 2) + assert.Equal(t, StatusCompleted, statuses[len(statuses)-1]) +} + +func TestPoller_TransientRetryThenSuccess(t *testing.T) { + t.Parallel() + + var callCount int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := atomic.AddInt32(&callCount, 1) + if n <= 3 { + // First 3 calls return 500 (transient). + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error": "server error"}`)) + return + } + _ = json.NewEncoder(w).Encode(OptimizeJobStatus{ + OperationID: "op-retry", + Status: StatusCompleted, + }) + })) + defer server.Close() + + poller := &Poller{ + Client: newPollerTestClient(server.URL), + OperationID: "op-retry", + Interval: 10 * time.Millisecond, + } + + result, err := poller.PollUntilDone(t.Context()) + require.NoError(t, err) + assert.Equal(t, StatusCompleted, result.Status) + assert.GreaterOrEqual(t, atomic.LoadInt32(&callCount), int32(4)) +} + +func TestPoller_TransientRetryExhausted(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // Always return 500. + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error": "server error"}`)) + })) + defer server.Close() + + poller := &Poller{ + Client: newPollerTestClient(server.URL), + OperationID: "op-exhaust", + Interval: 10 * time.Millisecond, + MaxAttempts: 20, // low cap to keep test fast + } + + _, err := poller.PollUntilDone(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "consecutive transient errors") +} + +func TestPoller_NilClient(t *testing.T) { + t.Parallel() + poller := &Poller{OperationID: "op-1"} + _, err := poller.PollUntilDone(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "Client is nil") +} + +func TestPoller_EmptyOperationID(t *testing.T) { + t.Parallel() + poller := &Poller{Client: &OptimizeClient{}} + _, err := poller.PollUntilDone(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "OperationID is empty") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/transient.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/transient.go new file mode 100644 index 00000000000..46000b4fe4c --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/transient.go @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package agents + +import ( + "errors" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// IsTransientError checks whether an error represents a transient HTTP failure +// (429 Too Many Requests, 5xx Server Error, or connection-level errors) that +// is safe to retry. +func IsTransientError(err error) bool { + if respErr, ok := errors.AsType[*azcore.ResponseError](err); ok { + return respErr.StatusCode == 429 || respErr.StatusCode >= 500 + } + // Connection resets and similar I/O errors are also transient. + msg := err.Error() + return strings.Contains(msg, "connection reset") || + strings.Contains(msg, "connection refused") || + strings.Contains(msg, "EOF") +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/transient_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/transient_test.go new file mode 100644 index 00000000000..742a5deda1e --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/transient_test.go @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package agents + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsTransientError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + transient bool + }{ + {"connection reset", fmt.Errorf("read tcp: connection reset by peer"), true}, + {"connection refused", fmt.Errorf("dial tcp: connection refused"), true}, + {"unexpected EOF", fmt.Errorf("unexpected EOF"), true}, + {"not found", fmt.Errorf("not found"), false}, + {"auth error", fmt.Errorf("authorization denied"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.transient, IsTransientError(tt.err)) + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go index 3254f137b57..adaa987e9f6 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent.go @@ -1567,7 +1567,8 @@ func (p *AgentServiceTargetProvider) deployArtifacts( if len(endpoints) > 0 { last := artifacts[len(artifacts)-1] last.Metadata["note"] = "For information on invoking the agent, see " + output.WithLinkFormat( - "https://aka.ms/azd-agents-invoke") + "https://aka.ms/azd-agents-invoke") + + "\n\nSet up an evaluation suite to measure quality and impact in one step with " + output.WithHighLightFormat("azd ai agent eval init") } }