From 61759f704f4659266c333d1284d5635199157350 Mon Sep 17 00:00:00 2001 From: Ryan Berger Date: Thu, 19 May 2022 16:35:04 -0600 Subject: [PATCH 1/3] fix: validate sqlc.* function call arg count --- internal/sql/validate/param_style.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go index 5e89601e03..eeb1b78c2c 100644 --- a/internal/sql/validate/param_style.go +++ b/internal/sql/validate/param_style.go @@ -1,6 +1,8 @@ package validate import ( + "fmt" + "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" "github.com/kyleconroy/sqlc/internal/sql/named" @@ -14,9 +16,19 @@ import ( func ParamStyle(n ast.Node) error { namedFunc := astutils.Search(n, named.IsParamFunc) for _, f := range namedFunc.Items { - fc, ok := f.(*ast.FuncCall) - if ok { - switch val := fc.Args.Items[0].(type) { + if fc, ok := f.(*ast.FuncCall); ok { + args := fc.Args.Items + + if len(args) != 1 { + return &sqlerr.Error{ + Code: "", // TODO: Pick a new error code + Message: fmt.Sprintf( + "sqlc.arg() requires one argument, %d provided", + len(args)), + } + } + + switch val := args[0].(type) { case *ast.FuncCall: return &sqlerr.Error{ Code: "", // TODO: Pick a new error code From 4ed34908ca41e0a963c65c0910fb235a8b7dee35 Mon Sep 17 00:00:00 2001 From: Ryan Berger Date: Mon, 23 May 2022 07:50:53 -0600 Subject: [PATCH 2/3] remove nil/zero length check, skip bad FuncCalls in visitor --- internal/sql/validate/func_call.go | 6 ++---- internal/sql/validate/param_style.go | 11 ++--------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/internal/sql/validate/func_call.go b/internal/sql/validate/func_call.go index 5fbac048d2..8ba6d6359d 100644 --- a/internal/sql/validate/func_call.go +++ b/internal/sql/validate/func_call.go @@ -38,10 +38,8 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil } - if call.Args == nil || len(call.Args.Items) == 0 { - return v - } - if len(call.Args.Items) > 1 { + + if len(call.Args.Items) != 1 { v.err = &sqlerr.Error{ Message: fmt.Sprintf("expected 1 parameter to sqlc.arg; got %d", len(call.Args.Items)), Location: call.Pos(), diff --git a/internal/sql/validate/param_style.go b/internal/sql/validate/param_style.go index eeb1b78c2c..48008122e8 100644 --- a/internal/sql/validate/param_style.go +++ b/internal/sql/validate/param_style.go @@ -1,8 +1,6 @@ package validate import ( - "fmt" - "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" "github.com/kyleconroy/sqlc/internal/sql/named" @@ -19,13 +17,8 @@ func ParamStyle(n ast.Node) error { if fc, ok := f.(*ast.FuncCall); ok { args := fc.Args.Items - if len(args) != 1 { - return &sqlerr.Error{ - Code: "", // TODO: Pick a new error code - Message: fmt.Sprintf( - "sqlc.arg() requires one argument, %d provided", - len(args)), - } + if len(args) == 0 { + continue } switch val := args[0].(type) { From 8cf27ffca7db3166125f414ef420657dc04102a5 Mon Sep 17 00:00:00 2001 From: Ryan Berger Date: Mon, 23 May 2022 08:05:41 -0600 Subject: [PATCH 3/3] add integration test for zero argument sqlc function call --- internal/endtoend/testdata/sqlc_arg_invalid/mysql/query.sql | 3 +++ internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt | 3 ++- .../endtoend/testdata/sqlc_arg_invalid/postgresql/query.sql | 3 +++ .../endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt | 3 ++- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/internal/endtoend/testdata/sqlc_arg_invalid/mysql/query.sql b/internal/endtoend/testdata/sqlc_arg_invalid/mysql/query.sql index ae901531d2..3e46d7204b 100644 --- a/internal/endtoend/testdata/sqlc_arg_invalid/mysql/query.sql +++ b/internal/endtoend/testdata/sqlc_arg_invalid/mysql/query.sql @@ -9,6 +9,9 @@ select id, first_name from users where id = sqlc.argh(target_id); -- name: TooManyArgs :one select id, first_name from users where id = sqlc.arg('foo', 'bar'); +-- name: TooFewArgs :one +select id, first_name from users where id = sqlc.arg(); + -- name: InvalidArgFunc :one select id, first_name from users where id = sqlc.arg(sqlc.arg(target_id)); diff --git a/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt b/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt index 3f07cbb5ef..8009988505 100644 --- a/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt +++ b/internal/endtoend/testdata/sqlc_arg_invalid/mysql/stderr.txt @@ -1,5 +1,6 @@ # package querytest query.sql:7:1: function "sqlc.argh" does not exist query.sql:10:45: expected 1 parameter to sqlc.arg; got 2 -query.sql:13:54: Invalid argument to sqlc.arg() +query.sql:13:45: expected 1 parameter to sqlc.arg; got 0 query.sql:16:54: Invalid argument to sqlc.arg() +query.sql:19:54: Invalid argument to sqlc.arg() diff --git a/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/query.sql b/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/query.sql index 6d3044085c..fe25398cce 100644 --- a/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/query.sql +++ b/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/query.sql @@ -9,6 +9,9 @@ select id, first_name from users where id = sqlc.argh(target_id); -- name: TooManyArgs :one select id, first_name from users where id = sqlc.arg('foo', 'bar'); +-- name: TooFewArgs :one +select id, first_name from users where id = sqlc.arg(); + -- name: InvalidArgFunc :one select id, first_name from users where id = sqlc.arg(sqlc.arg(target_id)); diff --git a/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt b/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt index 3f07cbb5ef..8009988505 100644 --- a/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt +++ b/internal/endtoend/testdata/sqlc_arg_invalid/postgresql/stderr.txt @@ -1,5 +1,6 @@ # package querytest query.sql:7:1: function "sqlc.argh" does not exist query.sql:10:45: expected 1 parameter to sqlc.arg; got 2 -query.sql:13:54: Invalid argument to sqlc.arg() +query.sql:13:45: expected 1 parameter to sqlc.arg; got 0 query.sql:16:54: Invalid argument to sqlc.arg() +query.sql:19:54: Invalid argument to sqlc.arg()