Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions fastapi_filter_sqlalchemy/filter_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from warnings import warn

from pydantic import ValidationInfo, field_validator
from sqlalchemy import String, cast, func, or_
from sqlalchemy import Integer, String, cast, func, or_
from sqlalchemy.orm import Query
from sqlalchemy.sql.selectable import Select

Expand Down Expand Up @@ -92,6 +92,8 @@ class Constants:
prefix: str
original_filter: type["Filter"]
ordering_fk_fields_mapping: dict = {}
ordering_convert_str_to_int_fields: list[str] = []
ordering_lower_case_fields: list[str] = []

class Direction(str, Enum):
asc = "asc"
Expand Down Expand Up @@ -209,11 +211,19 @@ def sort(self, query: Query | Select):
direction = Filter.Direction.desc

field_name = field_name.replace("-", "").replace("+", "")

order_by_field = getattr(self.Constants.model, field_name)
if field_name in self.Constants.ordering_convert_str_to_int_fields:
order_by_field = cast(order_by_field, Integer)
query = query.add_columns(
cast(getattr(self.Constants.model, field_name), Integer).label(f"{field_name}_integer_value")
)
if field_name in self.Constants.ordering_lower_case_fields:
order_by_field = func.lower(order_by_field)
query = query.add_columns(
func.lower(getattr(self.Constants.model, field_name)).label(f"{field_name}_lower_case")
)
if field_name in self.Constants.ordering_fk_fields_mapping:
model = self.Constants.ordering_fk_fields_mapping[field_name]

order_by_field = getattr(model, base_field_name)
query = query.order_by(getattr(order_by_field, direction)())

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ignore_missing_imports = true

[tool.poetry]
name = "fastapi-filter-sqlalchemy"
version = "0.0.3"
version = "0.0.4"
description = "FastAPI filter SQLAlchemy"
authors = ["Sergey V. Elfimov <elfimovserg@gmail.com>"]
packages = [{include = "fastapi_filter_sqlalchemy"}]
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class User(Base): # type: ignore[misc, valid-type]
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
name = Column(String)
age = Column(Integer, nullable=False)
code = Column(String)
address_id = Column(Integer, ForeignKey("addresses.id"))
address: Mapped[Address] = relationship(Address, backref="users", lazy="joined") # type: ignore[valid-type]
favorite_sports: Mapped[Sport] = relationship( # type: ignore[valid-type]
Expand Down Expand Up @@ -178,24 +179,28 @@ async def users(session, User, Address):
name=None,
age=21,
created_at=datetime(2021, 12, 1),
code="21",
),
User(
name="Mr Praline",
age=33,
created_at=datetime(2021, 12, 1),
address=Address(street="22 rue Bellier", city="Nantes", country="France"),
code="11",
),
User(
name="The colonel",
age=90,
created_at=datetime(2021, 12, 2),
address=Address(street="Wrench", city="Bathroom", country="Clue"),
code="11222",
),
User(
name="Mr Creosote",
age=21,
created_at=datetime(2021, 12, 3),
address=Address(city="Nantes", country="France"),
code="321",
),
User(
name="Rabbit of Caerbannog",
Expand All @@ -208,6 +213,7 @@ async def users(session, User, Address):
age=50,
created_at=datetime(2021, 12, 4),
address=Address(street="4567 avenue", city="Denver", country="United States"),
code="89",
),
]
session.add_all(user_instances)
Expand Down Expand Up @@ -349,6 +355,8 @@ class Constants(Filter.Constants): # type: ignore[name-defined]
model = User
search_model_fields = ["name"]
search_field_name = "search"
ordering_lower_case_fields = ["name"]
ordering_convert_str_to_int_fields = ["code"]

def get_custom_filter(self, query, value):
return query.filter(
Expand Down
6 changes: 6 additions & 0 deletions tests/test_order_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
lambda previous_user, user: (previous_user.age < user.age)
or (previous_user.age == user.age and previous_user.created_at >= user.created_at),
],
[
["code"],
lambda previous_user, user: int(previous_user.code) <= int(user.code)
if previous_user.code and user.code
else True,
],
],
)
@pytest.mark.asyncio
Expand Down
Loading