Skip to content

Commit a912944

Browse files
committed
Add a generic deserializer for Java/Scala 2.12 lambdas
Java support serialization of lambdas by using the serialization proxy pattern. Deserialization of a lambda uses `LambdaMetafactory` to create a new anonymous subclass. More details of the scheme are documented: https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html From those docs: > SerializedLambda has a readResolve method that looks for a > (possibly private) static method called $deserializeLambda$ > in the capturing class, invokes that with itself as the first > argument, and returns the result. Lambda classes implementing > $deserializeLambda$ are responsible for validating that the > properties of the SerializedLambda are consistent with a lambda > actually captured by that class. The Java compiler generates code in `$deserializeLambda$` that switches on the implementation method name and signature to locate an invokedynamic instruction generated for the particular lambda expression. Then, the `SerializedLambda` is further unpacked, validating that this implementation method still represents the same functional interface as it did when it was serialized. (The source may have been recompiled in the interim.) In Java, serializable lambda expressions are the exception rather than the rule. In Scala, however, the serializability of `FunctionN` means that we would end up generating a large amount of code to support deserialization. Instead, we are pursuing an alternative approach in which the `$deserializeLambda$` method is a simple forwarder to the generic deserializer added here. This is capable of deserializing lambdas created by the Java compiler, although this is not its intended use case. The enclosed tests use Java lambdas. This generic deserializer also works by calling `LambdaMetafactory`, but it does so explicitly, rather than implicitly during linkage of the `invokedynamic` instruction. We have to mimic the caching property of `invokedynamic` instruction to ensure we reuse the classes when constructing. I originally tried using a central cache, but wasn't able to come up with a scheme to avoid potential classloader memory leaks. Instead, I now allow the caller to provide a cache. The scala compiler will host an instance of this cache in each class that hosts a lambda. This is analagous the the `MethodCache` used by reflective calls. If the name or signature of the implementation method has changed, we fail during deserialization with an `IllegalArgumentError.` However, we do not fail fast in a few cases that Java would, as we cannot reflect on the "current" functional interface supported by this implementation method. We just instantiate using the "previous" functional interface class/method. This might: 1. fail inside `LambdaMetafactory` if the new implementation method is not compatible with the old functional interface. 2. pass through `LambdaMetafactory` by chance, but fail when instantiating the class in other cases. For example: ``` % tail sandbox/test{1,2}.scala ==> sandbox/test1.scala <== class C { def test: (String => String) = { val s: String = "" (t) => s + t } } ==> sandbox/test2.scala <== class C { def test: (String, String) => String = { (s, t) => s + t } } % (for i in 1 2; do scalac -Ydelambdafy:method -Xprint:delambdafy sandbox/test$i.scala 2>&1 ; done) | grep 'def $anon' final <static> <artifact> private[this] def $anonfun$1(t: String, s$1: String): String = s$1.+(t); final <static> <artifact> private[this] def $anonfun$1(s: String, t: String): String = s.+(t); ``` 3. Silently create an instance of the old functional interface. For example, imagine switching from `FuncInterface1` to `FuncInterface2` where these were identical other than the name. I don't believe that these are showstoppers. Failing test case demonstrating overly weak cache
1 parent 82eba69 commit a912944

File tree

2 files changed

+313
-0
lines changed

2 files changed

