使用 PyArrow 处理 RecordBatch数据的数据源编写器的基类。
与 Spark 对象的迭代器DataSourceWriter不同Row,此类在写入数据时针对箭头格式进行优化。 与本机支持 Arrow 的系统或库交互时,它可以提供更好的性能。 实现此类并返回一个实例 DataSource.writer() ,以便使用 Arrow 使数据源可写。
Syntax
from pyspark.sql.datasource import DataSourceArrowWriter
class MyDataSourceArrowWriter(DataSourceArrowWriter):
def write(self, iterator):
...
方法
| 方法 | 说明 |
|---|---|
write(iterator) |
将 PyArrow RecordBatch 对象的迭代器写入接收器。 在每个执行程序上调用一次。 返回一个 WriterCommitMessage或 None 如果没有提交消息。 此方法是抽象的,必须实现。 |
commit(messages) |
使用从所有执行程序收集的提交消息列表提交写入作业。 成功运行所有任务时,在驱动程序上调用。 从 DataSourceWriter 继承。 |
abort(messages) |
使用从所有执行程序收集的提交消息列表中止写入作业。 当一个或多个任务失败时,在驱动程序上调用。 从 DataSourceWriter 继承。 |
备注
- 驱动程序从所有执行程序收集提交消息,并将其传递给
commit()所有任务是否成功,或传递到abort()任何任务失败。 - 如果写入任务失败,则其提交消息将
None位于传递给或commit()传递abort()的列表中。
示例
实现基于箭头的编写器,用于计算所有批处理中的行数:
from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage
@dataclass
class MyCommitMessage(WriterCommitMessage):
num_rows: int
class MyDataSourceArrowWriter(DataSourceArrowWriter):
def write(self, iterator):
total_rows = 0
for batch in iterator:
total_rows += len(batch)
return MyCommitMessage(num_rows=total_rows)
def commit(self, messages):
total = sum(m.num_rows for m in messages if m is not None)
print(f"Committed {total} rows")
def abort(self, messages):
print("Write job failed, performing cleanup")