Skip to content

Commit a7552bc

Browse files
adonovangopherbot
authored andcommitted
go/ast/inspector: add PreorderSeq and All iterators
This CL adds two functions that return go1.23 iterators over selected nodes: - inspect.PreorderSeq(types...), which is like Preorder but honors the break statement; - All[N](inspect), which iterates over nodes of a single type, determined by the type parameter. + Tests, benchmark. Fixes golang/go#67795 Change-Id: I77817f2e595846cf3ce29dc347ae895e927fc805 Reviewed-on: https://go-review.googlesource.com/c/tools/+/616218 Auto-Submit: Alan Donovan <adonovan@google.com> Commit-Queue: Alan Donovan <adonovan@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Robert Findley <rfindley@google.com>
1 parent d2e4621 commit a7552bc

File tree

4 files changed

+179
-1
lines changed

4 files changed

+179
-1
lines changed

go/ast/inspector/inspector.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ func (in *Inspector) Preorder(types []ast.Node, f func(ast.Node)) {
7373
// check, Preorder is almost twice as fast as Nodes. The two
7474
// features seem to contribute similar slowdowns (~1.4x each).
7575

76+
// This function is equivalent to the PreorderSeq call below,
77+
// but to avoid the additional dynamic call (which adds 13-35%
78+
// to the benchmarks), we expand it out.
79+
//
80+
// in.PreorderSeq(types...)(func(n ast.Node) bool {
81+
// f(n)
82+
// return true
83+
// })
84+
7685
mask := maskOf(types)
7786
for i := 0; i < len(in.events); {
7887
ev := in.events[i]

go/ast/inspector/inspector_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ func TestInspectPruning(t *testing.T) {
160160
compare(t, nodesA, nodesB)
161161
}
162162

163-
func compare(t *testing.T, nodesA, nodesB []ast.Node) {
163+
// compare calls t.Error if !slices.Equal(nodesA, nodesB).
164+
func compare[N comparable](t *testing.T, nodesA, nodesB []N) {
164165
if len(nodesA) != len(nodesB) {
165166
t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB))
166167
} else {

go/ast/inspector/iter.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build go1.23
6+
7+
package inspector
8+
9+
import (
10+
"go/ast"
11+
"iter"
12+
)
13+
14+
// PreorderSeq returns an iterator that visits all the
15+
// nodes of the files supplied to New in depth-first order.
16+
// It visits each node n before n's children.
17+
// The complete traversal sequence is determined by ast.Inspect.
18+
//
19+
// The types argument, if non-empty, enables type-based
20+
// filtering of events: only nodes whose type matches an
21+
// element of the types slice are included in the sequence.
22+
func (in *Inspector) PreorderSeq(types ...ast.Node) iter.Seq[ast.Node] {
23+
24+
// This implementation is identical to Preorder,
25+
// except that it supports breaking out of the loop.
26+
27+
return func(yield func(ast.Node) bool) {
28+
mask := maskOf(types)
29+
for i := 0; i < len(in.events); {
30+
ev := in.events[i]
31+
if ev.index > i {
32+
// push
33+
if ev.typ&mask != 0 {
34+
if !yield(ev.node) {
35+
break
36+
}
37+
}
38+
pop := ev.index
39+
if in.events[pop].typ&mask == 0 {
40+
// Subtrees do not contain types: skip them and pop.
41+
i = pop + 1
42+
continue
43+
}
44+
}
45+
i++
46+
}
47+
}
48+
}
49+
50+
// All[N] returns an iterator over all the nodes of type N.
51+
// N must be a pointer-to-struct type that implements ast.Node.
52+
//
53+
// Example:
54+
//
55+
// for call := range All[*ast.CallExpr](in) { ... }
56+
func All[N interface {
57+
*S
58+
ast.Node
59+
}, S any](in *Inspector) iter.Seq[N] {
60+
61+
// To avoid additional dynamic call overheads,
62+
// we duplicate rather than call the logic of PreorderSeq.
63+
64+
mask := typeOf((N)(nil))
65+
return func(yield func(N) bool) {
66+
for i := 0; i < len(in.events); {
67+
ev := in.events[i]
68+
if ev.index > i {
69+
// push
70+
if ev.typ&mask != 0 {
71+
if !yield(ev.node.(N)) {
72+
break
73+
}
74+
}
75+
pop := ev.index
76+
if in.events[pop].typ&mask == 0 {
77+
// Subtrees do not contain types: skip them and pop.
78+
i = pop + 1
79+
continue
80+
}
81+
}
82+
i++
83+
}
84+
}
85+
}

go/ast/inspector/iter_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build go1.23
6+
7+
package inspector_test
8+
9+
import (
10+
"go/ast"
11+
"iter"
12+
"slices"
13+
"testing"
14+
15+
"golang.org/x/tools/go/ast/inspector"
16+
)
17+
18+
// TestPreorderSeq checks PreorderSeq against Preorder.
19+
func TestPreorderSeq(t *testing.T) {
20+
inspect := inspector.New(netFiles)
21+
22+
nodeFilter := []ast.Node{(*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)}
23+
24+
// reference implementation
25+
var want []ast.Node
26+
inspect.Preorder(nodeFilter, func(n ast.Node) {
27+
want = append(want, n)
28+
})
29+
30+
// Check entire sequence.
31+
got := slices.Collect(inspect.PreorderSeq(nodeFilter...))
32+
compare(t, got, want)
33+
34+
// Check that break works.
35+
got = firstN(10, inspect.PreorderSeq(nodeFilter...))
36+
compare(t, got, want[:10])
37+
}
38+
39+
// TestAll checks All against Preorder.
40+
func TestAll(t *testing.T) {
41+
inspect := inspector.New(netFiles)
42+
43+
// reference implementation
44+
var want []*ast.CallExpr
45+
inspect.Preorder([]ast.Node{(*ast.CallExpr)(nil)}, func(n ast.Node) {
46+
want = append(want, n.(*ast.CallExpr))
47+
})
48+
49+
// Check entire sequence.
50+
got := slices.Collect(inspector.All[*ast.CallExpr](inspect))
51+
compare(t, got, want)
52+
53+
// Check that break works.
54+
got = firstN(10, inspector.All[*ast.CallExpr](inspect))
55+
compare(t, got, want[:10])
56+
}
57+
58+
// firstN(n, seq), returns a slice of up to n elements of seq.
59+
func firstN[T any](n int, seq iter.Seq[T]) (res []T) {
60+
for x := range seq {
61+
res = append(res, x)
62+
if len(res) == n {
63+
break
64+
}
65+
}
66+
return res
67+
}
68+
69+
// BenchmarkAllCalls is like BenchmarkInspectCalls,
70+
// but using the single-type filtering iterator, All.
71+
// (The iterator adds about 5-15%.)
72+
func BenchmarkAllCalls(b *testing.B) {
73+
inspect := inspector.New(netFiles)
74+
b.ResetTimer()
75+
76+
// Measure marginal cost of traversal.
77+
var ncalls int
78+
for range b.N {
79+
for range inspector.All[*ast.CallExpr](inspect) {
80+
ncalls++
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)