Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 144 additions & 2 deletions github/repos_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ package github
import (
"context"
"fmt"
"iter"
)

// GetRulesForBranch gets all the repository rules that apply to the specified branch.
// ListRulesForBranch gets all the repository rules that apply to the specified branch.
//
// GitHub API docs: https://docs.github.com/rest/repos/rules?apiVersion=2022-11-28#get-rules-for-a-branch
//
//meta:operation GET /repos/{owner}/{repo}/rules/branches/{branch}
func (s *RepositoriesService) GetRulesForBranch(ctx context.Context, owner, repo, branch string, opts *ListOptions) (*BranchRules, *Response, error) {
func (s *RepositoriesService) ListRulesForBranch(ctx context.Context, owner, repo, branch string, opts *ListOptions) (*BranchRules, *Response, error) {
u := fmt.Sprintf("repos/%v/%v/rules/branches/%v", owner, repo, branch)

u, err := addOptions(u, opts)
Expand All @@ -37,6 +38,147 @@ func (s *RepositoriesService) GetRulesForBranch(ctx context.Context, owner, repo
return rules, resp, nil
}

// ListRulesForBranchIter returns an iterator that paginates through all results of ListRulesForBranch.
//
// Note that since [BranchRules] contains a large number of slices, this iterator
// returns type `any` and it is therefore the responsibility of the caller to perform a
// type switch to determine what item is being returned for each iteration.
func (s *RepositoriesService) ListRulesForBranchIter(ctx context.Context, owner, repo, branch string, opts *ListOptions) iter.Seq2[any, error] {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see why this iterator isn't being generated, but do you need to add some configuration to block the generator or will it see that this already exists?

I'd also like a go refactoring it once it's merged as I think we could make the API a bit more specific and remove any.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see why this iterator isn't being generated, but do you need to add some configuration to block the generator or will it see that this already exists?

No, the generator already skips this one due to the BranchRules struct having many slices and one is not called out as "the" slice to use.

I'd also like a go refactoring it once it's merged as I think we could make the API a bit more specific and remove any.

OK, sure.

return func(yield func(any, error) bool) {
// Create a copy of opts to avoid mutating the caller's struct
if opts == nil {
opts = &ListOptions{}
} else {
opts = Ptr(*opts)
}

for {
results, resp, err := s.ListRulesForBranch(ctx, owner, repo, branch, opts)
if err != nil {
yield(nil, err)
return
}

// Now iterate through ALL possible results from [BranchRules].
for _, item := range results.Creation {
if !yield(item, nil) {
return
}
}
for _, item := range results.Update {
if !yield(item, nil) {
return
}
}
for _, item := range results.Deletion {
if !yield(item, nil) {
return
}
}
for _, item := range results.RequiredLinearHistory {
if !yield(item, nil) {
return
}
}
for _, item := range results.MergeQueue {
if !yield(item, nil) {
return
}
}
for _, item := range results.RequiredDeployments {
if !yield(item, nil) {
return
}
}
for _, item := range results.RequiredSignatures {
if !yield(item, nil) {
return
}
}
for _, item := range results.PullRequest {
if !yield(item, nil) {
return
}
}
for _, item := range results.RequiredStatusChecks {
if !yield(item, nil) {
return
}
}
for _, item := range results.NonFastForward {
if !yield(item, nil) {
return
}
}
for _, item := range results.CommitMessagePattern {
if !yield(item, nil) {
return
}
}
for _, item := range results.CommitAuthorEmailPattern {
if !yield(item, nil) {
return
}
}
for _, item := range results.CommitterEmailPattern {
if !yield(item, nil) {
return
}
}
for _, item := range results.BranchNamePattern {
if !yield(item, nil) {
return
}
}
for _, item := range results.TagNamePattern {
if !yield(item, nil) {
return
}
}
for _, item := range results.Workflows {
if !yield(item, nil) {
return
}
}
for _, item := range results.CodeScanning {
if !yield(item, nil) {
return
}
}
for _, item := range results.CopilotCodeReview {
if !yield(item, nil) {
return
}
}
for _, item := range results.FileExtensionRestriction {
if !yield(item, nil) {
return
}
}
for _, item := range results.FilePathRestriction {
if !yield(item, nil) {
return
}
}
for _, item := range results.MaxFilePathLength {
if !yield(item, nil) {
return
}
}
for _, item := range results.MaxFileSize {
if !yield(item, nil) {
return
}
}

if resp.NextPage == 0 {
break
}
opts.Page = resp.NextPage
}
}
}

// RepositoryListRulesetsOptions specifies optional parameters to the
// RepositoriesService.GetAllRulesets method.
type RepositoryListRulesetsOptions struct {
Expand Down
96 changes: 83 additions & 13 deletions github/repos_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/google/go-cmp/cmp"
)

func TestRepositoriesService_GetRulesForBranch(t *testing.T) {
func TestRepositoriesService_ListRulesForBranch(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)

Expand All @@ -40,9 +40,9 @@ func TestRepositoriesService_GetRulesForBranch(t *testing.T) {
})

ctx := t.Context()
rules, _, err := client.Repositories.GetRulesForBranch(ctx, "o", "repo", "branch", nil)
rules, _, err := client.Repositories.ListRulesForBranch(ctx, "o", "repo", "branch", nil)
if err != nil {
t.Errorf("Repositories.GetRulesForBranch returned error: %v", err)
t.Errorf("Repositories.ListRulesForBranch returned error: %v", err)
}

want := &BranchRules{
Expand All @@ -51,19 +51,89 @@ func TestRepositoriesService_GetRulesForBranch(t *testing.T) {
}

if !cmp.Equal(rules, want) {
t.Errorf("Repositories.GetRulesForBranch returned %+v, want %+v", rules, want)
t.Errorf("Repositories.ListRulesForBranch returned %+v, want %+v", rules, want)
}

const methodName = "GetRulesForBranch"
const methodName = "ListRulesForBranch"
testNewRequestAndDoFailure(t, methodName, client, func() (*Response, error) {
got, resp, err := client.Repositories.GetRulesForBranch(ctx, "o", "repo", "branch", nil)
got, resp, err := client.Repositories.ListRulesForBranch(ctx, "o", "repo", "branch", nil)
if got != nil {
t.Errorf("testNewRequestAndDoFailure %v = %#v, want nil", methodName, got)
}
return resp, err
})
}

func TestRepositoriesService_ListRulesForBranchIter(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)
var callNum int
mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
callNum++
switch callNum {
case 1:
w.Header().Set("Link", `<https://api.github.com/?page=1>; rel="next"`)
fmt.Fprint(w, `[{"type":"creation"},{"type":"deletion"},{"type":"update"}]`)
case 2:
fmt.Fprint(w, `[{"type":"creation"},{"type":"deletion"},{"type":"update"},{"type":"workflows"}]`)
case 3, 5:
fmt.Fprint(w, `[{"type":"creation"},{"type":"deletion"}]`)
case 4:
w.WriteHeader(http.StatusNotFound)
}
})

