Skip to content

Commit e41fe16

Browse files
author
Baroukh Ovadia
committed
Allow for mixing parameter styles
1 parent ba125cc commit e41fe16

File tree

9 files changed

+154
-45
lines changed

9 files changed

+154
-45
lines changed

internal/compiler/parse.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
3434
if o.Debug.DumpAST {
3535
debug.Dump(stmt)
3636
}
37-
if err := validate.ParamStyle(stmt); err != nil {
38-
return nil, err
39-
}
40-
if err := validate.ParamRef(stmt); err != nil {
37+
lastNumber, err := validate.ParamRef(stmt)
38+
if err != nil {
4139
return nil, err
4240
}
4341
raw, ok := stmt.(*ast.RawStmt)
@@ -75,7 +73,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
7573
return nil, err
7674
}
7775

78-
raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw)
76+
raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, lastNumber)
7977
rvs := rangeVars(raw.Stmt)
8078
refs := findParameters(raw.Stmt)
8179
if o.UsePositionalParameters {

internal/endtoend/testdata/mix/go/db.go

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/mix/go/models.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/mix/go/test.sql.go

Lines changed: 75 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"path": "go",
6+
"name": "querytest",
7+
"schema": "test.sql",
8+
"queries": "test.sql"
9+
}
10+
]
11+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
CREATE TABLE bar (
2+
id serial not null,
3+
name text not null,
4+
phone text not null
5+
);
6+
7+
-- name: CountOne :one
8+
SELECT count(1) FROM bar WHERE id = $2 AND phone < @phone_param and name <> $1;
9+
10+
-- name: CountTwo :one
11+
SELECT count(1) FROM bar WHERE id = sqlc.arg(id_param) AND phone < @phone_param and name <> $1;
12+
13+
-- name: CountThree :one
14+
SELECT count(1) FROM bar WHERE id > sqlc.arg(id_param) AND name = $1;
15+
16+
-- name: CountFour :one
17+
SELECT count(1) FROM bar WHERE id > $2 AND phone <> @phone_param AND name <> $1;

internal/sql/rewrite/parameters.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,14 @@ func isNamedParamSignCast(node ast.Node) bool {
4141
return astutils.Join(expr.Name, ".") == "@" && cast
4242
}
4343

44-
func NamedParameters(engine config.Engine, raw *ast.RawStmt) (*ast.RawStmt, map[int]string, []source.Edit) {
44+
func NamedParameters(engine config.Engine, raw *ast.RawStmt, argn int) (*ast.RawStmt, map[int]string, []source.Edit) {
4545
foundFunc := astutils.Search(raw, named.IsParamFunc)
4646
foundSign := astutils.Search(raw, named.IsParamSign)
4747
if len(foundFunc.Items)+len(foundSign.Items) == 0 {
4848
return raw, map[int]string{}, nil
4949
}
5050

5151
args := map[string]int{}
52-
argn := 0
5352
var edits []source.Edit
5453
node := astutils.Apply(raw, func(cr *astutils.Cursor) bool {
5554
node := cr.Node()

internal/sql/validate/param_ref.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"github.com/kyleconroy/sqlc/internal/sql/sqlerr"
99
)
1010

11-
func ParamRef(n ast.Node) error {
11+
func ParamRef(n ast.Node) (int, error) {
1212
var allrefs []*ast.ParamRef
1313

1414
// Find all parameter references
@@ -23,14 +23,17 @@ func ParamRef(n ast.Node) error {
2323
for _, r := range allrefs {
2424
seen[r.Number] = struct{}{}
2525
}
26-
26+
var max int
2727
for i := 1; i <= len(seen); i += 1 {
28+
if i > max {
29+
max = i
30+
}
2831
if _, ok := seen[i]; !ok {
29-
return &sqlerr.Error{
32+
return 0, &sqlerr.Error{
3033
Code: "42P18",
3134
Message: fmt.Sprintf("could not determine data type of parameter $%d", i),
3235
}
3336
}
3437
}
35-
return nil
38+
return max, nil
3639
}

internal/sql/validate/param_style.go

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)