diff --git a/internal/codegen/golang/enum.go b/internal/codegen/golang/enum.go index c6a0d1ccbf..c7385bcb77 100644 --- a/internal/codegen/golang/enum.go +++ b/internal/codegen/golang/enum.go @@ -3,6 +3,8 @@ package golang import ( "regexp" "strings" + + "github.com/kyleconroy/sqlc/internal/codegen/sdk" ) var IdentPattern = regexp.MustCompile("[^a-zA-Z0-9_]+") @@ -33,7 +35,7 @@ func EnumValueName(value string) string { id = strings.Replace(id, "/", "_", -1) id = IdentPattern.ReplaceAllString(id, "") for _, part := range strings.Split(id, "_") { - name += strings.Title(part) + name += sdk.Title(part) } return name } diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index 6b22c8798b..298c850f87 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -5,6 +5,7 @@ import ( "sort" "strings" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/plugin" ) @@ -72,7 +73,7 @@ func toCamelInitCase(name string, initUpper bool) string { if p == "id" { out += "ID" } else { - out += strings.Title(p) + out += sdk.Title(p) } } return out diff --git a/internal/codegen/golang/mysql_type.go b/internal/codegen/golang/mysql_type.go index 76d7271774..ad37109bdb 100644 --- a/internal/codegen/golang/mysql_type.go +++ b/internal/codegen/golang/mysql_type.go @@ -87,9 +87,11 @@ func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) string { for _, enum := range schema.Enums { if enum.Name == columnType { if schema.Name == req.Catalog.DefaultSchema { - return StructName(enum.Name, req.Settings) + return StructName(enum.Name, req.Settings.Rename[enum.Name]) } - return StructName(schema.Name+"_"+enum.Name, req.Settings) + + schemaEnumName := schema.Name + "_" + enum.Name + return StructName(schemaEnumName, req.Settings.Rename[schemaEnumName]) } } } diff --git a/internal/codegen/golang/postgresql_type.go b/internal/codegen/golang/postgresql_type.go index afa5a21aa7..5df26fd5b6 100644 --- a/internal/codegen/golang/postgresql_type.go +++ b/internal/codegen/golang/postgresql_type.go @@ -281,9 +281,11 @@ func postgresType(req *plugin.CodeGenRequest, col *plugin.Column) string { for _, enum := range schema.Enums { if rel.Name == enum.Name && rel.Schema == schema.Name { if schema.Name == req.Catalog.DefaultSchema { - return StructName(enum.Name, req.Settings) + return StructName(enum.Name, req.Settings.Rename[enum.Name]) } - return StructName(schema.Name+"_"+enum.Name, req.Settings) + + schemaEnumName := schema.Name + "_" + enum.Name + return StructName(schemaEnumName, req.Settings.Rename[schemaEnumName]) } } diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 1feb800b22..935d2ef402 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -24,7 +24,7 @@ func buildEnums(req *plugin.CodeGenRequest) []Enum { enumName = schema.Name + "_" + enum.Name } e := Enum{ - Name: StructName(enumName, req.Settings), + Name: StructName(enumName, req.Settings.Rename[enumName]), Comment: enum.Comment, } seen := make(map[string]struct{}, len(enum.Vals)) @@ -33,8 +33,11 @@ func buildEnums(req *plugin.CodeGenRequest) []Enum { if _, found := seen[value]; found || value == "" { value = fmt.Sprintf("value_%d", i) } + + nameWithValue := enumName + "_" + value + e.Constants = append(e.Constants, Constant{ - Name: StructName(enumName+"_"+value, req.Settings), + Name: StructName(nameWithValue, req.Settings.Rename[nameWithValue]), Value: v, Type: e.Name, }) @@ -68,7 +71,7 @@ func buildStructs(req *plugin.CodeGenRequest) []Struct { } s := Struct{ Table: plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name}, - Name: StructName(structName, req.Settings), + Name: StructName(structName, req.Settings.Rename[structName]), Comment: table.Comment, } for _, column := range table.Columns { @@ -80,7 +83,7 @@ func buildStructs(req *plugin.CodeGenRequest) []Struct { tags["json:"] = JSONTagName(column.Name, req.Settings) } s.Fields = append(s.Fields, Field{ - Name: StructName(column.Name, req.Settings), + Name: StructName(column.Name, req.Settings.Rename[column.Name]), Type: goType(req, column), Tags: tags, Comment: column.Comment, @@ -122,7 +125,7 @@ func argName(name string) string { } else if p == "id" { out += "ID" } else { - out += strings.Title(p) + out += sdk.Title(p) } } return out @@ -207,7 +210,8 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) same := true for i, f := range s.Fields { c := query.Columns[i] - sameName := f.Name == StructName(columnName(c, i), req.Settings) + colName := columnName(c, i) + sameName := f.Name == StructName(colName, req.Settings.Rename[colName]) sameType := f.Type == goType(req, c) sameTable := sdk.SameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema) if !sameName || !sameType || !sameTable { @@ -266,7 +270,7 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColumn for i, c := range columns { colName := columnName(c.Column, i) tagName := colName - fieldName := StructName(colName, req.Settings) + fieldName := StructName(colName, req.Settings.Rename[colName]) baseFieldName := fieldName // Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be // reused. diff --git a/internal/codegen/golang/struct.go b/internal/codegen/golang/struct.go index f72a228ae3..fdd8037803 100644 --- a/internal/codegen/golang/struct.go +++ b/internal/codegen/golang/struct.go @@ -5,6 +5,7 @@ import ( "unicode" "unicode/utf8" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/plugin" ) @@ -15,16 +16,20 @@ type Struct struct { Comment string } -func StructName(name string, settings *plugin.Settings) string { - if rename := settings.Rename[name]; rename != "" { +// StructName constructs a valid camel case value from a snake case +func StructName(name, rename string) string { + + if rename != "" { return rename } + out := "" + for _, p := range strings.Split(name, "_") { if p == "id" { out += "ID" } else { - out += strings.Title(p) + out += sdk.Title(p) } } diff --git a/internal/codegen/kotlin/gen.go b/internal/codegen/kotlin/gen.go index 275b76c2b2..00875f870a 100644 --- a/internal/codegen/kotlin/gen.go +++ b/internal/codegen/kotlin/gen.go @@ -252,7 +252,7 @@ func dataClassName(name string, settings *plugin.Settings) string { } out := "" for _, p := range strings.Split(name, "_") { - out += strings.Title(p) + out += sdk.Title(p) } return out } @@ -409,7 +409,7 @@ func ktArgName(name string) string { if i == 0 { out += strings.ToLower(p) } else { - out += strings.Title(p) + out += sdk.Title(p) } } return out @@ -456,7 +456,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) gq := Query{ Cmd: query.Cmd, - ClassName: strings.Title(query.Name), + ClassName: sdk.Title(query.Name), ConstantName: sdk.LowerTitle(query.Name), FieldName: sdk.LowerTitle(query.Name) + "Stmt", MethodName: sdk.LowerTitle(query.Name), diff --git a/internal/codegen/python/gen.go b/internal/codegen/python/gen.go index 6bdb5491b1..07a13d78bf 100644 --- a/internal/codegen/python/gen.go +++ b/internal/codegen/python/gen.go @@ -216,7 +216,7 @@ func modelName(name string, settings *plugin.Settings) string { } out := "" for _, p := range strings.Split(name, "_") { - out += strings.Title(p) + out += sdk.Title(p) } return out } diff --git a/internal/codegen/sdk/utils.go b/internal/codegen/sdk/utils.go index 1dffda9e7e..860902a965 100644 --- a/internal/codegen/sdk/utils.go +++ b/internal/codegen/sdk/utils.go @@ -3,6 +3,9 @@ package sdk import ( "strings" "unicode" + + "golang.org/x/text/cases" + "golang.org/x/text/language" ) func LowerTitle(s string) string { @@ -16,7 +19,19 @@ func LowerTitle(s string) string { } func Title(s string) string { - return strings.Title(s) + + if s == "" { + return s + } + + // If the first character is a digit return s + // + // When a string starts with a digit cases.Title skips all the digits and title case + // the first character it finds. + if unicode.IsDigit(rune(s[0])) { + return s + } + return cases.Title(language.English, cases.NoLower).String(s) } // Go string literals cannot contain backtick. If a string contains diff --git a/internal/codegen/sdk/utils_test.go b/internal/codegen/sdk/utils_test.go index e16244883a..c415d306c9 100644 --- a/internal/codegen/sdk/utils_test.go +++ b/internal/codegen/sdk/utils_test.go @@ -6,23 +6,94 @@ import ( func TestLowerTitle(t *testing.T) { - // empty - if LowerTitle("") != "" { - t.Fatal("expected empty title to remain empty") + testCases := []struct { + name string + value string + out string + err string + }{ + { + name: "Empty", + value: "", + out: "", + err: "expected empty title to remain empty", + }, + { + name: "All Lowercase", + value: "lowercase", + out: "lowercase", + err: "expected no changes when input is all lowercase", + }, + { + name: "All Uppercase", + value: "UPPERCASE", + out: "uPPERCASE", + err: "expected first rune to be lower when input is all uppercase", + }, + { + name: "Title Case", + value: "Title Case", + out: "title Case", + err: "expected first rune to be lower when input is Title Case", + }, } - // all lowercase - if LowerTitle("lowercase") != "lowercase" { - t.Fatal("expected no changes when input is all lowercase") + for i := range testCases { + + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + out := LowerTitle(tc.value) + if out != tc.out { + t.Fatal(tc.err) + } + }) } +} - // all uppercase - if LowerTitle("UPPERCASE") != "uPPERCASE" { - t.Fatal("expected first rune to be lower when input is all uppercase") +func TestTitle(t *testing.T) { + + testCases := []struct { + name string + value string + out string + err string + }{ + { + name: "Empty", + value: "", + out: "", + err: "expected empty title to remain empty", + }, + { + name: "Lowercase", + value: "lowercase", + out: "Lowercase", + err: "expected frist rune to be uppercase", + }, + { + name: "CamelCase", + value: "camelCase", + out: "CamelCase", + err: "expected only first rune to be converted to uppercase", + }, + { + name: "Digit Prefix", + value: "1title", + out: "1title", + err: "expected 1title to remain 1title", + }, } - // Title Case - if LowerTitle("Title Case") != "title Case" { - t.Fatal("expected first rune to be lower when input is Title Case") + for i := range testCases { + + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + out := Title(tc.value) + if out != tc.out { + t.Fatal(tc.err) + } + }) } } diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 57a8e48c59..3a09318b17 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -9,6 +9,7 @@ import ( "regexp" "strings" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/metadata" "github.com/kyleconroy/sqlc/internal/migrations" "github.com/kyleconroy/sqlc/internal/multierr" @@ -32,7 +33,7 @@ func structName(name string) string { if p == "id" { out += "ID" } else { - out += strings.Title(p) + out += sdk.Title(p) } } return out @@ -47,7 +48,7 @@ func enumValueName(value string) string { id = strings.Replace(id, "/", "_", -1) id = identPattern.ReplaceAllString(id, "") for _, part := range strings.Split(id, "_") { - name += strings.Title(part) + name += sdk.Title(part) } return name } diff --git a/internal/tools/sqlc-pg-gen/main.go b/internal/tools/sqlc-pg-gen/main.go index 990920d1d0..b7f5623045 100644 --- a/internal/tools/sqlc-pg-gen/main.go +++ b/internal/tools/sqlc-pg-gen/main.go @@ -13,6 +13,7 @@ import ( pgx "github.com/jackc/pgx/v4" + "github.com/kyleconroy/sqlc/internal/codegen/sdk" "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/catalog" ) @@ -262,7 +263,7 @@ func run(ctx context.Context) error { var funcName string for _, part := range strings.Split(name, "_") { - funcName += strings.Title(part) + funcName += sdk.Title(part) } _, err := conn.Exec(ctx, fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS \"%s\"", extension))