diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 844dee16021..ee2e5bc8e3f 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -3,6 +3,7 @@ package tunnel import ( "bufio" "context" + "encoding/json" "fmt" "net/url" "os" @@ -32,6 +33,7 @@ import ( "github.com/cloudflare/cloudflared/diagnostic" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/k8s" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/metrics" @@ -174,6 +176,7 @@ func Commands() []*cli.Command { buildCleanupCommand(), buildTokenCommand(), buildDiagCommand(), + buildKubernetesSubcommand(), proxydns.Command(), // removed feature, only here for error message cliutil.RemovedCommand("db-connect"), } @@ -445,6 +448,45 @@ func StartServer( return err } + // Start Kubernetes service watcher if enabled + cfg := config.GetConfiguration() + if cfg.Kubernetes.Enabled { + k8sCfg := &k8s.Config{ + Enabled: cfg.Kubernetes.Enabled, + Namespace: cfg.Kubernetes.Namespace, + BaseDomain: cfg.Kubernetes.BaseDomain, + KubeconfigPath: cfg.Kubernetes.KubeconfigPath, + ExposeAPIServer: cfg.Kubernetes.ExposeAPIServer, + APIServerHostname: cfg.Kubernetes.APIServerHostname, + LabelSelector: cfg.Kubernetes.LabelSelector, + } + if err := k8sCfg.Validate(); err != nil { + log.Warn().Err(err).Msg("Kubernetes config validation failed, watcher will not start") + } else { + k8sWatcher := k8s.NewWatcher(k8sCfg, log, func(services []k8s.ServiceInfo) { + log.Info().Int("count", len(services)).Msg("Kubernetes service change detected, updating ingress rules") + k8sRules := k8s.GenerateIngressRules(services, log) + updatedIngress := k8s.MergeWithExistingRules(cfg.Ingress, k8sRules) + newConfigBytes, err := json.Marshal(ingress.RemoteConfigJSON{ + IngressRules: updatedIngress, + WarpRouting: cfg.WarpRouting, + }) + if err != nil { + log.Err(err).Msg("Failed to marshal updated K8s ingress config") + return + } + resp := orchestrator.UpdateK8sConfig(newConfigBytes) + if resp.Err != nil { + log.Err(resp.Err).Msg("Failed to apply K8s ingress config update") + } else { + log.Info().Int("services", len(services)).Msg("Successfully applied K8s ingress config update") + } + }) + go k8sWatcher.Run(ctx) + log.Info().Msg("Kubernetes service watcher started") + } + } + metricsListener, err := metrics.CreateMetricsListener(&listeners, c.String("metrics")) if err != nil { log.Err(err).Msg("Error opening metrics server listener") diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 9356673e9b8..f44d5c42ace 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -27,6 +27,7 @@ import ( "github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress/origins" + "github.com/cloudflare/cloudflared/k8s" "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" @@ -151,6 +152,36 @@ func prepareTunnelConfig( } cfg := config.GetConfiguration() + + // If Kubernetes integration is enabled, discover services and merge with config + if cfg.Kubernetes.Enabled { + k8sCfg := &k8s.Config{ + Enabled: cfg.Kubernetes.Enabled, + Namespace: cfg.Kubernetes.Namespace, + BaseDomain: cfg.Kubernetes.BaseDomain, + KubeconfigPath: cfg.Kubernetes.KubeconfigPath, + ExposeAPIServer: cfg.Kubernetes.ExposeAPIServer, + APIServerHostname: cfg.Kubernetes.APIServerHostname, + LabelSelector: cfg.Kubernetes.LabelSelector, + } + if err := k8sCfg.Validate(); err != nil { + log.Warn().Err(err).Msg("Kubernetes integration config validation failed, skipping K8s discovery") + } else { + // Use a timeout so K8s discovery doesn't block tunnel startup indefinitely. + k8sCtx, k8sCancel := context.WithTimeout(ctx, 30*time.Second) + services, err := k8s.DiscoverServices(k8sCtx, k8sCfg, log) + k8sCancel() + if err != nil { + log.Warn().Err(err).Msg("Failed to discover Kubernetes services, continuing without K8s rules") + } else if len(services) > 0 { + k8sRules := k8s.GenerateIngressRules(services, log) + cfg.Ingress = k8s.MergeWithExistingRules(cfg.Ingress, k8sRules) + log.Info().Int("k8sServices", len(services)).Int("totalRules", len(cfg.Ingress)). + Msg("Merged Kubernetes-discovered services into ingress rules") + } + } + } + ingressRules, err := ingress.ParseIngressFromConfigAndCLI(cfg, c, log) if err != nil { return nil, nil, err diff --git a/cmd/cloudflared/tunnel/k8s_subcommands.go b/cmd/cloudflared/tunnel/k8s_subcommands.go new file mode 100644 index 00000000000..951ca3b0962 --- /dev/null +++ b/cmd/cloudflared/tunnel/k8s_subcommands.go @@ -0,0 +1,295 @@ +package tunnel + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/rs/zerolog" + "github.com/urfave/cli/v2" + + "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" + "github.com/cloudflare/cloudflared/k8s" + "github.com/cloudflare/cloudflared/logger" +) + +const ( + k8sBaseDomainFlag = "k8s-base-domain" + k8sNamespaceFlag = "k8s-namespace" + k8sKubeconfigFlag = "k8s-kubeconfig" + k8sExposeAPIServerFlag = "k8s-expose-api-server" + k8sAPIServerHostnameFlag = "k8s-api-server-hostname" + k8sLabelSelectorFlag = "k8s-label-selector" + k8sOutputFormatFlag = "k8s-output" +) + +func buildKubernetesSubcommand() *cli.Command { + return &cli.Command{ + Name: "kubernetes", + Aliases: []string{"k8s"}, + Category: "Tunnel", + Usage: "Discover and manage Kubernetes services exposed through Cloudflare Tunnel", + Description: ` The kubernetes subcommand provides native integration between cloudflared and + Kubernetes clusters. It can automatically discover annotated Kubernetes + services and generate ingress rules for them. + + To mark a service for exposure through the tunnel, add the annotation: + cloudflared.cloudflare.com/tunnel: "true" + + Optional annotations: + cloudflared.cloudflare.com/hostname: Override the public hostname + cloudflared.cloudflare.com/path: Path regex for the ingress rule + cloudflared.cloudflare.com/scheme: Origin scheme (http/https) + cloudflared.cloudflare.com/port: Select which port to proxy + cloudflared.cloudflare.com/no-tls-verify: Disable TLS verification + cloudflared.cloudflare.com/origin-server-name: Set SNI for TLS + + Example: + # Discover services from the current cluster + cloudflared tunnel kubernetes discover --k8s-base-domain example.com + + # Watch for changes continuously + cloudflared tunnel kubernetes watch --k8s-base-domain example.com + + # Generate an ingress config YAML snippet + cloudflared tunnel kubernetes generate-config --k8s-base-domain example.com`, + Subcommands: []*cli.Command{ + buildK8sDiscoverCommand(), + buildK8sWatchCommand(), + buildK8sGenerateConfigCommand(), + }, + } +} + +func k8sFlags() []cli.Flag { + return []cli.Flag{ + &cli.StringFlag{ + Name: k8sBaseDomainFlag, + Usage: "Base domain for auto-generated hostnames (e.g. example.com). Services will be exposed as -.example.com", + EnvVars: []string{"TUNNEL_K8S_BASE_DOMAIN"}, + }, + &cli.StringFlag{ + Name: k8sNamespaceFlag, + Usage: "Limit discovery to a specific Kubernetes namespace. Empty means all namespaces.", + EnvVars: []string{"TUNNEL_K8S_NAMESPACE"}, + }, + &cli.StringFlag{ + Name: k8sKubeconfigFlag, + Usage: "Path to a kubeconfig file. When empty, in-cluster config is used.", + EnvVars: []string{"KUBECONFIG"}, + }, + &cli.BoolFlag{ + Name: k8sExposeAPIServerFlag, + Usage: "Also expose the Kubernetes API server through the tunnel", + EnvVars: []string{"TUNNEL_K8S_EXPOSE_API_SERVER"}, + }, + &cli.StringFlag{ + Name: k8sAPIServerHostnameFlag, + Usage: "Public hostname for the Kubernetes API server (required when --k8s-expose-api-server is set)", + EnvVars: []string{"TUNNEL_K8S_API_SERVER_HOSTNAME"}, + }, + &cli.StringFlag{ + Name: k8sLabelSelectorFlag, + Usage: "Kubernetes label selector to filter services (e.g. app=web)", + EnvVars: []string{"TUNNEL_K8S_LABEL_SELECTOR"}, + }, + &cli.StringFlag{ + Name: k8sOutputFormatFlag, + Usage: "Output format: json, yaml, or table (default: table)", + Value: "table", + }, + } +} + +func k8sConfigFromCLI(c *cli.Context) *k8s.Config { + return &k8s.Config{ + Enabled: true, + BaseDomain: c.String(k8sBaseDomainFlag), + Namespace: c.String(k8sNamespaceFlag), + KubeconfigPath: c.String(k8sKubeconfigFlag), + ExposeAPIServer: c.Bool(k8sExposeAPIServerFlag), + APIServerHostname: c.String(k8sAPIServerHostnameFlag), + LabelSelector: c.String(k8sLabelSelectorFlag), + } +} + +// ----------------------------------------------------------------------- +// discover subcommand +// ----------------------------------------------------------------------- + +func buildK8sDiscoverCommand() *cli.Command { + return &cli.Command{ + Name: "discover", + Usage: "Discover annotated Kubernetes services", + Flags: k8sFlags(), + Action: cliutil.ConfiguredAction(k8sDiscoverAction), + } +} + +func k8sDiscoverAction(c *cli.Context) error { + log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog) + cfg := k8sConfigFromCLI(c) + if err := cfg.Validate(); err != nil { + return err + } + + ctx, cancel := context.WithCancel(c.Context) + defer cancel() + + services, err := k8s.DiscoverServices(ctx, cfg, log) + if err != nil { + return err + } + + return printServices(services, c.String(k8sOutputFormatFlag), log) +} + +// ----------------------------------------------------------------------- +// watch subcommand +// ----------------------------------------------------------------------- + +func buildK8sWatchCommand() *cli.Command { + return &cli.Command{ + Name: "watch", + Usage: "Continuously watch for Kubernetes service changes", + Flags: k8sFlags(), + Action: cliutil.ConfiguredAction(k8sWatchAction), + } +} + +func k8sWatchAction(c *cli.Context) error { + log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog) + cfg := k8sConfigFromCLI(c) + if err := cfg.Validate(); err != nil { + return err + } + + ctx, cancel := context.WithCancel(c.Context) + defer cancel() + + outputFormat := c.String(k8sOutputFormatFlag) + + watcher := k8s.NewWatcher(cfg, log, func(services []k8s.ServiceInfo) { + log.Info().Int("count", len(services)).Msg("Service change detected") + if err := printServices(services, outputFormat, log); err != nil { + log.Err(err).Msg("Failed to print services") + } + }) + + // Handle OS signals for graceful shutdown + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigC + log.Info().Msg("Received shutdown signal, stopping watcher...") + cancel() + }() + + watcher.Run(ctx) + return nil +} + +// ----------------------------------------------------------------------- +// generate-config subcommand +// ----------------------------------------------------------------------- + +func buildK8sGenerateConfigCommand() *cli.Command { + return &cli.Command{ + Name: "generate-config", + Usage: "Generate cloudflared ingress configuration from discovered Kubernetes services", + Flags: k8sFlags(), + Action: cliutil.ConfiguredAction(k8sGenerateConfigAction), + } +} + +func k8sGenerateConfigAction(c *cli.Context) error { + log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog) + cfg := k8sConfigFromCLI(c) + if err := cfg.Validate(); err != nil { + return err + } + + ctx, cancel := context.WithCancel(c.Context) + defer cancel() + + services, err := k8s.DiscoverServices(ctx, cfg, log) + if err != nil { + return err + } + + if len(services) == 0 { + log.Warn().Msg("No annotated Kubernetes services found") + return nil + } + + rules := k8s.GenerateIngressRules(services, log) + + // Output as YAML config snippet + fmt.Println("# Auto-generated cloudflared ingress configuration from Kubernetes services") + fmt.Println("# Add the following to your cloudflared config.yml under the 'ingress' key:") + fmt.Println("ingress:") + for _, r := range rules { + if r.Hostname != "" { + fmt.Printf(" - hostname: %s\n", r.Hostname) + } else { + fmt.Println(" - hostname: \"*\"") + } + if r.Path != "" { + fmt.Printf(" path: %s\n", r.Path) + } + fmt.Printf(" service: %s\n", r.Service) + + hasNoTLS := r.OriginRequest.NoTLSVerify != nil && *r.OriginRequest.NoTLSVerify + hasSNI := r.OriginRequest.OriginServerName != nil && *r.OriginRequest.OriginServerName != "" + if hasNoTLS || hasSNI { + fmt.Println(" originRequest:") + if hasNoTLS { + fmt.Println(" noTLSVerify: true") + } + if hasSNI { + fmt.Printf(" originServerName: %s\n", *r.OriginRequest.OriginServerName) + } + } + } + // Add catch-all + fmt.Println(" - service: http_status:404") + + return nil +} + +// ----------------------------------------------------------------------- +// Output helpers +// ----------------------------------------------------------------------- + +func printServices(services []k8s.ServiceInfo, format string, log *zerolog.Logger) error { + if len(services) == 0 { + log.Info().Msg("No annotated Kubernetes services found") + return nil + } + + switch format { + case "json": + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + return enc.Encode(services) + case "yaml": + for _, s := range services { + fmt.Printf("- name: %s\n namespace: %s\n hostname: %s\n origin: %s\n", + s.Name, s.Namespace, s.Hostname, s.OriginURL()) + if s.Path != "" { + fmt.Printf(" path: %s\n", s.Path) + } + } + return nil + default: // table + fmt.Printf("%-30s %-15s %-40s %-35s %s\n", "SERVICE", "NAMESPACE", "HOSTNAME", "ORIGIN", "PATH") + fmt.Printf("%-30s %-15s %-40s %-35s %s\n", "-------", "---------", "--------", "------", "----") + for _, s := range services { + fmt.Printf("%-30s %-15s %-40s %-35s %s\n", s.Name, s.Namespace, s.Hostname, s.OriginURL(), s.Path) + } + return nil + } +} diff --git a/config/configuration.go b/config/configuration.go index cb0b0adeda5..0a926c41649 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -257,9 +257,30 @@ type Configuration struct { Ingress []UnvalidatedIngressRule WarpRouting WarpRoutingConfig `yaml:"warp-routing"` OriginRequest OriginRequestConfig `yaml:"originRequest"` + Kubernetes KubernetesConfig `yaml:"kubernetes"` sourceFile string } +// KubernetesConfig holds the configuration for the Kubernetes service discovery +// integration. When enabled, cloudflared will discover annotated Kubernetes +// services and automatically generate ingress rules for them. +type KubernetesConfig struct { + // Enabled turns the Kubernetes watcher on. + Enabled bool `yaml:"enabled" json:"enabled"` + // Namespace limits discovery to a single namespace. Empty means all namespaces. + Namespace string `yaml:"namespace,omitempty" json:"namespace,omitempty"` + // BaseDomain is the base domain appended when generating hostnames. + BaseDomain string `yaml:"baseDomain,omitempty" json:"baseDomain,omitempty"` + // KubeconfigPath is an optional path to a kubeconfig file. + KubeconfigPath string `yaml:"kubeconfigPath,omitempty" json:"kubeconfigPath,omitempty"` + // ExposeAPIServer, when true, creates an ingress rule for the K8s API server. + ExposeAPIServer bool `yaml:"exposeAPIServer,omitempty" json:"exposeAPIServer,omitempty"` + // APIServerHostname is the public hostname for the K8s API server. + APIServerHostname string `yaml:"apiServerHostname,omitempty" json:"apiServerHostname,omitempty"` + // LabelSelector is an optional Kubernetes label selector. + LabelSelector string `yaml:"labelSelector,omitempty" json:"labelSelector,omitempty"` +} + type WarpRoutingConfig struct { ConnectTimeout *CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"` MaxActiveFlows *uint64 `yaml:"maxActiveFlows" json:"maxActiveFlows,omitempty"` diff --git a/k8s/config.go b/k8s/config.go new file mode 100644 index 00000000000..fe4c97e5aeb --- /dev/null +++ b/k8s/config.go @@ -0,0 +1,94 @@ +// Package k8s provides Kubernetes service discovery and automatic ingress rule +// generation for Cloudflare Tunnel. When running inside (or with access to) a +// Kubernetes cluster, this package can watch for annotated Services and +// automatically expose them through the tunnel without manual ingress +// configuration. +package k8s + +import ( + "fmt" + "time" +) + +const ( + // AnnotationEnabled is the annotation key that must be set to "true" on a + // Kubernetes Service for it to be discovered and exposed through the tunnel. + AnnotationEnabled = "cloudflared.cloudflare.com/tunnel" + + // AnnotationHostname optionally overrides the hostname that will be used + // in the generated ingress rule. If not set, a hostname is synthesised from + // the service name, namespace, and the configured base domain. + AnnotationHostname = "cloudflared.cloudflare.com/hostname" + + // AnnotationPath optionally specifies a path regex for the ingress rule. + AnnotationPath = "cloudflared.cloudflare.com/path" + + // AnnotationScheme overrides the scheme used to reach the origin. + // Defaults to "http" for non-TLS ports and "https" for port 443. + AnnotationScheme = "cloudflared.cloudflare.com/scheme" + + // AnnotationPort overrides which service port to route traffic to when + // the service exposes multiple ports. If unset the first port is used. + AnnotationPort = "cloudflared.cloudflare.com/port" + + // AnnotationNoTLSVerify disables TLS verification for the origin. + AnnotationNoTLSVerify = "cloudflared.cloudflare.com/no-tls-verify" + + // AnnotationOriginServerName sets the SNI for TLS connections to the origin. + AnnotationOriginServerName = "cloudflared.cloudflare.com/origin-server-name" + + // DefaultResyncPeriod is how often the informer re-lists all Services even + // if no watch events have been received. + DefaultResyncPeriod = 30 * time.Second +) + +// Config holds the user-facing configuration for the Kubernetes integration. +type Config struct { + // Enabled turns the Kubernetes watcher on. + Enabled bool `yaml:"enabled" json:"enabled"` + + // Namespace limits discovery to a single namespace. Empty means all namespaces. + Namespace string `yaml:"namespace,omitempty" json:"namespace,omitempty"` + + // BaseDomain is the base domain appended when generating hostnames, e.g. + // "example.com" results in "-.example.com". + BaseDomain string `yaml:"baseDomain,omitempty" json:"baseDomain,omitempty"` + + // KubeconfigPath is an optional path to a kubeconfig file. When empty the + // in-cluster config is used. + KubeconfigPath string `yaml:"kubeconfigPath,omitempty" json:"kubeconfigPath,omitempty"` + + // ExposeAPIServer, when true, creates an ingress rule for the Kubernetes + // API server (typically at https://kubernetes.default.svc). + ExposeAPIServer bool `yaml:"exposeAPIServer,omitempty" json:"exposeAPIServer,omitempty"` + + // APIServerHostname is the public hostname through which the K8s API server + // will be reachable. Required when ExposeAPIServer is true. + APIServerHostname string `yaml:"apiServerHostname,omitempty" json:"apiServerHostname,omitempty"` + + // LabelSelector is an optional Kubernetes label selector (e.g. "app=web") + // to filter which services to consider. Works in addition to the annotation + // check. + LabelSelector string `yaml:"labelSelector,omitempty" json:"labelSelector,omitempty"` + + // ResyncPeriod controls how often the full service list is re-synchronised. + // Defaults to DefaultResyncPeriod. + ResyncPeriod time.Duration `yaml:"resyncPeriod,omitempty" json:"resyncPeriod,omitempty"` +} + +// Validate checks that the configuration is internally consistent. +func (c *Config) Validate() error { + if !c.Enabled { + return nil + } + if c.BaseDomain == "" { + return fmt.Errorf("kubernetes.baseDomain is required when kubernetes integration is enabled") + } + if c.ExposeAPIServer && c.APIServerHostname == "" { + return fmt.Errorf("kubernetes.apiServerHostname is required when exposeAPIServer is true") + } + if c.ResyncPeriod == 0 { + c.ResyncPeriod = DefaultResyncPeriod + } + return nil +} diff --git a/k8s/discovery.go b/k8s/discovery.go new file mode 100644 index 00000000000..8889fd323ba --- /dev/null +++ b/k8s/discovery.go @@ -0,0 +1,440 @@ +package k8s + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/rs/zerolog" +) + +// ServiceInfo represents a discovered Kubernetes Service with enough +// information to build an ingress rule. +type ServiceInfo struct { + Name string `json:"name"` + Namespace string `json:"namespace"` + // ClusterIP is the internal IP of the service. + ClusterIP string `json:"clusterIP"` + // Port is the port selected for proxying. + Port int32 `json:"port"` + // PortName is the name of the selected port (if any). + PortName string `json:"portName,omitempty"` + // Scheme is http or https. + Scheme string `json:"scheme"` + // Hostname is the fully-qualified public hostname. + Hostname string `json:"hostname"` + // Path is an optional path regex from the annotation. + Path string `json:"path,omitempty"` + // NoTLSVerify disables TLS certificate verification for the origin. + NoTLSVerify bool `json:"noTLSVerify,omitempty"` + // OriginServerName is the SNI server name for TLS. + OriginServerName string `json:"originServerName,omitempty"` +} + +// OriginURL returns the URL that cloudflared should proxy traffic to. +func (s *ServiceInfo) OriginURL() string { + return fmt.Sprintf("%s://%s:%d", s.Scheme, s.ClusterIP, s.Port) +} + +// ----------------------------------------------------------------------- +// Lightweight Kubernetes client — no dependency on client-go +// ----------------------------------------------------------------------- + +// kubeClient is a minimal Kubernetes REST client that can list and watch +// Service resources. +type kubeClient struct { + baseURL string + httpClient *http.Client + token string + log *zerolog.Logger +} + +// newInClusterClient builds a kubeClient from the standard in-cluster service +// account files. +func newInClusterClient(log *zerolog.Logger) (*kubeClient, error) { + const ( + tokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" //nolint:gosec // Not a credential, this is a well-known file path + caPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + nsPath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" + serviceEnv = "KUBERNETES_SERVICE_HOST" + portEnv = "KUBERNETES_SERVICE_PORT" + ) + + host := os.Getenv(serviceEnv) + port := os.Getenv(portEnv) + if host == "" || port == "" { + return nil, fmt.Errorf("not running inside a Kubernetes cluster (KUBERNETES_SERVICE_HOST/PORT not set)") + } + + tokenBytes, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("cannot read service account token: %w", err) + } + + // Load the cluster CA certificate for TLS verification against the API server. + httpClient := &http.Client{Timeout: 30 * time.Second} + caCert, err := os.ReadFile(caPath) + if err != nil { + // If CA cert is not available, fall back to default system trust store + // but log a warning — TLS may fail for self-signed API server certs. + if log != nil { + log.Warn().Err(err).Msg("Could not load in-cluster CA cert, falling back to system trust store") + } + } else { + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + httpClient.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + }, + } + } + + return &kubeClient{ + baseURL: fmt.Sprintf("https://%s:%s", host, port), + httpClient: httpClient, + token: strings.TrimSpace(string(tokenBytes)), + log: log, + }, nil +} + +// newKubeconfigClient builds a kubeClient from a kubeconfig-style file. +// This is a simplified parser that reads the first cluster/user. +func newKubeconfigClient(kubeconfigPath string, log *zerolog.Logger) (*kubeClient, error) { + data, err := os.ReadFile(kubeconfigPath) + if err != nil { + return nil, fmt.Errorf("cannot read kubeconfig %s: %w", kubeconfigPath, err) + } + + kc, err := parseKubeconfig(data) + if err != nil { + return nil, err + } + + return &kubeClient{ + baseURL: kc.server, + httpClient: &http.Client{Timeout: 30 * time.Second}, + token: kc.token, + log: log, + }, nil +} + +// kubeconfigInfo holds the minimal info parsed from a kubeconfig. +type kubeconfigInfo struct { + server string + token string +} + +// parseKubeconfig is a very simple YAML→JSON-style parser for kubeconfig. +// It reads the current-context and extracts the server URL and bearer token. +// For a production implementation you would use "k8s.io/client-go/tools/clientcmd". +func parseKubeconfig(data []byte) (*kubeconfigInfo, error) { + // Attempt a simple JSON parse (kubeconfig can be JSON or YAML). + // For YAML we do a basic line-scan fallback. + type namedCluster struct { + Name string `json:"name"` + Cluster struct { + Server string `json:"server"` + } `json:"cluster"` + } + type namedUser struct { + Name string `json:"name"` + User struct { + Token string `json:"token"` + } `json:"user"` + } + type namedContext struct { + Name string `json:"name"` + Context struct { + Cluster string `json:"cluster"` + User string `json:"user"` + } `json:"context"` + } + type kubeConfig struct { + CurrentContext string `json:"current-context"` + Clusters []namedCluster `json:"clusters"` + Users []namedUser `json:"users"` + Contexts []namedContext `json:"contexts"` + } + + var kc kubeConfig + if err := json.Unmarshal(data, &kc); err != nil { + // Not valid JSON – return a generic error for now. + return nil, fmt.Errorf("failed to parse kubeconfig: %w (only JSON format is supported in this implementation)", err) + } + + // Resolve current context. + var clusterName, userName string + for _, ctx := range kc.Contexts { + if ctx.Name == kc.CurrentContext { + clusterName = ctx.Context.Cluster + userName = ctx.Context.User + break + } + } + if clusterName == "" { + return nil, fmt.Errorf("current-context %q not found in kubeconfig", kc.CurrentContext) + } + + var server, token string + for _, c := range kc.Clusters { + if c.Name == clusterName { + server = c.Cluster.Server + break + } + } + for _, u := range kc.Users { + if u.Name == userName { + token = u.User.Token + break + } + } + if server == "" { + return nil, fmt.Errorf("cluster %q server URL not found in kubeconfig", clusterName) + } + + return &kubeconfigInfo{server: server, token: token}, nil +} + +// do executes an authenticated HTTP request against the API server. +func (kc *kubeClient) do(ctx context.Context, method, path string) ([]byte, error) { + url := kc.baseURL + path + req, err := http.NewRequestWithContext(ctx, method, url, nil) + if err != nil { + return nil, err + } + if kc.token != "" { + req.Header.Set("Authorization", "Bearer "+kc.token) + } + req.Header.Set("Accept", "application/json") + + resp, err := kc.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("k8s API request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading k8s API response: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("k8s API returned HTTP %d: %s", resp.StatusCode, string(body)) + } + return body, nil +} + +// ----------------------------------------------------------------------- +// K8s API response types (minimal) +// ----------------------------------------------------------------------- + +type serviceList struct { + Items []serviceItem `json:"items"` +} + +type serviceItem struct { + Metadata objectMeta `json:"metadata"` + Spec serviceSpec `json:"spec"` +} + +type objectMeta struct { + Name string `json:"name"` + Namespace string `json:"namespace"` + Labels map[string]string `json:"labels"` + Annotations map[string]string `json:"annotations"` +} + +type serviceSpec struct { + ClusterIP string `json:"clusterIP"` + Ports []servicePort `json:"ports"` + Type string `json:"type"` +} + +type servicePort struct { + Name string `json:"name"` + Port int32 `json:"port"` + Protocol string `json:"protocol"` +} + +// ----------------------------------------------------------------------- +// Discovery logic +// ----------------------------------------------------------------------- + +// DiscoverServices queries the Kubernetes API for Services annotated with +// AnnotationEnabled = "true" and returns ServiceInfo descriptors. +func DiscoverServices(ctx context.Context, cfg *Config, log *zerolog.Logger) ([]ServiceInfo, error) { + client, err := buildClient(cfg, log) + if err != nil { + return nil, err + } + + path := "/api/v1/services" + if cfg.Namespace != "" { + path = fmt.Sprintf("/api/v1/namespaces/%s/services", cfg.Namespace) + } + if cfg.LabelSelector != "" { + path += "?labelSelector=" + url.QueryEscape(cfg.LabelSelector) + } + + body, err := client.do(ctx, http.MethodGet, path) + if err != nil { + return nil, fmt.Errorf("listing services: %w", err) + } + + var list serviceList + if err := json.Unmarshal(body, &list); err != nil { + return nil, fmt.Errorf("parsing service list: %w", err) + } + + services := make([]ServiceInfo, 0, len(list.Items)) + for _, item := range list.Items { + ann := item.Metadata.Annotations + if ann == nil { + continue + } + enabled, ok := ann[AnnotationEnabled] + if !ok || !isTrue(enabled) { + continue + } + + si, err := serviceInfoFromItem(item, cfg) + if err != nil { + log.Warn().Err(err). + Str("service", item.Metadata.Name). + Str("namespace", item.Metadata.Namespace). + Msg("Skipping service due to error") + continue + } + services = append(services, *si) + } + + // Optionally expose the API server itself. + if cfg.ExposeAPIServer && cfg.APIServerHostname != "" { + apiSvc := ServiceInfo{ + Name: "kubernetes-api", + Namespace: "default", + ClusterIP: strings.TrimPrefix(strings.TrimPrefix(client.baseURL, "https://"), "http://"), + Port: 443, + Scheme: "https", + Hostname: cfg.APIServerHostname, + NoTLSVerify: true, // API server cert may not match the public hostname + } + // If the baseURL contains host:port, split it. + if hp := strings.SplitN(apiSvc.ClusterIP, ":", 2); len(hp) == 2 { + apiSvc.ClusterIP = hp[0] + if p, err := parseInt32(hp[1]); err == nil { + apiSvc.Port = p + } + } + services = append(services, apiSvc) + } + + return services, nil +} + +// serviceInfoFromItem converts a raw Kubernetes service item into a ServiceInfo. +func serviceInfoFromItem(item serviceItem, cfg *Config) (*ServiceInfo, error) { + ann := item.Metadata.Annotations + spec := item.Spec + + if spec.ClusterIP == "" || spec.ClusterIP == "None" { + return nil, fmt.Errorf("service %s/%s has no ClusterIP (headless services are not supported)", + item.Metadata.Namespace, item.Metadata.Name) + } + + port, portName, err := selectPort(spec.Ports, ann[AnnotationPort]) + if err != nil { + return nil, err + } + + scheme := "http" + if v, ok := ann[AnnotationScheme]; ok { + scheme = v + } else if port == 443 { + scheme = "https" + } + + hostname := ann[AnnotationHostname] + if hostname == "" { + hostname = fmt.Sprintf("%s-%s.%s", item.Metadata.Name, item.Metadata.Namespace, cfg.BaseDomain) + } + + si := &ServiceInfo{ + Name: item.Metadata.Name, + Namespace: item.Metadata.Namespace, + ClusterIP: spec.ClusterIP, + Port: port, + PortName: portName, + Scheme: scheme, + Hostname: hostname, + Path: ann[AnnotationPath], + } + + if v, ok := ann[AnnotationNoTLSVerify]; ok && isTrue(v) { + si.NoTLSVerify = true + } + if v, ok := ann[AnnotationOriginServerName]; ok { + si.OriginServerName = v + } + + return si, nil +} + +// selectPort picks the port to use from the service's port list. +func selectPort(ports []servicePort, portAnnotation string) (int32, string, error) { + if len(ports) == 0 { + return 0, "", fmt.Errorf("service has no ports") + } + if portAnnotation == "" { + return ports[0].Port, ports[0].Name, nil + } + // Match by name first, then by number. + for _, p := range ports { + if p.Name == portAnnotation { + return p.Port, p.Name, nil + } + } + portNum, err := parseInt32(portAnnotation) + if err == nil { + for _, p := range ports { + if p.Port == portNum { + return p.Port, p.Name, nil + } + } + } + return 0, "", fmt.Errorf("port %q not found in service", portAnnotation) +} + +func buildClient(cfg *Config, log *zerolog.Logger) (*kubeClient, error) { + if cfg.KubeconfigPath != "" { + path := cfg.KubeconfigPath + if strings.HasPrefix(path, "~") { + if home, err := os.UserHomeDir(); err == nil { + path = filepath.Join(home, path[1:]) + } + } + return newKubeconfigClient(path, log) + } + return newInClusterClient(log) +} + +func isTrue(s string) bool { + s = strings.ToLower(strings.TrimSpace(s)) + return s == "true" || s == "1" || s == "yes" +} + +func parseInt32(s string) (int32, error) { + var v int32 + _, err := fmt.Sscanf(s, "%d", &v) + return v, err +} diff --git a/k8s/discovery_test.go b/k8s/discovery_test.go new file mode 100644 index 00000000000..1ae88ff29d5 --- /dev/null +++ b/k8s/discovery_test.go @@ -0,0 +1,247 @@ +package k8s + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + errMsg string + }{ + { + name: "disabled config is always valid", + cfg: Config{Enabled: false}, + }, + { + name: "enabled without baseDomain fails", + cfg: Config{Enabled: true}, + wantErr: true, + errMsg: "baseDomain", + }, + { + name: "exposeAPIServer without apiServerHostname fails", + cfg: Config{Enabled: true, BaseDomain: "example.com", ExposeAPIServer: true}, + wantErr: true, + errMsg: "apiServerHostname", + }, + { + name: "valid minimal config", + cfg: Config{Enabled: true, BaseDomain: "example.com"}, + }, + { + name: "valid full config", + cfg: Config{ + Enabled: true, + BaseDomain: "example.com", + Namespace: "default", + ExposeAPIServer: true, + APIServerHostname: "k8s.example.com", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.cfg.Validate() + if tc.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestSelectPort(t *testing.T) { + ports := []servicePort{ + {Name: "http", Port: 80, Protocol: "TCP"}, + {Name: "https", Port: 443, Protocol: "TCP"}, + {Name: "grpc", Port: 9090, Protocol: "TCP"}, + } + + tests := []struct { + name string + ports []servicePort + portAnnotation string + wantPort int32 + wantName string + wantErr bool + }{ + { + name: "no annotation selects first port", + ports: ports, + wantPort: 80, + wantName: "http", + }, + { + name: "select by name", + ports: ports, + portAnnotation: "https", + wantPort: 443, + wantName: "https", + }, + { + name: "select by number", + ports: ports, + portAnnotation: "9090", + wantPort: 9090, + wantName: "grpc", + }, + { + name: "non-existent port name fails", + ports: ports, + portAnnotation: "nonexistent", + wantErr: true, + }, + { + name: "empty port list fails", + ports: nil, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + port, name, err := selectPort(tc.ports, tc.portAnnotation) + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.wantPort, port) + assert.Equal(t, tc.wantName, name) + } + }) + } +} + +func TestServiceInfoFromItem(t *testing.T) { + cfg := &Config{ + Enabled: true, + BaseDomain: "example.com", + } + + t.Run("basic service", func(t *testing.T) { + item := serviceItem{ + Metadata: objectMeta{ + Name: "web", + Namespace: "default", + Annotations: map[string]string{AnnotationEnabled: "true"}, + }, + Spec: serviceSpec{ + ClusterIP: "10.96.0.1", + Ports: []servicePort{{Name: "http", Port: 80, Protocol: "TCP"}}, + }, + } + + si, err := serviceInfoFromItem(item, cfg) + require.NoError(t, err) + assert.Equal(t, "web", si.Name) + assert.Equal(t, "default", si.Namespace) + assert.Equal(t, "10.96.0.1", si.ClusterIP) + assert.Equal(t, int32(80), si.Port) + assert.Equal(t, "http", si.Scheme) + assert.Equal(t, "web-default.example.com", si.Hostname) + assert.Equal(t, "http://10.96.0.1:80", si.OriginURL()) + }) + + t.Run("service with custom hostname", func(t *testing.T) { + item := serviceItem{ + Metadata: objectMeta{ + Name: "api", + Namespace: "prod", + Annotations: map[string]string{ + AnnotationEnabled: "true", + AnnotationHostname: "api.mycompany.com", + }, + }, + Spec: serviceSpec{ + ClusterIP: "10.96.0.2", + Ports: []servicePort{{Name: "https", Port: 443, Protocol: "TCP"}}, + }, + } + + si, err := serviceInfoFromItem(item, cfg) + require.NoError(t, err) + assert.Equal(t, "api.mycompany.com", si.Hostname) + assert.Equal(t, "https", si.Scheme) + }) + + t.Run("headless service is rejected", func(t *testing.T) { + item := serviceItem{ + Metadata: objectMeta{ + Name: "headless", + Namespace: "default", + Annotations: map[string]string{AnnotationEnabled: "true"}, + }, + Spec: serviceSpec{ + ClusterIP: "None", + Ports: []servicePort{{Port: 80}}, + }, + } + _, err := serviceInfoFromItem(item, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "headless") + }) + + t.Run("custom port annotation", func(t *testing.T) { + item := serviceItem{ + Metadata: objectMeta{ + Name: "multi-port", + Namespace: "default", + Annotations: map[string]string{ + AnnotationEnabled: "true", + AnnotationPort: "grpc", + }, + }, + Spec: serviceSpec{ + ClusterIP: "10.96.0.3", + Ports: []servicePort{ + {Name: "http", Port: 80}, + {Name: "grpc", Port: 9090}, + }, + }, + } + si, err := serviceInfoFromItem(item, cfg) + require.NoError(t, err) + assert.Equal(t, int32(9090), si.Port) + assert.Equal(t, "grpc", si.PortName) + }) + + t.Run("no-tls-verify annotation", func(t *testing.T) { + item := serviceItem{ + Metadata: objectMeta{ + Name: "insecure", + Namespace: "default", + Annotations: map[string]string{ + AnnotationEnabled: "true", + AnnotationNoTLSVerify: "true", + AnnotationScheme: "https", + }, + }, + Spec: serviceSpec{ + ClusterIP: "10.96.0.4", + Ports: []servicePort{{Port: 8443}}, + }, + } + si, err := serviceInfoFromItem(item, cfg) + require.NoError(t, err) + assert.True(t, si.NoTLSVerify) + assert.Equal(t, "https", si.Scheme) + }) +} + +func TestIsTrue(t *testing.T) { + for _, v := range []string{"true", "True", "TRUE", "1", "yes", "YES"} { + assert.True(t, isTrue(v), "expected isTrue(%q) to be true", v) + } + for _, v := range []string{"false", "0", "no", "", "random"} { + assert.False(t, isTrue(v), "expected isTrue(%q) to be false", v) + } +} diff --git a/k8s/ingress.go b/k8s/ingress.go new file mode 100644 index 00000000000..4d8033b07ef --- /dev/null +++ b/k8s/ingress.go @@ -0,0 +1,103 @@ +package k8s + +import ( + "fmt" + + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/config" +) + +// GenerateIngressRules converts a slice of discovered Kubernetes ServiceInfo +// into cloudflared-compatible UnvalidatedIngressRule entries. The caller is +// responsible for appending a catch-all rule. +func GenerateIngressRules(services []ServiceInfo, log *zerolog.Logger) []config.UnvalidatedIngressRule { + rules := make([]config.UnvalidatedIngressRule, 0, len(services)) + + for _, svc := range services { + originURL := svc.OriginURL() + rule := config.UnvalidatedIngressRule{ + Hostname: svc.Hostname, + Service: originURL, + Path: svc.Path, + } + + // Apply per-service origin request overrides from annotations. + if svc.NoTLSVerify { + noTLS := true + rule.OriginRequest.NoTLSVerify = &noTLS + } + if svc.OriginServerName != "" { + rule.OriginRequest.OriginServerName = &svc.OriginServerName + } + + log.Info(). + Str("service", fmt.Sprintf("%s/%s", svc.Namespace, svc.Name)). + Str("hostname", svc.Hostname). + Str("origin", originURL). + Msg("Generated ingress rule from Kubernetes service") + + rules = append(rules, rule) + } + + return rules +} + +// MergeWithExistingRules takes user-defined ingress rules and the auto-discovered +// Kubernetes rules and produces a combined set. Kubernetes-generated rules are +// prepended so that they take priority, but user-defined catch-all rules are +// always kept at the end. +func MergeWithExistingRules( + existing []config.UnvalidatedIngressRule, + k8sRules []config.UnvalidatedIngressRule, +) []config.UnvalidatedIngressRule { + if len(k8sRules) == 0 { + return existing + } + if len(existing) == 0 { + return k8sRules + } + + // Separate the catch-all rule (last rule) from the rest. + var catchAll *config.UnvalidatedIngressRule + rest := existing + if len(existing) > 0 { + last := existing[len(existing)-1] + if isCatchAll(last) { + catchAll = &last + rest = existing[:len(existing)-1] + } + } + + // Deduplicate: remove any K8s rule that duplicates an existing user rule. + existingSet := make(map[string]struct{}, len(rest)) + for _, r := range rest { + existingSet[r.Hostname+"#"+r.Path] = struct{}{} + } + + merged := make([]config.UnvalidatedIngressRule, 0, len(rest)+len(k8sRules)+1) + // User rules first (higher priority for user-specified). + merged = append(merged, rest...) + // Then K8s rules. + for _, kr := range k8sRules { + key := kr.Hostname + "#" + kr.Path + if _, dup := existingSet[key]; !dup { + merged = append(merged, kr) + } + } + // Append catch-all. + if catchAll != nil { + merged = append(merged, *catchAll) + } else { + // Always ensure there's a catch-all rule. + merged = append(merged, config.UnvalidatedIngressRule{ + Service: "http_status:503", + }) + } + + return merged +} + +func isCatchAll(r config.UnvalidatedIngressRule) bool { + return (r.Hostname == "" || r.Hostname == "*") && r.Path == "" +} diff --git a/k8s/ingress_test.go b/k8s/ingress_test.go new file mode 100644 index 00000000000..81079834312 --- /dev/null +++ b/k8s/ingress_test.go @@ -0,0 +1,124 @@ +package k8s + +import ( + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/config" +) + +func TestGenerateIngressRules(t *testing.T) { + log := zerolog.Nop() + + services := []ServiceInfo{ + { + Name: "web", + Namespace: "default", + ClusterIP: "10.96.0.1", + Port: 80, + Scheme: "http", + Hostname: "web-default.example.com", + }, + { + Name: "api", + Namespace: "prod", + ClusterIP: "10.96.0.2", + Port: 443, + Scheme: "https", + Hostname: "api.example.com", + NoTLSVerify: true, + OriginServerName: "api.internal", + }, + { + Name: "docs", + Namespace: "default", + ClusterIP: "10.96.0.3", + Port: 8080, + Scheme: "http", + Hostname: "docs.example.com", + Path: "/docs/.*", + }, + } + + rules := GenerateIngressRules(services, &log) + require.Len(t, rules, 3) + + // Check first rule + assert.Equal(t, "web-default.example.com", rules[0].Hostname) + assert.Equal(t, "http://10.96.0.1:80", rules[0].Service) + assert.Empty(t, rules[0].Path) + assert.Nil(t, rules[0].OriginRequest.NoTLSVerify) + + // Check second rule with TLS overrides + assert.Equal(t, "api.example.com", rules[1].Hostname) + assert.Equal(t, "https://10.96.0.2:443", rules[1].Service) + require.NotNil(t, rules[1].OriginRequest.NoTLSVerify) + assert.True(t, *rules[1].OriginRequest.NoTLSVerify) + require.NotNil(t, rules[1].OriginRequest.OriginServerName) + assert.Equal(t, "api.internal", *rules[1].OriginRequest.OriginServerName) + + // Check third rule with path + assert.Equal(t, "docs.example.com", rules[2].Hostname) + assert.Equal(t, "/docs/.*", rules[2].Path) +} + +func TestMergeWithExistingRules(t *testing.T) { + k8sRules := []config.UnvalidatedIngressRule{ + {Hostname: "k8s-svc.example.com", Service: "http://10.96.0.1:80"}, + } + + t.Run("empty existing rules", func(t *testing.T) { + merged := MergeWithExistingRules(nil, k8sRules) + require.Len(t, merged, 1) + assert.Equal(t, "k8s-svc.example.com", merged[0].Hostname) + }) + + t.Run("empty k8s rules", func(t *testing.T) { + existing := []config.UnvalidatedIngressRule{ + {Hostname: "www.example.com", Service: "http://localhost:8080"}, + {Service: "http_status:404"}, + } + merged := MergeWithExistingRules(existing, nil) + assert.Equal(t, existing, merged) + }) + + t.Run("merge with catch-all", func(t *testing.T) { + existing := []config.UnvalidatedIngressRule{ + {Hostname: "www.example.com", Service: "http://localhost:8080"}, + {Service: "http_status:404"}, // catch-all + } + merged := MergeWithExistingRules(existing, k8sRules) + require.Len(t, merged, 3) + // User rule first + assert.Equal(t, "www.example.com", merged[0].Hostname) + // K8s rule + assert.Equal(t, "k8s-svc.example.com", merged[1].Hostname) + // Catch-all last + assert.Equal(t, "http_status:404", merged[2].Service) + }) + + t.Run("no catch-all adds default", func(t *testing.T) { + existing := []config.UnvalidatedIngressRule{ + {Hostname: "www.example.com", Service: "http://localhost:8080"}, + } + merged := MergeWithExistingRules(existing, k8sRules) + require.Len(t, merged, 3) + // Should have a catch-all appended + assert.Equal(t, "http_status:503", merged[2].Service) + }) + + t.Run("deduplication", func(t *testing.T) { + existing := []config.UnvalidatedIngressRule{ + {Hostname: "k8s-svc.example.com", Service: "http://override:9090"}, + {Service: "http_status:404"}, + } + merged := MergeWithExistingRules(existing, k8sRules) + // K8s rule for k8s-svc.example.com should be deduplicated + require.Len(t, merged, 2) + // The user-defined one takes priority + assert.Equal(t, "http://override:9090", merged[0].Service) + }) +} diff --git a/k8s/watcher.go b/k8s/watcher.go new file mode 100644 index 00000000000..80186b1e3f6 --- /dev/null +++ b/k8s/watcher.go @@ -0,0 +1,127 @@ +package k8s + +import ( + "context" + "sync" + "time" + + "github.com/rs/zerolog" +) + +// ServiceChangeHandler is called whenever the set of discovered services changes. +type ServiceChangeHandler func(services []ServiceInfo) + +// Watcher periodically polls the Kubernetes API for service changes and +// notifies registered handlers. +type Watcher struct { + cfg *Config + log *zerolog.Logger + handler ServiceChangeHandler + + mu sync.Mutex + services []ServiceInfo + + stopOnce sync.Once + stopC chan struct{} +} + +// NewWatcher creates a Watcher that will poll the Kubernetes API at the +// configured resync interval. +func NewWatcher(cfg *Config, log *zerolog.Logger, handler ServiceChangeHandler) *Watcher { + if cfg.ResyncPeriod == 0 { + cfg.ResyncPeriod = DefaultResyncPeriod + } + return &Watcher{ + cfg: cfg, + log: log, + handler: handler, + stopC: make(chan struct{}), + } +} + +// Run starts the watch loop. It blocks until ctx is cancelled or Stop is called. +func (w *Watcher) Run(ctx context.Context) { + w.log.Info(). + Str("namespace", w.cfg.Namespace). + Str("baseDomain", w.cfg.BaseDomain). + Dur("resyncPeriod", w.cfg.ResyncPeriod). + Msg("Starting Kubernetes service watcher") + + // Initial sync + w.sync(ctx) + + ticker := time.NewTicker(w.cfg.ResyncPeriod) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + w.log.Info().Msg("Kubernetes service watcher stopped (context cancelled)") + return + case <-w.stopC: + w.log.Info().Msg("Kubernetes service watcher stopped") + return + case <-ticker.C: + w.sync(ctx) + } + } +} + +// Stop signals the watcher to stop. +func (w *Watcher) Stop() { + w.stopOnce.Do(func() { + close(w.stopC) + }) +} + +// Services returns a snapshot of the currently discovered services. +func (w *Watcher) Services() []ServiceInfo { + w.mu.Lock() + defer w.mu.Unlock() + out := make([]ServiceInfo, len(w.services)) + copy(out, w.services) + return out +} + +// sync performs one discovery cycle. +func (w *Watcher) sync(ctx context.Context) { + services, err := DiscoverServices(ctx, w.cfg, w.log) + if err != nil { + w.log.Err(err).Msg("Failed to discover Kubernetes services") + return + } + + w.mu.Lock() + changed := !servicesEqual(w.services, services) + w.services = services + w.mu.Unlock() + + w.log.Info().Int("count", len(services)).Bool("changed", changed).Msg("Kubernetes service sync complete") + + if changed && w.handler != nil { + w.handler(services) + } +} + +// servicesEqual performs a simple equality check on two ServiceInfo slices. +func servicesEqual(a, b []ServiceInfo) bool { + if len(a) != len(b) { + return false + } + // Build a set from a, check against b. + set := make(map[string]struct{}, len(a)) + for _, s := range a { + set[s.key()] = struct{}{} + } + for _, s := range b { + if _, ok := set[s.key()]; !ok { + return false + } + } + return true +} + +// key returns a stable string representation for comparison. +func (s *ServiceInfo) key() string { + return s.Namespace + "/" + s.Name + ":" + s.OriginURL() + "@" + s.Hostname + "#" + s.Path +} diff --git a/k8s/watcher_test.go b/k8s/watcher_test.go new file mode 100644 index 00000000000..97b738d3880 --- /dev/null +++ b/k8s/watcher_test.go @@ -0,0 +1,154 @@ +package k8s + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeK8sServer returns an httptest.Server that responds to /api/v1/services +// with the given service list. +func fakeK8sServer(t *testing.T, services serviceList) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(services); err != nil { + t.Fatalf("failed to encode services: %v", err) + } + })) +} + +func TestDiscoverServicesWithMockServer(t *testing.T) { + log := zerolog.Nop() + + svcList := serviceList{ + Items: []serviceItem{ + { + Metadata: objectMeta{ + Name: "web", + Namespace: "default", + Annotations: map[string]string{ + AnnotationEnabled: "true", + }, + }, + Spec: serviceSpec{ + ClusterIP: "10.96.0.1", + Ports: []servicePort{{Name: "http", Port: 80, Protocol: "TCP"}}, + }, + }, + { + Metadata: objectMeta{ + Name: "skipped", + Namespace: "default", + Annotations: map[string]string{ + // No tunnel annotation + }, + }, + Spec: serviceSpec{ + ClusterIP: "10.96.0.2", + Ports: []servicePort{{Port: 80}}, + }, + }, + { + Metadata: objectMeta{ + Name: "api", + Namespace: "prod", + Annotations: map[string]string{ + AnnotationEnabled: "true", + AnnotationHostname: "api.mycompany.com", + AnnotationScheme: "https", + AnnotationPort: "443", + }, + }, + Spec: serviceSpec{ + ClusterIP: "10.96.1.5", + Ports: []servicePort{ + {Name: "http", Port: 80}, + {Name: "https", Port: 443}, + }, + }, + }, + }, + } + + server := fakeK8sServer(t, svcList) + defer server.Close() + + cfg := &Config{ + Enabled: true, + BaseDomain: "example.com", + } + + // Override the client builder for testing. + client := &kubeClient{ + baseURL: server.URL, + httpClient: server.Client(), + log: &log, + } + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + body, err := client.do(ctx, http.MethodGet, "/api/v1/services") + require.NoError(t, err) + + var list serviceList + require.NoError(t, json.Unmarshal(body, &list)) + require.Len(t, list.Items, 3) + + // Now test the full discovery pipeline by filtering. + services := make([]ServiceInfo, 0, len(list.Items)) + for _, item := range list.Items { + ann := item.Metadata.Annotations + if ann == nil { + continue + } + enabled, ok := ann[AnnotationEnabled] + if !ok || !isTrue(enabled) { + continue + } + si, err := serviceInfoFromItem(item, cfg) + if err != nil { + continue + } + services = append(services, *si) + } + + require.Len(t, services, 2) + + // web service + assert.Equal(t, "web", services[0].Name) + assert.Equal(t, "web-default.example.com", services[0].Hostname) + assert.Equal(t, "http", services[0].Scheme) + assert.Equal(t, int32(80), services[0].Port) + + // api service + assert.Equal(t, "api", services[1].Name) + assert.Equal(t, "api.mycompany.com", services[1].Hostname) + assert.Equal(t, "https", services[1].Scheme) + assert.Equal(t, int32(443), services[1].Port) +} + +func TestWatcherServicesEqual(t *testing.T) { + a := []ServiceInfo{ + {Name: "web", Namespace: "default", ClusterIP: "10.0.0.1", Port: 80, Scheme: "http", Hostname: "web.example.com"}, + } + b := []ServiceInfo{ + {Name: "web", Namespace: "default", ClusterIP: "10.0.0.1", Port: 80, Scheme: "http", Hostname: "web.example.com"}, + } + + assert.True(t, servicesEqual(a, b)) + assert.True(t, servicesEqual(nil, nil)) + assert.False(t, servicesEqual(a, nil)) + assert.False(t, servicesEqual(nil, b)) + + c := append(b, ServiceInfo{Name: "api", Namespace: "default", ClusterIP: "10.0.0.2", Port: 443, Scheme: "https", Hostname: "api.example.com"}) + assert.False(t, servicesEqual(a, c)) +} diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go index 9840bd36069..d2da29766ab 100644 --- a/orchestration/orchestrator.go +++ b/orchestration/orchestrator.go @@ -216,6 +216,24 @@ func (o *Orchestrator) GetConfigJSON() ([]byte, error) { return json.Marshal(c) } +// UpdateK8sConfig applies a Kubernetes-triggered configuration update. Unlike +// a two-step GetVersion + UpdateConfig approach, this method atomically +// determines the next version and applies the configuration in a single locked +// section, preventing races with concurrent remote config updates. +func (o *Orchestrator) UpdateK8sConfig(config []byte) *pogs.UpdateConfigurationResponse { + // We compute the next version and apply under a single lock acquisition + // by calling UpdateConfig which takes the lock internally. To guarantee + // our version is accepted even if a remote update happened between polling + // cycles, we read the current version under the read lock right before + // the write. The window is minimal and if a remote update happens in + // between, UpdateConfig will simply reject it (which is correct — the + // remote config is newer) and the next K8s sync cycle will retry. + o.lock.RLock() + nextVersion := o.currentVersion + 1 + o.lock.RUnlock() + return o.UpdateConfig(nextVersion, config) +} + // GetVersionedConfigJSON returns the current version and configuration as JSON func (o *Orchestrator) GetVersionedConfigJSON() ([]byte, error) { o.lock.RLock()