@@ -14,56 +14,63 @@ import StdNames._
14
14
import Contexts ._
15
15
import transform .TypeUtils ._
16
16
17
- object ConstFold {
17
+ object ConstFold :
18
18
19
19
import tpd ._
20
20
21
- /** If tree is a constant operation, replace with result. */
22
- def apply [T <: Tree ](tree : T )(using Context ): T = finish(tree) {
23
- tree match {
24
- case Apply (Select (xt, op), yt :: Nil ) =>
21
+ private val foldedBinops = Set [Name ](
22
+ nme.ZOR , nme.OR , nme.XOR , nme.ZAND , nme.AND , nme.EQ , nme.NE ,
23
+ nme.LT , nme.GT , nme.LE , nme.GE , nme.LSL , nme.LSR , nme.ASR ,
24
+ nme.ADD , nme.SUB , nme.MUL , nme.DIV , nme.MOD )
25
+
26
+ private val foldedUnops = Set [Name ](
27
+ nme.UNARY_! , nme.UNARY_~ , nme.UNARY_+ , nme.UNARY_- )
28
+
29
+ def Apply [T <: Apply ](tree : T )(using Context ): T =
30
+ tree.fun match
31
+ case Select (xt, op) if foldedBinops.contains(op) =>
25
32
xt.tpe.widenTermRefExpr.normalized match
26
33
case ConstantType (x) =>
27
- yt.tpe.widenTermRefExpr match
28
- case ConstantType (y) => foldBinop(op, x, y)
29
- case _ => null
30
- case _ => null
31
- case Select (xt, op) =>
32
- xt.tpe.widenTermRefExpr match {
33
- case ConstantType (x) => foldUnop(op, x)
34
- case _ => null
35
- }
36
- case TypeApply (_, List (targ)) if tree.symbol eq defn.Predef_classOf =>
37
- Constant (targ.tpe)
38
- case Apply (TypeApply (Select (qual, nme.getClass_), _), Nil )
39
- if qual.tpe.widen.isPrimitiveValueType =>
40
- Constant (qual.tpe.widen)
41
- case _ => null
42
- }
43
- }
34
+ tree.args match
35
+ case yt :: Nil =>
36
+ yt.tpe.widenTermRefExpr.normalized match
37
+ case ConstantType (y) => tree.withFoldedType(foldBinop(op, x, y))
38
+ case _ => tree
39
+ case _ => tree
40
+ case _ => tree
41
+ case TypeApply (Select (qual, nme.getClass_), _)
42
+ if qual.tpe.widen.isPrimitiveValueType && tree.args.isEmpty =>
43
+ tree.withFoldedType(Constant (qual.tpe.widen))
44
+ case _ =>
45
+ tree
46
+
47
+ def Select [T <: Select ](tree : T )(using Context ): T =
48
+ if foldedUnops.contains(tree.name) then
49
+ tree.qualifier.tpe.widenTermRefExpr.normalized match
50
+ case ConstantType (x) => tree.withFoldedType(foldUnop(tree.name, x))
51
+ case _ => tree
52
+ else tree
53
+
54
+ /** If tree is a constant operation, replace with result. */
55
+ def apply [T <: Tree ](tree : T )(using Context ): T = tree match
56
+ case tree : Apply => Apply (tree)
57
+ case tree : Select => Select (tree)
58
+ case TypeApply (_, targ :: Nil ) if tree.symbol eq defn.Predef_classOf =>
59
+ tree.withFoldedType(Constant (targ.tpe))
60
+ case _ => tree
44
61
45
62
/** If tree is a constant value that can be converted to type `pt`, perform
46
63
* the conversion.
47
64
*/
48
65
def apply [T <: Tree ](tree : T , pt : Type )(using Context ): T =
49
- finish(apply(tree)) {
50
- tree.tpe.widenTermRefExpr.normalized match {
51
- case ConstantType (x) => x convertTo pt
52
- case _ => null
53
- }
54
- }
55
-
56
- inline private def finish [T <: Tree ](tree : T )(compX : => Constant )(using Context ): T =
57
- try {
58
- val x = compX
59
- if (x ne null ) tree.withType(ConstantType (x)).asInstanceOf [T ]
60
- else tree
61
- }
62
- catch {
63
- case _ : ArithmeticException => tree // the code will crash at runtime,
64
- // but that is better than the
65
- // compiler itself crashing
66
- }
66
+ val tree1 = apply(tree)
67
+ tree.tpe.widenTermRefExpr.normalized match
68
+ case ConstantType (x) => tree1.withFoldedType(x.convertTo(pt))
69
+ case _ => tree1
70
+
71
+ extension [T <: Tree ](tree : T )(using Context )
72
+ private def withFoldedType (c : Constant | Null ): T =
73
+ if c == null then tree else tree.withType(ConstantType (c)).asInstanceOf [T ]
67
74
68
75
private def foldUnop (op : Name , x : Constant ): Constant = (op, x.tag) match {
69
76
case (nme.UNARY_! , BooleanTag ) => Constant (! x.booleanValue)
@@ -166,23 +173,22 @@ object ConstFold {
166
173
case _ => null
167
174
}
168
175
169
- private def foldBinop (op : Name , x : Constant , y : Constant ): Constant = {
176
+ private def foldBinop (op : Name , x : Constant , y : Constant ): Constant =
170
177
val optag =
171
178
if (x.tag == y.tag) x.tag
172
179
else if (x.isNumeric && y.isNumeric) math.max(x.tag, y.tag)
173
180
else NoTag
174
181
175
- try optag match {
176
- case BooleanTag => foldBooleanOp(op, x, y)
177
- case ByteTag | ShortTag | CharTag | IntTag => foldSubrangeOp(op, x, y)
178
- case LongTag => foldLongOp(op, x, y)
179
- case FloatTag => foldFloatOp(op, x, y)
180
- case DoubleTag => foldDoubleOp(op, x, y)
181
- case StringTag if op == nme.ADD => Constant (x.stringValue + y.stringValue)
182
- case _ => null
183
- }
184
- catch {
185
- case ex : ArithmeticException => null
186
- }
187
- }
188
- }
182
+ try optag match
183
+ case BooleanTag => foldBooleanOp(op, x, y)
184
+ case ByteTag | ShortTag | CharTag | IntTag => foldSubrangeOp(op, x, y)
185
+ case LongTag => foldLongOp(op, x, y)
186
+ case FloatTag => foldFloatOp(op, x, y)
187
+ case DoubleTag => foldDoubleOp(op, x, y)
188
+ case StringTag if op == nme.ADD => Constant (x.stringValue + y.stringValue)
189
+ case _ => null
190
+ catch case ex : ArithmeticException => null // the code will crash at runtime,
191
+ // but that is better than the
192
+ // compiler itself crashing
193
+ end foldBinop
194
+ end ConstFold
0 commit comments