Source code for etl_lib.task.data_loading.CSVLoad2Neo4jTask

import abc
import logging
from pathlib import Path
from typing import Type

from pydantic import BaseModel

from etl_lib.core.ETLContext import ETLContext
from etl_lib.core.ClosedLoopBatchProcessor import ClosedLoopBatchProcessor
from etl_lib.core.Task import Task, TaskReturn
from etl_lib.core.ValidationBatchProcessor import ValidationBatchProcessor
from etl_lib.data_sink.CypherBatchSink import CypherBatchSink
from etl_lib.data_source.CSVBatchSource import CSVBatchSource


[docs] class CSVLoad2Neo4jTask(Task): ''' Loads the specified CSV file to Neo4j. Uses BatchProcessors to read, optionally validate, and write to Neo4j. The validation step uses Pydantic and is only enabled when a model is provided. Rows with fail validation will be written to en error file. The location of the error file is determined as follows: If the context env vars hold an entry `ETL_ERROR_PATH` the file will be placed there, with the name set to name of the provided filename appended with `.error.json` If `ETL_ERROR_PATH` is not set, the file will be placed in the same directory as the CSV file. Example usage: (from the gtfs demo) .. code-block:: python class LoadStopsTask(CSVLoad2Neo4jTask): class Stop(BaseModel): id: str = Field(alias="stop_id") name: str = Field(alias="stop_name") latitude: float = Field(alias="stop_lat") longitude: float = Field(alias="stop_lon") platform_code: Optional[str] = None parent_station: Optional[str] = None type: Optional[str] = Field(alias="location_type", default=None) timezone: Optional[str] = Field(alias="stop_timezone", default=None) code: Optional[str] = Field(alias="stop_code", default=None) def __init__(self, context: ETLContext, file: Path): super().__init__(context, file, model=LoadStopsTask.Stop) def task_name(self) -> str: return f"{self.__class__.__name__}('{self.file}')" def _query(self): return """ UNWIND $batch AS row MERGE (s:Stop {id: row.id}) SET s.name = row.name, s.location= point({latitude: row.latitude, longitude: row.longitude}), s.platformCode= row.platform_code, s.parentStation= row.parent_station, s.type= row.type, s.timezone= row.timezone, s.code= row.code """ '''
[docs] def __init__(self, context: ETLContext, file: Path, model: Type[BaseModel] | None = None, batch_size: int = 5000): super().__init__(context) self.batch_size = batch_size self.model = model self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") self.file = file
[docs] def run_internal(self, **kwargs) -> TaskReturn: csv = CSVBatchSource(self.file, self.context, self) predecessor = csv if self.model is not None: error_path = self.context.env("ETL_ERROR_PATH") if error_path is None: error_file = self.file.with_suffix(".error.json") else: error_file = error_path / self.file.with_name(self.file.stem + ".error.json").name predecessor = ValidationBatchProcessor(self.context, self, csv, self.model, error_file) cypher = CypherBatchSink(self.context, self, predecessor, self._query()) end = ClosedLoopBatchProcessor(self.context, self, cypher) result = next(end.get_batch(self.batch_size)) return TaskReturn(True, result.statistics)
def __repr__(self): return f"{self.__class__.__name__}({self.file})" @abc.abstractmethod def _query(self) -> str: pass