Source code for etl_lib.core.ValidationBatchProcessor

import json
from pathlib import Path
from typing import Type, Generator

from pydantic import BaseModel, ValidationError

from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
from etl_lib.core.ETLContext import ETLContext
from etl_lib.core.Task import Task
from etl_lib.core.utils import merge_summery


[docs] class ValidationBatchProcessor(BatchProcessor): """ Batch processor for validation, using Pydantic. """
[docs] def __init__(self, context: ETLContext, task: Task, predecessor, model: Type[BaseModel] | None, error_file: Path | None): """ Constructs a new ValidationBatchProcessor. The :py:class:`etl_lib.core.BatchProcessor.BatchResults` returned from the :py:func:`~get_batch` of this implementation will contain the following additional entries: - `valid_rows`: Number of valid rows. - `invalid_rows`: Number of invalid rows. Args: context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance. task: :py:class:`etl_lib.core.Task.Task` instance owning this batchProcessor. predecessor: BatchProcessor which :py:func:`~get_batch` function will be called to receive batches to process. model: Pydantic model class used to validate each row in the batch. Optional. error_file: Path to the file that will receive each row that did not pass validation. Required if `model` is provided. """ super().__init__(context, task, predecessor) if model is not None and error_file is None: raise ValueError('you must provide error file if the model is specified') self.error_file = error_file self.model = model
[docs] def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]: assert self.predecessor is not None if self.model is None: for batch in self.predecessor.get_batch(max_batch_size): yield BatchResults( chunk=batch.chunk, statistics=merge_summery(batch.statistics, { "valid_rows": len(batch.chunk), "invalid_rows": 0 }), batch_size=len(batch.chunk) ) return for batch in self.predecessor.get_batch(max_batch_size): valid_rows = [] invalid_rows = [] for row in batch.chunk: try: # Validate and transform the row validated_row = json.loads(self.model(**row).model_dump_json()) valid_rows.append(validated_row) except ValidationError as e: # Collect invalid rows with errors invalid_rows.append({"row": row, "errors": e.errors()}) # Write invalid rows to the error file if invalid_rows: assert self.error_file is not None with open(self.error_file, "a") as f: for invalid in invalid_rows: # the following is needed as ValueError (contained in 'ctx') is not json serializable serializable = {"row": invalid["row"], "errors": [{k: v for k, v in e.items() if k != "ctx"} for e in invalid["errors"]]} f.write(f"{json.dumps(serializable)}\n") # Yield BatchResults with statistics yield BatchResults( chunk=valid_rows, statistics=merge_summery(batch.statistics, { "valid_rows": len(valid_rows), "invalid_rows": len(invalid_rows) }), batch_size=len(batch.chunk) )