Source code for etl_lib.core.ParallelBatchProcessor

import queue
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Generator, List

from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
from etl_lib.core.utils import merge_summery


[docs] class ParallelBatchResult(BatchResults): """ Represents a batch split into parallelizable partitions. `chunk` is a list of lists, each sub-list is a partition. """ pass
[docs] class ParallelBatchProcessor(BatchProcessor): """ BatchProcessor that runs worker threads over partitions of batches. Receives a special BatchResult (:py:class:`ParallelBatchResult`) from the predecessor. All chunks in a ParallelBatchResult it receives can be processed in parallel. See :py:class:`etl_lib.core.SplittingBatchProcessor` on how to produce them. Prefetches the next ParallelBatchResults from its predecessor. The actual processing of the batches is deferred to the configured worker. Note: - The predecessor must emit `ParallelBatchResult` instances. Args: context: ETL context. worker_factory: A zero-arg callable that returns a new BatchProcessor each time it's called. This processor is responsible for the processing pf the batches. task: optional Task for reporting. predecessor: upstream BatchProcessor that must emit ParallelBatchResult. max_workers: number of parallel threads for partitions. prefetch: number of ParallelBatchResults to prefetch from the predecessor. Behavior: - For every ParallelBatchResult, spins up `max_workers` threads. - Each thread calls its own worker from `worker_factory()`, with its partition wrapped by `SingleBatchWrapper`. - Collects and merges their BatchResults in a fail-fast manner: on first exception, logs the error, cancels remaining threads, and raises an exception. """
[docs] def __init__( self, context, worker_factory: Callable[[], BatchProcessor], task=None, predecessor=None, max_workers: int = 4, prefetch: int = 4, ): super().__init__(context, task, predecessor) self.worker_factory = worker_factory self.max_workers = max_workers self.prefetch = prefetch self._batches_done = 0
def _process_parallel(self, pbr: ParallelBatchResult) -> BatchResults: """ Run one worker per partition in `pbr.chunk`, merge their outputs, and include upstream statistics from `pbr.statistics` so counters (e.g., valid/invalid rows from validation) are preserved through the parallel stage. Progress reporting: - After each partition completes, report batch count only """ merged_stats = dict(pbr.statistics or {}) merged_chunk = [] total = 0 parts_total = len(pbr.chunk) partitions_done = 0 self.logger.debug(f"Processing pbr of len {parts_total}") with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix='PBP_worker_') as pool: futures = [pool.submit(self._process_partition, part) for part in pbr.chunk] try: for f in as_completed(futures): out = f.result() # Merge into this PBR's cumulative result (returned downstream) merged_stats = merge_summery(merged_stats, out.statistics or {}) total += out.batch_size merged_chunk.extend(out.chunk if isinstance(out.chunk, list) else [out.chunk]) partitions_done += 1 self.context.reporter.report_progress( task=self.task, batches=self._batches_done, expected_batches=None, stats={}, ) except Exception as e: for g in futures: g.cancel() pool.shutdown(cancel_futures=True) raise RuntimeError("partition processing failed") from e self.logger.debug(f"Finished processing pbr with {merged_stats}") return BatchResults(chunk=merged_chunk, statistics=merged_stats, batch_size=total)
[docs] def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]: """ Pulls ParallelBatchResult batches from the predecessor, prefetching up to `prefetch` ahead, processes each batch's partitions in parallel threads, and yields a flattened BatchResults. The predecessor can run ahead while the current batch is processed. """ pbr_queue: queue.Queue[ParallelBatchResult | object] = queue.Queue(self.prefetch) SENTINEL = object() exc: BaseException | None = None def producer(): nonlocal exc try: for pbr in self.predecessor.get_batch(max_batch_size): self.logger.debug( f"adding pgr {pbr.statistics} / {len(pbr.chunk)} to queue of size {pbr_queue.qsize()}" ) pbr_queue.put(pbr) except BaseException as e: exc = e finally: pbr_queue.put(SENTINEL) threading.Thread(target=producer, daemon=True, name='prefetcher').start() while True: pbr = pbr_queue.get() if pbr is SENTINEL: if exc is not None: self.logger.error("Upstream producer failed", exc_info=True) raise exc break result = self._process_parallel(pbr) yield result
[docs] class SingleBatchWrapper(BatchProcessor): """ Simple BatchProcessor that returns the batch it receives via init. Will be used as predecessor for the worker """
[docs] def __init__(self, context, batch: List[Any]): super().__init__(context=context, predecessor=None) self._batch = batch
[docs] def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]: # Ignores max_size; yields exactly one BatchResults containing the whole batch yield BatchResults( chunk=self._batch, statistics={}, batch_size=len(self._batch) )
def _process_partition(self, partition): """ Processes one partition of items by: 1. Wrapping it in SingleBatchWrapper 2. Instantiating a fresh worker via worker_factory() 3. Setting the worker's predecessor to the wrapper 4. Running exactly one batch and returning its BatchResults Raises whatever exception the worker raises, allowing _process_parallel to handle fail-fast behavior. """ self.logger.debug("Processing partition") wrapper = self.SingleBatchWrapper(self.context, partition) worker = self.worker_factory() worker.predecessor = wrapper result = next(worker.get_batch(len(partition))) self.logger.debug(f"finished processing partition with {result.statistics}") return result