+313
-0
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package scala.compat.java8.runtime
2+
3+
import java.lang.invoke._
4+
5+
/**
6+
* This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12
7+
* compiler will add to classes hosting lambdas.
8+
*
9+
* It is not intended to be consumed directly.
10+
*/
11+
object LambdaDeserializer {
12+
/**
13+
* Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class
14+
* and instantiating this class with the captured arguments.
15+
*
16+
* A cache may be provided to ensure that subsequent deserialization of the same lambda expression
17+
* is cheap, it amounts to a reflective call to the constructor of the previously created class.
18+
* However, deserialization of the same lambda expression is not guaranteed to use the same class,
19+
* concurrent deserialization of the same lambda expression may spin up more than one class.
20+
*
21+
* Assumptions:
22+
* - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are
23+
* not stored in `SerializedLambda`, so we can't reconstitute them.
24+
* - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored.
25+
*
26+
* @param lookup The factory for method handles. Must have access to the implementation method, the
27+
* functional interface class, and `java.io.Serializable` or `scala.Serializable` as
28+
* required.
29+
* @param cache A cache used to avoid spinning up a class for each deserialization of a given lambda. May be `null`
30+
* @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve`
31+
* member of the anonymous class created by `LambdaMetaFactory`.
32+
* @return An instance of the functional interface
33+
*/
34+
def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = {
35+
def slashDot(name: String) = name.replaceAll("/", ".")
36+
val loader = lookup.lookupClass().getClassLoader
37+
val implClass = loader.loadClass(slashDot(serialized.getImplClass))
38+
39+
def makeCallSite: CallSite = {
40+
import serialized._
41+
def parseDescriptor(s: String) =
42+
MethodType.fromMethodDescriptorString(s, loader)
43+
44+
val funcInterfaceSignature = parseDescriptor(getFunctionalInterfaceMethodSignature)
45+
val instantiated = parseDescriptor(getInstantiatedMethodType)
46+
val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass))
47+
48+
val implMethodSig = parseDescriptor(getImplMethodSignature)
49+
// Construct the invoked type from the impl method type. This is the type of a factory
50+
// that will be generated by the meta-factory. It is a method type, with param types
51+
// coming form the types of the captures, and return type being the functional interface.
52+
val invokedType: MethodType = {
53+
// 1. Add receiver for non-static impl methods
54+
val withReceiver = getImplMethodKind match {
55+
case MethodHandleInfo.REF_invokeStatic | MethodHandleInfo.REF_newInvokeSpecial =>
56+
implMethodSig
57+
case _ =>
58+
implMethodSig.insertParameterTypes(0, implClass)
59+
}
60+
// 2. Remove lambda parameters, leaving only captures. Note: the receiver may be a lambda parameter,
61+
// such as in `Function<Object, String> s = Object::toString`
62+
val lambdaArity = funcInterfaceSignature.parameterCount()
63+
val from = withReceiver.parameterCount() - lambdaArity
64+
val to = withReceiver.parameterCount()
65+
66+
// 3. Drop the lambda return type and replace with the functional interface.
67+
withReceiver.dropParameterTypes(from, to).changeReturnType(functionalInterfaceClass)
68+
}
69+
70+
// Lookup the implementation method
71+
val implMethod: MethodHandle = try {
72+
findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
73+
} catch {
74+
case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e)
75+
}
76+
77+
val flags: Int = LambdaMetafactory.FLAG_SERIALIZABLE | LambdaMetafactory.FLAG_MARKERS
78+
val isScalaFunction = functionalInterfaceClass.getName.startsWith("scala.Function")
79+
val markerInterface: Class[_] = loader.loadClass(if (isScalaFunction) ScalaSerializable else JavaIOSerializable)
80+
81+
LambdaMetafactory.altMetafactory(
82+
lookup, getFunctionalInterfaceMethodName, invokedType,
83+
84+
/* samMethodType = */ funcInterfaceSignature,
85+
/* implMethod = */ implMethod,
86+
/* instantiatedMethodType = */ instantiated,
87+
/* flags = */ flags.asInstanceOf[AnyRef],
88+
/* markerInterfaceCount = */ 1.asInstanceOf[AnyRef],
89+
/* markerInterfaces[0] = */ markerInterface,
90+
/* bridgeCount = */ 0.asInstanceOf[AnyRef]
91+
)
92+
}
93+
94+
val key = serialized.getImplMethodName + " : " + serialized.getImplMethodSignature
95+
val factory: MethodHandle = if (cache == null) {
96+
makeCallSite.getTarget
97+
} else cache.get(key) match {
98+
case null =>
99+
val callSite = makeCallSite
100+
val temp = callSite.getTarget
101+
cache.put(key, temp)
102+
temp
103+
case target => target
104+
}
105+
106+
val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n))
107+
factory.invokeWithArguments(captures: _*)
108+
}
109+
110+
private val ScalaSerializable = "scala.Serializable"
111+
112+
private val JavaIOSerializable = {
113+
// We could actually omit this marker interface as LambdaMetaFactory will add it if
114+
// the FLAG_SERIALIZABLE is set and of the provided markers extend it. But the code
115+
// is cleaner if we uniformly add a single marker, so I'm leaving it in place.
116+
"java.io.Serializable"
117+
}
118+
119+
private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_],
120+
name: String, signature: MethodType): MethodHandle = {
121+
kind match {
122+
case MethodHandleInfo.REF_invokeStatic =>
123+
lookup.findStatic(owner, name, signature)
124+
case MethodHandleInfo.REF_newInvokeSpecial =>
125+
lookup.findConstructor(owner, signature)
126+
case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface =>
127+
lookup.findVirtual(owner, name, signature)
128+
case MethodHandleInfo.REF_invokeSpecial =>
129+
lookup.findSpecial(owner, name, signature, owner)
130+
}
131+
}
132+
}
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package scala.compat.java8.runtime;
2+
3+
import org.junit.Assert;
4+
import org.junit.Test;
5+
6+
import java.io.Serializable;
7+
import java.lang.invoke.MethodHandle;
8+
import java.lang.invoke.MethodHandles;
9+
import java.lang.invoke.SerializedLambda;
10+
import java.lang.reflect.Method;
11+
import java.util.Arrays;
12+
import java.util.HashMap;
13+
14+
public final class LambdaDeserializerTest {
15+
private LambdaHost lambdaHost = new LambdaHost();
16+
17+
@Test
18+
public void serializationPrivate() {
19+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByPrivateImplMethod();
20+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
21+
}
22+
23+
@Test
24+
public void serializationStatic() {
25+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
26+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
27+
}
28+
29+
@Test
30+
public void serializationVirtualMethodReference() {
31+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByVirtualMethodReference();
32+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
33+
}
34+
35+
@Test
36+
public void serializationInterfaceMethodReference() {
37+
F1<I, Object> f1 = lambdaHost.lambdaBackedByInterfaceMethodReference();
38+
I i = new I() {
39+
};
40+
Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i));
41+
}
42+
43+
@Test
44+
public void serializationStaticMethodReference() {
45+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticMethodReference();
46+
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
47+
}
48+
49+
@Test
50+
public void serializationNewInvokeSpecial() {
51+
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
52+
Assert.assertEquals(f1.apply(), reconstitute(f1).apply());
53+
}
54+
55+
@Test
56+
public void uncached() {
57+
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
58+
F0<Object> reconstituted1 = reconstitute(f1);
59+
F0<Object> reconstituted2 = reconstitute(f1);
60+
Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass());
61+
}
62+
63+
@Test
64+
public void cached() {
65+
HashMap<String, MethodHandle> cache = new HashMap<>();
66+
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
67+
F0<Object> reconstituted1 = reconstitute(f1, cache);
68+
F0<Object> reconstituted2 = reconstitute(f1, cache);
69+
Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass());
70+
}
71+
72+
@Test
73+
public void implMethodNameChanged() {
74+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
75+
SerializedLambda sl = writeReplace(f1);
76+
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
77+
}
78+
79+
@Test
80+
public void implMethodSignatureChanged() {
81+
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
82+
SerializedLambda sl = writeReplace(f1);
83+
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
84+
}
85+
86+
private void checkIllegalAccess(SerializedLambda serialized) {
87+
try {
88+
LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized);
89+
throw new AssertionError();
90+
} catch (IllegalArgumentException iae) {
91+
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
92+
Assert.fail("Unexpected message: " + iae.getMessage());
93+
}
94+
}
95+
}
96+
97+
private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) {
98+
Object[] captures = new Object[sl.getCapturedArgCount()];
99+
for (int i = 0; i < captures.length; i++) {
100+
captures[i] = sl.getCapturedArg(i);
101+
}
102+
return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(),
103+
sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature,
104+
sl.getInstantiatedMethodType(), captures);
105+
}
106+
107+
private Class<?> loadClass(String className) {
108+
try {
109+
return Class.forName(className.replace('/', '.'));
110+
} catch (ClassNotFoundException e) {
111+
throw new RuntimeException(e);
112+
}
113+
}
114+
private <A, B> A reconstitute(A f1) {
115+
return reconstitute(f1, null);
116+
}
117+
118+
@SuppressWarnings("unchecked")
119+
private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) {
120+
try {
121+
return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1));
122+
} catch (Exception e) {
123+
throw new RuntimeException(e);
124+
}
125+
}
126+
127+
private <A> SerializedLambda writeReplace(A f1) {
128+
try {
129+
Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace");
130+
writeReplace.setAccessible(true);
131+
return (SerializedLambda) writeReplace.invoke(f1);
132+
} catch (Exception e) {
133+
throw new RuntimeException(e);
134+
}
135+
}
136+
}
137+
138+
139+
interface F1<A, B> extends Serializable {
140+
B apply(A a);
141+
}
142+
143+
interface F0<A> extends Serializable {
144+
A apply();
145+
}
146+
147+
class LambdaHost {
148+
public F1<Boolean, String> lambdaBackedByPrivateImplMethod() {
149+
int local = 42;
150+
return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString();
151+
}
152+
153+
@SuppressWarnings("Convert2MethodRef")
154+
public F1<Boolean, String> lambdaBackedByStaticImplMethod() {
155+
return (b) -> String.valueOf(b);
156+
}
157+
158+
public F1<Boolean, String> lambdaBackedByStaticMethodReference() {
159+
return String::valueOf;
160+
}
161+
162+
public F1<Boolean, String> lambdaBackedByVirtualMethodReference() {
163+
return Object::toString;
164+
}
165+
166+
public F1<I, Object> lambdaBackedByInterfaceMethodReference() {
167+
return I::i;
168+
}
169+
170+
public F0<Object> lambdaBackedByConstructorCall() {
171+
return String::new;
172+
}
173+
174+
public static MethodHandles.Lookup lookup() {
175+
return MethodHandles.lookup();
176+
}
177+
}
178+
179+
interface I {
180+
default String i() { return "i"; };
181+
}

0 commit comments

Comments
 (0)