Partilhar via


DataSourceArrowWriter

Uma classe base para escritores de fontes de dados que processam dados usando o RecordBatcharquivo do PyArrow.

Ao contrário DataSourceWriterde , que trabalha com um iterador de objetos Spark Row , esta classe está otimizada para o formato Arrow ao escrever dados. Pode oferecer melhor desempenho ao interagir com sistemas ou bibliotecas que suportam nativamente o Arrow. Implemente esta classe e retorne uma instância de DataSource.writer() para tornar uma fonte de dados gravável usando o Arrow.

Sintaxe

from pyspark.sql.datasource import DataSourceArrowWriter

class MyDataSourceArrowWriter(DataSourceArrowWriter):
    def write(self, iterator):
        ...

Methods

Método Descrição
write(iterator) Escreve um iterador de objetos PyArrow RecordBatch no lava-loiça. Chamado uma vez a cada executor. Devolve uma WriterCommitMessage, ou None se não houver mensagem de commit. Este método é abstrato e deve ser implementado.
commit(messages) Faz commit na tarefa de escrita usando uma lista de mensagens de commit recolhidas de todos os executores. Invocado no driver quando todas as tarefas correm com sucesso. Herdado de DataSourceWriter.
abort(messages) Aborta o trabalho de escrita usando uma lista de mensagens de commit recolhidas de todos os executores. Invocado no driver quando uma ou mais tarefas falhavam. Herdado de DataSourceWriter.

Notes

  • O driver recolhe mensagens de commit de todos os executores e passa-as se commit() todas as tarefas têm sucesso, ou se abort() alguma falha.
  • Se uma tarefa de escrita falhar, a sua mensagem de commit estará None na lista passada para commit() ou abort().

Exemplos

Implemente um escritor baseado em Arrow que conte as linhas em todos os lotes:

from dataclasses import dataclass
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, WriterCommitMessage

@dataclass
class MyCommitMessage(WriterCommitMessage):
    num_rows: int

class MyDataSourceArrowWriter(DataSourceArrowWriter):
    def write(self, iterator):
        total_rows = 0
        for batch in iterator:
            total_rows += len(batch)
        return MyCommitMessage(num_rows=total_rows)

    def commit(self, messages):
        total = sum(m.num_rows for m in messages if m is not None)
        print(f"Committed {total} rows")

    def abort(self, messages):
        print("Write job failed, performing cleanup")