diff --git a/CHANGELOG.md b/CHANGELOG.md index a4d9cf7..063b0fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,4 +93,7 @@ - NaT further downstream # 0.3.22 -- made the base class from_dict constructor \ No newline at end of file +- made the base class from_dict constructor + +# 0.3.23 +- optimized bulk load via dynamic index/FK management & added FTS support with generated columns \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1c9fec5..3772c5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "orm-loader" -version = "0.3.22" +version = "0.3.23" description = "Generic base classes to handle ORM functionality for multiple downstream datamodels" readme = "README.md" authors = [ diff --git a/src/orm_loader/helpers/bulk.py b/src/orm_loader/helpers/bulk.py index 14d5147..4c3a40a 100644 --- a/src/orm_loader/helpers/bulk.py +++ b/src/orm_loader/helpers/bulk.py @@ -6,6 +6,58 @@ logger = get_logger(__name__) +def disable_fk_check(session: Session) -> str | int: + """Disables FK checks and returns the previous state.""" + engine = session.get_bind() + dialect = engine.dialect.name + previous_state = None + + if dialect == "postgresql": + previous_state = session.execute(text("SHOW session_replication_role")).scalar() + session.execute(text("SET session_replication_role = 'replica'")) + elif dialect == "sqlite": + previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() + session.execute(text("PRAGMA foreign_keys = OFF")) + else: + raise NotImplementedError(f"FK disable not implemented for {dialect}") + + logger.info("Disabled foreign key checks for bulk load.") + assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int" + return previous_state + +def enable_fk_check(session: Session) -> str | int: + """Explicitly enables FK checks and returns the previous state.""" + engine = session.get_bind() + dialect = engine.dialect.name + previous_state = None + + if dialect == "postgresql": + previous_state = session.execute(text("SHOW session_replication_role")).scalar() + session.execute(text("SET session_replication_role = 'origin'")) + elif dialect == "sqlite": + previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() + session.execute(text("PRAGMA foreign_keys = ON")) + else: + raise NotImplementedError(f"FK enable not implemented for {dialect}") + + logger.info("Explicitly re-enabled foreign key checks.") + assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int" + return previous_state + +def restore_fk_check(session: Session, previous_state: str | int): + """Restores FK checks to a specifically provided previous state.""" + engine = session.get_bind() + dialect = engine.dialect.name + + if dialect == "postgresql": + session.execute(text(f"SET session_replication_role = '{previous_state}'")) + elif dialect == "sqlite": + session.execute(text(f"PRAGMA foreign_keys = {previous_state}")) + else: + raise NotImplementedError(f"FK restore not implemented for {dialect}") + + logger.info(f"Restored foreign key checks to state: {previous_state}") + @contextmanager def bulk_load_context( session: Session, @@ -13,22 +65,10 @@ def bulk_load_context( disable_fk: bool = True, no_autoflush: bool = True, ): - engine = session.get_bind() - dialect = engine.dialect.name - fk_disabled = False - + previous_fk_state = None try: if disable_fk: - if dialect == "postgresql": - session.execute(text( - "SET session_replication_role = replica" - )) - fk_disabled = True - elif dialect == "sqlite": - session.execute(text("PRAGMA foreign_keys = OFF")) - fk_disabled = True - - logger.info("Disabled foreign key checks for bulk load") + previous_fk_state = disable_fk_check(session) if no_autoflush: with session.no_autoflush: @@ -41,15 +81,9 @@ def bulk_load_context( raise finally: - if fk_disabled: - if dialect == "postgresql": - session.execute(text( - "SET session_replication_role = DEFAULT" - )) - elif dialect == "sqlite": - session.execute(text("PRAGMA foreign_keys = ON")) - - logger.info("Re-enabled foreign key checks after bulk load") + if previous_fk_state is not None: + restore_fk_check(session, previous_fk_state) + @contextmanager def engine_with_replica_role(engine: Engine): diff --git a/src/orm_loader/helpers/null_handlers.py b/src/orm_loader/helpers/null_handlers.py index 007a28c..09a5464 100644 --- a/src/orm_loader/helpers/null_handlers.py +++ b/src/orm_loader/helpers/null_handlers.py @@ -1,4 +1,3 @@ -import math import pandas as pd from typing import Any diff --git a/src/orm_loader/loaders/loader_interface.py b/src/orm_loader/loaders/loader_interface.py index aaaf2e5..825b816 100644 --- a/src/orm_loader/loaders/loader_interface.py +++ b/src/orm_loader/loaders/loader_interface.py @@ -122,7 +122,6 @@ def orm_file_load(cls, ctx: LoaderContext) -> int: If chunksize is None, pandas returns a single DataFrame, which we normalise to a one-element iterator for unified processing. """ - total = 0 delimiter = infer_delim(ctx.path) encoding = infer_encoding(ctx.path)['encoding'] @@ -144,11 +143,10 @@ def orm_file_load(cls, ctx: LoaderContext) -> int: logger.info(f"Loading with chunksize '{ctx.chunksize}' for file {ctx.path.name}") chunks = (reader,) if isinstance(reader, pd.DataFrame) else reader - i = 0 - for chunk in chunks: + total = 0 + for i, chunk in enumerate(chunks): + logger.debug(f"Processing chunk {i} with {len(chunk)} rows for {ctx.tableclass.__tablename__}") chunk = _normalise_columns(chunk) - logger.info(f"Processing chunk {i} with {len(chunk)} rows for {ctx.tableclass.__tablename__}") - i += 1 if ctx.dedupe: chunk = cls.dedupe(chunk, ctx) if ctx.normalise: @@ -158,7 +156,6 @@ def orm_file_load(cls, ctx: LoaderContext) -> int: session=ctx.session, dataframe=chunk ) - return total class ParquetLoader(LoaderInterface): diff --git a/src/orm_loader/tables/loadable_table.py b/src/orm_loader/tables/loadable_table.py index c7b9d14..d352344 100644 --- a/src/orm_loader/tables/loadable_table.py +++ b/src/orm_loader/tables/loadable_table.py @@ -4,11 +4,13 @@ from typing import Type, ClassVar, Optional from pathlib import Path +from contextlib import contextmanager from .orm_table import ORMTableBase from .typing import CSVTableProtocol from ..loaders.loader_interface import LoaderInterface, LoaderContext, PandasLoader, ParquetLoader from ..loaders.loading_helpers import quick_load_pg +from ..helpers.bulk import restore_fk_check, disable_fk_check logger = logging.getLogger(__name__) @@ -92,9 +94,7 @@ def create_staging_table( If the database dialect is unsupported. """ table = cls.__table__ - session.execute(sa.text(f""" - DROP TABLE IF EXISTS "{cls.staging_tablename()}"; - """)) + session.execute(sa.text(f"""DROP TABLE IF EXISTS "{cls.staging_tablename()}";""")) if session.bind is None: raise RuntimeError("Session is not bound to an engine") @@ -102,10 +102,17 @@ def create_staging_table( dialect = session.bind.dialect.name if dialect == "postgresql": + logger.info("Disabling indices on staging table for performance") session.execute(sa.text(f''' CREATE UNLOGGED TABLE "{cls.staging_tablename()}" - (LIKE "{table.name}" INCLUDING ALL); + (LIKE "{table.name}" INCLUDING DEFAULTS INCLUDING CONSTRAINTS); ''')) + + # Need to drop the columns we are not going to load into, otherwise the COPY will fail + computed_cols = [c.name for c in table.columns if c.computed is not None] + for col in computed_cols: + session.execute(sa.text(f'ALTER TABLE "{cls.staging_tablename()}" DROP COLUMN "{col}";')) + elif dialect == "sqlite": metadata = sa.MetaData() @@ -146,6 +153,58 @@ def create_staging_table( session.commit() + @classmethod + @contextmanager + def manage_indices(cls: Type['CSVTableProtocol'], session: so.Session): + """ + Temporarily drops non-primary key indices before a bulk operation + and recreates them afterwards to prevent write amplification. + """ + indices = list(cls.__table__.indexes) + inspector = sa.inspect(session.bind) + assert inspector is not None, "Failed to create inspector for index management" + + if indices: + existing_in_db = {idx['name'] for idx in inspector.get_indexes(cls.__tablename__)} + to_drop = [i for i in indices if i.name in existing_in_db] + + if to_drop: + logger.info(f"Dropping {len(to_drop)} active indices...") + for idx in to_drop: + session.execute(sa.schema.DropIndex(idx)) + session.commit() + + # session.commit() above restores the original state of the session. We need that one after we are done + previous_fk_state = disable_fk_check(session) + + try: + yield + session.commit() + + except Exception as e: + session.rollback() + logger.error(f"Table `{cls.__tablename__}`: Merge operation failed - {e}") + raise + finally: + restore_fk_check(session, previous_fk_state) + + if indices: + logger.info(f"Table `{cls.__tablename__}`: Verifying/Rebuilding indices.") + inspector.clear_cache() # Required to ensure we get the current state of the database after potential changes + existing_idx_names = {idx['name'] for idx in inspector.get_indexes(cls.__tablename__)} + + for idx in indices: + if idx.name not in existing_idx_names: + try: + logger.info(f"Restoring missing index: {idx.name}") + session.execute(sa.schema.CreateIndex(idx)) + session.commit() + except Exception as e: + session.rollback() + logger.error(f"Failed to restore {idx.name}: {e}") + else: + logger.debug(f"Index {idx.name} actually exists on disk. Skipping.") + @classmethod def get_staging_table( @@ -174,7 +233,7 @@ def get_staging_table( staging_name = cls.staging_tablename() if not inspector.has_table(staging_name): - logger.warning(f"Staging table {staging_name} does not exist; recreating",) + logger.debug(f"Staging table {staging_name} does not exist; recreating",) cls.create_staging_table(session) return sa.Table( @@ -225,8 +284,10 @@ def load_staging( ) return total except Exception as e: + loader_context.session.rollback() logger.warning(f"COPY failed for {cls.staging_tablename()}: {e}") logger.info('Falling back to ORM-based load functionality') + total = cls.orm_staging_load( loader=loader, loader_context=loader_context @@ -318,7 +379,7 @@ def load_csv( Number of rows loaded. """ - logger.debug(f"Loading CSV for {cls.__tablename__} via staging table from {path}") + logger.debug(f"Table `{cls.__tablename__}`: Loading CSV from {path}") if path.stem.lower() != cls.__tablename__: raise ValueError( @@ -337,11 +398,19 @@ def load_csv( if loader is None: loader = cls._select_loader(path) + + # Load to staging (Indices are already excluded via updated create_staging_table) + logger.info(f"Table `{cls.__tablename__}`: Loading data into unlogged staging table") total = cls.load_staging(loader=loader, loader_context=loader_context) - cls.merge_from_staging(session, merge_strategy=merge_strategy) + # Merge staging to target (Wrapped in our index dropper!) + logger.info(f"Table `{cls.__tablename__}`: Merging staging data into target table") + with cls.manage_indices(session): + cls.merge_from_staging(session, merge_strategy=merge_strategy) + cls.drop_staging_table(session) + logger.info(f"Table `{cls.__tablename__}`: Successfully finished ingestion. Total rows: {total}") return total @@ -402,9 +471,13 @@ def _merge_insert( """ Insert all rows from the staging table into the target table. """ + # Get all columns that are NOT computed + insertable_cols = [c.name for c in cls.__table__.columns if c.computed is None] + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + session.execute(sa.text(f""" - INSERT INTO "{target}" - SELECT * FROM "{staging}"; + INSERT INTO "{target}" ({cols_str}) + SELECT {cols_str} FROM "{staging}"; """)) @@ -420,18 +493,23 @@ def _merge_upsert( """ Merge staging data using an upsert strategy. """ + + # Get all columns that are NOT computed + insertable_cols = [c.name for c in cls.__table__.columns if c.computed is None] + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + if dialect == "postgresql": # INSERT … ON CONFLICT DO NOTHING session.execute(sa.text(f""" - INSERT INTO "{target}" - SELECT * FROM "{staging}" + INSERT INTO "{target}" ({cols_str}) + SELECT {cols_str} FROM "{staging}" ON CONFLICT ({", ".join(f'"{c}"' for c in pk_cols)}) DO NOTHING; """)) elif dialect == "sqlite": session.execute(sa.text(f""" - INSERT OR IGNORE INTO "{target}" - SELECT * FROM "{staging}"; + INSERT OR IGNORE INTO "{target}" ({cols_str}) + SELECT {cols_str} FROM "{staging}"; """)) else: @@ -507,5 +585,6 @@ def csv_columns(cls) -> dict[str, sa.ColumnElement]: dict[str, sqlalchemy.ColumnElement] Mapping of input column names to SQLAlchemy columns. """ - return cls.model_columns() - \ No newline at end of file + cols = cls.model_columns() + computed_names = {c.name for c in cls.__table__.columns if c.computed is not None} # type: ignore + return {k: v for k, v in cols.items() if k not in computed_names} \ No newline at end of file diff --git a/src/orm_loader/tables/typing.py b/src/orm_loader/tables/typing.py index 0be646e..b38357b 100644 --- a/src/orm_loader/tables/typing.py +++ b/src/orm_loader/tables/typing.py @@ -2,6 +2,7 @@ import sqlalchemy.orm as so import sqlalchemy as sa from pathlib import Path +from contextlib import AbstractContextManager if TYPE_CHECKING: from ..loaders import LoaderContext, LoaderInterface @@ -101,6 +102,10 @@ def _merge_replace(cls, session: so.Session, target: str, staging: str, pk_cols: @classmethod def _merge_upsert(cls, session: so.Session, target: str, staging: str, pk_cols: list[str], dialect: str) -> None: ... + + @classmethod + def manage_indices(cls, session: so.Session) -> AbstractContextManager[None]: + ... @runtime_checkable diff --git a/tests/conftest.py b/tests/conftest.py index d2735c6..e12e9a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,6 @@ import pytest import sqlalchemy as sa import sqlalchemy.orm as so -import os -from orm_loader.tables import CSVLoadableTableInterface import time from tests.models import Base diff --git a/tests/loaders/test_loader_e2e.py b/tests/loaders/test_loader_e2e.py index 9567b3d..697d601 100644 --- a/tests/loaders/test_loader_e2e.py +++ b/tests/loaders/test_loader_e2e.py @@ -379,7 +379,7 @@ class TextTable(Base, CSVLoadableTableInterface): '1\t"hello\nworld"\n' ) - inserted = TextTable.load_csv( # type: ignore + TextTable.load_csv( # type: ignore session, csv, loader=PandasLoader(), @@ -406,7 +406,7 @@ class TextTable2(Base, CSVLoadableTableInterface): '1\t"foo\tbar"\n' ) - inserted = TextTable2.load_csv( # type: ignore + TextTable2.load_csv( # type: ignore session, csv, loader=PandasLoader(), diff --git a/tests/loaders/test_pandas_loader.py b/tests/loaders/test_pandas_loader.py index 0312890..aee46d0 100644 --- a/tests/loaders/test_pandas_loader.py +++ b/tests/loaders/test_pandas_loader.py @@ -1,6 +1,5 @@ from pathlib import Path import sqlalchemy as sa -import sqlalchemy.orm as so from orm_loader.loaders.data_classes import LoaderContext import pandas as pd from orm_loader.loaders.loader_interface import PandasLoader diff --git a/tests/loaders/test_pg_loader.py b/tests/loaders/test_pg_loader.py index 1ed0f89..32228fd 100644 --- a/tests/loaders/test_pg_loader.py +++ b/tests/loaders/test_pg_loader.py @@ -1,17 +1,10 @@ import sqlalchemy as sa -import sqlalchemy.orm as so -from sqlalchemy.orm import Session -from pathlib import Path import pandas as pd import pytest -from orm_loader.loaders.data_classes import _clean_nulls -from orm_loader.tables.loadable_table import CSVLoadableTableInterface -from orm_loader.loaders.loader_interface import PandasLoader from orm_loader.loaders.loading_helpers import infer_encoding, infer_delim, check_line_ending, quick_load_pg -from tests.models import Base, SimpleTable, RequiredTable, CompositeTable +from tests.models import SimpleTable -import numpy as np @pytest.mark.postgres def test_copy_and_orm_path_equivalence(pg_session, tmp_path): @@ -24,7 +17,7 @@ def test_copy_and_orm_path_equivalence(pg_session, tmp_path): ] ).to_csv(csv, index=False, sep="\t") - inserted = SimpleTable.load_csv(pg_session, csv) + SimpleTable.load_csv(pg_session, csv) pg_session.commit() rows = pg_session.execute(sa.select(SimpleTable).order_by(SimpleTable.id)).scalars().all() @@ -298,7 +291,7 @@ def test_copy_fails_with_raw_carriage_returns_but_succeeds_after_normalisation(p try: quick_load_pg(path=csv, session=pg_session, tablename="test_table") pg_session.commit() - except Exception as e: + except Exception: failed = True # Clean up table for second attempt diff --git a/tests/pg_db.py b/tests/pg_db.py index f116bdf..e383f9d 100644 --- a/tests/pg_db.py +++ b/tests/pg_db.py @@ -1,9 +1,7 @@ -import os import time import pytest import sqlalchemy as sa from sqlalchemy.orm import sessionmaker -from sqlalchemy_utils import create_database, database_exists, drop_database from tests.models import Base