Source code for etl_lib.task.data_loading.ParallelParquetLoad2Neo4jTask

import abc
from pathlib import Path
from typing import Optional, Type

from pydantic import BaseModel

from etl_lib.core.ClosedLoopBatchProcessor import ClosedLoopBatchProcessor
from etl_lib.core.ETLContext import ETLContext
from etl_lib.core.ParallelBatchProcessor import ParallelBatchProcessor
from etl_lib.core.SplittingBatchProcessor import (SplittingBatchProcessor,
                                                  dict_id_extractor)
from etl_lib.core.Task import Task, TaskReturn
from etl_lib.core.ValidationBatchProcessor import ValidationBatchProcessor
from etl_lib.data_sink.CypherBatchSink import CypherBatchSink
from etl_lib.data_source.ParquetBatchSource import ParquetBatchSource


[docs] class ParallelParquetLoad2Neo4jTask(Task): """ Parallel Parquet → Neo4j load using the mix-and-batch strategy. """
[docs] def __init__(self, context: ETLContext, file: Path, model: Optional[Type[BaseModel]] = None, error_file: Optional[Path] = None, table_size: int = 10, batch_size: int = 5000, max_workers: Optional[int] = None, prefetch: int = 4, **parquet_reader_kwargs): super().__init__(context) self.file = file self.model = model if model is not None and error_file is None: raise ValueError('you must provide error file if the model is specified') self.error_file = error_file self.table_size = table_size self.batch_size = batch_size self.max_workers = max_workers or table_size self.prefetch = prefetch self.parquet_reader_kwargs = parquet_reader_kwargs
[docs] def run_internal(self) -> TaskReturn: total_count = ParquetBatchSource.get_total_rows(self.file) source = ParquetBatchSource(self.file, self.context, self, **self.parquet_reader_kwargs) predecessor = source if self.model is not None: predecessor = ValidationBatchProcessor(self.context, self, source, self.model, self.error_file) splitter = SplittingBatchProcessor( context=self.context, task=self, predecessor=predecessor, table_size=self.table_size, id_extractor=self._id_extractor() ) parallel = ParallelBatchProcessor( context=self.context, task=self, predecessor=splitter, worker_factory=lambda: CypherBatchSink(self.context, self, None, self._query()), max_workers=self.max_workers, prefetch=self.prefetch ) closing = ClosedLoopBatchProcessor(self.context, self, parallel, expected_rows=total_count) result = next(closing.get_batch(self.batch_size)) return TaskReturn(True, result.statistics)
def _id_extractor(self): return dict_id_extractor() @abc.abstractmethod def _query(self): pass