Skip to content

API Reference

admin

Admin module for Luthien Control.

auth

Authentication logic for admin users.

AdminAuthService

Source code in luthien_control/admin/auth.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class AdminAuthService:
    """Service for admin authentication operations."""

    async def ensure_default_admin(self, db: AsyncSession) -> None:
        """Ensure a default admin user exists."""
        # Check if any admin users exist
        admins = await admin_user_crud.list_all(db)
        if admins:
            return

        # Create default admin from environment variables
        default_username = os.getenv("ADMIN_USERNAME", "admin")
        default_password = os.getenv("ADMIN_PASSWORD", "changeme")

        logger.warning(f"Creating default admin user '{default_username}'. Please change the password immediately!")

        await admin_user_crud.create(
            db,
            username=default_username,
            password=default_password,
            is_superuser=True,
        )

    async def authenticate(self, db: AsyncSession, username: str, password: str) -> Optional[AdminUser]:
        """Authenticate admin user."""
        # Clean up expired sessions
        await admin_session_crud.cleanup_expired_sessions(db)

        # Verify credentials
        return await admin_user_crud.verify_password(db, username, password)

    async def create_session(self, db: AsyncSession, admin_user: AdminUser) -> AdminSession:
        """Create a new session for authenticated user."""
        session_hours = int(os.getenv("ADMIN_SESSION_HOURS", "24"))
        if admin_user.id is None:
            raise ValueError("Admin user ID is None")
        return await admin_session_crud.create_session(db, admin_user.id, hours=session_hours)

    async def get_user_from_session(self, db: AsyncSession, session_token: str) -> Optional[AdminUser]:
        """Get admin user from session token."""
        session = await admin_session_crud.get_valid_session(db, session_token)
        if not session:
            return None

        user = await admin_user_crud.get_by_id(db, session.admin_user_id)
        if not user or not user.is_active:
            return None

        return user

    async def logout(self, db: AsyncSession, session_token: str) -> bool:
        """Logout by deleting session."""
        return await admin_session_crud.delete_session(db, session_token)

Service for admin authentication operations.

authenticate(db, username, password) async
Source code in luthien_control/admin/auth.py
38
39
40
41
42
43
44
async def authenticate(self, db: AsyncSession, username: str, password: str) -> Optional[AdminUser]:
    """Authenticate admin user."""
    # Clean up expired sessions
    await admin_session_crud.cleanup_expired_sessions(db)

    # Verify credentials
    return await admin_user_crud.verify_password(db, username, password)

Authenticate admin user.

create_session(db, admin_user) async
Source code in luthien_control/admin/auth.py
46
47
48
49
50
51
async def create_session(self, db: AsyncSession, admin_user: AdminUser) -> AdminSession:
    """Create a new session for authenticated user."""
    session_hours = int(os.getenv("ADMIN_SESSION_HOURS", "24"))
    if admin_user.id is None:
        raise ValueError("Admin user ID is None")
    return await admin_session_crud.create_session(db, admin_user.id, hours=session_hours)

Create a new session for authenticated user.

ensure_default_admin(db) async
Source code in luthien_control/admin/auth.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
async def ensure_default_admin(self, db: AsyncSession) -> None:
    """Ensure a default admin user exists."""
    # Check if any admin users exist
    admins = await admin_user_crud.list_all(db)
    if admins:
        return

    # Create default admin from environment variables
    default_username = os.getenv("ADMIN_USERNAME", "admin")
    default_password = os.getenv("ADMIN_PASSWORD", "changeme")

    logger.warning(f"Creating default admin user '{default_username}'. Please change the password immediately!")

    await admin_user_crud.create(
        db,
        username=default_username,
        password=default_password,
        is_superuser=True,
    )

Ensure a default admin user exists.

get_user_from_session(db, session_token) async
Source code in luthien_control/admin/auth.py
53
54
55
56
57
58
59
60
61
62
63
async def get_user_from_session(self, db: AsyncSession, session_token: str) -> Optional[AdminUser]:
    """Get admin user from session token."""
    session = await admin_session_crud.get_valid_session(db, session_token)
    if not session:
        return None

    user = await admin_user_crud.get_by_id(db, session.admin_user_id)
    if not user or not user.is_active:
        return None

    return user

Get admin user from session token.

logout(db, session_token) async
Source code in luthien_control/admin/auth.py
65
66
67
async def logout(self, db: AsyncSession, session_token: str) -> bool:
    """Logout by deleting session."""
    return await admin_session_crud.delete_session(db, session_token)

Logout by deleting session.

crud

Admin CRUD operations.

admin_user

CRUD operations for admin users.

AdminSessionCRUD
Source code in luthien_control/admin/crud/admin_user.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class AdminSessionCRUD:
    """CRUD operations for admin sessions."""

    async def create_session(self, db: AsyncSession, admin_user_id: int, hours: int = 24) -> AdminSession:
        """Create a new admin session."""
        session_token = secrets.token_urlsafe(32)
        expires_at = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=hours)

        session = AdminSession(
            session_token=session_token,
            admin_user_id=admin_user_id,
            expires_at=expires_at,
        )

        db.add(session)
        await db.commit()
        await db.refresh(session)
        return session

    async def get_valid_session(self, db: AsyncSession, session_token: str) -> Optional[AdminSession]:
        """Get a valid (non-expired) session by token."""
        result = await db.execute(
            select(AdminSession).where(
                and_(
                    AdminSession.session_token == session_token,  # type: ignore
                    AdminSession.expires_at > datetime.now(timezone.utc).replace(tzinfo=None),  # type: ignore
                )
            )
        )
        return result.scalar_one_or_none()

    async def delete_session(self, db: AsyncSession, session_token: str) -> bool:
        """Delete a session (logout)."""
        result = await db.execute(select(AdminSession).where(AdminSession.session_token == session_token))  # type: ignore
        session = result.scalar_one_or_none()

        if session:
            await db.delete(session)
            await db.commit()
            return True

        return False

    async def cleanup_expired_sessions(self, db: AsyncSession) -> int:
        """Clean up expired sessions."""
        result = await db.execute(
            select(AdminSession).where(AdminSession.expires_at <= datetime.now(timezone.utc).replace(tzinfo=None))  # type: ignore
        )
        expired_sessions = list(result.scalars().all())

        count = len(expired_sessions)
        for session in expired_sessions:
            await db.delete(session)

        if count > 0:
            await db.commit()

        return count

CRUD operations for admin sessions.

cleanup_expired_sessions(db) async
Source code in luthien_control/admin/crud/admin_user.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
async def cleanup_expired_sessions(self, db: AsyncSession) -> int:
    """Clean up expired sessions."""
    result = await db.execute(
        select(AdminSession).where(AdminSession.expires_at <= datetime.now(timezone.utc).replace(tzinfo=None))  # type: ignore
    )
    expired_sessions = list(result.scalars().all())

    count = len(expired_sessions)
    for session in expired_sessions:
        await db.delete(session)

    if count > 0:
        await db.commit()

    return count

Clean up expired sessions.

create_session(db, admin_user_id, hours=24) async
Source code in luthien_control/admin/crud/admin_user.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
async def create_session(self, db: AsyncSession, admin_user_id: int, hours: int = 24) -> AdminSession:
    """Create a new admin session."""
    session_token = secrets.token_urlsafe(32)
    expires_at = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=hours)

    session = AdminSession(
        session_token=session_token,
        admin_user_id=admin_user_id,
        expires_at=expires_at,
    )

    db.add(session)
    await db.commit()
    await db.refresh(session)
    return session

Create a new admin session.

delete_session(db, session_token) async
Source code in luthien_control/admin/crud/admin_user.py
 99
100
101
102
103
104
105
106
107
108
109
async def delete_session(self, db: AsyncSession, session_token: str) -> bool:
    """Delete a session (logout)."""
    result = await db.execute(select(AdminSession).where(AdminSession.session_token == session_token))  # type: ignore
    session = result.scalar_one_or_none()

    if session:
        await db.delete(session)
        await db.commit()
        return True

    return False

Delete a session (logout).

get_valid_session(db, session_token) async
Source code in luthien_control/admin/crud/admin_user.py
87
88
89
90
91
92
93
94
95
96
97
async def get_valid_session(self, db: AsyncSession, session_token: str) -> Optional[AdminSession]:
    """Get a valid (non-expired) session by token."""
    result = await db.execute(
        select(AdminSession).where(
            and_(
                AdminSession.session_token == session_token,  # type: ignore
                AdminSession.expires_at > datetime.now(timezone.utc).replace(tzinfo=None),  # type: ignore
            )
        )
    )
    return result.scalar_one_or_none()

Get a valid (non-expired) session by token.

AdminUserCRUD
Source code in luthien_control/admin/crud/admin_user.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class AdminUserCRUD:
    """CRUD operations for admin users."""

    async def get_by_username(self, db: AsyncSession, username: str) -> Optional[AdminUser]:
        """Get admin user by username."""
        result = await db.execute(select(AdminUser).where(AdminUser.username == username))  # type: ignore
        return result.scalar_one_or_none()

    async def get_by_id(self, db: AsyncSession, user_id: int) -> Optional[AdminUser]:
        """Get admin user by ID."""
        result = await db.execute(select(AdminUser).where(AdminUser.id == user_id))  # type: ignore
        return result.scalar_one_or_none()

    async def create(
        self,
        db: AsyncSession,
        username: str,
        password: str,
        is_superuser: bool = False,
    ) -> AdminUser:
        """Create a new admin user."""
        password_hash = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt())

        admin_user = AdminUser(
            username=username,
            password_hash=password_hash.decode("utf-8"),
            is_superuser=is_superuser,
        )

        db.add(admin_user)
        await db.commit()
        await db.refresh(admin_user)
        return admin_user

    async def verify_password(self, db: AsyncSession, username: str, password: str) -> Optional[AdminUser]:
        """Verify username and password."""
        user = await self.get_by_username(db, username)
        if not user or not user.is_active:
            return None

        if bcrypt.checkpw(password.encode("utf-8"), user.password_hash.encode("utf-8")):
            # Update last login
            user.last_login = datetime.now(timezone.utc).replace(tzinfo=None)
            await db.commit()
            return user

        return None

    async def list_all(self, db: AsyncSession) -> List[AdminUser]:
        """List all admin users."""
        result = await db.execute(select(AdminUser).order_by(AdminUser.created_at))  # type: ignore
        return list(result.scalars().all())

CRUD operations for admin users.

create(db, username, password, is_superuser=False) async
Source code in luthien_control/admin/crud/admin_user.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
async def create(
    self,
    db: AsyncSession,
    username: str,
    password: str,
    is_superuser: bool = False,
) -> AdminUser:
    """Create a new admin user."""
    password_hash = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt())

    admin_user = AdminUser(
        username=username,
        password_hash=password_hash.decode("utf-8"),
        is_superuser=is_superuser,
    )

    db.add(admin_user)
    await db.commit()
    await db.refresh(admin_user)
    return admin_user

Create a new admin user.

get_by_id(db, user_id) async
Source code in luthien_control/admin/crud/admin_user.py
22
23
24
25
async def get_by_id(self, db: AsyncSession, user_id: int) -> Optional[AdminUser]:
    """Get admin user by ID."""
    result = await db.execute(select(AdminUser).where(AdminUser.id == user_id))  # type: ignore
    return result.scalar_one_or_none()

Get admin user by ID.

get_by_username(db, username) async
Source code in luthien_control/admin/crud/admin_user.py
17
18
19
20
async def get_by_username(self, db: AsyncSession, username: str) -> Optional[AdminUser]:
    """Get admin user by username."""
    result = await db.execute(select(AdminUser).where(AdminUser.username == username))  # type: ignore
    return result.scalar_one_or_none()

Get admin user by username.

list_all(db) async
Source code in luthien_control/admin/crud/admin_user.py
62
63
64
65
async def list_all(self, db: AsyncSession) -> List[AdminUser]:
    """List all admin users."""
    result = await db.execute(select(AdminUser).order_by(AdminUser.created_at))  # type: ignore
    return list(result.scalars().all())

List all admin users.

verify_password(db, username, password) async
Source code in luthien_control/admin/crud/admin_user.py
48
49
50
51
52
53
54
55
56
57
58
59
60
async def verify_password(self, db: AsyncSession, username: str, password: str) -> Optional[AdminUser]:
    """Verify username and password."""
    user = await self.get_by_username(db, username)
    if not user or not user.is_active:
        return None

    if bcrypt.checkpw(password.encode("utf-8"), user.password_hash.encode("utf-8")):
        # Update last login
        user.last_login = datetime.now(timezone.utc).replace(tzinfo=None)
        await db.commit()
        return user

    return None

Verify username and password.

dependencies

Dependencies for admin authentication.

CSRFProtection

Source code in luthien_control/admin/dependencies.py
34
35
36
37
38
39
40
41
42
43
44
class CSRFProtection:
    """CSRF protection for forms."""

    def __init__(self):
        self.token_name = "csrf_token"

    async def generate_token(self) -> str:
        """Generate CSRF token."""
        import secrets

        return secrets.token_urlsafe(32)

CSRF protection for forms.

generate_token() async
Source code in luthien_control/admin/dependencies.py
40
41
42
43
44
async def generate_token(self) -> str:
    """Generate CSRF token."""
    import secrets

    return secrets.token_urlsafe(32)

Generate CSRF token.

get_current_admin(session_token=None, db=Depends(get_db_session)) async

Source code in luthien_control/admin/dependencies.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
async def get_current_admin(
    session_token: Annotated[Optional[str], Cookie()] = None,
    db: AsyncSession = Depends(get_db_session),
) -> AdminUser:
    """Get current authenticated admin user from session cookie."""
    if not session_token:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Not authenticated",
        )

    user = await admin_auth_service.get_user_from_session(db, session_token)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid or expired session",
        )

    return user

Get current authenticated admin user from session cookie.

router

Admin router for authentication and policy management.

admin_home(request, current_admin, db=Depends(get_db_session)) async

Source code in luthien_control/admin/router.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
@router.get("/", response_class=HTMLResponse)
async def admin_home(
    request: Request,
    current_admin: Annotated[AdminUser, Depends(get_current_admin)],
    db: AsyncSession = Depends(get_db_session),
) -> HTMLResponse:
    """Admin dashboard hub."""
    # Get some basic stats for the dashboard
    policies = await list_policies(db, active_only=False)
    active_policies = [p for p in policies if p.is_active]

    return templates.TemplateResponse(
        request,
        "dashboard.html",
        {
            "current_admin": current_admin,
            "total_policies": len(policies),
            "active_policies": len(active_policies),
        },
    )

Admin dashboard hub.

create_policy_handler(request, name, type, config, current_admin, db=Depends(get_db_session), description=None, is_active=False, csrf_token='') async

Source code in luthien_control/admin/router.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
@router.post("/policies/new", response_model=None)
async def create_policy_handler(
    request: Request,
    name: Annotated[str, Form()],
    type: Annotated[str, Form()],
    config: Annotated[str, Form()],
    current_admin: Annotated[AdminUser, Depends(get_current_admin)],
    db: AsyncSession = Depends(get_db_session),
    description: Annotated[Optional[str], Form()] = None,
    is_active: Annotated[bool, Form()] = False,
    csrf_token: Annotated[str, Form(alias="csrf_token")] = "",
):
    """Handle new policy creation."""
    # Validate CSRF token
    cookie_csrf = request.cookies.get("csrf_token")
    if not cookie_csrf or cookie_csrf != csrf_token:
        raise HTTPException(status_code=400, detail="Invalid request")

    # Parse and validate JSON config
    try:
        config_dict = json.loads(config)
    except json.JSONDecodeError as e:
        new_csrf = await csrf_protection.generate_token()
        response = templates.TemplateResponse(
            request,
            "policy_new.html",
            {
                "current_admin": current_admin,
                "csrf_token": new_csrf,
                "error": f"Invalid JSON: {str(e)}",
                "form_data": {
                    "name": name,
                    "type": type,
                    "config": config,
                    "description": description,
                    "is_active": is_active,
                },
            },
            status_code=400,
        )
        response.set_cookie(
            key="csrf_token",
            value=new_csrf,
            httponly=True,
            secure=request.url.scheme == "https",
            samesite="strict",
        )
        return response  # type: ignore

    # Create new policy
    try:
        policy = ControlPolicy(
            name=name,
            type=type,
            config=config_dict,
            description=description,
            is_active=is_active,
        )
        await save_policy_to_db(db, policy)
    except Exception as e:
        new_csrf = await csrf_protection.generate_token()
        response = templates.TemplateResponse(
            request,
            "policy_new.html",
            {
                "current_admin": current_admin,
                "csrf_token": new_csrf,
                "error": f"Creation failed: {str(e)}",
                "form_data": {
                    "name": name,
                    "type": type,
                    "config": config,
                    "description": description,
                    "is_active": is_active,
                },
            },
            status_code=400,
        )
        response.set_cookie(
            key="csrf_token",
            value=new_csrf,
            httponly=True,
            secure=request.url.scheme == "https",
            samesite="strict",
        )
        return response  # type: ignore

    redirect = RedirectResponse(url="/admin/policies", status_code=303)
    redirect.delete_cookie(key="csrf_token")
    return redirect

Handle new policy creation.

edit_policy_page(request, policy_name, current_admin, db=Depends(get_db_session)) async

Source code in luthien_control/admin/router.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
@router.get("/policies/{policy_name}/edit", response_class=HTMLResponse)
async def edit_policy_page(
    request: Request,
    policy_name: str,
    current_admin: Annotated[AdminUser, Depends(get_current_admin)],
    db: AsyncSession = Depends(get_db_session),
) -> HTMLResponse:
    """Display policy edit page."""
    policy = await get_policy_by_name(db, policy_name)
    if not policy:
        raise HTTPException(status_code=404, detail="Policy not found")

    csrf_token = await csrf_protection.generate_token()
    response = templates.TemplateResponse(
        request,
        "policy_edit.html",
        {
            "current_admin": current_admin,
            "policy": policy,
            "csrf_token": csrf_token,
            "config_json": json.dumps(policy.config, indent=2),
            "error": None,
        },
    )
    response.set_cookie(
        key="csrf_token",
        value=csrf_token,
        httponly=True,
        secure=request.url.scheme == "https",
        samesite="strict",
    )
    return response

Display policy edit page.

login(request, response, username, password, csrf_token, db=Depends(get_db_session)) async

Source code in luthien_control/admin/router.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@router.post("/login", response_model=None)
async def login(
    request: Request,
    response: Response,
    username: Annotated[str, Form()],
    password: Annotated[str, Form()],
    csrf_token: Annotated[str, Form(alias="csrf_token")],
    db: AsyncSession = Depends(get_db_session),
):
    """Handle login form submission."""
    # Validate CSRF token
    cookie_csrf = request.cookies.get("csrf_token")
    if not cookie_csrf or cookie_csrf != csrf_token:
        return templates.TemplateResponse(
            request,
            "login.html",
            {
                "csrf_token": await csrf_protection.generate_token(),
                "error": "Invalid request. Please try again.",
            },
            status_code=400,
        )

    # Authenticate user
    user = await admin_auth_service.authenticate(db, username, password)
    if not user:
        new_csrf = await csrf_protection.generate_token()
        response = templates.TemplateResponse(
            request,
            "login.html",
            {
                "csrf_token": new_csrf,
                "error": "Invalid username or password",
            },
            status_code=401,
        )
        response.set_cookie(
            key="csrf_token",
            value=new_csrf,
            httponly=True,
            secure=request.url.scheme == "https",
            samesite="strict",
        )
        return response  # type: ignore

    # Create session
    session = await admin_auth_service.create_session(db, user)

    redirect = RedirectResponse(url="/admin/policies", status_code=303)
    redirect.set_cookie(
        key="session_token",
        value=session.session_token,
        httponly=True,
        secure=request.url.scheme == "https",
        samesite="strict",
        max_age=86400,  # 24 hours
    )
    redirect.delete_cookie(key="csrf_token")

    return redirect

Handle login form submission.

login_page(request) async

Source code in luthien_control/admin/router.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@router.get("/login", response_class=HTMLResponse)
async def login_page(request: Request) -> HTMLResponse:
    """Display login page."""
    csrf_token = await csrf_protection.generate_token()
    response = templates.TemplateResponse(
        request,
        "login.html",
        {
            "csrf_token": csrf_token,
            "error": None,
        },
    )
    response.set_cookie(
        key="csrf_token",
        value=csrf_token,
        httponly=True,
        secure=request.url.scheme == "https",
        samesite="strict",
    )
    return response

Display login page.

logout(request, db=Depends(get_db_session)) async

Source code in luthien_control/admin/router.py
113
114
115
116
117
118
119
120
121
122
123
124
125
@router.get("/logout")
async def logout(
    request: Request,
    db: AsyncSession = Depends(get_db_session),
):
    """Logout and redirect to login page."""
    session_token = request.cookies.get("session_token")
    if session_token:
        await admin_auth_service.logout(db, session_token)

    redirect = RedirectResponse(url="/admin/login", status_code=303)
    redirect.delete_cookie(key="session_token")
    return redirect

Logout and redirect to login page.

new_policy_page(request, current_admin) async

Source code in luthien_control/admin/router.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
@router.get("/policies/new", response_class=HTMLResponse)
async def new_policy_page(
    request: Request,
    current_admin: Annotated[AdminUser, Depends(get_current_admin)],
) -> HTMLResponse:
    """Display new policy creation page."""
    csrf_token = await csrf_protection.generate_token()
    response = templates.TemplateResponse(
        request,
        "policy_new.html",
        {
            "current_admin": current_admin,
            "csrf_token": csrf_token,
            "error": None,
        },
    )
    response.set_cookie(
        key="csrf_token",
        value=csrf_token,
        httponly=True,
        secure=request.url.scheme == "https",
        samesite="strict",
    )
    return response

Display new policy creation page.

policies_list(request, current_admin, db=Depends(get_db_session)) async

Source code in luthien_control/admin/router.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@router.get("/policies", response_class=HTMLResponse)
async def policies_list(
    request: Request,
    current_admin: Annotated[AdminUser, Depends(get_current_admin)],
    db: AsyncSession = Depends(get_db_session),
) -> HTMLResponse:
    """List all control policies."""
    policies = await list_policies(db, active_only=False)

    return templates.TemplateResponse(
        request,
        "policies.html",
        {
            "current_admin": current_admin,
            "policies": policies,
        },
    )

List all control policies.

update_policy_handler(request, policy_name, config, current_admin, db=Depends(get_db_session), description=None, is_active=False, csrf_token='') async

Source code in luthien_control/admin/router.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@router.post("/policies/{policy_name}/edit", response_model=None)
async def update_policy_handler(
    request: Request,
    policy_name: str,
    config: Annotated[str, Form()],
    current_admin: Annotated[AdminUser, Depends(get_current_admin)],
    db: AsyncSession = Depends(get_db_session),
    description: Annotated[Optional[str], Form()] = None,
    is_active: Annotated[bool, Form()] = False,
    csrf_token: Annotated[str, Form(alias="csrf_token")] = "",
):
    """Handle policy update form submission."""
    # Validate CSRF token
    cookie_csrf = request.cookies.get("csrf_token")
    if not cookie_csrf or cookie_csrf != csrf_token:
        raise HTTPException(status_code=400, detail="Invalid request")

    # Get the policy first
    policy = await get_policy_by_name(db, policy_name)
    if not policy:
        raise HTTPException(status_code=404, detail="Policy not found")

    # Parse and validate JSON config
    try:
        config_dict = json.loads(config)
    except json.JSONDecodeError as e:
        new_csrf = await csrf_protection.generate_token()
        response = templates.TemplateResponse(
            request,
            "policy_edit.html",
            {
                "current_admin": current_admin,
                "policy": policy,
                "csrf_token": new_csrf,
                "config_json": config,
                "error": f"Invalid JSON: {str(e)}",
            },
            status_code=400,
        )
        response.set_cookie(
            key="csrf_token",
            value=new_csrf,
            httponly=True,
            secure=request.url.scheme == "https",
            samesite="strict",
        )
        return response  # type: ignore

    # Update policy
    try:
        policy.config = config_dict
        if description is not None:
            policy.description = description
        policy.is_active = is_active

        db.add(policy)
        await db.commit()
        await db.refresh(policy)
    except Exception as e:
        new_csrf = await csrf_protection.generate_token()
        response = templates.TemplateResponse(
            request,
            "policy_edit.html",
            {
                "current_admin": current_admin,
                "policy": policy,
                "csrf_token": new_csrf,
                "config_json": config,
                "error": f"Update failed: {str(e)}",
            },
            status_code=400,
        )
        response.set_cookie(
            key="csrf_token",
            value=new_csrf,
            httponly=True,
            secure=request.url.scheme == "https",
            samesite="strict",
        )
        return response  # type: ignore

    redirect = RedirectResponse(url="/admin/policies", status_code=303)
    redirect.delete_cookie(key="csrf_token")
    return redirect

Handle policy update form submission.

api

openai_chat_completions

Choice

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
68
69
70
71
72
73
74
class Choice(DeepEventedModel):
    """A single choice in a chat completion response."""

    index: int = Field(default=0)
    message: Message = Field(default_factory=Message)
    finish_reason: Optional[str] = Field(default=None)
    logprobs: Optional[LogProbs] = Field(default=None)

A single choice in a chat completion response.

CompletionTokensDetails

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
84
85
86
87
88
89
90
class CompletionTokensDetails(DeepEventedModel):
    """Details about completion token usage."""

    reasoning_tokens: int = Field(default=0)
    audio_tokens: int = Field(default=0)
    accepted_prediction_tokens: int = Field(default=0)
    rejected_prediction_tokens: int = Field(default=0)

Details about completion token usage.

ContentPartImage

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
120
121
122
123
124
class ContentPartImage(DeepEventedModel):
    """An image content part."""

    type: str = Field(default="image_url", frozen=True)
    image_url: ImageUrl = Field()

An image content part.

ContentPartText

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
113
114
115
116
117
class ContentPartText(DeepEventedModel):
    """A text content part."""

    type: str = Field(default="text", frozen=True)
    text: str = Field()

A text content part.

FunctionDefinition

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
137
138
139
140
141
142
class FunctionDefinition(DeepEventedModel):
    """The definition of a function that can be called by the model."""

    name: str = Field()
    description: Optional[str] = Field(default=None)
    parameters: Optional[EDict[str, Any]] = Field(default_factory=EDict)

The definition of a function that can be called by the model.

ImageUrl

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
106
107
108
109
110
class ImageUrl(DeepEventedModel):
    """The image URL details."""

    url: str
    detail: Literal["auto", "low", "high"] = "auto"

The image URL details.

LogProbs

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
61
62
63
64
65
class LogProbs(DeepEventedModel):
    """Log probability information for the choice."""

    content: Optional[EList[EDict]] = Field(default_factory=lambda: EList[EDict]())
    refusal: Optional[EList[EDict]] = Field(default_factory=lambda: EList[EDict]())

Log probability information for the choice.

Message

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
49
50
51
52
53
54
55
56
57
58
class Message(DeepEventedModel):
    """A message in a chat completion."""

    content: Optional[str] = Field(default=None)
    refusal: Optional[str] = Field(default=None)
    role: str = Field(default_factory=str)
    annotations: EList[Annotation] = Field(default_factory=lambda: EList[Annotation]())
    audio: Optional[Audio] = Field(default=None)
    function_call: Optional[FunctionCall] = Field(default=None)
    tool_calls: Optional[EList[ToolCall]] = Field(default_factory=lambda: EList[ToolCall]())

A message in a chat completion.

OpenAIChatCompletionsRequest

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/request.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class OpenAIChatCompletionsRequest(DeepEventedModel):
    """Request context for OpenAI chat completions.

    Based on the OpenAI API reference:
    https://platform.openai.com/docs/api-reference/chat/create?lang=python
    (retrieved 2025-06-16)

    This model is evented and will emit a `changed` signal on any modification.
    """

    messages: EList[Message] = Field()
    model: str = Field()
    audio: Optional[Audio] = Field(default=None)
    frequency_penalty: Optional[float] = Field(default=None)
    function_call: Optional[RequestFunctionCallSpec] = Field(default=None)  # deprecated
    functions: Optional[EList[FunctionDefinition]] = Field(default=None)  # deprecated
    logit_bias: Optional[EDict[str, float]] = Field(default=None)
    logprobs: Optional[bool] = Field(default=None)
    max_completion_tokens: Optional[int] = Field(default=None)
    max_tokens: Optional[int] = Field(default=None)  # deprecated
    metadata: Optional[EDict[str, str]] = Field(default=None)
    modalities: Optional[EList[str]] = Field(default=None)
    n: Optional[int] = Field(default=None)
    parallel_tool_calls: Optional[bool] = Field(default=None)
    prediction: Optional[Prediction] = Field(default=None)
    presence_penalty: Optional[float] = Field(default=None)
    reasoning_effort: Optional[str] = Field(default=None)  # "low", "medium", "high"
    response_format: Optional[ResponseFormat] = Field(default=None)
    seed: Optional[int] = Field(default=None)
    service_tier: Optional[str] = Field(default=None)
    stop: Optional[str | EList[str]] = Field(default=None)
    store: Optional[bool] = Field(default=None)
    stream: Optional[bool] = Field(default=None)
    stream_options: Optional[StreamOptions] = Field(default=None)
    temperature: Optional[float] = Field(default=None)
    tool_choice: Optional[ToolChoice] = Field(default=None)
    tools: Optional[EList[ToolDefinition]] = Field(default=None)
    top_logprobs: Optional[int] = Field(default=None)
    top_p: Optional[float] = Field(default=None)
    user: Optional[str] = Field(default=None)
    web_search_options: Optional[WebSearchOptions] = Field(default=None)

Request context for OpenAI chat completions.

Based on the OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create?lang=python (retrieved 2025-06-16)

This model is evented and will emit a changed signal on any modification.

OpenAIChatCompletionsResponse

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/response.py
14
15
16
17
18
19
20
21
22
23
24
class OpenAIChatCompletionsResponse(DeepEventedModel):
    """The request for a chat completion."""

    choices: EList[Choice] = Field(default_factory=lambda: EList[Choice]())
    created: int = Field()
    id: str = Field()
    model: str = Field()
    object: str = Field(default="chat.completion")
    service_tier: Optional[str] = Field(default=None)
    system_fingerprint: Optional[str] = Field(default=None)
    usage: Usage = Field(default_factory=Usage)

The request for a chat completion.

PromptTokensDetails

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
77
78
79
80
81
class PromptTokensDetails(DeepEventedModel):
    """Details about prompt token usage."""

    cached_tokens: int = Field(default=0)
    audio_tokens: int = Field(default=0)

Details about prompt token usage.

ResponseFormat

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
127
128
129
130
131
132
133
134
class ResponseFormat(DeepEventedModel):
    """An object specifying the format that the model must output.

    See https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses
    """

    type: Literal["text", "json_object", "json_schema"] = Field(default="text")
    json_schema: Optional[EDict[str, Type]] = Field(default=None)

An object specifying the format that the model must output.

See https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses

ToolChoice

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
158
159
160
161
162
class ToolChoice(DeepEventedModel):
    """A specific tool choice."""

    type: str = Field(default="function", frozen=True)
    function: ToolChoiceFunction = Field()

A specific tool choice.

ToolChoiceFunction

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
152
153
154
155
class ToolChoiceFunction(DeepEventedModel):
    """The function to call in a tool choice."""

    name: str = Field()

The function to call in a tool choice.

ToolDefinition

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
145
146
147
148
149
class ToolDefinition(DeepEventedModel):
    """A tool that can be used by the model."""

    type: str = Field(default="function", frozen=True)
    function: FunctionDefinition = Field()

A tool that can be used by the model.

Usage

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
 93
 94
 95
 96
 97
 98
 99
100
class Usage(DeepEventedModel):
    """Token usage statistics for the chat completion request."""

    prompt_tokens: int = Field(default=0)
    completion_tokens: int = Field(default=0)
    total_tokens: int = Field(default=0)
    prompt_tokens_details: Optional[PromptTokensDetails] = Field(default_factory=PromptTokensDetails)
    completion_tokens_details: Optional[CompletionTokensDetails] = Field(default_factory=CompletionTokensDetails)

Token usage statistics for the chat completion request.

datatypes

Choice

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
68
69
70
71
72
73
74
class Choice(DeepEventedModel):
    """A single choice in a chat completion response."""

    index: int = Field(default=0)
    message: Message = Field(default_factory=Message)
    finish_reason: Optional[str] = Field(default=None)
    logprobs: Optional[LogProbs] = Field(default=None)

A single choice in a chat completion response.

CompletionTokensDetails

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
84
85
86
87
88
89
90
class CompletionTokensDetails(DeepEventedModel):
    """Details about completion token usage."""

    reasoning_tokens: int = Field(default=0)
    audio_tokens: int = Field(default=0)
    accepted_prediction_tokens: int = Field(default=0)
    rejected_prediction_tokens: int = Field(default=0)

Details about completion token usage.

ContentPartImage

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
120
121
122
123
124
class ContentPartImage(DeepEventedModel):
    """An image content part."""

    type: str = Field(default="image_url", frozen=True)
    image_url: ImageUrl = Field()

An image content part.

ContentPartText

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
113
114
115
116
117
class ContentPartText(DeepEventedModel):
    """A text content part."""

    type: str = Field(default="text", frozen=True)
    text: str = Field()

A text content part.

FunctionDefinition

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
137
138
139
140
141
142
class FunctionDefinition(DeepEventedModel):
    """The definition of a function that can be called by the model."""

    name: str = Field()
    description: Optional[str] = Field(default=None)
    parameters: Optional[EDict[str, Any]] = Field(default_factory=EDict)

The definition of a function that can be called by the model.

ImageUrl

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
106
107
108
109
110
class ImageUrl(DeepEventedModel):
    """The image URL details."""

    url: str
    detail: Literal["auto", "low", "high"] = "auto"

The image URL details.

LogProbs

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
61
62
63
64
65
class LogProbs(DeepEventedModel):
    """Log probability information for the choice."""

    content: Optional[EList[EDict]] = Field(default_factory=lambda: EList[EDict]())
    refusal: Optional[EList[EDict]] = Field(default_factory=lambda: EList[EDict]())

Log probability information for the choice.

Message

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
49
50
51
52
53
54
55
56
57
58
class Message(DeepEventedModel):
    """A message in a chat completion."""

    content: Optional[str] = Field(default=None)
    refusal: Optional[str] = Field(default=None)
    role: str = Field(default_factory=str)
    annotations: EList[Annotation] = Field(default_factory=lambda: EList[Annotation]())
    audio: Optional[Audio] = Field(default=None)
    function_call: Optional[FunctionCall] = Field(default=None)
    tool_calls: Optional[EList[ToolCall]] = Field(default_factory=lambda: EList[ToolCall]())

A message in a chat completion.

PromptTokensDetails

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
77
78
79
80
81
class PromptTokensDetails(DeepEventedModel):
    """Details about prompt token usage."""

    cached_tokens: int = Field(default=0)
    audio_tokens: int = Field(default=0)

Details about prompt token usage.

RequestFunctionCall

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
165
166
167
168
class RequestFunctionCall(DeepEventedModel):
    """A function call in a request."""

    name: str = Field()

A function call in a request.

ResponseFormat

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
127
128
129
130
131
132
133
134
class ResponseFormat(DeepEventedModel):
    """An object specifying the format that the model must output.

    See https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses
    """

    type: Literal["text", "json_object", "json_schema"] = Field(default="text")
    json_schema: Optional[EDict[str, Type]] = Field(default=None)

An object specifying the format that the model must output.

See https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses

ToolChoice

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
158
159
160
161
162
class ToolChoice(DeepEventedModel):
    """A specific tool choice."""

    type: str = Field(default="function", frozen=True)
    function: ToolChoiceFunction = Field()

A specific tool choice.

ToolChoiceFunction

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
152
153
154
155
class ToolChoiceFunction(DeepEventedModel):
    """The function to call in a tool choice."""

    name: str = Field()

The function to call in a tool choice.

ToolDefinition

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
145
146
147
148
149
class ToolDefinition(DeepEventedModel):
    """A tool that can be used by the model."""

    type: str = Field(default="function", frozen=True)
    function: FunctionDefinition = Field()

A tool that can be used by the model.

Usage

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/datatypes.py
 93
 94
 95
 96
 97
 98
 99
100
class Usage(DeepEventedModel):
    """Token usage statistics for the chat completion request."""

    prompt_tokens: int = Field(default=0)
    completion_tokens: int = Field(default=0)
    total_tokens: int = Field(default=0)
    prompt_tokens_details: Optional[PromptTokensDetails] = Field(default_factory=PromptTokensDetails)
    completion_tokens_details: Optional[CompletionTokensDetails] = Field(default_factory=CompletionTokensDetails)

Token usage statistics for the chat completion request.

request

