@@ -4174,37 +4174,66 @@ object Types {
4174
4174
4175
4175
def tryCompiletimeConstantFold (using Context ): Type = tycon match {
4176
4176
case tycon : TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
4177
- def constValue (tp : Type ): Option [Any ] = tp.dealias match {
4177
+ extension (tp : Type ) def fixForEvaluation : Type =
4178
+ tp.normalized.dealias match {
4179
+ case tp : TermRef => tp.underlying
4180
+ case tp => tp
4181
+ }
4182
+
4183
+ def constValue (tp : Type ): Option [Any ] = tp.fixForEvaluation match {
4178
4184
case ConstantType (Constant (n)) => Some (n)
4179
4185
case _ => None
4180
4186
}
4181
4187
4182
- def boolValue (tp : Type ): Option [Boolean ] = tp.dealias match {
4188
+ def boolValue (tp : Type ): Option [Boolean ] = tp.fixForEvaluation match {
4183
4189
case ConstantType (Constant (n : Boolean )) => Some (n)
4184
4190
case _ => None
4185
4191
}
4186
4192
4187
- def intValue (tp : Type ): Option [Int ] = tp.dealias match {
4193
+ def intValue (tp : Type ): Option [Int ] = tp.fixForEvaluation match {
4188
4194
case ConstantType (Constant (n : Int )) => Some (n)
4189
4195
case _ => None
4190
4196
}
4191
4197
4192
- def stringValue (tp : Type ): Option [String ] = tp.dealias match {
4193
- case ConstantType (Constant (n : String )) => Some (n)
4198
+ def longValue (tp : Type ): Option [Long ] = tp.fixForEvaluation match {
4199
+ case ConstantType (Constant (n : Long )) => Some (n)
4194
4200
case _ => None
4195
4201
}
4196
4202
4203
+ def stringValue (tp : Type ): Option [String ] = tp.fixForEvaluation match {
4204
+ case ConstantType (Constant (n : String )) => Some (n)
4205
+ case _ => None
4206
+ }
4207
+ def isConst : Option [Type ] = args.head.fixForEvaluation match {
4208
+ case ConstantType (_) => Some (ConstantType (Constant (true )))
4209
+ case _ => Some (ConstantType (Constant (false )))
4210
+ }
4197
4211
def natValue (tp : Type ): Option [Int ] = intValue(tp).filter(n => n >= 0 && n < Int .MaxValue )
4198
4212
4199
4213
def constantFold1 [T ](extractor : Type => Option [T ], op : T => Any ): Option [Type ] =
4200
- extractor(args.head.normalized ).map(a => ConstantType (Constant (op(a))))
4214
+ extractor(args.head).map(a => ConstantType (Constant (op(a))))
4201
4215
4202
4216
def constantFold2 [T ](extractor : Type => Option [T ], op : (T , T ) => Any ): Option [Type ] =
4217
+ constantFold2AB(extractor, extractor, op)
4218
+
4219
+ def constantFold2AB [TA , TB ](extractorA : Type => Option [TA ], extractorB : Type => Option [TB ], op : (TA , TB ) => Any ): Option [Type ] =
4203
4220
for {
4204
- a <- extractor (args.head.normalized )
4205
- b <- extractor (args.tail.head.normalized )
4221
+ a <- extractorA (args.head)
4222
+ b <- extractorB (args.last )
4206
4223
} yield ConstantType (Constant (op(a, b)))
4207
4224
4225
+ def constantFold3 [TA , TB , TC ](
4226
+ extractorA : Type => Option [TA ],
4227
+ extractorB : Type => Option [TB ],
4228
+ extractorC : Type => Option [TC ],
4229
+ op : (TA , TB , TC ) => Any
4230
+ ): Option [Type ] =
4231
+ for {
4232
+ a <- extractorA(args.head)
4233
+ b <- extractorB(args(1 ))
4234
+ c <- extractorC(args.last)
4235
+ } yield ConstantType (Constant (op(a, b, c)))
4236
+
4208
4237
trace(i " compiletime constant fold $this" , typr, show = true ) {
4209
4238
val name = tycon.symbol.name
4210
4239
val owner = tycon.symbol.owner
@@ -4216,10 +4245,13 @@ object Types {
4216
4245
} else if (owner == defn.CompiletimeOpsAnyModuleClass ) name match {
4217
4246
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
4218
4247
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
4248
+ case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString)
4249
+ case tpnme.IsConst if nArgs == 1 => isConst
4219
4250
case _ => None
4220
4251
} else if (owner == defn.CompiletimeOpsIntModuleClass ) name match {
4221
4252
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
4222
4253
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => - x)
4254
+ // ToString is deprecated for ops.int, and moved to ops.any
4223
4255
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
4224
4256
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
4225
4257
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
@@ -4244,9 +4276,43 @@ object Types {
4244
4276
case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
4245
4277
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
4246
4278
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
4279
+ case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer .numberOfLeadingZeros(_))
4280
+ case _ => None
4281
+ } else if (owner == defn.CompiletimeOpsLongModuleClass ) name match {
4282
+ case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs)
4283
+ case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => - x)
4284
+ case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _)
4285
+ case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _)
4286
+ case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _)
4287
+ case tpnme.Div if nArgs == 2 => constantFold2(longValue, {
4288
+ case (_, 0L ) => throw new TypeError (" Division by 0" )
4289
+ case (a, b) => a / b
4290
+ })
4291
+ case tpnme.Mod if nArgs == 2 => constantFold2(longValue, {
4292
+ case (_, 0L ) => throw new TypeError (" Modulo by 0" )
4293
+ case (a, b) => a % b
4294
+ })
4295
+ case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _)
4296
+ case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _)
4297
+ case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _)
4298
+ case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _)
4299
+ case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _)
4300
+ case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _)
4301
+ case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _)
4302
+ case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _)
4303
+ case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _)
4304
+ case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _)
4305
+ case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _)
4306
+ case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _)
4307
+ case tpnme.NumberOfLeadingZeros if nArgs == 1 =>
4308
+ constantFold1(longValue, java.lang.Long .numberOfLeadingZeros(_))
4247
4309
case _ => None
4248
4310
} else if (owner == defn.CompiletimeOpsStringModuleClass ) name match {
4249
4311
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
4312
+ case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length)
4313
+ case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _)
4314
+ case tpnme.Substring if nArgs == 3 =>
4315
+ constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
4250
4316
case _ => None
4251
4317
} else if (owner == defn.CompiletimeOpsBooleanModuleClass ) name match {
4252
4318
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => ! x)
0 commit comments