Skip to content

Commit 6cb1c42

Browse files
Merge pull request #67 from bdonlan/kms-cache-restore
Restore KMS caching logic
2 parents 0e15a35 + 2a6e6e4 commit 6cb1c42

13 files changed

+178
-35
lines changed

CHANGELOG.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
# Changelog
22

3-
## 1.3.5
3+
## 1.3.6
4+
5+
### Minor Changes
46

57
(nothing yet)
68

9+
## 1.3.5
10+
11+
### Minor Changes
12+
13+
* Restored the KMS client cache with a fix for the memory leak.
14+
* When using a master key provider that can only service a subset of regions
15+
(e.g. using the deprecated constructors), and requesting a master key from a
16+
region not servicable by that MKP, the exception will now be thrown on first
17+
use of the MK, rather than at getMasterKey time.
18+
719
## 1.3.4
820

921
### Minor Changes

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ You can get the latest release from Maven:
4545
<dependency>
4646
<groupId>com.amazonaws</groupId>
4747
<artifactId>aws-encryption-sdk-java</artifactId>
48-
<version>1.3.4</version>
48+
<version>1.3.5</version>
4949
</dependency>
5050
```
5151

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
<groupId>com.amazonaws</groupId>
66
<artifactId>aws-encryption-sdk-java</artifactId>
7-
<version>1.3.5-SNAPSHOT</version>
7+
<version>1.3.6-SNAPSHOT</version>
88
<packaging>jar</packaging>
99

1010
<name>aws-encryption-sdk-java</name>

src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.Collection;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.function.Supplier;
2425

2526
import com.amazonaws.AmazonServiceException;
2627
import com.amazonaws.AmazonWebServiceRequest;
@@ -48,7 +49,7 @@
4849
* {@link AwsCrypto}.
4950
*/
5051
public final class KmsMasterKey extends MasterKey<KmsMasterKey> implements KmsMethods {
51-
private final AWSKMS kms_;
52+
private final Supplier<AWSKMS> kms_;
5253
private final MasterKeyProvider<KmsMasterKey> sourceProvider_;
5354
private final String id_;
5455
private final List<String> grantTokens_ = new ArrayList<>();
@@ -77,12 +78,12 @@ public static KmsMasterKey getInstance(final AWSCredentialsProvider creds, final
7778
return new KmsMasterKeyProvider(creds, keyId).getMasterKey(keyId);
7879
}
7980

80-
static KmsMasterKey getInstance(final AWSKMS kms, final String id,
81+
static KmsMasterKey getInstance(final Supplier<AWSKMS> kms, final String id,
8182
final MasterKeyProvider<KmsMasterKey> provider) {
8283
return new KmsMasterKey(kms, id, provider);
8384
}
8485

85-
private KmsMasterKey(final AWSKMS kms, final String id, final MasterKeyProvider<KmsMasterKey> provider) {
86+
private KmsMasterKey(final Supplier<AWSKMS> kms, final String id, final MasterKeyProvider<KmsMasterKey> provider) {
8687
kms_ = kms;
8788
id_ = id;
8889
sourceProvider_ = provider;
@@ -101,7 +102,7 @@ public String getKeyId() {
101102
@Override
102103
public DataKey<KmsMasterKey> generateDataKey(final CryptoAlgorithm algorithm,
103104
final Map<String, String> encryptionContext) {
104-
final GenerateDataKeyResult gdkResult = kms_.generateDataKey(updateUserAgent(
105+
final GenerateDataKeyResult gdkResult = kms_.get().generateDataKey(updateUserAgent(
105106
new GenerateDataKeyRequest()
106107
.withKeyId(getKeyId())
107108
.withNumberOfBytes(algorithm.getDataKeyLength())
@@ -145,7 +146,7 @@ public DataKey<KmsMasterKey> encryptDataKey(final CryptoAlgorithm algorithm,
145146
throw new IllegalArgumentException("Only RAW encoded keys are supported");
146147
}
147148
try {
148-
final EncryptResult encryptResult = kms_.encrypt(updateUserAgent(
149+
final EncryptResult encryptResult = kms_.get().encrypt(updateUserAgent(
149150
new EncryptRequest()
150151
.withKeyId(id_)
151152
.withPlaintext(ByteBuffer.wrap(key.getEncoded()))
@@ -167,7 +168,7 @@ public DataKey<KmsMasterKey> decryptDataKey(final CryptoAlgorithm algorithm,
167168
final List<Exception> exceptions = new ArrayList<>();
168169
for (final EncryptedDataKey edk : encryptedDataKeys) {
169170
try {
170-
final DecryptResult decryptResult = kms_.decrypt(updateUserAgent(
171+
final DecryptResult decryptResult = kms_.get().decrypt(updateUserAgent(
171172
new DecryptRequest()
172173
.withCiphertextBlob(ByteBuffer.wrap(edk.getEncryptedDataKey()))
173174
.withEncryptionContext(encryptionContext)

src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525
import java.util.Map;
2626
import java.util.Objects;
2727
import java.util.concurrent.ConcurrentHashMap;
28+
import java.util.function.Supplier;
2829

30+
import com.amazonaws.AmazonServiceException;
2931
import com.amazonaws.ClientConfiguration;
32+
import com.amazonaws.Request;
33+
import com.amazonaws.Response;
3034
import com.amazonaws.auth.AWSCredentials;
3135
import com.amazonaws.auth.AWSCredentialsProvider;
3236
import com.amazonaws.auth.AWSStaticCredentialsProvider;
@@ -71,12 +75,16 @@ public interface RegionalClientSupplier {
7175
AWSKMS getClient(String regionName);
7276
}
7377

74-
public static final class Builder implements Cloneable {
78+
public static class Builder implements Cloneable {
7579
private String defaultRegion_ = null;
7680
private RegionalClientSupplier regionalClientSupplier_ = null;
7781
private AWSKMSClientBuilder templateBuilder_ = null;
7882
private List<String> keyIds_ = new ArrayList<>();
7983

84+
Builder() {
85+
// Default access: Don't allow outside classes to extend this class
86+
}
87+
8088
public Builder clone() {
8189
try {
8290
Builder cloned = (Builder) super.clone();
@@ -259,11 +267,68 @@ private RegionalClientSupplier clientFactory() {
259267
AWSKMSClientBuilder builder = templateBuilder_ != null ? cloneClientBuilder(templateBuilder_)
260268
: AWSKMSClientBuilder.standard();
261269

270+
ConcurrentHashMap<String, AWSKMS> clientCache = new ConcurrentHashMap<>();
271+
snoopClientCache(clientCache);
272+
262273
return region -> {
263-
// Clone yet again as we're going to change the region field.
264-
return cloneClientBuilder(builder).withRegion(region).build();
274+
AWSKMS kms = clientCache.get(region);
275+
276+
if (kms != null) return kms;
277+
278+
// We can't just use computeIfAbsent as we need to avoid leaking KMS clients if we're asked to decrypt
279+
// an EDK with a bogus region in its ARN. So we'll install a request handler to identify the first
280+
// successful call, and cache it when we see that.
281+
SuccessfulRequestCacher cacher = new SuccessfulRequestCacher(clientCache, region);
282+
ArrayList<RequestHandler2> handlers = new ArrayList<>();
283+
if (builder.getRequestHandlers() != null) {
284+
handlers.addAll(builder.getRequestHandlers());
285+
}
286+
handlers.add(cacher);
287+
288+
kms = cloneClientBuilder(builder)
289+
.withRegion(region)
290+
.withRequestHandlers(handlers.toArray(new RequestHandler2[handlers.size()]))
291+
.build();
292+
cacher.client_ = kms;
293+
294+
return kms;
265295
};
266296
}
297+
298+
protected void snoopClientCache(ConcurrentHashMap<String, AWSKMS> map) {
299+
// no-op - this is a test hook
300+
}
301+
}
302+
303+
private static class SuccessfulRequestCacher extends RequestHandler2 {
304+
private final ConcurrentHashMap<String, AWSKMS> cache_;
305+
private final String region_;
306+
private AWSKMS client_;
307+
308+
volatile boolean ranBefore_ = false;
309+
310+
private SuccessfulRequestCacher(
311+
final ConcurrentHashMap<String, AWSKMS> cache,
312+
final String region
313+
) {
314+
this.region_ = region;
315+
this.cache_ = cache;
316+
}
317+
318+
@Override public void afterResponse(final Request<?> request, final Response<?> response) {
319+
if (ranBefore_) return;
320+
ranBefore_ = true;
321+
322+
cache_.putIfAbsent(region_, client_);
323+
}
324+
325+
@Override public void afterError(final Request<?> request, final Response<?> response, final Exception e) {
326+
if (ranBefore_) return;
327+
if (e instanceof AmazonServiceException) {
328+
ranBefore_ = true;
329+
cache_.putIfAbsent(region_, client_);
330+
}
331+
}
267332
}
268333

269334
public static Builder builder() {
@@ -453,12 +518,17 @@ public KmsMasterKey getMasterKey(final String provider, final String keyId) thro
453518
regionName = defaultRegion_;
454519
}
455520

456-
AWSKMS kms = regionalClientSupplier_.getClient(regionName);
457-
if (kms == null) {
458-
throw new AwsCryptoException("Can't use keys from region " + regionName);
459-
}
521+
String regionName_ = regionName;
522+
523+
Supplier<AWSKMS> kmsSupplier = () -> {
524+
AWSKMS kms = regionalClientSupplier_.getClient(regionName_);
525+
if (kms == null) {
526+
throw new AwsCryptoException("Can't use keys from region " + regionName_);
527+
}
528+
return kms;
529+
};
460530

461-
final KmsMasterKey result = KmsMasterKey.getInstance(kms, keyId, this);
531+
final KmsMasterKey result = KmsMasterKey.getInstance(kmsSupplier, keyId, this);
462532
result.setGrantTokens(grantTokens_);
463533
return result;
464534
}

src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import com.amazonaws.encryptionsdk.model.CipherFrameHeadersTest;
2424
import com.amazonaws.encryptionsdk.model.KeyBlobTest;
2525
import com.amazonaws.encryptionsdk.multi.MultipleMasterKeyTest;
26-
import com.amazonaws.services.kms.KMSProviderBuilderMockTests;
26+
import com.amazonaws.encryptionsdk.kms.KMSProviderBuilderMockTests;
2727

2828
@RunWith(Suite.class)
2929
@Suite.SuiteClasses({

src/test/java/com/amazonaws/encryptionsdk/IntegrationTestSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import org.junit.runner.RunWith;
44
import org.junit.runners.Suite;
55

6-
import com.amazonaws.services.kms.KMSProviderBuilderIntegrationTests;
7-
import com.amazonaws.services.kms.XCompatKmsDecryptTest;
6+
import com.amazonaws.encryptionsdk.kms.KMSProviderBuilderIntegrationTests;
7+
import com.amazonaws.encryptionsdk.kms.XCompatKmsDecryptTest;
88

99
@RunWith(Suite.class)
1010
@Suite.SuiteClasses({

src/test/java/com/amazonaws/services/kms/KMSProviderBuilderIntegrationTests.java renamed to src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderIntegrationTests.java

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
package com.amazonaws.services.kms;
1+
package com.amazonaws.encryptionsdk.kms;
22

33
import static org.junit.Assert.assertEquals;
4+
import static org.junit.Assert.assertFalse;
5+
import static org.junit.Assert.assertNotNull;
46
import static org.junit.Assert.assertTrue;
57
import static org.junit.Assert.fail;
68
import static org.mockito.ArgumentMatchers.any;
@@ -10,29 +12,91 @@
1012
import static org.mockito.Mockito.spy;
1113
import static org.mockito.Mockito.verify;
1214

15+
import java.nio.charset.StandardCharsets;
1316
import java.util.Arrays;
17+
import java.util.Collections;
18+
import java.util.HashMap;
19+
import java.util.concurrent.ConcurrentHashMap;
20+
import java.util.concurrent.atomic.AtomicReference;
1421

1522
import org.junit.Test;
1623
import org.mockito.ArgumentCaptor;
1724

1825
import com.amazonaws.AbortedException;
19-
import com.amazonaws.AmazonWebServiceRequest;
2026
import com.amazonaws.ClientConfiguration;
2127
import com.amazonaws.Request;
2228
import com.amazonaws.auth.AWSCredentials;
2329
import com.amazonaws.auth.AWSCredentialsProvider;
2430
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
2531
import com.amazonaws.client.builder.AwsClientBuilder;
2632
import com.amazonaws.encryptionsdk.AwsCrypto;
33+
import com.amazonaws.encryptionsdk.CryptoAlgorithm;
2734
import com.amazonaws.encryptionsdk.CryptoResult;
2835
import com.amazonaws.encryptionsdk.MasterKeyProvider;
2936
import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
3037
import com.amazonaws.encryptionsdk.internal.VersionInfo;
31-
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider;
38+
import com.amazonaws.encryptionsdk.model.KeyBlob;
3239
import com.amazonaws.handlers.RequestHandler2;
3340
import com.amazonaws.http.exception.HttpRequestTimeoutException;
41+
import com.amazonaws.services.kms.AWSKMS;
42+
import com.amazonaws.services.kms.AWSKMSClientBuilder;
3443

3544
public class KMSProviderBuilderIntegrationTests {
45+
@Test
46+
public void whenBogusRegionsDecrypted_doesNotLeakClients() throws Exception {
47+
AtomicReference<ConcurrentHashMap<String, AWSKMS>> kmsCache = new AtomicReference<>();
48+
49+
KmsMasterKeyProvider mkp = (new KmsMasterKeyProvider.Builder() {
50+
@Override protected void snoopClientCache(
51+
final ConcurrentHashMap<String, AWSKMS> map
52+
) {
53+
kmsCache.set(map);
54+
}
55+
}).build();
56+
57+
try {
58+
mkp.decryptDataKey(
59+
CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_HKDF_SHA256,
60+
Collections.singleton(
61+
new KeyBlob("aws-kms",
62+
"arn:aws:kms:us-bogus-1:123456789010:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f"
63+
.getBytes(StandardCharsets.UTF_8),
64+
new byte[40]
65+
)
66+
),
67+
new HashMap<>()
68+
);
69+
fail("Expected CannotUnwrapDataKeyException");
70+
} catch (CannotUnwrapDataKeyException e) {
71+
// ok
72+
}
73+
74+
assertTrue(kmsCache.get().isEmpty());
75+
}
76+
77+
@Test
78+
public void whenOperationSuccessful_clientIsCached() {
79+
AtomicReference<ConcurrentHashMap<String, AWSKMS>> kmsCache = new AtomicReference<>();
80+
81+
KmsMasterKeyProvider mkp = (new KmsMasterKeyProvider.Builder() {
82+
@Override protected void snoopClientCache(
83+
final ConcurrentHashMap<String, AWSKMS> map
84+
) {
85+
kmsCache.set(map);
86+
}
87+
}).withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[0])
88+
.build();
89+
90+
new AwsCrypto().encryptData(mkp, new byte[1]);
91+
92+
AWSKMS kms = kmsCache.get().get("us-west-2");
93+
assertNotNull(kms);
94+
95+
new AwsCrypto().encryptData(mkp, new byte[1]);
96+
97+
// Cache entry should stay the same
98+
assertEquals(kms, kmsCache.get().get("us-west-2"));
99+
}
36100

37101
@Test
38102
public void whenConstructedWithoutArguments_canUseMultipleRegions() throws Exception {
@@ -75,7 +139,7 @@ public void whenHandlerConfigured_handlerIsInvoked() throws Exception {
75139
KmsMasterKeyProvider.builder()
76140
.withClientBuilder(
77141
AWSKMSClientBuilder.standard()
78-
.withRequestHandlers(handler)
142+
.withRequestHandlers(handler)
79143
)
80144
.withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[0])
81145
.build();

src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java renamed to src/test/java/com/amazonaws/encryptionsdk/kms/KMSProviderBuilderMockTests.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
package com.amazonaws.services.kms;
1+
package com.amazonaws.encryptionsdk.kms;
22

33
import static com.amazonaws.encryptionsdk.multi.MultipleProviderFactory.buildMultiProvider;
44
import static com.amazonaws.regions.Region.getRegion;
5-
import static com.amazonaws.regions.Regions.DEFAULT_REGION;
65
import static com.amazonaws.regions.Regions.fromName;
76
import static java.util.Collections.singletonList;
87
import static org.junit.Assert.assertEquals;
@@ -30,8 +29,6 @@
3029
import com.amazonaws.encryptionsdk.AwsCrypto;
3130
import com.amazonaws.encryptionsdk.MasterKeyProvider;
3231
import com.amazonaws.encryptionsdk.internal.VersionInfo;
33-
import com.amazonaws.encryptionsdk.kms.KmsMasterKey;
34-
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider;
3532
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier;
3633
import com.amazonaws.regions.Region;
3734
import com.amazonaws.regions.Regions;

src/test/java/com/amazonaws/services/kms/KMSTestFixtures.java renamed to src/test/java/com/amazonaws/encryptionsdk/kms/KMSTestFixtures.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.amazonaws.services.kms;
1+
package com.amazonaws.encryptionsdk.kms;
22

33
final class KMSTestFixtures {
44
private KMSTestFixtures() {

0 commit comments

Comments
 (0)