From 51cd99d272083770ed31b2e4cd73262d6b277995 Mon Sep 17 00:00:00 2001 From: Celso Alexandre Date: Tue, 26 May 2026 19:49:46 -0300 Subject: [PATCH 1/5] feat(compiler): sqlc.switch for compile-time query branches Add a sqlc.switch(@selector, sqlc.when('key', 'sql'), sqlc.else('sql')) macro that expands at compile time into one static query per branch, named . This implements the sqlc.switch() idea floated in discussions/364 ("generate multiple optimized queries at compile time rather than runtime CASE"). Each branch fragment is spliced into the query in place of the macro call and re-parsed as an ordinary query, so: - the generated SQL is fully static (no runtime CASE, planner uses indexes); - branch fragments are author-written constants, never runtime input, so there is no SQL injection surface; - a bad column reference in a fragment is a normal compile error; - a generated name colliding with another query is caught by the existing duplicate-query-name check. Recognition is AST-based, identical to sqlc.arg/sqlc.slice, so it works wherever those macros parse: WHERE on all engines, ORDER BY on PostgreSQL and MySQL. SQLite drops function calls in ORDER BY (see #4429), so it errors there instead of emitting the unexpanded call. The macro is rejected in the SELECT projection, where branches could change the result shape. The only change to existing code is a thin wrapper in parseQueries; all macro logic lives in the new internal/compiler/expand_switch.go. Includes unit tests and golden end-to-end tests for PostgreSQL (stdlib + pgx), MySQL, and SQLite, plus a design note in docs/proposals/sqlc-switch.md. --- docs/proposals/sqlc-switch.md | 163 +++++++++++ internal/compiler/compile.go | 52 ++-- internal/compiler/expand_switch.go | 254 ++++++++++++++++++ internal/compiler/expand_switch_test.go | 58 ++++ .../testdata/sqlc_switch/mysql/go/db.go | 31 +++ .../testdata/sqlc_switch/mysql/go/models.go | 15 ++ .../sqlc_switch/mysql/go/query.sql.go | 97 +++++++ .../testdata/sqlc_switch/mysql/query.sql | 7 + .../testdata/sqlc_switch/mysql/schema.sql | 5 + .../testdata/sqlc_switch/mysql/sqlc.json | 12 + .../sqlc_switch/postgresql/pgx/go/db.go | 32 +++ .../sqlc_switch/postgresql/pgx/go/models.go | 15 ++ .../postgresql/pgx/go/query.sql.go | 88 ++++++ .../sqlc_switch/postgresql/pgx/query.sql | 7 + .../sqlc_switch/postgresql/pgx/schema.sql | 5 + .../sqlc_switch/postgresql/pgx/sqlc.json | 13 + .../sqlc_switch/postgresql/stdlib/go/db.go | 31 +++ .../postgresql/stdlib/go/models.go | 15 ++ .../postgresql/stdlib/go/query.sql.go | 97 +++++++ .../sqlc_switch/postgresql/stdlib/query.sql | 7 + .../sqlc_switch/postgresql/stdlib/schema.sql | 5 + .../sqlc_switch/postgresql/stdlib/sqlc.json | 12 + .../testdata/sqlc_switch/sqlite/go/db.go | 31 +++ .../testdata/sqlc_switch/sqlite/go/models.go | 11 + .../sqlc_switch/sqlite/go/query.sql.go | 66 +++++ .../testdata/sqlc_switch/sqlite/query.sql | 5 + .../testdata/sqlc_switch/sqlite/schema.sql | 5 + .../testdata/sqlc_switch/sqlite/sqlc.json | 12 + 28 files changed, 1130 insertions(+), 21 deletions(-) create mode 100644 docs/proposals/sqlc-switch.md create mode 100644 internal/compiler/expand_switch.go create mode 100644 internal/compiler/expand_switch_test.go create mode 100644 internal/endtoend/testdata/sqlc_switch/mysql/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_switch/mysql/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_switch/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/sqlc_switch/mysql/query.sql create mode 100644 internal/endtoend/testdata/sqlc_switch/mysql/schema.sql create mode 100644 internal/endtoend/testdata/sqlc_switch/mysql/sqlc.json create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/query.sql.go create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/pgx/query.sql create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/pgx/schema.sql create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/pgx/sqlc.json create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/query.sql.go create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/query.sql create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/schema.sql create mode 100644 internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/sqlc.json create mode 100644 internal/endtoend/testdata/sqlc_switch/sqlite/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_switch/sqlite/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go create mode 100644 internal/endtoend/testdata/sqlc_switch/sqlite/query.sql create mode 100644 internal/endtoend/testdata/sqlc_switch/sqlite/schema.sql create mode 100644 internal/endtoend/testdata/sqlc_switch/sqlite/sqlc.json diff --git a/docs/proposals/sqlc-switch.md b/docs/proposals/sqlc-switch.md new file mode 100644 index 0000000000..74287747b9 --- /dev/null +++ b/docs/proposals/sqlc-switch.md @@ -0,0 +1,163 @@ +# Proposal: `sqlc.switch` — bounded dynamic ORDER BY / WHERE via compile-time branch expansion + +Status: **draft** (evox-it fork) +Tracking upstream: discussions/364, issues #2061, #3414, #2060; PRs #4005, #4260, #2859 (all blocked on a canonical design) + +## Problem + +sqlc has no way to vary the *structure* of a query (sort order, filter shape) at +runtime. The community has asked for "dynamic queries" since 2020. The two +existing workarounds both fail: + +1. **`CASE WHEN` in `ORDER BY`/`WHERE`** — defeats the query planner. Postgres + will not use an index when the sort key is hidden inside a `CASE` expression. +2. **Runtime string interpolation** (`sqlc.raw`, `@filter::text`) — every such + proposal has been closed because it reintroduces SQL injection and breaks + sqlc's "type-safe, it's just SQL" guarantee. + +## Design goals (derived from maintainer objections) + +- **No SQL injection, ever.** User input must never reach the query string. +- **Schema-validated.** Every dynamic fragment must parse as real SQL and + reference real columns at compile time. Bad column = compile error. +- **Planner/index friendly.** The emitted SQL must be a clean static + `ORDER BY col DESC`, never a `CASE` wrapper. +- **Finite + enumerable.** The set of runtime choices is fixed at compile time. +- **Modeled on existing precedent** (`sqlc.slice`, the `sqlc.*` macro family). + +## Syntax + +```sql +-- name: ListAuthors :many +SELECT * FROM authors +WHERE deleted_at IS NULL +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC') +); +``` + +- `sqlc.switch(@selector, branches…)` sits where a value/expression is grammatically + legal (`ORDER BY` position, `WHERE` position). `@selector` is the runtime chooser. +- `sqlc.when('key', 'sql-fragment')` — `'key'` is the enum value; `'sql-fragment'` + is a **string literal** (must be — `ASC`/`DESC` are not valid inside a function + arg list in any engine grammar). The fragment is an author-authored compile-time + constant. +- `sqlc.else('sql-fragment')` — optional default branch. + +## Why string-literal fragments are still safe + +The fragment is a constant in the `.sql` file written by the developer, exactly +like the rest of the query. It is **not** runtime input. sqlc re-parses each +fragment in its grammatical context (e.g. `SELECT 1 FROM authors ORDER BY `) +and validates every column reference against the catalog. A typo or unknown +column fails `sqlc generate`. The only thing that varies at runtime is *which +already-validated branch* is chosen — a closed enum. Injection is structurally +impossible. + +## Codegen strategy: compile-time expansion (one function per branch) + +This follows Kyle Conroy's own `sqlc.switch()` suggestion in discussions/364 +("generate multiple optimized queries at compile time rather than runtime CASE"). + +A query containing `sqlc.switch` is **expanded by the compiler into N concrete +queries**, one per branch. Each clone has the whole `sqlc.switch(...)` call +replaced in the SQL by that branch's fragment. The resulting query strings are +fully static constants — no runtime `strings.Replace`, no markers. + +### Recognition is AST-based, like the other sqlc.* macros + +`sqlc.switch`/`when`/`else` are recognized exactly the way `sqlc.arg` and +`sqlc.slice` are: by searching the parsed AST for a `FuncCall` whose schema is +`sqlc` (`astutils.Search`). There is no bespoke SQL lexer. A consequence is that +the feature works in precisely the clauses where the engine parser produces such +a node — i.e. **wherever `sqlc.arg` works**: + +| Position | PostgreSQL | MySQL | SQLite | +|---|---|---|---| +| WHERE | ✅ | ✅ | ✅ | +| ORDER BY | ✅ | ✅ | ❌ parser drops the clause | + +SQLite's parser discards *any* function call in `ORDER BY` (true of plain +`ORDER BY upper(name)` too — see upstream PR #4429), so `sqlc.switch` there is a +compile error rather than silently emitting the unexpanded call. This is the +same limitation `sqlc.arg` has. + +Once recognized, the compiler replaces the `sqlc.switch(...)` text span with each +branch's fragment, renames the `-- name:` comment to ``, +and **re-parses each branch as an ordinary query**. Every branch therefore goes +through the normal parser + analyzer, so a bad column reference in a fragment is +a compile error, and the generated query strings are fully static constants — no +runtime markers, no `strings.Replace`. + +### Generated Go — v1 (implemented) + +One static function per branch: + +```go +const listAuthorsNameAsc = `SELECT ... ORDER BY authors.name ASC` +func (q *Queries) ListAuthorsNameAsc(ctx context.Context) ([]Author, error) { ... } + +const listAuthorsRecent = `SELECT ... ORDER BY authors.created_at DESC, authors.id DESC` +func (q *Queries) ListAuthorsRecent(ctx context.Context) ([]Author, error) { ... } + +const listAuthorsElse = `SELECT ... ORDER BY authors.id ASC` +func (q *Queries) ListAuthorsElse(ctx context.Context) ([]Author, error) { ... } +``` + +This is the whole upstreamable primitive: pure compile-time expansion in the +compiler, **zero codegen changes** (the branches are ordinary queries, so every +language's codegen gets them for free). It is the conservative core to propose +first. + +### Generated Go — v2 (proposed extension) + +A generated enum for the selector plus one exported dispatcher that switches on +it. This is a codegen convenience layered on top of v1; it adds codegen surface +(enum synthesis + dispatcher emission) and is best proposed as a follow-up once +the primitive is accepted: + +```go +type ListAuthorsSort string +const ( + ListAuthorsSortNameAsc ListAuthorsSort = "name_asc" + ListAuthorsSortRecent ListAuthorsSort = "recent" +) +func (q *Queries) ListAuthors(ctx context.Context, sort ListAuthorsSort) ([]Author, error) { + switch sort { + case ListAuthorsSortNameAsc: return q.ListAuthorsNameAsc(ctx) + case ListAuthorsSortRecent: return q.ListAuthorsRecent(ctx) + default: return q.ListAuthorsElse(ctx) + } +} +``` + +## Naming rules + +| Element | Source | Example | +|---|---|---| +| Branch fn | `` + camelize(key) | `ListAuthorsNameAsc` | +| `sqlc.else` fn | `` + `Else` | `ListAuthorsElse` | +| Enum type (v2) | `` + selector name | `ListAuthorsSort` | +| Enum const (v2) | enum type + camelize(key) | `ListAuthorsSortNameAsc` | +| Dispatcher (v2) | `` | `ListAuthors` | + +## v1 scope (implemented + tested) + +- Recognized in WHERE (all engines) and ORDER BY (PostgreSQL, MySQL); SQLite + ORDER BY is a clear compile error (parser limitation, parity with `sqlc.arg`). +- **Not allowed in the SELECT projection** — branches there could change the + result columns; rejected at compile time. +- One static function per branch. Enum + dispatcher (v2) are the proposed + follow-up. +- Engine-agnostic: the compiler emits ordinary queries, so all codegens benefit; + no codegen changes in v1. +- Golden end-to-end tests for PostgreSQL (stdlib + pgx), MySQL, and SQLite. + +## Open questions for upstream + +1. Dispatcher on by default, or opt-in via a query annotation? +2. Should `sqlc.else` be mandatory (compile error if a non-exhaustive switch is + possible) or optional (zero-value selector → else)? +3. Fragment validation depth: parse-only vs full type-check of the sort expr. diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index b6bba42e16..a3a24d8bef 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -98,33 +98,43 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) { continue } for _, stmt := range stmts { - query, err := c.parseQuery(stmt.Raw, src, o) + // A statement may expand into several: a sqlc.case(...) macro + // produces one static query per branch. Without a macro this + // yields the original statement unchanged. + sources, err := c.statementSources(stmt.Raw, src) if err != nil { - var e *sqlerr.Error - loc := stmt.Raw.Pos() - if errors.As(err, &e) && e.Location != 0 { - loc = e.Location - } - merr.Add(filename, src, loc, err) - // If this rpc unauthenticated error bubbles up, then all future parsing/analysis will fail - if errors.Is(err, rpc.ErrUnauthenticated) { - return nil, merr - } - continue - } - if query == nil { + merr.Add(filename, src, stmt.Raw.Pos(), err) continue } - query.Metadata.Filename = filepath.Base(filename) - queryName := query.Metadata.Name - if queryName != "" { - if _, exists := set[queryName]; exists { - merr.Add(filename, src, stmt.Raw.Pos(), fmt.Errorf("duplicate query name: %s", queryName)) + for _, ss := range sources { + query, err := c.parseQuery(ss.raw, ss.src, o) + if err != nil { + var e *sqlerr.Error + loc := ss.raw.Pos() + if errors.As(err, &e) && e.Location != 0 { + loc = e.Location + } + merr.Add(filename, ss.src, loc, err) + // If this rpc unauthenticated error bubbles up, then all future parsing/analysis will fail + if errors.Is(err, rpc.ErrUnauthenticated) { + return nil, merr + } continue } - set[queryName] = struct{}{} + if query == nil { + continue + } + query.Metadata.Filename = filepath.Base(filename) + queryName := query.Metadata.Name + if queryName != "" { + if _, exists := set[queryName]; exists { + merr.Add(filename, ss.src, ss.raw.Pos(), fmt.Errorf("duplicate query name: %s", queryName)) + continue + } + set[queryName] = struct{}{} + } + q = append(q, query) } - q = append(q, query) } } if len(merr.Errs()) > 0 { diff --git a/internal/compiler/expand_switch.go b/internal/compiler/expand_switch.go new file mode 100644 index 0000000000..c4152f58c2 --- /dev/null +++ b/internal/compiler/expand_switch.go @@ -0,0 +1,254 @@ +package compiler + +import ( + "fmt" + "strings" + "unicode" + + "github.com/sqlc-dev/sqlc/internal/metadata" + "github.com/sqlc-dev/sqlc/internal/source" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" +) + +// sqlcSwitchBranch is a single branch of a sqlc.switch(...) macro: a key that +// names the generated query variant and a SQL fragment that is spliced into the +// query in place of the whole sqlc.switch(...) call. +type sqlcSwitchBranch struct { + key string // "name_asc", or "else" for the default branch + fragment string // author-authored SQL, e.g. "authors.name ASC" +} + +// isSqlcFunc reports whether node is a call to sqlc.. +func isSqlcFunc(node ast.Node, name string) bool { + call, ok := node.(*ast.FuncCall) + if !ok || call.Func == nil { + return false + } + return call.Func.Schema == "sqlc" && call.Func.Name == name +} + +// stringConst extracts a string literal value from an A_Const node. +func stringConst(node ast.Node) (string, bool) { + c, ok := node.(*ast.A_Const) + if !ok { + return "", false + } + s, ok := c.Val.(*ast.String) + if !ok { + return "", false + } + return s.Str, true +} + +// camelize turns a branch key like "name_asc" into "NameAsc" so it can be +// appended to a query name and remain a valid Go identifier. +func camelize(s string) string { + var b strings.Builder + upper := true + for _, r := range s { + if r == '_' || r == '-' || r == ' ' { + upper = true + continue + } + if upper { + b.WriteRune(unicode.ToUpper(r)) + upper = false + } else { + b.WriteRune(r) + } + } + return b.String() +} + +// stmtSource pairs a statement to compile with the source text its byte +// locations are relative to. +type stmtSource struct { + raw *ast.RawStmt + src string +} + +// statementSources returns the statements to compile for a single parsed +// statement. Normally that is just the statement itself; if it contains a +// sqlc.switch(...) macro, it is the re-parsed branch variants it expands into. +func (c *Compiler) statementSources(raw *ast.RawStmt, src string) ([]stmtSource, error) { + variants, err := c.expandSqlcSwitch(raw, src) + if err != nil { + return nil, err + } + if variants == nil { + return []stmtSource{{raw: raw, src: src}}, nil + } + var sources []stmtSource + for _, v := range variants { + stmts, err := c.parser.Parse(strings.NewReader(v)) + if err != nil { + return nil, err + } + for i := range stmts { + sources = append(sources, stmtSource{raw: stmts[i].Raw, src: v}) + } + } + return sources, nil +} + +// expandSqlcSwitch looks for a sqlc.switch(...) macro in a statement and, if +// present, returns one rewritten SQL string per branch. Each rewritten string +// is a normal query: the whole sqlc.switch(...) call is replaced by that +// branch's SQL fragment and the "-- name:" comment is renamed to +// . Each variant is re-parsed and analyzed as an ordinary +// query, so a bad column reference in a fragment is a compile error, and a +// generated name that collides with another query is caught by the normal +// duplicate-query-name check. +// +// The macro is recognized from the AST the same way sqlc.arg/sqlc.slice are +// (a FuncCall with schema "sqlc"), so it works in exactly the clauses where +// those macros parse. Returning (nil, nil) means there is no sqlc.switch and the +// statement should be compiled as-is. +func (c *Compiler) expandSqlcSwitch(raw *ast.RawStmt, src string) ([]string, error) { + found := astutils.Search(raw, func(n ast.Node) bool { return isSqlcFunc(n, "switch") }) + if len(found.Items) == 0 { + // Some parsers (e.g. SQLite for ORDER BY) silently discard a function + // call they cannot place rather than erroring. If the text clearly + // contains the macro but no node survived, fail loudly instead of + // emitting the unexpanded call into the generated SQL. + if stmtSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen); err == nil && + strings.Contains(stmtSQL, "sqlc.switch") { + return nil, fmt.Errorf("sqlc.switch() is not supported in this position for this engine") + } + return nil, nil + } + if len(found.Items) > 1 { + return nil, fmt.Errorf("only one sqlc.switch() per query is supported") + } + call := found.Items[0].(*ast.FuncCall) + + // sqlc.switch() is only allowed where it does not change the result shape + // (WHERE, ORDER BY, ...), never in the SELECT projection: different branches + // there could produce different columns and break the shared row type. + if sel, ok := raw.Stmt.(*ast.SelectStmt); ok && sel.TargetList != nil { + inTarget := astutils.Search(sel.TargetList, func(n ast.Node) bool { return isSqlcFunc(n, "switch") }) + if len(inTarget.Items) > 0 { + return nil, fmt.Errorf("sqlc.switch() is not allowed in the SELECT list; use it in WHERE or ORDER BY") + } + } + + if call.Args == nil || len(call.Args.Items) < 2 { + return nil, fmt.Errorf("sqlc.switch() requires a selector and at least one sqlc.when()/sqlc.else() branch") + } + + // args[0] is the runtime selector (e.g. @sort). It plays no role in the + // generated code (one function per branch, named by branch key) but is + // required so the intent is explicit. + branches := make([]sqlcSwitchBranch, 0, len(call.Args.Items)-1) + seenElse := false + for _, arg := range call.Args.Items[1:] { + switch { + case isSqlcFunc(arg, "when"): + when := arg.(*ast.FuncCall) + if when.Args == nil || len(when.Args.Items) != 2 { + return nil, fmt.Errorf("sqlc.when() requires exactly 2 arguments: a key and a SQL fragment") + } + key, ok := stringConst(when.Args.Items[0]) + if !ok { + return nil, fmt.Errorf("sqlc.when() key must be a string literal") + } + frag, ok := stringConst(when.Args.Items[1]) + if !ok { + return nil, fmt.Errorf("sqlc.when() fragment must be a string literal") + } + branches = append(branches, sqlcSwitchBranch{key: key, fragment: frag}) + case isSqlcFunc(arg, "else"): + if seenElse { + return nil, fmt.Errorf("sqlc.switch() allows at most one sqlc.else()") + } + seenElse = true + els := arg.(*ast.FuncCall) + if els.Args == nil || len(els.Args.Items) != 1 { + return nil, fmt.Errorf("sqlc.else() requires exactly 1 argument: a SQL fragment") + } + frag, ok := stringConst(els.Args.Items[0]) + if !ok { + return nil, fmt.Errorf("sqlc.else() fragment must be a string literal") + } + branches = append(branches, sqlcSwitchBranch{key: "else", fragment: frag}) + default: + return nil, fmt.Errorf("sqlc.switch() branches must be sqlc.when() or sqlc.else() calls") + } + } + + // Locate the byte span of the whole sqlc.switch(...) call within its + // statement so it can be replaced textually with each branch fragment. + stmtSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen) + if err != nil { + return nil, err + } + switchStart := call.Location - raw.StmtLocation + if switchStart < 0 || switchStart >= len(stmtSQL) { + return nil, fmt.Errorf("could not locate sqlc.switch() in source") + } + switchEnd, err := matchParen(stmtSQL, switchStart) + if err != nil { + return nil, err + } + + name, _, err := metadata.ParseQueryNameAndType(stmtSQL, metadata.CommentSyntax(c.parser.CommentSyntax())) + if err != nil { + return nil, err + } + if name == "" { + return nil, fmt.Errorf("sqlc.switch() requires the query to have a -- name: annotation") + } + + variants := make([]string, 0, len(branches)) + for _, br := range branches { + spliced := stmtSQL[:switchStart] + br.fragment + stmtSQL[switchEnd+1:] + // The plucked statement may exclude its trailing ";" (it can fall + // outside StmtLen), so re-parsing a branch without one could yield an + // empty statement. Normalize to exactly one terminator. + spliced = strings.TrimRight(spliced, " \t\r\n;") + ";" + newName := name + camelize(br.key) + // Rename only the name comment. "name: " is shared by all comment + // styles (-- /* #), so a single replacement is enough. + spliced = strings.Replace(spliced, "name: "+name, "name: "+newName, 1) + variants = append(variants, spliced) + } + return variants, nil +} + +// matchParen returns the index of the ')' that closes the first '(' at or after +// start in s, skipping single-quoted string literals so parentheses inside a +// fragment (e.g. coalesce(x, 0)) do not throw off the depth count. +func matchParen(s string, start int) (int, error) { + i := start + for i < len(s) && s[i] != '(' { + i++ + } + if i >= len(s) { + return 0, fmt.Errorf("could not locate opening parenthesis of sqlc.switch()") + } + depth := 0 + for ; i < len(s); i++ { + switch s[i] { + case '\'': + // Advance to the closing quote, honoring '' escapes. + for i++; i < len(s); i++ { + if s[i] == '\'' { + if i+1 < len(s) && s[i+1] == '\'' { + i++ + continue + } + break + } + } + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + return i, nil + } + } + } + return 0, fmt.Errorf("could not locate closing parenthesis of sqlc.switch()") +} diff --git a/internal/compiler/expand_switch_test.go b/internal/compiler/expand_switch_test.go new file mode 100644 index 0000000000..151feaef30 --- /dev/null +++ b/internal/compiler/expand_switch_test.go @@ -0,0 +1,58 @@ +package compiler + +import "testing" + +func TestCamelize(t *testing.T) { + for _, tc := range []struct { + in string + want string + }{ + {"name_asc", "NameAsc"}, + {"recent", "Recent"}, + {"else", "Else"}, + {"created-at-desc", "CreatedAtDesc"}, + {"two words", "TwoWords"}, + {"", ""}, + } { + if got := camelize(tc.in); got != tc.want { + t.Errorf("camelize(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestMatchParen(t *testing.T) { + for _, tc := range []struct { + name string + in string + start int + want int + wantErr bool + }{ + {"simple", "f(x)", 0, 3, false}, + {"nested", "f(g(x), h(y))", 0, 12, false}, + {"string with parens", "f('a)b', x)", 0, 10, false}, + {"escaped quote in string", "f('a''b)', x)", 0, 12, false}, + {"fragment with call", "case('coalesce(x, 0) ASC')", 0, 25, false}, + {"unbalanced", "f(x", 0, 0, true}, + {"no open paren", "abc", 0, 0, true}, + } { + t.Run(tc.name, func(t *testing.T) { + got, err := matchParen(tc.in, tc.start) + if tc.wantErr { + if err == nil { + t.Fatalf("matchParen(%q) expected error, got %d", tc.in, got) + } + return + } + if err != nil { + t.Fatalf("matchParen(%q) unexpected error: %v", tc.in, err) + } + if got != tc.want { + t.Errorf("matchParen(%q) = %d, want %d", tc.in, got, tc.want) + } + if tc.in[got] != ')' { + t.Errorf("matchParen(%q) index %d is %q, not ')'", tc.in, got, tc.in[got]) + } + }) + } +} diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/go/db.go b/internal/endtoend/testdata/sqlc_switch/mysql/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/go/models.go b/internal/endtoend/testdata/sqlc_switch/mysql/go/models.go new file mode 100644 index 0000000000..f864877022 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "time" +) + +type Author struct { + ID int64 + Name string + CreatedAt time.Time +} diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/go/query.sql.go b/internal/endtoend/testdata/sqlc_switch/mysql/go/query.sql.go new file mode 100644 index 0000000000..aee79c120f --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/go/query.sql.go @@ -0,0 +1,97 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthorsElse = `-- name: ListAuthorsElse :many +SELECT id, name, created_at FROM authors +WHERE name = ? +ORDER BY authors.id ASC +` + +func (q *Queries) ListAuthorsElse(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsElse, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNameAsc = `-- name: ListAuthorsNameAsc :many +SELECT id, name, created_at FROM authors +WHERE name = ? +ORDER BY authors.name ASC +` + +func (q *Queries) ListAuthorsNameAsc(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsNameAsc, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsRecent = `-- name: ListAuthorsRecent :many +SELECT id, name, created_at FROM authors +WHERE name = ? +ORDER BY authors.created_at DESC, authors.id DESC +` + +func (q *Queries) ListAuthorsRecent(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsRecent, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/query.sql b/internal/endtoend/testdata/sqlc_switch/mysql/query.sql new file mode 100644 index 0000000000..87c86717b7 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/query.sql @@ -0,0 +1,7 @@ +-- name: ListAuthors :many +SELECT * FROM authors +WHERE name = ? +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC')); diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/schema.sql b/internal/endtoend/testdata/sqlc_switch/mysql/schema.sql new file mode 100644 index 0000000000..98cfc6b3f4 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGINT PRIMARY KEY AUTO_INCREMENT, + name varchar(255) NOT NULL, + created_at datetime NOT NULL +); diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/sqlc.json b/internal/endtoend/testdata/sqlc_switch/mysql/sqlc.json new file mode 100644 index 0000000000..974aa9ff9e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "mysql", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/db.go b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..0057c62319 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/models.go b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..bb085de6f3 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +type Author struct { + ID int64 + Name string + CreatedAt pgtype.Timestamptz +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..0962dfd25e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/query.sql.go @@ -0,0 +1,88 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthorsElse = `-- name: ListAuthorsElse :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.id ASC +` + +func (q *Queries) ListAuthorsElse(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.Query(ctx, listAuthorsElse, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNameAsc = `-- name: ListAuthorsNameAsc :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.name ASC +` + +func (q *Queries) ListAuthorsNameAsc(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.Query(ctx, listAuthorsNameAsc, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsRecent = `-- name: ListAuthorsRecent :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.created_at DESC, authors.id DESC +` + +func (q *Queries) ListAuthorsRecent(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.Query(ctx, listAuthorsRecent, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/query.sql b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/query.sql new file mode 100644 index 0000000000..c9b8da1615 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/query.sql @@ -0,0 +1,7 @@ +-- name: ListAuthors :many +SELECT * FROM authors +WHERE name = $1 +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC')); diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/schema.sql b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/schema.sql new file mode 100644 index 0000000000..f9f2d25ab8 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name text NOT NULL, + created_at timestamptz NOT NULL +); diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..d12b82a6c6 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "engine": "postgresql", + "sql_package": "pgx/v5", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..f864877022 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "time" +) + +type Author struct { + ID int64 + Name string + CreatedAt time.Time +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..948865d030 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,97 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthorsElse = `-- name: ListAuthorsElse :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.id ASC +` + +func (q *Queries) ListAuthorsElse(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsElse, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNameAsc = `-- name: ListAuthorsNameAsc :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.name ASC +` + +func (q *Queries) ListAuthorsNameAsc(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsNameAsc, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsRecent = `-- name: ListAuthorsRecent :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.created_at DESC, authors.id DESC +` + +func (q *Queries) ListAuthorsRecent(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsRecent, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/query.sql b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..c9b8da1615 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/query.sql @@ -0,0 +1,7 @@ +-- name: ListAuthors :many +SELECT * FROM authors +WHERE name = $1 +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC')); diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..f9f2d25ab8 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name text NOT NULL, + created_at timestamptz NOT NULL +); diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..cd518671ac --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "postgresql", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/go/db.go b/internal/endtoend/testdata/sqlc_switch/sqlite/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/go/models.go b/internal/endtoend/testdata/sqlc_switch/sqlite/go/models.go new file mode 100644 index 0000000000..7a8fcad68e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +type Author struct { + ID int64 + Name string + CreatedAt string +} diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go b/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go new file mode 100644 index 0000000000..893b01e848 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go @@ -0,0 +1,66 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const findAuthorsElse = `-- name: FindAuthorsElse :many +SELECT id, name, created_at FROM authors +WHERE 1 = 1 +` + +func (q *Queries) FindAuthorsElse(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, findAuthorsElse) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const findAuthorsNamed = `-- name: FindAuthorsNamed :many +SELECT id, name, created_at FROM authors +WHERE name IS NOT NULL +` + +func (q *Queries) FindAuthorsNamed(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, findAuthorsNamed) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql b/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql new file mode 100644 index 0000000000..570e08759d --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql @@ -0,0 +1,5 @@ +-- name: FindAuthors :many +SELECT * FROM authors +WHERE sqlc.switch(@filter, + sqlc.when('named', 'name IS NOT NULL'), + sqlc.else( '1 = 1')); diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/schema.sql b/internal/endtoend/testdata/sqlc_switch/sqlite/schema.sql new file mode 100644 index 0000000000..9acd8d6107 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id INTEGER PRIMARY KEY, + name text NOT NULL, + created_at text NOT NULL +); diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/sqlc.json b/internal/endtoend/testdata/sqlc_switch/sqlite/sqlc.json new file mode 100644 index 0000000000..1f9d43df5d --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "sqlite", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} From df7296200ec9f0b6c7c431605e4223cd8c1479a9 Mon Sep 17 00:00:00 2001 From: Celso Alexandre Date: Tue, 26 May 2026 20:23:00 -0300 Subject: [PATCH 2/5] feat(golang): multiple sqlc.switch per query + shared Params/Row structs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow several sqlc.switch() calls in one query as long as they declare the same keys (e.g. the same sort applied in a CTE pre-sort and the final ORDER BY). Expansion stays linear in the number of keys — one function per key, each call contributing its own fragment — not the cross product. Branches expanded from a sqlc.switch() now share a single Params and Row struct named after the original query, instead of an identical copy per branch. All branches have the same parameters and result columns (only the spliced fragment differs), so the per-branch structs were byte-identical. A new SwitchGroup field links the branches from compiler metadata through to the Go generator, which emits the shared struct once and points every branch at it. --- internal/cmd/shim.go | 1 + internal/codegen/golang/gen.go | 1 + internal/codegen/golang/query.go | 13 +- internal/codegen/golang/result.go | 42 ++++++ internal/compiler/compile.go | 1 + internal/compiler/expand_switch.go | 198 +++++++++++++++++++---------- internal/metadata/meta.go | 5 + internal/plugin/codegen.pb.go | 8 ++ 8 files changed, 200 insertions(+), 69 deletions(-) diff --git a/internal/cmd/shim.go b/internal/cmd/shim.go index 654500429a..7b456795e8 100644 --- a/internal/cmd/shim.go +++ b/internal/cmd/shim.go @@ -161,6 +161,7 @@ func pluginQueries(r *compiler.Result) []*plugin.Query { Params: params, Filename: q.Metadata.Filename, InsertIntoTable: iit, + SwitchGroup: q.Metadata.SwitchGroup, }) } return out diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 5b81c149c3..645c41633e 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -125,6 +125,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat if err != nil { return nil, err } + shareSwitchGroupStructs(queries, options) if options.OmitUnusedStructs { enums, structs = filterUnusedStructs(enums, structs, queries, options.ModelsTypeQualifier()) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 27c596c24e..dc7bb74bfd 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -18,6 +18,12 @@ type QueryValue struct { Typ string SQLDriver opts.SQLDriver + // DefinedElsewhere suppresses emitting this value's struct definition + // because another query already emits an identical one (e.g. sqlc.switch + // branches that share a single Params/Row struct). The struct is still used + // by value in the method signature. + DefinedElsewhere bool + // ModelQualifier prefixes references to model types when the models file // lives in a different Go package (e.g. "model."). Empty otherwise. ModelQualifier string @@ -28,7 +34,7 @@ type QueryValue struct { } func (v QueryValue) EmitStruct() bool { - return v.Emit + return v.Emit && !v.DefinedElsewhere } func (v QueryValue) IsStruct() bool { @@ -62,7 +68,7 @@ func (v QueryValue) Pairs() []Argument { if v.isEmpty() { return nil } - if !v.EmitStruct() && v.IsStruct() { + if !v.Emit && v.IsStruct() { var out []Argument for _, f := range v.Struct.Fields { out = append(out, Argument{ @@ -279,6 +285,9 @@ type Query struct { Arg QueryValue // Used for :copyfrom Table *plugin.Identifier + // SwitchGroup links the branches expanded from one sqlc.switch() macro so + // they can share a single Params/Row struct. + SwitchGroup string } func (q Query) hasRetType() bool { diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 5bfa7f795e..fe5a990104 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -226,6 +226,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, enums []En SQL: query.Text, Comments: comments, Table: query.InsertIntoTable, + SwitchGroup: query.SwitchGroup, } sqlpkg := parseDriver(options.SqlPackage) @@ -471,3 +472,44 @@ func checkIncompatibleFieldTypes(fields []Field) error { } return nil } + +// shareSwitchGroupStructs makes the branch functions expanded from a single +// sqlc.switch() macro use one shared Params and Row struct instead of an +// identical copy per branch. All branches of a macro have the same parameters +// and result columns (only the spliced ORDER BY/WHERE fragment differs), so the +// per-query structs are byte-identical; this collapses them to one named after +// the original query (the SwitchGroup), emitted once. +func shareSwitchGroupStructs(queries []Query, options *opts.Options) { + groups := map[string][]int{} + for i := range queries { + if g := queries[i].SwitchGroup; g != "" { + groups[g] = append(groups[g], i) + } + } + for group, idx := range groups { + if len(idx) < 2 { + continue + } + canon := &queries[idx[0]] + + // Params: a single struct arg becomes shared. (Few-param queries inline + // their args instead of using a struct, so they are already identical.) + if canon.Arg.IsStruct() && canon.Arg.Emit { + canon.Arg.Struct.Name = StructName(group+"Params", options) + for _, i := range idx[1:] { + queries[i].Arg.Struct = canon.Arg.Struct + queries[i].Arg.DefinedElsewhere = true + } + } + + // Row: only when the branch built its own *Row struct. If it reused a + // table model, every branch already shares that model. + if canon.Ret.IsStruct() && canon.Ret.Emit { + canon.Ret.Struct.Name = StructName(group+"Row", options) + for _, i := range idx[1:] { + queries[i].Ret.Struct = canon.Ret.Struct + queries[i].Ret.Emit = false + } + } + } +} diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index a3a24d8bef..966a6500f3 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -124,6 +124,7 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) { if query == nil { continue } + query.Metadata.SwitchGroup = ss.group query.Metadata.Filename = filepath.Base(filename) queryName := query.Metadata.Name if queryName != "" { diff --git a/internal/compiler/expand_switch.go b/internal/compiler/expand_switch.go index c4152f58c2..0f9a1f5f80 100644 --- a/internal/compiler/expand_switch.go +++ b/internal/compiler/expand_switch.go @@ -2,6 +2,7 @@ package compiler import ( "fmt" + "sort" "strings" "unicode" @@ -62,17 +63,19 @@ func camelize(s string) string { } // stmtSource pairs a statement to compile with the source text its byte -// locations are relative to. +// locations are relative to. group is the sqlc.switch() group name (the +// original query name) for branch variants, empty otherwise. type stmtSource struct { - raw *ast.RawStmt - src string + raw *ast.RawStmt + src string + group string } // statementSources returns the statements to compile for a single parsed // statement. Normally that is just the statement itself; if it contains a // sqlc.switch(...) macro, it is the re-parsed branch variants it expands into. func (c *Compiler) statementSources(raw *ast.RawStmt, src string) ([]stmtSource, error) { - variants, err := c.expandSqlcSwitch(raw, src) + variants, group, err := c.expandSqlcSwitch(raw, src) if err != nil { return nil, err } @@ -86,60 +89,27 @@ func (c *Compiler) statementSources(raw *ast.RawStmt, src string) ([]stmtSource, return nil, err } for i := range stmts { - sources = append(sources, stmtSource{raw: stmts[i].Raw, src: v}) + sources = append(sources, stmtSource{raw: stmts[i].Raw, src: v, group: group}) } } return sources, nil } -// expandSqlcSwitch looks for a sqlc.switch(...) macro in a statement and, if -// present, returns one rewritten SQL string per branch. Each rewritten string -// is a normal query: the whole sqlc.switch(...) call is replaced by that -// branch's SQL fragment and the "-- name:" comment is renamed to -// . Each variant is re-parsed and analyzed as an ordinary -// query, so a bad column reference in a fragment is a compile error, and a -// generated name that collides with another query is caught by the normal -// duplicate-query-name check. -// -// The macro is recognized from the AST the same way sqlc.arg/sqlc.slice are -// (a FuncCall with schema "sqlc"), so it works in exactly the clauses where -// those macros parse. Returning (nil, nil) means there is no sqlc.switch and the -// statement should be compiled as-is. -func (c *Compiler) expandSqlcSwitch(raw *ast.RawStmt, src string) ([]string, error) { - found := astutils.Search(raw, func(n ast.Node) bool { return isSqlcFunc(n, "switch") }) - if len(found.Items) == 0 { - // Some parsers (e.g. SQLite for ORDER BY) silently discard a function - // call they cannot place rather than erroring. If the text clearly - // contains the macro but no node survived, fail loudly instead of - // emitting the unexpanded call into the generated SQL. - if stmtSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen); err == nil && - strings.Contains(stmtSQL, "sqlc.switch") { - return nil, fmt.Errorf("sqlc.switch() is not supported in this position for this engine") - } - return nil, nil - } - if len(found.Items) > 1 { - return nil, fmt.Errorf("only one sqlc.switch() per query is supported") - } - call := found.Items[0].(*ast.FuncCall) - - // sqlc.switch() is only allowed where it does not change the result shape - // (WHERE, ORDER BY, ...), never in the SELECT projection: different branches - // there could produce different columns and break the shared row type. - if sel, ok := raw.Stmt.(*ast.SelectStmt); ok && sel.TargetList != nil { - inTarget := astutils.Search(sel.TargetList, func(n ast.Node) bool { return isSqlcFunc(n, "switch") }) - if len(inTarget.Items) > 0 { - return nil, fmt.Errorf("sqlc.switch() is not allowed in the SELECT list; use it in WHERE or ORDER BY") - } - } +// parsedSwitch is one sqlc.switch(...) call: its byte span within the statement +// text and its branches in declaration order. +type parsedSwitch struct { + start, end int + branches []sqlcSwitchBranch +} +// switchBranches parses the when()/else() branches of a single sqlc.switch call. +func switchBranches(call *ast.FuncCall) ([]sqlcSwitchBranch, error) { if call.Args == nil || len(call.Args.Items) < 2 { return nil, fmt.Errorf("sqlc.switch() requires a selector and at least one sqlc.when()/sqlc.else() branch") } - - // args[0] is the runtime selector (e.g. @sort). It plays no role in the - // generated code (one function per branch, named by branch key) but is - // required so the intent is explicit. + // args[0] is the runtime selector (e.g. @sort). It is not used by the + // generated code (one function per branch, named by key) but is required so + // the intent is explicit. branches := make([]sqlcSwitchBranch, 0, len(call.Args.Items)-1) seenElse := false for _, arg := range call.Args.Items[1:] { @@ -176,44 +146,138 @@ func (c *Compiler) expandSqlcSwitch(raw *ast.RawStmt, src string) ([]string, err return nil, fmt.Errorf("sqlc.switch() branches must be sqlc.when() or sqlc.else() calls") } } + return branches, nil +} - // Locate the byte span of the whole sqlc.switch(...) call within its - // statement so it can be replaced textually with each branch fragment. - stmtSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen) - if err != nil { - return nil, err +// expandSqlcSwitch looks for sqlc.switch(...) macros in a statement and, if +// present, returns one rewritten SQL string per branch key. In each variant +// every sqlc.switch(...) call is replaced by that key's SQL fragment and the +// "-- name:" comment is renamed to . Each variant is +// re-parsed and analyzed as an ordinary query, so a bad column reference in a +// fragment is a compile error and a generated name that collides with another +// query is caught by the normal duplicate-query-name check. +// +// A query may contain several sqlc.switch() calls (e.g. the same sort applied in +// a CTE pre-sort and in the final ORDER BY). They must all declare the same set +// of keys; expansion stays linear in the number of keys (one function per key), +// not the cross product, with each call contributing its own fragment per key. +// +// The macro is recognized from the AST the same way sqlc.arg/sqlc.slice are +// (a FuncCall with schema "sqlc"), so it works in exactly the clauses where +// those macros parse. Returning (nil, nil) means there is no sqlc.switch and the +// statement should be compiled as-is. +func (c *Compiler) expandSqlcSwitch(raw *ast.RawStmt, src string) ([]string, string, error) { + found := astutils.Search(raw, func(n ast.Node) bool { return isSqlcFunc(n, "switch") }) + if len(found.Items) == 0 { + // Some parsers (e.g. SQLite for ORDER BY) silently discard a function + // call they cannot place rather than erroring. If the text clearly + // contains the macro but no node survived, fail loudly instead of + // emitting the unexpanded call into the generated SQL. + if stmtSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen); err == nil && + strings.Contains(stmtSQL, "sqlc.switch") { + return nil, "", fmt.Errorf("sqlc.switch() is not supported in this position for this engine") + } + return nil, "", nil } - switchStart := call.Location - raw.StmtLocation - if switchStart < 0 || switchStart >= len(stmtSQL) { - return nil, fmt.Errorf("could not locate sqlc.switch() in source") + + // sqlc.switch() is only allowed where it does not change the result shape + // (WHERE, ORDER BY, ...), never in the SELECT projection: different branches + // there could produce different columns and break the shared row type. + if sel, ok := raw.Stmt.(*ast.SelectStmt); ok && sel.TargetList != nil { + inTarget := astutils.Search(sel.TargetList, func(n ast.Node) bool { return isSqlcFunc(n, "switch") }) + if len(inTarget.Items) > 0 { + return nil, "", fmt.Errorf("sqlc.switch() is not allowed in the SELECT list; use it in WHERE or ORDER BY") + } } - switchEnd, err := matchParen(stmtSQL, switchStart) + + stmtSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen) if err != nil { - return nil, err + return nil, "", err } - name, _, err := metadata.ParseQueryNameAndType(stmtSQL, metadata.CommentSyntax(c.parser.CommentSyntax())) if err != nil { - return nil, err + return nil, "", err } if name == "" { - return nil, fmt.Errorf("sqlc.switch() requires the query to have a -- name: annotation") + return nil, "", fmt.Errorf("sqlc.switch() requires the query to have a -- name: annotation") } - variants := make([]string, 0, len(branches)) - for _, br := range branches { - spliced := stmtSQL[:switchStart] + br.fragment + stmtSQL[switchEnd+1:] + // Parse every switch and locate its byte span. + switches := make([]parsedSwitch, 0, len(found.Items)) + for _, item := range found.Items { + call := item.(*ast.FuncCall) + branches, err := switchBranches(call) + if err != nil { + return nil, "", err + } + start := call.Location - raw.StmtLocation + if start < 0 || start >= len(stmtSQL) { + return nil, "", fmt.Errorf("could not locate sqlc.switch() in source") + } + end, err := matchParen(stmtSQL, start) + if err != nil { + return nil, "", err + } + switches = append(switches, parsedSwitch{start: start, end: end, branches: branches}) + } + + // All switches must use the same keys; the first one fixes the order. + canonical := switches[0].branches + for _, sw := range switches[1:] { + if !sameKeys(canonical, sw.branches) { + return nil, "", fmt.Errorf("all sqlc.switch() in a query must use the same when()/else() keys") + } + } + + // Apply switch spans right-to-left so earlier (leftward) spans keep their + // original offsets while later ones are replaced. + ordered := append([]parsedSwitch(nil), switches...) + sort.Slice(ordered, func(i, j int) bool { return ordered[i].start > ordered[j].start }) + + variants := make([]string, 0, len(canonical)) + for _, cb := range canonical { + spliced := stmtSQL + for _, sw := range ordered { + spliced = spliced[:sw.start] + fragmentForKey(sw.branches, cb.key) + spliced[sw.end+1:] + } // The plucked statement may exclude its trailing ";" (it can fall // outside StmtLen), so re-parsing a branch without one could yield an // empty statement. Normalize to exactly one terminator. spliced = strings.TrimRight(spliced, " \t\r\n;") + ";" - newName := name + camelize(br.key) + newName := name + camelize(cb.key) // Rename only the name comment. "name: " is shared by all comment // styles (-- /* #), so a single replacement is enough. spliced = strings.Replace(spliced, "name: "+name, "name: "+newName, 1) variants = append(variants, spliced) } - return variants, nil + return variants, name, nil +} + +// sameKeys reports whether two branch lists declare the same set of keys. +func sameKeys(a, b []sqlcSwitchBranch) bool { + if len(a) != len(b) { + return false + } + set := make(map[string]bool, len(a)) + for _, br := range a { + set[br.key] = true + } + for _, br := range b { + if !set[br.key] { + return false + } + } + return true +} + +// fragmentForKey returns the SQL fragment a switch declares for the given key. +func fragmentForKey(branches []sqlcSwitchBranch, key string) string { + for _, br := range branches { + if br.key == key { + return br.fragment + } + } + return "" } // matchParen returns the index of the ')' that closes the first '(' at or after diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index 8f63624d2c..cf9117e6c4 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -24,6 +24,11 @@ type Metadata struct { RuleSkiplist map[string]struct{} Filename string + + // SwitchGroup is set on queries generated by expanding a sqlc.switch() macro. + // All branches of the same macro share the value (the original query name), + // letting codegen give them a single shared Params/Row struct. + SwitchGroup string } const ( diff --git a/internal/plugin/codegen.pb.go b/internal/plugin/codegen.pb.go index 525ffc72ef..660ee1f3bd 100644 --- a/internal/plugin/codegen.pb.go +++ b/internal/plugin/codegen.pb.go @@ -816,6 +816,14 @@ type Query struct { Comments []string `protobuf:"bytes,6,rep,name=comments,proto3" json:"comments,omitempty"` Filename string `protobuf:"bytes,7,opt,name=filename,proto3" json:"filename,omitempty"` InsertIntoTable *Identifier `protobuf:"bytes,8,opt,name=insert_into_table,proto3" json:"insert_into_table,omitempty"` + SwitchGroup string `protobuf:"bytes,9,opt,name=switch_group,json=switchGroup,proto3" json:"switch_group,omitempty"` +} + +func (x *Query) GetSwitchGroup() string { + if x != nil { + return x.SwitchGroup + } + return "" } func (x *Query) Reset() { From 348b5da201614029989cfd3a7157df9093c0085f Mon Sep 17 00:00:00 2001 From: Celso Alexandre Date: Tue, 26 May 2026 20:28:09 -0300 Subject: [PATCH 3/5] fix(golang): reference shared struct fields as arg.Field in query body VariableForField used EmitStruct() to decide between arg.Field and an inlined bare name. A value whose struct is DefinedElsewhere (a shared sqlc.switch Params struct) still takes a single struct arg, so it must use arg.Field. Switch to v.Emit to match Pairs(). --- internal/codegen/golang/query.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index dc7bb74bfd..50eb4dedd3 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -266,7 +266,10 @@ func (v QueryValue) VariableForField(f Field) string { if !v.IsStruct() { return v.Name } - if !v.EmitStruct() { + // Use v.Emit (single struct param) rather than EmitStruct(): a value whose + // struct is DefinedElsewhere still receives a single struct arg, so fields + // are referenced as arg.Field, not as inlined bare names. + if !v.Emit { return toLowerCase(f.Name) } return v.Name + "." + f.Name From f7bf98b7d61457edd4320314efc57a545753bc03 Mon Sep 17 00:00:00 2001 From: fahmifan Date: Sat, 6 Jun 2026 11:37:47 +0700 Subject: [PATCH 4/5] fix(sqlite): populate SelectStmt.SortClause during AST conversion SQLite ORDER BY clauses were parsed but not propagated into SelectStmt.SortClause during AST conversion. This caused ORDER BY expressions, including sqlc.switch(...) and function calls such as upper(name), to be invisible to later AST visitors and macro recognition passes. Co-authored-by: OpenAI Codex --- .../sqlc_switch/sqlite/go/query.sql.go | 84 +++++++++++++++++++ .../testdata/sqlc_switch/sqlite/query.sql | 7 ++ internal/engine/sqlite/convert.go | 64 ++++++++++---- 3 files changed, 140 insertions(+), 15 deletions(-) diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go b/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go index 893b01e848..8e41d3c353 100644 --- a/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go @@ -64,3 +64,87 @@ func (q *Queries) FindAuthorsNamed(ctx context.Context) ([]Author, error) { } return items, nil } + +const listAuthorsElse = `-- name: ListAuthorsElse :many +SELECT id, name, created_at FROM authors +ORDER BY authors.id ASC +` + +func (q *Queries) ListAuthorsElse(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsElse) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNameAsc = `-- name: ListAuthorsNameAsc :many +SELECT id, name, created_at FROM authors +ORDER BY authors.name ASC +` + +func (q *Queries) ListAuthorsNameAsc(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsNameAsc) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsRecent = `-- name: ListAuthorsRecent :many +SELECT id, name, created_at FROM authors +ORDER BY authors.created_at DESC, authors.id DESC +` + +func (q *Queries) ListAuthorsRecent(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsRecent) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql b/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql index 570e08759d..adf1674908 100644 --- a/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql @@ -3,3 +3,10 @@ SELECT * FROM authors WHERE sqlc.switch(@filter, sqlc.when('named', 'name IS NOT NULL'), sqlc.else( '1 = 1')); + +-- name: ListAuthors :many +SELECT id, name, created_at FROM authors +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC')); diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index e9868f5be6..0c4c6aa9da 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -514,6 +514,13 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt()) selectStmt.LimitCount = limitCount selectStmt.LimitOffset = limitOffset + + if n.Order_by_stmt() != nil { + if sortClause, ok := c.convert(n.Order_by_stmt()).(*ast.List); ok { + selectStmt.SortClause = sortClause + } + } + // Only set WithClause if there are CTEs if len(ctes.Items) > 0 { selectStmt.WithClause = &ast.WithClause{Ctes: &ctes} @@ -622,21 +629,48 @@ func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef } func (c *cc) convertOrderby_stmtContext(n parser.IOrder_by_stmtContext) ast.Node { - if orderBy, ok := n.(*parser.Order_by_stmtContext); ok { - list := &ast.List{Items: []ast.Node{}} - for _, o := range orderBy.AllOrdering_term() { - term, ok := o.(*parser.Ordering_termContext) - if !ok { - continue - } - list.Items = append(list.Items, &ast.CaseExpr{ - Xpr: c.convert(term.Expr()), - Location: term.Expr().GetStart().GetStart(), - }) + orderBy, ok := n.(*parser.Order_by_stmtContext) + if !ok || orderBy == nil { + return &ast.List{} + } + + list := &ast.List{Items: []ast.Node{}} + for _, o := range orderBy.AllOrdering_term() { + term, ok := o.(*parser.Ordering_termContext) + if !ok { + continue } - return list + list.Items = append(list.Items, c.convertOrderingTerm(term)) + } + + return list +} + +func (c *cc) convertOrderingTerm(term *parser.Ordering_termContext) *ast.SortBy { + sortByDir := ast.SortByDirDefault + if ad := term.Asc_desc(); ad != nil { + if ad.ASC_() != nil { + sortByDir = ast.SortByDirAsc + } else { + sortByDir = ast.SortByDirDesc + } + } + + sortByNulls := ast.SortByNullsDefault + if term.NULLS_() != nil { + if term.FIRST_() != nil { + sortByNulls = ast.SortByNullsFirst + } else { + sortByNulls = ast.SortByNullsLast + } + } + + return &ast.SortBy{ + Node: c.convert(term.Expr()), + SortbyDir: sortByDir, + SortbyNulls: sortByNulls, + UseOp: &ast.List{}, } - return todo("convertOrderby_stmtContext", n) } func (c *cc) convertLimit_stmtContext(n parser.ILimit_stmtContext) (ast.Node, ast.Node) { @@ -826,7 +860,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { if opCtx.MINUS() != nil { // Negative number: -expr return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, Rexpr: expr, } } @@ -837,7 +871,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { if opCtx.TILDE() != nil { // Bitwise NOT: ~expr return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, Rexpr: expr, } } From 3142c2614cf8960714c5ad6b64155aaf1ac87ddd Mon Sep 17 00:00:00 2001 From: Celso Alexandre Date: Fri, 12 Jun 2026 17:39:27 -0300 Subject: [PATCH 5/5] fix(compiler): allow sqlite implicit rowid columns in strict ORDER BY validation Populating SelectStmt.SortClause for SQLite (006f4e8bd) made the strict_order_by validation apply to SQLite queries for the first time. That broke queries ordering by rowid, _rowid_, or oid, which exist on most SQLite tables but never appear in the declared schema. Skip validation for those implicit names on the sqlite engine only; typos in regular column names still fail. Co-Authored-By: Claude Fable 5 --- internal/compiler/output_columns.go | 23 +++++++ .../testdata/order_by_rowid/sqlite/go/db.go | 31 +++++++++ .../order_by_rowid/sqlite/go/models.go | 10 +++ .../order_by_rowid/sqlite/go/query.sql.go | 64 +++++++++++++++++++ .../testdata/order_by_rowid/sqlite/query.sql | 5 ++ .../testdata/order_by_rowid/sqlite/schema.sql | 4 ++ .../testdata/order_by_rowid/sqlite/sqlc.json | 12 ++++ 7 files changed, 149 insertions(+) create mode 100644 internal/endtoend/testdata/order_by_rowid/sqlite/go/db.go create mode 100644 internal/endtoend/testdata/order_by_rowid/sqlite/go/models.go create mode 100644 internal/endtoend/testdata/order_by_rowid/sqlite/go/query.sql.go create mode 100644 internal/endtoend/testdata/order_by_rowid/sqlite/query.sql create mode 100644 internal/endtoend/testdata/order_by_rowid/sqlite/schema.sql create mode 100644 internal/endtoend/testdata/order_by_rowid/sqlite/sqlc.json diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index dbd486359a..8054984ae1 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" + "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/astutils" "github.com/sqlc-dev/sqlc/internal/sql/catalog" @@ -85,6 +86,9 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er if !ok { continue } + if c.conf.Engine == config.EngineSQLite && isSQLiteImplicitColumn(sb.Node) { + continue + } if err := findColumnForNode(sb.Node, tables, targets); err != nil { return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err) } @@ -714,6 +718,25 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef) return cols, nil } +// isSQLiteImplicitColumn reports whether the node references one of SQLite's +// implicit rowid columns, which exist on most tables but never appear in the +// declared schema. +func isSQLiteImplicitColumn(node ast.Node) bool { + ref, ok := node.(*ast.ColumnRef) + if !ok { + return false + } + parts := stringSlice(ref.Fields) + if len(parts) == 0 { + return false + } + switch parts[len(parts)-1] { + case "rowid", "_rowid_", "oid": + return true + } + return false +} + func findColumnForNode(item ast.Node, tables []*Table, targetList *ast.List) error { ref, ok := item.(*ast.ColumnRef) if !ok { diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/go/db.go b/internal/endtoend/testdata/order_by_rowid/sqlite/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/go/models.go b/internal/endtoend/testdata/order_by_rowid/sqlite/go/models.go new file mode 100644 index 0000000000..b334d0651a --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/go/models.go @@ -0,0 +1,10 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +type Author struct { + ID int64 + Name string +} diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/go/query.sql.go b/internal/endtoend/testdata/order_by_rowid/sqlite/go/query.sql.go new file mode 100644 index 0000000000..7ded692519 --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/go/query.sql.go @@ -0,0 +1,64 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthorsByQualifiedRowid = `-- name: ListAuthorsByQualifiedRowid :many +SELECT id, name FROM authors ORDER BY authors._rowid_ DESC +` + +func (q *Queries) ListAuthorsByQualifiedRowid(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsByQualifiedRowid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsByRowid = `-- name: ListAuthorsByRowid :many +SELECT id, name FROM authors ORDER BY rowid +` + +func (q *Queries) ListAuthorsByRowid(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsByRowid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/query.sql b/internal/endtoend/testdata/order_by_rowid/sqlite/query.sql new file mode 100644 index 0000000000..5723bead81 --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/query.sql @@ -0,0 +1,5 @@ +-- name: ListAuthorsByRowid :many +SELECT id, name FROM authors ORDER BY rowid; + +-- name: ListAuthorsByQualifiedRowid :many +SELECT id, name FROM authors ORDER BY authors._rowid_ DESC; diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/schema.sql b/internal/endtoend/testdata/order_by_rowid/sqlite/schema.sql new file mode 100644 index 0000000000..0f2208b2a6 --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/schema.sql @@ -0,0 +1,4 @@ +CREATE TABLE authors ( + id INTEGER PRIMARY KEY, + name text NOT NULL +); diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/sqlc.json b/internal/endtoend/testdata/order_by_rowid/sqlite/sqlc.json new file mode 100644 index 0000000000..1f9d43df5d --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "sqlite", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +}