@@ -211,12 +211,11 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider)
211
211
212
212
private void ProcessNeedKmsState ( CryptContext context , CancellationToken cancellationToken )
213
213
{
214
- var requests = context . GetKmsMessageRequests ( ) ;
215
- foreach ( var request in requests )
214
+ while ( context . GetNextKmsMessageRequest ( ) is { } request )
216
215
{
217
216
SendKmsRequest ( request , cancellationToken ) ;
218
217
}
219
- requests . MarkDone ( ) ;
218
+ context . MarkKmsDone ( ) ;
220
219
}
221
220
222
221
private async Task ProcessNeedKmsStateAsync ( CryptContext context , CancellationToken cancellationToken )
@@ -277,25 +276,46 @@ private static byte[] ProcessReadyState(CryptContext context)
277
276
278
277
private void SendKmsRequest ( KmsRequest request , CancellationToken cancellation )
279
278
{
280
- var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
281
-
282
- var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
283
- var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
284
- using ( var sslStream = sslStreamFactory . CreateStream ( endpoint , cancellation ) )
285
- using ( var binary = request . GetMessage ( ) )
279
+ try
286
280
{
281
+ var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
282
+
283
+ var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
284
+ var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
285
+ using var sslStream = sslStreamFactory . CreateStream ( endpoint , cancellation ) ;
286
+
287
+ var sleepMs = request . Sleep ;
288
+ if ( sleepMs > 0 )
289
+ {
290
+ Thread . Sleep ( sleepMs ) ;
291
+ }
292
+
293
+ using var binary = request . GetMessage ( ) ;
287
294
var requestBytes = binary . ToArray ( ) ;
288
295
sslStream . Write ( requestBytes , 0 , requestBytes . Length ) ;
289
296
290
297
while ( request . BytesNeeded > 0 )
291
298
{
292
299
var buffer = new byte [ request . BytesNeeded ] ; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
293
300
var count = sslStream . Read ( buffer , 0 , buffer . Length ) ;
301
+
302
+ if ( count == 0 )
303
+ {
304
+ throw new IOException ( "Unexpected end of stream. No data was read from the SSL stream." ) ;
305
+ }
306
+
294
307
var responseBytes = new byte [ count ] ;
295
308
Buffer . BlockCopy ( buffer , 0 , responseBytes , 0 , count ) ;
296
309
request . Feed ( responseBytes ) ;
297
310
}
298
311
}
312
+ catch ( Exception ex ) when ( ex is IOException or SocketException )
313
+ {
314
+ if ( ! request . Fail ( ) )
315
+ {
316
+ throw ;
317
+ }
318
+ }
299
319
}
300
320
301
321
private async Task SendKmsRequestAsync ( KmsRequest request , CancellationToken cancellation )
@@ -307,16 +327,15 @@ private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken can
307
327
var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
308
328
var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
309
329
using var sslStream = await sslStreamFactory . CreateStreamAsync ( endpoint , cancellation ) . ConfigureAwait ( false ) ;
310
- using var binary = request . GetMessage ( ) ;
311
-
312
- var requestBytes = binary . ToArray ( ) ;
313
330
314
331
var sleepMs = request . Sleep ;
315
332
if ( sleepMs > 0 )
316
333
{
317
334
await Task . Delay ( sleepMs , cancellation ) . ConfigureAwait ( false ) ;
318
335
}
319
336
337
+ using var binary = request . GetMessage ( ) ;
338
+ var requestBytes = binary . ToArray ( ) ;
320
339
await sslStream . WriteAsync ( requestBytes , 0 , requestBytes . Length ) . ConfigureAwait ( false ) ;
321
340
322
341
while ( request . BytesNeeded > 0 )
0 commit comments