Skip to content

Commit 42982da

Browse files
authored
Improve collection during repr and repr_html (#1036)
* Improve table readout of a dataframe in jupyter notebooks by making the table scrollable and displaying the first record batch up to 2MB * Add option to only display a portion of a cell data and the user can click on a button to toggle showing more or less * We cannot expect that the first non-empy batch is sufficient for our 2MB limit, so switch over to collecting until we run out or use up the size * Update python unit test to allow the additional formatting data to exist and only check the table contents * Combining collection for repr and repr_html into one function * Small clippy suggestion * Collect was occuring twice on repr * Switch to execute_stream_partitioned
1 parent b8dd97b commit 42982da

File tree

3 files changed

+225
-40
lines changed

3 files changed

+225
-40
lines changed

python/tests/test_dataframe.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import os
18+
import re
1819
from typing import Any
1920

2021
import pyarrow as pa
@@ -1245,13 +1246,17 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame:
12451246
def test_dataframe_repr_html(df) -> None:
12461247
output = df._repr_html_()
12471248

1248-
ref_html = """<table border='1'>
1249-
<tr><th>a</td><th>b</td><th>c</td></tr>
1250-
<tr><td>1</td><td>4</td><td>8</td></tr>
1251-
<tr><td>2</td><td>5</td><td>5</td></tr>
1252-
<tr><td>3</td><td>6</td><td>8</td></tr>
1253-
</table>
1254-
"""
1249+
# Since we've added a fair bit of processing to the html output, lets just verify
1250+
# the values we are expecting in the table exist. Use regex and ignore everything
1251+
# between the <th></th> and <td></td>. We also don't want the closing > on the
1252+
# td and th segments because that is where the formatting data is written.
12551253

1256-
# Ignore whitespace just to make this test look cleaner
1257-
assert output.replace(" ", "") == ref_html.replace(" ", "")
1254+
headers = ["a", "b", "c"]
1255+
headers = [f"<th(.*?)>{v}</th>" for v in headers]
1256+
header_pattern = "(.*?)".join(headers)
1257+
assert len(re.findall(header_pattern, output, re.DOTALL)) == 1
1258+
1259+
body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]]
1260+
body_lines = [f"<td(.*?)>{v}</td>" for inner in body_data for v in inner]
1261+
body_pattern = "(.*?)".join(body_lines)
1262+
assert len(re.findall(body_pattern, output, re.DOTALL)) == 1

src/dataframe.rs

Lines changed: 210 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ use datafusion::common::UnnestOptions;
3131
use datafusion::config::{CsvOptions, TableParquetOptions};
3232
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3333
use datafusion::datasource::TableProvider;
34+
use datafusion::error::DataFusionError;
3435
use datafusion::execution::SendableRecordBatchStream;
3536
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
3637
use datafusion::prelude::*;
38+
use futures::{StreamExt, TryStreamExt};
3739
use pyo3::exceptions::PyValueError;
3840
use pyo3::prelude::*;
3941
use pyo3::pybacked::PyBackedStr;
@@ -70,6 +72,9 @@ impl PyTableProvider {
7072
PyTable::new(table_provider)
7173
}
7274
}
75+
const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
76+
const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
77+
const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;
7378

7479
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
7580
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -111,56 +116,151 @@ impl PyDataFrame {
111116
}
112117

113118
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
114-
let df = self.df.as_ref().clone().limit(0, Some(10))?;
115-
let batches = wait_for_future(py, df.collect())?;
116-
let batches_as_string = pretty::pretty_format_batches(&batches);
117-
match batches_as_string {
118-
Ok(batch) => Ok(format!("DataFrame()\n{batch}")),
119-
Err(err) => Ok(format!("Error: {:?}", err.to_string())),
119+
let (batches, has_more) = wait_for_future(
120+
py,
121+
collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
122+
)?;
123+
if batches.is_empty() {
124+
// This should not be reached, but do it for safety since we index into the vector below
125+
return Ok("No data to display".to_string());
120126
}
121-
}
122127

123-
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
124-
let mut html_str = "<table border='1'>\n".to_string();
128+
let batches_as_displ =
129+
pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
130+
131+
let additional_str = match has_more {
132+
true => "\nData truncated.",
133+
false => "",
134+
};
125135

