Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 102 additions & 50 deletions cmd/pyrycode-relay/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,71 @@
package main

import (
"context"
"errors"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"slices"
"syscall"
"time"

"github.com/pyrycode/pyrycode-relay/internal/relay"
)

// drainDeadline bounds how long Shutdown will wait for in-flight WS
// close handshakes before force-closing. nhooyr.io/websocket.Conn.Close
// waits up to 5s per conn for the peer's reciprocal close (lessons.md);
// 10s leaves ~5s of headroom while keeping a stuck drain from delaying
// a fly machine update indefinitely. Lives at the wiring site per the
// established policy-values-in-main convention (#21, #60).
const drainDeadline = 10 * time.Second

// Version is overridden at build time via -ldflags.
var Version = "dev"

func main() {
os.Exit(run(os.Args[1:], signalContextFor(syscall.SIGTERM, syscall.SIGINT)))
}

// signalContextFor returns a context that is cancelled when any of the
// listed signals is received. Equivalent to signal.NotifyContext; broken
// out so tests can drive the shutdown path with a synthetic cancellation
// instead of the real signal handler.
func signalContextFor(sigs ...os.Signal) context.Context {
ctx, _ := signal.NotifyContext(context.Background(), sigs...)
return ctx
}

func run(args []string, sigCtx context.Context) int {
fs := flag.NewFlagSet("pyrycode-relay", flag.ExitOnError)
var (
domain = flag.String("domain", "", "Public domain for Let's Encrypt cert issuance (required unless --insecure-listen is set).")
certCache = flag.String("cert-cache", defaultCertCache(), "Directory for autocert's TLS certificate cache.")
insecureListen = flag.String("insecure-listen", "", "Listen address for plain HTTP (e.g. :8080). Disables autocert; use only when fronted by a reverse proxy.")
metricsListen = flag.String("metrics-listen", "127.0.0.1:9090", "Listen address for the /metrics endpoint. Must be a loopback IP literal (e.g. 127.0.0.1:9090, [::1]:9090). Empty disables.")
trustXFF = flag.Bool("trust-x-forwarded-for", false,
domain = fs.String("domain", "", "Public domain for Let's Encrypt cert issuance (required unless --insecure-listen is set).")
certCache = fs.String("cert-cache", defaultCertCache(), "Directory for autocert's TLS certificate cache.")
insecureListen = fs.String("insecure-listen", "", "Listen address for plain HTTP (e.g. :8080). Disables autocert; use only when fronted by a reverse proxy.")
metricsListen = fs.String("metrics-listen", "127.0.0.1:9090", "Listen address for the /metrics endpoint. Must be a loopback IP literal (e.g. 127.0.0.1:9090, [::1]:9090). Empty disables.")
trustXFF = fs.Bool("trust-x-forwarded-for", false,
"Trust the X-Forwarded-For header as the source IP for per-IP rate limiting. "+
"WARNING: enabling this without a trusted reverse proxy in front of the relay "+
"allows clients to spoof their source IP and bypass per-IP rate limits.")
showVersion = flag.Bool("version", false, "Print version and exit.")
showVersion = fs.Bool("version", false, "Print version and exit.")
)
flag.Parse()
_ = fs.Parse(args)

if *showVersion {
fmt.Println(Version)
return
return 0
}

logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
slog.SetDefault(logger)

if *insecureListen == "" && *domain == "" {
logger.Error("either --domain (for autocert) or --insecure-listen (for behind-proxy mode) must be set")
os.Exit(2)
return 2
}

// Boot-time env-var validation runs BEFORE CheckInsecureListenInProduction
Expand All @@ -65,15 +89,15 @@ func main() {
} else {
logger.Error("refusing to start: invalid env-var config", "err", err)
}
os.Exit(2)
return 2
}

if err := relay.CheckInsecureListenInProduction(*insecureListen, os.Getenv); err != nil {
logger.Error("refusing to start: production-mode misconfiguration",
"err", err,
"env_var", "PYRYCODE_RELAY_PRODUCTION",
"fix", "remove --insecure-listen and set --domain, or unset PYRYCODE_RELAY_PRODUCTION")
os.Exit(2)
return 2
}

// CheckRunningAsRoot is the in-process backstop for the CI non-root-build
Expand All @@ -89,14 +113,14 @@ func main() {
"env_var", "PYRYCODE_RELAY_PRODUCTION",
"effective_uid", syscall.Geteuid(),
"fix", "drop privileges before exec (e.g. Dockerfile USER directive or --user <non-zero>, kubernetes securityContext.runAsUser), or unset PYRYCODE_RELAY_PRODUCTION if the deploy is truly dev")
os.Exit(2)
return 2
}

if err := relay.CheckCapabilities(); err != nil {
logger.Error("refusing to start: unexpected Linux capabilities",
"err", err,
"fix", "drop extra capabilities (e.g. --cap-drop=ALL --cap-add=NET_BIND_SERVICE on docker, or securityContext.capabilities on kubernetes)")
os.Exit(2)
return 2
}

startedAt := time.Now()
Expand All @@ -117,7 +141,7 @@ func main() {
"err", err,
"value", *metricsListen,
"fix", "use a loopback IP literal such as 127.0.0.1:9090 or [::1]:9090, or pass --metrics-listen= to disable")
os.Exit(2)
return 2
}

