From 0d196e5e6866e8e57e8b3db0cb9d151dc639a41a Mon Sep 17 00:00:00 2001 From: Jason Lynch Date: Tue, 16 Jun 2026 09:14:04 -0400 Subject: [PATCH 1/3] refactor: add validation package Moves the validation error and path to a new `validation` package. We'll add validation helper functions to this package and possibly move it to a top level package in a future commit. PLAT-611 --- server/internal/api/apiv1/errors.go | 3 +- server/internal/api/apiv1/validate.go | 319 +++++++++------------ server/internal/api/apiv1/validate_test.go | 19 -- server/internal/validation/error.go | 64 +++++ server/internal/validation/error_test.go | 28 ++ 5 files changed, 233 insertions(+), 200 deletions(-) create mode 100644 server/internal/validation/error.go create mode 100644 server/internal/validation/error_test.go diff --git a/server/internal/api/apiv1/errors.go b/server/internal/api/apiv1/errors.go index 1f4d6243..72e05323 100644 --- a/server/internal/api/apiv1/errors.go +++ b/server/internal/api/apiv1/errors.go @@ -10,6 +10,7 @@ import ( "github.com/pgEdge/control-plane/server/internal/database" "github.com/pgEdge/control-plane/server/internal/etcd" "github.com/pgEdge/control-plane/server/internal/task" + "github.com/pgEdge/control-plane/server/internal/validation" "github.com/pgEdge/control-plane/server/internal/workflows" ) @@ -51,7 +52,7 @@ func ErrHostAlreadyExistsWithID(hostID string) *api.APIError { func apiErr(err error) error { var goaErr *goa.ServiceError var apiErr *api.APIError - var vErr *validationError + var vErr *validation.Error switch { case err == nil: return nil diff --git a/server/internal/api/apiv1/validate.go b/server/internal/api/apiv1/validate.go index 585b41b5..c6e3c780 100644 --- a/server/internal/api/apiv1/validate.go +++ b/server/internal/api/apiv1/validate.go @@ -20,63 +20,21 @@ import ( "github.com/pgEdge/control-plane/server/internal/postgres/hba" "github.com/pgEdge/control-plane/server/internal/storage" "github.com/pgEdge/control-plane/server/internal/utils" + "github.com/pgEdge/control-plane/server/internal/validation" ) -type validationError struct { - path []string - err error -} - -func newValidationError(err error, path []string) *validationError { - return &validationError{ - path: path, - err: err, - } -} - -func (v *validationError) Unwrap() error { - return v.err -} - -func (v *validationError) Error() string { - if len(v.path) == 0 { - return v.err.Error() - } - - var path strings.Builder - for i, ele := range v.path { - if i > 0 && !strings.HasPrefix(ele, "[") { - path.WriteString(".") - } - path.WriteString(ele) - } - return fmt.Sprintf("%s: %s", path.String(), v.err.Error()) -} - -func arrayIndexPath(idx int) string { - return fmt.Sprintf("[%d]", idx) -} - -func mapKeyPath(key string) string { - return fmt.Sprintf("[%s]", key) -} - -func appendPath(path []string, new ...string) []string { - return append(slices.Clone(path), new...) -} - // validateAuthFileGUCs rejects postgresql_conf settings that would make // user-supplied pg_hba_conf/pg_ident_conf entries ineffective. When hba_file // or ident_file is set, Patroni ignores the pg_hba/pg_ident arrays it manages, // so the control-plane-generated file (including user entries) would never be // written. GUC names are case-insensitive in PostgreSQL, so we compare lower. -func validateAuthFileGUCs(conf map[string]any, path []string) []error { +func validateAuthFileGUCs(conf map[string]any, path validation.Path) []error { var errs []error for key := range conf { switch strings.ToLower(strings.TrimSpace(key)) { case "hba_file", "ident_file": err := fmt.Errorf("%q is not allowed: it overrides the control-plane-managed pg_hba.conf/pg_ident.conf and would make pg_hba_conf/pg_ident_conf entries ineffective", key) - errs = append(errs, newValidationError(err, appendPath(path, mapKeyPath(key)))) + errs = append(errs, validation.NewError(err, path.AppendMapKey(key))) } } return errs @@ -85,7 +43,7 @@ func validateAuthFileGUCs(conf map[string]any, path []string) []error { // validatePgHbaConf checks that every non-comment pg_hba_conf entry parses. // Blank and comment lines are allowed and skipped. Validation is intentionally // minimal — see server/internal/postgres/hba/parse.go. -func validatePgHbaConf(lines []string, path []string) []error { +func validatePgHbaConf(lines []string, path validation.Path) []error { var errs []error for i, line := range lines { if hba.IsComment(line) { @@ -93,14 +51,14 @@ func validatePgHbaConf(lines []string, path []string) []error { } if _, err := hba.ParseEntry(line); err != nil { wrapped := fmt.Errorf("invalid pg_hba entry %q: %w", line, err) - errs = append(errs, newValidationError(wrapped, appendPath(path, arrayIndexPath(i)))) + errs = append(errs, validation.NewError(wrapped, path.AppendArrayIndex(i))) } } return errs } // validatePgIdentConf checks that every non-comment pg_ident_conf entry parses. -func validatePgIdentConf(lines []string, path []string) []error { +func validatePgIdentConf(lines []string, path validation.Path) []error { var errs []error for i, line := range lines { if hba.IsComment(line) { @@ -108,7 +66,7 @@ func validatePgIdentConf(lines []string, path []string) []error { } if _, err := hba.ParseIdent(line); err != nil { wrapped := fmt.Errorf("invalid pg_ident entry %q: %w", line, err) - errs = append(errs, newValidationError(wrapped, appendPath(path, arrayIndexPath(i)))) + errs = append(errs, validation.NewError(wrapped, path.AppendArrayIndex(i))) } } return errs @@ -117,23 +75,24 @@ func validatePgIdentConf(lines []string, path []string) []error { func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, spec *api.DatabaseSpec) error { var errs []error - errs = append(errs, validateCPUs(spec.Cpus, []string{"cpus"})...) - errs = append(errs, validateMemory(spec.Memory, []string{"memory"})...) - errs = append(errs, validatePorts(spec.Port, spec.PatroniPort, []string{"port"})) - errs = append(errs, validateUsers(spec.DatabaseUsers, []string{"database_users"})...) - errs = append(errs, validateScripts(spec.Scripts, []string{"scripts"})...) + errs = append(errs, validateCPUs(spec.Cpus, validation.NewPath("cpus"))...) + errs = append(errs, validateMemory(spec.Memory, validation.NewPath("memory"))...) + errs = append(errs, validatePorts(spec.Port, spec.PatroniPort, validation.NewPath("port"))) + errs = append(errs, validateUsers(spec.DatabaseUsers, validation.NewPath("database_users"))...) + errs = append(errs, validateScripts(spec.Scripts, validation.NewPath("scripts"))...) // Track node-name uniqueness and prepare set for cross-node checks. seenNodeNames := make(ds.Set[string], len(spec.Nodes)) // Track nodes that themselves have a source_node (treated as "new" nodes). newNodesWithSource := make(ds.Set[string], len(spec.Nodes)) + nodesPath := validation.NewPath("nodes") for i, node := range spec.Nodes { - nodePath := []string{"nodes", arrayIndexPath(i)} + nodePath := nodesPath.AppendArrayIndex(i) if seenNodeNames.Has(node.Name) { err := errors.New("node names must be unique within a database") - errs = append(errs, newValidationError(err, nodePath)) + errs = append(errs, validation.NewError(err, nodePath)) } seenNodeNames.Add(node.Name) @@ -154,11 +113,11 @@ func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, s continue } - srcPath := []string{"nodes", arrayIndexPath(i), "source_node"} + srcPath := nodesPath.AppendArrayIndex(i).Append("source_node") if !seenNodeNames.Has(src) { // Attach error to the specific field path - errs = append(errs, newValidationError(errors.New("source node does not exist"), + errs = append(errs, validation.NewError(errors.New("source node does not exist"), srcPath)) continue } @@ -166,7 +125,7 @@ func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, s // prevent using a "new" node (one that has its own source_node) // as the source for another node. if newNodesWithSource.Has(src) { - errs = append(errs, newValidationError( + errs = append(errs, validation.NewError( errors.New("source node must refer to an existing node"), srcPath, )) @@ -175,40 +134,40 @@ func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, s // Reject postgresql_conf GUCs that would make user-supplied pg_hba/pg_ident // entries ineffective, then validate the entries themselves. - errs = append(errs, validateAuthFileGUCs(spec.PostgresqlConf, []string{"postgresql_conf"})...) - errs = append(errs, validatePgHbaConf(spec.PgHbaConf, []string{"pg_hba_conf"})...) - errs = append(errs, validatePgIdentConf(spec.PgIdentConf, []string{"pg_ident_conf"})...) + errs = append(errs, validateAuthFileGUCs(spec.PostgresqlConf, validation.NewPath("postgresql_conf"))...) + errs = append(errs, validatePgHbaConf(spec.PgHbaConf, validation.NewPath("pg_hba_conf"))...) + errs = append(errs, validatePgIdentConf(spec.PgIdentConf, validation.NewPath("pg_ident_conf"))...) if spec.BackupConfig != nil { - errs = append(errs, validateBackupConfig(spec.BackupConfig, []string{"backup_config"})...) + errs = append(errs, validateBackupConfig(spec.BackupConfig, validation.NewPath("backup_config"))...) } if spec.RestoreConfig != nil { - errs = append(errs, validateRestoreConfig(spec.RestoreConfig, []string{"restore_config"})...) + errs = append(errs, validateRestoreConfig(spec.RestoreConfig, validation.NewPath("restore_config"))...) } // Validate orchestrator_opts (spec-level) - errs = append(errs, validateOrchestratorOpts(spec.OrchestratorOpts, []string{"orchestrator_opts"})...) + errs = append(errs, validateOrchestratorOpts(spec.OrchestratorOpts, validation.NewPath("orchestrator_opts"))...) // Validate services — seed portOwner with Postgres ports so services can't collide with the database. portOwner := make(servicePortOwnerMap) seedPostgresPorts(spec, portOwner) - servicesPath := []string{"services"} + servicesPath := validation.NewPath("services") switch orchestrator { case config.OrchestratorSystemD: if len(spec.Services) != 0 { - errs = append(errs, newValidationError(errors.New("services are not yet supported for systemd clusters"), servicesPath)) + errs = append(errs, validation.NewError(errors.New("services are not yet supported for systemd clusters"), servicesPath)) } default: seenServiceIDs := make(ds.Set[string], len(spec.Services)) for i, svc := range spec.Services { - svcPath := appendPath(servicesPath, arrayIndexPath(i)) + svcPath := servicesPath.AppendArrayIndex(i) // Check for duplicate service IDs if seenServiceIDs.Has(string(svc.ServiceID)) { err := errors.New("service IDs must be unique within a database") - errs = append(errs, newValidationError(err, svcPath)) + errs = append(errs, validation.NewError(err, svcPath)) } seenServiceIDs.Add(string(svc.ServiceID)) @@ -243,8 +202,8 @@ func validateDatabaseUpdate(old *database.Spec, new *api.DatabaseSpec) error { if !existingNodeNames.Has(src) { // Newly added node is trying to use a new/non-existing node as source. - path := []string{"nodes", arrayIndexPath(i), "source_node"} - errs = append(errs, newValidationError( + path := validation.NewPath("nodes", validation.ArrayIndexElement(i), "source_node") + errs = append(errs, validation.NewError( errors.New("source node must refer to an existing node"), path, )) @@ -277,7 +236,7 @@ func validateDatabaseUpdate(old *database.Spec, new *api.DatabaseSpec) error { // first time so that bootstrap-only fields are accepted. For service types that // have no bootstrap fields (e.g. postgrest) the flag has no effect. for i, svc := range new.Services { - svcPath := []string{"services", arrayIndexPath(i)} + svcPath := validation.NewPath("services", validation.ArrayIndexElement(i)) isExistingService := existingServiceIDs.Has(string(svc.ServiceID)) errs = append(errs, validateServicePortConflicts(svc, svcPath, portOwner)...) @@ -291,14 +250,14 @@ func validateNode( orchestrator config.Orchestrator, db *api.DatabaseSpec, node *api.DatabaseNodeSpec, - path []string, + path validation.Path, ) []error { var errs []error - cpusPath := appendPath(path, "cpus") + cpusPath := path.Append("cpus") errs = append(errs, validateCPUs(node.Cpus, cpusPath)...) - memPath := appendPath(path, "memory") + memPath := path.Append("memory") errs = append(errs, validateMemory(node.Memory, memPath)...) port := db.Port @@ -309,19 +268,19 @@ func validateNode( if node.PatroniPort != nil { patroniPort = node.PatroniPort } - portPath := appendPath(path, "port") + portPath := path.Append("port") errs = append(errs, validatePorts(port, patroniPort, portPath)) seenHostIDs := make(ds.Set[string], len(node.HostIds)) for i, h := range node.HostIds { hostID := string(h) - hostPath := appendPath(path, "host_ids", arrayIndexPath(i)) + hostPath := path.Append("host_ids").AppendArrayIndex(i) errs = append(errs, validateIdentifier(hostID, hostPath)) if seenHostIDs.Has(hostID) { err := errors.New("host IDs must be unique within a node") - errs = append(errs, newValidationError(err, hostPath)) + errs = append(errs, validation.NewError(err, hostPath)) } seenHostIDs.Add(hostID) @@ -329,88 +288,88 @@ func validateNode( // source_node + restore_config validation (field-level) src := utils.FromPointer(node.SourceNode) - srcPath := appendPath(path, "source_node") + srcPath := path.Append("source_node") // If restore_config is provided, source_node must be empty if node.RestoreConfig != nil && src != "" { - errs = append(errs, newValidationError(errors.New("specify either source_node or restore_config"), srcPath)) + errs = append(errs, validation.NewError(errors.New("specify either source_node or restore_config"), srcPath)) } else if src != "" { // Self-reference is invalid if src == node.Name { - errs = append(errs, newValidationError(errors.New("a node cannot use itself as a source node"), srcPath)) + errs = append(errs, validation.NewError(errors.New("a node cannot use itself as a source node"), srcPath)) } } - errs = append(errs, validateAuthFileGUCs(node.PostgresqlConf, appendPath(path, "postgresql_conf"))...) - errs = append(errs, validatePgHbaConf(node.PgHbaConf, appendPath(path, "pg_hba_conf"))...) - errs = append(errs, validatePgIdentConf(node.PgIdentConf, appendPath(path, "pg_ident_conf"))...) + errs = append(errs, validateAuthFileGUCs(node.PostgresqlConf, path.Append("postgresql_conf"))...) + errs = append(errs, validatePgHbaConf(node.PgHbaConf, path.Append("pg_hba_conf"))...) + errs = append(errs, validatePgIdentConf(node.PgIdentConf, path.Append("pg_ident_conf"))...) if node.BackupConfig != nil { - backupConfigPath := appendPath(path, "backup_config") + backupConfigPath := path.Append("backup_config") errs = append(errs, validateBackupConfig(node.BackupConfig, backupConfigPath)...) } if node.RestoreConfig != nil { - restoreConfigPath := appendPath(path, "restore_config") + restoreConfigPath := path.Append("restore_config") errs = append(errs, validateRestoreConfig(node.RestoreConfig, restoreConfigPath)...) } switch orchestrator { case config.OrchestratorSystemD: if db.Port == nil && node.Port == nil { - portPath := appendPath(path, "port") - errs = append(errs, newValidationError(errors.New("port must be defined"), portPath)) + portPath := path.Append("port") + errs = append(errs, validation.NewError(errors.New("port must be defined"), portPath)) } if db.PatroniPort == nil && node.PatroniPort == nil { - portPath := appendPath(path, "patroni_port") - errs = append(errs, newValidationError(errors.New("patroni_port must be defined"), portPath)) + portPath := path.Append("patroni_port") + errs = append(errs, validation.NewError(errors.New("patroni_port must be defined"), portPath)) } } // Validate orchestrator_opts (per-node) - errs = append(errs, validateOrchestratorOpts(node.OrchestratorOpts, appendPath(path, "orchestrator_opts"))...) + errs = append(errs, validateOrchestratorOpts(node.OrchestratorOpts, path.Append("orchestrator_opts"))...) return errs } -func validateServiceSpec(svc *api.ServiceSpec, path []string, isUpdate bool, databaseID string, dbUsers []*api.DatabaseUserSpec, nodeNames ...ds.Set[string]) []error { +func validateServiceSpec(svc *api.ServiceSpec, path validation.Path, isUpdate bool, databaseID string, dbUsers []*api.DatabaseUserSpec, nodeNames ...ds.Set[string]) []error { var errs []error // Validate service_id - serviceIDPath := appendPath(path, "service_id") + serviceIDPath := path.Append("service_id") errs = append(errs, validateIdentifier(string(svc.ServiceID), serviceIDPath)) // Enforce Docker Swarm service name budget: "{databaseID}-{serviceID}-{8charHash}" must be ≤63 chars. if len(databaseID)+len(string(svc.ServiceID)) > 53 { err := fmt.Errorf("database ID and service ID combined must not exceed 53 characters (got %d)", len(databaseID)+len(string(svc.ServiceID))) - errs = append(errs, newValidationError(err, serviceIDPath)) + errs = append(errs, validation.NewError(err, serviceIDPath)) } // Validate service_type allowlist - supportedServiceTypes := []string{"mcp", "postgrest", "rag"} + supportedServiceTypes := validation.NewPath("mcp", "postgrest", "rag") if !slices.Contains(supportedServiceTypes, svc.ServiceType) { err := fmt.Errorf("unsupported service type %q (supported: %s)", svc.ServiceType, strings.Join(supportedServiceTypes, ", ")) - errs = append(errs, newValidationError(err, appendPath(path, "service_type"))) + errs = append(errs, validation.NewError(err, path.Append("service_type"))) } // Validate version (semver pattern or "latest") if svc.Version != "latest" && !semverPattern.MatchString(svc.Version) { err := errors.New("version must be in semver format (e.g., '1.0.0') or 'latest'") - errs = append(errs, newValidationError(err, appendPath(path, "version"))) + errs = append(errs, validation.NewError(err, path.Append("version"))) } // Validate host_ids (uniqueness and format) seenHostIDs := make(ds.Set[string], len(svc.HostIds)) for i, hostID := range svc.HostIds { hostIDStr := string(hostID) - hostIDPath := appendPath(path, "host_ids", arrayIndexPath(i)) + hostIDPath := path.Append("host_ids").AppendArrayIndex(i) errs = append(errs, validateIdentifier(hostIDStr, hostIDPath)) // may need to relax this if there is a use-case for multiple service instances on the same host if seenHostIDs.Has(hostIDStr) { err := errors.New("host IDs must be unique within a service") - errs = append(errs, newValidationError(err, hostIDPath)) + errs = append(errs, validation.NewError(err, hostIDPath)) } seenHostIDs.Add(hostIDStr) } @@ -418,16 +377,16 @@ func validateServiceSpec(svc *api.ServiceSpec, path []string, isUpdate bool, dat // Validate config based on service_type switch svc.ServiceType { case "mcp": - errs = append(errs, validateMCPServiceConfig(svc.Config, appendPath(path, "config"), isUpdate)...) + errs = append(errs, validateMCPServiceConfig(svc.Config, path.Append("config"), isUpdate)...) case "postgrest": - errs = append(errs, validatePostgRESTServiceConfig(svc.Config, appendPath(path, "config"))...) + errs = append(errs, validatePostgRESTServiceConfig(svc.Config, path.Append("config"))...) case "rag": - errs = append(errs, validateRAGServiceConfig(svc.Config, appendPath(path, "config"), isUpdate)...) + errs = append(errs, validateRAGServiceConfig(svc.Config, path.Append("config"), isUpdate)...) } // Validate database_connection if provided if svc.DatabaseConnection != nil { - dcPath := appendPath(path, "database_connection") + dcPath := path.Append("database_connection") var nn ds.Set[string] if len(nodeNames) > 0 { nn = nodeNames[0] @@ -445,31 +404,31 @@ func validateServiceSpec(svc *api.ServiceSpec, path []string, isUpdate bool, dat writeSafe := map[string]bool{database.TargetSessionAttrsPrimary: true, database.TargetSessionAttrsReadWrite: true} if tsa != "" && !writeSafe[tsa] { err := fmt.Errorf("allow_writes requires target_session_attrs 'primary' or 'read-write', got '%s'", tsa) - errs = append(errs, newValidationError(err, appendPath(path, "database_connection", "target_session_attrs"))) + errs = append(errs, validation.NewError(err, path.Append("database_connection", "target_session_attrs"))) } } } // Validate cpus if provided if svc.Cpus != nil { - errs = append(errs, validateCPUs(svc.Cpus, appendPath(path, "cpus"))...) + errs = append(errs, validateCPUs(svc.Cpus, path.Append("cpus"))...) } // Validate memory if provided if svc.Memory != nil { - errs = append(errs, validateMemory(svc.Memory, appendPath(path, "memory"))...) + errs = append(errs, validateMemory(svc.Memory, path.Append("memory"))...) } // Validate orchestrator_opts (service-specific restrictions on top of shared checks) - errs = append(errs, validateServiceOrchestratorOpts(svc.OrchestratorOpts, appendPath(path, "orchestrator_opts"))...) + errs = append(errs, validateServiceOrchestratorOpts(svc.OrchestratorOpts, path.Append("orchestrator_opts"))...) return errs } -func validateConnectAs(svc *api.ServiceSpec, dbUsers []*api.DatabaseUserSpec, path []string) []error { - connectAsPath := appendPath(path, "connect_as") +func validateConnectAs(svc *api.ServiceSpec, dbUsers []*api.DatabaseUserSpec, path validation.Path) []error { + connectAsPath := path.Append("connect_as") if svc.ConnectAs == "" { - return []error{newValidationError(errors.New("connect_as is required"), connectAsPath)} + return []error{validation.NewError(errors.New("connect_as is required"), connectAsPath)} } for _, u := range dbUsers { @@ -479,42 +438,42 @@ func validateConnectAs(svc *api.ServiceSpec, dbUsers []*api.DatabaseUserSpec, pa } err := fmt.Errorf("connect_as %q does not match any database_users entry", svc.ConnectAs) - return []error{newValidationError(err, connectAsPath)} + return []error{validation.NewError(err, connectAsPath)} } -func validateMCPServiceConfig(config map[string]any, path []string, isUpdate bool) []error { +func validateMCPServiceConfig(config map[string]any, path validation.Path, isUpdate bool) []error { _, errs := database.ParseMCPServiceConfig(config, isUpdate) var result []error for _, err := range errs { - result = append(result, newValidationError(err, path)) + result = append(result, validation.NewError(err, path)) } return result } -func validatePostgRESTServiceConfig(config map[string]any, path []string) []error { +func validatePostgRESTServiceConfig(config map[string]any, path validation.Path) []error { _, errs := database.ParsePostgRESTServiceConfig(config) var result []error for _, err := range errs { - result = append(result, newValidationError(err, path)) + result = append(result, validation.NewError(err, path)) } return result } -func validateDatabaseConnection(dc *api.DatabaseConnection, path []string, nodeNames ds.Set[string]) []error { +func validateDatabaseConnection(dc *api.DatabaseConnection, path validation.Path, nodeNames ds.Set[string]) []error { var errs []error // Validate target_nodes: no duplicates, no empty strings, must exist in spec if dc.TargetNodes != nil { seen := make(ds.Set[string], len(dc.TargetNodes)) for i, node := range dc.TargetNodes { - nodePath := appendPath(path, "target_nodes", arrayIndexPath(i)) + nodePath := path.Append("target_nodes").AppendArrayIndex(i) if node == "" { - errs = append(errs, newValidationError(errors.New("node name must not be empty"), nodePath)) + errs = append(errs, validation.NewError(errors.New("node name must not be empty"), nodePath)) } else if nodeNames != nil && !nodeNames.Has(node) { - errs = append(errs, newValidationError(fmt.Errorf("node %q does not exist in the database spec", node), nodePath)) + errs = append(errs, validation.NewError(fmt.Errorf("node %q does not exist in the database spec", node), nodePath)) } if seen.Has(node) { - errs = append(errs, newValidationError(fmt.Errorf("duplicate node name %q", node), nodePath)) + errs = append(errs, validation.NewError(fmt.Errorf("duplicate node name %q", node), nodePath)) } seen.Add(node) } @@ -531,74 +490,74 @@ func validateDatabaseConnection(dc *api.DatabaseConnection, path []string, nodeN } if !valid[*dc.TargetSessionAttrs] { err := fmt.Errorf("invalid target_session_attrs %q (must be primary, prefer-standby, standby, read-write, or any)", *dc.TargetSessionAttrs) - errs = append(errs, newValidationError(err, appendPath(path, "target_session_attrs"))) + errs = append(errs, validation.NewError(err, path.Append("target_session_attrs"))) } } return errs } -func validateRAGServiceConfig(config map[string]any, path []string, isUpdate bool) []error { +func validateRAGServiceConfig(config map[string]any, path validation.Path, isUpdate bool) []error { _, errs := database.ParseRAGServiceConfig(config, isUpdate) var result []error for _, err := range errs { - result = append(result, newValidationError(err, path)) + result = append(result, validation.NewError(err, path)) } return result } -func validateCPUs(value *string, path []string) []error { +func validateCPUs(value *string, path validation.Path) []error { var errs []error cpus, err := parseCPUs(value) if err != nil { - errs = append(errs, newValidationError(err, path)) + errs = append(errs, validation.NewError(err, path)) } if cpus != 0 && cpus < 0.001 { err := errors.New("cannot be less than 1 millicpu") - errs = append(errs, newValidationError(err, path)) + errs = append(errs, validation.NewError(err, path)) } return errs } -func validateMemory(value *string, path []string) []error { +func validateMemory(value *string, path validation.Path) []error { var errs []error _, err := parseBytes(value) if err != nil { - errs = append(errs, newValidationError(err, path)) + errs = append(errs, validation.NewError(err, path)) } return errs } -func validatePorts(postgresPort, patroniPort *int, path []string) error { +func validatePorts(postgresPort, patroniPort *int, path validation.Path) error { postgres := utils.FromPointer(postgresPort) patroni := utils.FromPointer(patroniPort) if postgres > 0 && postgres == patroni { - return newValidationError(errors.New("postgres and patroni ports must not conflict"), path) + return validation.NewError(errors.New("postgres and patroni ports must not conflict"), path) } return nil } -func validateUsers(users []*api.DatabaseUserSpec, path []string) []error { +func validateUsers(users []*api.DatabaseUserSpec, path validation.Path) []error { var errs []error seenNames := ds.NewSet[string]() var hasOwner bool for i, user := range users { - userPath := appendPath(path, arrayIndexPath(i)) + userPath := path.AppendArrayIndex(i) if seenNames.Has(user.Username) { err := errors.New("usernames must be unique within a database") - errs = append(errs, newValidationError(err, userPath)) + errs = append(errs, validation.NewError(err, userPath)) } if user.DbOwner != nil && *user.DbOwner && hasOwner { err := errors.New("cannot have multiple users with db_owner = true") - errs = append(errs, newValidationError(err, userPath)) + errs = append(errs, validation.NewError(err, userPath)) } seenNames.Add(user.Username) @@ -643,7 +602,7 @@ type servicePortOwnerMap map[hostPort]string // validateServicePortConflicts checks that the service's explicit port (if any) // does not collide with a port already claimed by another service on the same host. -func validateServicePortConflicts(svc *api.ServiceSpec, path []string, owner servicePortOwnerMap) []error { +func validateServicePortConflicts(svc *api.ServiceSpec, path validation.Path, owner servicePortOwnerMap) []error { if svc.Port == nil || *svc.Port <= 0 { return nil } @@ -653,7 +612,7 @@ func validateServicePortConflicts(svc *api.ServiceSpec, path []string, owner ser key := hostPort{hostID: string(hostID), port: *svc.Port} if prev, exists := owner[key]; exists { err := fmt.Errorf("port %d conflicts with service %q on the same host", *svc.Port, prev) - errs = append(errs, newValidationError(err, appendPath(path, "port"))) + errs = append(errs, validation.NewError(err, path.Append("port"))) } else { owner[key] = string(svc.ServiceID) } @@ -661,33 +620,33 @@ func validateServicePortConflicts(svc *api.ServiceSpec, path []string, owner ser return errs } -func validateBackupConfig(cfg *api.BackupConfigSpec, path []string) []error { +func validateBackupConfig(cfg *api.BackupConfigSpec, path validation.Path) []error { var errs []error for i, repo := range cfg.Repositories { - repoPath := appendPath(path, "repositories", arrayIndexPath(i)) + repoPath := path.Append("repositories").AppendArrayIndex(i) errs = append(errs, validateBackupRepository(repo, repoPath)...) } return errs } -func validateRestoreConfig(cfg *api.RestoreConfigSpec, path []string) []error { +func validateRestoreConfig(cfg *api.RestoreConfigSpec, path validation.Path) []error { var errs []error - sourceDbIdPath := appendPath(path, "source_database_id") + sourceDbIdPath := path.Append("source_database_id") errs = append(errs, validateIdentifier(string(cfg.SourceDatabaseID), sourceDbIdPath)) - repoPath := appendPath(path, "repository") + repoPath := path.Append("repository") errs = append(errs, validateRestoreRepository(cfg.Repository, repoPath)...) - restoreOptsPath := appendPath(path, "restore_options") + restoreOptsPath := path.Append("restore_options") errs = append(errs, validatePgBackRestOptions(cfg.RestoreOptions, restoreOptsPath)...) return errs } -func validateBackupRepository(cfg *api.BackupRepositorySpec, path []string) []error { +func validateBackupRepository(cfg *api.BackupRepositorySpec, path validation.Path) []error { props := repoProperties{ id: cfg.ID, repoType: cfg.Type, @@ -704,7 +663,7 @@ func validateBackupRepository(cfg *api.BackupRepositorySpec, path []string) []er return validateRepoProperties(props, path) } -func validateRestoreRepository(cfg *api.RestoreRepositorySpec, path []string) []error { +func validateRestoreRepository(cfg *api.RestoreRepositorySpec, path validation.Path) []error { props := repoProperties{ id: cfg.ID, repoType: cfg.Type, @@ -734,12 +693,12 @@ type repoProperties struct { customOptions map[string]string } -func validateRepoProperties(props repoProperties, path []string) []error { +func validateRepoProperties(props repoProperties, path validation.Path) []error { var errs []error id := utils.FromPointer(props.id) if id != "" { - idPath := appendPath(path, "id") + idPath := path.Append("id") errs = append(errs, validateIdentifier(string(id), idPath)) } @@ -754,70 +713,70 @@ func validateRepoProperties(props repoProperties, path []string) []error { case pgbackrest.RepositoryTypeS3: errs = append(errs, validateS3RepoProperties(props, path)...) default: - err := newValidationError( + err := validation.NewError( fmt.Errorf("unsupported repo type '%s'", repoType), - appendPath(path, "type"), + path.Append("type"), ) errs = append(errs, err) } - customOptsPath := appendPath(path, "custom_options") + customOptsPath := path.Append("custom_options") errs = append(errs, validatePgBackRestOptions(props.customOptions, customOptsPath)...) return errs } -func validateAzureRepoProperties(props repoProperties, path []string) []error { +func validateAzureRepoProperties(props repoProperties, path validation.Path) []error { var errs []error if utils.FromPointer(props.azureAccount) == "" { err := errors.New("azure_account is required for azure repositories") - errs = append(errs, newValidationError(err, appendPath(path, "azure_account"))) + errs = append(errs, validation.NewError(err, path.Append("azure_account"))) } if utils.FromPointer(props.azureContainer) == "" { err := errors.New("azure_container is required for azure repositories") - errs = append(errs, newValidationError(err, appendPath(path, "azure_container"))) + errs = append(errs, validation.NewError(err, path.Append("azure_container"))) } if utils.FromPointer(props.azureKey) == "" { err := errors.New("azure_key is required for azure repositories") - errs = append(errs, newValidationError(err, appendPath(path, "azure_key"))) + errs = append(errs, validation.NewError(err, path.Append("azure_key"))) } return errs } -func validateFSRepoProperties(props repoProperties, path []string) []error { +func validateFSRepoProperties(props repoProperties, path validation.Path) []error { var errs []error basePath := utils.FromPointer(props.basePath) if basePath == "" { err := fmt.Errorf("base_path is required for %s repositories", props.repoType) - errs = append(errs, newValidationError(err, appendPath(path, "base_path"))) + errs = append(errs, validation.NewError(err, path.Append("base_path"))) } else if !filepath.IsAbs(*props.basePath) { err := fmt.Errorf("base_path must be absolute for %s repositories", props.repoType) - errs = append(errs, newValidationError(err, appendPath(path, "base_path"))) + errs = append(errs, validation.NewError(err, path.Append("base_path"))) } return errs } -func validateGCSRepoProperties(props repoProperties, path []string) []error { +func validateGCSRepoProperties(props repoProperties, path validation.Path) []error { var errs []error if utils.FromPointer(props.gcsBucket) == "" { err := errors.New("gcs_bucket is required for gcs repositories") - errs = append(errs, newValidationError(err, appendPath(path, "gcs_bucket"))) + errs = append(errs, validation.NewError(err, path.Append("gcs_bucket"))) } return errs } -func validateS3RepoProperties(props repoProperties, path []string) []error { +func validateS3RepoProperties(props repoProperties, path validation.Path) []error { var errs []error if utils.FromPointer(props.s3Bucket) == "" { err := errors.New("s3_bucket is required for s3 repositories") - errs = append(errs, newValidationError(err, appendPath(path, "s3_bucket"))) + errs = append(errs, validation.NewError(err, path.Append("s3_bucket"))) } return errs @@ -829,7 +788,7 @@ var semverPattern = regexp.MustCompile(`^\d+\.\d+(\.\d+)?$`) // reservedLabelPrefix is the label key prefix reserved for system use. const reservedLabelPrefix = "pgedge." -func validateOrchestratorOpts(opts *api.OrchestratorOpts, path []string) []error { +func validateOrchestratorOpts(opts *api.OrchestratorOpts, path validation.Path) []error { if opts == nil || opts.Swarm == nil { return nil } @@ -837,9 +796,9 @@ func validateOrchestratorOpts(opts *api.OrchestratorOpts, path []string) []error var errs []error for key := range opts.Swarm.ExtraLabels { if strings.HasPrefix(key, reservedLabelPrefix) { - labelPath := appendPath(path, "swarm", "extra_labels", mapKeyPath(key)) + labelPath := path.Append("swarm", "extra_labels").AppendMapKey(key) err := fmt.Errorf("labels starting with %q are reserved for system use", reservedLabelPrefix) - errs = append(errs, newValidationError(err, labelPath)) + errs = append(errs, validation.NewError(err, labelPath)) } } return errs @@ -848,7 +807,7 @@ func validateOrchestratorOpts(opts *api.OrchestratorOpts, path []string) []error // validateServiceOrchestratorOpts runs the shared orchestrator_opts checks and // adds service-specific restrictions. Services do not support extra_volumes // (bind mounts are configured per service type) or driver_opts on extra_networks. -func validateServiceOrchestratorOpts(opts *api.OrchestratorOpts, path []string) []error { +func validateServiceOrchestratorOpts(opts *api.OrchestratorOpts, path validation.Path) []error { errs := validateOrchestratorOpts(opts, path) if opts == nil || opts.Swarm == nil { @@ -857,28 +816,28 @@ func validateServiceOrchestratorOpts(opts *api.OrchestratorOpts, path []string) if len(opts.Swarm.ExtraVolumes) > 0 { err := errors.New("extra_volumes is not supported for services") - errs = append(errs, newValidationError(err, appendPath(path, "swarm", "extra_volumes"))) + errs = append(errs, validation.NewError(err, path.Append("swarm", "extra_volumes"))) } for i, net := range opts.Swarm.ExtraNetworks { if len(net.DriverOpts) > 0 { - netPath := appendPath(path, "swarm", "extra_networks", arrayIndexPath(i), "driver_opts") + netPath := path.Append("swarm", "extra_networks").AppendArrayIndex(i).Append("driver_opts") err := errors.New("driver_opts is not supported for services") - errs = append(errs, newValidationError(err, netPath)) + errs = append(errs, validation.NewError(err, netPath)) } } return errs } -func validatePgBackRestOptions(opts map[string]string, path []string) []error { +func validatePgBackRestOptions(opts map[string]string, path validation.Path) []error { var errs []error for key := range opts { if !pgBackRestOptionPattern.MatchString(key) { - optPath := appendPath(path, mapKeyPath(key)) + optPath := path.AppendMapKey(key) err := errors.New("invalid option name") - errs = append(errs, newValidationError(err, optPath)) + errs = append(errs, validation.NewError(err, optPath)) } } @@ -888,15 +847,15 @@ func validatePgBackRestOptions(opts map[string]string, path []string) []error { func validateBackupOptions(opts *api.BackupOptions) error { var errs []error - optsPath := []string{"backup_options"} + optsPath := validation.NewPath("backup_options") errs = append(errs, validatePgBackRestOptions(opts.BackupOptions, optsPath)...) return errors.Join(errs...) } -func validateIdentifier(ident string, path []string) error { +func validateIdentifier(ident string, path validation.Path) error { if err := utils.ValidateID(ident); err != nil { - return newValidationError(err, path) + return validation.NewError(err, path) } return nil @@ -919,20 +878,20 @@ func validateHostIDUniqueness(ctx context.Context, hostSvc *host.Service, hostID } } -func validateScripts(scripts *api.DatabaseScripts, path []string) []error { +func validateScripts(scripts *api.DatabaseScripts, path validation.Path) []error { if scripts == nil { return nil } return slices.Concat( - validateScript(scripts.PostInit, appendPath(path, "post_init")), - validateScript(scripts.PostDatabaseCreate, appendPath(path, "post_database_create")), + validateScript(scripts.PostInit, path.Append("post_init")), + validateScript(scripts.PostDatabaseCreate, path.Append("post_database_create")), ) } -func validateScript(statements []string, path []string) []error { +func validateScript(statements []string, path validation.Path) []error { var errs []error for i, statement := range statements { - statementPath := appendPath(path, arrayIndexPath(i)) + statementPath := path.AppendArrayIndex(i) if err := validateSQLStatement(statement, statementPath); err != nil { errs = append(errs, err) } @@ -940,11 +899,11 @@ func validateScript(statements []string, path []string) []error { return errs } -func validateSQLStatement(statement string, path []string) error { +func validateSQLStatement(statement string, path validation.Path) error { _, err := postgresparser.ParseSQLStrict(statement) if err != nil { err = fmt.Errorf("failed to parse SQL statement: %w", err) - return newValidationError(err, path) + return validation.NewError(err, path) } return nil } diff --git a/server/internal/api/apiv1/validate_test.go b/server/internal/api/apiv1/validate_test.go index 97f1a84e..40637a2e 100644 --- a/server/internal/api/apiv1/validate_test.go +++ b/server/internal/api/apiv1/validate_test.go @@ -12,25 +12,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestValidationError(t *testing.T) { - t.Run("with path", func(t *testing.T) { - err := newValidationError(errors.New("test error"), []string{ - "array", - arrayIndexPath(0), - "map", - mapKeyPath("key"), - }) - - assert.ErrorContains(t, err, "array[0].map[key]: test error") - }) - - t.Run("without path", func(t *testing.T) { - err := newValidationError(errors.New("test error"), nil) - - assert.ErrorContains(t, err, "test error") - }) -} - func TestValidateCPUs(t *testing.T) { for _, tc := range []struct { name string diff --git a/server/internal/validation/error.go b/server/internal/validation/error.go new file mode 100644 index 00000000..4f6e1c25 --- /dev/null +++ b/server/internal/validation/error.go @@ -0,0 +1,64 @@ +package validation + +import ( + "fmt" + "slices" + "strings" +) + +type Path []string + +func NewPath(elems ...string) Path { + return elems +} + +func ArrayIndexElement(idx int) string { + return fmt.Sprintf("[%d]", idx) +} + +func MapKeyElement(key string) string { + return fmt.Sprintf("[%s]", key) +} + +func (p Path) String() string { + var path strings.Builder + for i, ele := range p { + if i > 0 && !strings.HasPrefix(ele, "[") { + path.WriteString(".") + } + path.WriteString(ele) + } + return path.String() +} + +func (p Path) Append(elem ...string) Path { + return append(slices.Clone(p), elem...) +} + +func (p Path) AppendArrayIndex(idx int) Path { + return p.Append(ArrayIndexElement(idx)) +} + +func (p Path) AppendMapKey(key string) Path { + return p.Append(MapKeyElement(key)) +} + +type Error struct { + Path Path + Err error +} + +func NewError(err error, path Path) *Error { + return &Error{Err: err, Path: path} +} + +func (e *Error) Unwrap() error { + return e.Err +} + +func (e *Error) Error() string { + if len(e.Path) == 0 { + return e.Err.Error() + } + return fmt.Sprintf("%s: %s", e.Path.String(), e.Err.Error()) +} diff --git a/server/internal/validation/error_test.go b/server/internal/validation/error_test.go new file mode 100644 index 00000000..6e24d84e --- /dev/null +++ b/server/internal/validation/error_test.go @@ -0,0 +1,28 @@ +package validation_test + +import ( + "errors" + "testing" + + "github.com/pgEdge/control-plane/server/internal/validation" + "github.com/stretchr/testify/assert" +) + +func TestValidationError(t *testing.T) { + t.Run("with path", func(t *testing.T) { + err := validation.NewError(errors.New("test error"), validation.NewPath( + "array", + validation.ArrayIndexElement(0), + "map", + validation.MapKeyElement("key"), + )) + + assert.ErrorContains(t, err, "array[0].map[key]: test error") + }) + + t.Run("without path", func(t *testing.T) { + err := validation.NewError(errors.New("test error"), nil) + + assert.ErrorContains(t, err, "test error") + }) +} From 83a6d9e77d547a4e71182ac3e09f3261fd1f4255 Mon Sep 17 00:00:00 2001 From: Jason Lynch Date: Tue, 16 Jun 2026 12:58:29 -0400 Subject: [PATCH 2/3] feat: consolidate duplicate port validation We were validating duplicate ports in a few spots, each with different error message shapes, and neither was able to catch duplicated ports from nodes deployed to the same host. This commit consolidates and improves those checks so that we'll catch duplicated ports from either postgres, patroni, or a service on the same host. PLAT-611 --- server/internal/api/apiv1/validate.go | 129 +++++------ server/internal/api/apiv1/validate_test.go | 249 +++++++++------------ server/internal/ds/set.go | 16 ++ server/internal/validation/validators.go | 50 +++++ 4 files changed, 225 insertions(+), 219 deletions(-) create mode 100644 server/internal/validation/validators.go diff --git a/server/internal/api/apiv1/validate.go b/server/internal/api/apiv1/validate.go index c6e3c780..947d94bc 100644 --- a/server/internal/api/apiv1/validate.go +++ b/server/internal/api/apiv1/validate.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "maps" "path/filepath" "regexp" "slices" @@ -77,7 +78,7 @@ func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, s errs = append(errs, validateCPUs(spec.Cpus, validation.NewPath("cpus"))...) errs = append(errs, validateMemory(spec.Memory, validation.NewPath("memory"))...) - errs = append(errs, validatePorts(spec.Port, spec.PatroniPort, validation.NewPath("port"))) + errs = append(errs, validateUniquePorts(spec)...) errs = append(errs, validateUsers(spec.DatabaseUsers, validation.NewPath("database_users"))...) errs = append(errs, validateScripts(spec.Scripts, validation.NewPath("scripts"))...) @@ -148,10 +149,6 @@ func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, s // Validate orchestrator_opts (spec-level) errs = append(errs, validateOrchestratorOpts(spec.OrchestratorOpts, validation.NewPath("orchestrator_opts"))...) - // Validate services — seed portOwner with Postgres ports so services can't collide with the database. - portOwner := make(servicePortOwnerMap) - seedPostgresPorts(spec, portOwner) - servicesPath := validation.NewPath("services") switch orchestrator { @@ -171,7 +168,6 @@ func validateDatabaseSpec(orchestrator config.Orchestrator, databaseID string, s } seenServiceIDs.Add(string(svc.ServiceID)) - errs = append(errs, validateServicePortConflicts(svc, svcPath, portOwner)...) errs = append(errs, validateServiceSpec(svc, svcPath, false, databaseID, spec.DatabaseUsers, seenNodeNames)...) } } @@ -228,10 +224,6 @@ func validateDatabaseUpdate(old *database.Spec, new *api.DatabaseSpec) error { existingServiceIDs.Add(svc.ServiceID) } - // Seed portOwner with Postgres ports so services can't collide with the database. - portOwner := make(servicePortOwnerMap) - seedPostgresPorts(new, portOwner) - // Validate each service. Pass isUpdate=false for services being added for the // first time so that bootstrap-only fields are accepted. For service types that // have no bootstrap fields (e.g. postgrest) the flag has no effect. @@ -239,7 +231,6 @@ func validateDatabaseUpdate(old *database.Spec, new *api.DatabaseSpec) error { svcPath := validation.NewPath("services", validation.ArrayIndexElement(i)) isExistingService := existingServiceIDs.Has(string(svc.ServiceID)) - errs = append(errs, validateServicePortConflicts(svc, svcPath, portOwner)...) errs = append(errs, validateServiceSpec(svc, svcPath, isExistingService, old.DatabaseID, new.DatabaseUsers, newNodeNames)...) } @@ -260,17 +251,6 @@ func validateNode( memPath := path.Append("memory") errs = append(errs, validateMemory(node.Memory, memPath)...) - port := db.Port - if node.Port != nil { - port = node.Port - } - patroniPort := db.PatroniPort - if node.PatroniPort != nil { - patroniPort = node.PatroniPort - } - portPath := path.Append("port") - errs = append(errs, validatePorts(port, patroniPort, portPath)) - seenHostIDs := make(ds.Set[string], len(node.HostIds)) for i, h := range node.HostIds { hostID := string(h) @@ -532,15 +512,58 @@ func validateMemory(value *string, path validation.Path) []error { return errs } -func validatePorts(postgresPort, patroniPort *int, path validation.Path) error { - postgres := utils.FromPointer(postgresPort) - patroni := utils.FromPointer(patroniPort) +func validateUniquePorts(spec *api.DatabaseSpec) []error { + hostPorts := map[string]*validation.Unique[int]{} + specPostgresPort := utils.FromPointer(spec.Port) + specPatroniPort := utils.FromPointer(spec.PatroniPort) + + nodesPath := validation.NewPath("nodes") + for i, node := range spec.Nodes { + postgresPort := utils.FromPointer(node.Port) + if postgresPort == 0 { + postgresPort = specPostgresPort + } + patroniPort := utils.FromPointer(node.PatroniPort) + if patroniPort == 0 { + patroniPort = specPatroniPort + } - if postgres > 0 && postgres == patroni { - return validation.NewError(errors.New("postgres and patroni ports must not conflict"), path) + nodePath := nodesPath.AppendArrayIndex(i) + for _, h := range node.HostIds { + hostID := string(h) + if _, ok := hostPorts[hostID]; !ok { + hostPorts[hostID] = validation.NewUnique[int]() + } + if postgresPort != 0 { + hostPorts[hostID].RecordSeen(nodePath.Append("port"), postgresPort) + } + if patroniPort != 0 { + hostPorts[hostID].RecordSeen(nodePath.Append("patroni_port"), patroniPort) + } + } } - return nil + servicesPath := validation.NewPath("services") + for i, service := range spec.Services { + servicePath := servicesPath.AppendArrayIndex(i) + + for _, h := range service.HostIds { + hostID := string(h) + if _, ok := hostPorts[hostID]; !ok { + hostPorts[hostID] = validation.NewUnique[int]() + } + if port := utils.FromPointer(service.Port); port != 0 { + hostPorts[hostID].RecordSeen(servicePath.Append("port"), port) + } + } + } + + var errs []error + for _, hostID := range slices.Sorted(maps.Keys(hostPorts)) { + errs = append(errs, hostPorts[hostID].Validate(fmt.Errorf("duplicate ports allocated on host '%s'", hostID))...) + } + + return errs } func validateUsers(users []*api.DatabaseUserSpec, path validation.Path) []error { @@ -570,56 +593,6 @@ func validateUsers(users []*api.DatabaseUserSpec, path validation.Path) []error return errs } -// seedPostgresPorts registers each node's effective Postgres port in the -// portOwner map so that service port validation can detect collisions with -// the database. A node-level port override (node.Port) takes precedence -// over the spec-level default (spec.Port). -func seedPostgresPorts(spec *api.DatabaseSpec, owner servicePortOwnerMap) { - for _, node := range spec.Nodes { - pgPort := utils.FromPointer(spec.Port) - if node.Port != nil { - pgPort = *node.Port - } - if pgPort > 0 { - for _, hostID := range node.HostIds { - owner[hostPort{hostID: string(hostID), port: pgPort}] = "postgres" - } - } - } -} - -// hostPort identifies a unique (host, port) binding for cross-service -// port conflict detection. -type hostPort struct { - hostID string - port int -} - -// servicePortOwnerMap tracks which service owns a given (host, port) pair. -// Callers create one map and pass it to validateServicePortConflicts for -// each service in the spec. -type servicePortOwnerMap map[hostPort]string - -// validateServicePortConflicts checks that the service's explicit port (if any) -// does not collide with a port already claimed by another service on the same host. -func validateServicePortConflicts(svc *api.ServiceSpec, path validation.Path, owner servicePortOwnerMap) []error { - if svc.Port == nil || *svc.Port <= 0 { - return nil - } - - var errs []error - for _, hostID := range svc.HostIds { - key := hostPort{hostID: string(hostID), port: *svc.Port} - if prev, exists := owner[key]; exists { - err := fmt.Errorf("port %d conflicts with service %q on the same host", *svc.Port, prev) - errs = append(errs, validation.NewError(err, path.Append("port"))) - } else { - owner[key] = string(svc.ServiceID) - } - } - return errs -} - func validateBackupConfig(cfg *api.BackupConfigSpec, path validation.Path) []error { var errs []error diff --git a/server/internal/api/apiv1/validate_test.go b/server/internal/api/apiv1/validate_test.go index 40637a2e..b43aa701 100644 --- a/server/internal/api/apiv1/validate_test.go +++ b/server/internal/api/apiv1/validate_test.go @@ -79,61 +79,124 @@ func TestValidateMemory(t *testing.T) { } } -func TestValidatePorts(t *testing.T) { +func TestValidateUniquePorts(t *testing.T) { for _, tc := range []struct { - name string - postgresPort *int - patroniPort *int - expected string + name string + spec *api.DatabaseSpec + expected []string }{ { - name: "both nil", - postgresPort: nil, - patroniPort: nil, - }, - { - name: "postgres port nil", - postgresPort: nil, - patroniPort: utils.PointerTo(8888), - }, - { - name: "patroni port nil", - postgresPort: utils.PointerTo(8888), - patroniPort: nil, - }, - { - name: "both zero", - postgresPort: utils.PointerTo(0), - patroniPort: utils.PointerTo(0), + name: "no ports", + spec: &api.DatabaseSpec{ + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, }, { - name: "postgres port zero", - postgresPort: nil, - patroniPort: utils.PointerTo(0), + name: "patroni and postgres port conflict", + spec: &api.DatabaseSpec{ + Port: utils.PointerTo(5432), + PatroniPort: utils.PointerTo(5432), + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, + expected: []string{ + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].patroni_port, nodes[0].port`, + }, }, { - name: "patroni port zero", - postgresPort: utils.PointerTo(0), - patroniPort: nil, + name: "patroni and postgres port conflict with override", + spec: &api.DatabaseSpec{ + Port: utils.PointerTo(8888), + PatroniPort: utils.PointerTo(8888), + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + Port: utils.PointerTo(5432), + PatroniPort: utils.PointerTo(5432), + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, + expected: []string{ + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].patroni_port, nodes[0].port`, + }, }, { - name: "both defined non-equal", - postgresPort: utils.PointerTo(5432), - patroniPort: utils.PointerTo(8888), + name: "service port conflict", + spec: &api.DatabaseSpec{ + Port: utils.PointerTo(5432), + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + Services: []*api.ServiceSpec{ + { + ServiceID: "mcp", + ServiceType: "mcp", + Port: utils.PointerTo(5432), + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, + expected: []string{ + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].port, services[0].port`, + }, }, { - name: "conflicting", - postgresPort: utils.PointerTo(5432), - patroniPort: utils.PointerTo(5432), - expected: "postgres and patroni ports must not conflict", + name: "two nodes on same host", + spec: &api.DatabaseSpec{ + Port: utils.PointerTo(5432), + PatroniPort: utils.PointerTo(8888), + Nodes: []*api.DatabaseNodeSpec{ + { + Name: "n1", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + { + Name: "n2", + HostIds: []api.Identifier{ + api.Identifier("host-1"), + }, + }, + }, + }, + expected: []string{ + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].port, nodes[1].port`, + `duplicate ports allocated on host 'host-1': '8888' duplicated in: nodes[0].patroni_port, nodes[1].patroni_port`, + }, }, } { t.Run(tc.name, func(t *testing.T) { - err := validatePorts(tc.postgresPort, tc.patroniPort, nil) - if tc.expected == "" { + err := errors.Join(validateUniquePorts(tc.spec)...) + if len(tc.expected) < 1 { assert.NoError(t, err) } else { - assert.ErrorContains(t, err, tc.expected) + for _, expected := range tc.expected { + assert.ErrorContains(t, err, expected) + } } }) } @@ -474,71 +537,6 @@ func TestValidateNode(t *testing.T) { "patroni_port: patroni_port must be defined", }, }, - { - name: "invalid inherited ports", - orchestrator: config.OrchestratorSystemD, - db: &api.DatabaseSpec{ - Port: utils.PointerTo(5432), - PatroniPort: utils.PointerTo(5432), - }, - node: &api.DatabaseNodeSpec{ - HostIds: []api.Identifier{ - api.Identifier("host-1"), - }, - }, - expected: []string{ - "port: postgres and patroni ports must not conflict", - }, - }, - { - name: "invalid inherited db port", - orchestrator: config.OrchestratorSystemD, - db: &api.DatabaseSpec{ - Port: utils.PointerTo(5432), - PatroniPort: utils.PointerTo(8888), - }, - node: &api.DatabaseNodeSpec{ - PatroniPort: utils.PointerTo(5432), - HostIds: []api.Identifier{ - api.Identifier("host-1"), - }, - }, - expected: []string{ - "port: postgres and patroni ports must not conflict", - }, - }, - { - name: "invalid inherited patroni port", - orchestrator: config.OrchestratorSystemD, - db: &api.DatabaseSpec{ - Port: utils.PointerTo(5432), - PatroniPort: utils.PointerTo(8888), - }, - node: &api.DatabaseNodeSpec{ - Port: utils.PointerTo(8888), - HostIds: []api.Identifier{ - api.Identifier("host-1"), - }, - }, - expected: []string{ - "port: postgres and patroni ports must not conflict", - }, - }, - { - name: "invalid node ports", - orchestrator: config.OrchestratorSystemD, - db: &api.DatabaseSpec{}, - node: &api.DatabaseNodeSpec{ - Port: utils.PointerTo(5432), - PatroniPort: utils.PointerTo(5432), - HostIds: []api.Identifier{ - api.Identifier("host-1"), - }, - }, - expected: []string{ - "port: postgres and patroni ports must not conflict", - }, - }, { name: "invalid", orchestrator: config.OrchestratorSwarm, @@ -783,7 +781,7 @@ func TestValidateDatabaseSpec(t *testing.T) { }, }, expected: []string{ - `services[1].port: port 8080 conflicts with service "mcp-server" on the same host`, + `duplicate ports allocated on host 'host-1': '8080' duplicated in: services[0].port, services[1].port`, }, }, { @@ -972,7 +970,7 @@ func TestValidateDatabaseSpec(t *testing.T) { }, }, expected: []string{ - `port 5432 conflicts with service "postgres" on the same host`, + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].port, services[0].port`, }, }, { @@ -1008,7 +1006,7 @@ func TestValidateDatabaseSpec(t *testing.T) { }, }, expected: []string{ - `port 5433 conflicts with service "postgres" on the same host`, + `duplicate ports allocated on host 'host-1': '5433' duplicated in: nodes[0].port, services[0].port`, }, }, { @@ -1158,8 +1156,9 @@ func TestValidateDatabaseSpec(t *testing.T) { "nodes[2]: node names must be unique within a database", "backup_config.repositories[0].base_path: base_path must be absolute for posix repositories", "restore_config.repository.base_path: base_path must be absolute for posix repositories", - "port: postgres and patroni ports must not conflict", - "nodes[1].port: postgres and patroni ports must not conflict", + `duplicate ports allocated on host 'host-1': '5432' duplicated in: nodes[0].patroni_port, nodes[0].port`, + `duplicate ports allocated on host 'host-2': '8888' duplicated in: nodes[1].patroni_port, nodes[1].port`, + `duplicate ports allocated on host 'host-3': '5432' duplicated in: nodes[2].patroni_port, nodes[2].port`, }, }, { @@ -2323,38 +2322,6 @@ func TestValidateDatabaseUpdate_ServiceBootstrapFields(t *testing.T) { }, }, }, - { - name: "port conflict on update-database", - old: &database.Spec{}, - new: &api.DatabaseSpec{ - DatabaseUsers: []*api.DatabaseUserSpec{ - {Username: "app", DbOwner: utils.PointerTo(true)}, - }, - Services: []*api.ServiceSpec{ - { - ServiceID: "mcp-server", - ServiceType: "mcp", - Version: "latest", - HostIds: []api.Identifier{"host-1"}, - ConnectAs: "app", - Port: utils.PointerTo(8080), - Config: validMCPConfig, - }, - { - ServiceID: "postgrest-server", - ServiceType: "postgrest", - Version: "latest", - HostIds: []api.Identifier{"host-1"}, - ConnectAs: "app", - Port: utils.PointerTo(8080), - Config: map[string]any{}, - }, - }, - }, - expected: []string{ - `port 8080 conflicts with service "mcp-server" on the same host`, - }, - }, } { t.Run(tc.name, func(t *testing.T) { err := validateDatabaseUpdate(tc.old, tc.new) diff --git a/server/internal/ds/set.go b/server/internal/ds/set.go index ef829c77..004a77a5 100644 --- a/server/internal/ds/set.go +++ b/server/internal/ds/set.go @@ -1,7 +1,9 @@ package ds import ( + "cmp" "slices" + "strings" ) // Set is a generic set type. @@ -131,3 +133,17 @@ func SetDifference[T comparable](a, b []T) Set[T] { func SetSymmetricDifference[T comparable](a, b []T) Set[T] { return NewSet(a...).SymmetricDifference(NewSet(b...)) } + +// SetToString is a shortcut for producing a sorted, comma-separated string +// representation of a Set of string-ish values. +func SetToString[T ~string](s Set[T]) string { + lastIdx := s.Size() - 1 + var builder strings.Builder + for i, element := range s.ToSortedSlice(cmp.Compare) { + builder.WriteString(string(element)) + if i < lastIdx { + builder.WriteString(", ") + } + } + return builder.String() +} diff --git a/server/internal/validation/validators.go b/server/internal/validation/validators.go new file mode 100644 index 00000000..7603b012 --- /dev/null +++ b/server/internal/validation/validators.go @@ -0,0 +1,50 @@ +package validation + +import ( + "cmp" + "errors" + "fmt" + "maps" + "slices" + + "github.com/pgEdge/control-plane/server/internal/ds" +) + +var ErrUnique = errors.New("must be unique") + +type Unique[T cmp.Ordered] struct { + seen map[T]ds.Set[string] +} + +func NewUnique[T cmp.Ordered]() *Unique[T] { + return &Unique[T]{ + seen: map[T]ds.Set[string]{}, + } +} + +func (u *Unique[T]) RecordSeen(path Path, value T) { + if u.seen == nil { + u.seen = make(map[T]ds.Set[string]) + } + if _, ok := u.seen[value]; !ok { + u.seen[value] = ds.NewSet[string]() + } + u.seen[value].Add(path.String()) +} + +func (u *Unique[T]) Validate(base error) []error { + if base == nil { + base = ErrUnique + } + var errs []error + for _, key := range slices.Sorted(maps.Keys(u.seen)) { + paths := u.seen[key] + if len(paths) <= 1 { + continue + } + errs = append(errs, &Error{ + Err: fmt.Errorf("%w: '%v' duplicated in: %s", base, key, ds.SetToString(paths)), + }) + } + return errs +} From 2a90ed15bd38e7eeb505be83df1689f40740e0f0 Mon Sep 17 00:00:00 2001 From: Jason Lynch Date: Tue, 16 Jun 2026 15:02:16 -0400 Subject: [PATCH 3/3] docs: add changelog entry for unique port validation PLAT-611 --- changes/unreleased/Fixed-20260616-150152.yaml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 changes/unreleased/Fixed-20260616-150152.yaml diff --git a/changes/unreleased/Fixed-20260616-150152.yaml b/changes/unreleased/Fixed-20260616-150152.yaml new file mode 100644 index 00000000..23098c1e --- /dev/null +++ b/changes/unreleased/Fixed-20260616-150152.yaml @@ -0,0 +1,3 @@ +kind: Fixed +body: Database specs with duplicate allocated ports on a single host are now rejected by the API. +time: 2026-06-16T15:01:52.449574-04:00