From 8c5d4ee572f56176ed06a26f091461a9703d71eb Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Tue, 12 May 2015 15:26:59 +1000 Subject: [PATCH 1/2] Add a generic deserializer for Java/Scala 2.12 lambdas Java support serialization of lambdas by using the serialization proxy pattern. Deserialization of a lambda uses `LambdaMetafactory` to create a new anonymous subclass. More details of the scheme are documented: https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html From those docs: > SerializedLambda has a readResolve method that looks for a > (possibly private) static method called $deserializeLambda$ > in the capturing class, invokes that with itself as the first > argument, and returns the result. Lambda classes implementing > $deserializeLambda$ are responsible for validating that the > properties of the SerializedLambda are consistent with a lambda > actually captured by that class. The Java compiler generates code in `$deserializeLambda$` that switches on the implementation method name and signature to locate an invokedynamic instruction generated for the particular lambda expression. Then, the `SerializedLambda` is further unpacked, validating that this implementation method still represents the same functional interface as it did when it was serialized. (The source may have been recompiled in the interim.) In Java, serializable lambda expressions are the exception rather than the rule. In Scala, however, the serializability of `FunctionN` means that we would end up generating a large amount of code to support deserialization. Instead, we are pursuing an alternative approach in which the `$deserializeLambda$` method is a simple forwarder to the generic deserializer added here. This is capable of deserializing lambdas created by the Java compiler, although this is not its intended use case. The enclosed tests use Java lambdas. This generic deserializer also works by calling `LambdaMetafactory`, but it does so explicitly, rather than implicitly during linkage of the `invokedynamic` instruction. We have to mimic the caching property of `invokedynamic` instruction to ensure we reuse the classes when constructing. The cache here uses weak references to keys and values to avoid retention of `Class` or `ClassLoader` instances. If the name or signature of the implementation method has changed, we fail during deserialization with an `IllegalArgumentError.` However, we do not fail fast in a few cases that Java would, as we cannot reflect on the "current" functional interface supported by this implementation method. We just instantiate using the "previous" functional interface class/method. This might: 1. fail inside `LambdaMetafactory` if the new implementation method is not compatible with the old functional interface. 2. pass through `LambdaMetafactory` by chance, but fail when instantiating the class in other cases. For example: ``` % tail sandbox/test{1,2}.scala ==> sandbox/test1.scala <== class C { def test: (String => String) = { val s: String = "" (t) => s + t } } ==> sandbox/test2.scala <== class C { def test: (String, String) => String = { (s, t) => s + t } } % (for i in 1 2; do scalac -Ydelambdafy:method -Xprint:delambdafy sandbox/test$i.scala 2>&1 ; done) | grep 'def $anon' final private[this] def $anonfun$1(t: String, s$1: String): String = s$1.+(t); final private[this] def $anonfun$1(s: String, t: String): String = s.+(t); ``` 3. Silently create an instance of the old functional interface. For example, imagine switching from `FuncInterface1` to `FuncInterface2` where these were identical other than the name. I don't believe that these are showstoppers. --- build.sbt | 4 +- .../java8/runtime/LambdaDeserializer.scala | 112 ++++++++++++++++++ .../java8/runtime/LambdaDeserializerTest.java | 108 +++++++++++++++++ 3 files changed, 222 insertions(+), 2 deletions(-) create mode 100644 src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala create mode 100644 src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java diff --git a/build.sbt b/build.sbt index 4b3a450..6e85837 100644 --- a/build.sbt +++ b/build.sbt @@ -2,9 +2,9 @@ import com.typesafe.tools.mima.plugin.{MimaPlugin, MimaKeys} scalaModuleSettings -scalaVersion := "2.11.5" +scalaVersion := "2.11.6" -snapshotScalaBinaryVersion := "2.11.5" +snapshotScalaBinaryVersion := "2.11.6" organization := "org.scala-lang.modules" diff --git a/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala b/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala new file mode 100644 index 0000000..714e2f3 --- /dev/null +++ b/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala @@ -0,0 +1,112 @@ +package scala.compat.java8.runtime + +import java.lang.invoke._ +import java.lang.ref.WeakReference + +/** + * This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12 + * compiler will add to classes hosting lambdas. + * + * It is intended to be consumed directly. + */ +object LambdaDeserializer { + private final case class CacheKey(implClass: Class[_], implMethodName: String, implMethodSignature: String) + private val cache = new java.util.WeakHashMap[CacheKey, WeakReference[CallSite]]() + + /** + * Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class + * and instantiating this class with the captured arguments. + * + * A cache is employed to ensure that subsequent deserialization of the same lambda expression + * is cheap, it amounts to a reflective call to the constructor of the previously created class. + * However, deserialization of the same lambda expression is not guaranteed to use the same class, + * concurrent deserialization of the same lambda expression may spin up more than one class. + * + * This cache is weak in keys and values to avoid retention of the enclosing class (and its classloader) + * of deserialized lambdas. + * + * Assumptions: + * - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are + * not stored in `SerializedLambda`, so we can't reconstitute them. + * - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored. + * + * Note: The Java compiler + * + * @param lookup The factory for method handles. Must have access to the implementation method, the + * functional interface class, and `java.io.Serializable` or `scala.Serializable` as + * required. + * @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve` + * member of the anonymous class created by `LambdaMetaFactory`. + * @return An instance of the functional interface + */ + def deserializeLambda(lookup: MethodHandles.Lookup, serialized: SerializedLambda): AnyRef = { + def slashDot(name: String) = name.replaceAll("/", ".") + val loader = lookup.lookupClass().getClassLoader + val implClass = loader.loadClass(slashDot(serialized.getImplClass)) + + def makeCallSite: CallSite = { + import serialized._ + def parseDescriptor(s: String) = + MethodType.fromMethodDescriptorString(s, loader) + + val funcInterfacesSignature = parseDescriptor(getFunctionalInterfaceMethodSignature) + val methodType: MethodType = funcInterfacesSignature + val instantiated = parseDescriptor(getInstantiatedMethodType) + val implMethodSig = parseDescriptor(getImplMethodSignature) + + val from = implMethodSig.parameterCount() - funcInterfacesSignature.parameterCount() + val to = implMethodSig.parameterCount() + val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass)) + var invokedType: MethodType = + implMethodSig.dropParameterTypes(from, to) + .changeReturnType(functionalInterfaceClass) + + val implMethod: MethodHandle = try { + getImplMethodKind match { + case MethodHandleInfo.REF_invokeStatic => + lookup.findStatic(implClass, getImplMethodName, implMethodSig) + case MethodHandleInfo.REF_invokeVirtual => + invokedType = invokedType.insertParameterTypes(0, implClass) + lookup.findVirtual(implClass, getImplMethodName, implMethodSig) + case MethodHandleInfo.REF_invokeSpecial => + invokedType = invokedType.insertParameterTypes(0, implClass) + lookup.findSpecial(implClass, getImplMethodName, implMethodSig, implClass) + } + } catch { + case e: ReflectiveOperationException => + throw new IllegalArgumentException("Illegal lambda deserialization", e) + } + val FLAG_SERIALIZABLE = 1 + val FLAG_MARKERS = 2 + val flags: Int = FLAG_SERIALIZABLE | FLAG_MARKERS + val markerInterface: Class[_] = if (functionalInterfaceClass.getName.startsWith("scala.Function")) + loader.loadClass("scala.Serializable") + else + loader.loadClass("java.io.Serializable") + + LambdaMetafactory.altMetafactory( + lookup, getFunctionalInterfaceMethodName, invokedType, + + /* samMethodType = */ funcInterfacesSignature, + /* implMethod = */ implMethod, + /* instantiatedMethodType = */ instantiated, + /* flags = */ flags.asInstanceOf[AnyRef], + /* markerInterfaceCount = */ 1.asInstanceOf[AnyRef], + /* markerInterfaces[0] = */ markerInterface, + /* bridgeCount = */ 0.asInstanceOf[AnyRef] + ) + } + + val key = new CacheKey(implClass, serialized.getImplMethodName, serialized.getImplMethodSignature) + val callSiteRef: WeakReference[CallSite] = cache.get(key) + var site = if (callSiteRef == null) null else callSiteRef.get() + if (site == null) { + site = makeCallSite + cache.put(key, new WeakReference(site)) + } + + val factory = site.getTarget + val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n)) + factory.invokeWithArguments(captures: _*) + } +} diff --git a/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java new file mode 100644 index 0000000..d3f707f --- /dev/null +++ b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java @@ -0,0 +1,108 @@ +package scala.compat.java8.runtime; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.Serializable; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.SerializedLambda; +import java.lang.reflect.Method; +import java.util.Arrays; + +public final class LambdaDeserializerTest { + private LambdaHost lambdaHost = new LambdaHost(); + + @Test + public void serializationPrivate() { + F1 f1 = lambdaHost.lambdaBackedByPrivateImplMethod(); + F1 f2 = reconstitute(f1); + Assert.assertEquals(f1.apply(true), f2.apply(true)); + } + + @Test + public void serializationStatic() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + F1 f2 = reconstitute(f1); + Assert.assertEquals(f1.apply(true), f2.apply(true)); + } + + @Test + public void implMethodNameChanged() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); + } + + @Test + public void implMethodSignatureChanged() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); + } + + private void checkIllegalAccess(SerializedLambda serialized) { + try { + LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), serialized); + throw new AssertionError(); + } catch (IllegalArgumentException iae) { + if (!iae.getMessage().contains("Illegal lambda deserialization")) { + Assert.fail("Unexpected message: " + iae.getMessage()); + } + } + } + + private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) { + Object[] captures = new Object[sl.getCapturedArgCount()]; + for (int i = 0; i < captures.length; i++) { + captures[i] = sl.getCapturedArg(i); + } + return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(), + sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature, + sl.getInstantiatedMethodType(), captures); + } + + private Class loadClass(String className) { + try { + return Class.forName(className.replace('/', '.')); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private F1 reconstitute(F1 f1) { + try { + return (F1) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), writeReplace(f1)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private SerializedLambda writeReplace(F1 f1) { + try { + Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace"); + writeReplace.setAccessible(true); + return (SerializedLambda) writeReplace.invoke(f1); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} + + +interface F1 extends Serializable { + B apply(A a); +} + +class LambdaHost { + public F1 lambdaBackedByPrivateImplMethod() { + int local = 42; + return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString(); + } + + public F1 lambdaBackedByStaticImplMethod() { + int local = 42; + return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString(); + } + + public static MethodHandles.Lookup lookup() { return MethodHandles.lookup(); } +} From 599f3c28d1fc138fbeb35777d87905345f846128 Mon Sep 17 00:00:00 2001 From: Lukas Rytz Date: Wed, 13 May 2015 10:33:35 +0200 Subject: [PATCH 2/2] Update scala modules sbt plugin to 1.0.3 --- build.sbt | 14 +------------- project/plugins.sbt | 4 +--- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/build.sbt b/build.sbt index 6e85837..7c21bf0 100644 --- a/build.sbt +++ b/build.sbt @@ -1,11 +1,7 @@ -import com.typesafe.tools.mima.plugin.{MimaPlugin, MimaKeys} - scalaModuleSettings scalaVersion := "2.11.6" -snapshotScalaBinaryVersion := "2.11.6" - organization := "org.scala-lang.modules" name := "scala-java8-compat" @@ -23,15 +19,7 @@ libraryDependencies += "junit" % "junit" % "4.11" % "test" libraryDependencies += "com.novocode" % "junit-interface" % "0.10" % "test" -MimaPlugin.mimaDefaultSettings - -MimaKeys.previousArtifact := None - -// run mima during tests -test in Test := { - MimaKeys.reportBinaryIssues.value - (test in Test).value -} +mimaPreviousVersion := None testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a") diff --git a/project/plugins.sbt b/project/plugins.sbt index 5f604e6..25f3373 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1 @@ -addSbtPlugin("org.scala-lang.modules" % "scala-module-plugin" % "1.0.2") - -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") +addSbtPlugin("org.scala-lang.modules" % "scala-module-plugin" % "1.0.3")