From 21ef4093c944f27a21595a3d0c4a53a8dd59c207 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Mon, 18 Nov 2019 12:16:31 -0800 Subject: [PATCH 1/2] parser: Return error if missing RETURNING If a query specifies a return value (:one or :many), require that statements have a RETURNING clause. Fixes #121 --- internal/dinosql/parser.go | 28 +++++++++++++++++++++++++++- internal/dinosql/query_test.go | 7 +++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index 9230137ec9..ae157f9519 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -339,6 +339,30 @@ func parseMetadata(t string) (string, string, error) { return "", "", nil } +func validateCmd(n nodes.Node, name, cmd string) error { + // TODO: Convert cmd to an enum + if !(cmd == ":many" || cmd == ":one") { + return nil + } + var list nodes.List + switch stmt := n.(type) { + case nodes.SelectStmt: + return nil + case nodes.DeleteStmt: + list = stmt.ReturningList + case nodes.InsertStmt: + list = stmt.ReturningList + case nodes.UpdateStmt: + list = stmt.ReturningList + default: + return nil + } + if len(list.Items) == 0 { + return fmt.Errorf("query %q specifies parameter %q without containing a RETURNING clause", name, cmd) + } + return nil +} + func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) { if err := validateParamRef(stmt); err != nil { return nil, err @@ -366,7 +390,9 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err != nil { return nil, err } - + if err := validateCmd(raw.Stmt, name, cmd); err != nil { + return nil, err + } rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) params, err := resolveCatalogRefs(c, rvs, refs) diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index f92d35906a..0b6f07cb53 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -807,6 +807,13 @@ func TestInvalidQueries(t *testing.T) { SELECT id FROM foo; `, }, + { + ` + CREATE TABLE foo (id text not null); + -- name: DeleteFoo :one + DELETE FROM foo WHERE id = $1; + `, + }, } { test := tc t.Run(strconv.Itoa(i), func(t *testing.T) { From 53ee2b725bad601d2663fa0da4da2e56894e05ac Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Mon, 18 Nov 2019 12:23:31 -0800 Subject: [PATCH 2/2] Add tests for update and insert --- internal/dinosql/query_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index 0b6f07cb53..59d865868c 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -785,6 +785,7 @@ func TestStarWalker(t *testing.T) { func TestInvalidQueries(t *testing.T) { for i, tc := range []struct { stmt string + msg string }{ { ` @@ -792,6 +793,7 @@ func TestInvalidQueries(t *testing.T) { -- name: ListFoos SELECT id FROM foo; `, + "invalid query comment: -- name: ListFoos", }, { ` @@ -799,6 +801,7 @@ func TestInvalidQueries(t *testing.T) { -- name: ListFoos :one :many SELECT id FROM foo; `, + "invalid query comment: -- name: ListFoos :one :many", }, { ` @@ -806,6 +809,7 @@ func TestInvalidQueries(t *testing.T) { -- name: ListFoos :two SELECT id FROM foo; `, + "invalid query type: :two", }, { ` @@ -813,6 +817,23 @@ func TestInvalidQueries(t *testing.T) { -- name: DeleteFoo :one DELETE FROM foo WHERE id = $1; `, + `query "DeleteFoo" specifies parameter ":one" without containing a RETURNING clause`, + }, + { + ` + CREATE TABLE foo (id text not null); + -- name: UpdateFoo :one + UPDATE foo SET id = $2 WHERE id = $1; + `, + `query "UpdateFoo" specifies parameter ":one" without containing a RETURNING clause`, + }, + { + ` + CREATE TABLE foo (id text not null); + -- name: InsertFoo :one + INSERT INTO foo (id) VALUES ($1); + `, + `query "InsertFoo" specifies parameter ":one" without containing a RETURNING clause`, }, } { test := tc @@ -821,6 +842,9 @@ func TestInvalidQueries(t *testing.T) { if err == nil { t.Errorf("expected err, got nil") } + if diff := cmp.Diff(test.msg, err.Error()); diff != "" { + t.Errorf("error message differs: \n%s", diff) + } }) } }