
# @pytest.mark.asyncio
# async def test_bulk_delete_own_documents(cape):
#     """Test that users can bulk delete their own documents using delete statement"""
#     async with cape.get_session(AuthContext(subject="alice")) as session:
#         # Delete all documents owned by alice
#         stmt = delete(SecureDocument).where(SecureDocument.owner_id == "alice")
#         await session.execute(stmt)
#         await session.commit()
        
#         # Verify alice's documents are deleted
#         results = await session.execute(
#             select(SecureDocument).where(SecureDocument.owner_id == "alice")
#         )
#         alice_docs = results.scalars().all()
#         assert len(alice_docs) == 0
        
#         # Verify bob's documents still exist
#         results = await session.execute(
#             select(SecureDocument).where(SecureDocument.owner_id == "bob")
#         )
#         bob_docs = results.scalars().all()
#         assert len(bob_docs) == 2

# @pytest.mark.asyncio
# async def test_bulk_delete_permission_denied(cape):
#     """Test that users cannot bulk delete documents they don't own"""
#     async with cape.get_session(AuthContext(subject="bob")) as session:
#         # Attempt to delete alice's documents
#         stmt = delete(SecureDocument).where(SecureDocument.owner_id == "alice")
#         await session.execute(stmt)
#         await session.commit()
        
#         # Verify alice's documents were not deleted
#         results = await session.execute(
#             select(SecureDocument).where(SecureDocument.owner_id == "alice")
#         )
#         alice_docs = results.scalars().all()
#         assert len(alice_docs) == 2

@pytest.mark.asyncio
async def test_bulk_insert_with_statement(cape):
    """Test bulk insert using insert statement"""
    async with cape.get_session(AuthContext(subject="alice", context={"org_id": "org1"})) as session:
        # Create multiple documents using insert statement
        stmt = insert(SecureDocument).values([
            {
                "title": f"Bulk Doc {i}",
                "content": f"Bulk Content {i}",
                "org_id": "org1"
            }
            for i in range(3)
        ])
        await session.execute(stmt)
        await session.commit()
        
        # Verify documents were inserted with correct ownership
        results = await session.execute(
            select(SecureDocument)
            .where(SecureDocument.title.like("Bulk Doc%"))
        )
        docs = results.scalars().all()
        
        assert len(docs) == 3
        for doc in docs:
            assert doc.owner_id == "alice"
            assert doc.org_id == "org1"
            assert "Bulk Content" in doc.content

# @pytest.mark.asyncio
# async def test_bulk_insert_permission_denied(cape):
#     """Test bulk insert fails when trying to set unauthorized ownership"""
#     async with cape.get_session(AuthContext(subject="alice", context={"org_id": "org1"})) as session:
#         with pytest.raises(PermissionDeniedError):
#             stmt = insert(SecureDocument).values([
#                 {
#                     "title": f"Bulk Doc {i}",
#                     "content": f"Bulk Content {i}",
#                     "owner_id": "bob",  # Trying to create docs owned by bob
#                     "org_id": "org1"
#                 }
#                 for i in range(3)
#             ])
#             await session.execute(stmt)
#             await session.commit()

# @pytest.mark.asyncio
# async def test_complex_update_query(cape):
#     """Test complex update query with multiple conditions"""
#     async with cape.get_session(AuthContext(subject="alice", context={"org_id": "org1"})) as session:
#         # Update documents that match multiple conditions
#         stmt = (
#             update(SecureDocument)
#             .where(
#                 SecureDocument.owner_id == "alice",
#                 SecureDocument.org_id == "org1",
#                 SecureDocument.title.like("Doc%")
#             )
#             .values(content="Complex update")
#         )
#         await session.execute(stmt)
#         await session.commit()
        
#         # Verify only matching documents were updated
#         results = await session.execute(
#             select(SecureDocument)
#             .where(SecureDocument.content == "Complex update")
#         )
#         updated_docs = results.scalars().all()
        
#         assert len(updated_docs) == 1  # Should only update Doc 1
#         assert all(doc.owner_id == "alice" for doc in updated_docs)
#         assert all(doc.org_id == "org1" for doc in updated_docs)