OpenAIChatCompletionsRequest

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/request.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class OpenAIChatCompletionsRequest(DeepEventedModel):
    """Request context for OpenAI chat completions.

    Based on the OpenAI API reference:
    https://platform.openai.com/docs/api-reference/chat/create?lang=python
    (retrieved 2025-06-16)

    This model is evented and will emit a `changed` signal on any modification.
    """

    messages: EList[Message] = Field()
    model: str = Field()
    audio: Optional[Audio] = Field(default=None)
    frequency_penalty: Optional[float] = Field(default=None)
    function_call: Optional[RequestFunctionCallSpec] = Field(default=None)  # deprecated
    functions: Optional[EList[FunctionDefinition]] = Field(default=None)  # deprecated
    logit_bias: Optional[EDict[str, float]] = Field(default=None)
    logprobs: Optional[bool] = Field(default=None)
    max_completion_tokens: Optional[int] = Field(default=None)
    max_tokens: Optional[int] = Field(default=None)  # deprecated
    metadata: Optional[EDict[str, str]] = Field(default=None)
    modalities: Optional[EList[str]] = Field(default=None)
    n: Optional[int] = Field(default=None)
    parallel_tool_calls: Optional[bool] = Field(default=None)
    prediction: Optional[Prediction] = Field(default=None)
    presence_penalty: Optional[float] = Field(default=None)
    reasoning_effort: Optional[str] = Field(default=None)  # "low", "medium", "high"
    response_format: Optional[ResponseFormat] = Field(default=None)
    seed: Optional[int] = Field(default=None)
    service_tier: Optional[str] = Field(default=None)
    stop: Optional[str | EList[str]] = Field(default=None)
    store: Optional[bool] = Field(default=None)
    stream: Optional[bool] = Field(default=None)
    stream_options: Optional[StreamOptions] = Field(default=None)
    temperature: Optional[float] = Field(default=None)
    tool_choice: Optional[ToolChoice] = Field(default=None)
    tools: Optional[EList[ToolDefinition]] = Field(default=None)
    top_logprobs: Optional[int] = Field(default=None)
    top_p: Optional[float] = Field(default=None)
    user: Optional[str] = Field(default=None)
    web_search_options: Optional[WebSearchOptions] = Field(default=None)

Request context for OpenAI chat completions.

Based on the OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create?lang=python (retrieved 2025-06-16)

This model is evented and will emit a changed signal on any modification.

response

OpenAIChatCompletionsResponse

Bases: DeepEventedModel

Source code in luthien_control/api/openai_chat_completions/response.py
14
15
16
17
18
19
20
21
22
23
24
class OpenAIChatCompletionsResponse(DeepEventedModel):
    """The request for a chat completion."""

    choices: EList[Choice] = Field(default_factory=lambda: EList[Choice]())
    created: int = Field()
    id: str = Field()
    model: str = Field()
    object: str = Field(default="chat.completion")
    service_tier: Optional[str] = Field(default=None)
    system_fingerprint: Optional[str] = Field(default=None)
    usage: Usage = Field(default_factory=Usage)

The request for a chat completion.

control_policy

add_api_key_header

AddApiKeyHeaderPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/add_api_key_header.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class AddApiKeyHeaderPolicy(ControlPolicy):
    """Adds the configured OpenAI API key to the request Authorization header.

    This policy reads the API key from the application settings and adds it
    to the request. It has no policy-specific configuration beyond its name.
    """

    name: Optional[str] = Field(default="AddApiKeyHeaderPolicy")

    async def apply(
        self,
        transaction: Transaction,
        container: DependencyContainer,
        session: AsyncSession,
    ) -> Transaction:
        """
        Sets the API key on the transaction's request.

        Reads OpenAI API key from settings via the container.
        Requires the DependencyContainer and AsyncSession in signature for interface compliance,
        but session is not directly used in this policy's logic.

        Raises:
            NoRequestError if the request is not found in the transaction.
            ApiKeyNotFoundError if the OpenAI API key is not configured.

        Args:
            transaction: The current transaction.
            container: The application dependency container.
            session: An active SQLAlchemy AsyncSession (unused).

        Returns:
            The potentially modified transaction.
        """
        if transaction.request is None:
            raise NoRequestError("No request in transaction.")
        api_key = container.settings.get_openai_api_key()
        if not api_key:
            raise ApiKeyNotFoundError(f"OpenAI API key not configured ({self.name}).")
        self.logger.info(f"Setting API key from settings ({self.name}).")
        transaction.request.api_key = api_key

        return transaction

Adds the configured OpenAI API key to the request Authorization header.

This policy reads the API key from the application settings and adds it to the request. It has no policy-specific configuration beyond its name.

apply(transaction, container, session) async
Source code in luthien_control/control_policy/add_api_key_header.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
async def apply(
    self,
    transaction: Transaction,
    container: DependencyContainer,
    session: AsyncSession,
) -> Transaction:
    """
    Sets the API key on the transaction's request.

    Reads OpenAI API key from settings via the container.
    Requires the DependencyContainer and AsyncSession in signature for interface compliance,
    but session is not directly used in this policy's logic.

    Raises:
        NoRequestError if the request is not found in the transaction.
        ApiKeyNotFoundError if the OpenAI API key is not configured.

    Args:
        transaction: The current transaction.
        container: The application dependency container.
        session: An active SQLAlchemy AsyncSession (unused).

    Returns:
        The potentially modified transaction.
    """
    if transaction.request is None:
        raise NoRequestError("No request in transaction.")
    api_key = container.settings.get_openai_api_key()
    if not api_key:
        raise ApiKeyNotFoundError(f"OpenAI API key not configured ({self.name}).")
    self.logger.info(f"Setting API key from settings ({self.name}).")
    transaction.request.api_key = api_key

    return transaction

Sets the API key on the transaction's request.

Reads OpenAI API key from settings via the container. Requires the DependencyContainer and AsyncSession in signature for interface compliance, but session is not directly used in this policy's logic.

Parameters:

Name Type Description Default
transaction Transaction

The current transaction.

required
container DependencyContainer

The application dependency container.

required
session AsyncSession

An active SQLAlchemy AsyncSession (unused).

required

Returns:

Type Description
Transaction

The potentially modified transaction.

add_api_key_header_from_env

Add an API key header, where the key is sourced from a configured environment variable.

This policy is used to add an API key to the request Authorization header. The API key is read from an environment variable whose name is configured when the policy is instantiated.

AddApiKeyHeaderFromEnvPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/add_api_key_header_from_env.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class AddApiKeyHeaderFromEnvPolicy(ControlPolicy):
    """Adds an API key to the request Authorization header from an environment variable.

    The API key is read from an environment variable whose name is configured
    when the policy is instantiated. This allows different API keys to be used
    based on deployment environment.
    """

    api_key_env_var_name: str = Field(...)

    async def apply(
        self,
        transaction: Transaction,
        container: DependencyContainer,
        session: AsyncSession,
    ) -> Transaction:
        """
        Sets the API key on the transaction's request.

        The API key is read from the environment variable specified by self.api_key_env_var_name.
        Requires DependencyContainer and AsyncSession for interface compliance, but they are not
        directly used in this policy's primary logic beyond what ControlPolicy might require.

        Raises:
            NoRequestError if the request is not found in the transaction.
            ApiKeyNotFoundError if the configured environment variable is not set or is empty.

        Args:
            transaction: The current transaction.
            container: The application dependency container (unused).
            session: An active SQLAlchemy AsyncSession (unused).

        Returns:
            The potentially modified transaction.
        """
        if transaction.request is None:
            raise NoRequestError("No request in transaction.")

        api_key = os.environ.get(self.api_key_env_var_name)

        if not api_key:
            error_message = (
                f"API key not found. Environment variable '{self.api_key_env_var_name}' is not set or is empty."
            )
            self.logger.error(f"{error_message} ({self.name})")
            raise ApiKeyNotFoundError(f"{error_message} ({self.name})")

        self.logger.info(f"Setting API key from env var '{self.api_key_env_var_name}' ({self.name}).")
        transaction.request.api_key = api_key

        return transaction

Adds an API key to the request Authorization header from an environment variable.

The API key is read from an environment variable whose name is configured when the policy is instantiated. This allows different API keys to be used based on deployment environment.

apply(transaction, container, session) async
Source code in luthien_control/control_policy/add_api_key_header_from_env.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
async def apply(
    self,
    transaction: Transaction,
    container: DependencyContainer,
    session: AsyncSession,
) -> Transaction:
    """
    Sets the API key on the transaction's request.

    The API key is read from the environment variable specified by self.api_key_env_var_name.
    Requires DependencyContainer and AsyncSession for interface compliance, but they are not
    directly used in this policy's primary logic beyond what ControlPolicy might require.

    Raises:
        NoRequestError if the request is not found in the transaction.
        ApiKeyNotFoundError if the configured environment variable is not set or is empty.

    Args:
        transaction: The current transaction.
        container: The application dependency container (unused).
        session: An active SQLAlchemy AsyncSession (unused).

    Returns:
        The potentially modified transaction.
    """
    if transaction.request is None:
        raise NoRequestError("No request in transaction.")

    api_key = os.environ.get(self.api_key_env_var_name)

    if not api_key:
        error_message = (
            f"API key not found. Environment variable '{self.api_key_env_var_name}' is not set or is empty."
        )
        self.logger.error(f"{error_message} ({self.name})")
        raise ApiKeyNotFoundError(f"{error_message} ({self.name})")

    self.logger.info(f"Setting API key from env var '{self.api_key_env_var_name}' ({self.name}).")
    transaction.request.api_key = api_key

    return transaction

Sets the API key on the transaction's request.

The API key is read from the environment variable specified by self.api_key_env_var_name. Requires DependencyContainer and AsyncSession for interface compliance, but they are not directly used in this policy's primary logic beyond what ControlPolicy might require.

Parameters:

Name Type Description Default
transaction Transaction

The current transaction.

required
container DependencyContainer

The application dependency container (unused).

required
session AsyncSession

An active SQLAlchemy AsyncSession (unused).

required

Returns:

Type Description
Transaction

The potentially modified transaction.

backend_call_policy

BackendCallPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/backend_call_policy.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class BackendCallPolicy(ControlPolicy):
    """
    This policy makes a backend LLM call.
    """

    name: Optional[str] = Field(default="BackendCallPolicy")
    backend_call_spec: BackendCallSpec = Field(...)

    async def apply(
        self,
        transaction: Transaction,
        container: DependencyContainer,
        session: AsyncSession,
    ) -> Transaction:
        api_key = os.environ.get(self.backend_call_spec.api_key_env_var)
        if api_key:
            transaction.request.api_key = api_key
        transaction.request.api_endpoint = self.backend_call_spec.api_endpoint

        # Update the request payload with all arguments from backend_call_spec.request_args
        # Use model_validate to properly handle nested pydantic models and EventedList/EventedDict
        if self.backend_call_spec.request_args:
            current_data = transaction.request.payload.model_dump()
            current_data.update(self.backend_call_spec.request_args)
            transaction.request.payload = transaction.request.payload.__class__.model_validate(current_data)

        # Set the model if specified
        if self.backend_call_spec.model:
            transaction.request.payload.model = self.backend_call_spec.model

        openai_client = container.create_openai_client(
            transaction.request.api_endpoint, api_key or transaction.request.api_key
        )
        try:
            response_payload = await openai_client.chat.completions.create(**transaction.request.payload.model_dump())
            transaction.response.payload = response_payload
            transaction.response.api_endpoint = transaction.request.api_endpoint
        except openai.APITimeoutError as e:
            self.logger.error(f"Timeout error during backend request: {e} ({self.name})")
            raise
        except openai.APIConnectionError as e:
            self.logger.error(f"Connection error during backend request: {e} ({self.name})")
            raise
        except openai.APIError as e:
            self.logger.error(f"OpenAI API error during backend request: {e} ({self.name})")
            raise
        except Exception as e:
            self.logger.exception(f"Unexpected error during backend request: {e} ({self.name})")
            raise
        return transaction

This policy makes a backend LLM call.

branching_policy

BranchingPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/branching_policy.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class BranchingPolicy(ControlPolicy):
    """
    A Control Policy that conditionally applies different policies based on transaction evaluation.

    This policy evaluates conditions in order and applies the policy associated with the first
    matching condition. If no conditions match, it applies the default policy (if configured).
    """

    name: Optional[str] = Field(default="BranchingPolicy")
    cond_to_policy_map: OrderedDict[Condition, ControlPolicy] = Field(default_factory=OrderedDict, exclude=True)
    default_policy: Optional[ControlPolicy] = Field(default=None)

    @field_validator("cond_to_policy_map", mode="before")
    @classmethod
    def validate_cond_to_policy_map(cls, value):
        """Validate and convert condition-to-policy mapping."""
        if isinstance(value, OrderedDict):
            return value
        if isinstance(value, dict):
            return OrderedDict(value)
        raise ValueError("cond_to_policy_map must be a dict or OrderedDict")

    @field_validator("default_policy", mode="before")
    @classmethod
    def validate_default_policy(cls, value):
        """Validate default policy field."""
        return value

    async def apply(
        self, transaction: Transaction, container: DependencyContainer, session: AsyncSession
    ) -> Transaction:
        """
        Apply the first policy that matches the condition. If no condition matches, apply the default policy (if set).

        Args:
            transaction: The transaction to apply the policy to.
            container: The dependency container.
            session: The database session.

        Returns:
            The potentially modified transaction.
        """
        for cond, policy in self.cond_to_policy_map.items():
            if cond.evaluate(transaction):
                return await policy.apply(transaction, container, session)
        if self.default_policy:
            return await self.default_policy.apply(transaction, container, session)
        return transaction

    def serialize(self) -> SerializableDict:
        """Override serialize to handle complex condition-to-policy mapping."""
        data = super().serialize()
        data["cond_to_policy_map"] = {
            json.dumps(cond.serialize()): policy.serialize() for cond, policy in self.cond_to_policy_map.items()
        }
        if self.default_policy:
            data["default_policy"] = self.default_policy.serialize()
        else:
            data["default_policy"] = None
        return data

    @classmethod
    def from_serialized(cls, config: SerializableDict) -> "BranchingPolicy":
        """Custom from_serialized to handle JSON-serialized condition keys."""
        config_copy = dict(config)

        cond_to_policy_map = OrderedDict()
        serialized_cond_map = config_copy.pop("cond_to_policy_map", None)
        if serialized_cond_map is not None:
            if not isinstance(serialized_cond_map, dict):
                raise TypeError(
                    f"Expected 'cond_to_policy_map' to be a dict in BranchingPolicy config, "
                    f"got {type(serialized_cond_map)}"
                )

            for cond_json_str, policy_config in serialized_cond_map.items():
                if not isinstance(cond_json_str, str):
                    raise TypeError(
                        f"Condition key in 'cond_to_policy_map' must be a JSON string, got {type(cond_json_str)}"
                    )

                if not isinstance(policy_config, dict):
                    raise TypeError(
                        f"Policy config for condition '{cond_json_str}' must be a dict, got {type(policy_config)}"
                    )

                try:
                    condition_serializable_dict = json.loads(cond_json_str)
                except json.JSONDecodeError as e:
                    raise ValueError(f"Failed to parse condition JSON string '{cond_json_str}': {e}")

                if not isinstance(condition_serializable_dict, dict):
                    raise TypeError(
                        f"Deserialized condition config for '{cond_json_str}' must be a dict, "
                        f"got {type(condition_serializable_dict)}"
                    )

                condition = Condition.from_serialized(condition_serializable_dict)
                policy = ControlPolicy.from_serialized(policy_config)
                cond_to_policy_map[condition] = policy

        default_policy = None
        default_policy_serializable = config_copy.pop("default_policy", None)
        if default_policy_serializable is not None:
            if not isinstance(default_policy_serializable, dict):
                raise TypeError(
                    f"Expected 'default_policy' config to be a dict, got {type(default_policy_serializable)}"
                )
            default_policy = ControlPolicy.from_serialized(default_policy_serializable)

        instance = super().from_serialized(config_copy)

        instance.cond_to_policy_map = cond_to_policy_map
        instance.default_policy = default_policy

        return instance

A Control Policy that conditionally applies different policies based on transaction evaluation.

This policy evaluates conditions in order and applies the policy associated with the first matching condition. If no conditions match, it applies the default policy (if configured).

apply(transaction, container, session) async
Source code in luthien_control/control_policy/branching_policy.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
async def apply(
    self, transaction: Transaction, container: DependencyContainer, session: AsyncSession
) -> Transaction:
    """
    Apply the first policy that matches the condition. If no condition matches, apply the default policy (if set).

    Args:
        transaction: The transaction to apply the policy to.
        container: The dependency container.
        session: The database session.

    Returns:
        The potentially modified transaction.
    """
    for cond, policy in self.cond_to_policy_map.items():
        if cond.evaluate(transaction):
            return await policy.apply(transaction, container, session)
    if self.default_policy:
        return await self.default_policy.apply(transaction, container, session)
    return transaction

Apply the first policy that matches the condition. If no condition matches, apply the default policy (if set).

Parameters:

Name Type Description Default
transaction Transaction

The transaction to apply the policy to.

required
container DependencyContainer

The dependency container.

required
session AsyncSession

The database session.

required

Returns:

Type Description
Transaction

The potentially modified transaction.

from_serialized(config) classmethod
Source code in luthien_control/control_policy/branching_policy.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@classmethod
def from_serialized(cls, config: SerializableDict) -> "BranchingPolicy":
    """Custom from_serialized to handle JSON-serialized condition keys."""
    config_copy = dict(config)

    cond_to_policy_map = OrderedDict()
    serialized_cond_map = config_copy.pop("cond_to_policy_map", None)
    if serialized_cond_map is not None:
        if not isinstance(serialized_cond_map, dict):
            raise TypeError(
                f"Expected 'cond_to_policy_map' to be a dict in BranchingPolicy config, "
                f"got {type(serialized_cond_map)}"
            )

        for cond_json_str, policy_config in serialized_cond_map.items():
            if not isinstance(cond_json_str, str):
                raise TypeError(
                    f"Condition key in 'cond_to_policy_map' must be a JSON string, got {type(cond_json_str)}"
                )

            if not isinstance(policy_config, dict):
                raise TypeError(
                    f"Policy config for condition '{cond_json_str}' must be a dict, got {type(policy_config)}"
                )

            try:
                condition_serializable_dict = json.loads(cond_json_str)
            except json.JSONDecodeError as e:
                raise ValueError(f"Failed to parse condition JSON string '{cond_json_str}': {e}")

            if not isinstance(condition_serializable_dict, dict):
                raise TypeError(
                    f"Deserialized condition config for '{cond_json_str}' must be a dict, "
                    f"got {type(condition_serializable_dict)}"
                )

            condition = Condition.from_serialized(condition_serializable_dict)
            policy = ControlPolicy.from_serialized(policy_config)
            cond_to_policy_map[condition] = policy

    default_policy = None
    default_policy_serializable = config_copy.pop("default_policy", None)
    if default_policy_serializable is not None:
        if not isinstance(default_policy_serializable, dict):
            raise TypeError(
                f"Expected 'default_policy' config to be a dict, got {type(default_policy_serializable)}"
            )
        default_policy = ControlPolicy.from_serialized(default_policy_serializable)

    instance = super().from_serialized(config_copy)

    instance.cond_to_policy_map = cond_to_policy_map
    instance.default_policy = default_policy

    return instance

Custom from_serialized to handle JSON-serialized condition keys.

serialize()
Source code in luthien_control/control_policy/branching_policy.py
67
68
69
70
71
72
73
74
75
76
77
def serialize(self) -> SerializableDict:
    """Override serialize to handle complex condition-to-policy mapping."""
    data = super().serialize()
    data["cond_to_policy_map"] = {
        json.dumps(cond.serialize()): policy.serialize() for cond, policy in self.cond_to_policy_map.items()
    }
    if self.default_policy:
        data["default_policy"] = self.default_policy.serialize()
    else:
        data["default_policy"] = None
    return data

Override serialize to handle complex condition-to-policy mapping.

validate_cond_to_policy_map(value) classmethod
Source code in luthien_control/control_policy/branching_policy.py
30
31
32
33
34
35
36
37
38
@field_validator("cond_to_policy_map", mode="before")
@classmethod
def validate_cond_to_policy_map(cls, value):
    """Validate and convert condition-to-policy mapping."""
    if isinstance(value, OrderedDict):
        return value
    if isinstance(value, dict):
        return OrderedDict(value)
    raise ValueError("cond_to_policy_map must be a dict or OrderedDict")

Validate and convert condition-to-policy mapping.

validate_default_policy(value) classmethod
Source code in luthien_control/control_policy/branching_policy.py
40
41
42
43
44
@field_validator("default_policy", mode="before")
@classmethod
def validate_default_policy(cls, value):
    """Validate default policy field."""
    return value

Validate default policy field.

client_api_key_auth

ClientApiKeyAuthPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/client_api_key_auth.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class ClientApiKeyAuthPolicy(ControlPolicy):
    """Verifies the client API key from the transaction's request.

    This policy authenticates clients by checking their API key against
    the database. It ensures the key exists and is active.

    Attributes:
        name (str): The name of this policy instance.
    """

    name: Optional[str] = Field(default="ClientApiKeyAuthPolicy")

    async def apply(
        self,
        transaction: Transaction,
        container: DependencyContainer,
        session: AsyncSession,
    ) -> Transaction:
        """
        Verifies the API key from the transaction's request.
        Requires the DependencyContainer and an active SQLAlchemy AsyncSession.

        Raises:
            NoRequestError: If transaction.request is None.
            ClientAuthenticationError: If the key is missing, invalid, or inactive.

        Args:
            transaction: The current transaction.
            container: The application dependency container.
            session: An active SQLAlchemy AsyncSession.

        Returns:
            The unmodified transaction if authentication is successful.
        """
        if transaction.request is None:
            raise NoRequestError("No request in transaction for API key auth.")

        api_key_value = transaction.request.api_key

        if not api_key_value:
            self.logger.warning("Missing API key in transaction request.")
            raise ClientAuthenticationNotFoundError(detail="Not authenticated: Missing API key.")

        try:
            db_key = await get_api_key_by_value(session, api_key_value)
        except LuthienDBQueryError:
            self.logger.warning(
                f"Invalid API key provided (key starts with: {api_key_value[:4]}...) ({self.__class__.__name__})."
            )
            raise ClientAuthenticationError(detail="Invalid API Key")

        if not db_key.is_active:
            self.logger.warning(
                f"Inactive API key provided (Name: {db_key.name}, ID: {db_key.id}). ({self.__class__.__name__})."
            )
            raise ClientAuthenticationError(detail="Inactive API Key")

        self.logger.info(
            f"Client API key authenticated successfully "
            f"(Name: {db_key.name}, ID: {db_key.id}). ({self.__class__.__name__})."
        )

        return transaction

Verifies the client API key from the transaction's request.

This policy authenticates clients by checking their API key against the database. It ensures the key exists and is active.

Attributes:

Name Type Description
name str

The name of this policy instance.

apply(transaction, container, session) async
Source code in luthien_control/control_policy/client_api_key_auth.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
async def apply(
    self,
    transaction: Transaction,
    container: DependencyContainer,
    session: AsyncSession,
) -> Transaction:
    """
    Verifies the API key from the transaction's request.
    Requires the DependencyContainer and an active SQLAlchemy AsyncSession.

    Raises:
        NoRequestError: If transaction.request is None.
        ClientAuthenticationError: If the key is missing, invalid, or inactive.

    Args:
        transaction: The current transaction.
        container: The application dependency container.
        session: An active SQLAlchemy AsyncSession.

    Returns:
        The unmodified transaction if authentication is successful.
    """
    if transaction.request is None:
        raise NoRequestError("No request in transaction for API key auth.")

    api_key_value = transaction.request.api_key

    if not api_key_value:
        self.logger.warning("Missing API key in transaction request.")
        raise ClientAuthenticationNotFoundError(detail="Not authenticated: Missing API key.")

    try:
        db_key = await get_api_key_by_value(session, api_key_value)
    except LuthienDBQueryError:
        self.logger.warning(
            f"Invalid API key provided (key starts with: {api_key_value[:4]}...) ({self.__class__.__name__})."
        )
        raise ClientAuthenticationError(detail="Invalid API Key")

    if not db_key.is_active:
        self.logger.warning(
            f"Inactive API key provided (Name: {db_key.name}, ID: {db_key.id}). ({self.__class__.__name__})."
        )
        raise ClientAuthenticationError(detail="Inactive API Key")

    self.logger.info(
        f"Client API key authenticated successfully "
        f"(Name: {db_key.name}, ID: {db_key.id}). ({self.__class__.__name__})."
    )

    return transaction

Verifies the API key from the transaction's request. Requires the DependencyContainer and an active SQLAlchemy AsyncSession.

Raises:

Type Description
NoRequestError

If transaction.request is None.

ClientAuthenticationError

If the key is missing, invalid, or inactive.

Parameters:

Name Type Description Default
transaction Transaction

The current transaction.

required
container DependencyContainer

The application dependency container.

required
session AsyncSession

An active SQLAlchemy AsyncSession.

required

Returns:

Type Description
Transaction

The unmodified transaction if authentication is successful.

conditions

AllCondition

Bases: Condition

Source code in luthien_control/control_policy/conditions/all_cond.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class AllCondition(Condition):
    type: Literal["all"] = "all"
    conditions: List[Condition] = Field(...)

    @field_serializer("conditions")
    def serialize_conditions(self, value: List[Condition]) -> List[dict]:
        """Custom serializer for conditions field."""
        return [condition.serialize() for condition in value]

    @field_validator("conditions", mode="before")
    @classmethod
    def validate_conditions(cls, value):
        """Custom validator to deserialize conditions from dicts."""
        if isinstance(value, list):
            result = []
            for item in value:
                if isinstance(item, dict):
                    result.append(Condition.from_serialized(item))
                elif isinstance(item, Condition):
                    result.append(item)
            return result
        return value

    def evaluate(self, transaction: Transaction) -> bool:
        return all(condition.evaluate(transaction) for condition in self.conditions)
serialize_conditions(value)
Source code in luthien_control/control_policy/conditions/all_cond.py
13
14
15
16
@field_serializer("conditions")
def serialize_conditions(self, value: List[Condition]) -> List[dict]:
    """Custom serializer for conditions field."""
    return [condition.serialize() for condition in value]

Custom serializer for conditions field.

validate_conditions(value) classmethod
Source code in luthien_control/control_policy/conditions/all_cond.py
18
19
20
21
22
23
24
25
26
27
28
29
30
@field_validator("conditions", mode="before")
@classmethod
def validate_conditions(cls, value):
    """Custom validator to deserialize conditions from dicts."""
    if isinstance(value, list):
        result = []
        for item in value:
            if isinstance(item, dict):
                result.append(Condition.from_serialized(item))
            elif isinstance(item, Condition):
                result.append(item)
        return result
    return value

Custom validator to deserialize conditions from dicts.

AnyCondition

Bases: Condition

Source code in luthien_control/control_policy/conditions/any_cond.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class AnyCondition(Condition):
    type: Literal["any"] = "any"
    conditions: List[Condition] = Field(...)

    @field_serializer("conditions")
    def serialize_conditions(self, value: List[Condition]) -> List[dict]:
        """Custom serializer for conditions field."""
        return [condition.serialize() for condition in value]

    @field_validator("conditions", mode="before")
    @classmethod
    def validate_conditions(cls, value):
        """Custom validator to deserialize conditions from dicts."""
        if isinstance(value, list):
            result = []
            for item in value:
                if isinstance(item, dict):
                    result.append(Condition.from_serialized(item))
                elif isinstance(item, Condition):
                    result.append(item)
            return result
        return value

    def evaluate(self, transaction: Transaction) -> bool:
        if not self.conditions:
            return False
        return any(condition.evaluate(transaction) for condition in self.conditions)
serialize_conditions(value)
Source code in luthien_control/control_policy/conditions/any_cond.py
13
14
15
16
@field_serializer("conditions")
def serialize_conditions(self, value: List[Condition]) -> List[dict]:
    """Custom serializer for conditions field."""
    return [condition.serialize() for condition in value]

Custom serializer for conditions field.

validate_conditions(value) classmethod
Source code in luthien_control/control_policy/conditions/any_cond.py
18
19
20
21
22
23
24
25
26
27
28
29
30
@field_validator("conditions", mode="before")
@classmethod
def validate_conditions(cls, value):
    """Custom validator to deserialize conditions from dicts."""
    if isinstance(value, list):
        result = []
        for item in value:
            if isinstance(item, dict):
                result.append(Condition.from_serialized(item))
            elif isinstance(item, Condition):
                result.append(item)
        return result
    return value

Custom validator to deserialize conditions from dicts.

Condition

Bases: BaseModel, ABC

Source code in luthien_control/control_policy/conditions/condition.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class Condition(BaseModel, abc.ABC):
    """
    Abstract base class for conditions in control policies.

    Conditions are used to evaluate whether a policy should be applied based on
    the current transaction.
    """

    type: Any  # Allow any string type including Literal types

    @abc.abstractmethod
    def evaluate(self, transaction: Transaction) -> bool:
        pass

    def serialize(self) -> SerializableDict:
        """Serialize using Pydantic model_dump through SerializableDict validation."""
        data = safe_model_dump(self)
        data["type"] = self.type
        return data

    @classmethod
    def from_serialized(cls, serialized: SerializableDict) -> "Condition":
        """Construct a condition from a serialized configuration.

        This method acts as a dispatcher. It looks up the concrete condition class
        based on the 'type' field in the config and delegates to its from_serialized method.

        Args:
            serialized: The condition-specific configuration dictionary. It must contain
                        a 'type' key that maps to a registered condition type.

        Returns:
            An instance of the concrete condition class.

        Raises:
            ValueError: If the 'type' key is missing in config or the type is not registered.
        """
        # Moved import inside the method to break circular dependency
        from luthien_control.control_policy.conditions.registry import NAME_TO_CONDITION_CLASS

        condition_type_name_val = str(serialized.get("type"))

        target_condition_class = NAME_TO_CONDITION_CLASS[condition_type_name_val]

        return safe_model_validate(target_condition_class, serialized)

    def __repr__(self) -> str:
        return f"{self.serialize()})"

    def __hash__(self) -> int:
        return hash(str(self))

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        return self.serialize() == other.serialize()

    model_config = ConfigDict(arbitrary_types_allowed=True)

Abstract base class for conditions in control policies.

Conditions are used to evaluate whether a policy should be applied based on the current transaction.

from_serialized(serialized) classmethod
Source code in luthien_control/control_policy/conditions/condition.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@classmethod
def from_serialized(cls, serialized: SerializableDict) -> "Condition":
    """Construct a condition from a serialized configuration.

    This method acts as a dispatcher. It looks up the concrete condition class
    based on the 'type' field in the config and delegates to its from_serialized method.

    Args:
        serialized: The condition-specific configuration dictionary. It must contain
                    a 'type' key that maps to a registered condition type.

    Returns:
        An instance of the concrete condition class.

    Raises:
        ValueError: If the 'type' key is missing in config or the type is not registered.
    """
    # Moved import inside the method to break circular dependency
    from luthien_control.control_policy.conditions.registry import NAME_TO_CONDITION_CLASS

    condition_type_name_val = str(serialized.get("type"))

    target_condition_class = NAME_TO_CONDITION_CLASS[condition_type_name_val]

    return safe_model_validate(target_condition_class, serialized)

Construct a condition from a serialized configuration.

This method acts as a dispatcher. It looks up the concrete condition class based on the 'type' field in the config and delegates to its from_serialized method.

Parameters:

Name Type Description Default
serialized SerializableDict

The condition-specific configuration dictionary. It must contain a 'type' key that maps to a registered condition type.

required

Returns:

Type Description
Condition

An instance of the concrete condition class.

Raises:

Type Description
ValueError

If the 'type' key is missing in config or the type is not registered.

serialize()
Source code in luthien_control/control_policy/conditions/condition.py
24
25
26
27
28
def serialize(self) -> SerializableDict:
    """Serialize using Pydantic model_dump through SerializableDict validation."""
    data = safe_model_dump(self)
    data["type"] = self.type
    return data

Serialize using Pydantic model_dump through SerializableDict validation.

ContainsCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
226
227
228
229
230
231
232
class ContainsCondition(ComparisonCondition):
    """
    Condition to check if the left value contains the right value.
    """

    type: Literal["contains"] = Field(default="contains")
    comparator = contains

Condition to check if the left value contains the right value.

EqualsCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class EqualsCondition(ComparisonCondition):
    """
    Condition to check if two values are equal.

    Examples:
        # Traditional
        EqualsCondition(path("request.payload.model"), "gpt-4o")

        # Dynamic
        EqualsCondition(path("request.payload.model"), path("data.preferred_model"))

        # Static vs dynamic
        EqualsCondition("gpt-4o", path("request.payload.model"))
    """

    type: Literal["equals"] = Field(default="equals")
    comparator = equals

Condition to check if two values are equal.

Examples:

Traditional

EqualsCondition(path("request.payload.model"), "gpt-4o")

Dynamic

EqualsCondition(path("request.payload.model"), path("data.preferred_model"))

Static vs dynamic

EqualsCondition("gpt-4o", path("request.payload.model"))

GreaterThanCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
253
254
255
256
257
258
259
class GreaterThanCondition(ComparisonCondition):
    """
    Condition to check if the left value is greater than the right value.
    """

    type: Literal["greater_than"] = Field(default="greater_than")
    comparator = greater_than

Condition to check if the left value is greater than the right value.

GreaterThanOrEqualCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
262
263
264
265
266
267
268
class GreaterThanOrEqualCondition(ComparisonCondition):
    """
    Condition to check if the left value is greater than or equal to the right value.
    """

    type: Literal["greater_than_or_equal"] = Field(default="greater_than_or_equal")
    comparator = greater_than_or_equal

Condition to check if the left value is greater than or equal to the right value.

LessThanCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
235
236
237
238
239
240
241
class LessThanCondition(ComparisonCondition):
    """
    Condition to check if the left value is less than the right value.
    """

    type: Literal["less_than"] = Field(default="less_than")
    comparator = less_than

Condition to check if the left value is less than the right value.

LessThanOrEqualCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
244
245
246
247
248
249
250
class LessThanOrEqualCondition(ComparisonCondition):
    """
    Condition to check if the left value is less than or equal to the right value.
    """

    type: Literal["less_than_or_equal"] = Field(default="less_than_or_equal")
    comparator = less_than_or_equal

Condition to check if the left value is less than or equal to the right value.

NotCondition

Bases: Condition

Source code in luthien_control/control_policy/conditions/not_cond.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class NotCondition(Condition):
    type: Literal["not"] = "not"
    cond: Condition = Field(...)

    @field_serializer("cond")
    def serialize_cond(self, value: Condition) -> dict:
        """Custom serializer for cond field."""
        return value.serialize()

    @field_validator("cond", mode="before")
    @classmethod
    def validate_cond(cls, value):
        """Custom validator to deserialize condition from dict."""
        if isinstance(value, dict):
            return Condition.from_serialized(value)
        return value

    def evaluate(self, transaction: Transaction) -> bool:
        return not self.cond.evaluate(transaction)

    def __repr__(self) -> str:
        return f"{type(self).__name__}(value={self.cond!r})"
serialize_cond(value)
Source code in luthien_control/control_policy/conditions/not_cond.py
13
14
15
16
@field_serializer("cond")
def serialize_cond(self, value: Condition) -> dict:
    """Custom serializer for cond field."""
    return value.serialize()

Custom serializer for cond field.

validate_cond(value) classmethod
Source code in luthien_control/control_policy/conditions/not_cond.py
18
19
20
21
22
23
24
@field_validator("cond", mode="before")
@classmethod
def validate_cond(cls, value):
    """Custom validator to deserialize condition from dict."""
    if isinstance(value, dict):
        return Condition.from_serialized(value)
    return value

Custom validator to deserialize condition from dict.

NotEqualsCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
217
218
219
220
221
222
223
class NotEqualsCondition(ComparisonCondition):
    """
    Condition to check if two values are NOT equal.
    """

    type: Literal["not_equals"] = Field(default="not_equals")
    comparator = not_equals

Condition to check if two values are NOT equal.

RegexMatchCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
271
272
273
274
275
276
277
class RegexMatchCondition(ComparisonCondition):
    """
    Condition to check if the left value matches a regex pattern.
    """

    type: Literal["regex_match"] = Field(default="regex_match")
    comparator = regex_match

Condition to check if the left value matches a regex pattern.

path(transaction_path)

Source code in luthien_control/control_policy/conditions/value_resolvers.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def path(transaction_path: str) -> TransactionPath:
    """
    Convenience function to create a TransactionPath.

    Args:
        transaction_path: The path to the value in the transaction

    Returns:
        A TransactionPath instance

    Example:
        path("request.payload.model")
    """
    return TransactionPath(path=transaction_path)

Convenience function to create a TransactionPath.

Parameters:

Name Type Description Default
transaction_path str

The path to the value in the transaction

required

Returns:

Type Description
TransactionPath

A TransactionPath instance

Example

path("request.payload.model")

all_cond

AllCondition

Bases: Condition

Source code in luthien_control/control_policy/conditions/all_cond.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class AllCondition(Condition):
    type: Literal["all"] = "all"
    conditions: List[Condition] = Field(...)

    @field_serializer("conditions")
    def serialize_conditions(self, value: List[Condition]) -> List[dict]:
        """Custom serializer for conditions field."""
        return [condition.serialize() for condition in value]

    @field_validator("conditions", mode="before")
    @classmethod
    def validate_conditions(cls, value):
        """Custom validator to deserialize conditions from dicts."""
        if isinstance(value, list):
            result = []
            for item in value:
                if isinstance(item, dict):
                    result.append(Condition.from_serialized(item))
                elif isinstance(item, Condition):
                    result.append(item)
            return result
        return value

    def evaluate(self, transaction: Transaction) -> bool:
        return all(condition.evaluate(transaction) for condition in self.conditions)
