diff --git a/internal_admin/admin/filters.py b/internal_admin/admin/filters.py
index b780877..36d2c00 100644
--- a/internal_admin/admin/filters.py
+++ b/internal_admin/admin/filters.py
@@ -9,7 +9,9 @@
from typing import Any
from sqlalchemy import Boolean, Date, DateTime
+from sqlalchemy.inspection import inspect as sa_inspect
from sqlalchemy.orm import Session
+from sqlalchemy.sql.sqltypes import TypeDecorator
from .model_admin import ModelAdmin
@@ -204,9 +206,44 @@ def get_choices(self, session: Session, model_class: type[Any]) -> list[tuple[An
if not hasattr(model_class, self.field_name):
return []
- # This is simplified - in practice you'd need proper relationship introspection
- # For now, return empty choices
- return []
+ try:
+ relationship = self._find_relationship(model_class)
+ if relationship is None:
+ return []
+
+ related_model = relationship.mapper.class_
+ related_mapper = sa_inspect(related_model)
+ pk_attr = related_mapper.primary_key[0].key
+
+ label_attr = self.display_field if hasattr(related_model, self.display_field) else None
+ if label_attr is None:
+ for candidate in ("display_name", "name", "title", "username", "email"):
+ if hasattr(related_model, candidate):
+ label_attr = candidate
+ break
+
+ query = session.query(related_model)
+ if label_attr:
+ query = query.order_by(getattr(related_model, label_attr).asc())
+ else:
+ query = query.order_by(getattr(related_model, pk_attr).asc())
+
+ rows = query.limit(200).all()
+ choices = []
+ for row in rows:
+ value = getattr(row, pk_attr)
+ if label_attr:
+ label_value = getattr(row, label_attr, None)
+ display = str(label_value) if label_value not in (None, "") else str(value)
+ else:
+ display = str(row)
+ if display.startswith("<") and " object at " in display:
+ display = str(value)
+ choices.append((value, display))
+
+ return choices
+ except Exception:
+ return []
def apply_filter(self, query: Any, value: Any) -> Any:
"""Apply foreign key filter."""
@@ -216,8 +253,27 @@ def apply_filter(self, query: Any, value: Any) -> Any:
model_class = query.column_descriptions[0]['type']
field = getattr(model_class, self.field_name)
+ column = model_class.__table__.columns.get(self.field_name)
+ if column is not None:
+ column_type = type(column.type)
+ if isinstance(column.type, TypeDecorator):
+ column_type = type(column.type.impl)
+ try:
+ if column_type.__name__ in {"Integer", "BigInteger", "SmallInteger"}:
+ value = int(value)
+ except (TypeError, ValueError):
+ return query
+
return query.filter(field == value)
+ def _find_relationship(self, model_class: type[Any]) -> Any | None:
+ mapper = sa_inspect(model_class)
+ for relationship in mapper.relationships:
+ for local_column in relationship.local_columns:
+ if local_column.key == self.field_name:
+ return relationship
+ return None
+
class FilterManager:
"""
diff --git a/internal_admin/admin/form_engine.py b/internal_admin/admin/form_engine.py
index 2858579..1f3d628 100644
--- a/internal_admin/admin/form_engine.py
+++ b/internal_admin/admin/form_engine.py
@@ -10,9 +10,11 @@
from typing import Any
from sqlalchemy import Boolean, Column, Date, DateTime, Float, Integer, String, Text
+from sqlalchemy.inspection import inspect as sa_inspect
from sqlalchemy.orm import Session
from sqlalchemy.sql.sqltypes import TypeDecorator
+from ..registry import get_registry
from .model_admin import ModelAdmin
@@ -53,6 +55,7 @@ def __init__(self, model_admin: ModelAdmin) -> None:
self.model_admin = model_admin
self.model = model_admin.model
self._type_mapping = self._get_type_mapping()
+ self._foreign_key_choice_limit = 200
def generate_form_fields(self, session: Session, instance: Any | None = None) -> list[FormField]:
"""
@@ -198,24 +201,64 @@ def _get_foreign_key_choices(self, column: Column, session: Session) -> list[tup
Returns:
List of (value, label) tuples
"""
- choices = [("", "-- Select --")]
+ related_model = self._get_related_model_for_column(column)
+ if related_model is None:
+ return []
- # Get the referenced table and model
- list(column.foreign_keys)[0]
-
- # Find the model class for the referenced table
- # This is a simplified approach - in practice, you might need
- # a more sophisticated model registry lookup
try:
- # Try to find model class by table name
- # This requires models to be registered or discoverable
+ mapper = sa_inspect(related_model)
+ pk_attr = mapper.primary_key[0].key
- # For now, skip foreign key choices - can be implemented later
- # when we have better model discovery
- return choices
+ label_attr = self._resolve_related_label_attr(related_model)
+ query = session.query(related_model)
+ if label_attr and hasattr(related_model, label_attr):
+ query = query.order_by(getattr(related_model, label_attr).asc())
+ else:
+ query = query.order_by(getattr(related_model, pk_attr).asc())
+ rows = query.limit(self._foreign_key_choice_limit).all()
+ return [
+ (getattr(row, pk_attr), self._get_related_display_value(row, label_attr, pk_attr))
+ for row in rows
+ ]
except Exception:
- return choices
+ return []
+
+ def _get_related_model_for_column(self, column: Column) -> type[Any] | None:
+ relationships = sa_inspect(self.model).relationships
+ for relationship in relationships:
+ if column in relationship.local_columns:
+ return relationship.mapper.class_
+
+ try:
+ foreign_key = next(iter(column.foreign_keys))
+ except StopIteration:
+ return None
+
+ referenced_table = foreign_key.column.table
+ for model_class in get_registry().get_registered_models().keys():
+ if getattr(model_class, "__table__", None) is referenced_table:
+ return model_class
+
+ return None
+
+ def _resolve_related_label_attr(self, related_model: type[Any]) -> str | None:
+ preferred = ("display_name", "name", "title", "username", "email")
+ for attr_name in preferred:
+ if hasattr(related_model, attr_name):
+ return attr_name
+ return None
+
+ def _get_related_display_value(self, row: Any, label_attr: str | None, pk_attr: str) -> str:
+ if label_attr:
+ value = getattr(row, label_attr, None)
+ if value not in (None, ""):
+ return str(value)
+
+ value = str(row)
+ if value.startswith("<") and " object at " in value:
+ return str(getattr(row, pk_attr))
+ return value
def validate_form_data(self, form_data: dict[str, Any]) -> dict[str, Any]:
"""
diff --git a/internal_admin/admin/query_engine.py b/internal_admin/admin/query_engine.py
index c16443b..97eab39 100644
--- a/internal_admin/admin/query_engine.py
+++ b/internal_admin/admin/query_engine.py
@@ -8,8 +8,10 @@
from typing import Any
from sqlalchemy import or_
+from sqlalchemy import Boolean, Date, DateTime, Float, Integer
from sqlalchemy.orm import Query, Session
from sqlalchemy.orm.strategy_options import selectinload
+from sqlalchemy.sql.sqltypes import TypeDecorator
from .model_admin import ModelAdmin
@@ -199,11 +201,17 @@ def _apply_filters(self, query: Query, filters: dict[str, Any] | None) -> Query:
continue
field = getattr(self.model, field_name)
+ column = self.model.__table__.columns.get(field_name)
# Skip empty values
if value is None or value == "":
continue
+ try:
+ value = self._coerce_filter_value(column, value)
+ except ValueError:
+ continue
+
# Handle different filter types
if isinstance(value, (list, tuple)):
# Multiple values - use IN clause
@@ -217,6 +225,42 @@ def _apply_filters(self, query: Query, filters: dict[str, Any] | None) -> Query:
return query
+ def _coerce_filter_value(self, column: Any, value: Any) -> Any:
+ if column is None:
+ return value
+
+ if isinstance(value, (list, tuple)):
+ return [self._coerce_filter_value(column, item) for item in value]
+
+ column_type = type(column.type)
+ if isinstance(column.type, TypeDecorator):
+ column_type = type(column.type.impl)
+
+ if column_type == Boolean:
+ if isinstance(value, bool):
+ return value
+ normalized = str(value).strip().lower()
+ if normalized in {"true", "1", "yes", "on"}:
+ return True
+ if normalized in {"false", "0", "no", "off"}:
+ return False
+ raise ValueError("Invalid boolean filter value")
+
+ if column_type == Integer:
+ return int(value)
+ if column_type == Float:
+ return float(value)
+ if column_type == Date:
+ if isinstance(value, str):
+ return Date().python_type.fromisoformat(value)
+ return value
+ if column_type == DateTime:
+ if isinstance(value, str):
+ return DateTime().python_type.fromisoformat(value.replace("T", " "))
+ return value
+
+ return value
+
def _apply_ordering(self, query: Query, ordering: list[str] | None) -> Query:
"""
Apply ordering to query.
diff --git a/internal_admin/templates/admin/form.html b/internal_admin/templates/admin/form.html
index 65ff16e..9f10095 100644
--- a/internal_admin/templates/admin/form.html
+++ b/internal_admin/templates/admin/form.html
@@ -85,7 +85,7 @@
{{ "Basic Information" if is_create else "Edit Detai
{% if field.choices %}
{% for value, display in field.choices %}
{% endfor %}
diff --git a/tests/conftest.py b/tests/conftest.py
index 2d72455..fde9cc7 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,9 +8,9 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
-from sqlalchemy import create_engine, Column, Integer, String, Boolean
+from sqlalchemy import create_engine, Column, Integer, String, Boolean, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import sessionmaker, Session
+from sqlalchemy.orm import sessionmaker, Session, relationship
from internal_admin import AdminSite, AdminConfig, ModelAdmin
from internal_admin.auth.models import AdminUser
@@ -30,6 +30,10 @@ class TestUser(Base):
is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False)
+ @property
+ def display_name(self) -> str:
+ return self.username or f"User {self.id}"
+
class TestModel(Base):
"""Simple test model for admin testing."""
@@ -41,6 +45,31 @@ class TestModel(Base):
is_active = Column(Boolean, default=True)
+class TestCategory(Base):
+ """Related model for foreign key tests."""
+ __tablename__ = "test_categories"
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String(100), nullable=False)
+
+ products = relationship("TestProduct", back_populates="category")
+
+ def __str__(self) -> str:
+ return self.name
+
+
+class TestProduct(Base):
+ """Model containing a foreign key for admin tests."""
+ __tablename__ = "test_products"
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String(100), nullable=False)
+ category_id = Column(Integer, ForeignKey("test_categories.id"), nullable=False)
+ is_active = Column(Boolean, default=True)
+
+ category = relationship("TestCategory", back_populates="products")
+
+
class TestModelAdmin(ModelAdmin):
"""Test ModelAdmin configuration."""
list_display = ["id", "name", "is_active"]
@@ -48,6 +77,19 @@ class TestModelAdmin(ModelAdmin):
list_filter = ["is_active"]
+class TestCategoryAdmin(ModelAdmin):
+ """Admin configuration for category model."""
+ list_display = ["id", "name"]
+ search_fields = ["name"]
+
+
+class TestProductAdmin(ModelAdmin):
+ """Admin configuration for product model."""
+ list_display = ["id", "name", "category_id", "is_active"]
+ search_fields = ["name"]
+ list_filter = ["category_id", "is_active"]
+
+
@pytest.fixture(scope="session")
def test_db() -> Generator[str, None, None]:
"""Create a temporary test database."""
@@ -87,6 +129,8 @@ def admin_site(admin_config: AdminConfig) -> AdminSite:
# Create fresh AdminSite after clearing registry
site = AdminSite(admin_config)
site.register(TestModel, TestModelAdmin)
+ site.register(TestCategory, TestCategoryAdmin)
+ site.register(TestProduct, TestProductAdmin)
return site
@@ -150,6 +194,29 @@ def test_objects(db_session: Session) -> list[TestModel]:
return objects
+@pytest.fixture
+def fk_objects(db_session: Session) -> dict[str, Any]:
+ """Create related objects for foreign key tests."""
+ categories = [
+ TestCategory(name="Hardware"),
+ TestCategory(name="Software"),
+ ]
+ db_session.add_all(categories)
+ db_session.flush()
+
+ products = [
+ TestProduct(name="Keyboard", category_id=categories[0].id, is_active=True),
+ TestProduct(name="IDE License", category_id=categories[1].id, is_active=True),
+ ]
+ db_session.add_all(products)
+ db_session.commit()
+
+ return {
+ "categories": categories,
+ "products": products,
+ }
+
+
@pytest.fixture
def authenticated_client(client: TestClient, test_user: TestUser) -> TestClient:
"""Create authenticated test client."""
diff --git a/tests/test_admin.py b/tests/test_admin.py
index d27c286..007cf56 100644
--- a/tests/test_admin.py
+++ b/tests/test_admin.py
@@ -10,35 +10,33 @@
class TestAdminSite:
"""Test AdminSite functionality."""
-
+
def test_admin_site_creation(self, admin_config: AdminConfig):
"""Test AdminSite can be created."""
site = AdminSite(admin_config)
assert site.config == admin_config
assert not site._initialized
-
+
def test_model_registration(self, admin_site: AdminSite):
"""Test model registration works."""
from tests.conftest import TestModel, TestUser
assert admin_site.is_registered(TestModel)
assert not admin_site.is_registered(TestUser)
-
+
def test_get_registered_models(self, admin_site: AdminSite):
"""Test getting registered models."""
models = admin_site.get_registered_models()
assert len(models) == 1
- # Check that TestModel is registered (by class name since import paths may differ)
model_names = [model.__name__ for model in models.keys()]
assert "TestModel" in model_names
class TestAdminConfig:
"""Test AdminConfig functionality."""
-
+
def test_config_validation(self, test_db: str):
"""Test config validation."""
from tests.conftest import TestUser
- # Valid config should work
config = AdminConfig(
database_url=test_db,
secret_key="test-key",
@@ -47,7 +45,7 @@ def test_config_validation(self, test_db: str):
assert config.database_url == test_db
assert config.is_sqlite
assert not config.is_postgresql
-
+
def test_invalid_config(self):
"""Test invalid config raises errors."""
from tests.conftest import TestUser
@@ -57,7 +55,7 @@ def test_invalid_config(self):
secret_key="test-key",
user_model=TestUser
)
-
+
with pytest.raises(ValueError, match="secret_key is required"):
AdminConfig(
database_url="sqlite:///test.db",
@@ -68,47 +66,46 @@ def test_invalid_config(self):
class TestModelAdmin:
"""Test ModelAdmin functionality."""
-
+
def test_model_admin_creation(self):
"""Test ModelAdmin can be created."""
from tests.conftest import TestModel
admin = ModelAdmin(TestModel)
assert admin.model == TestModel
-
+
def test_list_display(self):
"""Test list_display configuration."""
from tests.conftest import TestModel
admin = ModelAdmin(TestModel)
-
- # Should return default columns if not configured
+
display_fields = admin.get_list_display()
assert "id" in display_fields
assert "name" in display_fields
-
+
def test_search_fields(self):
"""Test search_fields configuration."""
from tests.conftest import TestModel
admin = ModelAdmin(TestModel)
admin.search_fields = ["name"]
-
+
search_fields = admin.get_search_fields()
assert search_fields == ["name"]
class TestAdminRoutes:
"""Test admin route functionality."""
-
+
def test_dashboard_requires_auth(self, client: TestClient):
"""Test dashboard requires authentication."""
response = client.get("/admin/")
assert response.status_code == 401
-
+
def test_login_page(self, client: TestClient):
"""Test login page is accessible."""
response = client.get("/admin/login")
assert response.status_code == 200
assert "login" in response.text.lower()
-
+
def test_model_list_requires_auth(self, client: TestClient):
"""Test model list requires authentication."""
response = client.get("/admin/testmodel/")
@@ -117,24 +114,24 @@ def test_model_list_requires_auth(self, client: TestClient):
class TestAuthentication:
"""Test authentication functionality."""
-
+
def test_successful_login(self, client: TestClient, test_user):
"""Test successful login."""
response = client.post("/admin/login", data={
"username": test_user.username,
"password": "testpass123"
})
- assert response.status_code == 302 # Redirect after login
-
+ assert response.status_code == 302
+
def test_invalid_login(self, client: TestClient):
"""Test invalid login credentials."""
response = client.post("/admin/login", data={
"username": "nonexistent",
"password": "wrongpass"
})
- assert response.status_code == 302 # Redirect to login with error
+ assert response.status_code == 302
assert "error=invalid_credentials" in response.headers["location"]
-
+
def test_authenticated_dashboard_access(self, authenticated_client: TestClient):
"""Test authenticated users can access dashboard."""
response = authenticated_client.get("/admin/")
@@ -144,21 +141,21 @@ def test_authenticated_dashboard_access(self, authenticated_client: TestClient):
class TestCRUDOperations:
"""Test CRUD operations via admin interface."""
-
+
def test_model_list_view(self, authenticated_client: TestClient, test_objects):
"""Test model list view."""
response = authenticated_client.get("/admin/testmodel/")
assert response.status_code == 200
assert "Test 1" in response.text
assert "Test 2" in response.text
-
+
def test_create_form_view(self, authenticated_client: TestClient):
"""Test create form view."""
response = authenticated_client.get("/admin/testmodel/create/")
assert response.status_code == 200
assert "form" in response.text.lower()
assert "name" in response.text.lower()
-
+
def test_create_object(self, authenticated_client: TestClient):
"""Test creating new object."""
response = authenticated_client.post("/admin/testmodel/create/", data={
@@ -166,14 +163,14 @@ def test_create_object(self, authenticated_client: TestClient):
"description": "Created via test",
"is_active": True
})
- assert response.status_code == 302 # Redirect after create
-
+ assert response.status_code == 302
+
def test_edit_form_view(self, authenticated_client: TestClient, test_objects):
"""Test edit form view."""
response = authenticated_client.get(f"/admin/testmodel/{test_objects[0].id}/")
assert response.status_code == 200
assert test_objects[0].name in response.text
-
+
def test_delete_confirmation(self, authenticated_client: TestClient, test_objects):
"""Test delete confirmation page."""
response = authenticated_client.get(f"/admin/testmodel/{test_objects[0].id}/delete/")
@@ -184,29 +181,64 @@ def test_delete_confirmation(self, authenticated_client: TestClient, test_object
class TestSearchAndFiltering:
"""Test search and filtering functionality."""
-
+
def test_search(self, authenticated_client: TestClient, test_objects):
"""Test search functionality."""
response = authenticated_client.get("/admin/testmodel/?search=Test 1")
assert response.status_code == 200
assert "Test 1" in response.text
-
+
def test_filter(self, authenticated_client: TestClient, test_objects):
"""Test filtering functionality."""
response = authenticated_client.get("/admin/testmodel/?is_active=true")
assert response.status_code == 200
- # Should show active objects but not inactive ones
class TestPermissions:
"""Test permission system."""
-
+
def test_superuser_permissions(self, authenticated_client: TestClient):
"""Test superuser has all permissions."""
- # Should be able to access model list
response = authenticated_client.get("/admin/testmodel/")
assert response.status_code == 200
-
- # Should be able to access create form
+
response = authenticated_client.get("/admin/testmodel/create/")
- assert response.status_code == 200
\ No newline at end of file
+ assert response.status_code == 200
+
+
+class TestForeignKeys:
+ """Test foreign key functionality in forms and filters."""
+
+ def test_fk_field_renders_select_choices(self, authenticated_client: TestClient, fk_objects):
+ """Create form should render related model choices for FK fields."""
+ response = authenticated_client.get("/admin/testproduct/create/")
+ assert response.status_code == 200
+ assert "category_id" in response.text
+ assert "Hardware" in response.text
+ assert "Software" in response.text
+
+ def test_fk_create_submission_persists_relation(self, authenticated_client: TestClient, db_session, fk_objects):
+ """Submitting FK value should be validated and persisted correctly."""
+ from tests.conftest import TestProduct
+
+ hardware_id = fk_objects["categories"][0].id
+ response = authenticated_client.post("/admin/testproduct/create/", data={
+ "name": "Mouse",
+ "category_id": str(hardware_id),
+ "is_active": "true",
+ })
+
+ assert response.status_code == 302
+
+ created = db_session.query(TestProduct).filter(TestProduct.name == "Mouse").first()
+ assert created is not None
+ assert created.category_id == hardware_id
+
+ def test_fk_filter_applies_correctly(self, authenticated_client: TestClient, fk_objects):
+ """FK list filter should match rows by related ID."""
+ software_id = fk_objects["categories"][1].id
+ response = authenticated_client.get(f"/admin/testproduct/?category_id={software_id}")
+
+ assert response.status_code == 200
+ assert "IDE License" in response.text
+ assert "Keyboard" not in response.text