From 020501744cc4c8b04e13adb9a974c73c3b598aff Mon Sep 17 00:00:00 2001 From: appleboy Date: Fri, 20 Feb 2026 09:52:01 +0800 Subject: [PATCH] refactor: enable context-aware operations and graceful shutdown throughout - Refactor functions throughout the codebase to accept and propagate context.Context for improved cancellation support - Add signal-based context initialization for graceful shutdown in main - Move application logic into a run function that returns an exit code for easier control flow - Use net.ListenConfig with context in place of net.Listen to support context-aware socket binding - Update browser open, callback server, token exchange, token verification, and refresh operations to be context-aware - Update all related unit tests to provide context explicitly and use ListenConfig for port binding - Replace plain lock.release calls with error-ignoring variants to avoid unused return value errors - Add error handling for JSON encoding failures in all HTTP test handlers - Introduce a helper function for error condition OAuth device flow tests to reduce duplicate test code - Minor test improvements including commenting, error messages, and use of constants over literals for tokens Signed-off-by: appleboy --- browser.go | 9 ++-- browser_flow.go | 12 ++--- callback.go | 8 +-- callback_test.go | 27 ++++++----- detect.go | 6 ++- detect_test.go | 23 ++++----- filelock_test.go | 4 +- main.go | 35 +++++++++----- main_test.go | 14 ++++-- polling_test.go | 123 ++++++++++++++++++++++++----------------------- tokens.go | 2 +- 11 files changed, 144 insertions(+), 119 deletions(-) diff --git a/browser.go b/browser.go index 7582915..fe72b4d 100644 --- a/browser.go +++ b/browser.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os/exec" "runtime" @@ -9,16 +10,16 @@ import ( // openBrowser attempts to open url in the user's default browser. // Returns an error if launching the browser fails, but callers should // always print the URL as a fallback regardless of the error. -func openBrowser(url string) error { +func openBrowser(ctx context.Context, url string) error { var cmd *exec.Cmd switch runtime.GOOS { case "darwin": - cmd = exec.Command("open", url) + cmd = exec.CommandContext(ctx, "open", url) case "windows": - cmd = exec.Command("cmd", "/c", "start", url) + cmd = exec.CommandContext(ctx, "cmd", "/c", "start", url) default: - cmd = exec.Command("xdg-open", url) + cmd = exec.CommandContext(ctx, "xdg-open", url) } if err := cmd.Start(); err != nil { diff --git a/browser_flow.go b/browser_flow.go index 535759c..c0dada7 100644 --- a/browser_flow.go +++ b/browser_flow.go @@ -18,7 +18,7 @@ import ( // - (storage, true, nil) on success // - (nil, false, nil) when openBrowser() fails — caller should fall back to Device Code Flow // - (nil, false, err) on a hard error (CSRF mismatch, token exchange failure, etc.) -func performBrowserFlow() (*TokenStorage, bool, error) { +func performBrowserFlow(ctx context.Context) (*TokenStorage, bool, error) { state, err := generateState() if err != nil { return nil, false, fmt.Errorf("failed to generate state: %w", err) @@ -34,7 +34,7 @@ func performBrowserFlow() (*TokenStorage, bool, error) { fmt.Println("Step 1: Opening browser for authorization...") fmt.Printf("\n %s\n\n", authURL) - if err := openBrowser(authURL); err != nil { + if err := openBrowser(ctx, authURL); err != nil { // Browser failed to open — signal the caller to fall back immediately. fmt.Printf("Could not open browser: %v\n", err) return nil, false, nil @@ -43,7 +43,7 @@ func performBrowserFlow() (*TokenStorage, bool, error) { fmt.Println("Browser opened. Please complete authorization in your browser.") fmt.Printf("Step 2: Waiting for callback on http://localhost:%d/callback ...\n", callbackPort) - code, err := startCallbackServer(callbackPort, state) + code, err := startCallbackServer(ctx, callbackPort, state) if err != nil { if errors.Is(err, ErrCallbackTimeout) { // User opened the browser but didn't complete authorization in time. @@ -59,7 +59,7 @@ func performBrowserFlow() (*TokenStorage, bool, error) { fmt.Println("Authorization code received!") fmt.Println("Step 3: Exchanging authorization code for tokens...") - storage, err := exchangeCode(code, pkce.Verifier) + storage, err := exchangeCode(ctx, code, pkce.Verifier) if err != nil { return nil, false, fmt.Errorf("token exchange failed: %w", err) } @@ -88,8 +88,8 @@ func buildAuthURL(state string, pkce *PKCEParams) string { } // exchangeCode exchanges an authorization code for access + refresh tokens. -func exchangeCode(code, codeVerifier string) (*TokenStorage, error) { - ctx, cancel := context.WithTimeout(context.Background(), tokenExchangeTimeout) +func exchangeCode(ctx context.Context, code, codeVerifier string) (*TokenStorage, error) { + ctx, cancel := context.WithTimeout(ctx, tokenExchangeTimeout) defer cancel() data := url.Values{} diff --git a/callback.go b/callback.go index 9837d62..e272a9b 100644 --- a/callback.go +++ b/callback.go @@ -32,7 +32,7 @@ type callbackResult struct { // and returns the authorization code (or an error). // // The server shuts itself down after the first request. -func startCallbackServer(port int, expectedState string) (string, error) { +func startCallbackServer(ctx context.Context, port int, expectedState string) (string, error) { resultCh := make(chan callbackResult, 1) var once sync.Once @@ -80,7 +80,7 @@ func startCallbackServer(port int, expectedState string) (string, error) { WriteTimeout: 10 * time.Second, } - ln, err := net.Listen("tcp", srv.Addr) + ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", srv.Addr) if err != nil { return "", fmt.Errorf("failed to start callback server on port %d: %w", port, err) } @@ -90,9 +90,9 @@ func startCallbackServer(port int, expectedState string) (string, error) { }() defer func() { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - _ = srv.Shutdown(ctx) + _ = srv.Shutdown(shutdownCtx) }() select { diff --git a/callback_test.go b/callback_test.go index 60249ca..8c8ad10 100644 --- a/callback_test.go +++ b/callback_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "io" "net/http" @@ -11,11 +12,13 @@ import ( // startCallbackServerAsync starts the callback server in a goroutine and // returns a channel that will receive the authorization code (or error string). -func startCallbackServerAsync(t *testing.T, port int, state string) chan string { +func startCallbackServerAsync( + t *testing.T, ctx context.Context, port int, state string, +) chan string { t.Helper() ch := make(chan string, 1) go func() { - code, err := startCallbackServer(port, state) + code, err := startCallbackServer(ctx, port, state) if err != nil { ch <- "ERROR:" + err.Error() } else { @@ -31,13 +34,13 @@ func TestCallbackServer_Success(t *testing.T) { const port = 19101 state := "test-state-success" - ch := startCallbackServerAsync(t, port, state) + ch := startCallbackServerAsync(t, context.Background(), port, state) callbackURL := fmt.Sprintf( "http://127.0.0.1:%d/callback?code=mycode123&state=%s", port, state, ) - resp, err := http.Get(callbackURL) //nolint:noctx + resp, err := http.Get(callbackURL) //nolint:noctx,gosec if err != nil { t.Fatalf("GET callback failed: %v", err) } @@ -65,13 +68,13 @@ func TestCallbackServer_StateMismatch(t *testing.T) { const port = 19102 state := "expected-state" - ch := startCallbackServerAsync(t, port, state) + ch := startCallbackServerAsync(t, context.Background(), port, state) callbackURL := fmt.Sprintf( "http://127.0.0.1:%d/callback?code=mycode&state=wrong-state", port, ) - resp, err := http.Get(callbackURL) //nolint:noctx + resp, err := http.Get(callbackURL) //nolint:noctx,gosec if err != nil { t.Fatalf("GET callback failed: %v", err) } @@ -96,13 +99,13 @@ func TestCallbackServer_OAuthError(t *testing.T) { const port = 19103 state := "state-for-error" - ch := startCallbackServerAsync(t, port, state) + ch := startCallbackServerAsync(t, context.Background(), port, state) callbackURL := fmt.Sprintf( "http://127.0.0.1:%d/callback?error=access_denied&error_description=User+denied&state=%s", port, state, ) - resp, err := http.Get(callbackURL) //nolint:noctx + resp, err := http.Get(callbackURL) //nolint:noctx,gosec if err != nil { t.Fatalf("GET callback failed: %v", err) } @@ -130,14 +133,14 @@ func TestCallbackServer_DoubleCallback(t *testing.T) { const port = 19105 state := "test-state-double" - ch := startCallbackServerAsync(t, port, state) + ch := startCallbackServerAsync(t, context.Background(), port, state) url := fmt.Sprintf("http://127.0.0.1:%d/callback?code=mycode&state=%s", port, state) done := make(chan error, 2) for range 2 { go func() { - resp, err := http.Get(url) //nolint:noctx + resp, err := http.Get(url) //nolint:noctx,gosec if err == nil { resp.Body.Close() } @@ -167,13 +170,13 @@ func TestCallbackServer_MissingCode(t *testing.T) { const port = 19104 state := "state-for-missing-code" - ch := startCallbackServerAsync(t, port, state) + ch := startCallbackServerAsync(t, context.Background(), port, state) callbackURL := fmt.Sprintf( "http://127.0.0.1:%d/callback?state=%s", port, state, ) - resp, err := http.Get(callbackURL) //nolint:noctx + resp, err := http.Get(callbackURL) //nolint:noctx,gosec if err != nil { t.Fatalf("GET callback failed: %v", err) } diff --git a/detect.go b/detect.go index 88f074a..b7c5aa4 100644 --- a/detect.go +++ b/detect.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "net" "os" @@ -23,7 +24,7 @@ type BrowserAvailability struct { // This function never attempts to open a browser itself; it only inspects // the environment. Callers that pass the check should still handle // openBrowser() failures as a secondary fallback. -func checkBrowserAvailability(port int) BrowserAvailability { +func checkBrowserAvailability(ctx context.Context, port int) BrowserAvailability { // Stage 1a: SSH without X11/Wayland forwarding. // SSH_TTY / SSH_CLIENT / SSH_CONNECTION indicate a remote shell. // If a display is also present (X11 forwarding), the browser can still open. @@ -44,7 +45,8 @@ func checkBrowserAvailability(port int) BrowserAvailability { // Stage 2: Verify the callback port can be bound. // A busy port means the redirect server cannot start. - ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + lc := &net.ListenConfig{} + ln, err := lc.Listen(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port)) if err != nil { return BrowserAvailability{ false, diff --git a/detect_test.go b/detect_test.go index f6e8e79..1b5b4f8 100644 --- a/detect_test.go +++ b/detect_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "net" "testing" ) @@ -12,7 +13,7 @@ func TestCheckBrowserAvailability_SSH_NoDisplay(t *testing.T) { t.Setenv("DISPLAY", "") t.Setenv("WAYLAND_DISPLAY", "") - avail := checkBrowserAvailability(18888) + avail := checkBrowserAvailability(context.Background(), 18888) if avail.Available { t.Error("expected browser unavailable in SSH session without display") @@ -29,7 +30,7 @@ func TestCheckBrowserAvailability_SSHClient_NoDisplay(t *testing.T) { t.Setenv("DISPLAY", "") t.Setenv("WAYLAND_DISPLAY", "") - avail := checkBrowserAvailability(18888) + avail := checkBrowserAvailability(context.Background(), 18888) if avail.Available { t.Error("expected browser unavailable when SSH_CLIENT set and no display") @@ -43,7 +44,7 @@ func TestCheckBrowserAvailability_SSHConnection_NoDisplay(t *testing.T) { t.Setenv("DISPLAY", "") t.Setenv("WAYLAND_DISPLAY", "") - avail := checkBrowserAvailability(18888) + avail := checkBrowserAvailability(context.Background(), 18888) if avail.Available { t.Error("expected browser unavailable when SSH_CONNECTION set and no display") @@ -58,14 +59,14 @@ func TestCheckBrowserAvailability_SSH_WithX11(t *testing.T) { // Use a port that is definitely free (bind to :0 and get the port, // then close it; the brief gap is acceptable for a unit test). - ln, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") if err != nil { t.Skip("cannot allocate test port") } port := ln.Addr().(*net.TCPAddr).Port ln.Close() - avail := checkBrowserAvailability(port) + avail := checkBrowserAvailability(context.Background(), port) // X11 forwarding over SSH should be detected as browser-capable // (DISPLAY is set, port is free). @@ -76,7 +77,7 @@ func TestCheckBrowserAvailability_SSH_WithX11(t *testing.T) { func TestCheckBrowserAvailability_PortUnavailable(t *testing.T) { // Bind a port and keep it busy during the test. - ln, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") if err != nil { t.Skip("cannot bind test port") } @@ -92,7 +93,7 @@ func TestCheckBrowserAvailability_PortUnavailable(t *testing.T) { t.Setenv("DISPLAY", ":0") t.Setenv("WAYLAND_DISPLAY", "") - avail := checkBrowserAvailability(port) + avail := checkBrowserAvailability(context.Background(), port) if avail.Available { t.Errorf("expected browser unavailable when port %d is busy", port) @@ -104,7 +105,7 @@ func TestCheckBrowserAvailability_PortUnavailable(t *testing.T) { func TestCheckBrowserAvailability_PortAvailable(t *testing.T) { // Find a free port. - ln, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") if err != nil { t.Skip("cannot allocate test port") } @@ -117,7 +118,7 @@ func TestCheckBrowserAvailability_PortAvailable(t *testing.T) { t.Setenv("DISPLAY", ":0") t.Setenv("WAYLAND_DISPLAY", "") - avail := checkBrowserAvailability(port) + avail := checkBrowserAvailability(context.Background(), port) if !avail.Available { t.Errorf( @@ -128,7 +129,7 @@ func TestCheckBrowserAvailability_PortAvailable(t *testing.T) { } func TestCheckBrowserAvailability_ReasonIsEmptyWhenAvailable(t *testing.T) { - ln, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") if err != nil { t.Skip("cannot allocate test port") } @@ -141,7 +142,7 @@ func TestCheckBrowserAvailability_ReasonIsEmptyWhenAvailable(t *testing.T) { t.Setenv("DISPLAY", ":0") t.Setenv("WAYLAND_DISPLAY", "") - avail := checkBrowserAvailability(port) + avail := checkBrowserAvailability(context.Background(), port) if avail.Available && avail.Reason != "" { t.Errorf("expected empty reason when browser is available, got: %s", avail.Reason) diff --git a/filelock_test.go b/filelock_test.go index e058e8c..4425188 100644 --- a/filelock_test.go +++ b/filelock_test.go @@ -62,7 +62,7 @@ func TestConcurrentLocks(t *testing.T) { concurrent-- mu.Unlock() - lock.release() + _ = lock.release() }(i) } @@ -89,5 +89,5 @@ func TestStaleLockRemoval(t *testing.T) { if err != nil { t.Fatalf("acquireFileLock() with stale lock: %v", err) } - lock.release() + _ = lock.release() } diff --git a/main.go b/main.go index 9a14ead..d78814b 100644 --- a/main.go +++ b/main.go @@ -8,13 +8,22 @@ import ( "net/http" "net/url" "os" + "os/signal" "strings" + "syscall" "time" ) func main() { initConfig() + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + exitCode := run(ctx) + stop() + os.Exit(exitCode) +} + +func run(ctx context.Context) int { clientMode := "public (PKCE)" if !isPublicClient() { clientMode = "confidential" @@ -25,7 +34,6 @@ func main() { fmt.Printf("Client ID : %s\n", clientID) fmt.Println() - ctx := context.Background() var storage *TokenStorage // Try to reuse or refresh existing tokens. @@ -37,7 +45,7 @@ func main() { storage = existing } else { fmt.Println("Access token expired, attempting refresh...") - newStorage, err := refreshAccessToken(existing.RefreshToken) + newStorage, err := refreshAccessToken(ctx, existing.RefreshToken) if err != nil { fmt.Printf("Refresh failed: %v\n", err) fmt.Println("Starting new authentication flow...") @@ -55,7 +63,7 @@ func main() { storage, err = authenticate(ctx) if err != nil { fmt.Fprintf(os.Stderr, "Authentication failed: %v\n", err) - os.Exit(1) + return 1 } } @@ -76,7 +84,7 @@ func main() { // Verify token against server. fmt.Println("\nVerifying token with server...") - if err := verifyToken(storage.AccessToken); err != nil { + if err := verifyToken(ctx, storage.AccessToken); err != nil { fmt.Printf("Token verification failed: %v\n", err) } else { fmt.Println("Token verified successfully.") @@ -90,17 +98,18 @@ func main() { storage, err = authenticate(ctx) if err != nil { fmt.Fprintf(os.Stderr, "Re-authentication failed: %v\n", err) - os.Exit(1) + return 1 } if err := makeAPICallWithAutoRefresh(ctx, storage); err != nil { fmt.Fprintf(os.Stderr, "API call failed after re-authentication: %v\n", err) - os.Exit(1) + return 1 } fmt.Println("API call successful after re-authentication.") } else { fmt.Fprintf(os.Stderr, "API call failed: %v\n", err) } } + return 0 } // authenticate selects and runs the appropriate OAuth flow: @@ -115,14 +124,14 @@ func authenticate(ctx context.Context) (*TokenStorage, error) { return performDeviceFlow(ctx) } - avail := checkBrowserAvailability(callbackPort) + avail := checkBrowserAvailability(ctx, callbackPort) if !avail.Available { fmt.Printf("Auth method : Device Code Flow (%s)\n", avail.Reason) return performDeviceFlow(ctx) } fmt.Println("Auth method : Authorization Code Flow (browser)") - storage, ok, err := performBrowserFlow() + storage, ok, err := performBrowserFlow(ctx) if err != nil { return nil, err } @@ -138,8 +147,8 @@ func authenticate(ctx context.Context) (*TokenStorage, error) { // Token refresh // ----------------------------------------------------------------------- -func refreshAccessToken(refreshToken string) (*TokenStorage, error) { - ctx, cancel := context.WithTimeout(context.Background(), refreshTokenTimeout) +func refreshAccessToken(ctx context.Context, refreshToken string) (*TokenStorage, error) { + ctx, cancel := context.WithTimeout(ctx, refreshTokenTimeout) defer cancel() data := url.Values{} @@ -225,8 +234,8 @@ func refreshAccessToken(refreshToken string) (*TokenStorage, error) { // Token verification / API demo // ----------------------------------------------------------------------- -func verifyToken(accessToken string) error { - ctx, cancel := context.WithTimeout(context.Background(), tokenVerificationTimeout) +func verifyToken(ctx context.Context, accessToken string) error { + ctx, cancel := context.WithTimeout(ctx, tokenVerificationTimeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, serverURL+"/oauth/tokeninfo", nil) @@ -275,7 +284,7 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *TokenStorage) erro if resp.StatusCode == http.StatusUnauthorized { fmt.Println("Access token rejected (401), refreshing...") - newStorage, err := refreshAccessToken(storage.RefreshToken) + newStorage, err := refreshAccessToken(ctx, storage.RefreshToken) if err != nil { if err == ErrRefreshTokenExpired { return ErrRefreshTokenExpired diff --git a/main_test.go b/main_test.go index 0237066..72e7346 100644 --- a/main_test.go +++ b/main_test.go @@ -25,7 +25,7 @@ func init() { clientID = "test-client" } if tokenFile == "" { - tokenFile = ".authgate-tokens.json" + tokenFile = ".authgate-tokens.json" //nolint:gosec } if scope == "" { scope = "read write" @@ -352,14 +352,16 @@ func TestRefreshAccessToken_RotationMode(t *testing.T) { resp["refresh_token"] = tt.responseRefreshToken } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + if err := json.NewEncoder(w).Encode(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } }), ) defer srv.Close() serverURL = srv.URL - storage, err := refreshAccessToken(tt.oldRefreshToken) + storage, err := refreshAccessToken(context.Background(), tt.oldRefreshToken) if err != nil { t.Fatalf("refreshAccessToken() error: %v", err) } @@ -398,14 +400,16 @@ func TestRequestDeviceCode_WithRetry(t *testing.T) { return } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ + if err := json.NewEncoder(w).Encode(map[string]interface{}{ "device_code": "test-device-code", "user_code": "TEST-CODE", "verification_uri": testServer.URL + "/device", "verification_uri_complete": testServer.URL + "/device?user_code=TEST-CODE", "expires_in": 600, "interval": 5, - }) + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } })) defer testServer.Close() diff --git a/polling_test.go b/polling_test.go index 90ded7b..70c6c06 100644 --- a/polling_test.go +++ b/polling_test.go @@ -13,6 +13,8 @@ import ( "golang.org/x/oauth2" ) +const testAccessToken = "test-access-token" + func TestPollForToken_AuthorizationPending(t *testing.T) { attempts := atomic.Int32{} server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -21,20 +23,24 @@ func TestPollForToken_AuthorizationPending(t *testing.T) { if attempts.Load() < 3 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{ + if err := json.NewEncoder(w).Encode(map[string]string{ "error": "authorization_pending", "error_description": "User has not yet authorized", - }) + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } return } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "access_token": "test-access-token", + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": testAccessToken, "refresh_token": "test-refresh-token", "token_type": "Bearer", "expires_in": 3600, - }) + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } })) defer server.Close() @@ -54,8 +60,8 @@ func TestPollForToken_AuthorizationPending(t *testing.T) { if err != nil { t.Fatalf("expected success, got error: %v", err) } - if token.AccessToken != "test-access-token" { - t.Errorf("access token = %q, want %q", token.AccessToken, "test-access-token") + if token.AccessToken != testAccessToken { + t.Errorf("access token = %q, want %q", token.AccessToken, testAccessToken) } if attempts.Load() < 3 { t.Errorf("expected at least 3 attempts, got %d", attempts.Load()) @@ -73,30 +79,36 @@ func TestPollForToken_SlowDown(t *testing.T) { slowDownCount.Add(1) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{ + if err := json.NewEncoder(w).Encode(map[string]string{ "error": "slow_down", "error_description": "Polling too frequently", - }) + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } return } if attempts.Load() < 5 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{ + if err := json.NewEncoder(w).Encode(map[string]string{ "error": "authorization_pending", "error_description": "User has not yet authorized", - }) + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } return } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "access_token": "test-access-token", + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": testAccessToken, "refresh_token": "test-refresh-token", "token_type": "Bearer", "expires_in": 3600, - }) + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } })) defer server.Close() @@ -116,22 +128,27 @@ func TestPollForToken_SlowDown(t *testing.T) { if err != nil { t.Fatalf("expected success, got error: %v", err) } - if token.AccessToken != "test-access-token" { - t.Errorf("access token = %q, want %q", token.AccessToken, "test-access-token") + if token.AccessToken != testAccessToken { + t.Errorf("access token = %q, want %q", token.AccessToken, testAccessToken) } if slowDownCount.Load() < 2 { t.Errorf("expected at least 2 slow_down responses, got %d", slowDownCount.Load()) } } -func TestPollForToken_ExpiredToken(t *testing.T) { +// pollForTokenErrorTest is a shared helper for tests that expect pollForTokenWithProgress +// to return a specific error when the server responds with a terminal OAuth error code. +func pollForTokenErrorTest(t *testing.T, errCode, errDesc, expectedMsg string) { + t.Helper() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{ - "error": "expired_token", - "error_description": "Device code has expired", - }) + if err := json.NewEncoder(w).Encode(map[string]string{ + "error": errCode, + "error_description": errDesc, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } })) defer server.Close() @@ -149,53 +166,39 @@ func TestPollForToken_ExpiredToken(t *testing.T) { _, err := pollForTokenWithProgress(ctx, config, deviceAuth) if err == nil { - t.Fatal("expected error for expired token, got nil") + t.Fatal("expected error, got nil") } - if err.Error() != "device code expired, please restart the flow" { + if err.Error() != expectedMsg { t.Errorf("unexpected error message: %v", err) } } -func TestPollForToken_AccessDenied(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{ - "error": "access_denied", - "error_description": "User denied the authorization request", - }) - })) - defer server.Close() - - config := &oauth2.Config{ - ClientID: "test-client", - Endpoint: oauth2.Endpoint{TokenURL: server.URL}, - } - deviceAuth := &oauth2.DeviceAuthResponse{ - DeviceCode: "test-device-code", - Interval: 1, - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() +func TestPollForToken_ExpiredToken(t *testing.T) { + pollForTokenErrorTest(t, + "expired_token", + "Device code has expired", + "device code expired, please restart the flow", + ) +} - _, err := pollForTokenWithProgress(ctx, config, deviceAuth) - if err == nil { - t.Fatal("expected error for access denied, got nil") - } - if err.Error() != "user denied authorization" { - t.Errorf("unexpected error message: %v", err) - } +func TestPollForToken_AccessDenied(t *testing.T) { + pollForTokenErrorTest(t, + "access_denied", + "User denied the authorization request", + "user denied authorization", + ) } func TestPollForToken_ContextTimeout(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]string{ + if err := json.NewEncoder(w).Encode(map[string]string{ "error": "authorization_pending", "error_description": "User has not yet authorized", - }) + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } })) defer server.Close() @@ -236,12 +239,14 @@ func TestExchangeDeviceCode_Success(t *testing.T) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "access_token": "test-access-token", + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": testAccessToken, "refresh_token": "test-refresh-token", "token_type": "Bearer", "expires_in": 3600, - }) + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } })) defer server.Close() @@ -250,8 +255,8 @@ func TestExchangeDeviceCode_Success(t *testing.T) { if err != nil { t.Fatalf("expected success, got error: %v", err) } - if token.AccessToken != "test-access-token" { - t.Errorf("access token = %q, want %q", token.AccessToken, "test-access-token") + if token.AccessToken != testAccessToken { + t.Errorf("access token = %q, want %q", token.AccessToken, testAccessToken) } if token.RefreshToken != "test-refresh-token" { t.Errorf("refresh token = %q, want %q", token.RefreshToken, "test-refresh-token") diff --git a/tokens.go b/tokens.go index c2ded43..e18409f 100644 --- a/tokens.go +++ b/tokens.go @@ -58,7 +58,7 @@ func saveTokens(storage *TokenStorage) error { if err != nil { return fmt.Errorf("failed to acquire lock: %w", err) } - defer lock.release() + defer func() { _ = lock.release() }() var storageMap TokenStorageMap if existing, err := os.ReadFile(tokenFile); err == nil {