serialize_conditions(value)
Source code in luthien_control/control_policy/conditions/all_cond.py
13
14
15
16
@field_serializer("conditions")
def serialize_conditions(self, value: List[Condition]) -> List[dict]:
    """Custom serializer for conditions field."""
    return [condition.serialize() for condition in value]

Custom serializer for conditions field.

validate_conditions(value) classmethod
Source code in luthien_control/control_policy/conditions/all_cond.py
18
19
20
21
22
23
24
25
26
27
28
29
30
@field_validator("conditions", mode="before")
@classmethod
def validate_conditions(cls, value):
    """Custom validator to deserialize conditions from dicts."""
    if isinstance(value, list):
        result = []
        for item in value:
            if isinstance(item, dict):
                result.append(Condition.from_serialized(item))
            elif isinstance(item, Condition):
                result.append(item)
        return result
    return value

Custom validator to deserialize conditions from dicts.

any_cond

AnyCondition

Bases: Condition

Source code in luthien_control/control_policy/conditions/any_cond.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class AnyCondition(Condition):
    type: Literal["any"] = "any"
    conditions: List[Condition] = Field(...)

    @field_serializer("conditions")
    def serialize_conditions(self, value: List[Condition]) -> List[dict]:
        """Custom serializer for conditions field."""
        return [condition.serialize() for condition in value]

    @field_validator("conditions", mode="before")
    @classmethod
    def validate_conditions(cls, value):
        """Custom validator to deserialize conditions from dicts."""
        if isinstance(value, list):
            result = []
            for item in value:
                if isinstance(item, dict):
                    result.append(Condition.from_serialized(item))
                elif isinstance(item, Condition):
                    result.append(item)
            return result
        return value

    def evaluate(self, transaction: Transaction) -> bool:
        if not self.conditions:
            return False
        return any(condition.evaluate(transaction) for condition in self.conditions)
serialize_conditions(value)
Source code in luthien_control/control_policy/conditions/any_cond.py
13
14
15
16
@field_serializer("conditions")
def serialize_conditions(self, value: List[Condition]) -> List[dict]:
    """Custom serializer for conditions field."""
    return [condition.serialize() for condition in value]

Custom serializer for conditions field.

validate_conditions(value) classmethod
Source code in luthien_control/control_policy/conditions/any_cond.py
18
19
20
21
22
23
24
25
26
27
28
29
30
@field_validator("conditions", mode="before")
@classmethod
def validate_conditions(cls, value):
    """Custom validator to deserialize conditions from dicts."""
    if isinstance(value, list):
        result = []
        for item in value:
            if isinstance(item, dict):
                result.append(Condition.from_serialized(item))
            elif isinstance(item, Condition):
                result.append(item)
        return result
    return value

Custom validator to deserialize conditions from dicts.

comparison_conditions

Comparison conditions for control policies.

This module implements comparison-based conditions (equals, contains, greater than, etc.) using a clean ValueResolver pattern for flexible value resolution.

Pyright Type Checker Suppression

The # pyright: reportCallIssue=false comment at the top of this file suppresses type checker warnings for positional argument usage in comparison condition constructors.

Why This Is Necessary

All comparison conditions inherit from Pydantic's BaseModel, which enforces keyword-only constructors. However, we provide a more natural API with positional arguments:

# Natural, concise syntax (what we want)
EqualsCondition(path("request.payload.model"), "gpt-4o")

# Verbose but type-safe (what Pydantic expects)
EqualsCondition(left=path("request.payload.model"), right="gpt-4o")

Our custom __init__ methods handle both patterns correctly at runtime, but static analysis tools like pyright cannot see through the Pydantic inheritance to understand this flexibility.

Safety Considerations

This suppression is safe because: 1. We only suppress reportCallIssue (constructor signature mismatches) 2. Our overload definitions provide proper type hints 3. Runtime behavior is thoroughly tested 4. Other type checking (return types, field access, etc.) remains active

For Users of This Module

When using these comparison conditions in your code, you may encounter pyright warnings. See the ComparisonCondition class documentation for guidance on suppressing these warnings appropriately in your own files.

ComparisonCondition

Bases: Condition, ABC

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class ComparisonCondition(Condition, ABC):
    """
    Clean comparison condition that uses ValueResolver objects for flexible value resolution.

    This approach eliminates the need for is_dynamic_* flags by using explicit types.

    ## Constructor Usage

    This class supports both positional and keyword argument patterns:

    ### Positional Arguments (Recommended for brevity)
    ```python
    EqualsCondition(path("request.payload.model"), "gpt-4o")
    EqualsCondition(path("left_path"), path("right_path"))  # Dynamic comparison
    EqualsCondition("static_left", "static_right")          # Static comparison
    ```

    ### Keyword Arguments (Explicit, type-safe)
    ```python
    EqualsCondition(left=path("request.payload.model"), right="gpt-4o")
    EqualsCondition(left=path("left_path"), right=path("right_path"))
    ```

    ## Pyright Type Checker Warning

    **Important**: When using positional arguments, pyright may show this error:
    ```
    Expected 0 positional arguments (reportCallIssue)
    ```

    This is a known issue due to the underlying Pydantic BaseModel inheritance. The code
    works correctly at runtime, but pyright's static analysis doesn't recognize our
    custom `__init__` override.

    ### How to Suppress the Warning

    Add this comment to suppress the specific error on individual calls:
    ```python
    condition = EqualsCondition(path("test"), "value")  # pyright: ignore[reportCallIssue]
    ```

    Or add this at the top of your file to suppress all such errors in that file:
    ```python
    # pyright: reportCallIssue=false
    ```

    ### When to Use Each Approach

    - **Use positional**: For concise, readable condition creation in tests and simple cases
    - **Use keywords**: When you need full type safety or when working in strict typing environments
    - **Suppress warnings**: When you prefer the positional syntax and understand the trade-off
    """

    comparator: ClassVar[Comparator]

    left: ValueResolver
    right: ValueResolver
    comparator_name: str = Field(alias="comparator")

    @overload
    def __init__(self, left: Union[Any, ValueResolver], right: Union[Any, ValueResolver]) -> None: ...

    @overload
    def __init__(self, *, left: ValueResolver, right: ValueResolver, comparator: str, **kwargs: Any) -> None: ...

    def __init__(
        self,
        left: Union[Any, ValueResolver, None] = None,
        right: Union[Any, ValueResolver, None] = None,
        *,
        # Pydantic keyword-only arguments
        comparator: Union[str, None] = None,
        **kwargs,
    ):
        """Initialize with both positional and keyword argument support."""
        # Handle positional arguments
        if left is not None and right is not None:
            kwargs["left"] = auto_resolve_value(left)
            kwargs["right"] = auto_resolve_value(right)
            if comparator is None:
                kwargs["comparator"] = COMPARATOR_TO_NAME[type(self).comparator]
            else:
                kwargs["comparator"] = comparator

        super().__init__(**kwargs)

    @field_serializer("left", "right")
    def serialize_value_resolver(self, value: ValueResolver) -> dict:
        """Custom serializer for ValueResolver fields."""
        return value.serialize()

    @field_validator("left", "right", mode="before")
    @classmethod
    def validate_value_resolver(cls, value: StaticValue | ValueResolver):
        """Custom validator to deserialize ValueResolver from dict."""
        if isinstance(value, StaticValue) and isinstance(value.value, dict) and "type" in value.value:
            return create_value_resolver(value.value)
        else:
            return auto_resolve_value(value)

    def evaluate(self, transaction: Transaction) -> bool:
        """Evaluate the condition against the transaction."""
        left_value = self.left.resolve(transaction)
        right_value = self.right.resolve(transaction)
        return type(self).comparator.evaluate(left_value, right_value)

    def __repr__(self) -> str:
        return f"{type(self).__name__}({self.left!r}, {self.right!r})"

    @classmethod
    def from_legacy_format(cls, key: str, value: Any) -> "ComparisonCondition":
        """
        Create a condition from legacy ComparisonCondition format.

        Args:
            key: Transaction path (e.g., "request.payload.model")
            value: Static value to compare against

        Returns:
            A ComparisonCondition instance
        """
        return cls(path(key), value)

Clean comparison condition that uses ValueResolver objects for flexible value resolution.

This approach eliminates the need for is_dynamic_* flags by using explicit types.

Constructor Usage

This class supports both positional and keyword argument patterns:

EqualsCondition(path("request.payload.model"), "gpt-4o")
EqualsCondition(path("left_path"), path("right_path"))  # Dynamic comparison
EqualsCondition("static_left", "static_right")          # Static comparison
Keyword Arguments (Explicit, type-safe)
EqualsCondition(left=path("request.payload.model"), right="gpt-4o")
EqualsCondition(left=path("left_path"), right=path("right_path"))
Pyright Type Checker Warning

Important: When using positional arguments, pyright may show this error:

Expected 0 positional arguments (reportCallIssue)

This is a known issue due to the underlying Pydantic BaseModel inheritance. The code works correctly at runtime, but pyright's static analysis doesn't recognize our custom __init__ override.

How to Suppress the Warning

Add this comment to suppress the specific error on individual calls:

condition = EqualsCondition(path("test"), "value")  # pyright: ignore[reportCallIssue]

Or add this at the top of your file to suppress all such errors in that file:

# pyright: reportCallIssue=false
When to Use Each Approach
  • Use positional: For concise, readable condition creation in tests and simple cases
  • Use keywords: When you need full type safety or when working in strict typing environments
  • Suppress warnings: When you prefer the positional syntax and understand the trade-off
__init__(left=None, right=None, *, comparator=None, **kwargs)
__init__(
    left: Union[Any, ValueResolver],
    right: Union[Any, ValueResolver],
) -> None
__init__(
    *,
    left: ValueResolver,
    right: ValueResolver,
    comparator: str,
    **kwargs: Any,
) -> None
Source code in luthien_control/control_policy/conditions/comparison_conditions.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def __init__(
    self,
    left: Union[Any, ValueResolver, None] = None,
    right: Union[Any, ValueResolver, None] = None,
    *,
    # Pydantic keyword-only arguments
    comparator: Union[str, None] = None,
    **kwargs,
):
    """Initialize with both positional and keyword argument support."""
    # Handle positional arguments
    if left is not None and right is not None:
        kwargs["left"] = auto_resolve_value(left)
        kwargs["right"] = auto_resolve_value(right)
        if comparator is None:
            kwargs["comparator"] = COMPARATOR_TO_NAME[type(self).comparator]
        else:
            kwargs["comparator"] = comparator

    super().__init__(**kwargs)

Initialize with both positional and keyword argument support.

evaluate(transaction)
Source code in luthien_control/control_policy/conditions/comparison_conditions.py
174
175
176
177
178
def evaluate(self, transaction: Transaction) -> bool:
    """Evaluate the condition against the transaction."""
    left_value = self.left.resolve(transaction)
    right_value = self.right.resolve(transaction)
    return type(self).comparator.evaluate(left_value, right_value)

Evaluate the condition against the transaction.

from_legacy_format(key, value) classmethod
Source code in luthien_control/control_policy/conditions/comparison_conditions.py
183
184
185
186
187
188
189
190
191
192
193
194
195
@classmethod
def from_legacy_format(cls, key: str, value: Any) -> "ComparisonCondition":
    """
    Create a condition from legacy ComparisonCondition format.

    Args:
        key: Transaction path (e.g., "request.payload.model")
        value: Static value to compare against

    Returns:
        A ComparisonCondition instance
    """
    return cls(path(key), value)

Create a condition from legacy ComparisonCondition format.

Parameters:

Name Type Description Default
key str

Transaction path (e.g., "request.payload.model")

required
value Any

Static value to compare against

required

Returns:

Type Description
ComparisonCondition

A ComparisonCondition instance

serialize_value_resolver(value)
Source code in luthien_control/control_policy/conditions/comparison_conditions.py
160
161
162
163
@field_serializer("left", "right")
def serialize_value_resolver(self, value: ValueResolver) -> dict:
    """Custom serializer for ValueResolver fields."""
    return value.serialize()

Custom serializer for ValueResolver fields.

validate_value_resolver(value) classmethod
Source code in luthien_control/control_policy/conditions/comparison_conditions.py
165
166
167
168
169
170
171
172
@field_validator("left", "right", mode="before")
@classmethod
def validate_value_resolver(cls, value: StaticValue | ValueResolver):
    """Custom validator to deserialize ValueResolver from dict."""
    if isinstance(value, StaticValue) and isinstance(value.value, dict) and "type" in value.value:
        return create_value_resolver(value.value)
    else:
        return auto_resolve_value(value)

Custom validator to deserialize ValueResolver from dict.

ContainsCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
226
227
228
229
230
231
232
class ContainsCondition(ComparisonCondition):
    """
    Condition to check if the left value contains the right value.
    """

    type: Literal["contains"] = Field(default="contains")
    comparator = contains

Condition to check if the left value contains the right value.

EqualsCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class EqualsCondition(ComparisonCondition):
    """
    Condition to check if two values are equal.

    Examples:
        # Traditional
        EqualsCondition(path("request.payload.model"), "gpt-4o")

        # Dynamic
        EqualsCondition(path("request.payload.model"), path("data.preferred_model"))

        # Static vs dynamic
        EqualsCondition("gpt-4o", path("request.payload.model"))
    """

    type: Literal["equals"] = Field(default="equals")
    comparator = equals

Condition to check if two values are equal.

Examples:

Traditional

EqualsCondition(path("request.payload.model"), "gpt-4o")

Dynamic

EqualsCondition(path("request.payload.model"), path("data.preferred_model"))

Static vs dynamic

EqualsCondition("gpt-4o", path("request.payload.model"))

GreaterThanCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
253
254
255
256
257
258
259
class GreaterThanCondition(ComparisonCondition):
    """
    Condition to check if the left value is greater than the right value.
    """

    type: Literal["greater_than"] = Field(default="greater_than")
    comparator = greater_than

Condition to check if the left value is greater than the right value.

GreaterThanOrEqualCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
262
263
264
265
266
267
268
class GreaterThanOrEqualCondition(ComparisonCondition):
    """
    Condition to check if the left value is greater than or equal to the right value.
    """

    type: Literal["greater_than_or_equal"] = Field(default="greater_than_or_equal")
    comparator = greater_than_or_equal

Condition to check if the left value is greater than or equal to the right value.

LessThanCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
235
236
237
238
239
240
241
class LessThanCondition(ComparisonCondition):
    """
    Condition to check if the left value is less than the right value.
    """

    type: Literal["less_than"] = Field(default="less_than")
    comparator = less_than

Condition to check if the left value is less than the right value.

LessThanOrEqualCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
244
245
246
247
248
249
250
class LessThanOrEqualCondition(ComparisonCondition):
    """
    Condition to check if the left value is less than or equal to the right value.
    """

    type: Literal["less_than_or_equal"] = Field(default="less_than_or_equal")
    comparator = less_than_or_equal

Condition to check if the left value is less than or equal to the right value.

NotEqualsCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
217
218
219
220
221
222
223
class NotEqualsCondition(ComparisonCondition):
    """
    Condition to check if two values are NOT equal.
    """

    type: Literal["not_equals"] = Field(default="not_equals")
    comparator = not_equals

Condition to check if two values are NOT equal.

RegexMatchCondition

Bases: ComparisonCondition

Source code in luthien_control/control_policy/conditions/comparison_conditions.py
271
272
273
274
275
276
277
class RegexMatchCondition(ComparisonCondition):
    """
    Condition to check if the left value matches a regex pattern.
    """

    type: Literal["regex_match"] = Field(default="regex_match")
    comparator = regex_match

Condition to check if the left value matches a regex pattern.

condition

Condition

Bases: BaseModel, ABC

Source code in luthien_control/control_policy/conditions/condition.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class Condition(BaseModel, abc.ABC):
    """
    Abstract base class for conditions in control policies.

    Conditions are used to evaluate whether a policy should be applied based on
    the current transaction.
    """

    type: Any  # Allow any string type including Literal types

    @abc.abstractmethod
    def evaluate(self, transaction: Transaction) -> bool:
        pass

    def serialize(self) -> SerializableDict:
        """Serialize using Pydantic model_dump through SerializableDict validation."""
        data = safe_model_dump(self)
        data["type"] = self.type
        return data

    @classmethod
    def from_serialized(cls, serialized: SerializableDict) -> "Condition":
        """Construct a condition from a serialized configuration.

        This method acts as a dispatcher. It looks up the concrete condition class
        based on the 'type' field in the config and delegates to its from_serialized method.

        Args:
            serialized: The condition-specific configuration dictionary. It must contain
                        a 'type' key that maps to a registered condition type.

        Returns:
            An instance of the concrete condition class.

        Raises:
            ValueError: If the 'type' key is missing in config or the type is not registered.
        """
        # Moved import inside the method to break circular dependency
        from luthien_control.control_policy.conditions.registry import NAME_TO_CONDITION_CLASS

        condition_type_name_val = str(serialized.get("type"))

        target_condition_class = NAME_TO_CONDITION_CLASS[condition_type_name_val]

        return safe_model_validate(target_condition_class, serialized)

    def __repr__(self) -> str:
        return f"{self.serialize()})"

    def __hash__(self) -> int:
        return hash(str(self))

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        return self.serialize() == other.serialize()

    model_config = ConfigDict(arbitrary_types_allowed=True)

Abstract base class for conditions in control policies.

Conditions are used to evaluate whether a policy should be applied based on the current transaction.

from_serialized(serialized) classmethod
Source code in luthien_control/control_policy/conditions/condition.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@classmethod
def from_serialized(cls, serialized: SerializableDict) -> "Condition":
    """Construct a condition from a serialized configuration.

    This method acts as a dispatcher. It looks up the concrete condition class
    based on the 'type' field in the config and delegates to its from_serialized method.

    Args:
        serialized: The condition-specific configuration dictionary. It must contain
                    a 'type' key that maps to a registered condition type.

    Returns:
        An instance of the concrete condition class.

    Raises:
        ValueError: If the 'type' key is missing in config or the type is not registered.
    """
    # Moved import inside the method to break circular dependency
    from luthien_control.control_policy.conditions.registry import NAME_TO_CONDITION_CLASS

    condition_type_name_val = str(serialized.get("type"))

    target_condition_class = NAME_TO_CONDITION_CLASS[condition_type_name_val]

    return safe_model_validate(target_condition_class, serialized)

Construct a condition from a serialized configuration.

This method acts as a dispatcher. It looks up the concrete condition class based on the 'type' field in the config and delegates to its from_serialized method.

Parameters:

Name Type Description Default
serialized SerializableDict

The condition-specific configuration dictionary. It must contain a 'type' key that maps to a registered condition type.

required

Returns:

Type Description
Condition

An instance of the concrete condition class.

Raises:

Type Description
ValueError

If the 'type' key is missing in config or the type is not registered.

serialize()
Source code in luthien_control/control_policy/conditions/condition.py
24
25
26
27
28
def serialize(self) -> SerializableDict:
    """Serialize using Pydantic model_dump through SerializableDict validation."""
    data = safe_model_dump(self)
    data["type"] = self.type
    return data

Serialize using Pydantic model_dump through SerializableDict validation.

not_cond

NotCondition

Bases: Condition

Source code in luthien_control/control_policy/conditions/not_cond.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class NotCondition(Condition):
    type: Literal["not"] = "not"
    cond: Condition = Field(...)

    @field_serializer("cond")
    def serialize_cond(self, value: Condition) -> dict:
        """Custom serializer for cond field."""
        return value.serialize()

    @field_validator("cond", mode="before")
    @classmethod
    def validate_cond(cls, value):
        """Custom validator to deserialize condition from dict."""
        if isinstance(value, dict):
            return Condition.from_serialized(value)
        return value

    def evaluate(self, transaction: Transaction) -> bool:
        return not self.cond.evaluate(transaction)

    def __repr__(self) -> str:
        return f"{type(self).__name__}(value={self.cond!r})"
serialize_cond(value)
Source code in luthien_control/control_policy/conditions/not_cond.py
13
14
15
16
@field_serializer("cond")
def serialize_cond(self, value: Condition) -> dict:
    """Custom serializer for cond field."""
    return value.serialize()

Custom serializer for cond field.

validate_cond(value) classmethod
Source code in luthien_control/control_policy/conditions/not_cond.py
18
19
20
21
22
23
24
@field_validator("cond", mode="before")
@classmethod
def validate_cond(cls, value):
    """Custom validator to deserialize condition from dict."""
    if isinstance(value, dict):
        return Condition.from_serialized(value)
    return value

Custom validator to deserialize condition from dict.

util

get_transaction_value(transaction, path)
Source code in luthien_control/control_policy/conditions/util.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def get_transaction_value(transaction: Transaction, path: str) -> Any:
    """Get a value from the transaction using a path.

    Args:
        transaction: The transaction.
        path: The path to the value e.g. "request.payload.model", "response.payload.choices", "data.user_id".

    Returns:
        The value at the path.

    Raises:
        ValueError: If the path is invalid or the value cannot be accessed.
    """
    vals = path.split(".")
    if len(vals) < 2:
        raise ValueError("Path must contain at least two components")

    x: Any = getattr(transaction, vals.pop(0))
    while vals:
        key = vals.pop(0)

        # Try dict-like access first (includes EventedDict)
        if hasattr(x, "__getitem__") and (isinstance(x, dict) or hasattr(x, "keys")):
            try:
                x = x[key]
                continue
            except (KeyError, TypeError):
                pass

        # Try attribute access
        if hasattr(x, key):
            x = getattr(x, key)
        else:
            # Try accessing as index for list-like objects
            try:
                x = x[int(key)]
            except (ValueError, TypeError, IndexError):
                raise AttributeError(f"Cannot access '{key}' on {type(x).__name__}")
    return x

Get a value from the transaction using a path.

Parameters:

Name Type Description Default
transaction Transaction

The transaction.

required
path str

The path to the value e.g. "request.payload.model", "response.payload.choices", "data.user_id".

required

Returns:

Type Description
Any

The value at the path.

Raises:

Type Description
ValueError

If the path is invalid or the value cannot be accessed.

value_resolvers

StaticValue

Bases: ValueResolver

Source code in luthien_control/control_policy/conditions/value_resolvers.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class StaticValue(ValueResolver):
    """
    A static value that doesn't depend on the transaction.
    """

    type: Literal["static"] = Field(default="static")
    value: Any = Field(...)

    def resolve(self, transaction: Transaction) -> Any:
        """Return the static value."""
        return self.value

    def __repr__(self) -> str:
        return f"StaticValue(value={self.value!r})"

    def __eq__(self, other: object) -> bool:
        """Check equality with another StaticValue."""
        return isinstance(other, StaticValue) and self.value == other.value

A static value that doesn't depend on the transaction.

__eq__(other)
Source code in luthien_control/control_policy/conditions/value_resolvers.py
66
67
68
def __eq__(self, other: object) -> bool:
    """Check equality with another StaticValue."""
    return isinstance(other, StaticValue) and self.value == other.value

Check equality with another StaticValue.

resolve(transaction)
Source code in luthien_control/control_policy/conditions/value_resolvers.py
59
60
61
def resolve(self, transaction: Transaction) -> Any:
    """Return the static value."""
    return self.value

Return the static value.

TransactionPath

Bases: ValueResolver

Source code in luthien_control/control_policy/conditions/value_resolvers.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class TransactionPath(ValueResolver):
    """
    A value resolver that extracts a value from a transaction using a path.
    """

    type: Literal["transaction_path"] = Field(default="transaction_path")
    path: str = Field(...)

    def resolve(self, transaction: Transaction) -> Any:
        """Resolve the value from the transaction using the path."""
        try:
            return get_transaction_value(transaction, self.path)
        except (AttributeError, ValueError):
            return None

    def __repr__(self) -> str:
        return f"TransactionPath(path={self.path!r})"

    def __eq__(self, other: object) -> bool:
        """Check equality with another TransactionPath."""
        return isinstance(other, TransactionPath) and self.path == other.path

A value resolver that extracts a value from a transaction using a path.

__eq__(other)
Source code in luthien_control/control_policy/conditions/value_resolvers.py
89
90
91
def __eq__(self, other: object) -> bool:
    """Check equality with another TransactionPath."""
    return isinstance(other, TransactionPath) and self.path == other.path

Check equality with another TransactionPath.

resolve(transaction)
Source code in luthien_control/control_policy/conditions/value_resolvers.py
79
80
81
82
83
84
def resolve(self, transaction: Transaction) -> Any:
    """Resolve the value from the transaction using the path."""
    try:
        return get_transaction_value(transaction, self.path)
    except (AttributeError, ValueError):
        return None

Resolve the value from the transaction using the path.

ValueResolver

Bases: BaseModel, ABC

Source code in luthien_control/control_policy/conditions/value_resolvers.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class ValueResolver(BaseModel, ABC):
    """
    Abstract base class for resolving values from transactions.
    """

    type: Any = Field(default="")  # Allow any string type including Literal types

    @abstractmethod
    def resolve(self, transaction: Transaction) -> Any:
        """
        Resolve and return a value from the transaction.

        Args:
            transaction: The transaction to resolve the value from

        Returns:
            The resolved value
        """
        pass

    def serialize(self) -> SerializableDict:
        """Serialize using Pydantic model_dump through SerializableDict validation."""
        return safe_model_dump(self)

    @classmethod
    def from_serialized(cls, serialized: SerializableDict) -> "ValueResolver":
        """
        Create a value resolver from a serialized dictionary.

        Args:
            serialized: The serialized representation

        Returns:
            A ValueResolver instance
        """
        return safe_model_validate(cls, serialized)

    model_config = ConfigDict(arbitrary_types_allowed=True)

Abstract base class for resolving values from transactions.

from_serialized(serialized) classmethod
Source code in luthien_control/control_policy/conditions/value_resolvers.py
35
36
37
38
39
40
41
42
43
44
45
46
@classmethod
def from_serialized(cls, serialized: SerializableDict) -> "ValueResolver":
    """
    Create a value resolver from a serialized dictionary.

    Args:
        serialized: The serialized representation

    Returns:
        A ValueResolver instance
    """
    return safe_model_validate(cls, serialized)

Create a value resolver from a serialized dictionary.

Parameters:

Name Type Description Default
serialized SerializableDict

The serialized representation

required

Returns:

Type Description
ValueResolver

A ValueResolver instance

resolve(transaction) abstractmethod
Source code in luthien_control/control_policy/conditions/value_resolvers.py
18
19
20
21
22
23
24
25
26
27
28
29
@abstractmethod
def resolve(self, transaction: Transaction) -> Any:
    """
    Resolve and return a value from the transaction.

    Args:
        transaction: The transaction to resolve the value from

    Returns:
        The resolved value
    """
    pass

Resolve and return a value from the transaction.

Parameters:

Name Type Description Default
transaction Transaction

The transaction to resolve the value from

required

Returns:

Type Description
Any

The resolved value

serialize()
Source code in luthien_control/control_policy/conditions/value_resolvers.py
31
32
33
def serialize(self) -> SerializableDict:
    """Serialize using Pydantic model_dump through SerializableDict validation."""
    return safe_model_dump(self)

Serialize using Pydantic model_dump through SerializableDict validation.

auto_resolve_value(value)
Source code in luthien_control/control_policy/conditions/value_resolvers.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def auto_resolve_value(value: Any) -> ValueResolver:
    """
    Automatically convert a value to an appropriate ValueResolver.

    Args:
        value: Either a static value or a ValueResolver instance

    Returns:
        A ValueResolver instance
    """
    if isinstance(value, ValueResolver):
        return value
    else:
        return StaticValue(value=value)

Automatically convert a value to an appropriate ValueResolver.

Parameters:

Name Type Description Default
value Any

Either a static value or a ValueResolver instance

required

Returns:

Type Description
ValueResolver

A ValueResolver instance

create_value_resolver(serialized)
Source code in luthien_control/control_policy/conditions/value_resolvers.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def create_value_resolver(serialized: SerializableDict) -> ValueResolver:
    """
    Create a value resolver from serialized data.

    Args:
        serialized: The serialized value resolver data

    Returns:
        A ValueResolver instance

    Raises:
        ValueError: If the resolver type is unknown
        KeyError: If the type field is missing
    """
    resolver_type = serialized.get("type")
    if resolver_type not in VALUE_RESOLVER_REGISTRY:
        raise ValueError(f"Unknown value resolver type: {resolver_type}")

    resolver_class = VALUE_RESOLVER_REGISTRY[resolver_type]
    return safe_model_validate(resolver_class, serialized)

Create a value resolver from serialized data.

Parameters:

Name Type Description Default
serialized SerializableDict

The serialized value resolver data

required

Returns:

Type Description
ValueResolver

A ValueResolver instance

Raises:

Type Description
ValueError

If the resolver type is unknown

KeyError

If the type field is missing

path(transaction_path)
Source code in luthien_control/control_policy/conditions/value_resolvers.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def path(transaction_path: str) -> TransactionPath:
    """
    Convenience function to create a TransactionPath.

    Args:
        transaction_path: The path to the value in the transaction

    Returns:
        A TransactionPath instance

    Example:
        path("request.payload.model")
    """
    return TransactionPath(path=transaction_path)

Convenience function to create a TransactionPath.

Parameters:

Name Type Description Default
transaction_path str

The path to the value in the transaction

required

Returns:

Type Description
TransactionPath

A TransactionPath instance

Example

path("request.payload.model")

control_policy

ControlPolicy

Bases: BaseModel, ABC

Source code in luthien_control/control_policy/control_policy.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class ControlPolicy(BaseModel, abc.ABC):
    """Abstract Base Class defining the interface for a processing step.

    Attributes:
        name (Optional[str]): An optional name for the policy instance.
            Subclasses are expected to set this, often in their `__init__` method.
            It's used for logging and identification purposes.
    """

    name: Optional[str] = Field(default=None)
    type: str = Field(default="")
    logger: logging.Logger = Field(default_factory=lambda: logging.getLogger(__name__), exclude=True)

    @classmethod
    def get_policy_type_name(cls) -> str:
        """Get the canonical policy type name for serialization.

        By default, this looks up the class in the registry to get its registered name.
        Subclasses can override this if they need custom behavior.

        Returns:
            The policy type name used in serialization.
        """
        # Import here to avoid circular imports
        from luthien_control.control_policy.registry import POLICY_CLASS_TO_NAME

        policy_type = POLICY_CLASS_TO_NAME.get(cls)
        if policy_type is None:
            raise ValueError(f"{cls.__name__} is not registered in POLICY_CLASS_TO_NAME registry")
        return policy_type

    def __init__(self, **data: Any) -> None:
        """Initializes the ControlPolicy.

        This is an abstract base class, and this constructor typically handles
        common initialization or can be overridden by subclasses.

        Args:
            **data: Arbitrary keyword arguments that subclasses might use.
        """
        if "type" not in data:
            data["type"] = self.get_policy_type_name()
        super().__init__(**data)

    @abc.abstractmethod
    async def apply(
        self,
        transaction: "Transaction",
        container: "DependencyContainer",
        session: "AsyncSession",
    ) -> "Transaction":
        """
        Apply the policy to the transaction using provided dependencies.

        Args:
            transaction: The current transaction.
            container: The dependency injection container.
            session: The database session for the current request. We include this separately because
                it's request-scoped rather than application-scoped.

        Returns:
            The potentially modified transaction.

        Raises:
            Exception: Processors may raise exceptions to halt the processing flow.
        """
        raise NotImplementedError

    def serialize(self) -> SerializableDict:
        """Serialize using Pydantic model_dump through SerializableDict validation."""
        data = self.model_dump(mode="python", by_alias=True, exclude_none=True)
        from luthien_control.control_policy.serialization import SerializableDictAdapter

        return SerializableDictAdapter.validate_python(data)

    # construct from serialization
    @classmethod
    def from_serialized(cls: Type[PolicyT], config: SerializableDict) -> PolicyT:
        """
        Construct a policy from a serialized configuration and optional dependencies.

        This method acts as a dispatcher. It looks up the concrete policy class
        based on the 'type' field in the config and delegates to its from_serialized method.

        Args:
            config: The policy-specific configuration dictionary. It must contain a 'type' key
                    that maps to a registered policy type.
            **kwargs: Additional dependencies needed for instantiation, passed to the
                      concrete policy's from_serialized method.

        Returns:
            An instance of the concrete policy class.

        Raises:
            ValueError: If the 'type' key is missing in config or the type is not registered.
        """
        # Import inside the method to break circular dependency
        from luthien_control.control_policy.registry import POLICY_NAME_TO_CLASS

        config_copy = dict(config)

        policy_type_name_val = config_copy.get("type")

        if not policy_type_name_val and cls != ControlPolicy:
            try:
                inferred_type = cls.get_policy_type_name()
                config_copy["type"] = inferred_type
                policy_type_name_val = inferred_type
            except ValueError:
                pass

        if not policy_type_name_val:
            raise ValueError("Policy configuration must include a 'type' field")

        target_policy_class = POLICY_NAME_TO_CLASS.get(str(policy_type_name_val))
        if not target_policy_class:
            raise ValueError(
                f"Unknown policy type '{policy_type_name_val}'. Ensure it is registered in POLICY_NAME_TO_CLASS."
            )

        return cast(PolicyT, safe_model_validate(target_policy_class, config_copy))

    model_config = ConfigDict(arbitrary_types_allowed=True)

Abstract Base Class defining the interface for a processing step.

Attributes:

Name Type Description
name Optional[str]

An optional name for the policy instance. Subclasses are expected to set this, often in their __init__ method. It's used for logging and identification purposes.

__init__(**data)
Source code in luthien_control/control_policy/control_policy.py
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(self, **data: Any) -> None:
    """Initializes the ControlPolicy.

    This is an abstract base class, and this constructor typically handles
    common initialization or can be overridden by subclasses.

    Args:
        **data: Arbitrary keyword arguments that subclasses might use.
    """
    if "type" not in data:
        data["type"] = self.get_policy_type_name()
    super().__init__(**data)

Initializes the ControlPolicy.

This is an abstract base class, and this constructor typically handles common initialization or can be overridden by subclasses.

Parameters:

Name Type Description Default
**data Any

Arbitrary keyword arguments that subclasses might use.

{}
apply(transaction, container, session) abstractmethod async
Source code in luthien_control/control_policy/control_policy.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@abc.abstractmethod
async def apply(
    self,
    transaction: "Transaction",
    container: "DependencyContainer",
    session: "AsyncSession",
) -> "Transaction":
    """
    Apply the policy to the transaction using provided dependencies.

    Args:
        transaction: The current transaction.
        container: The dependency injection container.
        session: The database session for the current request. We include this separately because
            it's request-scoped rather than application-scoped.

    Returns:
        The potentially modified transaction.

    Raises:
        Exception: Processors may raise exceptions to halt the processing flow.
    """
    raise NotImplementedError

Apply the policy to the transaction using provided dependencies.

Parameters:

Name Type Description Default
transaction Transaction

The current transaction.

required
container DependencyContainer

The dependency injection container.

required
session AsyncSession

The database session for the current request. We include this separately because it's request-scoped rather than application-scoped.

required

Returns:

Type Description
Transaction

The potentially modified transaction.

Raises:

Type Description
Exception

Processors may raise exceptions to halt the processing flow.

from_serialized(config) classmethod
Source code in luthien_control/control_policy/control_policy.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@classmethod
def from_serialized(cls: Type[PolicyT], config: SerializableDict) -> PolicyT:
    """
    Construct a policy from a serialized configuration and optional dependencies.

    This method acts as a dispatcher. It looks up the concrete policy class
    based on the 'type' field in the config and delegates to its from_serialized method.

    Args:
        config: The policy-specific configuration dictionary. It must contain a 'type' key
                that maps to a registered policy type.
        **kwargs: Additional dependencies needed for instantiation, passed to the
                  concrete policy's from_serialized method.

    Returns:
        An instance of the concrete policy class.

    Raises:
        ValueError: If the 'type' key is missing in config or the type is not registered.
    """
    # Import inside the method to break circular dependency
    from luthien_control.control_policy.registry import POLICY_NAME_TO_CLASS

    config_copy = dict(config)

    policy_type_name_val = config_copy.get("type")

    if not policy_type_name_val and cls != ControlPolicy:
        try:
            inferred_type = cls.get_policy_type_name()
            config_copy["type"] = inferred_type
            policy_type_name_val = inferred_type
        except ValueError:
            pass

    if not policy_type_name_val:
        raise ValueError("Policy configuration must include a 'type' field")

    target_policy_class = POLICY_NAME_TO_CLASS.get(str(policy_type_name_val))
    if not target_policy_class:
        raise ValueError(
            f"Unknown policy type '{policy_type_name_val}'. Ensure it is registered in POLICY_NAME_TO_CLASS."
        )

    return cast(PolicyT, safe_model_validate(target_policy_class, config_copy))

Construct a policy from a serialized configuration and optional dependencies.

This method acts as a dispatcher. It looks up the concrete policy class based on the 'type' field in the config and delegates to its from_serialized method.

Parameters:

Name Type Description Default
config SerializableDict

The policy-specific configuration dictionary. It must contain a 'type' key that maps to a registered policy type.

required
**kwargs

Additional dependencies needed for instantiation, passed to the concrete policy's from_serialized method.

required

Returns:

Type Description
PolicyT

An instance of the concrete policy class.

Raises:

Type Description
ValueError

If the 'type' key is missing in config or the type is not registered.

