diff --git a/internal/codegen/golang/opts/override.go b/internal/codegen/golang/opts/override.go index 6916c0c7f3..bdde72a78f 100644 --- a/internal/codegen/golang/opts/override.go +++ b/internal/codegen/golang/opts/override.go @@ -78,9 +78,10 @@ func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool { } func (o *Override) MatchesColumn(col *plugin.Column) bool { - columnType := sdk.DataType(col.Type) + columnType := canonicalPostgreSQLType(sdk.DataType(col.Type)) + overrideType := canonicalPostgreSQLType(o.DBType) notNull := col.NotNull || col.IsArray - return o.DBType != "" && o.DBType == columnType && o.Nullable != notNull && o.Unsigned == col.Unsigned + return o.DBType != "" && overrideType == columnType && o.Nullable != notNull && o.Unsigned == col.Unsigned } func (o *Override) parse(req *plugin.GenerateRequest) (err error) { diff --git a/internal/codegen/golang/opts/override_test.go b/internal/codegen/golang/opts/override_test.go index 8405666f36..b0439c7603 100644 --- a/internal/codegen/golang/opts/override_test.go +++ b/internal/codegen/golang/opts/override_test.go @@ -1,9 +1,12 @@ package opts import ( + "strings" "testing" "github.com/google/go-cmp/cmp" + + "github.com/sqlc-dev/sqlc/internal/plugin" ) func TestTypeOverrides(t *testing.T) { @@ -100,6 +103,104 @@ func TestTypeOverrides(t *testing.T) { } } +func TestMatchesColumnTimestamptzAliases(t *testing.T) { + t.Parallel() + + parseOverride := func(t *testing.T, dbType string, nullable bool) Override { + t.Helper() + o := Override{ + DBType: dbType, + Nullable: nullable, + GoType: GoType{Spec: "*time.Time"}, + } + if err := o.parse(nil); err != nil { + t.Fatalf("override parsing failed: %s", err) + } + return o + } + + column := func(typeName string, nullable bool) *plugin.Column { + typ := &plugin.Identifier{Name: typeName} + if schema, name, ok := strings.Cut(typeName, "."); ok && schema == "pg_catalog" { + typ = &plugin.Identifier{Schema: schema, Name: name} + } + return &plugin.Column{ + Type: typ, + NotNull: !nullable, + } + } + + for _, test := range []struct { + name string + override Override + column *plugin.Column + wantMatch bool + }{ + { + name: "timestamptz override matches timestamptz column", + override: parseOverride(t, "timestamptz", true), + column: column("timestamptz", true), + wantMatch: true, + }, + { + name: "timestamptz override matches timestamp with time zone column", + override: parseOverride(t, "timestamptz", true), + column: column("timestamp with time zone", true), + wantMatch: true, + }, + { + name: "timestamptz override matches pg_catalog.timestamptz column", + override: parseOverride(t, "timestamptz", true), + column: column("pg_catalog.timestamptz", true), + wantMatch: true, + }, + { + name: "pg_catalog.timestamptz override matches timestamptz column", + override: parseOverride(t, "pg_catalog.timestamptz", true), + column: column("timestamptz", true), + wantMatch: true, + }, + { + name: "pg_catalog.timestamptz override matches timestamp with time zone column", + override: parseOverride(t, "pg_catalog.timestamptz", true), + column: column("timestamp with time zone", true), + wantMatch: true, + }, + { + name: "timestamp with time zone override matches timestamptz column", + override: parseOverride(t, "timestamp with time zone", true), + column: column("timestamptz", true), + wantMatch: true, + }, + { + name: "timestamptz override does not match not-null column", + override: parseOverride(t, "timestamptz", true), + column: column("timestamptz", false), + wantMatch: false, + }, + { + name: "timestamptz override does not match timestamp without time zone", + override: parseOverride(t, "timestamptz", true), + column: column("timestamp", true), + wantMatch: false, + }, + { + name: "timestamptz override does not match timestamp without time zone long form", + override: parseOverride(t, "timestamptz", true), + column: column("timestamp without time zone", true), + wantMatch: false, + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + got := test.override.MatchesColumn(test.column) + if got != test.wantMatch { + t.Errorf("MatchesColumn() = %v, want %v", got, test.wantMatch) + } + }) + } +} + func FuzzOverride(f *testing.F) { for _, spec := range []string{ "string", diff --git a/internal/codegen/golang/opts/pg_type.go b/internal/codegen/golang/opts/pg_type.go new file mode 100644 index 0000000000..ac8845ac24 --- /dev/null +++ b/internal/codegen/golang/opts/pg_type.go @@ -0,0 +1,32 @@ +package opts + +var pgTypeCanonicalNames map[string]string + +func init() { + groups := []struct { + canonical string + aliases []string + }{ + {"pg_catalog.timestamptz", []string{"timestamptz", "timestamp with time zone"}}, + {"pg_catalog.timestamp", []string{"timestamp", "timestamp without time zone"}}, + {"pg_catalog.time", []string{"time", "time without time zone"}}, + {"pg_catalog.timetz", []string{"timetz", "time with time zone"}}, + } + + pgTypeCanonicalNames = make(map[string]string, len(groups)*3) + for _, g := range groups { + pgTypeCanonicalNames[g.canonical] = g.canonical + for _, alias := range g.aliases { + pgTypeCanonicalNames[alias] = g.canonical + } + } +} + +// canonicalPostgreSQLType maps PostgreSQL type aliases to a single canonical name +// so db_type overrides match regardless of spelling in schema SQL or config. +func canonicalPostgreSQLType(t string) string { + if canonical, ok := pgTypeCanonicalNames[t]; ok { + return canonical + } + return t +} diff --git a/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/go/db.go b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/go/db.go new file mode 100644 index 0000000000..0057c62319 --- /dev/null +++ b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/go/models.go b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/go/models.go new file mode 100644 index 0000000000..0290897229 --- /dev/null +++ b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "time" +) + +type User struct { + ShortForm *time.Time + LongForm *time.Time +} diff --git a/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/go/query.sql.go new file mode 100644 index 0000000000..588b15f774 --- /dev/null +++ b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/go/query.sql.go @@ -0,0 +1,34 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const listUsers = `-- name: ListUsers :many +SELECT short_form, long_form FROM users +` + +func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { + rows, err := q.db.Query(ctx, listUsers) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan(&i.ShortForm, &i.LongForm); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/query.sql b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/query.sql new file mode 100644 index 0000000000..dba5df7732 --- /dev/null +++ b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/query.sql @@ -0,0 +1,2 @@ +-- name: ListUsers :many +SELECT * FROM users; diff --git a/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/schema.sql b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/schema.sql new file mode 100644 index 0000000000..83d20e464d --- /dev/null +++ b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/schema.sql @@ -0,0 +1,4 @@ +CREATE TABLE users ( + short_form timestamptz, + long_form timestamp with time zone +); diff --git a/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/sqlc.json b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/sqlc.json new file mode 100644 index 0000000000..1383e444a3 --- /dev/null +++ b/internal/endtoend/testdata/overrides_timestamptz_alias/postgresql/pgx/v5/sqlc.json @@ -0,0 +1,24 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v5", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ], + "overrides": [ + { + "db_type": "timestamptz", + "nullable": true, + "go_type": { + "import": "time", + "type": "Time", + "pointer": true + } + } + ] +}