diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index f85075cd2de8..80416cee000a 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -437,13 +437,13 @@ object desugar { private def toDefParam(tparam: TypeDef, keepAnnotations: Boolean): TypeDef = { var mods = tparam.rawMods if (!keepAnnotations) mods = mods.withAnnotations(Nil) - tparam.withMods(mods & (EmptyFlags | Sealed) | Param) + tparam.withMods(mods & EmptyFlags | Param) } private def toDefParam(vparam: ValDef, keepAnnotations: Boolean, keepDefault: Boolean): ValDef = { var mods = vparam.rawMods if (!keepAnnotations) mods = mods.withAnnotations(Nil) val hasDefault = if keepDefault then HasDefault else EmptyFlags - vparam.withMods(mods & (GivenOrImplicit | Erased | hasDefault) | Param) + vparam.withMods(mods & (GivenOrImplicit | Erased | hasDefault | Tracked) | Param) } def mkApply(fn: Tree, paramss: List[ParamClause])(using Context): Tree = @@ -529,7 +529,7 @@ object desugar { // but not on the constructor parameters. The reverse is true for // annotations on class _value_ parameters. val constrTparams = impliedTparams.map(toDefParam(_, keepAnnotations = false)) - val constrVparamss = + def defVparamss = if (originalVparamss.isEmpty) { // ensure parameter list is non-empty if (isCaseClass) report.error(CaseClassMissingParamList(cdef), namePos) @@ -540,6 +540,7 @@ object desugar { ListOfNil } else originalVparamss.nestedMap(toDefParam(_, keepAnnotations = true, keepDefault = true)) + val constrVparamss = defVparamss val derivedTparams = constrTparams.zipWithConserve(impliedTparams)((tparam, impliedParam) => derivedTypeParam(tparam).withAnnotations(impliedParam.mods.annotations)) @@ -614,6 +615,11 @@ object desugar { case _ => false } + def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match { + case PostfixOp(_, Ident(tpnme.raw.STAR)) => true + case _ => false + } + def appliedRef(tycon: Tree, tparams: List[TypeDef] = constrTparams, widenHK: Boolean = false) = { val targs = for (tparam <- tparams) yield { val targ = refOfDef(tparam) @@ -630,10 +636,13 @@ object desugar { appliedTypeTree(tycon, targs) } - def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match { - case PostfixOp(_, Ident(tpnme.raw.STAR)) => true - case _ => false - } + def addParamRefinements(core: Tree, paramss: List[List[ValDef]]): Tree = + val refinements = + for params <- paramss; param <- params; if param.mods.is(Tracked) yield + ValDef(param.name, SingletonTypeTree(TermRefTree().watching(param)), EmptyTree) + .withSpan(param.span) + if refinements.isEmpty then core + else RefinedTypeTree(core, refinements).showing(i"refined result: $result", Printers.desugar) // a reference to the class type bound by `cdef`, with type parameters coming from the constructor val classTypeRef = appliedRef(classTycon) @@ -854,18 +863,17 @@ object desugar { Nil } else { - val defParamss = constrVparamss match { + val defParamss = defVparamss match case Nil :: paramss => paramss // drop leading () that got inserted by class // TODO: drop this once we do not silently insert empty class parameters anymore case paramss => paramss - } val finalFlag = if ctx.settings.YcompileScala2Library.value then EmptyFlags else Final // implicit wrapper is typechecked in same scope as constructor, so // we can reuse the constructor parameters; no derived params are needed. DefDef( className.toTermName, joinParams(constrTparams, defParamss), - classTypeRef, creatorExpr) + addParamRefinements(classTypeRef, defParamss), creatorExpr) .withMods(companionMods | mods.flags.toTermFlags & (GivenOrImplicit | Inline) | finalFlag) .withSpan(cdef.span) :: Nil } @@ -894,7 +902,9 @@ object desugar { } if mods.isAllOf(Given | Inline | Transparent) then report.error("inline given instances cannot be trasparent", cdef) - val classMods = if mods.is(Given) then mods &~ (Inline | Transparent) | Synthetic else mods + var classMods = if mods.is(Given) then mods &~ (Inline | Transparent) | Synthetic else mods + if vparamAccessors.exists(_.mods.is(Tracked)) then + classMods |= Dependent cpy.TypeDef(cdef: TypeDef)( name = className, rhs = cpy.Template(impl)(constr, parents1, clsDerived, self1, diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 817ff5c6c9fa..7e003da1556d 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -231,6 +231,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case class Infix()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Infix) + case class Tracked()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Tracked) + /** Used under pureFunctions to mark impure function types `A => B` in `FunctionWithMods` */ case class Impure()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Impure) } diff --git a/compiler/src/dotty/tools/dotc/core/Flags.scala b/compiler/src/dotty/tools/dotc/core/Flags.scala index 6ae9541a327f..aa6f52949eaa 100644 --- a/compiler/src/dotty/tools/dotc/core/Flags.scala +++ b/compiler/src/dotty/tools/dotc/core/Flags.scala @@ -242,7 +242,7 @@ object Flags { val (AccessorOrSealed @ _, Accessor @ _, Sealed @ _) = newFlags(11, "", "sealed") /** A mutable var, an open class */ - val (MutableOrOpen @ __, Mutable @ _, Open @ _) = newFlags(12, "mutable", "open") + val (MutableOrOpen @ _, Mutable @ _, Open @ _) = newFlags(12, "mutable", "open") /** Symbol is local to current class (i.e. private[this] or protected[this] * pre: Private or Protected are also set @@ -377,6 +377,8 @@ object Flags { /** Symbol cannot be found as a member during typer */ val (Invisible @ _, _, _) = newFlags(45, "") + val (Tracked @ _, _, Dependent @ _) = newFlags(46, "tracked", "dependent") + // ------------ Flags following this one are not pickled ---------------------------------- /** Symbol is not a member of its owner */ @@ -452,7 +454,7 @@ object Flags { CommonSourceModifierFlags.toTypeFlags | Abstract | Sealed | Opaque | Open val TermSourceModifierFlags: FlagSet = - CommonSourceModifierFlags.toTermFlags | Inline | AbsOverride | Lazy + CommonSourceModifierFlags.toTermFlags | Inline | AbsOverride | Lazy | Tracked /** Flags representing modifiers that can appear in trees */ val ModifierFlags: FlagSet = @@ -477,7 +479,7 @@ object Flags { */ val AfterLoadFlags: FlagSet = commonFlags( FromStartFlags, AccessFlags, Final, AccessorOrSealed, - Abstract, LazyOrTrait, SelfName, JavaDefined, JavaAnnotation, Transparent) + Abstract, LazyOrTrait, SelfName, JavaDefined, JavaAnnotation, Transparent, Tracked) /** A value that's unstable unless complemented with a Stable flag */ val UnstableValueFlags: FlagSet = Mutable | Method diff --git a/compiler/src/dotty/tools/dotc/core/NamerOps.scala b/compiler/src/dotty/tools/dotc/core/NamerOps.scala index ea0cbfbd0c07..86b71fb473bd 100644 --- a/compiler/src/dotty/tools/dotc/core/NamerOps.scala +++ b/compiler/src/dotty/tools/dotc/core/NamerOps.scala @@ -5,6 +5,7 @@ package core import Contexts.*, Symbols.*, Types.*, Flags.*, Scopes.*, Decorators.*, Names.*, NameOps.* import SymDenotations.{LazyType, SymDenotation}, StdNames.nme import TypeApplications.EtaExpansion +import collection.mutable /** Operations that are shared between Namer and TreeUnpickler */ object NamerOps: @@ -14,9 +15,56 @@ object NamerOps: * @param ctor the constructor */ def effectiveResultType(ctor: Symbol, paramss: List[List[Symbol]])(using Context): Type = - paramss match - case TypeSymbols(tparams) :: _ => ctor.owner.typeRef.appliedTo(tparams.map(_.typeRef)) - case _ => ctor.owner.typeRef + var resType = paramss match + case TypeSymbols(tparams) :: _ => + ctor.owner.typeRef.appliedTo(tparams.map(_.typeRef)) + case _ => + ctor.owner.typeRef + for params <- paramss; param <- params do + if param.is(Tracked) then + resType = RefinedType(resType, param.name, param.termRef) + resType + + /** Split dependent class refinements off parent type and add them to `refinements` */ + extension (tp: Type) + def separateRefinements(refinements: mutable.LinkedHashMap[Name, Type])(using Context): Type = + tp match + case RefinedType(tp1, rname, rinfo) => + try tp1.separateRefinements(refinements) + finally + refinements(rname) = refinements.get(rname) match + case Some(tp) => tp & rinfo + case None => rinfo + case tp => tp + + /** Add all parent `refinements` to the result type of the info of the dependent + * class constructor `constr`. Parent refinements refer to parameter accessors + * in the current class. These have to be mapped to the paramRefs of the + * constructor info. + */ + def integrateParentRefinements( + constr: Symbol, refinements: mutable.LinkedHashMap[Name, Type])(using Context): Unit = + + /** @param info the (remaining part) of the constructor info + * @param nameToParamRef the map from parameter names to paramRefs of + * previously encountered parts of `info`. + */ + def recur(info: Type, nameToParamRef: mutable.Map[Name, Type]): Type = info match + case info: MethodOrPoly => + info.derivedLambdaType(resType = + recur(info.resType, nameToParamRef ++= info.paramNames.zip(info.paramRefs))) + case _ => + val mapParams = new TypeMap: + def apply(t: Type) = t match + case t: TermRef if t.symbol.is(ParamAccessor) && t.symbol.owner == constr.owner => + nameToParamRef(t.name) + case _ => + mapOver(t) + refinements.foldLeft(info): (info, refinement) => + val (rname, rinfo) = refinement + RefinedType(info, rname, mapParams(rinfo)) + constr.info = recur(constr.info, mutable.Map()) + end integrateParentRefinements /** If isConstructor, make sure it has at least one non-implicit parameter list * This is done by adding a () in front of a leading old style implicit parameter, diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 4e3596ea8814..a32ed7712b80 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -10,6 +10,7 @@ import Contexts.ctx import dotty.tools.dotc.reporting.trace import config.Feature.migrateTo3 import config.Printers.* +import transform.TypeUtils.stripRefinement trait PatternTypeConstrainer { self: TypeComparer => @@ -88,11 +89,6 @@ trait PatternTypeConstrainer { self: TypeComparer => } } - def stripRefinement(tp: Type): Type = tp match { - case tp: RefinedOrRecType => stripRefinement(tp.parent) - case tp => tp - } - def tryConstrainSimplePatternType(pat: Type, scrut: Type) = { val patCls = pat.classSymbol val scrCls = scrut.classSymbol diff --git a/compiler/src/dotty/tools/dotc/core/Scopes.scala b/compiler/src/dotty/tools/dotc/core/Scopes.scala index 7df5a7fa3c09..6529bde4aec1 100644 --- a/compiler/src/dotty/tools/dotc/core/Scopes.scala +++ b/compiler/src/dotty/tools/dotc/core/Scopes.scala @@ -17,6 +17,7 @@ import Denotations.* import printing.Texts.* import printing.Printer import SymDenotations.NoDenotation +import util.common.alwaysFalse import collection.mutable import scala.compiletime.uninitialized @@ -94,15 +95,13 @@ object Scopes { def foreach[U](f: Symbol => U)(using Context): Unit = toList.foreach(f) /** Selects all Symbols of this Scope which satisfy a predicate. */ - def filter(p: Symbol => Boolean)(using Context): List[Symbol] = { + def filter(p: Symbol => Boolean, stopAt: Symbol => Boolean = alwaysFalse)(using Context): List[Symbol] = { ensureComplete() var syms: List[Symbol] = Nil var e = lastEntry - while ((e != null) && e.owner == this) { - val sym = e.sym - if (p(sym)) syms = sym :: syms + while e != null && e.owner == this && !stopAt(e.sym) do + if p(e.sym) then syms = e.sym :: syms e = e.prev - } syms } diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 253a45ffd7a8..d8586ac8f13b 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -623,6 +623,7 @@ object StdNames { val toString_ : N = "toString" val toTypeConstructor: N = "toTypeConstructor" val tpe : N = "tpe" + val tracked: N = "tracked" val transparent : N = "transparent" val tree : N = "tree" val true_ : N = "true" diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index e18e1463f3ae..2084e4fc2997 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -2380,7 +2380,7 @@ object SymDenotations { * Both getters and setters are returned in this list. */ def paramAccessors(using Context): List[Symbol] = - unforcedDecls.filter(_.is(ParamAccessor)) + unforcedDecls.filter(_.is(ParamAccessor))//, stopAt = sym => sym.is(Method, butNot = ParamAccessor)) /** The term parameter getters of this class. */ def paramGetters(using Context): List[Symbol] = diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index 2e4fe9967d6a..3ee31caf2404 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -776,6 +776,7 @@ class TreePickler(pickler: TastyPickler) { if (flags.is(Exported)) writeModTag(EXPORTED) if (flags.is(Given)) writeModTag(GIVEN) if (flags.is(Implicit)) writeModTag(IMPLICIT) + if (flags.is(Tracked)) writeModTag(TRACKED) if (isTerm) { if (flags.is(Lazy, butNot = Module)) writeModTag(LAZY) if (flags.is(AbsOverride)) { writeModTag(ABSTRACT); writeModTag(OVERRIDE) } diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index c366146a789e..4f35816c2cf8 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -733,6 +733,7 @@ class TreeUnpickler(reader: TastyReader, case INVISIBLE => addFlag(Invisible) case TRANSPARENT => addFlag(Transparent) case INFIX => addFlag(Infix) + case TRACKED => addFlag(Tracked) case PRIVATEqualified => readByte() privateWithin = readWithin @@ -1011,12 +1012,20 @@ class TreeUnpickler(reader: TastyReader, * but skip constructor arguments. Return any trees that were partially * parsed in this way as InferredTypeTrees. */ - def readParents(withArgs: Boolean)(using Context): List[Tree] = + def readParents(cls: ClassSymbol, withArgs: Boolean)(using Context): List[Tree] = collectWhile(nextByte != SELFDEF && nextByte != DEFDEF) { nextUnsharedTag match case APPLY | TYPEAPPLY | BLOCK => - if withArgs then readTree() - else InferredTypeTree().withType(readParentType()) + if withArgs then + readTree() + else if cls.is(Dependent) then + val parentReader = fork + val parentCoreType = readParentType() + if parentCoreType.dealias.typeSymbol.is(Dependent) + then parentReader.readTree() // read the whole tree since we need to see the refinement + else InferredTypeTree().withType(parentCoreType) + else + InferredTypeTree().withType(readParentType()) case _ => readTpt() } @@ -1042,9 +1051,10 @@ class TreeUnpickler(reader: TastyReader, while (bodyIndexer.reader.nextByte != DEFDEF) bodyIndexer.skipTree() bodyIndexer.indexStats(end) } - val parentReader = fork - val parents = readParents(withArgs = false)(using parentCtx) - val parentTypes = parents.map(_.tpe.dealias) + val parentsReader = fork + val parents = readParents(cls, withArgs = false)(using parentCtx) + val parentRefinements = mutable.LinkedHashMap[Name, Type]() + val parentTypes = parents.map(_.tpe.dealias.separateRefinements(parentRefinements)) val self = if (nextByte == SELFDEF) { readByte() @@ -1057,11 +1067,13 @@ class TreeUnpickler(reader: TastyReader, selfInfo = if (self.isEmpty) NoType else self.tpt.tpe ).integrateOpaqueMembers val constr = readIndexedDef().asInstanceOf[DefDef] + if parentRefinements.nonEmpty then + integrateParentRefinements(constr.symbol, parentRefinements) val mappedParents: LazyTreeList = if parents.exists(_.isInstanceOf[InferredTypeTree]) then // parents were not read fully, will need to be read again later on demand - new LazyReader(parentReader, localDummy, ctx.mode, ctx.source, - _.readParents(withArgs = true) + new LazyReader(parentsReader, localDummy, ctx.mode, ctx.source, + _.readParents(cls, withArgs = true) .map(_.changeOwner(localDummy, constr.symbol))) else parents diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index a2cc6499e843..d5f98260451e 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -53,6 +53,7 @@ object Parsers { enum ParamOwner: case Class // class or trait or enum case CaseClass // case class or enum case + case ImplicitClass // implicit class case Type // type alias or abstract type case TypeParam // type parameter case Def // method @@ -60,8 +61,9 @@ object Parsers { case ExtensionPrefix // extension clause, up to and including extension parameter case ExtensionFollow // extension clause, following extension parameter - def isClass = // owner is a class - this == Class || this == CaseClass + def isClass = this match // owner is a class + case Class | CaseClass | ImplicitClass | Given => true + case _ => false def takesOnlyUsingClauses = // only using clauses allowed for this owner this == Given || this == ExtensionFollow def acceptsVariance = @@ -3100,6 +3102,7 @@ object Parsers { case nme.open => Mod.Open() case nme.transparent => Mod.Transparent() case nme.infix => Mod.Infix() + case nme.tracked => Mod.Tracked() } } @@ -3166,6 +3169,7 @@ object Parsers { * | AccessModifier * | override * | opaque + * | tracked * LocalModifier ::= abstract | final | sealed | open | implicit | lazy | erased | inline | transparent */ def modifiers(allowed: BitSet = modifierTokens, start: Modifiers = Modifiers()): Modifiers = { @@ -3283,7 +3287,7 @@ object Parsers { val isAbstractOwner = paramOwner == ParamOwner.Type || paramOwner == ParamOwner.TypeParam val start = in.offset var mods = annotsAsMods() | Param - if paramOwner == ParamOwner.Class || paramOwner == ParamOwner.CaseClass then + if paramOwner.isClass then mods |= PrivateLocal if isIdent(nme.raw.PLUS) && checkVarianceOK() then mods |= Covariant @@ -3359,6 +3363,8 @@ object Parsers { mods = addFlag(modifiers(start = mods), ParamAccessor) mods = if in.token == VAL then + if !mods.is(Private) && paramOwner != ParamOwner.ImplicitClass then + mods |= Tracked in.nextToken() mods else if in.token == VAR then @@ -3427,7 +3433,8 @@ object Parsers { val isParams = !impliedMods.is(Given) || startParamTokens.contains(in.token) - || isIdent && (in.name == nme.inline || in.lookahead.isColon) + || isIdent + && (in.name == nme.inline || in.name == nme.tracked || in.lookahead.isColon) (mods, isParams) (if isParams then commaSeparated(() => param()) else contextTypes(paramOwner, numLeadParams, impliedMods)) match { @@ -3895,7 +3902,10 @@ object Parsers { } def classDefRest(start: Offset, mods: Modifiers, name: TypeName): TypeDef = - val constr = classConstr(if mods.is(Case) then ParamOwner.CaseClass else ParamOwner.Class) + val constr = classConstr( + if mods.is(Case) then ParamOwner.CaseClass + else if mods.is(Implicit) then ParamOwner.ImplicitClass + else ParamOwner.Class) val templ = templateOpt(constr) finalizeDef(TypeDef(name, templ), mods, start) @@ -4001,6 +4011,15 @@ object Parsers { val nameStart = in.offset val name = if isIdent && followingIsGivenSig() then ident() else EmptyTermName + def adjustDefParams(paramss: List[ParamClause]): List[ParamClause] = + paramss.nestedMap: param => + if !param.mods.isAllOf(PrivateLocal) then + syntaxError(em"method parameter ${param.name} may not be `a val`", param.span) + if param.mods.is(Tracked) then + syntaxError(em"method parameter ${param.name} may not be `tracked`", param.span) + param.withMods(param.mods &~ (AccessFlags | ParamAccessor | Tracked | Mutable) | Param) + .asInstanceOf[List[ParamClause]] + val gdef = val tparams = typeParamClauseOpt(ParamOwner.Given) newLineOpt() @@ -4022,16 +4041,17 @@ object Parsers { mods1 |= Lazy ValDef(name, parents.head, subExpr()) else - DefDef(name, joinParams(tparams, vparamss), parents.head, subExpr()) + DefDef(name, adjustDefParams(joinParams(tparams, vparamss)), parents.head, subExpr()) else if (isStatSep || isStatSeqEnd) && parentsIsType then if name.isEmpty then syntaxError(em"anonymous given cannot be abstract") - DefDef(name, joinParams(tparams, vparamss), parents.head, EmptyTree) + DefDef(name, adjustDefParams(joinParams(tparams, vparamss)), parents.head, EmptyTree) else - val tparams1 = tparams.map(tparam => tparam.withMods(tparam.mods | PrivateLocal)) - val vparamss1 = vparamss.map(_.map(vparam => - vparam.withMods(vparam.mods &~ Param | ParamAccessor | Protected))) - val constr = makeConstructor(tparams1, vparamss1) + val vparamss1 = vparamss.nestedMap: vparam => + if vparam.mods.is(Private) + then vparam.withMods(vparam.mods &~ PrivateLocal | Protected) + else vparam + val constr = makeConstructor(tparams, vparamss1) val templ = if isStatSep || isStatSeqEnd then Template(constr, parents, Nil, EmptyValDef, Nil) else withTemplate(constr, parents) diff --git a/compiler/src/dotty/tools/dotc/parsing/Tokens.scala b/compiler/src/dotty/tools/dotc/parsing/Tokens.scala index fbf4e8d701dd..a6992cd5a676 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Tokens.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Tokens.scala @@ -294,7 +294,7 @@ object Tokens extends TokensCommon { final val closingParens = BitSet(RPAREN, RBRACKET, RBRACE) - final val softModifierNames = Set(nme.inline, nme.into, nme.opaque, nme.open, nme.transparent, nme.infix) + final val softModifierNames = Set(nme.inline, nme.into, nme.opaque, nme.open, nme.transparent, nme.infix, nme.tracked) def showTokenDetailed(token: Int): String = debugString(token) diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 7fed5bc97f35..ced8743fa6a7 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -111,7 +111,7 @@ class PlainPrinter(_ctx: Context) extends Printer { protected def refinementNameString(tp: RefinedType): String = nameString(tp.refinedName) /** String representation of a refinement */ - protected def toTextRefinement(rt: RefinedType): Text = + def toTextRefinement(rt: RefinedType): Text = val keyword = rt.refinedInfo match { case _: ExprType | _: MethodOrPoly => "def " case _: TypeBounds => "type " diff --git a/compiler/src/dotty/tools/dotc/printing/Printer.scala b/compiler/src/dotty/tools/dotc/printing/Printer.scala index 8687925ed5fb..297dc31ea94a 100644 --- a/compiler/src/dotty/tools/dotc/printing/Printer.scala +++ b/compiler/src/dotty/tools/dotc/printing/Printer.scala @@ -4,7 +4,7 @@ package printing import core.* import Texts.*, ast.Trees.* -import Types.{Type, SingletonType, LambdaParam, NamedType}, +import Types.{Type, SingletonType, LambdaParam, NamedType, RefinedType}, Symbols.Symbol, Scopes.Scope, Constants.Constant, Names.Name, Denotations._, Annotations.Annotation, Contexts.Context import typer.Implicits.* @@ -104,6 +104,9 @@ abstract class Printer { /** Textual representation of a prefix of some reference, ending in `.` or `#` */ def toTextPrefixOf(tp: NamedType): Text + /** textual representation of a refinement, with no enclosing {...} */ + def toTextRefinement(rt: RefinedType): Text + /** Textual representation of a reference in a capture set */ def toTextCaptureRef(tp: Type): Text diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 23fcc80d3f22..9e21a55b8281 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -18,6 +18,7 @@ import config.Printers.typr import config.Feature import util.SrcPos import reporting.* +import transform.TypeUtils.stripRefinement import NameKinds.WildcardParamName object PostTyper { @@ -332,11 +333,11 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => case Select(nu: New, nme.CONSTRUCTOR) if isCheckable(nu) => // need to check instantiability here, because the type of the New itself // might be a type constructor. - ctx.typer.checkClassType(tree.tpe, tree.srcPos, traitReq = false, stablePrefixReq = true) + ctx.typer.checkClassType(tree.tpe, tree.srcPos, traitReq = false, stablePrefixReq = true, refinementOK = true) if !nu.tpe.isLambdaSub then // Check the constructor type as well; it could be an illegal singleton type // which would not be reflected as `tree.tpe` - ctx.typer.checkClassType(nu.tpe, tree.srcPos, traitReq = false, stablePrefixReq = false) + ctx.typer.checkClassType(nu.tpe, tree.srcPos, traitReq = false, stablePrefixReq = false, refinementOK = true) Checking.checkInstantiable(tree.tpe, nu.tpe, nu.srcPos) withNoCheckNews(nu :: Nil)(app1) case _ => @@ -411,8 +412,12 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => // Constructor parameters are in scope when typing a parent. // While they can safely appear in a parent tree, to preserve // soundness we need to ensure they don't appear in a parent - // type (#16270). - val illegalRefs = parent.tpe.namedPartsWith(p => p.symbol.is(ParamAccessor) && (p.symbol.owner eq sym)) + // type (#16270). We can strip any refinement of a parent type since + // these refinements are split off from the parent type constructor + // application `parent` in Namer and don't show up as parent types + // of the class. + val illegalRefs = parent.tpe.stripRefinement.namedPartsWith: + p => p.symbol.is(ParamAccessor) && (p.symbol.owner eq sym) if illegalRefs.nonEmpty then report.error( em"The type of a class parent cannot refer to constructor parameters, but ${parent.tpe} refers to ${illegalRefs.map(_.name.show).mkString(",")}", parent.srcPos) diff --git a/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala b/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala index 90f6e2795f12..419fb4b9153e 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala @@ -7,11 +7,11 @@ import TypeErasure.ErasedValueType import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.* import Names.Name -object TypeUtils { +object TypeUtils: /** A decorator that provides methods on types * that are needed in the transformer pipeline. */ - extension (self: Type) { + extension (self: Type) def isErasedValueType(using Context): Boolean = self.isInstanceOf[ErasedValueType] @@ -104,5 +104,11 @@ object TypeUtils { case _ => val cls = self.underlyingClassRef(refinementOK = false).typeSymbol cls.isTransparentClass && (!traitOnly || cls.is(Trait)) - } -} + + /** Strip all outer refinements off this type */ + def stripRefinement: Type = self match + case self: RefinedOrRecType => self.parent.stripRefinement + case seld => self + +end TypeUtils + diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index 90c26e279d01..a2248fa7d219 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -198,7 +198,7 @@ object Checking { * and that the instance conforms to the self type of the created class. */ def checkInstantiable(tp: Type, srcTp: Type, pos: SrcPos)(using Context): Unit = - tp.underlyingClassRef(refinementOK = false) match + tp.underlyingClassRef(refinementOK = true) match case tref: TypeRef => val cls = tref.symbol if (cls.isOneOf(AbstractOrTrait)) { @@ -1021,8 +1021,8 @@ trait Checking { * check that class prefix is stable. * @return `tp` itself if it is a class or trait ref, ObjectType if not. */ - def checkClassType(tp: Type, pos: SrcPos, traitReq: Boolean, stablePrefixReq: Boolean)(using Context): Type = - tp.underlyingClassRef(refinementOK = false) match { + def checkClassType(tp: Type, pos: SrcPos, traitReq: Boolean, stablePrefixReq: Boolean, refinementOK: Boolean = false)(using Context): Type = + tp.underlyingClassRef(refinementOK) match case tref: TypeRef => if (traitReq && !tref.symbol.is(Trait)) report.error(TraitIsExpected(tref.symbol), pos) if (stablePrefixReq && ctx.phase <= refchecksPhase) checkStable(tref.prefix, pos, "class prefix") @@ -1030,7 +1030,6 @@ trait Checking { case _ => report.error(NotClassType(tp), pos) defn.ObjectType - } /** If `sym` is an old-style implicit conversion, check that implicit conversions are enabled. * @pre sym.is(GivenOrImplicit) @@ -1198,6 +1197,20 @@ trait Checking { } } + /** Check that all refinements in class parent come from tracked parameters */ + def checkOnlyDependentRefinements(cls: ClassSymbol, parent: Tree)(using Context): Unit = + def recur(ptype: Type): Unit = ptype.dealias match + case rt @ RefinedType(ptype1, rname, rinfo) => + val ok = rname.isTermName && ptype1.nonPrivateMember(rname).hasAltWith(_.symbol.is(Tracked)) + if !ok then + report.error( + em"""Illegal refinement { ${ctx.printer.toTextRefinement(rt).show} } in parent type of $cls; + |only val refinements of tracked parameters are allowed.""", + parent.srcPos) + recur(ptype1) + case _ => + recur(parent.tpe) + /** Check that `tpt` does not define a higher-kinded type */ def checkSimpleKinded(tpt: Tree)(using Context): Tree = if (!tpt.tpe.hasSimpleKind && !ctx.isJava) @@ -1593,6 +1606,7 @@ trait ReChecking extends Checking { override def checkCanThrow(tp: Type, span: Span)(using Context): Tree = EmptyTree override def checkCatch(pat: Tree, guard: Tree)(using Context): Unit = () override def checkNoContextFunctionType(tree: Tree)(using Context): Unit = () + override def checkOnlyDependentRefinements(cls: ClassSymbol, parent: Tree)(using Context): Unit = () override def checkFeature(name: TermName, description: => String, featureUseSite: Symbol, pos: SrcPos)(using Context): Unit = () } @@ -1601,7 +1615,7 @@ trait NoChecking extends ReChecking { override def checkNonCyclic(sym: Symbol, info: TypeBounds, reportErrors: Boolean)(using Context): Type = info override def checkNonCyclicInherited(joint: Type, parents: List[Type], decls: Scope, pos: SrcPos)(using Context): Unit = () override def checkStable(tp: Type, pos: SrcPos, kind: String)(using Context): Unit = () - override def checkClassType(tp: Type, pos: SrcPos, traitReq: Boolean, stablePrefixReq: Boolean)(using Context): Type = tp + override def checkClassType(tp: Type, pos: SrcPos, traitReq: Boolean, stablePrefixReq: Boolean, refinementOK: Boolean)(using Context): Type = tp override def checkImplicitConversionDefOK(sym: Symbol)(using Context): Unit = () override def checkImplicitConversionUseOK(tree: Tree, expected: Type)(using Context): Unit = () override def checkFeasibleParent(tp: Type, pos: SrcPos, where: => String = "")(using Context): Type = tp diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index cca26abdd1ec..5f71c2d0e7d6 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1500,8 +1500,12 @@ class Namer { typer: Typer => core match case Select(New(tpt), nme.CONSTRUCTOR) => val targs1 = targs map (typedAheadType(_)) - val ptype = typedAheadType(tpt).tpe appliedTo targs1.tpes - if (ptype.typeParams.isEmpty) ptype + val ptype = typedAheadType(tpt).tpe.appliedTo(targs1.tpes) + if ptype.typeParams.isEmpty + //&& !ptype.dealias.typeSymbol.primaryConstructor.info.finalResultType.isInstanceOf[RefinedType] + && !ptype.dealias.typeSymbol.is(Dependent) + then + ptype else if (denot.is(ModuleClass) && denot.sourceModule.isOneOf(GivenOrImplicit)) missingType(denot.symbol, "parent ")(using creationContext) @@ -1539,7 +1543,7 @@ class Namer { typer: Typer => if (cls.isRefinementClass) ptype else { val pt = checkClassType(ptype, parent.srcPos, - traitReq = parent ne parents.head, stablePrefixReq = true) + traitReq = parent ne parents.head, stablePrefixReq = true, refinementOK = true) if (pt.derivesFrom(cls)) { val addendum = parent match { case Select(qual: Super, _) if Feature.migrateTo3 => @@ -1605,14 +1609,18 @@ class Namer { typer: Typer => completeConstructor(denot) denot.info = tempInfo.nn - val parentTypes = defn.adjustForTuple(cls, cls.typeParams, - defn.adjustForBoxedUnit(cls, - addUsingTraits( - ensureFirstIsClass(cls, parents.map(checkedParentType(_))) - ) - ) - ) - typr.println(i"completing $denot, parents = $parents%, %, parentTypes = $parentTypes%, %") + /** The refinements coming from all parent class constructor applications */ + val parentRefinements = mutable.LinkedHashMap[Name, Type]() + + val parentTypes = + defn.adjustForTuple(cls, cls.typeParams, + defn.adjustForBoxedUnit(cls, + addUsingTraits( + ensureFirstIsClass(cls, parents.map(checkedParentType(_))) + ))).map(_.separateRefinements(parentRefinements)) + + typr.println(i"completing $denot, parents = $parents%, %, stripped parent types = $parentTypes%, %") + typr.println(i"constr type = ${cls.primaryConstructor.infoOrCompleter}, refinements = ${parentRefinements.toList}") if (impl.derived.nonEmpty) { val (derivingClass, derivePos) = original.removeAttachment(desugar.DerivingCompanion) match { @@ -1627,6 +1635,9 @@ class Namer { typer: Typer => denot.info = tempInfo.nn.finalized(parentTypes) tempInfo = null // The temporary info can now be garbage-collected + if parentRefinements.nonEmpty then + integrateParentRefinements(cls.primaryConstructor, parentRefinements) + cls.setFlag(Dependent) Checking.checkWellFormed(cls) if (isDerivedValueClass(cls)) cls.setFlag(Final) cls.info = avoidPrivateLeaks(cls) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 3d61ac0a51f3..4d92f37103c8 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2704,6 +2704,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer ensureAccessible(constr.termRef, superAccess = true, tree.srcPos) else checkParentCall(result, cls) + if !cls.isRefinementClass then + checkOnlyDependentRefinements(cls, parent) if cls is Case then checkCaseInheritance(psym, cls, tree.srcPos) result @@ -4360,7 +4362,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName) case _ => errorTree(tree, em"cannot convert from $tree to an instance creation expression") - val tycon = tree.tpe.widen.finalResultType.underlyingClassRef(refinementOK = false) + val tycon = tree.tpe.widen.finalResultType.underlyingClassRef(refinementOK = true) typed( untpd.Select( untpd.New(untpd.TypedSplice(tpt.withType(tycon))), diff --git a/compiler/test/dotc/pos-test-pickling.blacklist b/compiler/test/dotc/pos-test-pickling.blacklist index eb4b861eb324..11ce6fc7ed02 100644 --- a/compiler/test/dotc/pos-test-pickling.blacklist +++ b/compiler/test/dotc/pos-test-pickling.blacklist @@ -111,3 +111,6 @@ java-inherited-type1 # recursion limit exceeded i7445b.scala + +# alias types at different levels of dereferencing +parsercombinators-givens.scala \ No newline at end of file diff --git a/tasty/src/dotty/tools/tasty/TastyFormat.scala b/tasty/src/dotty/tools/tasty/TastyFormat.scala index 7e412a5e67a7..bf788b0467cd 100644 --- a/tasty/src/dotty/tools/tasty/TastyFormat.scala +++ b/tasty/src/dotty/tools/tasty/TastyFormat.scala @@ -219,6 +219,7 @@ Standard-Section: "ASTs" TopLevelStat* EXPORTED -- An export forwarder OPEN -- an open class INVISIBLE -- invisible during typechecking + TRACKED -- a tracked class parameter / a dependent class Annotation Variance = STABLE -- invariant @@ -485,6 +486,7 @@ object TastyFormat { final val INVISIBLE = 44 final val EMPTYCLAUSE = 45 final val SPLITCLAUSE = 46 + final val TRACKED = 47 // Cat. 2: tag Nat @@ -662,7 +664,8 @@ object TastyFormat { | INVISIBLE | ANNOTATION | PRIVATEqualified - | PROTECTEDqualified => true + | PROTECTEDqualified + | TRACKED => true case _ => false } diff --git a/tests/neg/i0248-inherit-refined.check b/tests/neg/i0248-inherit-refined.check new file mode 100644 index 000000000000..c964bf37aaa4 --- /dev/null +++ b/tests/neg/i0248-inherit-refined.check @@ -0,0 +1,18 @@ +-- Error: tests/neg/i0248-inherit-refined.scala:4:18 ------------------------------------------------------------------- +4 | class B extends X // error + | ^ + | Illegal refinement { type T = Int } in parent type of class B; + | only val refinements of tracked parameters are allowed. +-- [E170] Type Error: tests/neg/i0248-inherit-refined.scala:6:18 ------------------------------------------------------- +6 | class C extends Y // error + | ^ + | test.A & test.B is not a class type +-- [E170] Type Error: tests/neg/i0248-inherit-refined.scala:8:18 ------------------------------------------------------- +8 | class D extends Z // error + | ^ + | test.A | test.B is not a class type +-- Error: tests/neg/i0248-inherit-refined.scala:9:28 ------------------------------------------------------------------- +9 | abstract class E extends ({ val x: Int }) // error + | ^^^^^^^^^^^^^^ + | Illegal refinement { val x: Int } in parent type of class E; + | only val refinements of tracked parameters are allowed. diff --git a/tests/neg/i3964.scala b/tests/neg/i3964.scala new file mode 100644 index 000000000000..8670b6067979 --- /dev/null +++ b/tests/neg/i3964.scala @@ -0,0 +1,16 @@ +trait Animal +class Dog extends Animal +class Cat extends Animal + +object Test1: + + abstract class Bar { val x: Animal } + val bar: Bar { val x: Cat } = new Bar { val x = new Cat } // error, but should work + + trait Foo { val x: Animal } + val foo: Foo { val x: Cat } = new Foo { val x = new Cat } // error, but should work + +object Test3: + trait Vec(tracked val size: Int) + class Vec8 extends Vec(8): + val s: 8 = size // error, but should work \ No newline at end of file diff --git a/tests/pos/depclass-1.scala b/tests/pos/depclass-1.scala new file mode 100644 index 000000000000..a49520d58d09 --- /dev/null +++ b/tests/pos/depclass-1.scala @@ -0,0 +1,8 @@ +class A(tracked val source: String) + +class B(source: String) extends A(source) + +class C(source: String) extends B(source) + +val x = C("hello") +val _: A{ val source: "hello" } = x \ No newline at end of file diff --git a/tests/pos/i3920.scala b/tests/pos/i3920.scala new file mode 100644 index 000000000000..acd413ad3c07 --- /dev/null +++ b/tests/pos/i3920.scala @@ -0,0 +1,32 @@ + +trait Ordering { + type T + def compare(t1:T, t2: T): Int +} + +class SetFunctor(tracked val ord: Ordering) { + type Set = List[ord.T] + def empty: Set = Nil + + implicit class helper(s: Set) { + def add(x: ord.T): Set = x :: remove(x) + def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0) + def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0) + } +} + +object Test { + val orderInt = new Ordering { + type T = Int + def compare(t1: T, t2: T): Int = t1 - t2 + } + + val IntSet = new SetFunctor(orderInt) + import IntSet._ + + def main(args: Array[String]) = { + val set = IntSet.empty.add(6).add(8).add(23) + assert(!set.member(7)) + assert(set.member(8)) + } +} \ No newline at end of file diff --git a/tests/pos/i3964.scala b/tests/pos/i3964.scala new file mode 100644 index 000000000000..457a8189b22d --- /dev/null +++ b/tests/pos/i3964.scala @@ -0,0 +1,31 @@ +trait Animal +class Dog extends Animal +class Cat extends Animal + +object Test2: + class Bar(tracked val x: Animal) + val b = new Bar(new Cat) + val bar: Bar { val x: Cat } = new Bar(new Cat) // ok + + trait Foo(tracked val x: Animal) + val foo: Foo { val x: Cat } = new Foo(new Cat) {} // ok + +object Test3: + trait Vec(tracked val size: Int) + class Vec8 extends Vec(8) + + abstract class Lst(tracked val size: Int) + class Lst8 extends Lst(8) + + val v8a: Vec { val size: 8 } = new Vec8 + val v8b: Vec { val size: 8 } = new Vec(8) {} + + val l8a: Lst { val size: 8 } = new Lst8 + val l8b: Lst { val size: 8 } = new Lst(8) {} + + class VecN(tracked val n: Int) extends Vec(n) + class Vec9 extends VecN(9) + val v9a = VecN(9) + val _: Vec { val size: 9 } = v9a + val v9b = Vec9() + val _: Vec { val size: 9 } = v9b diff --git a/tests/pos/i3964a/Defs_1.scala b/tests/pos/i3964a/Defs_1.scala new file mode 100644 index 000000000000..80a81a31ca6d --- /dev/null +++ b/tests/pos/i3964a/Defs_1.scala @@ -0,0 +1,16 @@ +trait Animal +class Dog extends Animal +class Cat extends Animal + +object Test2: + class Bar(tracked val x: Animal) + val b = new Bar(new Cat) + val bar: Bar { val x: Cat } = new Bar(new Cat) // ok + + trait Foo(tracked val x: Animal) + val foo: Foo { val x: Cat } = new Foo(new Cat) {} // ok + +trait Vec(tracked val size: Int) +class Vec8 extends Vec(8) + +abstract class Lst(tracked val size: Int) \ No newline at end of file diff --git a/tests/pos/i3964a/Uses_2.scala b/tests/pos/i3964a/Uses_2.scala new file mode 100644 index 000000000000..65612b60cb0e --- /dev/null +++ b/tests/pos/i3964a/Uses_2.scala @@ -0,0 +1,14 @@ +class Lst8 extends Lst(8) + +val v8a: Vec { val size: 8 } = new Vec8 +val v8b: Vec { val size: 8 } = new Vec(8) {} + +val l8a: Lst { val size: 8 } = new Lst8 +val l8b: Lst { val size: 8 } = new Lst(8) {} + +class VecN(tracked val n: Int) extends Vec(n) +class Vec9 extends VecN(9) +val v9a = VecN(9) +val _: Vec { val size: 9 } = v9a +val v9b = Vec9() +val _: Vec { val size: 9 } = v9b diff --git a/tests/pos/parsercombinators-expanded.scala b/tests/pos/parsercombinators-expanded.scala new file mode 100644 index 000000000000..e766048f6984 --- /dev/null +++ b/tests/pos/parsercombinators-expanded.scala @@ -0,0 +1,62 @@ +import collection.mutable + +/// A parser combinator. +trait Combinator[T]: + + /// The context from which elements are being parsed, typically a stream of tokens. + type Context + /// The element being parsed. + type Element + + extension (self: T) + /// Parses and returns an element from `context`. + def parse(context: Context): Option[Element] +end Combinator + +final case class Apply[C, E](action: C => Option[E]) +final case class Combine[A, B](first: A, second: B) + +object test: + + class apply[C, E] extends Combinator[Apply[C, E]]: + type Context = C + type Element = E + extension(self: Apply[C, E]) + def parse(context: C): Option[E] = self.action(context) + + def apply[C, E]: apply[C, E] = new apply[C, E] + + class combine[A, B]( + tracked val f: Combinator[A], + tracked val s: Combinator[B] { type Context = f.Context} + ) extends Combinator[Combine[A, B]]: + type Context = f.Context + type Element = (f.Element, s.Element) + extension(self: Combine[A, B]) + def parse(context: Context): Option[Element] = ??? + + def combine[A, B]( + _f: Combinator[A], + _s: Combinator[B] { type Context = _f.Context} + ) = new combine[A, B](_f, _s) + // cast is needed since the type of new combine[A, B](_f, _s) + // drops the required refinement. + + extension [A] (buf: mutable.ListBuffer[A]) def popFirst() = + if buf.isEmpty then None + else try Some(buf.head) finally buf.remove(0) + + @main def hello: Unit = { + val source = (0 to 10).toList + val stream = source.to(mutable.ListBuffer) + + val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst()) + val m = Combine(n, n) + + val c = combine( + apply[mutable.ListBuffer[Int], Int], + apply[mutable.ListBuffer[Int], Int] + ) + val r = c.parse(m)(stream) // was type mismatch, now OK + val rc: Option[(Int, Int)] = r + } diff --git a/tests/pos/parsercombinators-givens-2.scala b/tests/pos/parsercombinators-givens-2.scala new file mode 100644 index 000000000000..19a15cfd490a --- /dev/null +++ b/tests/pos/parsercombinators-givens-2.scala @@ -0,0 +1,50 @@ +import collection.mutable + +/// A parser combinator. +trait Combinator[T]: + + /// The context from which elements are being parsed, typically a stream of tokens. + type Context + /// The element being parsed. + type Element + + extension (self: T) + /// Parses and returns an element from `context`. + def parse(context: Context): Option[Element] +end Combinator + +final case class Apply[C, E](action: C => Option[E]) +final case class Combine[A, B](first: A, second: B) + +given apply[C, E]: Combinator[Apply[C, E]] with { + type Context = C + type Element = E + extension(self: Apply[C, E]) { + def parse(context: C): Option[E] = self.action(context) + } +} + +given combine[A, B, C](using + f: Combinator[A] { type Context = C }, + s: Combinator[B] { type Context = C } +): Combinator[Combine[A, B]] with { + type Context = f.Context + type Element = (f.Element, s.Element) + extension(self: Combine[A, B]) { + def parse(context: Context): Option[Element] = ??? + } +} + +extension [A] (buf: mutable.ListBuffer[A]) def popFirst() = + if buf.isEmpty then None + else try Some(buf.head) finally buf.remove(0) + +@main def hello: Unit = { + val source = (0 to 10).toList + val stream = source.to(mutable.ListBuffer) + + val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst()) + val m = Combine(n, n) + + val r = m.parse(stream) // works, but Element type is not resolved correctly +} diff --git a/tests/pos/parsercombinators-givens.scala b/tests/pos/parsercombinators-givens.scala new file mode 100644 index 000000000000..78d985ada9a8 --- /dev/null +++ b/tests/pos/parsercombinators-givens.scala @@ -0,0 +1,52 @@ +import collection.mutable + +/// A parser combinator. +trait Combinator[T]: + + /// The context from which elements are being parsed, typically a stream of tokens. + type Context + /// The element being parsed. + type Element + + extension (self: T) + /// Parses and returns an element from `context`. + def parse(context: Context): Option[Element] +end Combinator + +final case class Apply[C, E](action: C => Option[E]) +final case class Combine[A, B](first: A, second: B) + +given apply[C, E]: Combinator[Apply[C, E]] with { + type Context = C + type Element = E + extension(self: Apply[C, E]) { + def parse(context: C): Option[E] = self.action(context) + } +} + +given combine[A, B](using + tracked val f: Combinator[A], + tracked val s: Combinator[B] { type Context = f.Context } +): Combinator[Combine[A, B]] with { + type Context = f.Context + type Element = (f.Element, s.Element) + extension(self: Combine[A, B]) { + def parse(context: Context): Option[Element] = ??? + } +} + +extension [A] (buf: mutable.ListBuffer[A]) def popFirst() = + if buf.isEmpty then None + else try Some(buf.head) finally buf.remove(0) + +@main def hello: Unit = { + val source = (0 to 10).toList + val stream = source.to(mutable.ListBuffer) + + val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst()) + val m = Combine(n, n) + + val r = m.parse(stream) // error: type mismatch, found `mutable.ListBuffer[Int]`, required `?1.Context` + val rc: Option[(Int, Int)] = r + // it would be great if this worked +}