Skip to content

Commit 12a6331

Browse files
committed
postgresql: Generate the functions in pg_catalog
Connecting to a default instance of PostgreSQL, generate the complete set of available functions
1 parent 64ac3a6 commit 12a6331

File tree

8 files changed

+33212
-165
lines changed

8 files changed

+33212
-165
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,8 @@ test:
77
sqlc-dev:
88
go build -o ~/bin/sqlc-dev ./cmd/sqlc/
99

10+
sqlc-pg-gen:
11+
go build -o ~/bin/sqlc-pg-gen ./internal/postgresql/cmd/gen
12+
1013
regen: sqlc-dev
1114
./scripts/regenerate.sh

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ require (
66
github.com/antlr/antlr4 v0.0.0-20200209180723-1177c0b58d07
77
github.com/davecgh/go-spew v1.1.1
88
github.com/google/go-cmp v0.3.0
9+
github.com/jackc/pgtype v1.3.0
10+
github.com/jackc/pgx/v4 v4.6.0
911
github.com/jinzhu/inflection v1.0.0
1012
github.com/lfittl/pg_query_go v1.0.0
1113
github.com/lib/pq v1.4.0

go.sum

Lines changed: 64 additions & 0 deletions
Large diffs are not rendered by default.

internal/compiler/go_type.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func (r *Result) postgresType(col *Column, settings config.CombinedSettings) str
103103
}
104104
return "sql.NullString"
105105

106-
case "bool", "pg_catalog.bool":
106+
case "boolean", "bool", "pg_catalog.bool":
107107
if notNull {
108108
return "bool"
109109
}

