Skip to content

Commit cd8592d

Browse files
feat: Add generic interface for custom data sinks (#4244)
## Changes Made 1. Add a generic `DataSink` interface. Users can use this to write custom write sinks that have optional `.start()`, `.write()`, `.finalize()` methods. 2. Add `DataFrame.write()` that takes in a custom data sink and writes the dataframe using the sink. 3. Add `LanceDataSink` as an example of implementing `DataSink` for LanceDB. 4. Modify `DataFrame.write_lance()` to use this new `LanceDataSink`. ## Checklist - [x] Documented in API Docs (if applicable) - [x] Documented in User Guide (if applicable) - [x] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [x] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent ed37efe commit cd8592d

File tree

28 files changed

+630
-86
lines changed

28 files changed

+630
-86
lines changed

daft/daft/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Iterator, Litera
66
from daft.catalog import Catalog, Table
77
from daft.dataframe.display import MermaidOptions
88
from daft.execution import physical_plan
9+
from daft.io import DataSink
910
from daft.io.scan import ScanOperator
1011
from daft.plan_scheduler.physical_plan_scheduler import PartitionT
1112
from daft.runners.partitioning import PartitionCacheEntry
@@ -1821,6 +1822,7 @@ class LogicalPlanBuilder:
18211822
io_config: IOConfig | None = None,
18221823
kwargs: dict[str, Any] | None = None,
18231824
) -> LogicalPlanBuilder: ...
1825+
def datasink_write(self, name: str, sink: DataSink) -> LogicalPlanBuilder: ...
18241826
def schema(self) -> PySchema: ...
18251827
def describe(self) -> LogicalPlanBuilder: ...
18261828
def summarize(self) -> LogicalPlanBuilder: ...

daft/dataframe/dataframe.py

Lines changed: 34 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
import ray
5555
import torch
5656

57-
from daft.io import DataCatalogTable
57+
from daft.io import DataCatalogTable, DataSink
5858
from daft.unity_catalog import UnityCatalogTable
5959

6060
if sys.version_info < (3, 10):
@@ -65,8 +65,8 @@
6565
from daft.schema import Schema
6666

6767
UDFReturnType = TypeVar("UDFReturnType", covariant=True)
68-
6968
T = TypeVar("T")
69+
R = TypeVar("R")
7070
P = ParamSpec("P")
7171

7272

@@ -1178,6 +1178,35 @@ def _create_metadata_param(metadata: Optional[Dict[str, str]]):
11781178

11791179
return with_operations
11801180

