diff --git a/internal/endtoend/testdata/order_by_binds/sqlite/go/db.go b/internal/endtoend/testdata/order_by_binds/sqlite/go/db.go new file mode 100644 index 0000000000..0ca2c6d0dd --- /dev/null +++ b/internal/endtoend/testdata/order_by_binds/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package order_by_binds + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/order_by_binds/sqlite/go/models.go b/internal/endtoend/testdata/order_by_binds/sqlite/go/models.go new file mode 100644 index 0000000000..e9a4756502 --- /dev/null +++ b/internal/endtoend/testdata/order_by_binds/sqlite/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package order_by_binds + +import ( + "database/sql" +) + +type Author struct { + ID int64 + Name string + Bio sql.NullString +} diff --git a/internal/endtoend/testdata/order_by_binds/sqlite/go/query.sql.go b/internal/endtoend/testdata/order_by_binds/sqlite/go/query.sql.go new file mode 100644 index 0000000000..55c26a5aed --- /dev/null +++ b/internal/endtoend/testdata/order_by_binds/sqlite/go/query.sql.go @@ -0,0 +1,173 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package order_by_binds + +import ( + "context" +) + +const listAuthorsColumnSort = `-- name: ListAuthorsColumnSort :many +SELECT id, name, bio FROM authors +WHERE id > ?1 +ORDER BY CASE WHEN ?2 = 'name' THEN name END +` + +type ListAuthorsColumnSortParams struct { + MinID int64 + SortColumn interface{} +} + +func (q *Queries) ListAuthorsColumnSort(ctx context.Context, arg ListAuthorsColumnSortParams) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsColumnSort, arg.MinID, arg.SortColumn) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsColumnSortDirection = `-- name: ListAuthorsColumnSortDirection :many +SELECT id, name, bio FROM authors +WHERE id > ? +ORDER BY + CASE + WHEN ?2 = 'asc' THEN name + END ASC, + CASE + WHEN ?2 = 'desc' OR ?2 IS NULL THEN name + END DESC +` + +type ListAuthorsColumnSortDirectionParams struct { + ID int64 + OrderBy interface{} +} + +func (q *Queries) ListAuthorsColumnSortDirection(ctx context.Context, arg ListAuthorsColumnSortDirectionParams) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsColumnSortDirection, arg.ID, arg.OrderBy) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsColumnSortFnWtihArg = `-- name: ListAuthorsColumnSortFnWtihArg :many +SELECT id, name, bio FROM authors +ORDER BY id % ? +` + +func (q *Queries) ListAuthorsColumnSortFnWtihArg(ctx context.Context, id int64) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsColumnSortFnWtihArg, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNameSort = `-- name: ListAuthorsNameSort :many +SELECT id, name, bio FROM authors +WHERE id > ?1 +ORDER BY name ASC +` + +func (q *Queries) ListAuthorsNameSort(ctx context.Context, minID int64) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsNameSort, minID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNamedParamsOnly = `-- name: ListAuthorsNamedParamsOnly :many +SELECT id, name, bio FROM authors +ORDER BY + CASE + WHEN ?1 = 'name' THEN name + WHEN ?1 = 'bio' THEN bio + END ASC +` + +func (q *Queries) ListAuthorsNamedParamsOnly(ctx context.Context, sort interface{}) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsNamedParamsOnly, sort) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/order_by_binds/sqlite/query.sql b/internal/endtoend/testdata/order_by_binds/sqlite/query.sql new file mode 100644 index 0000000000..4fa2e5d88e --- /dev/null +++ b/internal/endtoend/testdata/order_by_binds/sqlite/query.sql @@ -0,0 +1,32 @@ +-- name: ListAuthorsColumnSort :many +SELECT * FROM authors +WHERE id > sqlc.arg(min_id) +ORDER BY CASE WHEN sqlc.arg(sort_column) = 'name' THEN name END; + +-- name: ListAuthorsColumnSortDirection :many +SELECT * FROM authors +WHERE id > ? +ORDER BY + CASE + WHEN @order_by = 'asc' THEN name + END ASC, + CASE + WHEN @order_by = 'desc' OR @order_by IS NULL THEN name + END DESC; + +-- name: ListAuthorsColumnSortFnWtihArg :many +SELECT * FROM authors +ORDER BY id % ?; + +-- name: ListAuthorsNameSort :many +SELECT * FROM authors +WHERE id > sqlc.arg(min_id) +ORDER BY name ASC; + +-- name: ListAuthorsNamedParamsOnly :many +SELECT * FROM authors +ORDER BY + CASE + WHEN @sort = 'name' THEN name + WHEN @sort = 'bio' THEN bio + END ASC; diff --git a/internal/endtoend/testdata/order_by_binds/sqlite/schema.sql b/internal/endtoend/testdata/order_by_binds/sqlite/schema.sql new file mode 100644 index 0000000000..eaa9aab806 --- /dev/null +++ b/internal/endtoend/testdata/order_by_binds/sqlite/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + bio TEXT +); diff --git a/internal/endtoend/testdata/order_by_binds/sqlite/sqlc.json b/internal/endtoend/testdata/order_by_binds/sqlite/sqlc.json new file mode 100644 index 0000000000..36d9a495cc --- /dev/null +++ b/internal/endtoend/testdata/order_by_binds/sqlite/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "sqlite", + "name": "order_by_binds", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index e9868f5be6..35259a6597 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -512,8 +512,11 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No } limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt()) + sortClause := c.convertOrderby_stmtContext(n.Order_by_stmt()) + selectStmt.LimitCount = limitCount selectStmt.LimitOffset = limitOffset + selectStmt.SortClause = sortClause // Only set WithClause if there are CTEs if len(ctes.Items) > 0 { selectStmt.WithClause = &ast.WithClause{Ctes: &ctes} @@ -621,22 +624,59 @@ func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef } } -func (c *cc) convertOrderby_stmtContext(n parser.IOrder_by_stmtContext) ast.Node { - if orderBy, ok := n.(*parser.Order_by_stmtContext); ok { - list := &ast.List{Items: []ast.Node{}} - for _, o := range orderBy.AllOrdering_term() { - term, ok := o.(*parser.Ordering_termContext) - if !ok { - continue +func (c *cc) convertOrderby_stmtContext(n parser.IOrder_by_stmtContext) *ast.List { + if n == nil { + return nil + } + + orderBy, ok := n.(*parser.Order_by_stmtContext) + if !ok { + if debug.Active { + log.Printf("sqlite.convertOrderby_stmtContext: unexpected type %T", n) + } + return nil + } + + if len(orderBy.AllOrdering_term()) == 0 { + return nil + } + + sortItems := &ast.List{Items: []ast.Node{}} + for _, o := range orderBy.AllOrdering_term() { + term, ok := o.(*parser.Ordering_termContext) + if !ok { + continue + } + + // Sort direction: ASC, DESC, or default + sortByDir := ast.SortByDirDefault + if adNode := term.Asc_desc(); adNode != nil { + if adNode.ASC_() != nil { + sortByDir = ast.SortByDirAsc + } else if adNode.DESC_() != nil { + sortByDir = ast.SortByDirDesc } - list.Items = append(list.Items, &ast.CaseExpr{ - Xpr: c.convert(term.Expr()), - Location: term.Expr().GetStart().GetStart(), - }) } - return list + + // Nulls ordering: NULLS FIRST, NULLS LAST, or default + sortByNulls := ast.SortByNullsDefault + if term.NULLS_() != nil { + if term.FIRST_() != nil { + sortByNulls = ast.SortByNullsFirst + } else if term.LAST_() != nil { + sortByNulls = ast.SortByNullsLast + } + } + + sortItems.Items = append(sortItems.Items, &ast.SortBy{ + Node: c.convert(term.Expr()), + SortbyDir: sortByDir, + SortbyNulls: sortByNulls, + UseOp: &ast.List{}, + Location: term.GetStart().GetStart(), + }) } - return todo("convertOrderby_stmtContext", n) + return sortItems } func (c *cc) convertLimit_stmtContext(n parser.ILimit_stmtContext) (ast.Node, ast.Node) { @@ -826,7 +866,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { if opCtx.MINUS() != nil { // Negative number: -expr return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, Rexpr: expr, } } @@ -837,7 +877,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { if opCtx.TILDE() != nil { // Bitwise NOT: ~expr return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, Rexpr: expr, } } @@ -1313,9 +1353,6 @@ func (c *cc) convert(node node) ast.Node { case *parser.Insert_stmtContext: return c.convertInsert_stmtContext(n) - case *parser.Order_by_stmtContext: - return c.convertOrderby_stmtContext(n) - case *parser.Select_stmtContext: return c.convertMultiSelect_stmtContext(n) diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index d1ea1a22cc..c0454049d6 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -140,10 +140,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, // TODO: This code assumes that @foo::bool is on a single line var replace string - if engine == config.EngineMySQL || !dollar { - replace = "?" - } else if engine == config.EngineSQLite { + if engine == config.EngineSQLite { replace = fmt.Sprintf("?%d", argn) + } else if engine == config.EngineMySQL || !dollar { + replace = "?" } else { replace = fmt.Sprintf("$%d", argn) } @@ -168,10 +168,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, // TODO: This code assumes that @foo is on a single line var replace string - if engine == config.EngineMySQL || !dollar { - replace = "?" - } else if engine == config.EngineSQLite { + if engine == config.EngineSQLite { replace = fmt.Sprintf("?%d", argn) + } else if engine == config.EngineMySQL || !dollar { + replace = "?" } else { replace = fmt.Sprintf("$%d", argn) }