diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index 9230137ec9..ae157f9519 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -339,6 +339,30 @@ func parseMetadata(t string) (string, string, error) { return "", "", nil } +func validateCmd(n nodes.Node, name, cmd string) error { + // TODO: Convert cmd to an enum + if !(cmd == ":many" || cmd == ":one") { + return nil + } + var list nodes.List + switch stmt := n.(type) { + case nodes.SelectStmt: + return nil + case nodes.DeleteStmt: + list = stmt.ReturningList + case nodes.InsertStmt: + list = stmt.ReturningList + case nodes.UpdateStmt: + list = stmt.ReturningList + default: + return nil + } + if len(list.Items) == 0 { + return fmt.Errorf("query %q specifies parameter %q without containing a RETURNING clause", name, cmd) + } + return nil +} + func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) { if err := validateParamRef(stmt); err != nil { return nil, err @@ -366,7 +390,9 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err != nil { return nil, err } - + if err := validateCmd(raw.Stmt, name, cmd); err != nil { + return nil, err + } rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) params, err := resolveCatalogRefs(c, rvs, refs) diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index f92d35906a..59d865868c 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -785,6 +785,7 @@ func TestStarWalker(t *testing.T) { func TestInvalidQueries(t *testing.T) { for i, tc := range []struct { stmt string + msg string }{ { ` @@ -792,6 +793,7 @@ func TestInvalidQueries(t *testing.T) { -- name: ListFoos SELECT id FROM foo; `, + "invalid query comment: -- name: ListFoos", }, { ` @@ -799,6 +801,7 @@ func TestInvalidQueries(t *testing.T) { -- name: ListFoos :one :many SELECT id FROM foo; `, + "invalid query comment: -- name: ListFoos :one :many", }, { ` @@ -806,6 +809,31 @@ func TestInvalidQueries(t *testing.T) { -- name: ListFoos :two SELECT id FROM foo; `, + "invalid query type: :two", + }, + { + ` + CREATE TABLE foo (id text not null); + -- name: DeleteFoo :one + DELETE FROM foo WHERE id = $1; + `, + `query "DeleteFoo" specifies parameter ":one" without containing a RETURNING clause`, + }, + { + ` + CREATE TABLE foo (id text not null); + -- name: UpdateFoo :one + UPDATE foo SET id = $2 WHERE id = $1; + `, + `query "UpdateFoo" specifies parameter ":one" without containing a RETURNING clause`, + }, + { + ` + CREATE TABLE foo (id text not null); + -- name: InsertFoo :one + INSERT INTO foo (id) VALUES ($1); + `, + `query "InsertFoo" specifies parameter ":one" without containing a RETURNING clause`, }, } { test := tc @@ -814,6 +842,9 @@ func TestInvalidQueries(t *testing.T) { if err == nil { t.Errorf("expected err, got nil") } + if diff := cmp.Diff(test.msg, err.Error()); diff != "" { + t.Errorf("error message differs: \n%s", diff) + } }) } }