diff --git a/build.sbt b/build.sbt index 4b3a450..7c21bf0 100644 --- a/build.sbt +++ b/build.sbt @@ -1,10 +1,6 @@ -import com.typesafe.tools.mima.plugin.{MimaPlugin, MimaKeys} - scalaModuleSettings -scalaVersion := "2.11.5" - -snapshotScalaBinaryVersion := "2.11.5" +scalaVersion := "2.11.6" organization := "org.scala-lang.modules" @@ -23,15 +19,7 @@ libraryDependencies += "junit" % "junit" % "4.11" % "test" libraryDependencies += "com.novocode" % "junit-interface" % "0.10" % "test" -MimaPlugin.mimaDefaultSettings - -MimaKeys.previousArtifact := None - -// run mima during tests -test in Test := { - MimaKeys.reportBinaryIssues.value - (test in Test).value -} +mimaPreviousVersion := None testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a") diff --git a/project/plugins.sbt b/project/plugins.sbt index 5f604e6..25f3373 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1 @@ -addSbtPlugin("org.scala-lang.modules" % "scala-module-plugin" % "1.0.2") - -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") +addSbtPlugin("org.scala-lang.modules" % "scala-module-plugin" % "1.0.3") diff --git a/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala b/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala new file mode 100644 index 0000000..714e2f3 --- /dev/null +++ b/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala @@ -0,0 +1,112 @@ +package scala.compat.java8.runtime + +import java.lang.invoke._ +import java.lang.ref.WeakReference + +/** + * This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12 + * compiler will add to classes hosting lambdas. + * + * It is intended to be consumed directly. + */ +object LambdaDeserializer { + private final case class CacheKey(implClass: Class[_], implMethodName: String, implMethodSignature: String) + private val cache = new java.util.WeakHashMap[CacheKey, WeakReference[CallSite]]() + + /** + * Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class + * and instantiating this class with the captured arguments. + * + * A cache is employed to ensure that subsequent deserialization of the same lambda expression + * is cheap, it amounts to a reflective call to the constructor of the previously created class. + * However, deserialization of the same lambda expression is not guaranteed to use the same class, + * concurrent deserialization of the same lambda expression may spin up more than one class. + * + * This cache is weak in keys and values to avoid retention of the enclosing class (and its classloader) + * of deserialized lambdas. + * + * Assumptions: + * - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are + * not stored in `SerializedLambda`, so we can't reconstitute them. + * - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored. + * + * Note: The Java compiler + * + * @param lookup The factory for method handles. Must have access to the implementation method, the + * functional interface class, and `java.io.Serializable` or `scala.Serializable` as + * required. + * @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve` + * member of the anonymous class created by `LambdaMetaFactory`. + * @return An instance of the functional interface + */ + def deserializeLambda(lookup: MethodHandles.Lookup, serialized: SerializedLambda): AnyRef = { + def slashDot(name: String) = name.replaceAll("/", ".") + val loader = lookup.lookupClass().getClassLoader + val implClass = loader.loadClass(slashDot(serialized.getImplClass)) + + def makeCallSite: CallSite = { + import serialized._ + def parseDescriptor(s: String) = + MethodType.fromMethodDescriptorString(s, loader) + + val funcInterfacesSignature = parseDescriptor(getFunctionalInterfaceMethodSignature) + val methodType: MethodType = funcInterfacesSignature + val instantiated = parseDescriptor(getInstantiatedMethodType) + val implMethodSig = parseDescriptor(getImplMethodSignature) + + val from = implMethodSig.parameterCount() - funcInterfacesSignature.parameterCount() + val to = implMethodSig.parameterCount() + val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass)) + var invokedType: MethodType = + implMethodSig.dropParameterTypes(from, to) + .changeReturnType(functionalInterfaceClass) + + val implMethod: MethodHandle = try { + getImplMethodKind match { + case MethodHandleInfo.REF_invokeStatic => + lookup.findStatic(implClass, getImplMethodName, implMethodSig) + case MethodHandleInfo.REF_invokeVirtual => + invokedType = invokedType.insertParameterTypes(0, implClass) + lookup.findVirtual(implClass, getImplMethodName, implMethodSig) + case MethodHandleInfo.REF_invokeSpecial => + invokedType = invokedType.insertParameterTypes(0, implClass) + lookup.findSpecial(implClass, getImplMethodName, implMethodSig, implClass) + } + } catch { + case e: ReflectiveOperationException => + throw new IllegalArgumentException("Illegal lambda deserialization", e) + } + val FLAG_SERIALIZABLE = 1 + val FLAG_MARKERS = 2 + val flags: Int = FLAG_SERIALIZABLE | FLAG_MARKERS + val markerInterface: Class[_] = if (functionalInterfaceClass.getName.startsWith("scala.Function")) + loader.loadClass("scala.Serializable") + else + loader.loadClass("java.io.Serializable") + + LambdaMetafactory.altMetafactory( + lookup, getFunctionalInterfaceMethodName, invokedType, + + /* samMethodType = */ funcInterfacesSignature, + /* implMethod = */ implMethod, + /* instantiatedMethodType = */ instantiated, + /* flags = */ flags.asInstanceOf[AnyRef], + /* markerInterfaceCount = */ 1.asInstanceOf[AnyRef], + /* markerInterfaces[0] = */ markerInterface, + /* bridgeCount = */ 0.asInstanceOf[AnyRef] + ) + } + + val key = new CacheKey(implClass, serialized.getImplMethodName, serialized.getImplMethodSignature) + val callSiteRef: WeakReference[CallSite] = cache.get(key) + var site = if (callSiteRef == null) null else callSiteRef.get() + if (site == null) { + site = makeCallSite + cache.put(key, new WeakReference(site)) + } + + val factory = site.getTarget + val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n)) + factory.invokeWithArguments(captures: _*) + } +} diff --git a/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java new file mode 100644 index 0000000..d3f707f --- /dev/null +++ b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java @@ -0,0 +1,108 @@ +package scala.compat.java8.runtime; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.Serializable; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.SerializedLambda; +import java.lang.reflect.Method; +import java.util.Arrays; + +public final class LambdaDeserializerTest { + private LambdaHost lambdaHost = new LambdaHost(); + + @Test + public void serializationPrivate() { + F1 f1 = lambdaHost.lambdaBackedByPrivateImplMethod(); + F1 f2 = reconstitute(f1); + Assert.assertEquals(f1.apply(true), f2.apply(true)); + } + + @Test + public void serializationStatic() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + F1 f2 = reconstitute(f1); + Assert.assertEquals(f1.apply(true), f2.apply(true)); + } + + @Test + public void implMethodNameChanged() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); + } + + @Test + public void implMethodSignatureChanged() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); + } + + private void checkIllegalAccess(SerializedLambda serialized) { + try { + LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), serialized); + throw new AssertionError(); + } catch (IllegalArgumentException iae) { + if (!iae.getMessage().contains("Illegal lambda deserialization")) { + Assert.fail("Unexpected message: " + iae.getMessage()); + } + } + } + + private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) { + Object[] captures = new Object[sl.getCapturedArgCount()]; + for (int i = 0; i < captures.length; i++) { + captures[i] = sl.getCapturedArg(i); + } + return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(), + sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature, + sl.getInstantiatedMethodType(), captures); + } + + private Class loadClass(String className) { + try { + return Class.forName(className.replace('/', '.')); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private F1 reconstitute(F1 f1) { + try { + return (F1) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), writeReplace(f1)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private SerializedLambda writeReplace(F1 f1) { + try { + Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace"); + writeReplace.setAccessible(true); + return (SerializedLambda) writeReplace.invoke(f1); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} + + +interface F1 extends Serializable { + B apply(A a); +} + +class LambdaHost { + public F1 lambdaBackedByPrivateImplMethod() { + int local = 42; + return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString(); + } + + public F1 lambdaBackedByStaticImplMethod() { + int local = 42; + return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString(); + } + + public static MethodHandles.Lookup lookup() { return MethodHandles.lookup(); } +}