diff --git a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala index 9d6c3020d406..9f427496ac0d 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala @@ -69,9 +69,41 @@ object Objects: // ----------------------------- abstract domain ----------------------------- + /** Syntax for the data structure abstraction used in abstract domain: + * + * ve ::= ObjectRef(class) // global object + * | OfClass(class, vs[outer], ctor, args, env) // instance of a class + * | OfArray(object[owner], regions) + * | Fun(..., env) // value elements that can be contained in ValueSet + * vs ::= ValueSet(ve) // set of abstract values + * Bottom ::= ValueSet(Empty) + * val ::= ve | Cold | vs // all possible abstract values in domain + * Ref ::= ObjectRef | OfClass // values that represent a reference to some (global or instance) object + * ThisValue ::= Ref | Cold // possible values for 'this' + * + * refMap = Ref -> ( valsMap, varsMap, outersMap ) // refMap stores field informations of an object or instance + * valsMap = valsym -> val // maps immutable fields to their values + * varsMap = valsym -> addr // each mutable field has an abstract address + * outersMap = class -> val // maps outer objects to their values + * + * arrayMap = OfArray -> addr // an array has one address that stores the join value of every element + * + * heap = addr -> val // heap is mutable + * + * env = (valsMap, Option[env]) // stores local variables in the residing method, and possibly outer environments + * + * addr ::= localVarAddr(regions, valsym, owner) + * | fieldVarAddr(regions, valsym, owner) // independent of OfClass/ObjectRef + * | arrayAddr(regions, owner) // independent of array element type + * + * regions ::= List(sourcePosition) + */ + sealed abstract class Value: def show(using Context): String + /** ValueElement are elements that can be contained in a RefSet */ + sealed abstract class ValueElement extends Value /** * A reference caches the values for outers and immutable fields. @@ -80,7 +112,7 @@ object Objects: valsMap: mutable.Map[Symbol, Value], varsMap: mutable.Map[Symbol, Heap.Addr], outersMap: mutable.Map[ClassSymbol, Value]) - extends Value: + extends ValueElement: protected val vals: mutable.Map[Symbol, Value] = valsMap protected val vars: mutable.Map[Symbol, Heap.Addr] = varsMap protected val outers: mutable.Map[ClassSymbol, Value] = outersMap @@ -164,8 +196,7 @@ object Objects: * * @param owner The static object whose initialization creates the array. */ - case class OfArray(owner: ClassSymbol, regions: Regions.Data)(using @constructorOnly ctx: Context) - extends Ref(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty, outersMap = mutable.Map.empty): + case class OfArray(owner: ClassSymbol, regions: Regions.Data)(using @constructorOnly ctx: Context) extends ValueElement: val klass: ClassSymbol = defn.ArrayClass val addr: Heap.Addr = Heap.arrayAddr(regions, owner) def show(using Context) = "OfArray(owner = " + owner.show + ")" @@ -173,7 +204,7 @@ object Objects: /** * Represents a lambda expression */ - case class Fun(code: Tree, thisV: Value, klass: ClassSymbol, env: Env.Data) extends Value: + case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data) extends ValueElement: def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")" /** @@ -181,15 +212,20 @@ object Objects: * * It comes from `if` expressions. */ - case class RefSet(refs: ListSet[Value]) extends Value: - assert(refs.forall(!_.isInstanceOf[RefSet])) - def show(using Context) = refs.map(_.show).mkString("[", ",", "]") + case class ValueSet(values: ListSet[ValueElement]) extends Value: + def show(using Context) = values.map(_.show).mkString("[", ",", "]") - /** A cold alias which should not be used during initialization. */ + /** A cold alias which should not be used during initialization. + * + * Cold is not ValueElement since RefSet containing Cold is equivalent to Cold + */ case object Cold extends Value: def show(using Context) = "Cold" - val Bottom = RefSet(ListSet.empty) + val Bottom = ValueSet(ListSet.empty) + + /** Possible types for 'this' */ + type ThisValue = Ref | Cold.type /** Checking state */ object State: @@ -243,7 +279,7 @@ object Objects: obj end doCheckObject - def checkObjectAccess(clazz: ClassSymbol)(using data: Data, ctx: Context, pendingTrace: Trace): Value = + def checkObjectAccess(clazz: ClassSymbol)(using data: Data, ctx: Context, pendingTrace: Trace): ObjectRef = val index = data.checkingObjects.indexOf(ObjectRef(clazz)) if index != -1 then @@ -390,16 +426,20 @@ object Objects: * * @return the environment and value for `this` owned by the given method. */ - def resolveEnv(meth: Symbol, thisV: Value, env: Data)(using Context): Option[(Value, Data)] = log("Resolving env for " + meth.show + ", this = " + thisV.show + ", env = " + env.show, printer) { + def resolveEnv(meth: Symbol, thisV: ThisValue, env: Data)(using Context): Option[(ThisValue, Data)] = log("Resolving env for " + meth.show + ", this = " + thisV.show + ", env = " + env.show, printer) { env match case localEnv: LocalEnv => if localEnv.meth == meth then Some(thisV -> env) else resolveEnv(meth, thisV, localEnv.outer) case NoEnv => - // TODO: handle RefSet thisV match case ref: OfClass => - resolveEnv(meth, ref.outer, ref.env) + ref.outer match + case outer : ThisValue => + resolveEnv(meth, outer, ref.env) + case _ => + // TODO: properly handle the case where ref.outer is ValueSet + None case _ => None } @@ -473,7 +513,7 @@ object Objects: val config = Config(thisV, summon[Env.Data], Heap.getHeapData()) super.get(config, expr).map(_.value) - def cachedEval(thisV: Value, expr: Tree, cacheResult: Boolean)(fun: Tree => Value)(using Heap.MutableData, Env.Data): Value = + def cachedEval(thisV: ThisValue, expr: Tree, cacheResult: Boolean)(fun: Tree => Value)(using Heap.MutableData, Env.Data): Value = val config = Config(thisV, summon[Env.Data], Heap.getHeapData()) val result = super.cachedEval(config, expr, cacheResult, default = Res(Bottom, Heap.getHeapData())) { expr => Res(fun(expr), Heap.getHeapData()) @@ -530,34 +570,37 @@ object Objects: extension (a: Value) def join(b: Value): Value = (a, b) match - case (Cold, b) => Cold - case (a, Cold) => Cold - case (Bottom, b) => b - case (a, Bottom) => a - case (RefSet(refs1), RefSet(refs2)) => RefSet(refs1 ++ refs2) - case (a, RefSet(refs)) => RefSet(refs + a) - case (RefSet(refs), b) => RefSet(refs + b) - case (a, b) => RefSet(ListSet(a, b)) + case (Cold, _) => Cold + case (_, Cold) => Cold + case (Bottom, b) => b + case (a, Bottom) => a + case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1 ++ values2) + case (a : ValueElement, ValueSet(values)) => ValueSet(values + a) + case (ValueSet(values), b : ValueElement) => ValueSet(values + b) + case (a : ValueElement, b : ValueElement) => ValueSet(ListSet(a, b)) def widen(height: Int)(using Context): Value = if height == 0 then Cold else a match - case Bottom => Bottom + case Bottom => Bottom + + case ValueSet(values) => + values.map(ref => ref.widen(height)).join - case RefSet(refs) => - refs.map(ref => ref.widen(height)).join + case Fun(code, thisV, klass, env) => + Fun(code, thisV.widenRefOrCold(height), klass, env.widen(height)) - case Fun(code, thisV, klass, env) => - Fun(code, thisV.widen(height), klass, env.widen(height)) + case ref @ OfClass(klass, outer, _, args, env) => + val outer2 = outer.widen(height - 1) + val args2 = args.map(_.widen(height - 1)) + val env2 = env.widen(height - 1) + ref.widenedCopy(outer2, args2, env2) - case ref @ OfClass(klass, outer, _, args, env) => - val outer2 = outer.widen(height - 1) - val args2 = args.map(_.widen(height - 1)) - val env2 = env.widen(height - 1) - ref.widenedCopy(outer2, args2, env2) + case _ => a - case _ => a + extension (value: Ref | Cold.type) + def widenRefOrCold(height : Int)(using Context) : Ref | Cold.type = value.widen(height).asInstanceOf[ThisValue] extension (values: Iterable[Value]) def join: Value = if values.isEmpty then Bottom else values.reduce { (v1, v2) => v1.join(v2) } @@ -620,7 +663,7 @@ object Objects: val ddef = target.defTree.asInstanceOf[DefDef] val meth = ddef.symbol - val (thisV, outerEnv) = + val (thisV : ThisValue, outerEnv) = if meth.owner.isClass then (ref, Env.NoEnv) else @@ -629,7 +672,6 @@ object Objects: val env2 = Env.of(ddef, args.map(_.value), outerEnv) extendTrace(ddef) { given Env.Data = env2 - // eval(ddef.rhs, ref, cls, cacheResult = true) cache.cachedEval(ref, ddef.rhs, cacheResult = true) { expr => Returns.installHandler(meth) val res = cases(expr, thisV, cls) @@ -665,19 +707,19 @@ object Objects: given Env.Data = env extendTrace(code) { eval(code, thisV, klass, cacheResult = true) } - case RefSet(vs) => + case ValueSet(vs) => vs.map(v => call(v, meth, args, receiver, superType)).join } /** Handle constructor calls `(args)`. * - * @param thisV The value for the receiver. + * @param value The value for the receiver. * @param ctor The symbol of the target method. * @param args Arguments of the constructor call (all parameter blocks flatten to a list). */ - def callConstructor(thisV: Value, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("call " + ctor.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) { + def callConstructor(value: Value, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("call " + ctor.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) { - thisV match + value match case ref: Ref => if ctor.hasSource then val cls = ctor.owner.enclosingClass.asClass @@ -689,13 +731,17 @@ object Objects: val tpl = cls.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template] extendTrace(cls.defTree) { eval(tpl, ref, cls, cacheResult = true) } else - extendTrace(ddef) { eval(ddef.rhs, ref, cls, cacheResult = true) } + extendTrace(ddef) { // The return values for secondary constructors can be ignored + Returns.installHandler(ctor) + eval(ddef.rhs, ref, cls, cacheResult = true) + Returns.popHandler(ctor) + } else // no source code available Bottom case _ => - report.warning("[Internal error] unexpected constructor call, meth = " + ctor + ", this = " + thisV + Trace.show, Trace.position) + report.warning("[Internal error] unexpected constructor call, meth = " + ctor + ", this = " + value + Trace.show, Trace.position) Bottom } @@ -706,8 +752,8 @@ object Objects: * @param receiver The type of the receiver. * @param needResolve Whether the target of the selection needs resolution? */ - def select(thisV: Value, field: Symbol, receiver: Type, needResolve: Boolean = true): Contextual[Value] = log("select " + field.show + ", this = " + thisV.show, printer, (_: Value).show) { - thisV match + def select(value: Value, field: Symbol, receiver: Type, needResolve: Boolean = true): Contextual[Value] = log("select " + field.show + ", this = " + value.show, printer, (_: Value).show) { + value match case Cold => report.warning("Using cold alias", Trace.position) Bottom @@ -755,12 +801,16 @@ object Objects: report.warning("[Internal error] unexpected tree in selecting a function, fun = " + fun.code.show + Trace.show, fun.code) Bottom + case arr: OfArray => + report.warning("[Internal error] unexpected tree in selecting an array, array = " + arr.show + Trace.show, Trace.position) + Bottom + case Bottom => if field.isStaticObject then ObjectRef(field.moduleClass.asClass) else Bottom - case RefSet(refs) => - refs.map(ref => select(ref, field, receiver)).join + case ValueSet(values) => + values.map(ref => select(ref, field, receiver)).join } /** Handle assignment `lhs.f = rhs`. @@ -775,13 +825,16 @@ object Objects: case fun: Fun => report.warning("[Internal error] unexpected tree in assignment, fun = " + fun.code.show + Trace.show, Trace.position) + case arr: OfArray => + report.warning("[Internal error] unexpected tree in assignment, array = " + arr.show + Trace.show, Trace.position) + case Cold => report.warning("Assigning to cold aliases is forbidden. Calling trace:\n" + Trace.show, Trace.position) case Bottom => - case RefSet(refs) => - refs.foreach(ref => assign(ref, field, rhs, rhsTyp)) + case ValueSet(values) => + values.foreach(ref => assign(ref, field, rhs, rhsTyp)) case ref: Ref => if ref.hasVar(field) then @@ -811,9 +864,7 @@ object Objects: report.warning("[Internal error] unexpected outer in instantiating a class, outer = " + outer.show + ", class = " + klass.show + ", " + Trace.show, Trace.position) Bottom - case value: (Bottom.type | ObjectRef | OfClass | Cold.type) => - // The outer can be a bottom value for top-level classes. - + case outer: (Ref | Cold.type | Bottom.type) => if klass == defn.ArrayClass then val arr = OfArray(State.currentObject, summon[Regions.Data]) Heap.write(arr.addr, Bottom) @@ -821,18 +872,26 @@ object Objects: else // Widen the outer to finitize the domain. Arguments already widened in `evalArgs`. val (outerWidened, envWidened) = - if klass.owner.isClass then - (outer.widen(1), Env.NoEnv) - else - // klass.enclosingMethod returns its primary constructor - Env.resolveEnv(klass.owner.enclosingMethod, outer, summon[Env.Data]).getOrElse(Cold -> Env.NoEnv) + outer match + case _ : Bottom.type => // For top-level classes + (Bottom, Env.NoEnv) + case thisV : (Ref | Cold.type) => + if klass.owner.isClass then + if klass.owner.is(Flags.Package) then + report.warning("[Internal error] top-level class should have `Bottom` as outer, class = " + klass.show + ", outer = " + outer.show + ", " + Trace.show, Trace.position) + (Bottom, Env.NoEnv) + else + (thisV.widenRefOrCold(1), Env.NoEnv) + else + // klass.enclosingMethod returns its primary constructor + Env.resolveEnv(klass.owner.enclosingMethod, thisV, summon[Env.Data]).getOrElse(Cold -> Env.NoEnv) val instance = OfClass(klass, outerWidened, ctor, args.map(_.value), envWidened) callConstructor(instance, ctor, args) instance - case RefSet(refs) => - refs.map(ref => instantiate(ref, klass, ctor, args)).join + case ValueSet(values) => + values.map(ref => instantiate(ref, klass, ctor, args)).join } /** Handle local variable definition, `val x = e` or `var x = e`. @@ -854,7 +913,7 @@ object Objects: * @param thisV The value for `this` where the variable is used. * @param sym The symbol of the variable. */ - def readLocal(thisV: Value, sym: Symbol): Contextual[Value] = log("reading local " + sym.show, printer, (_: Value).show) { + def readLocal(thisV: ThisValue, sym: Symbol): Contextual[Value] = log("reading local " + sym.show, printer, (_: Value).show) { def isByNameParam(sym: Symbol) = sym.is(Flags.Param) && sym.info.isInstanceOf[ExprType] Env.resolveEnv(sym.enclosingMethod, thisV, summon[Env.Data]) match case Some(thisV -> env) => @@ -884,7 +943,7 @@ object Objects: case Cold => report.warning("Calling cold by-name alias. Call trace: \n" + Trace.show, Trace.position) Bottom - case _: RefSet | _: Ref => + case _: ValueSet | _: Ref | _: OfArray => report.warning("[Internal error] Unexpected by-name value " + value.show + ". Calling trace:\n" + Trace.show, Trace.position) Bottom else @@ -904,7 +963,7 @@ object Objects: * @param sym The symbol of the variable. * @param value The value of the rhs of the assignment. */ - def writeLocal(thisV: Value, sym: Symbol, value: Value): Contextual[Value] = log("write local " + sym.show + " with " + value.show, printer, (_: Value).show) { + def writeLocal(thisV: ThisValue, sym: Symbol, value: Value): Contextual[Value] = log("write local " + sym.show + " with " + value.show, printer, (_: Value).show) { assert(sym.is(Flags.Mutable), "Writing to immutable variable " + sym.show) Env.resolveEnv(sym.enclosingMethod, thisV, summon[Env.Data]) match @@ -928,7 +987,7 @@ object Objects: // -------------------------------- algorithm -------------------------------- /** Check an individual object */ - private def accessObject(classSym: ClassSymbol)(using Context, State.Data, Trace): Value = log("accessing " + classSym.show, printer, (_: Value).show) { + private def accessObject(classSym: ClassSymbol)(using Context, State.Data, Trace): ObjectRef = log("accessing " + classSym.show, printer, (_: Value).show) { if classSym.hasSource then State.checkObjectAccess(classSym) else @@ -965,13 +1024,13 @@ object Objects: * @param klass The enclosing class where the expression is located. * @param cacheResult It is used to reduce the size of the cache. */ - def eval(expr: Tree, thisV: Value, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + ", regions = " + Regions.show + " in " + klass.show, printer, (_: Value).show) { + def eval(expr: Tree, thisV: ThisValue, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + ", regions = " + Regions.show + " in " + klass.show, printer, (_: Value).show) { cache.cachedEval(thisV, expr, cacheResult) { expr => cases(expr, thisV, klass) } } /** Evaluate a list of expressions */ - def evalExprs(exprs: List[Tree], thisV: Value, klass: ClassSymbol): Contextual[List[Value]] = + def evalExprs(exprs: List[Tree], thisV: ThisValue, klass: ClassSymbol): Contextual[List[Value]] = exprs.map { expr => eval(expr, thisV, klass) } /** Handles the evaluation of different expressions @@ -982,7 +1041,7 @@ object Objects: * @param thisV The value for `C.this` where `C` is represented by the parameter `klass`. * @param klass The enclosing class where the expression `expr` is located. */ - def cases(expr: Tree, thisV: Value, klass: ClassSymbol): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) { + def cases(expr: Tree, thisV: ThisValue, klass: ClassSymbol): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) { val trace2 = trace.add(expr) expr match @@ -1182,7 +1241,7 @@ object Objects: * @param thisV The value for `C.this` where `C` is represented by `klass`. * @param klass The enclosing class where the type `tp` is located. */ - def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: Value, klass: ClassSymbol): Contextual[Value] = + def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: ThisValue, klass: ClassSymbol): Contextual[Value] = // expected member types for `unapplySeq` def lengthType = ExprType(defn.IntType) def lengthCompareType = MethodType(List(defn.IntType), defn.IntType) @@ -1372,7 +1431,7 @@ object Objects: * Object access elission happens when the object access is used as a prefix * in `new o.C` and `C` does not need an outer. */ - def evalType(tp: Type, thisV: Value, klass: ClassSymbol, elideObjectAccess: Boolean = false): Contextual[Value] = log("evaluating " + tp.show, printer, (_: Value).show) { + def evalType(tp: Type, thisV: ThisValue, klass: ClassSymbol, elideObjectAccess: Boolean = false): Contextual[Value] = log("evaluating " + tp.show, printer, (_: Value).show) { tp match case _: ConstantType => Bottom @@ -1422,7 +1481,7 @@ object Objects: } /** Evaluate arguments of methods and constructors */ - def evalArgs(args: List[Arg], thisV: Value, klass: ClassSymbol): Contextual[List[ArgInfo]] = + def evalArgs(args: List[Arg], thisV: ThisValue, klass: ClassSymbol): Contextual[List[ArgInfo]] = val argInfos = new mutable.ArrayBuffer[ArgInfo] args.foreach { arg => val res = @@ -1458,7 +1517,7 @@ object Objects: * @param thisV The value of the current object to be initialized. * @param klass The class to which the template belongs. */ - def init(tpl: Template, thisV: Ref, klass: ClassSymbol): Contextual[Value] = log("init " + klass.show, printer, (_: Value).show) { + def init(tpl: Template, thisV: Ref, klass: ClassSymbol): Contextual[Ref] = log("init " + klass.show, printer, (_: Value).show) { val paramsMap = tpl.constr.termParamss.flatten.map { vdef => vdef.name -> Env.valValue(vdef.symbol) }.toMap @@ -1609,9 +1668,9 @@ object Objects: Bottom else resolveThis(target, ref.outerValue(klass), outerCls) - case RefSet(refs) => - refs.map(ref => resolveThis(target, ref, klass)).join - case fun: Fun => + case ValueSet(values) => + values.map(ref => resolveThis(target, ref, klass)).join + case _: Fun | _ : OfArray => report.warning("[Internal error] unexpected thisV = " + thisV + ", target = " + target.show + ", klass = " + klass.show + Trace.show, Trace.position) Bottom } @@ -1622,7 +1681,7 @@ object Objects: * @param thisV The value for `C.this` where `C` is represented by the parameter `klass`. * @param klass The enclosing class where the type `tref` is located. */ - def outerValue(tref: TypeRef, thisV: Value, klass: ClassSymbol): Contextual[Value] = + def outerValue(tref: TypeRef, thisV: ThisValue, klass: ClassSymbol): Contextual[Value] = val cls = tref.classSymbol.asClass if tref.prefix == NoPrefix then val enclosing = cls.owner.lexicallyEnclosingClass.asClass diff --git a/tests/init-global/pos/secondary-constructor-return.scala b/tests/init-global/pos/secondary-constructor-return.scala new file mode 100644 index 000000000000..c4a0c1f95001 --- /dev/null +++ b/tests/init-global/pos/secondary-constructor-return.scala @@ -0,0 +1,12 @@ +class Foo (var x: Int) { + def this(a : Int, b : Int) = { + this(a + b) + return + } + val y = x +} + +object A { + val a = new Foo(2, 3) + val b = a.y +} \ No newline at end of file