diff --git a/github/repos_rules.go b/github/repos_rules.go index 14a90b00596..ddf8929e23b 100644 --- a/github/repos_rules.go +++ b/github/repos_rules.go @@ -8,14 +8,15 @@ package github import ( "context" "fmt" + "iter" ) -// GetRulesForBranch gets all the repository rules that apply to the specified branch. +// ListRulesForBranch gets all the repository rules that apply to the specified branch. // // GitHub API docs: https://docs.github.com/rest/repos/rules?apiVersion=2022-11-28#get-rules-for-a-branch // //meta:operation GET /repos/{owner}/{repo}/rules/branches/{branch} -func (s *RepositoriesService) GetRulesForBranch(ctx context.Context, owner, repo, branch string, opts *ListOptions) (*BranchRules, *Response, error) { +func (s *RepositoriesService) ListRulesForBranch(ctx context.Context, owner, repo, branch string, opts *ListOptions) (*BranchRules, *Response, error) { u := fmt.Sprintf("repos/%v/%v/rules/branches/%v", owner, repo, branch) u, err := addOptions(u, opts) @@ -37,6 +38,147 @@ func (s *RepositoriesService) GetRulesForBranch(ctx context.Context, owner, repo return rules, resp, nil } +// ListRulesForBranchIter returns an iterator that paginates through all results of ListRulesForBranch. +// +// Note that since [BranchRules] contains a large number of slices, this iterator +// returns type `any` and it is therefore the responsibility of the caller to perform a +// type switch to determine what item is being returned for each iteration. +func (s *RepositoriesService) ListRulesForBranchIter(ctx context.Context, owner, repo, branch string, opts *ListOptions) iter.Seq2[any, error] { + return func(yield func(any, error) bool) { + // Create a copy of opts to avoid mutating the caller's struct + if opts == nil { + opts = &ListOptions{} + } else { + opts = Ptr(*opts) + } + + for { + results, resp, err := s.ListRulesForBranch(ctx, owner, repo, branch, opts) + if err != nil { + yield(nil, err) + return + } + + // Now iterate through ALL possible results from [BranchRules]. + for _, item := range results.Creation { + if !yield(item, nil) { + return + } + } + for _, item := range results.Update { + if !yield(item, nil) { + return + } + } + for _, item := range results.Deletion { + if !yield(item, nil) { + return + } + } + for _, item := range results.RequiredLinearHistory { + if !yield(item, nil) { + return + } + } + for _, item := range results.MergeQueue { + if !yield(item, nil) { + return + } + } + for _, item := range results.RequiredDeployments { + if !yield(item, nil) { + return + } + } + for _, item := range results.RequiredSignatures { + if !yield(item, nil) { + return + } + } + for _, item := range results.PullRequest { + if !yield(item, nil) { + return + } + } + for _, item := range results.RequiredStatusChecks { + if !yield(item, nil) { + return + } + } + for _, item := range results.NonFastForward { + if !yield(item, nil) { + return + } + } + for _, item := range results.CommitMessagePattern { + if !yield(item, nil) { + return + } + } + for _, item := range results.CommitAuthorEmailPattern { + if !yield(item, nil) { + return + } + } + for _, item := range results.CommitterEmailPattern { + if !yield(item, nil) { + return + } + } + for _, item := range results.BranchNamePattern { + if !yield(item, nil) { + return + } + } + for _, item := range results.TagNamePattern { + if !yield(item, nil) { + return + } + } + for _, item := range results.Workflows { + if !yield(item, nil) { + return + } + } + for _, item := range results.CodeScanning { + if !yield(item, nil) { + return + } + } + for _, item := range results.CopilotCodeReview { + if !yield(item, nil) { + return + } + } + for _, item := range results.FileExtensionRestriction { + if !yield(item, nil) { + return + } + } + for _, item := range results.FilePathRestriction { + if !yield(item, nil) { + return + } + } + for _, item := range results.MaxFilePathLength { + if !yield(item, nil) { + return + } + } + for _, item := range results.MaxFileSize { + if !yield(item, nil) { + return + } + } + + if resp.NextPage == 0 { + break + } + opts.Page = resp.NextPage + } + } +} + // RepositoryListRulesetsOptions specifies optional parameters to the // RepositoriesService.GetAllRulesets method. type RepositoryListRulesetsOptions struct { diff --git a/github/repos_rules_test.go b/github/repos_rules_test.go index 4b20e672519..ff93fbd4887 100644 --- a/github/repos_rules_test.go +++ b/github/repos_rules_test.go @@ -14,7 +14,7 @@ import ( "github.com/google/go-cmp/cmp" ) -func TestRepositoriesService_GetRulesForBranch(t *testing.T) { +func TestRepositoriesService_ListRulesForBranch(t *testing.T) { t.Parallel() client, mux, _ := setup(t) @@ -40,9 +40,9 @@ func TestRepositoriesService_GetRulesForBranch(t *testing.T) { }) ctx := t.Context() - rules, _, err := client.Repositories.GetRulesForBranch(ctx, "o", "repo", "branch", nil) + rules, _, err := client.Repositories.ListRulesForBranch(ctx, "o", "repo", "branch", nil) if err != nil { - t.Errorf("Repositories.GetRulesForBranch returned error: %v", err) + t.Errorf("Repositories.ListRulesForBranch returned error: %v", err) } want := &BranchRules{ @@ -51,12 +51,12 @@ func TestRepositoriesService_GetRulesForBranch(t *testing.T) { } if !cmp.Equal(rules, want) { - t.Errorf("Repositories.GetRulesForBranch returned %+v, want %+v", rules, want) + t.Errorf("Repositories.ListRulesForBranch returned %+v, want %+v", rules, want) } - const methodName = "GetRulesForBranch" + const methodName = "ListRulesForBranch" testNewRequestAndDoFailure(t, methodName, client, func() (*Response, error) { - got, resp, err := client.Repositories.GetRulesForBranch(ctx, "o", "repo", "branch", nil) + got, resp, err := client.Repositories.ListRulesForBranch(ctx, "o", "repo", "branch", nil) if got != nil { t.Errorf("testNewRequestAndDoFailure %v = %#v, want nil", methodName, got) } @@ -64,6 +64,76 @@ func TestRepositoriesService_GetRulesForBranch(t *testing.T) { }) } +func TestRepositoriesService_ListRulesForBranchIter(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + var callNum int + mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + callNum++ + switch callNum { + case 1: + w.Header().Set("Link", `; rel="next"`) + fmt.Fprint(w, `[{"type":"creation"},{"type":"deletion"},{"type":"update"}]`) + case 2: + fmt.Fprint(w, `[{"type":"creation"},{"type":"deletion"},{"type":"update"},{"type":"workflows"}]`) + case 3, 5: + fmt.Fprint(w, `[{"type":"creation"},{"type":"deletion"}]`) + case 4: + w.WriteHeader(http.StatusNotFound) + } + }) + + iter := client.Repositories.ListRulesForBranchIter(t.Context(), "o", "r", "b", nil) + var gotItems int + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 7; gotItems != want { + t.Errorf("client.Repositories.ListRulesForBranchIter call 1 got %v items; want %v", gotItems, want) + } + + opts := &ListOptions{} + iter = client.Repositories.ListRulesForBranchIter(t.Context(), "o", "r", "b", opts) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 2; gotItems != want { + t.Errorf("client.Repositories.ListRulesForBranchIter call 2 got %v items; want %v", gotItems, want) + } + + iter = client.Repositories.ListRulesForBranchIter(t.Context(), "o", "r", "b", nil) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err == nil { + t.Error("expected error; got nil") + } + } + if gotItems != 1 { + t.Errorf("client.Repositories.ListRulesForBranchIter call 3 got %v items; want 1 (an error)", gotItems) + } + + iter = client.Repositories.ListRulesForBranchIter(t.Context(), "o", "r", "b", nil) + gotItems = 0 + iter(func(_ any, err error) bool { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + return false + }) + if gotItems != 1 { + t.Errorf("client.Repositories.ListRulesForBranchIter call 4 got %v items; want 1 (an error)", gotItems) + } +} + func TestRepositoriesService_UpdateRuleset_OmitZero_Nil(t *testing.T) { t.Parallel() client, mux, _ := setup(t) @@ -132,7 +202,7 @@ func TestRepositoriesService_UpdateRuleset_OmitZero_EmptySlice(t *testing.T) { } } -func TestRepositoriesService_GetRulesForBranch_ListOptions(t *testing.T) { +func TestRepositoriesService_ListRulesForBranch_ListOptions(t *testing.T) { t.Parallel() client, mux, _ := setup(t) @@ -152,9 +222,9 @@ func TestRepositoriesService_GetRulesForBranch_ListOptions(t *testing.T) { opts := &ListOptions{Page: 2, PerPage: 35} ctx := t.Context() - rules, _, err := client.Repositories.GetRulesForBranch(ctx, "o", "repo", "branch", opts) + rules, _, err := client.Repositories.ListRulesForBranch(ctx, "o", "repo", "branch", opts) if err != nil { - t.Errorf("Repositories.GetRulesForBranch returned error: %v", err) + t.Errorf("Repositories.ListRulesForBranch returned error: %v", err) } want := &BranchRules{ @@ -162,17 +232,17 @@ func TestRepositoriesService_GetRulesForBranch_ListOptions(t *testing.T) { } if !cmp.Equal(rules, want) { - t.Errorf("Repositories.GetRulesForBranch returned %+v, want %+v", rules, want) + t.Errorf("Repositories.ListRulesForBranch returned %+v, want %+v", rules, want) } - const methodName = "GetRulesForBranch" + const methodName = "ListRulesForBranch" testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Repositories.GetRulesForBranch(ctx, "\n", "\n", "\n", opts) + _, _, err = client.Repositories.ListRulesForBranch(ctx, "\n", "\n", "\n", opts) return err }) testNewRequestAndDoFailure(t, methodName, client, func() (*Response, error) { - got, resp, err := client.Repositories.GetRulesForBranch(ctx, "o", "repo", "branch", opts) + got, resp, err := client.Repositories.ListRulesForBranch(ctx, "o", "repo", "branch", opts) if got != nil { t.Errorf("testNewRequestAndDoFailure %v = %#v, want nil", methodName, got) }