From afb7e6b0ebae6457c06488852be753487c8126c3 Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Thu, 18 Jun 2026 19:33:44 +0300 Subject: [PATCH 01/10] feat: introduce self-healing SSH client with connection management and tunneling Signed-off-by: Daniil Studenikin --- docs/ARCHITECTURE.md | 70 ++++- internal/infrastructure/ssh/v2/client.go | 84 +++++ internal/infrastructure/ssh/v2/conn.go | 289 ++++++++++++++++++ internal/infrastructure/ssh/v2/conn_test.go | 278 +++++++++++++++++ internal/infrastructure/ssh/v2/dialer.go | 221 ++++++++++++++ internal/infrastructure/ssh/v2/dialer_test.go | 152 +++++++++ internal/infrastructure/ssh/v2/endpoint.go | 178 +++++++++++ .../infrastructure/ssh/v2/endpoint_test.go | 144 +++++++++ internal/infrastructure/ssh/v2/errors.go | 111 +++++++ internal/infrastructure/ssh/v2/errors_test.go | 87 ++++++ internal/infrastructure/ssh/v2/options.go | 108 +++++++ .../infrastructure/ssh/v2/testserver_test.go | 254 +++++++++++++++ internal/infrastructure/ssh/v2/tunnel.go | 223 ++++++++++++++ internal/infrastructure/ssh/v2/tunnel_test.go | 253 +++++++++++++++ 14 files changed, 2446 insertions(+), 6 deletions(-) create mode 100644 internal/infrastructure/ssh/v2/client.go create mode 100644 internal/infrastructure/ssh/v2/conn.go create mode 100644 internal/infrastructure/ssh/v2/conn_test.go create mode 100644 internal/infrastructure/ssh/v2/dialer.go create mode 100644 internal/infrastructure/ssh/v2/dialer_test.go create mode 100644 internal/infrastructure/ssh/v2/endpoint.go create mode 100644 internal/infrastructure/ssh/v2/endpoint_test.go create mode 100644 internal/infrastructure/ssh/v2/errors.go create mode 100644 internal/infrastructure/ssh/v2/errors_test.go create mode 100644 internal/infrastructure/ssh/v2/options.go create mode 100644 internal/infrastructure/ssh/v2/testserver_test.go create mode 100644 internal/infrastructure/ssh/v2/tunnel.go create mode 100644 internal/infrastructure/ssh/v2/tunnel_test.go diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index b1718f2..5f26c21 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -51,11 +51,19 @@ storage-e2e/ │ │ └── vm_block_device.go │ │ │ ├── infrastructure/ # Infrastructure layer -│ │ └── ssh/ # SSH operations +│ │ └── ssh/ # SSH operations (legacy) │ │ ├── client.go │ │ ├── interface.go │ │ ├── tunnel.go -│ │ └── types.go +│ │ ├── types.go +│ │ └── v2/ # Self-healing SSH client (Dialer/Route + Tunnel) +│ │ ├── client.go # New, Client, Close + package docs +│ │ ├── conn.go # connection core: snapshot/refresh/keepalive + withConn +│ │ ├── dialer.go # Dialer interface, Route, chain closer +│ │ ├── endpoint.go # Endpoint, auth, host/key resolution +│ │ ├── errors.go # transient classification, ExitError +│ │ ├── options.go # functional options +│ │ └── tunnel.go # Tunnel, accept loop │ │ │ └── logger/ # Structured logging │ ├── logger.go # Logger implementation @@ -473,10 +481,18 @@ internal/kubernetes/ # Internal Kubernetes clients ``` infrastructure/ssh/ -├── client.go # SSH client implementation (Exec, ExecCapture, tunnels) -├── interface.go # SSH client interface -├── tunnel.go # Port forwarding and tunneling -└── types.go # SSH-related types +├── client.go # SSH client implementation (Exec, ExecCapture, tunnels) [legacy] +├── interface.go # SSH client interface [legacy] +├── tunnel.go # Port forwarding and tunneling [legacy] +├── types.go # SSH-related types [legacy] +└── v2/ # Self-healing SSH client (see below) + ├── client.go # New, Client, Close + package docs + ├── conn.go # connection core: snapshot/refresh/keepalive + withConn executor + ├── dialer.go # Dialer interface, Route, chain closer + ├── endpoint.go # Endpoint, auth, host/key resolution + ├── errors.go # transient classification, ExitError + ├── options.go # functional options + └── tunnel.go # Tunnel, accept loop ``` **Responsibilities**: @@ -494,6 +510,48 @@ infrastructure/ssh/ - `ExecCapture` keeps stdout and stderr separate while preserving retry/reconnect behavior - Proper resource cleanup +#### 3.4.1 Self-healing SSH client (`internal/infrastructure/ssh/v2/`) + +A ground-up rewrite that lives in parallel with the legacy package (no consumers +migrated yet). It separates **how we connect** (directly or via jump hosts) from +**what we do over the connection** (currently only tunneling), and hides every +reconnect from callers. + +**Design**: + +- `Dialer` is the injection point: `Dial(ctx) (*ssh.Client, io.Closer, error)` + + `Describe()`. `Route(first Endpoint, more ...Endpoint)` builds the built-in + implementation; the last hop is always the target, so the `(first, more...)` + signature guarantees at least one hop at compile time. The returned `io.Closer` + tears down the whole chain (target + every jump + ssh-agent connections). +- `Endpoint` describes a single host: `User`, `Addr` (`host` or `host:port`, + default `:22`), `KeyPath` (`~` expanded), optional `Passphrase` + (falls back to `SSH_PASSPHRASE` then ssh-agent), optional per-hop `HostKey`. +- The unexported `conn` core owns the current `*ssh.Client`, its chain `Closer`, + and a generation counter under a mutex. `snapshot` reads them; `refresh` + re-dials via `singleflight` keyed on the failed generation so concurrent + reconnects collapse into one and a stale generation never tears down a freshly + healed link. The slow `Dial` runs outside the lock on a detached context + (`context.WithoutCancel` + timeout) so one caller's cancellation can't abort + the shared flight. +- A single generic executor `withConn[T]` runs an operation against the live + client and heals on transient failures (bounded by `WithRetries`); the tunnel + uses it today and `Run`/`Upload` are designed to reuse it unchanged. +- Optional keepalive (`WithKeepalive`) probes the link and heals through the same + `refresh` path; every heal is logged at WARN. + +**Public API v1**: `New(ctx, Dialer, ...Option)`, `Client.Tunnel(ctx, remotePort)` +(self-healing local forward on a free `127.0.0.1` port; `Tunnel.LocalAddr`, +`Tunnel.Close`), `Client.Close`. Options: `WithKeepalive`, `WithRetries`, +`WithLogger`, `WithHostKeyCallback`, `WithInsecureIgnoreHostKey` (host key +defaults to `InsecureIgnoreHostKey` — a conscious default for ephemeral e2e VMs). + +**Extension points (designed, not yet implemented)**: `Run` (transparent retry +only when the session fails to open; mid-flight drops heal but surface the error +to avoid double side effects; opt-in `Idempotent` for true retry) and `Upload`. +Transient-error classification uses `errors.Is`/`errors.As` against standard +types — never error-string matching. + ### 3.5 Logger Module (`internal/logger/`) ``` diff --git a/internal/infrastructure/ssh/v2/client.go b/internal/infrastructure/ssh/v2/client.go new file mode 100644 index 0000000..56cd3dc --- /dev/null +++ b/internal/infrastructure/ssh/v2/client.go @@ -0,0 +1,84 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package ssh provides a self-healing SSH client whose connection strategy +// ("how we connect" — directly or through jump hosts) is separated from the +// operations performed over it ("what we do" — currently tunneling). +// +// The injection point is the Dialer: Route builds one for a direct connection or +// an arbitrary chain of jump hosts. New opens a Client over a Dialer and hides +// every reconnect: callers invoke methods and never reason about reconnection. +// All operations funnel through a single reconnect-aware executor (withConn) over +// a shared connection core (conn), so future operations such as Run and Upload +// can be added without touching the healing logic. +// +// The primary use case is opening a tunnel to the API server of a closed +// Kubernetes cluster and pointing a kubeconfig at it: +// +// c, _ := ssh.New(ctx, ssh.Route(jumpEp, targetEp)) +// defer c.Close() +// t, _ := c.Tunnel(ctx, 6443) +// defer t.Close() +// rest := &rest.Config{Host: "https://" + t.LocalAddr()} +package ssh + +import ( + "context" + "errors" + "log/slog" +) + +// Client is a self-healing SSH client over a Dialer-provided connection. It is +// safe for concurrent use; reconnects are transparent to callers. +type Client struct { + conn *conn + retries int + log *slog.Logger +} + +// New connects immediately over d, starts keepalive when enabled, and returns a +// ready Client. The context bounds the initial connection. If d implements the +// internal host-key defaulter (as the built-in Route does), the resolved +// host-key option is pushed into it so per-hop Endpoint.HostKey values take +// precedence over the Client-level default. +func New(ctx context.Context, d Dialer, opts ...Option) (*Client, error) { + if d == nil { + return nil, errors.New("ssh: nil dialer") + } + + o := defaultOptions() + for _, opt := range opts { + opt(&o) + } + + if hkd, ok := d.(hostKeyDefaulter); ok { + hkd.setDefaultHostKey(o.hostKey) + } + + core, err := newConn(ctx, d, o) + if err != nil { + return nil, err + } + + return &Client{conn: core, retries: o.retries, log: o.log}, nil +} + +// Close tears down the connection and its whole chain and stops keepalive. It is +// idempotent and safe for concurrent use. Open tunnels keep their listeners; the +// caller should Close those separately. +func (c *Client) Close() error { + return c.conn.Close() +} diff --git a/internal/infrastructure/ssh/v2/conn.go b/internal/infrastructure/ssh/v2/conn.go new file mode 100644 index 0000000..eb0c5a7 --- /dev/null +++ b/internal/infrastructure/ssh/v2/conn.go @@ -0,0 +1,289 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "fmt" + "io" + "log/slog" + "strconv" + "sync" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/sync/singleflight" +) + +// conn is the connection core shared by every high-level operation. It owns the +// current *ssh.Client together with the Closer for its whole chain and a +// monotonically increasing generation counter. All reconnect logic lives here so +// callers (Tunnel today, Run/Upload later) never see a reconnect: they ask for +// the live client, run their operation, and on a transient failure call withConn +// which heals the connection underneath them. +type conn struct { + dialer Dialer + log *slog.Logger + dialTimeout time.Duration + + // flight deduplicates concurrent reconnects keyed by the failed generation, + // preventing a reconnect storm from tearing down a freshly healed link. + flight singleflight.Group + + mu sync.Mutex + client *ssh.Client + closer io.Closer + gen uint64 + closed bool + + // keepalive lifecycle. + kaCancel context.CancelFunc + wg sync.WaitGroup +} + +// newConn establishes the initial connection and, when keepalive > 0, starts the +// background keepalive goroutine. The initial dial uses the caller's context so +// startup honors their deadline and cancellation. +func newConn(ctx context.Context, d Dialer, o options) (*conn, error) { + client, closer, err := d.Dial(ctx) + if err != nil { + return nil, fmt.Errorf("connect to %s: %w", d.Describe(), err) + } + + c := &conn{ + dialer: d, + log: o.log, + dialTimeout: o.dialTimeout, + client: client, + closer: closer, + gen: 1, + } + + if o.keepalive > 0 { + // Keepalive must outlive the caller's setup context: the connection + // stays alive until Close, not until the New call returns. A fresh root + // context canceled by Close is therefore correct here. + kaCtx, cancel := context.WithCancel(context.Background()) + c.kaCancel = cancel + c.wg.Add(1) + //nolint:contextcheck // intentional: keepalive lifetime is bound to Close, not the setup context. + go c.keepaliveLoop(kaCtx, o.keepalive) + } + + return c, nil +} + +// snapshot returns the current client and its generation under the lock. The +// generation lets callers tell refresh which connection failed them, so a +// concurrent heal is not duplicated. +func (c *conn) snapshot() (client *ssh.Client, gen uint64) { + c.mu.Lock() + defer c.mu.Unlock() + return c.client, c.gen +} + +// refresh re-establishes the connection that failed at generation failedGen and +// returns the now-current client and generation. Concurrent callers that failed +// on the same generation are collapsed into a single dial via singleflight; a +// caller whose failedGen is already stale (someone else healed first) gets the +// current client back without dialing. The actual Dial runs outside the lock and +// on a detached context with its own timeout, so one caller's cancellation can +// never abort the shared reconnect that others are waiting on. +func (c *conn) refresh(ctx context.Context, failedGen uint64) (*ssh.Client, uint64, error) { + key := strconv.FormatUint(failedGen, 10) + + type healed struct { + client *ssh.Client + gen uint64 + } + + v, err, _ := c.flight.Do(key, func() (interface{}, error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, errClosed + } + // Someone already healed past failedGen — reuse the live client. + if c.gen != failedGen { + cur := healed{client: c.client, gen: c.gen} + c.mu.Unlock() + return cur, nil + } + c.mu.Unlock() + + // Detach from the caller's context so one cancellation does not abort the + // shared flight, but still bound the dial with our own timeout. + dialCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), c.dialTimeout) + defer cancel() + + client, closer, dialErr := c.dialer.Dial(dialCtx) + if dialErr != nil { + return nil, fmt.Errorf("reconnect to %s: %w", c.dialer.Describe(), dialErr) + } + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + _ = closer.Close() + return nil, errClosed + } + old := c.closer + c.client = client + c.closer = closer + c.gen++ + newGen := c.gen + c.mu.Unlock() + + // Tear down the dead chain outside the lock. + if old != nil { + _ = old.Close() + } + // Self-healing must be loud, not silent. + c.log.Warn("ssh: connection re-established", + "route", c.dialer.Describe(), "generation", newGen) + + return healed{client: client, gen: newGen}, nil + }) + if err != nil { + return nil, 0, err + } + r, ok := v.(healed) + if !ok { + return nil, 0, fmt.Errorf("ssh: unexpected refresh result type %T", v) + } + return r.client, r.gen, nil +} + +// keepaliveLoop periodically probes the connection. A failed probe is not just a +// reason to exit: it routes through refresh so the link is proactively healed via +// the same single path as a failed operation. Keepalive only narrows the window +// in which a dead connection is noticed; the authoritative "heal now" signal is +// still a failed operation. +func (c *conn) keepaliveLoop(ctx context.Context, interval time.Duration) { + defer c.wg.Done() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + client, gen := c.snapshot() + if client == nil { + continue + } + if _, _, err := client.SendRequest("keepalive@openssh.com", true, nil); err == nil { + continue + } + c.log.Warn("ssh: keepalive failed, healing connection", + "route", c.dialer.Describe()) + if _, _, err := c.refresh(ctx, gen); err != nil { + if c.isClosed() || ctx.Err() != nil { + return + } + c.log.Warn("ssh: keepalive-triggered reconnect failed", + "route", c.dialer.Describe(), "err", err) + } + } + } +} + +// isClosed reports whether Close has been called. +func (c *conn) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +// Close tears down the connection and its whole chain and stops the keepalive +// goroutine. It is idempotent and safe for concurrent use. +func (c *conn) Close() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil + } + c.closed = true + closer := c.closer + cancel := c.kaCancel + c.client = nil + c.closer = nil + c.mu.Unlock() + + if cancel != nil { + cancel() + } + c.wg.Wait() + + if closer != nil { + if err := closer.Close(); err != nil && !isTransient(err) { + return err + } + } + return nil +} + +// withConn runs op against the live client and heals the connection on transient +// failures, retrying up to retries times. It is the single reconnect-aware +// executor that every high-level operation builds on, so the reconnect policy +// lives in exactly one place. op MUST be safe to invoke more than once; callers +// whose operation is not idempotent (e.g. a command that already started running) +// must classify their own mid-flight failures as non-transient before they reach +// here. +// +// It is a generic free function rather than a method because Go methods cannot +// have type parameters; T lets callers return a typed result (a net.Conn for a +// tunnel dial, a session for Run, …) without boxing. +func withConn[T any](ctx context.Context, c *conn, retries int, op func(context.Context, *ssh.Client) (T, error)) (T, error) { + var zero T + + client, gen := c.snapshot() + for attempt := 0; ; attempt++ { + if err := ctx.Err(); err != nil { + return zero, err + } + if client == nil { + return zero, errClosed + } + + result, err := op(ctx, client) + if err == nil { + return result, nil + } + + // An explicit cancellation outranks any transient classification. + if ctx.Err() != nil { + return zero, ctx.Err() + } + if !isTransient(err) { + return zero, err + } + if attempt >= retries { + return zero, fmt.Errorf("after %d attempt(s): %w", attempt+1, err) + } + + c.log.Warn("ssh: operation failed on broken connection, healing", + "route", c.dialer.Describe(), "attempt", attempt+1, "err", err) + + client, gen, err = c.refresh(ctx, gen) + if err != nil { + return zero, fmt.Errorf("heal connection: %w", err) + } + } +} diff --git a/internal/infrastructure/ssh/v2/conn_test.go b/internal/infrastructure/ssh/v2/conn_test.go new file mode 100644 index 0000000..3aa6aa3 --- /dev/null +++ b/internal/infrastructure/ssh/v2/conn_test.go @@ -0,0 +1,278 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "io" + "sync" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +func newTestConn(t *testing.T, d Dialer, keepalive time.Duration) *conn { + t.Helper() + o := defaultOptions() + o.log = quietLogger() + o.keepalive = keepalive + o.dialTimeout = 5 * time.Second + c, err := newConn(context.Background(), d, o) + if err != nil { + t.Fatalf("newConn: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + return c +} + +func TestConnSnapshotInitialGeneration(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + + c := newTestConn(t, d, 0) + client, gen := c.snapshot() + if client == nil { + t.Fatalf("snapshot returned nil client") + } + if gen != 1 { + t.Fatalf("initial generation = %d, want 1", gen) + } + if d.dialCount() != 1 { + t.Fatalf("dial count = %d, want 1", d.dialCount()) + } +} + +func TestConnRefreshStaleGenerationDoesNotReconnect(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + // Pretend we failed on generation 0, but the live connection is already at + // generation 1 — nobody should reconnect a healthy link. + client, gen, err := c.refresh(context.Background(), 0) + if err != nil { + t.Fatalf("refresh: %v", err) + } + if gen != 1 { + t.Fatalf("generation = %d, want 1 (unchanged)", gen) + } + if client == nil { + t.Fatalf("refresh returned nil client") + } + if d.dialCount() != 1 { + t.Fatalf("dial count = %d, want 1 (no reconnect)", d.dialCount()) + } +} + +func TestConnRefreshDeduplicatesConcurrentReconnects(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + if d.dialCount() != 1 { + t.Fatalf("setup dial count = %d, want 1", d.dialCount()) + } + + // Gate the next dial so all concurrent refreshers pile into one flight. + gate := make(chan struct{}) + d.setGate(gate) + + const n = 8 + var wg sync.WaitGroup + gens := make([]uint64, n) + errs := make([]error, n) + start := make(chan struct{}) + + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + <-start + _, gen, err := c.refresh(context.Background(), 1) + gens[i] = gen + errs[i] = err + }(i) + } + + close(start) + // Give the goroutines time to coalesce in singleflight before releasing the + // single dial. Polling the dial count avoids a fixed sleep race. + waitFor(t, 2*time.Second, func() bool { return d.dialCount() == 2 }) + close(gate) + wg.Wait() + + for i := 0; i < n; i++ { + if errs[i] != nil { + t.Fatalf("refresher %d error: %v", i, errs[i]) + } + if gens[i] != 2 { + t.Fatalf("refresher %d generation = %d, want 2", i, gens[i]) + } + } + if d.dialCount() != 2 { + t.Fatalf("dial count = %d, want 2 (one reconnect for all callers)", d.dialCount()) + } +} + +func TestWithConnHealsOnTransientFailure(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + var calls int + got, err := withConn(context.Background(), c, 3, func(_ context.Context, client *ssh.Client) (string, error) { + calls++ + if calls == 1 { + return "", io.EOF // looks like a dropped session + } + if client == nil { + return "", errors.New("nil client after heal") + } + return "ok", nil + }) + if err != nil { + t.Fatalf("withConn: %v", err) + } + if got != "ok" { + t.Fatalf("result = %q, want ok", got) + } + if calls != 2 { + t.Fatalf("op calls = %d, want 2", calls) + } + if d.dialCount() != 2 { + t.Fatalf("dial count = %d, want 2 (one heal)", d.dialCount()) + } +} + +func TestWithConnDoesNotRetryNonTransient(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + sentinel := errors.New("application error") + var calls int + _, err := withConn(context.Background(), c, 3, func(_ context.Context, _ *ssh.Client) (struct{}, error) { + calls++ + return struct{}{}, sentinel + }) + if !errors.Is(err, sentinel) { + t.Fatalf("err = %v, want %v", err, sentinel) + } + if calls != 1 { + t.Fatalf("op calls = %d, want 1 (no retry)", calls) + } + if d.dialCount() != 1 { + t.Fatalf("dial count = %d, want 1 (no reconnect)", d.dialCount()) + } +} + +func TestWithConnRespectsContextCancellation(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var calls int + _, err := withConn(ctx, c, 3, func(_ context.Context, _ *ssh.Client) (struct{}, error) { + calls++ + return struct{}{}, nil + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err = %v, want context.Canceled", err) + } + if calls != 0 { + t.Fatalf("op calls = %d, want 0 (ctx already canceled)", calls) + } +} + +func TestWithConnExhaustsRetries(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + var calls int + _, err := withConn(context.Background(), c, 2, func(_ context.Context, _ *ssh.Client) (struct{}, error) { + calls++ + return struct{}{}, io.EOF + }) + if err == nil { + t.Fatalf("expected error after exhausting retries") + } + if !errors.Is(err, io.EOF) { + t.Fatalf("err = %v, want wrapped io.EOF", err) + } + // retries=2 means: initial attempt + 2 heals = 3 op calls. + if calls != 3 { + t.Fatalf("op calls = %d, want 3", calls) + } +} + +func TestConnCloseIsIdempotent(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestConn(t, d, 0) + + if err := c.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } + if _, _, err := c.refresh(context.Background(), 1); !errors.Is(err, errClosed) { + t.Fatalf("refresh after close = %v, want errClosed", err) + } +} + +func TestKeepaliveHealsDroppedConnection(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + _ = newTestConn(t, d, 100*time.Millisecond) + + // Kill the live transport; the keepalive probe should notice and heal. + srv.dropConns() + + waitFor(t, 5*time.Second, func() bool { return d.dialCount() >= 2 }) + if d.dialCount() < 2 { + t.Fatalf("dial count = %d, want >= 2 (keepalive heal)", d.dialCount()) + } +} + +// waitFor polls cond until it is true or the timeout elapses. It is an eventual +// assertion with a bound, not a fixed sleep. +func waitFor(t *testing.T, timeout time.Duration, cond func() bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(5 * time.Millisecond) + } +} diff --git a/internal/infrastructure/ssh/v2/dialer.go b/internal/infrastructure/ssh/v2/dialer.go new file mode 100644 index 0000000..f1064f8 --- /dev/null +++ b/internal/infrastructure/ssh/v2/dialer.go @@ -0,0 +1,221 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "strings" + "time" + + "golang.org/x/crypto/ssh" +) + +// Dialer is the injection point that decides how a live connection to the +// target host is established. Implementations hide whether the path is direct +// or routed through one or more jump hosts; the rest of the package only sees a +// ready *ssh.Client plus a Closer for the whole chain. +type Dialer interface { + // Dial brings up a live connection to the target host, transparently + // traversing any intermediate jump hops. The returned io.Closer tears down + // the ENTIRE chain (target + every jump + any ssh-agent connection). It + // must honor ctx for cancellation and deadlines. + Dial(ctx context.Context) (*ssh.Client, io.Closer, error) + // Describe returns a human-readable description of the route for logs and + // error messages. + Describe() string +} + +// hostKeyDefaulter lets the Client push its host-key default into a Dialer that +// supports per-hop host-key resolution (the built-in route). It is unexported +// on purpose: third-party Dialers simply ignore the Client host-key options and +// own their verification policy entirely. +type hostKeyDefaulter interface { + setDefaultHostKey(ssh.HostKeyCallback) +} + +// Route builds a Dialer for a path of one or more hops. first is the entry +// point; more lists subsequent hops in travel order, and the LAST element is +// always the target host. A single argument means a direct connection, two +// means one jump, and so on. The (first, more...) signature guarantees at least +// one hop at compile time. +func Route(first Endpoint, more ...Endpoint) Dialer { + hops := make([]Endpoint, 0, 1+len(more)) + hops = append(hops, first) + hops = append(hops, more...) + return &route{hops: hops} +} + +// route is the built-in Dialer implementation produced by Route. +type route struct { + hops []Endpoint + defaultHostKey ssh.HostKeyCallback +} + +// setDefaultHostKey implements hostKeyDefaulter. +func (r *route) setDefaultHostKey(cb ssh.HostKeyCallback) { r.defaultHostKey = cb } + +// Describe renders the route as "user@host -> user@host -> ...". +func (r *route) Describe() string { + labels := make([]string, len(r.hops)) + for i, hop := range r.hops { + labels[i] = hop.label() + } + return strings.Join(labels, " -> ") +} + +// Dial establishes the full chain: it dials the first hop over TCP, then for +// every subsequent hop opens a forwarded connection from the previous hop and +// performs a fresh SSH handshake on top of it. On any failure every resource +// opened so far is closed before returning. +func (r *route) Dial(ctx context.Context) (cl *ssh.Client, closer io.Closer, err error) { + chain := &chainCloser{} + // Unwind everything on error so a partially-built chain never leaks. + defer func() { + if err != nil { + _ = chain.Close() + } + }() + + first := r.hops[0] + cfg, agentCloser, cfgErr := first.clientConfig(ctx, r.defaultHostKey) + if cfgErr != nil { + return nil, nil, fmt.Errorf("build config for %s: %w", first.label(), cfgErr) + } + chain.add(agentCloser) + + current, dialErr := dialSSH(ctx, first.addr(), cfg) + if dialErr != nil { + return nil, nil, fmt.Errorf("dial %s: %w", first.label(), dialErr) + } + chain.add(current) + + for _, hop := range r.hops[1:] { + hopCfg, hopAgentCloser, hopErr := hop.clientConfig(ctx, r.defaultHostKey) + if hopErr != nil { + return nil, nil, fmt.Errorf("build config for %s: %w", hop.label(), hopErr) + } + chain.add(hopAgentCloser) + + next, jumpErr := dialThroughJump(ctx, current, hop.addr()) + if jumpErr != nil { + return nil, nil, fmt.Errorf("dial %s via %s: %w", hop.label(), first.label(), jumpErr) + } + + hopClient, handshakeErr := handshakeOver(ctx, next, hop.addr(), hopCfg) + if handshakeErr != nil { + _ = next.Close() + return nil, nil, fmt.Errorf("handshake to %s: %w", hop.label(), handshakeErr) + } + chain.add(hopClient) + current = hopClient + } + + return current, chain, nil +} + +// dialSSH performs a context-aware TCP dial followed by an SSH handshake. The +// context bounds the TCP connect, and its deadline (if any) bounds the +// handshake; the deadline is cleared once the handshake succeeds. +func dialSSH(ctx context.Context, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + client, err := handshakeOver(ctx, conn, addr, cfg) + if err != nil { + _ = conn.Close() + return nil, err + } + return client, nil +} + +// handshakeOver runs the SSH client handshake on an existing net.Conn, honoring +// the context deadline during the handshake. +func handshakeOver(ctx context.Context, conn net.Conn, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { + if deadline, ok := ctx.Deadline(); ok { + _ = conn.SetDeadline(deadline) + } + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, cfg) + if err != nil { + return nil, err + } + // Clear the handshake deadline so it does not bleed into later traffic. + _ = conn.SetDeadline(time.Time{}) + return ssh.NewClient(sshConn, chans, reqs), nil +} + +// dialThroughJump opens a forwarded TCP connection to addr from the jump client +// while respecting ctx. ssh.Client.Dial has no context variant, so the dial runs +// in a goroutine and ctx cancellation abandons (and later closes) the result +// without leaking the goroutine. +func dialThroughJump(ctx context.Context, jump *ssh.Client, addr string) (net.Conn, error) { + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + conn, err := jump.Dial("tcp", addr) + ch <- result{conn: conn, err: err} + }() + + select { + case <-ctx.Done(): + go func() { + if r := <-ch; r.conn != nil { + _ = r.conn.Close() + } + }() + return nil, ctx.Err() + case r := <-ch: + return r.conn, r.err + } +} + +// chainCloser closes a set of resources in reverse order of registration, so +// the target host is torn down before the jump hops that carry it. nil entries +// are ignored, letting callers register optional ssh-agent closers +// unconditionally. +type chainCloser struct { + closers []io.Closer +} + +// add registers c for later closing. A nil closer is ignored. +func (cc *chainCloser) add(c io.Closer) { + if c != nil { + cc.closers = append(cc.closers, c) + } +} + +// Close closes every registered resource in reverse order and joins any errors. +func (cc *chainCloser) Close() error { + var errs []error + for i := len(cc.closers) - 1; i >= 0; i-- { + if err := cc.closers[i].Close(); err != nil && !isTransient(err) { + errs = append(errs, err) + } + } + if len(errs) == 0 { + return nil + } + return fmt.Errorf("close ssh chain: %w", errors.Join(errs...)) +} diff --git a/internal/infrastructure/ssh/v2/dialer_test.go b/internal/infrastructure/ssh/v2/dialer_test.go new file mode 100644 index 0000000..10bda67 --- /dev/null +++ b/internal/infrastructure/ssh/v2/dialer_test.go @@ -0,0 +1,152 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "errors" + "fmt" + "io" + "sync" + "testing" +) + +func TestRouteHopsAndDescribe(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + first Endpoint + more []Endpoint + wantHops int + wantDesc string + }{ + { + name: "direct", + first: Endpoint{User: "root", Addr: "target"}, + wantHops: 1, + wantDesc: "root@target:22", + }, + { + name: "single jump", + first: Endpoint{User: "bastion", Addr: "jump:2222"}, + more: []Endpoint{{User: "root", Addr: "target"}}, + wantHops: 2, + wantDesc: "bastion@jump:2222 -> root@target:22", + }, + { + name: "two jumps preserve order", + first: Endpoint{User: "a", Addr: "h1"}, + more: []Endpoint{ + {User: "b", Addr: "h2"}, + {User: "c", Addr: "h3"}, + }, + wantHops: 3, + wantDesc: "a@h1:22 -> b@h2:22 -> c@h3:22", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + d := Route(tc.first, tc.more...) + r, ok := d.(*route) + if !ok { + t.Fatalf("Route returned %T, want *route", d) + } + if len(r.hops) != tc.wantHops { + t.Fatalf("hops = %d, want %d", len(r.hops), tc.wantHops) + } + if got := d.Describe(); got != tc.wantDesc { + t.Fatalf("Describe() = %q, want %q", got, tc.wantDesc) + } + }) + } +} + +// recordCloser records the order in which it is closed and can fail on demand. +type recordCloser struct { + id int + order *[]int + mu *sync.Mutex + err error +} + +func (c recordCloser) Close() error { + c.mu.Lock() + *c.order = append(*c.order, c.id) + c.mu.Unlock() + return c.err +} + +func TestChainCloserReverseOrderAndNilSkip(t *testing.T) { + t.Parallel() + + var order []int + var mu sync.Mutex + cc := &chainCloser{} + + cc.add(recordCloser{id: 1, order: &order, mu: &mu}) + cc.add(nil) // must be skipped without panicking + cc.add(recordCloser{id: 2, order: &order, mu: &mu}) + cc.add(recordCloser{id: 3, order: &order, mu: &mu}) + + if err := cc.Close(); err != nil { + t.Fatalf("Close() unexpected error: %v", err) + } + + want := []int{3, 2, 1} + if len(order) != len(want) { + t.Fatalf("close order = %v, want %v", order, want) + } + for i := range want { + if order[i] != want[i] { + t.Fatalf("close order = %v, want %v", order, want) + } + } +} + +func TestChainCloserAggregatesErrors(t *testing.T) { + t.Parallel() + + var order []int + var mu sync.Mutex + boom := errors.New("close boom") + cc := &chainCloser{} + cc.add(recordCloser{id: 1, order: &order, mu: &mu, err: boom}) + cc.add(recordCloser{id: 2, order: &order, mu: &mu}) + + err := cc.Close() + if err == nil || !errors.Is(err, boom) { + t.Fatalf("Close() = %v, want error wrapping %v", err, boom) + } +} + +// transientCloser returns a transient error from Close; chainCloser must ignore +// it (an already-dead peer is not a close failure worth surfacing). +type transientCloser struct{} + +func (transientCloser) Close() error { return fmt.Errorf("read: %w", io.EOF) } + +func TestChainCloserIgnoresTransientCloseErrors(t *testing.T) { + t.Parallel() + + cc := &chainCloser{} + cc.add(transientCloser{}) + if err := cc.Close(); err != nil { + t.Fatalf("Close() = %v, want nil (transient close errors ignored)", err) + } +} diff --git a/internal/infrastructure/ssh/v2/endpoint.go b/internal/infrastructure/ssh/v2/endpoint.go new file mode 100644 index 0000000..56f3251 --- /dev/null +++ b/internal/infrastructure/ssh/v2/endpoint.go @@ -0,0 +1,178 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "os/user" + "path/filepath" + "strings" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +// Endpoint describes a single SSH host along a route: how to address it and how +// to authenticate to it. The zero value is not useful; at minimum User and Addr +// must be set, plus a usable credential (KeyPath, an ssh-agent, or both). +type Endpoint struct { + // User is the login name. + User string + // Addr is "host" or "host:port"; the default port is 22. + Addr string + // KeyPath is the path to a private key file. A leading "~" is expanded to + // the current user's home directory. It may be empty to rely solely on an + // ssh-agent. + KeyPath string + // Passphrase decrypts an encrypted KeyPath. It is optional: when empty the + // SSH_PASSPHRASE environment variable is consulted, and failing that the key + // is skipped in favor of the ssh-agent. + Passphrase string + // HostKey verifies the server's host key. When nil the Client-level callback + // (see WithHostKeyCallback) applies, defaulting to InsecureIgnoreHostKey. + HostKey ssh.HostKeyCallback +} + +// addr returns the dial address with a default :22 port when none is present. +func (e Endpoint) addr() string { + if e.Addr == "" { + return "" + } + if _, _, err := net.SplitHostPort(e.Addr); err == nil { + return e.Addr + } + return net.JoinHostPort(e.Addr, "22") +} + +// label is a short human-readable identity for logs and route descriptions. +func (e Endpoint) label() string { + return fmt.Sprintf("%s@%s", e.User, e.addr()) +} + +// clientConfig builds the ssh.ClientConfig for this endpoint and returns an +// io.Closer that owns any ssh-agent connection opened for authentication. The +// caller (the route's connection chain) is responsible for closing it so the +// agent socket is not leaked on every reconnect. The closer is nil when no +// agent connection was opened. +func (e Endpoint) clientConfig(ctx context.Context, defaultHostKey ssh.HostKeyCallback) (*ssh.ClientConfig, io.Closer, error) { + var signers []ssh.Signer + + if e.KeyPath != "" { + keyPath, err := expandTilde(e.KeyPath) + if err != nil { + return nil, nil, fmt.Errorf("resolve key path %q: %w", e.KeyPath, err) + } + raw, err := os.ReadFile(keyPath) + if err != nil { + return nil, nil, fmt.Errorf("read private key %q: %w", keyPath, err) + } + signer, err := parseSigner(raw, e.Passphrase) + if err != nil { + return nil, nil, fmt.Errorf("parse private key %q: %w", keyPath, err) + } + if signer != nil { + signers = append(signers, signer) + } + } + + agentCloser := io.Closer(nil) + if sock := os.Getenv("SSH_AUTH_SOCK"); sock != "" { + var dialer net.Dialer + //nolint:gosec // G704: SSH_AUTH_SOCK is the standard, operator-controlled ssh-agent socket path. + if conn, err := dialer.DialContext(ctx, "unix", sock); err == nil { + if agentSigners, err := agent.NewClient(conn).Signers(); err == nil { + signers = append(signers, agentSigners...) + } + // The connection must stay open for the agent signers to sign; the + // route's chain closer owns and closes it. + agentCloser = conn + } + } + + if len(signers) == 0 { + return nil, nil, fmt.Errorf("no usable credentials for %s: set KeyPath or start an ssh-agent", e.label()) + } + + hostKey := e.HostKey + if hostKey == nil { + hostKey = defaultHostKey + } + if hostKey == nil { + //nolint:gosec // G106: last-resort default for ephemeral e2e VMs; overridable per Endpoint or via WithHostKeyCallback. + hostKey = ssh.InsecureIgnoreHostKey() + } + + cfg := &ssh.ClientConfig{ + User: e.User, + Auth: []ssh.AuthMethod{ssh.PublicKeys(signers...)}, + HostKeyCallback: hostKey, + Timeout: defaultDialTimeout, + } + return cfg, agentCloser, nil +} + +// parseSigner parses a private key, transparently handling passphrase-protected +// keys. When the key is encrypted but no passphrase is available (neither the +// explicit value nor SSH_PASSPHRASE), it returns (nil, nil) so the caller falls +// back to the ssh-agent. Passphrase protection is detected structurally via +// *ssh.PassphraseMissingError, not by inspecting error text. +func parseSigner(raw []byte, passphrase string) (ssh.Signer, error) { + signer, err := ssh.ParsePrivateKey(raw) + if err == nil { + return signer, nil + } + + var missing *ssh.PassphraseMissingError + if !errors.As(err, &missing) { + return nil, err + } + + pass := passphrase + if pass == "" { + pass = os.Getenv("SSH_PASSPHRASE") + } + if pass == "" { + // Encrypted key with no passphrase: defer to the ssh-agent fallback. + return nil, nil + } + + signer, err = ssh.ParsePrivateKeyWithPassphrase(raw, []byte(pass)) + if err != nil { + return nil, fmt.Errorf("decrypt private key with passphrase: %w", err) + } + return signer, nil +} + +// expandTilde expands a leading "~" or "~/" to the current user's home dir. +func expandTilde(path string) (string, error) { + if !strings.HasPrefix(path, "~") { + return path, nil + } + usr, err := user.Current() + if err != nil { + return "", fmt.Errorf("look up current user: %w", err) + } + if path == "~" { + return usr.HomeDir, nil + } + return filepath.Join(usr.HomeDir, strings.TrimPrefix(path, "~/")), nil +} diff --git a/internal/infrastructure/ssh/v2/endpoint_test.go b/internal/infrastructure/ssh/v2/endpoint_test.go new file mode 100644 index 0000000..89a7ebf --- /dev/null +++ b/internal/infrastructure/ssh/v2/endpoint_test.go @@ -0,0 +1,144 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestEndpointAddr(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + addr string + want string + }{ + {name: "host only gets default port", addr: "example.com", want: "example.com:22"}, + {name: "host with port preserved", addr: "example.com:2222", want: "example.com:2222"}, + {name: "ipv4 with port", addr: "10.0.0.1:6443", want: "10.0.0.1:6443"}, + {name: "empty stays empty", addr: "", want: ""}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + e := Endpoint{User: "u", Addr: tc.addr} + if got := e.addr(); got != tc.want { + t.Fatalf("addr() = %q, want %q", got, tc.want) + } + }) + } +} + +func TestExpandTilde(t *testing.T) { + t.Parallel() + + t.Run("no tilde unchanged", func(t *testing.T) { + t.Parallel() + got, err := expandTilde("/etc/ssh/key") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "/etc/ssh/key" { + t.Fatalf("got %q, want /etc/ssh/key", got) + } + }) + + t.Run("tilde expands to home", func(t *testing.T) { + t.Parallel() + got, err := expandTilde("~/.ssh/id_ed25519") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got == "~/.ssh/id_ed25519" { + t.Fatalf("tilde was not expanded: %q", got) + } + }) +} + +func TestParseSigner(t *testing.T) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + plain, err := ssh.MarshalPrivateKey(priv, "") + if err != nil { + t.Fatalf("marshal plain key: %v", err) + } + plainPEM := pem.EncodeToMemory(plain) + + encrypted, err := ssh.MarshalPrivateKeyWithPassphrase(priv, "", []byte("s3cret")) + if err != nil { + t.Fatalf("marshal encrypted key: %v", err) + } + encryptedPEM := pem.EncodeToMemory(encrypted) + + t.Run("plain key parses", func(t *testing.T) { + signer, err := parseSigner(plainPEM, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if signer == nil { + t.Fatalf("expected a signer, got nil") + } + }) + + t.Run("encrypted without passphrase defers to agent", func(t *testing.T) { + t.Setenv("SSH_PASSPHRASE", "") + signer, err := parseSigner(encryptedPEM, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if signer != nil { + t.Fatalf("expected nil signer (agent fallback), got one") + } + }) + + t.Run("encrypted with explicit passphrase parses", func(t *testing.T) { + signer, err := parseSigner(encryptedPEM, "s3cret") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if signer == nil { + t.Fatalf("expected a signer, got nil") + } + }) + + t.Run("encrypted with env passphrase parses", func(t *testing.T) { + t.Setenv("SSH_PASSPHRASE", "s3cret") + signer, err := parseSigner(encryptedPEM, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if signer == nil { + t.Fatalf("expected a signer, got nil") + } + }) + + t.Run("garbage fails", func(t *testing.T) { + if _, err := parseSigner([]byte("not a key"), ""); err == nil { + t.Fatalf("expected error for garbage input") + } + }) +} diff --git a/internal/infrastructure/ssh/v2/errors.go b/internal/infrastructure/ssh/v2/errors.go new file mode 100644 index 0000000..7715426 --- /dev/null +++ b/internal/infrastructure/ssh/v2/errors.go @@ -0,0 +1,111 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "syscall" +) + +// errClosed is returned by the connection core once Close has been called. +var errClosed = errors.New("ssh: client is closed") + +// isTransient reports whether err denotes a recoverable transport failure that +// healing the SSH connection might fix (a dropped session, a reset peer, a +// timed-out read, …). Classification is done structurally via errors.Is and +// errors.As against standard error values and types — never by matching error +// text — so it stays correct as wrapping changes. +// +// Context cancellation (context.Canceled, context.DeadlineExceeded) is +// deliberately NOT transient: those mean the caller asked to stop, so retrying +// would ignore an explicit signal. +func isTransient(err error) bool { + if err == nil { + return false + } + + // Context cancellation outranks everything: it is an explicit stop signal, + // not a recoverable transport failure. Check it first because + // context.DeadlineExceeded also satisfies net.Error with Timeout()==true. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + + // A clean or truncated EOF is the most common symptom of a session that + // died underneath us (the x/crypto/ssh mux surfaces the stored disconnect + // error, usually io.EOF, to pending channel/session opens). + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + + // Operating on a connection/listener that was already closed by a peer or + // by our own reconnect. + if errors.Is(err, net.ErrClosed) { + return true + } + + // Low-level socket failures that a fresh dial typically recovers from. + if errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.ECONNREFUSED) || + errors.Is(err, syscall.ECONNABORTED) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ETIMEDOUT) || + errors.Is(err, syscall.EHOSTUNREACH) || + errors.Is(err, syscall.ENETUNREACH) { + return true + } + + // Any net.Error that reports a timeout (covers i/o timeouts that are not a + // bare syscall.ETIMEDOUT, e.g. deadline-driven failures). + var nerr net.Error + if errors.As(err, &nerr) && nerr.Timeout() { + return true + } + + return false +} + +// ExitError reports that a remote command ran to completion but exited with a +// non-zero status. It is intentionally distinct from a transport error: a +// non-zero exit is a normal program outcome, not a broken connection, so the +// operation core must never retry it. +// +// It is part of the contract for the future Run operation (see package docs); +// the connection core already treats *ExitError as non-transient because +// isTransient returns false for it. +type ExitError struct { + // Cmd is the command line that was executed. + Cmd string + // ExitCode is the process exit status reported by the remote end. + ExitCode int + // Stderr holds captured standard error, when available. + Stderr string + // Err is the underlying error returned by the SSH library, if any. + Err error +} + +// Error implements the error interface. +func (e *ExitError) Error() string { + return fmt.Sprintf("ssh: command %q exited with code %d", e.Cmd, e.ExitCode) +} + +// Unwrap exposes the underlying SSH library error for errors.Is/As. +func (e *ExitError) Unwrap() error { return e.Err } diff --git a/internal/infrastructure/ssh/v2/errors_test.go b/internal/infrastructure/ssh/v2/errors_test.go new file mode 100644 index 0000000..1c08efa --- /dev/null +++ b/internal/infrastructure/ssh/v2/errors_test.go @@ -0,0 +1,87 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "syscall" + "testing" +) + +func TestIsTransient(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "nil", err: nil, want: false}, + {name: "io.EOF", err: io.EOF, want: true}, + {name: "wrapped EOF", err: fmt.Errorf("dial: %w", io.EOF), want: true}, + {name: "unexpected EOF", err: io.ErrUnexpectedEOF, want: true}, + {name: "net closed", err: net.ErrClosed, want: true}, + {name: "wrapped net closed", err: fmt.Errorf("accept: %w", net.ErrClosed), want: true}, + {name: "ECONNRESET", err: syscall.ECONNRESET, want: true}, + {name: "ECONNREFUSED", err: syscall.ECONNREFUSED, want: true}, + {name: "EPIPE", err: syscall.EPIPE, want: true}, + {name: "timeout net error", err: timeoutErr{}, want: true}, + {name: "context canceled", err: context.Canceled, want: false}, + {name: "context deadline", err: context.DeadlineExceeded, want: false}, + {name: "plain error", err: errors.New("boom"), want: false}, + {name: "exit error", err: &ExitError{Cmd: "false", ExitCode: 1}, want: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isTransient(tc.err); got != tc.want { + t.Fatalf("isTransient(%v) = %v, want %v", tc.err, got, tc.want) + } + }) + } +} + +// timeoutErr is a net.Error that reports a timeout but is not a syscall errno, +// exercising the net.Error/Timeout() branch of the classifier. +type timeoutErr struct{} + +func (timeoutErr) Error() string { return "i/o timeout" } +func (timeoutErr) Timeout() bool { return true } +func (timeoutErr) Temporary() bool { return true } + +func TestExitErrorUnwrap(t *testing.T) { + t.Parallel() + + underlying := errors.New("session: exited") + exit := &ExitError{Cmd: "do-thing", ExitCode: 2, Stderr: "nope", Err: underlying} + + if !errors.Is(exit, underlying) { + t.Fatalf("errors.Is should find the wrapped error") + } + var target *ExitError + if !errors.As(error(exit), &target) { + t.Fatalf("errors.As should match *ExitError") + } + if target.ExitCode != 2 { + t.Fatalf("ExitCode = %d, want 2", target.ExitCode) + } +} diff --git a/internal/infrastructure/ssh/v2/options.go b/internal/infrastructure/ssh/v2/options.go new file mode 100644 index 0000000..1959ce8 --- /dev/null +++ b/internal/infrastructure/ssh/v2/options.go @@ -0,0 +1,108 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "log/slog" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/deckhouse/storage-e2e/internal/config" + "github.com/deckhouse/storage-e2e/internal/logger" +) + +// defaultDialTimeout bounds a single (re)connect attempt performed by the +// connection core. It is deliberately internal: callers shape overall +// patience through context deadlines and WithRetries, while this only caps one +// detached dial so a reconnect storm can never hang indefinitely. +const defaultDialTimeout = 30 * time.Second + +// options holds the resolved configuration for a Client. The zero value is not +// used directly; defaultOptions seeds sensible defaults that individual Option +// funcs then override. +type options struct { + keepalive time.Duration + retries int + log *slog.Logger + hostKey ssh.HostKeyCallback + dialTimeout time.Duration +} + +// defaultOptions returns the baseline configuration. Host key verification +// defaults to InsecureIgnoreHostKey because this package targets ephemeral e2e +// VMs whose host keys are not known ahead of time; this is a conscious default +// that WithHostKeyCallback overrides. +func defaultOptions() options { + return options{ + keepalive: 0, + retries: config.SSHRetryCount, + log: logger.GetLogger(), + //nolint:gosec // G106: ephemeral e2e VMs have no known host key; conscious default, overridable via WithHostKeyCallback. + hostKey: ssh.InsecureIgnoreHostKey(), + dialTimeout: defaultDialTimeout, + } +} + +// Option configures a Client. Options are applied in order; later options win. +type Option func(*options) + +// WithKeepalive enables a background keepalive probe at interval d. A non-zero +// interval starts a goroutine that sends "keepalive@openssh.com" and proactively +// heals the connection on failure. The zero value (default) disables keepalive. +func WithKeepalive(d time.Duration) Option { + return func(o *options) { o.keepalive = d } +} + +// WithRetries sets how many times an operation re-establishes the connection +// before giving up. Negative values are clamped to zero (no reconnect retries). +func WithRetries(n int) Option { + return func(o *options) { + if n < 0 { + n = 0 + } + o.retries = n + } +} + +// WithLogger sets the structured logger used for healing WARN messages and +// diagnostics. A nil logger is ignored so the default logger remains in place. +func WithLogger(l *slog.Logger) Option { + return func(o *options) { + if l != nil { + o.log = l + } + } +} + +// WithHostKeyCallback sets the host key verification callback used for every hop +// that does not carry its own Endpoint.HostKey. A nil callback is ignored. +func WithHostKeyCallback(cb ssh.HostKeyCallback) Option { + return func(o *options) { + if cb != nil { + o.hostKey = cb + } + } +} + +// WithInsecureIgnoreHostKey disables host key verification for hops without an +// explicit Endpoint.HostKey. This is the default, but the option exists so the +// intent can be made explicit at the call site. +func WithInsecureIgnoreHostKey() Option { + //nolint:gosec // G106: explicit opt-in to skip host key verification. + return func(o *options) { o.hostKey = ssh.InsecureIgnoreHostKey() } +} diff --git a/internal/infrastructure/ssh/v2/testserver_test.go b/internal/infrastructure/ssh/v2/testserver_test.go new file mode 100644 index 0000000..f487109 --- /dev/null +++ b/internal/infrastructure/ssh/v2/testserver_test.go @@ -0,0 +1,254 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "io" + "log/slog" + "net" + "strconv" + "sync" + "testing" + + "golang.org/x/crypto/ssh" +) + +// quietLogger returns a logger that discards output, keeping test logs clean. +func quietLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +// testServer is an in-process SSH server on 127.0.0.1 used by tests. It accepts +// any client (NoClientAuth), answers keepalive global requests, and serves +// "direct-tcpip" channels by dialing the requested address and proxying bytes — +// enough to exercise tunnels end to end. dropConns force-closes live transports +// to simulate a dropped session. +type testServer struct { + ln net.Listener + cfg *ssh.ServerConfig + wg sync.WaitGroup + closeOnce sync.Once + + mu sync.Mutex + conns []net.Conn +} + +func newTestServer(t *testing.T) *testServer { + t.Helper() + + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generate host key: %v", err) + } + signer, err := ssh.NewSignerFromSigner(priv) + if err != nil { + t.Fatalf("build host signer: %v", err) + } + + cfg := &ssh.ServerConfig{NoClientAuth: true} + cfg.AddHostKey(signer) + + var lc net.ListenConfig + ln, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + + s := &testServer{ln: ln, cfg: cfg} + s.wg.Add(1) + go s.acceptLoop() + t.Cleanup(s.Close) + return s +} + +func (s *testServer) addr() string { return s.ln.Addr().String() } + +func (s *testServer) acceptLoop() { + defer s.wg.Done() + for { + nConn, err := s.ln.Accept() + if err != nil { + return + } + s.mu.Lock() + s.conns = append(s.conns, nConn) + s.mu.Unlock() + + s.wg.Add(1) + go s.handleConn(nConn) + } +} + +func (s *testServer) handleConn(nConn net.Conn) { + defer s.wg.Done() + + sconn, chans, reqs, err := ssh.NewServerConn(nConn, s.cfg) + if err != nil { + _ = nConn.Close() + return + } + defer sconn.Close() + + go func() { + for req := range reqs { + if req.WantReply { + _ = req.Reply(true, nil) + } + } + }() + + for newCh := range chans { + if newCh.ChannelType() != "direct-tcpip" { + _ = newCh.Reject(ssh.UnknownChannelType, "only direct-tcpip is supported") + continue + } + go handleDirectTCPIP(newCh) + } +} + +// directTCPIPMsg is the extra data layout of a direct-tcpip channel open. +type directTCPIPMsg struct { + DestAddr string + DestPort uint32 + OrigAddr string + OrigPort uint32 +} + +func handleDirectTCPIP(newCh ssh.NewChannel) { + var msg directTCPIPMsg + if err := ssh.Unmarshal(newCh.ExtraData(), &msg); err != nil { + _ = newCh.Reject(ssh.ConnectionFailed, "bad direct-tcpip payload") + return + } + + target := net.JoinHostPort(msg.DestAddr, strconv.Itoa(int(msg.DestPort))) + var dialer net.Dialer + remote, err := dialer.DialContext(context.Background(), "tcp", target) + if err != nil { + _ = newCh.Reject(ssh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newCh.Accept() + if err != nil { + _ = remote.Close() + return + } + go ssh.DiscardRequests(reqs) + + go func() { + _, _ = io.Copy(ch, remote) + _ = ch.Close() + }() + go func() { + _, _ = io.Copy(remote, ch) + _ = remote.Close() + }() +} + +// dropConns force-closes all live transport connections, simulating a session +// drop (a Wi-Fi flap on the developer's laptop). +func (s *testServer) dropConns() { + s.mu.Lock() + defer s.mu.Unlock() + for _, c := range s.conns { + _ = c.Close() + } + s.conns = nil +} + +func (s *testServer) Close() { + s.closeOnce.Do(func() { + _ = s.ln.Close() + s.dropConns() + s.wg.Wait() + }) +} + +// serverDialer is a test Dialer that connects to a testServer. It counts dials +// and can gate each dial on a channel to make reconnect concurrency +// deterministic. +type serverDialer struct { + addr string + + mu sync.Mutex + dials int + gate chan struct{} +} + +func (d *serverDialer) Dial(ctx context.Context) (*ssh.Client, io.Closer, error) { + d.mu.Lock() + d.dials++ + gate := d.gate + d.mu.Unlock() + + if gate != nil { + <-gate + } + + client, err := dialSSH(ctx, d.addr, &ssh.ClientConfig{ + User: "test", + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if err != nil { + return nil, nil, err + } + return client, client, nil +} + +func (d *serverDialer) Describe() string { return "test://" + d.addr } + +func (d *serverDialer) dialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dials +} + +func (d *serverDialer) setGate(gate chan struct{}) { + d.mu.Lock() + d.gate = gate + d.mu.Unlock() +} + +// newEchoServer starts a TCP echo server on 127.0.0.1 and returns its port. +func newEchoServer(t *testing.T) int { + t.Helper() + var lc net.ListenConfig + ln, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("echo listen: %v", err) + } + t.Cleanup(func() { _ = ln.Close() }) + + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + _, _ = io.Copy(c, c) + }(c) + } + }() + + return ln.Addr().(*net.TCPAddr).Port +} diff --git a/internal/infrastructure/ssh/v2/tunnel.go b/internal/infrastructure/ssh/v2/tunnel.go new file mode 100644 index 0000000..00eebbc --- /dev/null +++ b/internal/infrastructure/ssh/v2/tunnel.go @@ -0,0 +1,223 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "net" + "strconv" + "sync" + "time" + + "golang.org/x/crypto/ssh" +) + +// acceptDeadline bounds each listener Accept so the serve loop re-checks its +// context promptly even when no client is connecting. +const acceptDeadline = 500 * time.Millisecond + +// Tunnel is a local TCP forward to a port on the target host. It listens on +// 127.0.0.1 on an automatically chosen free port and heals transparently: when +// the SSH session drops, the next forwarded connection re-opens it via the +// connection core and the listener keeps serving instead of dying. +type Tunnel struct { + // LocalPort is the chosen local port on 127.0.0.1. + LocalPort int + // RemotePort is the forwarded port on the target host. + RemotePort int + + listener net.Listener + cancel context.CancelFunc + wg sync.WaitGroup + closeOnce sync.Once + closeErr error +} + +// Tunnel forwards remotePort on the target host to a fresh local port on +// 127.0.0.1 and returns once the listener is up. The returned Tunnel serves +// until its Close is called or ctx is canceled. Establishing each forwarded +// connection is reconnect-aware and bounded by the Client's retry budget; every +// heal is logged at WARN. +func (c *Client) Tunnel(ctx context.Context, remotePort int) (*Tunnel, error) { + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("tunnel setup: %w", err) + } + + var lc net.ListenConfig + listener, err := lc.Listen(ctx, "tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("listen on local port: %w", err) + } + tcpAddr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + _ = listener.Close() + return nil, fmt.Errorf("unexpected listener address type %T", listener.Addr()) + } + localPort := tcpAddr.Port + + // The serve loop outlives the setup call, so derive a cancellable context + // from the caller's: caller cancellation stops the tunnel, and so does Close. + serveCtx, cancel := context.WithCancel(ctx) + + t := &Tunnel{ + LocalPort: localPort, + RemotePort: remotePort, + listener: listener, + cancel: cancel, + } + + t.wg.Add(1) + go t.serve(serveCtx, c.conn, c.retries, c.log) + + c.log.Info("ssh: tunnel established", + "local", t.LocalAddr(), "remote_port", remotePort, "route", c.conn.dialer.Describe()) + + return t, nil +} + +// LocalAddr returns the local "127.0.0.1:" address of the tunnel. +func (t *Tunnel) LocalAddr() string { + return "127.0.0.1:" + strconv.Itoa(t.LocalPort) +} + +// Close stops the tunnel: it cancels the serve loop, closes the listener, and +// waits for all in-flight connections to drain. It is idempotent and safe for +// concurrent use. It does not close the underlying SSH connection, which the +// owning Client manages. +func (t *Tunnel) Close() error { + t.closeOnce.Do(func() { + t.cancel() + t.closeErr = t.listener.Close() + t.wg.Wait() + }) + return t.closeErr +} + +// serve accepts local connections and forwards each one over the SSH connection. +// A short Accept deadline keeps the loop responsive to ctx; a dead session does +// not stop the loop — it is healed per connection in handle. +func (t *Tunnel) serve(ctx context.Context, core *conn, retries int, log *slog.Logger) { + defer t.wg.Done() + + for { + select { + case <-ctx.Done(): + return + default: + } + + if tcp, ok := t.listener.(*net.TCPListener); ok { + _ = tcp.SetDeadline(time.Now().Add(acceptDeadline)) + } + + local, err := t.listener.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + continue + } + // The listener was closed by Close (not via ctx); stop serving. + return + } + + t.wg.Add(1) + go func() { + defer t.wg.Done() + t.handle(ctx, core, retries, local, log) + }() + } +} + +// handle forwards a single accepted local connection to the remote port. The +// remote dial is reconnect-aware: a transient failure heals the SSH connection +// and retries within the budget. Once both ends are connected, bytes are copied +// in both directions; closing the conns on completion or cancellation unblocks +// any read still in flight. +func (t *Tunnel) handle(ctx context.Context, core *conn, retries int, local net.Conn, log *slog.Logger) { + defer local.Close() + + remotePort := t.RemotePort + remote, err := withConn(ctx, core, retries, func(ctx context.Context, client *ssh.Client) (net.Conn, error) { + return dialChannel(ctx, client, "127.0.0.1:"+strconv.Itoa(remotePort)) + }) + if err != nil { + if ctx.Err() == nil { + log.Warn("ssh: tunnel forward failed", + "local", t.LocalAddr(), "remote_port", remotePort, "err", err) + } + return + } + defer remote.Close() + + // Closing both conns on cancellation unblocks the copy goroutines, which + // would otherwise sit in a blocking Read. + stop := make(chan struct{}) + defer close(stop) + go func() { + select { + case <-ctx.Done(): + _ = local.Close() + _ = remote.Close() + case <-stop: + } + }() + + done := make(chan struct{}, 2) + go func() { _, _ = io.Copy(remote, local); done <- struct{}{} }() + go func() { _, _ = io.Copy(local, remote); done <- struct{}{} }() + + // When one direction ends, close both ends to unblock the other. + <-done + _ = local.Close() + _ = remote.Close() + <-done +} + +// dialChannel opens a forwarded TCP connection to addr over the SSH client while +// respecting ctx. ssh.Client.Dial has no context variant, so the dial runs in a +// goroutine and ctx cancellation abandons (and later closes) the result without +// leaking the goroutine. +func dialChannel(ctx context.Context, client *ssh.Client, addr string) (net.Conn, error) { + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + conn, err := client.Dial("tcp", addr) + ch <- result{conn: conn, err: err} + }() + + select { + case <-ctx.Done(): + go func() { + if r := <-ch; r.conn != nil { + _ = r.conn.Close() + } + }() + return nil, ctx.Err() + case r := <-ch: + return r.conn, r.err + } +} diff --git a/internal/infrastructure/ssh/v2/tunnel_test.go b/internal/infrastructure/ssh/v2/tunnel_test.go new file mode 100644 index 0000000..d615400 --- /dev/null +++ b/internal/infrastructure/ssh/v2/tunnel_test.go @@ -0,0 +1,253 @@ +/* +Copyright 2025 Flant JSC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ssh + +import ( + "context" + "net" + "testing" + "time" +) + +func newTestClient(t *testing.T, d Dialer, keepalive time.Duration) *Client { + t.Helper() + c, err := New(context.Background(), d, + WithLogger(quietLogger()), + WithKeepalive(keepalive), + ) + if err != nil { + t.Fatalf("New: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + return c +} + +// dialTimeout dials addr with a bounded context, satisfying the noctx linter. +func dialTimeout(addr string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + var d net.Dialer + return d.DialContext(ctx, "tcp", addr) +} + +// roundtrip writes payload to addr and reads the echoed reply back. +func roundtrip(t *testing.T, addr, payload string) string { + t.Helper() + conn, err := dialTimeout(addr, 3*time.Second) + if err != nil { + t.Fatalf("dial tunnel %s: %v", addr, err) + } + defer conn.Close() + + _ = conn.SetDeadline(time.Now().Add(3 * time.Second)) + if _, err := conn.Write([]byte(payload)); err != nil { + t.Fatalf("write: %v", err) + } + buf := make([]byte, len(payload)) + if _, err := readFull(conn, buf); err != nil { + t.Fatalf("read: %v", err) + } + return string(buf) +} + +func readFull(conn net.Conn, buf []byte) (int, error) { + total := 0 + for total < len(buf) { + n, err := conn.Read(buf[total:]) + total += n + if err != nil { + return total, err + } + } + return total, nil +} + +func TestTunnelForwardsTraffic(t *testing.T) { + t.Parallel() + echoPort := newEchoServer(t) + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestClient(t, d, 0) + + tun, err := c.Tunnel(context.Background(), echoPort) + if err != nil { + t.Fatalf("Tunnel: %v", err) + } + defer tun.Close() + + if tun.LocalPort == 0 { + t.Fatalf("expected a non-zero local port") + } + if tun.RemotePort != echoPort { + t.Fatalf("RemotePort = %d, want %d", tun.RemotePort, echoPort) + } + if got := tun.LocalAddr(); got == "" { + t.Fatalf("LocalAddr empty") + } + + if got := roundtrip(t, tun.LocalAddr(), "hello-tunnel"); got != "hello-tunnel" { + t.Fatalf("echo = %q, want hello-tunnel", got) + } +} + +func TestTunnelHealsAfterDroppedSession(t *testing.T) { + t.Parallel() + echoPort := newEchoServer(t) + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestClient(t, d, 0) + + tun, err := c.Tunnel(context.Background(), echoPort) + if err != nil { + t.Fatalf("Tunnel: %v", err) + } + defer tun.Close() + + if got := roundtrip(t, tun.LocalAddr(), "before"); got != "before" { + t.Fatalf("echo before drop = %q, want before", got) + } + + // Simulate the SSH session dying mid-test. + srv.dropConns() + + // The next forwarded connection must transparently heal the session and + // keep serving. Retry the roundtrip until it works within a bound. + var lastErr error + deadline := time.Now().Add(8 * time.Second) + for time.Now().Before(deadline) { + got, err := tryRoundtrip(tun.LocalAddr(), "after") + if err == nil && got == "after" { + lastErr = nil + break + } + lastErr = err + time.Sleep(20 * time.Millisecond) + } + if lastErr != nil { + t.Fatalf("tunnel did not heal after dropped session: %v", lastErr) + } + if d.dialCount() < 2 { + t.Fatalf("dial count = %d, want >= 2 (healed)", d.dialCount()) + } +} + +func tryRoundtrip(addr, payload string) (string, error) { + conn, err := dialTimeout(addr, 2*time.Second) + if err != nil { + return "", err + } + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + if _, err := conn.Write([]byte(payload)); err != nil { + return "", err + } + buf := make([]byte, len(payload)) + if _, err := readFull(conn, buf); err != nil { + return "", err + } + return string(buf), nil +} + +func TestTunnelCloseIsIdempotentAndStopsListener(t *testing.T) { + t.Parallel() + echoPort := newEchoServer(t) + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestClient(t, d, 0) + + tun, err := c.Tunnel(context.Background(), echoPort) + if err != nil { + t.Fatalf("Tunnel: %v", err) + } + addr := tun.LocalAddr() + + if err := tun.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := tun.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } + + // The listener must be gone after Close. + waitFor(t, 2*time.Second, func() bool { + conn, err := dialTimeout(addr, 200*time.Millisecond) + if err != nil { + return true + } + _ = conn.Close() + return false + }) + if conn, err := dialTimeout(addr, 200*time.Millisecond); err == nil { + _ = conn.Close() + t.Fatalf("listener still accepting after Close") + } +} + +func TestTunnelStopsWhenContextCancelled(t *testing.T) { + t.Parallel() + echoPort := newEchoServer(t) + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c := newTestClient(t, d, 0) + + ctx, cancel := context.WithCancel(context.Background()) + tun, err := c.Tunnel(ctx, echoPort) + if err != nil { + t.Fatalf("Tunnel: %v", err) + } + defer tun.Close() + + addr := tun.LocalAddr() + cancel() + + waitFor(t, 2*time.Second, func() bool { + conn, err := dialTimeout(addr, 200*time.Millisecond) + if err != nil { + return true + } + _ = conn.Close() + return false + }) + if conn, err := dialTimeout(addr, 200*time.Millisecond); err == nil { + _ = conn.Close() + t.Fatalf("listener still accepting after context cancel") + } +} + +func TestNewRejectsNilDialer(t *testing.T) { + t.Parallel() + _, err := New(context.Background(), nil) + if err == nil { + t.Fatalf("expected error for nil dialer") + } +} + +func TestClientCloseIdempotent(t *testing.T) { + t.Parallel() + srv := newTestServer(t) + d := &serverDialer{addr: srv.addr()} + c, err := New(context.Background(), d, WithLogger(quietLogger())) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("first Close: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } +} From a4bef1b6724ccc8adb005da6e5a609bb7a7868d0 Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Mon, 22 Jun 2026 11:51:22 +0300 Subject: [PATCH 02/10] refactor: remove redundant comments and improve code clarity in SSH client implementation Signed-off-by: Daniil Studenikin --- internal/infrastructure/ssh/v2/client.go | 10 ---- internal/infrastructure/ssh/v2/conn.go | 51 ------------------- internal/infrastructure/ssh/v2/conn_test.go | 9 ---- internal/infrastructure/ssh/v2/dialer.go | 43 ---------------- internal/infrastructure/ssh/v2/dialer_test.go | 1 - internal/infrastructure/ssh/v2/endpoint.go | 50 +++--------------- .../infrastructure/ssh/v2/endpoint_test.go | 22 -------- internal/infrastructure/ssh/v2/errors.go | 39 ++------------ internal/infrastructure/ssh/v2/errors_test.go | 2 - internal/infrastructure/ssh/v2/options.go | 31 ++--------- .../infrastructure/ssh/v2/testserver_test.go | 13 ----- internal/infrastructure/ssh/v2/tunnel.go | 41 +-------------- internal/infrastructure/ssh/v2/tunnel_test.go | 6 --- 13 files changed, 16 insertions(+), 302 deletions(-) diff --git a/internal/infrastructure/ssh/v2/client.go b/internal/infrastructure/ssh/v2/client.go index 56cd3dc..331b2bf 100644 --- a/internal/infrastructure/ssh/v2/client.go +++ b/internal/infrastructure/ssh/v2/client.go @@ -41,19 +41,12 @@ import ( "log/slog" ) -// Client is a self-healing SSH client over a Dialer-provided connection. It is -// safe for concurrent use; reconnects are transparent to callers. type Client struct { conn *conn retries int log *slog.Logger } -// New connects immediately over d, starts keepalive when enabled, and returns a -// ready Client. The context bounds the initial connection. If d implements the -// internal host-key defaulter (as the built-in Route does), the resolved -// host-key option is pushed into it so per-hop Endpoint.HostKey values take -// precedence over the Client-level default. func New(ctx context.Context, d Dialer, opts ...Option) (*Client, error) { if d == nil { return nil, errors.New("ssh: nil dialer") @@ -76,9 +69,6 @@ func New(ctx context.Context, d Dialer, opts ...Option) (*Client, error) { return &Client{conn: core, retries: o.retries, log: o.log}, nil } -// Close tears down the connection and its whole chain and stops keepalive. It is -// idempotent and safe for concurrent use. Open tunnels keep their listeners; the -// caller should Close those separately. func (c *Client) Close() error { return c.conn.Close() } diff --git a/internal/infrastructure/ssh/v2/conn.go b/internal/infrastructure/ssh/v2/conn.go index eb0c5a7..1c63207 100644 --- a/internal/infrastructure/ssh/v2/conn.go +++ b/internal/infrastructure/ssh/v2/conn.go @@ -29,19 +29,11 @@ import ( "golang.org/x/sync/singleflight" ) -// conn is the connection core shared by every high-level operation. It owns the -// current *ssh.Client together with the Closer for its whole chain and a -// monotonically increasing generation counter. All reconnect logic lives here so -// callers (Tunnel today, Run/Upload later) never see a reconnect: they ask for -// the live client, run their operation, and on a transient failure call withConn -// which heals the connection underneath them. type conn struct { dialer Dialer log *slog.Logger dialTimeout time.Duration - // flight deduplicates concurrent reconnects keyed by the failed generation, - // preventing a reconnect storm from tearing down a freshly healed link. flight singleflight.Group mu sync.Mutex @@ -50,14 +42,10 @@ type conn struct { gen uint64 closed bool - // keepalive lifecycle. kaCancel context.CancelFunc wg sync.WaitGroup } -// newConn establishes the initial connection and, when keepalive > 0, starts the -// background keepalive goroutine. The initial dial uses the caller's context so -// startup honors their deadline and cancellation. func newConn(ctx context.Context, d Dialer, o options) (*conn, error) { client, closer, err := d.Dial(ctx) if err != nil { @@ -74,35 +62,21 @@ func newConn(ctx context.Context, d Dialer, o options) (*conn, error) { } if o.keepalive > 0 { - // Keepalive must outlive the caller's setup context: the connection - // stays alive until Close, not until the New call returns. A fresh root - // context canceled by Close is therefore correct here. kaCtx, cancel := context.WithCancel(context.Background()) c.kaCancel = cancel c.wg.Add(1) - //nolint:contextcheck // intentional: keepalive lifetime is bound to Close, not the setup context. go c.keepaliveLoop(kaCtx, o.keepalive) } return c, nil } -// snapshot returns the current client and its generation under the lock. The -// generation lets callers tell refresh which connection failed them, so a -// concurrent heal is not duplicated. func (c *conn) snapshot() (client *ssh.Client, gen uint64) { c.mu.Lock() defer c.mu.Unlock() return c.client, c.gen } -// refresh re-establishes the connection that failed at generation failedGen and -// returns the now-current client and generation. Concurrent callers that failed -// on the same generation are collapsed into a single dial via singleflight; a -// caller whose failedGen is already stale (someone else healed first) gets the -// current client back without dialing. The actual Dial runs outside the lock and -// on a detached context with its own timeout, so one caller's cancellation can -// never abort the shared reconnect that others are waiting on. func (c *conn) refresh(ctx context.Context, failedGen uint64) (*ssh.Client, uint64, error) { key := strconv.FormatUint(failedGen, 10) @@ -117,7 +91,6 @@ func (c *conn) refresh(ctx context.Context, failedGen uint64) (*ssh.Client, uint c.mu.Unlock() return nil, errClosed } - // Someone already healed past failedGen — reuse the live client. if c.gen != failedGen { cur := healed{client: c.client, gen: c.gen} c.mu.Unlock() @@ -125,8 +98,6 @@ func (c *conn) refresh(ctx context.Context, failedGen uint64) (*ssh.Client, uint } c.mu.Unlock() - // Detach from the caller's context so one cancellation does not abort the - // shared flight, but still bound the dial with our own timeout. dialCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), c.dialTimeout) defer cancel() @@ -148,11 +119,9 @@ func (c *conn) refresh(ctx context.Context, failedGen uint64) (*ssh.Client, uint newGen := c.gen c.mu.Unlock() - // Tear down the dead chain outside the lock. if old != nil { _ = old.Close() } - // Self-healing must be loud, not silent. c.log.Warn("ssh: connection re-established", "route", c.dialer.Describe(), "generation", newGen) @@ -168,11 +137,6 @@ func (c *conn) refresh(ctx context.Context, failedGen uint64) (*ssh.Client, uint return r.client, r.gen, nil } -// keepaliveLoop periodically probes the connection. A failed probe is not just a -// reason to exit: it routes through refresh so the link is proactively healed via -// the same single path as a failed operation. Keepalive only narrows the window -// in which a dead connection is noticed; the authoritative "heal now" signal is -// still a failed operation. func (c *conn) keepaliveLoop(ctx context.Context, interval time.Duration) { defer c.wg.Done() @@ -204,15 +168,12 @@ func (c *conn) keepaliveLoop(ctx context.Context, interval time.Duration) { } } -// isClosed reports whether Close has been called. func (c *conn) isClosed() bool { c.mu.Lock() defer c.mu.Unlock() return c.closed } -// Close tears down the connection and its whole chain and stops the keepalive -// goroutine. It is idempotent and safe for concurrent use. func (c *conn) Close() error { c.mu.Lock() if c.closed { @@ -239,17 +200,6 @@ func (c *conn) Close() error { return nil } -// withConn runs op against the live client and heals the connection on transient -// failures, retrying up to retries times. It is the single reconnect-aware -// executor that every high-level operation builds on, so the reconnect policy -// lives in exactly one place. op MUST be safe to invoke more than once; callers -// whose operation is not idempotent (e.g. a command that already started running) -// must classify their own mid-flight failures as non-transient before they reach -// here. -// -// It is a generic free function rather than a method because Go methods cannot -// have type parameters; T lets callers return a typed result (a net.Conn for a -// tunnel dial, a session for Run, …) without boxing. func withConn[T any](ctx context.Context, c *conn, retries int, op func(context.Context, *ssh.Client) (T, error)) (T, error) { var zero T @@ -267,7 +217,6 @@ func withConn[T any](ctx context.Context, c *conn, retries int, op func(context. return result, nil } - // An explicit cancellation outranks any transient classification. if ctx.Err() != nil { return zero, ctx.Err() } diff --git a/internal/infrastructure/ssh/v2/conn_test.go b/internal/infrastructure/ssh/v2/conn_test.go index 3aa6aa3..c8ced89 100644 --- a/internal/infrastructure/ssh/v2/conn_test.go +++ b/internal/infrastructure/ssh/v2/conn_test.go @@ -65,8 +65,6 @@ func TestConnRefreshStaleGenerationDoesNotReconnect(t *testing.T) { d := &serverDialer{addr: srv.addr()} c := newTestConn(t, d, 0) - // Pretend we failed on generation 0, but the live connection is already at - // generation 1 — nobody should reconnect a healthy link. client, gen, err := c.refresh(context.Background(), 0) if err != nil { t.Fatalf("refresh: %v", err) @@ -92,7 +90,6 @@ func TestConnRefreshDeduplicatesConcurrentReconnects(t *testing.T) { t.Fatalf("setup dial count = %d, want 1", d.dialCount()) } - // Gate the next dial so all concurrent refreshers pile into one flight. gate := make(chan struct{}) d.setGate(gate) @@ -114,8 +111,6 @@ func TestConnRefreshDeduplicatesConcurrentReconnects(t *testing.T) { } close(start) - // Give the goroutines time to coalesce in singleflight before releasing the - // single dial. Polling the dial count avoids a fixed sleep race. waitFor(t, 2*time.Second, func() bool { return d.dialCount() == 2 }) close(gate) wg.Wait() @@ -226,7 +221,6 @@ func TestWithConnExhaustsRetries(t *testing.T) { if !errors.Is(err, io.EOF) { t.Fatalf("err = %v, want wrapped io.EOF", err) } - // retries=2 means: initial attempt + 2 heals = 3 op calls. if calls != 3 { t.Fatalf("op calls = %d, want 3", calls) } @@ -255,7 +249,6 @@ func TestKeepaliveHealsDroppedConnection(t *testing.T) { d := &serverDialer{addr: srv.addr()} _ = newTestConn(t, d, 100*time.Millisecond) - // Kill the live transport; the keepalive probe should notice and heal. srv.dropConns() waitFor(t, 5*time.Second, func() bool { return d.dialCount() >= 2 }) @@ -264,8 +257,6 @@ func TestKeepaliveHealsDroppedConnection(t *testing.T) { } } -// waitFor polls cond until it is true or the timeout elapses. It is an eventual -// assertion with a bound, not a fixed sleep. func waitFor(t *testing.T, timeout time.Duration, cond func() bool) { t.Helper() deadline := time.Now().Add(timeout) diff --git a/internal/infrastructure/ssh/v2/dialer.go b/internal/infrastructure/ssh/v2/dialer.go index f1064f8..37981e9 100644 --- a/internal/infrastructure/ssh/v2/dialer.go +++ b/internal/infrastructure/ssh/v2/dialer.go @@ -28,34 +28,15 @@ import ( "golang.org/x/crypto/ssh" ) -// Dialer is the injection point that decides how a live connection to the -// target host is established. Implementations hide whether the path is direct -// or routed through one or more jump hosts; the rest of the package only sees a -// ready *ssh.Client plus a Closer for the whole chain. type Dialer interface { - // Dial brings up a live connection to the target host, transparently - // traversing any intermediate jump hops. The returned io.Closer tears down - // the ENTIRE chain (target + every jump + any ssh-agent connection). It - // must honor ctx for cancellation and deadlines. Dial(ctx context.Context) (*ssh.Client, io.Closer, error) - // Describe returns a human-readable description of the route for logs and - // error messages. Describe() string } -// hostKeyDefaulter lets the Client push its host-key default into a Dialer that -// supports per-hop host-key resolution (the built-in route). It is unexported -// on purpose: third-party Dialers simply ignore the Client host-key options and -// own their verification policy entirely. type hostKeyDefaulter interface { setDefaultHostKey(ssh.HostKeyCallback) } -// Route builds a Dialer for a path of one or more hops. first is the entry -// point; more lists subsequent hops in travel order, and the LAST element is -// always the target host. A single argument means a direct connection, two -// means one jump, and so on. The (first, more...) signature guarantees at least -// one hop at compile time. func Route(first Endpoint, more ...Endpoint) Dialer { hops := make([]Endpoint, 0, 1+len(more)) hops = append(hops, first) @@ -63,16 +44,13 @@ func Route(first Endpoint, more ...Endpoint) Dialer { return &route{hops: hops} } -// route is the built-in Dialer implementation produced by Route. type route struct { hops []Endpoint defaultHostKey ssh.HostKeyCallback } -// setDefaultHostKey implements hostKeyDefaulter. func (r *route) setDefaultHostKey(cb ssh.HostKeyCallback) { r.defaultHostKey = cb } -// Describe renders the route as "user@host -> user@host -> ...". func (r *route) Describe() string { labels := make([]string, len(r.hops)) for i, hop := range r.hops { @@ -81,13 +59,8 @@ func (r *route) Describe() string { return strings.Join(labels, " -> ") } -// Dial establishes the full chain: it dials the first hop over TCP, then for -// every subsequent hop opens a forwarded connection from the previous hop and -// performs a fresh SSH handshake on top of it. On any failure every resource -// opened so far is closed before returning. func (r *route) Dial(ctx context.Context) (cl *ssh.Client, closer io.Closer, err error) { chain := &chainCloser{} - // Unwind everything on error so a partially-built chain never leaks. defer func() { if err != nil { _ = chain.Close() @@ -131,9 +104,6 @@ func (r *route) Dial(ctx context.Context) (cl *ssh.Client, closer io.Closer, err return current, chain, nil } -// dialSSH performs a context-aware TCP dial followed by an SSH handshake. The -// context bounds the TCP connect, and its deadline (if any) bounds the -// handshake; the deadline is cleared once the handshake succeeds. func dialSSH(ctx context.Context, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { var d net.Dialer conn, err := d.DialContext(ctx, "tcp", addr) @@ -148,8 +118,6 @@ func dialSSH(ctx context.Context, addr string, cfg *ssh.ClientConfig) (*ssh.Clie return client, nil } -// handshakeOver runs the SSH client handshake on an existing net.Conn, honoring -// the context deadline during the handshake. func handshakeOver(ctx context.Context, conn net.Conn, addr string, cfg *ssh.ClientConfig) (*ssh.Client, error) { if deadline, ok := ctx.Deadline(); ok { _ = conn.SetDeadline(deadline) @@ -158,15 +126,10 @@ func handshakeOver(ctx context.Context, conn net.Conn, addr string, cfg *ssh.Cli if err != nil { return nil, err } - // Clear the handshake deadline so it does not bleed into later traffic. _ = conn.SetDeadline(time.Time{}) return ssh.NewClient(sshConn, chans, reqs), nil } -// dialThroughJump opens a forwarded TCP connection to addr from the jump client -// while respecting ctx. ssh.Client.Dial has no context variant, so the dial runs -// in a goroutine and ctx cancellation abandons (and later closes) the result -// without leaking the goroutine. func dialThroughJump(ctx context.Context, jump *ssh.Client, addr string) (net.Conn, error) { type result struct { conn net.Conn @@ -191,22 +154,16 @@ func dialThroughJump(ctx context.Context, jump *ssh.Client, addr string) (net.Co } } -// chainCloser closes a set of resources in reverse order of registration, so -// the target host is torn down before the jump hops that carry it. nil entries -// are ignored, letting callers register optional ssh-agent closers -// unconditionally. type chainCloser struct { closers []io.Closer } -// add registers c for later closing. A nil closer is ignored. func (cc *chainCloser) add(c io.Closer) { if c != nil { cc.closers = append(cc.closers, c) } } -// Close closes every registered resource in reverse order and joins any errors. func (cc *chainCloser) Close() error { var errs []error for i := len(cc.closers) - 1; i >= 0; i-- { diff --git a/internal/infrastructure/ssh/v2/dialer_test.go b/internal/infrastructure/ssh/v2/dialer_test.go index 10bda67..0c41c9c 100644 --- a/internal/infrastructure/ssh/v2/dialer_test.go +++ b/internal/infrastructure/ssh/v2/dialer_test.go @@ -77,7 +77,6 @@ func TestRouteHopsAndDescribe(t *testing.T) { } } -// recordCloser records the order in which it is closed and can fail on demand. type recordCloser struct { id int order *[]int diff --git a/internal/infrastructure/ssh/v2/endpoint.go b/internal/infrastructure/ssh/v2/endpoint.go index 56f3251..01773b7 100644 --- a/internal/infrastructure/ssh/v2/endpoint.go +++ b/internal/infrastructure/ssh/v2/endpoint.go @@ -31,28 +31,14 @@ import ( "golang.org/x/crypto/ssh/agent" ) -// Endpoint describes a single SSH host along a route: how to address it and how -// to authenticate to it. The zero value is not useful; at minimum User and Addr -// must be set, plus a usable credential (KeyPath, an ssh-agent, or both). type Endpoint struct { - // User is the login name. - User string - // Addr is "host" or "host:port"; the default port is 22. - Addr string - // KeyPath is the path to a private key file. A leading "~" is expanded to - // the current user's home directory. It may be empty to rely solely on an - // ssh-agent. - KeyPath string - // Passphrase decrypts an encrypted KeyPath. It is optional: when empty the - // SSH_PASSPHRASE environment variable is consulted, and failing that the key - // is skipped in favor of the ssh-agent. + User string + Addr string + KeyPath string Passphrase string - // HostKey verifies the server's host key. When nil the Client-level callback - // (see WithHostKeyCallback) applies, defaulting to InsecureIgnoreHostKey. - HostKey ssh.HostKeyCallback + HostKey ssh.HostKeyCallback } -// addr returns the dial address with a default :22 port when none is present. func (e Endpoint) addr() string { if e.Addr == "" { return "" @@ -63,16 +49,10 @@ func (e Endpoint) addr() string { return net.JoinHostPort(e.Addr, "22") } -// label is a short human-readable identity for logs and route descriptions. func (e Endpoint) label() string { return fmt.Sprintf("%s@%s", e.User, e.addr()) } -// clientConfig builds the ssh.ClientConfig for this endpoint and returns an -// io.Closer that owns any ssh-agent connection opened for authentication. The -// caller (the route's connection chain) is responsible for closing it so the -// agent socket is not leaked on every reconnect. The closer is nil when no -// agent connection was opened. func (e Endpoint) clientConfig(ctx context.Context, defaultHostKey ssh.HostKeyCallback) (*ssh.ClientConfig, io.Closer, error) { var signers []ssh.Signer @@ -97,13 +77,10 @@ func (e Endpoint) clientConfig(ctx context.Context, defaultHostKey ssh.HostKeyCa agentCloser := io.Closer(nil) if sock := os.Getenv("SSH_AUTH_SOCK"); sock != "" { var dialer net.Dialer - //nolint:gosec // G704: SSH_AUTH_SOCK is the standard, operator-controlled ssh-agent socket path. if conn, err := dialer.DialContext(ctx, "unix", sock); err == nil { if agentSigners, err := agent.NewClient(conn).Signers(); err == nil { signers = append(signers, agentSigners...) } - // The connection must stay open for the agent signers to sign; the - // route's chain closer owns and closes it. agentCloser = conn } } @@ -117,7 +94,6 @@ func (e Endpoint) clientConfig(ctx context.Context, defaultHostKey ssh.HostKeyCa hostKey = defaultHostKey } if hostKey == nil { - //nolint:gosec // G106: last-resort default for ephemeral e2e VMs; overridable per Endpoint or via WithHostKeyCallback. hostKey = ssh.InsecureIgnoreHostKey() } @@ -130,39 +106,27 @@ func (e Endpoint) clientConfig(ctx context.Context, defaultHostKey ssh.HostKeyCa return cfg, agentCloser, nil } -// parseSigner parses a private key, transparently handling passphrase-protected -// keys. When the key is encrypted but no passphrase is available (neither the -// explicit value nor SSH_PASSPHRASE), it returns (nil, nil) so the caller falls -// back to the ssh-agent. Passphrase protection is detected structurally via -// *ssh.PassphraseMissingError, not by inspecting error text. func parseSigner(raw []byte, passphrase string) (ssh.Signer, error) { signer, err := ssh.ParsePrivateKey(raw) if err == nil { return signer, nil } - var missing *ssh.PassphraseMissingError - if !errors.As(err, &missing) { + if _, ok := errors.AsType[*ssh.PassphraseMissingError](err); !ok { return nil, err } - pass := passphrase - if pass == "" { - pass = os.Getenv("SSH_PASSPHRASE") - } - if pass == "" { - // Encrypted key with no passphrase: defer to the ssh-agent fallback. + if passphrase == "" { return nil, nil } - signer, err = ssh.ParsePrivateKeyWithPassphrase(raw, []byte(pass)) + signer, err = ssh.ParsePrivateKeyWithPassphrase(raw, []byte(passphrase)) if err != nil { return nil, fmt.Errorf("decrypt private key with passphrase: %w", err) } return signer, nil } -// expandTilde expands a leading "~" or "~/" to the current user's home dir. func expandTilde(path string) (string, error) { if !strings.HasPrefix(path, "~") { return path, nil diff --git a/internal/infrastructure/ssh/v2/endpoint_test.go b/internal/infrastructure/ssh/v2/endpoint_test.go index 89a7ebf..a72d17a 100644 --- a/internal/infrastructure/ssh/v2/endpoint_test.go +++ b/internal/infrastructure/ssh/v2/endpoint_test.go @@ -104,17 +104,6 @@ func TestParseSigner(t *testing.T) { } }) - t.Run("encrypted without passphrase defers to agent", func(t *testing.T) { - t.Setenv("SSH_PASSPHRASE", "") - signer, err := parseSigner(encryptedPEM, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if signer != nil { - t.Fatalf("expected nil signer (agent fallback), got one") - } - }) - t.Run("encrypted with explicit passphrase parses", func(t *testing.T) { signer, err := parseSigner(encryptedPEM, "s3cret") if err != nil { @@ -125,17 +114,6 @@ func TestParseSigner(t *testing.T) { } }) - t.Run("encrypted with env passphrase parses", func(t *testing.T) { - t.Setenv("SSH_PASSPHRASE", "s3cret") - signer, err := parseSigner(encryptedPEM, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if signer == nil { - t.Fatalf("expected a signer, got nil") - } - }) - t.Run("garbage fails", func(t *testing.T) { if _, err := parseSigner([]byte("not a key"), ""); err == nil { t.Fatalf("expected error for garbage input") diff --git a/internal/infrastructure/ssh/v2/errors.go b/internal/infrastructure/ssh/v2/errors.go index 7715426..f583b72 100644 --- a/internal/infrastructure/ssh/v2/errors.go +++ b/internal/infrastructure/ssh/v2/errors.go @@ -25,7 +25,6 @@ import ( "syscall" ) -// errClosed is returned by the connection core once Close has been called. var errClosed = errors.New("ssh: client is closed") // isTransient reports whether err denotes a recoverable transport failure that @@ -33,36 +32,23 @@ var errClosed = errors.New("ssh: client is closed") // timed-out read, …). Classification is done structurally via errors.Is and // errors.As against standard error values and types — never by matching error // text — so it stays correct as wrapping changes. -// -// Context cancellation (context.Canceled, context.DeadlineExceeded) is -// deliberately NOT transient: those mean the caller asked to stop, so retrying -// would ignore an explicit signal. func isTransient(err error) bool { if err == nil { return false } - // Context cancellation outranks everything: it is an explicit stop signal, - // not a recoverable transport failure. Check it first because - // context.DeadlineExceeded also satisfies net.Error with Timeout()==true. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } - // A clean or truncated EOF is the most common symptom of a session that - // died underneath us (the x/crypto/ssh mux surfaces the stored disconnect - // error, usually io.EOF, to pending channel/session opens). if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { return true } - // Operating on a connection/listener that was already closed by a peer or - // by our own reconnect. if errors.Is(err, net.ErrClosed) { return true } - // Low-level socket failures that a fresh dial typically recovers from. if errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ECONNABORTED) || @@ -73,39 +59,22 @@ func isTransient(err error) bool { return true } - // Any net.Error that reports a timeout (covers i/o timeouts that are not a - // bare syscall.ETIMEDOUT, e.g. deadline-driven failures). - var nerr net.Error - if errors.As(err, &nerr) && nerr.Timeout() { + if nerr, ok := errors.AsType[net.Error](err); ok && nerr.Timeout() { return true } return false } -// ExitError reports that a remote command ran to completion but exited with a -// non-zero status. It is intentionally distinct from a transport error: a -// non-zero exit is a normal program outcome, not a broken connection, so the -// operation core must never retry it. -// -// It is part of the contract for the future Run operation (see package docs); -// the connection core already treats *ExitError as non-transient because -// isTransient returns false for it. type ExitError struct { - // Cmd is the command line that was executed. - Cmd string - // ExitCode is the process exit status reported by the remote end. + Cmd string ExitCode int - // Stderr holds captured standard error, when available. - Stderr string - // Err is the underlying error returned by the SSH library, if any. - Err error + Stderr string + Err error } -// Error implements the error interface. func (e *ExitError) Error() string { return fmt.Sprintf("ssh: command %q exited with code %d", e.Cmd, e.ExitCode) } -// Unwrap exposes the underlying SSH library error for errors.Is/As. func (e *ExitError) Unwrap() error { return e.Err } diff --git a/internal/infrastructure/ssh/v2/errors_test.go b/internal/infrastructure/ssh/v2/errors_test.go index 1c08efa..bd917d7 100644 --- a/internal/infrastructure/ssh/v2/errors_test.go +++ b/internal/infrastructure/ssh/v2/errors_test.go @@ -60,8 +60,6 @@ func TestIsTransient(t *testing.T) { } } -// timeoutErr is a net.Error that reports a timeout but is not a syscall errno, -// exercising the net.Error/Timeout() branch of the classifier. type timeoutErr struct{} func (timeoutErr) Error() string { return "i/o timeout" } diff --git a/internal/infrastructure/ssh/v2/options.go b/internal/infrastructure/ssh/v2/options.go index 1959ce8..2f802ee 100644 --- a/internal/infrastructure/ssh/v2/options.go +++ b/internal/infrastructure/ssh/v2/options.go @@ -26,15 +26,8 @@ import ( "github.com/deckhouse/storage-e2e/internal/logger" ) -// defaultDialTimeout bounds a single (re)connect attempt performed by the -// connection core. It is deliberately internal: callers shape overall -// patience through context deadlines and WithRetries, while this only caps one -// detached dial so a reconnect storm can never hang indefinitely. const defaultDialTimeout = 30 * time.Second -// options holds the resolved configuration for a Client. The zero value is not -// used directly; defaultOptions seeds sensible defaults that individual Option -// funcs then override. type options struct { keepalive time.Duration retries int @@ -43,33 +36,22 @@ type options struct { dialTimeout time.Duration } -// defaultOptions returns the baseline configuration. Host key verification -// defaults to InsecureIgnoreHostKey because this package targets ephemeral e2e -// VMs whose host keys are not known ahead of time; this is a conscious default -// that WithHostKeyCallback overrides. func defaultOptions() options { return options{ - keepalive: 0, - retries: config.SSHRetryCount, - log: logger.GetLogger(), - //nolint:gosec // G106: ephemeral e2e VMs have no known host key; conscious default, overridable via WithHostKeyCallback. + keepalive: 0, + retries: config.SSHRetryCount, + log: logger.GetLogger(), hostKey: ssh.InsecureIgnoreHostKey(), dialTimeout: defaultDialTimeout, } } -// Option configures a Client. Options are applied in order; later options win. type Option func(*options) -// WithKeepalive enables a background keepalive probe at interval d. A non-zero -// interval starts a goroutine that sends "keepalive@openssh.com" and proactively -// heals the connection on failure. The zero value (default) disables keepalive. func WithKeepalive(d time.Duration) Option { return func(o *options) { o.keepalive = d } } -// WithRetries sets how many times an operation re-establishes the connection -// before giving up. Negative values are clamped to zero (no reconnect retries). func WithRetries(n int) Option { return func(o *options) { if n < 0 { @@ -79,8 +61,6 @@ func WithRetries(n int) Option { } } -// WithLogger sets the structured logger used for healing WARN messages and -// diagnostics. A nil logger is ignored so the default logger remains in place. func WithLogger(l *slog.Logger) Option { return func(o *options) { if l != nil { @@ -89,8 +69,6 @@ func WithLogger(l *slog.Logger) Option { } } -// WithHostKeyCallback sets the host key verification callback used for every hop -// that does not carry its own Endpoint.HostKey. A nil callback is ignored. func WithHostKeyCallback(cb ssh.HostKeyCallback) Option { return func(o *options) { if cb != nil { @@ -99,9 +77,6 @@ func WithHostKeyCallback(cb ssh.HostKeyCallback) Option { } } -// WithInsecureIgnoreHostKey disables host key verification for hops without an -// explicit Endpoint.HostKey. This is the default, but the option exists so the -// intent can be made explicit at the call site. func WithInsecureIgnoreHostKey() Option { //nolint:gosec // G106: explicit opt-in to skip host key verification. return func(o *options) { o.hostKey = ssh.InsecureIgnoreHostKey() } diff --git a/internal/infrastructure/ssh/v2/testserver_test.go b/internal/infrastructure/ssh/v2/testserver_test.go index f487109..c51191b 100644 --- a/internal/infrastructure/ssh/v2/testserver_test.go +++ b/internal/infrastructure/ssh/v2/testserver_test.go @@ -30,16 +30,10 @@ import ( "golang.org/x/crypto/ssh" ) -// quietLogger returns a logger that discards output, keeping test logs clean. func quietLogger() *slog.Logger { return slog.New(slog.NewTextHandler(io.Discard, nil)) } -// testServer is an in-process SSH server on 127.0.0.1 used by tests. It accepts -// any client (NoClientAuth), answers keepalive global requests, and serves -// "direct-tcpip" channels by dialing the requested address and proxying bytes — -// enough to exercise tunnels end to end. dropConns force-closes live transports -// to simulate a dropped session. type testServer struct { ln net.Listener cfg *ssh.ServerConfig @@ -123,7 +117,6 @@ func (s *testServer) handleConn(nConn net.Conn) { } } -// directTCPIPMsg is the extra data layout of a direct-tcpip channel open. type directTCPIPMsg struct { DestAddr string DestPort uint32 @@ -163,8 +156,6 @@ func handleDirectTCPIP(newCh ssh.NewChannel) { }() } -// dropConns force-closes all live transport connections, simulating a session -// drop (a Wi-Fi flap on the developer's laptop). func (s *testServer) dropConns() { s.mu.Lock() defer s.mu.Unlock() @@ -182,9 +173,6 @@ func (s *testServer) Close() { }) } -// serverDialer is a test Dialer that connects to a testServer. It counts dials -// and can gate each dial on a channel to make reconnect concurrency -// deterministic. type serverDialer struct { addr string @@ -227,7 +215,6 @@ func (d *serverDialer) setGate(gate chan struct{}) { d.mu.Unlock() } -// newEchoServer starts a TCP echo server on 127.0.0.1 and returns its port. func newEchoServer(t *testing.T) int { t.Helper() var lc net.ListenConfig diff --git a/internal/infrastructure/ssh/v2/tunnel.go b/internal/infrastructure/ssh/v2/tunnel.go index 00eebbc..7bedf3f 100644 --- a/internal/infrastructure/ssh/v2/tunnel.go +++ b/internal/infrastructure/ssh/v2/tunnel.go @@ -30,18 +30,10 @@ import ( "golang.org/x/crypto/ssh" ) -// acceptDeadline bounds each listener Accept so the serve loop re-checks its -// context promptly even when no client is connecting. const acceptDeadline = 500 * time.Millisecond -// Tunnel is a local TCP forward to a port on the target host. It listens on -// 127.0.0.1 on an automatically chosen free port and heals transparently: when -// the SSH session drops, the next forwarded connection re-opens it via the -// connection core and the listener keeps serving instead of dying. type Tunnel struct { - // LocalPort is the chosen local port on 127.0.0.1. - LocalPort int - // RemotePort is the forwarded port on the target host. + LocalPort int RemotePort int listener net.Listener @@ -51,11 +43,6 @@ type Tunnel struct { closeErr error } -// Tunnel forwards remotePort on the target host to a fresh local port on -// 127.0.0.1 and returns once the listener is up. The returned Tunnel serves -// until its Close is called or ctx is canceled. Establishing each forwarded -// connection is reconnect-aware and bounded by the Client's retry budget; every -// heal is logged at WARN. func (c *Client) Tunnel(ctx context.Context, remotePort int) (*Tunnel, error) { if err := ctx.Err(); err != nil { return nil, fmt.Errorf("tunnel setup: %w", err) @@ -73,8 +60,6 @@ func (c *Client) Tunnel(ctx context.Context, remotePort int) (*Tunnel, error) { } localPort := tcpAddr.Port - // The serve loop outlives the setup call, so derive a cancellable context - // from the caller's: caller cancellation stops the tunnel, and so does Close. serveCtx, cancel := context.WithCancel(ctx) t := &Tunnel{ @@ -93,15 +78,10 @@ func (c *Client) Tunnel(ctx context.Context, remotePort int) (*Tunnel, error) { return t, nil } -// LocalAddr returns the local "127.0.0.1:" address of the tunnel. func (t *Tunnel) LocalAddr() string { return "127.0.0.1:" + strconv.Itoa(t.LocalPort) } -// Close stops the tunnel: it cancels the serve loop, closes the listener, and -// waits for all in-flight connections to drain. It is idempotent and safe for -// concurrent use. It does not close the underlying SSH connection, which the -// owning Client manages. func (t *Tunnel) Close() error { t.closeOnce.Do(func() { t.cancel() @@ -111,9 +91,6 @@ func (t *Tunnel) Close() error { return t.closeErr } -// serve accepts local connections and forwards each one over the SSH connection. -// A short Accept deadline keeps the loop responsive to ctx; a dead session does -// not stop the loop — it is healed per connection in handle. func (t *Tunnel) serve(ctx context.Context, core *conn, retries int, log *slog.Logger) { defer t.wg.Done() @@ -133,11 +110,9 @@ func (t *Tunnel) serve(ctx context.Context, core *conn, retries int, log *slog.L if ctx.Err() != nil { return } - var ne net.Error - if errors.As(err, &ne) && ne.Timeout() { + if ne, ok := errors.AsType[net.Error](err); ok && ne.Timeout() { continue } - // The listener was closed by Close (not via ctx); stop serving. return } @@ -149,11 +124,6 @@ func (t *Tunnel) serve(ctx context.Context, core *conn, retries int, log *slog.L } } -// handle forwards a single accepted local connection to the remote port. The -// remote dial is reconnect-aware: a transient failure heals the SSH connection -// and retries within the budget. Once both ends are connected, bytes are copied -// in both directions; closing the conns on completion or cancellation unblocks -// any read still in flight. func (t *Tunnel) handle(ctx context.Context, core *conn, retries int, local net.Conn, log *slog.Logger) { defer local.Close() @@ -170,8 +140,6 @@ func (t *Tunnel) handle(ctx context.Context, core *conn, retries int, local net. } defer remote.Close() - // Closing both conns on cancellation unblocks the copy goroutines, which - // would otherwise sit in a blocking Read. stop := make(chan struct{}) defer close(stop) go func() { @@ -187,17 +155,12 @@ func (t *Tunnel) handle(ctx context.Context, core *conn, retries int, local net. go func() { _, _ = io.Copy(remote, local); done <- struct{}{} }() go func() { _, _ = io.Copy(local, remote); done <- struct{}{} }() - // When one direction ends, close both ends to unblock the other. <-done _ = local.Close() _ = remote.Close() <-done } -// dialChannel opens a forwarded TCP connection to addr over the SSH client while -// respecting ctx. ssh.Client.Dial has no context variant, so the dial runs in a -// goroutine and ctx cancellation abandons (and later closes) the result without -// leaking the goroutine. func dialChannel(ctx context.Context, client *ssh.Client, addr string) (net.Conn, error) { type result struct { conn net.Conn diff --git a/internal/infrastructure/ssh/v2/tunnel_test.go b/internal/infrastructure/ssh/v2/tunnel_test.go index d615400..eb378e4 100644 --- a/internal/infrastructure/ssh/v2/tunnel_test.go +++ b/internal/infrastructure/ssh/v2/tunnel_test.go @@ -36,7 +36,6 @@ func newTestClient(t *testing.T, d Dialer, keepalive time.Duration) *Client { return c } -// dialTimeout dials addr with a bounded context, satisfying the noctx linter. func dialTimeout(addr string, timeout time.Duration) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -44,7 +43,6 @@ func dialTimeout(addr string, timeout time.Duration) (net.Conn, error) { return d.DialContext(ctx, "tcp", addr) } -// roundtrip writes payload to addr and reads the echoed reply back. func roundtrip(t *testing.T, addr, payload string) string { t.Helper() conn, err := dialTimeout(addr, 3*time.Second) @@ -121,11 +119,8 @@ func TestTunnelHealsAfterDroppedSession(t *testing.T) { t.Fatalf("echo before drop = %q, want before", got) } - // Simulate the SSH session dying mid-test. srv.dropConns() - // The next forwarded connection must transparently heal the session and - // keep serving. Retry the roundtrip until it works within a bound. var lastErr error deadline := time.Now().Add(8 * time.Second) for time.Now().Before(deadline) { @@ -182,7 +177,6 @@ func TestTunnelCloseIsIdempotentAndStopsListener(t *testing.T) { t.Fatalf("second Close: %v", err) } - // The listener must be gone after Close. waitFor(t, 2*time.Second, func() bool { conn, err := dialTimeout(addr, 200*time.Millisecond) if err != nil { From e5eeb885c44da9f60c71f4f57d7a31c48d103694 Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Mon, 22 Jun 2026 12:07:39 +0300 Subject: [PATCH 03/10] refactor: remove ExitError type and update documentation in errors.go Signed-off-by: Daniil Studenikin --- docs/ARCHITECTURE.md | 4 ++-- internal/infrastructure/ssh/v2/errors.go | 14 -------------- internal/infrastructure/ssh/v2/errors_test.go | 19 ------------------- 3 files changed, 2 insertions(+), 35 deletions(-) diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 5f26c21..093a5e9 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -61,7 +61,7 @@ storage-e2e/ │ │ ├── conn.go # connection core: snapshot/refresh/keepalive + withConn │ │ ├── dialer.go # Dialer interface, Route, chain closer │ │ ├── endpoint.go # Endpoint, auth, host/key resolution -│ │ ├── errors.go # transient classification, ExitError +│ │ ├── errors.go # transient classification │ │ ├── options.go # functional options │ │ └── tunnel.go # Tunnel, accept loop │ │ @@ -490,7 +490,7 @@ infrastructure/ssh/ ├── conn.go # connection core: snapshot/refresh/keepalive + withConn executor ├── dialer.go # Dialer interface, Route, chain closer ├── endpoint.go # Endpoint, auth, host/key resolution - ├── errors.go # transient classification, ExitError + ├── errors.go # transient classification ├── options.go # functional options └── tunnel.go # Tunnel, accept loop ``` diff --git a/internal/infrastructure/ssh/v2/errors.go b/internal/infrastructure/ssh/v2/errors.go index f583b72..6d42020 100644 --- a/internal/infrastructure/ssh/v2/errors.go +++ b/internal/infrastructure/ssh/v2/errors.go @@ -19,7 +19,6 @@ package ssh import ( "context" "errors" - "fmt" "io" "net" "syscall" @@ -65,16 +64,3 @@ func isTransient(err error) bool { return false } - -type ExitError struct { - Cmd string - ExitCode int - Stderr string - Err error -} - -func (e *ExitError) Error() string { - return fmt.Sprintf("ssh: command %q exited with code %d", e.Cmd, e.ExitCode) -} - -func (e *ExitError) Unwrap() error { return e.Err } diff --git a/internal/infrastructure/ssh/v2/errors_test.go b/internal/infrastructure/ssh/v2/errors_test.go index bd917d7..6303177 100644 --- a/internal/infrastructure/ssh/v2/errors_test.go +++ b/internal/infrastructure/ssh/v2/errors_test.go @@ -47,7 +47,6 @@ func TestIsTransient(t *testing.T) { {name: "context canceled", err: context.Canceled, want: false}, {name: "context deadline", err: context.DeadlineExceeded, want: false}, {name: "plain error", err: errors.New("boom"), want: false}, - {name: "exit error", err: &ExitError{Cmd: "false", ExitCode: 1}, want: false}, } for _, tc := range tests { @@ -65,21 +64,3 @@ type timeoutErr struct{} func (timeoutErr) Error() string { return "i/o timeout" } func (timeoutErr) Timeout() bool { return true } func (timeoutErr) Temporary() bool { return true } - -func TestExitErrorUnwrap(t *testing.T) { - t.Parallel() - - underlying := errors.New("session: exited") - exit := &ExitError{Cmd: "do-thing", ExitCode: 2, Stderr: "nope", Err: underlying} - - if !errors.Is(exit, underlying) { - t.Fatalf("errors.Is should find the wrapped error") - } - var target *ExitError - if !errors.As(error(exit), &target) { - t.Fatalf("errors.As should match *ExitError") - } - if target.ExitCode != 2 { - t.Fatalf("ExitCode = %d, want 2", target.ExitCode) - } -} From 6becdda02e018bd72884bae5213a74000029d92a Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Mon, 22 Jun 2026 15:30:44 +0300 Subject: [PATCH 04/10] refactor: improve keepalive handling and encapsulate listener closure logic Signed-off-by: Daniil Studenikin --- internal/infrastructure/ssh/v2/conn.go | 25 +++++++++++++++++++++++- internal/infrastructure/ssh/v2/tunnel.go | 13 +++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/internal/infrastructure/ssh/v2/conn.go b/internal/infrastructure/ssh/v2/conn.go index 1c63207..b8f5cb9 100644 --- a/internal/infrastructure/ssh/v2/conn.go +++ b/internal/infrastructure/ssh/v2/conn.go @@ -152,9 +152,12 @@ func (c *conn) keepaliveLoop(ctx context.Context, interval time.Duration) { if client == nil { continue } - if _, _, err := client.SendRequest("keepalive@openssh.com", true, nil); err == nil { + if err := probeKeepalive(ctx, client, interval); err == nil { continue } + if ctx.Err() != nil { + return + } c.log.Warn("ssh: keepalive failed, healing connection", "route", c.dialer.Describe()) if _, _, err := c.refresh(ctx, gen); err != nil { @@ -168,6 +171,26 @@ func (c *conn) keepaliveLoop(ctx context.Context, interval time.Duration) { } } +func probeKeepalive(ctx context.Context, client *ssh.Client, timeout time.Duration) error { + errc := make(chan error, 1) + go func() { + _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) + errc <- err + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return fmt.Errorf("ssh: keepalive probe timed out after %s", timeout) + case err := <-errc: + return err + } +} + func (c *conn) isClosed() bool { c.mu.Lock() defer c.mu.Unlock() diff --git a/internal/infrastructure/ssh/v2/tunnel.go b/internal/infrastructure/ssh/v2/tunnel.go index 7bedf3f..41db1e6 100644 --- a/internal/infrastructure/ssh/v2/tunnel.go +++ b/internal/infrastructure/ssh/v2/tunnel.go @@ -41,6 +41,9 @@ type Tunnel struct { wg sync.WaitGroup closeOnce sync.Once closeErr error + + lnCloseOnce sync.Once + lnCloseErr error } func (c *Client) Tunnel(ctx context.Context, remotePort int) (*Tunnel, error) { @@ -85,14 +88,22 @@ func (t *Tunnel) LocalAddr() string { func (t *Tunnel) Close() error { t.closeOnce.Do(func() { t.cancel() - t.closeErr = t.listener.Close() + t.closeErr = t.closeListener() t.wg.Wait() }) return t.closeErr } +func (t *Tunnel) closeListener() error { + t.lnCloseOnce.Do(func() { + t.lnCloseErr = t.listener.Close() + }) + return t.lnCloseErr +} + func (t *Tunnel) serve(ctx context.Context, core *conn, retries int, log *slog.Logger) { defer t.wg.Done() + defer func() { _ = t.closeListener() }() for { select { From 5775165a1cf792fc34795641d1afa936f4653a96 Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Wed, 24 Jun 2026 18:19:32 +0300 Subject: [PATCH 05/10] refactor: rename Tunnel method to OpenTunnel and update related references This commit renames the Tunnel method to OpenTunnel in the SSH client implementation and updates all related references in the codebase, including tests and configuration files. Additionally, it improves the configuration structure by renaming environment variable keys for clarity. Signed-off-by: Daniil Studenikin --- internal/infrastructure/ssh/v2/client.go | 2 +- internal/infrastructure/ssh/v2/tunnel.go | 2 +- internal/infrastructure/ssh/v2/tunnel_test.go | 16 ++-- internal/provisioning/dvp/config.go | 41 ++------ internal/provisioning/dvp/connection.go | 85 ---------------- internal/provisioning/dvp/kubeconfig.go | 51 ---------- internal/provisioning/dvp/provider.go | 96 +++++++++++++++---- pkg/cluster/cluster.go | 5 +- pkg/kubernetes/modules.go | 3 - 9 files changed, 99 insertions(+), 202 deletions(-) delete mode 100644 internal/provisioning/dvp/connection.go diff --git a/internal/infrastructure/ssh/v2/client.go b/internal/infrastructure/ssh/v2/client.go index 331b2bf..5c761a5 100644 --- a/internal/infrastructure/ssh/v2/client.go +++ b/internal/infrastructure/ssh/v2/client.go @@ -30,7 +30,7 @@ limitations under the License. // // c, _ := ssh.New(ctx, ssh.Route(jumpEp, targetEp)) // defer c.Close() -// t, _ := c.Tunnel(ctx, 6443) +// t, _ := c.OpenTunnel(ctx, 6443) // defer t.Close() // rest := &rest.Config{Host: "https://" + t.LocalAddr()} package ssh diff --git a/internal/infrastructure/ssh/v2/tunnel.go b/internal/infrastructure/ssh/v2/tunnel.go index 41db1e6..9f7fd10 100644 --- a/internal/infrastructure/ssh/v2/tunnel.go +++ b/internal/infrastructure/ssh/v2/tunnel.go @@ -46,7 +46,7 @@ type Tunnel struct { lnCloseErr error } -func (c *Client) Tunnel(ctx context.Context, remotePort int) (*Tunnel, error) { +func (c *Client) OpenTunnel(ctx context.Context, remotePort int) (*Tunnel, error) { if err := ctx.Err(); err != nil { return nil, fmt.Errorf("tunnel setup: %w", err) } diff --git a/internal/infrastructure/ssh/v2/tunnel_test.go b/internal/infrastructure/ssh/v2/tunnel_test.go index eb378e4..93877d5 100644 --- a/internal/infrastructure/ssh/v2/tunnel_test.go +++ b/internal/infrastructure/ssh/v2/tunnel_test.go @@ -81,9 +81,9 @@ func TestTunnelForwardsTraffic(t *testing.T) { d := &serverDialer{addr: srv.addr()} c := newTestClient(t, d, 0) - tun, err := c.Tunnel(context.Background(), echoPort) + tun, err := c.OpenTunnel(context.Background(), echoPort) if err != nil { - t.Fatalf("Tunnel: %v", err) + t.Fatalf("OpenTunnel: %v", err) } defer tun.Close() @@ -109,9 +109,9 @@ func TestTunnelHealsAfterDroppedSession(t *testing.T) { d := &serverDialer{addr: srv.addr()} c := newTestClient(t, d, 0) - tun, err := c.Tunnel(context.Background(), echoPort) + tun, err := c.OpenTunnel(context.Background(), echoPort) if err != nil { - t.Fatalf("Tunnel: %v", err) + t.Fatalf("OpenTunnel: %v", err) } defer tun.Close() @@ -164,9 +164,9 @@ func TestTunnelCloseIsIdempotentAndStopsListener(t *testing.T) { d := &serverDialer{addr: srv.addr()} c := newTestClient(t, d, 0) - tun, err := c.Tunnel(context.Background(), echoPort) + tun, err := c.OpenTunnel(context.Background(), echoPort) if err != nil { - t.Fatalf("Tunnel: %v", err) + t.Fatalf("OpenTunnel: %v", err) } addr := tun.LocalAddr() @@ -199,9 +199,9 @@ func TestTunnelStopsWhenContextCancelled(t *testing.T) { c := newTestClient(t, d, 0) ctx, cancel := context.WithCancel(context.Background()) - tun, err := c.Tunnel(ctx, echoPort) + tun, err := c.OpenTunnel(ctx, echoPort) if err != nil { - t.Fatalf("Tunnel: %v", err) + t.Fatalf("OpenTunnel: %v", err) } defer tun.Close() diff --git a/internal/provisioning/dvp/config.go b/internal/provisioning/dvp/config.go index e34fcc9..0c470b2 100644 --- a/internal/provisioning/dvp/config.go +++ b/internal/provisioning/dvp/config.go @@ -16,49 +16,24 @@ limitations under the License. package dvp -import ( - "fmt" - "os" -) +const apiServerRemotePort = 6445 type Config struct { SSHUser string `env:"E2E_DVP_BASE_CLUSTER_SSH_USER,required"` SSHHost string `env:"E2E_DVP_BASE_CLUSTER_SSH_HOST,required"` - SSHKeyPath string `env:"E2E_DVP_BASE_CLUSTER_SSH_KEY_PATH,required"` + SSHKeyPath string `env:"E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY_PATH,required"` SSHPassphrase string `env:"E2E_DVP_BASE_CLUSTER_SSH_PASSPHRASE"` - SSHJumpHost string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_HOST"` - SSHJumpUser string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_USER"` - SSHJumpKeyPath string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_KEY_PATH"` + SSHJumpHost string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_HOST"` + SSHJumpUser string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_USER"` + SSHJumpKeyPath string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_PRIVATE_KEY_PATH"` + SSHJumpPassphrase string `env:"E2E_DVP_BASE_CLUSTER_SSH_JUMP_KEY_PASSPHRASE"` KubeConfigPath string `env:"E2E_DVP_BASE_CLUSTER_KUBECONFIG_PATH,required"` Namespace string `env:"E2E_DVP_BASE_CLUSTER_NAMESPACE" envDefault:"e2e-test-cluster"` } -func (c *Config) SetPassphrase() error { - if c.SSHPassphrase == "" { - return nil - } - if err := os.Setenv("SSH_PASSPHRASE", c.SSHPassphrase); err != nil { - return fmt.Errorf("failed to set SSH_PASSPHRASE: %w", err) - } - return nil -} - -func (c *Config) baseEndpoint() sshEndpoint { - ep := sshEndpoint{User: c.SSHUser, Host: c.SSHHost, KeyPath: c.SSHKeyPath} - if c.SSHJumpHost == "" { - return ep - } - - jump := sshEndpoint{User: c.SSHJumpUser, Host: c.SSHJumpHost, KeyPath: c.SSHJumpKeyPath} - if jump.User == "" { - jump.User = c.SSHUser - } - if jump.KeyPath == "" { - jump.KeyPath = c.SSHKeyPath - } - ep.Jump = &jump - return ep +func (c *Config) HasJumpHost() bool { + return c.SSHJumpUser != "" && c.SSHJumpHost != "" && c.SSHJumpKeyPath != "" } diff --git a/internal/provisioning/dvp/connection.go b/internal/provisioning/dvp/connection.go deleted file mode 100644 index 4bab332..0000000 --- a/internal/provisioning/dvp/connection.go +++ /dev/null @@ -1,85 +0,0 @@ -/* -Copyright 2026 Flant JSC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package dvp - -import ( - "context" - "errors" - "fmt" - - "github.com/deckhouse/storage-e2e/internal/infrastructure/ssh" -) - -const apiServerRemotePort = "6445" - -type sshEndpoint struct { - User string - Host string - KeyPath string - Jump *sshEndpoint -} - -func (e sshEndpoint) dial() (ssh.SSHClient, error) { - if e.Jump != nil { - return ssh.NewClientWithJumpHost( - e.Jump.User, e.Jump.Host, e.Jump.KeyPath, - e.User, e.Host, e.KeyPath, - ) - } - return ssh.NewClient(e.User, e.Host, e.KeyPath) -} - -type clusterConnection struct { - ssh ssh.SSHClient - tunnel *ssh.TunnelInfo -} - -func openTunnel(ctx context.Context, ep sshEndpoint) (*clusterConnection, error) { - sshClient, err := ep.dial() - if err != nil { - return nil, fmt.Errorf("ssh dial %s@%s: %w", ep.User, ep.Host, err) - } - - conn := &clusterConnection{ssh: sshClient} - - conn.tunnel, err = sshClient.OpenTunnel(ctx, apiServerRemotePort) - if err != nil { - _ = conn.Close() - return nil, fmt.Errorf("establish API server tunnel: %w", err) - } - - return conn, nil -} - -func (c *clusterConnection) Close() error { - if c == nil { - return nil - } - - var errs []error - if c.tunnel != nil && c.tunnel.StopFunc != nil { - if err := c.tunnel.StopFunc(); err != nil { - errs = append(errs, fmt.Errorf("stop API server tunnel: %w", err)) - } - } - if c.ssh != nil { - if err := c.ssh.Close(); err != nil { - errs = append(errs, fmt.Errorf("close ssh client: %w", err)) - } - } - return errors.Join(errs...) -} diff --git a/internal/provisioning/dvp/kubeconfig.go b/internal/provisioning/dvp/kubeconfig.go index fd94755..89016d1 100644 --- a/internal/provisioning/dvp/kubeconfig.go +++ b/internal/provisioning/dvp/kubeconfig.go @@ -25,7 +25,6 @@ import ( "time" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" ) func readKubeconfig(path string) ([]byte, error) { @@ -58,57 +57,7 @@ func expandUserPath(path string) (string, error) { return filepath.Join(home, strings.TrimPrefix(expanded, "~/")), nil } -func loadKubeconfigViaTunnel(localPort int, kubeconfigDir, host, kubeconfigSrcPath string) (*rest.Config, string, error) { - raw, err := readKubeconfig(kubeconfigSrcPath) - if err != nil { - return nil, "", fmt.Errorf("load base cluster kubeconfig: %w", err) - } - - path, err := kubeconfigFilePath(kubeconfigDir, host) - if err != nil { - return nil, "", err - } - - server := fmt.Sprintf("https://127.0.0.1:%d", localPort) - cfg, err := buildKubeconfig(raw, server, path) - if err != nil { - return nil, "", fmt.Errorf("build kubeconfig: %w", err) - } - return cfg, path, nil -} - -func buildKubeconfig(raw []byte, server, path string) (*rest.Config, error) { - apiCfg, err := clientcmd.Load(raw) - if err != nil { - return nil, fmt.Errorf("parse kubeconfig: %w", err) - } - for _, cluster := range apiCfg.Clusters { - cluster.Server = server - } - - if writeErr := clientcmd.WriteToFile(*apiCfg, path); writeErr != nil { - return nil, fmt.Errorf("write kubeconfig %q: %w", path, writeErr) - } - - restCfg, err := clientcmd.NewDefaultClientConfig(*apiCfg, &clientcmd.ConfigOverrides{}).ClientConfig() - - if err != nil { - return nil, fmt.Errorf("build rest config: %w", err) - } - configureTunnelTimeouts(restCfg) - return restCfg, nil -} - -func kubeconfigFilePath(dir, host string) (string, error) { - if err := os.MkdirAll(dir, 0o700); err != nil { - return "", fmt.Errorf("create kubeconfig dir %q: %w", dir, err) - } - return filepath.Join(dir, fmt.Sprintf("kubeconfig-%s.yml", host)), nil -} - func configureTunnelTimeouts(cfg *rest.Config) { - cfg.Timeout = 2 * time.Minute - prev := cfg.WrapTransport cfg.WrapTransport = func(rt http.RoundTripper) http.RoundTripper { if prev != nil { diff --git a/internal/provisioning/dvp/provider.go b/internal/provisioning/dvp/provider.go index 8d4d8a1..a0ad26f 100644 --- a/internal/provisioning/dvp/provider.go +++ b/internal/provisioning/dvp/provider.go @@ -20,10 +20,15 @@ import ( "context" "fmt" "log/slog" + "time" "github.com/caarlos0/env/v11" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" "github.com/deckhouse/storage-e2e/internal/config" + "github.com/deckhouse/storage-e2e/internal/infrastructure/ssh/v2" "github.com/deckhouse/storage-e2e/pkg/clusterprovider" "github.com/deckhouse/storage-e2e/pkg/kubernetes" ) @@ -39,10 +44,6 @@ func NewDVPProvider(logger *slog.Logger, cfg *clusterprovider.ClusterConfig) (cl if err := env.Parse(dvpConf); err != nil { return nil, err } - err := dvpConf.SetPassphrase() - if err != nil { - return nil, err - } return &dvpProvider{ cfg: cfg, @@ -53,6 +54,63 @@ func NewDVPProvider(logger *slog.Logger, cfg *clusterprovider.ClusterConfig) (cl func (p *dvpProvider) Name() string { return clusterprovider.ModeDVP } +func (p *dvpProvider) buildSshClient(ctx context.Context) (*ssh.Client, error) { + var dialer ssh.Dialer + if p.dvpConf.HasJumpHost() { + dialer = ssh.Route(ssh.Endpoint{ + User: p.dvpConf.SSHJumpUser, + Addr: p.dvpConf.SSHJumpHost, + KeyPath: p.dvpConf.SSHJumpKeyPath, + Passphrase: p.dvpConf.SSHJumpPassphrase, + }, ssh.Endpoint{ + User: p.dvpConf.SSHUser, + Addr: p.dvpConf.SSHHost, + KeyPath: p.dvpConf.SSHKeyPath, + Passphrase: p.dvpConf.SSHPassphrase, + }) + } else { + dialer = ssh.Route(ssh.Endpoint{ + User: p.dvpConf.SSHUser, + Addr: p.dvpConf.SSHHost, + KeyPath: p.dvpConf.SSHKeyPath, + Passphrase: p.dvpConf.SSHPassphrase, + }) + } + + sshClient, sshNewErr := ssh.New(ctx, dialer) + if sshNewErr != nil { + return nil, fmt.Errorf("creating ssh client: %w", sshNewErr) + } + return sshClient, nil +} + +func (p *dvpProvider) buildRestConfig(tun *ssh.Tunnel) (*rest.Config, error) { + rawKubeconfig, readErr := readKubeconfig(p.dvpConf.KubeConfigPath) + if readErr != nil { + return nil, fmt.Errorf("reading kubeconfig: %w", readErr) + } + + apiCfg, err := clientcmd.Load(rawKubeconfig) + overrides := &clientcmd.ConfigOverrides{ + ClusterInfo: clientcmdapi.Cluster{ + Server: tun.LocalAddr(), + }, + Timeout: (2 * time.Minute).String(), + } + + if err != nil { + return nil, fmt.Errorf("parsing kubeconfig: %w", err) + } + + restConfig, clientConfigErr := clientcmd.NewDefaultClientConfig(*apiCfg, overrides).ClientConfig() + if clientConfigErr != nil { + return nil, fmt.Errorf("creating client config: %w", clientConfigErr) + } + + configureTunnelTimeouts(restConfig) + return restConfig, nil +} + func (p *dvpProvider) Bootstrap(ctx context.Context) error { clusterDef, err := config.LoadClusterDefinition(p.cfg.ClusterBootstrapConfigPath) if err != nil { @@ -70,26 +128,28 @@ func (p *dvpProvider) Bootstrap(ctx context.Context) error { "jumpHost", p.dvpConf.SSHJumpHost, "kubeconfigSource", p.dvpConf.KubeConfigPath, ) - conn, err := openTunnel(ctx, p.dvpConf.baseEndpoint()) - if err != nil { - return fmt.Errorf("open tunnel to DVP base cluster: %w", err) + + sshClient, sshNewErr := p.buildSshClient(ctx) + if sshNewErr != nil { + return fmt.Errorf("creating ssh client: %w", sshNewErr) + } + + tun, tunErr := sshClient.OpenTunnel(ctx, apiServerRemotePort) + + if tunErr != nil { + return fmt.Errorf("creating tunnel: %w", tunErr) } defer func() { - if cerr := conn.Close(); cerr != nil { - p.logger.Warn("close DVP base cluster connection", "err", cerr) + tunCloseErr := tun.Close() + if tunCloseErr != nil { + p.logger.Warn("failed to close tunnel", "err", tunCloseErr) } }() - kubeconfig, kubeconfigPath, err := loadKubeconfigViaTunnel( - conn.tunnel.LocalPort, config.E2ETempDir, p.dvpConf.SSHHost, p.dvpConf.KubeConfigPath, - ) - if err != nil { - return fmt.Errorf("build kubeconfig for DVP base cluster: %w", err) + kubeconfig, buildRestConfErr := p.buildRestConfig(tun) + if buildRestConfErr != nil { + return fmt.Errorf("creating rest config: %w", buildRestConfErr) } - p.logger.Info("connected to DVP base cluster", - "kubeconfig", kubeconfigPath, - "apiServer", kubeconfig.Host, - ) p.logger.Info("waiting for virtualization module to become ready", "timeout", config.ModuleCheckTimeout, diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 327f321..d06f002 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -39,6 +39,8 @@ import ( "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" + "github.com/deckhouse/virtualization/api/core/v1alpha2" + internalcluster "github.com/deckhouse/storage-e2e/internal/cluster" "github.com/deckhouse/storage-e2e/internal/config" "github.com/deckhouse/storage-e2e/internal/infrastructure/ssh" @@ -47,7 +49,6 @@ import ( "github.com/deckhouse/storage-e2e/internal/logger" "github.com/deckhouse/storage-e2e/pkg/kubernetes" "github.com/deckhouse/storage-e2e/pkg/testkit" - "github.com/deckhouse/virtualization/api/core/v1alpha2" ) // extraCommanderValues stores additional values to be passed to Commander cluster creation @@ -1607,7 +1608,7 @@ func CleanupTestCluster(ctx context.Context, resources *TestClusterResources) er } } } else { - // Tunnel already exists, use it + // OpenTunnel already exists, use it logger.Success("Base cluster tunnel already exists") baseTunnel = resources.BaseTunnelInfo cleanupKubeconfig = resources.BaseKubeconfig diff --git a/pkg/kubernetes/modules.go b/pkg/kubernetes/modules.go index 896ab8f..2285f83 100644 --- a/pkg/kubernetes/modules.go +++ b/pkg/kubernetes/modules.go @@ -590,9 +590,6 @@ const moduleReadyPollInterval = 2 * time.Second // - On timeout the error carries the last observed phase and the IsReady // condition message so a stuck module is diagnosable from logs alone. func WaitForModuleReady(ctx context.Context, kubeconfig *rest.Config, moduleName string, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - var lastPhase, lastCondition string // ready re-reads the module and reports whether it has converged, recording From 15fdf63a9082ea8e22de564668f7311a393dc248 Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Fri, 26 Jun 2026 11:33:01 +0300 Subject: [PATCH 06/10] fix Signed-off-by: Daniil Studenikin --- internal/provisioning/dvp/config_test.go | 133 ----------------------- 1 file changed, 133 deletions(-) delete mode 100644 internal/provisioning/dvp/config_test.go diff --git a/internal/provisioning/dvp/config_test.go b/internal/provisioning/dvp/config_test.go deleted file mode 100644 index 6963a4f..0000000 --- a/internal/provisioning/dvp/config_test.go +++ /dev/null @@ -1,133 +0,0 @@ -/* -Copyright 2026 Flant JSC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package dvp - -import ( - "testing" -) - -func TestConfigBaseEndpoint(t *testing.T) { - tests := []struct { - name string - cfg Config - want sshEndpoint - }{ - { - name: "no jump host", - cfg: Config{ - SSHUser: "deckhouse", - SSHHost: "10.0.0.1", - SSHKeyPath: "/keys/id_rsa", - }, - want: sshEndpoint{ - User: "deckhouse", - Host: "10.0.0.1", - KeyPath: "/keys/id_rsa", - Jump: nil, - }, - }, - { - name: "jump host fully specified", - cfg: Config{ - SSHUser: "deckhouse", - SSHHost: "10.0.0.1", - SSHKeyPath: "/keys/target", - SSHJumpHost: "jump.example.com", - SSHJumpUser: "jumper", - SSHJumpKeyPath: "/keys/jump", - }, - want: sshEndpoint{ - User: "deckhouse", - Host: "10.0.0.1", - KeyPath: "/keys/target", - Jump: &sshEndpoint{ - User: "jumper", - Host: "jump.example.com", - KeyPath: "/keys/jump", - }, - }, - }, - { - name: "jump host inherits user and key from target", - cfg: Config{ - SSHUser: "deckhouse", - SSHHost: "10.0.0.1", - SSHKeyPath: "/keys/target", - SSHJumpHost: "jump.example.com", - }, - want: sshEndpoint{ - User: "deckhouse", - Host: "10.0.0.1", - KeyPath: "/keys/target", - Jump: &sshEndpoint{ - User: "deckhouse", - Host: "jump.example.com", - KeyPath: "/keys/target", - }, - }, - }, - { - name: "jump host inherits only missing fields", - cfg: Config{ - SSHUser: "deckhouse", - SSHHost: "10.0.0.1", - SSHKeyPath: "/keys/target", - SSHJumpHost: "jump.example.com", - SSHJumpUser: "jumper", - }, - want: sshEndpoint{ - User: "deckhouse", - Host: "10.0.0.1", - KeyPath: "/keys/target", - Jump: &sshEndpoint{ - User: "jumper", - Host: "jump.example.com", - KeyPath: "/keys/target", - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.cfg.baseEndpoint() - - if got.User != tt.want.User || got.Host != tt.want.Host || got.KeyPath != tt.want.KeyPath { - t.Errorf("endpoint = {User:%q Host:%q KeyPath:%q}, want {User:%q Host:%q KeyPath:%q}", - got.User, got.Host, got.KeyPath, tt.want.User, tt.want.Host, tt.want.KeyPath) - } - - switch { - case tt.want.Jump == nil && got.Jump != nil: - t.Errorf("Jump = %+v, want nil", got.Jump) - case tt.want.Jump != nil && got.Jump == nil: - t.Fatalf("Jump = nil, want %+v", tt.want.Jump) - case tt.want.Jump != nil && got.Jump != nil: - if got.Jump.User != tt.want.Jump.User || - got.Jump.Host != tt.want.Jump.Host || - got.Jump.KeyPath != tt.want.Jump.KeyPath { - t.Errorf("Jump = {User:%q Host:%q KeyPath:%q}, want {User:%q Host:%q KeyPath:%q}", - got.Jump.User, got.Jump.Host, got.Jump.KeyPath, - tt.want.Jump.User, tt.want.Jump.Host, tt.want.Jump.KeyPath) - } - if got.Jump.Jump != nil { - t.Errorf("Jump.Jump = %+v, want nil (no nested jump chain)", got.Jump.Jump) - } - } - }) - } -} From d38ed36f25d4ba2366d5dcea4cea08c04a4f2abd Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Fri, 26 Jun 2026 11:43:57 +0300 Subject: [PATCH 07/10] refactor: update keepalive context creation and rename buildSshClient method Signed-off-by: Daniil Studenikin --- docs/WORKLOG.md | 3 +++ internal/infrastructure/ssh/v2/conn.go | 2 +- internal/provisioning/dvp/provider.go | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/WORKLOG.md b/docs/WORKLOG.md index f255462..4112dc8 100644 --- a/docs/WORKLOG.md +++ b/docs/WORKLOG.md @@ -224,3 +224,6 @@ All notable changes to this repository are documented here. New entries are appe `log.Fatalf` instead of glued key/value arguments. - **Bugfix** `.github/workflows/e2e.yml`: checkout `storage-e2e` into `_storage-e2e` in the `run-tests` job before invoking `.github/scripts/e2e-run-tests.sh`. +- **Bugfix** `internal/infrastructure/ssh/v2/conn.go` (`newConn`): derive keepalive context via + `context.WithCancel(context.WithoutCancel(ctx))` instead of `context.Background()` to satisfy `contextcheck` while + keeping the loop lifetime tied to the connection (still cancelled in `Close`). diff --git a/internal/infrastructure/ssh/v2/conn.go b/internal/infrastructure/ssh/v2/conn.go index b8f5cb9..f6a441f 100644 --- a/internal/infrastructure/ssh/v2/conn.go +++ b/internal/infrastructure/ssh/v2/conn.go @@ -62,7 +62,7 @@ func newConn(ctx context.Context, d Dialer, o options) (*conn, error) { } if o.keepalive > 0 { - kaCtx, cancel := context.WithCancel(context.Background()) + kaCtx, cancel := context.WithCancel(context.WithoutCancel(ctx)) c.kaCancel = cancel c.wg.Add(1) go c.keepaliveLoop(kaCtx, o.keepalive) diff --git a/internal/provisioning/dvp/provider.go b/internal/provisioning/dvp/provider.go index a0ad26f..7ea776c 100644 --- a/internal/provisioning/dvp/provider.go +++ b/internal/provisioning/dvp/provider.go @@ -54,7 +54,7 @@ func NewDVPProvider(logger *slog.Logger, cfg *clusterprovider.ClusterConfig) (cl func (p *dvpProvider) Name() string { return clusterprovider.ModeDVP } -func (p *dvpProvider) buildSshClient(ctx context.Context) (*ssh.Client, error) { +func (p *dvpProvider) buildSSHClient(ctx context.Context) (*ssh.Client, error) { var dialer ssh.Dialer if p.dvpConf.HasJumpHost() { dialer = ssh.Route(ssh.Endpoint{ @@ -129,7 +129,7 @@ func (p *dvpProvider) Bootstrap(ctx context.Context) error { "kubeconfigSource", p.dvpConf.KubeConfigPath, ) - sshClient, sshNewErr := p.buildSshClient(ctx) + sshClient, sshNewErr := p.buildSSHClient(ctx) if sshNewErr != nil { return fmt.Errorf("creating ssh client: %w", sshNewErr) } From 2c6fdf0d234ab1671c2e3b83bbfa21e4a1454923 Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Fri, 26 Jun 2026 11:48:52 +0300 Subject: [PATCH 08/10] refactor: replace direct use of InsecureIgnoreHostKey with a wrapper function Signed-off-by: Daniil Studenikin --- internal/infrastructure/ssh/v2/endpoint.go | 2 +- internal/infrastructure/ssh/v2/options.go | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/internal/infrastructure/ssh/v2/endpoint.go b/internal/infrastructure/ssh/v2/endpoint.go index 01773b7..1678042 100644 --- a/internal/infrastructure/ssh/v2/endpoint.go +++ b/internal/infrastructure/ssh/v2/endpoint.go @@ -94,7 +94,7 @@ func (e Endpoint) clientConfig(ctx context.Context, defaultHostKey ssh.HostKeyCa hostKey = defaultHostKey } if hostKey == nil { - hostKey = ssh.InsecureIgnoreHostKey() + hostKey = insecureIgnoreHostKey() } cfg := &ssh.ClientConfig{ diff --git a/internal/infrastructure/ssh/v2/options.go b/internal/infrastructure/ssh/v2/options.go index 2f802ee..23fafd7 100644 --- a/internal/infrastructure/ssh/v2/options.go +++ b/internal/infrastructure/ssh/v2/options.go @@ -41,11 +41,15 @@ func defaultOptions() options { keepalive: 0, retries: config.SSHRetryCount, log: logger.GetLogger(), - hostKey: ssh.InsecureIgnoreHostKey(), + hostKey: insecureIgnoreHostKey(), dialTimeout: defaultDialTimeout, } } +func insecureIgnoreHostKey() ssh.HostKeyCallback { + return ssh.InsecureIgnoreHostKey() //nolint:gosec // G106: deliberate, see doc comment. +} + type Option func(*options) func WithKeepalive(d time.Duration) Option { @@ -78,6 +82,5 @@ func WithHostKeyCallback(cb ssh.HostKeyCallback) Option { } func WithInsecureIgnoreHostKey() Option { - //nolint:gosec // G106: explicit opt-in to skip host key verification. - return func(o *options) { o.hostKey = ssh.InsecureIgnoreHostKey() } + return func(o *options) { o.hostKey = insecureIgnoreHostKey() } } From 0d9245c752da665878d68557570ea63fbcfe2901 Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Fri, 26 Jun 2026 12:38:07 +0300 Subject: [PATCH 09/10] refactor: rename environment variables for SSH credentials in e2e workflow Signed-off-by: Daniil Studenikin --- .github/scripts/e2e-prepare-creds.sh | 10 +++++----- .github/workflows/e2e.yml | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/scripts/e2e-prepare-creds.sh b/.github/scripts/e2e-prepare-creds.sh index 6b7b009..a5337ad 100755 --- a/.github/scripts/e2e-prepare-creds.sh +++ b/.github/scripts/e2e-prepare-creds.sh @@ -4,8 +4,8 @@ # GITHUB_ENV. Never echoes secret values. # # Inputs (env): -# E2E_SSH_PRIVATE_KEY SSH private key contents (required) -# E2E_CLUSTER_KUBECONFIG base64-encoded kubeconfig (required) +# E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY SSH private key contents (required) +# E2E_DVP_BASE_CLUSTER_KUBECONFIG base64-encoded kubeconfig (required) # GITHUB_ENV file to append env exports to (required) # GITHUB_WORKSPACE workspace root to prune (optional) # RUNNER_TEMP dir for temp files (falls back to TMPDIR, then /tmp) @@ -16,14 +16,14 @@ tmp_dir="${RUNNER_TEMP:-${TMPDIR:-/tmp}}" ssh_key_path="$(mktemp "${tmp_dir%/}/e2e_ssh_key.XXXXXX")" kubeconfig_path="$(mktemp "${tmp_dir%/}/e2e_kubeconfig.XXXXXX")" -printf '%s\n' "${E2E_SSH_PRIVATE_KEY:?E2E_SSH_PRIVATE_KEY is required}" >"$ssh_key_path" +printf '%s\n' "${E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY:?E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY is required}" >"$ssh_key_path" chmod 600 "$ssh_key_path" -printf '%s' "${E2E_CLUSTER_KUBECONFIG:?E2E_CLUSTER_KUBECONFIG is required}" | base64 -d >"$kubeconfig_path" +printf '%s' "${E2E_DVP_BASE_CLUSTER_KUBECONFIG:?E2E_DVP_BASE_CLUSTER_KUBECONFIG is required}" | base64 -d >"$kubeconfig_path" chmod 600 "$kubeconfig_path" { - echo "E2E_DVP_BASE_CLUSTER_SSH_KEY_PATH=${ssh_key_path}" + echo "E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY_PATH=${ssh_key_path}" echo "E2E_DVP_BASE_CLUSTER_KUBECONFIG_PATH=${kubeconfig_path}" } >>"${GITHUB_ENV:?GITHUB_ENV is required}" diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index b40d302..b9a13b2 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -122,8 +122,8 @@ jobs: - name: Prepare credentials env: - E2E_SSH_PRIVATE_KEY: ${{ secrets.E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY }} - E2E_CLUSTER_KUBECONFIG: ${{ secrets.E2E_DVP_BASE_CLUSTER_KUBECONFIG }} + E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY: ${{ secrets.E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY }} + E2E_DVP_BASE_CLUSTER_KUBECONFIG: ${{ secrets.E2E_DVP_BASE_CLUSTER_KUBECONFIG }} run: bash _storage-e2e/.github/scripts/e2e-prepare-creds.sh - name: Bootstrap cluster From 17b5a2a6ddc1eb98b064fa0173b150ae9576bda2 Mon Sep 17 00:00:00 2001 From: Daniil Studenikin Date: Fri, 26 Jun 2026 12:39:19 +0300 Subject: [PATCH 10/10] refactor: rename environment variables for SSH credentials in e2e workflow Signed-off-by: Daniil Studenikin --- .github/workflows/e2e.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index b9a13b2..9323a7b 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -229,8 +229,8 @@ jobs: - name: Prepare credentials env: - E2E_SSH_PRIVATE_KEY: ${{ secrets.E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY }} - E2E_CLUSTER_KUBECONFIG: ${{ secrets.E2E_DVP_BASE_CLUSTER_KUBECONFIG }} + E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY: ${{ secrets.E2E_DVP_BASE_CLUSTER_SSH_PRIVATE_KEY }} + E2E_DVP_BASE_CLUSTER_KUBECONFIG: ${{ secrets.E2E_DVP_BASE_CLUSTER_KUBECONFIG }} run: bash _storage-e2e/.github/scripts/e2e-prepare-creds.sh - name: Remove cluster