diff --git a/internal/dinosql/gen.go b/internal/dinosql/gen.go index 9da593b974..ea71bc2348 100644 --- a/internal/dinosql/gen.go +++ b/internal/dinosql/gen.go @@ -18,7 +18,7 @@ import ( "github.com/jinzhu/inflection" ) -var identPattern = regexp.MustCompile("[^a-zA-Z0-9]+") +var identPattern = regexp.MustCompile("[^a-zA-Z0-9_]+") type GoConstant struct { Name string @@ -388,6 +388,18 @@ func (r Result) QueryImports(filename string) [][]string { return [][]string{stds, pkgs} } +func enumValueName(value string) string { + name := "" + id := strings.Replace(value, "-", "_", -1) + id = strings.Replace(id, ":", "_", -1) + id = strings.Replace(id, "/", "_", -1) + id = identPattern.ReplaceAllString(id, "") + for _, part := range strings.Split(id, "_") { + name += strings.Title(part) + } + return name +} + func (r Result) Enums() []GoEnum { var enums []GoEnum for name, schema := range r.Catalog.Schemas { @@ -406,16 +418,8 @@ func (r Result) Enums() []GoEnum { Comment: enum.Comment, } for _, v := range enum.Vals { - name := "" - id := strings.Replace(v, "-", "_", -1) - id = strings.Replace(id, ":", "_", -1) - id = strings.Replace(id, "/", "_", -1) - id = identPattern.ReplaceAllString(id, "") - for _, part := range strings.Split(id, "_") { - name += strings.Title(part) - } e.Constants = append(e.Constants, GoConstant{ - Name: e.Name + name, + Name: e.Name + enumValueName(v), Value: v, Type: e.Name, }) diff --git a/internal/dinosql/gen_test.go b/internal/dinosql/gen_test.go index 3920ee8441..2383fbe590 100644 --- a/internal/dinosql/gen_test.go +++ b/internal/dinosql/gen_test.go @@ -143,3 +143,27 @@ func TestNullInnerType(t *testing.T) { }) } } + +func TestEnumValueName(t *testing.T) { + values := map[string]string{ + // Valid separators + "foo-bar": "FooBar", + "foo_bar": "FooBar", + "foo:bar": "FooBar", + "foo/bar": "FooBar", + // Strip unknown characters + "foo@bar": "Foobar", + "foo+bar": "Foobar", + "foo!bar": "Foobar", + } + for k, v := range values { + input := k + expected := v + t.Run(k+"-"+v, func(t *testing.T) { + actual := enumValueName(k) + if actual != expected { + t.Errorf("expected name for %s to be %s, not %s", input, expected, actual) + } + }) + } +}