From 4af56db9b103252c48150c9e5e8307f6608f68af Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Tue, 1 Mar 2022 21:57:47 -0800 Subject: [PATCH] feat(sdk): Add the plugin SDK package Collect a few commonly used functions and put them into one package. Move the utils.go file from the codegen package here as well. --- internal/codegen/golang/compat.go | 16 ------ internal/codegen/golang/gen.go | 8 +-- internal/codegen/golang/go_type.go | 50 ++--------------- internal/codegen/golang/mysql_type.go | 11 +--- internal/codegen/golang/postgresql_type.go | 31 +++++++++-- internal/codegen/golang/result.go | 10 ++-- internal/codegen/golang/sqlite_type.go | 3 +- internal/codegen/kotlin/gen.go | 29 ++++------ internal/codegen/kotlin/mysql_type.go | 11 +--- internal/codegen/kotlin/postgresql_type.go | 3 +- internal/codegen/python/gen.go | 63 ++-------------------- internal/codegen/python/postgresql_type.go | 11 +--- internal/codegen/sdk/sdk.go | 59 ++++++++++++++++++++ internal/codegen/{ => sdk}/utils.go | 2 +- internal/codegen/{ => sdk}/utils_test.go | 2 +- 15 files changed, 126 insertions(+), 183 deletions(-) delete mode 100644 internal/codegen/golang/compat.go create mode 100644 internal/codegen/sdk/sdk.go rename internal/codegen/{ => sdk}/utils.go (97%) rename internal/codegen/{ => sdk}/utils_test.go (97%) diff --git a/internal/codegen/golang/compat.go b/internal/codegen/golang/compat.go deleted file mode 100644 index ae2c902f57..0000000000 --- a/internal/codegen/golang/compat.go +++ /dev/null @@ -1,16 +0,0 @@ -package golang - -import ( - "github.com/kyleconroy/sqlc/internal/plugin" -) - -func sameTableName(tableID, f *plugin.Identifier, defaultSchema string) bool { - if tableID == nil { - return false - } - schema := tableID.Schema - if tableID.Schema == "" { - schema = defaultSchema - } - return tableID.Catalog == f.Catalog && schema == f.Schema && tableID.Name == f.Name -} diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index b938675afd..a93490d12e 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -9,7 +9,7 @@ import ( "strings" "text/template" - "github.com/kyleconroy/sqlc/internal/codegen" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/metadata" "github.com/kyleconroy/sqlc/internal/plugin" ) @@ -58,9 +58,9 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie } funcMap := template.FuncMap{ - "lowerTitle": codegen.LowerTitle, - "comment": codegen.DoubleSlashComment, - "escape": codegen.EscapeBacktick, + "lowerTitle": sdk.LowerTitle, + "comment": sdk.DoubleSlashComment, + "escape": sdk.EscapeBacktick, "imports": i.Imports, "hasPrefix": strings.HasPrefix, } diff --git a/internal/codegen/golang/go_type.go b/internal/codegen/golang/go_type.go index 5ea995c84d..2832abbd9c 100644 --- a/internal/codegen/golang/go_type.go +++ b/internal/codegen/golang/go_type.go @@ -1,60 +1,18 @@ package golang import ( - "github.com/kyleconroy/sqlc/internal/pattern" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/plugin" ) -// XXX: These are copied from python codegen. -func matchString(pat, target string) bool { - matcher, err := pattern.MatchCompile(pat) - if err != nil { - panic(err) - } - return matcher.MatchString(target) -} - -func matches(o *plugin.Override, n *plugin.Identifier, defaultSchema string) bool { - if n == nil { - return false - } - - schema := n.Schema - if n.Schema == "" { - schema = defaultSchema - } - - if o.Table.Catalog != "" && !matchString(o.Table.Catalog, n.Catalog) { - return false - } - - if o.Table.Schema == "" && schema != "" { - return false - } - - if o.Table.Schema != "" && !matchString(o.Table.Schema, schema) { - return false - } - - if o.Table.Name == "" && n.Name != "" { - return false - } - - if o.Table.Name != "" && !matchString(o.Table.Name, n.Name) { - return false - } - - return true -} - func goType(req *plugin.CodeGenRequest, col *plugin.Column) string { // Check if the column's type has been overridden for _, oride := range req.Settings.Overrides { if oride.GoType.TypeName == "" { continue } - sameTable := matches(oride, col.Table, req.Catalog.DefaultSchema) - if oride.Column != "" && matchString(oride.ColumnName, col.Name) && sameTable { + sameTable := sdk.Matches(oride, col.Table, req.Catalog.DefaultSchema) + if oride.Column != "" && sdk.MatchString(oride.ColumnName, col.Name) && sameTable { return oride.GoType.TypeName } } @@ -66,7 +24,7 @@ func goType(req *plugin.CodeGenRequest, col *plugin.Column) string { } func goInnerType(req *plugin.CodeGenRequest, col *plugin.Column) string { - columnType := dataType(col.Type) + columnType := sdk.DataType(col.Type) notNull := col.NotNull || col.IsArray // package overrides have a higher precedence diff --git a/internal/codegen/golang/mysql_type.go b/internal/codegen/golang/mysql_type.go index af72ba849b..76d7271774 100644 --- a/internal/codegen/golang/mysql_type.go +++ b/internal/codegen/golang/mysql_type.go @@ -3,20 +3,13 @@ package golang import ( "log" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/debug" "github.com/kyleconroy/sqlc/internal/plugin" ) -func dataType(n *plugin.Identifier) string { - if n.Schema != "" { - return n.Schema + "." + n.Name - } else { - return n.Name - } -} - func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) string { - columnType := dataType(col.Type) + columnType := sdk.DataType(col.Type) notNull := col.NotNull || col.IsArray switch columnType { diff --git a/internal/codegen/golang/postgresql_type.go b/internal/codegen/golang/postgresql_type.go index 1a4b995e6e..afa5a21aa7 100644 --- a/internal/codegen/golang/postgresql_type.go +++ b/internal/codegen/golang/postgresql_type.go @@ -1,15 +1,40 @@ package golang import ( + "fmt" "log" + "strings" - "github.com/kyleconroy/sqlc/internal/compiler" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/debug" "github.com/kyleconroy/sqlc/internal/plugin" ) +func parseIdentifierString(name string) (*plugin.Identifier, error) { + parts := strings.Split(name, ".") + switch len(parts) { + case 1: + return &plugin.Identifier{ + Name: parts[0], + }, nil + case 2: + return &plugin.Identifier{ + Schema: parts[0], + Name: parts[1], + }, nil + case 3: + return &plugin.Identifier{ + Catalog: parts[0], + Schema: parts[1], + Name: parts[2], + }, nil + default: + return nil, fmt.Errorf("invalid name: %s", name) + } +} + func postgresType(req *plugin.CodeGenRequest, col *plugin.Column) string { - columnType := dataType(col.Type) + columnType := sdk.DataType(col.Type) notNull := col.NotNull || col.IsArray driver := parseDriver(req.Settings) @@ -239,7 +264,7 @@ func postgresType(req *plugin.CodeGenRequest, col *plugin.Column) string { return "interface{}" default: - rel, err := compiler.ParseRelationString(columnType) + rel, err := parseIdentifierString(columnType) if err != nil { // TODO: Should this actually return an error here? return "interface{}" diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 4a4b293a4d..1feb800b22 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -5,7 +5,7 @@ import ( "sort" "strings" - "github.com/kyleconroy/sqlc/internal/codegen" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/inflection" "github.com/kyleconroy/sqlc/internal/plugin" ) @@ -140,15 +140,15 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) var constantName string if req.Settings.Go.EmitExportedQueries { - constantName = codegen.Title(query.Name) + constantName = sdk.Title(query.Name) } else { - constantName = codegen.LowerTitle(query.Name) + constantName = sdk.LowerTitle(query.Name) } gq := Query{ Cmd: query.Cmd, ConstantName: constantName, - FieldName: codegen.LowerTitle(query.Name) + "Stmt", + FieldName: sdk.LowerTitle(query.Name) + "Stmt", MethodName: query.Name, SourceName: query.Filename, SQL: query.Text, @@ -209,7 +209,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) c := query.Columns[i] sameName := f.Name == StructName(columnName(c, i), req.Settings) sameType := f.Type == goType(req, c) - sameTable := sameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema) + sameTable := sdk.SameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema) if !sameName || !sameType || !sameTable { same = false } diff --git a/internal/codegen/golang/sqlite_type.go b/internal/codegen/golang/sqlite_type.go index 1a3f8e3465..f26e533522 100644 --- a/internal/codegen/golang/sqlite_type.go +++ b/internal/codegen/golang/sqlite_type.go @@ -4,11 +4,12 @@ import ( "log" "strings" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/plugin" ) func sqliteType(req *plugin.CodeGenRequest, col *plugin.Column) string { - dt := strings.ToLower(dataType(col.Type)) + dt := strings.ToLower(sdk.DataType(col.Type)) notNull := col.NotNull || col.IsArray switch dt { diff --git a/internal/codegen/kotlin/gen.go b/internal/codegen/kotlin/gen.go index 234cde4ebf..dc128a5e7f 100644 --- a/internal/codegen/kotlin/gen.go +++ b/internal/codegen/kotlin/gen.go @@ -10,23 +10,12 @@ import ( "strings" "text/template" - "github.com/kyleconroy/sqlc/internal/codegen" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/inflection" "github.com/kyleconroy/sqlc/internal/metadata" "github.com/kyleconroy/sqlc/internal/plugin" ) -func sameTableName(n, f *plugin.Identifier) bool { - if n == nil { - return false - } - schema := n.Schema - if n.Schema == "" { - schema = "public" - } - return n.Catalog == n.Catalog && schema == f.Schema && n.Name == f.Name -} - var ktIdentPattern = regexp.MustCompile("[^a-zA-Z0-9_]+") type Constant struct { @@ -269,7 +258,7 @@ func dataClassName(name string, settings *plugin.Settings) string { } func memberName(name string, settings *plugin.Settings) string { - return codegen.LowerTitle(dataClassName(name, settings)) + return sdk.LowerTitle(dataClassName(name, settings)) } func buildDataClasses(req *plugin.CodeGenRequest) []Struct { @@ -365,7 +354,7 @@ func makeType(req *plugin.CodeGenRequest, col *plugin.Column) ktType { IsEnum: isEnum, IsArray: col.IsArray, IsNull: !col.NotNull, - DataType: dataType(col.Type), + DataType: sdk.DataType(col.Type), Engine: req.Settings.Engine, } } @@ -468,9 +457,9 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) gq := Query{ Cmd: query.Cmd, ClassName: strings.Title(query.Name), - ConstantName: codegen.LowerTitle(query.Name), - FieldName: codegen.LowerTitle(query.Name) + "Stmt", - MethodName: codegen.LowerTitle(query.Name), + ConstantName: sdk.LowerTitle(query.Name), + FieldName: sdk.LowerTitle(query.Name) + "Stmt", + MethodName: sdk.LowerTitle(query.Name), SourceName: query.Filename, SQL: jdbcSQL(query.Text, req.Settings.Engine), Comments: query.Comments, @@ -507,7 +496,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) c := query.Columns[i] sameName := f.Name == memberName(ktColumnName(c, i), req.Settings) sameType := f.Type == makeType(req, c) - sameTable := sameTableName(c.Table, &s.Table) + sameTable := sdk.SameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema) if !sameName || !sameType || !sameTable { same = false @@ -779,8 +768,8 @@ func Generate(req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) { } funcMap := template.FuncMap{ - "lowerTitle": codegen.LowerTitle, - "comment": codegen.DoubleSlashComment, + "lowerTitle": sdk.LowerTitle, + "comment": sdk.DoubleSlashComment, "imports": i.Imports, "offset": Offset, } diff --git a/internal/codegen/kotlin/mysql_type.go b/internal/codegen/kotlin/mysql_type.go index 69d9dc3d86..1bc50df177 100644 --- a/internal/codegen/kotlin/mysql_type.go +++ b/internal/codegen/kotlin/mysql_type.go @@ -3,20 +3,13 @@ package kotlin import ( "log" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/debug" "github.com/kyleconroy/sqlc/internal/plugin" ) -func dataType(n *plugin.Identifier) string { - if n.Schema != "" { - return n.Schema + "." + n.Name - } else { - return n.Name - } -} - func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) (string, bool) { - columnType := dataType(col.Type) + columnType := sdk.DataType(col.Type) switch columnType { diff --git a/internal/codegen/kotlin/postgresql_type.go b/internal/codegen/kotlin/postgresql_type.go index b2469a7a49..d199ea5d7f 100644 --- a/internal/codegen/kotlin/postgresql_type.go +++ b/internal/codegen/kotlin/postgresql_type.go @@ -3,11 +3,12 @@ package kotlin import ( "log" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/plugin" ) func postgresType(req *plugin.CodeGenRequest, col *plugin.Column) (string, bool) { - columnType := dataType(col.Type) + columnType := sdk.DataType(col.Type) switch columnType { case "serial", "pg_catalog.serial4": diff --git a/internal/codegen/python/gen.go b/internal/codegen/python/gen.go index 51e15045a7..6e394413d2 100644 --- a/internal/codegen/python/gen.go +++ b/internal/codegen/python/gen.go @@ -8,10 +8,9 @@ import ( "sort" "strings" - "github.com/kyleconroy/sqlc/internal/codegen" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/inflection" "github.com/kyleconroy/sqlc/internal/metadata" - "github.com/kyleconroy/sqlc/internal/pattern" "github.com/kyleconroy/sqlc/internal/plugin" pyast "github.com/kyleconroy/sqlc/internal/python/ast" "github.com/kyleconroy/sqlc/internal/python/poet" @@ -192,8 +191,8 @@ func pyInnerType(req *plugin.CodeGenRequest, col *plugin.Column) string { if !pyTypeIsSet(oride.PythonType) { continue } - sameTable := matches(oride, col.Table, req.Catalog.DefaultSchema) - if oride.Column != "" && matchString(oride.ColumnName, col.Name) && sameTable { + sameTable := sdk.Matches(oride, col.Table, req.Catalog.DefaultSchema) + if oride.Column != "" && sdk.MatchString(oride.ColumnName, col.Name) && sameTable { return pyTypeString(oride.PythonType) } if oride.DbType != "" && oride.DbType == col.DataType && oride.Nullable != (col.NotNull || col.IsArray) { @@ -210,47 +209,6 @@ func pyInnerType(req *plugin.CodeGenRequest, col *plugin.Column) string { } } -func matchString(pat, target string) bool { - matcher, err := pattern.MatchCompile(pat) - if err != nil { - panic(err) - } - return matcher.MatchString(target) -} - -func matches(o *plugin.Override, n *plugin.Identifier, defaultSchema string) bool { - if n == nil { - return false - } - - schema := n.Schema - if n.Schema == "" { - schema = defaultSchema - } - - if o.Table.Catalog != "" && !matchString(o.Table.Catalog, n.Catalog) { - return false - } - - if o.Table.Schema == "" && schema != "" { - return false - } - - if o.Table.Schema != "" && !matchString(o.Table.Schema, schema) { - return false - } - - if o.Table.Name == "" && n.Name != "" { - return false - } - - if o.Table.Name != "" && !matchString(o.Table.Name, n.Name) { - return false - } - - return true -} - func modelName(name string, settings *plugin.Settings) string { if rename := settings.Rename[name]; rename != "" { return rename @@ -403,17 +361,6 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []pyColumn return &gs } -func sameTableName(tableID, f *plugin.Identifier, defaultSchema string) bool { - if tableID == nil { - return false - } - schema := tableID.Schema - if tableID.Schema == "" { - schema = defaultSchema - } - return tableID.Catalog == f.Catalog && schema == f.Schema && tableID.Name == f.Name -} - var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$(\d+)\b`) // Sqlalchemy uses ":name" for placeholders, so "$N" is converted to ":pN" @@ -445,7 +392,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) Cmd: query.Cmd, Comments: query.Comments, MethodName: methodName, - FieldName: codegen.LowerTitle(query.Name) + "Stmt", + FieldName: sdk.LowerTitle(query.Name) + "Stmt", ConstantName: strings.ToUpper(methodName), SQL: sqlalchemySQL(query.Text, req.Settings.Engine), SourceName: query.Filename, @@ -498,7 +445,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) trimmedPyType.InnerType = strings.TrimPrefix(trimmedPyType.InnerType, "models.") sameName := f.Name == columnName(c, i) sameType := f.Type == trimmedPyType - sameTable := sameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema) + sameTable := sdk.SameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema) if !sameName || !sameType || !sameTable { same = false } diff --git a/internal/codegen/python/postgresql_type.go b/internal/codegen/python/postgresql_type.go index 9e81c13ad5..12c99f1de3 100644 --- a/internal/codegen/python/postgresql_type.go +++ b/internal/codegen/python/postgresql_type.go @@ -3,19 +3,12 @@ package python import ( "log" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/plugin" ) -func dataType(n *plugin.Identifier) string { - if n.Schema != "" { - return n.Schema + "." + n.Name - } else { - return n.Name - } -} - func postgresType(req *plugin.CodeGenRequest, col *plugin.Column) string { - columnType := dataType(col.Type) + columnType := sdk.DataType(col.Type) switch columnType { case "serial", "serial4", "pg_catalog.serial4", "bigserial", "serial8", "pg_catalog.serial8", "smallserial", "serial2", "pg_catalog.serial2", "integer", "int", "int4", "pg_catalog.int4", "bigint", "int8", "pg_catalog.int8", "smallint", "int2", "pg_catalog.int2": diff --git a/internal/codegen/sdk/sdk.go b/internal/codegen/sdk/sdk.go new file mode 100644 index 0000000000..2f97da1bb8 --- /dev/null +++ b/internal/codegen/sdk/sdk.go @@ -0,0 +1,59 @@ +package sdk + +import ( + "github.com/kyleconroy/sqlc/internal/pattern" + "github.com/kyleconroy/sqlc/internal/plugin" +) + +func DataType(n *plugin.Identifier) string { + if n.Schema != "" { + return n.Schema + "." + n.Name + } else { + return n.Name + } +} + +func MatchString(pat, target string) bool { + matcher, err := pattern.MatchCompile(pat) + if err != nil { + panic(err) + } + return matcher.MatchString(target) +} + +func Matches(o *plugin.Override, n *plugin.Identifier, defaultSchema string) bool { + if n == nil { + return false + } + schema := n.Schema + if n.Schema == "" { + schema = defaultSchema + } + if o.Table.Catalog != "" && !MatchString(o.Table.Catalog, n.Catalog) { + return false + } + if o.Table.Schema == "" && schema != "" { + return false + } + if o.Table.Schema != "" && !MatchString(o.Table.Schema, schema) { + return false + } + if o.Table.Name == "" && n.Name != "" { + return false + } + if o.Table.Name != "" && !MatchString(o.Table.Name, n.Name) { + return false + } + return true +} + +func SameTableName(tableID, f *plugin.Identifier, defaultSchema string) bool { + if tableID == nil { + return false + } + schema := tableID.Schema + if tableID.Schema == "" { + schema = defaultSchema + } + return tableID.Catalog == f.Catalog && schema == f.Schema && tableID.Name == f.Name +} diff --git a/internal/codegen/utils.go b/internal/codegen/sdk/utils.go similarity index 97% rename from internal/codegen/utils.go rename to internal/codegen/sdk/utils.go index 8a4845b009..1dffda9e7e 100644 --- a/internal/codegen/utils.go +++ b/internal/codegen/sdk/utils.go @@ -1,4 +1,4 @@ -package codegen +package sdk import ( "strings" diff --git a/internal/codegen/utils_test.go b/internal/codegen/sdk/utils_test.go similarity index 97% rename from internal/codegen/utils_test.go rename to internal/codegen/sdk/utils_test.go index f4def83436..e16244883a 100644 --- a/internal/codegen/utils_test.go +++ b/internal/codegen/sdk/utils_test.go @@ -1,4 +1,4 @@ -package codegen +package sdk import ( "testing"