diff --git a/fastapi_filter_sqlalchemy/filter_sqlalchemy.py b/fastapi_filter_sqlalchemy/filter_sqlalchemy.py index 38e0efe..ba14aa5 100644 --- a/fastapi_filter_sqlalchemy/filter_sqlalchemy.py +++ b/fastapi_filter_sqlalchemy/filter_sqlalchemy.py @@ -4,7 +4,7 @@ from warnings import warn from pydantic import ValidationInfo, field_validator -from sqlalchemy import String, cast, func, or_ +from sqlalchemy import Integer, String, cast, func, or_ from sqlalchemy.orm import Query from sqlalchemy.sql.selectable import Select @@ -92,6 +92,8 @@ class Constants: prefix: str original_filter: type["Filter"] ordering_fk_fields_mapping: dict = {} + ordering_convert_str_to_int_fields: list[str] = [] + ordering_lower_case_fields: list[str] = [] class Direction(str, Enum): asc = "asc" @@ -209,11 +211,19 @@ def sort(self, query: Query | Select): direction = Filter.Direction.desc field_name = field_name.replace("-", "").replace("+", "") - order_by_field = getattr(self.Constants.model, field_name) + if field_name in self.Constants.ordering_convert_str_to_int_fields: + order_by_field = cast(order_by_field, Integer) + query = query.add_columns( + cast(getattr(self.Constants.model, field_name), Integer).label(f"{field_name}_integer_value") + ) + if field_name in self.Constants.ordering_lower_case_fields: + order_by_field = func.lower(order_by_field) + query = query.add_columns( + func.lower(getattr(self.Constants.model, field_name)).label(f"{field_name}_lower_case") + ) if field_name in self.Constants.ordering_fk_fields_mapping: model = self.Constants.ordering_fk_fields_mapping[field_name] - order_by_field = getattr(model, base_field_name) query = query.order_by(getattr(order_by_field, direction)()) diff --git a/pyproject.toml b/pyproject.toml index 73aab41..e3a19e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ ignore_missing_imports = true [tool.poetry] name = "fastapi-filter-sqlalchemy" -version = "0.0.3" +version = "0.0.4" description = "FastAPI filter SQLAlchemy" authors = ["Sergey V. Elfimov "] packages = [{include = "fastapi_filter_sqlalchemy"}] diff --git a/tests/conftest.py b/tests/conftest.py index 7403272..fb22efd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -123,6 +123,7 @@ class User(Base): # type: ignore[misc, valid-type] updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False) name = Column(String) age = Column(Integer, nullable=False) + code = Column(String) address_id = Column(Integer, ForeignKey("addresses.id")) address: Mapped[Address] = relationship(Address, backref="users", lazy="joined") # type: ignore[valid-type] favorite_sports: Mapped[Sport] = relationship( # type: ignore[valid-type] @@ -178,24 +179,28 @@ async def users(session, User, Address): name=None, age=21, created_at=datetime(2021, 12, 1), + code="21", ), User( name="Mr Praline", age=33, created_at=datetime(2021, 12, 1), address=Address(street="22 rue Bellier", city="Nantes", country="France"), + code="11", ), User( name="The colonel", age=90, created_at=datetime(2021, 12, 2), address=Address(street="Wrench", city="Bathroom", country="Clue"), + code="11222", ), User( name="Mr Creosote", age=21, created_at=datetime(2021, 12, 3), address=Address(city="Nantes", country="France"), + code="321", ), User( name="Rabbit of Caerbannog", @@ -208,6 +213,7 @@ async def users(session, User, Address): age=50, created_at=datetime(2021, 12, 4), address=Address(street="4567 avenue", city="Denver", country="United States"), + code="89", ), ] session.add_all(user_instances) @@ -349,6 +355,8 @@ class Constants(Filter.Constants): # type: ignore[name-defined] model = User search_model_fields = ["name"] search_field_name = "search" + ordering_lower_case_fields = ["name"] + ordering_convert_str_to_int_fields = ["code"] def get_custom_filter(self, query, value): return query.filter( diff --git a/tests/test_order_by.py b/tests/test_order_by.py index 0adc0b4..db04be0 100644 --- a/tests/test_order_by.py +++ b/tests/test_order_by.py @@ -22,6 +22,12 @@ lambda previous_user, user: (previous_user.age < user.age) or (previous_user.age == user.age and previous_user.created_at >= user.created_at), ], + [ + ["code"], + lambda previous_user, user: int(previous_user.code) <= int(user.code) + if previous_user.code and user.code + else True, + ], ], ) @pytest.mark.asyncio