Source code for etl_lib.data_source.ParquetBatchSource

import logging
from pathlib import Path
from typing import Generator, Optional

try:
    import pyarrow.parquet as pq
except ImportError:
    pq = None

from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
from etl_lib.core.ETLContext import ETLContext
from etl_lib.core.Task import Task


[docs] class ParquetBatchSource(BatchProcessor): """ BatchProcessor that reads a Parquet file using pyarrow. The returned batch of rows will have an additional `_row` column, containing the source row of the data, starting with 0. """
[docs] def __init__(self, file: Path, context: ETLContext, task: Optional[Task] = None, **kwargs): """ Constructs a new ParquetBatchSource. Args: file: Path to the Parquet file. context: :class:`etl_lib.core.ETLContext.ETLContext` instance. kwargs: Will be passed on to the `pyarrow.parquet.ParquetFile.iter_batches` method. """ super().__init__(context, task) if pq is None: raise ImportError("pyarrow is required for ParquetBatchSource. Install with 'pip install .[parquet]'") self.file = file self.kwargs = kwargs self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
[docs] @staticmethod def get_total_rows(file: Path) -> int: if pq is None: raise ImportError("pyarrow is required. Install with 'pip install .[parquet]'") return pq.ParquetFile(file).metadata.num_rows
[docs] def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]: parquet_file = pq.ParquetFile(self.file) batch_iter = parquet_file.iter_batches(batch_size=max_batch_size, **self.kwargs) row_counter = 0 for batch in batch_iter: rows = batch.to_pylist() for i, row in enumerate(rows): row["_row"] = row_counter + i batch_len = len(rows) row_counter += batch_len yield BatchResults( chunk=rows, statistics={"parquet_rows_read": batch_len}, batch_size=batch_len, )