get_policy_type_name() classmethod
Source code in luthien_control/control_policy/control_policy.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@classmethod
def get_policy_type_name(cls) -> str:
    """Get the canonical policy type name for serialization.

    By default, this looks up the class in the registry to get its registered name.
    Subclasses can override this if they need custom behavior.

    Returns:
        The policy type name used in serialization.
    """
    # Import here to avoid circular imports
    from luthien_control.control_policy.registry import POLICY_CLASS_TO_NAME

    policy_type = POLICY_CLASS_TO_NAME.get(cls)
    if policy_type is None:
        raise ValueError(f"{cls.__name__} is not registered in POLICY_CLASS_TO_NAME registry")
    return policy_type

Get the canonical policy type name for serialization.

By default, this looks up the class in the registry to get its registered name. Subclasses can override this if they need custom behavior.

Returns:

Type Description
str

The policy type name used in serialization.

serialize()
Source code in luthien_control/control_policy/control_policy.py
86
87
88
89
90
91
def serialize(self) -> SerializableDict:
    """Serialize using Pydantic model_dump through SerializableDict validation."""
    data = self.model_dump(mode="python", by_alias=True, exclude_none=True)
    from luthien_control.control_policy.serialization import SerializableDictAdapter

    return SerializableDictAdapter.validate_python(data)

Serialize using Pydantic model_dump through SerializableDict validation.

exceptions

ApiKeyNotFoundError

Bases: ControlPolicyError

Source code in luthien_control/control_policy/exceptions.py
62
63
64
65
66
67
68
69
70
71
72
73
class ApiKeyNotFoundError(ControlPolicyError):
    """Exception raised when the API key is not found in the settings."""

    def __init__(self, detail: str, status_code: int = 500):
        """Initializes the ApiKeyNotFoundError.

        Args:
            detail (str): A detailed error message explaining the missing API key.
            status_code (int): The HTTP status code to associate with this error.
                               Defaults to 500 (Internal Server Error).
        """
        super().__init__(detail, status_code=status_code, detail=detail)

Exception raised when the API key is not found in the settings.

__init__(detail, status_code=500)
Source code in luthien_control/control_policy/exceptions.py
65
66
67
68
69
70
71
72
73
def __init__(self, detail: str, status_code: int = 500):
    """Initializes the ApiKeyNotFoundError.

    Args:
        detail (str): A detailed error message explaining the missing API key.
        status_code (int): The HTTP status code to associate with this error.
                           Defaults to 500 (Internal Server Error).
    """
    super().__init__(detail, status_code=status_code, detail=detail)

Initializes the ApiKeyNotFoundError.

Parameters:

Name Type Description Default
detail str

A detailed error message explaining the missing API key.

required
status_code int

The HTTP status code to associate with this error. Defaults to 500 (Internal Server Error).

500

ClientAuthenticationError

Bases: ControlPolicyError

Source code in luthien_control/control_policy/exceptions.py
82
83
84
85
86
87
88
89
90
91
92
93
94
class ClientAuthenticationError(ControlPolicyError):
    """Exception raised when client API key authentication fails."""

    def __init__(self, detail: str, status_code: int = 401):
        """Initializes the ClientAuthenticationError.

        Args:
            detail (str): A detailed error message explaining the authentication failure.
            status_code (int): The HTTP status code to associate with this error.
                               Defaults to 401 (Unauthorized).
        """
        # Pass detail positionally for Exception.__str__ and keywords for ControlPolicyError attributes
        super().__init__(detail, status_code=status_code, detail=detail)

Exception raised when client API key authentication fails.

__init__(detail, status_code=401)
Source code in luthien_control/control_policy/exceptions.py
85
86
87
88
89
90
91
92
93
94
def __init__(self, detail: str, status_code: int = 401):
    """Initializes the ClientAuthenticationError.

    Args:
        detail (str): A detailed error message explaining the authentication failure.
        status_code (int): The HTTP status code to associate with this error.
                           Defaults to 401 (Unauthorized).
    """
    # Pass detail positionally for Exception.__str__ and keywords for ControlPolicyError attributes
    super().__init__(detail, status_code=status_code, detail=detail)

Initializes the ClientAuthenticationError.

Parameters:

Name Type Description Default
detail str

A detailed error message explaining the authentication failure.

required
status_code int

The HTTP status code to associate with this error. Defaults to 401 (Unauthorized).

401

ClientAuthenticationNotFoundError

Bases: ControlPolicyError

Source code in luthien_control/control_policy/exceptions.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class ClientAuthenticationNotFoundError(ControlPolicyError):
    """Exception raised when the client API key is not found in the request."""

    def __init__(self, detail: str, status_code: int = 401):
        """Initializes the ClientAuthenticationNotFoundError.

        Args:
            detail (str): A detailed error message explaining why the key was not found.
            status_code (int): The HTTP status code to associate with this error.
                               Defaults to 401 (Unauthorized).
        """
        # Pass detail positionally for Exception.__str__ and keywords for ControlPolicyError attributes
        super().__init__(detail, status_code=status_code, detail=detail)

Exception raised when the client API key is not found in the request.

__init__(detail, status_code=401)
Source code in luthien_control/control_policy/exceptions.py
100
101
102
103
104
105
106
107
108
109
def __init__(self, detail: str, status_code: int = 401):
    """Initializes the ClientAuthenticationNotFoundError.

    Args:
        detail (str): A detailed error message explaining why the key was not found.
        status_code (int): The HTTP status code to associate with this error.
                           Defaults to 401 (Unauthorized).
    """
    # Pass detail positionally for Exception.__str__ and keywords for ControlPolicyError attributes
    super().__init__(detail, status_code=status_code, detail=detail)

Initializes the ClientAuthenticationNotFoundError.

Parameters:

Name Type Description Default
detail str

A detailed error message explaining why the key was not found.

required
status_code int

The HTTP status code to associate with this error. Defaults to 401 (Unauthorized).

401

ControlPolicyError

Bases: LuthienException

Source code in luthien_control/control_policy/exceptions.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class ControlPolicyError(LuthienException):
    """Base exception for all control policy errors.

    Attributes:
        policy_name (Optional[str]): The name of the policy where the error
            occurred, if specified.
        status_code (Optional[int]): An HTTP status code associated with this
            error, if specified.
        detail (Optional[str]): A detailed error message. If not provided directly
            during initialization but other arguments are, the first positional
            argument is used as the detail.
    """

    def __init__(
        self, *args, policy_name: str | None = None, status_code: int | None = None, detail: str | None = None
    ):
        """Initializes the ControlPolicyError.

        Args:
            *args: Arguments passed to the base Exception class.
            policy_name (Optional[str]): The name of the policy where the error occurred.
            status_code (Optional[int]): An HTTP status code associated with this error.
            detail (Optional[str]): A detailed error message. If not provided and `args`
                                    is not empty, the first argument in `args` is used.
        """
        super().__init__(*args)
        self.policy_name = policy_name
        self.status_code = status_code
        # Use the first arg as detail if detail kwarg is not provided and args exist
        self.detail = detail or (args[0] if args else None)

Base exception for all control policy errors.

Attributes:

Name Type Description
policy_name Optional[str]

The name of the policy where the error occurred, if specified.

status_code Optional[int]

An HTTP status code associated with this error, if specified.

detail Optional[str]

A detailed error message. If not provided directly during initialization but other arguments are, the first positional argument is used as the detail.

__init__(*args, policy_name=None, status_code=None, detail=None)
Source code in luthien_control/control_policy/exceptions.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(
    self, *args, policy_name: str | None = None, status_code: int | None = None, detail: str | None = None
):
    """Initializes the ControlPolicyError.

    Args:
        *args: Arguments passed to the base Exception class.
        policy_name (Optional[str]): The name of the policy where the error occurred.
        status_code (Optional[int]): An HTTP status code associated with this error.
        detail (Optional[str]): A detailed error message. If not provided and `args`
                                is not empty, the first argument in `args` is used.
    """
    super().__init__(*args)
    self.policy_name = policy_name
    self.status_code = status_code
    # Use the first arg as detail if detail kwarg is not provided and args exist
    self.detail = detail or (args[0] if args else None)

Initializes the ControlPolicyError.

Parameters:

Name Type Description Default
*args

Arguments passed to the base Exception class.

()
policy_name Optional[str]

The name of the policy where the error occurred.

None
status_code Optional[int]

An HTTP status code associated with this error.

None
detail Optional[str]

A detailed error message. If not provided and args is not empty, the first argument in args is used.

None

LeakedApiKeyError

Bases: ControlPolicyError

Source code in luthien_control/control_policy/exceptions.py
112
113
114
115
116
117
118
119
120
121
122
123
124
class LeakedApiKeyError(ControlPolicyError):
    """Exception raised when a leaked API key is detected."""

    def __init__(self, detail: str, status_code: int = 403):
        """Initializes the LeakedApiKeyError.

        Args:
            detail (str): A detailed error message explaining the leaked key detection.
            status_code (int): The HTTP status code to associate with this error.
                               Defaults to 403 (Forbidden).
        """
        # Pass detail positionally for Exception.__str__ and keywords for ControlPolicyError attributes
        super().__init__(detail, status_code=status_code, detail=detail)

Exception raised when a leaked API key is detected.

__init__(detail, status_code=403)
Source code in luthien_control/control_policy/exceptions.py
115
116
117
118
119
120
121
122
123
124
def __init__(self, detail: str, status_code: int = 403):
    """Initializes the LeakedApiKeyError.

    Args:
        detail (str): A detailed error message explaining the leaked key detection.
        status_code (int): The HTTP status code to associate with this error.
                           Defaults to 403 (Forbidden).
    """
    # Pass detail positionally for Exception.__str__ and keywords for ControlPolicyError attributes
    super().__init__(detail, status_code=status_code, detail=detail)

Initializes the LeakedApiKeyError.

Parameters:

Name Type Description Default
detail str

A detailed error message explaining the leaked key detection.

required
status_code int

The HTTP status code to associate with this error. Defaults to 403 (Forbidden).

403

NoRequestError

Bases: ControlPolicyError

Source code in luthien_control/control_policy/exceptions.py
76
77
78
79
class NoRequestError(ControlPolicyError):
    """Exception raised when the request object is not found in the context."""

    pass

Exception raised when the request object is not found in the context.

PolicyLoadError

Bases: ValueError, ControlPolicyError

Source code in luthien_control/control_policy/exceptions.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class PolicyLoadError(ValueError, ControlPolicyError):
    """Custom exception for errors during policy loading/instantiation."""

    # Inherit from ValueError for semantic meaning (bad value/config)
    # Inherit from ControlPolicyError for categorization
    def __init__(
        self, *args, policy_name: str | None = None, status_code: int | None = None, detail: str | None = None
    ):
        """Initializes the PolicyLoadError.

        Args:
            *args: Arguments passed to the base Exception class.
            policy_name (Optional[str]): The name of the policy that failed to load.
            status_code (Optional[int]): An HTTP status code associated with this error.
            detail (Optional[str]): A detailed error message. If not provided and `args`
                                    is not empty, the first argument in `args` is used.
        """
        # Explicitly call ControlPolicyError.__init__ to handle kwargs
        ControlPolicyError.__init__(self, *args, policy_name=policy_name, status_code=status_code, detail=detail)

Custom exception for errors during policy loading/instantiation.

__init__(*args, policy_name=None, status_code=None, detail=None)
Source code in luthien_control/control_policy/exceptions.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
    self, *args, policy_name: str | None = None, status_code: int | None = None, detail: str | None = None
):
    """Initializes the PolicyLoadError.

    Args:
        *args: Arguments passed to the base Exception class.
        policy_name (Optional[str]): The name of the policy that failed to load.
        status_code (Optional[int]): An HTTP status code associated with this error.
        detail (Optional[str]): A detailed error message. If not provided and `args`
                                is not empty, the first argument in `args` is used.
    """
    # Explicitly call ControlPolicyError.__init__ to handle kwargs
    ControlPolicyError.__init__(self, *args, policy_name=policy_name, status_code=status_code, detail=detail)

Initializes the PolicyLoadError.

Parameters:

Name Type Description Default
*args

Arguments passed to the base Exception class.

()
policy_name Optional[str]

The name of the policy that failed to load.

None
status_code Optional[int]

An HTTP status code associated with this error.

None
detail Optional[str]

A detailed error message. If not provided and args is not empty, the first argument in args is used.

None

leaked_api_key_detection

Control Policy for detecting leaked API keys in LLM message content.

This policy inspects the 'messages' field in request bodies to prevent sensitive API keys from being sent to language models.

LeakedApiKeyDetectionPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/leaked_api_key_detection.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class LeakedApiKeyDetectionPolicy(ControlPolicy):
    """Detects API keys that might be leaked in message content sent to LLMs.

    This policy scans message content for patterns matching common API key formats
    to prevent accidental exposure of sensitive credentials to language models.
    """

    # Common API key patterns
    DEFAULT_PATTERNS: ClassVar[List[str]] = [
        r"sk-[a-zA-Z0-9]{48}",  # OpenAI API key pattern
        r"xoxb-[a-zA-Z0-9\-]{50,}",  # Slack bot token pattern
        r"github_pat_[a-zA-Z0-9]{22}_[a-zA-Z0-9]{59}",  # GitHub PAT pattern
    ]

    name: Optional[str] = Field(default_factory=lambda: "LeakedApiKeyDetectionPolicy")
    patterns: List[str] = Field(default_factory=lambda: LeakedApiKeyDetectionPolicy.DEFAULT_PATTERNS)
    compiled_patterns: List[re.Pattern] = Field(default_factory=list, exclude=True)

    @field_validator("patterns", mode="before")
    @classmethod
    def validate_patterns(cls, value):
        """Handle patterns validation and fallback to defaults for empty lists."""
        if value is None or (isinstance(value, list) and not value):
            return cls.DEFAULT_PATTERNS
        return value

    @model_validator(mode="after")
    def compile_patterns(self):
        """Compile regex patterns after validation."""
        self.compiled_patterns = [re.compile(pattern) for pattern in self.patterns]
        return self

    async def apply(
        self,
        transaction: Transaction,
        container: DependencyContainer,
        session: AsyncSession,
    ) -> Transaction:
        """
        Checks message content for potentially leaked API keys.

        Args:
            transaction: The current transaction.
            container: The application dependency container.
            session: An active SQLAlchemy AsyncSession.

        Returns:
            The transaction, potentially with an error response set.

        Raises:
            NoRequestError: If the request is not found in the transaction.
            LeakedApiKeyError: If a potential API key is detected in message content.
        """
        if transaction.request is None:
            raise NoRequestError("No request in transaction.")

        self.logger.info(f"Checking for leaked API keys in message content ({self.name}).")

        if hasattr(transaction.request.payload, "messages"):
            messages = transaction.request.payload.messages

            # Inspect each message's content
            for message in messages:
                if hasattr(message, "content") and isinstance(message.content, str):
                    content = message.content
                    if self._check_text(content):
                        error_message = (
                            "Potential API key detected in message content. For security, the request has been blocked."
                        )
                        self.logger.warning(f"{error_message} ({self.name})")
                        raise LeakedApiKeyError(detail=error_message)

        return transaction

    def _check_text(self, text: str) -> bool:
        """
        Checks if the given text contains any patterns matching potential API keys.

        Args:
            text: The text to check.

        Returns:
            True if a potential API key is found, False otherwise.
        """
        for pattern in self.compiled_patterns:
            if pattern.search(text):
                return True
        return False

Detects API keys that might be leaked in message content sent to LLMs.

This policy scans message content for patterns matching common API key formats to prevent accidental exposure of sensitive credentials to language models.

apply(transaction, container, session) async
Source code in luthien_control/control_policy/leaked_api_key_detection.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
async def apply(
    self,
    transaction: Transaction,
    container: DependencyContainer,
    session: AsyncSession,
) -> Transaction:
    """
    Checks message content for potentially leaked API keys.

    Args:
        transaction: The current transaction.
        container: The application dependency container.
        session: An active SQLAlchemy AsyncSession.

    Returns:
        The transaction, potentially with an error response set.

    Raises:
        NoRequestError: If the request is not found in the transaction.
        LeakedApiKeyError: If a potential API key is detected in message content.
    """
    if transaction.request is None:
        raise NoRequestError("No request in transaction.")

    self.logger.info(f"Checking for leaked API keys in message content ({self.name}).")

    if hasattr(transaction.request.payload, "messages"):
        messages = transaction.request.payload.messages

        # Inspect each message's content
        for message in messages:
            if hasattr(message, "content") and isinstance(message.content, str):
                content = message.content
                if self._check_text(content):
                    error_message = (
                        "Potential API key detected in message content. For security, the request has been blocked."
                    )
                    self.logger.warning(f"{error_message} ({self.name})")
                    raise LeakedApiKeyError(detail=error_message)

    return transaction

Checks message content for potentially leaked API keys.

Parameters:

Name Type Description Default
transaction Transaction

The current transaction.

required
container DependencyContainer

The application dependency container.

required
session AsyncSession

An active SQLAlchemy AsyncSession.

required

Returns:

Type Description
Transaction

The transaction, potentially with an error response set.

Raises:

Type Description
NoRequestError

If the request is not found in the transaction.

LeakedApiKeyError

If a potential API key is detected in message content.

compile_patterns()
Source code in luthien_control/control_policy/leaked_api_key_detection.py
46
47
48
49
50
@model_validator(mode="after")
def compile_patterns(self):
    """Compile regex patterns after validation."""
    self.compiled_patterns = [re.compile(pattern) for pattern in self.patterns]
    return self

Compile regex patterns after validation.

validate_patterns(value) classmethod
Source code in luthien_control/control_policy/leaked_api_key_detection.py
38
39
40
41
42
43
44
@field_validator("patterns", mode="before")
@classmethod
def validate_patterns(cls, value):
    """Handle patterns validation and fallback to defaults for empty lists."""
    if value is None or (isinstance(value, list) and not value):
        return cls.DEFAULT_PATTERNS
    return value

Handle patterns validation and fallback to defaults for empty lists.

loader

load_policy(serialized_policy)

Source code in luthien_control/control_policy/loader.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def load_policy(serialized_policy: SerializedPolicy) -> "ControlPolicy":
    """
    Loads a ControlPolicy instance from a dictionary containing its name and config,
    injecting required dependencies.

    Args:
        serialized_policy: A SerializedPolicy object.

    Returns:
        An instantiated ControlPolicy object.

    Raises:
        PolicyLoadError: If the policy name is unknown, data is missing/malformed,
                         or a required dependency is not provided.
        Exception: Potentially from the policy's from_serialized method if config is invalid.
    """
    # Import the policy registry here to avoid circular import
    from .registry import POLICY_NAME_TO_CLASS  # noqa: F401

    logger = logging.getLogger(__name__)

    policy_type = serialized_policy.type
    policy_config = serialized_policy.config

    if not isinstance(policy_type, str):
        raise PolicyLoadError(f"Policy 'type' must be a string, got: {type(policy_type)}")
    if not isinstance(policy_config, dict):
        raise PolicyLoadError(f"Policy 'config' must be a dictionary, got: {type(policy_config)}")

    policy_class = POLICY_NAME_TO_CLASS.get(policy_type)

    # Explicitly check if the policy type was found in the registry
    if policy_class is None:
        raise PolicyLoadError(
            f"Unknown policy type: '{policy_type}'. Available policies: {list(POLICY_NAME_TO_CLASS.keys())}"
        )

    try:
        instance = policy_class.from_serialized(policy_config)
        logger.info(f"Successfully loaded policy: {getattr(instance, 'name', policy_type)}")
        return instance
    except Exception as e:
        logger.error(f"Error instantiating policy '{policy_type}': {e}", exc_info=True)
        raise PolicyLoadError(f"Error instantiating policy '{policy_type}': {e}") from e

Loads a ControlPolicy instance from a dictionary containing its name and config, injecting required dependencies.

Parameters:

Name Type Description Default
serialized_policy SerializedPolicy

A SerializedPolicy object.

required

Returns:

Type Description
ControlPolicy

An instantiated ControlPolicy object.

Raises:

Type Description
PolicyLoadError

If the policy name is unknown, data is missing/malformed, or a required dependency is not provided.

Exception

Potentially from the policy's from_serialized method if config is invalid.

load_policy_from_file(filepath)

Source code in luthien_control/control_policy/loader.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def load_policy_from_file(filepath: str) -> "ControlPolicy":
    """Load a policy configuration from a file and instantiate it using the control_policy loader."""
    with open(filepath, "r") as f:
        raw_policy_data = json.load(f)

    if not isinstance(raw_policy_data, dict):
        raise PolicyLoadError(f"Policy data loaded from {filepath} must be a dictionary, got {type(raw_policy_data)}")

    policy_type = raw_policy_data.get("type")
    policy_config = raw_policy_data.get("config")

    if not isinstance(policy_type, str):
        raise PolicyLoadError(
            f"Policy file {filepath} must contain a 'type' field as a string. Got: {type(policy_type)}"
        )
    if not isinstance(policy_config, dict):
        raise PolicyLoadError(
            f"Policy file {filepath} must contain a 'config' field as a dictionary. Got: {type(policy_config)}"
        )

    serialized_policy_obj = SerializedPolicy(type=policy_type, config=policy_config)
    return load_policy(serialized_policy_obj)

Load a policy configuration from a file and instantiate it using the control_policy loader.

model_name_replacement

ModelNameReplacementPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/model_name_replacement.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class ModelNameReplacementPolicy(ControlPolicy):
    """Replaces model names in requests based on a configured mapping.

    This policy allows clients to use fake model names that will be
    replaced with real model names before the request is sent to the backend.
    This is useful for services like Cursor that assume model strings that match
    known models must route through specific endpoints.
    """

    name: Optional[str] = Field(default="ModelNameReplacementPolicy")
    model_mapping: Dict[str, str] = Field(default_factory=dict)

    async def apply(
        self,
        transaction: Transaction,
        container: DependencyContainer,
        session: AsyncSession,
    ) -> Transaction:
        """
        Replaces the model name in the request payload based on the configured mapping.

        Args:
            transaction: The current transaction.
            container: The application dependency container.
            session: An active SQLAlchemy AsyncSession (unused).

        Returns:
            The potentially modified transaction.

        Raises:
            NoRequestError: If no request is found in the transaction.
        """
        if transaction.request is None:
            raise NoRequestError("No request in transaction.")

        if hasattr(transaction.request.payload, "model"):
            original_model = transaction.request.payload.model

            if original_model in self.model_mapping:
                new_model = self.model_mapping[original_model]
                self.logger.info(f"Replacing model name: {original_model} -> {new_model}")
                transaction.request.payload.model = new_model

        return transaction

Replaces model names in requests based on a configured mapping.

This policy allows clients to use fake model names that will be replaced with real model names before the request is sent to the backend. This is useful for services like Cursor that assume model strings that match known models must route through specific endpoints.

apply(transaction, container, session) async
Source code in luthien_control/control_policy/model_name_replacement.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
async def apply(
    self,
    transaction: Transaction,
    container: DependencyContainer,
    session: AsyncSession,
) -> Transaction:
    """
    Replaces the model name in the request payload based on the configured mapping.

    Args:
        transaction: The current transaction.
        container: The application dependency container.
        session: An active SQLAlchemy AsyncSession (unused).

    Returns:
        The potentially modified transaction.

    Raises:
        NoRequestError: If no request is found in the transaction.
    """
    if transaction.request is None:
        raise NoRequestError("No request in transaction.")

    if hasattr(transaction.request.payload, "model"):
        original_model = transaction.request.payload.model

        if original_model in self.model_mapping:
            new_model = self.model_mapping[original_model]
            self.logger.info(f"Replacing model name: {original_model} -> {new_model}")
            transaction.request.payload.model = new_model

    return transaction

Replaces the model name in the request payload based on the configured mapping.

Parameters:

Name Type Description Default
transaction Transaction

The current transaction.

required
container DependencyContainer

The application dependency container.

required
session AsyncSession

An active SQLAlchemy AsyncSession (unused).

required

Returns:

Type Description
Transaction

The potentially modified transaction.

Raises:

Type Description
NoRequestError

If no request is found in the transaction.

noop_policy

NoopPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/noop_policy.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class NoopPolicy(ControlPolicy):
    """A policy that does nothing.

    This is the simplest possible policy implementation. It passes through
    the transaction unchanged and has no policy-specific configuration beyond
    its name.
    """

    name: Optional[str] = Field(default="NoopPolicy")

    async def apply(
        self, transaction: Transaction, container: DependencyContainer, session: AsyncSession
    ) -> Transaction:
        """Simply returns the transaction unchanged."""
        return transaction

A policy that does nothing.

This is the simplest possible policy implementation. It passes through the transaction unchanged and has no policy-specific configuration beyond its name.

apply(transaction, container, session) async
Source code in luthien_control/control_policy/noop_policy.py
21
22
23
24
25
async def apply(
    self, transaction: Transaction, container: DependencyContainer, session: AsyncSession
) -> Transaction:
    """Simply returns the transaction unchanged."""
    return transaction

Simply returns the transaction unchanged.

send_backend_request

SendBackendRequestPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/send_backend_request.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class SendBackendRequestPolicy(ControlPolicy):
    """
    Policy responsible for sending the chat completions request to the OpenAI-compatible backend
    using the OpenAI SDK and storing the structured response.

    Attributes:
        name (str): The name of this policy instance, used for logging and
            identification. It defaults to the class name if not provided
            during initialization.
        logger (logging.Logger): The logger instance for this policy.
    """

    name: Optional[str] = Field(default="SendBackendRequestPolicy")

    def _create_debug_info(
        self, backend_url: str, request_payload: Any, error: Exception, api_key: str = ""
    ) -> Dict[str, Any]:
        """Create debug information for backend request failures."""
        debug_info = {
            "backend_url": backend_url,
            "request_model": getattr(request_payload, "model", "unknown"),
            "request_messages_count": len(getattr(request_payload, "messages", [])),
            "error_type": error.__class__.__name__,
            "error_message": str(error),
        }

        # Add OpenAI-specific error details if available
        if hasattr(error, "response") and getattr(error, "response", None) is not None:
            response = getattr(error, "response")
            status_code = getattr(response, "status_code", None)
            debug_info["backend_response"] = {
                "status_code": status_code,
                "headers": dict(getattr(response, "headers", {})),
            }
            # Try to get response body if available
            if hasattr(response, "text"):
                debug_info["backend_response"]["body"] = getattr(response, "text", "")

            # For 404 errors, include identifying characters from the API key
            if status_code == 404 and api_key:
                debug_info["api_key_identifier"] = self._get_api_key_identifier(api_key)

        if hasattr(error, "body") and getattr(error, "body", None) is not None:
            debug_info["backend_error_body"] = getattr(error, "body")

        return debug_info

    def _get_api_key_identifier(self, api_key: str) -> str:
        """Get identifying characters from API key for debugging (first 8 and last 4 chars)."""
        if not api_key:
            return "empty"
        if len(api_key) <= 12:
            return f"{api_key[:4]}...{api_key[-2:]}"
        return f"{api_key[:8]}...{api_key[-4:]}"

    async def apply(
        self,
        transaction: Transaction,
        container: DependencyContainer,
        session: AsyncSession,
    ) -> Transaction:
        """
        Sends the chat completions request to the OpenAI-compatible backend using the OpenAI SDK.

        This policy uses the OpenAI SDK to send the structured chat completions request
        from transaction.request.payload to the backend API endpoint. The response
        is stored as a structured OpenAIChatCompletionsResponse in transaction.response.payload.

        Args:
            transaction: The current transaction, containing the request payload to be sent.
            container: The application dependency container, providing settings and OpenAI client.
            session: An active SQLAlchemy AsyncSession. (Unused by this policy but required by the interface).

        Returns:
            The Transaction, updated with transaction.response.payload containing the
            OpenAIChatCompletionsResponse from the backend.

        Raises:
            ValueError: If backend URL or API key is not configured.
            openai.APIError: For API-related errors from the OpenAI backend.
            openai.APITimeoutError: If the request to the backend times out.
            openai.APIConnectionError: For network-related issues during the backend request.
            Exception: For any other unexpected errors during request execution.
        """
        # Create OpenAI client for the backend request
        backend_url = transaction.request.api_endpoint
        api_key = transaction.request.api_key

        if not backend_url:
            raise ValueError("Backend URL is not configured")
        if not api_key:
            raise ValueError("OpenAI API key is not configured")

        self.logger.info(f"Creating OpenAI client with backend URL: '{backend_url}' ({self.name})")
        openai_client = container.create_openai_client(backend_url, api_key)

        # Get the structured request payload
        request_payload = transaction.request.payload

        self.logger.info(
            f"Sending chat completions request to backend with model '{request_payload.model}' "
            f"and {len(request_payload.messages)} messages. ({self.name}); "
            f"Target url: {backend_url}"
        )

        try:
            # Send request using OpenAI SDK
            # Use the request payload directly - the OpenAI SDK should accept our Pydantic model
            request_dict = request_payload.model_dump()
            # Remove any None values to avoid issues with the OpenAI SDK
            request_dict = {k: v for k, v in request_dict.items() if v is not None}

            backend_response = await openai_client.chat.completions.create(**request_dict)

            # Convert OpenAI SDK response to our structured response model
            response_payload = OpenAIChatCompletionsResponse.model_validate(backend_response.model_dump())

            # Store the structured response in the transaction
            transaction.response.payload = response_payload
            transaction.response.api_endpoint = backend_url

            self.logger.info(
                f"Received backend response with {len(response_payload.choices)} choices "
                f"and usage: {response_payload.usage}. ({self.name})"
            )

        except openai.APITimeoutError as e:
            self.logger.error(f"Timeout error during backend request: {e} ({self.name})")
            # Store debug information for potential dev mode access
            debug_info = self._create_debug_info(backend_url, request_payload, e, api_key)
            e.debug_info = debug_info  # type: ignore
            raise
        except openai.APIConnectionError as e:
            self.logger.error(f"Connection error during backend request: {e} ({self.name})")
            # Store debug information for potential dev mode access
            debug_info = self._create_debug_info(backend_url, request_payload, e, api_key)
            e.debug_info = debug_info  # type: ignore
            raise
        except openai.APIError as e:
            self.logger.error(f"OpenAI API error during backend request: {e} ({self.name})")
            # Store debug information for potential dev mode access
            debug_info = self._create_debug_info(backend_url, request_payload, e, api_key)
            e.debug_info = debug_info  # type: ignore
            raise
        except Exception as e:
            self.logger.exception(f"Unexpected error during backend request: {e} ({self.name})")
            # Store debug information for potential dev mode access
            debug_info = self._create_debug_info(backend_url, request_payload, e, api_key)
            e.debug_info = debug_info  # type: ignore
            raise

        return transaction

Policy responsible for sending the chat completions request to the OpenAI-compatible backend using the OpenAI SDK and storing the structured response.

Attributes:

Name Type Description
name str

The name of this policy instance, used for logging and identification. It defaults to the class name if not provided during initialization.

logger Logger

The logger instance for this policy.

apply(transaction, container, session) async
Source code in luthien_control/control_policy/send_backend_request.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
async def apply(
    self,
    transaction: Transaction,
    container: DependencyContainer,
    session: AsyncSession,
) -> Transaction:
    """
    Sends the chat completions request to the OpenAI-compatible backend using the OpenAI SDK.

    This policy uses the OpenAI SDK to send the structured chat completions request
    from transaction.request.payload to the backend API endpoint. The response
    is stored as a structured OpenAIChatCompletionsResponse in transaction.response.payload.

    Args:
        transaction: The current transaction, containing the request payload to be sent.
        container: The application dependency container, providing settings and OpenAI client.
        session: An active SQLAlchemy AsyncSession. (Unused by this policy but required by the interface).

    Returns:
        The Transaction, updated with transaction.response.payload containing the
        OpenAIChatCompletionsResponse from the backend.

    Raises:
        ValueError: If backend URL or API key is not configured.
        openai.APIError: For API-related errors from the OpenAI backend.
        openai.APITimeoutError: If the request to the backend times out.
        openai.APIConnectionError: For network-related issues during the backend request.
        Exception: For any other unexpected errors during request execution.
    """
    # Create OpenAI client for the backend request
    backend_url = transaction.request.api_endpoint
    api_key = transaction.request.api_key

    if not backend_url:
        raise ValueError("Backend URL is not configured")
    if not api_key:
        raise ValueError("OpenAI API key is not configured")

    self.logger.info(f"Creating OpenAI client with backend URL: '{backend_url}' ({self.name})")
    openai_client = container.create_openai_client(backend_url, api_key)

    # Get the structured request payload
    request_payload = transaction.request.payload

    self.logger.info(
        f"Sending chat completions request to backend with model '{request_payload.model}' "
        f"and {len(request_payload.messages)} messages. ({self.name}); "
        f"Target url: {backend_url}"
    )

    try:
        # Send request using OpenAI SDK
        # Use the request payload directly - the OpenAI SDK should accept our Pydantic model
        request_dict = request_payload.model_dump()
        # Remove any None values to avoid issues with the OpenAI SDK
        request_dict = {k: v for k, v in request_dict.items() if v is not None}

        backend_response = await openai_client.chat.completions.create(**request_dict)

        # Convert OpenAI SDK response to our structured response model
        response_payload = OpenAIChatCompletionsResponse.model_validate(backend_response.model_dump())

        # Store the structured response in the transaction
        transaction.response.payload = response_payload
        transaction.response.api_endpoint = backend_url

        self.logger.info(
            f"Received backend response with {len(response_payload.choices)} choices "
            f"and usage: {response_payload.usage}. ({self.name})"
        )

    except openai.APITimeoutError as e:
        self.logger.error(f"Timeout error during backend request: {e} ({self.name})")
        # Store debug information for potential dev mode access
        debug_info = self._create_debug_info(backend_url, request_payload, e, api_key)
        e.debug_info = debug_info  # type: ignore
        raise
    except openai.APIConnectionError as e:
        self.logger.error(f"Connection error during backend request: {e} ({self.name})")
        # Store debug information for potential dev mode access
        debug_info = self._create_debug_info(backend_url, request_payload, e, api_key)
        e.debug_info = debug_info  # type: ignore
        raise
    except openai.APIError as e:
        self.logger.error(f"OpenAI API error during backend request: {e} ({self.name})")
        # Store debug information for potential dev mode access
        debug_info = self._create_debug_info(backend_url, request_payload, e, api_key)
        e.debug_info = debug_info  # type: ignore
        raise
    except Exception as e:
        self.logger.exception(f"Unexpected error during backend request: {e} ({self.name})")
        # Store debug information for potential dev mode access
        debug_info = self._create_debug_info(backend_url, request_payload, e, api_key)
        e.debug_info = debug_info  # type: ignore
        raise

    return transaction

Sends the chat completions request to the OpenAI-compatible backend using the OpenAI SDK.

This policy uses the OpenAI SDK to send the structured chat completions request from transaction.request.payload to the backend API endpoint. The response is stored as a structured OpenAIChatCompletionsResponse in transaction.response.payload.

Parameters:

Name Type Description Default
transaction Transaction

The current transaction, containing the request payload to be sent.

required
container DependencyContainer

The application dependency container, providing settings and OpenAI client.

required
session AsyncSession

An active SQLAlchemy AsyncSession. (Unused by this policy but required by the interface).

required

Returns:

Type Description
Transaction

The Transaction, updated with transaction.response.payload containing the

Transaction

OpenAIChatCompletionsResponse from the backend.

Raises:

Type Description
ValueError

If backend URL or API key is not configured.

APIError

For API-related errors from the OpenAI backend.

APITimeoutError

If the request to the backend times out.

APIConnectionError

For network-related issues during the backend request.

Exception

For any other unexpected errors during request execution.

serial_policy

SerialPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/serial_policy.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class SerialPolicy(ControlPolicy):
    """
    A Control Policy that applies an ordered sequence of other policies.

    Policies are applied sequentially. If any policy raises an exception,
    the execution stops, and the exception propagates.

    Attributes:
        policies (Sequence[ControlPolicy]): The ordered sequence of ControlPolicy
            instances that this policy will apply.
        logger (logging.Logger): The logger instance for this policy.
        name (str): The name of this policy instance, used for logging and
            identification.
    """

    name: Optional[str] = Field(default="SerialPolicy")
    policies: Sequence[ControlPolicy] = Field(...)

    def __init__(self, **data):
        super().__init__(**data)
        if not self.policies:
            import logging

            logger = logging.getLogger(__name__)
            logger.warning(f"Initializing SerialPolicy '{self.name}' with an empty policy list.")

    async def apply(
        self,
        transaction: Transaction,
        container: DependencyContainer,
        session: AsyncSession,
    ) -> Transaction:
        """
        Applies the contained policies sequentially to the transaction.
        Requires the DependencyContainer and an active SQLAlchemy AsyncSession.

        Args:
            transaction: The current transaction.
            container: The application dependency container.
            session: An active SQLAlchemy AsyncSession, passed to member policies.

        Returns:
            The transaction after all contained policies have been applied.

        Raises:
            Exception: Propagates any exception raised by a contained policy.
        """
        self.logger.debug(f"Entering SerialPolicy: {self.name}")
        current_transaction = transaction
        for i, policy in enumerate(self.policies):
            member_policy_name = getattr(policy, "name", policy.__class__.__name__)  # Get policy name if available
            self.logger.info(f"Applying policy {i + 1}/{len(self.policies)} in {self.name}: {member_policy_name}")
            try:
                current_transaction = await policy.apply(current_transaction, container=container, session=session)
            except Exception as e:
                self.logger.error(
                    f"Error applying policy {member_policy_name} within {self.name}: {e}",
                    exc_info=True,
                )
                raise  # Re-raise the exception to halt processing
        self.logger.debug(f"Exiting SerialPolicy: {self.name}")
        return current_transaction

    def __repr__(self) -> str:
        """Provides a developer-friendly representation."""
        # Get the name of each policy, using getattr as fallback like in apply
        policy_reprs = [f"{p.name} <{p.__class__.__name__}>" for p in self.policies]
        policy_list_str = ", ".join(policy_reprs)
        return f"<{self.name}(policies=[{policy_list_str}])>"

    @classmethod
    def from_serialized(cls, config: SerializableDict) -> "SerialPolicy":
        """
        Constructs a SerialPolicy from serialized data, loading member policies.

        Args:
            config: The serialized configuration dictionary. Expects a 'policies' key
                    containing a list of dictionaries, each with 'type' and 'config'.

        Returns:
            An instance of SerialPolicy.

        Raises:
            PolicyLoadError: If 'policies' key is missing, not a list, or if loading
                             a member policy fails.
        """
        member_policy_data_list_val = config.get("policies")

        if member_policy_data_list_val is None:
            raise PolicyLoadError("SerialPolicy config missing 'policies' list (key not found).")
        if not isinstance(member_policy_data_list_val, Iterable):
            raise PolicyLoadError(
                f"SerialPolicy 'policies' must be an iterable. Got {type(member_policy_data_list_val)}"
            )

        instantiated_policies = []

        for i, member_data in enumerate(member_policy_data_list_val):
            if not isinstance(member_data, dict):
                raise PolicyLoadError(
                    f"Item at index {i} in SerialPolicy 'policies' is not a dictionary. Got {type(member_data)}"
                )

            try:
                # Import load_policy to properly handle member policy loading
                from luthien_control.control_policy.loader import load_policy

                # Get the type and config from member_data
                member_type = member_data.get("type")
                member_config = member_data.get("config", {})

                if not isinstance(member_type, str):
                    raise PolicyLoadError(
                        f"Member policy at index {i} must have a 'type' field as string. Got: {type(member_type)}"
                    )
                if not isinstance(member_config, dict):
                    raise PolicyLoadError(
                        f"Member policy at index {i} must have a 'config' field as dict. Got: {type(member_config)}"
                    )

                # If name is at the top level (legacy format), move it to config
                if "name" in member_data and "name" not in member_config:
                    member_config["name"] = member_data.get("name")

                # Create SerializedPolicy object from member_data
                serialized_member = SerializedPolicy(type=member_type, config=member_config)
                member_policy = load_policy(serialized_member)
                instantiated_policies.append(member_policy)
            except PolicyLoadError as e:
                raise PolicyLoadError(
                    f"Failed to load member policy at index {i} "
                    f"(name: {member_data.get('name', 'unknown')}) "
                    f"within SerialPolicy: {e}"
                ) from e
            except Exception as e:
                raise PolicyLoadError(
                    f"Unexpected error loading member policy at index {i} "
                    f"(name: {member_data.get('name', 'unknown')}) "
                    f"within SerialPolicy: {e}"
                ) from e

        return cls(policies=instantiated_policies, **{k: v for k, v in config.items() if k != "policies"})

A Control Policy that applies an ordered sequence of other policies.

Policies are applied sequentially. If any policy raises an exception, the execution stops, and the exception propagates.

Attributes:

Name Type Description
policies Sequence[ControlPolicy]

The ordered sequence of ControlPolicy instances that this policy will apply.

logger Logger

The logger instance for this policy.

name str

The name of this policy instance, used for logging and identification.

__repr__()
Source code in luthien_control/control_policy/serial_policy.py
78
79
80
81
82
83
def __repr__(self) -> str:
    """Provides a developer-friendly representation."""
    # Get the name of each policy, using getattr as fallback like in apply
    policy_reprs = [f"{p.name} <{p.__class__.__name__}>" for p in self.policies]
    policy_list_str = ", ".join(policy_reprs)
    return f"<{self.name}(policies=[{policy_list_str}])>"

Provides a developer-friendly representation.

apply(transaction, container, session) async
Source code in luthien_control/control_policy/serial_policy.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
async def apply(
    self,
    transaction: Transaction,
    container: DependencyContainer,
    session: AsyncSession,
) -> Transaction:
    """
    Applies the contained policies sequentially to the transaction.
    Requires the DependencyContainer and an active SQLAlchemy AsyncSession.

    Args:
        transaction: The current transaction.
        container: The application dependency container.
        session: An active SQLAlchemy AsyncSession, passed to member policies.

    Returns:
        The transaction after all contained policies have been applied.

    Raises:
        Exception: Propagates any exception raised by a contained policy.
    """
    self.logger.debug(f"Entering SerialPolicy: {self.name}")
    current_transaction = transaction
    for i, policy in enumerate(self.policies):
        member_policy_name = getattr(policy, "name", policy.__class__.__name__)  # Get policy name if available
        self.logger.info(f"Applying policy {i + 1}/{len(self.policies)} in {self.name}: {member_policy_name}")
        try:
            current_transaction = await policy.apply(current_transaction, container=container, session=session)
        except Exception as e:
            self.logger.error(
                f"Error applying policy {member_policy_name} within {self.name}: {e}",
                exc_info=True,
            )
            raise  # Re-raise the exception to halt processing
    self.logger.debug(f"Exiting SerialPolicy: {self.name}")
    return current_transaction

Applies the contained policies sequentially to the transaction. Requires the DependencyContainer and an active SQLAlchemy AsyncSession.

Parameters:

Name Type Description Default
transaction Transaction

The current transaction.

required
container DependencyContainer

The application dependency container.

required
session AsyncSession

An active SQLAlchemy AsyncSession, passed to member policies.

required

Returns:

Type Description
Transaction

The transaction after all contained policies have been applied.

Raises:

Type Description
Exception

Propagates any exception raised by a contained policy.

from_serialized(config) classmethod
Source code in luthien_control/control_policy/serial_policy.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@classmethod
def from_serialized(cls, config: SerializableDict) -> "SerialPolicy":
    """
    Constructs a SerialPolicy from serialized data, loading member policies.

    Args:
        config: The serialized configuration dictionary. Expects a 'policies' key
                containing a list of dictionaries, each with 'type' and 'config'.

    Returns:
        An instance of SerialPolicy.

    Raises:
        PolicyLoadError: If 'policies' key is missing, not a list, or if loading
                         a member policy fails.
    """
    member_policy_data_list_val = config.get("policies")

    if member_policy_data_list_val is None:
        raise PolicyLoadError("SerialPolicy config missing 'policies' list (key not found).")
    if not isinstance(member_policy_data_list_val, Iterable):
        raise PolicyLoadError(
            f"SerialPolicy 'policies' must be an iterable. Got {type(member_policy_data_list_val)}"
        )

    instantiated_policies = []

    for i, member_data in enumerate(member_policy_data_list_val):
        if not isinstance(member_data, dict):
            raise PolicyLoadError(
                f"Item at index {i} in SerialPolicy 'policies' is not a dictionary. Got {type(member_data)}"
            )

        try:
            # Import load_policy to properly handle member policy loading
            from luthien_control.control_policy.loader import load_policy

            # Get the type and config from member_data
            member_type = member_data.get("type")
            member_config = member_data.get("config", {})

            if not isinstance(member_type, str):
                raise PolicyLoadError(
                    f"Member policy at index {i} must have a 'type' field as string. Got: {type(member_type)}"
                )
            if not isinstance(member_config, dict):
                raise PolicyLoadError(
                    f"Member policy at index {i} must have a 'config' field as dict. Got: {type(member_config)}"
                )

            # If name is at the top level (legacy format), move it to config
            if "name" in member_data and "name" not in member_config:
                member_config["name"] = member_data.get("name")

            # Create SerializedPolicy object from member_data
            serialized_member = SerializedPolicy(type=member_type, config=member_config)
            member_policy = load_policy(serialized_member)
            instantiated_policies.append(member_policy)
        except PolicyLoadError as e:
            raise PolicyLoadError(
                f"Failed to load member policy at index {i} "
                f"(name: {member_data.get('name', 'unknown')}) "
                f"within SerialPolicy: {e}"
            ) from e
        except Exception as e:
            raise PolicyLoadError(
                f"Unexpected error loading member policy at index {i} "
                f"(name: {member_data.get('name', 'unknown')}) "
                f"within SerialPolicy: {e}"
            ) from e

    return cls(policies=instantiated_policies, **{k: v for k, v in config.items() if k != "policies"})

Constructs a SerialPolicy from serialized data, loading member policies.

Parameters:

Name Type Description Default
config SerializableDict

The serialized configuration dictionary. Expects a 'policies' key containing a list of dictionaries, each with 'type' and 'config'.

required

Returns:

Type Description
SerialPolicy

An instance of SerialPolicy.

Raises:

Type Description
PolicyLoadError

If 'policies' key is missing, not a list, or if loading a member policy fails.

serialization

SerializedPolicy dataclass

Source code in luthien_control/control_policy/serialization.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@dataclass
class SerializedPolicy:
    """Represents the serialized form of a ControlPolicy.

    This structure is used to store and transfer policy configurations.
    The 'type' field identifies the specific policy class, and the 'config'
    field contains the parameters needed to reconstruct that policy instance.

    Attributes:
        type (str): The registered name of the policy type (e.g., "AddApiKeyHeader").
        config (SerializableDict): A dictionary containing the configuration
                                   parameters for the policy instance.
    """

    type: str
    config: SerializableDict

Represents the serialized form of a ControlPolicy.

This structure is used to store and transfer policy configurations. The 'type' field identifies the specific policy class, and the 'config' field contains the parameters needed to reconstruct that policy instance.

Attributes:

Name Type Description
type str

The registered name of the policy type (e.g., "AddApiKeyHeader").

config SerializableDict

A dictionary containing the configuration parameters for the policy instance.

safe_model_dump(model)

Source code in luthien_control/control_policy/serialization.py
17
18
19
20
def safe_model_dump(model: BaseModel) -> SerializableDict:
    """Safely dump a Pydantic model through SerializableDict validation."""
    data = model.model_dump(mode="python", by_alias=True)
    return SerializableDictAdapter.validate_python(data)

Safely dump a Pydantic model through SerializableDict validation.

safe_model_validate(model_class, data)

Source code in luthien_control/control_policy/serialization.py
26
27
28
29
def safe_model_validate(model_class: type[T], data: SerializableDict) -> T:
    """Safely validate data through SerializableDict before creating model."""
    validated_data = SerializableDictAdapter.validate_python(data)
    return model_class.model_validate(validated_data, from_attributes=True)

Safely validate data through SerializableDict before creating model.

set_backend_policy

SetBackendPolicy

Bases: ControlPolicy

Source code in luthien_control/control_policy/set_backend_policy.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class SetBackendPolicy(ControlPolicy):
    """A policy that sets the backend URL for the transaction."""

    name: Optional[str] = Field(default="SetBackendPolicy")
    backend_url: Optional[str] = Field(default=None)

    async def apply(
        self, transaction: Transaction, container: DependencyContainer, session: AsyncSession
    ) -> Transaction:
        if self.backend_url is not None:
            # Set the base URL only - the OpenAI client will append the specific endpoint path
            # The original api_endpoint (e.g., "chat/completions") will be used by the OpenAI client
            transaction.request.api_endpoint = self.backend_url
        return transaction

    def _get_policy_specific_config(self) -> SerializableDict:
        """Return policy-specific configuration for backward compatibility with tests."""
        return SerializableDict(
            backend_url=self.backend_url,
        )

A policy that sets the backend URL for the transaction.

core

dependencies

get_db_session(dependencies=Depends(get_dependencies)) async

Source code in luthien_control/core/dependencies.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
async def get_db_session(
    dependencies: DependencyContainer = Depends(get_dependencies),
) -> AsyncGenerator[AsyncSession, None]:
    """FastAPI dependency to get an async database session using the container's factory."""
    session_factory = dependencies.db_session_factory
    if session_factory is None:
        # This shouldn't happen if the container is initialized correctly
        logger.critical("DB Session Factory not found in DependencyContainer.")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Internal server error: Database session factory not available.",
        )

    async with session_factory() as session:
        try:
            yield session
        except Exception:
            await session.rollback()
            raise
        finally:
            # The session context manager should handle commit/close,
            # but rollback is explicit on exception.
            pass

FastAPI dependency to get an async database session using the container's factory.

get_dependencies(request)

Source code in luthien_control/core/dependencies.py
21
22
23
24
25
26
27
28
29
30
31
32
33
def get_dependencies(request: Request) -> DependencyContainer:
    """Dependency to retrieve the DependencyContainer from application state."""
    dependencies: DependencyContainer | None = getattr(request.app.state, "dependencies", None)
    if dependencies is None:
        logger.critical(
            "DependencyContainer not found in application state. "
            "This indicates a critical setup error in the application lifespan."
        )
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Internal server error: Application dependencies not initialized.",
        )
    return dependencies

Dependency to retrieve the DependencyContainer from application state.

get_main_control_policy(dependencies=Depends(get_dependencies)) async

Source code in luthien_control/core/dependencies.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
async def get_main_control_policy(
    dependencies: DependencyContainer = Depends(get_dependencies),
) -> ControlPolicy:
    """
    Dependency to load and provide the main ControlPolicy instance.

    Uses the DependencyContainer to access settings, http_client, and a database session.
    """
    settings = dependencies.settings
    policy_filepath = settings.get_policy_filepath()
    if policy_filepath:
        logger.info(f"Loading main control policy from file: {policy_filepath}")
        return load_policy_from_file(policy_filepath)

    top_level_policy_name = settings.get_top_level_policy_name()
    if not top_level_policy_name:
        logger.error("TOP_LEVEL_POLICY_NAME is not configured in settings.")
        raise HTTPException(status_code=500, detail="Internal server error: Control policy name not configured.")
    try:
        # Get a session using the container's factory - No longer needed here, load_policy_from_db handles it
        # async with session_factory() as session:
        # Pass the container directly to load_policy_from_db
        main_policy = await load_policy_from_db(
            name=top_level_policy_name,
            container=dependencies,  # Pass the whole container
        )

        if not main_policy:
            logger.error(f"Main control policy '{top_level_policy_name}' could not be loaded (not found or inactive).")
            raise HTTPException(
                status_code=500,
                detail=f"Internal server error: Main control policy '{top_level_policy_name}' not found or inactive.",
            )

        return main_policy

    except PolicyLoadError as e:
        logger.exception(f"Failed to load main control policy '{top_level_policy_name}': {e}")
        raise HTTPException(status_code=500, detail=f"Internal server error: Could not load main control policy. {e}")
    except HTTPException:  # Re-raise HTTPExceptions from session creation
        raise
    except Exception as e:
        logger.exception(f"Unexpected error loading main control policy '{top_level_policy_name}': {e}")
        raise HTTPException(
            status_code=500, detail="Internal server error: Unexpected issue loading main control policy."
        )

Dependency to load and provide the main ControlPolicy instance.

Uses the DependencyContainer to access settings, http_client, and a database session.

initialize_app_dependencies(app_settings) async

Source code in luthien_control/core/dependencies.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
async def initialize_app_dependencies(app_settings: Settings) -> DependencyContainer:
    """Initialize and configure core application dependencies.

    This function sets up essential services required by the application,
    including an HTTP client and a database connection pool. It encapsulates
    the creation and configuration of these dependencies into a
    DependencyContainer instance.

    Args:
        app_settings: The application settings instance.

    Returns:
        A DependencyContainer instance populated with initialized dependencies.

    Raises:
        RuntimeError: If initialization of the HTTP client or database engine fails.
    """
    logger.info("Initializing core application dependencies...")

    # Initialize HTTP client
    timeout = httpx.Timeout(5.0, connect=5.0, read=60.0, write=5.0)
    http_client = httpx.AsyncClient(timeout=timeout)
    logger.info("HTTP Client initialized for DependencyContainer.")

    # Initialize Database Engine and Session Factory
    try:
        logger.info("Attempting to create main DB engine and session factory for DependencyContainer...")
        _db_engine = await create_db_engine()  # Uses app_settings implicitly via global settings instance
        logger.info("Main DB engine successfully created for DependencyContainer.")
        # Use the actual session factory from database_async module
        db_session_factory = db_get_session
        logger.info("DB Session Factory reference obtained for DependencyContainer.")

    except Exception as db_exc:
        logger.critical(f"Failed to initialize database for DependencyContainer due to exception: {db_exc}")
        await http_client.aclose()  # Clean up HTTP client
        logger.info("HTTP client closed due to DB initialization failure.")
        # No need to call close_db_engine here, as db_engine might not be valid or fully initialized.
        # The caller (lifespan) will handle global engine cleanup if needed.
        raise RuntimeError(f"Failed to initialize database for DependencyContainer: {db_exc}") from db_exc

    # Create and return Dependency Container
    try:
        dependencies = DependencyContainer(
            settings=app_settings,
            http_client=http_client,
            db_session_factory=db_session_factory,
        )
        logger.info("Dependency Container created successfully.")
        return dependencies
    except Exception as container_exc:
        logger.critical(f"Failed to create Dependency Container instance: {container_exc}", exc_info=True)
        # Clean up resources created within this helper function
        await http_client.aclose()
        logger.info("HTTP client closed due to Dependency Container instantiation failure.")
        # If db_engine was successfully created, it's now managed by the global close_db_engine,
        # which will be called by the lifespan's shutdown phase.
        # We don't call close_db_engine(db_engine_instance_if_any) here because the global one handles it.
        raise RuntimeError(f"Failed to create Dependency Container instance: {container_exc}") from container_exc

Initialize and configure core application dependencies.

This function sets up essential services required by the application, including an HTTP client and a database connection pool. It encapsulates the creation and configuration of these dependencies into a DependencyContainer instance.

Parameters:

Name Type Description Default
app_settings Settings

The application settings instance.

required

Returns:

Type Description
DependencyContainer

A DependencyContainer instance populated with initialized dependencies.

Raises:

Type Description
RuntimeError

If initialization of the HTTP client or database engine fails.

dependency_container

DependencyContainer

Source code in luthien_control/core/dependency_container.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class DependencyContainer:
    """Holds shared dependencies for the application.

    This class is responsible for holding all shared dependencies for the application.
    It is used to inject dependencies into the application and to make it easier to mock dependencies for testing.
    """

    def __init__(
        self,
        settings: Settings,
        http_client: httpx.AsyncClient,
        db_session_factory: Callable[[], AsyncContextManager[AsyncSession]],
    ) -> None:
        """
        Initializes the container.

        Args:
            settings: Application settings.
            http_client: Shared asynchronous HTTP client.
            db_session_factory: A factory function that returns an async context manager
                                yielding an SQLAlchemy AsyncSession.
        """
        self.settings = settings
        self.http_client = http_client
        self.db_session_factory = db_session_factory

    def create_openai_client(self, base_url: str, api_key: str) -> openai.AsyncOpenAI:
        """
        Creates an OpenAI client for the specified backend URL and API key.

        We include this factory here for the sake of consistency with other external dependencies.
        By maintaining all external dependencies in one place, we can easily mock them for testing
        and keep track of which parts of the application have external dependencies.

        Args:
            base_url: The base URL for the OpenAI-compatible API endpoint.
            api_key: The API key for authentication.

        Returns:
            An configured OpenAI AsyncClient instance.

        Raises:
            ValueError: If the base_url is missing or doesn't have a valid protocol.
        """
        if not base_url:
            raise ValueError("Base URL cannot be empty")

        if not base_url.startswith(("http://", "https://")):
            raise ValueError(f"Base URL must start with 'http://' or 'https://': {base_url}")

        return openai.AsyncOpenAI(api_key=api_key, base_url=base_url)

Holds shared dependencies for the application.

This class is responsible for holding all shared dependencies for the application. It is used to inject dependencies into the application and to make it easier to mock dependencies for testing.

__init__(settings, http_client, db_session_factory)
Source code in luthien_control/core/dependency_container.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
    self,
    settings: Settings,
    http_client: httpx.AsyncClient,
    db_session_factory: Callable[[], AsyncContextManager[AsyncSession]],
) -> None:
    """
    Initializes the container.

    Args:
        settings: Application settings.
        http_client: Shared asynchronous HTTP client.
        db_session_factory: A factory function that returns an async context manager
                            yielding an SQLAlchemy AsyncSession.
    """
    self.settings = settings
    self.http_client = http_client
    self.db_session_factory = db_session_factory

Initializes the container.

Parameters:

Name Type Description Default
settings Settings

Application settings.

required
http_client AsyncClient

Shared asynchronous HTTP client.

required
db_session_factory Callable[[], AsyncContextManager[AsyncSession]]

A factory function that returns an async context manager yielding an SQLAlchemy AsyncSession.

required
create_openai_client(base_url, api_key)
Source code in luthien_control/core/dependency_container.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def create_openai_client(self, base_url: str, api_key: str) -> openai.AsyncOpenAI:
    """
    Creates an OpenAI client for the specified backend URL and API key.

    We include this factory here for the sake of consistency with other external dependencies.
    By maintaining all external dependencies in one place, we can easily mock them for testing
    and keep track of which parts of the application have external dependencies.

    Args:
        base_url: The base URL for the OpenAI-compatible API endpoint.
        api_key: The API key for authentication.

    Returns:
        An configured OpenAI AsyncClient instance.

    Raises:
        ValueError: If the base_url is missing or doesn't have a valid protocol.
    """
    if not base_url:
        raise ValueError("Base URL cannot be empty")

    if not base_url.startswith(("http://", "https://")):
        raise ValueError(f"Base URL must start with 'http://' or 'https://': {base_url}")

    return openai.AsyncOpenAI(api_key=api_key, base_url=base_url)

Creates an OpenAI client for the specified backend URL and API key.

We include this factory here for the sake of consistency with other external dependencies. By maintaining all external dependencies in one place, we can easily mock them for testing and keep track of which parts of the application have external dependencies.

Parameters:

Name Type Description Default
base_url str

The base URL for the OpenAI-compatible API endpoint.

required
api_key str

The API key for authentication.

required

Returns:

Type Description
AsyncOpenAI

An configured OpenAI AsyncClient instance.

Raises:

Type Description
ValueError

If the base_url is missing or doesn't have a valid protocol.

generic_events

Generic event system with type-safe event dispatching.

Event

Bases: Generic[T]

Source code in luthien_control/core/generic_events.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class Event(Generic[T]):
    """A generic event that maintains a named registry of typed event listeners to be dispatched on demand.


    Typical usage:
        start_policy_event = LuthienEventType("start_policy")
        event: Event[dict] = Event[dict](start_policy_event)
        data = {"foo": "bar"}
        def listener_1(event_type, data):
            print(f"Listener 1 received event: {event_type} {data}")
        def listener_2(event_type, data):
            print(f"Listener 2 received event: {event_type} {data}")

        event.register("listener_1", listener_1)
        event.register("listener_2", listener_2)
        event.dispatch(data)  # Listener 1 and 2 will receive the event
        event.unregister("listener_1")
        event.dispatch(data)  # Listener 2 will receive the event

    Type Parameters:
        T: The type of data that will be passed to event listeners
    """

    def __init__(self, event_type: str) -> None:
        """Initialize an event with no registered listeners.

        Args:
            event_type: The type of event this observer is for
        """
        self._event_type = event_type
        self._listeners: Dict[str, EventListener[T]] = {}

    def register(self, name: str, listener: EventListener[T]) -> None:
        """Register a named observer.

        Args:
            name: Unique identifier for this listener
            listener: Callable that accepts an argument of type T
        """
        self._listeners[name] = listener

    def unregister(self, name: str) -> None:
        """Remove a registered observer by name.

        Args:
            name: The name of the listener to remove

        Raises:
            KeyError: If no listener with the given name exists
        """
        del self._listeners[name]

    def dispatch(self, data: T) -> None:
        """Dispatch the event to all registered observers.

        Args:
            data: The data of type T to pass to all listeners
        """
        for name, listener in self._listeners.items():
            try:
                listener(self._event_type, data)
            except Exception as e:
                logging.exception(f"Error dispatching event to listener {name}: {e}")

    @property
    def listener_count(self) -> int:
        """Return the number of registered listeners."""
        return len(self._listeners)

    def get_listeners(self) -> Dict[str, EventListener[T]]:
        """Return a *copy* of the registered listeners dictionary.

        Returns:
            A copy of the listeners registry
        """
        return self._listeners.copy()

A generic event that maintains a named registry of typed event listeners to be dispatched on demand.

Typical usage

start_policy_event = LuthienEventType("start_policy") event: Event[dict] = Eventdict data = {"foo": "bar"} def listener_1(event_type, data): print(f"Listener 1 received event: {event_type} {data}") def listener_2(event_type, data): print(f"Listener 2 received event: {event_type} {data}")

event.register("listener_1", listener_1) event.register("listener_2", listener_2) event.dispatch(data) # Listener 1 and 2 will receive the event event.unregister("listener_1") event.dispatch(data) # Listener 2 will receive the event

Type Parameters

T: The type of data that will be passed to event listeners

property

Return the number of registered listeners.

listener_count property

Return the number of registered listeners.

__init__(event_type)
Source code in luthien_control/core/generic_events.py
38
39
40
41
42
43
44
45
def __init__(self, event_type: str) -> None:
    """Initialize an event with no registered listeners.

    Args:
        event_type: The type of event this observer is for
    """
    self._event_type = event_type
    self._listeners: Dict[str, EventListener[T]] = {}

Initialize an event with no registered listeners.

Parameters:

Name Type Description Default
event_type str

The type of event this observer is for

required
dispatch(data)
Source code in luthien_control/core/generic_events.py
67
68
69
70
71
72
73
74
75
76
77
def dispatch(self, data: T) -> None:
    """Dispatch the event to all registered observers.

    Args:
        data: The data of type T to pass to all listeners
    """
    for name, listener in self._listeners.items():
        try:
            listener(self._event_type, data)
        except Exception as e:
            logging.exception(f"Error dispatching event to listener {name}: {e}")

Dispatch the event to all registered observers.

Parameters:

Name Type Description Default
data T

The data of type T to pass to all listeners

required
get_listeners()
Source code in luthien_control/core/generic_events.py
84
85
86
87
88
89
90
def get_listeners(self) -> Dict[str, EventListener[T]]:
    """Return a *copy* of the registered listeners dictionary.

    Returns:
        A copy of the listeners registry
    """
    return self._listeners.copy()

Return a copy of the registered listeners dictionary.

Returns:

Type Description
Dict[str, EventListener[T]]

A copy of the listeners registry

register(name, listener)
Source code in luthien_control/core/generic_events.py
47
48
49
50
51
52
53
54
def register(self, name: str, listener: EventListener[T]) -> None:
    """Register a named observer.

    Args:
        name: Unique identifier for this listener
        listener: Callable that accepts an argument of type T
    """
    self._listeners[name] = listener

Register a named observer.

Parameters:

Name Type Description Default
name str

Unique identifier for this listener

required
listener EventListener[T]

Callable that accepts an argument of type T

required
unregister(name)
Source code in luthien_control/core/generic_events.py
56
57
58
59
60
61
62
63
64
65
def unregister(self, name: str) -> None:
    """Remove a registered observer by name.

    Args:
        name: The name of the listener to remove

    Raises:
        KeyError: If no listener with the given name exists
    """
    del self._listeners[name]

Remove a registered observer by name.

Parameters:

Name Type Description Default
name str

The name of the listener to remove

required

Raises:

Type Description
KeyError

If no listener with the given name exists

logging

setup_logging()

Source code in luthien_control/core/logging.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def setup_logging():
    """
    Configures logging for the application.

    Reads the desired log level from the LOG_LEVEL environment variable.
    Defaults to INFO if not set or invalid.
    Sets a standard format and directs logs to stderr.
    Sets louder libraries to WARNING level.
    Optionally configures Loki handler if LOKI_URL is set.
    """
    settings = Settings()
    log_level_name = settings.get_log_level(default=DEFAULT_LOG_LEVEL)

    if log_level_name not in VALID_LOG_LEVELS:
        print(
            f"WARNING: Invalid LOG_LEVEL '{log_level_name}'. "
            f"Defaulting to {DEFAULT_LOG_LEVEL}. "
            f"Valid levels are: {', '.join(VALID_LOG_LEVELS)}",
            file=sys.stderr,
        )
        log_level_name = DEFAULT_LOG_LEVEL

    log_level = logging.getLevelName(log_level_name)

    # Configure root logger
    root_logger = logging.getLogger()
    root_logger.setLevel(log_level)

    # Clear any existing handlers
    root_logger.handlers.clear()

    # Console handler
    console_handler = logging.StreamHandler(sys.stderr)
    console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
    root_logger.addHandler(console_handler)

    # Loki handler if configured
    loki_url = os.getenv("LOKI_URL")
    if loki_url:
        loki_handler = _get_loki_handler(loki_url)
        if loki_handler:
            root_logger.addHandler(loki_handler)
            logging.getLogger(__name__).info(f"Loki logging configured for {loki_url}")

    # Quiet down noisy libraries
    for lib_name in NOISY_LIBRARIES:
        logging.getLogger(lib_name).setLevel(logging.WARNING)

    # Log that configuration is complete (useful for debugging setup issues)
    logging.getLogger(__name__).info(f"Logging configured with level {log_level_name}.")

Configures logging for the application.

Reads the desired log level from the LOG_LEVEL environment variable. Defaults to INFO if not set or invalid. Sets a standard format and directs logs to stderr. Sets louder libraries to WARNING level. Optionally configures Loki handler if LOKI_URL is set.

request

Request

Bases: DeepEventedModel

Source code in luthien_control/core/request.py
 7
 8
 9
10
11
12
class Request(DeepEventedModel):
    """A request to the Luthien Control API."""

    payload: OpenAIChatCompletionsRequest = Field()
    api_endpoint: str = Field()
    api_key: str = Field()

A request to the Luthien Control API.

response

Response

Bases: DeepEventedModel

Source code in luthien_control/core/response.py
 9
10
11
12
13
class Response(DeepEventedModel):
    """A response from the Luthien Control API."""

    payload: Optional[OpenAIChatCompletionsResponse] = Field(default=None)
    api_endpoint: Optional[str] = Field(default=None)

A response from the Luthien Control API.

tracked_context

TrackedContext module with explicit mutation API and event tracking.

TrackedContext

Source code in luthien_control/core/tracked_context/tracked_context.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class TrackedContext:
    """Transaction context with explicit mutation API and event tracking."""

    def __init__(self, transaction_id: Optional[uuid.UUID] = None):
        """Initialize tracked context."""
        self._transaction_id = transaction_id or uuid.uuid4()
        self._request: Optional[httpx.Request] = None
        self._response: Optional[httpx.Response] = None
        self._data: Dict[str, Any] = {}
        self.events = ContextEvents()

    @property
    def transaction_id(self) -> uuid.UUID:
        """Get transaction ID."""
        return self._transaction_id

    def update_request(
        self,
        method: Optional[str] = None,
        url: Optional[str] = None,
        headers: Optional[Dict[str, str]] = None,
        content: Optional[bytes] = None,
        from_scratch: bool = False,
        preserve_existing_headers: bool = True,
    ) -> httpx.Request:
        """Create or set the request."""
        differences = {}
        if from_scratch or self._request is None:
            if not all([method, url]):
                raise ValueError("Attempted to create new request, but method and url are required")
            method = str(method)
            url = str(url)
            self._request = httpx.Request(method=method, url=url, headers=headers, content=content)
            differences = {
                k: {"old": None, "new": getattr(self._request, k)} for k in ["method", "url", "headers", "content"]
            }
        else:
            if method is not None:
                differences["method"] = {"old": self._request.method, "new": method}
                self._request.method = method
            if url is not None:
                differences["url"] = {"old": self._request.url, "new": url}
                self._request.url = httpx.URL(url)
            if headers is not None:
                header_diffs = _update_headers(self._request, headers, preserve_existing_headers)
                differences["headers"] = header_diffs
            if content is not None:
                differences["content"] = {"old": self._request.content, "new": content}
                self._request._content = content

        self.events.mutation.dispatch(
            MutationEventPayload(
                transaction_id=self._transaction_id,
                operation="set_request",
                details=differences,
            )
        )
        return self._request

    @property
    def request(self) -> Optional[httpx.Request]:
        """Get a copy of the tracked request."""
        return copy(self._request)

    def update_response(
        self,
        status_code: Optional[int] = None,
        content: Optional[bytes] = None,
        headers: Optional[Dict[str, str]] = None,
        from_scratch: bool = False,
        preserve_existing_headers: bool = True,
    ) -> httpx.Response:
        """Update the response."""
        differences = {}
        if from_scratch or self._response is None:
            if not status_code:
                raise ValueError("Attempted to create new response, but status_code is required")
            status_code = int(status_code)
            self._response = httpx.Response(
                status_code=status_code,
                headers=headers,
                content=content,
            )
            differences = {
                k: {"old": None, "new": getattr(self._response, k)} for k in ["status_code", "headers", "content"]
            }
        else:
            if status_code is not None:
                differences["status_code"] = {"old": self._response.status_code, "new": status_code}
                self._response.status_code = status_code
            if headers is not None:
                differences["headers"] = _update_headers(self._response, headers, preserve_existing_headers)
            if content is not None:
                differences["content"] = {"old": self._response.content, "new": content}
                self._response._content = content

        self.events.mutation.dispatch(
            MutationEventPayload(
                transaction_id=self._transaction_id,
                operation="set_response",
                details=differences,
            )
        )
        return self._response

    @property
    def response(self) -> Optional[httpx.Response]:
        """Get a copy of the tracked response."""
        return copy(self._response)

    def get_data(self, key: str, default: Any = None) -> Any:
        """Get data value."""
        return self._data.get(key, default)

    def set_data(self, key: str, value: Any) -> None:
        """Set data value."""
        old_value = self._data.get(key)
        self._data[key] = value
        self.events.mutation.dispatch(
            MutationEventPayload(
                transaction_id=self._transaction_id,
                operation="set_data",
                details={"key": key, "old_value": old_value, "new_value": value},
            )
        )

    def get_all_data(self) -> Dict[str, Any]:
        """Get copy of all data."""
        return self._data.copy()

Transaction context with explicit mutation API and event tracking.

property

Get a copy of the tracked request.

request property

Get a copy of the tracked request.

property

Get a copy of the tracked response.

response property

Get a copy of the tracked response.

property

Get transaction ID.

transaction_id property

Get transaction ID.

__init__(transaction_id=None)
Source code in luthien_control/core/tracked_context/tracked_context.py
46
47
48
49
50
51
52
def __init__(self, transaction_id: Optional[uuid.UUID] = None):
    """Initialize tracked context."""
    self._transaction_id = transaction_id or uuid.uuid4()
    self._request: Optional[httpx.Request] = None
    self._response: Optional[httpx.Response] = None
    self._data: Dict[str, Any] = {}
    self.events = ContextEvents()

Initialize tracked context.

get_all_data()
Source code in luthien_control/core/tracked_context/tracked_context.py
169
170
171
def get_all_data(self) -> Dict[str, Any]:
    """Get copy of all data."""
    return self._data.copy()

Get copy of all data.

get_data(key, default=None)
Source code in luthien_control/core/tracked_context/tracked_context.py
153
154
155
def get_data(self, key: str, default: Any = None) -> Any:
    """Get data value."""
    return self._data.get(key, default)

Get data value.

set_data(key, value)
Source code in luthien_control/core/tracked_context/tracked_context.py
157
158
159
160
161
162
163
164
165
166
167
def set_data(self, key: str, value: Any) -> None:
    """Set data value."""
    old_value = self._data.get(key)
    self._data[key] = value
    self.events.mutation.dispatch(
        MutationEventPayload(
            transaction_id=self._transaction_id,
            operation="set_data",
            details={"key": key, "old_value": old_value, "new_value": value},
        )
    )

Set data value.

update_request(method=None, url=None, headers=None, content=None, from_scratch=False, preserve_existing_headers=True)
Source code in luthien_control/core/tracked_context/tracked_context.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def update_request(
    self,
    method: Optional[str] = None,
    url: Optional[str] = None,
    headers: Optional[Dict[str, str]] = None,
    content: Optional[bytes] = None,
    from_scratch: bool = False,
    preserve_existing_headers: bool = True,
) -> httpx.Request:
    """Create or set the request."""
    differences = {}
    if from_scratch or self._request is None:
        if not all([method, url]):
            raise ValueError("Attempted to create new request, but method and url are required")
        method = str(method)
        url = str(url)
        self._request = httpx.Request(method=method, url=url, headers=headers, content=content)
        differences = {
            k: {"old": None, "new": getattr(self._request, k)} for k in ["method", "url", "headers", "content"]
        }
    else:
        if method is not None:
            differences["method"] = {"old": self._request.method, "new": method}
            self._request.method = method
        if url is not None:
            differences["url"] = {"old": self._request.url, "new": url}
            self._request.url = httpx.URL(url)
        if headers is not None:
            header_diffs = _update_headers(self._request, headers, preserve_existing_headers)
            differences["headers"] = header_diffs
        if content is not None:
            differences["content"] = {"old": self._request.content, "new": content}
            self._request._content = content

    self.events.mutation.dispatch(
        MutationEventPayload(
            transaction_id=self._transaction_id,
            operation="set_request",
            details=differences,
        )
    )
    return self._request

Create or set the request.

