Source code for etl_lib.core.ParallelBatchProcessor

import queue
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Generator, List

from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
from etl_lib.core.utils import merge_summery


[docs] class ParallelBatchResult(BatchResults): """ Represents one *wave* produced by the splitter. `chunk` is a list of bucket-batches. Each sub-list is processed by one worker instance. """ pass
[docs] class ParallelBatchProcessor(BatchProcessor): """ BatchProcessor that runs a worker over the bucket-batches of each ParallelBatchResult in parallel threads, while prefetching the next ParallelBatchResult from its predecessor. Note: - The predecessor must emit `ParallelBatchResult` instances (waves). - This processor collects the BatchResults from all workers for one wave and merges them into one BatchResults. - The returned BatchResults will not obey the max_batch_size from get_batch() because it represents the full wave. Args: context: ETL context. worker_factory: A zero-arg callable that returns a new BatchProcessor each time it's called. task: optional Task for reporting. predecessor: upstream BatchProcessor that must emit ParallelBatchResult See :class:`~etl_lib.core.SplittingBatchProcessor.SplittingBatchProcessor`. max_workers: number of parallel threads for bucket processing. prefetch: number of waves to prefetch. Behavior: - For every wave, spins up `max_workers` threads. - Each thread processes one bucket-batch using a fresh worker from `worker_factory()`. - Collects and merges worker results in a fail-fast manner. """
[docs] def __init__( self, context, worker_factory: Callable[[], BatchProcessor], task=None, predecessor=None, max_workers: int = 4, prefetch: int = 4, ): super().__init__(context, task, predecessor) self.worker_factory = worker_factory self.max_workers = max_workers self.prefetch = prefetch
def _process_wave(self, wave: ParallelBatchResult) -> BatchResults: """ Process one wave: run one worker per bucket-batch and merge their BatchResults. Statistics: `wave.statistics` is used as the initial merged stats, then merged with each worker's stats. """ merged_stats = dict(wave.statistics or {}) merged_chunk = [] total = 0 self.logger.debug(f"Processing wave with {len(wave.chunk)} buckets") with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix="PBP_worker_") as pool: futures = [pool.submit(self._process_bucket_batch, bucket_batch) for bucket_batch in wave.chunk] try: for f in as_completed(futures): out = f.result() merged_stats = merge_summery(merged_stats, out.statistics or {}) total += out.batch_size merged_chunk.extend(out.chunk if isinstance(out.chunk, list) else [out.chunk]) except Exception as e: self.logger.exception("bucket processing failed") for g in futures: g.cancel() pool.shutdown(cancel_futures=True) raise self.logger.debug(f"Finished wave with stats={merged_stats}") return BatchResults(chunk=merged_chunk, statistics=merged_stats, batch_size=total)
[docs] def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]: """ Pull waves from the predecessor (prefetching up to `prefetch` ahead), process each wave's buckets in parallel, and yield one flattened BatchResults per wave. """ wave_queue: queue.Queue[ParallelBatchResult | object] = queue.Queue(self.prefetch) SENTINEL = object() exc: BaseException | None = None def producer(): nonlocal exc try: for wave in self.predecessor.get_batch(max_batch_size): self.logger.debug( f"adding wave stats={wave.statistics} buckets={len(wave.chunk)} to queue size={wave_queue.qsize()}" ) wave_queue.put(wave) except BaseException as e: exc = e finally: wave_queue.put(SENTINEL) threading.Thread(target=producer, daemon=True, name="prefetcher").start() while True: wave = wave_queue.get() if wave is SENTINEL: if exc is not None: self.logger.error("Upstream producer failed", exc_info=True) raise exc break yield self._process_wave(wave)
[docs] class SingleBatchWrapper(BatchProcessor): """ Simple BatchProcessor that returns exactly one batch (the bucket-batch passed in via init). Used as predecessor for the per-bucket worker. """
[docs] def __init__(self, context, batch: List[Any]): super().__init__(context=context, predecessor=None) self._batch = batch
[docs] def get_batch(self, max_size: int) -> Generator[BatchResults, None, None]: yield BatchResults( chunk=self._batch, statistics={}, batch_size=len(self._batch), )
def _process_bucket_batch(self, bucket_batch): """ Process one bucket-batch by running a fresh worker over it. """ self.logger.debug(f"Processing batch w/ size {len(bucket_batch)}") wrapper = self.SingleBatchWrapper(self.context, bucket_batch) worker = self.worker_factory() worker.predecessor = wrapper result = next(worker.get_batch(len(bucket_batch))) self.logger.debug(f"Finished bucket batch stats={result.statistics}") return result