Skip to content

Commit b9942ea

Browse files
Runtime hooks (#412)
* Add crac context implementation to RIC * Add restore endpoint and runtime hooks execution logic --------- Co-authored-by: Anton Stepanov <110172761+anton-stepanof@users.noreply.github.com>
1 parent db9ef92 commit b9942ea

File tree

15 files changed

+959
-281
lines changed

15 files changed

+959
-281
lines changed

aws-lambda-java-runtime-interface-client/pom.xml

Lines changed: 282 additions & 269 deletions
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*/
4+
5+
package com.amazonaws.services.lambda.crac;
6+
7+
public class CheckpointException extends Exception {
8+
private static final long serialVersionUID = -4956873658083157585L;
9+
public CheckpointException() {
10+
super();
11+
}
12+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*/
4+
5+
package com.amazonaws.services.lambda.crac;
6+
7+
public abstract class Context<R extends Resource> implements Resource {
8+
9+
protected Context() {
10+
}
11+
12+
@Override
13+
public abstract void beforeCheckpoint(Context<? extends Resource> context)
14+
throws CheckpointException;
15+
16+
@Override
17+
public abstract void afterRestore(Context<? extends Resource> context)
18+
throws RestoreException;
19+
20+
public abstract void register(R resource);
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*/
4+
5+
package com.amazonaws.services.lambda.crac;
6+
7+
import java.util.ArrayList;
8+
import java.util.Collections;
9+
import java.util.List;
10+
import java.util.Map;
11+
import java.util.WeakHashMap;
12+
import java.util.stream.Collectors;
13+
14+
15+
/**
16+
* Spec reference: https://crac.github.io/openjdk-builds/javadoc/api/java.base/jdk/crac/package-summary.html
17+
*/
18+
19+
public class ContextImpl extends Context<Resource> {
20+
21+
private volatile long order = -1L;
22+
private final WeakHashMap<Resource, Long> checkpointQueue = new WeakHashMap<>();
23+
24+
@Override
25+
public synchronized void beforeCheckpoint(Context<? extends Resource> context) throws CheckpointException {
26+
27+
List<Throwable> exceptionsThrown = new ArrayList<>();
28+
for (Resource resource : getCheckpointQueueReverseOrderOfRegistration()) {
29+
try {
30+
resource.beforeCheckpoint(this);
31+
} catch (CheckpointException e) {
32+
Collections.addAll(exceptionsThrown, e.getSuppressed());
33+
} catch (Exception e) {
34+
exceptionsThrown.add(e);
35+
}
36+
}
37+
38+
if (!exceptionsThrown.isEmpty()) {
39+
CheckpointException checkpointException = new CheckpointException();
40+
for (Throwable t : exceptionsThrown) {
41+
checkpointException.addSuppressed(t);
42+
}
43+
throw checkpointException;
44+
}
45+
}
46+
47+
@Override
48+
public synchronized void afterRestore(Context<? extends Resource> context) throws RestoreException {
49+
50+
List<Throwable> exceptionsThrown = new ArrayList<>();
51+
for (Resource resource : getCheckpointQueueForwardOrderOfRegistration()) {
52+
try {
53+
resource.afterRestore(this);
54+
} catch (RestoreException e) {
55+
Collections.addAll(exceptionsThrown, e.getSuppressed());
56+
} catch (Exception e) {
57+
exceptionsThrown.add(e);
58+
}
59+
}
60+
61+
if (!exceptionsThrown.isEmpty()) {
62+
RestoreException restoreException = new RestoreException();
63+
for (Throwable t : exceptionsThrown) {
64+
restoreException.addSuppressed(t);
65+
}
66+
throw restoreException;
67+
}
68+
}
69+
70+
@Override
71+
public synchronized void register(Resource resource) {
72+
checkpointQueue.put(resource, ++order);
73+
}
74+
75+
private List<Resource> getCheckpointQueueReverseOrderOfRegistration() {
76+
return checkpointQueue.entrySet()
77+
.stream()
78+
.sorted((r1, r2) -> (int) (r2.getValue() - r1.getValue()))
79+
.map(Map.Entry::getKey)
80+
.collect(Collectors.toList());
81+
}
82+
83+
private List<Resource> getCheckpointQueueForwardOrderOfRegistration() {
84+
return checkpointQueue.entrySet()
85+
.stream()
86+
.sorted((r1, r2) -> (int) (r1.getValue() - r2.getValue()))
87+
.map(Map.Entry::getKey)
88+
.collect(Collectors.toList());
89+
}
90+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*/
4+
5+
package com.amazonaws.services.lambda.crac;
6+
7+
/**
8+
* Provides the global context for registering resources.
9+
*/
10+
public final class Core {
11+
12+
private Core() {
13+
}
14+
15+
private static Context<Resource> globalContext = new ContextImpl();
16+
17+
public static Context<Resource> getGlobalContext() {
18+
return globalContext;
19+
}
20+
21+
public static void checkpointRestore() {
22+
throw new UnsupportedOperationException();
23+
}
24+
25+
static void resetGlobalContext() {
26+
globalContext = new ContextImpl();
27+
}
28+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*/
4+
5+
package com.amazonaws.services.lambda.crac;
6+
7+
public interface Resource {
8+
void afterRestore(Context<? extends Resource> context) throws Exception;
9+
10+
void beforeCheckpoint(Context<? extends Resource> context) throws Exception;
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*/
4+
5+
package com.amazonaws.services.lambda.crac;
6+
7+
public class RestoreException extends Exception {
8+
private static final long serialVersionUID = -823900409868237860L;
9+
10+
public RestoreException() {
11+
super();
12+
}
13+
}

aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/AWSLambda.java

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//
66
package com.amazonaws.services.lambda.runtime.api.client;
77

8+
import com.amazonaws.services.lambda.crac.Core;
89
import com.amazonaws.services.lambda.runtime.LambdaLogger;
910
import com.amazonaws.services.lambda.runtime.api.client.LambdaRequestHandler.UserFaultHandler;
1011
import com.amazonaws.services.lambda.runtime.api.client.logging.FramedTelemetryLogSink;
@@ -32,6 +33,7 @@
3233
import java.security.Security;
3334
import java.util.Properties;
3435

36+
3537
/**
3638
* The entrypoint of this class is {@link AWSLambda#startRuntime}. It performs two main tasks:
3739
*
@@ -59,6 +61,10 @@ public class AWSLambda {
5961
// https://github.com/aws/aws-xray-sdk-java/blob/2f467e50db61abb2ed2bd630efc21bddeabd64d9/aws-xray-recorder-sdk-core/src/main/java/com/amazonaws/xray/contexts/LambdaSegmentContext.java#L39-L40
6062
private static final String LAMBDA_TRACE_HEADER_PROP = "com.amazonaws.xray.traceHeader";
6163

64+
private static final String INIT_TYPE_SNAP_START = "snap-start";
65+
66+
private static final String AWS_LAMBDA_INITIALIZATION_TYPE = System.getenv(ReservedRuntimeEnvironmentVariables.AWS_LAMBDA_INITIALIZATION_TYPE);
67+
6268
static {
6369
// Override the disabledAlgorithms setting to match configuration for openjdk8-u181.
6470
// This is to keep DES ciphers around while we deploying security updates.
@@ -211,14 +217,13 @@ private static void startRuntime(String handler, LambdaLogger lambdaLogger) thro
211217
requestHandler = findRequestHandler(handler, customerClassLoader);
212218
} catch (UserFault userFault) {
213219
lambdaLogger.log(userFault.reportableError());
214-
ByteArrayOutputStream payload = new ByteArrayOutputStream(1024);
215-
Failure failure = new Failure(userFault);
216-
GsonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
217-
runtimeClient.postInitError(payload.toByteArray(), failure.getErrorType());
220+
reportInitError(new Failure(userFault), runtimeClient);
218221
System.exit(1);
219222
return;
220223
}
221-
224+
if (INIT_TYPE_SNAP_START.equals(AWS_LAMBDA_INITIALIZATION_TYPE)) {
225+
onInitComplete(runtimeClient, lambdaLogger);
226+
}
222227
boolean shouldExit = false;
223228
while (!shouldExit) {
224229
UserFault userFault = null;
@@ -260,6 +265,49 @@ private static void startRuntime(String handler, LambdaLogger lambdaLogger) thro
260265
}
261266
}
262267

268+
static void onInitComplete(final LambdaRuntimeClient runtimeClient, final LambdaLogger lambdaLogger) throws IOException {
269+
try {
270+
Core.getGlobalContext().beforeCheckpoint(null);
271+
// Blocking call to RAPID /restore/next API, will return after taking snapshot.
272+
// This will also be the 'entrypoint' when resuming from snapshots.
273+
runtimeClient.getRestoreNext();
274+
} catch (Exception e1) {
275+
logExceptionCloudWatch(lambdaLogger, e1);
276+
reportInitError(new Failure(e1), runtimeClient);
277+
System.exit(64);
278+
}
279+
try {
280+
Core.getGlobalContext().afterRestore(null);
281+
} catch (Exception restoreExc) {
282+
logExceptionCloudWatch(lambdaLogger, restoreExc);
283+
Failure errorPayload = new Failure(restoreExc);
284+
reportRestoreError(errorPayload, runtimeClient);
285+
System.exit(64);
286+
}
287+
}
288+
289+
private static void logExceptionCloudWatch(LambdaLogger lambdaLogger, Exception exc) {
290+
UserFault.filterStackTrace(exc);
291+
UserFault userFault = UserFault.makeUserFault(exc, true);
292+
lambdaLogger.log(userFault.reportableError());
293+
}
294+
295+
static void reportInitError(final Failure failure,
296+
final LambdaRuntimeClient runtimeClient) throws IOException {
297+
298+
ByteArrayOutputStream payload = new ByteArrayOutputStream(1024);
299+
JacksonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
300+
runtimeClient.postInitError(payload.toByteArray(), failure.getErrorType());
301+
}
302+
303+
static int reportRestoreError(final Failure failure,
304+
final LambdaRuntimeClient runtimeClient) throws IOException {
305+
306+
ByteArrayOutputStream payload = new ByteArrayOutputStream(1024);
307+
JacksonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
308+
return runtimeClient.postRestoreError(payload.toByteArray(), failure.getErrorType());
309+
}
310+
263311
private static PojoSerializer<XRayErrorCause> xRayErrorCauseSerializer;
264312

265313
/**

aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/Failure.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,8 @@ public int compare(Class o1, Class o2) {
8181
public String getErrorType() {
8282
return errorType;
8383
}
84+
85+
public String getErrorMessage() {
86+
return errorMessage;
87+
}
8488
}

aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/ReservedRuntimeEnvironmentVariables.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ public interface ReservedRuntimeEnvironmentVariables {
8787
*/
8888
String AWS_LAMBDA_RUNTIME_API = "AWS_LAMBDA_RUNTIME_API";
8989

90+
91+
/**
92+
* Initialization type
93+
*/
94+
String AWS_LAMBDA_INITIALIZATION_TYPE = "AWS_LAMBDA_INITIALIZATION_TYPE";
95+
9096
/**
9197
* The path to your Lambda function code.
9298
*/

aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/UserFault.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ public static <T extends Throwable> T filterStackTrace(T t) {
8080
if(cause != null) {
8181
filterStackTrace(cause);
8282
}
83+
84+
Throwable[] suppressedExceptions = t.getSuppressed();
85+
for(Throwable suppressed: suppressedExceptions) {
86+
filterStackTrace(suppressed);
87+
}
88+
8389
return t;
8490
}
8591

0 commit comments

Comments
 (0)