Skip to content

Commit 41261b5

Browse files
committed
fix(compiler): Fix column expansion to work with quoted non-keyword identifiers
close #2575
1 parent 320800a commit 41261b5

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

internal/compiler/expand.go

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,7 @@ func (c *Compiler) expand(qc *QueryCatalog, raw *ast.RawStmt) ([]source.Edit, er
3838

3939
func (c *Compiler) quoteIdent(ident string) string {
4040
if c.parser.IsReservedKeyword(ident) {
41-
switch c.conf.Engine {
42-
case config.EngineMySQL:
43-
return "`" + ident + "`"
44-
default:
45-
return "\"" + ident + "\""
46-
}
41+
return c.quote(ident)
4742
}
4843
if c.conf.Engine == config.EnginePostgreSQL {
4944
// camelCase means the column is also camelCase
@@ -54,6 +49,15 @@ func (c *Compiler) quoteIdent(ident string) string {
5449
return ident
5550
}
5651

52+
func (c *Compiler) quote(x string) string {
53+
switch c.conf.Engine {
54+
case config.EngineMySQL:
55+
return "`" + x + "`"
56+
default:
57+
return "\"" + x + "\""
58+
}
59+
}
60+
5761
func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edit, error) {
5862
tables, err := c.sourceTables(qc, node)
5963
if err != nil {
@@ -132,16 +136,36 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node)
132136
for _, p := range parts {
133137
old = append(old, c.quoteIdent(p))
134138
}
135-
oldString := strings.Join(old, ".")
139+
140+
var oldString string
141+
var oldFunc func(string) int
136142

137143
// use the sqlc.embed string instead
138144
if embed, ok := qc.embeds.Find(ref); ok {
139145
oldString = embed.Orig()
146+
} else {
147+
oldFunc = func(s string) int {
148+
length := 0
149+
for i, o := range old {
150+
if hasSeparator := i > 0; hasSeparator {
151+
length++
152+
}
153+
if strings.HasPrefix(s[length:], o) {
154+
length += len(o)
155+
} else if quoted := c.quote(o); strings.HasPrefix(s[length:], quoted) {
156+
length += len(quoted)
157+
} else {
158+
length += len(o)
159+
}
160+
}
161+
return length
162+
}
140163
}
141164

142165
edits = append(edits, source.Edit{
143166
Location: res.Location - raw.StmtLocation,
144167
Old: oldString,
168+
OldFunc: oldFunc,
145169
New: strings.Join(cols, ", "),
146170
})
147171
}

internal/engine/sqlite/convert.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ func (c *cc) convertInsert_stmtContext(n *parser.Insert_stmtContext) ast.Node {
767767
rel.Schemaname = &schemaName
768768
}
769769
if n.Table_alias() != nil {
770-
tableAlias := n.Table_alias().GetText()
770+
tableAlias := identifier(n.Table_alias().GetText())
771771
rel.Alias = &ast.Alias{
772772
Aliasname: &tableAlias,
773773
}
@@ -837,7 +837,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
837837
rv.Schemaname = &schema
838838
}
839839
if from.Table_alias() != nil {
840-
alias := from.Table_alias().GetText()
840+
alias := identifier(from.Table_alias().GetText())
841841
rv.Alias = &ast.Alias{Aliasname: &alias}
842842
}
843843
if from.Table_alias_fallback() != nil {
@@ -870,7 +870,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
870870
}
871871

872872
if from.Table_alias() != nil {
873-
alias := from.Table_alias().GetText()
873+
alias := identifier(from.Table_alias().GetText())
874874
rf.Alias = &ast.Alias{Aliasname: &alias}
875875
}
876876

@@ -881,7 +881,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
881881
}
882882

883883
if from.Table_alias() != nil {
884-
alias := from.Table_alias().GetText()
884+
alias := identifier(from.Table_alias().GetText())
885885
rs.Alias = &ast.Alias{Aliasname: &alias}
886886
}
887887

internal/source/code.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ type Edit struct {
1212
Location int
1313
Old string
1414
New string
15+
OldFunc func(string) int
1516
}
1617

1718
func LineNumber(source string, head int) (int, int) {
@@ -63,8 +64,14 @@ func Mutate(raw string, a []Edit) (string, error) {
6364
if start > len(s) || start < 0 {
6465
return "", fmt.Errorf("edit start location is out of bounds")
6566
}
67+
var oldLen int
68+
if edit.OldFunc != nil {
69+
oldLen = edit.OldFunc(s[start:])
70+
} else {
71+
oldLen = len(edit.Old)
72+
}
6673

67-
stop := edit.Location + len(edit.Old)
74+
stop := edit.Location + oldLen
6875
if stop > len(s) {
6976
return "", fmt.Errorf("edit stop location is out of bounds")
7077
}
@@ -73,7 +80,7 @@ func Mutate(raw string, a []Edit) (string, error) {
7380
// this edit overlaps the previous one (and is therefore a developer error)
7481
if idx != 0 {
7582
prevEdit := a[idx-1]
76-
if prevEdit.Location < edit.Location+len(edit.Old) {
83+
if prevEdit.Location < edit.Location+oldLen {
7784
return "", fmt.Errorf("2 edits overlap")
7885
}
7986
}

0 commit comments

Comments
 (0)