Skip to content

Commit aa1017f

Browse files
committed
Optimize constraint initialization
Use a single traversal and a todo list instead of multiple traversals
1 parent 0a4cb68 commit aa1017f

File tree

1 file changed

+21
-26
lines changed

1 file changed

+21
-26
lines changed

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -229,23 +229,32 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
229229
* and to handle them separately is for efficiency, so that type expressions
230230
* used as bounds become smaller.
231231
*
232+
* TODO: try to do without stripping? It would mean it is more efficient
233+
* to pull out full bounds from a constraint.
234+
*
232235
* @param isUpper If true, `bound` is an upper bound, else a lower bound.
233236
*/
234-
private def stripParams(tp: Type, paramBuf: mutable.ListBuffer[TypeParamRef],
237+
private def stripParams(
238+
tp: Type,
239+
todos: mutable.ListBuffer[(OrderingConstraint, TypeParamRef) => OrderingConstraint],
235240
isUpper: Boolean)(using Context): Type = tp match {
236241
case param: TypeParamRef if contains(param) =>
237-
if (!paramBuf.contains(param)) paramBuf += param
242+
todos += (if isUpper then order(_, _, param) else order(_, param, _))
238243
NoType
244+
case tp: TypeBounds =>
245+
val lo1 = stripParams(tp.lo, todos, !isUpper).orElse(defn.NothingType)
246+
val hi1 = stripParams(tp.hi, todos, isUpper).orElse(defn.AnyKindType)
247+
tp.derivedTypeBounds(lo1, hi1)
239248
case tp: AndType if isUpper =>
240-
val tp1 = stripParams(tp.tp1, paramBuf, isUpper)
241-
val tp2 = stripParams(tp.tp2, paramBuf, isUpper)
249+
val tp1 = stripParams(tp.tp1, todos, isUpper)
250+
val tp2 = stripParams(tp.tp2, todos, isUpper)
242251
if (tp1.exists)
243252
if (tp2.exists) tp.derivedAndType(tp1, tp2)
244253
else tp1
245254
else tp2
246255
case tp: OrType if !isUpper =>
247-
val tp1 = stripParams(tp.tp1, paramBuf, isUpper)
248-
val tp2 = stripParams(tp.tp2, paramBuf, isUpper)
256+
val tp1 = stripParams(tp.tp1, todos, isUpper)
257+
val tp2 = stripParams(tp.tp2, todos, isUpper)
249258
if (tp1.exists)
250259
if (tp2.exists) tp.derivedOrType(tp1, tp2)
251260
else tp1
@@ -254,17 +263,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
254263
tp
255264
}
256265

257-
/** The bound type `tp` without clearly dependent parameters.
258-
* A top or bottom type if type consists only of dependent parameters.
259-
* TODO: try to do without normalization? It would mean it is more efficient
260-
* to pull out full bounds from a constraint.
261-
* @param isUpper If true, `bound` is an upper bound, else a lower bound.
262-
*/
263-
private def normalizedType(tp: Type, paramBuf: mutable.ListBuffer[TypeParamRef],
264-
isUpper: Boolean)(using Context): Type =
265-
stripParams(tp, paramBuf, isUpper)
266-
.orElse(if (isUpper) defn.AnyKindType else defn.NothingType)
267-
268266
def add(poly: TypeLambda, tvars: List[TypeVar])(using Context): This = {
269267
assert(!contains(poly))
270268
val nparams = poly.paramNames.length
@@ -280,18 +278,15 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
280278
*/
281279
private def init(poly: TypeLambda)(using Context): This = {
282280
var current = this
283-
val loBuf, hiBuf = new mutable.ListBuffer[TypeParamRef]
281+
val todos = new mutable.ListBuffer[(OrderingConstraint, TypeParamRef) => OrderingConstraint]
284282
var i = 0
285283
while (i < poly.paramNames.length) {
286284
val param = poly.paramRefs(i)
287-
val bounds = nonParamBounds(param)
288-
val lo = normalizedType(bounds.lo, loBuf, isUpper = false)
289-
val hi = normalizedType(bounds.hi, hiBuf, isUpper = true)
290-
current = updateEntry(current, param, bounds.derivedTypeBounds(lo, hi))
291-
current = loBuf.foldLeft(current)(order(_, _, param))
292-
current = hiBuf.foldLeft(current)(order(_, param, _))
293-
loBuf.clear()
294-
hiBuf.clear()
285+
val stripped = stripParams(nonParamBounds(param), todos, isUpper = true)
286+
current = updateEntry(current, param, stripped)
287+
while todos.nonEmpty do
288+
current = todos.head(current, param)
289+
todos.dropInPlace(1)
295290
i += 1
296291
}
297292
current.checkNonCyclic()

0 commit comments

Comments
 (0)