Source code for etl_lib.core.ETLContext

import logging
from typing import Any, Dict, List, NamedTuple

from neo4j.exceptions import Neo4jError

try:
    from graphdatascience import GraphDataScience

    gds_available = False
except ImportError:
    gds_available = False
    logging.info("Graph Data Science not installed, skipping")
    GraphDataScience = None

from neo4j import GraphDatabase, Session, WRITE_ACCESS, SummaryCounters

try:
    from sqlalchemy import create_engine
    from sqlalchemy.engine import Engine

    sqlalchemy_available = True
except ImportError:
    sqlalchemy_available = False
    logging.info("SQL Alchemy not installed, skipping")
    create_engine = None  # this and next line needed to prevent PyCharm warning
    Engine = None

from etl_lib.core.ProgressReporter import get_reporter


[docs] class QueryResult(NamedTuple): """Result of a query against the neo4j database.""" data: List[Any] """Data as returned from the query.""" summery: Dict[str, int] """Counters as reported by neo4j. Contains entries such as `nodes_created`, `nodes_deleted`, etc."""
[docs] def append_results(r1: QueryResult, r2: QueryResult) -> QueryResult: """ Appends two QueryResult objects, summing the values for duplicate keys in the summary. Args: r1: The first QueryResult object. r2: The second QueryResult object to append. Returns: A new QueryResult object with combined data and summed summary counts. """ combined_summery = r1.summery.copy() for key, value in r2.summery.items(): combined_summery[key] = combined_summery.get(key, 0) + value return QueryResult(r1.data + r2.data, combined_summery)
[docs] class Neo4jContext: """ Holds the connection to the neo4j database and provides facilities to execute queries. """
[docs] def __init__(self, env_vars: dict): """ Create a new Neo4j context. Reads the following env_vars keys: - `NEO4J_URI`, - `NEO4J_USERNAME`, - `NEO4J_PASSWORD`. - `NEO4J_DATABASE`, """ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") self.uri = env_vars["NEO4J_URI"] self.auth = (env_vars["NEO4J_USERNAME"], env_vars["NEO4J_PASSWORD"]) self.database = env_vars["NEO4J_DATABASE"] self.__neo4j_connect()
[docs] def query_database(self, session: Session, query, **kwargs) -> QueryResult: """ Executes Cypher and returns (records, counters) with retryable write semantics. Accepts either a single query string or a list of queries. Does not work with CALL {} IN TRANSACTION queries. """ if isinstance(query, list): results = None for single in query: part = self.query_database(session, single, **kwargs) results = append_results(results, part) if results is not None else part return results def _tx(tx, q, params): res = tx.run(q, **params) records = list(res) counters = res.consume().counters return records, counters try: records, counters = session.execute_write(_tx, query, kwargs) return QueryResult(records, self.__counters_2_dict(counters)) except Neo4jError as e: self.logger.error(e) raise
@staticmethod def __counters_2_dict(counters: SummaryCounters): return { "constraints_added": counters.constraints_added, "constraints_removed": counters.constraints_removed, "indexes_added": counters.indexes_added, "indexes_removed": counters.indexes_removed, "labels_added": counters.labels_added, "labels_removed": counters.labels_removed, "nodes_created": counters.nodes_created, "nodes_deleted": counters.nodes_deleted, "properties_set": counters.properties_set, "relationships_created": counters.relationships_created, "relationships_deleted": counters.relationships_deleted, }
[docs] def session(self, database=None): """ Create a new Neo4j session in write mode, caller is responsible to close the session. Args: database: name of the database to use for this session. If not provided, the database name provided during construction will be used. Returns: newly created Neo4j session. """ if database is None: return self.driver.session(database=self.database, default_access_mode=WRITE_ACCESS) else: return self.driver.session(database=database, default_access_mode=WRITE_ACCESS)
def __neo4j_connect(self): self.driver = GraphDatabase.driver(uri=self.uri, auth=self.auth, notifications_min_severity="OFF") self.driver.verify_connectivity() self.logger.info( f"driver connected to instance at {self.uri} with username {self.auth[0]} and database {self.database}")
[docs] def gds(neo4j_context) -> GraphDataScience: """ Creates a new GraphDataScience client. Args: neo4j_context: Neo4j context containing driver and database name. Returns: gds client. """ return GraphDataScience.from_neo4j_driver(driver=neo4j_context.driver, database=neo4j_context.database)
if sqlalchemy_available:
[docs] class SQLContext:
[docs] def __init__(self, database_url: str, pool_size: int = 10, max_overflow: int = 20): """ Initializes the SQL context with an SQLAlchemy engine. Args: database_url (str): SQLAlchemy connection URL. pool_size (int): Number of connections to maintain in the pool. max_overflow (int): Additional connections allowed beyond pool_size. """ self.engine: Engine = create_engine( database_url, pool_pre_ping=True, pool_size=pool_size, max_overflow=max_overflow, pool_recycle=1800, # recycle connections older than 30m connect_args={ # turn on TCP keepalives on the client socket: "keepalives": 1, "keepalives_idle": 60, # after 60s of idle "keepalives_interval": 10, # probe every 10s "keepalives_count": 5, # give up after 5 failed probes })
[docs] class ETLContext: """ General context information. Will be passed to all :class:`~etl_lib.core.Task.Task` to provide access to environment variables and functionally deemed general enough that all parts of the ETL pipeline would need it. """
[docs] def __init__(self, env_vars: dict): """ Create a new ETLContext. Args: env_vars: Environment variables. Stored internally and can be accessed via :func:`~env` . The context created will contain an :class:`~Neo4jContext` and a :class:`~etl_lib.core.ProgressReporter.ProgressReporter`. See there for keys used from the provided `env_vars` dict. """ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") self.neo4j = Neo4jContext(env_vars) self.__env_vars = env_vars self.reporter = get_reporter(self) sql_uri = self.env("SQLALCHEMY_URI") if sql_uri is not None and sqlalchemy_available: self.sql = SQLContext(sql_uri) if gds_available: self.gds = gds(self.neo4j)
[docs] def env(self, key: str) -> Any: """ Returns the value of an entry in the `env_vars` dict. Args: key: name of the entry to read. Returns: value of the entry, or None if the key is not in the dict. """ if key in self.__env_vars: return self.__env_vars[key] return None