Source code for etl_lib.data_source.SQLBatchSource

import time
from typing import Generator, Callable, Optional, List, Dict

from psycopg2 import OperationalError as PsycopgOperationalError
from sqlalchemy import text
from sqlalchemy.exc import OperationalError as SAOperationalError, DBAPIError

from etl_lib.core.BatchProcessor import BatchResults, BatchProcessor
from etl_lib.core.ETLContext import ETLContext
from etl_lib.core.Task import Task


[docs] class SQLBatchSource(BatchProcessor):
[docs] def __init__( self, context: ETLContext, task: Task, query: str, record_transformer: Optional[Callable[[dict], dict]] = None, **kwargs ): """ Constructs a new SQLBatchSource. Args: context: :class:`etl_lib.core.ETLContext.ETLContext` instance. task: :class:`etl_lib.core.Task.Task` instance owning this batchProcessor. query: SQL query to execute. record_transformer: Optional function to transform each row (dict format). kwargs: Arguments passed as parameters with the query. """ super().__init__(context, task) self.query = query.strip().rstrip(";") self.record_transformer = record_transformer self.kwargs = kwargs
def _fetch_page(self, limit: int, offset: int) -> Optional[List[Dict]]: """ Fetch a single batch of rows using LIMIT/OFFSET, with retry/backoff. Each page is executed in its own transaction. On transient disconnects or DB errors, it retries up to 3 times with exponential backoff. Args: limit: maximum number of rows to return. offset: number of rows to skip before starting this page. Returns: A list of row dicts (after applying record_transformer, if any), or None if no rows are returned. Raises: Exception: re-raises the last caught error on final failure. """ paged_sql = f"{self.query} LIMIT :limit OFFSET :offset" params = {**self.kwargs, "limit": limit, "offset": offset} max_retries = 5 backoff = 2.0 for attempt in range(1, max_retries + 1): try: with self.context.sql.engine.connect() as conn: with conn.begin(): rows = conn.execute(text(paged_sql), params).mappings().all() result = [ self.record_transformer(dict(r)) if self.record_transformer else dict(r) for r in rows ] return result if result else None except (PsycopgOperationalError, SAOperationalError, DBAPIError) as err: if attempt == max_retries: self.logger.error( f"Page fetch failed after {max_retries} attempts " f"(limit={limit}, offset={offset}): {err}" ) raise self.logger.warning( f"Transient DB error on page fetch {attempt}/{max_retries}: {err!r}, " f"retrying in {backoff:.1f}s" ) time.sleep(backoff) backoff *= 2 return None
[docs] def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]: """ Yield successive batches until the query is exhausted. Calls _fetch_page() repeatedly, advancing the offset by the number of rows returned. Stops when no more rows are returned. Args: max_batch_size: upper limit on rows per batch. Yields: BatchResults for each non-empty page. """ offset = 0 while True: chunk = self._fetch_page(max_batch_size, offset) if not chunk: break yield BatchResults( chunk=chunk, statistics={"sql_rows_read": len(chunk)}, batch_size=len(chunk), ) offset += len(chunk)