From 7e6779bd406cd3829ef4121ef436a6e3150c34c7 Mon Sep 17 00:00:00 2001 From: Steve Hipwell Date: Fri, 8 May 2026 16:21:15 +0100 Subject: [PATCH] feat!: refactor client constructor to use options pattern Signed-off-by: Steve Hipwell --- README.md | 65 +- example/actionpermissions/main.go | 5 +- example/auditlogstream/main.go | 2 +- example/basicauth/main.go | 6 +- .../newreposecretwithxcrypto/main.go | 5 +- .../newusersecretwithxcrypto/main.go | 5 +- example/commitpr/main.go | 6 +- example/contents/main.go | 6 +- example/listenvironments/main.go | 5 +- example/migrations/main.go | 5 +- example/newfilewithappauth/main.go | 14 +- example/newrepo/main.go | 5 +- example/newreposecretwithxcrypto/main.go | 5 +- example/otel/main.go | 17 +- example/ratelimit/main.go | 6 +- example/simple/main.go | 5 +- example/tokenauth/main.go | 5 +- example/topics/main.go | 6 +- example/uploadreleaseassetfromrelease/main.go | 5 +- example/verifyartifact/main.go | 5 +- github/actions_artifacts.go | 2 +- github/actions_artifacts_test.go | 25 +- github/actions_workflow_jobs.go | 2 +- github/actions_workflow_jobs_test.go | 21 +- github/actions_workflow_runs.go | 4 +- github/actions_workflow_runs_test.go | 16 +- github/actions_workflows_test.go | 12 +- github/copilot_test.go | 34 +- github/doc.go | 36 +- github/example_iterators_test.go | 7 +- github/examples_test.go | 67 +- github/github.go | 518 ++++++---- github/github_test.go | 924 +++++++++++++++--- github/rate_limit.go | 2 +- github/rate_limit_test.go | 2 +- github/repos_contents.go | 2 +- github/repos_contents_test.go | 6 +- github/repos_releases_test.go | 2 +- test/fields/fields.go | 14 +- test/integration/authorizations_test.go | 8 +- test/integration/github_test.go | 15 +- tools/metadata/main.go | 7 +- 42 files changed, 1436 insertions(+), 473 deletions(-) diff --git a/README.md b/README.md index 8bec62bb186..2770cd86c7e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ go-github will require the N-1 major release of Go by default. [support-policy]: https://golang.org/doc/devel/release.html#policy -## Development +## Development ## If you're interested in using the [GraphQL API v4][], the recommended library is [shurcooL/githubv4][]. @@ -66,7 +66,10 @@ Construct a new GitHub client, then use the various services on the client to access different parts of the GitHub API. For example: ```go -client := github.NewClient(nil) +client, err := github.NewClient() +if err != nil { + // Handle error. +} // list all organizations for user "willnorris" orgs, _, err := client.Organizations.List(context.Background(), "willnorris", nil) @@ -75,7 +78,10 @@ orgs, _, err := client.Organizations.List(context.Background(), "willnorris", ni Some API methods have optional parameters that can be passed. For example: ```go -client := github.NewClient(nil) +client, err := github.NewClient() +if err != nil { + // Handle error. +} // list public repositories for org "github" opt := &github.RepositoryListByOrgOptions{Type: "public"} @@ -95,12 +101,15 @@ For more sample code snippets, head over to the ### Authentication ### -Use the `WithAuthToken` method to configure your client to authenticate using an +Use the `github.WithAuthToken` options method to configure your client to authenticate using an OAuth token (for example, a [personal access token][]). This is what is needed for a majority of use cases aside from GitHub Apps. ```go -client := github.NewClient(nil).WithAuthToken("... your access token ...") +client, err := github.NewClient(github.WithAuthToken("... your access token ...")) +if err != nil { + // Handle error. +} ``` Note that when using an authenticated Client, all calls made by the client will @@ -146,7 +155,10 @@ func main() { } // Use installation transport with client. - client := github.NewClient(&http.Client{Transport: itr}) + client, err := github.NewClient(github.WithTransport(itr)) + if err != nil { + // Handle error. + } // Use client... } @@ -186,11 +198,14 @@ func main() { // InstallationTokenSource has the mechanism to refresh the token when it expires. httpClient := oauth2.NewClient(context.Background(), installationTokenSource) - client := github.NewClient(httpClient) + client, err := github.NewClient(github.WithHTTPClient(httpClient)) + if err != nil { + // Handle error. + } } ``` -*Note*: In order to interact with certain APIs, for example writing a file to a repo, one must generate an installation token +_Note_: In order to interact with certain APIs, for example writing a file to a repo, one must generate an installation token using the installation ID of the GitHub app and authenticate with the OAuth method mentioned above. See the examples. ### Rate Limiting ### @@ -296,9 +311,10 @@ import ( _ "github.com/bartventer/httpcache/store/memcache" // Register the in-memory backend ) -client := github.NewClient( - httpcache.NewClient("memcache://"), -).WithAuthToken(os.Getenv("GITHUB_TOKEN")) +client, err := github.NewClient(github.WithHTTPClient(httpcache.NewClient("memcache://")), github.WithAuthToken(os.Getenv("GITHUB_TOKEN"))) +if err != nil { + // Handle error. +} ``` Alternatively, the [bored-engineer/github-conditional-http-transport](https://github.com/bored-engineer/github-conditional-http-transport) @@ -334,7 +350,10 @@ embedded type of a more specific list options struct (for example `github.Response` struct. ```go -client := github.NewClient(nil) +client, err := github.NewClient() +if err != nil { + // Handle error. +} opt := &github.RepositoryListByOrgOptions{ ListOptions: github.ListOptions{PerPage: 10}, @@ -372,7 +391,10 @@ To handle rate limiting issues, make sure to use a rate-limiting transport. To use these methods, simply create an iterator and then range over it, for example: ```go -client := github.NewClient(nil) +client, err := github.NewClient() +if err != nil { + // Handle error. +} var allRepos []*github.Repository // create an iterator and start looping through all the results @@ -389,7 +411,10 @@ Alternatively, if you wish to use an external package, there is `enrichman/gh-it Its iterator will handle pagination for you, looping through all the available results. ```go -client := github.NewClient(nil) +client, err := github.NewClient() +if err != nil { + // Handle error. +} var allRepos []*github.Repository // create an iterator and start looping through all the results @@ -465,12 +490,14 @@ implementing preview features of the GitHub API, we've adopted the following versioning policy: * We increment the **major version** with any incompatible change to - non-preview functionality, including changes to the exported Go API surface - or behavior of the API. + non-preview functionality, including changes to the exported Go API surface + or behavior of the API. + * We increment the **minor version** with any backwards-compatible changes to - functionality, as well as any changes to preview functionality in the GitHub - API. GitHub makes no guarantee about the stability of preview functionality, - so neither do we consider it a stable part of the go-github API. + functionality, as well as any changes to preview functionality in the GitHub + API. GitHub makes no guarantee about the stability of preview functionality, + so neither do we consider it a stable part of the go-github API. + * We increment the **patch version** with any backwards-compatible bug fixes. Preview functionality may take the form of entire methods or simply additional diff --git a/example/actionpermissions/main.go b/example/actionpermissions/main.go index 91939f9236c..fb76272bbea 100644 --- a/example/actionpermissions/main.go +++ b/example/actionpermissions/main.go @@ -35,7 +35,10 @@ func main() { log.Fatal("No owner: owner of repo must be given") } ctx := context.Background() - client := github.NewClient(nil).WithAuthToken(token) + client, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + log.Fatal(err) + } actionsPermissionsRepository, _, err := client.Repositories.GetActionsPermissions(ctx, *owner, *name) if err != nil { diff --git a/example/auditlogstream/main.go b/example/auditlogstream/main.go index 773e22b6dfc..8be23d8f280 100644 --- a/example/auditlogstream/main.go +++ b/example/auditlogstream/main.go @@ -157,7 +157,7 @@ func runDelete(args []string) { } func newClient(token, apiURL string) *github.Client { - client, err := github.NewClient(nil).WithAuthToken(token).WithEnterpriseURLs(apiURL, apiURL) + client, err := github.NewClient(github.WithAuthToken(token), github.WithEnterpriseURLs(apiURL, apiURL)) if err != nil { log.Fatalf("Error creating GitHub client: %v", err) } diff --git a/example/basicauth/main.go b/example/basicauth/main.go index f4cffc4b887..9ff69a3915e 100644 --- a/example/basicauth/main.go +++ b/example/basicauth/main.go @@ -39,7 +39,11 @@ func main() { Password: strings.TrimSpace(string(password)), } - client := github.NewClient(tp.Client()) + client, err := github.NewClient(github.WithHTTPClient(tp.Client())) + if err != nil { + fmt.Printf("\nerror: %v\n", err) + return + } ctx := context.Background() user, _, err := client.Users.Get(ctx, "") diff --git a/example/codespaces/newreposecretwithxcrypto/main.go b/example/codespaces/newreposecretwithxcrypto/main.go index 25fbdc85ec1..cc2fbd4b9c8 100644 --- a/example/codespaces/newreposecretwithxcrypto/main.go +++ b/example/codespaces/newreposecretwithxcrypto/main.go @@ -73,7 +73,10 @@ func main() { } ctx := context.Background() - client := github.NewClient(nil).WithAuthToken(token) + client, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + log.Fatal(err) + } if err := addRepoSecret(ctx, client, *owner, *repo, secretName, secretValue); err != nil { log.Fatal(err) diff --git a/example/codespaces/newusersecretwithxcrypto/main.go b/example/codespaces/newusersecretwithxcrypto/main.go index 2b336f7082f..fc3537bc09d 100644 --- a/example/codespaces/newusersecretwithxcrypto/main.go +++ b/example/codespaces/newusersecretwithxcrypto/main.go @@ -66,7 +66,10 @@ func main() { } ctx := context.Background() - client := github.NewClient(nil).WithAuthToken(token) + client, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + log.Fatal(err) + } if err := addUserSecret(ctx, client, secretName, secretValue, *owner, *repo); err != nil { log.Fatal(err) diff --git a/example/commitpr/main.go b/example/commitpr/main.go index 3d7d0073fc8..d9a7f364323 100644 --- a/example/commitpr/main.go +++ b/example/commitpr/main.go @@ -221,7 +221,11 @@ func main() { if *sourceOwner == "" || *sourceRepo == "" || *commitBranch == "" || *sourceFiles == "" || *authorName == "" || *authorEmail == "" { log.Fatal("You need to specify a non-empty value for the flags `-source-owner`, `-source-repo`, `-commit-branch`, `-files`, `-author-name` and `-author-email`") } - client = github.NewClient(nil).WithAuthToken(token) + c, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + log.Fatal(err) + } + client = c ref, err := getRef() if err != nil { diff --git a/example/contents/main.go b/example/contents/main.go index da99099a18b..aca9ea3224d 100644 --- a/example/contents/main.go +++ b/example/contents/main.go @@ -50,7 +50,11 @@ func main() { fmt.Printf("\nDownloading %v/%v/%v at ref %v to %v...\n", owner, repo, repoPath, ref, outputPath) - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + fmt.Printf("Error creating GitHub client: %v\n", err) + os.Exit(1) + } rc, _, err := client.Repositories.DownloadContents(context.Background(), owner, repo, repoPath, &github.RepositoryContentGetOptions{Ref: ref}) if err != nil { diff --git a/example/listenvironments/main.go b/example/listenvironments/main.go index 7f0464f60b4..1718d157dbe 100644 --- a/example/listenvironments/main.go +++ b/example/listenvironments/main.go @@ -31,7 +31,10 @@ func main() { } ctx := context.Background() - client := github.NewClient(nil).WithAuthToken(token) + client, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + log.Fatal(err) + } expectedPageSize := 2 diff --git a/example/migrations/main.go b/example/migrations/main.go index cfd21923cc0..3d9f88c3e5a 100644 --- a/example/migrations/main.go +++ b/example/migrations/main.go @@ -17,7 +17,10 @@ import ( func fetchAllUserMigrations() ([]*github.UserMigration, error) { ctx := context.Background() - client := github.NewClient(nil).WithAuthToken("") + client, err := github.NewClient(github.WithAuthToken("")) + if err != nil { + return nil, err + } migrations, _, err := client.Migrations.ListUserMigrations(ctx, &github.ListOptions{Page: 1}) return migrations, err diff --git a/example/newfilewithappauth/main.go b/example/newfilewithappauth/main.go index ec14a2d193a..fcc63a0439f 100644 --- a/example/newfilewithappauth/main.go +++ b/example/newfilewithappauth/main.go @@ -34,12 +34,10 @@ func main() { itr.BaseURL = gitHost // create git client with app transport - client, err := github.NewClient( - &http.Client{ - Transport: itr, - Timeout: time.Second * 30, - }, - ).WithEnterpriseURLs(gitHost, gitHost) + client, err := github.NewClient(github.WithHTTPClient(&http.Client{ + Transport: itr, + Timeout: time.Second * 30, + }), github.WithEnterpriseURLs(gitHost, gitHost)) if err != nil { log.Fatalf("failed to create git client for app: %v\n", err) } @@ -64,9 +62,7 @@ func main() { log.Fatalf("failed to create installation token: %v\n", err) } - apiClient, err := github.NewClient(nil).WithAuthToken( - token.GetToken(), - ).WithEnterpriseURLs(gitHost, gitHost) + apiClient, err := github.NewClient(github.WithAuthToken(token.GetToken()), github.WithEnterpriseURLs(gitHost, gitHost)) if err != nil { log.Fatalf("failed to create new git client with token: %v\n", err) } diff --git a/example/newrepo/main.go b/example/newrepo/main.go index a3b78119cde..93bd4f61f43 100644 --- a/example/newrepo/main.go +++ b/example/newrepo/main.go @@ -36,7 +36,10 @@ func main() { log.Fatal("No name: New repos must be given a name") } ctx := context.Background() - client := github.NewClient(nil).WithAuthToken(token) + client, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + log.Fatal(err) + } r := &github.Repository{Name: name, Private: private, Description: description, AutoInit: autoInit} repo, _, err := client.Repositories.Create(ctx, "", r) diff --git a/example/newreposecretwithxcrypto/main.go b/example/newreposecretwithxcrypto/main.go index ceb440535c5..b95edbafb60 100644 --- a/example/newreposecretwithxcrypto/main.go +++ b/example/newreposecretwithxcrypto/main.go @@ -73,7 +73,10 @@ func main() { } ctx := context.Background() - client := github.NewClient(nil).WithAuthToken(token) + client, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + log.Fatal(err) + } if err := addRepoSecret(ctx, client, *owner, *repo, secretName, secretValue); err != nil { log.Fatal(err) diff --git a/example/otel/main.go b/example/otel/main.go index c726d5dc7bb..638e613bcd3 100644 --- a/example/otel/main.go +++ b/example/otel/main.go @@ -35,15 +35,16 @@ func main() { } }() - // Configure HTTP client with OTel transport - httpClient := &http.Client{ - Transport: otel.NewTransport( - http.DefaultTransport, - otel.WithTracerProvider(tp), - ), - } + // Configure OTel transport + t := otel.NewTransport( + http.DefaultTransport, + otel.WithTracerProvider(tp), + ) - client := github.NewClient(httpClient) + client, err := github.NewClient(github.WithTransport(t)) + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } // Make a request (Get Rate Limits is public and cheap) limits, resp, err := client.RateLimit.Get(context.Background()) diff --git a/example/ratelimit/main.go b/example/ratelimit/main.go index 014b7c90f63..585377987e1 100644 --- a/example/ratelimit/main.go +++ b/example/ratelimit/main.go @@ -37,7 +37,11 @@ func main() { paginator := githubpagination.NewClient(rateLimiter, githubpagination.WithPerPage(100), // default to 100 results per page ) - client := github.NewClient(paginator) + client, err := github.NewClient(github.WithHTTPClient(paginator)) + if err != nil { + fmt.Printf("Error creating GitHub client: %v\n", err) + return + } // Example usage of the client repos, _, err := client.Repositories.ListByUser(context.Background(), username, nil) diff --git a/example/simple/main.go b/example/simple/main.go index 2b5849e3a7f..3ac8dd7400d 100644 --- a/example/simple/main.go +++ b/example/simple/main.go @@ -17,7 +17,10 @@ import ( // Fetch all the public organizations' membership of a user. func fetchOrganizations(username string) ([]*github.Organization, error) { - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + return nil, err + } orgs, _, err := client.Organizations.List(context.Background(), username, nil) return orgs, err } diff --git a/example/tokenauth/main.go b/example/tokenauth/main.go index c91592dbe94..9b86034749e 100644 --- a/example/tokenauth/main.go +++ b/example/tokenauth/main.go @@ -25,7 +25,10 @@ func main() { fmt.Println() ctx := context.Background() - client := github.NewClient(nil).WithAuthToken(string(token)) + client, err := github.NewClient(github.WithAuthToken(string(token))) + if err != nil { + log.Fatalf("Error creating GitHub client: %v\n", err) + } user, resp, err := client.Users.Get(ctx, "") if err != nil { diff --git a/example/topics/main.go b/example/topics/main.go index df6f481a62a..9a3e8393daf 100644 --- a/example/topics/main.go +++ b/example/topics/main.go @@ -17,7 +17,11 @@ import ( // Fetch and lists all the public topics associated with the specified GitHub topic. func fetchTopics(topic string) (*github.TopicsSearchResult, error) { - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + return nil, err + } + topics, _, err := client.Search.Topics(context.Background(), topic, nil) return topics, err } diff --git a/example/uploadreleaseassetfromrelease/main.go b/example/uploadreleaseassetfromrelease/main.go index 9e9378d23a2..532c351d496 100644 --- a/example/uploadreleaseassetfromrelease/main.go +++ b/example/uploadreleaseassetfromrelease/main.go @@ -24,7 +24,10 @@ func main() { } ctx := context.Background() - client := github.NewClient(nil).WithAuthToken(token) + client, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + log.Fatalf("Error creating GitHub client: %v\n", err) + } owner := "OWNER" repo := "REPO" diff --git a/example/verifyartifact/main.go b/example/verifyartifact/main.go index 2faa1266c2b..f4d271c6a28 100644 --- a/example/verifyartifact/main.go +++ b/example/verifyartifact/main.go @@ -78,7 +78,10 @@ func main() { } ctx := context.Background() - client := github.NewClient(nil).WithAuthToken(token) + client, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + log.Fatalf("Error creating GitHub client: %v\n", err) + } // Fetch attestations from the GitHub API. // The attestations API doesn't differentiate between users and orgs, diff --git a/github/actions_artifacts.go b/github/actions_artifacts.go index 022be81b898..2a345edbc9c 100644 --- a/github/actions_artifacts.go +++ b/github/actions_artifacts.go @@ -161,7 +161,7 @@ func (s *ActionsService) GetArtifact(ctx context.Context, owner, repo string, ar func (s *ActionsService) DownloadArtifact(ctx context.Context, owner, repo string, artifactID int64, maxRedirects int) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/artifacts/%v/zip", owner, repo, artifactID) - if s.client.RateLimitRedirectionalEndpoints { + if s.client.rateLimitRedirectionalEndpoints { return s.downloadArtifactWithRateLimit(ctx, u, maxRedirects) } diff --git a/github/actions_artifacts_test.go b/github/actions_artifacts_test.go index 694a5e4e757..d3256f49ad5 100644 --- a/github/actions_artifacts_test.go +++ b/github/actions_artifacts_test.go @@ -291,7 +291,7 @@ func TestActionsService_DownloadArtifact(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -318,14 +318,15 @@ func TestActionsService_DownloadArtifact(t *testing.T) { return err }) - // Add custom round tripper - client.client.Transport = roundTripperFunc(func(*http.Request) (*http.Response, error) { + // Create "bad" client with custom round tripper + badClient, err := client.Clone(WithTransport(roundTripperFunc(func(*http.Request) (*http.Response, error) { return nil, errors.New("failed to download artifact") - }) - // propagate custom round tripper to client without CheckRedirect - client.initialize() + }))) + if err != nil { + t.Fatalf("failed to clone client: %v", err) + } testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.DownloadArtifact(ctx, "o", "r", 1, 1) + _, _, err = badClient.Actions.DownloadArtifact(ctx, "o", "r", 1, 1) return err }) }) @@ -352,7 +353,7 @@ func TestActionsService_DownloadArtifact_invalidOwner(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, _, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits ctx := t.Context() _, _, err := client.Actions.DownloadArtifact(ctx, "%", "r", 1, 1) @@ -381,7 +382,7 @@ func TestActionsService_DownloadArtifact_invalidRepo(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, _, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits ctx := t.Context() _, _, err := client.Actions.DownloadArtifact(ctx, "o", "%", 1, 1) @@ -410,7 +411,7 @@ func TestActionsService_DownloadArtifact_StatusMovedPermanently_dontFollowRedire t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -446,7 +447,7 @@ func TestActionsService_DownloadArtifact_StatusMovedPermanently_followRedirects( t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, serverURL := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -494,7 +495,7 @@ func TestActionsService_DownloadArtifact_unexpectedCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, serverURL := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") diff --git a/github/actions_workflow_jobs.go b/github/actions_workflow_jobs.go index d18fe5f6b77..9419cf89711 100644 --- a/github/actions_workflow_jobs.go +++ b/github/actions_workflow_jobs.go @@ -150,7 +150,7 @@ func (s *ActionsService) GetWorkflowJobByID(ctx context.Context, owner, repo str func (s *ActionsService) GetWorkflowJobLogs(ctx context.Context, owner, repo string, jobID int64, maxRedirects int) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/jobs/%v/logs", owner, repo, jobID) - if s.client.RateLimitRedirectionalEndpoints { + if s.client.rateLimitRedirectionalEndpoints { return s.getWorkflowJobLogsWithRateLimit(ctx, u, maxRedirects) } diff --git a/github/actions_workflow_jobs_test.go b/github/actions_workflow_jobs_test.go index c20fec965f8..2593eba1c7e 100644 --- a/github/actions_workflow_jobs_test.go +++ b/github/actions_workflow_jobs_test.go @@ -202,7 +202,7 @@ func TestActionsService_GetWorkflowJobLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -228,14 +228,15 @@ func TestActionsService_GetWorkflowJobLogs(t *testing.T) { return err }) - // Add custom round tripper - client.client.Transport = roundTripperFunc(func(*http.Request) (*http.Response, error) { + // Create "bad" client with custom round tripper + badClient, err := client.Clone(WithTransport(roundTripperFunc(func(*http.Request) (*http.Response, error) { return nil, errors.New("failed to get workflow logs") - }) - // propagate custom round tripper to client without CheckRedirect - client.initialize() + }))) + if err != nil { + t.Fatalf("failed to clone client: %v", err) + } testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 1) + _, _, err = badClient.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 1) return err }) }) @@ -262,7 +263,7 @@ func TestActionsService_GetWorkflowJobLogs_StatusMovedPermanently_dontFollowRedi t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -298,7 +299,7 @@ func TestActionsService_GetWorkflowJobLogs_StatusMovedPermanently_followRedirect t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, serverURL := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits // Mock a redirect link, which leads to an archive link mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { @@ -350,7 +351,7 @@ func TestActionsService_GetWorkflowJobLogs_unexpectedCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, serverURL := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits // Mock a redirect link, which leads to an archive link mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { diff --git a/github/actions_workflow_runs.go b/github/actions_workflow_runs.go index 30fc954ef40..6e46328dfbd 100644 --- a/github/actions_workflow_runs.go +++ b/github/actions_workflow_runs.go @@ -263,7 +263,7 @@ func (s *ActionsService) GetWorkflowRunAttempt(ctx context.Context, owner, repo func (s *ActionsService) GetWorkflowRunAttemptLogs(ctx context.Context, owner, repo string, runID int64, attemptNumber, maxRedirects int) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/runs/%v/attempts/%v/logs", owner, repo, runID, attemptNumber) - if s.client.RateLimitRedirectionalEndpoints { + if s.client.rateLimitRedirectionalEndpoints { return s.getWorkflowRunAttemptLogsWithRateLimit(ctx, u, maxRedirects) } @@ -383,7 +383,7 @@ func (s *ActionsService) CancelWorkflowRunByID(ctx context.Context, owner, repo func (s *ActionsService) GetWorkflowRunLogs(ctx context.Context, owner, repo string, runID int64, maxRedirects int) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/runs/%v/logs", owner, repo, runID) - if s.client.RateLimitRedirectionalEndpoints { + if s.client.rateLimitRedirectionalEndpoints { return s.getWorkflowRunLogsWithRateLimit(ctx, u, maxRedirects) } diff --git a/github/actions_workflow_runs_test.go b/github/actions_workflow_runs_test.go index 9730a78944d..42a7a1ff33e 100644 --- a/github/actions_workflow_runs_test.go +++ b/github/actions_workflow_runs_test.go @@ -208,7 +208,7 @@ func TestActionsService_GetWorkflowRunAttemptLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -257,7 +257,7 @@ func TestActionsService_GetWorkflowRunAttemptLogs_StatusMovedPermanently_dontFol t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -293,7 +293,7 @@ func TestActionsService_GetWorkflowRunAttemptLogs_StatusMovedPermanently_followR t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, serverURL := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits // Mock a redirect link, which leads to an archive link mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { @@ -351,7 +351,7 @@ func TestActionsService_GetWorkflowRunAttemptLogs_unexpectedCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, serverURL := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits // Mock a redirect link, which leads to an archive link mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { @@ -519,7 +519,7 @@ func TestActionsService_GetWorkflowRunLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -568,7 +568,7 @@ func TestActionsService_GetWorkflowRunLogs_StatusMovedPermanently_dontFollowRedi t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -604,7 +604,7 @@ func TestActionsService_GetWorkflowRunLogs_StatusMovedPermanently_followRedirect t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, serverURL := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits // Mock a redirect link, which leads to an archive link mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { @@ -662,7 +662,7 @@ func TestActionsService_GetWorkflowRunLogs_unexpectedCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, serverURL := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits // Mock a redirect link, which leads to an archive link mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { diff --git a/github/actions_workflows_test.go b/github/actions_workflows_test.go index eb41e338fb1..fefe382626b 100644 --- a/github/actions_workflows_test.go +++ b/github/actions_workflows_test.go @@ -263,7 +263,7 @@ func TestActionsService_CreateWorkflowDispatchEventByID(t *testing.T) { } // Test s.client.NewRequest failure - client.BaseURL.Path = "" + client.baseURL.Path = "" _, _, err = client.Actions.CreateWorkflowDispatchEventByID(ctx, "o", "r", 72844, event) if err == nil { t.Error("client.BaseURL.Path='' CreateWorkflowDispatchEventByID err = nil, want error") @@ -318,7 +318,7 @@ func TestActionsService_CreateWorkflowDispatchEventByFileName(t *testing.T) { } // Test s.client.NewRequest failure - client.BaseURL.Path = "" + client.baseURL.Path = "" _, _, err = client.Actions.CreateWorkflowDispatchEventByFileName(ctx, "o", "r", "main.yml", event) if err == nil { t.Error("client.BaseURL.Path='' CreateWorkflowDispatchEventByFileName err = nil, want error") @@ -411,7 +411,7 @@ func TestActionsService_EnableWorkflowByID(t *testing.T) { } // Test s.client.NewRequest failure - client.BaseURL.Path = "" + client.baseURL.Path = "" _, err = client.Actions.EnableWorkflowByID(ctx, "o", "r", 72844) if err == nil { t.Error("client.BaseURL.Path='' EnableWorkflowByID err = nil, want error") @@ -446,7 +446,7 @@ func TestActionsService_EnableWorkflowByFilename(t *testing.T) { } // Test s.client.NewRequest failure - client.BaseURL.Path = "" + client.baseURL.Path = "" _, err = client.Actions.EnableWorkflowByFileName(ctx, "o", "r", "main.yml") if err == nil { t.Error("client.BaseURL.Path='' EnableWorkflowByFilename err = nil, want error") @@ -481,7 +481,7 @@ func TestActionsService_DisableWorkflowByID(t *testing.T) { } // Test s.client.NewRequest failure - client.BaseURL.Path = "" + client.baseURL.Path = "" _, err = client.Actions.DisableWorkflowByID(ctx, "o", "r", 72844) if err == nil { t.Error("client.BaseURL.Path='' DisableWorkflowByID err = nil, want error") @@ -516,7 +516,7 @@ func TestActionsService_DisableWorkflowByFileName(t *testing.T) { } // Test s.client.NewRequest failure - client.BaseURL.Path = "" + client.baseURL.Path = "" _, err = client.Actions.DisableWorkflowByFileName(ctx, "o", "r", "main.yml") if err == nil { t.Error("client.BaseURL.Path='' DisableWorkflowByFileName err = nil, want error") diff --git a/github/copilot_test.go b/github/copilot_test.go index 25a645f0456..468eb2f50d7 100644 --- a/github/copilot_test.go +++ b/github/copilot_test.go @@ -2974,7 +2974,7 @@ func TestCopilotService_DownloadCopilotMetrics(t *testing.T) { }) ctx := t.Context() - url := client.BaseURL.String() + "path/to/download" + url := client.baseURL.String() + "path/to/download" got, resp, err := client.Copilot.DownloadCopilotMetrics(ctx, url) if err != nil { t.Errorf("Copilot.DownloadCopilotMetrics returned error: %v", err) @@ -3033,7 +3033,7 @@ func TestCopilotService_DownloadCopilotMetrics(t *testing.T) { w.WriteHeader(http.StatusNotFound) }) - urlErr := client.BaseURL.String() + "path/to/download/error" + urlErr := client.baseURL.String() + "path/to/download/error" _, _, err = client.Copilot.DownloadCopilotMetrics(ctx, urlErr) if err == nil { t.Error("Copilot.DownloadCopilotMetrics expected error but got none") @@ -3056,7 +3056,7 @@ func TestCopilotService_DownloadCopilotMetrics(t *testing.T) { testMethod(t, r, "GET") fmt.Fprint(w, `[{invalid JSON`) }) - urlBadJSON := client.BaseURL.String() + "path/to/download/badjson" + urlBadJSON := client.baseURL.String() + "path/to/download/badjson" _, _, err = client.Copilot.DownloadCopilotMetrics(ctx, urlBadJSON) if err == nil { t.Error("Copilot.DownloadCopilotMetrics expected error for bad JSON, got none") @@ -3115,7 +3115,7 @@ func TestCopilotService_DownloadDailyMetrics(t *testing.T) { }) ctx := t.Context() - url := client.BaseURL.String() + "path/to/daily" + url := client.baseURL.String() + "path/to/daily" got, resp, err := client.Copilot.DownloadDailyMetrics(ctx, url) if err != nil { t.Errorf("Copilot.DownloadDailyMetrics returned error: %v", err) @@ -3178,7 +3178,7 @@ func TestCopilotService_DownloadDailyMetrics(t *testing.T) { testMethod(t, r, "GET") w.WriteHeader(http.StatusNotFound) }) - if _, _, err := client.Copilot.DownloadDailyMetrics(ctx, client.BaseURL.String()+"path/to/daily/error"); err == nil { + if _, _, err := client.Copilot.DownloadDailyMetrics(ctx, client.baseURL.String()+"path/to/daily/error"); err == nil { t.Error("Copilot.DownloadDailyMetrics expected error but got none") } if _, _, err := client.Copilot.DownloadDailyMetrics(ctx, "\n"); err == nil { @@ -3192,7 +3192,7 @@ func TestCopilotService_DownloadDailyMetrics(t *testing.T) { testMethod(t, r, "GET") fmt.Fprint(w, `{invalid`) }) - if _, _, err := client.Copilot.DownloadDailyMetrics(ctx, client.BaseURL.String()+"path/to/daily/badjson"); err == nil { + if _, _, err := client.Copilot.DownloadDailyMetrics(ctx, client.baseURL.String()+"path/to/daily/badjson"); err == nil { t.Error("Copilot.DownloadDailyMetrics expected error for bad JSON, got none") } } @@ -3235,7 +3235,7 @@ func TestCopilotService_DownloadPeriodicMetrics(t *testing.T) { }) ctx := t.Context() - url := client.BaseURL.String() + "path/to/periodic" + url := client.baseURL.String() + "path/to/periodic" got, resp, err := client.Copilot.DownloadPeriodicMetrics(ctx, url) if err != nil { t.Errorf("Copilot.DownloadPeriodicMetrics returned error: %v", err) @@ -3282,7 +3282,7 @@ func TestCopilotService_DownloadPeriodicMetrics(t *testing.T) { testMethod(t, r, "GET") w.WriteHeader(http.StatusNotFound) }) - if _, _, err := client.Copilot.DownloadPeriodicMetrics(ctx, client.BaseURL.String()+"path/to/periodic/error"); err == nil { + if _, _, err := client.Copilot.DownloadPeriodicMetrics(ctx, client.baseURL.String()+"path/to/periodic/error"); err == nil { t.Error("Copilot.DownloadPeriodicMetrics expected error but got none") } if _, _, err := client.Copilot.DownloadPeriodicMetrics(ctx, "\n"); err == nil { @@ -3296,7 +3296,7 @@ func TestCopilotService_DownloadPeriodicMetrics(t *testing.T) { testMethod(t, r, "GET") fmt.Fprint(w, `{invalid`) }) - if _, _, err := client.Copilot.DownloadPeriodicMetrics(ctx, client.BaseURL.String()+"path/to/periodic/badjson"); err == nil { + if _, _, err := client.Copilot.DownloadPeriodicMetrics(ctx, client.baseURL.String()+"path/to/periodic/badjson"); err == nil { t.Error("Copilot.DownloadPeriodicMetrics expected error for bad JSON, got none") } } @@ -3313,7 +3313,7 @@ func TestCopilotService_DownloadUserDailyMetrics(t *testing.T) { }) ctx := t.Context() - url := client.BaseURL.String() + "path/to/users-daily" + url := client.baseURL.String() + "path/to/users-daily" got, resp, err := client.Copilot.DownloadUserDailyMetrics(ctx, url) if err != nil { t.Errorf("Copilot.DownloadUserDailyMetrics returned error: %v", err) @@ -3376,7 +3376,7 @@ func TestCopilotService_DownloadUserDailyMetrics(t *testing.T) { mux.HandleFunc("/path/to/users-daily/empty", func(_ http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") }) - gotEmpty, _, err := client.Copilot.DownloadUserDailyMetrics(ctx, client.BaseURL.String()+"path/to/users-daily/empty") + gotEmpty, _, err := client.Copilot.DownloadUserDailyMetrics(ctx, client.baseURL.String()+"path/to/users-daily/empty") if err != nil { t.Errorf("Copilot.DownloadUserDailyMetrics empty body returned error: %v", err) } @@ -3388,7 +3388,7 @@ func TestCopilotService_DownloadUserDailyMetrics(t *testing.T) { testMethod(t, r, "GET") w.WriteHeader(http.StatusNotFound) }) - if _, _, err := client.Copilot.DownloadUserDailyMetrics(ctx, client.BaseURL.String()+"path/to/users-daily/error"); err == nil { + if _, _, err := client.Copilot.DownloadUserDailyMetrics(ctx, client.baseURL.String()+"path/to/users-daily/error"); err == nil { t.Error("Copilot.DownloadUserDailyMetrics expected error but got none") } if _, _, err := client.Copilot.DownloadUserDailyMetrics(ctx, "\n"); err == nil { @@ -3403,7 +3403,7 @@ func TestCopilotService_DownloadUserDailyMetrics(t *testing.T) { testMethod(t, r, "GET") fmt.Fprint(w, "{\"user_id\":1,\"day\":\"2026-04-01\"}\n{bad\n") }) - if _, _, err := client.Copilot.DownloadUserDailyMetrics(ctx, client.BaseURL.String()+"path/to/users-daily/badjson"); err == nil { + if _, _, err := client.Copilot.DownloadUserDailyMetrics(ctx, client.baseURL.String()+"path/to/users-daily/badjson"); err == nil { t.Error("Copilot.DownloadUserDailyMetrics expected error for bad JSON, got none") } } @@ -3420,7 +3420,7 @@ func TestCopilotService_DownloadUserPeriodicMetrics(t *testing.T) { }) ctx := t.Context() - url := client.BaseURL.String() + "path/to/users-periodic" + url := client.baseURL.String() + "path/to/users-periodic" got, resp, err := client.Copilot.DownloadUserPeriodicMetrics(ctx, url) if err != nil { t.Errorf("Copilot.DownloadUserPeriodicMetrics returned error: %v", err) @@ -3469,7 +3469,7 @@ func TestCopilotService_DownloadUserPeriodicMetrics(t *testing.T) { mux.HandleFunc("/path/to/users-periodic/empty", func(_ http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") }) - gotEmpty, _, err := client.Copilot.DownloadUserPeriodicMetrics(ctx, client.BaseURL.String()+"path/to/users-periodic/empty") + gotEmpty, _, err := client.Copilot.DownloadUserPeriodicMetrics(ctx, client.baseURL.String()+"path/to/users-periodic/empty") if err != nil { t.Errorf("Copilot.DownloadUserPeriodicMetrics empty body returned error: %v", err) } @@ -3481,7 +3481,7 @@ func TestCopilotService_DownloadUserPeriodicMetrics(t *testing.T) { testMethod(t, r, "GET") w.WriteHeader(http.StatusNotFound) }) - if _, _, err := client.Copilot.DownloadUserPeriodicMetrics(ctx, client.BaseURL.String()+"path/to/users-periodic/error"); err == nil { + if _, _, err := client.Copilot.DownloadUserPeriodicMetrics(ctx, client.baseURL.String()+"path/to/users-periodic/error"); err == nil { t.Error("Copilot.DownloadUserPeriodicMetrics expected error but got none") } if _, _, err := client.Copilot.DownloadUserPeriodicMetrics(ctx, "\n"); err == nil { @@ -3495,7 +3495,7 @@ func TestCopilotService_DownloadUserPeriodicMetrics(t *testing.T) { testMethod(t, r, "GET") fmt.Fprint(w, "{not json\n") }) - if _, _, err := client.Copilot.DownloadUserPeriodicMetrics(ctx, client.BaseURL.String()+"path/to/users-periodic/badjson"); err == nil { + if _, _, err := client.Copilot.DownloadUserPeriodicMetrics(ctx, client.baseURL.String()+"path/to/users-periodic/badjson"); err == nil { t.Error("Copilot.DownloadUserPeriodicMetrics expected error for bad JSON, got none") } } diff --git a/github/doc.go b/github/doc.go index 08784214ab6..df5d1354a31 100644 --- a/github/doc.go +++ b/github/doc.go @@ -10,17 +10,23 @@ Usage: import "github.com/google/go-github/v86/github" -Construct a new GitHub client, then use the various services on the client to +Construct a new GitHub client using [NewClient], then use the various services on the client to access different parts of the GitHub API. For example: - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + // Handle error. + } // list all organizations for user "willnorris" orgs, _, err := client.Organizations.List(ctx, "willnorris", nil) Some API methods have optional parameters that can be passed. For example: - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + // Handle error. + } // list public repositories for org "github" opt := &github.RepositoryListByOrgOptions{Type: "public"} @@ -39,11 +45,14 @@ For more sample code snippets, head over to the https://github.com/google/go-git # Authentication -Use [Client.WithAuthToken] to configure your client to authenticate using an OAuth token +Use [WithAuthToken] to configure your client to authenticate using an OAuth token (for example, a personal access token). This is what is needed for a majority of use cases aside from GitHub Apps. - client := github.NewClient(nil).WithAuthToken("... your access token ...") + client, err := github.NewClient(github.WithAuthToken("... your access token ...")) + if err != nil { + // Handle error. + } Note that when using an authenticated [Client], all calls made by the client will include the specified OAuth token. Therefore, authenticated clients should @@ -55,7 +64,7 @@ For API methods that require HTTP Basic Authentication, use the GitHub Apps authentication can be provided by the https://github.com/bradleyfalzon/ghinstallation package. It supports both authentication as an installation, using an installation access token, -and as an app, using a JWT. +and as an app, using a JWT. Use the [WithTransport] option to configure your client to use the appropriate transport. To authenticate as an installation: @@ -69,7 +78,10 @@ To authenticate as an installation: } // Use installation transport with client - client := github.NewClient(&http.Client{Transport: itr}) + client, err := github.NewClient(github.WithTransport(itr)) + if err != nil { + // Handle error. + } // Use client... } @@ -86,7 +98,10 @@ To authenticate as an app, using a JWT: } // Use app transport with client - client := github.NewClient(&http.Client{Transport: atr}) + client, err := github.NewClient(github.WithTransport(atr)) + if err != nil { + // Handle error. + } // Use client... } @@ -177,7 +192,10 @@ embedded type of a more specific list options struct (for example [PullRequestListOptions]). Pages information is available via the [Response] struct. - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + // Handle error. + } opt := &github.RepositoryListByOrgOptions{ ListOptions: github.ListOptions{PerPage: 10}, diff --git a/github/example_iterators_test.go b/github/example_iterators_test.go index 36d1afa8dcf..cf8a87e8049 100644 --- a/github/example_iterators_test.go +++ b/github/example_iterators_test.go @@ -14,7 +14,10 @@ import ( ) func ExampleRepositoriesService_ListByUserIter() { - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } ctx := context.Background() // List all repositories for a user using the iterator. @@ -23,7 +26,7 @@ func ExampleRepositoriesService_ListByUserIter() { opts := &github.RepositoryListByUserOptions{Type: "public"} for repo, err := range client.Repositories.ListByUserIter(ctx, "octocat", opts) { if err != nil { - log.Fatalf("Error listing repos: %v", err) + log.Fatalf("Error listing repositories by user: %v", err) } fmt.Println(repo.GetName()) } diff --git a/github/examples_test.go b/github/examples_test.go index a95a12dfb4d..eb77cf7942d 100644 --- a/github/examples_test.go +++ b/github/examples_test.go @@ -16,7 +16,10 @@ import ( ) func ExampleMarkdownService_Render() { - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } input := "# heading #\n\nLink to issue #1" opt := &github.MarkdownOptions{Mode: "gfm", Context: "google/go-github"} @@ -24,33 +27,37 @@ func ExampleMarkdownService_Render() { ctx := context.Background() output, _, err := client.Markdown.Render(ctx, input, opt) if err != nil { - fmt.Println(err) + log.Fatalf("Error rendering markdown: %v", err) } fmt.Println(output) } func ExampleRepositoriesService_GetReadme() { - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } ctx := context.Background() readme, _, err := client.Repositories.GetReadme(ctx, "google", "go-github", nil) if err != nil { - fmt.Println(err) - return + log.Fatalf("Error getting README: %v", err) } content, err := readme.GetContent() if err != nil { - fmt.Println(err) - return + log.Fatalf("Error getting README content: %v", err) } fmt.Printf("google/go-github README:\n%v\n", content) } func ExampleRepositoriesService_ListByUser() { - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } user := "willnorris" opt := &github.RepositoryListByUserOptions{Type: "owner", Sort: "updated", Direction: "desc"} @@ -58,7 +65,7 @@ func ExampleRepositoriesService_ListByUser() { ctx := context.Background() repos, _, err := client.Repositories.ListByUser(ctx, user, opt) if err != nil { - fmt.Println(err) + log.Fatalf("error listing repositories by user: %v", err) } fmt.Printf("Recently updated repositories by %q: %v", user, github.Stringify(repos)) @@ -73,7 +80,10 @@ func ExampleRepositoriesService_CreateFile() { // github.NewClient() instead of nil. See the following documentation for more // information on how to authenticate with the client: // https://pkg.go.dev/github.com/google/go-github/v86/github#hdr-Authentication - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } ctx := context.Background() fileContent := []byte("This is the content of my file\nand the 2nd line of it") @@ -86,21 +96,23 @@ func ExampleRepositoriesService_CreateFile() { Branch: github.Ptr("master"), Committer: &github.CommitAuthor{Name: github.Ptr("FirstName LastName"), Email: github.Ptr("user@example.com")}, } - _, _, err := client.Repositories.CreateFile(ctx, "myOrganization", "myRepository", "myNewFile.md", opts) - if err != nil { - fmt.Println(err) - return + if _, _, err := client.Repositories.CreateFile(ctx, "myOrganization", "myRepository", "myNewFile.md", opts); err != nil { + log.Fatalf("Error creating file: %v", err) } } func ExampleUsersService_ListAll() { - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } + ctx := context.Background() opts := &github.UserListOptions{} for { users, _, err := client.Users.ListAll(ctx, opts) if err != nil { - log.Fatalf("error listing users: %v", err) + log.Fatalf("Error listing users: %v", err) } if len(users) == 0 { break @@ -118,7 +130,10 @@ func ExamplePullRequestsService_Create() { // github.NewClient() instead of nil. See the following documentation for more // information on how to authenticate with the client: // https://pkg.go.dev/github.com/google/go-github/v86/github#hdr-Authentication - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } newPR := &github.NewPullRequest{ Title: github.Ptr("My awesome pull request"), @@ -131,8 +146,7 @@ func ExamplePullRequestsService_Create() { ctx := context.Background() pr, _, err := client.PullRequests.Create(ctx, "myOrganization", "myRepository", newPR) if err != nil { - fmt.Println(err) - return + log.Fatalf("Error creating pull request: %v", err) } fmt.Printf("PR created: %v\n", pr.GetHTMLURL()) @@ -147,7 +161,10 @@ func ExampleTeamsService_ListTeams() { // See the following documentation for more information on how to authenticate // with the client: // https://pkg.go.dev/github.com/google/go-github/v86/github#hdr-Authentication - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } teamName := "Developers team" ctx := context.Background() @@ -156,8 +173,7 @@ func ExampleTeamsService_ListTeams() { for { teams, resp, err := client.Teams.ListTeams(ctx, "myOrganization", opts) if err != nil { - fmt.Println(err) - return + log.Fatalf("error listing teams: %v", err) } for _, t := range teams { if t.GetName() == teamName { @@ -175,13 +191,16 @@ func ExampleTeamsService_ListTeams() { } func ExampleUsersService_ListUserSocialAccounts() { - client := github.NewClient(nil) + client, err := github.NewClient() + if err != nil { + log.Fatalf("Error creating GitHub client: %v", err) + } ctx := context.Background() opts := &github.ListOptions{} for { accounts, resp, err := client.Users.ListUserSocialAccounts(ctx, "shreyjain13", opts) if err != nil { - log.Fatalf("Failed to list user social accounts: %v", err) + log.Fatalf("Error listing user social accounts: %v", err) } if resp.NextPage == 0 || len(accounts) == 0 { break diff --git a/github/github.go b/github/github.go index 9f1eee4b90e..756d96365cc 100644 --- a/github/github.go +++ b/github/github.go @@ -162,25 +162,24 @@ var ErrPathForbidden = errors.New("path must not contain '..' due to auth vulner // A Client manages communication with the GitHub API. type Client struct { - clientMu sync.Mutex // clientMu protects the client fields during copy and Client calls. client *http.Client // HTTP client used to communicate with the API. clientIgnoreRedirects *http.Client // HTTP client used to communicate with the API on endpoints where we don't want to follow redirects. // Base URL for API requests. Defaults to the public GitHub API, but can be - // set to a domain endpoint to use with GitHub Enterprise. BaseURL should + // set to a domain endpoint to use with GitHub Enterprise. baseURL should // always be specified with a trailing slash. - BaseURL *url.URL + baseURL *url.URL // Base URL for uploading files. - UploadURL *url.URL + uploadURL *url.URL // User agent used when communicating with the GitHub API. - UserAgent string + userAgent string - // DisableRateLimitCheck stops the client checking for rate limits or tracking + // disableRateLimitCheck stops the client checking for rate limits or tracking // them. This is different to setting BypassRateLimitCheck in the context, // as that still tracks the rate limits. - DisableRateLimitCheck bool + disableRateLimitCheck bool rateMu sync.Mutex rateLimits [Categories]Rate // Rate limits for the client as determined by the most recent API calls. @@ -188,10 +187,10 @@ type Client struct { // If specified, Client will block requests for at most this duration in case of reaching a secondary // rate limit - MaxSecondaryRateLimitRetryAfterDuration time.Duration + maxSecondaryRateLimitRetryAfterDuration time.Duration // Whether to respect rate limit headers on endpoints that return 302 redirections to artifacts - RateLimitRedirectionalEndpoints bool + rateLimitRedirectionalEndpoints bool common service // Reuse a single struct instead of allocating one for each service on the heap. @@ -248,8 +247,6 @@ type service struct { // This should only be used for requests to the GitHub API because // request headers will contain an authorization token. func (c *Client) Client() *http.Client { - c.clientMu.Lock() - defer c.clientMu.Unlock() clientCopy := *c.client return &clientCopy } @@ -337,113 +334,271 @@ func addOptions[P structPtr[T], T any](s string, opts P) (string, error) { return u.String(), nil } -// NewClient returns a new GitHub API client. If a nil httpClient is -// provided, a new http.Client will be used. To use API methods which require -// authentication, either use Client.WithAuthToken or provide NewClient with -// an http.Client that will perform the authentication for you (such as that -// provided by the golang.org/x/oauth2 library). -// -// Note: When using a nil httpClient, the default client has no timeout set. -// This may not be suitable for production environments. It is recommended to -// provide a custom http.Client with an appropriate timeout. -func NewClient(httpClient *http.Client) *Client { - if httpClient == nil { - httpClient = &http.Client{} - } - httpClient2 := *httpClient - c := &Client{client: &httpClient2} - c.initialize() - return c -} - -// WithAuthToken returns a copy of the client configured to use the provided token for the Authorization header. -func (c *Client) WithAuthToken(token string) *Client { - c2 := c.copy() - defer c2.initialize() - transport := c2.client.Transport - if transport == nil { - transport = http.DefaultTransport - } - c2.client.Transport = roundTripperFunc( - func(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - if token != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token)) - } - return transport.RoundTrip(req) - }, - ) - return c2 +// errUninitialized is returned when an uninitialized Client is used. +var errUninitialized = errors.New("client is not initialized") + +// clientOptions holds the configuration options for a Client. +type clientOptions struct { + httpClient *http.Client + transport http.RoundTripper + userAgent *string + envProxy bool + token *string + baseURL *url.URL + uploadURL *url.URL + disableRateLimitCheck bool + rateLimitRedirectionalEndpoints bool + maxSecondaryRateLimitRetryAfterDuration *time.Duration + marketplaceStubbed bool +} + +// ClientOptionsFunc is a functional option for providing configuration options +// to a Client. +type ClientOptionsFunc func(*clientOptions) error + +// WithHTTPClient returns a ClientOptionsFunc that sets the http.Client +// for a Client. If not set, a default http.Client will be used. +func WithHTTPClient(httpClient *http.Client) ClientOptionsFunc { + return func(o *clientOptions) error { + if httpClient == nil { + return errors.New("http client must not be nil") + } + + httpClient := *httpClient + o.httpClient = &httpClient + return nil + } } -// WithEnterpriseURLs returns a copy of the client configured to use the provided base and -// upload URLs. If the base URL does not have the suffix "/api/v3/", it will be added -// automatically. If the upload URL does not have the suffix "/api/uploads", it will be -// added automatically. -// -// Note that WithEnterpriseURLs is a convenience helper only; -// its behavior is equivalent to setting the BaseURL and UploadURL fields. -// -// Another important thing is that by default, the GitHub Enterprise URL format -// should be http(s)://[hostname]/api/v3/ or you will always receive the 406 status code. -// The upload URL format should be http(s)://[hostname]/api/uploads/. -func (c *Client) WithEnterpriseURLs(baseURL, uploadURL string) (*Client, error) { - c2 := c.copy() - defer c2.initialize() - var err error - c2.BaseURL, err = url.Parse(baseURL) - if err != nil { - return nil, err +// WithTransport returns a ClientOptionsFunc that sets the http.RoundTripper +// for a Client. This overrides the transport set by [WithHTTPClient]. If not +// set and no HTTP client is provided, the default http.RoundTripper will be used. +func WithTransport(transport http.RoundTripper) ClientOptionsFunc { + return func(o *clientOptions) error { + if transport == nil { + return errors.New("transport must not be nil") + } + + o.transport = transport + return nil } +} - if !strings.HasSuffix(c2.BaseURL.Path, "/") { - c2.BaseURL.Path += "/" +// WithUserAgent returns a ClientOptionsFunc that sets the User-Agent header +// for a Client. If not set, a default User-Agent will be used. +func WithUserAgent(userAgent string) ClientOptionsFunc { + return func(o *clientOptions) error { + o.userAgent = &userAgent + return nil } - if !strings.HasSuffix(c2.BaseURL.Path, "/api/v3/") && - !strings.HasPrefix(c2.BaseURL.Host, "api.") && - !strings.Contains(c2.BaseURL.Host, ".api.") { - c2.BaseURL.Path += "api/v3/" +} + +// WithEnvProxy returns a ClientOptionsFunc that configures the Client to use +// the HTTP proxy settings from the environment variables +// (e.g., HTTP_PROXY, HTTPS_PROXY, NO_PROXY). +// If not set, the client will not use environment proxy settings. +func WithEnvProxy() ClientOptionsFunc { + return func(o *clientOptions) error { + o.envProxy = true + return nil } +} - c2.UploadURL, err = url.Parse(uploadURL) - if err != nil { - return nil, err +// WithAuthToken returns a ClientOptionsFunc that sets the authentication token +// for a Client. If not set, the client will make unauthenticated requests. +func WithAuthToken(token string) ClientOptionsFunc { + return func(o *clientOptions) error { + if token == "" { + return errors.New("token must not be empty") + } + + o.token = &token + return nil + } +} + +// WithEnterpriseURLs returns a ClientOptionsFunc that sets the base and upload +// URLs for a Client. +func WithEnterpriseURLs(baseURL, uploadURL string) ClientOptionsFunc { + return func(o *clientOptions) error { + if baseURL == "" { + return errors.New("base url must not be empty") + } + + if uploadURL == "" { + return errors.New("upload url must not be empty") + } + + b, err := url.Parse(baseURL) + if err != nil { + return err + } + + if !strings.HasSuffix(b.Path, "/") { + b.Path += "/" + } + + if !strings.HasSuffix(b.Path, "/api/v3/") && + !strings.HasPrefix(b.Host, "api.") && + !strings.Contains(b.Host, ".api.") { + b.Path += "api/v3/" + } + + o.baseURL = b + + u, err := url.Parse(uploadURL) + if err != nil { + return err + } + + if !strings.HasSuffix(u.Path, "/") { + u.Path += "/" + } + if !strings.HasSuffix(u.Path, "/api/uploads/") && + !strings.HasPrefix(u.Host, "api.") && + !strings.Contains(u.Host, ".api.") && + !strings.HasPrefix(u.Host, "uploads.") { + u.Path += "api/uploads/" + } + + o.uploadURL = u + + return nil } +} - if !strings.HasSuffix(c2.UploadURL.Path, "/") { - c2.UploadURL.Path += "/" +// WithDisableRateLimitCheck returns a ClientOptionsFunc that disables rate +// limit checking for a Client. If not set, the client will check for rate +// limits and track them. +func WithDisableRateLimitCheck() ClientOptionsFunc { + return func(o *clientOptions) error { + o.disableRateLimitCheck = true + return nil } - if !strings.HasSuffix(c2.UploadURL.Path, "/api/uploads/") && - !strings.HasPrefix(c2.UploadURL.Host, "api.") && - !strings.Contains(c2.UploadURL.Host, ".api.") && - !strings.HasPrefix(c2.UploadURL.Host, "uploads.") { - c2.UploadURL.Path += "api/uploads/" +} + +// WithRateLimitRedirectionalEndpoints returns a ClientOptionsFunc that +// configures the Client to respect rate limit headers on endpoints that +// return 302 redirection to artifacts. If not set, the client will not +// respect rate limit headers on these endpoints. +func WithRateLimitRedirectionalEndpoints() ClientOptionsFunc { + return func(o *clientOptions) error { + o.rateLimitRedirectionalEndpoints = true + return nil } - return c2, nil } -// initialize sets default values and initializes services. -func (c *Client) initialize() { - if c.client == nil { +// WithSecondaryRateLimitOptions returns a ClientOptionsFunc that configures the Client +// secondary rate limits. +func WithSecondaryRateLimitOptions(maxRetryAfterDuration time.Duration) ClientOptionsFunc { + return func(o *clientOptions) error { + o.maxSecondaryRateLimitRetryAfterDuration = &maxRetryAfterDuration + return nil + } +} + +// NewClient returns a new GitHub API client configured with the provided +// options. The default configuration is suitable for making unauthenticated +// requests to the public GitHub API. To make authenticated requests, +// use [WithAuthToken] or provide an http.Client that performs authentication +// (e.g., using golang.org/x/oauth2) to [WithHTTPClient]. For GitHub +// Enterprise, use [WithEnterpriseURLs] to set the base and upload URLs. If no +// http.Client is provided, a default one will be used, but it is recommended +// to provide a custom http.Client with an appropriate timeout for production +// environments. +// +// Note: When using a nil httpClient, the default client has no timeout set. +// This may not be suitable for production environments. It is recommended to +// provide a custom http.Client with an appropriate timeout. +func NewClient(opts ...ClientOptionsFunc) (*Client, error) { + o := clientOptions{} + for _, opt := range opts { + if err := opt(&o); err != nil { + return nil, err + } + } + + return newClient(o) +} + +// newClient creates a new Client with the provided options. This is an internal +// helper function that is called by [NewClient] and [Client.Clone]. +func newClient(opts clientOptions) (*Client, error) { + c := &Client{} + + if opts.httpClient != nil { + c.client = opts.httpClient + } else { c.client = &http.Client{} } - // Copy the main http client into the IgnoreRedirects one, overriding the `CheckRedirect` func - c.clientIgnoreRedirects = &http.Client{} - c.clientIgnoreRedirects.Transport = c.client.Transport - c.clientIgnoreRedirects.Timeout = c.client.Timeout - c.clientIgnoreRedirects.Jar = c.client.Jar - c.clientIgnoreRedirects.CheckRedirect = func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse + + if opts.transport != nil { + c.client.Transport = opts.transport + } + + if opts.envProxy { + transport := c.client.Transport + if transport == nil { + transport = http.DefaultTransport + } + + t, ok := transport.(*http.Transport) + if !ok { + return nil, errors.New("cannot set environment proxy on non-http transport") + } + + t2 := t.Clone() + t2.Proxy = http.ProxyFromEnvironment + c.client.Transport = t2 + } + + if opts.token != nil { + transport := c.client.Transport + if transport == nil { + transport = http.DefaultTransport + } + c.client.Transport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", *opts.token)) + return transport.RoundTrip(req) + }) + } + + c.clientIgnoreRedirects = &http.Client{ + Transport: c.client.Transport, + Timeout: c.client.Timeout, + Jar: c.client.Jar, + CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }, } - if c.BaseURL == nil { - c.BaseURL, _ = url.Parse(defaultBaseURL) + + if opts.userAgent != nil { + c.userAgent = *opts.userAgent + } else { + c.userAgent = defaultUserAgent } - if c.UploadURL == nil { - c.UploadURL, _ = url.Parse(uploadBaseURL) + + if opts.baseURL != nil { + c.baseURL = opts.baseURL + } else { + c.baseURL, _ = url.Parse(defaultBaseURL) } - if c.UserAgent == "" { - c.UserAgent = defaultUserAgent + + if opts.uploadURL != nil { + c.uploadURL = opts.uploadURL + } else { + c.uploadURL, _ = url.Parse(uploadBaseURL) + } + + c.disableRateLimitCheck = opts.disableRateLimitCheck + + if !c.disableRateLimitCheck { + c.rateLimitRedirectionalEndpoints = opts.rateLimitRedirectionalEndpoints + + if opts.maxSecondaryRateLimitRetryAfterDuration != nil { + c.maxSecondaryRateLimitRetryAfterDuration = *opts.maxSecondaryRateLimitRetryAfterDuration + } } + c.common.client = c c.Actions = (*ActionsService)(&c.common) c.Activity = (*ActivityService)(&c.common) @@ -470,11 +625,7 @@ func (c *Client) initialize() { c.Issues = (*IssuesService)(&c.common) c.Licenses = (*LicensesService)(&c.common) c.Markdown = (*MarkdownService)(&c.common) - var marketplaceStubbed bool - if c.Marketplace != nil { - marketplaceStubbed = c.Marketplace.Stubbed - } - c.Marketplace = &MarketplaceService{client: c, Stubbed: marketplaceStubbed} + c.Marketplace = &MarketplaceService{client: c, Stubbed: opts.marketplaceStubbed} c.Meta = (*MetaService)(&c.common) c.Migrations = (*MigrationService)(&c.common) c.Organizations = (*OrganizationsService)(&c.common) @@ -491,55 +642,84 @@ func (c *Client) initialize() { c.SubIssue = (*SubIssueService)(&c.common) c.Teams = (*TeamsService)(&c.common) c.Users = (*UsersService)(&c.common) + + return c, nil } -// copy returns a copy of the current client. It must be initialized before use. -func (c *Client) copy() *Client { - c.clientMu.Lock() - // can't use *c here because that would copy mutexes by value. - clone := Client{ - client: &http.Client{}, - UserAgent: c.UserAgent, - BaseURL: c.BaseURL, - UploadURL: c.UploadURL, - RateLimitRedirectionalEndpoints: c.RateLimitRedirectionalEndpoints, - secondaryRateLimitReset: c.secondaryRateLimitReset, +// UserAgent returns the User-Agent header value for the client. +func (c *Client) UserAgent() string { + return c.userAgent +} + +// BaseURL returns the base URL for API requests. +func (c *Client) BaseURL() string { + if c.baseURL == nil { + return "" } + return c.baseURL.String() +} + +// UploadURL returns the base URL for upload API requests. +func (c *Client) UploadURL() string { + if c.uploadURL == nil { + return "" + } + return c.uploadURL.String() +} + +// Clone returns a copy of the client with the same configuration and services. +// The returned client has its own http.Client but shares the client +// configuration such as transport and timeout. The returned client starts with +// the same rate limit information as the original client, but it is not +// updated when the original client's rate limit information is updated. +// The returned client is independent of the original client and can be +// modified without affecting the original client. +func (c *Client) Clone(opts ...ClientOptionsFunc) (*Client, error) { + if c.client == nil { + return nil, errUninitialized + } + + o := clientOptions{ + userAgent: &c.userAgent, + baseURL: Ptr(*c.baseURL), + uploadURL: Ptr(*c.uploadURL), + disableRateLimitCheck: c.disableRateLimitCheck, + rateLimitRedirectionalEndpoints: c.rateLimitRedirectionalEndpoints, + maxSecondaryRateLimitRetryAfterDuration: &c.maxSecondaryRateLimitRetryAfterDuration, + } + if c.Marketplace != nil { - clone.Marketplace = &MarketplaceService{Stubbed: c.Marketplace.Stubbed} + o.marketplaceStubbed = c.Marketplace.Stubbed + } + + for _, opt := range opts { + if err := opt(&o); err != nil { + return nil, err + } } - c.clientMu.Unlock() - if c.client != nil { - clone.client.Transport = c.client.Transport - clone.client.CheckRedirect = c.client.CheckRedirect - clone.client.Jar = c.client.Jar - clone.client.Timeout = c.client.Timeout + + if o.httpClient == nil { + o.httpClient = &http.Client{ + Transport: c.client.Transport, + CheckRedirect: c.client.CheckRedirect, + Jar: c.client.Jar, + Timeout: c.client.Timeout, + } } - c.rateMu.Lock() - clone.rateLimits = c.rateLimits - c.rateMu.Unlock() - return &clone -} -// NewClientWithEnvProxy enhances NewClient with the HttpProxy env. -func NewClientWithEnvProxy() *Client { - return NewClient(&http.Client{Transport: &http.Transport{Proxy: http.ProxyFromEnvironment}}) -} + clone, err := newClient(o) + if err != nil { + return nil, err + } -// NewTokenClient returns a new GitHub API client authenticated with the provided token. -// -// Deprecated: Use NewClient(nil).WithAuthToken(token) instead. -func NewTokenClient(_ context.Context, token string) *Client { - // This always returns a nil error. - return NewClient(nil).WithAuthToken(token) -} + if !clone.disableRateLimitCheck { + c.rateMu.Lock() + clone.rateLimits = c.rateLimits + clone.secondaryRateLimitReset = c.secondaryRateLimitReset + c.rateMu.Unlock() + } -// NewEnterpriseClient returns a new GitHub API client with provided -// base URL and upload URL (often is your GitHub Enterprise hostname). -// -// Deprecated: Use NewClient(httpClient).WithEnterpriseURLs(baseURL, uploadURL) instead. -func NewEnterpriseClient(baseURL, uploadURL string, httpClient *http.Client) (*Client, error) { - return NewClient(httpClient).WithEnterpriseURLs(baseURL, uploadURL) + return clone, nil } // RequestOption represents an option that can modify an http.Request. @@ -560,15 +740,15 @@ func WithVersion(version string) RequestOption { // specified, the value pointed to by body is JSON encoded and included as the // request body. func (c *Client) NewRequest(ctx context.Context, method, urlStr string, body any, opts ...RequestOption) (*http.Request, error) { - if !strings.HasSuffix(c.BaseURL.Path, "/") { - return nil, fmt.Errorf("baseURL must have a trailing slash, but %q does not", c.BaseURL) + if !strings.HasSuffix(c.baseURL.Path, "/") { + return nil, fmt.Errorf("baseURL must have a trailing slash, but %q does not", c.baseURL) } if err := checkURLPathTraversal(urlStr); err != nil { return nil, err } - u, err := c.BaseURL.Parse(urlStr) + u, err := c.baseURL.Parse(urlStr) if err != nil { return nil, err } @@ -593,8 +773,8 @@ func (c *Client) NewRequest(ctx context.Context, method, urlStr string, body any req.Header.Set("Content-Type", "application/json") } req.Header.Set("Accept", mediaTypeV3) - if c.UserAgent != "" { - req.Header.Set("User-Agent", c.UserAgent) + if c.userAgent != "" { + req.Header.Set("User-Agent", c.userAgent) } req.Header.Set(headerAPIVersion, defaultAPIVersion) @@ -610,15 +790,15 @@ func (c *Client) NewRequest(ctx context.Context, method, urlStr string, body any // Relative URLs should always be specified without a preceding slash. // Body is sent with Content-Type: application/x-www-form-urlencoded. func (c *Client) NewFormRequest(ctx context.Context, urlStr string, body io.Reader, opts ...RequestOption) (*http.Request, error) { - if !strings.HasSuffix(c.BaseURL.Path, "/") { - return nil, fmt.Errorf("baseURL must have a trailing slash, but %q does not", c.BaseURL) + if !strings.HasSuffix(c.baseURL.Path, "/") { + return nil, fmt.Errorf("baseURL must have a trailing slash, but %q does not", c.baseURL) } if err := checkURLPathTraversal(urlStr); err != nil { return nil, err } - u, err := c.BaseURL.Parse(urlStr) + u, err := c.baseURL.Parse(urlStr) if err != nil { return nil, err } @@ -630,8 +810,8 @@ func (c *Client) NewFormRequest(ctx context.Context, urlStr string, body io.Read req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", mediaTypeV3) - if c.UserAgent != "" { - req.Header.Set("User-Agent", c.UserAgent) + if c.userAgent != "" { + req.Header.Set("User-Agent", c.userAgent) } req.Header.Set(headerAPIVersion, defaultAPIVersion) @@ -665,15 +845,15 @@ func checkURLPathTraversal(urlStr string) error { // urlStr, in which case it is resolved relative to the UploadURL of the Client. // Relative URLs should always be specified without a preceding slash. func (c *Client) NewUploadRequest(ctx context.Context, urlStr string, reader io.Reader, size int64, mediaType string, opts ...RequestOption) (*http.Request, error) { - if !strings.HasSuffix(c.UploadURL.Path, "/") { - return nil, fmt.Errorf("uploadURL must have a trailing slash, but %q does not", c.UploadURL) + if !strings.HasSuffix(c.uploadURL.Path, "/") { + return nil, fmt.Errorf("uploadURL must have a trailing slash, but %q does not", c.uploadURL) } if err := checkURLPathTraversal(urlStr); err != nil { return nil, err } - u, err := c.UploadURL.Parse(urlStr) + u, err := c.uploadURL.Parse(urlStr) if err != nil { return nil, err } @@ -703,7 +883,7 @@ func (c *Client) NewUploadRequest(ctx context.Context, urlStr string, reader io. } req.Header.Set("Content-Type", mediaType) req.Header.Set("Accept", mediaTypeV3) - req.Header.Set("User-Agent", c.UserAgent) + req.Header.Set("User-Agent", c.userAgent) req.Header.Set(headerAPIVersion, defaultAPIVersion) for _, opt := range opts { @@ -935,7 +1115,7 @@ func (c *Client) bareDo(caller *http.Client, req *http.Request) (*Response, erro rateLimitCategory := CoreCategory - if !c.DisableRateLimitCheck { + if !c.disableRateLimitCheck { rateLimitCategory = GetRateLimitCategory(req.Method, req.URL.Path) if bypass := ctx.Value(BypassRateLimitCheck); bypass == nil { @@ -986,7 +1166,7 @@ func (c *Client) bareDo(caller *http.Client, req *http.Request) (*Response, erro // Don't update the rate limits if the client has rate limits disabled or if // this was a cached response. The X-From-Cache is set by // https://github.com/bartventer/httpcache if it's enabled. - if !c.DisableRateLimitCheck && response.Header.Get("X-From-Cache") == "" { + if !c.disableRateLimitCheck && response.Header.Get("X-From-Cache") == "" { c.rateMu.Lock() c.rateLimits[rateLimitCategory] = response.Rate c.rateMu.Unlock() @@ -1026,8 +1206,8 @@ func (c *Client) bareDo(caller *http.Client, req *http.Request) (*Response, erro var rerr *AbuseRateLimitError if errors.As(err, &rerr) && rerr.RetryAfter != nil { // if a max duration is specified, make sure that we are waiting at most this duration - if c.MaxSecondaryRateLimitRetryAfterDuration > 0 && rerr.GetRetryAfter() > c.MaxSecondaryRateLimitRetryAfterDuration { - rerr.RetryAfter = &c.MaxSecondaryRateLimitRetryAfterDuration + if c.maxSecondaryRateLimitRetryAfterDuration > 0 && rerr.GetRetryAfter() > c.maxSecondaryRateLimitRetryAfterDuration { + rerr.RetryAfter = &c.maxSecondaryRateLimitRetryAfterDuration } c.rateMu.Lock() c.secondaryRateLimitReset = time.Now().Add(*rerr.RetryAfter) @@ -1069,7 +1249,7 @@ func (c *Client) bareDoUntilFound(req *http.Request, maxRedirects int) (*url.URL if rerr.Location == nil { return nil, nil, errInvalidLocation } - newURL := c.BaseURL.ResolveReference(rerr.Location) + newURL := c.baseURL.ResolveReference(rerr.Location) return newURL, response, nil } // If permanent redirect response is returned, follow it @@ -1077,12 +1257,12 @@ func (c *Client) bareDoUntilFound(req *http.Request, maxRedirects int) (*url.URL if rerr.Location == nil { return nil, nil, errInvalidLocation } - newURL := c.BaseURL.ResolveReference(rerr.Location) + newURL := c.baseURL.ResolveReference(rerr.Location) // Refuse to follow a permanent redirect to a different host: // req.Clone preserves Authorization headers added by the auth // transport, so a cross-host target would leak credentials. - if newURL.Host != c.BaseURL.Host { - return nil, response, fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.BaseURL.Host, newURL.Host) + if newURL.Host != c.baseURL.Host { + return nil, response, fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.baseURL.Host, newURL.Host) } newRequest := req.Clone(req.Context()) newRequest.URL = newURL @@ -1864,9 +2044,9 @@ func (c *Client) checkRedirectHost(location string) error { return fmt.Errorf("invalid redirect location %q: %w", location, err) } // Resolve relative locations against BaseURL so relative paths are allowed. - target = c.BaseURL.ResolveReference(target) - if target.Host != c.BaseURL.Host { - return fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.BaseURL.Host, target.Host) + target = c.baseURL.ResolveReference(target) + if target.Host != c.baseURL.Host { + return fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.baseURL.Host, target.Host) } return nil } diff --git a/github/github_test.go b/github/github_test.go index 87703736508..2995ef2fe48 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -45,6 +45,26 @@ type raceSafeTestConn struct { net.Conn } +// mustNewClient is a helper function that creates a new Client and fails the test if there is an error. +func mustNewClient(t *testing.T, opts ...ClientOptionsFunc) *Client { + t.Helper() + c, err := NewClient(opts...) + if err != nil { + t.Fatal(err) + } + return c +} + +// mustParseURL is a helper function that parses a URL and fails the test if there is an error. +func mustParseURL(t *testing.T, rawurl string) *url.URL { + t.Helper() + u, err := url.Parse(rawurl) + if err != nil { + t.Fatalf("Failed to parse URL %q: %v", rawurl, err) + } + return u +} + // setup sets up a test HTTP server along with a github.Client that is // configured to talk to that test server. Tests should register handlers on // mux which provide mock responses for the API method being tested. @@ -103,11 +123,11 @@ func setup(t *testing.T) (client *Client, mux *http.ServeMux, serverURL string) } // client is the GitHub client being tested and is // configured to use test server. - client = NewClient(httpClient) + client = mustNewClient(t, WithHTTPClient(httpClient)) url, _ := url.Parse(server.URL + baseURLPath + "/") - client.BaseURL = url - client.UploadURL = url + client.baseURL = url + client.uploadURL = url t.Cleanup(server.Close) @@ -352,7 +372,7 @@ func testNewRequestAndDoFailureCategory(t *testing.T, methodName string, client t.Error("testNewRequestAndDoFailure: must supply method methodName") } - client.BaseURL.Path = "" + client.baseURL.Path = "" resp, err := f() if resp != nil { t.Errorf("client.BaseURL.Path='' %v resp = %#v, want nil", methodName, resp) @@ -361,10 +381,10 @@ func testNewRequestAndDoFailureCategory(t *testing.T, methodName string, client t.Errorf("client.BaseURL.Path='' %v err = nil, want error", methodName) } - client.BaseURL.Path = "/api-v3/" + client.baseURL.Path = "/api-v3/" client.rateLimits[category].Reset.Time = time.Now().Add(10 * time.Minute) resp, err = f() - if client.DisableRateLimitCheck { + if client.disableRateLimitCheck { return } if bypass := resp.Request.Context().Value(BypassRateLimitCheck); bypass != nil { @@ -427,123 +447,155 @@ func assertWrite(t *testing.T, w io.Writer, data []byte) { assertNilError(t, err) } -func TestNewClient(t *testing.T) { +func TestWithHTTPClient(t *testing.T) { t.Parallel() - c := NewClient(nil) - if got, want := c.BaseURL.String(), defaultBaseURL; got != want { - t.Errorf("NewClient BaseURL is %v, want %v", got, want) - } - if got, want := c.UserAgent, defaultUserAgent; got != want { - t.Errorf("NewClient UserAgent is %v, want %v", got, want) - } + t.Run("nil_client", func(t *testing.T) { + t.Parallel() - c2 := NewClient(nil) - if c.client == c2.client { - t.Error("NewClient returned same http.Clients, but they should differ") - } -} + opts := clientOptions{} + err := WithHTTPClient(nil)(&opts) + if err == nil || err.Error() != "http client must not be nil" { + t.Errorf("WithHTTPClient errored: %v", err) + } + }) -func TestNewClientWithEnvProxy(t *testing.T) { - t.Parallel() - client := NewClientWithEnvProxy() - if got, want := client.BaseURL.String(), defaultBaseURL; got != want { - t.Errorf("NewClient BaseURL is %v, want %v", got, want) - } -} + t.Run("non_nil_client", func(t *testing.T) { + t.Parallel() -func TestClient(t *testing.T) { - t.Parallel() - c := NewClient(nil) - c2 := c.Client() - if c.client == c2 { - t.Error("Client returned same http.Client, but should be different") - } + customClient := &http.Client{Timeout: 10 * time.Second} + opts := clientOptions{} + err := WithHTTPClient(customClient)(&opts) + if err != nil { + t.Errorf("WithHTTPClient errored: %v", err) + } + + if opts.httpClient == nil { + t.Error("httpClient is nil") + } + + if opts.httpClient == customClient { + t.Error("httpClient should be a shallow copy of the provided client, but is the same instance") + } + + if opts.httpClient.Timeout != customClient.Timeout { + t.Errorf("httpClient Timeout = %v, want %v", opts.httpClient.Timeout, customClient.Timeout) + } + }) } -func TestWithAuthToken(t *testing.T) { +func TestWithTransport(t *testing.T) { t.Parallel() - token := "gh_test_token" - validate := func(t *testing.T, c *http.Client, token string) { - t.Helper() - want := token - if want != "" { - want = "Bearer " + want + t.Run("nil_transport", func(t *testing.T) { + t.Parallel() + + opts := clientOptions{} + err := WithTransport(nil)(&opts) + if err == nil || err.Error() != "transport must not be nil" { + t.Errorf("WithTransport errored: %v", err) } - gotReq := false - headerVal := "" - srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - gotReq = true - headerVal = r.Header.Get("Authorization") - })) - _, err := c.Get(srv.URL) - assertNilError(t, err) - if !gotReq { - t.Error("request not sent") + }) + + t.Run("non_nil_transport", func(t *testing.T) { + t.Parallel() + + customTransport := &http.Transport{IdleConnTimeout: 10 * time.Second} + opts := clientOptions{} + err := WithTransport(customTransport)(&opts) + if err != nil { + t.Errorf("WithTransport errored: %v", err) } - if headerVal != want { - t.Errorf("Authorization header is %v, want %v", headerVal, want) + + if opts.transport == nil { + t.Error("transport is nil") } - } - t.Run("zero-value Client", func(t *testing.T) { - t.Parallel() - c := new(Client).WithAuthToken(token) - validate(t, c.Client(), token) + if opts.transport != customTransport { + t.Errorf("transport = %v, want %v", opts.transport, customTransport) + } }) +} + +func TestWithUserAgent(t *testing.T) { + t.Parallel() - t.Run("NewClient", func(t *testing.T) { + t.Run("empty_user_agent", func(t *testing.T) { t.Parallel() - httpClient := &http.Client{} - client := NewClient(httpClient).WithAuthToken(token) - validate(t, client.Client(), token) - // make sure the original client isn't setting auth headers now - validate(t, httpClient, "") + + opts := clientOptions{} + err := WithUserAgent("")(&opts) + if err != nil { + t.Errorf("WithUserAgent errored: %v", err) + } + + if *opts.userAgent != "" { + t.Errorf("userAgent = %v, want empty string", opts.userAgent) + } }) - t.Run("NewTokenClient", func(t *testing.T) { + t.Run("custom_user_agent", func(t *testing.T) { t.Parallel() - validate(t, NewTokenClient(t.Context(), token).Client(), token) + + customUserAgent := "MyCustomUserAgent/1.0" + opts := clientOptions{} + err := WithUserAgent(customUserAgent)(&opts) + if err != nil { + t.Errorf("WithUserAgent errored: %v", err) + } + + if opts.userAgent == nil || *opts.userAgent != customUserAgent { + t.Errorf("userAgent = %v, want %v", opts.userAgent, customUserAgent) + } }) +} + +func TestWithEnvProxy(t *testing.T) { + t.Parallel() - t.Run("do not set Authorization when empty token", func(t *testing.T) { + opts := clientOptions{} + err := WithEnvProxy()(&opts) + if err != nil { + t.Errorf("WithEnvProxy errored: %v", err) + } + + if !opts.envProxy { + t.Error("envProxy is false, want true") + } +} + +func TestWithAuthToken(t *testing.T) { + t.Parallel() + + t.Run("empty_token", func(t *testing.T) { t.Parallel() - c := new(Client).WithAuthToken("") - - gotReq := false - ifAuthorizationSet := false - srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - gotReq = true - _, ifAuthorizationSet = r.Header["Authorization"] - })) - _, err := c.client.Get(srv.URL) - assertNilError(t, err) - if !gotReq { - t.Error("request not sent") - } - if ifAuthorizationSet { - t.Error("The header 'Authorization' must not be set") + + opts := clientOptions{} + err := WithAuthToken("")(&opts) + if err == nil || err.Error() != "token must not be empty" { + t.Error("expected error for empty token, got nil") } }) - t.Run("preserves Marketplace Stubbed field", func(t *testing.T) { + t.Run("valid_token", func(t *testing.T) { t.Parallel() - c := NewClient(nil) - c.Marketplace.Stubbed = true - - c2 := c.WithAuthToken("token") + validToken := "ghp_exampletoken1234567890" + opts := clientOptions{} + err := WithAuthToken(validToken)(&opts) + if err != nil { + t.Errorf("WithAuthToken errored: %v", err) + } - if !c2.Marketplace.Stubbed { - t.Fatal("WithAuthToken reset Marketplace.Stubbed; want true") + if opts.token == nil || *opts.token != validToken { + t.Errorf("token = %v, want %v", opts.token, validToken) } }) } func TestWithEnterpriseURLs(t *testing.T) { t.Parallel() - for _, test := range []struct { + for _, tt := range []struct { name string baseURL string wantBaseURL string @@ -640,34 +692,615 @@ func TestWithEnterpriseURLs(t *testing.T) { uploadURL: "https://uploads.custom-upload-url/", wantUploadURL: "https://uploads.custom-upload-url/", }, + { + name: "missing_base_url", + baseURL: "", + uploadURL: "https://custom-upload-url/api/uploads/", + wantErr: "base url must not be empty", + }, + { + name: "missing_upload_url", + baseURL: "https://custom-url/api/v3/", + uploadURL: "", + wantErr: "upload url must not be empty", + }, } { - t.Run(test.name, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { t.Parallel() - validate := func(c *Client, err error) { - t.Helper() - if test.wantErr != "" { - if err == nil || !strings.Contains(err.Error(), test.wantErr) { - t.Fatalf("error does not contain expected string %q: %v", test.wantErr, err) - } - return + + opts := clientOptions{} + err := WithEnterpriseURLs(tt.baseURL, tt.uploadURL)(&opts) + if err != nil { + if tt.wantErr == "" { + t.Fatalf("WithEnterpriseURLs returned unexpected error: %v", err) } - if err != nil { - t.Fatalf("got unexpected error: %v", err) + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("error does not contain expected string %q: %v", tt.wantErr, err) } - if c.BaseURL.String() != test.wantBaseURL { - t.Errorf("BaseURL is %v, want %v", c.BaseURL, test.wantBaseURL) + return + } + if tt.wantErr != "" { + t.Fatalf("WithEnterpriseURLs did not return expected error containing %q", tt.wantErr) + } + + if opts.baseURL.String() != tt.wantBaseURL { + t.Errorf("BaseURL is %v, want %v", opts.baseURL, tt.wantBaseURL) + } + if opts.uploadURL.String() != tt.wantUploadURL { + t.Errorf("UploadURL is %v, want %v", opts.uploadURL, tt.wantUploadURL) + } + }) + } +} + +func TestWithDisableRateLimitCheck(t *testing.T) { + t.Parallel() + + opts := clientOptions{} + err := WithDisableRateLimitCheck()(&opts) + if err != nil { + t.Errorf("WithDisableRateLimitCheck errored: %v", err) + } + + if !opts.disableRateLimitCheck { + t.Error("disableRateLimitCheck is false, want true") + } +} + +func TestWithRateLimitRedirectionalEndpoints(t *testing.T) { + t.Parallel() + + opts := clientOptions{} + err := WithRateLimitRedirectionalEndpoints()(&opts) + if err != nil { + t.Errorf("WithRateLimitRedirectionalEndpoints errored: %v", err) + } + + if !opts.rateLimitRedirectionalEndpoints { + t.Error("rateLimitRedirectionalEndpoints is false, want true") + } +} + +func TestWithSecondaryRateLimitOptions(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + maxRetryAfterDuration time.Duration + }{ + { + name: "maxRetryAfterDuration is 0 (default)", + maxRetryAfterDuration: 0, + }, + { + name: "maxRetryAfterDuration is 1 minute", + maxRetryAfterDuration: time.Minute, + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + opts := clientOptions{} + err := WithSecondaryRateLimitOptions(tt.maxRetryAfterDuration)(&opts) + if err != nil { + t.Errorf("WithSecondaryRateLimitOptions errored: %v", err) + } + if *opts.maxSecondaryRateLimitRetryAfterDuration != tt.maxRetryAfterDuration { + t.Errorf("maxSecondaryRateLimitRetryAfterDuration is %v, want %v", *opts.maxSecondaryRateLimitRetryAfterDuration, tt.maxRetryAfterDuration) + } + }) + } +} + +func TestNewClient(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + opts []ClientOptionsFunc + wantErr string + }{ + { + name: "no_options", + opts: []ClientOptionsFunc{}, + wantErr: "", + }, + { + name: "with_options", + opts: []ClientOptionsFunc{ + WithHTTPClient(&http.Client{Timeout: 10 * time.Second}), + }, + wantErr: "", + }, + { + name: "with_bad_options", + opts: []ClientOptionsFunc{ + func(_ *clientOptions) error { + return errors.New("bad option error") + }, + }, + wantErr: "bad option error", + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c, err := NewClient(tt.opts...) + if err != nil { + if tt.wantErr == "" { + t.Fatalf("NewClient returned unexpected error: %v", err) } - if c.UploadURL.String() != test.wantUploadURL { - t.Errorf("UploadURL is %v, want %v", c.UploadURL, test.wantUploadURL) + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("error does not contain expected string %q: %v", tt.wantErr, err) } + return + } + if tt.wantErr != "" { + t.Fatalf("NewClient did not return expected error containing %q", tt.wantErr) + } + + if c.client == nil { + t.Error("NewClient client is not initialized") } - validate(NewClient(nil).WithEnterpriseURLs(test.baseURL, test.uploadURL)) - validate(new(Client).WithEnterpriseURLs(test.baseURL, test.uploadURL)) - validate(NewEnterpriseClient(test.baseURL, test.uploadURL, nil)) }) } } +func Test_newClient(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + opts clientOptions + wantErr string + }{ + { + name: "default_options", + opts: clientOptions{}, + wantErr: "", + }, + { + name: "with_http_client", + opts: clientOptions{ + httpClient: &http.Client{Transport: &http.Transport{IdleConnTimeout: 5 * time.Second}}, + }, + wantErr: "", + }, + { + name: "with_transport", + opts: clientOptions{ + transport: &http.Transport{IdleConnTimeout: 10 * time.Second}, + }, + wantErr: "", + }, + { + name: "with_all_options", + opts: clientOptions{ + httpClient: &http.Client{Transport: &http.Transport{IdleConnTimeout: 5 * time.Second}}, + transport: &http.Transport{IdleConnTimeout: 10 * time.Second}, + userAgent: Ptr("CustomUserAgent/1.0"), + baseURL: mustParseURL(t, "https://custom-url/api/v3/"), + uploadURL: mustParseURL(t, "https://custom-upload-url/api/uploads/"), + disableRateLimitCheck: true, + rateLimitRedirectionalEndpoints: true, + maxSecondaryRateLimitRetryAfterDuration: Ptr(2 * time.Minute), + }, + wantErr: "", + }, + { + name: "with_rate_limit_options", + opts: clientOptions{ + disableRateLimitCheck: false, + rateLimitRedirectionalEndpoints: true, + maxSecondaryRateLimitRetryAfterDuration: Ptr(2 * time.Minute), + }, + wantErr: "", + }, + { + name: "with_incompatible_transport_for_env_proxy", + opts: clientOptions{ + transport: roundTripperFunc(func(_ *http.Request) (*http.Response, error) { + return nil, nil + }), + envProxy: true, + }, + wantErr: "cannot set environment proxy on non-http transport", + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c, err := newClient(tt.opts) + if err != nil { + if tt.wantErr == "" { + t.Fatalf("newClient returned unexpected error: %v", err) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("error does not contain expected string %q: %v", tt.wantErr, err) + } + return + } + if tt.wantErr != "" { + t.Fatalf("newClient did not return expected error containing %q", tt.wantErr) + } + + if c.client == nil { + t.Error("newClient http.Client is not initialized") + } + if tt.opts.httpClient != nil && c.client != tt.opts.httpClient { + t.Error("newClient http.Client should be the same instance as the provided httpClient option") + } + if tt.opts.transport != nil && !tt.opts.envProxy && tt.opts.token == nil && c.client.Transport != tt.opts.transport { + t.Error("newClient http.Client.Transport should be the same instance as the provided transport option") + } + + if c.clientIgnoreRedirects == nil { + t.Error("newClient http.Client used for redirects is not initialized") + } + if c.clientIgnoreRedirects.Transport != c.client.Transport { + t.Error("newClient http.Client and http.Client used for redirects should share the same Transport instance") + } + if c.clientIgnoreRedirects.Timeout != c.client.Timeout { + t.Errorf("newClient http.Client and http.Client used for redirects should have the same Timeout, got %v and %v", c.client.Timeout, c.clientIgnoreRedirects.Timeout) + } + if c.clientIgnoreRedirects.Jar != c.client.Jar { + t.Error("newClient http.Client and http.Client used for redirects should share the same Jar instance") + } + if c.clientIgnoreRedirects.CheckRedirect == nil { + t.Error("newClient http.Client used for redirects should have a CheckRedirect function") + } + + if tt.opts.userAgent != nil && c.userAgent != *tt.opts.userAgent { + t.Errorf("newClient userAgent is %v, want %v", c.userAgent, *tt.opts.userAgent) + } + if tt.opts.userAgent == nil && c.userAgent != defaultUserAgent { + t.Errorf("newClient userAgent is %v, want %v", c.userAgent, defaultUserAgent) + } + + if tt.opts.baseURL != nil && c.baseURL.String() != tt.opts.baseURL.String() { + t.Errorf("newClient baseURL is %v, want %v", c.baseURL.String(), tt.opts.baseURL.String()) + } + if tt.opts.baseURL == nil && c.baseURL.String() != defaultBaseURL { + t.Errorf("newClient baseURL is %v, want %v", c.baseURL.String(), defaultBaseURL) + } + + if tt.opts.uploadURL != nil && c.uploadURL.String() != tt.opts.uploadURL.String() { + t.Errorf("newClient uploadURL is %v, want %v", c.uploadURL.String(), tt.opts.uploadURL.String()) + } + if tt.opts.uploadURL == nil && c.uploadURL.String() != uploadBaseURL { + t.Errorf("newClient uploadURL is %v, want %v", c.uploadURL.String(), uploadBaseURL) + } + + if c.disableRateLimitCheck != tt.opts.disableRateLimitCheck { + t.Errorf("newClient disableRateLimitCheck is %v, want %v", c.disableRateLimitCheck, tt.opts.disableRateLimitCheck) + } + if tt.opts.disableRateLimitCheck && (c.rateLimitRedirectionalEndpoints || c.maxSecondaryRateLimitRetryAfterDuration != 0) { + t.Error("newClient should not set rate limit options when disableRateLimitCheck is true") + } + + if !tt.opts.disableRateLimitCheck && c.rateLimitRedirectionalEndpoints != tt.opts.rateLimitRedirectionalEndpoints { + t.Errorf("newClient rateLimitRedirectionalEndpoints is %v, want %v", c.rateLimitRedirectionalEndpoints, tt.opts.rateLimitRedirectionalEndpoints) + } + if !tt.opts.disableRateLimitCheck && tt.opts.maxSecondaryRateLimitRetryAfterDuration != nil && c.maxSecondaryRateLimitRetryAfterDuration != *tt.opts.maxSecondaryRateLimitRetryAfterDuration { + t.Errorf("newClient maxSecondaryRateLimitRetryAfterDuration is %v, want %v", c.maxSecondaryRateLimitRetryAfterDuration, *tt.opts.maxSecondaryRateLimitRetryAfterDuration) + } + + if c.common.client != c { + t.Error("newClient common.client is not initialized or does not point to the client") + } + if c.Actions == nil || c.Activity == nil || c.Admin == nil || c.Apps == nil || c.Authorizations == nil || c.Billing == nil || c.Checks == nil || c.Classroom == nil || c.CodeScanning == nil || c.CodesOfConduct == nil || c.Codespaces == nil || c.Copilot == nil || c.Credentials == nil || c.Dependabot == nil || c.DependencyGraph == nil || c.Emojis == nil || c.Enterprise == nil || c.Gists == nil || c.Git == nil || c.Gitignores == nil || c.Interactions == nil || c.IssueImport == nil || c.Issues == nil || c.Licenses == nil || c.Markdown == nil || c.Marketplace == nil || c.Meta == nil || c.Migrations == nil || c.Organizations == nil || c.PrivateRegistries == nil || c.Projects == nil || c.PullRequests == nil || c.RateLimit == nil || c.Reactions == nil || c.Repositories == nil || c.SCIM == nil || c.Search == nil || c.SecretScanning == nil || c.SecurityAdvisories == nil || c.SubIssue == nil || c.Teams == nil || c.Users == nil { + t.Error("newClient service fields are not all initialized") + } + + if c.Marketplace.Stubbed != tt.opts.marketplaceStubbed { + t.Errorf("newClient marketplaceStubbed is %v, want %v", c.Marketplace.Stubbed, tt.opts.marketplaceStubbed) + } + }) + } +} + +func TestClient_UserAgent(t *testing.T) { + t.Parallel() + + c := mustNewClient(t) + + if got, want := c.UserAgent(), defaultUserAgent; got != want { + t.Errorf("Client.UserAgent() = %v, want %v", got, want) + } + + customUserAgent := "CustomUserAgent/1.0" + c.userAgent = customUserAgent + + if got, want := c.UserAgent(), customUserAgent; got != want { + t.Errorf("Client.UserAgent() = %v, want %v", got, want) + } +} + +func TestClient_BaseURL(t *testing.T) { + t.Parallel() + + c := mustNewClient(t) + + if got, want := c.BaseURL(), defaultBaseURL; got != want { + t.Errorf("Client.BaseURL() = %v, want %v", got, want) + } + + customBaseURL := "https://custom-url/api/v3/" + c.baseURL = mustParseURL(t, customBaseURL) + + if got, want := c.BaseURL(), customBaseURL; got != want { + t.Errorf("Client.BaseURL() = %v, want %v", got, want) + } +} + +func TestClient_UploadURL(t *testing.T) { + t.Parallel() + + c := mustNewClient(t) + + if got, want := c.UploadURL(), uploadBaseURL; got != want { + t.Errorf("Client.UploadURL() = %v, want %v", got, want) + } + + customUploadURL := "https://custom-upload-url/api/uploads/" + c.uploadURL = mustParseURL(t, customUploadURL) + + if got, want := c.UploadURL(), customUploadURL; got != want { + t.Errorf("Client.UploadURL() = %v, want %v", got, want) + } +} + +func TestClient_Clone(t *testing.T) { + t.Parallel() + + t.Run("uninitialized_client", func(t *testing.T) { + t.Parallel() + + var c Client + + _, err := c.Clone() + if err == nil || !errors.Is(err, errUninitialized) { + t.Fatalf("Client.Clone returned unexpected error: %v", err) + } + }) + + t.Run("initialized_client", func(t *testing.T) { + t.Parallel() + + c := mustNewClient(t) + c.userAgent = "CustomUserAgent/1.0" + c.baseURL.Path = "/custom/" + c.uploadURL.Path = "/custom-upload/" + c.disableRateLimitCheck = false + c.rateLimitRedirectionalEndpoints = true + c.maxSecondaryRateLimitRetryAfterDuration = 2 * time.Minute + c.Marketplace.Stubbed = true + c.client.Transport = &http.Transport{IdleConnTimeout: 10 * time.Second} + c.client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return nil } + c.client.Timeout = 15 * time.Second + c.rateLimits[CoreCategory].Remaining = 100 + c.secondaryRateLimitReset = time.Now().Add(30 * time.Second) + + cloned, err := c.Clone() + if err != nil { + t.Fatalf("Client.Clone returned error: %v", err) + } + + if cloned.client == c.client { + t.Error("Cloned Client has same http.Client instance, but should be different") + } + if cloned.client.Transport != c.client.Transport { + t.Error("Cloned Client http.Client.Transport is not the same instance as original") + } + if cloned.client.CheckRedirect == nil || fmt.Sprintf("%p", cloned.client.CheckRedirect) != fmt.Sprintf("%p", c.client.CheckRedirect) { + t.Error("Cloned Client http.Client.CheckRedirect is not the same function instance as original") + } + if cloned.client.Jar != c.client.Jar { + t.Error("Cloned Client http.Client.Jar is not the same instance as original") + } + if cloned.client.Timeout != c.client.Timeout { + t.Errorf("Cloned Client http.Client.Timeout is %v, want %v", cloned.client.Timeout, c.client.Timeout) + } + if got, want := cloned.userAgent, c.userAgent; got != want { + t.Errorf("Cloned Client userAgent is %v, want %v", got, want) + } + if got, want := cloned.baseURL.String(), c.baseURL.String(); got != want { + t.Errorf("Cloned Client baseURL is %v, want %v", got, want) + } + if got, want := cloned.uploadURL.String(), c.uploadURL.String(); got != want { + t.Errorf("Cloned Client uploadURL is %v, want %v", got, want) + } + if cloned.disableRateLimitCheck != c.disableRateLimitCheck { + t.Errorf("Cloned Client disableRateLimitCheck is %v, want %v", cloned.disableRateLimitCheck, c.disableRateLimitCheck) + } + if cloned.rateLimitRedirectionalEndpoints != c.rateLimitRedirectionalEndpoints { + t.Errorf("Cloned Client rateLimitRedirectionalEndpoints is %v, want %v", cloned.rateLimitRedirectionalEndpoints, c.rateLimitRedirectionalEndpoints) + } + if cloned.maxSecondaryRateLimitRetryAfterDuration != c.maxSecondaryRateLimitRetryAfterDuration { + t.Errorf("Cloned Client maxSecondaryRateLimitRetryAfterDuration is %v, want %v", cloned.maxSecondaryRateLimitRetryAfterDuration, c.maxSecondaryRateLimitRetryAfterDuration) + } + if cloned.Marketplace.Stubbed != c.Marketplace.Stubbed { + t.Errorf("Cloned Client Marketplace.Stubbed is %v, want %v", cloned.Marketplace.Stubbed, c.Marketplace.Stubbed) + } + if cloned.rateLimits[CoreCategory].Remaining != c.rateLimits[CoreCategory].Remaining { + t.Errorf("Cloned Client rateLimits[CoreCategory].Remaining is %v, want %v", cloned.rateLimits[CoreCategory].Remaining, c.rateLimits[CoreCategory].Remaining) + } + if !cloned.secondaryRateLimitReset.Equal(c.secondaryRateLimitReset) { + t.Errorf("Cloned Client secondaryRateLimitReset is %v, want %v", cloned.secondaryRateLimitReset, c.secondaryRateLimitReset) + } + }) + + t.Run("initialized_client_no_rate_limit_check", func(t *testing.T) { + t.Parallel() + + c := mustNewClient(t) + c.userAgent = "CustomUserAgent/1.0" + c.baseURL.Path = "/custom/" + c.uploadURL.Path = "/custom-upload/" + c.disableRateLimitCheck = true + c.Marketplace.Stubbed = true + c.client.Transport = &http.Transport{IdleConnTimeout: 10 * time.Second} + c.client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return nil } + c.client.Timeout = 15 * time.Second + c.rateLimits[CoreCategory].Remaining = 100 + c.secondaryRateLimitReset = time.Now().Add(30 * time.Second) + + cloned, err := c.Clone() + if err != nil { + t.Fatalf("Client.Clone returned error: %v", err) + } + + if cloned.client == c.client { + t.Error("Cloned Client has same http.Client instance, but should be different") + } + if cloned.client.Transport != c.client.Transport { + t.Error("Cloned Client http.Client.Transport is not the same instance as original") + } + if cloned.client.CheckRedirect == nil || fmt.Sprintf("%p", cloned.client.CheckRedirect) != fmt.Sprintf("%p", c.client.CheckRedirect) { + t.Error("Cloned Client http.Client.CheckRedirect is not the same function instance as original") + } + if cloned.client.Jar != c.client.Jar { + t.Error("Cloned Client http.Client.Jar is not the same instance as original") + } + if cloned.client.Timeout != c.client.Timeout { + t.Errorf("Cloned Client http.Client.Timeout is %v, want %v", cloned.client.Timeout, c.client.Timeout) + } + if got, want := cloned.userAgent, c.userAgent; got != want { + t.Errorf("Cloned Client userAgent is %v, want %v", got, want) + } + if got, want := cloned.baseURL.String(), c.baseURL.String(); got != want { + t.Errorf("Cloned Client baseURL is %v, want %v", got, want) + } + if got, want := cloned.uploadURL.String(), c.uploadURL.String(); got != want { + t.Errorf("Cloned Client uploadURL is %v, want %v", got, want) + } + if cloned.disableRateLimitCheck != c.disableRateLimitCheck { + t.Errorf("Cloned Client disableRateLimitCheck is %v, want %v", cloned.disableRateLimitCheck, c.disableRateLimitCheck) + } + if cloned.Marketplace.Stubbed != c.Marketplace.Stubbed { + t.Errorf("Cloned Client Marketplace.Stubbed is %v, want %v", cloned.Marketplace.Stubbed, c.Marketplace.Stubbed) + } + }) + + t.Run("initialized_client_with_transport", func(t *testing.T) { + t.Parallel() + + c := mustNewClient(t) + c.client.Transport = &http.Transport{IdleConnTimeout: 10 * time.Second} + c.client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return nil } + c.client.Timeout = 15 * time.Second + + tr := &http.Transport{IdleConnTimeout: 30 * time.Second} + + cloned, err := c.Clone(WithTransport(tr)) + if err != nil { + t.Fatalf("Client.Clone returned error: %v", err) + } + + if cloned.client == c.client { + t.Error("Cloned Client has same http.Client instance, but should be different") + } + if cloned.client.Transport != tr { + t.Error("Cloned Client http.Client.Transport is not the same instance as original") + } + if cloned.client.CheckRedirect == nil || fmt.Sprintf("%p", cloned.client.CheckRedirect) != fmt.Sprintf("%p", c.client.CheckRedirect) { + t.Error("Cloned Client http.Client.CheckRedirect is not the same function instance as original") + } + if cloned.client.Jar != c.client.Jar { + t.Error("Cloned Client http.Client.Jar is not the same instance as original") + } + if cloned.client.Timeout != c.client.Timeout { + t.Errorf("Cloned Client http.Client.Timeout is %v, want %v", cloned.client.Timeout, c.client.Timeout) + } + }) + + t.Run("initialized_client_with_http_client", func(t *testing.T) { + t.Parallel() + + c := mustNewClient(t) + c.client.Transport = &http.Transport{IdleConnTimeout: 10 * time.Second} + c.client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return nil } + c.client.Timeout = 15 * time.Second + + h := &http.Client{ + Transport: &http.Transport{IdleConnTimeout: 20 * time.Second}, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse }, + Timeout: 25 * time.Second, + } + + cloned, err := c.Clone(WithHTTPClient(h)) + if err != nil { + t.Fatalf("Client.Clone returned error: %v", err) + } + + if cloned.client == h { + t.Error("Cloned Client has same http.Client instance as provided in WithHTTPClient, but should be a different instance") + } + if cloned.client.Transport != h.Transport { + t.Error("Cloned Client http.Client.Transport is not the same instance as original") + } + if cloned.client.CheckRedirect == nil || fmt.Sprintf("%p", cloned.client.CheckRedirect) != fmt.Sprintf("%p", h.CheckRedirect) { + t.Error("Cloned Client http.Client.CheckRedirect is not the same function instance as original") + } + if cloned.client.Jar != h.Jar { + t.Error("Cloned Client http.Client.Jar is not the same instance as original") + } + if cloned.client.Timeout != h.Timeout { + t.Errorf("Cloned Client http.Client.Timeout is %v, want %v", cloned.client.Timeout, h.Timeout) + } + }) + + t.Run("initialized_client_with_http_client_and_transport", func(t *testing.T) { + t.Parallel() + + c := mustNewClient(t) + c.client.Transport = &http.Transport{IdleConnTimeout: 10 * time.Second} + c.client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return nil } + c.client.Timeout = 15 * time.Second + + h := &http.Client{ + Transport: &http.Transport{IdleConnTimeout: 20 * time.Second}, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse }, + Timeout: 25 * time.Second, + } + + tr := &http.Transport{IdleConnTimeout: 30 * time.Second} + + cloned, err := c.Clone(WithHTTPClient(h), WithTransport(tr)) + if err != nil { + t.Fatalf("Client.Clone returned error: %v", err) + } + + if cloned.client == h { + t.Error("Cloned Client has same http.Client instance as provided in WithHTTPClient, but should be a different instance") + } + if cloned.client.Transport != tr { + t.Error("Cloned Client http.Client.Transport is not the same instance as original") + } + if cloned.client.CheckRedirect == nil || fmt.Sprintf("%p", cloned.client.CheckRedirect) != fmt.Sprintf("%p", h.CheckRedirect) { + t.Error("Cloned Client http.Client.CheckRedirect is not the same function instance as original") + } + if cloned.client.Jar != h.Jar { + t.Error("Cloned Client http.Client.Jar is not the same instance as original") + } + if cloned.client.Timeout != h.Timeout { + t.Errorf("Cloned Client http.Client.Timeout is %v, want %v", cloned.client.Timeout, h.Timeout) + } + }) +} + +func TestClient_Client(t *testing.T) { + t.Parallel() + c, err := NewClient() + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + + c2 := c.Client() + if c.client == c2 { + t.Error("Client returned same http.Client, but should be different") + } +} + // Ensure that length of Client.rateLimits is the same as number of fields in RateLimits struct. func TestClient_rateLimits(t *testing.T) { t.Parallel() @@ -678,7 +1311,7 @@ func TestClient_rateLimits(t *testing.T) { func TestNewRequest(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) inURL, outURL := "/foo", defaultBaseURL+"foo" inBody, outBody := &User{Login: Ptr("l")}, `{"login":"l"}`+"\n" @@ -698,7 +1331,7 @@ func TestNewRequest(t *testing.T) { userAgent := req.Header.Get("User-Agent") // test that default user-agent is attached to the request - if got, want := userAgent, c.UserAgent; got != want { + if got, want := userAgent, c.userAgent; got != want { t.Errorf("NewRequest() User-Agent is %v, want %v", got, want) } @@ -720,7 +1353,7 @@ func TestNewRequest(t *testing.T) { func TestNewRequest_invalidJSON(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) type T struct { F func() @@ -734,14 +1367,14 @@ func TestNewRequest_invalidJSON(t *testing.T) { func TestNewRequest_badURL(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) _, err := c.NewRequest(t.Context(), "GET", ":", nil) testURLParseError(t, err) } func TestNewRequest_badMethod(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) if _, err := c.NewRequest(t.Context(), "BOGUS\nMETHOD", ".", nil); err == nil { t.Fatal("NewRequest returned nil; expected error") } @@ -751,8 +1384,8 @@ func TestNewRequest_badMethod(t *testing.T) { // This caused a problem with Google's internal http client. func TestNewRequest_emptyUserAgent(t *testing.T) { t.Parallel() - c := NewClient(nil) - c.UserAgent = "" + c := mustNewClient(t) + c.userAgent = "" req, err := c.NewRequest(t.Context(), "GET", ".", nil) if err != nil { t.Fatalf("NewRequest returned unexpected error: %v", err) @@ -770,7 +1403,7 @@ func TestNewRequest_emptyUserAgent(t *testing.T) { // subtle errors. func TestNewRequest_emptyBody(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) req, err := c.NewRequest(t.Context(), "GET", ".", nil) if err != nil { t.Fatalf("NewRequest returned unexpected error: %v", err) @@ -789,13 +1422,13 @@ func TestNewRequest_errorForNoTrailingSlash(t *testing.T) { {rawurl: "https://example.com/api/v3", wantError: true}, {rawurl: "https://example.com/api/v3/", wantError: false}, } - c := NewClient(nil) + c := mustNewClient(t) for _, test := range tests { u, err := url.Parse(test.rawurl) if err != nil { t.Fatalf("url.Parse returned unexpected error: %v.", err) } - c.BaseURL = u + c.baseURL = u if _, err := c.NewRequest(t.Context(), "GET", "test", nil); test.wantError && err == nil { t.Fatal("Expected error to be returned.") } else if !test.wantError && err != nil { @@ -840,7 +1473,7 @@ func TestCheckURLPathTraversal(t *testing.T) { func TestNewRequest_pathTraversal(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) tests := []struct { urlStr string @@ -863,7 +1496,7 @@ func TestNewRequest_pathTraversal(t *testing.T) { func TestNewFormRequest_pathTraversal(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) _, err := c.NewFormRequest(t.Context(), "repos/x/../../../admin", nil) if !errors.Is(err, ErrPathForbidden) { @@ -873,7 +1506,7 @@ func TestNewFormRequest_pathTraversal(t *testing.T) { func TestNewUploadRequest_pathTraversal(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) _, err := c.NewUploadRequest(t.Context(), "repos/x/../../../admin", nil, 0, "") if !errors.Is(err, ErrPathForbidden) { @@ -883,7 +1516,7 @@ func TestNewUploadRequest_pathTraversal(t *testing.T) { func TestNewFormRequest(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) inURL, outURL := "/foo", defaultBaseURL+"foo" form := url.Values{} @@ -909,7 +1542,7 @@ func TestNewFormRequest(t *testing.T) { } // test that default user-agent is attached to the request - if got, want := req.Header.Get("User-Agent"), c.UserAgent; got != want { + if got, want := req.Header.Get("User-Agent"), c.userAgent; got != want { t.Errorf("NewFormRequest() User-Agent is %v, want %v", got, want) } @@ -930,15 +1563,15 @@ func TestNewFormRequest(t *testing.T) { func TestNewFormRequest_badURL(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) _, err := c.NewFormRequest(t.Context(), ":", nil) testURLParseError(t, err) } func TestNewFormRequest_emptyUserAgent(t *testing.T) { t.Parallel() - c := NewClient(nil) - c.UserAgent = "" + c := mustNewClient(t) + c.userAgent = "" req, err := c.NewFormRequest(t.Context(), ".", nil) if err != nil { t.Fatalf("NewFormRequest returned unexpected error: %v", err) @@ -950,7 +1583,7 @@ func TestNewFormRequest_emptyUserAgent(t *testing.T) { func TestNewFormRequest_emptyBody(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) req, err := c.NewFormRequest(t.Context(), ".", nil) if err != nil { t.Fatalf("NewFormRequest returned unexpected error: %v", err) @@ -969,13 +1602,13 @@ func TestNewFormRequest_errorForNoTrailingSlash(t *testing.T) { {rawURL: "https://example.com/api/v3", wantError: true}, {rawURL: "https://example.com/api/v3/", wantError: false}, } - c := NewClient(nil) + c := mustNewClient(t) for _, test := range tests { u, err := url.Parse(test.rawURL) if err != nil { t.Fatalf("url.Parse returned unexpected error: %v.", err) } - c.BaseURL = u + c.baseURL = u if _, err := c.NewFormRequest(t.Context(), "test", nil); test.wantError && err == nil { t.Fatal("Expected error to be returned.") } else if !test.wantError && err != nil { @@ -986,7 +1619,7 @@ func TestNewFormRequest_errorForNoTrailingSlash(t *testing.T) { func TestNewUploadRequest_WithVersion(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) req, _ := c.NewUploadRequest(t.Context(), "https://example.com/", nil, 0, "") apiVersion := req.Header.Get(headerAPIVersion) @@ -1003,7 +1636,7 @@ func TestNewUploadRequest_WithVersion(t *testing.T) { func TestNewUploadRequest_badURL(t *testing.T) { t.Parallel() - c := NewClient(nil) + c := mustNewClient(t) _, err := c.NewUploadRequest(t.Context(), ":", nil, 0, "") testURLParseError(t, err) @@ -1023,13 +1656,13 @@ func TestNewUploadRequest_errorForNoTrailingSlash(t *testing.T) { {rawurl: "https://example.com/api/uploads", wantError: true}, {rawurl: "https://example.com/api/uploads/", wantError: false}, } - c := NewClient(nil) + c := mustNewClient(t) for _, test := range tests { u, err := url.Parse(test.rawurl) if err != nil { t.Fatalf("url.Parse returned unexpected error: %v.", err) } - c.UploadURL = u + c.uploadURL = u if _, err = c.NewUploadRequest(t.Context(), "test", nil, 0, ""); test.wantError && err == nil { t.Fatal("Expected error to be returned.") } else if !test.wantError && err != nil { @@ -1438,8 +2071,8 @@ func TestDo_sanitizeURL(t *testing.T) { ClientID: "id", ClientSecret: "secret", } - unauthedClient := NewClient(tp.Client()) - unauthedClient.BaseURL = &url.URL{Scheme: "http", Host: "127.0.0.1:0", Path: "/"} // Use port 0 on purpose to trigger a dial TCP error, expect to get "dial tcp 127.0.0.1:0: connect: can't assign requested address". + unauthedClient := mustNewClient(t, WithHTTPClient(tp.Client())) + unauthedClient.baseURL = &url.URL{Scheme: "http", Host: "127.0.0.1:0", Path: "/"} // Use port 0 on purpose to trigger a dial TCP error, expect to get "dial tcp 127.0.0.1:0: connect: can't assign requested address". req, err := unauthedClient.NewRequest(t.Context(), "GET", ".", nil) if err != nil { t.Fatalf("NewRequest returned unexpected error: %v", err) @@ -2129,7 +2762,7 @@ func TestDo_rateLimit_abuseRateLimitError_maxDuration(t *testing.T) { t.Parallel() client, mux, _ := setup(t) // specify a max retry after duration of 1 min - client.MaxSecondaryRateLimitRetryAfterDuration = 60 * time.Second + client.maxSecondaryRateLimitRetryAfterDuration = 60 * time.Second // x-ratelimit-reset value of 1h into the future, to make sure we are way over the max wait time duration. blockUntil := time.Now().Add(1 * time.Hour).Unix() @@ -2159,7 +2792,7 @@ func TestDo_rateLimit_abuseRateLimitError_maxDuration(t *testing.T) { t.Fatal("abuseRateLimitErr RetryAfter is nil, expected not-nil") } // check that the retry after is set to be the max allowed duration - if got, want := *abuseRateLimitErr.RetryAfter, client.MaxSecondaryRateLimitRetryAfterDuration; got != want { + if got, want := *abuseRateLimitErr.RetryAfter, client.maxSecondaryRateLimitRetryAfterDuration; got != want { t.Errorf("abuseRateLimitErr RetryAfter = %v, want %v", got, want) } } @@ -2168,7 +2801,7 @@ func TestDo_rateLimit_abuseRateLimitError_maxDuration(t *testing.T) { func TestDo_rateLimit_disableRateLimitCheck(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.DisableRateLimitCheck = true + client.disableRateLimitCheck = true reset := time.Now().UTC().Add(60 * time.Second) client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}} @@ -3155,8 +3788,8 @@ func TestUnauthenticatedRateLimitedTransport(t *testing.T) { ClientID: clientID, ClientSecret: clientSecret, } - unauthedClient := NewClient(tp.Client()) - unauthedClient.BaseURL = client.BaseURL + unauthedClient := mustNewClient(t, WithHTTPClient(tp.Client())) + unauthedClient.baseURL = client.baseURL req, _ := unauthedClient.NewRequest(t.Context(), "GET", ".", nil) _, err := unauthedClient.Do(req, nil) assertNilError(t, err) @@ -3232,8 +3865,8 @@ func TestBasicAuthTransport(t *testing.T) { Password: password, OTP: otp, } - basicAuthClient := NewClient(tp.Client()) - basicAuthClient.BaseURL = client.BaseURL + basicAuthClient := mustNewClient(t, WithHTTPClient(tp.Client())) + basicAuthClient.baseURL = client.baseURL req, _ := basicAuthClient.NewRequest(t.Context(), "GET", ".", nil) _, err := basicAuthClient.Do(req, nil) assertNilError(t, err) @@ -3632,13 +4265,16 @@ func TestClientCopy_leak_transport(t *testing.T) { accessToken := r.Header.Get("Authorization") _, _ = fmt.Fprintf(w, `{"login": "%v"}`, accessToken) })) - clientPreconfiguredWithURLs, err := NewClient(nil).WithEnterpriseURLs(srv.URL, srv.URL) + clientPreconfiguredWithURLs := mustNewClient(t, WithEnterpriseURLs(srv.URL, srv.URL)) + + aliceClient, err := clientPreconfiguredWithURLs.Clone(WithAuthToken("alice")) + if err != nil { + t.Fatal(err) + } + bobClient, err := clientPreconfiguredWithURLs.Clone(WithAuthToken("bob")) if err != nil { t.Fatal(err) } - - aliceClient := clientPreconfiguredWithURLs.WithAuthToken("alice") - bobClient := clientPreconfiguredWithURLs.WithAuthToken("bob") alice, _, err := aliceClient.Users.Get(t.Context(), "") if err != nil { diff --git a/github/rate_limit.go b/github/rate_limit.go index 47aec9d8216..a63f8382d00 100644 --- a/github/rate_limit.go +++ b/github/rate_limit.go @@ -77,7 +77,7 @@ func (r RateLimits) String() string { //meta:operation GET /rate_limit func (s *RateLimitService) Get(ctx context.Context) (*RateLimits, *Response, error) { // This resource is not subject to rate limits. - if !s.client.DisableRateLimitCheck { + if !s.client.disableRateLimitCheck { ctx = context.WithValue(ctx, BypassRateLimitCheck, true) } diff --git a/github/rate_limit_test.go b/github/rate_limit_test.go index 3eb5a817999..e02c1e1284a 100644 --- a/github/rate_limit_test.go +++ b/github/rate_limit_test.go @@ -409,7 +409,7 @@ func TestRateLimits_bypassRateLimitCheckContext(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.DisableRateLimitCheck = tt.disableRateLimitCheck + client.disableRateLimitCheck = tt.disableRateLimitCheck mux.HandleFunc("/rate_limit", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") diff --git a/github/repos_contents.go b/github/repos_contents.go index 9932deb873e..024496fe596 100644 --- a/github/repos_contents.go +++ b/github/repos_contents.go @@ -340,7 +340,7 @@ func (s *RepositoriesService) GetArchiveLink(ctx context.Context, owner, repo st u += fmt.Sprintf("/%v", opts.Ref) } - if s.client.RateLimitRedirectionalEndpoints { + if s.client.rateLimitRedirectionalEndpoints { return s.getArchiveLinkWithRateLimit(ctx, u, maxRedirects) } diff --git a/github/repos_contents_test.go b/github/repos_contents_test.go index dcaf602411e..99c96d0f22a 100644 --- a/github/repos_contents_test.go +++ b/github/repos_contents_test.go @@ -972,7 +972,7 @@ func TestRepositoriesService_GetArchiveLink(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/tarball/yo", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -1029,7 +1029,7 @@ func TestRepositoriesService_GetArchiveLink_StatusMovedPermanently_dontFollowRed t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, _ := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits mux.HandleFunc("/repos/o/r/tarball", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") @@ -1064,7 +1064,7 @@ func TestRepositoriesService_GetArchiveLink_StatusMovedPermanently_followRedirec t.Run(tc.name, func(t *testing.T) { t.Parallel() client, mux, serverURL := setup(t) - client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + client.rateLimitRedirectionalEndpoints = tc.respectRateLimits // Mock a redirect link, which leads to an archive link mux.HandleFunc("/repos/o/r/tarball", func(w http.ResponseWriter, r *http.Request) { diff --git a/github/repos_releases_test.go b/github/repos_releases_test.go index 552ec760e04..0a94aaa090d 100644 --- a/github/repos_releases_test.go +++ b/github/repos_releases_test.go @@ -983,7 +983,7 @@ func TestRepositoriesService_UploadReleaseAssetFromRelease_AbsoluteTemplate(t *t size := int64(len(body)) // Build an absolute URL using the test client's BaseURL. - absoluteUploadURL := client.BaseURL.String() + "repos/o/r/releases/1/assets{?name,label}" + absoluteUploadURL := client.baseURL.String() + "repos/o/r/releases/1/assets{?name,label}" release := &RepositoryRelease{UploadURL: &absoluteUploadURL} opts := &UploadOptions{Name: "abs.txt"} diff --git a/test/fields/fields.go b/test/fields/fields.go index 501d8dbaec0..85c07f5676a 100644 --- a/test/fields/fields.go +++ b/test/fields/fields.go @@ -40,9 +40,19 @@ func main() { token := os.Getenv("GITHUB_AUTH_TOKEN") if token == "" { fmt.Print("!!! No OAuth token. Some tests won't run. !!!\n\n") - client = github.NewClient(nil) + c, err := github.NewClient() + if err != nil { + fmt.Printf("Error creating GitHub client: %v\n", err) + os.Exit(1) + } + client = c } else { - client = github.NewClient(nil).WithAuthToken(token) + c, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + fmt.Printf("Error creating GitHub client with token: %v\n", err) + os.Exit(1) + } + client = c } for _, tt := range []struct { diff --git a/test/integration/authorizations_test.go b/test/integration/authorizations_test.go index 307091c3d6c..96ad60c8599 100644 --- a/test/integration/authorizations_test.go +++ b/test/integration/authorizations_test.go @@ -125,6 +125,8 @@ func failIfNotStatusCode(t *testing.T, resp *github.Response, expectedCode int) // // See GitHub API docs: https://developer.github.com/v3/oauth_authorizations/#check-an-authorization func getOAuthAppClient(t *testing.T) *github.Client { + t.Helper() + username, ok := os.LookupEnv(envKeyClientID) if !ok { t.Skipf(msgEnvMissing, envKeyClientID) @@ -140,5 +142,9 @@ func getOAuthAppClient(t *testing.T) *github.Client { Password: strings.TrimSpace(password), } - return github.NewClient(tp.Client()) + c, err := github.NewClient(github.WithHTTPClient(tp.Client())) + if err != nil { + t.Fatal(err) + } + return c } diff --git a/test/integration/github_test.go b/test/integration/github_test.go index 787c8f145fe..8d25d255af5 100644 --- a/test/integration/github_test.go +++ b/test/integration/github_test.go @@ -24,9 +24,20 @@ import ( var client, auth = sync.OnceValues(func() (*github.Client, bool) { token := os.Getenv("GITHUB_AUTH_TOKEN") if token == "" { - return github.NewClient(nil), false + c, err := github.NewClient() + if err != nil { + fmt.Printf("Error creating GitHub client: %v\n", err) + os.Exit(1) + } + return c, false + } + + c, err := github.NewClient(github.WithAuthToken(token)) + if err != nil { + fmt.Printf("Error creating GitHub client with token: %v\n", err) + os.Exit(1) } - return github.NewClient(nil).WithAuthToken(token), true + return c, true })() func skipIfMissingAuth(t *testing.T) { diff --git a/tools/metadata/main.go b/tools/metadata/main.go index cb4c9d76894..8c70152c720 100644 --- a/tools/metadata/main.go +++ b/tools/metadata/main.go @@ -57,6 +57,7 @@ type rootCmd struct { // for testing GithubURL string `kong:"hidden,default='https://api.github.com'"` + UploadURL string `kong:"hidden,default='https://uploads.github.com'"` } func (c *rootCmd) opsFile() (string, *operationsFile, error) { @@ -68,12 +69,12 @@ func (c *rootCmd) opsFile() (string, *operationsFile, error) { return filename, opsFile, nil } -func githubClient(apiURL string) (*github.Client, error) { +func githubClient(apiURL, uploadURL string) (*github.Client, error) { token := os.Getenv("GITHUB_TOKEN") if token == "" { return nil, errors.New("GITHUB_TOKEN environment variable must be set to a GitHub personal access token with the public_repo scope") } - return github.NewClient(nil).WithAuthToken(token).WithEnterpriseURLs(apiURL, "") + return github.NewClient(github.WithAuthToken(token), github.WithEnterpriseURLs(apiURL, uploadURL)) } type updateOpenAPICmd struct { @@ -95,7 +96,7 @@ func (c *updateOpenAPICmd) Run(root *rootCmd) error { for i := range origOps { origOps[i] = origOps[i].clone() } - client, err := githubClient(root.GithubURL) + client, err := githubClient(root.GithubURL, root.UploadURL) if err != nil { return err }