iter := client.Repositories.ListRulesForBranchIter(t.Context(), "o", "r", "b", nil)
var gotItems int
for _, err := range iter {
gotItems++
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
if want := 7; gotItems != want {
t.Errorf("client.Repositories.ListRulesForBranchIter call 1 got %v items; want %v", gotItems, want)
}

opts := &ListOptions{}
iter = client.Repositories.ListRulesForBranchIter(t.Context(), "o", "r", "b", opts)
gotItems = 0
for _, err := range iter {
gotItems++
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
if want := 2; gotItems != want {
t.Errorf("client.Repositories.ListRulesForBranchIter call 2 got %v items; want %v", gotItems, want)
}

iter = client.Repositories.ListRulesForBranchIter(t.Context(), "o", "r", "b", nil)
gotItems = 0
for _, err := range iter {
gotItems++
if err == nil {
t.Error("expected error; got nil")
}
}
if gotItems != 1 {
t.Errorf("client.Repositories.ListRulesForBranchIter call 3 got %v items; want 1 (an error)", gotItems)
}

iter = client.Repositories.ListRulesForBranchIter(t.Context(), "o", "r", "b", nil)
gotItems = 0
iter(func(_ any, err error) bool {
gotItems++
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
return false
})
if gotItems != 1 {
t.Errorf("client.Repositories.ListRulesForBranchIter call 4 got %v items; want 1 (an error)", gotItems)
}
}

func TestRepositoriesService_UpdateRuleset_OmitZero_Nil(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)
Expand Down Expand Up @@ -132,7 +202,7 @@ func TestRepositoriesService_UpdateRuleset_OmitZero_EmptySlice(t *testing.T) {
}
}

func TestRepositoriesService_GetRulesForBranch_ListOptions(t *testing.T) {
func TestRepositoriesService_ListRulesForBranch_ListOptions(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)

Expand All @@ -152,27 +222,27 @@ func TestRepositoriesService_GetRulesForBranch_ListOptions(t *testing.T) {

opts := &ListOptions{Page: 2, PerPage: 35}
ctx := t.Context()
rules, _, err := client.Repositories.GetRulesForBranch(ctx, "o", "repo", "branch", opts)
rules, _, err := client.Repositories.ListRulesForBranch(ctx, "o", "repo", "branch", opts)
if err != nil {
t.Errorf("Repositories.GetRulesForBranch returned error: %v", err)
t.Errorf("Repositories.ListRulesForBranch returned error: %v", err)
}

want := &BranchRules{
Creation: []*BranchRuleMetadata{{RulesetID: 42069}},
}

if !cmp.Equal(rules, want) {
t.Errorf("Repositories.GetRulesForBranch returned %+v, want %+v", rules, want)
t.Errorf("Repositories.ListRulesForBranch returned %+v, want %+v", rules, want)
}

const methodName = "GetRulesForBranch"
const methodName = "ListRulesForBranch"
testBadOptions(t, methodName, func() (err error) {
_, _, err = client.Repositories.GetRulesForBranch(ctx, "\n", "\n", "\n", opts)
_, _, err = client.Repositories.ListRulesForBranch(ctx, "\n", "\n", "\n", opts)
return err
})

testNewRequestAndDoFailure(t, methodName, client, func() (*Response, error) {
got, resp, err := client.Repositories.GetRulesForBranch(ctx, "o", "repo", "branch", opts)
got, resp, err := client.Repositories.ListRulesForBranch(ctx, "o", "repo", "branch", opts)
if got != nil {
t.Errorf("testNewRequestAndDoFailure %v = %#v, want nil", methodName, got)
}
Expand Down
Loading