# @pytest.mark.asyncio
# async def test_join_query_permissions(cape):
#     """Test that permissions are properly enforced in join queries"""
#     # Create a related model for testing joins
#     class RelatedDoc(SQLModel, table=True):
#         __table_args__ = {'extend_existing': True}
#         id: Optional[int] = Field(default=None, primary_key=True)
#         doc_id: int = Field(foreign_key="securedocument.id")
#         note: str
    
#     # Set up permissions for RelatedDoc
#     cape.permission_required(RelatedDoc, role="*", actions=["read"], context_fields=["doc_id"])
    
#     async with cape.get_session(AuthContext(subject="alice", context={"org_id": "org1"})) as session:
#         # Create some related documents
#         doc = await session.execute(
#             select(SecureDocument).where(SecureDocument.owner_id == "alice").limit(1)
#         )
#         doc = doc.scalar_one()
        
#         related = RelatedDoc(doc_id=doc.id, note="Test note")
#         session.add(related)
#         await session.commit()
        
#         # Test join query
#         results = await session.execute(
#             select(SecureDocument, RelatedDoc)
#             .join(RelatedDoc, SecureDocument.id == RelatedDoc.doc_id)
#         )
#         joined_results = results.all()
        
#         assert len(joined_results) == 1
#         assert joined_results[0][0].owner_id == "alice"
#         assert joined_results[0][1].note == "Test note"


# Ways to define policy

    def policy(self, policy_config: dict):
        """
        New decorator for declarative permission policies.
        
        Example:
        @cape.policy({
            "read": {"role": "*"},
            "create": {"role": "*"},
            "update": {"role": "*"},
            "delete": {
                "role": "*", 
                "context_field": ["owner_id"]
            },
        })
        """
        def decorator(cls: Type[SQLModel]) -> Type[SQLModel]:
            for action, config in policy_config.items():
                # Extract configuration
                role = config.get("role", "*")
                owner_field = config.get("owner_field")
                context_fields = config.get("context_field", [])
                
                # Create RLSConfig for each action
                rls_config = RLSConfig(
                    model=cls,
                    action=action,
                    role=role,
                    owner_field=owner_field,
                    context_fields=context_fields
                )
                
                self.row_level_security.register_model(rls_config)
            
            return cls
            
        return decorator


# Would it be possible to rely on AST to parse the policy?

from datetime import datetime
from typing import Any, Dict, Optional, Type
from uuid import UUID, uuid4
from sqlmodel import SQLModel, Field
from sqlalchemy import event
from sqlalchemy.orm import Session
import json

class AuditLog(SQLModel, table=True):
    """Model for storing audit logs"""
    id: UUID = Field(default_factory=uuid4, primary_key=True)
    table: str
    record_id: str
    user_id: Optional[str]
    action: str  # "INSERT", "UPDATE", "DELETE"
    timestamp: datetime = Field(default_factory=datetime.utcnow)
    old_values: Optional[Dict[str, Any]] = Field(default=None)
    new_values: Optional[Dict[str, Any]] = Field(default=None)

    class Config:
        json_encoders = {
            datetime: lambda v: v.isoformat(),
            UUID: lambda v: str(v)
        }

def _get_auth_user_id(session: Session) -> Optional[str]:
    """Extract user_id from session's auth context"""
    auth_context = session.info.get("auth_context")
    return auth_context.id if auth_context else None

def _serialize_model(model: SQLModel) -> dict:
    """Convert model to dictionary, handling special types"""
    data = model.__dict__.copy()
    data.pop('_sa_instance_state', None)
    
    # Convert special types to string representation
    for key, value in data.items():
        if isinstance(value, (datetime, UUID)):
            data[key] = str(value)
    return data