126-
let df = self.df.as_ref().clone().limit(0, Some(10))?;
127-
let batches = wait_for_future(py, df.collect())?;
136+
Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
137+
}
128138

139+
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
140+
let (batches, has_more) = wait_for_future(
141+
py,
142+
collect_record_batches_to_display(
143+
self.df.as_ref().clone(),
144+
MIN_TABLE_ROWS_TO_DISPLAY,
145+
usize::MAX,
146+
),
147+
)?;
129148
if batches.is_empty() {
130-
html_str.push_str("</table>\n");
131-
return Ok(html_str);
149+
// This should not be reached, but do it for safety since we index into the vector below
150+
return Ok("No data to display".to_string());
132151
}
133152

153+
let table_uuid = uuid::Uuid::new_v4().to_string();
154+
155+
let mut html_str = "
156+
<style>
157+
.expandable-container {
158+
display: inline-block;
159+
max-width: 200px;
160+
}
161+
.expandable {
162+
white-space: nowrap;
163+
overflow: hidden;
164+
text-overflow: ellipsis;
165+
display: block;
166+
}
167+
.full-text {
168+
display: none;
169+
white-space: normal;
170+
}
171+
.expand-btn {
172+
cursor: pointer;
173+
color: blue;
174+
text-decoration: underline;
175+
border: none;
176+
background: none;
177+
font-size: inherit;
178+
display: block;
179+
margin-top: 5px;
180+
}
181+
</style>
182+
183+
<div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\">
184+
<table style=\"border-collapse: collapse; min-width: 100%\">
185+
<thead>\n".to_string();
186+
134187
let schema = batches[0].schema();
135188