update_response(status_code=None, content=None, headers=None, from_scratch=False, preserve_existing_headers=True)
Source code in luthien_control/core/tracked_context/tracked_context.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def update_response(
    self,
    status_code: Optional[int] = None,
    content: Optional[bytes] = None,
    headers: Optional[Dict[str, str]] = None,
    from_scratch: bool = False,
    preserve_existing_headers: bool = True,
) -> httpx.Response:
    """Update the response."""
    differences = {}
    if from_scratch or self._response is None:
        if not status_code:
            raise ValueError("Attempted to create new response, but status_code is required")
        status_code = int(status_code)
        self._response = httpx.Response(
            status_code=status_code,
            headers=headers,
            content=content,
        )
        differences = {
            k: {"old": None, "new": getattr(self._response, k)} for k in ["status_code", "headers", "content"]
        }
    else:
        if status_code is not None:
            differences["status_code"] = {"old": self._response.status_code, "new": status_code}
            self._response.status_code = status_code
        if headers is not None:
            differences["headers"] = _update_headers(self._response, headers, preserve_existing_headers)
        if content is not None:
            differences["content"] = {"old": self._response.content, "new": content}
            self._response._content = content

    self.events.mutation.dispatch(
        MutationEventPayload(
            transaction_id=self._transaction_id,
            operation="set_response",
            details=differences,
        )
    )
    return self._response

Update the response.

get_tx_value(tracked_context, path)

Source code in luthien_control/core/tracked_context/util.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_tx_value(tracked_context: TrackedContext, path: str) -> Any:
    """Get a value from the tracked context using a path.

    Args:
        tracked_context: The tracked context.
        path: The path to the value e.g. "request.headers.user-agent", "response.status_code", "data.user_id".

    Returns:
        The value at the path.

    Raises:
        ValueError: If the path is invalid or the value cannot be accessed.
    """
    vals = path.split(".")
    if len(vals) < 2:
        raise ValueError("Path must contain at least two components")

    # Handle the first segment specially for TrackedContext
    first_segment = vals.pop(0)
    if first_segment == "request":
        if tracked_context.request is None:
            raise ValueError("Request is None in tracked context")
        x: Any = tracked_context.request
    elif first_segment == "response":
        if tracked_context.response is None:
            raise ValueError("Response is None in tracked context")
        x = tracked_context.response
    elif first_segment == "data":
        x = tracked_context.get_all_data()
    else:
        raise ValueError(f"Invalid path segment: {first_segment}")

    for next_segment in vals:
        if isinstance(x, bytes):
            try:
                x = json.loads(x)
            except json.JSONDecodeError as e:
                # Wrapping the original error for better diagnostics
                raise ValueError(f"Failed to decode JSON content for path '{path}' at segment '{next_segment}'") from e

        if isinstance(x, dict):
            x = x[next_segment]
        else:
            x = getattr(x, next_segment)
    return x

Get a value from the tracked context using a path.

Parameters:

Name Type Description Default
tracked_context TrackedContext

The tracked context.

required
path str

The path to the value e.g. "request.headers.user-agent", "response.status_code", "data.user_id".

required

Returns:

Type Description
Any

The value at the path.

Raises:

Type Description
ValueError

If the path is invalid or the value cannot be accessed.

tracked_context

TrackedContext with explicit mutation API and event tracking.

MutationEventPayload dataclass
Source code in luthien_control/core/tracked_context/tracked_context.py
30
31
32
33
34
35
36
@dataclass
class MutationEventPayload:
    """Record of an explicit mutation."""

    transaction_id: Optional[uuid.UUID]
    operation: str  # e.g., "set_header", "set_response_status"
    details: Dict[str, Any]

Record of an explicit mutation.

TrackedContext
Source code in luthien_control/core/tracked_context/tracked_context.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class TrackedContext:
    """Transaction context with explicit mutation API and event tracking."""

    def __init__(self, transaction_id: Optional[uuid.UUID] = None):
        """Initialize tracked context."""
        self._transaction_id = transaction_id or uuid.uuid4()
        self._request: Optional[httpx.Request] = None
        self._response: Optional[httpx.Response] = None
        self._data: Dict[str, Any] = {}
        self.events = ContextEvents()

    @property
    def transaction_id(self) -> uuid.UUID:
        """Get transaction ID."""
        return self._transaction_id

    def update_request(
        self,
        method: Optional[str] = None,
        url: Optional[str] = None,
        headers: Optional[Dict[str, str]] = None,
        content: Optional[bytes] = None,
        from_scratch: bool = False,
        preserve_existing_headers: bool = True,
    ) -> httpx.Request:
        """Create or set the request."""
        differences = {}
        if from_scratch or self._request is None:
            if not all([method, url]):
                raise ValueError("Attempted to create new request, but method and url are required")
            method = str(method)
            url = str(url)
            self._request = httpx.Request(method=method, url=url, headers=headers, content=content)
            differences = {
                k: {"old": None, "new": getattr(self._request, k)} for k in ["method", "url", "headers", "content"]
            }
        else:
            if method is not None:
                differences["method"] = {"old": self._request.method, "new": method}
                self._request.method = method
            if url is not None:
                differences["url"] = {"old": self._request.url, "new": url}
                self._request.url = httpx.URL(url)
            if headers is not None:
                header_diffs = _update_headers(self._request, headers, preserve_existing_headers)
                differences["headers"] = header_diffs
            if content is not None:
                differences["content"] = {"old": self._request.content, "new": content}
                self._request._content = content

        self.events.mutation.dispatch(
            MutationEventPayload(
                transaction_id=self._transaction_id,
                operation="set_request",
                details=differences,
            )
        )
        return self._request

    @property
    def request(self) -> Optional[httpx.Request]:
        """Get a copy of the tracked request."""
        return copy(self._request)

    def update_response(
        self,
        status_code: Optional[int] = None,
        content: Optional[bytes] = None,
        headers: Optional[Dict[str, str]] = None,
        from_scratch: bool = False,
        preserve_existing_headers: bool = True,
    ) -> httpx.Response:
        """Update the response."""
        differences = {}
        if from_scratch or self._response is None:
            if not status_code:
                raise ValueError("Attempted to create new response, but status_code is required")
            status_code = int(status_code)
            self._response = httpx.Response(
                status_code=status_code,
                headers=headers,
                content=content,
            )
            differences = {
                k: {"old": None, "new": getattr(self._response, k)} for k in ["status_code", "headers", "content"]
            }
        else:
            if status_code is not None:
                differences["status_code"] = {"old": self._response.status_code, "new": status_code}
                self._response.status_code = status_code
            if headers is not None:
                differences["headers"] = _update_headers(self._response, headers, preserve_existing_headers)
            if content is not None:
                differences["content"] = {"old": self._response.content, "new": content}
                self._response._content = content

        self.events.mutation.dispatch(
            MutationEventPayload(
                transaction_id=self._transaction_id,
                operation="set_response",
                details=differences,
            )
        )
        return self._response

    @property
    def response(self) -> Optional[httpx.Response]:
        """Get a copy of the tracked response."""
        return copy(self._response)

    def get_data(self, key: str, default: Any = None) -> Any:
        """Get data value."""
        return self._data.get(key, default)

    def set_data(self, key: str, value: Any) -> None:
        """Set data value."""
        old_value = self._data.get(key)
        self._data[key] = value
        self.events.mutation.dispatch(
            MutationEventPayload(
                transaction_id=self._transaction_id,
                operation="set_data",
                details={"key": key, "old_value": old_value, "new_value": value},
            )
        )

    def get_all_data(self) -> Dict[str, Any]:
        """Get copy of all data."""
        return self._data.copy()

Transaction context with explicit mutation API and event tracking.

property

Get a copy of the tracked request.

request property

Get a copy of the tracked request.

property

Get a copy of the tracked response.

response property

Get a copy of the tracked response.

property

Get transaction ID.

transaction_id property

Get transaction ID.

__init__(transaction_id=None)
Source code in luthien_control/core/tracked_context/tracked_context.py
46
47
48
49
50
51
52
def __init__(self, transaction_id: Optional[uuid.UUID] = None):
    """Initialize tracked context."""
    self._transaction_id = transaction_id or uuid.uuid4()
    self._request: Optional[httpx.Request] = None
    self._response: Optional[httpx.Response] = None
    self._data: Dict[str, Any] = {}
    self.events = ContextEvents()

Initialize tracked context.

get_all_data()
Source code in luthien_control/core/tracked_context/tracked_context.py
169
170
171
def get_all_data(self) -> Dict[str, Any]:
    """Get copy of all data."""
    return self._data.copy()

Get copy of all data.

get_data(key, default=None)
Source code in luthien_control/core/tracked_context/tracked_context.py
153
154
155
def get_data(self, key: str, default: Any = None) -> Any:
    """Get data value."""
    return self._data.get(key, default)

Get data value.

set_data(key, value)
Source code in luthien_control/core/tracked_context/tracked_context.py
157
158
159
160
161
162
163
164
165
166
167
def set_data(self, key: str, value: Any) -> None:
    """Set data value."""
    old_value = self._data.get(key)
    self._data[key] = value
    self.events.mutation.dispatch(
        MutationEventPayload(
            transaction_id=self._transaction_id,
            operation="set_data",
            details={"key": key, "old_value": old_value, "new_value": value},
        )
    )

Set data value.

update_request(method=None, url=None, headers=None, content=None, from_scratch=False, preserve_existing_headers=True)
Source code in luthien_control/core/tracked_context/tracked_context.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def update_request(
    self,
    method: Optional[str] = None,
    url: Optional[str] = None,
    headers: Optional[Dict[str, str]] = None,
    content: Optional[bytes] = None,
    from_scratch: bool = False,
    preserve_existing_headers: bool = True,
) -> httpx.Request:
    """Create or set the request."""
    differences = {}
    if from_scratch or self._request is None:
        if not all([method, url]):
            raise ValueError("Attempted to create new request, but method and url are required")
        method = str(method)
        url = str(url)
        self._request = httpx.Request(method=method, url=url, headers=headers, content=content)
        differences = {
            k: {"old": None, "new": getattr(self._request, k)} for k in ["method", "url", "headers", "content"]
        }
    else:
        if method is not None:
            differences["method"] = {"old": self._request.method, "new": method}
            self._request.method = method
        if url is not None:
            differences["url"] = {"old": self._request.url, "new": url}
            self._request.url = httpx.URL(url)
        if headers is not None:
            header_diffs = _update_headers(self._request, headers, preserve_existing_headers)
            differences["headers"] = header_diffs
        if content is not None:
            differences["content"] = {"old": self._request.content, "new": content}
            self._request._content = content

    self.events.mutation.dispatch(
        MutationEventPayload(
            transaction_id=self._transaction_id,
            operation="set_request",
            details=differences,
        )
    )
    return self._request

Create or set the request.

update_response(status_code=None, content=None, headers=None, from_scratch=False, preserve_existing_headers=True)
Source code in luthien_control/core/tracked_context/tracked_context.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def update_response(
    self,
    status_code: Optional[int] = None,
    content: Optional[bytes] = None,
    headers: Optional[Dict[str, str]] = None,
    from_scratch: bool = False,
    preserve_existing_headers: bool = True,
) -> httpx.Response:
    """Update the response."""
    differences = {}
    if from_scratch or self._response is None:
        if not status_code:
            raise ValueError("Attempted to create new response, but status_code is required")
        status_code = int(status_code)
        self._response = httpx.Response(
            status_code=status_code,
            headers=headers,
            content=content,
        )
        differences = {
            k: {"old": None, "new": getattr(self._response, k)} for k in ["status_code", "headers", "content"]
        }
    else:
        if status_code is not None:
            differences["status_code"] = {"old": self._response.status_code, "new": status_code}
            self._response.status_code = status_code
        if headers is not None:
            differences["headers"] = _update_headers(self._response, headers, preserve_existing_headers)
        if content is not None:
            differences["content"] = {"old": self._response.content, "new": content}
            self._response._content = content

    self.events.mutation.dispatch(
        MutationEventPayload(
            transaction_id=self._transaction_id,
            operation="set_response",
            details=differences,
        )
    )
    return self._response

Update the response.

util

Utilities for working with TrackedContext.

get_tx_value(tracked_context, path)
Source code in luthien_control/core/tracked_context/util.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_tx_value(tracked_context: TrackedContext, path: str) -> Any:
    """Get a value from the tracked context using a path.

    Args:
        tracked_context: The tracked context.
        path: The path to the value e.g. "request.headers.user-agent", "response.status_code", "data.user_id".

    Returns:
        The value at the path.

    Raises:
        ValueError: If the path is invalid or the value cannot be accessed.
    """
    vals = path.split(".")
    if len(vals) < 2:
        raise ValueError("Path must contain at least two components")

    # Handle the first segment specially for TrackedContext
    first_segment = vals.pop(0)
    if first_segment == "request":
        if tracked_context.request is None:
            raise ValueError("Request is None in tracked context")
        x: Any = tracked_context.request
    elif first_segment == "response":
        if tracked_context.response is None:
            raise ValueError("Response is None in tracked context")
        x = tracked_context.response
    elif first_segment == "data":
        x = tracked_context.get_all_data()
    else:
        raise ValueError(f"Invalid path segment: {first_segment}")

    for next_segment in vals:
        if isinstance(x, bytes):
            try:
                x = json.loads(x)
            except json.JSONDecodeError as e:
                # Wrapping the original error for better diagnostics
                raise ValueError(f"Failed to decode JSON content for path '{path}' at segment '{next_segment}'") from e

        if isinstance(x, dict):
            x = x[next_segment]
        else:
            x = getattr(x, next_segment)
    return x

Get a value from the tracked context using a path.

Parameters:

Name Type Description Default
tracked_context TrackedContext

The tracked context.

required
path str

The path to the value e.g. "request.headers.user-agent", "response.status_code", "data.user_id".

required

Returns:

Type Description
Any

The value at the path.

Raises:

Type Description
ValueError

If the path is invalid or the value cannot be accessed.

transaction

Transaction

Bases: DeepEventedModel

Source code in luthien_control/core/transaction.py
12
13
14
15
16
17
18
class Transaction(DeepEventedModel):
    """A transaction between the Luthien Control API and the client."""

    transaction_id: UUID = Field(default_factory=uuid4)
    request: Request = Field()
    response: Response = Field()
    data: EventedDict[str, Any] = Field(default_factory=EventedDict)

A transaction between the Luthien Control API and the client.

transaction_context

TransactionContext dataclass

Source code in luthien_control/core/transaction_context.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@dataclass
class TransactionContext:
    """Holds the state for a single transaction through the proxy.

    Attributes:
        transaction_id: A unique identifier for the transaction.
        request: The incoming HTTP request object.
        response: The outgoing HTTP response object.
        data: A general-purpose dictionary for policies to store and share
            information related to this transaction.
    """

    # Core Identifiers and State
    transaction_id: uuid.UUID = field(default_factory=uuid.uuid4)
    request: Optional[Request] = None
    response: Optional[Response] = None

    # General purpose data store for policies to share information
    data: Dict[str, Any] = field(default_factory=dict)

Holds the state for a single transaction through the proxy.

Attributes:

Name Type Description
transaction_id UUID

A unique identifier for the transaction.

request Optional[Request]

The incoming HTTP request object.

response Optional[Response]

The outgoing HTTP response object.

data Dict[str, Any]

A general-purpose dictionary for policies to store and share information related to this transaction.

get_tx_value(transaction_context, path)

Source code in luthien_control/core/transaction_context.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def get_tx_value(transaction_context: TransactionContext, path: str) -> Any:
    """Get a value from the transaction context using a path.

    Args:
        transaction_context: The transaction context.
        path: The path to the value e.g. "request.headers.user-agent", "response.status_code", "data.user_id".

    Returns:
        The value at the path.

    Raises:
        ValueError: If the path is invalid or the value cannot be accessed.
        TypeError: If the transaction_id is not a UUID.
    """
    vals = path.split(".")
    if len(vals) < 2:
        raise ValueError("Path must contain at least two components")

    x: Any = getattr(transaction_context, vals.pop(0))
    while vals:
        # If x is bytes, and we still have path segments to process,
        # it implies these segments are keys into the JSON content.
        if isinstance(x, bytes) and vals:  # Check if vals is not empty
            x = json.loads(x)

        if isinstance(x, dict):
            x = x[vals.pop(0)]
        else:
            x = getattr(x, vals.pop(0))
    return x

Get a value from the transaction context using a path.

Parameters:

Name Type Description Default
transaction_context TransactionContext

The transaction context.

required
path str

The path to the value e.g. "request.headers.user-agent", "response.status_code", "data.user_id".

required

Returns:

Type Description
Any

The value at the path.

Raises:

Type Description
ValueError

If the path is invalid or the value cannot be accessed.

TypeError

If the transaction_id is not a UUID.

custom_openapi_schema

create_custom_openapi(app)

Source code in luthien_control/custom_openapi_schema.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def create_custom_openapi(app: FastAPI):
    """
    Generate a custom OpenAPI schema for the FastAPI application.

    This function retrieves the default schema and modifies it, specifically
    to set `allowReserved=True` for the `full_path` path parameter used
    in proxy routes. This is necessary for correctly handling URLs containing
    reserved characters within that path segment.

    Args:
        app: The FastAPI application instance.

    Returns:
        The modified OpenAPI schema dictionary.
    """
    # Check if schema already exists to avoid redundant generation
    if app.openapi_schema:
        return app.openapi_schema

    logger.debug("Generating custom OpenAPI schema.")
    openapi_schema = get_openapi(
        title=app.title,
        version=app.version,
        description=app.description,
        routes=app.routes,
    )

    # Modify the schema for the path parameter
    paths = openapi_schema.get("paths", {})
    logger.debug(f"Found {len(paths)} paths in schema. Searching for '{{full_path}}'.")
    for path_key, path_item in paths.items():
        if "{full_path}" in path_key:
            logger.debug(f"Processing path: {path_key}")
            # path_item contains methods like 'get', 'post', etc.
            for method, method_item in path_item.items():
                # Ensure 'parameters' exists and is a list
                parameters = method_item.get("parameters", [])
                if not isinstance(parameters, list):
                    logger.warning(f"Unexpected 'parameters' format in {path_key} -> {method}. Skipping.")
                    continue

                found_param = False
                for param in parameters:
                    # Ensure param is a dictionary and has 'name' and 'in' keys
                    if not isinstance(param, dict) or "name" not in param or "in" not in param:
                        logger.warning(
                            f"Malformed parameter definition in {path_key} -> {method}. Skipping param: {param}"
                        )
                        continue

                    if param["name"] == "full_path" and param["in"] == "path":
                        param["allowReserved"] = True
                        found_param = True
                        logger.info(f"Set allowReserved=true for 'full_path' parameter in {path_key} -> {method}")
                        # Assuming only one 'full_path' param per method
                        break  # No need to check other params for this method
                if not found_param:
                    logger.debug(f"No 'full_path' path parameter found in {path_key} -> {method}")

    # Cache the generated schema in the app instance
    app.openapi_schema = openapi_schema
    logger.debug("Custom OpenAPI schema generation complete.")
    return app.openapi_schema

Generate a custom OpenAPI schema for the FastAPI application.

This function retrieves the default schema and modifies it, specifically to set allowReserved=True for the full_path path parameter used in proxy routes. This is necessary for correctly handling URLs containing reserved characters within that path segment.

Parameters:

Name Type Description Default
app FastAPI

The FastAPI application instance.

required

Returns:

Type Description

The modified OpenAPI schema dictionary.

db

Database models and session management.

ControlPolicy

Bases: SQLModel

Source code in luthien_control/db/sqlmodel_models.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class ControlPolicy(SQLModel, table=True):
    __tablename__ = "policies"  # type: ignore (again, shut up pyright)
    """Database model for storing control policy configurations."""

    # Primary key
    id: Optional[int] = Field(default=None, primary_key=True)

    # --- Core Fields ---
    name: str = Field(index=True, unique=True)  # Unique name used for lookup
    type: str = Field()  # Type of policy, used for instantiation
    config: dict[str, Any] = Field(default={}, sa_column=Column(JSON))
    is_active: bool = Field(default=True, index=True)
    description: Optional[str] = Field(default=None)

    # --- Timestamps ---
    created_at: dt.datetime = Field(
        default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=False
    )
    updated_at: dt.datetime = Field(
        default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=False
    )

    def __init__(self, **data: Any):
        # Ensure timestamps are set on creation if not provided
        if "created_at" not in data:
            data["created_at"] = datetime.now(timezone.utc).replace(tzinfo=None)
        if "updated_at" not in data:
            data["updated_at"] = datetime.now(timezone.utc).replace(tzinfo=None)
        super().__init__(**data)

    @model_validator(mode="before")
    @classmethod
    def validate_timestamps(cls, values):
        """Ensure updated_at is always set/updated."""
        if isinstance(values, dict):
            values["updated_at"] = datetime.now(timezone.utc).replace(tzinfo=None)
        return values

class-attribute instance-attribute

Database model for storing control policy configurations.

__tablename__ = 'policies' class-attribute instance-attribute

Database model for storing control policy configurations.

validate_timestamps(values) classmethod

Source code in luthien_control/db/sqlmodel_models.py
78
79
80
81
82
83
84
@model_validator(mode="before")
@classmethod
def validate_timestamps(cls, values):
    """Ensure updated_at is always set/updated."""
    if isinstance(values, dict):
        values["updated_at"] = datetime.now(timezone.utc).replace(tzinfo=None)
    return values

Ensure updated_at is always set/updated.

LuthienLog

Bases: SQLModel

Source code in luthien_control/db/sqlmodel_models.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class LuthienLog(SQLModel, table=True):
    """
    Represents a log entry in the Luthien logging system using SQLModel.

    Attributes:
        id: Unique identifier for the log entry (primary key).
        transaction_id: Identifier to group related log entries.
        datetime: Timestamp indicating when the log entry was generated (timezone-aware).
        data: JSON blob containing the primary logged data.
        datatype: String identifier for the nature and schema of 'data'.
        notes: JSON blob for additional contextual information.
    """

    __tablename__ = "luthien_log"  # type: ignore (shut up pyright)

    id: Optional[int] = Field(default=None, primary_key=True, index=True)
    transaction_id: str = Field(index=True, nullable=False)
    datetime: NaiveDatetime = Field(
        default_factory=NaiveDatetime.now,
        nullable=False,
        index=True,
    )
    data: Optional[dict[str, Any]] = Field(default=None, sa_column=Column(JsonBOrJson))
    datatype: str = Field(index=True, nullable=False)
    notes: Optional[dict[str, Any]] = Field(default=None, sa_column=Column(JsonBOrJson))

    def __init__(self, **data: Any) -> None:
        """Override init to ensure datetime is converted to NaiveDatetime."""
        if "datetime" in data:
            dt_value = data["datetime"]
            if isinstance(dt_value, datetime) and not isinstance(dt_value, NaiveDatetime):
                data["datetime"] = NaiveDatetime(dt_value)
        super().__init__(**data)

    # __table_args__ = (
    #     Index("ix_sqlmodel_luthien_log_transaction_id", "transaction_id"),
    #     Index("ix_sqlmodel_luthien_log_datetime", "datetime"),
    #     Index("ix_sqlmodel_luthien_log_datatype", "datatype"),
    #     {"extend_existing": True},
    # )

    # __repr__ is not automatically generated by SQLModel like Pydantic models,
    # but you can add one if desired.
    def __repr__(self) -> str:
        return (
            f"<LuthienLog(id={self.id}, "
            f"transaction_id='{self.transaction_id}', "
            f"datetime='{self.datetime}', "
            f"datatype='{self.datatype}')>"
        )

Represents a log entry in the Luthien logging system using SQLModel.

Attributes:

Name Type Description
id Optional[int]

Unique identifier for the log entry (primary key).

transaction_id str

Identifier to group related log entries.

datetime NaiveDatetime

Timestamp indicating when the log entry was generated (timezone-aware).

data Optional[dict[str, Any]]

JSON blob containing the primary logged data.

datatype str

String identifier for the nature and schema of 'data'.

notes Optional[dict[str, Any]]

JSON blob for additional contextual information.

__init__(**data)

Source code in luthien_control/db/sqlmodel_models.py
113
114
115
116
117
118
119
def __init__(self, **data: Any) -> None:
    """Override init to ensure datetime is converted to NaiveDatetime."""
    if "datetime" in data:
        dt_value = data["datetime"]
        if isinstance(dt_value, datetime) and not isinstance(dt_value, NaiveDatetime):
            data["datetime"] = NaiveDatetime(dt_value)
    super().__init__(**data)

Override init to ensure datetime is converted to NaiveDatetime.

client_api_key_crud

create_api_key(session, api_key) async

Source code in luthien_control/db/client_api_key_crud.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
async def create_api_key(session: AsyncSession, api_key: ClientApiKey) -> ClientApiKey:
    """Create a new API key in the database.

    Args:
        session: The database session
        api_key: The API key to create

    Returns:
        The created API key with updated ID

    Raises:
        LuthienDBIntegrityError: If a constraint violation occurs
        LuthienDBTransactionError: If the transaction fails
        LuthienDBOperationError: For other database errors
    """
    try:
        session.add(api_key)
        await session.commit()
        await session.refresh(api_key)
        logger.info(f"Successfully created API key with ID: {api_key.id}")
        return api_key
    except IntegrityError as ie:
        await session.rollback()
        logger.error(f"Integrity error creating API key: {ie}")
        raise LuthienDBIntegrityError(f"Could not create API key due to constraint violation: {ie}", ie) from ie
    except SQLAlchemyError as sqla_err:
        await session.rollback()
        logger.error(f"SQLAlchemy error creating API key: {sqla_err}")
        raise LuthienDBTransactionError(f"Database transaction failed while creating API key: {sqla_err}") from sqla_err
    except Exception as e:
        await session.rollback()
        logger.error(f"Unexpected error creating API key: {e}")
        raise LuthienDBOperationError(f"Unexpected error during API key creation: {e}") from e

Create a new API key in the database.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
api_key ClientApiKey

The API key to create

required

Returns:

Type Description
ClientApiKey

The created API key with updated ID

Raises:

Type Description
LuthienDBIntegrityError

If a constraint violation occurs

LuthienDBTransactionError

If the transaction fails

LuthienDBOperationError

For other database errors

get_api_key_by_value(session, key_value) async

Source code in luthien_control/db/client_api_key_crud.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
async def get_api_key_by_value(session: AsyncSession, key_value: str) -> ClientApiKey:
    """Get an active API key by its value.

    Args:
        session: The database session
        key_value: The value of the API key to retrieve

    Returns:
        The API key

    Raises:
        LuthienDBQueryError: If the API key is not found or if the query execution fails
        LuthienDBOperationError: For unexpected errors during lookup
    """
    try:
        stmt = select(ClientApiKey).where(
            ClientApiKey.key_value == key_value,  # type: ignore[arg-type]
            ClientApiKey.is_active,  # type: ignore[arg-type]
        )
        result = await session.execute(stmt)
        api_key = result.scalar_one_or_none()
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error fetching API key by value: {sqla_err}", exc_info=True)
        raise LuthienDBQueryError(f"Database query failed while fetching API key: {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error fetching API key by value: {e}", exc_info=True)
        raise LuthienDBOperationError(f"Unexpected error during API key lookup: {e}") from e

    if not api_key:
        raise LuthienDBQueryError(f"Active API key with value '{key_value}' not found")

    return api_key

Get an active API key by its value.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
key_value str

The value of the API key to retrieve

required

Returns:

Type Description
ClientApiKey

The API key

Raises:

Type Description
LuthienDBQueryError

If the API key is not found or if the query execution fails

LuthienDBOperationError

For unexpected errors during lookup

list_api_keys(session, active_only=False) async

Source code in luthien_control/db/client_api_key_crud.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
async def list_api_keys(session: AsyncSession, active_only: bool = False) -> List[ClientApiKey]:
    """Get a list of all API keys.

    Args:
        session: The database session
        active_only: If True, only return active API keys

    Returns:
        A list of API keys

    Raises:
        LuthienDBQueryError: If the query execution fails
    """
    try:
        if active_only:
            stmt = select(ClientApiKey).where(ClientApiKey.is_active)  # type: ignore[arg-type]
        else:
            stmt = select(ClientApiKey)

        result = await session.execute(stmt)
        return list(result.scalars().all())
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error listing API keys: {sqla_err}")
        raise LuthienDBQueryError(f"Database query failed while listing API keys: {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error listing API keys: {e}")
        raise LuthienDBOperationError(f"Unexpected error during API key listing: {e}") from e

Get a list of all API keys.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
active_only bool

If True, only return active API keys

False

Returns:

Type Description
List[ClientApiKey]

A list of API keys

Raises:

Type Description
LuthienDBQueryError

If the query execution fails

update_api_key(session, key_id, api_key_update) async

Source code in luthien_control/db/client_api_key_crud.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
async def update_api_key(session: AsyncSession, key_id: int, api_key_update: ClientApiKey) -> ClientApiKey:
    """Update an existing API key.

    Args:
        session: The database session
        key_id: The ID of the API key to update
        api_key_update: The updated API key data

    Returns:
        The updated API key

    Raises:
        LuthienDBQueryError: If the API key is not found
        LuthienDBIntegrityError: If a constraint violation occurs
        LuthienDBTransactionError: If the transaction fails
        LuthienDBOperationError: For other database errors
    """
    try:
        stmt = select(ClientApiKey).where(ClientApiKey.id == key_id)  # type: ignore[arg-type]
        result = await session.execute(stmt)
        api_key = result.scalar_one_or_none()
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error updating API key: {sqla_err}")
        raise LuthienDBTransactionError(f"Database transaction failed while updating API key: {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error updating API key: {e}")
        raise LuthienDBOperationError(f"Unexpected error during API key update: {e}") from e

    if not api_key:
        raise LuthienDBQueryError(f"API key with ID {key_id} not found")

    try:
        # Update fields
        api_key.name = api_key_update.name
        api_key.is_active = api_key_update.is_active
        api_key.metadata_ = api_key_update.metadata_

        await session.commit()
        await session.refresh(api_key)
        logger.info(f"Successfully updated API key with ID: {api_key.id}")
        return api_key
    except IntegrityError as ie:
        await session.rollback()
        logger.error(f"Integrity error updating API key: {ie}")
        raise LuthienDBIntegrityError(f"Could not update API key due to constraint violation: {ie}", ie) from ie
    except SQLAlchemyError as sqla_err:
        await session.rollback()
        logger.error(f"SQLAlchemy error updating API key: {sqla_err}")
        raise LuthienDBTransactionError(f"Database transaction failed while updating API key: {sqla_err}") from sqla_err
    except Exception as e:
        await session.rollback()
        logger.error(f"Unexpected error updating API key: {e}")
        raise LuthienDBOperationError(f"Unexpected error during API key update: {e}") from e

Update an existing API key.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
key_id int

The ID of the API key to update

required
api_key_update ClientApiKey

The updated API key data

required

Returns:

Type Description
ClientApiKey

The updated API key

Raises:

Type Description
LuthienDBQueryError

If the API key is not found

LuthienDBIntegrityError

If a constraint violation occurs

LuthienDBTransactionError

If the transaction fails

LuthienDBOperationError

For other database errors

control_policy_crud

get_policy_by_name(session, name) async

Source code in luthien_control/db/control_policy_crud.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
async def get_policy_by_name(session: AsyncSession, name: str) -> DBControlPolicy:
    """Get a policy by its name.

    Args:
        session: The database session
        name: The name of the policy to retrieve

    Returns:
        The policy

    Raises:
        LuthienDBQueryError: If the policy is not found or if the query execution fails
        LuthienDBOperationError: For unexpected errors during lookup
    """
    try:
        stmt = select(DBControlPolicy).where(
            DBControlPolicy.name == name,  # type: ignore[arg-type]
            DBControlPolicy.is_active,  # type: ignore[arg-type]
        )
        result = await session.execute(stmt)
        policy = result.scalar_one_or_none()
        if not policy:
            raise LuthienDBQueryError(f"Policy with name '{name}' not found")
        return policy
    except LuthienDBQueryError:
        raise
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error fetching policy by name '{name}': {sqla_err}", exc_info=True)
        raise LuthienDBQueryError(f"Database query failed while fetching policy '{name}': {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error fetching policy by name '{name}': {e}", exc_info=True)
        raise LuthienDBOperationError(f"Unexpected error during policy lookup: {e}") from e

Get a policy by its name.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
name str

The name of the policy to retrieve

required

Returns:

Type Description
ControlPolicy

The policy

Raises:

Type Description
LuthienDBQueryError

If the policy is not found or if the query execution fails

LuthienDBOperationError

For unexpected errors during lookup

get_policy_config_by_name(session, name) async

Source code in luthien_control/db/control_policy_crud.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
async def get_policy_config_by_name(session: AsyncSession, name: str) -> DBControlPolicy:
    """Get a policy configuration by its name, regardless of its active status.

    Args:
        session: The database session
        name: The name of the policy to retrieve

    Returns:
        The policy

    Raises:
        LuthienDBQueryError: If the policy is not found or if the query execution fails
        LuthienDBOperationError: For unexpected errors during lookup
    """
    try:
        stmt = select(DBControlPolicy).where(DBControlPolicy.name == name)  # type: ignore[arg-type]
        result = await session.execute(stmt)
        policy = result.scalar_one_or_none()
        if not policy:
            raise LuthienDBQueryError(f"Policy with name '{name}' not found")
        return policy
    except LuthienDBQueryError:
        raise
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error fetching policy configuration by name '{name}': {sqla_err}", exc_info=True)
        raise LuthienDBQueryError(f"Database query failed while fetching policy config '{name}'") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error fetching policy configuration by name '{name}': {e}", exc_info=True)
        raise LuthienDBOperationError(f"Unexpected error during policy config lookup: {e}") from e

Get a policy configuration by its name, regardless of its active status.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
name str

The name of the policy to retrieve

required

Returns:

Type Description
ControlPolicy

The policy

Raises:

Type Description
LuthienDBQueryError

If the policy is not found or if the query execution fails

LuthienDBOperationError

For unexpected errors during lookup

list_policies(session, active_only=False) async

Source code in luthien_control/db/control_policy_crud.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
async def list_policies(session: AsyncSession, active_only: bool = False) -> List[DBControlPolicy]:
    """Get a list of all policies.

    Args:
        session: The database session
        active_only: If True, only return active policies

    Returns:
        A list of policies

    Raises:
        LuthienDBQueryError: If the query execution fails
    """
    try:
        if active_only:
            stmt = select(DBControlPolicy).where(DBControlPolicy.is_active)  # type: ignore[arg-type]
        else:
            stmt = select(DBControlPolicy)

        result = await session.execute(stmt)
        return list(result.scalars().all())
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error listing policies: {sqla_err}")
        raise LuthienDBQueryError(f"Database query failed while listing policies: {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error listing policies: {e}")
        raise LuthienDBOperationError(f"Unexpected error during policy listing: {e}") from e

Get a list of all policies.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
active_only bool

If True, only return active policies

False

Returns:

Type Description
List[ControlPolicy]

A list of policies

Raises:

Type Description
LuthienDBQueryError

If the query execution fails

load_policy_from_db(name, container) async

Source code in luthien_control/db/control_policy_crud.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
async def load_policy_from_db(
    name: str,
    container: "DependencyContainer",
) -> "ABCControlPolicy":
    """Load a policy configuration from the database and instantiate it using the control_policy loader.

    Args:
        name: The name of the policy to load
        container: The dependency container providing access to the database session

    Returns:
        The instantiated policy

    Raises:
        LuthienDBQueryError: If the database query fails or policy is not found
        LuthienDBOperationError: If the policy cannot be instantiated or other database operation errors occur
    """
    try:
        async with container.db_session_factory() as session:
            policy_name = await get_policy_by_name(session, name)

        # Prepare the data for the simple loader
        policy_data_dict = {
            "type": policy_name.type,  # The loader uses this to find the class
            "config": policy_name.config or {},
        }

        # Construct the SerializedPolicy dataclass instance
        serialized_policy_obj = SerializedPolicy(type=policy_data_dict["type"], config=policy_data_dict["config"])

        try:
            instance = load_policy(serialized_policy_obj)
            logger.info(f"Successfully loaded and instantiated policy '{policy_name.name}' from database.")
            return instance
        except PolicyLoadError as e:
            logger.error(f"Failed to load policy '{name}' from database: {e}")
            raise LuthienDBOperationError(
                f"Failed to instantiate policy '{name}' from database configuration: {e}"
            ) from e
        except Exception as e:
            logger.exception(f"Unexpected error loading policy '{name}' from database: {e}")
            raise LuthienDBOperationError(f"Unexpected error during policy instantiation for '{name}': {e}") from e
    except LuthienDBQueryError:
        raise
    except LuthienDBOperationError:
        raise
    except Exception as e:
        logger.exception(f"Unexpected error during policy loading process for '{name}': {e}")
        raise LuthienDBOperationError(f"Unexpected error during policy loading process for '{name}': {e}") from e

Load a policy configuration from the database and instantiate it using the control_policy loader.

Parameters:

Name Type Description Default
name str

The name of the policy to load

required
container DependencyContainer

The dependency container providing access to the database session

required

Returns:

Type Description
ControlPolicy

The instantiated policy

Raises:

Type Description
LuthienDBQueryError

If the database query fails or policy is not found

LuthienDBOperationError

If the policy cannot be instantiated or other database operation errors occur

save_policy_to_db(session, policy) async

