Skip to content

Commit 6da56ff

Browse files
authored
parser: Return error if missing RETURNING (#131)
* parser: Return error if missing RETURNING If a query specifies a return value (:one or :many), require that statements have a RETURNING clause. Fixes #121
1 parent 0f39706 commit 6da56ff

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

internal/dinosql/parser.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,30 @@ func parseMetadata(t string) (string, string, error) {
339339
return "", "", nil
340340
}
341341

342+
func validateCmd(n nodes.Node, name, cmd string) error {
343+
// TODO: Convert cmd to an enum
344+
if !(cmd == ":many" || cmd == ":one") {
345+
return nil
346+
}
347+
var list nodes.List
348+
switch stmt := n.(type) {
349+
case nodes.SelectStmt:
350+
return nil
351+
case nodes.DeleteStmt:
352+
list = stmt.ReturningList
353+
case nodes.InsertStmt:
354+
list = stmt.ReturningList
355+
case nodes.UpdateStmt:
356+
list = stmt.ReturningList
357+
default:
358+
return nil
359+
}
360+
if len(list.Items) == 0 {
361+
return fmt.Errorf("query %q specifies parameter %q without containing a RETURNING clause", name, cmd)
362+
}
363+
return nil
364+
}
365+
342366
func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) {
343367
if err := validateParamRef(stmt); err != nil {
344368
return nil, err
@@ -366,7 +390,9 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
366390
if err != nil {
367391
return nil, err
368392
}
369-
393+
if err := validateCmd(raw.Stmt, name, cmd); err != nil {
394+
return nil, err
395+
}
370396
rvs := rangeVars(raw.Stmt)
371397
refs := findParameters(raw.Stmt)
372398
params, err := resolveCatalogRefs(c, rvs, refs)

internal/dinosql/query_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,27 +785,55 @@ func TestStarWalker(t *testing.T) {
785785
func TestInvalidQueries(t *testing.T) {
786786
for i, tc := range []struct {
787787
stmt string
788+
msg string
788789
}{
789790
{
790791
`
791792
CREATE TABLE foo (id text not null);
792793
-- name: ListFoos
793794
SELECT id FROM foo;
794795
`,
796+
"invalid query comment: -- name: ListFoos",
795797
},
796798
{
797799
`
798800
CREATE TABLE foo (id text not null);
799801
-- name: ListFoos :one :many
800802
SELECT id FROM foo;
801803
`,
804+
"invalid query comment: -- name: ListFoos :one :many",
802805
},
803806
{
804807
`
805808
CREATE TABLE foo (id text not null);
806809
-- name: ListFoos :two
807810
SELECT id FROM foo;
808811
`,
812+
"invalid query type: :two",
813+
},
814+
{
815+
`
816+
CREATE TABLE foo (id text not null);
817+
-- name: DeleteFoo :one
818+
DELETE FROM foo WHERE id = $1;
819+
`,
820+
`query "DeleteFoo" specifies parameter ":one" without containing a RETURNING clause`,
821+
},
822+
{
823+
`
824+
CREATE TABLE foo (id text not null);
825+
-- name: UpdateFoo :one
826+
UPDATE foo SET id = $2 WHERE id = $1;
827+
`,
828+
`query "UpdateFoo" specifies parameter ":one" without containing a RETURNING clause`,
829+
},
830+
{
831+
`
832+
CREATE TABLE foo (id text not null);
833+
-- name: InsertFoo :one
834+
INSERT INTO foo (id) VALUES ($1);
835+
`,
836+
`query "InsertFoo" specifies parameter ":one" without containing a RETURNING clause`,
809837
},
810838
} {
811839
test := tc
@@ -814,6 +842,9 @@ func TestInvalidQueries(t *testing.T) {
814842
if err == nil {
815843
t.Errorf("expected err, got nil")
816844
}
845+
if diff := cmp.Diff(test.msg, err.Error()); diff != "" {
846+
t.Errorf("error message differs: \n%s", diff)
847+
}
817848
})
818849
}
819850
}

0 commit comments

Comments
 (0)