136189
let mut header = Vec::new();
137190
for field in schema.fields() {
138-
header.push(format!("<th>{}</td>", field.name()));
191+
header.push(format!("<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;'>{}</th>", field.name()));
139192
}
140193
let header_str = header.join("");
141-
html_str.push_str(&format!("<tr>{}</tr>\n", header_str));
142-
143-
for batch in batches {
144-
let formatters = batch
145-
.columns()
146-
.iter()
147-
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
148-
.map(|c| {
149-
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
150-
})
151-
.collect::<Result<Vec<_>, _>>()?;
152-
153-
for row in 0..batch.num_rows() {
194+
html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str));
195+
196+
let batch_formatters = batches
197+
.iter()
198+
.map(|batch| {
199+
batch
200+
.columns()
201+
.iter()
202+
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
203+
.map(|c| {
204+
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
205+
})
206+
.collect::<Result<Vec<_>, _>>()
207+
})
208+
.collect::<Result<Vec<_>, _>>()?;
209+
210+
let rows_per_batch = batches.iter().map(|batch| batch.num_rows());
211+
212+
// We need to build up row by row for html
213+
let mut table_row = 0;
214+
for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) {
215+
for batch_row in 0..num_rows_in_batch {
216+
table_row += 1;
154217
let mut cells = Vec::new();
155-
for formatter in &formatters {
156-
cells.push(format!("<td>{}</td>", formatter.value(row)));
218+
for (col, formatter) in batch_formatter.iter().enumerate() {
219+
let cell_data = formatter.value(batch_row).to_string();
220+
// From testing, primitive data types do not typically get larger than 21 characters
221+
if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
222+
let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
223+
cells.push(format!("
224+
<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
225+
<div class=\"expandable-container\">
226+
<span class=\"expandable\" id=\"{table_uuid}-min-text-{table_row}-{col}\">{short_cell_data}</span>
227+
<span class=\"full-text\" id=\"{table_uuid}-full-text-{table_row}-{col}\">{cell_data}</span>
228+
<button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{table_row},{col})\">...</button>
229+
</div>
230+
</td>"));
231+
} else {
232+
cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(batch_row)));
233+
}
157234
}
158235
let row_str = cells.join("");
159236
html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
160237
}
161238
}
239+
html_str.push_str("</tbody></table></div>\n");
240+
241+
html_str.push_str("
242+
<script>
243+
function toggleDataFrameCellText(table_uuid, row, col) {
244+
var shortText = document.getElementById(table_uuid + \"-min-text-\" + row + \"-\" + col);
245+
var fullText = document.getElementById(table_uuid + \"-full-text-\" + row + \"-\" + col);
246+
var button = event.target;
247+
248+
if (fullText.style.display === \"none\") {
249+
shortText.style.display = \"none\";
250+
fullText.style.display = \"inline\";
251+
button.textContent = \"(less)\";
252+
} else {
253+
shortText.style.display = \"inline\";
254+
fullText.style.display = \"none\";
255+
button.textContent = \"...\";
256+
}
257+
}
258+
</script>
259+
");
162260

163-
html_str.push_str("</table>\n");
261+
if has_more {
262+
html_str.push_str("Data truncated due to size.");
263+
}
164264

165265
Ok(html_str)
166266
}
@@ -771,3 +871,83 @@ fn record_batch_into_schema(
771871

772872
RecordBatch::try_new(schema, data_arrays)
773873
}
874+
875+
/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
876+
/// It additionally returns a bool, which indicates if there are more record batches available.
877+
/// We do this so we can determine if we should indicate to the user that the data has been
878+
/// truncated. This collects until we have achived both of these two conditions
879+
///
880+
/// - We have collected our minimum number of rows
881+
/// - We have reached our limit, either data size or maximum number of rows
882+
///
883+
/// Otherwise it will return when the stream has exhausted. If you want a specific number of
884+
/// rows, set min_rows == max_rows.
885+
async fn collect_record_batches_to_display(
886+
df: DataFrame,
887+
min_rows: usize,
888+
max_rows: usize,
889+
) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
890+
let partitioned_stream = df.execute_stream_partitioned().await?;
891+
let mut stream = futures::stream::iter(partitioned_stream).flatten();
892+
let mut size_estimate_so_far = 0;
893+
let mut rows_so_far = 0;
894+
let mut record_batches = Vec::default();
895+
let mut has_more = false;
896+
897+
while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
898+
|| rows_so_far < min_rows
899+
{
900+
let mut rb = match stream.next().await {
901+
None => {
902+
break;
903+
}
904+
Some(Ok(r)) => r,
905+
Some(Err(e)) => return Err(e),
906+
};
907+
908+
let mut rows_in_rb = rb.num_rows();
909+
if rows_in_rb > 0 {
910+
size_estimate_so_far += rb.get_array_memory_size();
911+
912+
if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
913+
let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
914+
let total_rows = rows_in_rb + rows_so_far;
915+
916+
let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
917+
if reduced_row_num < min_rows {
918+
reduced_row_num = min_rows.min(total_rows);
919+
}
920+
921+
let limited_rows_this_rb = reduced_row_num - rows_so_far;
922+
if limited_rows_this_rb < rows_in_rb {
923+
rows_in_rb = limited_rows_this_rb;
924+
rb = rb.slice(0, limited_rows_this_rb);
925+
has_more = true;
926+
}
927+
}
928+
929+
if rows_in_rb + rows_so_far > max_rows {
930+
rb = rb.slice(0, max_rows - rows_so_far);
931+
has_more = true;
932+
}
933+
934+
rows_so_far += rb.num_rows();
935+
record_batches.push(rb);
936+
}
937+
}
938+
939+
if record_batches.is_empty() {
940+
return Ok((Vec::default(), false));
941+
}
942+
943+
if !has_more {
944+
// Data was not already truncated, so check to see if more record batches remain
945+
has_more = match stream.try_next().await {
946+
Ok(None) => false, // reached end
947+
Ok(Some(_)) => true,
948+
Err(_) => false, // Stream disconnected
949+
};
950+
}
951+
952+
Ok((record_batches, has_more))
953+
}

src/utils.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
4242
#[inline]
4343
pub(crate) fn get_global_ctx() -> &'static SessionContext {
4444
static CTX: OnceLock<SessionContext> = OnceLock::new();
45-
CTX.get_or_init(|| SessionContext::new())
45+
CTX.get_or_init(SessionContext::new)
4646
}
4747

4848
/// Utility to collect rust futures with GIL released

0 commit comments

Comments
 (0)