Source code in luthien_control/db/control_policy_crud.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
async def save_policy_to_db(session: AsyncSession, policy: DBControlPolicy) -> DBControlPolicy:
    """Create a new policy in the database.

    Args:
        session: The database session
        policy: The policy to create

    Returns:
        The created policy with updated ID

    Raises:
        LuthienDBIntegrityError: If a constraint violation occurs
        LuthienDBTransactionError: If the transaction fails
        LuthienDBOperationError: For other database errors
    """
    try:
        session.add(policy)
        await session.commit()
        await session.refresh(policy)
        logger.info(f"Successfully created policy with ID: {policy.id}")
        return policy
    except IntegrityError as ie:
        await session.rollback()
        logger.error(f"Integrity error creating policy: {ie}")
        raise LuthienDBIntegrityError(f"Could not create policy due to constraint violation: {ie}", ie) from ie
    except SQLAlchemyError as sqla_err:
        await session.rollback()
        logger.error(f"SQLAlchemy error creating policy: {sqla_err}")
        raise LuthienDBTransactionError(f"Database transaction failed while creating policy: {sqla_err}") from sqla_err
    except Exception as e:
        await session.rollback()
        logger.error(f"Error creating policy: {e}")
        raise LuthienDBOperationError(f"Unexpected error during policy creation: {e}") from e

Create a new policy in the database.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
policy ControlPolicy

The policy to create

required

Returns:

Type Description
ControlPolicy

The created policy with updated ID

Raises:

Type Description
LuthienDBIntegrityError

If a constraint violation occurs

LuthienDBTransactionError

If the transaction fails

LuthienDBOperationError

For other database errors

update_policy(session, policy_id, policy_update) async

Source code in luthien_control/db/control_policy_crud.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
async def update_policy(session: AsyncSession, policy_id: int, policy_update: DBControlPolicy) -> DBControlPolicy:
    """Update an existing policy.

    Args:
        session: The database session
        policy_id: The ID of the policy to update
        policy_update: The updated policy data

    Returns:
        The updated policy

    Raises:
        LuthienDBQueryError: If the policy is not found
        LuthienDBIntegrityError: If a constraint violation occurs
        LuthienDBTransactionError: If the transaction fails
        LuthienDBOperationError: For other database errors
    """
    try:
        stmt = select(DBControlPolicy).where(DBControlPolicy.id == policy_id)  # type: ignore[arg-type]
        result = await session.execute(stmt)
        policy = result.scalar_one_or_none()

        if not policy:
            raise LuthienDBQueryError(f"Policy with ID {policy_id} not found")

        # Update fields
        policy.name = policy_update.name
        policy.config = policy_update.config
        policy.is_active = policy_update.is_active
        policy.description = policy_update.description

        await session.commit()
        await session.refresh(policy)
        logger.info(f"Successfully updated policy with ID: {policy.id}")
        return policy
    except LuthienDBQueryError:
        raise
    except IntegrityError as ie:
        await session.rollback()
        logger.error(f"Integrity error updating policy: {ie}")
        raise LuthienDBIntegrityError(f"Could not update policy due to constraint violation: {ie}", ie) from ie
    except SQLAlchemyError as sqla_err:
        await session.rollback()
        logger.error(f"SQLAlchemy error updating policy: {sqla_err}")
        raise LuthienDBTransactionError(f"Database transaction failed while updating policy: {sqla_err}") from sqla_err
    except Exception as e:
        await session.rollback()
        logger.error(f"Error updating policy: {e}")
        raise LuthienDBOperationError(f"Unexpected error during policy update: {e}") from e

Update an existing policy.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
policy_id int

The ID of the policy to update

required
policy_update ControlPolicy

The updated policy data

required

Returns:

Type Description
ControlPolicy

The updated policy

Raises:

Type Description
LuthienDBQueryError

If the policy is not found

LuthienDBIntegrityError

If a constraint violation occurs

LuthienDBTransactionError

If the transaction fails

LuthienDBOperationError

For other database errors

database_async

close_db_engine() async

Source code in luthien_control/db/database_async.py
116
117
118
119
120
121
122
123
124
125
126
127
128
async def close_db_engine() -> None:
    """Closes the database engine."""
    global _db_engine
    if _db_engine:
        try:
            await _db_engine.dispose()
            logger.info("Database engine closed successfully.")
        except Exception as e:
            logger.error(f"Error closing database engine: {e}", exc_info=True)
        finally:
            _db_engine = None
    else:
        logger.info("Database engine was already None or not initialized during shutdown.")

Closes the database engine.

create_db_engine() async

Source code in luthien_control/db/database_async.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
async def create_db_engine() -> AsyncEngine:
    """Creates the asyncpg engine for the application DB.
    Returns:
        The asyncpg engine for the application DB.

    Raises:
        LuthienDBConfigurationError: If the database configuration is invalid.
        LuthienDBConnectionError: If the database connection fails.
    """
    global _db_engine, _db_session_factory
    if _db_engine:
        logger.debug("Database engine already initialized.")
        return _db_engine

    logger.info("Attempting to create database engine...")

    db_url = _get_db_url()

    try:
        # Get and validate pool sizes
        pool_min_size = settings.get_main_db_pool_min_size()
        pool_max_size = settings.get_main_db_pool_max_size()

        _db_engine = create_async_engine(
            db_url,
            echo=False,  # Set to True for debugging SQL queries
            pool_pre_ping=True,
            pool_size=pool_min_size,
            max_overflow=pool_max_size - pool_min_size,
        )

        _db_session_factory = async_sessionmaker(
            _db_engine,
            expire_on_commit=False,
            class_=AsyncSession,
        )

        logger.info("Database engine created successfully.")
        return _db_engine
    except Exception as e:
        masked_url = _mask_password(db_url)
        raise LuthienDBConnectionError(f"Failed to create database engine using URL ({masked_url}): {e}")

Creates the asyncpg engine for the application DB. Returns: The asyncpg engine for the application DB.

Raises:

Type Description
LuthienDBConfigurationError

If the database configuration is invalid.

LuthienDBConnectionError

If the database connection fails.

get_db_session() async

Source code in luthien_control/db/database_async.py
131
132
133
134
135
136
137
138
139
140
141
142
@contextlib.asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
    """Get a SQLAlchemy async session for the database as a context manager."""
    if _db_session_factory is None:
        raise RuntimeError("Database session factory has not been initialized")

    async with _db_session_factory() as session:
        try:
            yield session
        except Exception:
            await session.rollback()
            raise

Get a SQLAlchemy async session for the database as a context manager.

exceptions

Database-specific exceptions for the Luthien Control project.

LuthienDBIntegrityError

Bases: LuthienDBOperationError

Source code in luthien_control/db/exceptions.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class LuthienDBIntegrityError(LuthienDBOperationError):
    """Exception raised when a database integrity constraint is violated.

    This exception wraps SQLAlchemy's IntegrityError and provides a more
    specific error type for the Luthien Control project.
    """

    def __init__(self, message: str, original_error: Optional[IntegrityError] = None):
        """Initialize the exception.

        Args:
            message: A descriptive error message
            original_error: The original IntegrityError that was raised
        """
        super().__init__(message)
        self.original_error = original_error

Exception raised when a database integrity constraint is violated.

This exception wraps SQLAlchemy's IntegrityError and provides a more specific error type for the Luthien Control project.

__init__(message, original_error=None)
Source code in luthien_control/db/exceptions.py
45
46
47
48
49
50
51
52
53
def __init__(self, message: str, original_error: Optional[IntegrityError] = None):
    """Initialize the exception.

    Args:
        message: A descriptive error message
        original_error: The original IntegrityError that was raised
    """
    super().__init__(message)
    self.original_error = original_error

Initialize the exception.

Parameters:

Name Type Description Default
message str

A descriptive error message

required
original_error Optional[IntegrityError]

The original IntegrityError that was raised

None

LuthienDBOperationError

Bases: LuthienDBException

Source code in luthien_control/db/exceptions.py
10
11
12
13
14
15
16
17
class LuthienDBOperationError(LuthienDBException):
    """Base exception for database operation errors.

    This exception is raised when a database operation fails for any reason.
    It serves as a base class for more specific database operation errors.
    """

    pass

Base exception for database operation errors.

This exception is raised when a database operation fails for any reason. It serves as a base class for more specific database operation errors.

LuthienDBQueryError

Bases: LuthienDBOperationError

Source code in luthien_control/db/exceptions.py
20
21
22
23
24
25
26
class LuthienDBQueryError(LuthienDBOperationError):
    """Exception raised when a database query fails.

    This exception is raised when a SELECT query fails to execute properly.
    """

    pass

Exception raised when a database query fails.

This exception is raised when a SELECT query fails to execute properly.

LuthienDBTransactionError

Bases: LuthienDBOperationError

Source code in luthien_control/db/exceptions.py
29
30
31
32
33
34
35
class LuthienDBTransactionError(LuthienDBOperationError):
    """Exception raised when a database transaction fails.

    This exception is raised when a transaction (commit, rollback) fails.
    """

    pass

Exception raised when a database transaction fails.

This exception is raised when a transaction (commit, rollback) fails.

luthien_log_crud

count_logs(session, transaction_id=None, datatype=None, start_datetime=None, end_datetime=None) async

Source code in luthien_control/db/luthien_log_crud.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
async def count_logs(
    session: AsyncSession,
    transaction_id: Optional[str] = None,
    datatype: Optional[str] = None,
    start_datetime: Optional[datetime] = None,
    end_datetime: Optional[datetime] = None,
) -> int:
    """Count logs with optional filtering.

    Args:
        session: The database session
        transaction_id: Optional filter by transaction ID
        datatype: Optional filter by datatype
        start_datetime: Optional filter for logs after this datetime
        end_datetime: Optional filter for logs before this datetime

    Returns:
        The count of matching logs

    Raises:
        LuthienDBQueryError: If the query execution fails
        LuthienDBOperationError: For unexpected errors
    """
    try:
        stmt = select(col(LuthienLog.id))

        # Apply filters
        if transaction_id:
            stmt = stmt.where(col(LuthienLog.transaction_id) == transaction_id)
        if datatype:
            stmt = stmt.where(col(LuthienLog.datatype) == datatype)
        if start_datetime:
            stmt = stmt.where(col(LuthienLog.datetime) >= start_datetime)
        if end_datetime:
            stmt = stmt.where(col(LuthienLog.datetime) <= end_datetime)

        result = await session.execute(stmt)
        return len(list(result.scalars().all()))
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error counting logs: {sqla_err}")
        raise LuthienDBQueryError(f"Database query failed while counting logs: {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error counting logs: {e}")
        raise LuthienDBOperationError(f"Unexpected error during log counting: {e}") from e

Count logs with optional filtering.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
transaction_id Optional[str]

Optional filter by transaction ID

None
datatype Optional[str]

Optional filter by datatype

None
start_datetime Optional[datetime]

Optional filter for logs after this datetime

None
end_datetime Optional[datetime]

Optional filter for logs before this datetime

None

Returns:

Type Description
int

The count of matching logs

Raises:

Type Description
LuthienDBQueryError

If the query execution fails

LuthienDBOperationError

For unexpected errors

get_log_by_id(session, log_id) async

Source code in luthien_control/db/luthien_log_crud.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
async def get_log_by_id(session: AsyncSession, log_id: int) -> LuthienLog:
    """Get a specific log by its ID.

    Args:
        session: The database session
        log_id: The ID of the log to retrieve

    Returns:
        The log entry

    Raises:
        LuthienDBQueryError: If the log is not found or if the query execution fails
        LuthienDBOperationError: For unexpected errors during lookup
    """
    try:
        stmt = select(LuthienLog).where(col(LuthienLog.id) == log_id)
        result = await session.execute(stmt)
        log_entry = result.scalar_one_or_none()
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error fetching log by ID: {sqla_err}", exc_info=True)
        raise LuthienDBQueryError(f"Database query failed while fetching log: {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error fetching log by ID: {e}", exc_info=True)
        raise LuthienDBOperationError(f"Unexpected error during log lookup: {e}") from e

    if not log_entry:
        raise LuthienDBQueryError(f"Log with ID {log_id} not found")

    return log_entry

Get a specific log by its ID.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
log_id int

The ID of the log to retrieve

required

Returns:

Type Description
LuthienLog

The log entry

Raises:

Type Description
LuthienDBQueryError

If the log is not found or if the query execution fails

LuthienDBOperationError

For unexpected errors during lookup

get_unique_datatypes(session) async

Source code in luthien_control/db/luthien_log_crud.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
async def get_unique_datatypes(session: AsyncSession) -> List[str]:
    """Get a list of unique datatypes from the logs.

    Args:
        session: The database session

    Returns:
        A list of unique datatype values

    Raises:
        LuthienDBQueryError: If the query execution fails
        LuthienDBOperationError: For unexpected errors
    """
    try:
        stmt = select(col(LuthienLog.datatype)).distinct().order_by(col(LuthienLog.datatype))
        result = await session.execute(stmt)
        return list(result.scalars().all())
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error fetching unique datatypes: {sqla_err}")
        raise LuthienDBQueryError(f"Database query failed while fetching datatypes: {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error fetching unique datatypes: {e}")
        raise LuthienDBOperationError(f"Unexpected error during datatype lookup: {e}") from e

Get a list of unique datatypes from the logs.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required

Returns:

Type Description
List[str]

A list of unique datatype values

Raises:

Type Description
LuthienDBQueryError

If the query execution fails

LuthienDBOperationError

For unexpected errors

get_unique_transaction_ids(session, limit=100) async

Source code in luthien_control/db/luthien_log_crud.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
async def get_unique_transaction_ids(session: AsyncSession, limit: int = 100) -> List[str]:
    """Get a list of unique transaction IDs from recent logs.

    Args:
        session: The database session
        limit: Maximum number of transaction IDs to return (default: 100)

    Returns:
        A list of unique transaction ID values

    Raises:
        LuthienDBQueryError: If the query execution fails
        LuthienDBOperationError: For unexpected errors
    """
    try:
        # Get distinct transaction_ids ordered by the most recent datetime for each transaction
        stmt = (
            select(col(LuthienLog.transaction_id))
            .group_by(col(LuthienLog.transaction_id))
            .order_by(desc(col(LuthienLog.transaction_id)))
            .limit(limit)
        )
        result = await session.execute(stmt)
        return list(result.scalars().all())
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error fetching unique transaction IDs: {sqla_err}")
        raise LuthienDBQueryError(f"Database query failed while fetching transaction IDs: {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error fetching unique transaction IDs: {e}")
        raise LuthienDBOperationError(f"Unexpected error during transaction ID lookup: {e}") from e

Get a list of unique transaction IDs from recent logs.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
limit int

Maximum number of transaction IDs to return (default: 100)

100

Returns:

Type Description
List[str]

A list of unique transaction ID values

Raises:

Type Description
LuthienDBQueryError

If the query execution fails

LuthienDBOperationError

For unexpected errors

list_logs(session, transaction_id=None, datatype=None, limit=100, offset=0, start_datetime=None, end_datetime=None) async

Source code in luthien_control/db/luthien_log_crud.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
async def list_logs(
    session: AsyncSession,
    transaction_id: Optional[str] = None,
    datatype: Optional[str] = None,
    limit: int = 100,
    offset: int = 0,
    start_datetime: Optional[datetime] = None,
    end_datetime: Optional[datetime] = None,
) -> List[LuthienLog]:
    """Get a list of logs with optional filtering.

    Args:
        session: The database session
        transaction_id: Optional filter by transaction ID
        datatype: Optional filter by datatype
        limit: Maximum number of logs to return (default: 100)
        offset: Number of logs to skip (default: 0)
        start_datetime: Optional filter for logs after this datetime
        end_datetime: Optional filter for logs before this datetime

    Returns:
        A list of LuthienLog entries

    Raises:
        LuthienDBQueryError: If the query execution fails
        LuthienDBOperationError: For unexpected errors
    """
    try:
        stmt = select(LuthienLog).order_by(desc(col(LuthienLog.datetime)))

        # Apply filters
        if transaction_id:
            stmt = stmt.where(col(LuthienLog.transaction_id) == transaction_id)
        if datatype:
            stmt = stmt.where(col(LuthienLog.datatype) == datatype)
        if start_datetime:
            stmt = stmt.where(col(LuthienLog.datetime) >= start_datetime)
        if end_datetime:
            stmt = stmt.where(col(LuthienLog.datetime) <= end_datetime)

        # Apply pagination
        stmt = stmt.limit(limit).offset(offset)

        result = await session.execute(stmt)
        return list(result.scalars().all())
    except SQLAlchemyError as sqla_err:
        logger.error(f"SQLAlchemy error listing logs: {sqla_err}")
        raise LuthienDBQueryError(f"Database query failed while listing logs: {sqla_err}") from sqla_err
    except Exception as e:
        logger.error(f"Unexpected error listing logs: {e}")
        raise LuthienDBOperationError(f"Unexpected error during log listing: {e}") from e

Get a list of logs with optional filtering.

Parameters:

Name Type Description Default
session AsyncSession

The database session

required
transaction_id Optional[str]

Optional filter by transaction ID

None
datatype Optional[str]

Optional filter by datatype

None
limit int

Maximum number of logs to return (default: 100)

100
offset int

Number of logs to skip (default: 0)

0
start_datetime Optional[datetime]

Optional filter for logs after this datetime

None
end_datetime Optional[datetime]

Optional filter for logs before this datetime

None

Returns:

Type Description
List[LuthienLog]

A list of LuthienLog entries

Raises:

Type Description
LuthienDBQueryError

If the query execution fails

LuthienDBOperationError

For unexpected errors

naive_datetime

NaiveDatetime

Bases: datetime

Source code in luthien_control/db/naive_datetime.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class NaiveDatetime(datetime):
    """A datetime that automatically strips timezone info."""

    def __new__(cls, *args, **kwargs):
        # Handle datetime object as first argument
        if args and isinstance(args[0], datetime):
            dt = args[0]
            # Convert to naive UTC if timezone-aware, otherwise keep as-is
            if dt.tzinfo is not None:
                dt = dt.astimezone(timezone.utc).replace(tzinfo=None)
            return super().__new__(cls, dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, dt.microsecond)
        else:
            # Normal datetime constructor
            return super().__new__(cls, *args, **kwargs)

    @classmethod
    def now(cls, tz: Optional[tzinfo] = None) -> "NaiveDatetime":
        """Create a NaiveDatetime representing the current UTC time (naive)."""
        # Always return naive UTC time regardless of tz parameter for consistency
        return cls(datetime.now(timezone.utc))

    @classmethod
    def __get_pydantic_core_schema__(cls, source_type: Any, handler: Any) -> core_schema.CoreSchema:
        """Pydantic schema for NaiveDatetime."""
        return core_schema.with_info_before_validator_function(
            cls._convert_to_naive,
            core_schema.datetime_schema(),
        )

    @classmethod
    def _convert_to_naive(cls, value: Any, info: Any) -> Any:
        """Convert datetime to naive before Pydantic processes it."""
        if isinstance(value, datetime):
            return cls(value)  # This will trigger our __new__ method
        return value  # Let Pydantic handle other types

A datetime that automatically strips timezone info.

__get_pydantic_core_schema__(source_type, handler) classmethod
Source code in luthien_control/db/naive_datetime.py
28
29
30
31
32
33
34
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: Any) -> core_schema.CoreSchema:
    """Pydantic schema for NaiveDatetime."""
    return core_schema.with_info_before_validator_function(
        cls._convert_to_naive,
        core_schema.datetime_schema(),
    )

Pydantic schema for NaiveDatetime.

now(tz=None) classmethod
Source code in luthien_control/db/naive_datetime.py
22
23
24
25
26
@classmethod
def now(cls, tz: Optional[tzinfo] = None) -> "NaiveDatetime":
    """Create a NaiveDatetime representing the current UTC time (naive)."""
    # Always return naive UTC time regardless of tz parameter for consistency
    return cls(datetime.now(timezone.utc))

Create a NaiveDatetime representing the current UTC time (naive).

sqlmodel_models

AdminSession

Bases: SQLModel

Source code in luthien_control/db/sqlmodel_models.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class AdminSession(SQLModel, table=True):
    """Admin session model for managing active sessions."""

    __tablename__ = "admin_sessions"  # type: ignore

    id: Optional[int] = Field(default=None, primary_key=True)
    session_token: str = Field(sa_column=Column(String(255), unique=True, nullable=False))
    admin_user_id: int = Field(
        foreign_key="admin_users.id",
        description="Reference to admin user",
    )
    expires_at: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=False))
    created_at: datetime = Field(
        default_factory=datetime.utcnow,
        sa_column=Column(DateTime(timezone=True), server_default=func.now()),
    )

    __table_args__ = (
        Index("idx_session_token", "session_token"),
        Index("idx_session_expires", "expires_at"),
    )

Admin session model for managing active sessions.

AdminUser

Bases: SQLModel

Source code in luthien_control/db/sqlmodel_models.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class AdminUser(SQLModel, table=True):
    """Admin user model for authentication and authorization."""

    __tablename__ = "admin_users"  # type: ignore

    id: Optional[int] = Field(default=None, primary_key=True)
    username: str = Field(sa_column=Column(String(50), unique=True, nullable=False))
    password_hash: str = Field(sa_column=Column(String(255), nullable=False))
    is_active: bool = Field(default=True)
    is_superuser: bool = Field(default=False)
    last_login: Optional[datetime] = Field(default=None, sa_column=Column(DateTime(timezone=True)))
    created_at: datetime = Field(
        default_factory=datetime.utcnow,
        sa_column=Column(DateTime(timezone=True), server_default=func.now()),
    )
    updated_at: datetime = Field(
        default_factory=datetime.utcnow,
        sa_column=Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()),
    )

    __table_args__ = (
        Index("idx_admin_username", "username"),
        Index("idx_admin_active", "is_active"),
    )

Admin user model for authentication and authorization.

ControlPolicy

Bases: SQLModel

Source code in luthien_control/db/sqlmodel_models.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class ControlPolicy(SQLModel, table=True):
    __tablename__ = "policies"  # type: ignore (again, shut up pyright)
    """Database model for storing control policy configurations."""

    # Primary key
    id: Optional[int] = Field(default=None, primary_key=True)

    # --- Core Fields ---
    name: str = Field(index=True, unique=True)  # Unique name used for lookup
    type: str = Field()  # Type of policy, used for instantiation
    config: dict[str, Any] = Field(default={}, sa_column=Column(JSON))
    is_active: bool = Field(default=True, index=True)
    description: Optional[str] = Field(default=None)

    # --- Timestamps ---
    created_at: dt.datetime = Field(
        default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=False
    )
    updated_at: dt.datetime = Field(
        default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=False
    )

    def __init__(self, **data: Any):
        # Ensure timestamps are set on creation if not provided
        if "created_at" not in data:
            data["created_at"] = datetime.now(timezone.utc).replace(tzinfo=None)
        if "updated_at" not in data:
            data["updated_at"] = datetime.now(timezone.utc).replace(tzinfo=None)
        super().__init__(**data)

    @model_validator(mode="before")
    @classmethod
    def validate_timestamps(cls, values):
        """Ensure updated_at is always set/updated."""
        if isinstance(values, dict):
            values["updated_at"] = datetime.now(timezone.utc).replace(tzinfo=None)
        return values
class-attribute instance-attribute

Database model for storing control policy configurations.

__tablename__ = 'policies' class-attribute instance-attribute

Database model for storing control policy configurations.

validate_timestamps(values) classmethod
Source code in luthien_control/db/sqlmodel_models.py
78
79
80
81
82
83
84
@model_validator(mode="before")
@classmethod
def validate_timestamps(cls, values):
    """Ensure updated_at is always set/updated."""
    if isinstance(values, dict):
        values["updated_at"] = datetime.now(timezone.utc).replace(tzinfo=None)
    return values

Ensure updated_at is always set/updated.

JsonBOrJson

Bases: TypeDecorator

Source code in luthien_control/db/sqlmodel_models.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class JsonBOrJson(types.TypeDecorator):
    """
    Represents a JSON type that uses JSONB for PostgreSQL and JSON for other dialects (like SQLite).

    This is mostly a hack for unit testing, as SQLite does not support JSONB.
    """

    impl = JSON  # Default implementation if dialect-specific is not found
    cache_ok = True  # Safe to cache this type decorator

    def load_dialect_impl(self, dialect):
        if dialect.name == "postgresql":
            return dialect.type_descriptor(JSONB())
        else:
            return dialect.type_descriptor(JSON())

Represents a JSON type that uses JSONB for PostgreSQL and JSON for other dialects (like SQLite).

This is mostly a hack for unit testing, as SQLite does not support JSONB.

LuthienLog

Bases: SQLModel

Source code in luthien_control/db/sqlmodel_models.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class LuthienLog(SQLModel, table=True):
    """
    Represents a log entry in the Luthien logging system using SQLModel.

    Attributes:
        id: Unique identifier for the log entry (primary key).
        transaction_id: Identifier to group related log entries.
        datetime: Timestamp indicating when the log entry was generated (timezone-aware).
        data: JSON blob containing the primary logged data.
        datatype: String identifier for the nature and schema of 'data'.
        notes: JSON blob for additional contextual information.
    """

    __tablename__ = "luthien_log"  # type: ignore (shut up pyright)

    id: Optional[int] = Field(default=None, primary_key=True, index=True)
    transaction_id: str = Field(index=True, nullable=False)
    datetime: NaiveDatetime = Field(
        default_factory=NaiveDatetime.now,
        nullable=False,
        index=True,
    )
    data: Optional[dict[str, Any]] = Field(default=None, sa_column=Column(JsonBOrJson))
    datatype: str = Field(index=True, nullable=False)
    notes: Optional[dict[str, Any]] = Field(default=None, sa_column=Column(JsonBOrJson))

    def __init__(self, **data: Any) -> None:
        """Override init to ensure datetime is converted to NaiveDatetime."""
        if "datetime" in data:
            dt_value = data["datetime"]
            if isinstance(dt_value, datetime) and not isinstance(dt_value, NaiveDatetime):
                data["datetime"] = NaiveDatetime(dt_value)
        super().__init__(**data)

    # __table_args__ = (
    #     Index("ix_sqlmodel_luthien_log_transaction_id", "transaction_id"),
    #     Index("ix_sqlmodel_luthien_log_datetime", "datetime"),
    #     Index("ix_sqlmodel_luthien_log_datatype", "datatype"),
    #     {"extend_existing": True},
    # )

    # __repr__ is not automatically generated by SQLModel like Pydantic models,
    # but you can add one if desired.
    def __repr__(self) -> str:
        return (
            f"<LuthienLog(id={self.id}, "
            f"transaction_id='{self.transaction_id}', "
            f"datetime='{self.datetime}', "
            f"datatype='{self.datatype}')>"
        )

Represents a log entry in the Luthien logging system using SQLModel.

Attributes:

Name Type Description
id Optional[int]

Unique identifier for the log entry (primary key).

transaction_id str

Identifier to group related log entries.

datetime NaiveDatetime

Timestamp indicating when the log entry was generated (timezone-aware).

data Optional[dict[str, Any]]

JSON blob containing the primary logged data.

datatype str

String identifier for the nature and schema of 'data'.

notes Optional[dict[str, Any]]

JSON blob for additional contextual information.

__init__(**data)
Source code in luthien_control/db/sqlmodel_models.py
113
114
115
116
117
118
119
def __init__(self, **data: Any) -> None:
    """Override init to ensure datetime is converted to NaiveDatetime."""
    if "datetime" in data:
        dt_value = data["datetime"]
        if isinstance(dt_value, datetime) and not isinstance(dt_value, NaiveDatetime):
            data["datetime"] = NaiveDatetime(dt_value)
    super().__init__(**data)

Override init to ensure datetime is converted to NaiveDatetime.

exceptions

LuthienDBConfigurationError

Bases: LuthienDBException

Source code in luthien_control/exceptions.py
13
14
15
16
class LuthienDBConfigurationError(LuthienDBException):
    """Exception raised when a database configuration is invalid or missing required variables."""

    pass

Exception raised when a database configuration is invalid or missing required variables.

LuthienDBConnectionError

Bases: LuthienDBException

Source code in luthien_control/exceptions.py
19
20
21
22
class LuthienDBConnectionError(LuthienDBException):
    """Exception raised when a connection to the database fails."""

    pass

Exception raised when a connection to the database fails.

LuthienDBException

Bases: LuthienException

Source code in luthien_control/exceptions.py
 7
 8
 9
10
class LuthienDBException(LuthienException):
    """Base exception for all Luthien DB related errors."""

    pass

Base exception for all Luthien DB related errors.

LuthienException

Bases: Exception

Source code in luthien_control/exceptions.py
1
2
3
4
class LuthienException(Exception):
    """Base exception for all Luthien errors."""

    pass

Base exception for all Luthien errors.

logs

router

get_datatypes(session=Depends(get_db_session)) async

Source code in luthien_control/logs/router.py
151
152
153
154
155
156
157
158
159
160
161
162
163
@router.get("/admin/logs-api/metadata/datatypes")
async def get_datatypes(
    session: AsyncSession = Depends(get_db_session),
) -> List[str]:
    """Get list of unique datatypes from logs."""
    try:
        return await get_unique_datatypes(session=session)
    except (LuthienDBQueryError, LuthienDBOperationError) as db_err:
        logger.error(f"Database error getting datatypes: {db_err}")
        raise HTTPException(status_code=500, detail="Failed to retrieve datatypes from database")
    except Exception as e:
        logger.error(f"Unexpected error getting datatypes: {e}")
        raise HTTPException(status_code=500, detail="An unexpected error occurred")

Get list of unique datatypes from logs.

get_log(log_id, session=Depends(get_db_session)) async

Source code in luthien_control/logs/router.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
@router.get("/admin/logs-api/logs/{log_id}")
async def get_log(
    log_id: int,
    session: AsyncSession = Depends(get_db_session),
) -> Dict[str, Any]:
    """Get a specific log by ID."""
    try:
        log = await get_log_by_id(session=session, log_id=log_id)

        return {
            "id": log.id,
            "transaction_id": log.transaction_id,
            "datetime": log.datetime.isoformat() if log.datetime else None,
            "datatype": log.datatype,
            "data": log.data,
            "notes": log.notes,
        }

    except LuthienDBQueryError:
        raise HTTPException(status_code=404, detail=f"Log with ID {log_id} not found")
    except LuthienDBOperationError as db_err:
        logger.error(f"Database error getting log {log_id}: {db_err}")
        raise HTTPException(status_code=500, detail="Failed to retrieve log from database")
    except Exception as e:
        logger.error(f"Unexpected error getting log {log_id}: {e}")
        raise HTTPException(status_code=500, detail="An unexpected error occurred")

Get a specific log by ID.

get_logs(session=Depends(get_db_session), transaction_id=Query(None, description='Filter by transaction ID'), datatype=Query(None, description='Filter by datatype'), limit=Query(100, ge=1, le=1000, description='Maximum number of logs to return'), offset=Query(0, ge=0, description='Number of logs to skip'), start_datetime=Query(None, description='Start datetime (ISO format)'), end_datetime=Query(None, description='End datetime (ISO format)')) async

Source code in luthien_control/logs/router.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@router.get("/admin/logs-api/logs")
async def get_logs(
    session: AsyncSession = Depends(get_db_session),
    transaction_id: Optional[str] = Query(None, description="Filter by transaction ID"),
    datatype: Optional[str] = Query(None, description="Filter by datatype"),
    limit: int = Query(100, ge=1, le=1000, description="Maximum number of logs to return"),
    offset: int = Query(0, ge=0, description="Number of logs to skip"),
    start_datetime: Optional[str] = Query(None, description="Start datetime (ISO format)"),
    end_datetime: Optional[str] = Query(None, description="End datetime (ISO format)"),
) -> Dict[str, Any]:
    """Get logs with optional filtering and pagination.

    Returns:
        Dictionary containing logs, pagination info, and metadata
    """
    try:
        # Parse datetime strings if provided
        start_dt = None
        end_dt = None
        if start_datetime:
            try:
                start_dt = datetime.fromisoformat(start_datetime.replace("Z", "+00:00"))
            except ValueError as e:
                raise HTTPException(status_code=400, detail=f"Invalid start_datetime format: {e}")

        if end_datetime:
            try:
                end_dt = datetime.fromisoformat(end_datetime.replace("Z", "+00:00"))
            except ValueError as e:
                raise HTTPException(status_code=400, detail=f"Invalid end_datetime format: {e}")

        # Get logs and total count
        logs = await list_logs(
            session=session,
            transaction_id=transaction_id,
            datatype=datatype,
            limit=limit,
            offset=offset,
            start_datetime=start_dt,
            end_datetime=end_dt,
        )

        total_count = await count_logs(
            session=session,
            transaction_id=transaction_id,
            datatype=datatype,
            start_datetime=start_dt,
            end_datetime=end_dt,
        )

        # Convert logs to dict format for JSON response
        logs_data = []
        for log in logs:
            log_dict = {
                "id": log.id,
                "transaction_id": log.transaction_id,
                "datetime": log.datetime.isoformat() if log.datetime else None,
                "datatype": log.datatype,
                "data": log.data,
                "notes": log.notes,
            }
            logs_data.append(log_dict)

        return {
            "logs": logs_data,
            "pagination": {
                "limit": limit,
                "offset": offset,
                "total": total_count,
                "has_next": offset + limit < total_count,
                "has_prev": offset > 0,
            },
            "filters": {
                "transaction_id": transaction_id,
                "datatype": datatype,
                "start_datetime": start_datetime,
                "end_datetime": end_datetime,
            },
        }

    except HTTPException:
        # Re-raise HTTPExceptions (like 400 Bad Request) without modification
        raise
    except (LuthienDBQueryError, LuthienDBOperationError) as db_err:
        logger.error(f"Database error getting logs: {db_err}")
        raise HTTPException(status_code=500, detail="Failed to retrieve logs from database")
    except Exception as e:
        logger.error(f"Unexpected error getting logs: {e}")
        raise HTTPException(status_code=500, detail="An unexpected error occurred")

Get logs with optional filtering and pagination.

Returns:

Type Description
Dict[str, Any]

Dictionary containing logs, pagination info, and metadata

get_transaction_ids(limit=Query(100, ge=1, le=500, description='Maximum number of transaction IDs to return'), session=Depends(get_db_session)) async

Source code in luthien_control/logs/router.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
@router.get("/admin/logs-api/metadata/transaction-ids")
async def get_transaction_ids(
    limit: int = Query(100, ge=1, le=500, description="Maximum number of transaction IDs to return"),
    session: AsyncSession = Depends(get_db_session),
) -> List[str]:
    """Get list of unique transaction IDs from recent logs."""
    try:
        return await get_unique_transaction_ids(session=session, limit=limit)
    except (LuthienDBQueryError, LuthienDBOperationError) as db_err:
        logger.error(f"Database error getting transaction IDs: {db_err}")
        raise HTTPException(status_code=500, detail="Failed to retrieve transaction IDs from database")
    except Exception as e:
        logger.error(f"Unexpected error getting transaction IDs: {e}")
        raise HTTPException(status_code=500, detail="An unexpected error occurred")

Get list of unique transaction IDs from recent logs.

logs_ui(request) async

Source code in luthien_control/logs/router.py
26
27
28
29
@router.get("/admin/logs", response_class=HTMLResponse)
async def logs_ui(request: Request):
    """Serve the logs exploration UI."""
    return templates.TemplateResponse(request, "logs.html", {})

Serve the logs exploration UI.

main

health_check() async

Source code in luthien_control/main.py
102
103
104
105
106
107
108
109
110
111
112
@app.get("/health", tags=["General"], status_code=200)
async def health_check():
    """Perform a basic health check.

    This endpoint can be used to verify that the application is running
    and responsive.

    Returns:
        A dictionary indicating the application status.
    """
    return {"status": "ok"}

Perform a basic health check.

This endpoint can be used to verify that the application is running and responsive.

Returns:

Type Description

A dictionary indicating the application status.

lifespan(app) async

