Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions internal/codegen/golang/opts/override.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool {
}

func (o *Override) MatchesColumn(col *plugin.Column) bool {
columnType := sdk.DataType(col.Type)
columnType := canonicalPostgreSQLType(sdk.DataType(col.Type))
overrideType := canonicalPostgreSQLType(o.DBType)
notNull := col.NotNull || col.IsArray
return o.DBType != "" && o.DBType == columnType && o.Nullable != notNull && o.Unsigned == col.Unsigned
return o.DBType != "" && overrideType == columnType && o.Nullable != notNull && o.Unsigned == col.Unsigned
}

func (o *Override) parse(req *plugin.GenerateRequest) (err error) {
Expand Down
101 changes: 101 additions & 0 deletions internal/codegen/golang/opts/override_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package opts

import (
"strings"
"testing"

"github.com/google/go-cmp/cmp"

"github.com/sqlc-dev/sqlc/internal/plugin"
)

func TestTypeOverrides(t *testing.T) {
Expand Down Expand Up @@ -100,6 +103,104 @@ func TestTypeOverrides(t *testing.T) {
}
}

func TestMatchesColumnTimestamptzAliases(t *testing.T) {
t.Parallel()

parseOverride := func(t *testing.T, dbType string, nullable bool) Override {
t.Helper()
o := Override{
DBType: dbType,
Nullable: nullable,
GoType: GoType{Spec: "*time.Time"},
}
if err := o.parse(nil); err != nil {
t.Fatalf("override parsing failed: %s", err)
}
return o
}

column := func(typeName string, nullable bool) *plugin.Column {
typ := &plugin.Identifier{Name: typeName}
if schema, name, ok := strings.Cut(typeName, "."); ok && schema == "pg_catalog" {
typ = &plugin.Identifier{Schema: schema, Name: name}
}
return &plugin.Column{
Type: typ,
NotNull: !nullable,
}
}

for _, test := range []struct {
name string
override Override
column *plugin.Column
wantMatch bool
}{
{
name: "timestamptz override matches timestamptz column",
override: parseOverride(t, "timestamptz", true),
column: column("timestamptz", true),
wantMatch: true,
},
{
name: "timestamptz override matches timestamp with time zone column",
override: parseOverride(t, "timestamptz", true),
column: column("timestamp with time zone", true),
wantMatch: true,
},
{
name: "timestamptz override matches pg_catalog.timestamptz column",
override: parseOverride(t, "timestamptz", true),
column: column("pg_catalog.timestamptz", true),
wantMatch: true,
},
{
name: "pg_catalog.timestamptz override matches timestamptz column",
override: parseOverride(t, "pg_catalog.timestamptz", true),
column: column("timestamptz", true),
wantMatch: true,
},
{
name: "pg_catalog.timestamptz override matches timestamp with time zone column",
override: parseOverride(t, "pg_catalog.timestamptz", true),
column: column("timestamp with time zone", true),
wantMatch: true,
},
{
name: "timestamp with time zone override matches timestamptz column",
override: parseOverride(t, "timestamp with time zone", true),
column: column("timestamptz", true),
wantMatch: true,
},
{
name: "timestamptz override does not match not-null column",
override: parseOverride(t, "timestamptz", true),
column: column("timestamptz", false),
wantMatch: false,
},
{
name: "timestamptz override does not match timestamp without time zone",
override: parseOverride(t, "timestamptz", true),
column: column("timestamp", true),
wantMatch: false,
},
{
name: "timestamptz override does not match timestamp without time zone long form",
override: parseOverride(t, "timestamptz", true),
column: column("timestamp without time zone", true),
wantMatch: false,
},
} {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
got := test.override.MatchesColumn(test.column)
if got != test.wantMatch {
t.Errorf("MatchesColumn() = %v, want %v", got, test.wantMatch)
}
})
}
}

func FuzzOverride(f *testing.F) {
for _, spec := range []string{
"string",
Expand Down
32 changes: 32 additions & 0 deletions internal/codegen/golang/opts/pg_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package opts

var pgTypeCanonicalNames map[string]string

func init() {
groups := []struct {
canonical string
aliases []string
}{
{"pg_catalog.timestamptz", []string{"timestamptz", "timestamp with time zone"}},
{"pg_catalog.timestamp", []string{"timestamp", "timestamp without time zone"}},
{"pg_catalog.time", []string{"time", "time without time zone"}},
{"pg_catalog.timetz", []string{"timetz", "time with time zone"}},
}

pgTypeCanonicalNames = make(map[string]string, len(groups)*3)
for _, g := range groups {
pgTypeCanonicalNames[g.canonical] = g.canonical
for _, alias := range g.aliases {
pgTypeCanonicalNames[alias] = g.canonical
}
}
}

// canonicalPostgreSQLType maps PostgreSQL type aliases to a single canonical name
// so db_type overrides match regardless of spelling in schema SQL or config.
func canonicalPostgreSQLType(t string) string {
if canonical, ok := pgTypeCanonicalNames[t]; ok {
return canonical
}
return t
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- name: ListUsers :many
SELECT * FROM users;
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE users (
short_form timestamptz,
long_form timestamp with time zone
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"version": "1",
"packages": [
{
"path": "go",
"engine": "postgresql",
"sql_package": "pgx/v5",
"name": "querytest",
"schema": "schema.sql",
"queries": "query.sql"
}
],
"overrides": [
{
"db_type": "timestamptz",
"nullable": true,
"go_type": {
"import": "time",
"type": "Time",
"pointer": true
}
}
]
}
Loading