diff --git a/internal/compiler/go_type.go b/internal/compiler/go_type.go index e2426e9258..38db8e40e8 100644 --- a/internal/compiler/go_type.go +++ b/internal/compiler/go_type.go @@ -33,6 +33,23 @@ func (r *Result) goInnerType(col *Column, settings config.CombinedSettings) stri } } + // TODO: Extend the engine interface to handle types + switch settings.Package.Engine { + case config.EngineMySQL, config.EngineXDolphin: + return r.mysqlType(col, settings) + case config.EnginePostgreSQL: + return r.postgresType(col, settings) + case config.EngineXLemon: + return r.postgresType(col, settings) + default: + return "interface{}" + } +} + +func (r *Result) postgresType(col *Column, settings config.CombinedSettings) string { + columnType := col.DataType + notNull := col.NotNull || col.IsArray + switch columnType { case "serial", "pg_catalog.serial4": if notNull { @@ -186,3 +203,65 @@ func (r *Result) goInnerType(col *Column, settings config.CombinedSettings) stri return "interface{}" } } + +func (r *Result) mysqlType(col *Column, settings config.CombinedSettings) string { + columnType := col.DataType + notNull := col.NotNull || col.IsArray + + switch columnType { + + case "varchar", "text", "char", "tinytext", "mediumtext", "longtext": + if notNull { + return "string" + } + return "sql.NullString" + + case "int", "integer", "smallint", "mediumint", "year": + if notNull { + return "int32" + } + return "sql.NullInt32" + + case "bigint": + if notNull { + return "int64" + } + return "sql.NullInt64" + + case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob": + return "[]byte" + + case "double", "double precision", "real": + if notNull { + return "float64" + } + return "sql.NullFloat64" + + case "decimal", "dec", "fixed": + if notNull { + return "string" + } + return "sql.NullString" + + case "enum": + // TODO: Proper Enum support + return "string" + + case "date", "timestamp", "datetime", "time": + if notNull { + return "time.Time" + } + return "sql.NullTime" + + case "boolean", "bool", "tinyint": + if notNull { + return "bool" + } + return "sql.NullBool" + + default: + log.Printf("unknown MySQL type: %s\n", columnType) + return "interface{}" + + } +} diff --git a/internal/dolphin/convert.go b/internal/dolphin/convert.go index 0b528a4f7d..970b89b527 100644 --- a/internal/dolphin/convert.go +++ b/internal/dolphin/convert.go @@ -62,6 +62,15 @@ func convertCreateTableStmt(n *pcast.CreateTableStmt) ast.Node { IfNotExists: n.IfNotExists, } for _, def := range n.Cols { + var vals *ast.List + if len(def.Tp.Elems) > 0 { + vals = &ast.List{} + for i := range def.Tp.Elems { + vals.Items = append(vals.Items, &ast.String{ + Str: def.Tp.Elems[i], + }) + } + } create.Cols = append(create.Cols, &ast.ColumnDef{ Colname: def.Name.String(), TypeName: &ast.TypeName{Name: types.TypeStr(def.Tp.Tp)}, @@ -79,30 +88,64 @@ func convertDropTableStmt(n *pcast.DropTableStmt) ast.Node { return drop } -func convertSelectStmt(n *pcast.SelectStmt) ast.Node { +func convertFieldList(n *pcast.FieldList) *ast.List { + fields := make([]ast.Node, len(n.Fields)) + for i := range n.Fields { + fields[i] = convertSelectField(n.Fields[i]) + } + return &ast.List{Items: fields} +} + +func convertSelectField(n *pcast.SelectField) *pg.ResTarget { + var val ast.Node + if n.WildCard != nil { + val = convertWildCardField(n.WildCard) + } else { + val = convert(n.Expr) + } + var name *string + if n.AsName.O != "" { + name = &n.AsName.O + } + return &pg.ResTarget{ + // TODO: Populate Indirection field + Name: name, + Val: val, + Location: n.Offset, + } +} + +func convertSelectStmt(n *pcast.SelectStmt) *pg.SelectStmt { + return &pg.SelectStmt{ + TargetList: convertFieldList(n.Fields), + FromClause: convertTableRefsClause(n.From), + } +} + +func convertTableRefsClause(n *pcast.TableRefsClause) *ast.List { var tables []ast.Node - visit(n.From, func(n pcast.Node) { + visit(n, func(n pcast.Node) { name, ok := n.(*pcast.TableName) if !ok { return } - tables = append(tables, parseTableName(name)) - }) - var cols []ast.Node - visit(n.Fields, func(n pcast.Node) { - col, ok := n.(*pcast.ColumnName) - if !ok { - return - } - cols = append(cols, &ast.ResTarget{ - Val: &ast.ColumnRef{ - Name: col.Name.String(), - }, + schema := name.Schema.String() + rel := name.Name.String() + tables = append(tables, &pg.RangeVar{ + Schemaname: &schema, + Relname: &rel, }) }) - return &pg.SelectStmt{ - FromClause: &ast.List{Items: tables}, - TargetList: &ast.List{Items: cols}, + return &ast.List{Items: tables} +} + +func convertWildCardField(n *pcast.WildCardField) *pg.ColumnRef { + return &pg.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{ + &pg.A_Star{}, + }, + }, } } diff --git a/internal/dolphin/parse.go b/internal/dolphin/parse.go index b4487ee719..5a4a48801a 100644 --- a/internal/dolphin/parse.go +++ b/internal/dolphin/parse.go @@ -73,7 +73,7 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { Raw: &ast.RawStmt{ Stmt: out, StmtLocation: loc, - StmtLen: len(text), + StmtLen: len(text) - 1, // Subtract one to remove semicolon }, }) } diff --git a/internal/endtoend/testdata/dolphin_select_star/endtoend.json b/internal/endtoend/testdata/dolphin_select_star/endtoend.json new file mode 100644 index 0000000000..e9665ca4a3 --- /dev/null +++ b/internal/endtoend/testdata/dolphin_select_star/endtoend.json @@ -0,0 +1,3 @@ +{ + "experimental_parser_only": true +} diff --git a/internal/endtoend/testdata/dolphin_select_star/go/db.go b/internal/endtoend/testdata/dolphin_select_star/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/dolphin_select_star/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/dolphin_select_star/go/models.go b/internal/endtoend/testdata/dolphin_select_star/go/models.go new file mode 100644 index 0000000000..3c82a24d7c --- /dev/null +++ b/internal/endtoend/testdata/dolphin_select_star/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type User struct { + ID int32 `json:"id"` + FirstName string `json:"first_name"` + LastName sql.NullString `json:"last_name"` + Age int32 `json:"age"` +} diff --git a/internal/endtoend/testdata/dolphin_select_star/go/query.sql.go b/internal/endtoend/testdata/dolphin_select_star/go/query.sql.go new file mode 100644 index 0000000000..062f587b58 --- /dev/null +++ b/internal/endtoend/testdata/dolphin_select_star/go/query.sql.go @@ -0,0 +1,40 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const getAll = `-- name: GetAll :many +SELECT id, first_name, last_name, age FROM users +` + +func (q *Queries) GetAll(ctx context.Context) ([]User, error) { + rows, err := q.db.QueryContext(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.Age, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/dolphin_select_star/query.sql b/internal/endtoend/testdata/dolphin_select_star/query.sql new file mode 100644 index 0000000000..e2f85e2a9a --- /dev/null +++ b/internal/endtoend/testdata/dolphin_select_star/query.sql @@ -0,0 +1,2 @@ +/* name: GetAll :many */ +SELECT * FROM users; diff --git a/internal/endtoend/testdata/dolphin_select_star/schema.sql b/internal/endtoend/testdata/dolphin_select_star/schema.sql new file mode 100644 index 0000000000..00fb1b51f3 --- /dev/null +++ b/internal/endtoend/testdata/dolphin_select_star/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE users ( + id integer NOT NULL AUTO_INCREMENT PRIMARY KEY, + first_name varchar(255) NOT NULL, + last_name varchar(255), + age integer NOT NULL +) ENGINE=InnoDB; diff --git a/internal/endtoend/testdata/dolphin_select_star/sqlc.json b/internal/endtoend/testdata/dolphin_select_star/sqlc.json new file mode 100644 index 0000000000..4060b6889b --- /dev/null +++ b/internal/endtoend/testdata/dolphin_select_star/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "_dolphin", + "emit_json_tags": true + } + ] +} diff --git a/internal/source/code.go b/internal/source/code.go index 3fa1a14a15..b84324b55f 100644 --- a/internal/source/code.go +++ b/internal/source/code.go @@ -81,14 +81,24 @@ func StripComments(sql string) (string, []string, error) { s := bufio.NewScanner(strings.NewReader(strings.TrimSpace(sql))) var lines, comments []string for s.Scan() { - if strings.HasPrefix(s.Text(), "-- name:") { + t := s.Text() + if strings.HasPrefix(t, "-- name:") { continue } - if strings.HasPrefix(s.Text(), "--") { - comments = append(comments, strings.TrimPrefix(s.Text(), "--")) + if strings.HasPrefix(t, "/* name:") && strings.HasSuffix(t, "*/") { continue } - lines = append(lines, s.Text()) + if strings.HasPrefix(t, "--") { + comments = append(comments, strings.TrimPrefix(t, "--")) + continue + } + if strings.HasPrefix(t, "/*") && strings.HasSuffix(t, "*/") { + t = strings.TrimPrefix(t, "/*") + t = strings.TrimSuffix(t, "*/") + comments = append(comments, t) + continue + } + lines = append(lines, t) } return strings.Join(lines, "\n"), comments, s.Err() } diff --git a/internal/sql/ast/column_def.go b/internal/sql/ast/column_def.go index db6a78e779..5496d0c2b5 100644 --- a/internal/sql/ast/column_def.go +++ b/internal/sql/ast/column_def.go @@ -5,6 +5,7 @@ type ColumnDef struct { TypeName *TypeName IsNotNull bool IsArray bool + Vals *List } func (n *ColumnDef) Pos() int {