Source code in luthien_control/main.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Manage the lifespan of the application resources.

    This asynchronous context manager handles the startup and shutdown events
    of the FastAPI application. It initializes dependencies on startup
    and ensures they are properly cleaned up on shutdown.

    Args:
        app: The FastAPI application instance.

    Yields:
        None: After startup procedures are complete, allowing the application to run.

    Raises:
        RuntimeError: If critical application dependencies fail to initialize during startup.
    """
    logger.info("Application startup sequence initiated.")

    # Startup: Load Settings
    app_settings = Settings()
    logger.info("Settings loaded.")

    # Startup: Initialize Application Dependencies via helper
    # This variable will hold the container if successfully created.
    initialized_dependencies: DependencyContainer | None = None
    try:
        initialized_dependencies = await initialize_app_dependencies(app_settings)
        app.state.dependencies = initialized_dependencies
        logger.info("Core application dependencies initialized and stored in app state.")

        # Ensure default admin user exists
        async for db in get_db_session(initialized_dependencies):
            await admin_auth_service.ensure_default_admin(db)
            break

    except Exception as init_exc:
        # _initialize_app_dependencies is responsible for cleaning up resources it
        # attempted to create (like its own http_client) if it fails internally.
        # The main concern here is logging and ensuring the app doesn't start.
        logger.critical(f"Fatal error during application dependency initialization: {init_exc}", exc_info=True)
        # If _initialize_app_dependencies failed before creating db_engine, close_db_engine is safe.
        # If it failed *after* db_engine creation but before container, db_engine might be open.
        # The helper itself doesn't call close_db_engine(); it expects lifespan to do so.
        # Global close_db_engine handles if engine was never set or already closed.
        await close_db_engine()
        logger.info("DB Engine closed due to dependency initialization failure during startup.")
        # Re-raise to prevent application from starting up in a bad state.
        raise RuntimeError(
            f"Application startup failed due to dependency initialization error: {init_exc}"
        ) from init_exc

    yield  # Application runs here

    # Shutdown: Clean up resources
    logger.info("Application shutdown sequence initiated.")

    # Close main DB engine (handles its own check if already closed or never initialized)
    await close_db_engine()
    logger.info("Main DB Engine closed.")

    # Shutdown: Close the HTTP client via the container if available
    await initialized_dependencies.http_client.aclose()
    logger.info("HTTP Client from DependencyContainer closed.")

    logger.info("Application shutdown complete.")

Manage the lifespan of the application resources.

This asynchronous context manager handles the startup and shutdown events of the FastAPI application. It initializes dependencies on startup and ensures they are properly cleaned up on shutdown.

Parameters:

Name Type Description Default
app FastAPI

The FastAPI application instance.

required

Yields:

Name Type Description
None

After startup procedures are complete, allowing the application to run.

Raises:

Type Description
RuntimeError

If critical application dependencies fail to initialize during startup.

read_root() async

Source code in luthien_control/main.py
123
124
125
126
127
128
129
130
@app.get("/")
async def read_root():
    """Provide a simple root endpoint.

    Returns:
        A welcome message indicating the proxy is running.
    """
    return {"message": "Luthien Control Proxy is running."}

Provide a simple root endpoint.

Returns:

Type Description

A welcome message indicating the proxy is running.

proxy

debugging

Enhanced debugging utilities for the proxy pipeline.

DebugLoggingMiddleware

Bases: BaseHTTPMiddleware

Source code in luthien_control/proxy/debugging.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class DebugLoggingMiddleware(BaseHTTPMiddleware):
    """Middleware to add detailed request/response logging for debugging."""

    async def dispatch(self, request: Request, call_next) -> Response:
        start_time = time.time()
        request_id = request.headers.get("x-request-id", "no-id")

        # Log request details
        request_body = None
        if request.method in ["POST", "PUT", "PATCH"]:
            try:
                request_body = await request.body()

                # Reconstruct request for downstream processing
                async def receive():
                    return {"type": "http.request", "body": request_body}

                request._receive = receive

                # Try to parse JSON for logging
                try:
                    parsed_body = json.loads(request_body) if request_body else None
                    logger.debug(
                        f"[{request_id}] Incoming {request.method} request",
                        extra={
                            "path": request.url.path,
                            "headers": dict(request.headers),
                            "body": parsed_body,
                            "query_params": dict(request.query_params),
                        },
                    )
                except json.JSONDecodeError:
                    logger.debug(
                        f"[{request_id}] Incoming {request.method} request (non-JSON body)",
                        extra={
                            "path": request.url.path,
                            "headers": dict(request.headers),
                            "body_length": len(request_body) if request_body else 0,
                            "query_params": dict(request.query_params),
                        },
                    )
            except Exception as e:
                logger.error(f"[{request_id}] Error reading request body: {e}")
        else:
            logger.debug(
                f"[{request_id}] Incoming {request.method} request",
                extra={
                    "path": request.url.path,
                    "headers": dict(request.headers),
                    "query_params": dict(request.query_params),
                },
            )

        # Process request
        response = await call_next(request)

        # Calculate duration
        duration = time.time() - start_time

        # Log response details
        logger.info(
            f"[{request_id}] Request completed",
            extra={
                "method": request.method,
                "path": request.url.path,
                "status_code": response.status_code,
                "duration_seconds": duration,
            },
        )

        # Add debug headers
        response.headers["X-Request-ID"] = request_id
        response.headers["X-Processing-Time"] = f"{duration:.3f}s"

        return response

Middleware to add detailed request/response logging for debugging.

create_debug_response(status_code, message, transaction_id, details=None, include_debug_info=True)

Source code in luthien_control/proxy/debugging.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def create_debug_response(
    status_code: int,
    message: str,
    transaction_id: str,
    details: Optional[Dict[str, Any]] = None,
    include_debug_info: bool = True,
) -> Dict[str, Any]:
    """Create a detailed error response for debugging."""
    response = {
        "detail": message,
        "transaction_id": transaction_id,
    }

    if include_debug_info and details:
        response["debug"] = str({"timestamp": datetime.now(UTC).isoformat(), **details})

    return response

Create a detailed error response for debugging.

log_policy_execution(transaction_id, policy_name, status, duration=None, error=None, details=None)

Source code in luthien_control/proxy/debugging.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def log_policy_execution(
    transaction_id: str,
    policy_name: str,
    status: str,
    duration: Optional[float] = None,
    error: Optional[str] = None,
    details: Optional[Dict[str, Any]] = None,
) -> None:
    """Log policy execution details."""
    log_data = {
        "transaction_id": transaction_id,
        "policy_name": policy_name,
        "status": status,
    }

    if duration is not None:
        log_data["duration_seconds"] = str(duration)

    if error:
        log_data["error"] = error

    if details:
        log_data.update(details)

    if status == "error":
        logger.error(f"[{transaction_id}] Policy {policy_name} failed", extra=log_data)
    else:
        logger.info(f"[{transaction_id}] Policy {policy_name} {status}", extra=log_data)

Log policy execution details.

log_transaction_state(transaction_id, stage, details)

Source code in luthien_control/proxy/debugging.py
92
93
94
95
96
97
def log_transaction_state(transaction_id: str, stage: str, details: Dict[str, Any]) -> None:
    """Log transaction state at various stages of processing."""
    logger.debug(
        f"[{transaction_id}] Transaction state at {stage}",
        extra={"stage": stage, "timestamp": datetime.now(UTC).isoformat(), **details},
    )

Log transaction state at various stages of processing.

orchestration

run_policy_flow(request, main_policy, dependencies, session) async

Source code in luthien_control/proxy/orchestration.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
async def run_policy_flow(
    request: fastapi.Request,
    main_policy: ControlPolicy,
    dependencies: DependencyContainer,
    session: AsyncSession,
) -> fastapi.Response:
    """
    Orchestrates the execution of the main ControlPolicy using injected dependencies.
    Exceptions raised by policies are expected to be caught by FastAPI exception handlers.

    Args:
        request: The incoming FastAPI request.
        main_policy: The main policy instance to execute.
        dependencies: The application's dependency container.
        session: The database session for this request.

    Returns:
        The final FastAPI response.
    """
    # 1. Initialize Context
    body = await request.body()
    url = request.path_params["full_path"]
    api_key = request.headers.get("authorization", "").replace("Bearer ", "")
    transaction = _initialize_transaction(body, url, api_key)

    # Log initial transaction state
    log_transaction_state(
        str(transaction.transaction_id),
        "initialization",
        {
            "url": url,
            "method": request.method,
            "has_api_key": bool(api_key),
            "body_length": len(body) if body else 0,
            "headers_count": len(request.headers),
        },
    )

    # 2. Apply the main policy
    policy_start_time = None
    try:
        logger.info(
            "Applying control policy",
            extra={
                "transaction_id": str(transaction.transaction_id),
                "policy_name": main_policy.name,
                "url": url,
                "method": request.method,
            },
        )
        policy_start_time = time.time()
        transaction = await main_policy.apply(transaction=transaction, container=dependencies, session=session)

        # Log successful policy execution
        log_policy_execution(
            str(transaction.transaction_id),
            main_policy.name or "unknown",
            "completed",
            duration=time.time() - policy_start_time if policy_start_time else None,
            details={"has_response": transaction.response.payload is not None},
        )

        logger.info(
            "Policy execution complete",
            extra={
                "transaction_id": str(transaction.transaction_id),
                "policy_name": main_policy.name,
                "duration_seconds": time.time() - policy_start_time if policy_start_time else None,
            },
        )
        if transaction.response.payload is not None:
            final_response = openai_chat_completions_response_to_fastapi_response(transaction.response.payload)
        else:
            final_response = JSONResponse(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                content={
                    "detail": "Internal Server Error: No response payload",
                    "transaction_id": str(transaction.transaction_id),
                    "policy_name": main_policy.name,
                },
            )

    except ControlPolicyError as e:
        # Log policy error
        policy_duration = time.time() - policy_start_time if policy_start_time else None
        log_policy_execution(
            str(transaction.transaction_id),
            main_policy.name or "unknown",
            "error",
            duration=policy_duration,
            error=str(e),
            details={
                "error_type": e.__class__.__name__,
                "policy_name": getattr(e, "policy_name", "unknown"),
            },
        )

        logger.warning(
            f"Control policy error - transaction {transaction.transaction_id}",
            extra={
                "transaction_id": str(transaction.transaction_id),
                "error": str(e),
                "error_type": e.__class__.__name__,
                "policy_name": getattr(e, "policy_name", "unknown"),
            },
        )
        # Directly build a JSONResponse for policy errors
        policy_name_for_error = getattr(e, "policy_name", "unknown")
        status_code = getattr(e, "status_code", None) or status.HTTP_400_BAD_REQUEST  # Use 400 if None or not specified
        error_detail = getattr(e, "detail", str(e))  # Use str(e) if no detail attribute

        # Check if we're in dev mode and if the exception has debug info
        settings = Settings()
        debug_details = None

        if settings.dev_mode():
            # Check if the ControlPolicyError itself has debug info
            if hasattr(e, "debug_info"):
                debug_details = e.debug_info  # type: ignore
            # Check if the underlying exception (__cause__) has debug info
            elif hasattr(e, "__cause__") and hasattr(e.__cause__, "debug_info"):
                debug_details = e.__cause__.debug_info  # type: ignore

        # Use create_debug_response to generate the response
        response_content = create_debug_response(
            status_code=status_code,
            message=f"Policy error in '{policy_name_for_error}': {error_detail}",
            transaction_id=str(transaction.transaction_id),
            details=debug_details,
            include_debug_info=settings.dev_mode(),
        )

        final_response = JSONResponse(
            status_code=status_code,
            content=response_content,
        )

    except Exception as e:
        # Log unexpected error
        policy_duration = time.time() - policy_start_time if policy_start_time else None
        log_policy_execution(
            str(transaction.transaction_id),
            main_policy.name or "unknown",
            "error",
            duration=policy_duration,
            error=str(e),
            details={
                "error_type": e.__class__.__name__,
                "unexpected": True,
            },
        )

        # Handle unexpected errors during initialization or policy execution
        logger.exception(
            f"Unhandled exception during policy flow - transaction {transaction.transaction_id}",
            extra={
                "transaction_id": str(transaction.transaction_id),
                "error": str(e),
                "error_type": e.__class__.__name__,
            },
        )
        # Try to build an error response using the builder
        policy_name_for_error = getattr(main_policy, "name", main_policy.__class__.__name__)

        # Check if we're in dev mode and if the exception has debug info
        settings = Settings()
        debug_details = None

        if settings.dev_mode() and hasattr(e, "debug_info"):
            debug_details = e.debug_info  # type: ignore

        # Use create_debug_response to generate the response
        response_content = create_debug_response(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            message="Internal Server Error",
            transaction_id=str(transaction.transaction_id),
            details=debug_details,
            include_debug_info=settings.dev_mode(),
        )

        # Add policy name to the response
        response_content["policy_name"] = policy_name_for_error

        final_response = JSONResponse(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            content=response_content,
        )

    return final_response

Orchestrates the execution of the main ControlPolicy using injected dependencies. Exceptions raised by policies are expected to be caught by FastAPI exception handlers.

Parameters:

Name Type Description Default
request Request

The incoming FastAPI request.

required
main_policy ControlPolicy

The main policy instance to execute.

required
dependencies DependencyContainer

The application's dependency container.

required
session AsyncSession

The database session for this request.

required

Returns:

Type Description
Response

The final FastAPI response.

server

api_proxy_endpoint(request, full_path=default_path, dependencies=Depends(get_dependencies), main_policy=Depends(get_main_control_policy), session=Depends(get_db_session), payload=default_payload, token=Security(http_bearer_auth)) async

Source code in luthien_control/proxy/server.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@router.post(
    "/api/{full_path:path}",
)
async def api_proxy_endpoint(
    request: Request,
    full_path: str = default_path,
    # --- Core Dependencies ---
    dependencies: DependencyContainer = Depends(get_dependencies),
    main_policy: ControlPolicy = Depends(get_main_control_policy),
    session: AsyncSession = Depends(get_db_session),
    # --- Swagger UI Enhancements ---
    # The 'payload' and 'token' parameters enhance the Swagger UI:
    # - 'payload' (dict[str, Any], optional): Provides a schema for the request body.
    #   Actual body content is read directly from the 'request' object.
    # - 'token' (Optional[str]): Enables the 'Authorize' button (Bearer token).
    #   Actual token validation is handled by the policy flow.
    payload: dict[str, Any] = default_payload,
    token: Optional[str] = Security(http_bearer_auth),
):
    """
    Main API proxy endpoint using the policy orchestration flow.
    Handles requests starting with /api/.
    Uses Dependency Injection Container and provides a DB session.

    **Authentication Note:** This endpoint uses Bearer Token authentication
    (Authorization: Bearer <token>). However, the requirement for a valid token
    depends on whether the currently configured control policy includes client
    authentication (e.g., ClientApiKeyAuthPolicy). If the policy does not require
    authentication, the token field can be left blank.
    """
    return await _handle_api_request(request, main_policy, dependencies, session)

Main API proxy endpoint using the policy orchestration flow. Handles requests starting with /api/. Uses Dependency Injection Container and provides a DB session.

Authentication Note: This endpoint uses Bearer Token authentication (Authorization: Bearer ). However, the requirement for a valid token depends on whether the currently configured control policy includes client authentication (e.g., ClientApiKeyAuthPolicy). If the policy does not require authentication, the token field can be left blank.

api_proxy_get_endpoint(request, full_path=default_path, dependencies=Depends(get_dependencies), main_policy=Depends(get_main_control_policy), session=Depends(get_db_session), token=Security(http_bearer_auth)) async

Source code in luthien_control/proxy/server.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@router.get(
    "/api/{full_path:path}",
)
async def api_proxy_get_endpoint(
    request: Request,
    full_path: str = default_path,
    # --- Core Dependencies ---
    dependencies: DependencyContainer = Depends(get_dependencies),
    main_policy: ControlPolicy = Depends(get_main_control_policy),
    session: AsyncSession = Depends(get_db_session),
    # --- Swagger UI Enhancements ---
    # - 'token' (Optional[str]): Enables the 'Authorize' button (Bearer token).
    #   Actual token validation is handled by the policy flow.
    token: Optional[str] = Security(http_bearer_auth),
):
    """
    Main API proxy endpoint for GET requests using the policy orchestration flow.
    Handles GET requests starting with /api/.
    Uses Dependency Injection Container and provides a DB session.

    **Authentication Note:** This endpoint uses Bearer Token authentication
    (Authorization: Bearer <token>). However, the requirement for a valid token
    depends on whether the currently configured control policy includes client
    authentication (e.g., ClientApiKeyAuthPolicy). If the policy does not require
    authentication, the token field can be left blank.
    """
    return await _handle_api_request(request, main_policy, dependencies, session)

Main API proxy endpoint for GET requests using the policy orchestration flow. Handles GET requests starting with /api/. Uses Dependency Injection Container and provides a DB session.

Authentication Note: This endpoint uses Bearer Token authentication (Authorization: Bearer ). However, the requirement for a valid token depends on whether the currently configured control policy includes client authentication (e.g., ClientApiKeyAuthPolicy). If the policy does not require authentication, the token field can be left blank.

api_proxy_options_handler(full_path=default_path) async

Source code in luthien_control/proxy/server.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
@router.options("/api/{full_path:path}")
async def api_proxy_options_handler(
    full_path: str = default_path,  # Keep for path consistency, though not used in this simple handler
):
    """
    Handles OPTIONS requests for the API proxy endpoint, indicating allowed methods.
    """
    logger.info(f"Explicit OPTIONS request received for /api/{full_path}")
    headers = {
        "Allow": "GET, POST, OPTIONS",
        "Access-Control-Allow-Origin": "*",  # Allow any origin
        "Access-Control-Allow-Methods": "GET, POST, OPTIONS",  # Allowed methods
        "Access-Control-Allow-Headers": "Authorization, Content-Type",  # Allowed headers
    }
    return Response(status_code=200, headers=headers)

Handles OPTIONS requests for the API proxy endpoint, indicating allowed methods.

settings

Settings

Source code in luthien_control/settings.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class Settings:
    """Application configuration settings loaded from environment variables."""

    # --- Core Settings ---
    BACKEND_URL: Optional[str] = None
    # Comma-separated list of control policies for the beta framework

    # --- Database Settings ---
    DB_SERVER: str = "localhost"
    DB_USER: Optional[str] = None
    DB_PASSWORD: Optional[str] = None
    DB_NAME: Optional[str] = None
    DB_HOST: Optional[str] = None
    DB_PORT: Optional[int] = 5432

    # --- OpenAI Settings ---
    OPENAI_API_KEY: Optional[str] = None

    # --- Helper Methods using os.getenv ---
    def get_backend_url(self) -> Optional[str]:
        """Returns the backend URL as a string, if set."""
        url = os.getenv("BACKEND_URL")
        if url:
            # Basic validation (can be enhanced)
            parsed = urlparse(url)
            if not all([parsed.scheme, parsed.netloc]):
                raise ValueError(f"Invalid BACKEND_URL format: {url}")
        return url

    def get_database_url(self) -> Optional[str]:
        """Returns the primary DATABASE_URL, if set."""
        return os.getenv("DATABASE_URL")

    def get_openai_api_key(self) -> str | None:
        """Returns the OpenAI API key, if set."""
        return os.getenv("OPENAI_API_KEY")

    def get_top_level_policy_name(self) -> str:
        """Returns the name of the top-level policy instance to load."""
        return os.getenv("TOP_LEVEL_POLICY_NAME", "root")

    def get_policy_filepath(self) -> str | None:
        """Returns the path to the policy file, if set."""
        return os.getenv("POLICY_FILEPATH")

    # --- Database settings Getters using os.getenv ---
    def get_postgres_user(self) -> str | None:
        return os.getenv("DB_USER")

    def get_postgres_password(self) -> str | None:
        return os.getenv("DB_PASSWORD")

    def get_postgres_db(self) -> str | None:
        return os.getenv("DB_NAME")

    def get_postgres_host(self) -> str | None:
        return os.getenv("DB_HOST")

    def get_postgres_port(self) -> int | None:
        """Returns the PostgreSQL port as an integer, or None if not set."""
        port_str = os.getenv("DB_PORT")
        if port_str is None:
            return None
        try:
            return int(port_str)
        except ValueError:
            raise ValueError("DB_PORT environment variable must be an integer.")

    # --- DB Pool Size Getters ---
    def get_main_db_pool_min_size(self) -> int:
        """Returns the minimum pool size for the main DB."""
        try:
            return int(os.getenv("MAIN_DB_POOL_MIN_SIZE", "1"))
        except ValueError:
            raise ValueError("MAIN_DB_POOL_MIN_SIZE environment variable must be an integer.")

    def get_main_db_pool_max_size(self) -> int:
        """Returns the maximum pool size for the main DB."""
        try:
            return int(os.getenv("MAIN_DB_POOL_MAX_SIZE", "10"))
        except ValueError:
            raise ValueError("MAIN_DB_POOL_MAX_SIZE environment variable must be an integer.")

    # --- Logging Settings --- #
    def get_log_level(self, default: str = "INFO") -> str:
        """Gets the configured log level, defaulting if not set."""
        return os.getenv("LOG_LEVEL", default).upper()

    # uvicorn
    def get_app_host(self, default: str = "0.0.0.0") -> str:
        """Gets the configured app host, defaulting if not set."""
        return os.getenv("LUTHIEN_CONTROL_HOST", default)

    def get_app_port(self, default: int = 8000) -> int:
        """Gets the configured app port, defaulting if not set."""
        return int(os.getenv("LUTHIEN_CONTROL_PORT", default))

    def get_app_reload(self, default: bool = False) -> bool:
        """Gets the configured app reload, defaulting if not set."""
        reload = os.getenv("LUTHIEN_CONTROL_RELOAD")
        if reload is None:
            return default
        elif reload.lower() == "true":
            return True
        elif reload.lower() == "false":
            return False
        else:
            raise ValueError(f"LUTHIEN_CONTROL_RELOAD environment variable must be 'true' or 'false' (got {reload}).")

    # get_log_level is reused

    # --- Database DSN Helper Properties using Getters ---
    @property
    def admin_dsn(self) -> str:
        """DSN for connecting to the default 'postgres' db for admin tasks.
        Raises ValueError if required DB settings are missing.
        """
        user = self.get_postgres_user()
        password = self.get_postgres_password()
        host = self.get_postgres_host()
        port = self.get_postgres_port()

        if not all([user, password, host, port]):
            missing = [
                name
                for name, val in [("USER", user), ("PASSWORD", password), ("HOST", host), ("PORT", port)]
                if not val
            ]
            raise ValueError(f"Missing required database settings ({', '.join(missing)}) for admin_dsn")

        return f"postgresql://{user}:{password}@{host}:{port}/postgres"

    @property
    def base_dsn(self) -> str:
        """Base DSN without a specific database name.
        Raises ValueError if required DB settings are missing.
        """
        user = self.get_postgres_user()
        password = self.get_postgres_password()
        host = self.get_postgres_host()
        port = self.get_postgres_port()

        if not all([user, password, host, port]):
            missing = [
                name
                for name, val in [("USER", user), ("PASSWORD", password), ("HOST", host), ("PORT", port)]
                if not val
            ]
            raise ValueError(f"Missing required database settings ({', '.join(missing)}) for base_dsn")

        return f"postgresql://{user}:{password}@{host}:{port}"

    def get_db_dsn(self, db_name: str | None = None) -> str:
        """Returns the DSN for a specific database name, or the default DB_NAME.
        Raises ValueError if required DB settings or the target db_name are missing.
        """
        target_db = db_name or self.get_postgres_db()
        if not target_db:
            raise ValueError("Missing target database name (either provide db_name or set DB_NAME env var)")
        base = self.base_dsn  # Use property
        return f"{base}/{target_db}"

    def get_run_mode(self) -> str:
        """Returns the run mode, defaulting to 'prod' if not set."""
        return os.getenv("RUN_MODE", "prod")

    def dev_mode(self) -> bool:
        """Returns True if the run mode is 'dev', False otherwise."""
        return self.get_run_mode() == "dev"

Application configuration settings loaded from environment variables.

property

DSN for connecting to the default 'postgres' db for admin tasks. Raises ValueError if required DB settings are missing.

admin_dsn property

DSN for connecting to the default 'postgres' db for admin tasks. Raises ValueError if required DB settings are missing.

property

Base DSN without a specific database name. Raises ValueError if required DB settings are missing.

base_dsn property

Base DSN without a specific database name. Raises ValueError if required DB settings are missing.

dev_mode()

Source code in luthien_control/settings.py
177
178
179
def dev_mode(self) -> bool:
    """Returns True if the run mode is 'dev', False otherwise."""
    return self.get_run_mode() == "dev"

Returns True if the run mode is 'dev', False otherwise.

get_app_host(default='0.0.0.0')

Source code in luthien_control/settings.py
100
101
102
def get_app_host(self, default: str = "0.0.0.0") -> str:
    """Gets the configured app host, defaulting if not set."""
    return os.getenv("LUTHIEN_CONTROL_HOST", default)

Gets the configured app host, defaulting if not set.

get_app_port(default=8000)

Source code in luthien_control/settings.py
104
105
106
def get_app_port(self, default: int = 8000) -> int:
    """Gets the configured app port, defaulting if not set."""
    return int(os.getenv("LUTHIEN_CONTROL_PORT", default))

Gets the configured app port, defaulting if not set.

get_app_reload(default=False)

Source code in luthien_control/settings.py
108
109
110
111
112
113
114
115
116
117
118
def get_app_reload(self, default: bool = False) -> bool:
    """Gets the configured app reload, defaulting if not set."""
    reload = os.getenv("LUTHIEN_CONTROL_RELOAD")
    if reload is None:
        return default
    elif reload.lower() == "true":
        return True
    elif reload.lower() == "false":
        return False
    else:
        raise ValueError(f"LUTHIEN_CONTROL_RELOAD environment variable must be 'true' or 'false' (got {reload}).")

Gets the configured app reload, defaulting if not set.

get_backend_url()

Source code in luthien_control/settings.py
30
31
32
33
34
35
36
37
38
def get_backend_url(self) -> Optional[str]:
    """Returns the backend URL as a string, if set."""
    url = os.getenv("BACKEND_URL")
    if url:
        # Basic validation (can be enhanced)
        parsed = urlparse(url)
        if not all([parsed.scheme, parsed.netloc]):
            raise ValueError(f"Invalid BACKEND_URL format: {url}")
    return url

Returns the backend URL as a string, if set.

get_database_url()

Source code in luthien_control/settings.py
40
41
42
def get_database_url(self) -> Optional[str]:
    """Returns the primary DATABASE_URL, if set."""
    return os.getenv("DATABASE_URL")

Returns the primary DATABASE_URL, if set.

get_db_dsn(db_name=None)

Source code in luthien_control/settings.py
163
164
165
166
167
168
169
170
171
def get_db_dsn(self, db_name: str | None = None) -> str:
    """Returns the DSN for a specific database name, or the default DB_NAME.
    Raises ValueError if required DB settings or the target db_name are missing.
    """
    target_db = db_name or self.get_postgres_db()
    if not target_db:
        raise ValueError("Missing target database name (either provide db_name or set DB_NAME env var)")
    base = self.base_dsn  # Use property
    return f"{base}/{target_db}"

Returns the DSN for a specific database name, or the default DB_NAME. Raises ValueError if required DB settings or the target db_name are missing.

get_log_level(default='INFO')

Source code in luthien_control/settings.py
95
96
97
def get_log_level(self, default: str = "INFO") -> str:
    """Gets the configured log level, defaulting if not set."""
    return os.getenv("LOG_LEVEL", default).upper()

Gets the configured log level, defaulting if not set.

get_main_db_pool_max_size()

Source code in luthien_control/settings.py
87
88
89
90
91
92
def get_main_db_pool_max_size(self) -> int:
    """Returns the maximum pool size for the main DB."""
    try:
        return int(os.getenv("MAIN_DB_POOL_MAX_SIZE", "10"))
    except ValueError:
        raise ValueError("MAIN_DB_POOL_MAX_SIZE environment variable must be an integer.")

Returns the maximum pool size for the main DB.

get_main_db_pool_min_size()

Source code in luthien_control/settings.py
80
81
82
83
84
85
def get_main_db_pool_min_size(self) -> int:
    """Returns the minimum pool size for the main DB."""
    try:
        return int(os.getenv("MAIN_DB_POOL_MIN_SIZE", "1"))
    except ValueError:
        raise ValueError("MAIN_DB_POOL_MIN_SIZE environment variable must be an integer.")

Returns the minimum pool size for the main DB.

get_openai_api_key()

Source code in luthien_control/settings.py
44
45
46
def get_openai_api_key(self) -> str | None:
    """Returns the OpenAI API key, if set."""
    return os.getenv("OPENAI_API_KEY")

Returns the OpenAI API key, if set.

get_policy_filepath()

Source code in luthien_control/settings.py
52
53
54
def get_policy_filepath(self) -> str | None:
    """Returns the path to the policy file, if set."""
    return os.getenv("POLICY_FILEPATH")

Returns the path to the policy file, if set.

get_postgres_port()

Source code in luthien_control/settings.py
69
70
71
72
73
74
75
76
77
def get_postgres_port(self) -> int | None:
    """Returns the PostgreSQL port as an integer, or None if not set."""
    port_str = os.getenv("DB_PORT")
    if port_str is None:
        return None
    try:
        return int(port_str)
    except ValueError:
        raise ValueError("DB_PORT environment variable must be an integer.")

Returns the PostgreSQL port as an integer, or None if not set.

get_run_mode()

Source code in luthien_control/settings.py
173
174
175
def get_run_mode(self) -> str:
    """Returns the run mode, defaulting to 'prod' if not set."""
    return os.getenv("RUN_MODE", "prod")

Returns the run mode, defaulting to 'prod' if not set.

get_top_level_policy_name()

Source code in luthien_control/settings.py
48
49
50
def get_top_level_policy_name(self) -> str:
    """Returns the name of the top-level policy instance to load."""
    return os.getenv("TOP_LEVEL_POLICY_NAME", "root")

Returns the name of the top-level policy instance to load.

utils

DeepEventedModel

Bases: EventedModel

Source code in luthien_control/utils/deep_evented_model.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class DeepEventedModel(EventedModel):
    """A Pydantic EventedModel that emits a single `changed` signal on any change.

    This includes changes to top-level fields as well as changes within
    nested evented containers (like EventedList, EventedDict) or other
    DeepEventedModel instances.

    Attributes:
        changed: A signal that is emitted with no arguments when any value
                 in the model or its nested evented children changes.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    changed: ClassVar[Signal] = Signal()

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        # Connect our master `changed` signal to the base model's event group.
        # This handles all top-level field assignments.
        self.events.connect(self.changed)
        # Connect to the event groups of any initial child objects.
        self._connect_children()

    def __setattr__(self, name: str, value: Any) -> None:
        # Before the attribute is set, we must disconnect from the old child object.
        if name in self.__class__.model_fields:
            old_value = getattr(self, name, None)
            self._disconnect_child(old_value)

        super().__setattr__(name, value)

        # After the attribute is set, we connect to the new child object.
        if name in self.__class__.model_fields:
            new_value = getattr(self, name)
            self._connect_child(new_value)
            # The base EventedModel handles emitting the field-specific signal,
            # which is already piped to our `changed` signal.

    def _connect_child(self, child: Any) -> None:
        """If `child` is an evented object, connect its events to our signal."""
        if isinstance(child, DeepEventedModel):
            child.changed.connect(self.changed)
        elif isinstance(child, EventedList):
            child.events.connect(self.changed)
            child.events.inserted.connect(self._on_item_inserted)
            child.events.removed.connect(self._on_item_removed)
            for item in child:
                self._connect_child(item)
        elif isinstance(child, EventedDict):
            child.events.connect(self.changed)
            child.events.added.connect(self._on_item_added)
            for item in child.values():
                self._connect_child(item)
        elif self._is_evented(child):
            child.events.connect(self.changed)

    def _disconnect_child(self, child: Any) -> None:
        """If `child` is an evented object, disconnect its events."""
        if isinstance(child, DeepEventedModel):
            child.changed.disconnect(self.changed)
        elif isinstance(child, EventedList):
            child.events.disconnect(self.changed)
            child.events.inserted.disconnect(self._on_item_inserted)
            child.events.removed.disconnect(self._on_item_removed)
            for item in child:
                self._disconnect_child(item)
        elif isinstance(child, EventedDict):
            child.events.disconnect(self.changed)
            child.events.added.disconnect(self._on_item_added)
            for item in child.values():
                self._disconnect_child(item)
        elif self._is_evented(child):
            child.events.disconnect(self.changed)

    def _on_item_inserted(self, index: int, value: Any):
        self._connect_child(value)

    def _on_item_removed(self, index: int, value: Any):
        self._disconnect_child(value)

    def _on_item_added(self, key: str, value: Any):
        self._connect_child(value)

    def _connect_children(self) -> None:
        """Connect to the events of all evented children in the model."""
        for name in self.__class__.model_fields:
            child = getattr(self, name)
            self._connect_child(child)

    def _is_evented(self, obj: Any) -> bool:
        """Check if an object has a connectable `events` signal group."""
        events = getattr(obj, "events", None)
        return events is not None and callable(getattr(events, "connect", None))

    @model_serializer(mode="wrap")
    def _serialize_model(self, serializer, info):
        """Custom model serializer that converts EventedList and EventedDict to regular containers."""
        # First, check if this model actually has any EventedList or EventedDict fields
        has_evented_containers = False
        for field_name, field_info in self.__class__.model_fields.items():
            value = getattr(self, field_name)
            if isinstance(value, (EventedList, EventedDict)):
                has_evented_containers = True
                break

        # If no evented containers, use default serialization
        if not has_evented_containers:
            return serializer(self)

        # Otherwise, handle evented containers specially
        data = {}
        for field_name, field_info in self.__class__.model_fields.items():
            value = getattr(self, field_name)
            if isinstance(value, EventedList):
                data[field_name] = list(value)
            elif isinstance(value, EventedDict):
                data[field_name] = dict(value)
            else:
                # For other types, use the default field serialization
                data[field_name] = value
        return data

A Pydantic EventedModel that emits a single changed signal on any change.

This includes changes to top-level fields as well as changes within nested evented containers (like EventedList, EventedDict) or other DeepEventedModel instances.

Attributes:

Name Type Description
changed Signal

A signal that is emitted with no arguments when any value in the model or its nested evented children changes.

backend_call_spec

BackendCallSpec

Bases: BaseModel

Source code in luthien_control/utils/backend_call_spec.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class BackendCallSpec(BaseModel):
    """
    A specification for a backend LLM call.
    """

    model: str = Field(default="gpt-4o-mini")
    api_endpoint: str = Field(default="https://api.openai.com/v1")
    api_key_env_var: str = Field(default="OPENAI_API_KEY")
    request_args: dict[str, Any] = Field(
        default_factory=dict,
        description="Arguments to be passed to OpenAIChatCompletionsRequest.",
    )

A specification for a backend LLM call.

deep_evented_model

DeepEventedModel

Bases: EventedModel

Source code in luthien_control/utils/deep_evented_model.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class DeepEventedModel(EventedModel):
    """A Pydantic EventedModel that emits a single `changed` signal on any change.

    This includes changes to top-level fields as well as changes within
    nested evented containers (like EventedList, EventedDict) or other
    DeepEventedModel instances.

    Attributes:
        changed: A signal that is emitted with no arguments when any value
                 in the model or its nested evented children changes.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    changed: ClassVar[Signal] = Signal()

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        # Connect our master `changed` signal to the base model's event group.
        # This handles all top-level field assignments.
        self.events.connect(self.changed)
        # Connect to the event groups of any initial child objects.
        self._connect_children()

    def __setattr__(self, name: str, value: Any) -> None:
        # Before the attribute is set, we must disconnect from the old child object.
        if name in self.__class__.model_fields:
            old_value = getattr(self, name, None)
            self._disconnect_child(old_value)

        super().__setattr__(name, value)

        # After the attribute is set, we connect to the new child object.
        if name in self.__class__.model_fields:
            new_value = getattr(self, name)
            self._connect_child(new_value)
            # The base EventedModel handles emitting the field-specific signal,
            # which is already piped to our `changed` signal.

    def _connect_child(self, child: Any) -> None:
        """If `child` is an evented object, connect its events to our signal."""
        if isinstance(child, DeepEventedModel):
            child.changed.connect(self.changed)
        elif isinstance(child, EventedList):
            child.events.connect(self.changed)
            child.events.inserted.connect(self._on_item_inserted)
            child.events.removed.connect(self._on_item_removed)
            for item in child:
                self._connect_child(item)
        elif isinstance(child, EventedDict):
            child.events.connect(self.changed)
            child.events.added.connect(self._on_item_added)
            for item in child.values():
                self._connect_child(item)
        elif self._is_evented(child):
            child.events.connect(self.changed)

    def _disconnect_child(self, child: Any) -> None:
        """If `child` is an evented object, disconnect its events."""
        if isinstance(child, DeepEventedModel):
            child.changed.disconnect(self.changed)
        elif isinstance(child, EventedList):
            child.events.disconnect(self.changed)
            child.events.inserted.disconnect(self._on_item_inserted)
            child.events.removed.disconnect(self._on_item_removed)
            for item in child:
                self._disconnect_child(item)
        elif isinstance(child, EventedDict):
            child.events.disconnect(self.changed)
            child.events.added.disconnect(self._on_item_added)
            for item in child.values():
                self._disconnect_child(item)
        elif self._is_evented(child):
            child.events.disconnect(self.changed)

    def _on_item_inserted(self, index: int, value: Any):
        self._connect_child(value)

    def _on_item_removed(self, index: int, value: Any):
        self._disconnect_child(value)

    def _on_item_added(self, key: str, value: Any):
        self._connect_child(value)

    def _connect_children(self) -> None:
        """Connect to the events of all evented children in the model."""
        for name in self.__class__.model_fields:
            child = getattr(self, name)
            self._connect_child(child)

    def _is_evented(self, obj: Any) -> bool:
        """Check if an object has a connectable `events` signal group."""
        events = getattr(obj, "events", None)
        return events is not None and callable(getattr(events, "connect", None))

    @model_serializer(mode="wrap")
    def _serialize_model(self, serializer, info):
        """Custom model serializer that converts EventedList and EventedDict to regular containers."""
        # First, check if this model actually has any EventedList or EventedDict fields
        has_evented_containers = False
        for field_name, field_info in self.__class__.model_fields.items():
            value = getattr(self, field_name)
            if isinstance(value, (EventedList, EventedDict)):
                has_evented_containers = True
                break

        # If no evented containers, use default serialization
        if not has_evented_containers:
            return serializer(self)

        # Otherwise, handle evented containers specially
        data = {}
        for field_name, field_info in self.__class__.model_fields.items():
            value = getattr(self, field_name)
            if isinstance(value, EventedList):
                data[field_name] = list(value)
            elif isinstance(value, EventedDict):
                data[field_name] = dict(value)
            else:
                # For other types, use the default field serialization
                data[field_name] = value
        return data

A Pydantic EventedModel that emits a single changed signal on any change.

This includes changes to top-level fields as well as changes within nested evented containers (like EventedList, EventedDict) or other DeepEventedModel instances.

Attributes:

Name Type Description
changed Signal

A signal that is emitted with no arguments when any value in the model or its nested evented children changes.