Skip to content

Commit 72aecbf

Browse files
committed
BufWriter: handle possibility of overflow
1 parent 5fd9372 commit 72aecbf

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

library/std/src/io/buffered/bufwriter.rs

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ impl<W: Write> BufWriter<W> {
190190
/// data. Writes as much as possible without exceeding capacity. Returns
191191
/// the number of bytes written.
192192
pub(super) fn write_to_buf(&mut self, buf: &[u8]) -> usize {
193-
let available = self.buf.capacity() - self.buf.len();
193+
let available = self.spare_capacity();
194194
let amt_to_buffer = available.min(buf.len());
195195

196196
// SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction.
@@ -353,7 +353,7 @@ impl<W: Write> BufWriter<W> {
353353
// or their write patterns are somewhat pathological.
354354
#[inline(never)]
355355
fn write_cold(&mut self, buf: &[u8]) -> io::Result<usize> {
356-
if self.buf.len() + buf.len() > self.buf.capacity() {
356+
if buf.len() > self.spare_capacity() {
357357
self.flush_buf()?;
358358
}
359359

@@ -371,7 +371,7 @@ impl<W: Write> BufWriter<W> {
371371

372372
// SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
373373
// we entered this else block because `buf.len() < self.buf.capacity()`.
374-
// Therefore, `self.buf.len() + buf.len() <= self.buf.capacity()`.
374+
// Therefore, `buf.len() <= self.buf.capacity() - self.buf.len()`.
375375
unsafe {
376376
self.write_to_buffer_unchecked(buf);
377377
}
@@ -391,7 +391,8 @@ impl<W: Write> BufWriter<W> {
391391
// by calling `self.get_mut().write_all()` directly, which avoids
392392
// round trips through the buffer in the event of a series of partial
393393
// writes in some circumstances.
394-
if self.buf.len() + buf.len() > self.buf.capacity() {
394+
395+
if buf.len() > self.spare_capacity() {
395396
self.flush_buf()?;
396397
}
397398

@@ -409,7 +410,7 @@ impl<W: Write> BufWriter<W> {
409410

410411
// SAFETY: We just called `self.flush_buf()`, so `self.buf.len()` is 0, and
411412
// we entered this else block because `buf.len() < self.buf.capacity()`.
412-
// Therefore, `self.buf.len() + buf.len() <= self.buf.capacity()`.
413+
// Therefore, `buf.len() <= self.buf.capacity() - self.buf.len()`.
413414
unsafe {
414415
self.write_to_buffer_unchecked(buf);
415416
}
@@ -418,18 +419,23 @@ impl<W: Write> BufWriter<W> {
418419
}
419420
}
420421

421-
// SAFETY: Requires `self.buf.len() + buf.len() <= self.buf.capacity()`,
422+
// SAFETY: Requires `buf.len() <= self.buf.capacity() - self.buf.len()`,
422423
// i.e., that input buffer length is less than or equal to spare capacity.
423424
#[inline(always)]
424425
unsafe fn write_to_buffer_unchecked(&mut self, buf: &[u8]) {
425-
debug_assert!(self.buf.len() + buf.len() <= self.buf.capacity());
426+
debug_assert!(buf.len() <= self.spare_capacity());
426427
let old_len = self.buf.len();
427428
let buf_len = buf.len();
428429
let src = buf.as_ptr();
429430
let dst = self.buf.as_mut_ptr().add(old_len);
430431
ptr::copy_nonoverlapping(src, dst, buf_len);
431432
self.buf.set_len(old_len + buf_len);
432433
}
434+
435+
#[inline]
436+
fn spare_capacity(&self) -> usize {
437+
self.buf.capacity() - self.buf.len()
438+
}
433439
}
434440

435441
#[unstable(feature = "bufwriter_into_raw_parts", issue = "80690")]
@@ -505,7 +511,7 @@ impl<W: Write> Write for BufWriter<W> {
505511
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
506512
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
507513
// See `write_cold` for details.
508-
if self.buf.len() + buf.len() < self.buf.capacity() {
514+
if buf.len() < self.spare_capacity() {
509515
// SAFETY: safe by above conditional.
510516
unsafe {
511517
self.write_to_buffer_unchecked(buf);
@@ -521,7 +527,7 @@ impl<W: Write> Write for BufWriter<W> {
521527
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
522528
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
523529
// See `write_all_cold` for details.
524-
if self.buf.len() + buf.len() < self.buf.capacity() {
530+
if buf.len() < self.spare_capacity() {
525531
// SAFETY: safe by above conditional.
526532
unsafe {
527533
self.write_to_buffer_unchecked(buf);
@@ -537,31 +543,46 @@ impl<W: Write> Write for BufWriter<W> {
537543
// FIXME: Consider applying `#[inline]` / `#[inline(never)]` optimizations already applied
538544
// to `write` and `write_all`. The performance benefits can be significant. See #79930.
539545
if self.get_ref().is_write_vectored() {
540-
let total_len = bufs.iter().map(|b| b.len()).sum::<usize>();
541-
if self.buf.len() + total_len > self.buf.capacity() {
546+
// We have to handle the possibility that the total length of the buffers overflows
547+
// `usize` (even though this can only happen if multiple `IoSlice`s reference the
548+
// same underlying buffer, as otherwise the buffers wouldn't fit in memory). If the
549+
// computation overflows, then surely the input cannot fit in our buffer, so we forward
550+
// to the inner writer's `write_vectored` method to let it handle it appropriately.
551+
let saturated_total_len =
552+
bufs.iter().fold(0usize, |acc, b| acc.saturating_add(b.len()));
553+
554+
if saturated_total_len > self.spare_capacity() {
555+
// Flush if the total length of the input exceeds our buffer's spare capacity.
556+
// If we would have overflowed, this condition also holds, and we need to flush.
542557
self.flush_buf()?;
543558
}
544-
if total_len >= self.buf.capacity() {
559+
560+
if saturated_total_len >= self.buf.capacity() {
561+
// Forward to our inner writer if the total length of the input is greater than or
562+
// equal to our buffer capacity. If we would have overflowed, this condition also
563+
// holds, and we punt to the inner writer.
545564
self.panicked = true;
546565
let r = self.get_mut().write_vectored(bufs);
547566
self.panicked = false;
548567
r
549568
} else {
569+
// `saturated_total_len < self.buf.capacity()` implies that we did not saturate.
570+
550571
// SAFETY: We checked whether or not the spare capacity was large enough above. If
551572
// it was, then we're safe already. If it wasn't, we flushed, making sufficient
552573
// room for any input <= the buffer size, which includes this input.
553574
unsafe {
554575
bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
555576
};
556577

557-
Ok(total_len)
578+
Ok(saturated_total_len)
558579
}
559580
} else {
560581
let mut iter = bufs.iter();
561582
let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) {
562583
// This is the first non-empty slice to write, so if it does
563584
// not fit in the buffer, we still get to flush and proceed.
564-
if self.buf.len() + buf.len() > self.buf.capacity() {
585+
if buf.len() > self.spare_capacity() {
565586
self.flush_buf()?;
566587
}
567588
if buf.len() >= self.buf.capacity() {
@@ -586,12 +607,15 @@ impl<W: Write> Write for BufWriter<W> {
586607
};
587608
debug_assert!(total_written != 0);
588609
for buf in iter {
589-
if self.buf.len() + buf.len() <= self.buf.capacity() {
610+
if buf.len() <= self.spare_capacity() {
590611
// SAFETY: safe by above conditional.
591612
unsafe {
592613
self.write_to_buffer_unchecked(buf);
593614
}
594615

616+
// This cannot overflow `usize`. If we are here, we've written all of the bytes
617+
// so far to our buffer, and we've ensured that we never exceed the buffer's
618+
// capacity. Therefore, `total_written` <= `self.buf.capacity()` <= `usize::MAX`.
595619
total_written += buf.len();
596620
} else {
597621
break;

0 commit comments

Comments
 (0)