From 9ad25e5688916d06504616654ae0ca6b0a5e754c Mon Sep 17 00:00:00 2001 From: Peter Guy Date: Tue, 10 Mar 2026 22:04:28 -0700 Subject: [PATCH 1/2] refactor: validate and parse the endpoint and proxy at program load Amp-Thread-ID: https://ampcode.com/threads/T-019cdb3f-f7de-750b-b4c3-13762c7dfc11 Co-authored-by: Amp --- cmd/src/batch_common.go | 4 +- cmd/src/batch_remote.go | 16 +- cmd/src/batch_repositories.go | 2 +- cmd/src/code_intel_upload.go | 26 +-- cmd/src/debug_compose.go | 2 +- cmd/src/debug_kube.go | 2 +- cmd/src/debug_server.go | 2 +- cmd/src/login.go | 54 +++--- cmd/src/login_oauth.go | 35 ++-- cmd/src/login_test.go | 76 ++++---- cmd/src/login_validate.go | 23 +-- cmd/src/main.go | 123 +++++++----- cmd/src/main_test.go | 215 ++++++++++++--------- cmd/src/search.go | 2 +- cmd/src/search_jobs.go | 4 +- cmd/src/search_jobs_logs.go | 2 +- cmd/src/search_jobs_results.go | 2 +- cmd/src/search_stream.go | 6 +- cmd/src/search_stream_test.go | 4 +- cmd/src/users_prune.go | 4 +- internal/api/api.go | 12 +- internal/batches/executor/executor_test.go | 7 +- internal/batches/repozip/fetcher_test.go | 16 +- internal/oauth/flow.go | 50 +++-- internal/oauth/flow_test.go | 40 ++-- internal/oauth/http_transport.go | 14 +- internal/secrets/keyring.go | 6 +- internal/secrets/keyring_test.go | 17 +- 28 files changed, 441 insertions(+), 325 deletions(-) diff --git a/cmd/src/batch_common.go b/cmd/src/batch_common.go index 21fdce2e21..e294347dcf 100644 --- a/cmd/src/batch_common.go +++ b/cmd/src/batch_common.go @@ -537,7 +537,7 @@ func executeBatchSpec(ctx context.Context, opts executeBatchSpecOpts) (err error if err != nil { return execUI.CreatingBatchSpecError(lr.MaxUnlicensedChangesets, err) } - previewURL := cfg.Endpoint + url + previewURL := cfg.endpointURL.JoinPath(url).String() execUI.CreatingBatchSpecSuccess(previewURL) hasWorkspaceFiles := false @@ -567,7 +567,7 @@ func executeBatchSpec(ctx context.Context, opts executeBatchSpecOpts) (err error if err != nil { return err } - execUI.ApplyingBatchSpecSuccess(cfg.Endpoint + batch.URL) + execUI.ApplyingBatchSpecSuccess(cfg.endpointURL.JoinPath(batch.URL).String()) return nil } diff --git a/cmd/src/batch_remote.go b/cmd/src/batch_remote.go index 7dd7628c03..c1d7209675 100644 --- a/cmd/src/batch_remote.go +++ b/cmd/src/batch_remote.go @@ -5,7 +5,6 @@ import ( "flag" "fmt" cliLog "log" - "strings" "time" "github.com/sourcegraph/sourcegraph/lib/errors" @@ -155,13 +154,14 @@ Examples: } ui.ExecutingBatchSpecSuccess() - executionURL := fmt.Sprintf( - "%s/%s/batch-changes/%s/executions/%s", - strings.TrimSuffix(cfg.Endpoint, "/"), - strings.TrimPrefix(namespace.URL, "/"), - batchChangeName, - batchSpecID, - ) + executionURL := cfg.endpointURL.JoinPath( + fmt.Sprintf( + "%s/batch-changes/%s/executions/%s", + namespace.URL, + batchChangeName, + batchSpecID, + ), + ).String() ui.RemoteSuccess(executionURL) return nil diff --git a/cmd/src/batch_repositories.go b/cmd/src/batch_repositories.go index b02a4f5e58..97a7ea6cfe 100644 --- a/cmd/src/batch_repositories.go +++ b/cmd/src/batch_repositories.go @@ -131,7 +131,7 @@ Examples: Max: max, RepoCount: len(repos), Repos: repos, - SourcegraphEndpoint: cfg.Endpoint, + SourcegraphEndpoint: cfg.endpointURL.String(), }); err != nil { return err } diff --git a/cmd/src/code_intel_upload.go b/cmd/src/code_intel_upload.go index 303e2199e0..d5e686babc 100644 --- a/cmd/src/code_intel_upload.go +++ b/cmd/src/code_intel_upload.go @@ -7,7 +7,6 @@ import ( "flag" "fmt" "io" - "net/url" "os" "strings" "time" @@ -87,10 +86,7 @@ func handleCodeIntelUpload(args []string) error { return handleUploadError(uploadOptions.SourcegraphInstanceOptions.AccessToken, err) } - uploadURL, err := makeCodeIntelUploadURL(uploadID) - if err != nil { - return err - } + uploadURL := makeCodeIntelUploadURL(uploadID) if codeintelUploadFlags.json { serialized, err := json.Marshal(map[string]any{ @@ -132,7 +128,7 @@ func codeintelUploadOptions(out *output.Output) upload.UploadOptions { associatedIndexID = &codeintelUploadFlags.associatedIndexID } - cfg.AdditionalHeaders["Content-Type"] = "application/x-protobuf+scip" + cfg.additionalHeaders["Content-Type"] = "application/x-protobuf+scip" logger := upload.NewRequestLogger( os.Stdout, @@ -153,9 +149,9 @@ func codeintelUploadOptions(out *output.Output) upload.UploadOptions { AssociatedIndexID: associatedIndexID, }, SourcegraphInstanceOptions: upload.SourcegraphInstanceOptions{ - SourcegraphURL: cfg.Endpoint, - AccessToken: cfg.AccessToken, - AdditionalHeaders: cfg.AdditionalHeaders, + SourcegraphURL: cfg.endpointURL.String(), + AccessToken: cfg.accessToken, + AdditionalHeaders: cfg.additionalHeaders, MaxRetries: 5, RetryInterval: time.Second, Path: codeintelUploadFlags.uploadRoute, @@ -191,16 +187,12 @@ func printInferredArguments(out *output.Output) { // makeCodeIntelUploadURL constructs a URL to the upload with the given internal identifier. // The base of the URL is constructed from the configured Sourcegraph instance. -func makeCodeIntelUploadURL(uploadID int) (string, error) { - url, err := url.Parse(cfg.Endpoint) - if err != nil { - return "", err - } - +func makeCodeIntelUploadURL(uploadID int) string { + // Careful: copy by dereference makes a shallow copy, so User is not duplicated. + url := *cfg.endpointURL graphqlID := base64.URLEncoding.EncodeToString(fmt.Appendf(nil, `SCIPUpload:%d`, uploadID)) url.Path = codeintelUploadFlags.repo + "/-/code-intelligence/uploads/" + graphqlID - url.User = nil - return url.String(), nil + return url.String() } type errorWithHint struct { diff --git a/cmd/src/debug_compose.go b/cmd/src/debug_compose.go index 8e26d95b04..32c746065d 100644 --- a/cmd/src/debug_compose.go +++ b/cmd/src/debug_compose.go @@ -75,7 +75,7 @@ Examples: return errors.Wrap(err, "failed to get containers for subcommand with err") } // Safety check user knows what they are targeting with this debug command - log.Printf("This command will archive docker-cli data for %d containers\n SRC_ENDPOINT: %v\n Output filename: %v", len(containers), cfg.Endpoint, base) + log.Printf("This command will archive docker-cli data for %d containers\n SRC_ENDPOINT: %v\n Output filename: %v", len(containers), cfg.endpointURL, base) if verified, _ := verify("Do you want to start writing to an archive?"); !verified { return nil } diff --git a/cmd/src/debug_kube.go b/cmd/src/debug_kube.go index 69af7571e9..24f0b955c0 100644 --- a/cmd/src/debug_kube.go +++ b/cmd/src/debug_kube.go @@ -84,7 +84,7 @@ Examples: return errors.Wrapf(err, "failed to get current-context") } // Safety check user knows what they've targeted with this command - log.Printf("Archiving kubectl data for %d pods\n SRC_ENDPOINT: %v\n Context: %s Namespace: %v\n Output filename: %v", len(pods.Items), cfg.Endpoint, kubectx, namespace, base) + log.Printf("Archiving kubectl data for %d pods\n SRC_ENDPOINT: %v\n Context: %s Namespace: %v\n Output filename: %v", len(pods.Items), cfg.endpointURL, kubectx, namespace, base) if verified, _ := verify("Do you want to start writing to an archive?"); !verified { return nil } diff --git a/cmd/src/debug_server.go b/cmd/src/debug_server.go index 8ef59fc02a..d219ec7fdb 100644 --- a/cmd/src/debug_server.go +++ b/cmd/src/debug_server.go @@ -72,7 +72,7 @@ Examples: defer zw.Close() // Safety check user knows what they are targeting with this debug command - log.Printf("This command will archive docker-cli data for container: %s\n SRC_ENDPOINT: %s\n Output filename: %s", container, cfg.Endpoint, base) + log.Printf("This command will archive docker-cli data for container: %s\n SRC_ENDPOINT: %s\n Output filename: %s", container, cfg.endpointURL, base) if verified, _ := verify("Do you want to start writing to an archive?"); !verified { return nil } diff --git a/cmd/src/login.go b/cmd/src/login.go index 37e8bbe1bd..60cc1698b6 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "io" + "net/url" "os" "github.com/sourcegraph/src-cli/internal/api" @@ -48,23 +49,26 @@ Examples: if err := flagSet.Parse(args); err != nil { return err } - endpoint := cfg.Endpoint + + var loginEndpointURL *url.URL if flagSet.NArg() >= 1 { - endpoint = flagSet.Arg(0) - } - if endpoint == "" { - return cmderrors.Usage("expected exactly one argument: the Sourcegraph URL, or SRC_ENDPOINT to be set") + arg := flagSet.Arg(0) + u, err := parseEndpoint(arg) + if err != nil { + return cmderrors.Usage(fmt.Sprintf("invalid endpoint URL: %s", arg)) + } + loginEndpointURL = u } client := cfg.apiClient(apiFlags, io.Discard) return loginCmd(context.Background(), loginParams{ - cfg: cfg, - client: client, - endpoint: endpoint, - out: os.Stdout, - apiFlags: apiFlags, - oauthClient: oauth.NewClient(oauth.DefaultClientID), + cfg: cfg, + client: client, + out: os.Stdout, + apiFlags: apiFlags, + oauthClient: oauth.NewClient(oauth.DefaultClientID), + loginEndpointURL: loginEndpointURL, }) } @@ -76,12 +80,12 @@ Examples: } type loginParams struct { - cfg *config - client api.Client - endpoint string - out io.Writer - apiFlags *api.Flags - oauthClient oauth.Client + cfg *config + client api.Client + out io.Writer + apiFlags *api.Flags + oauthClient oauth.Client + loginEndpointURL *url.URL } type loginFlow func(context.Context, loginParams) error @@ -96,9 +100,9 @@ const ( ) func loginCmd(ctx context.Context, p loginParams) error { - if p.cfg.ConfigFilePath != "" { + if p.cfg.configFilePath != "" { fmt.Fprintln(p.out) - fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.ConfigFilePath) + fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.configFilePath) } _, flow := selectLoginFlow(p) @@ -107,15 +111,13 @@ func loginCmd(ctx context.Context, p loginParams) error { // selectLoginFlow decides what login flow to run based on configured AuthMode. func selectLoginFlow(p loginParams) (loginFlowKind, loginFlow) { - endpointArg := cleanEndpoint(p.endpoint) - + if p.loginEndpointURL != nil && p.loginEndpointURL.String() != p.cfg.endpointURL.String() { + return loginFlowEndpointConflict, runEndpointConflictLogin + } switch p.cfg.AuthMode() { case AuthModeOAuth: return loginFlowOAuth, runOAuthLogin case AuthModeAccessToken: - if endpointArg != p.cfg.Endpoint { - return loginFlowEndpointConflict, runEndpointConflictLogin - } return loginFlowValidate, runValidatedLogin default: return loginFlowMissingAuth, runMissingAuthLogin @@ -126,7 +128,7 @@ func printLoginProblem(out io.Writer, problem string) { fmt.Fprintf(out, "❌ Problem: %s\n", problem) } -func loginAccessTokenMessage(endpoint string) string { +func loginAccessTokenMessage(endpointURL *url.URL) string { return fmt.Sprintf("\n"+`🛠 To fix: Create an access token by going to %s/user/settings/tokens, then set the following environment variables in your terminal: export SRC_ENDPOINT=%s @@ -135,5 +137,5 @@ func loginAccessTokenMessage(endpoint string) string { To verify that it's working, run the login command again. Alternatively, you can try logging in interactively by running: src login %s -`, endpoint, endpoint, endpoint) +`, endpointURL, endpointURL, endpointURL) } diff --git a/cmd/src/login_oauth.go b/cmd/src/login_oauth.go index 5074083da8..1e404d6447 100644 --- a/cmd/src/login_oauth.go +++ b/cmd/src/login_oauth.go @@ -18,15 +18,14 @@ import ( var loadStoredOAuthToken = oauth.LoadToken func runOAuthLogin(ctx context.Context, p loginParams) error { - endpointArg := cleanEndpoint(p.endpoint) - client, err := oauthLoginClient(ctx, p, endpointArg) + client, err := oauthLoginClient(ctx, p) if err != nil { printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err)) - fmt.Fprintln(p.out, loginAccessTokenMessage(endpointArg)) + fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL)) return cmderrors.ExitCode1 } - if err := validateCurrentUser(ctx, client, p.out, endpointArg); err != nil { + if err := validateCurrentUser(ctx, client, p.out, p.cfg.endpointURL); err != nil { return err } @@ -39,13 +38,13 @@ func runOAuthLogin(ctx context.Context, p loginParams) error { // oauthLoginClient returns a api.Client with the OAuth token set. It will check secret storage for a token // and use it if one is present. // If no token is found, it will start a OAuth Device flow to get a token and storage in secret storage. -func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.Client, error) { +func oauthLoginClient(ctx context.Context, p loginParams) (api.Client, error) { // if we have a stored token, used it. Otherwise run the device flow - if token, err := loadStoredOAuthToken(ctx, endpoint); err == nil { - return newOAuthAPIClient(p, endpoint, token), nil + if token, err := loadStoredOAuthToken(ctx, p.cfg.endpointURL); err == nil { + return newOAuthAPIClient(p, token), nil } - token, err := runOAuthDeviceFlow(ctx, endpoint, p.out, p.oauthClient) + token, err := runOAuthDeviceFlow(ctx, p.cfg.endpointURL, p.out, p.oauthClient) if err != nil { return nil, err } @@ -55,23 +54,23 @@ func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api. fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err) } - return newOAuthAPIClient(p, endpoint, token), nil + return newOAuthAPIClient(p, token), nil } -func newOAuthAPIClient(p loginParams, endpoint string, token *oauth.Token) api.Client { +func newOAuthAPIClient(p loginParams, token *oauth.Token) api.Client { return api.NewClient(api.ClientOpts{ - Endpoint: endpoint, - AdditionalHeaders: p.cfg.AdditionalHeaders, + EndpointURL: p.cfg.endpointURL, + AdditionalHeaders: p.cfg.additionalHeaders, Flags: p.apiFlags, Out: p.out, - ProxyURL: p.cfg.ProxyURL, - ProxyPath: p.cfg.ProxyPath, + ProxyURL: p.cfg.proxyURL, + ProxyPath: p.cfg.proxyPath, OAuthToken: token, }) } -func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) { - authResp, err := client.Start(ctx, endpoint, nil) +func runOAuthDeviceFlow(ctx context.Context, endpointURL *url.URL, out io.Writer, client oauth.Client) (*oauth.Token, error) { + authResp, err := client.Start(ctx, endpointURL, nil) if err != nil { return nil, err } @@ -95,12 +94,12 @@ func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, cli interval = 5 * time.Second } - resp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn) + resp, err := client.Poll(ctx, endpointURL, authResp.DeviceCode, interval, authResp.ExpiresIn) if err != nil { return nil, err } - token := resp.Token(endpoint) + token := resp.Token(endpointURL) token.ClientID = client.ClientID() return token, nil } diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index 6577c19f9e..85e79816b2 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -15,33 +16,43 @@ import ( "github.com/sourcegraph/src-cli/internal/oauth" ) +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + if err != nil { + t.Fatalf("failed to parse URL %q: %v", raw, err) + } + return u +} + func TestLogin(t *testing.T) { - check := func(t *testing.T, cfg *config, endpointArg string) (output string, err error) { + check := func(t *testing.T, cfg *config, endpointArgURL *url.URL) (output string, err error) { t.Helper() var out bytes.Buffer err = loginCmd(context.Background(), loginParams{ - cfg: cfg, - client: cfg.apiClient(nil, io.Discard), - endpoint: endpointArg, - out: &out, - oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")}, + cfg: cfg, + client: cfg.apiClient(nil, io.Discard), + out: &out, + oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")}, + loginEndpointURL: endpointArgURL, }) return strings.TrimSpace(out.String()), err } t.Run("different endpoint in config vs. arg", func(t *testing.T) { - out, err := check(t, &config{Endpoint: "https://example.com"}, "https://sourcegraph.example.com") + out, err := check(t, &config{endpointURL: &url.URL{Scheme: "https", Host: "example.com"}}, &url.URL{Scheme: "https", Host: "sourcegraph.example.com"}) if err == nil { t.Fatal(err) } - if !strings.Contains(out, "OAuth Device flow authentication failed:") { - t.Errorf("got output %q, want oauth failure output", out) + if !strings.Contains(out, "The configured endpoint is https://example.com, not https://sourcegraph.example.com.") { + t.Errorf("got output %q, want configured endpoint error", out) } }) t.Run("no access token triggers oauth flow", func(t *testing.T) { - out, err := check(t, &config{Endpoint: "https://example.com"}, "https://sourcegraph.example.com") + u := &url.URL{Scheme: "https", Host: "example.com"} + out, err := check(t, &config{endpointURL: u}, u) if err == nil { t.Fatal(err) } @@ -51,8 +62,9 @@ func TestLogin(t *testing.T) { }) t.Run("warning when using config file", func(t *testing.T) { - out, err := check(t, &config{Endpoint: "https://example.com", ConfigFilePath: "f"}, "https://example.com") - if err == nil { + endpoint := &url.URL{Scheme: "https", Host: "example.com"} + out, err := check(t, &config{endpointURL: endpoint, configFilePath: "f"}, endpoint) + if err != cmderrors.ExitCode1 { t.Fatal(err) } if !strings.Contains(out, "Configuring src with a JSON file is deprecated") { @@ -69,13 +81,13 @@ func TestLogin(t *testing.T) { })) defer s.Close() - endpoint := s.URL - out, err := check(t, &config{Endpoint: endpoint, AccessToken: "x"}, endpoint) + u := mustParseURL(t, s.URL) + out, err := check(t, &config{endpointURL: u, accessToken: "x"}, u) if err != cmderrors.ExitCode1 { t.Fatal(err) } wantOut := "❌ Problem: Invalid access token.\n\n🛠 To fix: Create an access token by going to $ENDPOINT/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=$ENDPOINT\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in interactively by running: src login $ENDPOINT\n\n (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)" - wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", endpoint) + wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL) if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } @@ -87,13 +99,13 @@ func TestLogin(t *testing.T) { })) defer s.Close() - endpoint := s.URL - out, err := check(t, &config{Endpoint: endpoint, AccessToken: "x"}, endpoint) + u := mustParseURL(t, s.URL) + out, err := check(t, &config{endpointURL: u, accessToken: "x"}, u) if err != nil { t.Fatal(err) } wantOut := "✔︎ Authenticated as alice on $ENDPOINT" - wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", endpoint) + wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL) if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } @@ -105,7 +117,7 @@ func TestLogin(t *testing.T) { })) defer s.Close() - restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) { + restoreStoredOAuthLoader(t, func(_ context.Context, _ *url.URL) (*oauth.Token, error) { return &oauth.Token{ Endpoint: s.URL, ClientID: oauth.DefaultClientID, @@ -114,13 +126,13 @@ func TestLogin(t *testing.T) { }, nil }) + u, _ := url.ParseRequestURI(s.URL) startCalled := false var out bytes.Buffer err := loginCmd(context.Background(), loginParams{ - cfg: &config{Endpoint: s.URL}, - client: (&config{Endpoint: s.URL}).apiClient(nil, io.Discard), - endpoint: s.URL, - out: &out, + cfg: &config{endpointURL: u}, + client: (&config{endpointURL: u}).apiClient(nil, io.Discard), + out: &out, oauthClient: fakeOAuthClient{ startErr: fmt.Errorf("unexpected call to Start"), startCalled: &startCalled, @@ -150,18 +162,18 @@ func (f fakeOAuthClient) ClientID() string { return oauth.DefaultClientID } -func (f fakeOAuthClient) Discover(context.Context, string) (*oauth.OIDCConfiguration, error) { +func (f fakeOAuthClient) Discover(context.Context, *url.URL) (*oauth.OIDCConfiguration, error) { return nil, fmt.Errorf("unexpected call to Discover") } -func (f fakeOAuthClient) Start(context.Context, string, []string) (*oauth.DeviceAuthResponse, error) { +func (f fakeOAuthClient) Start(context.Context, *url.URL, []string) (*oauth.DeviceAuthResponse, error) { if f.startCalled != nil { *f.startCalled = true } return nil, f.startErr } -func (f fakeOAuthClient) Poll(context.Context, string, string, time.Duration, int) (*oauth.TokenResponse, error) { +func (f fakeOAuthClient) Poll(context.Context, *url.URL, string, time.Duration, int) (*oauth.TokenResponse, error) { return nil, fmt.Errorf("unexpected call to Poll") } @@ -172,8 +184,7 @@ func (f fakeOAuthClient) Refresh(context.Context, *oauth.Token) (*oauth.TokenRes func TestSelectLoginFlow(t *testing.T) { t.Run("uses oauth flow when no access token is configured", func(t *testing.T) { params := loginParams{ - cfg: &config{Endpoint: "https://example.com"}, - endpoint: "https://sourcegraph.example.com", + cfg: &config{endpointURL: mustParseURL(t, "https://example.com")}, } if got, _ := selectLoginFlow(params); got != loginFlowOAuth { @@ -183,8 +194,8 @@ func TestSelectLoginFlow(t *testing.T) { t.Run("uses endpoint conflict flow when auth exists for a different endpoint", func(t *testing.T) { params := loginParams{ - cfg: &config{Endpoint: "https://example.com", AccessToken: "x"}, - endpoint: "https://sourcegraph.example.com", + cfg: &config{endpointURL: mustParseURL(t, "https://example.com"), accessToken: "x"}, + loginEndpointURL: mustParseURL(t, "https://sourcegraph.example.com"), } if got, _ := selectLoginFlow(params); got != loginFlowEndpointConflict { @@ -194,8 +205,7 @@ func TestSelectLoginFlow(t *testing.T) { t.Run("uses validation flow when auth exists for the selected endpoint", func(t *testing.T) { params := loginParams{ - cfg: &config{Endpoint: "https://example.com", AccessToken: "x"}, - endpoint: "https://example.com", + cfg: &config{endpointURL: mustParseURL(t, "https://example.com"), accessToken: "x"}, } if got, _ := selectLoginFlow(params); got != loginFlowValidate { @@ -270,7 +280,7 @@ func TestValidateBrowserURL_WindowsRundll32Escape(t *testing.T) { } } -func restoreStoredOAuthLoader(t *testing.T, loader func(context.Context, string) (*oauth.Token, error)) { +func restoreStoredOAuthLoader(t *testing.T, loader func(context.Context, *url.URL) (*oauth.Token, error)) { t.Helper() prev := loadStoredOAuthToken diff --git a/cmd/src/login_validate.go b/cmd/src/login_validate.go index c54a6df421..9aa65cdcef 100644 --- a/cmd/src/login_validate.go +++ b/cmd/src/login_validate.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/url" "strings" "github.com/sourcegraph/src-cli/internal/api" @@ -11,28 +12,24 @@ import ( ) func runMissingAuthLogin(_ context.Context, p loginParams) error { - endpointArg := cleanEndpoint(p.endpoint) - fmt.Fprintln(p.out) printLoginProblem(p.out, "No access token is configured.") - fmt.Fprintln(p.out, loginAccessTokenMessage(endpointArg)) + fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL)) return cmderrors.ExitCode1 } func runEndpointConflictLogin(_ context.Context, p loginParams) error { - endpointArg := cleanEndpoint(p.endpoint) - fmt.Fprintln(p.out) - printLoginProblem(p.out, fmt.Sprintf("The configured endpoint is %s, not %s.", p.cfg.Endpoint, endpointArg)) - fmt.Fprintln(p.out, loginAccessTokenMessage(endpointArg)) + printLoginProblem(p.out, fmt.Sprintf("The configured endpoint is %s, not %s.", p.cfg.endpointURL, p.loginEndpointURL)) + fmt.Fprintln(p.out, loginAccessTokenMessage(p.loginEndpointURL)) return cmderrors.ExitCode1 } func runValidatedLogin(ctx context.Context, p loginParams) error { - return validateCurrentUser(ctx, p.client, p.out, cleanEndpoint(p.endpoint)) + return validateCurrentUser(ctx, p.client, p.out, p.cfg.endpointURL) } -func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, endpoint string) error { +func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, endpointURL *url.URL) error { query := `query CurrentUser { currentUser { username } }` var result struct { CurrentUser *struct{ Username string } @@ -41,9 +38,9 @@ func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, if strings.HasPrefix(err.Error(), "error: 401 Unauthorized") || strings.HasPrefix(err.Error(), "error: 403 Forbidden") { printLoginProblem(out, "Invalid access token.") } else { - printLoginProblem(out, fmt.Sprintf("Error communicating with %s: %s", endpoint, err)) + printLoginProblem(out, fmt.Sprintf("Error communicating with %s: %s", endpointURL, err)) } - fmt.Fprintln(out, loginAccessTokenMessage(endpoint)) + fmt.Fprintln(out, loginAccessTokenMessage(endpointURL)) fmt.Fprintln(out, " (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)") return cmderrors.ExitCode1 } @@ -51,11 +48,11 @@ func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, if result.CurrentUser == nil { // This should never happen; we verified there is an access token, so there should always be // a user. - printLoginProblem(out, fmt.Sprintf("Unable to determine user on %s.", endpoint)) + printLoginProblem(out, fmt.Sprintf("Unable to determine user on %s.", endpointURL)) return cmderrors.ExitCode1 } fmt.Fprintln(out) - fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", result.CurrentUser.Username, endpoint) + fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointURL) fmt.Fprintln(out) return nil } diff --git a/cmd/src/main.go b/cmd/src/main.go index c7f2ff450c..00f0366781 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -77,8 +77,8 @@ var ( verbose = flag.Bool("v", false, "print verbose output") // The following arguments are deprecated which is why they are no longer documented - configPath = flag.String("config", "", "") - endpoint = flag.String("endpoint", "", "") + configPath = flag.String("config", "", "") + endpointFlag = flag.String("endpoint", "", "") errConfigMerge = errors.New("when using a configuration file, zero or all environment variables must be set") errConfigAuthorizationConflict = errors.New("when passing an 'Authorization' additional headers, SRC_ACCESS_TOKEN must never be set") @@ -110,17 +110,42 @@ func normalizeDashHelp(args []string) []string { return args } +func parseEndpoint(endpoint string) (*url.URL, error) { + u, err := url.ParseRequestURI(strings.TrimSuffix(endpoint, "/")) + if err != nil { + return nil, err + } + if !(u.Scheme == "http" || u.Scheme == "https") { + return nil, errors.Newf("invalid scheme %s: require http or https", u.Scheme) + } + if u.Host == "" { + return nil, errors.Newf("empty host") + } + // auth in the URL is not used, and could be explosed in log output. + // Explicitly clear it in case it's accidentally set in SRC_ENDPOINT or the config file. + u.User = nil + return u, nil +} + var cfg *config -// config represents the config format. +// config holds the resolved configuration used at runtime. type config struct { + accessToken string + additionalHeaders map[string]string + proxyURL *url.URL + proxyPath string + configFilePath string + endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig +} + +// configFromFile holds the config as read from the config file, +// which is validated and parsed into the config struct. +type configFromFile struct { Endpoint string `json:"endpoint"` AccessToken string `json:"accessToken"` AdditionalHeaders map[string]string `json:"additionalHeaders"` Proxy string `json:"proxy"` - ProxyURL *url.URL - ProxyPath string - ConfigFilePath string } type AuthMode int @@ -131,7 +156,7 @@ const ( ) func (c *config) AuthMode() AuthMode { - if c.AccessToken != "" { + if c.accessToken != "" { return AuthModeAccessToken } return AuthModeOAuth @@ -140,18 +165,18 @@ func (c *config) AuthMode() AuthMode { // apiClient returns an api.Client built from the configuration. func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { opts := api.ClientOpts{ - Endpoint: c.Endpoint, - AccessToken: c.AccessToken, - AdditionalHeaders: c.AdditionalHeaders, + EndpointURL: c.endpointURL, + AccessToken: c.accessToken, + AdditionalHeaders: c.additionalHeaders, Flags: flags, Out: out, - ProxyURL: c.ProxyURL, - ProxyPath: c.ProxyPath, + ProxyURL: c.proxyURL, + ProxyPath: c.proxyPath, } // Only use OAuth if we do not have SRC_ACCESS_TOKEN set - if c.AccessToken == "" { - if t, err := oauth.LoadToken(context.Background(), c.Endpoint); err == nil { + if c.accessToken == "" { + if t, err := oauth.LoadToken(context.Background(), c.endpointURL); err == nil { opts.OAuthToken = t } } @@ -177,12 +202,20 @@ func readConfig() (*config, error) { if err != nil && (!os.IsNotExist(err) || userSpecified) { return nil, err } + + var cfgFromFile configFromFile var cfg config + var endpointStr string + var proxyStr string if err == nil { - cfg.ConfigFilePath = cfgPath - if err := json.Unmarshal(data, &cfg); err != nil { + cfg.configFilePath = cfgPath + if err := json.Unmarshal(data, &cfgFromFile); err != nil { return nil, err } + endpointStr = cfgFromFile.Endpoint + cfg.accessToken = cfgFromFile.AccessToken + cfg.additionalHeaders = cfgFromFile.AdditionalHeaders + proxyStr = cfgFromFile.Proxy } envToken := os.Getenv("SRC_ACCESS_TOKEN") @@ -203,21 +236,32 @@ func readConfig() (*config, error) { // Apply config overrides. if envToken != "" { - cfg.AccessToken = envToken + cfg.accessToken = envToken } if envEndpoint != "" { - cfg.Endpoint = envEndpoint + endpointStr = envEndpoint } - if cfg.Endpoint == "" { - cfg.Endpoint = "https://sourcegraph.com" + if endpointStr == "" { + endpointStr = "https://sourcegraph.com" } if envProxy != "" { - cfg.Proxy = envProxy + proxyStr = envProxy } - if cfg.Proxy != "" { + // Lastly, apply endpoint flag if set + if endpointFlag != nil && *endpointFlag != "" { + endpointStr = *endpointFlag + } + + if endpointURL, err := parseEndpoint(endpointStr); err != nil { + return nil, errors.Newf("invalid endpoint: %s", endpointStr) + } else { + cfg.endpointURL = endpointURL + } - parseEndpoint := func(endpoint string) (scheme string, address string) { + if proxyStr != "" { + + parseProxyEndpoint := func(endpoint string) (scheme string, address string) { parts := strings.SplitN(endpoint, "://", 2) if len(parts) == 2 { return parts[0], parts[1] @@ -231,15 +275,15 @@ func readConfig() (*config, error) { return slices.Contains(urlSchemes, scheme) } - scheme, address := parseEndpoint(cfg.Proxy) + scheme, address := parseProxyEndpoint(proxyStr) if isURLScheme(scheme) { - endpoint := cfg.Proxy + endpoint := proxyStr // assume socks means socks5, because that's all we support if scheme == "socks" { endpoint = "socks5://" + address } - cfg.ProxyURL, err = url.Parse(endpoint) + cfg.proxyURL, err = url.Parse(endpoint) if err != nil { return nil, err } @@ -250,32 +294,25 @@ func readConfig() (*config, error) { } isValidUDS, err := isValidUnixSocket(path) if err != nil { - return nil, errors.Newf("Invalid proxy configuration: %w", err) + return nil, errors.Newf("invalid proxy configuration: %w", err) } if !isValidUDS { return nil, errors.Newf("invalid proxy socket: %s", path) } - cfg.ProxyPath = path + cfg.proxyPath = path } else { - return nil, errors.Newf("invalid proxy endpoint: %s", cfg.Proxy) + return nil, errors.Newf("invalid proxy endpoint: %s", proxyStr) } } - cfg.AdditionalHeaders = parseAdditionalHeaders() + cfg.additionalHeaders = parseAdditionalHeaders() // Ensure that we're not clashing additonal headers - _, hasAuthorizationAdditonalHeader := cfg.AdditionalHeaders["authorization"] - if cfg.AccessToken != "" && hasAuthorizationAdditonalHeader { + _, hasAuthorizationAdditonalHeader := cfg.additionalHeaders["authorization"] + if cfg.accessToken != "" && hasAuthorizationAdditonalHeader { return nil, errConfigAuthorizationConflict } - // Lastly, apply endpoint flag if set - if endpoint != nil && *endpoint != "" { - cfg.Endpoint = *endpoint - } - - cfg.Endpoint = cleanEndpoint(cfg.Endpoint) - - if isCI() && cfg.AccessToken == "" { + if isCI() && cfg.accessToken == "" { return nil, errCIAccessTokenRequired } @@ -287,10 +324,6 @@ func isCI() bool { return ok && value != "" } -func cleanEndpoint(urlStr string) string { - return strings.TrimSuffix(urlStr, "/") -} - // isValidUnixSocket checks if the given path is a valid Unix socket. // // Parameters: @@ -310,7 +343,7 @@ func isValidUnixSocket(path string) (bool, error) { if os.IsNotExist(err) { return false, nil } - return false, errors.Newf("Not a UNIX Domain Socket: %v: %w", path, err) + return false, errors.Newf("not a UNIX Domain Socket: %v: %w", path, err) } defer conn.Close() diff --git a/cmd/src/main_test.go b/cmd/src/main_test.go index a2656b1929..ee95616796 100644 --- a/cmd/src/main_test.go +++ b/cmd/src/main_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/sourcegraph/src-cli/internal/api" ) @@ -29,7 +30,7 @@ func TestReadConfig(t *testing.T) { tests := []struct { name string - fileContents *config + fileContents *configFromFile envCI string envToken string envFooHeader string @@ -43,24 +44,29 @@ func TestReadConfig(t *testing.T) { { name: "defaults", want: &config{ - Endpoint: "https://sourcegraph.com", - AdditionalHeaders: map[string]string{}, + endpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, + additionalHeaders: map[string]string{}, }, }, { name: "config file, no overrides, trim slash", - fileContents: &config{ + fileContents: &configFromFile{ Endpoint: "https://example.com/", AccessToken: "deadbeef", Proxy: "https://proxy.com:8080", }, want: &config{ - Endpoint: "https://example.com", - AccessToken: "deadbeef", - AdditionalHeaders: map[string]string{}, - Proxy: "https://proxy.com:8080", - ProxyPath: "", - ProxyURL: &url.URL{ + endpointURL: &url.URL{ + Scheme: "https", + Host: "example.com", + }, + accessToken: "deadbeef", + additionalHeaders: map[string]string{}, + proxyPath: "", + proxyURL: &url.URL{ Scheme: "https", Host: "proxy.com:8080", }, @@ -68,7 +74,7 @@ func TestReadConfig(t *testing.T) { }, { name: "config file, token override only", - fileContents: &config{ + fileContents: &configFromFile{ Endpoint: "https://example.com/", AccessToken: "deadbeef", }, @@ -78,7 +84,7 @@ func TestReadConfig(t *testing.T) { }, { name: "config file, endpoint override only", - fileContents: &config{ + fileContents: &configFromFile{ Endpoint: "https://example.com/", AccessToken: "deadbeef", }, @@ -88,27 +94,29 @@ func TestReadConfig(t *testing.T) { }, { name: "config file, proxy override only (allow)", - fileContents: &config{ + fileContents: &configFromFile{ Endpoint: "https://example.com/", AccessToken: "deadbeef", Proxy: "https://proxy.com:8080", }, envProxy: "socks5://other.proxy.com:9999", want: &config{ - Endpoint: "https://example.com", - AccessToken: "deadbeef", - Proxy: "socks5://other.proxy.com:9999", - ProxyPath: "", - ProxyURL: &url.URL{ + endpointURL: &url.URL{ + Scheme: "https", + Host: "example.com", + }, + accessToken: "deadbeef", + proxyPath: "", + proxyURL: &url.URL{ Scheme: "socks5", Host: "other.proxy.com:9999", }, - AdditionalHeaders: map[string]string{}, + additionalHeaders: map[string]string{}, }, }, { name: "config file, all override", - fileContents: &config{ + fileContents: &configFromFile{ Endpoint: "https://example.com/", AccessToken: "deadbeef", Proxy: "https://proxy.com:8080", @@ -117,48 +125,58 @@ func TestReadConfig(t *testing.T) { envEndpoint: "https://override.com", envProxy: "socks5://other.proxy.com:9999", want: &config{ - Endpoint: "https://override.com", - AccessToken: "abc", - Proxy: "socks5://other.proxy.com:9999", - ProxyPath: "", - ProxyURL: &url.URL{ + endpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, + accessToken: "abc", + proxyPath: "", + proxyURL: &url.URL{ Scheme: "socks5", Host: "other.proxy.com:9999", }, - AdditionalHeaders: map[string]string{}, + additionalHeaders: map[string]string{}, }, }, { name: "no config file, token from environment", envToken: "abc", want: &config{ - Endpoint: "https://sourcegraph.com", - AccessToken: "abc", - AdditionalHeaders: map[string]string{}, + endpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, + accessToken: "abc", + additionalHeaders: map[string]string{}, }, }, { name: "no config file, endpoint from environment", envEndpoint: "https://example.com", want: &config{ - Endpoint: "https://example.com", - AccessToken: "", - AdditionalHeaders: map[string]string{}, + endpointURL: &url.URL{ + Scheme: "https", + Host: "example.com", + }, + accessToken: "", + additionalHeaders: map[string]string{}, }, }, { name: "no config file, proxy from environment", envProxy: "https://proxy.com:8080", want: &config{ - Endpoint: "https://sourcegraph.com", - AccessToken: "", - Proxy: "https://proxy.com:8080", - ProxyPath: "", - ProxyURL: &url.URL{ + endpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, + accessToken: "", + proxyPath: "", + proxyURL: &url.URL{ Scheme: "https", Host: "proxy.com:8080", }, - AdditionalHeaders: map[string]string{}, + additionalHeaders: map[string]string{}, }, }, { @@ -167,79 +185,92 @@ func TestReadConfig(t *testing.T) { envToken: "abc", envProxy: "https://proxy.com:8080", want: &config{ - Endpoint: "https://example.com", - AccessToken: "abc", - Proxy: "https://proxy.com:8080", - ProxyPath: "", - ProxyURL: &url.URL{ + endpointURL: &url.URL{ + Scheme: "https", + Host: "example.com", + }, + accessToken: "abc", + proxyPath: "", + proxyURL: &url.URL{ Scheme: "https", Host: "proxy.com:8080", }, - AdditionalHeaders: map[string]string{}, + additionalHeaders: map[string]string{}, }, }, { name: "UNIX Domain Socket proxy using scheme and absolute path", envProxy: "unix://" + socketPath, want: &config{ - Endpoint: "https://sourcegraph.com", - Proxy: "unix://" + socketPath, - ProxyPath: socketPath, - ProxyURL: nil, - AdditionalHeaders: map[string]string{}, + endpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, + proxyPath: socketPath, + proxyURL: nil, + additionalHeaders: map[string]string{}, }, }, { name: "UNIX Domain Socket proxy with absolute path", envProxy: socketPath, want: &config{ - Endpoint: "https://sourcegraph.com", - Proxy: socketPath, - ProxyPath: socketPath, - ProxyURL: nil, - AdditionalHeaders: map[string]string{}, + endpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, + proxyPath: socketPath, + proxyURL: nil, + additionalHeaders: map[string]string{}, }, }, { name: "socks --> socks5", envProxy: "socks://localhost:1080", want: &config{ - Endpoint: "https://sourcegraph.com", - Proxy: "socks://localhost:1080", - ProxyPath: "", - ProxyURL: &url.URL{ + endpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, + proxyPath: "", + proxyURL: &url.URL{ Scheme: "socks5", Host: "localhost:1080", }, - AdditionalHeaders: map[string]string{}, + additionalHeaders: map[string]string{}, }, }, { name: "socks5h", envProxy: "socks5h://localhost:1080", want: &config{ - Endpoint: "https://sourcegraph.com", - Proxy: "socks5h://localhost:1080", - ProxyPath: "", - ProxyURL: &url.URL{ + endpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, + proxyPath: "", + proxyURL: &url.URL{ Scheme: "socks5h", Host: "localhost:1080", }, - AdditionalHeaders: map[string]string{}, + additionalHeaders: map[string]string{}, }, }, { name: "endpoint flag should override config", flagEndpoint: "https://override.com/", - fileContents: &config{ + fileContents: &configFromFile{ Endpoint: "https://example.com/", AccessToken: "deadbeef", AdditionalHeaders: map[string]string{}, }, want: &config{ - Endpoint: "https://override.com", - AccessToken: "deadbeef", - AdditionalHeaders: map[string]string{}, + endpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, + accessToken: "deadbeef", + additionalHeaders: map[string]string{}, }, }, { @@ -248,9 +279,12 @@ func TestReadConfig(t *testing.T) { envEndpoint: "https://example.com", envToken: "abc", want: &config{ - Endpoint: "https://override.com", - AccessToken: "abc", - AdditionalHeaders: map[string]string{}, + endpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, + accessToken: "abc", + additionalHeaders: map[string]string{}, }, }, { @@ -260,9 +294,12 @@ func TestReadConfig(t *testing.T) { envToken: "abc", envFooHeader: "bar", want: &config{ - Endpoint: "https://override.com", - AccessToken: "abc", - AdditionalHeaders: map[string]string{"foo": "bar"}, + endpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, + accessToken: "abc", + additionalHeaders: map[string]string{"foo": "bar"}, }, }, { @@ -272,9 +309,12 @@ func TestReadConfig(t *testing.T) { envToken: "abc", envHeaders: "foo:bar\nfoo-bar:bar-baz", want: &config{ - Endpoint: "https://override.com", - AccessToken: "abc", - AdditionalHeaders: map[string]string{"foo-bar": "bar-baz", "foo": "bar"}, + endpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, + accessToken: "abc", + additionalHeaders: map[string]string{"foo-bar": "bar-baz", "foo": "bar"}, }, }, { @@ -292,14 +332,14 @@ func TestReadConfig(t *testing.T) { { name: "CI allows access token from config file", envCI: "1", - fileContents: &config{ + fileContents: &configFromFile{ Endpoint: "https://example.com/", AccessToken: "deadbeef", }, want: &config{ - Endpoint: "https://example.com", - AccessToken: "deadbeef", - AdditionalHeaders: map[string]string{}, + endpointURL: &url.URL{Scheme: "https", Host: "example.com"}, + accessToken: "deadbeef", + additionalHeaders: map[string]string{}, }, }, } @@ -323,8 +363,8 @@ func TestReadConfig(t *testing.T) { if test.flagEndpoint != "" { val := test.flagEndpoint - endpoint = &val - t.Cleanup(func() { endpoint = nil }) + endpointFlag = &val + t.Cleanup(func() { endpointFlag = nil }) } if test.fileContents != nil { @@ -351,8 +391,11 @@ func TestReadConfig(t *testing.T) { t.Fatal(err) } - config, err := readConfig() - if diff := cmp.Diff(test.want, config); diff != "" { + got, err := readConfig() + if diff := cmp.Diff(test.want, got, + cmp.AllowUnexported(config{}), + cmpopts.IgnoreFields(config{}, "configFilePath"), + ); diff != "" { t.Errorf("config: %v", diff) } var errMsg string @@ -374,7 +417,7 @@ func TestConfigAuthMode(t *testing.T) { }) t.Run("access token when configured", func(t *testing.T) { - if got := (&config{AccessToken: "token"}).AuthMode(); got != AuthModeAccessToken { + if got := (&config{accessToken: "token"}).AuthMode(); got != AuthModeAccessToken { t.Fatalf("AuthMode() = %v, want %v", got, AuthModeAccessToken) } }) diff --git a/cmd/src/search.go b/cmd/src/search.go index 385ac7c72d..a2cd1a0bbe 100644 --- a/cmd/src/search.go +++ b/cmd/src/search.go @@ -266,7 +266,7 @@ Other tips: } improved := searchResultsImproved{ - SourcegraphEndpoint: cfg.Endpoint, + SourcegraphEndpoint: cfg.endpointURL.String(), Query: queryString, Site: result.Site, searchResults: result.Search.Results, diff --git a/cmd/src/search_jobs.go b/cmd/src/search_jobs.go index 96c5d9070a..15bf5d8a25 100644 --- a/cmd/src/search_jobs.go +++ b/cmd/src/search_jobs.go @@ -156,8 +156,8 @@ func parseColumns(columnsFlag string) []string { // createSearchJobsClient creates a reusable API client for search jobs commands func createSearchJobsClient(out *flag.FlagSet, apiFlags *api.Flags) api.Client { return api.NewClient(api.ClientOpts{ - Endpoint: cfg.Endpoint, - AccessToken: cfg.AccessToken, + EndpointURL: cfg.endpointURL, + AccessToken: cfg.accessToken, Out: out.Output(), Flags: apiFlags, }) diff --git a/cmd/src/search_jobs_logs.go b/cmd/src/search_jobs_logs.go index 2fe9b6bfee..6327a609ab 100644 --- a/cmd/src/search_jobs_logs.go +++ b/cmd/src/search_jobs_logs.go @@ -22,7 +22,7 @@ func fetchJobLogs(jobID string, logURL string) (io.ReadCloser, error) { return nil, err } - req.Header.Add("Authorization", "token "+cfg.AccessToken) + req.Header.Add("Authorization", "token "+cfg.accessToken) resp, err := http.DefaultClient.Do(req) if err != nil { diff --git a/cmd/src/search_jobs_results.go b/cmd/src/search_jobs_results.go index 2e45f8219f..9d8bc7a9ab 100644 --- a/cmd/src/search_jobs_results.go +++ b/cmd/src/search_jobs_results.go @@ -22,7 +22,7 @@ func fetchJobResults(jobID string, resultsURL string) (io.ReadCloser, error) { return nil, err } - req.Header.Add("Authorization", "token "+cfg.AccessToken) + req.Header.Add("Authorization", "token "+cfg.accessToken) resp, err := http.DefaultClient.Do(req) if err != nil { diff --git a/cmd/src/search_stream.go b/cmd/src/search_stream.go index 512e7f9f75..9b5415851a 100644 --- a/cmd/src/search_stream.go +++ b/cmd/src/search_stream.go @@ -160,7 +160,7 @@ func textDecoder(query string, t *template.Template, w io.Writer) streaming.Deco SourcegraphEndpoint string *streaming.EventRepoMatch }{ - SourcegraphEndpoint: cfg.Endpoint, + SourcegraphEndpoint: cfg.endpointURL.String(), EventRepoMatch: match, }) if err != nil { @@ -172,7 +172,7 @@ func textDecoder(query string, t *template.Template, w io.Writer) streaming.Deco SourcegraphEndpoint string *streaming.EventCommitMatch }{ - SourcegraphEndpoint: cfg.Endpoint, + SourcegraphEndpoint: cfg.endpointURL.String(), EventCommitMatch: match, }) if err != nil { @@ -184,7 +184,7 @@ func textDecoder(query string, t *template.Template, w io.Writer) streaming.Deco SourcegraphEndpoint string *streaming.EventSymbolMatch }{ - SourcegraphEndpoint: cfg.Endpoint, + SourcegraphEndpoint: cfg.endpointURL.String(), EventSymbolMatch: match, }, ) diff --git a/cmd/src/search_stream_test.go b/cmd/src/search_stream_test.go index 1653b273ac..71b099abb9 100644 --- a/cmd/src/search_stream_test.go +++ b/cmd/src/search_stream_test.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -126,8 +127,9 @@ func TestSearchStream(t *testing.T) { s := testServer(t, http.HandlerFunc(mockStreamHandler)) defer s.Close() + u, _ := url.ParseRequestURI(s.URL) cfg = &config{ - Endpoint: s.URL, + endpointURL: u, } defer func() { cfg = nil }() diff --git a/cmd/src/users_prune.go b/cmd/src/users_prune.go index 90de530c68..d67fff6be0 100644 --- a/cmd/src/users_prune.go +++ b/cmd/src/users_prune.go @@ -225,7 +225,7 @@ type UserToDelete struct { // Verify user wants to remove users with table of users and a command prompt for [y/N] func confirmUserRemoval(usersToDelete []UserToDelete, daysThreshold int, displayUsers bool) (bool, error) { if displayUsers { - fmt.Printf("Users to remove from %s\n", cfg.Endpoint) + fmt.Printf("Users to remove from %s\n", cfg.endpointURL) t := table.NewWriter() t.SetOutputMirror(os.Stdout) t.AppendHeader(table.Row{"Username", "Email", "Days Since Last Active"}) @@ -243,7 +243,7 @@ func confirmUserRemoval(usersToDelete []UserToDelete, daysThreshold int, display } input := "" for strings.ToLower(input) != "y" && strings.ToLower(input) != "n" { - fmt.Printf("%v users were inactive for more than %v days on %v.\nDo you wish to proceed with user removal [y/N]: ", len(usersToDelete), daysThreshold, cfg.Endpoint) + fmt.Printf("%v users were inactive for more than %v days on %v.\nDo you wish to proceed with user removal [y/N]: ", len(usersToDelete), daysThreshold, cfg.endpointURL) if _, err := fmt.Scanln(&input); err != nil { return false, err } diff --git a/internal/api/api.go b/internal/api/api.go index adc6af53d5..a824fae1a5 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -12,7 +12,6 @@ import ( "net/url" "os" "runtime" - "strings" ioaux "github.com/jig/teereadcloser" "github.com/kballard/go-shellquote" @@ -72,7 +71,7 @@ type request struct { // ClientOpts encapsulates the options given to NewClient. type ClientOpts struct { - Endpoint string + EndpointURL *url.URL AccessToken string AdditionalHeaders map[string]string @@ -139,7 +138,7 @@ func NewClient(opts ClientOpts) Client { return &client{ opts: ClientOpts{ - Endpoint: opts.Endpoint, + EndpointURL: opts.EndpointURL, AccessToken: opts.AccessToken, AdditionalHeaders: opts.AdditionalHeaders, Flags: flags, @@ -174,7 +173,8 @@ func (c *client) NewHTTPRequest(ctx context.Context, method, p string, body io.R } func (c *client) createHTTPRequest(ctx context.Context, method, p string, body io.Reader) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, method, strings.TrimRight(c.opts.Endpoint, "/")+"/"+p, body) + // Can't use c.opts.EndpointURL.JoinPath(p) here because `p` could contain a query string + req, err := http.NewRequestWithContext(ctx, method, c.opts.EndpointURL.String()+"/"+p, body) if err != nil { return nil, err } @@ -269,7 +269,7 @@ func (r *request) do(ctx context.Context, result any) (bool, error) { if oauth.IsOAuthTransport(r.client.httpClient.Transport) { fmt.Println("The OAuth token is invalid. Please check that the Sourcegraph CLI client is still authorized.") fmt.Println("") - fmt.Printf("To re-authorize, run: src login %s\n", r.client.opts.Endpoint) + fmt.Printf("To re-authorize, run: src login %s\n", r.client.opts.EndpointURL) fmt.Println("") fmt.Println("Learn more at https://github.com/sourcegraph/src-cli#readme") fmt.Println("") @@ -360,6 +360,6 @@ func (r *request) curlCmd() (string, error) { s += fmt.Sprintf(" %s \\\n", shellquote.Join("-H", k+": "+v)) } s += fmt.Sprintf(" %s \\\n", shellquote.Join("-d", string(data))) - s += fmt.Sprintf(" %s", shellquote.Join(r.client.opts.Endpoint+"/.api/graphql")) + s += fmt.Sprintf(" %s", shellquote.Join(r.client.opts.EndpointURL.JoinPath(".api/graphql").String())) return s, nil } diff --git a/internal/batches/executor/executor_test.go b/internal/batches/executor/executor_test.go index 9fc96d927d..03f25e08a8 100644 --- a/internal/batches/executor/executor_test.go +++ b/internal/batches/executor/executor_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "runtime" @@ -406,7 +407,8 @@ func TestExecutor_Integration(t *testing.T) { // Setup an api.Client that points to this test server var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) // Temp dir for log files and downloaded archives testTempDir := t.TempDir() @@ -827,7 +829,8 @@ func testExecuteTasks(t *testing.T, tasks []*Task, archives ...mock.RepoArchive) t.Cleanup(ts.Close) var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) // Prepare images // diff --git a/internal/batches/repozip/fetcher_test.go b/internal/batches/repozip/fetcher_test.go index f871237e45..56d03d85a6 100644 --- a/internal/batches/repozip/fetcher_test.go +++ b/internal/batches/repozip/fetcher_test.go @@ -5,6 +5,7 @@ import ( "context" "net/http" "net/http/httptest" + "net/url" "os" "path" "path/filepath" @@ -44,7 +45,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client, @@ -89,7 +91,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client, @@ -153,7 +156,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client, @@ -193,7 +197,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client, @@ -262,7 +267,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client, diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index f34d9ae656..ec14a22826 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -72,6 +72,10 @@ type Token struct { ExpiresAt time.Time `json:"expires_at"` } +func (t *Token) EndpointURL() (*url.URL, error) { + return url.ParseRequestURI(t.Endpoint) +} + type ErrorResponse struct { Error string `json:"error"` ErrorDescription string `json:"error_description,omitempty"` @@ -79,9 +83,9 @@ type ErrorResponse struct { type Client interface { ClientID() string - Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) - Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) - Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) + Discover(ctx context.Context, endpointURL *url.URL) (*OIDCConfiguration, error) + Start(ctx context.Context, endpointURL *url.URL, scopes []string) (*DeviceAuthResponse, error) + Poll(ctx context.Context, endpointURL *url.URL, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) Refresh(ctx context.Context, token *Token) (*TokenResponse, error) } @@ -115,14 +119,14 @@ func (c *httpClient) ClientID() string { // // Before making any requests, the configCache is checked and if there is a cache hit, the // cached config is returned. -func (c *httpClient) Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) { - endpoint = strings.TrimRight(endpoint, "/") +func (c *httpClient) Discover(ctx context.Context, endpointURL *url.URL) (*OIDCConfiguration, error) { + endpoint := endpointURL.String() if config, ok := c.configCache[endpoint]; ok { return config, nil } - reqURL := endpoint + wellKnownPath + reqURL := endpointURL.JoinPath(wellKnownPath).String() req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) if err != nil { @@ -158,11 +162,9 @@ func (c *httpClient) Discover(ctx context.Context, endpoint string) (*OIDCConfig // Start starts the OAuth device flow with the given endpoint. If no scopes are given the default scopes are used. // // Default Scopes: "openid" "profile" "email" "offline_access" "user:all" -func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) { - endpoint = strings.TrimRight(endpoint, "/") - +func (c *httpClient) Start(ctx context.Context, endpointURL *url.URL, scopes []string) (*DeviceAuthResponse, error) { // Discover OIDC configuration - caches on first call - config, err := c.Discover(ctx, endpoint) + config, err := c.Discover(ctx, endpointURL) if err != nil { return nil, errors.Wrap(err, "OIDC discovery failed") } @@ -221,11 +223,9 @@ func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string // - Device is authorized, and a token is returned // - Device code has expried // - User denied authorization -func (c *httpClient) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) { - endpoint = strings.TrimRight(endpoint, "/") - +func (c *httpClient) Poll(ctx context.Context, endpointURL *url.URL, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) { // Discover OIDC configuration - caches on first call - config, err := c.Discover(ctx, endpoint) + config, err := c.Discover(ctx, endpointURL) if err != nil { return nil, errors.Wrap(err, "OIDC discovery failed") } @@ -326,7 +326,12 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str // Refresh exchanges a refresh token for a new access token. func (c *httpClient) Refresh(ctx context.Context, token *Token) (*TokenResponse, error) { - config, err := c.Discover(ctx, token.Endpoint) + endpointURL, err := token.EndpointURL() + if err != nil { + return nil, errors.Wrap(err, "invlaid token endpoint") + } + + config, err := c.Discover(ctx, endpointURL) if err != nil { return nil, errors.Wrap(err, "failed to discover OIDC configuration") } @@ -374,9 +379,9 @@ func (c *httpClient) Refresh(ctx context.Context, token *Token) (*TokenResponse, return &tokenResp, err } -func (t *TokenResponse) Token(endpoint string) *Token { +func (t *TokenResponse) Token(endpointURL *url.URL) *Token { return &Token{ - Endpoint: strings.TrimRight(endpoint, "/"), + Endpoint: endpointURL.String(), RefreshToken: t.RefreshToken, AccessToken: t.AccessToken, ExpiresAt: time.Now().Add(time.Second * time.Duration(t.ExpiresIn)), @@ -397,7 +402,12 @@ func StoreToken(ctx context.Context, token *Token) error { return errors.New("token endpoint cannot be empty when storing the token") } - store, err := secrets.Open(ctx, token.Endpoint) + u, err := token.EndpointURL() + if err != nil { + return errors.Wrap(err, "invalid token endpoint") + } + + store, err := secrets.Open(ctx, u) if err != nil { return err } @@ -409,8 +419,8 @@ func StoreToken(ctx context.Context, token *Token) error { return store.Put(oauthKey, data) } -func LoadToken(ctx context.Context, endpoint string) (*Token, error) { - store, err := secrets.Open(ctx, endpoint) +func LoadToken(ctx context.Context, endpointURL *url.URL) (*Token, error) { + store, err := secrets.Open(ctx, endpointURL) if err != nil { return nil, err } diff --git a/internal/oauth/flow_test.go b/internal/oauth/flow_test.go index 0b1ad5dc93..a6312e1a6c 100644 --- a/internal/oauth/flow_test.go +++ b/internal/oauth/flow_test.go @@ -5,12 +5,22 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "strings" "sync/atomic" "testing" "time" ) +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + if err != nil { + t.Fatalf("failed to parse URL %q: %v", raw, err) + } + return u +} + const ( testDeviceAuthPath = "/device/code" testTokenPath = "/token" @@ -51,7 +61,7 @@ func TestDiscover_Success(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID) - config, err := client.Discover(context.Background(), server.URL) + config, err := client.Discover(context.Background(), mustParseURL(t, server.URL)) if err != nil { t.Fatalf("Discover() error = %v", err) } @@ -81,13 +91,13 @@ func TestDiscover_Caching(t *testing.T) { client := NewClient(DefaultClientID) // Populate the cache - _, err := client.Discover(context.Background(), server.URL) + _, err := client.Discover(context.Background(), mustParseURL(t, server.URL)) if err != nil { t.Fatalf("Discover() error = %v", err) } // Second call should use cache - _, err = client.Discover(context.Background(), server.URL) + _, err = client.Discover(context.Background(), mustParseURL(t, server.URL)) if err != nil { t.Fatalf("Discover() error = %v", err) } @@ -106,7 +116,7 @@ func TestDiscover_Error(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID) - _, err := client.Discover(context.Background(), server.URL) + _, err := client.Discover(context.Background(), mustParseURL(t, server.URL)) if err == nil { t.Fatal("Discover() expected error, got nil") } @@ -153,7 +163,7 @@ func TestStart_Success(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID) - resp, err := client.Start(context.Background(), server.URL, nil) + resp, err := client.Start(context.Background(), mustParseURL(t, server.URL), nil) if err != nil { t.Fatalf("Start() error = %v", err) } @@ -205,7 +215,7 @@ func TestStart_WithScopes(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID) - _, err := client.Start(context.Background(), server.URL, []string{"read", "write"}) + _, err := client.Start(context.Background(), mustParseURL(t, server.URL), []string{"read", "write"}) if err != nil { t.Fatalf("Start() error = %v", err) } @@ -231,7 +241,7 @@ func TestStart_Error(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID) - _, err := client.Start(context.Background(), server.URL, nil) + _, err := client.Start(context.Background(), mustParseURL(t, server.URL), nil) if err == nil { t.Fatal("Start() expected error, got nil") } @@ -254,7 +264,7 @@ func TestStart_NoDeviceEndpoint(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID) - _, err := client.Start(context.Background(), server.URL, nil) + _, err := client.Start(context.Background(), mustParseURL(t, server.URL), nil) if err == nil { t.Fatal("Start() expected error, got nil") } @@ -302,7 +312,7 @@ func TestPoll_Success(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID).(*httpClient) - resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + resp, err := client.Poll(context.Background(), mustParseURL(t, server.URL), "test-device-code", 10*time.Millisecond, 60) if err != nil { t.Fatalf("Poll() error = %v", err) } @@ -345,7 +355,7 @@ func TestPoll_AuthorizationPending(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID).(*httpClient) - resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + resp, err := client.Poll(context.Background(), mustParseURL(t, server.URL), "test-device-code", 10*time.Millisecond, 60) if err != nil { t.Fatalf("Poll() error = %v", err) } @@ -387,7 +397,7 @@ func TestPoll_SlowDown(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID).(*httpClient) - resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + resp, err := client.Poll(context.Background(), mustParseURL(t, server.URL), "test-device-code", 10*time.Millisecond, 60) if err != nil { t.Fatalf("Poll() error = %v", err) } @@ -417,7 +427,7 @@ func TestPoll_ExpiredToken(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID).(*httpClient) - _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + _, err := client.Poll(context.Background(), mustParseURL(t, server.URL), "test-device-code", 10*time.Millisecond, 60) if err == nil { t.Fatal("Poll() expected error, got nil") } @@ -444,7 +454,7 @@ func TestPoll_AccessDenied(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID).(*httpClient) - _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60) + _, err := client.Poll(context.Background(), mustParseURL(t, server.URL), "test-device-code", 10*time.Millisecond, 60) if err == nil { t.Fatal("Poll() expected error, got nil") } @@ -470,7 +480,7 @@ func TestPoll_Timeout(t *testing.T) { defer server.Close() client := NewClient(DefaultClientID).(*httpClient) - _, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 0) + _, err := client.Poll(context.Background(), mustParseURL(t, server.URL), "test-device-code", 10*time.Millisecond, 0) if err == nil { t.Fatal("Poll() expected error, got nil") } @@ -499,7 +509,7 @@ func TestPoll_ContextCancellation(t *testing.T) { cancel() client := NewClient(DefaultClientID).(*httpClient) - _, err := client.Poll(ctx, server.URL, "test-device-code", 10*time.Millisecond, 3600) + _, err := client.Poll(ctx, mustParseURL(t, server.URL), "test-device-code", 10*time.Millisecond, 3600) if err == nil { t.Fatal("Poll() expected error, got nil") } diff --git a/internal/oauth/http_transport.go b/internal/oauth/http_transport.go index 1664c711f0..854dd61138 100644 --- a/internal/oauth/http_transport.go +++ b/internal/oauth/http_transport.go @@ -5,6 +5,8 @@ import ( "net/http" "sync" "time" + + "github.com/sourcegraph/sourcegraph/lib/errors" ) var _ http.Transport @@ -22,7 +24,7 @@ type Transport struct { } // storeRefreshedTokenFn is the function the transport should use to persist the token - mainly used during -// tests to swap out the implementation out with a mock +// tests to swap out the implementation with a mock var storeRefreshedTokenFn = StoreToken // RoundTrip implements http.RoundTripper. @@ -59,7 +61,8 @@ func (t *Transport) getToken(ctx context.Context) (Token, error) { } t.Token = token if token != prevToken { - // try to save the token if we fail let the request continue with in memory token + // Try to save the token. + // If we fail let the request continue with the in-memory token _ = storeRefreshedTokenFn(ctx, token) } @@ -81,7 +84,12 @@ func maybeRefresh(ctx context.Context, token *Token) (*Token, error) { return nil, err } - next := resp.Token(token.Endpoint) + endpointURL, err := token.EndpointURL() + if err != nil { + return nil, errors.Wrap(err, "invalid token endpoint") + } + + next := resp.Token(endpointURL) next.ClientID = token.ClientID return next, nil } diff --git a/internal/secrets/keyring.go b/internal/secrets/keyring.go index 9464b1054c..e9bd8fbeb1 100644 --- a/internal/secrets/keyring.go +++ b/internal/secrets/keyring.go @@ -2,7 +2,7 @@ package secrets import ( "context" - "strings" + "net/url" "github.com/sourcegraph/sourcegraph/lib/errors" "github.com/zalando/go-keyring" @@ -18,8 +18,8 @@ type keyringStore struct { } // Open opens the system keyring for the Sourcegraph CLI. -func Open(ctx context.Context, endpoint string) (*keyringStore, error) { - endpoint = strings.TrimRight(strings.TrimSpace(endpoint), "/") +func Open(ctx context.Context, endpointURL *url.URL) (*keyringStore, error) { + endpoint := endpointURL.String() if endpoint == "" { return nil, errors.New("endpoint cannot be empty") } diff --git a/internal/secrets/keyring_test.go b/internal/secrets/keyring_test.go index 65d2002782..a8910318e6 100644 --- a/internal/secrets/keyring_test.go +++ b/internal/secrets/keyring_test.go @@ -2,34 +2,35 @@ package secrets import ( "context" + "net/url" "testing" ) func TestOpen(t *testing.T) { tests := []struct { name string - endpoint string + endpoint *url.URL wantServiceName string wantErr bool }{ { - name: "normalized endpoint", - endpoint: " https://sourcegraph.example.com/ ", + name: "simple endpoint", + endpoint: &url.URL{Scheme: "https", Host: "sourcegraph.example.com"}, wantServiceName: "Sourcegraph CLI ", }, { - name: "normalized endpoint with path", - endpoint: " https://sourcegraph.example.com/sourcegraph/ ", + name: "endpoint with path", + endpoint: &url.URL{Scheme: "https", Host: "sourcegraph.example.com", Path: "/sourcegraph"}, wantServiceName: "Sourcegraph CLI ", }, { - name: "normalized endpoint with nested path", - endpoint: "https://sourcegraph.example.com/custom/path///", + name: "endpoint with nested path", + endpoint: &url.URL{Scheme: "https", Host: "sourcegraph.example.com", Path: "/custom/path"}, wantServiceName: "Sourcegraph CLI ", }, { name: "empty endpoint", - endpoint: " / ", + endpoint: &url.URL{Scheme: "", Host: ""}, wantErr: true, }, } From 93dc82afd0e6ae55ff8e6572509425f0d437e2ce Mon Sep 17 00:00:00 2001 From: Peter Guy Date: Wed, 11 Mar 2026 15:48:25 -0700 Subject: [PATCH 2/2] Update internal/oauth/flow.go Co-authored-by: William Bezuidenhout --- internal/oauth/flow.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index ec14a22826..5357466f57 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -328,7 +328,7 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str func (c *httpClient) Refresh(ctx context.Context, token *Token) (*TokenResponse, error) { endpointURL, err := token.EndpointURL() if err != nil { - return nil, errors.Wrap(err, "invlaid token endpoint") + return nil, errors.Wrap(err, "invalid token endpoint") } config, err := c.Discover(ctx, endpointURL)