Skip to content

Commit 57ae461

Browse files
committed
refactor: refactor callback server for in-server token exchange and testing
- Refactor callback handling to exchange the authorization code for tokens within the callback server, passing an exchange function as a parameter. - Adjust error handling and response types to propagate token exchange errors and results instead of raw codes. - Update related tests to use new exchange function injection, improving test coverage for error cases and exchange failures. - Add comprehensive unit tests for callback server scenarios, including token exchange failure and better validation of test outcomes. - Increase HTTP server write timeout from 10 to 15 seconds during the callback. Signed-off-by: appleboy <appleboy.tw@gmail.com>
1 parent c4242cc commit 57ae461

3 files changed

Lines changed: 119 additions & 48 deletions

File tree

browser_flow.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ func performBrowserFlow(ctx context.Context) (*TokenStorage, bool, error) {
4343
fmt.Println("Browser opened. Please complete authorization in your browser.")
4444
fmt.Printf("Step 2: Waiting for callback on http://localhost:%d/callback ...\n", callbackPort)
4545

46-
code, err := startCallbackServer(ctx, callbackPort, state)
46+
storage, err := startCallbackServer(ctx, callbackPort, state,
47+
func(callbackCtx context.Context, code string) (*TokenStorage, error) {
48+
fmt.Println("Step 3: Exchanging authorization code for tokens...")
49+
return exchangeCode(callbackCtx, code, pkce.Verifier)
50+
})
4751
if err != nil {
4852
if errors.Is(err, ErrCallbackTimeout) {
4953
// User opened the browser but didn't complete authorization in time.
@@ -54,14 +58,7 @@ func performBrowserFlow(ctx context.Context) (*TokenStorage, bool, error) {
5458
)
5559
return nil, false, nil
5660
}
57-
return nil, false, fmt.Errorf("authorization failed: %w", err)
58-
}
59-
fmt.Println("Authorization code received!")
60-
61-
fmt.Println("Step 3: Exchanging authorization code for tokens...")
62-
storage, err := exchangeCode(ctx, code, pkce.Verifier)
63-
if err != nil {
64-
return nil, false, fmt.Errorf("token exchange failed: %w", err)
61+
return nil, false, fmt.Errorf("authentication failed: %w", err)
6562
}
6663
storage.Flow = "browser"
6764

callback.go

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,20 @@ var ErrCallbackTimeout = fmt.Errorf("browser authorization timed out")
2222

2323
// callbackResult holds the outcome of the local callback round-trip.
2424
type callbackResult struct {
25-
Code string
26-
Error string
27-
Desc string
25+
Storage *TokenStorage
26+
Error string
27+
Desc string
2828
}
2929

3030
// startCallbackServer starts a local HTTP server on the given port and waits
31-
// for the OAuth callback. It validates the returned state against expectedState
32-
// and returns the authorization code (or an error).
31+
// for the OAuth callback. It validates the returned state against expectedState,
32+
// calls exchangeFn to exchange the code for tokens, and returns the resulting
33+
// TokenStorage (or an error).
3334
//
3435
// The server shuts itself down after the first request.
35-
func startCallbackServer(ctx context.Context, port int, expectedState string) (string, error) {
36+
func startCallbackServer(ctx context.Context, port int, expectedState string,
37+
exchangeFn func(context.Context, string) (*TokenStorage, error),
38+
) (*TokenStorage, error) {
3639
resultCh := make(chan callbackResult, 1)
3740

3841
var once sync.Once
@@ -69,20 +72,26 @@ func startCallbackServer(ctx context.Context, port int, expectedState string) (s
6972
return
7073
}
7174

75+
storage, exchangeErr := exchangeFn(r.Context(), code)
76+
if exchangeErr != nil {
77+
writeCallbackPage(w, false, "token_exchange_failed", exchangeErr.Error())
78+
sendResult(callbackResult{Error: "token_exchange_failed", Desc: exchangeErr.Error()})
79+
return
80+
}
7281
writeCallbackPage(w, true, "", "")
73-
sendResult(callbackResult{Code: code})
82+
sendResult(callbackResult{Storage: storage})
7483
})
7584

7685
srv := &http.Server{
7786
Addr: fmt.Sprintf("127.0.0.1:%d", port),
7887
Handler: mux,
7988
ReadTimeout: 10 * time.Second,
80-
WriteTimeout: 10 * time.Second,
89+
WriteTimeout: 15 * time.Second,
8190
}
8291

8392
ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", srv.Addr)
8493
if err != nil {
85-
return "", fmt.Errorf("failed to start callback server on port %d: %w", port, err)
94+
return nil, fmt.Errorf("failed to start callback server on port %d: %w", port, err)
8695
}
8796

8897
go func() {
@@ -99,14 +108,14 @@ func startCallbackServer(ctx context.Context, port int, expectedState string) (s
99108
case result := <-resultCh:
100109
if result.Error != "" {
101110
if result.Desc != "" {
102-
return "", fmt.Errorf("%s: %s", result.Error, result.Desc)
111+
return nil, fmt.Errorf("%s: %s", result.Error, result.Desc)
103112
}
104-
return "", fmt.Errorf("%s", result.Error)
113+
return nil, fmt.Errorf("%s", result.Error)
105114
}
106-
return result.Code, nil
115+
return result.Storage, nil
107116

108117
case <-time.After(callbackTimeout):
109-
return "", fmt.Errorf("%w after %s", ErrCallbackTimeout, callbackTimeout)
118+
return nil, fmt.Errorf("%w after %s", ErrCallbackTimeout, callbackTimeout)
110119
}
111120
}
112121

callback_test.go

Lines changed: 91 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,53 @@ import (
1010
"time"
1111
)
1212

13+
type callbackServerResult struct {
14+
storage *TokenStorage
15+
err error
16+
}
17+
1318
// startCallbackServerAsync starts the callback server in a goroutine and
14-
// returns a channel that will receive the authorization code (or error string).
19+
// returns a channel that will receive the result (storage or error).
1520
func startCallbackServerAsync(
1621
t *testing.T, ctx context.Context, port int, state string,
17-
) chan string {
22+
exchangeFn func(context.Context, string) (*TokenStorage, error),
23+
) chan callbackServerResult {
1824
t.Helper()
19-
ch := make(chan string, 1)
25+
ch := make(chan callbackServerResult, 1)
2026
go func() {
21-
code, err := startCallbackServer(ctx, port, state)
22-
if err != nil {
23-
ch <- "ERROR:" + err.Error()
24-
} else {
25-
ch <- code
26-
}
27+
storage, err := startCallbackServer(ctx, port, state, exchangeFn)
28+
ch <- callbackServerResult{storage, err}
2729
}()
2830
// Give the server a moment to bind.
2931
time.Sleep(50 * time.Millisecond)
3032
return ch
3133
}
3234

35+
// noExchangeFn returns an exchange function that fails the test if called.
36+
func noExchangeFn(t *testing.T) func(context.Context, string) (*TokenStorage, error) {
37+
t.Helper()
38+
return func(_ context.Context, _ string) (*TokenStorage, error) {
39+
t.Error("exchangeFn should not be called")
40+
return nil, fmt.Errorf("should not be called")
41+
}
42+
}
43+
44+
// stubExchangeFn returns an exchange function that validates the received code
45+
// and returns a minimal TokenStorage on success.
46+
func stubExchangeFn(wantCode string) func(context.Context, string) (*TokenStorage, error) {
47+
return func(_ context.Context, gotCode string) (*TokenStorage, error) {
48+
if gotCode != wantCode {
49+
return nil, fmt.Errorf("unexpected code: got %q, want %q", gotCode, wantCode)
50+
}
51+
return &TokenStorage{AccessToken: "test-token"}, nil
52+
}
53+
}
54+
3355
func TestCallbackServer_Success(t *testing.T) {
3456
const port = 19101
3557
state := "test-state-success"
3658

37-
ch := startCallbackServerAsync(t, context.Background(), port, state)
59+
ch := startCallbackServerAsync(t, context.Background(), port, state, stubExchangeFn("mycode123"))
3860

3961
callbackURL := fmt.Sprintf(
4062
"http://127.0.0.1:%d/callback?code=mycode123&state=%s",
@@ -56,8 +78,11 @@ func TestCallbackServer_Success(t *testing.T) {
5678

5779
select {
5880
case result := <-ch:
59-
if result != "mycode123" {
60-
t.Errorf("expected code mycode123, got: %s", result)
81+
if result.err != nil {
82+
t.Errorf("expected success, got error: %v", result.err)
83+
}
84+
if result.storage == nil {
85+
t.Error("expected non-nil storage")
6186
}
6287
case <-time.After(3 * time.Second):
6388
t.Fatal("timed out waiting for callback result")
@@ -68,7 +93,7 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
6893
const port = 19102
6994
state := "expected-state"
7095

71-
ch := startCallbackServerAsync(t, context.Background(), port, state)
96+
ch := startCallbackServerAsync(t, context.Background(), port, state, noExchangeFn(t))
7297

7398
callbackURL := fmt.Sprintf(
7499
"http://127.0.0.1:%d/callback?code=mycode&state=wrong-state",
@@ -87,8 +112,8 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
87112

88113
select {
89114
case result := <-ch:
90-
if !strings.HasPrefix(result, "ERROR:") {
91-
t.Errorf("expected error for state mismatch, got: %s", result)
115+
if result.err == nil {
116+
t.Error("expected error for state mismatch, got nil")
92117
}
93118
case <-time.After(3 * time.Second):
94119
t.Fatal("timed out waiting for callback result")
@@ -99,7 +124,7 @@ func TestCallbackServer_OAuthError(t *testing.T) {
99124
const port = 19103
100125
state := "state-for-error"
101126

102-
ch := startCallbackServerAsync(t, context.Background(), port, state)
127+
ch := startCallbackServerAsync(t, context.Background(), port, state, noExchangeFn(t))
103128

104129
callbackURL := fmt.Sprintf(
105130
"http://127.0.0.1:%d/callback?error=access_denied&error_description=User+denied&state=%s",
@@ -118,11 +143,48 @@ func TestCallbackServer_OAuthError(t *testing.T) {
118143

119144
select {
120145
case result := <-ch:
121-
if !strings.HasPrefix(result, "ERROR:") {
122-
t.Errorf("expected error for access_denied, got: %s", result)
146+
if result.err == nil {
147+
t.Error("expected error for access_denied, got nil")
123148
}
124-
if !strings.Contains(result, "access_denied") {
125-
t.Errorf("expected error to mention access_denied, got: %s", result)
149+
if !strings.Contains(result.err.Error(), "access_denied") {
150+
t.Errorf("expected error to mention access_denied, got: %v", result.err)
151+
}
152+
case <-time.After(3 * time.Second):
153+
t.Fatal("timed out waiting for callback result")
154+
}
155+
}
156+
157+
func TestCallbackServer_ExchangeFailure(t *testing.T) {
158+
const port = 19106
159+
state := "state-for-exchange-failure"
160+
161+
ch := startCallbackServerAsync(t, context.Background(), port, state,
162+
func(_ context.Context, _ string) (*TokenStorage, error) {
163+
return nil, fmt.Errorf("unauthorized_client: unauthorized_client")
164+
})
165+
166+
callbackURL := fmt.Sprintf(
167+
"http://127.0.0.1:%d/callback?code=mycode&state=%s",
168+
port, state,
169+
)
170+
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
171+
if err != nil {
172+
t.Fatalf("GET callback failed: %v", err)
173+
}
174+
defer resp.Body.Close()
175+
176+
body, _ := io.ReadAll(resp.Body)
177+
if !strings.Contains(string(body), "Authorization Failed") {
178+
t.Errorf("expected failure page for exchange error, got: %s", string(body))
179+
}
180+
181+
select {
182+
case result := <-ch:
183+
if result.err == nil {
184+
t.Error("expected error for exchange failure, got nil")
185+
}
186+
if !strings.Contains(result.err.Error(), "unauthorized_client") {
187+
t.Errorf("expected error to mention unauthorized_client, got: %v", result.err)
126188
}
127189
case <-time.After(3 * time.Second):
128190
t.Fatal("timed out waiting for callback result")
@@ -133,7 +195,7 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
133195
const port = 19105
134196
state := "test-state-double"
135197

136-
ch := startCallbackServerAsync(t, context.Background(), port, state)
198+
ch := startCallbackServerAsync(t, context.Background(), port, state, stubExchangeFn("mycode"))
137199

138200
url := fmt.Sprintf("http://127.0.0.1:%d/callback?code=mycode&state=%s", port, state)
139201

@@ -158,8 +220,11 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
158220

159221
select {
160222
case result := <-ch:
161-
if result != "mycode" {
162-
t.Errorf("expected mycode, got: %s", result)
223+
if result.err != nil {
224+
t.Errorf("expected success, got error: %v", result.err)
225+
}
226+
if result.storage == nil {
227+
t.Error("expected non-nil storage")
163228
}
164229
case <-time.After(3 * time.Second):
165230
t.Fatal("timed out waiting for callback result")
@@ -170,7 +235,7 @@ func TestCallbackServer_MissingCode(t *testing.T) {
170235
const port = 19104
171236
state := "state-for-missing-code"
172237

173-
ch := startCallbackServerAsync(t, context.Background(), port, state)
238+
ch := startCallbackServerAsync(t, context.Background(), port, state, noExchangeFn(t))
174239

175240
callbackURL := fmt.Sprintf(
176241
"http://127.0.0.1:%d/callback?state=%s",
@@ -184,8 +249,8 @@ func TestCallbackServer_MissingCode(t *testing.T) {
184249

185250
select {
186251
case result := <-ch:
187-
if !strings.HasPrefix(result, "ERROR:") {
188-
t.Errorf("expected error for missing code, got: %s", result)
252+
if result.err == nil {
253+
t.Error("expected error for missing code, got nil")
189254
}
190255
case <-time.After(3 * time.Second):
191256
t.Fatal("timed out waiting for callback result")

0 commit comments

Comments
 (0)