From b4d0b94dcd4cb0d42f669018c53ebd4aa7c07d42 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 26 Sep 2021 15:10:23 -0700 Subject: [PATCH] fix(engine/mysql): Case-insensitive identifiers --- .../identifier_case_sensitivity/db/db.go | 29 +++++++ .../identifier_case_sensitivity/db/models.go | 13 ++++ .../db/query.sql.go | 76 +++++++++++++++++++ .../identifier_case_sensitivity/query.sql | 24 ++++++ .../identifier_case_sensitivity/sqlc.json | 11 +++ internal/engine/dolphin/convert.go | 39 ++++++---- internal/engine/dolphin/utils.go | 8 +- 7 files changed, 180 insertions(+), 20 deletions(-) create mode 100644 internal/endtoend/testdata/identifier_case_sensitivity/db/db.go create mode 100644 internal/endtoend/testdata/identifier_case_sensitivity/db/models.go create mode 100644 internal/endtoend/testdata/identifier_case_sensitivity/db/query.sql.go create mode 100644 internal/endtoend/testdata/identifier_case_sensitivity/query.sql create mode 100644 internal/endtoend/testdata/identifier_case_sensitivity/sqlc.json diff --git a/internal/endtoend/testdata/identifier_case_sensitivity/db/db.go b/internal/endtoend/testdata/identifier_case_sensitivity/db/db.go new file mode 100644 index 0000000000..c3c034ae37 --- /dev/null +++ b/internal/endtoend/testdata/identifier_case_sensitivity/db/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/identifier_case_sensitivity/db/models.go b/internal/endtoend/testdata/identifier_case_sensitivity/db/models.go new file mode 100644 index 0000000000..3bc48237ef --- /dev/null +++ b/internal/endtoend/testdata/identifier_case_sensitivity/db/models.go @@ -0,0 +1,13 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "database/sql" +) + +type Author struct { + ID int64 + Name string + Bio sql.NullString +} diff --git a/internal/endtoend/testdata/identifier_case_sensitivity/db/query.sql.go b/internal/endtoend/testdata/identifier_case_sensitivity/db/query.sql.go new file mode 100644 index 0000000000..0d05dd0c07 --- /dev/null +++ b/internal/endtoend/testdata/identifier_case_sensitivity/db/query.sql.go @@ -0,0 +1,76 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package db + +import ( + "context" + "database/sql" +) + +const createAuthor = `-- name: CreateAuthor :execresult +INSERT INTO Authors ( + Name, Bio +) VALUES ( + ?, ? +) +` + +type CreateAuthorParams struct { + Name string + Bio sql.NullString +} + +func (q *Queries) CreateAuthor(ctx context.Context, arg CreateAuthorParams) (sql.Result, error) { + return q.db.ExecContext(ctx, createAuthor, arg.Name, arg.Bio) +} + +const deleteAuthor = `-- name: DeleteAuthor :exec +DELETE FROM Authors +WHERE ID = ? +` + +func (q *Queries) DeleteAuthor(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deleteAuthor, id) + return err +} + +const getAuthor = `-- name: GetAuthor :one +SELECT id, name, bio FROM Authors +WHERE ID = ? LIMIT 1 +` + +func (q *Queries) GetAuthor(ctx context.Context, id int64) (Author, error) { + row := q.db.QueryRowContext(ctx, getAuthor, id) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio FROM Authors +ORDER BY Name +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/identifier_case_sensitivity/query.sql b/internal/endtoend/testdata/identifier_case_sensitivity/query.sql new file mode 100644 index 0000000000..ac19cf948a --- /dev/null +++ b/internal/endtoend/testdata/identifier_case_sensitivity/query.sql @@ -0,0 +1,24 @@ +CREATE TABLE Authors ( + ID BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY, + Name text NOT NULL, + Bio text +); + +-- name: GetAuthor :one +SELECT * FROM Authors +WHERE ID = ? LIMIT 1; + +-- name: ListAuthors :many +SELECT * FROM Authors +ORDER BY Name; + +-- name: CreateAuthor :execresult +INSERT INTO Authors ( + Name, Bio +) VALUES ( + ?, ? +); + +-- name: DeleteAuthor :exec +DELETE FROM Authors +WHERE ID = ?; diff --git a/internal/endtoend/testdata/identifier_case_sensitivity/sqlc.json b/internal/endtoend/testdata/identifier_case_sensitivity/sqlc.json new file mode 100644 index 0000000000..72d8821559 --- /dev/null +++ b/internal/endtoend/testdata/identifier_case_sensitivity/sqlc.json @@ -0,0 +1,11 @@ +{ + "version": "1", + "packages": [ + { + "path": "db", + "engine": "mysql", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 0b7124c754..f6e5853969 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -25,6 +25,14 @@ func todo(n pcast.Node) *ast.TODO { return &ast.TODO{} } +func identifier(id string) string { + return strings.ToLower(id) +} + +func NewIdentifer(t string) *ast.String { + return &ast.String{Str: identifier(t)} +} + func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { alt := &ast.AlterTableStmt{ Table: parseTableName(n.Table), @@ -119,7 +127,7 @@ func (c *cc) convertAlterTableStmt(n *pcast.AlterTableStmt) ast.Node { } func (c *cc) convertAssignment(n *pcast.Assignment) *ast.ResTarget { - name := n.Column.Name.String() + name := identifier(n.Column.Name.String()) return &ast.ResTarget{ Name: &name, Val: c.convert(n.Expr), @@ -259,12 +267,12 @@ func (c *cc) convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node { func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef { var items []ast.Node if schema := n.Name.Schema.String(); schema != "" { - items = append(items, &ast.String{Str: schema}) + items = append(items, NewIdentifer(schema)) } if table := n.Name.Table.String(); table != "" { - items = append(items, &ast.String{Str: table}) + items = append(items, NewIdentifer(table)) } - items = append(items, &ast.String{Str: n.Name.Name.String()}) + items = append(items, NewIdentifer(n.Name.Name.String())) return &ast.ColumnRef{ Fields: &ast.List{ Items: items, @@ -275,7 +283,7 @@ func (c *cc) convertColumnNameExpr(n *pcast.ColumnNameExpr) *ast.ColumnRef { func (c *cc) convertColumnNames(cols []*pcast.ColumnName) *ast.List { list := &ast.List{Items: []ast.Node{}} for i := range cols { - name := cols[i].Name.String() + name := identifier(cols[i].Name.String()) list.Items = append(list.Items, &ast.ResTarget{ Name: &name, }) @@ -344,9 +352,9 @@ func (c *cc) convertFuncCallExpr(n *pcast.FuncCallExpr) ast.Node { // TODO: Deprecate the usage of Funcname items := []ast.Node{} if schema != "" { - items = append(items, &ast.String{Str: schema}) + items = append(items, NewIdentifer(schema)) } - items = append(items, &ast.String{Str: name}) + items = append(items, NewIdentifer(name)) args := &ast.List{} for _, arg := range n.Args { @@ -432,7 +440,8 @@ func (c *cc) convertSelectField(n *pcast.SelectField) *ast.ResTarget { } var name *string if n.AsName.O != "" { - name = &n.AsName.O + asname := identifier(n.AsName.O) + name = &asname } return &ast.ResTarget{ // TODO: Populate Indirection field @@ -479,7 +488,7 @@ func (c *cc) convertCommonTableExpression(n *pcast.CommonTableExpression) *ast.C columns := &ast.List{} for _, col := range n.ColNameList { - columns.Items = append(columns.Items, &ast.String{Str: col.String()}) + columns.Items = append(columns.Items, NewIdentifer(col.String())) } return &ast.CommonTableExpr{ @@ -556,7 +565,7 @@ func (c *cc) convertValueExpr(n *driver.ValueExpr) *ast.A_Const { func (c *cc) convertWildCardField(n *pcast.WildCardField) *ast.ColumnRef { items := []ast.Node{} if t := n.Table.String(); t != "" { - items = append(items, &ast.String{Str: t}) + items = append(items, NewIdentifer(t)) } items = append(items, &ast.A_Star{}) @@ -579,9 +588,7 @@ func (c *cc) convertAggregateFuncExpr(n *pcast.AggregateFuncExpr) *ast.FuncCall }, Funcname: &ast.List{ Items: []ast.Node{ - &ast.String{ - Str: name, - }, + NewIdentifer(name), }, }, Args: &ast.List{}, @@ -740,7 +747,7 @@ func (c *cc) convertDropDatabaseStmt(n *pcast.DropDatabaseStmt) ast.Node { return &ast.DropSchemaStmt{ MissingOk: !n.IfExists, Schemas: []*ast.String{ - {Str: n.Name}, + NewIdentifer(n.Name), }, } } @@ -1138,8 +1145,8 @@ func (c *cc) convertSplitRegionStmt(n *pcast.SplitRegionStmt) ast.Node { } func (c *cc) convertTableName(n *pcast.TableName) *ast.RangeVar { - schema := n.Schema.String() - rel := n.Name.String() + schema := identifier(n.Schema.String()) + rel := identifier(n.Name.String()) return &ast.RangeVar{ Schemaname: &schema, Relname: &rel, diff --git a/internal/engine/dolphin/utils.go b/internal/engine/dolphin/utils.go index 01cc2fa9ff..ce866c0eda 100644 --- a/internal/engine/dolphin/utils.go +++ b/internal/engine/dolphin/utils.go @@ -66,8 +66,8 @@ func text(nodes []pcast.Node) []string { func parseTableName(n *pcast.TableName) *ast.TableName { return &ast.TableName{ - Schema: n.Schema.String(), - Name: n.Name.String(), + Schema: identifier(n.Schema.String()), + Name: identifier(n.Name.String()), } } @@ -76,9 +76,9 @@ func toList(node pcast.Node) *ast.List { switch n := node.(type) { case *pcast.TableName: if schema := n.Schema.String(); schema != "" { - items = append(items, &ast.String{Str: schema}) + items = append(items, NewIdentifer(schema)) } - items = append(items, &ast.String{Str: n.Name.String()}) + items = append(items, NewIdentifer(n.Name.String())) default: return nil }