1181+
@DataframePublicAPI
1182+
def write_sink(self, sink: "DataSink[T]") -> "DataFrame":
1183+
"""Writes the DataFrame to the given DataSink.
1184+
1185+
Args:
1186+
sink: The DataSink to write to.
1187+
1188+
Returns:
1189+
DataFrame: A dataframe from the micropartition returned by the DataSink's `.finalize()` method.
1190+
"""
1191+
sink.start()
1192+
1193+
builder = self._builder.write_datasink(sink.name(), sink)
1194+
write_df = DataFrame(builder)
1195+
write_df.collect()
1196+
1197+
results = write_df.to_pydict()
1198+
assert "write_results" in results
1199+
micropartition = sink.finalize(results["write_results"])
1200+
if micropartition.schema() != sink.schema():
1201+
raise ValueError(
1202+
f"Schema mismatch between the data sink's schema and the result's schema:\nSink schema:\n{sink.schema()}\nResult schema:\n{micropartition.schema()}"
1203+
)
1204+
# TODO(desmond): Connect the old and new logical plan builders so that a .explain() shows the
1205+
# plan from the source all the way to the sink to the sink's results. In theory we can do this
1206+
# for all other sinks too.
1207+
write_plan_builder = to_logical_plan_builder(micropartition)
1208+
return DataFrame(write_plan_builder)
1209+
11811210
@DataframePublicAPI
11821211
def write_lance(
11831212
self,
@@ -1239,75 +1268,10 @@ def write_lance(
12391268
<BLANKLINE>
12401269
(Showing first 1 of 1 rows)
12411270
"""
1242-
from daft import from_pydict
1243-
from daft.io.object_store_options import io_config_to_storage_options
1244-
1245-
try:
1246-
import lance
1247-
import pyarrow as pa
1271+
from daft.dataframe.lance_data_sink import LanceDataSink
12481272

1249-
except ImportError:
1250-
raise ImportError("lance is not installed. Please install lance using `pip install daft[lance]`")
1251-
1252-
io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config
1253-
1254-
if isinstance(uri, (str, pathlib.Path)):
1255-
if isinstance(uri, str):
1256-
table_uri = uri
1257-
elif isinstance(uri, pathlib.Path):
1258-
table_uri = str(uri)
1259-
else:
1260-
table_uri = uri
1261-
pyarrow_schema = pa.schema((f.name, f.dtype.to_arrow_dtype()) for f in self.schema())
1262-
1263-
storage_options = io_config_to_storage_options(io_config, table_uri)
1264-
1265-
try:
1266-
table = lance.dataset(table_uri, storage_options=storage_options)
1267-
1268-
except ValueError:
1269-
table = None
1270-
1271-
version = 0
1272-
if table:
1273-
table_schema = table.schema
1274-
version = table.latest_version
1275-
if pyarrow_schema != table_schema and not (mode == "overwrite"):
1276-
raise ValueError(
1277-
"Schema of data does not match table schema\n"
1278-
f"Data schema:\n{pyarrow_schema}\nTable Schema:\n{table_schema}"
1279-
)
1280-
1281-
builder = self._builder.write_lance(
1282-
table_uri,
1283-
mode,
1284-
io_config=io_config,
1285-
kwargs=kwargs,
1286-
)
1287-
write_df = DataFrame(builder)
1288-
write_df.collect()
1289-
1290-
write_result = write_df.to_pydict()
1291-
assert "fragments" in write_result
1292-
fragments = write_result["fragments"]
1293-
1294-
if mode == "create" or mode == "overwrite":
1295-
operation = lance.LanceOperation.Overwrite(pyarrow_schema, fragments)
1296-
elif mode == "append":
1297-
operation = lance.LanceOperation.Append(fragments)
1298-
1299-
dataset = lance.LanceDataset.commit(table_uri, operation, read_version=version, storage_options=storage_options)
1300-
stats = dataset.stats.dataset_stats()
1301-
1302-
tbl = from_pydict(
1303-
{
1304-
"num_fragments": pa.array([stats["num_fragments"]], type=pa.int64()),
1305-
"num_deleted_rows": pa.array([stats["num_deleted_rows"]], type=pa.int64()),
1306-
"num_small_files": pa.array([stats["num_small_files"]], type=pa.int64()),
1307-
"version": pa.array([dataset.version], type=pa.int64()),
1308-
}
1309-
)
1310-
return tbl
1273+
sink = LanceDataSink(uri, self.schema(), mode, io_config, **kwargs)
1274+
return self.write_sink(sink)
13111275

13121276
###
13131277
# DataFrame operations

daft/dataframe/lance_data_sink.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import pathlib
2+
from itertools import chain
3+
from typing import Iterator, List, Literal, Optional, Union
4+
5+
import lance
6+
7+
from daft.context import get_context
8+
from daft.daft import IOConfig
9+
from daft.datatype import DataType
10+
from daft.io import DataSink, WriteOutput
11+
from daft.recordbatch import MicroPartition
12+
from daft.schema import Schema
13+
14+
15+
class LanceDataSink(DataSink[list[lance.FragmentMetadata]]):
16+
"""WriteSink for writing data to a Lance dataset."""
17+
18+
def _import_lance(self):
19+
try:
20+
import lance
21+
22+
return lance
23+
except ImportError:
24+
raise ImportError("lance is not installed. Please install lance using `pip install daft[lance]`")
25+
26+
def __init__(
27+
self,
28+
uri: Union[str, pathlib.Path],
29+
schema: Schema,
30+
mode: Literal["create", "append", "overwrite"],
31+
io_config: Optional[IOConfig] = None,
32+
**kwargs,
33+
):
34+
from daft.dependencies import pa
35+
from daft.io.object_store_options import io_config_to_storage_options
36+
37+
lance = self._import_lance()
38+
39+
if not isinstance(uri, (str, pathlib.Path)):
40+
raise TypeError(f"Expected URI to be str or pathlib.Path, got {type(uri)}")
41+
self._table_uri = str(uri)
42+
self._mode = mode
43+
self._io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config
44+
self._args = kwargs
45+
46+
self._storage_options = io_config_to_storage_options(self._io_config, self._table_uri)
47+
48+
self._pyarrow_schema = pa.schema((f.name, f.dtype.to_arrow_dtype()) for f in schema)
49+
50+
try:
51+
table = lance.dataset(self._table_uri, storage_options=self._storage_options)
52+
53+
except ValueError:
54+
table = None
55+
56+
self._version = 0
57+
if table:
58+
table_schema = table.schema
59+
self._version = table.latest_version
60+
if self._pyarrow_schema != table_schema and not (self._mode == "overwrite"):
61+
raise ValueError(
62+
"Schema of data does not match table schema\n"
63+
f"Data schema:\n{self._pyarrow_schema}\nTable Schema:\n{table_schema}"
64+
)
65+
66+
self._schema = Schema._from_field_name_and_types(
67+
[
68+
("num_fragments", DataType.int64()),
69+
("num_deleted_rows", DataType.int64()),
70+
("num_small_files", DataType.int64()),
71+
("version", DataType.int64()),
72+
]
73+
)
74+
75+
def schema(self) -> Schema:
76+
return self._schema
77+
78+
def write(self, micropartitions: Iterator[MicroPartition]) -> Iterator[WriteOutput[list[lance.FragmentMetadata]]]:
79+
"""Writes fragments from the given micropartitions."""
80+
lance = self._import_lance()
81+
82+
for micropartition in micropartitions:
83+
arrow_table = micropartition.to_arrow()
84+
bytes_written = arrow_table.nbytes
85+
rows_written = arrow_table.num_rows
86+
87+
fragments = lance.fragment.write_fragments(
88+
arrow_table,
89+
dataset_uri=self._table_uri,
90+
mode=self._mode,
91+
storage_options=self._storage_options,
92+
**self._args,
93+
)
94+
yield WriteOutput(
95+
output=fragments,
96+
bytes_written=bytes_written,
97+
rows_written=rows_written,
98+
)
99+
100+
def finalize(self, write_outputs: List[WriteOutput[list[lance.FragmentMetadata]]]) -> MicroPartition:
101+
"""Commits the fragments to the Lance dataset. Returns a DataFrame with the stats of the dataset."""
102+
from daft.dependencies import pa
103+
104+
lance = self._import_lance()
105+
106+
fragments = list(chain.from_iterable(write_output.output for write_output in write_outputs))
107+
108+
if self._mode == "create" or self._mode == "overwrite":
109+
operation = lance.LanceOperation.Overwrite(self._pyarrow_schema, fragments)
110+
elif self._mode == "append":
111+
operation = lance.LanceOperation.Append(fragments)
112+
113+
dataset = lance.LanceDataset.commit(
114+
self._table_uri, operation, read_version=self._version, storage_options=self._storage_options
115+
)
116+
stats = dataset.stats.dataset_stats()
117+
118+
tbl = MicroPartition.from_pydict(
119+
{
120+
"num_fragments": pa.array([stats["num_fragments"]], type=pa.int64()),
121+
"num_deleted_rows": pa.array([stats["num_deleted_rows"]], type=pa.int64()),
122+
"num_small_files": pa.array([stats["num_small_files"]], type=pa.int64()),
123+
"version": pa.array([dataset.version], type=pa.int64()),
124+
}
125+
)
126+
return tbl

daft/execution/execution_step.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
from typing import TYPE_CHECKING, Generic, Protocol
66

77
from daft.context import get_context
8-
from daft.daft import JoinSide, ResourceRequest
8+
from daft.daft import JoinSide, PyRecordBatch, ResourceRequest
99
from daft.expressions import Expression, ExpressionsProjection, col
1010
from daft.filesystem import overwrite_files
11-
from daft.recordbatch import MicroPartition, recordbatch_io
11+
from daft.recordbatch import MicroPartition, RecordBatch, recordbatch_io
1212
from daft.runners.partitioning import (
1313
Boundaries,
1414
MaterializedResult,
1515
PartialPartitionMetadata,
1616
PartitionMetadata,
1717
PartitionT,
1818
)
19+
from daft.series import Series
1920

2021
if TYPE_CHECKING:
2122
import pathlib
@@ -24,6 +25,7 @@
2425
from pyiceberg.table import TableProperties as IcebergTableProperties
2526

2627
from daft.daft import FileFormat, IOConfig, JoinType, ScanTask
28+
from daft.io import DataSink
2729
from daft.logical.map_partition_ops import MapPartitionOp
2830
from daft.logical.schema import Schema
2931

@@ -578,6 +580,30 @@ def _handle_file_write(self, input: MicroPartition) -> MicroPartition:
578580
)
579581

580582

583+
@dataclass(frozen=True)
584+
class DataSinkWrite(SingleOutputInstruction):
585+
sink: DataSink
586+
587+
def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
588+
result_field_name = "write_results"
589+
results = list(self.sink.write(iter(inputs)))
590+
results_series = Series.from_pylist(results, result_field_name, pyobj="force")
591+
series_dict = {result_field_name: results_series._series}
592+
rb = RecordBatch._from_pyrecordbatch(PyRecordBatch.from_pylist_series(series_dict))
593+
mp = MicroPartition._from_record_batches([rb])
594+
return [mp]
595+
596+
def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
597+
# TODO(desmond): We can potentially do something more useful here. For now, copy the implementation for the other writers.
598+
assert len(input_metadatas) == 1
599+
return [
600+
PartialPartitionMetadata(
601+
num_rows=None, # we can write more than 1 file per partition
602+
size_bytes=None,
603+
)
604+
]
605+
606+
581607
@dataclass(frozen=True)
582608
class Filter(SingleOutputInstruction):
583609
predicate: ExpressionsProjection

daft/execution/physical_plan.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from pyiceberg.table import TableProperties as IcebergTableProperties
6161

6262
from daft.daft import FileFormat, IOConfig, JoinType
63+
from daft.io import DataSink
6364
from daft.logical.schema import Schema
6465

6566

@@ -224,6 +225,17 @@ def lance_write(
224225
)
225226

226227

228+
def data_sink_write(
229+
child_plan: InProgressPhysicalPlan[PartitionT],
230+
sink: DataSink,
231+
) -> InProgressPhysicalPlan[PartitionT]:
232+
"""Write the results of `child_plan` into a custom write sink described by `sink`."""
233+
yield from (
234+
step.add_instruction(execution_step.DataSinkWrite(sink)) if isinstance(step, PartitionTaskBuilder) else step
235+
for step in child_plan
236+
)
237+
238+
227239
def pipeline_instruction(
228240
child_plan: InProgressPhysicalPlan[PartitionT],
229241
pipeable_instruction: Instruction,

daft/execution/rust_physical_plan_shim.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pyiceberg.schema import Schema as IcebergSchema
2424
from pyiceberg.table import TableProperties as IcebergTableProperties
2525

26+
from daft.io import DataSink
2627
from daft.recordbatch import MicroPartition
2728

2829

@@ -415,3 +416,10 @@ def write_lance(
415416
kwargs: dict | None,
416417
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
417418
return physical_plan.lance_write(input, path, mode, io_config, kwargs)
419+
420+
421+
def write_data_sink(
422+
input: physical_plan.InProgressPhysicalPlan[PartitionT],
423+
sink: DataSink,
424+
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
425+
return physical_plan.data_sink_write(input, sink)

0 commit comments

Comments
 (0)