Source code for etl_lib.task.data_loading.CSVLoad2Neo4jTask

import abc
import logging
from pathlib import Path
from typing import Type

from pydantic import BaseModel

from etl_lib.core.ETLContext import ETLContext
from etl_lib.core.ClosedLoopBatchProcessor import ClosedLoopBatchProcessor
from etl_lib.core.Task import Task, TaskReturn
from etl_lib.core.ValidationBatchProcessor import ValidationBatchProcessor
from etl_lib.data_sink.CypherBatchSink import CypherBatchSink
from etl_lib.data_source.CSVBatchSource import CSVBatchSource


[docs] class CSVLoad2Neo4jTask(Task): ''' Loads the specified CSV file to Neo4j. Uses BatchProcessors to read, validate and write to Neo4j. The validation step is using pydantic, hence a Pydantic model needs to be provided. Rows with fail validation will be written to en error file. The location of the error file is determined as follows: If the context env vars hold an entry `ETL_ERROR_PATH` the file will be placed there, with the name set to name of the provided filename appended with `.error.json` If `ETL_ERROR_PATH` is not set, the file will be placed in the same directory as the CSV file. Example usage: (from the gtfs demo) .. code-block:: python class LoadStopsTask(CSVLoad2Neo4jTask): class Stop(BaseModel): id: str = Field(alias="stop_id") name: str = Field(alias="stop_name") latitude: float = Field(alias="stop_lat") longitude: float = Field(alias="stop_lon") platform_code: Optional[str] = None parent_station: Optional[str] = None type: Optional[str] = Field(alias="location_type", default=None) timezone: Optional[str] = Field(alias="stop_timezone", default=None) code: Optional[str] = Field(alias="stop_code", default=None) def __init__(self, context: ETLContext, file: Path): super().__init__(context, LoadStopsTask.Stop, file) def task_name(self) -> str: return f"{self.__class__.__name__}('{self.file}')" def _query(self): return """ UNWIND $batch AS row MERGE (s:Stop {id: row.id}) SET s.name = row.name, s.location= point({latitude: row.latitude, longitude: row.longitude}), s.platformCode= row.platform_code, s.parentStation= row.parent_station, s.type= row.type, s.timezone= row.timezone, s.code= row.code """ '''
[docs] def __init__(self, context: ETLContext, model: Type[BaseModel], file: Path, batch_size: int = 5000): super().__init__(context) self.batch_size = batch_size self.model = model self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") self.file = file
[docs] def run_internal(self, **kwargs) -> TaskReturn: error_path = self.context.env("ETL_ERROR_PATH") if error_path is None: error_file = self.file.with_suffix(".error.json") else: error_file = error_path / self.file.with_name(self.file.stem + ".error.json").name csv = CSVBatchSource(self.file, self.context, self) validator = ValidationBatchProcessor(self.context, self, csv, self.model, error_file) cypher = CypherBatchSink(self.context, self, validator, self._query()) end = ClosedLoopBatchProcessor(self.context, self, cypher) result = next(end.get_batch(self.batch_size)) return TaskReturn(True, result.statistics)
def __repr__(self): return f"{self.__class__.__name__}({self.file})" @abc.abstractmethod def _query(self): pass