Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,7 @@
- NaT further downstream

# 0.3.22
- made the base class from_dict constructor
- made the base class from_dict constructor

# 0.3.23
- optimized bulk load via dynamic index/FK management & added FTS support with generated columns
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "orm-loader"
version = "0.3.22"
version = "0.3.23"
description = "Generic base classes to handle ORM functionality for multiple downstream datamodels"
readme = "README.md"
authors = [
Expand Down
80 changes: 57 additions & 23 deletions src/orm_loader/helpers/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,69 @@

logger = get_logger(__name__)

def disable_fk_check(session: Session) -> str | int:
"""Disables FK checks and returns the previous state."""
engine = session.get_bind()
dialect = engine.dialect.name
previous_state = None

if dialect == "postgresql":
previous_state = session.execute(text("SHOW session_replication_role")).scalar()
session.execute(text("SET session_replication_role = 'replica'"))
elif dialect == "sqlite":
previous_state = session.execute(text("PRAGMA foreign_keys")).scalar()
session.execute(text("PRAGMA foreign_keys = OFF"))
else:
raise NotImplementedError(f"FK disable not implemented for {dialect}")

logger.info("Disabled foreign key checks for bulk load.")
assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int"
return previous_state

def enable_fk_check(session: Session) -> str | int:
"""Explicitly enables FK checks and returns the previous state."""
engine = session.get_bind()
dialect = engine.dialect.name
previous_state = None

if dialect == "postgresql":
previous_state = session.execute(text("SHOW session_replication_role")).scalar()
session.execute(text("SET session_replication_role = 'origin'"))
elif dialect == "sqlite":
previous_state = session.execute(text("PRAGMA foreign_keys")).scalar()
session.execute(text("PRAGMA foreign_keys = ON"))
else:
raise NotImplementedError(f"FK enable not implemented for {dialect}")

logger.info("Explicitly re-enabled foreign key checks.")
assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int"
return previous_state

def restore_fk_check(session: Session, previous_state: str | int):
"""Restores FK checks to a specifically provided previous state."""
engine = session.get_bind()
dialect = engine.dialect.name

if dialect == "postgresql":
session.execute(text(f"SET session_replication_role = '{previous_state}'"))
elif dialect == "sqlite":
session.execute(text(f"PRAGMA foreign_keys = {previous_state}"))
else:
raise NotImplementedError(f"FK restore not implemented for {dialect}")

logger.info(f"Restored foreign key checks to state: {previous_state}")

@contextmanager
def bulk_load_context(
session: Session,
*,
disable_fk: bool = True,
no_autoflush: bool = True,
):
engine = session.get_bind()
dialect = engine.dialect.name
fk_disabled = False

previous_fk_state = None
try:
if disable_fk:
if dialect == "postgresql":
session.execute(text(
"SET session_replication_role = replica"
))
fk_disabled = True
elif dialect == "sqlite":
session.execute(text("PRAGMA foreign_keys = OFF"))
fk_disabled = True

logger.info("Disabled foreign key checks for bulk load")
previous_fk_state = disable_fk_check(session)

if no_autoflush:
with session.no_autoflush:
Expand All @@ -41,15 +81,9 @@ def bulk_load_context(
raise

finally:
if fk_disabled:
if dialect == "postgresql":
session.execute(text(
"SET session_replication_role = DEFAULT"
))
elif dialect == "sqlite":
session.execute(text("PRAGMA foreign_keys = ON"))

logger.info("Re-enabled foreign key checks after bulk load")
if previous_fk_state is not None:
restore_fk_check(session, previous_fk_state)


@contextmanager
def engine_with_replica_role(engine: Engine):
Expand Down
1 change: 0 additions & 1 deletion src/orm_loader/helpers/null_handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import pandas as pd
from typing import Any

Expand Down
9 changes: 3 additions & 6 deletions src/orm_loader/loaders/loader_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def orm_file_load(cls, ctx: LoaderContext) -> int:
If chunksize is None, pandas returns a single DataFrame, which we
normalise to a one-element iterator for unified processing.
"""
total = 0

delimiter = infer_delim(ctx.path)
encoding = infer_encoding(ctx.path)['encoding']
Expand All @@ -144,11 +143,10 @@ def orm_file_load(cls, ctx: LoaderContext) -> int:
logger.info(f"Loading with chunksize '{ctx.chunksize}' for file {ctx.path.name}")
chunks = (reader,) if isinstance(reader, pd.DataFrame) else reader

i = 0
for chunk in chunks:
total = 0
for i, chunk in enumerate(chunks):
logger.debug(f"Processing chunk {i} with {len(chunk)} rows for {ctx.tableclass.__tablename__}")
chunk = _normalise_columns(chunk)
logger.info(f"Processing chunk {i} with {len(chunk)} rows for {ctx.tableclass.__tablename__}")
i += 1
if ctx.dedupe:
chunk = cls.dedupe(chunk, ctx)
if ctx.normalise:
Expand All @@ -158,7 +156,6 @@ def orm_file_load(cls, ctx: LoaderContext) -> int:
session=ctx.session,
dataframe=chunk
)

return total

class ParquetLoader(LoaderInterface):
Expand Down
109 changes: 94 additions & 15 deletions src/orm_loader/tables/loadable_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

from typing import Type, ClassVar, Optional
from pathlib import Path
from contextlib import contextmanager

from .orm_table import ORMTableBase
from .typing import CSVTableProtocol
from ..loaders.loader_interface import LoaderInterface, LoaderContext, PandasLoader, ParquetLoader
from ..loaders.loading_helpers import quick_load_pg
from ..helpers.bulk import restore_fk_check, disable_fk_check

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,20 +94,25 @@ def create_staging_table(
If the database dialect is unsupported.
"""
table = cls.__table__
session.execute(sa.text(f"""
DROP TABLE IF EXISTS "{cls.staging_tablename()}";
"""))
session.execute(sa.text(f"""DROP TABLE IF EXISTS "{cls.staging_tablename()}";"""))

if session.bind is None:
raise RuntimeError("Session is not bound to an engine")

dialect = session.bind.dialect.name

if dialect == "postgresql":
logger.info("Disabling indices on staging table for performance")
session.execute(sa.text(f'''
CREATE UNLOGGED TABLE "{cls.staging_tablename()}"
(LIKE "{table.name}" INCLUDING ALL);
(LIKE "{table.name}" INCLUDING DEFAULTS INCLUDING CONSTRAINTS);
'''))

# Need to drop the columns we are not going to load into, otherwise the COPY will fail
computed_cols = [c.name for c in table.columns if c.computed is not None]
for col in computed_cols:
session.execute(sa.text(f'ALTER TABLE "{cls.staging_tablename()}" DROP COLUMN "{col}";'))

elif dialect == "sqlite":

metadata = sa.MetaData()
Expand Down Expand Up @@ -146,6 +153,58 @@ def create_staging_table(

session.commit()

@classmethod
@contextmanager
def manage_indices(cls: Type['CSVTableProtocol'], session: so.Session):
"""
Temporarily drops non-primary key indices before a bulk operation
and recreates them afterwards to prevent write amplification.
"""
indices = list(cls.__table__.indexes)
inspector = sa.inspect(session.bind)
assert inspector is not None, "Failed to create inspector for index management"

if indices:
existing_in_db = {idx['name'] for idx in inspector.get_indexes(cls.__tablename__)}
to_drop = [i for i in indices if i.name in existing_in_db]

if to_drop:
logger.info(f"Dropping {len(to_drop)} active indices...")
for idx in to_drop:
session.execute(sa.schema.DropIndex(idx))
session.commit()

# session.commit() above restores the original state of the session. We need that one after we are done
previous_fk_state = disable_fk_check(session)

try:
yield
session.commit()

except Exception as e:
session.rollback()
logger.error(f"Table `{cls.__tablename__}`: Merge operation failed - {e}")
raise
finally:
restore_fk_check(session, previous_fk_state)

if indices:
logger.info(f"Table `{cls.__tablename__}`: Verifying/Rebuilding indices.")
inspector.clear_cache() # Required to ensure we get the current state of the database after potential changes
existing_idx_names = {idx['name'] for idx in inspector.get_indexes(cls.__tablename__)}

for idx in indices:
if idx.name not in existing_idx_names:
try:
logger.info(f"Restoring missing index: {idx.name}")
session.execute(sa.schema.CreateIndex(idx))
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Failed to restore {idx.name}: {e}")
else:
logger.debug(f"Index {idx.name} actually exists on disk. Skipping.")


@classmethod
def get_staging_table(
Expand Down Expand Up @@ -174,7 +233,7 @@ def get_staging_table(
staging_name = cls.staging_tablename()

if not inspector.has_table(staging_name):
logger.warning(f"Staging table {staging_name} does not exist; recreating",)
logger.debug(f"Staging table {staging_name} does not exist; recreating",)
cls.create_staging_table(session)

return sa.Table(
Expand Down Expand Up @@ -225,8 +284,10 @@ def load_staging(
)
return total
except Exception as e:
loader_context.session.rollback()
logger.warning(f"COPY failed for {cls.staging_tablename()}: {e}")
logger.info('Falling back to ORM-based load functionality')

total = cls.orm_staging_load(
loader=loader,
loader_context=loader_context
Expand Down Expand Up @@ -318,7 +379,7 @@ def load_csv(
Number of rows loaded.
"""

logger.debug(f"Loading CSV for {cls.__tablename__} via staging table from {path}")
logger.debug(f"Table `{cls.__tablename__}`: Loading CSV from {path}")

if path.stem.lower() != cls.__tablename__:
raise ValueError(
Expand All @@ -337,11 +398,19 @@ def load_csv(

if loader is None:
loader = cls._select_loader(path)

# Load to staging (Indices are already excluded via updated create_staging_table)
logger.info(f"Table `{cls.__tablename__}`: Loading data into unlogged staging table")
total = cls.load_staging(loader=loader, loader_context=loader_context)

cls.merge_from_staging(session, merge_strategy=merge_strategy)
# Merge staging to target (Wrapped in our index dropper!)
logger.info(f"Table `{cls.__tablename__}`: Merging staging data into target table")
with cls.manage_indices(session):
cls.merge_from_staging(session, merge_strategy=merge_strategy)

cls.drop_staging_table(session)

logger.info(f"Table `{cls.__tablename__}`: Successfully finished ingestion. Total rows: {total}")
return total


Expand Down Expand Up @@ -402,9 +471,13 @@ def _merge_insert(
"""
Insert all rows from the staging table into the target table.
"""
# Get all columns that are NOT computed
insertable_cols = [c.name for c in cls.__table__.columns if c.computed is None]
cols_str = ", ".join(f'"{c}"' for c in insertable_cols)

session.execute(sa.text(f"""
INSERT INTO "{target}"
SELECT * FROM "{staging}";
INSERT INTO "{target}" ({cols_str})
SELECT {cols_str} FROM "{staging}";
"""))


Expand All @@ -420,18 +493,23 @@ def _merge_upsert(
"""
Merge staging data using an upsert strategy.
"""

# Get all columns that are NOT computed
insertable_cols = [c.name for c in cls.__table__.columns if c.computed is None]
cols_str = ", ".join(f'"{c}"' for c in insertable_cols)

if dialect == "postgresql":
# INSERT … ON CONFLICT DO NOTHING
session.execute(sa.text(f"""
INSERT INTO "{target}"
SELECT * FROM "{staging}"
INSERT INTO "{target}" ({cols_str})
SELECT {cols_str} FROM "{staging}"
ON CONFLICT ({", ".join(f'"{c}"' for c in pk_cols)}) DO NOTHING;
"""))

elif dialect == "sqlite":
session.execute(sa.text(f"""
INSERT OR IGNORE INTO "{target}"
SELECT * FROM "{staging}";
INSERT OR IGNORE INTO "{target}" ({cols_str})
SELECT {cols_str} FROM "{staging}";
"""))

else:
Expand Down Expand Up @@ -507,5 +585,6 @@ def csv_columns(cls) -> dict[str, sa.ColumnElement]:
dict[str, sqlalchemy.ColumnElement]
Mapping of input column names to SQLAlchemy columns.
"""
return cls.model_columns()

cols = cls.model_columns()
computed_names = {c.name for c in cls.__table__.columns if c.computed is not None} # type: ignore
return {k: v for k, v in cols.items() if k not in computed_names}
5 changes: 5 additions & 0 deletions src/orm_loader/tables/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sqlalchemy.orm as so
import sqlalchemy as sa
from pathlib import Path
from contextlib import AbstractContextManager
if TYPE_CHECKING:
from ..loaders import LoaderContext, LoaderInterface

Expand Down Expand Up @@ -101,6 +102,10 @@ def _merge_replace(cls, session: so.Session, target: str, staging: str, pk_cols:

@classmethod
def _merge_upsert(cls, session: so.Session, target: str, staging: str, pk_cols: list[str], dialect: str) -> None: ...

@classmethod
def manage_indices(cls, session: so.Session) -> AbstractContextManager[None]:
...


@runtime_checkable
Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest
import sqlalchemy as sa
import sqlalchemy.orm as so
import os
from orm_loader.tables import CSVLoadableTableInterface
import time

from tests.models import Base
Expand Down
Loading
Loading