diff --git a/internal_admin/admin/filters.py b/internal_admin/admin/filters.py index b780877..36d2c00 100644 --- a/internal_admin/admin/filters.py +++ b/internal_admin/admin/filters.py @@ -9,7 +9,9 @@ from typing import Any from sqlalchemy import Boolean, Date, DateTime +from sqlalchemy.inspection import inspect as sa_inspect from sqlalchemy.orm import Session +from sqlalchemy.sql.sqltypes import TypeDecorator from .model_admin import ModelAdmin @@ -204,9 +206,44 @@ def get_choices(self, session: Session, model_class: type[Any]) -> list[tuple[An if not hasattr(model_class, self.field_name): return [] - # This is simplified - in practice you'd need proper relationship introspection - # For now, return empty choices - return [] + try: + relationship = self._find_relationship(model_class) + if relationship is None: + return [] + + related_model = relationship.mapper.class_ + related_mapper = sa_inspect(related_model) + pk_attr = related_mapper.primary_key[0].key + + label_attr = self.display_field if hasattr(related_model, self.display_field) else None + if label_attr is None: + for candidate in ("display_name", "name", "title", "username", "email"): + if hasattr(related_model, candidate): + label_attr = candidate + break + + query = session.query(related_model) + if label_attr: + query = query.order_by(getattr(related_model, label_attr).asc()) + else: + query = query.order_by(getattr(related_model, pk_attr).asc()) + + rows = query.limit(200).all() + choices = [] + for row in rows: + value = getattr(row, pk_attr) + if label_attr: + label_value = getattr(row, label_attr, None) + display = str(label_value) if label_value not in (None, "") else str(value) + else: + display = str(row) + if display.startswith("<") and " object at " in display: + display = str(value) + choices.append((value, display)) + + return choices + except Exception: + return [] def apply_filter(self, query: Any, value: Any) -> Any: """Apply foreign key filter.""" @@ -216,8 +253,27 @@ def apply_filter(self, query: Any, value: Any) -> Any: model_class = query.column_descriptions[0]['type'] field = getattr(model_class, self.field_name) + column = model_class.__table__.columns.get(self.field_name) + if column is not None: + column_type = type(column.type) + if isinstance(column.type, TypeDecorator): + column_type = type(column.type.impl) + try: + if column_type.__name__ in {"Integer", "BigInteger", "SmallInteger"}: + value = int(value) + except (TypeError, ValueError): + return query + return query.filter(field == value) + def _find_relationship(self, model_class: type[Any]) -> Any | None: + mapper = sa_inspect(model_class) + for relationship in mapper.relationships: + for local_column in relationship.local_columns: + if local_column.key == self.field_name: + return relationship + return None + class FilterManager: """ diff --git a/internal_admin/admin/form_engine.py b/internal_admin/admin/form_engine.py index 2858579..1f3d628 100644 --- a/internal_admin/admin/form_engine.py +++ b/internal_admin/admin/form_engine.py @@ -10,9 +10,11 @@ from typing import Any from sqlalchemy import Boolean, Column, Date, DateTime, Float, Integer, String, Text +from sqlalchemy.inspection import inspect as sa_inspect from sqlalchemy.orm import Session from sqlalchemy.sql.sqltypes import TypeDecorator +from ..registry import get_registry from .model_admin import ModelAdmin @@ -53,6 +55,7 @@ def __init__(self, model_admin: ModelAdmin) -> None: self.model_admin = model_admin self.model = model_admin.model self._type_mapping = self._get_type_mapping() + self._foreign_key_choice_limit = 200 def generate_form_fields(self, session: Session, instance: Any | None = None) -> list[FormField]: """ @@ -198,24 +201,64 @@ def _get_foreign_key_choices(self, column: Column, session: Session) -> list[tup Returns: List of (value, label) tuples """ - choices = [("", "-- Select --")] + related_model = self._get_related_model_for_column(column) + if related_model is None: + return [] - # Get the referenced table and model - list(column.foreign_keys)[0] - - # Find the model class for the referenced table - # This is a simplified approach - in practice, you might need - # a more sophisticated model registry lookup try: - # Try to find model class by table name - # This requires models to be registered or discoverable + mapper = sa_inspect(related_model) + pk_attr = mapper.primary_key[0].key - # For now, skip foreign key choices - can be implemented later - # when we have better model discovery - return choices + label_attr = self._resolve_related_label_attr(related_model) + query = session.query(related_model) + if label_attr and hasattr(related_model, label_attr): + query = query.order_by(getattr(related_model, label_attr).asc()) + else: + query = query.order_by(getattr(related_model, pk_attr).asc()) + rows = query.limit(self._foreign_key_choice_limit).all() + return [ + (getattr(row, pk_attr), self._get_related_display_value(row, label_attr, pk_attr)) + for row in rows + ] except Exception: - return choices + return [] + + def _get_related_model_for_column(self, column: Column) -> type[Any] | None: + relationships = sa_inspect(self.model).relationships + for relationship in relationships: + if column in relationship.local_columns: + return relationship.mapper.class_ + + try: + foreign_key = next(iter(column.foreign_keys)) + except StopIteration: + return None + + referenced_table = foreign_key.column.table + for model_class in get_registry().get_registered_models().keys(): + if getattr(model_class, "__table__", None) is referenced_table: + return model_class + + return None + + def _resolve_related_label_attr(self, related_model: type[Any]) -> str | None: + preferred = ("display_name", "name", "title", "username", "email") + for attr_name in preferred: + if hasattr(related_model, attr_name): + return attr_name + return None + + def _get_related_display_value(self, row: Any, label_attr: str | None, pk_attr: str) -> str: + if label_attr: + value = getattr(row, label_attr, None) + if value not in (None, ""): + return str(value) + + value = str(row) + if value.startswith("<") and " object at " in value: + return str(getattr(row, pk_attr)) + return value def validate_form_data(self, form_data: dict[str, Any]) -> dict[str, Any]: """ diff --git a/internal_admin/admin/query_engine.py b/internal_admin/admin/query_engine.py index c16443b..97eab39 100644 --- a/internal_admin/admin/query_engine.py +++ b/internal_admin/admin/query_engine.py @@ -8,8 +8,10 @@ from typing import Any from sqlalchemy import or_ +from sqlalchemy import Boolean, Date, DateTime, Float, Integer from sqlalchemy.orm import Query, Session from sqlalchemy.orm.strategy_options import selectinload +from sqlalchemy.sql.sqltypes import TypeDecorator from .model_admin import ModelAdmin @@ -199,11 +201,17 @@ def _apply_filters(self, query: Query, filters: dict[str, Any] | None) -> Query: continue field = getattr(self.model, field_name) + column = self.model.__table__.columns.get(field_name) # Skip empty values if value is None or value == "": continue + try: + value = self._coerce_filter_value(column, value) + except ValueError: + continue + # Handle different filter types if isinstance(value, (list, tuple)): # Multiple values - use IN clause @@ -217,6 +225,42 @@ def _apply_filters(self, query: Query, filters: dict[str, Any] | None) -> Query: return query + def _coerce_filter_value(self, column: Any, value: Any) -> Any: + if column is None: + return value + + if isinstance(value, (list, tuple)): + return [self._coerce_filter_value(column, item) for item in value] + + column_type = type(column.type) + if isinstance(column.type, TypeDecorator): + column_type = type(column.type.impl) + + if column_type == Boolean: + if isinstance(value, bool): + return value + normalized = str(value).strip().lower() + if normalized in {"true", "1", "yes", "on"}: + return True + if normalized in {"false", "0", "no", "off"}: + return False + raise ValueError("Invalid boolean filter value") + + if column_type == Integer: + return int(value) + if column_type == Float: + return float(value) + if column_type == Date: + if isinstance(value, str): + return Date().python_type.fromisoformat(value) + return value + if column_type == DateTime: + if isinstance(value, str): + return DateTime().python_type.fromisoformat(value.replace("T", " ")) + return value + + return value + def _apply_ordering(self, query: Query, ordering: list[str] | None) -> Query: """ Apply ordering to query. diff --git a/internal_admin/templates/admin/form.html b/internal_admin/templates/admin/form.html index 65ff16e..9f10095 100644 --- a/internal_admin/templates/admin/form.html +++ b/internal_admin/templates/admin/form.html @@ -85,7 +85,7 @@
{{ "Basic Information" if is_create else "Edit Detai {% if field.choices %} {% for value, display in field.choices %} {% endfor %} diff --git a/tests/conftest.py b/tests/conftest.py index 2d72455..fde9cc7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,9 +8,9 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from sqlalchemy import create_engine, Column, Integer, String, Boolean +from sqlalchemy import create_engine, Column, Integer, String, Boolean, ForeignKey from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import sessionmaker, Session, relationship from internal_admin import AdminSite, AdminConfig, ModelAdmin from internal_admin.auth.models import AdminUser @@ -30,6 +30,10 @@ class TestUser(Base): is_active = Column(Boolean, default=True) is_superuser = Column(Boolean, default=False) + @property + def display_name(self) -> str: + return self.username or f"User {self.id}" + class TestModel(Base): """Simple test model for admin testing.""" @@ -41,6 +45,31 @@ class TestModel(Base): is_active = Column(Boolean, default=True) +class TestCategory(Base): + """Related model for foreign key tests.""" + __tablename__ = "test_categories" + + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + + products = relationship("TestProduct", back_populates="category") + + def __str__(self) -> str: + return self.name + + +class TestProduct(Base): + """Model containing a foreign key for admin tests.""" + __tablename__ = "test_products" + + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + category_id = Column(Integer, ForeignKey("test_categories.id"), nullable=False) + is_active = Column(Boolean, default=True) + + category = relationship("TestCategory", back_populates="products") + + class TestModelAdmin(ModelAdmin): """Test ModelAdmin configuration.""" list_display = ["id", "name", "is_active"] @@ -48,6 +77,19 @@ class TestModelAdmin(ModelAdmin): list_filter = ["is_active"] +class TestCategoryAdmin(ModelAdmin): + """Admin configuration for category model.""" + list_display = ["id", "name"] + search_fields = ["name"] + + +class TestProductAdmin(ModelAdmin): + """Admin configuration for product model.""" + list_display = ["id", "name", "category_id", "is_active"] + search_fields = ["name"] + list_filter = ["category_id", "is_active"] + + @pytest.fixture(scope="session") def test_db() -> Generator[str, None, None]: """Create a temporary test database.""" @@ -87,6 +129,8 @@ def admin_site(admin_config: AdminConfig) -> AdminSite: # Create fresh AdminSite after clearing registry site = AdminSite(admin_config) site.register(TestModel, TestModelAdmin) + site.register(TestCategory, TestCategoryAdmin) + site.register(TestProduct, TestProductAdmin) return site @@ -150,6 +194,29 @@ def test_objects(db_session: Session) -> list[TestModel]: return objects +@pytest.fixture +def fk_objects(db_session: Session) -> dict[str, Any]: + """Create related objects for foreign key tests.""" + categories = [ + TestCategory(name="Hardware"), + TestCategory(name="Software"), + ] + db_session.add_all(categories) + db_session.flush() + + products = [ + TestProduct(name="Keyboard", category_id=categories[0].id, is_active=True), + TestProduct(name="IDE License", category_id=categories[1].id, is_active=True), + ] + db_session.add_all(products) + db_session.commit() + + return { + "categories": categories, + "products": products, + } + + @pytest.fixture def authenticated_client(client: TestClient, test_user: TestUser) -> TestClient: """Create authenticated test client.""" diff --git a/tests/test_admin.py b/tests/test_admin.py index d27c286..007cf56 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -10,35 +10,33 @@ class TestAdminSite: """Test AdminSite functionality.""" - + def test_admin_site_creation(self, admin_config: AdminConfig): """Test AdminSite can be created.""" site = AdminSite(admin_config) assert site.config == admin_config assert not site._initialized - + def test_model_registration(self, admin_site: AdminSite): """Test model registration works.""" from tests.conftest import TestModel, TestUser assert admin_site.is_registered(TestModel) assert not admin_site.is_registered(TestUser) - + def test_get_registered_models(self, admin_site: AdminSite): """Test getting registered models.""" models = admin_site.get_registered_models() assert len(models) == 1 - # Check that TestModel is registered (by class name since import paths may differ) model_names = [model.__name__ for model in models.keys()] assert "TestModel" in model_names class TestAdminConfig: """Test AdminConfig functionality.""" - + def test_config_validation(self, test_db: str): """Test config validation.""" from tests.conftest import TestUser - # Valid config should work config = AdminConfig( database_url=test_db, secret_key="test-key", @@ -47,7 +45,7 @@ def test_config_validation(self, test_db: str): assert config.database_url == test_db assert config.is_sqlite assert not config.is_postgresql - + def test_invalid_config(self): """Test invalid config raises errors.""" from tests.conftest import TestUser @@ -57,7 +55,7 @@ def test_invalid_config(self): secret_key="test-key", user_model=TestUser ) - + with pytest.raises(ValueError, match="secret_key is required"): AdminConfig( database_url="sqlite:///test.db", @@ -68,47 +66,46 @@ def test_invalid_config(self): class TestModelAdmin: """Test ModelAdmin functionality.""" - + def test_model_admin_creation(self): """Test ModelAdmin can be created.""" from tests.conftest import TestModel admin = ModelAdmin(TestModel) assert admin.model == TestModel - + def test_list_display(self): """Test list_display configuration.""" from tests.conftest import TestModel admin = ModelAdmin(TestModel) - - # Should return default columns if not configured + display_fields = admin.get_list_display() assert "id" in display_fields assert "name" in display_fields - + def test_search_fields(self): """Test search_fields configuration.""" from tests.conftest import TestModel admin = ModelAdmin(TestModel) admin.search_fields = ["name"] - + search_fields = admin.get_search_fields() assert search_fields == ["name"] class TestAdminRoutes: """Test admin route functionality.""" - + def test_dashboard_requires_auth(self, client: TestClient): """Test dashboard requires authentication.""" response = client.get("/admin/") assert response.status_code == 401 - + def test_login_page(self, client: TestClient): """Test login page is accessible.""" response = client.get("/admin/login") assert response.status_code == 200 assert "login" in response.text.lower() - + def test_model_list_requires_auth(self, client: TestClient): """Test model list requires authentication.""" response = client.get("/admin/testmodel/") @@ -117,24 +114,24 @@ def test_model_list_requires_auth(self, client: TestClient): class TestAuthentication: """Test authentication functionality.""" - + def test_successful_login(self, client: TestClient, test_user): """Test successful login.""" response = client.post("/admin/login", data={ "username": test_user.username, "password": "testpass123" }) - assert response.status_code == 302 # Redirect after login - + assert response.status_code == 302 + def test_invalid_login(self, client: TestClient): """Test invalid login credentials.""" response = client.post("/admin/login", data={ "username": "nonexistent", "password": "wrongpass" }) - assert response.status_code == 302 # Redirect to login with error + assert response.status_code == 302 assert "error=invalid_credentials" in response.headers["location"] - + def test_authenticated_dashboard_access(self, authenticated_client: TestClient): """Test authenticated users can access dashboard.""" response = authenticated_client.get("/admin/") @@ -144,21 +141,21 @@ def test_authenticated_dashboard_access(self, authenticated_client: TestClient): class TestCRUDOperations: """Test CRUD operations via admin interface.""" - + def test_model_list_view(self, authenticated_client: TestClient, test_objects): """Test model list view.""" response = authenticated_client.get("/admin/testmodel/") assert response.status_code == 200 assert "Test 1" in response.text assert "Test 2" in response.text - + def test_create_form_view(self, authenticated_client: TestClient): """Test create form view.""" response = authenticated_client.get("/admin/testmodel/create/") assert response.status_code == 200 assert "form" in response.text.lower() assert "name" in response.text.lower() - + def test_create_object(self, authenticated_client: TestClient): """Test creating new object.""" response = authenticated_client.post("/admin/testmodel/create/", data={ @@ -166,14 +163,14 @@ def test_create_object(self, authenticated_client: TestClient): "description": "Created via test", "is_active": True }) - assert response.status_code == 302 # Redirect after create - + assert response.status_code == 302 + def test_edit_form_view(self, authenticated_client: TestClient, test_objects): """Test edit form view.""" response = authenticated_client.get(f"/admin/testmodel/{test_objects[0].id}/") assert response.status_code == 200 assert test_objects[0].name in response.text - + def test_delete_confirmation(self, authenticated_client: TestClient, test_objects): """Test delete confirmation page.""" response = authenticated_client.get(f"/admin/testmodel/{test_objects[0].id}/delete/") @@ -184,29 +181,64 @@ def test_delete_confirmation(self, authenticated_client: TestClient, test_object class TestSearchAndFiltering: """Test search and filtering functionality.""" - + def test_search(self, authenticated_client: TestClient, test_objects): """Test search functionality.""" response = authenticated_client.get("/admin/testmodel/?search=Test 1") assert response.status_code == 200 assert "Test 1" in response.text - + def test_filter(self, authenticated_client: TestClient, test_objects): """Test filtering functionality.""" response = authenticated_client.get("/admin/testmodel/?is_active=true") assert response.status_code == 200 - # Should show active objects but not inactive ones class TestPermissions: """Test permission system.""" - + def test_superuser_permissions(self, authenticated_client: TestClient): """Test superuser has all permissions.""" - # Should be able to access model list response = authenticated_client.get("/admin/testmodel/") assert response.status_code == 200 - - # Should be able to access create form + response = authenticated_client.get("/admin/testmodel/create/") - assert response.status_code == 200 \ No newline at end of file + assert response.status_code == 200 + + +class TestForeignKeys: + """Test foreign key functionality in forms and filters.""" + + def test_fk_field_renders_select_choices(self, authenticated_client: TestClient, fk_objects): + """Create form should render related model choices for FK fields.""" + response = authenticated_client.get("/admin/testproduct/create/") + assert response.status_code == 200 + assert "category_id" in response.text + assert "Hardware" in response.text + assert "Software" in response.text + + def test_fk_create_submission_persists_relation(self, authenticated_client: TestClient, db_session, fk_objects): + """Submitting FK value should be validated and persisted correctly.""" + from tests.conftest import TestProduct + + hardware_id = fk_objects["categories"][0].id + response = authenticated_client.post("/admin/testproduct/create/", data={ + "name": "Mouse", + "category_id": str(hardware_id), + "is_active": "true", + }) + + assert response.status_code == 302 + + created = db_session.query(TestProduct).filter(TestProduct.name == "Mouse").first() + assert created is not None + assert created.category_id == hardware_id + + def test_fk_filter_applies_correctly(self, authenticated_client: TestClient, fk_objects): + """FK list filter should match rows by related ID.""" + software_id = fk_objects["categories"][1].id + response = authenticated_client.get(f"/admin/testproduct/?category_id={software_id}") + + assert response.status_code == 200 + assert "IDE License" in response.text + assert "Keyboard" not in response.text