internal/postgresql/cmd/gen/main.go

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"fmt"
7+
"go/format"
8+
"log"
9+
"os"
10+
"strings"
11+
"text/template"
12+
13+
pgx "github.com/jackc/pgx/v4"
14+
15+
"github.com/kyleconroy/sqlc/internal/sql/ast"
16+
"github.com/kyleconroy/sqlc/internal/sql/catalog"
17+
)
18+
19+
// https://stackoverflow.com/questions/25308765/postgresql-how-can-i-inspect-which-arguments-to-a-procedure-have-a-default-valu
20+
const catalogFuncs = `
21+
SELECT p.proname as name,
22+
format_type(p.prorettype, NULL),
23+
array(select format_type(unnest(p.proargtypes), NULL)),
24+
p.proargnames,
25+
p.proargnames[p.pronargs-p.pronargdefaults+1:p.pronargs]
26+
FROM pg_catalog.pg_proc p
27+
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
28+
WHERE n.nspname OPERATOR(pg_catalog.~) '^(pg_catalog)$'
29+
AND p.proargmodes IS NULL
30+
AND pg_function_is_visible(p.oid)
31+
ORDER BY 1;
32+
`
33+
34+
const catalogTmpl = `
35+
package postgresql
36+
37+
import (
38+
"github.com/kyleconroy/sqlc/internal/sql/ast"
39+
"github.com/kyleconroy/sqlc/internal/sql/catalog"
40+
)
41+
42+
func pgCatalog() *catalog.Schema {
43+
s := &catalog.Schema{Name: "pg_catalog"}
44+
s.Funcs = []*catalog.Function{
45+
{{- range .}}
46+
{
47+
Name: "{{.Name}}",
48+
Args: []*catalog.Argument{
49+
{{range .Args}}{
50+
{{- if .Name}}
51+
Name: "{{.Name}}",
52+
{{- end}}
53+
{{- if .HasDefault}}
54+
HasDefault: true,
55+
{{- end}}
56+
Type: &ast.TypeName{Name: "{{.Type.Name}}"},
57+
},
58+
{{end}}
59+
},
60+
ReturnType: &ast.TypeName{Name: "{{.ReturnType.Name}}"},
61+
},
62+
{{- end}}
63+
}
64+
return s
65+
}
66+
`
67+
68+
func main() {
69+
if err := run(context.Background()); err != nil {
70+
log.Fatal(err)
71+
}
72+
}
73+
74+
type Proc struct {
75+
Name string
76+
ReturnType string
77+
ArgTypes []string
78+
ArgNames []string
79+
HasDefault []string
80+
}
81+
82+
func clean(arg string) string {
83+
arg = strings.TrimSpace(arg)
84+
arg = strings.Replace(arg, "\"any\"", "any", -1)
85+
arg = strings.Replace(arg, "\"char\"", "char", -1)
86+
arg = strings.Replace(arg, "\"timestamp\"", "char", -1)
87+
return arg
88+
}
89+
90+
func (p Proc) Func() catalog.Function {
91+
return catalog.Function{
92+
Name: p.Name,
93+
Args: p.Args(),
94+
ReturnType: &ast.TypeName{Name: clean(p.ReturnType)},
95+
}
96+
}
97+
98+
func (p Proc) Args() []*catalog.Argument {
99+
defaults := map[string]bool{}
100+
var args []*catalog.Argument
101+
if len(p.ArgTypes) == 0 {
102+
return args
103+
}
104+
for _, name := range p.HasDefault {
105+
defaults[name] = true
106+
}
107+
for i, arg := range p.ArgTypes {
108+
var name string
109+
if i < len(p.ArgNames) {
110+
name = p.ArgNames[i]
111+
}
112+
args = append(args, &catalog.Argument{
113+
Name: name,
114+
HasDefault: defaults[name],
115+
Type: &ast.TypeName{Name: clean(arg)},
116+
})
117+
}
118+
return args
119+
}
120+
121+
func run(ctx context.Context) error {
122+
tmpl, err := template.New("").Parse(catalogTmpl)
123+
if err != nil {
124+
return err
125+
}
126+
conn, err := pgx.Connect(ctx, os.Getenv("DATABASE_URL"))
127+
if err != nil {
128+
return err
129+
}
130+
defer conn.Close(ctx)
131+
132+
rows, err := conn.Query(ctx, catalogFuncs)
133+
if err != nil {
134+
return err
135+
}
136+
137+
defer rows.Close()
138+
139+
// Iterate through the result set
140+
var funcs []catalog.Function
141+
for rows.Next() {
142+
var p Proc
143+
err = rows.Scan(
144+
&p.Name,
145+
&p.ReturnType,
146+
&p.ArgTypes,
147+
&p.ArgNames,
148+
&p.HasDefault,
149+
)
150+
if err != nil {
151+
return err
152+
}
153+
154+
// TODO: Filter these out in SQL
155+
if strings.HasPrefix(p.ReturnType, "SETOF") {
156+
continue
157+
}
158+
159+
// The internal pseudo-type is used to declare functions that are meant
160+
// only to be called internally by the database system, and not by
161+
// direct invocation in an SQL query. If a function has at least one
162+
// internal-type argument then it cannot be called from SQL. To
163+
// preserve the type safety of this restriction it is important to
164+
// follow this coding rule: do not create any function that is declared
165+
// to return internal unless it has at least one internal argument
166+
//
167+
// https://www.postgresql.org/docs/current/datatype-pseudo.html
168+
for i := range p.ArgTypes {
169+
if p.ArgTypes[i] == "internal" {
170+
continue
171+
}
172+
}
173+
174+
funcs = append(funcs, p.Func())
175+
}
176+
177+
if rows.Err() != nil {
178+
return err
179+
}
180+
181+
out := bytes.NewBuffer([]byte{})
182+
if err := tmpl.Execute(out, funcs); err != nil {
183+
return err
184+
}
185+
code, err := format.Source(out.Bytes())
186+
if err != nil {
187+
return err
188+
}
189+
_, err = fmt.Fprintf(os.Stdout, string(code))
190+
return err
191+
}

0 commit comments

Comments
 (0)