Skip to content

Commit dfb0f24

Browse files
authored
Sanitize column names after applying column name transformation (#38)
* fix: sanitize column names in pagination and add test for column function * fix: update pagination defaults and enhance column sorting functionality * fix: correct SQL query string formatting in pagination test
1 parent 6890ffb commit dfb0f24

2 files changed

Lines changed: 136 additions & 59 deletions

File tree

page.go

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,11 @@ type Sort struct {
2929
}
3030

3131
func (s Sort) String() string {
32-
if s.Column == "" {
33-
return ""
34-
}
35-
s.Order = sanitizeOrder(s.Order)
3632
return fmt.Sprintf("%s %s", s.Column, s.Order)
3733
}
3834

3935
var _MatcherOrderBy = regexp.MustCompile(`-?([a-zA-Z0-9]+)`)
4036

41-
func sanitizeOrder(order Order) Order {
42-
switch strings.ToUpper(strings.TrimSpace(string(order))) {
43-
case string(Desc):
44-
return Desc
45-
case string(Asc):
46-
return Asc
47-
default:
48-
return Asc
49-
}
50-
}
51-
5237
func NewSort(s string) (Sort, bool) {
5338
s = strings.TrimSpace(s)
5439
if s == "" || !_MatcherOrderBy.MatchString(s) {
@@ -74,12 +59,6 @@ type Page struct {
7459
}
7560

7661
func NewPage(size, page uint32, sort ...Sort) *Page {
77-
if size == 0 {
78-
size = DefaultPageSize
79-
}
80-
if page == 0 {
81-
page = 1
82-
}
8362
return &Page{
8463
Size: size,
8564
Page: page,
@@ -105,40 +84,48 @@ func (p *Page) SetDefaults(o *PaginatorSettings) {
10584
}
10685
}
10786

108-
func (p *Page) GetOrder(defaultSort ...string) []Sort {
109-
// if page has sort, use it
87+
func (p *Page) GetOrder(columnFunc func(string) string, defaultSort ...string) []Sort {
88+
var sorts []Sort
11089
if p != nil && len(p.Sort) != 0 {
111-
for i, s := range p.Sort {
112-
s.Column = strings.TrimSpace(s.Column)
113-
s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize()
114-
s.Order = sanitizeOrder(s.Order)
115-
p.Sort[i] = s
90+
// use sort
91+
sorts = p.Sort
92+
}
93+
// fall back to column
94+
if len(sorts) == 0 {
95+
if p != nil && p.Column != "" {
96+
for part := range strings.SplitSeq(p.Column, ",") {
97+
if s, ok := NewSort(part); ok {
98+
sorts = append(sorts, s)
99+
}
100+
}
116101
}
117-
return p.Sort
118102
}
119-
// if page has column, use default sort
120-
if p == nil || p.Column == "" {
121-
sort := make([]Sort, 0, len(defaultSort))
103+
if len(sorts) == 0 {
122104
for _, s := range defaultSort {
123105
if s, ok := NewSort(s); ok {
124-
sort = append(sort, s)
106+
sorts = append(sorts, s)
125107
}
126108
}
127-
return sort
128109
}
129-
// use column
130-
sort := make([]Sort, 0)
131-
for part := range strings.SplitSeq(p.Column, ",") {
132-
part = strings.TrimSpace(part)
133-
if part == "" {
134-
continue
110+
111+
for i := range sorts {
112+
s := &sorts[i]
113+
s.Column = strings.TrimSpace(s.Column)
114+
if columnFunc != nil {
115+
s.Column = columnFunc(s.Column)
135116
}
136-
if s, ok := NewSort(part); ok {
137-
s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize()
138-
sort = append(sort, s)
117+
s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize()
118+
119+
switch strings.ToUpper(strings.TrimSpace(string(s.Order))) {
120+
case string(Desc):
121+
s.Order = Desc
122+
case string(Asc):
123+
s.Order = Asc
124+
default:
125+
s.Order = Asc
139126
}
140127
}
141-
return sort
128+
return sorts
142129
}
143130

144131
func (p *Page) Offset() uint64 {
@@ -229,13 +216,10 @@ type Paginator[T any] struct {
229216
}
230217

231218
func (p Paginator[T]) getOrder(page *Page) []string {
232-
sort := page.GetOrder(p.settings.Sort...)
219+
sort := page.GetOrder(p.settings.ColumnFunc, p.settings.Sort...)
233220
list := make([]string, len(sort))
234-
for i, s := range sort {
235-
if p.settings.ColumnFunc != nil {
236-
s.Column = p.settings.ColumnFunc(s.Column)
237-
}
238-
list[i] = s.String()
221+
for i := range sort {
222+
list[i] = sort[i].String()
239223
}
240224
return list
241225
}
@@ -253,6 +237,11 @@ func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.Sele
253237
}
254238

255239
func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, []any) {
240+
if page == nil {
241+
page = &Page{}
242+
}
243+
page.SetDefaults(&p.settings)
244+
256245
limit, offset := page.Limit(), page.Offset()
257246

258247
q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ")

page_test.go

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,24 @@ func TestPagination(t *testing.T) {
2626
page := pgkit.NewPage(0, 0)
2727
result, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)
2828
require.Len(t, result, 0)
29-
require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page)
29+
require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page)
3030

3131
sql, args, err := query.ToSql()
3232
require.NoError(t, err)
33-
require.Equal(t, "SELECT * FROM t ORDER BY id ASC LIMIT 6 OFFSET 0", sql)
33+
require.Equal(t, `SELECT * FROM t ORDER BY "id" ASC LIMIT 3 OFFSET 0`, sql)
3434
require.Empty(t, args)
3535

3636
result = paginator.PrepareResult(make([]T, 0), page)
3737
require.Len(t, result, 0)
38-
require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page)
38+
require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page)
3939

40-
result = paginator.PrepareResult(make([]T, MaxSize), page)
41-
require.Len(t, result, MaxSize)
42-
require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page)
40+
result = paginator.PrepareResult(make([]T, DefaultSize), page)
41+
require.Len(t, result, DefaultSize)
42+
require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page)
4343

44-
result = paginator.PrepareResult(make([]T, MaxSize+2), page)
45-
require.Len(t, result, MaxSize)
46-
require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize, More: true}, page)
44+
result = paginator.PrepareResult(make([]T, DefaultSize+2), page)
45+
require.Len(t, result, DefaultSize)
46+
require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize, More: true}, page)
4747
}
4848

