@@ -190,7 +190,7 @@ impl<W: Write> BufWriter<W> {
190
190
/// data. Writes as much as possible without exceeding capacity. Returns
191
191
/// the number of bytes written.
192
192
pub ( super ) fn write_to_buf ( & mut self , buf : & [ u8 ] ) -> usize {
193
- let available = self . buf . capacity ( ) - self . buf . len ( ) ;
193
+ let available = self . spare_capacity ( ) ;
194
194
let amt_to_buffer = available. min ( buf. len ( ) ) ;
195
195
196
196
// SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction.
@@ -353,7 +353,7 @@ impl<W: Write> BufWriter<W> {
353
353
// or their write patterns are somewhat pathological.
354
354
#[ inline( never) ]
355
355
fn write_cold ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
356
- if self . buf . len ( ) + buf . len ( ) > self . buf . capacity ( ) {
356
+ if buf. len ( ) > self . spare_capacity ( ) {
357
357
self . flush_buf ( ) ?;
358
358
}
359
359
@@ -371,7 +371,7 @@ impl<W: Write> BufWriter<W> {
371
371
372
372
// SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
373
373
// we entered this else block because `buf.len() < self.buf.capacity()`.
374
- // Therefore, `self. buf.len() + buf.len () <= self.buf.capacity ()`.
374
+ // Therefore, `buf.len() <= self. buf.capacity () - self.buf.len ()`.
375
375
unsafe {
376
376
self . write_to_buffer_unchecked ( buf) ;
377
377
}
@@ -391,7 +391,8 @@ impl<W: Write> BufWriter<W> {
391
391
// by calling `self.get_mut().write_all()` directly, which avoids
392
392
// round trips through the buffer in the event of a series of partial
393
393
// writes in some circumstances.
394
- if self . buf . len ( ) + buf. len ( ) > self . buf . capacity ( ) {
394
+
395
+ if buf. len ( ) > self . spare_capacity ( ) {
395
396
self . flush_buf ( ) ?;
396
397
}
397
398
@@ -409,7 +410,7 @@ impl<W: Write> BufWriter<W> {
409
410
410
411
// SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
411
412
// we entered this else block because `buf.len() < self.buf.capacity()`.
412
- // Therefore, `self. buf.len() + buf.len () <= self.buf.capacity ()`.
413
+ // Therefore, `buf.len() <= self. buf.capacity () - self.buf.len ()`.
413
414
unsafe {
414
415
self . write_to_buffer_unchecked ( buf) ;
415
416
}
@@ -418,18 +419,23 @@ impl<W: Write> BufWriter<W> {
418
419
}
419
420
}
420
421
421
- // SAFETY: Requires `self. buf.len() + buf.len () <= self.buf.capacity ()`,
422
+ // SAFETY: Requires `buf.len() <= self. buf.capacity () - self.buf.len ()`,
422
423
// i.e., that input buffer length is less than or equal to spare capacity.
423
424
#[ inline( always) ]
424
425
unsafe fn write_to_buffer_unchecked ( & mut self , buf : & [ u8 ] ) {
425
- debug_assert ! ( self . buf. len( ) + buf . len ( ) <= self . buf . capacity ( ) ) ;
426
+ debug_assert ! ( buf. len( ) <= self . spare_capacity ( ) ) ;
426
427
let old_len = self . buf . len ( ) ;
427
428
let buf_len = buf. len ( ) ;
428
429
let src = buf. as_ptr ( ) ;
429
430
let dst = self . buf . as_mut_ptr ( ) . add ( old_len) ;
430
431
ptr:: copy_nonoverlapping ( src, dst, buf_len) ;
431
432
self . buf . set_len ( old_len + buf_len) ;
432
433
}
434
+
435
+ #[ inline]
436
+ fn spare_capacity ( & self ) -> usize {
437
+ self . buf . capacity ( ) - self . buf . len ( )
438
+ }
433
439
}
434
440
435
441
#[ unstable( feature = "bufwriter_into_raw_parts" , issue = "80690" ) ]
@@ -505,7 +511,7 @@ impl<W: Write> Write for BufWriter<W> {
505
511
fn write ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
506
512
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
507
513
// See `write_cold` for details.
508
- if self . buf . len ( ) + buf . len ( ) < self . buf . capacity ( ) {
514
+ if buf. len ( ) < self . spare_capacity ( ) {
509
515
// SAFETY: safe by above conditional.
510
516
unsafe {
511
517
self . write_to_buffer_unchecked ( buf) ;
@@ -521,7 +527,7 @@ impl<W: Write> Write for BufWriter<W> {
521
527
fn write_all ( & mut self , buf : & [ u8 ] ) -> io:: Result < ( ) > {
522
528
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
523
529
// See `write_all_cold` for details.
524
- if self . buf . len ( ) + buf . len ( ) < self . buf . capacity ( ) {
530
+ if buf. len ( ) < self . spare_capacity ( ) {
525
531
// SAFETY: safe by above conditional.
526
532
unsafe {
527
533
self . write_to_buffer_unchecked ( buf) ;
@@ -537,31 +543,46 @@ impl<W: Write> Write for BufWriter<W> {
537
543
// FIXME: Consider applying `#[inline]` / `#[inline(never)]` optimizations already applied
538
544
// to `write` and `write_all`. The performance benefits can be significant. See #79930.
539
545
if self . get_ref ( ) . is_write_vectored ( ) {
540
- let total_len = bufs. iter ( ) . map ( |b| b. len ( ) ) . sum :: < usize > ( ) ;
541
- if self . buf . len ( ) + total_len > self . buf . capacity ( ) {
546
+ // We have to handle the possibility that the total length of the buffers overflows
547
+ // `usize` (even though this can only happen if multiple `IoSlice`s reference the
548
+ // same underlying buffer, as otherwise the buffers wouldn't fit in memory). If the
549
+ // computation overflows, then surely the input cannot fit in our buffer, so we forward
550
+ // to the inner writer's `write_vectored` method to let it handle it appropriately.
551
+ let saturated_total_len =
552
+ bufs. iter ( ) . fold ( 0usize , |acc, b| acc. saturating_add ( b. len ( ) ) ) ;
553
+
554
+ if saturated_total_len > self . spare_capacity ( ) {
555
+ // Flush if the total length of the input exceeds our buffer's spare capacity.
556
+ // If we would have overflowed, this condition also holds, and we need to flush.
542
557
self . flush_buf ( ) ?;
543
558
}
544
- if total_len >= self . buf . capacity ( ) {
559
+
560
+ if saturated_total_len >= self . buf . capacity ( ) {
561
+ // Forward to our inner writer if the total length of the input is greater than or
562
+ // equal to our buffer capacity. If we would have overflowed, this condition also
563
+ // holds, and we punt to the inner writer.
545
564
self . panicked = true ;
546
565
let r = self . get_mut ( ) . write_vectored ( bufs) ;
547
566
self . panicked = false ;
548
567
r
549
568
} else {
569
+ // `saturated_total_len < self.buf.capacity()` implies that we did not saturate.
570
+
550
571
// SAFETY: We checked whether or not the spare capacity was large enough above. If
551
572
// it was, then we're safe already. If it wasn't, we flushed, making sufficient
552
573
// room for any input <= the buffer size, which includes this input.
553
574
unsafe {
554
575
bufs. iter ( ) . for_each ( |b| self . write_to_buffer_unchecked ( b) ) ;
555
576
} ;
556
577
557
- Ok ( total_len )
578
+ Ok ( saturated_total_len )
558
579
}
559
580
} else {
560
581
let mut iter = bufs. iter ( ) ;
561
582
let mut total_written = if let Some ( buf) = iter. by_ref ( ) . find ( |& buf| !buf. is_empty ( ) ) {
562
583
// This is the first non-empty slice to write, so if it does
563
584
// not fit in the buffer, we still get to flush and proceed.
564
- if self . buf . len ( ) + buf . len ( ) > self . buf . capacity ( ) {
585
+ if buf. len ( ) > self . spare_capacity ( ) {
565
586
self . flush_buf ( ) ?;
566
587
}
567
588
if buf. len ( ) >= self . buf . capacity ( ) {
@@ -586,12 +607,15 @@ impl<W: Write> Write for BufWriter<W> {
586
607
} ;
587
608
debug_assert ! ( total_written != 0 ) ;
588
609
for buf in iter {
589
- if self . buf . len ( ) + buf . len ( ) <= self . buf . capacity ( ) {
610
+ if buf. len ( ) <= self . spare_capacity ( ) {
590
611
// SAFETY: safe by above conditional.
591
612
unsafe {
592
613
self . write_to_buffer_unchecked ( buf) ;
593
614
}
594
615
616
+ // This cannot overflow `usize`. If we are here, we've written all of the bytes
617
+ // so far to our buffer, and we've ensured that we never exceed the buffer's
618
+ // capacity. Therefore, `total_written` <= `self.buf.capacity()` <= `usize::MAX`.
595
619
total_written += buf. len ( ) ;
596
620
} else {
597
621
break ;
0 commit comments