def enable_audit_trail(model_class: Type[SQLModel]):
    """Enable audit trail for a specific model"""
    
    @event.listens_for(model_class, 'after_insert')
    def after_insert(mapper, connection, target):
        session = Session.object_session(target)
        if not session:
            return
            
        audit = AuditLog(
            table=model_class.__tablename__,
            record_id=str(target.id),
            user_id=_get_auth_user_id(session),
            action="INSERT",
            new_values=_serialize_model(target)
        )
        session.add(audit)

    @event.listens_for(model_class, 'after_update')
    def after_update(mapper, connection, target):
        session = Session.object_session(target)
        if not session:
            return
            
        # Get the history of changes
        changes = {}
        old_values = {}
        
        for attr in mapper.attrs:
            hist = getattr(target, '_sa_instance_state').attrs[attr.key].history
            if hist.has_changes():
                changes[attr.key] = hist.added[0] if hist.added else None
                old_values[attr.key] = hist.deleted[0] if hist.deleted else None

        if changes:  # Only create audit log if there were actual changes
            audit = AuditLog(
                table=model_class.__tablename__,
                record_id=str(target.id),
                user_id=_get_auth_user_id(session),
                action="UPDATE",
                old_values=old_values,
                new_values=changes
            )
            session.add(audit)

    @event.listens_for(model_class, 'after_delete')
    def after_delete(mapper, connection, target):
        session = Session.object_session(target)
        if not session:
            return
            
        audit = AuditLog(
            table=model_class.__tablename__,
            record_id=str(target.id),
            user_id=_get_auth_user_id(session),
            action="DELETE",
            old_values=_serialize_model(target)
        )
        session.add(audit)


# MARKDOWN
#### Audit Trail
CapeBase provides built-in audit logging for all database operations. Every change is automatically tracked with:
- Who made the change (user ID)
- What changed (before/after states)
- When it happened (timestamp)
- Action type (create/update/delete)

```python
from capebase import enable_audit_trail

# Enable audit trail for specific models
enable_audit_trail(Todo)

# Access audit logs
@app.get("/audit/todos/{todo_id}")
async def get_todo_history(
    todo_id: int,
    session=Depends(cape.get_db_dependency())
) -> list[AuditLog]:
    """Get audit history for a specific todo"""
    return await cape.get_audit_trail(Todo, todo_id, session)
```

Example audit log entry:
```json
{
    "id": "550e8400-e29b-41d4-a716-446655440000",
    "table": "todo",
    "record_id": 123,
    "user_id": "user_abc",
    "action": "UPDATE",
    "timestamp": "2024-03-20T14:28:23.382Z",
    "old_values": {
        "is_complete": false
    },
    "new_values": {
        "is_complete": true
    }
}
```

The audit trail is automatically integrated with the authentication system and stores changes in a separate table for compliance and security purposes.

    def _notify_change_from_statement(self, conn, clauseelement, multiparams, params, execution_options, result=None):
        """Process the returned rows from modified statements"""
        if not result:
            return

        # Get event type
        statement_classes = {
            (Insert, "sqlalchemy.sql.annotation.AnnotatedInsert"): "INSERT",
            (Update, "sqlalchemy.sql.annotation.AnnotatedUpdate"): "UPDATE",
            (Delete, "sqlalchemy.sql.annotation.AnnotatedDelete"): "DELETE"
        }

        event_type = None
        for (base_class, annotated_name), event in statement_classes.items():
            if isinstance(clauseelement, base_class) or clauseelement.__class__.__name__ == annotated_name:
                event_type = event
                break

        if not event_type:
            return

        # print(f"Event type: {event_type}, {clauseelement.__class__.__name__}. {clauseelement}")
        # Process returned rows
        try:
            rows = result.all()

            logger.error(f"Success: {clauseelement.__class__.__name__}")
            for row in rows:
                model_config = next(
                    (config for config in self.model_registry.values() 
                     if config.model.__tablename__ == clauseelement.table.name),
                    None
                )
                
                if model_config:
                    # Create model instance from Row
                    model_instance = model_config.model.model_validate(row)

                    change = ModelChange(
                        table=clauseelement.table.name,
                        event=event_type,
                        payload=model_instance,
                        timestamp=datetime.now(),
                    )
                    self._add_task(self.notification_engine.notify(change))
        except Exception as e:
            logger.error(f"Error processing returned rows: {e}, {clauseelement}=")