// maxFrameBytes: 256 KiB per-frame read cap. Derivation:
Expand All @@ -139,11 +163,9 @@ func main() {
rateLimitEvictionInterval = 5 * time.Minute
)
limiter := relay.NewIPRateLimiter(rateLimitRefillEvery, rateLimitBurst, rateLimitEvictionInterval)
// Best-effort: the current listener block calls os.Exit on error, which
// skips defers. A real graceful-shutdown path (signal handler + server
// Shutdown) is out of scope per #47; this defer runs on clean returns
// (e.g. --version above does not reach here, but future test entry
// points might).
// Graceful shutdown (#31) reaches limiter.Close on the clean return
// path: a signal or a listener-goroutine error triggers relay.Shutdown
// before run returns.
defer limiter.Close()
rateLimit := relay.NewRateLimitMiddleware(limiter, logger, *trustXFF)

Expand Down Expand Up @@ -172,7 +194,7 @@ func main() {
if err != nil {
logger.Error("refusing to start: invalid listener address",
"err", err, "addr", srv.Addr)
os.Exit(2)
return 2
}
expected := map[uint16]struct{}{port: {}}
actual := map[uint16]struct{}{port: {}}
Expand All @@ -181,7 +203,7 @@ func main() {
if err != nil {
logger.Error("refusing to start: invalid listener address",
"err", err, "addr", metricsSrv.Addr)
os.Exit(2)
return 2
}
expected[mp] = struct{}{}
actual[mp] = struct{}{}
Expand All @@ -192,28 +214,22 @@ func main() {
"err", err,
"unexpected_ports", surplus,
"expected_ports", expectedList)
os.Exit(2)
return 2
}
servers := []*http.Server{srv}
if metricsSrv != nil {
servers = append(servers, metricsSrv)
logger.Info("starting metrics listener", "listen", metricsSrv.Addr)
go func() {
if err := metricsSrv.ListenAndServe(); err != nil {
logger.Error("metrics listener failed", "err", err)
os.Exit(1)
}
}()
}
if err := srv.ListenAndServe(); err != nil {
logger.Error("listen failed", "err", err)
os.Exit(1)
}
return
return runServers(sigCtx, logger, reg, servers, func(s *http.Server) error {
return s.ListenAndServe()
})
}

mgr, err := relay.NewAutocertManager(*domain, *certCache)
if err != nil {
logger.Error("autocert setup failed", "err", err)
os.Exit(1)
return 1
}

httpsSrv := &http.Server{
Expand Down Expand Up @@ -242,13 +258,13 @@ func main() {
if err != nil {
logger.Error("refusing to start: invalid listener address",
"err", err, "addr", httpsSrv.Addr)
os.Exit(2)
return 2
}
httpPort, err := relay.ListenerPort(httpSrv.Addr)
if err != nil {
logger.Error("refusing to start: invalid listener address",
"err", err, "addr", httpSrv.Addr)
os.Exit(2)
return 2
}
expected := map[uint16]struct{}{443: {}, 80: {}}
actual := map[uint16]struct{}{httpsPort: {}, httpPort: {}}
Expand All @@ -257,7 +273,7 @@ func main() {
if err != nil {
logger.Error("refusing to start: invalid listener address",
"err", err, "addr", metricsSrv.Addr)
os.Exit(2)
return 2
}
expected[mp] = struct{}{}
actual[mp] = struct{}{}
Expand All @@ -268,33 +284,69 @@ func main() {
"err", err,
"unexpected_ports", surplus,
"expected_ports", expectedList)
os.Exit(2)
return 2
}

logger.Info("starting", "version", Version, "mode", "autocert",
"domain", *domain, "cert_cache", *certCache)

servers := []*http.Server{httpsSrv, httpSrv}
if metricsSrv != nil {
servers = append(servers, metricsSrv)
logger.Info("starting metrics listener", "listen", metricsSrv.Addr)
}
return runServers(sigCtx, logger, reg, servers, func(s *http.Server) error {
if s == httpsSrv {
return s.ListenAndServeTLS("", "")
}
return s.ListenAndServe()
})
}

// runServers launches one goroutine per *http.Server invoking listen(s)
// and blocks until either sigCtx is cancelled (operator signal) or one
// of the listeners returns a non-ErrServerClosed error. Either way it
// runs relay.Shutdown(drainCtx, …) and returns:
//
// - 0 on signal-triggered drain (clean operator action),
// - 1 on listener-error-triggered drain (process supervisor restart).
//
// Errors are logged with their source listener's Addr; only the first
// listener error wins (buffered chan).
func runServers(sigCtx context.Context, logger *slog.Logger, reg *relay.Registry, servers []*http.Server, listen func(*http.Server) error) int {
listenerErr := make(chan error, 1)
for _, s := range servers {
s := s
go func() {
if err := metricsSrv.ListenAndServe(); err != nil {
logger.Error("metrics listener failed", "err", err)
os.Exit(1)
if err := listen(s); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error("listener failed", "addr", s.Addr, "err", err)
select {
case listenerErr <- err:
default:
}
}
}()
}

go func() {
if err := httpSrv.ListenAndServe(); err != nil {
logger.Error("http-01 listener failed", "err", err)
os.Exit(1)
}
}()
var triggeredByErr bool
select {
case <-sigCtx.Done():
logger.Info("shutdown signal received; draining")
case <-listenerErr:
triggeredByErr = true
logger.Error("listener error; draining")
}

drainCtx, cancel := context.WithTimeout(context.Background(), drainDeadline)
defer cancel()
if err := relay.Shutdown(drainCtx, logger, reg, servers...); err != nil {
logger.Warn("drain incomplete", "err", err)
}

if err := httpsSrv.ListenAndServeTLS("", ""); err != nil {
logger.Error("https listener failed", "err", err)
os.Exit(1)
if triggeredByErr {
return 1
}
return 0
}

// listenerPortLists returns the ascending-sorted surplus (actual\expected)
Expand Down
Loading
Loading