Skip to content

Commit 0761f1a

Browse files
committed
Corrected sync path and tests
1 parent 9af1b54 commit 0761f1a

File tree

3 files changed

+38
-33
lines changed

3 files changed

+38
-33
lines changed

src/MongoDB.Driver.Encryption/CryptContext.cs

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,8 @@ public Binary FinalizeForEncryption()
156156
}
157157

158158
/// <summary>
159-
/// Gets a collection of KMS message requests to make
159+
/// Gets the next KMS message request
160160
/// </summary>
161-
/// <returns>Collection of KMS Messages</returns>
162-
public KmsRequestCollection GetKmsMessageRequests()
163-
{
164-
var requests = new List<KmsRequest>();
165-
for (IntPtr request = Library.mongocrypt_ctx_next_kms_ctx(_handle); request != IntPtr.Zero; request = Library.mongocrypt_ctx_next_kms_ctx(_handle))
166-
{
167-
requests.Add(new KmsRequest(request));
168-
}
169-
170-
return new KmsRequestCollection(requests, this);
171-
}
172-
173-
//TODO I think we should remove the previous method and use this
174161
public KmsRequest GetNextKmsMessageRequest()
175162
{
176163
var request = Library.mongocrypt_ctx_next_kms_ctx(_handle);

src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,11 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider)
211211

212212
private void ProcessNeedKmsState(CryptContext context, CancellationToken cancellationToken)
213213
{
214-
var requests = context.GetKmsMessageRequests();
215-
foreach (var request in requests)
214+
while (context.GetNextKmsMessageRequest() is { } request)
216215
{
217216
SendKmsRequest(request, cancellationToken);
218217
}
219-
requests.MarkDone();
218+
context.MarkKmsDone();
220219
}
221220

222221
private async Task ProcessNeedKmsStateAsync(CryptContext context, CancellationToken cancellationToken)
@@ -277,25 +276,46 @@ private static byte[] ProcessReadyState(CryptContext context)
277276

278277
private void SendKmsRequest(KmsRequest request, CancellationToken cancellation)
279278
{
280-
var endpoint = CreateKmsEndPoint(request.Endpoint);
281-
282-
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
283-
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
284-
using (var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation))
285-
using (var binary = request.GetMessage())
279+
try
286280
{
281+
var endpoint = CreateKmsEndPoint(request.Endpoint);
282+
283+
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
284+
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
285+
using var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation);
286+
287+
var sleepMs = request.Sleep;
288+
if (sleepMs > 0)
289+
{
290+
Thread.Sleep(sleepMs);
291+
}
292+
293+
using var binary = request.GetMessage();
287294
var requestBytes = binary.ToArray();
288295
sslStream.Write(requestBytes, 0, requestBytes.Length);
289296

290297
while (request.BytesNeeded > 0)
291298
{
292299
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
293300
var count = sslStream.Read(buffer, 0, buffer.Length);
301+
302+
if (count == 0)
303+
{
304+
throw new IOException("Unexpected end of stream. No data was read from the SSL stream.");
305+
}
306+
294307
var responseBytes = new byte[count];
295308
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
296309
request.Feed(responseBytes);
297310
}
298311
}
312+
catch (Exception ex) when (ex is IOException or SocketException)
313+
{
314+
if (!request.Fail())
315+
{
316+
throw;
317+
}
318+
}
299319
}
300320

301321
private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken cancellation)
@@ -307,16 +327,15 @@ private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken can
307327
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
308328
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
309329
using var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false);
310-
using var binary = request.GetMessage();
311-
312-
var requestBytes = binary.ToArray();
313330

314331
var sleepMs = request.Sleep;
315332
if (sleepMs > 0)
316333
{
317334
await Task.Delay(sleepMs, cancellation).ConfigureAwait(false);
318335
}
319336

337+
using var binary = request.GetMessage();
338+
var requestBytes = binary.ToArray();
320339
await sslStream.WriteAsync(requestBytes, 0, requestBytes.Length).ConfigureAwait(false);
321340

322341
while (request.BytesNeeded > 0)

tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ public void TestGetKmsProviderName(string kmsName)
433433
using (var cryptClient = CryptClientFactory.Create(cryptOptions))
434434
using (var context = cryptClient.StartCreateDataKeyContext(keyId))
435435
{
436-
var request = context.GetKmsMessageRequests().Single();
436+
var request = context.GetNextKmsMessageRequest();
437437
request.KmsProvider.Should().Be(kmsName);
438438
}
439439
}
@@ -634,22 +634,21 @@ private static (CryptContext.StateCode stateProcessed, Binary binaryProduced, Bs
634634

635635
case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS:
636636
{
637-
var requests = context.GetKmsMessageRequests();
638-
foreach (var req in requests)
637+
while (context.GetNextKmsMessageRequest() is { } request)
639638
{
640-
using var binary = req.GetMessage();
639+
using var binary = request.GetMessage();
641640
_output.WriteLine("Key Document: " + binary);
642641
var postRequest = binary.ToString();
643642
// TODO: add different hosts handling
644643
postRequest.Should().Contain("Host:kms.us-east-1.amazonaws.com"); // only AWS
645644

646645
var reply = ReadHttpTestFile(isKmsDecrypt ? "kms-decrypt-reply.txt" : "kms-encrypt-reply.txt");
647646
_output.WriteLine("Reply: " + reply);
648-
req.Feed(Encoding.UTF8.GetBytes(reply));
649-
req.BytesNeeded.Should().Be(0);
647+
request.Feed(Encoding.UTF8.GetBytes(reply));
648+
request.BytesNeeded.Should().Be(0);
650649
}
651650

652-
requests.MarkDone();
651+
context.MarkKmsDone();
653652
return (CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS, null, null);
654653
}
655654

0 commit comments

Comments
 (0)