Skip to content

Commit 895f4e1

Browse files
authored
fix: Validate sqlc function arguments (#1633)
* fix: validate sqlc.* function call arg count * remove nil/zero length check, skip bad FuncCalls in visitor * add integration test for zero argument sqlc function call
1 parent dfe4386 commit 895f4e1

File tree

6 files changed

+20
-9
lines changed

6 files changed

+20
-9
lines changed

internal/endtoend/testdata/sqlc_arg_invalid/mysql/query.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ select id, first_name from users where id = sqlc.argh(target_id);
99
-- name: TooManyArgs :one
1010
select id, first_name from users where id = sqlc.arg('foo', 'bar');
1111

12+
-- name: TooFewArgs :one
13+
select id, first_name from users where id = sqlc.arg();
14+
1215
-- name: InvalidArgFunc :one
1316
select id, first_name from users where id = sqlc.arg(sqlc.arg(target_id));
1417

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# package querytest
22
query.sql:7:1: function "sqlc.argh" does not exist
33
query.sql:10:45: expected 1 parameter to sqlc.arg; got 2
4-
query.sql:13:54: Invalid argument to sqlc.arg()
4+
query.sql:13:45: expected 1 parameter to sqlc.arg; got 0
55
query.sql:16:54: Invalid argument to sqlc.arg()
6+
query.sql:19:54: Invalid argument to sqlc.arg()

internal/endtoend/testdata/sqlc_arg_invalid/postgresql/query.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ select id, first_name from users where id = sqlc.argh(target_id);
99
-- name: TooManyArgs :one
1010
select id, first_name from users where id = sqlc.arg('foo', 'bar');
1111

12+
-- name: TooFewArgs :one
13+
select id, first_name from users where id = sqlc.arg();
14+
1215
-- name: InvalidArgFunc :one
1316
select id, first_name from users where id = sqlc.arg(sqlc.arg(target_id));
1417

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# package querytest
22
query.sql:7:1: function "sqlc.argh" does not exist
33
query.sql:10:45: expected 1 parameter to sqlc.arg; got 2
4-
query.sql:13:54: Invalid argument to sqlc.arg()
4+
query.sql:13:45: expected 1 parameter to sqlc.arg; got 0
55
query.sql:16:54: Invalid argument to sqlc.arg()
6+
query.sql:19:54: Invalid argument to sqlc.arg()

internal/sql/validate/func_call.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor {
3838
v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name)
3939
return nil
4040
}
41-
if call.Args == nil || len(call.Args.Items) == 0 {
42-
return v
43-
}
44-
if len(call.Args.Items) > 1 {
41+
42+
if len(call.Args.Items) != 1 {
4543
v.err = &sqlerr.Error{
4644
Message: fmt.Sprintf("expected 1 parameter to sqlc.arg; got %d", len(call.Args.Items)),
4745
Location: call.Pos(),

internal/sql/validate/param_style.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@ import (
1414
func ParamStyle(n ast.Node) error {
1515
namedFunc := astutils.Search(n, named.IsParamFunc)
1616
for _, f := range namedFunc.Items {
17-
fc, ok := f.(*ast.FuncCall)
18-
if ok {
19-
switch val := fc.Args.Items[0].(type) {
17+
if fc, ok := f.(*ast.FuncCall); ok {
18+
args := fc.Args.Items
19+
20+
if len(args) == 0 {
21+
continue
22+
}
23+
24+
switch val := args[0].(type) {
2025
case *ast.FuncCall:
2126
return &sqlerr.Error{
2227
Code: "", // TODO: Pick a new error code

0 commit comments

Comments
 (0)