Skip to content

Commit a1d3851

Browse files
committed
feat: add context support for HTTP and token operations
- Add context cancellation support to all HTTP and token-related operations for better handling of interruptions and timeouts - Pass context to refreshAccessToken, verifyToken, and makeAPICallWithAutoRefresh functions and update all relevant call sites - Use signal.NotifyContext in main for graceful shutdown on SIGTERM or interrupt signals - Update test cases to provide context to refreshAccessToken Signed-off-by: appleboy <appleboy.tw@gmail.com>
1 parent 16eb90a commit a1d3851

2 files changed

Lines changed: 31 additions & 21 deletions

File tree

main.go

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"net/http"
1212
"net/url"
1313
"os"
14+
"os/signal"
1415
"strings"
16+
"syscall"
1517
"time"
1618

1719
retry "github.com/appleboy/go-httpretry"
@@ -223,7 +225,9 @@ func main() {
223225

224226
fmt.Printf("=== OAuth Device Code Flow CLI Demo (with Refresh Token) ===\n\n")
225227

226-
ctx := context.Background()
228+
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
229+
defer stop()
230+
227231
var storage *TokenStorage
228232

229233
// Try to load existing tokens
@@ -238,7 +242,7 @@ func main() {
238242
fmt.Println("Access token expired, refreshing...")
239243

240244
// Try to refresh
241-
newStorage, err := refreshAccessToken(storage.RefreshToken)
245+
newStorage, err := refreshAccessToken(ctx, storage.RefreshToken)
242246
if err != nil {
243247
fmt.Printf("Refresh failed: %v\n", err)
244248
fmt.Println("Starting new device flow...")
@@ -275,15 +279,15 @@ func main() {
275279

276280
// Verify token
277281
fmt.Println("\nVerifying token...")
278-
if err := verifyToken(storage.AccessToken); err != nil {
282+
if err := verifyToken(ctx, storage.AccessToken); err != nil {
279283
fmt.Printf("Token verification failed: %v\n", err)
280284
} else {
281285
fmt.Println("Token verified successfully!")
282286
}
283287

284288
// Demonstrate automatic refresh on 401
285289
fmt.Println("\nDemonstrating automatic refresh on API call...")
286-
if err := makeAPICallWithAutoRefresh(storage); err != nil {
290+
if err := makeAPICallWithAutoRefresh(ctx, storage); err != nil {
287291
// Check if error is due to expired refresh token
288292
if err == ErrRefreshTokenExpired {
289293
fmt.Println("Refresh token expired, re-authenticating...")
@@ -295,7 +299,7 @@ func main() {
295299

296300
// Retry API call with new tokens
297301
fmt.Println("Retrying API call with new tokens...")
298-
if err := makeAPICallWithAutoRefresh(storage); err != nil {
302+
if err := makeAPICallWithAutoRefresh(ctx, storage); err != nil {
299303
fmt.Printf("API call failed after re-authentication: %v\n", err)
300304
os.Exit(1)
301305
}
@@ -606,19 +610,19 @@ func exchangeDeviceCode(
606610
return token, nil
607611
}
608612

609-
func verifyToken(accessToken string) error {
613+
func verifyToken(ctx context.Context, accessToken string) error {
610614
// Create request with timeout
611-
ctx, cancel := context.WithTimeout(context.Background(), tokenVerificationTimeout)
615+
reqCtx, cancel := context.WithTimeout(ctx, tokenVerificationTimeout)
612616
defer cancel()
613617

614-
req, err := http.NewRequestWithContext(ctx, "GET", serverURL+"/oauth/tokeninfo", nil)
618+
req, err := http.NewRequestWithContext(reqCtx, "GET", serverURL+"/oauth/tokeninfo", nil)
615619
if err != nil {
616620
return fmt.Errorf("failed to create request: %w", err)
617621
}
618622
req.Header.Set("Authorization", "Bearer "+accessToken)
619623

620624
// Execute request with retry logic
621-
resp, err := retryClient.DoWithContext(ctx, req)
625+
resp, err := retryClient.DoWithContext(reqCtx, req)
622626
if err != nil {
623627
return fmt.Errorf("request failed: %w", err)
624628
}
@@ -727,9 +731,9 @@ func saveTokens(storage *TokenStorage) error {
727731
}
728732

729733
// refreshAccessToken refreshes the access token using refresh token
730-
func refreshAccessToken(refreshToken string) (*TokenStorage, error) {
734+
func refreshAccessToken(ctx context.Context, refreshToken string) (*TokenStorage, error) {
731735
// Create request with timeout
732-
ctx, cancel := context.WithTimeout(context.Background(), refreshTokenTimeout)
736+
reqCtx, cancel := context.WithTimeout(ctx, refreshTokenTimeout)
733737
defer cancel()
734738

735739
data := url.Values{}
@@ -738,7 +742,7 @@ func refreshAccessToken(refreshToken string) (*TokenStorage, error) {
738742
data.Set("client_id", clientID)
739743

740744
req, err := http.NewRequestWithContext(
741-
ctx,
745+
reqCtx,
742746
"POST",
743747
serverURL+"/oauth/token",
744748
strings.NewReader(data.Encode()),
@@ -749,7 +753,7 @@ func refreshAccessToken(refreshToken string) (*TokenStorage, error) {
749753
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
750754

751755
// Execute request with retry logic
752-
resp, err := retryClient.DoWithContext(ctx, req)
756+
resp, err := retryClient.DoWithContext(reqCtx, req)
753757
if err != nil {
754758
return nil, fmt.Errorf("refresh request failed: %w", err)
755759
}
@@ -819,15 +823,18 @@ func refreshAccessToken(refreshToken string) (*TokenStorage, error) {
819823
}
820824

821825
// makeAPICallWithAutoRefresh demonstrates automatic refresh on 401
822-
func makeAPICallWithAutoRefresh(storage *TokenStorage) error {
826+
func makeAPICallWithAutoRefresh(ctx context.Context, storage *TokenStorage) error {
823827
// Try with current access token
824-
req, err := http.NewRequest("GET", serverURL+"/oauth/tokeninfo", nil)
828+
reqCtx, cancel := context.WithTimeout(ctx, tokenVerificationTimeout)
829+
defer cancel()
830+
831+
req, err := http.NewRequestWithContext(reqCtx, "GET", serverURL+"/oauth/tokeninfo", nil)
825832
if err != nil {
826833
return fmt.Errorf("failed to create request: %w", err)
827834
}
828835
req.Header.Set("Authorization", "Bearer "+storage.AccessToken)
829836

830-
resp, err := retryClient.DoWithContext(context.Background(), req)
837+
resp, err := retryClient.DoWithContext(reqCtx, req)
831838
if err != nil {
832839
return fmt.Errorf("API request failed: %w", err)
833840
}
@@ -837,7 +844,7 @@ func makeAPICallWithAutoRefresh(storage *TokenStorage) error {
837844
if resp.StatusCode == http.StatusUnauthorized {
838845
fmt.Println("Access token rejected (401), refreshing...")
839846

840-
newStorage, err := refreshAccessToken(storage.RefreshToken)
847+
newStorage, err := refreshAccessToken(ctx, storage.RefreshToken)
841848
if err != nil {
842849
// If refresh token is expired, propagate the error to trigger device flow
843850
if err == ErrRefreshTokenExpired {
@@ -855,13 +862,16 @@ func makeAPICallWithAutoRefresh(storage *TokenStorage) error {
855862
fmt.Println("Token refreshed, retrying API call...")
856863

857864
// Retry with new token
858-
req, err = http.NewRequest("GET", serverURL+"/oauth/tokeninfo", nil)
865+
retryCtx, retryCancel := context.WithTimeout(ctx, tokenVerificationTimeout)
866+
defer retryCancel()
867+
868+
req, err = http.NewRequestWithContext(retryCtx, "GET", serverURL+"/oauth/tokeninfo", nil)
859869
if err != nil {
860870
return fmt.Errorf("failed to create retry request: %w", err)
861871
}
862872
req.Header.Set("Authorization", "Bearer "+storage.AccessToken)
863873

864-
resp, err = retryClient.DoWithContext(context.Background(), req)
874+
resp, err = retryClient.DoWithContext(retryCtx, req)
865875
if err != nil {
866876
return fmt.Errorf("retry failed: %w", err)
867877
}

main_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ func TestRefreshAccessToken_RotationMode(t *testing.T) {
396396
serverURL = server.URL
397397

398398
// Call refreshAccessToken
399-
storage, err := refreshAccessToken(tt.oldRefreshToken)
399+
storage, err := refreshAccessToken(context.Background(), tt.oldRefreshToken)
400400
if err != nil {
401401
t.Fatalf("refreshAccessToken() error = %v", err)
402402
}
@@ -527,7 +527,7 @@ func TestRefreshAccessToken_ValidationErrors(t *testing.T) {
527527
serverURL = server.URL
528528

529529
// Call refreshAccessToken
530-
_, err := refreshAccessToken("test-refresh-token")
530+
_, err := refreshAccessToken(context.Background(), "test-refresh-token")
531531

532532
if tt.wantErr {
533533
if err == nil {

0 commit comments

Comments
 (0)