@@ -10,31 +10,53 @@ import (
1010 "time"
1111)
1212
13+ type callbackServerResult struct {
14+ storage * TokenStorage
15+ err error
16+ }
17+
1318// startCallbackServerAsync starts the callback server in a goroutine and
14- // returns a channel that will receive the authorization code ( or error string ).
19+ // returns a channel that will receive the result (storage or error).
1520func startCallbackServerAsync (
1621 t * testing.T , ctx context.Context , port int , state string ,
17- ) chan string {
22+ exchangeFn func (context.Context , string ) (* TokenStorage , error ),
23+ ) chan callbackServerResult {
1824 t .Helper ()
19- ch := make (chan string , 1 )
25+ ch := make (chan callbackServerResult , 1 )
2026 go func () {
21- code , err := startCallbackServer (ctx , port , state )
22- if err != nil {
23- ch <- "ERROR:" + err .Error ()
24- } else {
25- ch <- code
26- }
27+ storage , err := startCallbackServer (ctx , port , state , exchangeFn )
28+ ch <- callbackServerResult {storage , err }
2729 }()
2830 // Give the server a moment to bind.
2931 time .Sleep (50 * time .Millisecond )
3032 return ch
3133}
3234
35+ // noExchangeFn returns an exchange function that fails the test if called.
36+ func noExchangeFn (t * testing.T ) func (context.Context , string ) (* TokenStorage , error ) {
37+ t .Helper ()
38+ return func (_ context.Context , _ string ) (* TokenStorage , error ) {
39+ t .Error ("exchangeFn should not be called" )
40+ return nil , fmt .Errorf ("should not be called" )
41+ }
42+ }
43+
44+ // stubExchangeFn returns an exchange function that validates the received code
45+ // and returns a minimal TokenStorage on success.
46+ func stubExchangeFn (wantCode string ) func (context.Context , string ) (* TokenStorage , error ) {
47+ return func (_ context.Context , gotCode string ) (* TokenStorage , error ) {
48+ if gotCode != wantCode {
49+ return nil , fmt .Errorf ("unexpected code: got %q, want %q" , gotCode , wantCode )
50+ }
51+ return & TokenStorage {AccessToken : "test-token" }, nil
52+ }
53+ }
54+
3355func TestCallbackServer_Success (t * testing.T ) {
3456 const port = 19101
3557 state := "test-state-success"
3658
37- ch := startCallbackServerAsync (t , context .Background (), port , state )
59+ ch := startCallbackServerAsync (t , context .Background (), port , state , stubExchangeFn ( "mycode123" ) )
3860
3961 callbackURL := fmt .Sprintf (
4062 "http://127.0.0.1:%d/callback?code=mycode123&state=%s" ,
@@ -56,8 +78,11 @@ func TestCallbackServer_Success(t *testing.T) {
5678
5779 select {
5880 case result := <- ch :
59- if result != "mycode123" {
60- t .Errorf ("expected code mycode123, got: %s" , result )
81+ if result .err != nil {
82+ t .Errorf ("expected success, got error: %v" , result .err )
83+ }
84+ if result .storage == nil {
85+ t .Error ("expected non-nil storage" )
6186 }
6287 case <- time .After (3 * time .Second ):
6388 t .Fatal ("timed out waiting for callback result" )
@@ -68,7 +93,7 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
6893 const port = 19102
6994 state := "expected-state"
7095
71- ch := startCallbackServerAsync (t , context .Background (), port , state )
96+ ch := startCallbackServerAsync (t , context .Background (), port , state , noExchangeFn ( t ) )
7297
7398 callbackURL := fmt .Sprintf (
7499 "http://127.0.0.1:%d/callback?code=mycode&state=wrong-state" ,
@@ -87,8 +112,8 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
87112
88113 select {
89114 case result := <- ch :
90- if ! strings . HasPrefix ( result , "ERROR:" ) {
91- t .Errorf ("expected error for state mismatch, got: %s" , result )
115+ if result . err == nil {
116+ t .Error ("expected error for state mismatch, got nil" )
92117 }
93118 case <- time .After (3 * time .Second ):
94119 t .Fatal ("timed out waiting for callback result" )
@@ -99,7 +124,7 @@ func TestCallbackServer_OAuthError(t *testing.T) {
99124 const port = 19103
100125 state := "state-for-error"
101126
102- ch := startCallbackServerAsync (t , context .Background (), port , state )
127+ ch := startCallbackServerAsync (t , context .Background (), port , state , noExchangeFn ( t ) )
103128
104129 callbackURL := fmt .Sprintf (
105130 "http://127.0.0.1:%d/callback?error=access_denied&error_description=User+denied&state=%s" ,
@@ -118,11 +143,48 @@ func TestCallbackServer_OAuthError(t *testing.T) {
118143
119144 select {
120145 case result := <- ch :
121- if ! strings . HasPrefix ( result , "ERROR:" ) {
122- t .Errorf ("expected error for access_denied, got: %s" , result )
146+ if result . err == nil {
147+ t .Error ("expected error for access_denied, got nil" )
123148 }
124- if ! strings .Contains (result , "access_denied" ) {
125- t .Errorf ("expected error to mention access_denied, got: %s" , result )
149+ if ! strings .Contains (result .err .Error (), "access_denied" ) {
150+ t .Errorf ("expected error to mention access_denied, got: %v" , result .err )
151+ }
152+ case <- time .After (3 * time .Second ):
153+ t .Fatal ("timed out waiting for callback result" )
154+ }
155+ }
156+
157+ func TestCallbackServer_ExchangeFailure (t * testing.T ) {
158+ const port = 19106
159+ state := "state-for-exchange-failure"
160+
161+ ch := startCallbackServerAsync (t , context .Background (), port , state ,
162+ func (_ context.Context , _ string ) (* TokenStorage , error ) {
163+ return nil , fmt .Errorf ("unauthorized_client: unauthorized_client" )
164+ })
165+
166+ callbackURL := fmt .Sprintf (
167+ "http://127.0.0.1:%d/callback?code=mycode&state=%s" ,
168+ port , state ,
169+ )
170+ resp , err := http .Get (callbackURL ) //nolint:noctx,gosec
171+ if err != nil {
172+ t .Fatalf ("GET callback failed: %v" , err )
173+ }
174+ defer resp .Body .Close ()
175+
176+ body , _ := io .ReadAll (resp .Body )
177+ if ! strings .Contains (string (body ), "Authorization Failed" ) {
178+ t .Errorf ("expected failure page for exchange error, got: %s" , string (body ))
179+ }
180+
181+ select {
182+ case result := <- ch :
183+ if result .err == nil {
184+ t .Error ("expected error for exchange failure, got nil" )
185+ }
186+ if ! strings .Contains (result .err .Error (), "unauthorized_client" ) {
187+ t .Errorf ("expected error to mention unauthorized_client, got: %v" , result .err )
126188 }
127189 case <- time .After (3 * time .Second ):
128190 t .Fatal ("timed out waiting for callback result" )
@@ -133,7 +195,7 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
133195 const port = 19105
134196 state := "test-state-double"
135197
136- ch := startCallbackServerAsync (t , context .Background (), port , state )
198+ ch := startCallbackServerAsync (t , context .Background (), port , state , stubExchangeFn ( "mycode" ) )
137199
138200 url := fmt .Sprintf ("http://127.0.0.1:%d/callback?code=mycode&state=%s" , port , state )
139201
@@ -158,8 +220,11 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
158220
159221 select {
160222 case result := <- ch :
161- if result != "mycode" {
162- t .Errorf ("expected mycode, got: %s" , result )
223+ if result .err != nil {
224+ t .Errorf ("expected success, got error: %v" , result .err )
225+ }
226+ if result .storage == nil {
227+ t .Error ("expected non-nil storage" )
163228 }
164229 case <- time .After (3 * time .Second ):
165230 t .Fatal ("timed out waiting for callback result" )
@@ -170,7 +235,7 @@ func TestCallbackServer_MissingCode(t *testing.T) {
170235 const port = 19104
171236 state := "state-for-missing-code"
172237
173- ch := startCallbackServerAsync (t , context .Background (), port , state )
238+ ch := startCallbackServerAsync (t , context .Background (), port , state , noExchangeFn ( t ) )
174239
175240 callbackURL := fmt .Sprintf (
176241 "http://127.0.0.1:%d/callback?state=%s" ,
@@ -184,8 +249,8 @@ func TestCallbackServer_MissingCode(t *testing.T) {
184249
185250 select {
186251 case result := <- ch :
187- if ! strings . HasPrefix ( result , "ERROR:" ) {
188- t .Errorf ("expected error for missing code, got: %s" , result )
252+ if result . err == nil {
253+ t .Error ("expected error for missing code, got nil" )
189254 }
190255 case <- time .After (3 * time .Second ):
191256 t .Fatal ("timed out waiting for callback result" )
0 commit comments