Skip to content

Commit 2acf800

Browse files
committed
Combining collection for repr and repr_html into one function
1 parent 216ee03 commit 2acf800

File tree

1 file changed

+80
-27
lines changed

1 file changed

+80
-27
lines changed

src/dataframe.rs

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -116,21 +116,37 @@ impl PyDataFrame {
116116
}
117117

118118
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
119+
let (batches, has_more) = wait_for_future(
120+
py,
121+
collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
122+
)?;
123+
if batches.is_empty() {
124+
// This should not be reached, but do it for safety since we index into the vector below
125+
return Ok("No data to display".to_string());
126+
}
127+
119128
let df = self.df.as_ref().clone().limit(0, Some(10))?;
120129
let batches = wait_for_future(py, df.collect())?;
121-
let batches_as_string = pretty::pretty_format_batches(&batches);
122-
match batches_as_string {
123-
Ok(batch) => Ok(format!("DataFrame()\n{batch}")),
124-
Err(err) => Ok(format!("Error: {:?}", err.to_string())),
125-
}
130+
let batches_as_displ =
131+
pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
132+
133+
let additional_str = match has_more {
134+
true => "\nData truncated.",
135+
false => "",
136+
};
137+
138+
Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
126139
}
127140

128141
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
129-
let (batches, mut has_more) =
130-
wait_for_future(py, get_first_few_record_batches(self.df.as_ref().clone()))?;
131-
let Some(batches) = batches else {
132-
return Ok("No data to display".to_string());
133-
};
142+
let (batches, mut has_more) = wait_for_future(
143+
py,
144+
collect_record_batches_to_display(
145+
self.df.as_ref().clone(),
146+
MIN_TABLE_ROWS_TO_DISPLAY,
147+
usize::MAX,
148+
),
149+
)?;
134150
if batches.is_empty() {
135151
// This should not be reached, but do it for safety since we index into the vector below
136152
return Ok("No data to display".to_string());
@@ -200,10 +216,6 @@ impl PyDataFrame {
200216
let rows_per_batch = batches.iter().map(|batch| batch.num_rows());
201217
let total_rows = rows_per_batch.clone().sum();
202218

203-
// let (total_memory, total_rows) = batches.iter().fold((0, 0), |acc, batch| {
204-
// (acc.0 + batch.get_array_memory_size(), acc.1 + batch.num_rows())
205-
// });
206-
207219
let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY {
208220
true => {
209221
let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32;
@@ -887,37 +899,78 @@ fn record_batch_into_schema(
887899
/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
888900
/// It additionally returns a bool, which indicates if there are more record batches available.
889901
/// We do this so we can determine if we should indicate to the user that the data has been
890-
/// truncated.
891-
async fn get_first_few_record_batches(
902+
/// truncated. This collects until we have achived both of these two conditions
903+
///
904+
/// - We have collected our minimum number of rows
905+
/// - We have reached our limit, either data size or maximum number of rows
906+
///
907+
/// Otherwise it will return when the stream has exhausted. If you want a specific number of
908+
/// rows, set min_rows == max_rows.
909+
async fn collect_record_batches_to_display(
892910
df: DataFrame,
893-
) -> Result<(Option<Vec<RecordBatch>>, bool), DataFusionError> {
911+
min_rows: usize,
912+
max_rows: usize,
913+
) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
894914
let mut stream = df.execute_stream().await?;
895915
let mut size_estimate_so_far = 0;
916+
let mut rows_so_far = 0;
896917
let mut record_batches = Vec::default();
897-
while size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY {
898-
let rb = match stream.next().await {
918+
let mut has_more = false;
919+
920+
while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
921+
|| rows_so_far < min_rows
922+
{
923+
let mut rb = match stream.next().await {
899924
None => {
900925
break;
901926
}
902927
Some(Ok(r)) => r,
903928
Some(Err(e)) => return Err(e),
904929
};
905930

906-
if rb.num_rows() > 0 {
931+
let mut rows_in_rb = rb.num_rows();
932+
if rows_in_rb > 0 {
907933
size_estimate_so_far += rb.get_array_memory_size();
934+
935+
if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
936+
let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
937+
let total_rows = rows_in_rb + rows_so_far;
938+
939+
let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
940+
if reduced_row_num < min_rows {
941+
reduced_row_num = min_rows.min(total_rows);
942+
}
943+
944+
let limited_rows_this_rb = reduced_row_num - rows_so_far;
945+
if limited_rows_this_rb < rows_in_rb {
946+
rows_in_rb = limited_rows_this_rb;
947+
rb = rb.slice(0, limited_rows_this_rb);
948+
has_more = true;
949+
}
950+
}
951+
952+
if rows_in_rb + rows_so_far > max_rows {
953+
rb = rb.slice(0, max_rows - rows_so_far);
954+
has_more = true;
955+
}
956+
957+
rows_so_far += rb.num_rows();
908958
record_batches.push(rb);
909959
}
910960
}
911961

912962
if record_batches.is_empty() {
913-
return Ok((None, false));
963+
return Ok((Vec::default(), false));
914964
}
915965

916-
let has_more = match stream.try_next().await {
917-
Ok(None) => false, // reached end
918-
Ok(Some(_)) => true,
919-
Err(_) => false, // Stream disconnected
920-
};
966+
if !has_more {
967+
// Data was not already truncated, so check to see if more record batches remain
968+
has_more = match stream.try_next().await {
969+
Ok(None) => false, // reached end
970+
Ok(Some(_)) => true,
971+
Err(_) => false, // Stream disconnected
972+
};
973+
}
921974

922-
Ok((Some(record_batches), has_more))
975+
Ok((record_batches, has_more))
923976
}

0 commit comments

Comments
 (0)