Skip to content

Commit f57a807

Browse files
authored
Merge pull request #23 from haya14busa/sum-type
Make node types more strict
2 parents 84056cc + 253f161 commit f57a807

File tree

2 files changed

+102
-40
lines changed

2 files changed

+102
-40
lines changed

ast/node.go

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@ type Node interface {
1414
// Statement is the interface for statement (Ex command or Comment).
1515
type Statement interface {
1616
Node
17+
stmtNode()
1718
}
1819

1920
// ExCommand is the interface for Ex-command.
2021
type ExCommand interface {
21-
Node
22+
Statement
2223
Cmd() Cmd
2324
}
2425

2526
// Expr is the interface for expression.
2627
type Expr interface {
2728
Node
29+
exprNode()
2830
}
2931

3032
// File node represents a Vim script source file.
@@ -39,6 +41,7 @@ func (f *File) Pos() Pos { return f.Start }
3941

4042
// vimlparser: COMMENT .str
4143
type Comment struct {
44+
Statement
4245
Quote Pos // position of `"` starting the comment
4346
Text string // comment text (excluding '\n')
4447
}
@@ -480,7 +483,7 @@ type CurlyName struct {
480483
func (c *CurlyName) Pos() Pos { return c.CurlyName }
481484

482485
type CurlyNamePart interface {
483-
Node
486+
Expr
484487
IsCurlyExpr() bool
485488
}
486489

@@ -522,3 +525,55 @@ type LambdaExpr struct {
522525
}
523526

524527
func (i *LambdaExpr) Pos() Pos { return i.Lcurlybrace }
528+
529+
// stmtNode() ensures that only ExComamnd and Comment nodes can be assigned to
530+
// an Statement.
531+
//
532+
func (*Break) stmtNode() {}
533+
func (*Catch) stmtNode() {}
534+
func (*Continue) stmtNode() {}
535+
func (DelFunction) stmtNode() {}
536+
func (*EchoCmd) stmtNode() {}
537+
func (*Echohl) stmtNode() {}
538+
func (*Else) stmtNode() {}
539+
func (*ElseIf) stmtNode() {}
540+
func (*EndFor) stmtNode() {}
541+
func (EndFunction) stmtNode() {}
542+
func (*EndIf) stmtNode() {}
543+
func (*EndTry) stmtNode() {}
544+
func (*EndWhile) stmtNode() {}
545+
func (*ExCall) stmtNode() {}
546+
func (Excmd) stmtNode() {}
547+
func (*Execute) stmtNode() {}
548+
func (*Finally) stmtNode() {}
549+
func (*For) stmtNode() {}
550+
func (Function) stmtNode() {}
551+
func (*If) stmtNode() {}
552+
func (*Let) stmtNode() {}
553+
func (*LockVar) stmtNode() {}
554+
func (*Return) stmtNode() {}
555+
func (*Throw) stmtNode() {}
556+
func (*Try) stmtNode() {}
557+
func (*UnLet) stmtNode() {}
558+
func (*UnLockVar) stmtNode() {}
559+
func (*While) stmtNode() {}
560+
561+
func (*Comment) stmtNode() {}
562+
563+
// exprNode() ensures that only expression nodes can be assigned to an Expr.
564+
//
565+
func (*TernaryExpr) exprNode() {}
566+
func (*BinaryExpr) exprNode() {}
567+
func (*UnaryExpr) exprNode() {}
568+
func (*SubscriptExpr) exprNode() {}
569+
func (*SliceExpr) exprNode() {}
570+
func (*CallExpr) exprNode() {}
571+
func (*DotExpr) exprNode() {}
572+
func (*BasicLit) exprNode() {}
573+
func (*List) exprNode() {}
574+
func (*Dict) exprNode() {}
575+
func (*CurlyName) exprNode() {}
576+
func (*CurlyNameLit) exprNode() {}
577+
func (*CurlyNameExpr) exprNode() {}
578+
func (*Ident) exprNode() {}
579+
func (*LambdaExpr) exprNode() {}

go/export.go

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ import (
77
"github.com/haya14busa/go-vimlparser/token"
88
)
99

10-
func (self *VimLParser) Parse(reader *StringReader, filename string) ast.Node {
11-
return newAstNode(self.parse(reader), filename)
10+
// Parse parses Vim script in reader and returns Node.
11+
func (p *VimLParser) Parse(reader *StringReader, filename string) ast.Node {
12+
return newAstNode(p.parse(reader), filename)
1213
}
1314

14-
func (self *ExprParser) Parse() ast.Node {
15-
return newAstNode(self.parse(), "")
15+
// Parse parses Vim script expression.
16+
func (p *ExprParser) Parse() ast.Expr {
17+
return newExprNode(p.parse(), "")
1618
}
1719

1820
// ----
@@ -65,7 +67,7 @@ func newAstNode(n *VimNode, filename string) ast.Node {
6567
Func: pos,
6668
ExArg: newExArg(*n.ea, filename),
6769
Body: newBody(*n, filename),
68-
Name: newAstNode(n.left, filename),
70+
Name: newExprNode(n.left, filename),
6971
Params: newIdents(*n, filename),
7072
Attr: attr,
7173
EndFunction: newAstNode(n.endfunction, filename).(*ast.EndFunction),
@@ -81,14 +83,14 @@ func newAstNode(n *VimNode, filename string) ast.Node {
8183
return &ast.DelFunction{
8284
DelFunc: pos,
8385
ExArg: newExArg(*n.ea, filename),
84-
Name: newAstNode(n.left, filename),
86+
Name: newExprNode(n.left, filename),
8587
}
8688

8789
case NODE_RETURN:
8890
return &ast.Return{
8991
Return: pos,
9092
ExArg: newExArg(*n.ea, filename),
91-
Result: newAstNode(n.left, filename),
93+
Result: newExprNode(n.left, filename),
9294
}
9395

9496
case NODE_EXCALL:
@@ -103,10 +105,10 @@ func newAstNode(n *VimNode, filename string) ast.Node {
103105
Let: pos,
104106
ExArg: newExArg(*n.ea, filename),
105107
Op: n.op,
106-
Left: newAstNode(n.left, filename),
108+
Left: newExprNode(n.left, filename),
107109
List: newList(*n, filename),
108-
Rest: newAstNode(n.rest, filename),
109-
Right: newAstNode(n.right, filename),
110+
Rest: newExprNode(n.rest, filename),
111+
Right: newExprNode(n.right, filename),
110112
}
111113

112114
case NODE_UNLET:
@@ -150,7 +152,7 @@ func newAstNode(n *VimNode, filename string) ast.Node {
150152
If: pos,
151153
ExArg: newExArg(*n.ea, filename),
152154
Body: newBody(*n, filename),
153-
Condition: newAstNode(n.cond, filename),
155+
Condition: newExprNode(n.cond, filename),
154156
ElseIf: elifs,
155157
Else: els,
156158
EndIf: newAstNode(n.endif, filename).(*ast.EndIf),
@@ -161,7 +163,7 @@ func newAstNode(n *VimNode, filename string) ast.Node {
161163
ElseIf: pos,
162164
ExArg: newExArg(*n.ea, filename),
163165
Body: newBody(*n, filename),
164-
Condition: newAstNode(n.cond, filename),
166+
Condition: newExprNode(n.cond, filename),
165167
}
166168

167169
case NODE_ELSE:
@@ -182,7 +184,7 @@ func newAstNode(n *VimNode, filename string) ast.Node {
182184
While: pos,
183185
ExArg: newExArg(*n.ea, filename),
184186
Body: newBody(*n, filename),
185-
Condition: newAstNode(n.cond, filename),
187+
Condition: newExprNode(n.cond, filename),
186188
EndWhile: newAstNode(n.endwhile, filename).(*ast.EndWhile),
187189
}
188190

@@ -197,10 +199,10 @@ func newAstNode(n *VimNode, filename string) ast.Node {
197199
For: pos,
198200
ExArg: newExArg(*n.ea, filename),
199201
Body: newBody(*n, filename),
200-
Left: newAstNode(n.left, filename),
202+
Left: newExprNode(n.left, filename),
201203
List: newList(*n, filename),
202-
Rest: newAstNode(n.rest, filename),
203-
Right: newAstNode(n.right, filename),
204+
Rest: newExprNode(n.rest, filename),
205+
Right: newExprNode(n.right, filename),
204206
EndFor: newAstNode(n.endfor, filename).(*ast.EndFor),
205207
}
206208

@@ -270,7 +272,7 @@ func newAstNode(n *VimNode, filename string) ast.Node {
270272
return &ast.Throw{
271273
Throw: pos,
272274
ExArg: newExArg(*n.ea, filename),
273-
Expr: newAstNode(n.left, filename),
275+
Expr: newExprNode(n.left, filename),
274276
}
275277

276278
case NODE_ECHO, NODE_ECHON, NODE_ECHOMSG, NODE_ECHOERR:
@@ -298,9 +300,9 @@ func newAstNode(n *VimNode, filename string) ast.Node {
298300
case NODE_TERNARY:
299301
return &ast.TernaryExpr{
300302
Ternary: pos,
301-
Condition: newAstNode(n.cond, filename),
302-
Left: newAstNode(n.left, filename),
303-
Right: newAstNode(n.right, filename),
303+
Condition: newExprNode(n.cond, filename),
304+
Left: newExprNode(n.left, filename),
305+
Right: newExprNode(n.right, filename),
304306
}
305307

306308
case NODE_OR, NODE_AND, NODE_EQUAL, NODE_EQUALCI, NODE_EQUALCS,
@@ -313,44 +315,44 @@ func newAstNode(n *VimNode, filename string) ast.Node {
313315
NODE_ISNOTCI, NODE_ISNOTCS, NODE_ADD, NODE_SUBTRACT, NODE_CONCAT,
314316
NODE_MULTIPLY, NODE_DIVIDE, NODE_REMAINDER:
315317
return &ast.BinaryExpr{
316-
Left: newAstNode(n.left, filename),
318+
Left: newExprNode(n.left, filename),
317319
OpPos: pos,
318320
Op: opToken(n.type_),
319-
Right: newAstNode(n.right, filename),
321+
Right: newExprNode(n.right, filename),
320322
}
321323

322324
case NODE_NOT, NODE_MINUS, NODE_PLUS:
323325
return &ast.UnaryExpr{
324326
OpPos: pos,
325327
Op: opToken(n.type_),
326-
X: newAstNode(n.left, filename),
328+
X: newExprNode(n.left, filename),
327329
}
328330

329331
case NODE_SUBSCRIPT:
330332
return &ast.SubscriptExpr{
331333
Lbrack: pos,
332-
Left: newAstNode(n.left, filename),
333-
Right: newAstNode(n.right, filename),
334+
Left: newExprNode(n.left, filename),
335+
Right: newExprNode(n.right, filename),
334336
}
335337

336338
case NODE_SLICE:
337339
return &ast.SliceExpr{
338340
Lbrack: pos,
339-
X: newAstNode(n.left, filename),
340-
Low: newAstNode(n.rlist[0], filename),
341-
High: newAstNode(n.rlist[1], filename),
341+
X: newExprNode(n.left, filename),
342+
Low: newExprNode(n.rlist[0], filename),
343+
High: newExprNode(n.rlist[1], filename),
342344
}
343345

344346
case NODE_CALL:
345347
return &ast.CallExpr{
346348
Lparen: pos,
347-
Fun: newAstNode(n.left, filename),
349+
Fun: newExprNode(n.left, filename),
348350
Args: newRlist(*n, filename),
349351
}
350352

351353
case NODE_DOT:
352354
return &ast.DotExpr{
353-
Left: newAstNode(n.left, filename),
355+
Left: newExprNode(n.left, filename),
354356
Dot: pos,
355357
Right: newAstNode(n.right, filename).(*ast.Ident),
356358
}
@@ -378,8 +380,8 @@ func newAstNode(n *VimNode, filename string) ast.Node {
378380
kvs := make([]ast.KeyValue, 0, len(entries))
379381
for _, nn := range entries {
380382
kv := nn.([]interface{})
381-
k := newAstNode(kv[0].(*VimNode), filename)
382-
v := newAstNode(kv[1].(*VimNode), filename)
383+
k := newExprNode(kv[0].(*VimNode), filename)
384+
v := newExprNode(kv[1].(*VimNode), filename)
383385
kvs = append(kvs, ast.KeyValue{Key: k, Value: v})
384386
}
385387
return &ast.Dict{
@@ -434,20 +436,25 @@ func newAstNode(n *VimNode, filename string) ast.Node {
434436
n := n.value.(*VimNode)
435437
return &ast.CurlyNameExpr{
436438
CurlyNameExpr: pos,
437-
Value: newAstNode(n, filename),
439+
Value: newExprNode(n, filename),
438440
}
439441

440442
case NODE_LAMBDA:
441443
return &ast.LambdaExpr{
442444
Lcurlybrace: pos,
443445
Params: newIdents(*n, filename),
444-
Expr: newAstNode(n.left, filename),
446+
Expr: newExprNode(n.left, filename),
445447
}
446448

447449
}
448450
panic(fmt.Errorf("Unknown node type: %v, node: %v", n.type_, n))
449451
}
450452

453+
func newExprNode(n *VimNode, filename string) ast.Expr {
454+
node, _ := newAstNode(n, filename).(ast.Expr)
455+
return node
456+
}
457+
451458
func newPos(p *pos, filename string) *ast.Pos {
452459
if p == nil {
453460
return nil
@@ -508,7 +515,7 @@ func newBody(n VimNode, filename string) []ast.Statement {
508515
}
509516
for _, node := range n.body {
510517
if node != nil { // conservative
511-
body = append(body, newAstNode(node, filename))
518+
body = append(body, newAstNode(node, filename).(ast.Statement))
512519
}
513520
}
514521
return body
@@ -534,7 +541,7 @@ func newRlist(n VimNode, filename string) []ast.Expr {
534541
}
535542
for _, node := range n.rlist {
536543
if node != nil { // conservative
537-
exprs = append(exprs, newAstNode(node, filename))
544+
exprs = append(exprs, newExprNode(node, filename))
538545
}
539546
}
540547
return exprs
@@ -547,7 +554,7 @@ func newList(n VimNode, filename string) []ast.Expr {
547554
}
548555
for _, node := range n.list {
549556
if node != nil { // conservative
550-
list = append(list, newAstNode(node, filename))
557+
list = append(list, newExprNode(node, filename))
551558
}
552559
}
553560
return list
@@ -557,7 +564,7 @@ func newValues(n VimNode, filename string) []ast.Expr {
557564
var values []ast.Expr
558565
for _, v := range n.value.([]interface{}) {
559566
n := v.(*VimNode)
560-
values = append(values, newAstNode(n, filename))
567+
values = append(values, newExprNode(n, filename))
561568
}
562569
return values
563570
}

0 commit comments

Comments
 (0)