Skip to content

Commit 5a2a35a

Browse files
committed
compile: make closures compile properly
1 parent 329523c commit 5a2a35a

File tree

5 files changed

+428
-63
lines changed

5 files changed

+428
-63
lines changed

compile/compile.go

Lines changed: 203 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// compile python code
2-
32
package compile
43

54
import (
@@ -40,12 +39,26 @@ func (ls loopstack) Top() *loop {
4039
return &ls[len(ls)-1]
4140
}
4241

42+
type compilerScopeType uint8
43+
44+
const (
45+
compilerScopeModule compilerScopeType = iota
46+
compilerScopeClass
47+
compilerScopeFunction
48+
compilerScopeLambda
49+
compilerScopeComprehension
50+
)
51+
4352
// State for the compiler
4453
type compiler struct {
45-
Code *py.Code // code being built up
46-
OpCodes Instructions
47-
loops loopstack
48-
SymTable *symtable.SymTable
54+
Code *py.Code // code being built up
55+
OpCodes Instructions
56+
loops loopstack
57+
SymTable *symtable.SymTable
58+
scopeType compilerScopeType
59+
qualname string
60+
parent *compiler
61+
depth int
4962
}
5063

5164
// Set in py to avoid circular import
@@ -66,7 +79,7 @@ func init() {
6679
// the effects of any future statements in effect in the code calling
6780
// compile; if absent or zero these statements do influence the compilation,
6881
// in addition to any features explicitly specified.
69-
func Compile(str, filename, mode string, flags int, dont_inherit bool) (py.Object, error) {
82+
func Compile(str, filename, mode string, futureFlags int, dont_inherit bool) (py.Object, error) {
7083
// Parse Ast
7184
Ast, err := parser.ParseString(str, mode)
7285
if err != nil {
@@ -77,11 +90,27 @@ func Compile(str, filename, mode string, flags int, dont_inherit bool) (py.Objec
7790
if err != nil {
7891
return nil, err
7992
}
80-
return CompileAst(Ast, filename, flags, dont_inherit, SymTable)
93+
c := newCompiler(nil, compilerScopeModule)
94+
return c.compileAst(Ast, filename, futureFlags, dont_inherit, SymTable)
95+
}
96+
97+
// Make a new compiler
98+
func newCompiler(parent *compiler, scopeType compilerScopeType) *compiler {
99+
c := &compiler{
100+
// Code: code,
101+
// SymTable: SymTable,
102+
parent: parent,
103+
scopeType: scopeType,
104+
depth: 1,
105+
}
106+
if parent != nil {
107+
c.depth = parent.depth + 1
108+
}
109+
return c
81110
}
82111

83112
// As Compile but takes an Ast
84-
func CompileAst(Ast ast.Ast, filename string, flags int, dont_inherit bool, SymTable *symtable.SymTable) (code *py.Code, err error) {
113+
func (c *compiler) compileAst(Ast ast.Ast, filename string, futureFlags int, dont_inherit bool, SymTable *symtable.SymTable) (code *py.Code, err error) {
85114
defer func() {
86115
if r := recover(); r != nil {
87116
err = py.MakeException(r)
@@ -90,14 +119,19 @@ func CompileAst(Ast ast.Ast, filename string, flags int, dont_inherit bool, SymT
90119
//fmt.Println(ast.Dump(Ast))
91120
code = &py.Code{
92121
Filename: filename,
93-
Firstlineno: 1, // FIXME
94-
Name: "<module>", // FIXME
95-
Flags: int32(flags | py.CO_NOFREE), // FIXME
96-
}
97-
c := &compiler{
98-
Code: code,
99-
SymTable: SymTable,
122+
Firstlineno: 1, // FIXME
123+
Name: "<module>", // FIXME
124+
// Argcount: int32(len(node.Args.Args)),
125+
// Name: string(node.Name),
126+
// Kwonlyargcount: int32(len(node.Args.Kwonlyargs)),
127+
// Nlocals: int32(len(SymTable.Varnames)),
100128
}
129+
c.Code = code
130+
c.SymTable = SymTable
131+
code.Varnames = append(code.Varnames, SymTable.Varnames...)
132+
code.Cellvars = SymTable.Find(symtable.ScopeCell, 0)
133+
code.Freevars = SymTable.Find(symtable.ScopeFree, symtable.DefFreeClass)
134+
code.Flags = c.codeFlags(SymTable) | int32(futureFlags&py.CO_COMPILER_FLAGS_MASK)
101135
valueOnStack := false
102136
switch node := Ast.(type) {
103137
case *ast.Module:
@@ -112,10 +146,13 @@ func CompileAst(Ast ast.Ast, filename string, flags int, dont_inherit bool, SymT
112146
case ast.Expr:
113147
// Make None the first constant as lambda can't have a docstring
114148
c.Const(py.None)
115-
c.Code.Name = "<lambda>"
149+
code.Name = "<lambda>"
150+
c.setQualname() // FIXME is this in the right place!
116151
c.Expr(node)
117152
valueOnStack = true
118153
case *ast.FunctionDef:
154+
code.Name = string(node.Name)
155+
c.setQualname() // FIXME is this in the right place!
119156
c.Stmts(c.docString(node.Body))
120157
default:
121158
panic(py.ExceptionNewf(py.SyntaxError, "Unknown ModuleBase: %v", Ast))
@@ -130,6 +167,7 @@ func CompileAst(Ast ast.Ast, filename string, flags int, dont_inherit bool, SymT
130167
}
131168
code.Code = c.OpCodes.Assemble()
132169
code.Stacksize = int32(c.OpCodes.StackDepth())
170+
code.Nlocals = int32(len(code.Varnames))
133171
return code, nil
134172
}
135173

@@ -168,14 +206,23 @@ func (c *compiler) LoadConst(obj py.Object) {
168206
c.OpArg(vm.LOAD_CONST, c.Const(obj))
169207
}
170208

171-
// Returns the index into the slice provided, updating the slice if necessary
172-
func (c *compiler) Index(Id string, Names *[]string) uint32 {
209+
// Finds the Id in the slice provided, returning -1 if not found
210+
func (c *compiler) FindId(Id string, Names []string) int {
173211
// FIXME back this with a dict to stop O(N**2) behaviour on lots of vars
174-
for i, s := range *Names {
212+
for i, s := range Names {
175213
if Id == s {
176-
return uint32(i)
214+
return i
177215
}
178216
}
217+
return -1
218+
}
219+
220+
// Returns the index into the slice provided, updating the slice if necessary
221+
func (c *compiler) Index(Id string, Names *[]string) uint32 {
222+
i := c.FindId(Id, *Names)
223+
if i >= 0 {
224+
return uint32(i)
225+
}
179226
*Names = append(*Names, Id)
180227
return uint32(len(*Names) - 1)
181228
}
@@ -234,6 +281,131 @@ func (c *compiler) Stmts(stmts []ast.Stmt) {
234281
}
235282
}
236283

284+
/* The test for LOCAL must come before the test for FREE in order to
285+
handle classes where name is both local and free. The local var is
286+
a method and the free var is a free var referenced within a method.
287+
*/
288+
func (c *compiler) getRefType(name string) symtable.Scope {
289+
if c.scopeType == compilerScopeClass && name == "__class__" {
290+
return symtable.ScopeCell
291+
}
292+
scope := c.SymTable.GetScope(name)
293+
if scope == symtable.ScopeInvalid {
294+
panic(fmt.Sprintf("compile: getRefType: unknown scope for %s in %s\nsymbols: %s\nlocals: %s\nglobals: %s", name, c.Code.Name, c.SymTable.Symbols, c.Code.Varnames, c.Code.Names))
295+
}
296+
return scope
297+
}
298+
299+
// makeClosure constructs the function or closure for a func/class/lambda etc
300+
func (c *compiler) makeClosure(code *py.Code, args uint32, child *compiler) {
301+
free := uint32(len(code.Freevars))
302+
qualname := child.qualname
303+
if qualname == "" {
304+
qualname = c.qualname
305+
}
306+
307+
if free == 0 {
308+
c.LoadConst(code)
309+
c.LoadConst(py.String(qualname))
310+
c.OpArg(vm.MAKE_FUNCTION, args)
311+
return
312+
}
313+
for i := range code.Freevars {
314+
/* Bypass com_addop_varname because it will generate
315+
LOAD_DEREF but LOAD_CLOSURE is needed.
316+
*/
317+
name := code.Freevars[i]
318+
319+
/* Special case: If a class contains a method with a
320+
free variable that has the same name as a method,
321+
the name will be considered free *and* local in the
322+
class. It should be handled by the closure, as
323+
well as by the normal name loookup logic.
324+
*/
325+
reftype := c.getRefType(name)
326+
arg := 0
327+
if reftype == symtable.ScopeCell {
328+
arg = c.FindId(name, c.Code.Cellvars)
329+
} else { /* (reftype == FREE) */
330+
arg = c.FindId(name, c.Code.Freevars)
331+
}
332+
if arg < 0 {
333+
panic(fmt.Sprintf("compile: makeClosure: lookup %q in %q %v %v\nfreevars of %q: %v\n", name, c.SymTable.Name, reftype, arg, code.Name, code.Freevars))
334+
}
335+
c.OpArg(vm.LOAD_CLOSURE, uint32(arg))
336+
}
337+
c.OpArg(vm.BUILD_TUPLE, free)
338+
c.LoadConst(code)
339+
c.LoadConst(py.String(qualname))
340+
c.OpArg(vm.MAKE_CLOSURE, args)
341+
}
342+
343+
// Compute the flags for the current Code
344+
func (c *compiler) codeFlags(st *symtable.SymTable) (flags int32) {
345+
if st.Type == symtable.FunctionBlock {
346+
flags |= py.CO_NEWLOCALS
347+
if st.Unoptimized == 0 {
348+
flags |= py.CO_OPTIMIZED
349+
}
350+
if st.Nested {
351+
flags |= py.CO_NESTED
352+
}
353+
if st.Generator {
354+
flags |= py.CO_GENERATOR
355+
}
356+
if st.Varargs {
357+
flags |= py.CO_VARARGS
358+
}
359+
if st.Varkeywords {
360+
flags |= py.CO_VARKEYWORDS
361+
}
362+
}
363+
364+
/* (Only) inherit compilerflags in PyCF_MASK */
365+
flags |= c.Code.Flags & py.CO_COMPILER_FLAGS_MASK
366+
367+
if len(c.Code.Freevars) == 0 && len(c.Code.Cellvars) == 0 {
368+
flags |= py.CO_NOFREE
369+
}
370+
371+
return flags
372+
}
373+
374+
// Sets the qualname
375+
func (c *compiler) setQualname() {
376+
var base string
377+
if c.depth > 1 {
378+
force_global := false
379+
parent := c.parent
380+
if parent == nil {
381+
panic("compile: setQualname: expecting a parent")
382+
}
383+
if c.scopeType == compilerScopeFunction || c.scopeType == compilerScopeClass {
384+
// FIXME mangled = _Py_Mangle(parent.u_private, u.u_name)
385+
mangled := c.Code.Name
386+
scope := parent.SymTable.GetScope(mangled)
387+
if scope == symtable.ScopeGlobalImplicit {
388+
panic("compile: setQualname: not expecting scopeGlobalImplicit")
389+
}
390+
if scope == symtable.ScopeGlobalExplicit {
391+
force_global = true
392+
}
393+
}
394+
if !force_global {
395+
if parent.scopeType == compilerScopeFunction || parent.scopeType == compilerScopeLambda {
396+
base = parent.qualname + ".<locals>"
397+
} else {
398+
base = parent.qualname
399+
}
400+
}
401+
}
402+
if base != "" {
403+
c.qualname = base + "." + c.Code.Name
404+
} else {
405+
c.qualname = c.Code.Name
406+
}
407+
}
408+
237409
// Compile statement
238410
func (c *compiler) Stmt(stmt ast.Stmt) {
239411
switch node := stmt.(type) {
@@ -247,34 +419,15 @@ func (c *compiler) Stmt(stmt ast.Stmt) {
247419
if newSymTable == nil {
248420
panic("No symtable found for function")
249421
}
250-
code, err := CompileAst(node, c.Code.Filename, int(c.Code.Flags)|py.CO_OPTIMIZED|py.CO_NEWLOCALS, false, newSymTable) // FIXME pass on compile args
422+
newC := newCompiler(c, compilerScopeFunction)
423+
code, err := newC.compileAst(node, c.Code.Filename, 0, false, newSymTable)
251424
if err != nil {
252425
panic(err)
253426
}
427+
// FIXME need these set in code before we compile - (pass in node?)
254428
code.Argcount = int32(len(node.Args.Args))
255429
code.Name = string(node.Name)
256430
code.Kwonlyargcount = int32(len(node.Args.Kwonlyargs))
257-
code.Nlocals = code.Kwonlyargcount + int32(len(node.Args.Args))
258-
if code.Kwonlyargcount > 0 {
259-
code.Flags |= py.CO_VARARGS
260-
}
261-
262-
// Arguments
263-
for _, arg := range node.Args.Args {
264-
c.Index(string(arg.Arg), &code.Varnames)
265-
}
266-
for _, arg := range node.Args.Kwonlyargs {
267-
c.Index(string(arg.Arg), &code.Varnames)
268-
}
269-
if node.Args.Vararg != nil {
270-
code.Nlocals++
271-
c.Index(string(node.Args.Vararg.Arg), &code.Varnames)
272-
}
273-
if node.Args.Kwarg != nil {
274-
code.Nlocals++
275-
c.Index(string(node.Args.Kwarg.Arg), &code.Varnames)
276-
code.Flags |= py.CO_VARKEYWORDS
277-
}
278431

279432
// Defaults
280433
posdefaults := uint32(len(node.Args.Defaults))
@@ -316,10 +469,10 @@ func (c *compiler) Stmt(stmt ast.Stmt) {
316469
c.LoadConst(annotations)
317470
}
318471

319-
c.LoadConst(code)
320-
c.LoadConst(py.String(node.Name))
321-
c.OpArg(vm.MAKE_FUNCTION, posdefaults+(kwdefaults<<8)+(num_annotations<<16))
322-
c.OpArg(vm.STORE_NAME, c.Name(node.Name))
472+
args := uint32(posdefaults + (kwdefaults << 8) + (num_annotations << 16))
473+
c.makeClosure(code, args, newC)
474+
c.NameOp(string(node.Name), ast.Store)
475+
323476
case *ast.ClassDef:
324477
// Name Identifier
325478
// Bases []Expr
@@ -750,16 +903,15 @@ func (c *compiler) Expr(expr ast.Expr) {
750903
if newSymTable == nil {
751904
panic("No symtable found for lambda")
752905
}
753-
code, err := CompileAst(node.Body, c.Code.Filename, int(c.Code.Flags)|py.CO_OPTIMIZED|py.CO_NEWLOCALS, false, newSymTable) // FIXME pass on compile args
906+
newC := newCompiler(c, compilerScopeLambda)
907+
code, err := newC.compileAst(node.Body, c.Code.Filename, 0, false, newSymTable)
754908
if err != nil {
755909
panic(err)
756910
}
757911

758912
code.Argcount = int32(len(node.Args.Args))
759-
c.LoadConst(code)
760-
c.LoadConst(py.String("<lambda>"))
761-
// FIXME node.Args
762-
c.OpArg(vm.MAKE_FUNCTION, 0)
913+
// FIXME node.Args - more work on lambda needed
914+
c.makeClosure(code, 0, newC)
763915
case *ast.IfExp:
764916
// Test Expr
765917
// Body Expr

0 commit comments

Comments
 (0)