Skip to content

Commit c95a3e2

Browse files
- Default client override for testing.
1 parent 28d53d9 commit c95a3e2

3 files changed

Lines changed: 29 additions & 20 deletions

File tree

anysdk/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,11 @@ type anySdkHTTPClientConfigurator struct {
144144
func NewAnySdkClientConfigurator(
145145
rtCtx dto.RuntimeCtx,
146146
provName string,
147+
defaultClient *http.Client,
147148
) client.AnySdkClientConfigurator {
148149
return &anySdkHTTPClientConfigurator{
149150
runtimeCtx: rtCtx,
150-
authUtil: auth_util.NewAuthUtility(),
151+
authUtil: auth_util.NewAuthUtility(defaultClient),
151152
providerName: provName,
152153
}
153154
}

cmd/argparse/query.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"fmt"
66
"io"
7+
"net/http"
78
"os"
89
"runtime/pprof"
910

@@ -54,15 +55,16 @@ func parseExecPayload(
5455
}
5556

5657
type queryCmdPayload struct {
57-
rtCtx dto.RuntimeCtx
58-
provFilePath string
59-
svcFilePath string
60-
resourceStr string
61-
methodName string
62-
payload string
63-
payloadType string
64-
parameters map[string]interface{}
65-
auth map[string]*dto.AuthCtx
58+
rtCtx dto.RuntimeCtx
59+
provFilePath string
60+
svcFilePath string
61+
resourceStr string
62+
methodName string
63+
payload string
64+
payloadType string
65+
parameters map[string]interface{}
66+
auth map[string]*dto.AuthCtx
67+
defaultHttpClient *http.Client // for testing purposes
6668
}
6769

6870
func (qcp *queryCmdPayload) getService() (anysdk.Service, error) {
@@ -217,6 +219,7 @@ func runQueryCommand(authCtx *dto.AuthCtx, payload *queryCmdPayload) error {
217219
cc := anysdk.NewAnySdkClientConfigurator(
218220
payload.rtCtx,
219221
prov.GetName(),
222+
payload.defaultHttpClient,
220223
)
221224
response, apiErr := anysdk.CallFromSignature(
222225
cc, payload.rtCtx, authCtx, authCtx.Type, false, os.Stderr, prov, anysdk.NewAnySdkOpStoreDesignation(opStore), argList)

pkg/auth_util/auth_util.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,16 @@ type AuthUtility interface {
9999
}
100100

101101
type authUtil struct {
102-
// Placeholder for future implementation
102+
defaultClient *http.Client
103103
}
104104

105-
func NewAuthUtility() AuthUtility {
106-
return &authUtil{}
105+
func NewAuthUtility(defaultClient *http.Client) AuthUtility {
106+
if defaultClient == nil {
107+
defaultClient = http.DefaultClient
108+
}
109+
return &authUtil{
110+
defaultClient: defaultClient,
111+
}
107112
}
108113

109114
type transport struct {
@@ -341,7 +346,7 @@ func (au *authUtil) GoogleOauthServiceAccount(
341346
return nil, errToken
342347
}
343348
au.ActivateAuth(authCtx, "", dto.AuthServiceAccountStr)
344-
httpClient := netutils.GetHTTPClient(httpContext, http.DefaultClient)
349+
httpClient := netutils.GetHTTPClient(httpContext, au.defaultClient)
345350
return config.Client(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)), nil
346351
}
347352

@@ -355,7 +360,7 @@ func (au *authUtil) GenericOauthClientCredentials(
355360
return nil, errToken
356361
}
357362
au.ActivateAuth(authCtx, "", dto.ClientCredentialsStr)
358-
httpClient := netutils.GetHTTPClient(httpContext, http.DefaultClient)
363+
httpClient := netutils.GetHTTPClient(httpContext, au.defaultClient)
359364
return config.Client(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)), nil
360365
}
361366

@@ -365,7 +370,7 @@ func (au *authUtil) ApiTokenAuth(authCtx *dto.AuthCtx, httpContext netutils.HTTP
365370
return nil, fmt.Errorf("credentials error: %w", err)
366371
}
367372
au.ActivateAuth(authCtx, "", "api_key")
368-
httpClient := netutils.GetHTTPClient(httpContext, http.DefaultClient)
373+
httpClient := netutils.GetHTTPClient(httpContext, au.defaultClient)
369374
valPrefix := authCtx.ValuePrefix
370375
if enforceBearer {
371376
valPrefix = "Bearer "
@@ -404,7 +409,7 @@ func (au *authUtil) AwsSigningAuth(authCtx *dto.AuthCtx, httpContext netutils.HT
404409
au.ActivateAuth(authCtx, "", dto.AuthAWSSigningv4Str)
405410

406411
// Get the HTTP client from the runtime context.
407-
httpClient := netutils.GetHTTPClient(httpContext, http.DefaultClient)
412+
httpClient := netutils.GetHTTPClient(httpContext, au.defaultClient)
408413

409414
// Initialize the AWS signing transport with credentials and optional session token.
410415
tr, err := awssign.NewAwsSignTransport(httpClient.Transport, keyID, keyStr, sessionToken)
@@ -424,7 +429,7 @@ func (au *authUtil) BasicAuth(authCtx *dto.AuthCtx, httpContext netutils.HTTPCon
424429
return nil, fmt.Errorf("credentials error: %w", err)
425430
}
426431
au.ActivateAuth(authCtx, "", "basic")
427-
httpClient := netutils.GetHTTPClient(httpContext, http.DefaultClient)
432+
httpClient := netutils.GetHTTPClient(httpContext, au.defaultClient)
428433
tr, err := newTransport(b, AuthTypeBasic, authCtx.ValuePrefix, LocationHeader, "", httpClient.Transport)
429434
if err != nil {
430435
return nil, err
@@ -439,7 +444,7 @@ func (au *authUtil) CustomAuth(authCtx *dto.AuthCtx, httpContext netutils.HTTPCo
439444
return nil, fmt.Errorf("credentials error: %w", err)
440445
}
441446
au.ActivateAuth(authCtx, "", "custom")
442-
httpClient := netutils.GetHTTPClient(httpContext, http.DefaultClient)
447+
httpClient := netutils.GetHTTPClient(httpContext, au.defaultClient)
443448
tr, err := newTransport(b, AuthTypeCustom, authCtx.ValuePrefix, authCtx.Location, authCtx.Name, httpClient.Transport)
444449
if err != nil {
445450
return nil, err
@@ -482,7 +487,7 @@ func (au *authUtil) AzureDefaultAuth(authCtx *dto.AuthCtx, httpContext netutils.
482487
}
483488
tokenString := token.Token
484489
au.ActivateAuth(authCtx, "", "azure_default")
485-
httpClient := netutils.GetHTTPClient(httpContext, http.DefaultClient)
490+
httpClient := netutils.GetHTTPClient(httpContext, au.defaultClient)
486491
tr, err := newTransport([]byte(tokenString), AuthTypeBearer, "Bearer ", LocationHeader, "", httpClient.Transport)
487492
if err != nil {
488493
return nil, err

0 commit comments

Comments
 (0)