4949
func TestInvalidSort(t *testing.T) {
@@ -150,3 +150,91 @@ func TestPaginationEdgeCases(t *testing.T) {
150150
require.NoError(t, err4)
151151
require.Equal(t, "SELECT * FROM t LIMIT 21 OFFSET 0", sql4)
152152
}
153+
154+
func TestColumnFunc(t *testing.T) {
155+
fn := func(column string) string {
156+
switch column {
157+
case "id":
158+
return "ID"
159+
case "name":
160+
return "NAME"
161+
default:
162+
return column
163+
}
164+
}
165+
paginator := pgkit.NewPaginator[T](
166+
pgkit.WithColumnFunc(fn),
167+
)
168+
page := &pgkit.Page{
169+
Page: 1,
170+
Size: 10,
171+
Sort: []pgkit.Sort{
172+
{Column: "id", Order: pgkit.Asc},
173+
{Column: "name", Order: pgkit.Desc},
174+
{Column: "created_at", Order: pgkit.Asc},
175+
},
176+
}
177+
_, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)
178+
179+
sql, args, err := query.ToSql()
180+
require.NoError(t, err)
181+
require.Equal(t, `SELECT * FROM t ORDER BY "ID" ASC, "NAME" DESC, "created_at" ASC LIMIT 11 OFFSET 0`, sql)
182+
require.Empty(t, args)
183+
}
184+
185+
func TestColumnFallbackUsesColumnFunc(t *testing.T) {
186+
paginator := pgkit.NewPaginator[T](
187+
pgkit.WithColumnFunc(strings.ToUpper),
188+
pgkit.WithSort("id"),
189+
)
190+
page := &pgkit.Page{
191+
Page: 1,
192+
Size: 10,
193+
Column: "name",
194+
}
195+
196+
_, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)
197+
198+
sql, args, err := query.ToSql()
199+
require.NoError(t, err)
200+
require.Equal(t, `SELECT * FROM t ORDER BY "NAME" ASC LIMIT 11 OFFSET 0`, sql)
201+
require.Empty(t, args)
202+
}
203+
204+
func TestSortTakesPrecedenceOverColumn(t *testing.T) {
205+
paginator := pgkit.NewPaginator[T]()
206+
page := &pgkit.Page{
207+
Page: 1,
208+
Size: 10,
209+
Column: "name",
210+
Sort: []pgkit.Sort{
211+
{Column: "id", Order: pgkit.Desc},
212+
},
213+
}
214+
215+
_, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)
216+
217+
sql, args, err := query.ToSql()
218+
require.NoError(t, err)
219+
require.Equal(t, `SELECT * FROM t ORDER BY "id" DESC LIMIT 11 OFFSET 0`, sql)
220+
require.Empty(t, args)
221+
}
222+
223+
func TestPaginationOffsetAndPageRecompute(t *testing.T) {
224+
paginator := pgkit.NewPaginator[T]()
225+
page := &pgkit.Page{
226+
Page: 3,
227+
Size: 2,
228+
}
229+
230+
_, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)
231+
232+
sql, args, err := query.ToSql()
233+
require.NoError(t, err)
234+
require.Equal(t, "SELECT * FROM t LIMIT 3 OFFSET 4", sql)
235+
require.Empty(t, args)
236+
237+
result := paginator.PrepareResult(make([]T, 3), page)
238+
require.Len(t, result, 2)
239+
require.Equal(t, &pgkit.Page{Page: 3, Size: 2, More: true}, page)
240+
}

0 commit comments

Comments
 (0)