diff --git a/internal/dinosql/rewrite.go b/internal/dinosql/rewrite.go index 95b5c457a8..775f1c50e8 100644 --- a/internal/dinosql/rewrite.go +++ b/internal/dinosql/rewrite.go @@ -9,17 +9,21 @@ import ( ) // Given an AST node, return the string representation of names -func flatten(root nodes.Node) string { +func flatten(root nodes.Node) (string, bool) { sw := &stringWalker{} ast.Walk(sw, root) - return sw.String + return sw.String, sw.IsConst } type stringWalker struct { String string + IsConst bool } func (s *stringWalker) Visit(node nodes.Node) ast.Visitor { + if _, ok := node.(nodes.A_Const); ok { + s.IsConst = true + } if n, ok := node.(nodes.String); ok { s.String += n.Str } @@ -61,7 +65,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [ case isNamedParamFunc(node): fun := node.(nodes.FuncCall) - param := flatten(fun.Args) + param, isConst := flatten(fun.Args) if num, ok := args[param]; ok { cr.Replace(nodes.ParamRef{ Number: num, @@ -76,9 +80,15 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [ }) } // TODO: This code assumes that sqlc.arg(name) is on a single line + var old string + if isConst { + old = fmt.Sprintf("sqlc.arg('%s')", param) + } else { + old = fmt.Sprintf("sqlc.arg(%s)", param) + } edits = append(edits, edit{ Location: fun.Location - raw.StmtLocation, - Old: fmt.Sprintf("sqlc.arg(%s)", param), + Old: old, New: fmt.Sprintf("$%d", args[param]), }) return false @@ -86,7 +96,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [ case isNamedParamSignCast(node): expr := node.(nodes.A_Expr) cast := expr.Rexpr.(nodes.TypeCast) - param := flatten(cast.Arg) + param, _ := flatten(cast.Arg) if num, ok := args[param]; ok { cast.Arg = nodes.ParamRef{ Number: num, @@ -112,7 +122,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [ case isNamedParamSign(node): expr := node.(nodes.A_Expr) - param := flatten(expr.Rexpr) + param, _ := flatten(expr.Rexpr) if num, ok := args[param]; ok { cr.Replace(nodes.ParamRef{ Number: num, diff --git a/internal/endtoend/testdata/named_param/query.sql b/internal/endtoend/testdata/named_param/query.sql index 76c4e1a4d4..0d4488d0e0 100644 --- a/internal/endtoend/testdata/named_param/query.sql +++ b/internal/endtoend/testdata/named_param/query.sql @@ -1,7 +1,7 @@ CREATE TABLE foo (name text not null, bio text not null); -- name: FuncParams :many -SELECT name FROM foo WHERE name = sqlc.arg(slug) AND sqlc.arg(filter)::bool; +SELECT name FROM foo WHERE name = sqlc.arg('slug') AND sqlc.arg(filter)::bool; -- name: AtParams :many SELECT name FROM foo WHERE name = @slug AND @filter::bool; diff --git a/placeholder.go b/placeholder.go new file mode 100644 index 0000000000..d38639dfc9 --- /dev/null +++ b/placeholder.go @@ -0,0 +1,5 @@ +package sqlc +// This is a dummy file that allows SQLC to be "installed" as a module and locked using +// go.mod and then run using "go run github.com/kyleconroy/sqlc" + +type Placeholder struct{}