Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions browser.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"os/exec"
"runtime"
Expand All @@ -9,16 +10,16 @@ import (
// openBrowser attempts to open url in the user's default browser.
// Returns an error if launching the browser fails, but callers should
// always print the URL as a fallback regardless of the error.
func openBrowser(url string) error {
func openBrowser(ctx context.Context, url string) error {
var cmd *exec.Cmd

switch runtime.GOOS {
case "darwin":
cmd = exec.Command("open", url)
cmd = exec.CommandContext(ctx, "open", url)
case "windows":
cmd = exec.Command("cmd", "/c", "start", url)
cmd = exec.CommandContext(ctx, "cmd", "/c", "start", url)
default:
cmd = exec.Command("xdg-open", url)
cmd = exec.CommandContext(ctx, "xdg-open", url)
}

if err := cmd.Start(); err != nil {
Expand Down
12 changes: 6 additions & 6 deletions browser_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
// - (storage, true, nil) on success
// - (nil, false, nil) when openBrowser() fails — caller should fall back to Device Code Flow
// - (nil, false, err) on a hard error (CSRF mismatch, token exchange failure, etc.)
func performBrowserFlow() (*TokenStorage, bool, error) {
func performBrowserFlow(ctx context.Context) (*TokenStorage, bool, error) {
state, err := generateState()
if err != nil {
return nil, false, fmt.Errorf("failed to generate state: %w", err)
Expand All @@ -34,7 +34,7 @@ func performBrowserFlow() (*TokenStorage, bool, error) {
fmt.Println("Step 1: Opening browser for authorization...")
fmt.Printf("\n %s\n\n", authURL)

if err := openBrowser(authURL); err != nil {
if err := openBrowser(ctx, authURL); err != nil {
// Browser failed to open — signal the caller to fall back immediately.
fmt.Printf("Could not open browser: %v\n", err)
return nil, false, nil
Expand All @@ -43,7 +43,7 @@ func performBrowserFlow() (*TokenStorage, bool, error) {
fmt.Println("Browser opened. Please complete authorization in your browser.")
fmt.Printf("Step 2: Waiting for callback on http://localhost:%d/callback ...\n", callbackPort)

code, err := startCallbackServer(callbackPort, state)
code, err := startCallbackServer(ctx, callbackPort, state)
if err != nil {
if errors.Is(err, ErrCallbackTimeout) {
// User opened the browser but didn't complete authorization in time.
Expand All @@ -59,7 +59,7 @@ func performBrowserFlow() (*TokenStorage, bool, error) {
fmt.Println("Authorization code received!")

fmt.Println("Step 3: Exchanging authorization code for tokens...")
storage, err := exchangeCode(code, pkce.Verifier)
storage, err := exchangeCode(ctx, code, pkce.Verifier)
if err != nil {
return nil, false, fmt.Errorf("token exchange failed: %w", err)
}
Expand Down Expand Up @@ -88,8 +88,8 @@ func buildAuthURL(state string, pkce *PKCEParams) string {
}

// exchangeCode exchanges an authorization code for access + refresh tokens.
func exchangeCode(code, codeVerifier string) (*TokenStorage, error) {
ctx, cancel := context.WithTimeout(context.Background(), tokenExchangeTimeout)
func exchangeCode(ctx context.Context, code, codeVerifier string) (*TokenStorage, error) {
ctx, cancel := context.WithTimeout(ctx, tokenExchangeTimeout)
defer cancel()

data := url.Values{}
Expand Down
8 changes: 4 additions & 4 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type callbackResult struct {
// and returns the authorization code (or an error).
//
// The server shuts itself down after the first request.
func startCallbackServer(port int, expectedState string) (string, error) {
func startCallbackServer(ctx context.Context, port int, expectedState string) (string, error) {
resultCh := make(chan callbackResult, 1)

var once sync.Once
Expand Down Expand Up @@ -80,7 +80,7 @@ func startCallbackServer(port int, expectedState string) (string, error) {
WriteTimeout: 10 * time.Second,
}

ln, err := net.Listen("tcp", srv.Addr)
ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", srv.Addr)
if err != nil {
return "", fmt.Errorf("failed to start callback server on port %d: %w", port, err)
}
Expand All @@ -90,9 +90,9 @@ func startCallbackServer(port int, expectedState string) (string, error) {
}()

defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
_ = srv.Shutdown(shutdownCtx)
}()

select {
Expand Down
27 changes: 15 additions & 12 deletions callback_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"io"
"net/http"
Expand All @@ -11,11 +12,13 @@ import (

// startCallbackServerAsync starts the callback server in a goroutine and
// returns a channel that will receive the authorization code (or error string).
func startCallbackServerAsync(t *testing.T, port int, state string) chan string {
func startCallbackServerAsync(
t *testing.T, ctx context.Context, port int, state string,
) chan string {
t.Helper()
ch := make(chan string, 1)
go func() {
code, err := startCallbackServer(port, state)
code, err := startCallbackServer(ctx, port, state)
if err != nil {
ch <- "ERROR:" + err.Error()
} else {
Expand All @@ -31,13 +34,13 @@ func TestCallbackServer_Success(t *testing.T) {
const port = 19101
state := "test-state-success"

ch := startCallbackServerAsync(t, port, state)
ch := startCallbackServerAsync(t, context.Background(), port, state)

callbackURL := fmt.Sprintf(
"http://127.0.0.1:%d/callback?code=mycode123&state=%s",
port, state,
)
resp, err := http.Get(callbackURL) //nolint:noctx
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
if err != nil {
t.Fatalf("GET callback failed: %v", err)
}
Expand Down Expand Up @@ -65,13 +68,13 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
const port = 19102
state := "expected-state"

ch := startCallbackServerAsync(t, port, state)
ch := startCallbackServerAsync(t, context.Background(), port, state)

callbackURL := fmt.Sprintf(
"http://127.0.0.1:%d/callback?code=mycode&state=wrong-state",
port,
)
resp, err := http.Get(callbackURL) //nolint:noctx
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
if err != nil {
t.Fatalf("GET callback failed: %v", err)
}
Expand All @@ -96,13 +99,13 @@ func TestCallbackServer_OAuthError(t *testing.T) {
const port = 19103
state := "state-for-error"

ch := startCallbackServerAsync(t, port, state)
ch := startCallbackServerAsync(t, context.Background(), port, state)

callbackURL := fmt.Sprintf(
"http://127.0.0.1:%d/callback?error=access_denied&error_description=User+denied&state=%s",
port, state,
)
resp, err := http.Get(callbackURL) //nolint:noctx
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
if err != nil {
t.Fatalf("GET callback failed: %v", err)
}
Expand Down Expand Up @@ -130,14 +133,14 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
const port = 19105
state := "test-state-double"

ch := startCallbackServerAsync(t, port, state)
ch := startCallbackServerAsync(t, context.Background(), port, state)

url := fmt.Sprintf("http://127.0.0.1:%d/callback?code=mycode&state=%s", port, state)

done := make(chan error, 2)
for range 2 {
go func() {
resp, err := http.Get(url) //nolint:noctx
resp, err := http.Get(url) //nolint:noctx,gosec
if err == nil {
resp.Body.Close()
}
Expand Down Expand Up @@ -167,13 +170,13 @@ func TestCallbackServer_MissingCode(t *testing.T) {
const port = 19104
state := "state-for-missing-code"

ch := startCallbackServerAsync(t, port, state)
ch := startCallbackServerAsync(t, context.Background(), port, state)

callbackURL := fmt.Sprintf(
"http://127.0.0.1:%d/callback?state=%s",
port, state,
)
resp, err := http.Get(callbackURL) //nolint:noctx
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
if err != nil {
t.Fatalf("GET callback failed: %v", err)
}
Expand Down
6 changes: 4 additions & 2 deletions detect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"net"
"os"
Expand All @@ -23,7 +24,7 @@ type BrowserAvailability struct {
// This function never attempts to open a browser itself; it only inspects
// the environment. Callers that pass the check should still handle
// openBrowser() failures as a secondary fallback.
func checkBrowserAvailability(port int) BrowserAvailability {
func checkBrowserAvailability(ctx context.Context, port int) BrowserAvailability {
// Stage 1a: SSH without X11/Wayland forwarding.
// SSH_TTY / SSH_CLIENT / SSH_CONNECTION indicate a remote shell.
// If a display is also present (X11 forwarding), the browser can still open.
Expand All @@ -44,7 +45,8 @@ func checkBrowserAvailability(port int) BrowserAvailability {

// Stage 2: Verify the callback port can be bound.
// A busy port means the redirect server cannot start.
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
lc := &net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return BrowserAvailability{
false,
Expand Down
23 changes: 12 additions & 11 deletions detect_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"net"
"testing"
)
Expand All @@ -12,7 +13,7 @@ func TestCheckBrowserAvailability_SSH_NoDisplay(t *testing.T) {
t.Setenv("DISPLAY", "")
t.Setenv("WAYLAND_DISPLAY", "")

avail := checkBrowserAvailability(18888)
avail := checkBrowserAvailability(context.Background(), 18888)

if avail.Available {
t.Error("expected browser unavailable in SSH session without display")
Expand All @@ -29,7 +30,7 @@ func TestCheckBrowserAvailability_SSHClient_NoDisplay(t *testing.T) {
t.Setenv("DISPLAY", "")
t.Setenv("WAYLAND_DISPLAY", "")

avail := checkBrowserAvailability(18888)
avail := checkBrowserAvailability(context.Background(), 18888)

if avail.Available {
t.Error("expected browser unavailable when SSH_CLIENT set and no display")
Expand All @@ -43,7 +44,7 @@ func TestCheckBrowserAvailability_SSHConnection_NoDisplay(t *testing.T) {
t.Setenv("DISPLAY", "")
t.Setenv("WAYLAND_DISPLAY", "")

avail := checkBrowserAvailability(18888)
avail := checkBrowserAvailability(context.Background(), 18888)

if avail.Available {
t.Error("expected browser unavailable when SSH_CONNECTION set and no display")
Expand All @@ -58,14 +59,14 @@ func TestCheckBrowserAvailability_SSH_WithX11(t *testing.T) {

// Use a port that is definitely free (bind to :0 and get the port,
// then close it; the brief gap is acceptable for a unit test).
ln, err := net.Listen("tcp", "127.0.0.1:0")
ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
if err != nil {
t.Skip("cannot allocate test port")
}
port := ln.Addr().(*net.TCPAddr).Port
ln.Close()

avail := checkBrowserAvailability(port)
avail := checkBrowserAvailability(context.Background(), port)

// X11 forwarding over SSH should be detected as browser-capable
// (DISPLAY is set, port is free).
Expand All @@ -76,7 +77,7 @@ func TestCheckBrowserAvailability_SSH_WithX11(t *testing.T) {

func TestCheckBrowserAvailability_PortUnavailable(t *testing.T) {
// Bind a port and keep it busy during the test.
ln, err := net.Listen("tcp", "127.0.0.1:0")
ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
if err != nil {
t.Skip("cannot bind test port")
}
Expand All @@ -92,7 +93,7 @@ func TestCheckBrowserAvailability_PortUnavailable(t *testing.T) {
t.Setenv("DISPLAY", ":0")
t.Setenv("WAYLAND_DISPLAY", "")

avail := checkBrowserAvailability(port)
avail := checkBrowserAvailability(context.Background(), port)

if avail.Available {
t.Errorf("expected browser unavailable when port %d is busy", port)
Expand All @@ -104,7 +105,7 @@ func TestCheckBrowserAvailability_PortUnavailable(t *testing.T) {

func TestCheckBrowserAvailability_PortAvailable(t *testing.T) {
// Find a free port.
ln, err := net.Listen("tcp", "127.0.0.1:0")
ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
if err != nil {
t.Skip("cannot allocate test port")
}
Expand All @@ -117,7 +118,7 @@ func TestCheckBrowserAvailability_PortAvailable(t *testing.T) {
t.Setenv("DISPLAY", ":0")
t.Setenv("WAYLAND_DISPLAY", "")

avail := checkBrowserAvailability(port)
avail := checkBrowserAvailability(context.Background(), port)

if !avail.Available {
t.Errorf(
Expand All @@ -128,7 +129,7 @@ func TestCheckBrowserAvailability_PortAvailable(t *testing.T) {
}

func TestCheckBrowserAvailability_ReasonIsEmptyWhenAvailable(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
if err != nil {
t.Skip("cannot allocate test port")
}
Expand All @@ -141,7 +142,7 @@ func TestCheckBrowserAvailability_ReasonIsEmptyWhenAvailable(t *testing.T) {
t.Setenv("DISPLAY", ":0")
t.Setenv("WAYLAND_DISPLAY", "")

avail := checkBrowserAvailability(port)
avail := checkBrowserAvailability(context.Background(), port)

if avail.Available && avail.Reason != "" {
t.Errorf("expected empty reason when browser is available, got: %s", avail.Reason)
Expand Down
4 changes: 2 additions & 2 deletions filelock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestConcurrentLocks(t *testing.T) {
concurrent--
mu.Unlock()

lock.release()
_ = lock.release()
}(i)
}

Expand All @@ -89,5 +89,5 @@ func TestStaleLockRemoval(t *testing.T) {
if err != nil {
t.Fatalf("acquireFileLock() with stale lock: %v", err)
}
lock.release()
_ = lock.release()
}
Loading
Loading