diff --git a/docs/proposals/sqlc-switch.md b/docs/proposals/sqlc-switch.md new file mode 100644 index 0000000000..74287747b9 --- /dev/null +++ b/docs/proposals/sqlc-switch.md @@ -0,0 +1,163 @@ +# Proposal: `sqlc.switch` — bounded dynamic ORDER BY / WHERE via compile-time branch expansion + +Status: **draft** (evox-it fork) +Tracking upstream: discussions/364, issues #2061, #3414, #2060; PRs #4005, #4260, #2859 (all blocked on a canonical design) + +## Problem + +sqlc has no way to vary the *structure* of a query (sort order, filter shape) at +runtime. The community has asked for "dynamic queries" since 2020. The two +existing workarounds both fail: + +1. **`CASE WHEN` in `ORDER BY`/`WHERE`** — defeats the query planner. Postgres + will not use an index when the sort key is hidden inside a `CASE` expression. +2. **Runtime string interpolation** (`sqlc.raw`, `@filter::text`) — every such + proposal has been closed because it reintroduces SQL injection and breaks + sqlc's "type-safe, it's just SQL" guarantee. + +## Design goals (derived from maintainer objections) + +- **No SQL injection, ever.** User input must never reach the query string. +- **Schema-validated.** Every dynamic fragment must parse as real SQL and + reference real columns at compile time. Bad column = compile error. +- **Planner/index friendly.** The emitted SQL must be a clean static + `ORDER BY col DESC`, never a `CASE` wrapper. +- **Finite + enumerable.** The set of runtime choices is fixed at compile time. +- **Modeled on existing precedent** (`sqlc.slice`, the `sqlc.*` macro family). + +## Syntax + +```sql +-- name: ListAuthors :many +SELECT * FROM authors +WHERE deleted_at IS NULL +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC') +); +``` + +- `sqlc.switch(@selector, branches…)` sits where a value/expression is grammatically + legal (`ORDER BY` position, `WHERE` position). `@selector` is the runtime chooser. +- `sqlc.when('key', 'sql-fragment')` — `'key'` is the enum value; `'sql-fragment'` + is a **string literal** (must be — `ASC`/`DESC` are not valid inside a function + arg list in any engine grammar). The fragment is an author-authored compile-time + constant. +- `sqlc.else('sql-fragment')` — optional default branch. + +## Why string-literal fragments are still safe + +The fragment is a constant in the `.sql` file written by the developer, exactly +like the rest of the query. It is **not** runtime input. sqlc re-parses each +fragment in its grammatical context (e.g. `SELECT 1 FROM authors ORDER BY `) +and validates every column reference against the catalog. A typo or unknown +column fails `sqlc generate`. The only thing that varies at runtime is *which +already-validated branch* is chosen — a closed enum. Injection is structurally +impossible. + +## Codegen strategy: compile-time expansion (one function per branch) + +This follows Kyle Conroy's own `sqlc.switch()` suggestion in discussions/364 +("generate multiple optimized queries at compile time rather than runtime CASE"). + +A query containing `sqlc.switch` is **expanded by the compiler into N concrete +queries**, one per branch. Each clone has the whole `sqlc.switch(...)` call +replaced in the SQL by that branch's fragment. The resulting query strings are +fully static constants — no runtime `strings.Replace`, no markers. + +### Recognition is AST-based, like the other sqlc.* macros + +`sqlc.switch`/`when`/`else` are recognized exactly the way `sqlc.arg` and +`sqlc.slice` are: by searching the parsed AST for a `FuncCall` whose schema is +`sqlc` (`astutils.Search`). There is no bespoke SQL lexer. A consequence is that +the feature works in precisely the clauses where the engine parser produces such +a node — i.e. **wherever `sqlc.arg` works**: + +| Position | PostgreSQL | MySQL | SQLite | +|---|---|---|---| +| WHERE | ✅ | ✅ | ✅ | +| ORDER BY | ✅ | ✅ | ❌ parser drops the clause | + +SQLite's parser discards *any* function call in `ORDER BY` (true of plain +`ORDER BY upper(name)` too — see upstream PR #4429), so `sqlc.switch` there is a +compile error rather than silently emitting the unexpanded call. This is the +same limitation `sqlc.arg` has. + +Once recognized, the compiler replaces the `sqlc.switch(...)` text span with each +branch's fragment, renames the `-- name:` comment to ``, +and **re-parses each branch as an ordinary query**. Every branch therefore goes +through the normal parser + analyzer, so a bad column reference in a fragment is +a compile error, and the generated query strings are fully static constants — no +runtime markers, no `strings.Replace`. + +### Generated Go — v1 (implemented) + +One static function per branch: + +```go +const listAuthorsNameAsc = `SELECT ... ORDER BY authors.name ASC` +func (q *Queries) ListAuthorsNameAsc(ctx context.Context) ([]Author, error) { ... } + +const listAuthorsRecent = `SELECT ... ORDER BY authors.created_at DESC, authors.id DESC` +func (q *Queries) ListAuthorsRecent(ctx context.Context) ([]Author, error) { ... } + +const listAuthorsElse = `SELECT ... ORDER BY authors.id ASC` +func (q *Queries) ListAuthorsElse(ctx context.Context) ([]Author, error) { ... } +``` + +This is the whole upstreamable primitive: pure compile-time expansion in the +compiler, **zero codegen changes** (the branches are ordinary queries, so every +language's codegen gets them for free). It is the conservative core to propose +first. + +### Generated Go — v2 (proposed extension) + +A generated enum for the selector plus one exported dispatcher that switches on +it. This is a codegen convenience layered on top of v1; it adds codegen surface +(enum synthesis + dispatcher emission) and is best proposed as a follow-up once +the primitive is accepted: + +```go +type ListAuthorsSort string +const ( + ListAuthorsSortNameAsc ListAuthorsSort = "name_asc" + ListAuthorsSortRecent ListAuthorsSort = "recent" +) +func (q *Queries) ListAuthors(ctx context.Context, sort ListAuthorsSort) ([]Author, error) { + switch sort { + case ListAuthorsSortNameAsc: return q.ListAuthorsNameAsc(ctx) + case ListAuthorsSortRecent: return q.ListAuthorsRecent(ctx) + default: return q.ListAuthorsElse(ctx) + } +} +``` + +## Naming rules + +| Element | Source | Example | +|---|---|---| +| Branch fn | `` + camelize(key) | `ListAuthorsNameAsc` | +| `sqlc.else` fn | `` + `Else` | `ListAuthorsElse` | +| Enum type (v2) | `` + selector name | `ListAuthorsSort` | +| Enum const (v2) | enum type + camelize(key) | `ListAuthorsSortNameAsc` | +| Dispatcher (v2) | `` | `ListAuthors` | + +## v1 scope (implemented + tested) + +- Recognized in WHERE (all engines) and ORDER BY (PostgreSQL, MySQL); SQLite + ORDER BY is a clear compile error (parser limitation, parity with `sqlc.arg`). +- **Not allowed in the SELECT projection** — branches there could change the + result columns; rejected at compile time. +- One static function per branch. Enum + dispatcher (v2) are the proposed + follow-up. +- Engine-agnostic: the compiler emits ordinary queries, so all codegens benefit; + no codegen changes in v1. +- Golden end-to-end tests for PostgreSQL (stdlib + pgx), MySQL, and SQLite. + +## Open questions for upstream + +1. Dispatcher on by default, or opt-in via a query annotation? +2. Should `sqlc.else` be mandatory (compile error if a non-exhaustive switch is + possible) or optional (zero-value selector → else)? +3. Fragment validation depth: parse-only vs full type-check of the sort expr. diff --git a/internal/cmd/shim.go b/internal/cmd/shim.go index 654500429a..7b456795e8 100644 --- a/internal/cmd/shim.go +++ b/internal/cmd/shim.go @@ -161,6 +161,7 @@ func pluginQueries(r *compiler.Result) []*plugin.Query { Params: params, Filename: q.Metadata.Filename, InsertIntoTable: iit, + SwitchGroup: q.Metadata.SwitchGroup, }) } return out diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 332fe2400a..6cd6c1a432 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -126,6 +126,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat if err != nil { return nil, err } + shareSwitchGroupStructs(queries, options) if options.OmitUnusedStructs { enums, structs = filterUnusedStructs(enums, structs, queries, options.ModelsTypeQualifier()) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 27c596c24e..50eb4dedd3 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -18,6 +18,12 @@ type QueryValue struct { Typ string SQLDriver opts.SQLDriver + // DefinedElsewhere suppresses emitting this value's struct definition + // because another query already emits an identical one (e.g. sqlc.switch + // branches that share a single Params/Row struct). The struct is still used + // by value in the method signature. + DefinedElsewhere bool + // ModelQualifier prefixes references to model types when the models file // lives in a different Go package (e.g. "model."). Empty otherwise. ModelQualifier string @@ -28,7 +34,7 @@ type QueryValue struct { } func (v QueryValue) EmitStruct() bool { - return v.Emit + return v.Emit && !v.DefinedElsewhere } func (v QueryValue) IsStruct() bool { @@ -62,7 +68,7 @@ func (v QueryValue) Pairs() []Argument { if v.isEmpty() { return nil } - if !v.EmitStruct() && v.IsStruct() { + if !v.Emit && v.IsStruct() { var out []Argument for _, f := range v.Struct.Fields { out = append(out, Argument{ @@ -260,7 +266,10 @@ func (v QueryValue) VariableForField(f Field) string { if !v.IsStruct() { return v.Name } - if !v.EmitStruct() { + // Use v.Emit (single struct param) rather than EmitStruct(): a value whose + // struct is DefinedElsewhere still receives a single struct arg, so fields + // are referenced as arg.Field, not as inlined bare names. + if !v.Emit { return toLowerCase(f.Name) } return v.Name + "." + f.Name @@ -279,6 +288,9 @@ type Query struct { Arg QueryValue // Used for :copyfrom Table *plugin.Identifier + // SwitchGroup links the branches expanded from one sqlc.switch() macro so + // they can share a single Params/Row struct. + SwitchGroup string } func (q Query) hasRetType() bool { diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index c5126602da..a400563532 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -226,6 +226,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, enums []En SQL: query.Text, Comments: comments, Table: query.InsertIntoTable, + SwitchGroup: query.SwitchGroup, } sqlpkg := parseDriver(options.SqlPackage) @@ -471,3 +472,44 @@ func checkIncompatibleFieldTypes(fields []Field) error { } return nil } + +// shareSwitchGroupStructs makes the branch functions expanded from a single +// sqlc.switch() macro use one shared Params and Row struct instead of an +// identical copy per branch. All branches of a macro have the same parameters +// and result columns (only the spliced ORDER BY/WHERE fragment differs), so the +// per-query structs are byte-identical; this collapses them to one named after +// the original query (the SwitchGroup), emitted once. +func shareSwitchGroupStructs(queries []Query, options *opts.Options) { + groups := map[string][]int{} + for i := range queries { + if g := queries[i].SwitchGroup; g != "" { + groups[g] = append(groups[g], i) + } + } + for group, idx := range groups { + if len(idx) < 2 { + continue + } + canon := &queries[idx[0]] + + // Params: a single struct arg becomes shared. (Few-param queries inline + // their args instead of using a struct, so they are already identical.) + if canon.Arg.IsStruct() && canon.Arg.Emit { + canon.Arg.Struct.Name = StructName(group+"Params", options) + for _, i := range idx[1:] { + queries[i].Arg.Struct = canon.Arg.Struct + queries[i].Arg.DefinedElsewhere = true + } + } + + // Row: only when the branch built its own *Row struct. If it reused a + // table model, every branch already shares that model. + if canon.Ret.IsStruct() && canon.Ret.Emit { + canon.Ret.Struct.Name = StructName(group+"Row", options) + for _, i := range idx[1:] { + queries[i].Ret.Struct = canon.Ret.Struct + queries[i].Ret.Emit = false + } + } + } +} diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index b6bba42e16..966a6500f3 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -98,33 +98,44 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) { continue } for _, stmt := range stmts { - query, err := c.parseQuery(stmt.Raw, src, o) + // A statement may expand into several: a sqlc.case(...) macro + // produces one static query per branch. Without a macro this + // yields the original statement unchanged. + sources, err := c.statementSources(stmt.Raw, src) if err != nil { - var e *sqlerr.Error - loc := stmt.Raw.Pos() - if errors.As(err, &e) && e.Location != 0 { - loc = e.Location - } - merr.Add(filename, src, loc, err) - // If this rpc unauthenticated error bubbles up, then all future parsing/analysis will fail - if errors.Is(err, rpc.ErrUnauthenticated) { - return nil, merr - } - continue - } - if query == nil { + merr.Add(filename, src, stmt.Raw.Pos(), err) continue } - query.Metadata.Filename = filepath.Base(filename) - queryName := query.Metadata.Name - if queryName != "" { - if _, exists := set[queryName]; exists { - merr.Add(filename, src, stmt.Raw.Pos(), fmt.Errorf("duplicate query name: %s", queryName)) + for _, ss := range sources { + query, err := c.parseQuery(ss.raw, ss.src, o) + if err != nil { + var e *sqlerr.Error + loc := ss.raw.Pos() + if errors.As(err, &e) && e.Location != 0 { + loc = e.Location + } + merr.Add(filename, ss.src, loc, err) + // If this rpc unauthenticated error bubbles up, then all future parsing/analysis will fail + if errors.Is(err, rpc.ErrUnauthenticated) { + return nil, merr + } continue } - set[queryName] = struct{}{} + if query == nil { + continue + } + query.Metadata.SwitchGroup = ss.group + query.Metadata.Filename = filepath.Base(filename) + queryName := query.Metadata.Name + if queryName != "" { + if _, exists := set[queryName]; exists { + merr.Add(filename, ss.src, ss.raw.Pos(), fmt.Errorf("duplicate query name: %s", queryName)) + continue + } + set[queryName] = struct{}{} + } + q = append(q, query) } - q = append(q, query) } } if len(merr.Errs()) > 0 { diff --git a/internal/compiler/expand_switch.go b/internal/compiler/expand_switch.go new file mode 100644 index 0000000000..0f9a1f5f80 --- /dev/null +++ b/internal/compiler/expand_switch.go @@ -0,0 +1,318 @@ +package compiler + +import ( + "fmt" + "sort" + "strings" + "unicode" + + "github.com/sqlc-dev/sqlc/internal/metadata" + "github.com/sqlc-dev/sqlc/internal/source" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" +) + +// sqlcSwitchBranch is a single branch of a sqlc.switch(...) macro: a key that +// names the generated query variant and a SQL fragment that is spliced into the +// query in place of the whole sqlc.switch(...) call. +type sqlcSwitchBranch struct { + key string // "name_asc", or "else" for the default branch + fragment string // author-authored SQL, e.g. "authors.name ASC" +} + +// isSqlcFunc reports whether node is a call to sqlc.. +func isSqlcFunc(node ast.Node, name string) bool { + call, ok := node.(*ast.FuncCall) + if !ok || call.Func == nil { + return false + } + return call.Func.Schema == "sqlc" && call.Func.Name == name +} + +// stringConst extracts a string literal value from an A_Const node. +func stringConst(node ast.Node) (string, bool) { + c, ok := node.(*ast.A_Const) + if !ok { + return "", false + } + s, ok := c.Val.(*ast.String) + if !ok { + return "", false + } + return s.Str, true +} + +// camelize turns a branch key like "name_asc" into "NameAsc" so it can be +// appended to a query name and remain a valid Go identifier. +func camelize(s string) string { + var b strings.Builder + upper := true + for _, r := range s { + if r == '_' || r == '-' || r == ' ' { + upper = true + continue + } + if upper { + b.WriteRune(unicode.ToUpper(r)) + upper = false + } else { + b.WriteRune(r) + } + } + return b.String() +} + +// stmtSource pairs a statement to compile with the source text its byte +// locations are relative to. group is the sqlc.switch() group name (the +// original query name) for branch variants, empty otherwise. +type stmtSource struct { + raw *ast.RawStmt + src string + group string +} + +// statementSources returns the statements to compile for a single parsed +// statement. Normally that is just the statement itself; if it contains a +// sqlc.switch(...) macro, it is the re-parsed branch variants it expands into. +func (c *Compiler) statementSources(raw *ast.RawStmt, src string) ([]stmtSource, error) { + variants, group, err := c.expandSqlcSwitch(raw, src) + if err != nil { + return nil, err + } + if variants == nil { + return []stmtSource{{raw: raw, src: src}}, nil + } + var sources []stmtSource + for _, v := range variants { + stmts, err := c.parser.Parse(strings.NewReader(v)) + if err != nil { + return nil, err + } + for i := range stmts { + sources = append(sources, stmtSource{raw: stmts[i].Raw, src: v, group: group}) + } + } + return sources, nil +} + +// parsedSwitch is one sqlc.switch(...) call: its byte span within the statement +// text and its branches in declaration order. +type parsedSwitch struct { + start, end int + branches []sqlcSwitchBranch +} + +// switchBranches parses the when()/else() branches of a single sqlc.switch call. +func switchBranches(call *ast.FuncCall) ([]sqlcSwitchBranch, error) { + if call.Args == nil || len(call.Args.Items) < 2 { + return nil, fmt.Errorf("sqlc.switch() requires a selector and at least one sqlc.when()/sqlc.else() branch") + } + // args[0] is the runtime selector (e.g. @sort). It is not used by the + // generated code (one function per branch, named by key) but is required so + // the intent is explicit. + branches := make([]sqlcSwitchBranch, 0, len(call.Args.Items)-1) + seenElse := false + for _, arg := range call.Args.Items[1:] { + switch { + case isSqlcFunc(arg, "when"): + when := arg.(*ast.FuncCall) + if when.Args == nil || len(when.Args.Items) != 2 { + return nil, fmt.Errorf("sqlc.when() requires exactly 2 arguments: a key and a SQL fragment") + } + key, ok := stringConst(when.Args.Items[0]) + if !ok { + return nil, fmt.Errorf("sqlc.when() key must be a string literal") + } + frag, ok := stringConst(when.Args.Items[1]) + if !ok { + return nil, fmt.Errorf("sqlc.when() fragment must be a string literal") + } + branches = append(branches, sqlcSwitchBranch{key: key, fragment: frag}) + case isSqlcFunc(arg, "else"): + if seenElse { + return nil, fmt.Errorf("sqlc.switch() allows at most one sqlc.else()") + } + seenElse = true + els := arg.(*ast.FuncCall) + if els.Args == nil || len(els.Args.Items) != 1 { + return nil, fmt.Errorf("sqlc.else() requires exactly 1 argument: a SQL fragment") + } + frag, ok := stringConst(els.Args.Items[0]) + if !ok { + return nil, fmt.Errorf("sqlc.else() fragment must be a string literal") + } + branches = append(branches, sqlcSwitchBranch{key: "else", fragment: frag}) + default: + return nil, fmt.Errorf("sqlc.switch() branches must be sqlc.when() or sqlc.else() calls") + } + } + return branches, nil +} + +// expandSqlcSwitch looks for sqlc.switch(...) macros in a statement and, if +// present, returns one rewritten SQL string per branch key. In each variant +// every sqlc.switch(...) call is replaced by that key's SQL fragment and the +// "-- name:" comment is renamed to . Each variant is +// re-parsed and analyzed as an ordinary query, so a bad column reference in a +// fragment is a compile error and a generated name that collides with another +// query is caught by the normal duplicate-query-name check. +// +// A query may contain several sqlc.switch() calls (e.g. the same sort applied in +// a CTE pre-sort and in the final ORDER BY). They must all declare the same set +// of keys; expansion stays linear in the number of keys (one function per key), +// not the cross product, with each call contributing its own fragment per key. +// +// The macro is recognized from the AST the same way sqlc.arg/sqlc.slice are +// (a FuncCall with schema "sqlc"), so it works in exactly the clauses where +// those macros parse. Returning (nil, nil) means there is no sqlc.switch and the +// statement should be compiled as-is. +func (c *Compiler) expandSqlcSwitch(raw *ast.RawStmt, src string) ([]string, string, error) { + found := astutils.Search(raw, func(n ast.Node) bool { return isSqlcFunc(n, "switch") }) + if len(found.Items) == 0 { + // Some parsers (e.g. SQLite for ORDER BY) silently discard a function + // call they cannot place rather than erroring. If the text clearly + // contains the macro but no node survived, fail loudly instead of + // emitting the unexpanded call into the generated SQL. + if stmtSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen); err == nil && + strings.Contains(stmtSQL, "sqlc.switch") { + return nil, "", fmt.Errorf("sqlc.switch() is not supported in this position for this engine") + } + return nil, "", nil + } + + // sqlc.switch() is only allowed where it does not change the result shape + // (WHERE, ORDER BY, ...), never in the SELECT projection: different branches + // there could produce different columns and break the shared row type. + if sel, ok := raw.Stmt.(*ast.SelectStmt); ok && sel.TargetList != nil { + inTarget := astutils.Search(sel.TargetList, func(n ast.Node) bool { return isSqlcFunc(n, "switch") }) + if len(inTarget.Items) > 0 { + return nil, "", fmt.Errorf("sqlc.switch() is not allowed in the SELECT list; use it in WHERE or ORDER BY") + } + } + + stmtSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen) + if err != nil { + return nil, "", err + } + name, _, err := metadata.ParseQueryNameAndType(stmtSQL, metadata.CommentSyntax(c.parser.CommentSyntax())) + if err != nil { + return nil, "", err + } + if name == "" { + return nil, "", fmt.Errorf("sqlc.switch() requires the query to have a -- name: annotation") + } + + // Parse every switch and locate its byte span. + switches := make([]parsedSwitch, 0, len(found.Items)) + for _, item := range found.Items { + call := item.(*ast.FuncCall) + branches, err := switchBranches(call) + if err != nil { + return nil, "", err + } + start := call.Location - raw.StmtLocation + if start < 0 || start >= len(stmtSQL) { + return nil, "", fmt.Errorf("could not locate sqlc.switch() in source") + } + end, err := matchParen(stmtSQL, start) + if err != nil { + return nil, "", err + } + switches = append(switches, parsedSwitch{start: start, end: end, branches: branches}) + } + + // All switches must use the same keys; the first one fixes the order. + canonical := switches[0].branches + for _, sw := range switches[1:] { + if !sameKeys(canonical, sw.branches) { + return nil, "", fmt.Errorf("all sqlc.switch() in a query must use the same when()/else() keys") + } + } + + // Apply switch spans right-to-left so earlier (leftward) spans keep their + // original offsets while later ones are replaced. + ordered := append([]parsedSwitch(nil), switches...) + sort.Slice(ordered, func(i, j int) bool { return ordered[i].start > ordered[j].start }) + + variants := make([]string, 0, len(canonical)) + for _, cb := range canonical { + spliced := stmtSQL + for _, sw := range ordered { + spliced = spliced[:sw.start] + fragmentForKey(sw.branches, cb.key) + spliced[sw.end+1:] + } + // The plucked statement may exclude its trailing ";" (it can fall + // outside StmtLen), so re-parsing a branch without one could yield an + // empty statement. Normalize to exactly one terminator. + spliced = strings.TrimRight(spliced, " \t\r\n;") + ";" + newName := name + camelize(cb.key) + // Rename only the name comment. "name: " is shared by all comment + // styles (-- /* #), so a single replacement is enough. + spliced = strings.Replace(spliced, "name: "+name, "name: "+newName, 1) + variants = append(variants, spliced) + } + return variants, name, nil +} + +// sameKeys reports whether two branch lists declare the same set of keys. +func sameKeys(a, b []sqlcSwitchBranch) bool { + if len(a) != len(b) { + return false + } + set := make(map[string]bool, len(a)) + for _, br := range a { + set[br.key] = true + } + for _, br := range b { + if !set[br.key] { + return false + } + } + return true +} + +// fragmentForKey returns the SQL fragment a switch declares for the given key. +func fragmentForKey(branches []sqlcSwitchBranch, key string) string { + for _, br := range branches { + if br.key == key { + return br.fragment + } + } + return "" +} + +// matchParen returns the index of the ')' that closes the first '(' at or after +// start in s, skipping single-quoted string literals so parentheses inside a +// fragment (e.g. coalesce(x, 0)) do not throw off the depth count. +func matchParen(s string, start int) (int, error) { + i := start + for i < len(s) && s[i] != '(' { + i++ + } + if i >= len(s) { + return 0, fmt.Errorf("could not locate opening parenthesis of sqlc.switch()") + } + depth := 0 + for ; i < len(s); i++ { + switch s[i] { + case '\'': + // Advance to the closing quote, honoring '' escapes. + for i++; i < len(s); i++ { + if s[i] == '\'' { + if i+1 < len(s) && s[i+1] == '\'' { + i++ + continue + } + break + } + } + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + return i, nil + } + } + } + return 0, fmt.Errorf("could not locate closing parenthesis of sqlc.switch()") +} diff --git a/internal/compiler/expand_switch_test.go b/internal/compiler/expand_switch_test.go new file mode 100644 index 0000000000..151feaef30 --- /dev/null +++ b/internal/compiler/expand_switch_test.go @@ -0,0 +1,58 @@ +package compiler + +import "testing" + +func TestCamelize(t *testing.T) { + for _, tc := range []struct { + in string + want string + }{ + {"name_asc", "NameAsc"}, + {"recent", "Recent"}, + {"else", "Else"}, + {"created-at-desc", "CreatedAtDesc"}, + {"two words", "TwoWords"}, + {"", ""}, + } { + if got := camelize(tc.in); got != tc.want { + t.Errorf("camelize(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestMatchParen(t *testing.T) { + for _, tc := range []struct { + name string + in string + start int + want int + wantErr bool + }{ + {"simple", "f(x)", 0, 3, false}, + {"nested", "f(g(x), h(y))", 0, 12, false}, + {"string with parens", "f('a)b', x)", 0, 10, false}, + {"escaped quote in string", "f('a''b)', x)", 0, 12, false}, + {"fragment with call", "case('coalesce(x, 0) ASC')", 0, 25, false}, + {"unbalanced", "f(x", 0, 0, true}, + {"no open paren", "abc", 0, 0, true}, + } { + t.Run(tc.name, func(t *testing.T) { + got, err := matchParen(tc.in, tc.start) + if tc.wantErr { + if err == nil { + t.Fatalf("matchParen(%q) expected error, got %d", tc.in, got) + } + return + } + if err != nil { + t.Fatalf("matchParen(%q) unexpected error: %v", tc.in, err) + } + if got != tc.want { + t.Errorf("matchParen(%q) = %d, want %d", tc.in, got, tc.want) + } + if tc.in[got] != ')' { + t.Errorf("matchParen(%q) index %d is %q, not ')'", tc.in, got, tc.in[got]) + } + }) + } +} diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 6cc2567ebe..3dcfb30779 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" + "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/astutils" "github.com/sqlc-dev/sqlc/internal/sql/catalog" @@ -85,6 +86,9 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er if !ok { continue } + if c.conf.Engine == config.EngineSQLite && isSQLiteImplicitColumn(sb.Node) { + continue + } if err := findColumnForNode(sb.Node, tables, targets); err != nil { return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err) } @@ -713,6 +717,25 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef) return cols, nil } +// isSQLiteImplicitColumn reports whether the node references one of SQLite's +// implicit rowid columns, which exist on most tables but never appear in the +// declared schema. +func isSQLiteImplicitColumn(node ast.Node) bool { + ref, ok := node.(*ast.ColumnRef) + if !ok { + return false + } + parts := stringSlice(ref.Fields) + if len(parts) == 0 { + return false + } + switch parts[len(parts)-1] { + case "rowid", "_rowid_", "oid": + return true + } + return false +} + func findColumnForNode(item ast.Node, tables []*Table, targetList *ast.List) error { ref, ok := item.(*ast.ColumnRef) if !ok { diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/go/db.go b/internal/endtoend/testdata/order_by_rowid/sqlite/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/go/models.go b/internal/endtoend/testdata/order_by_rowid/sqlite/go/models.go new file mode 100644 index 0000000000..b334d0651a --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/go/models.go @@ -0,0 +1,10 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +type Author struct { + ID int64 + Name string +} diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/go/query.sql.go b/internal/endtoend/testdata/order_by_rowid/sqlite/go/query.sql.go new file mode 100644 index 0000000000..7ded692519 --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/go/query.sql.go @@ -0,0 +1,64 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthorsByQualifiedRowid = `-- name: ListAuthorsByQualifiedRowid :many +SELECT id, name FROM authors ORDER BY authors._rowid_ DESC +` + +func (q *Queries) ListAuthorsByQualifiedRowid(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsByQualifiedRowid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsByRowid = `-- name: ListAuthorsByRowid :many +SELECT id, name FROM authors ORDER BY rowid +` + +func (q *Queries) ListAuthorsByRowid(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsByRowid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/query.sql b/internal/endtoend/testdata/order_by_rowid/sqlite/query.sql new file mode 100644 index 0000000000..5723bead81 --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/query.sql @@ -0,0 +1,5 @@ +-- name: ListAuthorsByRowid :many +SELECT id, name FROM authors ORDER BY rowid; + +-- name: ListAuthorsByQualifiedRowid :many +SELECT id, name FROM authors ORDER BY authors._rowid_ DESC; diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/schema.sql b/internal/endtoend/testdata/order_by_rowid/sqlite/schema.sql new file mode 100644 index 0000000000..0f2208b2a6 --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/schema.sql @@ -0,0 +1,4 @@ +CREATE TABLE authors ( + id INTEGER PRIMARY KEY, + name text NOT NULL +); diff --git a/internal/endtoend/testdata/order_by_rowid/sqlite/sqlc.json b/internal/endtoend/testdata/order_by_rowid/sqlite/sqlc.json new file mode 100644 index 0000000000..1f9d43df5d --- /dev/null +++ b/internal/endtoend/testdata/order_by_rowid/sqlite/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "sqlite", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/go/db.go b/internal/endtoend/testdata/sqlc_switch/mysql/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/go/models.go b/internal/endtoend/testdata/sqlc_switch/mysql/go/models.go new file mode 100644 index 0000000000..f864877022 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "time" +) + +type Author struct { + ID int64 + Name string + CreatedAt time.Time +} diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/go/query.sql.go b/internal/endtoend/testdata/sqlc_switch/mysql/go/query.sql.go new file mode 100644 index 0000000000..aee79c120f --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/go/query.sql.go @@ -0,0 +1,97 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthorsElse = `-- name: ListAuthorsElse :many +SELECT id, name, created_at FROM authors +WHERE name = ? +ORDER BY authors.id ASC +` + +func (q *Queries) ListAuthorsElse(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsElse, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNameAsc = `-- name: ListAuthorsNameAsc :many +SELECT id, name, created_at FROM authors +WHERE name = ? +ORDER BY authors.name ASC +` + +func (q *Queries) ListAuthorsNameAsc(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsNameAsc, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsRecent = `-- name: ListAuthorsRecent :many +SELECT id, name, created_at FROM authors +WHERE name = ? +ORDER BY authors.created_at DESC, authors.id DESC +` + +func (q *Queries) ListAuthorsRecent(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsRecent, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/query.sql b/internal/endtoend/testdata/sqlc_switch/mysql/query.sql new file mode 100644 index 0000000000..87c86717b7 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/query.sql @@ -0,0 +1,7 @@ +-- name: ListAuthors :many +SELECT * FROM authors +WHERE name = ? +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC')); diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/schema.sql b/internal/endtoend/testdata/sqlc_switch/mysql/schema.sql new file mode 100644 index 0000000000..98cfc6b3f4 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGINT PRIMARY KEY AUTO_INCREMENT, + name varchar(255) NOT NULL, + created_at datetime NOT NULL +); diff --git a/internal/endtoend/testdata/sqlc_switch/mysql/sqlc.json b/internal/endtoend/testdata/sqlc_switch/mysql/sqlc.json new file mode 100644 index 0000000000..974aa9ff9e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "mysql", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/db.go b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..0057c62319 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/models.go b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..bb085de6f3 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +type Author struct { + ID int64 + Name string + CreatedAt pgtype.Timestamptz +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..0962dfd25e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/go/query.sql.go @@ -0,0 +1,88 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthorsElse = `-- name: ListAuthorsElse :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.id ASC +` + +func (q *Queries) ListAuthorsElse(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.Query(ctx, listAuthorsElse, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNameAsc = `-- name: ListAuthorsNameAsc :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.name ASC +` + +func (q *Queries) ListAuthorsNameAsc(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.Query(ctx, listAuthorsNameAsc, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsRecent = `-- name: ListAuthorsRecent :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.created_at DESC, authors.id DESC +` + +func (q *Queries) ListAuthorsRecent(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.Query(ctx, listAuthorsRecent, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/query.sql b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/query.sql new file mode 100644 index 0000000000..c9b8da1615 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/query.sql @@ -0,0 +1,7 @@ +-- name: ListAuthors :many +SELECT * FROM authors +WHERE name = $1 +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC')); diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/schema.sql b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/schema.sql new file mode 100644 index 0000000000..f9f2d25ab8 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name text NOT NULL, + created_at timestamptz NOT NULL +); diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..d12b82a6c6 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/pgx/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "engine": "postgresql", + "sql_package": "pgx/v5", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..f864877022 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "time" +) + +type Author struct { + ID int64 + Name string + CreatedAt time.Time +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..948865d030 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,97 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listAuthorsElse = `-- name: ListAuthorsElse :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.id ASC +` + +func (q *Queries) ListAuthorsElse(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsElse, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNameAsc = `-- name: ListAuthorsNameAsc :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.name ASC +` + +func (q *Queries) ListAuthorsNameAsc(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsNameAsc, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsRecent = `-- name: ListAuthorsRecent :many +SELECT id, name, created_at FROM authors +WHERE name = $1 +ORDER BY authors.created_at DESC, authors.id DESC +` + +func (q *Queries) ListAuthorsRecent(ctx context.Context, name string) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsRecent, name) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/query.sql b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..c9b8da1615 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/query.sql @@ -0,0 +1,7 @@ +-- name: ListAuthors :many +SELECT * FROM authors +WHERE name = $1 +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC')); diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..f9f2d25ab8 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name text NOT NULL, + created_at timestamptz NOT NULL +); diff --git a/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..cd518671ac --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/postgresql/stdlib/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "postgresql", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/go/db.go b/internal/endtoend/testdata/sqlc_switch/sqlite/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/go/models.go b/internal/endtoend/testdata/sqlc_switch/sqlite/go/models.go new file mode 100644 index 0000000000..7a8fcad68e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +type Author struct { + ID int64 + Name string + CreatedAt string +} diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go b/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go new file mode 100644 index 0000000000..8e41d3c353 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/go/query.sql.go @@ -0,0 +1,150 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const findAuthorsElse = `-- name: FindAuthorsElse :many +SELECT id, name, created_at FROM authors +WHERE 1 = 1 +` + +func (q *Queries) FindAuthorsElse(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, findAuthorsElse) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const findAuthorsNamed = `-- name: FindAuthorsNamed :many +SELECT id, name, created_at FROM authors +WHERE name IS NOT NULL +` + +func (q *Queries) FindAuthorsNamed(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, findAuthorsNamed) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsElse = `-- name: ListAuthorsElse :many +SELECT id, name, created_at FROM authors +ORDER BY authors.id ASC +` + +func (q *Queries) ListAuthorsElse(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsElse) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsNameAsc = `-- name: ListAuthorsNameAsc :many +SELECT id, name, created_at FROM authors +ORDER BY authors.name ASC +` + +func (q *Queries) ListAuthorsNameAsc(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsNameAsc) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAuthorsRecent = `-- name: ListAuthorsRecent :many +SELECT id, name, created_at FROM authors +ORDER BY authors.created_at DESC, authors.id DESC +` + +func (q *Queries) ListAuthorsRecent(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthorsRecent) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.CreatedAt); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql b/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql new file mode 100644 index 0000000000..adf1674908 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/query.sql @@ -0,0 +1,12 @@ +-- name: FindAuthors :many +SELECT * FROM authors +WHERE sqlc.switch(@filter, + sqlc.when('named', 'name IS NOT NULL'), + sqlc.else( '1 = 1')); + +-- name: ListAuthors :many +SELECT id, name, created_at FROM authors +ORDER BY sqlc.switch(@sort, + sqlc.when('name_asc', 'authors.name ASC'), + sqlc.when('recent', 'authors.created_at DESC, authors.id DESC'), + sqlc.else( 'authors.id ASC')); diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/schema.sql b/internal/endtoend/testdata/sqlc_switch/sqlite/schema.sql new file mode 100644 index 0000000000..9acd8d6107 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id INTEGER PRIMARY KEY, + name text NOT NULL, + created_at text NOT NULL +); diff --git a/internal/endtoend/testdata/sqlc_switch/sqlite/sqlc.json b/internal/endtoend/testdata/sqlc_switch/sqlite/sqlc.json new file mode 100644 index 0000000000..1f9d43df5d --- /dev/null +++ b/internal/endtoend/testdata/sqlc_switch/sqlite/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "sqlite", + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index e9868f5be6..0c4c6aa9da 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -514,6 +514,13 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt()) selectStmt.LimitCount = limitCount selectStmt.LimitOffset = limitOffset + + if n.Order_by_stmt() != nil { + if sortClause, ok := c.convert(n.Order_by_stmt()).(*ast.List); ok { + selectStmt.SortClause = sortClause + } + } + // Only set WithClause if there are CTEs if len(ctes.Items) > 0 { selectStmt.WithClause = &ast.WithClause{Ctes: &ctes} @@ -622,21 +629,48 @@ func (c *cc) convertWildCardField(n *parser.Result_columnContext) *ast.ColumnRef } func (c *cc) convertOrderby_stmtContext(n parser.IOrder_by_stmtContext) ast.Node { - if orderBy, ok := n.(*parser.Order_by_stmtContext); ok { - list := &ast.List{Items: []ast.Node{}} - for _, o := range orderBy.AllOrdering_term() { - term, ok := o.(*parser.Ordering_termContext) - if !ok { - continue - } - list.Items = append(list.Items, &ast.CaseExpr{ - Xpr: c.convert(term.Expr()), - Location: term.Expr().GetStart().GetStart(), - }) + orderBy, ok := n.(*parser.Order_by_stmtContext) + if !ok || orderBy == nil { + return &ast.List{} + } + + list := &ast.List{Items: []ast.Node{}} + for _, o := range orderBy.AllOrdering_term() { + term, ok := o.(*parser.Ordering_termContext) + if !ok { + continue } - return list + list.Items = append(list.Items, c.convertOrderingTerm(term)) + } + + return list +} + +func (c *cc) convertOrderingTerm(term *parser.Ordering_termContext) *ast.SortBy { + sortByDir := ast.SortByDirDefault + if ad := term.Asc_desc(); ad != nil { + if ad.ASC_() != nil { + sortByDir = ast.SortByDirAsc + } else { + sortByDir = ast.SortByDirDesc + } + } + + sortByNulls := ast.SortByNullsDefault + if term.NULLS_() != nil { + if term.FIRST_() != nil { + sortByNulls = ast.SortByNullsFirst + } else { + sortByNulls = ast.SortByNullsLast + } + } + + return &ast.SortBy{ + Node: c.convert(term.Expr()), + SortbyDir: sortByDir, + SortbyNulls: sortByNulls, + UseOp: &ast.List{}, } - return todo("convertOrderby_stmtContext", n) } func (c *cc) convertLimit_stmtContext(n parser.ILimit_stmtContext) (ast.Node, ast.Node) { @@ -826,7 +860,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { if opCtx.MINUS() != nil { // Negative number: -expr return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, Rexpr: expr, } } @@ -837,7 +871,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { if opCtx.TILDE() != nil { // Bitwise NOT: ~expr return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, Rexpr: expr, } } diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index 76ee992a7a..629f4ff75a 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -24,6 +24,11 @@ type Metadata struct { RuleSkiplist map[string]struct{} Filename string + + // SwitchGroup is set on queries generated by expanding a sqlc.switch() macro. + // All branches of the same macro share the value (the original query name), + // letting codegen give them a single shared Params/Row struct. + SwitchGroup string } const ( diff --git a/internal/plugin/codegen.pb.go b/internal/plugin/codegen.pb.go index 525ffc72ef..660ee1f3bd 100644 --- a/internal/plugin/codegen.pb.go +++ b/internal/plugin/codegen.pb.go @@ -816,6 +816,14 @@ type Query struct { Comments []string `protobuf:"bytes,6,rep,name=comments,proto3" json:"comments,omitempty"` Filename string `protobuf:"bytes,7,opt,name=filename,proto3" json:"filename,omitempty"` InsertIntoTable *Identifier `protobuf:"bytes,8,opt,name=insert_into_table,proto3" json:"insert_into_table,omitempty"` + SwitchGroup string `protobuf:"bytes,9,opt,name=switch_group,json=switchGroup,proto3" json:"switch_group,omitempty"` +} + +func (x *Query) GetSwitchGroup() string { + if x != nil { + return x.SwitchGroup + } + return "" } func (x *Query) Reset() {