feat(core): 初始化 Key-IP Sentinel 服务与部署骨架

- 搭建 FastAPI、Redis、PostgreSQL、Nginx 与 Docker Compose 基础结构
- 实现反向代理、首用绑定、拦截告警、归档任务和管理接口
- 提供 Vue3 管理后台初版,以及 uv/requirements 双依赖配置
This commit is contained in:
2026-03-04 00:18:33 +08:00
commit ab1bd90c65
50 changed files with 5645 additions and 0 deletions

49
app/api/auth.py Normal file
View File

@@ -0,0 +1,49 @@
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Request, status
from redis.asyncio import Redis
from app.config import Settings
from app.core.ip_utils import extract_client_ip
from app.core.security import (
clear_login_failures,
create_admin_jwt,
ensure_login_allowed,
register_login_failure,
verify_admin_password,
)
from app.dependencies import get_redis, get_settings
from app.schemas.auth import LoginRequest, TokenResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/api", tags=["auth"])
@router.post("/login", response_model=TokenResponse)
async def login(
payload: LoginRequest,
request: Request,
settings: Settings = Depends(get_settings),
redis: Redis | None = Depends(get_redis),
) -> TokenResponse:
if redis is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Login service is unavailable because Redis is offline.",
)
client_ip = extract_client_ip(request, settings)
await ensure_login_allowed(redis, client_ip, settings)
if not verify_admin_password(payload.password, settings):
await register_login_failure(redis, client_ip, settings)
logger.warning("Admin login failed.", extra={"client_ip": client_ip})
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid admin password.")
await clear_login_failures(redis, client_ip)
token, expires_in = create_admin_jwt(settings)
logger.info("Admin login succeeded.", extra={"client_ip": client_ip})
return TokenResponse(access_token=token, expires_in=expires_in)

153
app/api/bindings.py Normal file
View File

@@ -0,0 +1,153 @@
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import String, cast, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import Settings
from app.core.ip_utils import extract_client_ip
from app.dependencies import get_binding_service, get_db_session, get_settings, require_admin
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
from app.schemas.binding import (
BindingActionRequest,
BindingIPUpdateRequest,
BindingItem,
BindingListResponse,
)
from app.services.binding_service import BindingService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/api/bindings", tags=["bindings"], dependencies=[Depends(require_admin)])
def to_binding_item(binding: TokenBinding, binding_service: BindingService) -> BindingItem:
return BindingItem(
id=binding.id,
token_display=binding.token_display,
bound_ip=str(binding.bound_ip),
status=binding.status,
status_label=binding_service.status_label(binding.status),
first_used_at=binding.first_used_at,
last_used_at=binding.last_used_at,
created_at=binding.created_at,
)
async def get_binding_or_404(session: AsyncSession, binding_id: int) -> TokenBinding:
binding = await session.get(TokenBinding, binding_id)
if binding is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Binding was not found.")
return binding
def log_admin_action(request: Request, settings: Settings, action: str, binding_id: int) -> None:
logger.info(
"Admin binding action.",
extra={
"client_ip": extract_client_ip(request, settings),
"action": action,
"binding_id": binding_id,
},
)
@router.get("", response_model=BindingListResponse)
async def list_bindings(
page: int = Query(default=1, ge=1),
page_size: int = Query(default=20, ge=1, le=200),
token_suffix: str | None = Query(default=None),
ip: str | None = Query(default=None),
status_filter: int | None = Query(default=None, alias="status"),
session: AsyncSession = Depends(get_db_session),
binding_service: BindingService = Depends(get_binding_service),
) -> BindingListResponse:
statement = select(TokenBinding)
if token_suffix:
statement = statement.where(TokenBinding.token_display.ilike(f"%{token_suffix}%"))
if ip:
statement = statement.where(cast(TokenBinding.bound_ip, String).ilike(f"%{ip}%"))
if status_filter in {STATUS_ACTIVE, STATUS_BANNED}:
statement = statement.where(TokenBinding.status == status_filter)
total_result = await session.execute(select(func.count()).select_from(statement.subquery()))
total = int(total_result.scalar_one())
bindings = (
await session.scalars(
statement.order_by(TokenBinding.last_used_at.desc()).offset((page - 1) * page_size).limit(page_size)
)
).all()
return BindingListResponse(
items=[to_binding_item(item, binding_service) for item in bindings],
total=total,
page=page,
page_size=page_size,
)
@router.post("/unbind")
async def unbind_token(
payload: BindingActionRequest,
request: Request,
settings: Settings = Depends(get_settings),
session: AsyncSession = Depends(get_db_session),
binding_service: BindingService = Depends(get_binding_service),
):
binding = await get_binding_or_404(session, payload.id)
token_hash = binding.token_hash
await session.delete(binding)
await session.commit()
await binding_service.invalidate_binding_cache(token_hash)
log_admin_action(request, settings, "unbind", payload.id)
return {"success": True}
@router.put("/ip")
async def update_bound_ip(
payload: BindingIPUpdateRequest,
request: Request,
settings: Settings = Depends(get_settings),
session: AsyncSession = Depends(get_db_session),
binding_service: BindingService = Depends(get_binding_service),
):
binding = await get_binding_or_404(session, payload.id)
binding.bound_ip = payload.bound_ip
await session.commit()
await binding_service.sync_binding_cache(binding.token_hash, str(binding.bound_ip), binding.status)
log_admin_action(request, settings, "update_ip", payload.id)
return {"success": True}
@router.post("/ban")
async def ban_token(
payload: BindingActionRequest,
request: Request,
settings: Settings = Depends(get_settings),
session: AsyncSession = Depends(get_db_session),
binding_service: BindingService = Depends(get_binding_service),
):
binding = await get_binding_or_404(session, payload.id)
binding.status = STATUS_BANNED
await session.commit()
await binding_service.sync_binding_cache(binding.token_hash, str(binding.bound_ip), binding.status)
log_admin_action(request, settings, "ban", payload.id)
return {"success": True}
@router.post("/unban")
async def unban_token(
payload: BindingActionRequest,
request: Request,
settings: Settings = Depends(get_settings),
session: AsyncSession = Depends(get_db_session),
binding_service: BindingService = Depends(get_binding_service),
):
binding = await get_binding_or_404(session, payload.id)
binding.status = STATUS_ACTIVE
await session.commit()
await binding_service.sync_binding_cache(binding.token_hash, str(binding.bound_ip), binding.status)
log_admin_action(request, settings, "unban", payload.id)
return {"success": True}

109
app/api/dashboard.py Normal file
View File

@@ -0,0 +1,109 @@
from __future__ import annotations
from datetime import UTC, datetime, time, timedelta
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_binding_service, get_db_session, require_admin
from app.models.intercept_log import InterceptLog
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
from app.schemas.log import InterceptLogItem
from app.services.binding_service import BindingService
router = APIRouter(prefix="/admin/api", tags=["dashboard"], dependencies=[Depends(require_admin)])
class MetricSummary(BaseModel):
total: int
allowed: int
intercepted: int
class BindingSummary(BaseModel):
active: int
banned: int
class TrendPoint(BaseModel):
date: str
total: int
allowed: int
intercepted: int
class DashboardResponse(BaseModel):
today: MetricSummary
bindings: BindingSummary
trend: list[TrendPoint]
recent_intercepts: list[InterceptLogItem]
async def build_trend(
session: AsyncSession,
binding_service: BindingService,
) -> list[TrendPoint]:
series = await binding_service.get_metrics_window(days=7)
start_day = datetime.combine(datetime.now(UTC).date() - timedelta(days=6), time.min, tzinfo=UTC)
intercept_counts_result = await session.execute(
select(func.date(InterceptLog.intercepted_at), func.count())
.where(InterceptLog.intercepted_at >= start_day)
.group_by(func.date(InterceptLog.intercepted_at))
)
db_intercept_counts = {
row[0].isoformat(): int(row[1])
for row in intercept_counts_result.all()
}
trend: list[TrendPoint] = []
for item in series:
day = str(item["date"])
allowed = int(item["allowed"])
intercepted = max(int(item["intercepted"]), db_intercept_counts.get(day, 0))
total = max(int(item["total"]), allowed + intercepted)
trend.append(TrendPoint(date=day, total=total, allowed=allowed, intercepted=intercepted))
return trend
async def build_recent_intercepts(session: AsyncSession) -> list[InterceptLogItem]:
recent_logs = (
await session.scalars(select(InterceptLog).order_by(InterceptLog.intercepted_at.desc()).limit(10))
).all()
return [
InterceptLogItem(
id=item.id,
token_display=item.token_display,
bound_ip=str(item.bound_ip),
attempt_ip=str(item.attempt_ip),
alerted=item.alerted,
intercepted_at=item.intercepted_at,
)
for item in recent_logs
]
@router.get("/dashboard", response_model=DashboardResponse)
async def get_dashboard(
session: AsyncSession = Depends(get_db_session),
binding_service: BindingService = Depends(get_binding_service),
) -> DashboardResponse:
trend = await build_trend(session, binding_service)
active_count = await session.scalar(
select(func.count()).select_from(TokenBinding).where(TokenBinding.status == STATUS_ACTIVE)
)
banned_count = await session.scalar(
select(func.count()).select_from(TokenBinding).where(TokenBinding.status == STATUS_BANNED)
)
recent_intercepts = await build_recent_intercepts(session)
today = trend[-1] if trend else TrendPoint(date=datetime.now(UTC).date().isoformat(), total=0, allowed=0, intercepted=0)
return DashboardResponse(
today=MetricSummary(total=today.total, allowed=today.allowed, intercepted=today.intercepted),
bindings=BindingSummary(active=int(active_count or 0), banned=int(banned_count or 0)),
trend=trend,
recent_intercepts=recent_intercepts,
)

107
app/api/logs.py Normal file
View File

@@ -0,0 +1,107 @@
from __future__ import annotations
import csv
import io
from datetime import datetime
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import String, cast, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_db_session, require_admin
from app.models.intercept_log import InterceptLog
from app.schemas.log import InterceptLogItem, LogListResponse
router = APIRouter(prefix="/admin/api/logs", tags=["logs"], dependencies=[Depends(require_admin)])
def apply_log_filters(
statement,
token: str | None,
attempt_ip: str | None,
start_time: datetime | None,
end_time: datetime | None,
):
if token:
statement = statement.where(InterceptLog.token_display.ilike(f"%{token}%"))
if attempt_ip:
statement = statement.where(cast(InterceptLog.attempt_ip, String).ilike(f"%{attempt_ip}%"))
if start_time:
statement = statement.where(InterceptLog.intercepted_at >= start_time)
if end_time:
statement = statement.where(InterceptLog.intercepted_at <= end_time)
return statement
@router.get("", response_model=LogListResponse)
async def list_logs(
page: int = Query(default=1, ge=1),
page_size: int = Query(default=20, ge=1, le=200),
token: str | None = Query(default=None),
attempt_ip: str | None = Query(default=None),
start_time: datetime | None = Query(default=None),
end_time: datetime | None = Query(default=None),
session: AsyncSession = Depends(get_db_session),
) -> LogListResponse:
statement = apply_log_filters(select(InterceptLog), token, attempt_ip, start_time, end_time)
total_result = await session.execute(select(func.count()).select_from(statement.subquery()))
total = int(total_result.scalar_one())
logs = (
await session.scalars(
statement.order_by(InterceptLog.intercepted_at.desc()).offset((page - 1) * page_size).limit(page_size)
)
).all()
return LogListResponse(
items=[
InterceptLogItem(
id=item.id,
token_display=item.token_display,
bound_ip=str(item.bound_ip),
attempt_ip=str(item.attempt_ip),
alerted=item.alerted,
intercepted_at=item.intercepted_at,
)
for item in logs
],
total=total,
page=page,
page_size=page_size,
)
@router.get("/export")
async def export_logs(
token: str | None = Query(default=None),
attempt_ip: str | None = Query(default=None),
start_time: datetime | None = Query(default=None),
end_time: datetime | None = Query(default=None),
session: AsyncSession = Depends(get_db_session),
):
statement = apply_log_filters(select(InterceptLog), token, attempt_ip, start_time, end_time).order_by(
InterceptLog.intercepted_at.desc()
)
logs = (await session.scalars(statement)).all()
buffer = io.StringIO()
writer = csv.writer(buffer)
writer.writerow(["id", "token_display", "bound_ip", "attempt_ip", "alerted", "intercepted_at"])
for item in logs:
writer.writerow(
[
item.id,
item.token_display,
str(item.bound_ip),
str(item.attempt_ip),
item.alerted,
item.intercepted_at.isoformat(),
]
)
filename = f"sentinel-logs-{datetime.utcnow().strftime('%Y%m%d%H%M%S')}.csv"
return StreamingResponse(
iter([buffer.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)

77
app/api/settings.py Normal file
View File

@@ -0,0 +1,77 @@
from __future__ import annotations
import logging
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings
from app.core.ip_utils import extract_client_ip
from app.dependencies import get_redis, get_runtime_settings, get_settings, require_admin
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/api/settings", tags=["settings"], dependencies=[Depends(require_admin)])
class SettingsResponse(BaseModel):
alert_webhook_url: str | None = None
alert_threshold_count: int = Field(ge=1)
alert_threshold_seconds: int = Field(ge=1)
archive_days: int = Field(ge=1)
failsafe_mode: Literal["open", "closed"]
class SettingsUpdateRequest(SettingsResponse):
pass
def serialize_runtime_settings(runtime_settings: RuntimeSettings) -> dict[str, str]:
return {
"alert_webhook_url": runtime_settings.alert_webhook_url or "",
"alert_threshold_count": str(runtime_settings.alert_threshold_count),
"alert_threshold_seconds": str(runtime_settings.alert_threshold_seconds),
"archive_days": str(runtime_settings.archive_days),
"failsafe_mode": runtime_settings.failsafe_mode,
}
@router.get("", response_model=SettingsResponse)
async def get_runtime_config(
runtime_settings: RuntimeSettings = Depends(get_runtime_settings),
) -> SettingsResponse:
return SettingsResponse(**runtime_settings.model_dump())
@router.put("", response_model=SettingsResponse)
async def update_runtime_config(
payload: SettingsUpdateRequest,
request: Request,
settings: Settings = Depends(get_settings),
redis: Redis | None = Depends(get_redis),
):
if redis is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Settings persistence is unavailable because Redis is offline.",
)
updated = RuntimeSettings(**payload.model_dump())
try:
await redis.hset(RUNTIME_SETTINGS_REDIS_KEY, mapping=serialize_runtime_settings(updated))
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Failed to persist runtime settings.",
) from exc
async with request.app.state.runtime_settings_lock:
request.app.state.runtime_settings = updated
logger.info(
"Runtime settings updated.",
extra={"client_ip": extract_client_ip(request, settings)},
)
return SettingsResponse(**updated.model_dump())

98
app/config.py Normal file
View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from functools import cached_property
from ipaddress import ip_network
from typing import Literal
from pydantic import BaseModel, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
RUNTIME_SETTINGS_REDIS_KEY = "sentinel:settings"
class RuntimeSettings(BaseModel):
alert_webhook_url: str | None = None
alert_threshold_count: int = Field(default=5, ge=1)
alert_threshold_seconds: int = Field(default=300, ge=1)
archive_days: int = Field(default=90, ge=1)
failsafe_mode: Literal["open", "closed"] = "closed"
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
case_sensitive=False,
)
downstream_url: str = Field(alias="DOWNSTREAM_URL")
redis_addr: str = Field(alias="REDIS_ADDR")
redis_password: str = Field(default="", alias="REDIS_PASSWORD")
pg_dsn: str = Field(alias="PG_DSN")
sentinel_hmac_secret: str = Field(alias="SENTINEL_HMAC_SECRET", min_length=32)
admin_password: str = Field(alias="ADMIN_PASSWORD", min_length=8)
admin_jwt_secret: str = Field(alias="ADMIN_JWT_SECRET", min_length=16)
trusted_proxy_ips: tuple[str, ...] = Field(default_factory=tuple, alias="TRUSTED_PROXY_IPS")
sentinel_failsafe_mode: Literal["open", "closed"] = Field(
default="closed",
alias="SENTINEL_FAILSAFE_MODE",
)
app_port: int = Field(default=7000, alias="APP_PORT")
alert_webhook_url: str | None = Field(default=None, alias="ALERT_WEBHOOK_URL")
alert_threshold_count: int = Field(default=5, alias="ALERT_THRESHOLD_COUNT", ge=1)
alert_threshold_seconds: int = Field(default=300, alias="ALERT_THRESHOLD_SECONDS", ge=1)
archive_days: int = Field(default=90, alias="ARCHIVE_DAYS", ge=1)
redis_binding_ttl_seconds: int = 604800
downstream_max_connections: int = 512
downstream_max_keepalive_connections: int = 128
last_used_flush_interval_seconds: int = 5
last_used_queue_size: int = 10000
login_lockout_threshold: int = 5
login_lockout_seconds: int = 900
admin_jwt_expire_hours: int = 8
archive_job_interval_minutes: int = 60
archive_batch_size: int = 500
metrics_ttl_days: int = 30
webhook_timeout_seconds: int = 5
@field_validator("downstream_url")
@classmethod
def normalize_downstream_url(cls, value: str) -> str:
return value.rstrip("/")
@field_validator("trusted_proxy_ips", mode="before")
@classmethod
def split_proxy_ips(cls, value: object) -> tuple[str, ...]:
if value is None:
return tuple()
if isinstance(value, str):
parts = [item.strip() for item in value.split(",")]
return tuple(item for item in parts if item)
if isinstance(value, (list, tuple, set)):
return tuple(str(item).strip() for item in value if str(item).strip())
return (str(value).strip(),)
@cached_property
def trusted_proxy_networks(self):
return tuple(ip_network(item, strict=False) for item in self.trusted_proxy_ips)
def build_runtime_settings(self) -> RuntimeSettings:
return RuntimeSettings(
alert_webhook_url=self.alert_webhook_url or None,
alert_threshold_count=self.alert_threshold_count,
alert_threshold_seconds=self.alert_threshold_seconds,
archive_days=self.archive_days,
failsafe_mode=self.sentinel_failsafe_mode,
)
_settings: Settings | None = None
def get_settings() -> Settings:
global _settings
if _settings is None:
_settings = Settings()
return _settings

35
app/core/ip_utils.py Normal file
View File

@@ -0,0 +1,35 @@
from __future__ import annotations
from ipaddress import ip_address, ip_network
from fastapi import Request
from app.config import Settings
def is_ip_in_network(candidate_ip: str, network_value: str) -> bool:
return ip_address(candidate_ip) in ip_network(network_value, strict=False)
def is_trusted_proxy(source_ip: str, settings: Settings) -> bool:
try:
parsed_ip = ip_address(source_ip)
except ValueError:
return False
return any(parsed_ip in network for network in settings.trusted_proxy_networks)
def extract_client_ip(request: Request, settings: Settings) -> str:
client_host = request.client.host if request.client else "127.0.0.1"
if not is_trusted_proxy(client_host, settings):
return client_host
real_ip = request.headers.get("x-real-ip")
if not real_ip:
return client_host
try:
ip_address(real_ip)
except ValueError:
return client_host
return real_ip

104
app/core/security.py Normal file
View File

@@ -0,0 +1,104 @@
from __future__ import annotations
import hashlib
import hmac
from datetime import UTC, datetime, timedelta
from fastapi import HTTPException, status
from jose import JWTError, jwt
from redis.asyncio import Redis
from app.config import Settings
ALGORITHM = "HS256"
def mask_token(token: str) -> str:
if not token:
return "unknown"
if len(token) <= 8:
return f"{token[:2]}...{token[-2:]}"
return f"{token[:4]}...{token[-4:]}"[:20]
def hash_token(token: str, secret: str) -> str:
return hmac.new(secret.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
def extract_bearer_token(authorization: str | None) -> str | None:
if not authorization:
return None
scheme, _, token = authorization.partition(" ")
if scheme.lower() != "bearer" or not token:
return None
return token.strip()
def verify_admin_password(password: str, settings: Settings) -> bool:
return hmac.compare_digest(password, settings.admin_password)
def create_admin_jwt(settings: Settings) -> tuple[str, int]:
expires_in = settings.admin_jwt_expire_hours * 3600
now = datetime.now(UTC)
payload = {
"sub": "admin",
"iat": int(now.timestamp()),
"exp": int((now + timedelta(seconds=expires_in)).timestamp()),
}
token = jwt.encode(payload, settings.admin_jwt_secret, algorithm=ALGORITHM)
return token, expires_in
def decode_admin_jwt(token: str, settings: Settings) -> dict:
try:
payload = jwt.decode(token, settings.admin_jwt_secret, algorithms=[ALGORITHM])
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired admin token.",
) from exc
if payload.get("sub") != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid admin token subject.",
)
return payload
def login_failure_key(client_ip: str) -> str:
return f"sentinel:login:fail:{client_ip}"
async def ensure_login_allowed(redis: Redis, client_ip: str, settings: Settings) -> None:
try:
current = await redis.get(login_failure_key(client_ip))
if current is None:
return
if int(current) >= settings.login_lockout_threshold:
ttl = await redis.ttl(login_failure_key(client_ip))
retry_after = max(ttl, 0)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Too many failed login attempts. Retry after {retry_after} seconds.",
)
except HTTPException:
raise
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Login lock service is unavailable.",
) from exc
async def register_login_failure(redis: Redis, client_ip: str, settings: Settings) -> None:
key = login_failure_key(client_ip)
async with redis.pipeline(transaction=True) as pipeline:
pipeline.incr(key)
pipeline.expire(key, settings.login_lockout_seconds)
await pipeline.execute()
async def clear_login_failures(redis: Redis, client_ip: str) -> None:
await redis.delete(login_failure_key(client_ip))

53
app/dependencies.py Normal file
View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from fastapi import Depends, HTTPException, Request, status
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import RuntimeSettings, Settings
from app.core.security import decode_admin_jwt, extract_bearer_token
from app.services.alert_service import AlertService
from app.services.archive_service import ArchiveService
from app.services.binding_service import BindingService
def get_settings(request: Request) -> Settings:
return request.app.state.settings
def get_redis(request: Request) -> Redis | None:
return request.app.state.redis
async def get_db_session(request: Request) -> AsyncIterator[AsyncSession]:
session_factory = request.app.state.session_factory
async with session_factory() as session:
yield session
def get_binding_service(request: Request) -> BindingService:
return request.app.state.binding_service
def get_alert_service(request: Request) -> AlertService:
return request.app.state.alert_service
def get_archive_service(request: Request) -> ArchiveService:
return request.app.state.archive_service
def get_runtime_settings(request: Request) -> RuntimeSettings:
return request.app.state.runtime_settings
async def require_admin(request: Request, settings: Settings = Depends(get_settings)) -> dict:
token = extract_bearer_token(request.headers.get("authorization"))
if token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing admin bearer token.",
)
return decode_admin_jwt(token, settings)

193
app/main.py Normal file
View File

@@ -0,0 +1,193 @@
from __future__ import annotations
import asyncio
import json
import logging
from contextlib import asynccontextmanager
from datetime import UTC, datetime
import httpx
from fastapi import FastAPI
from redis.asyncio import Redis
from redis.asyncio import from_url as redis_from_url
from app.api import auth, bindings, dashboard, logs, settings as settings_api
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings, get_settings
from app.models import intercept_log, token_binding # noqa: F401
from app.models.db import close_db, get_session_factory, init_db
from app.proxy.handler import router as proxy_router
from app.services.alert_service import AlertService
from app.services.archive_service import ArchiveService
from app.services.binding_service import BindingService
class JsonFormatter(logging.Formatter):
reserved = {
"args",
"asctime",
"created",
"exc_info",
"exc_text",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"message",
"msg",
"name",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"thread",
"threadName",
}
def format(self, record: logging.LogRecord) -> str:
payload = {
"timestamp": datetime.now(UTC).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
for key, value in record.__dict__.items():
if key in self.reserved or key.startswith("_"):
continue
payload[key] = value
if record.exc_info:
payload["exception"] = self.formatException(record.exc_info)
return json.dumps(payload, default=str)
def configure_logging() -> None:
root_logger = logging.getLogger()
handler = logging.StreamHandler()
handler.setFormatter(JsonFormatter())
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)
configure_logging()
logger = logging.getLogger(__name__)
async def load_runtime_settings(redis: Redis | None, settings: Settings) -> RuntimeSettings:
runtime_settings = settings.build_runtime_settings()
if redis is None:
return runtime_settings
try:
raw = await redis.hgetall(RUNTIME_SETTINGS_REDIS_KEY)
except Exception:
logger.warning("Failed to load runtime settings from Redis; using environment defaults.")
return runtime_settings
if not raw:
return runtime_settings
return RuntimeSettings(
alert_webhook_url=raw.get("alert_webhook_url") or None,
alert_threshold_count=int(raw.get("alert_threshold_count", runtime_settings.alert_threshold_count)),
alert_threshold_seconds=int(raw.get("alert_threshold_seconds", runtime_settings.alert_threshold_seconds)),
archive_days=int(raw.get("archive_days", runtime_settings.archive_days)),
failsafe_mode=raw.get("failsafe_mode", runtime_settings.failsafe_mode),
)
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = get_settings()
init_db(settings)
session_factory = get_session_factory()
redis: Redis | None = redis_from_url(
settings.redis_addr,
password=settings.redis_password or None,
encoding="utf-8",
decode_responses=True,
)
try:
await redis.ping()
except Exception:
logger.warning("Redis is unavailable at startup; continuing in degraded mode.")
try:
await redis.aclose()
except Exception:
pass
redis = None
downstream_client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=600.0, write=600.0, pool=10.0),
limits=httpx.Limits(
max_connections=settings.downstream_max_connections,
max_keepalive_connections=settings.downstream_max_keepalive_connections,
),
follow_redirects=False,
)
webhook_client = httpx.AsyncClient(timeout=httpx.Timeout(settings.webhook_timeout_seconds))
runtime_settings = await load_runtime_settings(redis, settings)
app.state.settings = settings
app.state.redis = redis
app.state.session_factory = session_factory
app.state.downstream_client = downstream_client
app.state.webhook_client = webhook_client
app.state.runtime_settings = runtime_settings
app.state.runtime_settings_lock = asyncio.Lock()
binding_service = BindingService(
settings=settings,
session_factory=session_factory,
redis=redis,
runtime_settings_getter=lambda: app.state.runtime_settings,
)
alert_service = AlertService(
settings=settings,
session_factory=session_factory,
redis=redis,
http_client=webhook_client,
runtime_settings_getter=lambda: app.state.runtime_settings,
)
archive_service = ArchiveService(
settings=settings,
session_factory=session_factory,
binding_service=binding_service,
runtime_settings_getter=lambda: app.state.runtime_settings,
)
app.state.binding_service = binding_service
app.state.alert_service = alert_service
app.state.archive_service = archive_service
await binding_service.start()
await archive_service.start()
logger.info("Application started.")
try:
yield
finally:
await archive_service.stop()
await binding_service.stop()
await downstream_client.aclose()
await webhook_client.aclose()
if redis is not None:
await redis.aclose()
await close_db()
logger.info("Application stopped.")
app = FastAPI(title="Key-IP Sentinel", lifespan=lifespan)
app.include_router(auth.router)
app.include_router(dashboard.router)
app.include_router(bindings.router)
app.include_router(logs.router)
app.include_router(settings_api.router)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "ok"}
app.include_router(proxy_router)

48
app/models/db.py Normal file
View File

@@ -0,0 +1,48 @@
from __future__ import annotations
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config import Settings
class Base(DeclarativeBase):
pass
_engine: AsyncEngine | None = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
def init_db(settings: Settings) -> None:
global _engine, _session_factory
if _engine is not None and _session_factory is not None:
return
_engine = create_async_engine(
settings.pg_dsn,
pool_pre_ping=True,
pool_size=20,
max_overflow=40,
)
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
def get_engine() -> AsyncEngine:
if _engine is None:
raise RuntimeError("Database engine has not been initialized.")
return _engine
def get_session_factory() -> async_sessionmaker[AsyncSession]:
if _session_factory is None:
raise RuntimeError("Database session factory has not been initialized.")
return _session_factory
async def close_db() -> None:
global _engine, _session_factory
if _engine is not None:
await _engine.dispose()
_engine = None
_session_factory = None

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import Boolean, DateTime, Index, String, func, text
from sqlalchemy.dialects.postgresql import CIDR, INET
from sqlalchemy.orm import Mapped, mapped_column
from app.models.db import Base
class InterceptLog(Base):
__tablename__ = "intercept_logs"
__table_args__ = (
Index("idx_intercept_logs_hash", "token_hash"),
Index("idx_intercept_logs_time", text("intercepted_at DESC")),
)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
token_hash: Mapped[str] = mapped_column(String(64), nullable=False)
token_display: Mapped[str] = mapped_column(String(20), nullable=False)
bound_ip: Mapped[str] = mapped_column(CIDR, nullable=False)
attempt_ip: Mapped[str] = mapped_column(INET, nullable=False)
alerted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("FALSE"))
intercepted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, Index, SmallInteger, String, func, text
from sqlalchemy.dialects.postgresql import CIDR
from sqlalchemy.orm import Mapped, mapped_column
from app.models.db import Base
STATUS_ACTIVE = 1
STATUS_BANNED = 2
class TokenBinding(Base):
__tablename__ = "token_bindings"
__table_args__ = (
Index("idx_token_bindings_hash", "token_hash"),
Index("idx_token_bindings_ip", "bound_ip", postgresql_using="gist", postgresql_ops={"bound_ip": "inet_ops"}),
)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
token_display: Mapped[str] = mapped_column(String(20), nullable=False)
bound_ip: Mapped[str] = mapped_column(CIDR, nullable=False)
status: Mapped[int] = mapped_column(
SmallInteger,
nullable=False,
default=STATUS_ACTIVE,
server_default=text("1"),
)
first_used_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
last_used_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)

111
app/proxy/handler.py Normal file
View File

@@ -0,0 +1,111 @@
from __future__ import annotations
import logging
from urllib.parse import urlsplit
import httpx
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from app.config import Settings
from app.core.ip_utils import extract_client_ip
from app.core.security import extract_bearer_token
from app.dependencies import get_alert_service, get_binding_service, get_settings
from app.services.alert_service import AlertService
from app.services.binding_service import BindingService
logger = logging.getLogger(__name__)
router = APIRouter()
CONTENT_LENGTH_HEADER = "content-length"
def build_upstream_headers(request: Request, downstream_url: str) -> list[tuple[str, str]]:
downstream_host = urlsplit(downstream_url).netloc
headers: list[tuple[str, str]] = []
for header_name, header_value in request.headers.items():
if header_name.lower() == "host":
continue
headers.append((header_name, header_value))
headers.append(("host", downstream_host))
return headers
def build_upstream_url(settings: Settings, request: Request) -> str:
return f"{settings.downstream_url}{request.url.path}"
def apply_downstream_headers(response: StreamingResponse, upstream_response: httpx.Response) -> None:
for header_name, header_value in upstream_response.headers.multi_items():
if header_name.lower() == CONTENT_LENGTH_HEADER:
continue
response.headers.append(header_name, header_value)
@router.api_route("/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"], include_in_schema=False)
@router.api_route(
"/{path:path}",
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"],
include_in_schema=False,
)
async def reverse_proxy(
request: Request,
path: str = "",
settings: Settings = Depends(get_settings),
binding_service: BindingService = Depends(get_binding_service),
alert_service: AlertService = Depends(get_alert_service),
):
client_ip = extract_client_ip(request, settings)
token = extract_bearer_token(request.headers.get("authorization"))
if token:
binding_result = await binding_service.evaluate_token_binding(token, client_ip)
if binding_result.allowed:
await binding_service.increment_request_metric("allowed")
else:
await binding_service.increment_request_metric("intercepted" if binding_result.should_alert else None)
if binding_result.should_alert and binding_result.token_hash and binding_result.token_display and binding_result.bound_ip:
await alert_service.handle_intercept(
token_hash=binding_result.token_hash,
token_display=binding_result.token_display,
bound_ip=binding_result.bound_ip,
attempt_ip=client_ip,
)
return JSONResponse(
status_code=binding_result.status_code,
content={"detail": binding_result.detail},
)
else:
await binding_service.increment_request_metric("allowed")
downstream_client: httpx.AsyncClient = request.app.state.downstream_client
upstream_url = build_upstream_url(settings, request)
upstream_headers = build_upstream_headers(request, settings.downstream_url)
try:
upstream_request = downstream_client.build_request(
request.method,
upstream_url,
params=request.query_params.multi_items(),
headers=upstream_headers,
content=request.stream(),
)
upstream_response = await downstream_client.send(upstream_request, stream=True)
except httpx.HTTPError as exc:
logger.exception("Failed to reach downstream service.")
return JSONResponse(status_code=502, content={"detail": f"Downstream request failed: {exc!s}"})
async def stream_response():
try:
async for chunk in upstream_response.aiter_raw():
yield chunk
finally:
await upstream_response.aclose()
response = StreamingResponse(
stream_response(),
status_code=upstream_response.status_code,
media_type=None,
)
apply_downstream_headers(response, upstream_response)
return response

13
app/schemas/auth.py Normal file
View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from pydantic import BaseModel, Field
class LoginRequest(BaseModel):
password: str = Field(min_length=1, max_length=256)
class TokenResponse(BaseModel):
access_token: str
token_type: str = "bearer"
expires_in: int

42
app/schemas/binding.py Normal file
View File

@@ -0,0 +1,42 @@
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field, field_validator
class BindingItem(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
token_display: str
bound_ip: str
status: int
status_label: str
first_used_at: datetime
last_used_at: datetime
created_at: datetime
class BindingListResponse(BaseModel):
items: list[BindingItem]
total: int
page: int
page_size: int
class BindingActionRequest(BaseModel):
id: int = Field(gt=0)
class BindingIPUpdateRequest(BaseModel):
id: int = Field(gt=0)
bound_ip: str = Field(min_length=3, max_length=64)
@field_validator("bound_ip")
@classmethod
def validate_bound_ip(cls, value: str) -> str:
from ipaddress import ip_network
ip_network(value, strict=False)
return value

23
app/schemas/log.py Normal file
View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, ConfigDict
class InterceptLogItem(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
token_display: str
bound_ip: str
attempt_ip: str
alerted: bool
intercepted_at: datetime
class LogListResponse(BaseModel):
items: list[InterceptLogItem]
total: int
page: int
page_size: int

View File

@@ -0,0 +1,123 @@
from __future__ import annotations
import logging
from typing import Callable
import httpx
from redis.asyncio import Redis
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.config import RuntimeSettings, Settings
from app.models.intercept_log import InterceptLog
logger = logging.getLogger(__name__)
class AlertService:
def __init__(
self,
settings: Settings,
session_factory: async_sessionmaker[AsyncSession],
redis: Redis | None,
http_client: httpx.AsyncClient,
runtime_settings_getter: Callable[[], RuntimeSettings],
) -> None:
self.settings = settings
self.session_factory = session_factory
self.redis = redis
self.http_client = http_client
self.runtime_settings_getter = runtime_settings_getter
def alert_key(self, token_hash: str) -> str:
return f"sentinel:alert:{token_hash}"
async def handle_intercept(
self,
token_hash: str,
token_display: str,
bound_ip: str,
attempt_ip: str,
) -> None:
await self._write_intercept_log(token_hash, token_display, bound_ip, attempt_ip)
runtime_settings = self.runtime_settings_getter()
if self.redis is None:
logger.warning("Redis is unavailable. Intercept alert counters are disabled.")
return
try:
async with self.redis.pipeline(transaction=True) as pipeline:
pipeline.incr(self.alert_key(token_hash))
pipeline.expire(self.alert_key(token_hash), runtime_settings.alert_threshold_seconds)
result = await pipeline.execute()
except Exception:
logger.warning("Failed to update intercept alert counter.", extra={"token_hash": token_hash})
return
count = int(result[0])
if count < runtime_settings.alert_threshold_count:
return
payload = {
"token_display": token_display,
"attempt_ip": attempt_ip,
"bound_ip": bound_ip,
"count": count,
}
if runtime_settings.alert_webhook_url:
try:
await self.http_client.post(runtime_settings.alert_webhook_url, json=payload)
except httpx.HTTPError:
logger.exception("Failed to deliver alert webhook.", extra={"token_hash": token_hash})
try:
await self.redis.delete(self.alert_key(token_hash))
except Exception:
logger.warning("Failed to clear intercept alert counter.", extra={"token_hash": token_hash})
await self._mark_alerted_records(token_hash, runtime_settings.alert_threshold_seconds)
async def _write_intercept_log(
self,
token_hash: str,
token_display: str,
bound_ip: str,
attempt_ip: str,
) -> None:
async with self.session_factory() as session:
try:
session.add(
InterceptLog(
token_hash=token_hash,
token_display=token_display,
bound_ip=bound_ip,
attempt_ip=attempt_ip,
alerted=False,
)
)
await session.commit()
except SQLAlchemyError:
await session.rollback()
logger.exception("Failed to write intercept log.", extra={"token_hash": token_hash})
async def _mark_alerted_records(self, token_hash: str, threshold_seconds: int) -> None:
statement = text(
"""
UPDATE intercept_logs
SET alerted = TRUE
WHERE token_hash = :token_hash
AND intercepted_at >= NOW() - (:threshold_seconds || ' seconds')::interval
"""
)
async with self.session_factory() as session:
try:
await session.execute(
statement,
{"token_hash": token_hash, "threshold_seconds": threshold_seconds},
)
await session.commit()
except SQLAlchemyError:
await session.rollback()
logger.exception("Failed to mark intercept logs as alerted.", extra={"token_hash": token_hash})

View File

@@ -0,0 +1,84 @@
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
from typing import Callable
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from sqlalchemy import delete, select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.config import RuntimeSettings, Settings
from app.models.token_binding import TokenBinding
from app.services.binding_service import BindingService
logger = logging.getLogger(__name__)
class ArchiveService:
def __init__(
self,
settings: Settings,
session_factory: async_sessionmaker[AsyncSession],
binding_service: BindingService,
runtime_settings_getter: Callable[[], RuntimeSettings],
) -> None:
self.settings = settings
self.session_factory = session_factory
self.binding_service = binding_service
self.runtime_settings_getter = runtime_settings_getter
self.scheduler = AsyncIOScheduler(timezone="UTC")
async def start(self) -> None:
if self.scheduler.running:
return
self.scheduler.add_job(
self.archive_inactive_bindings,
trigger="interval",
minutes=self.settings.archive_job_interval_minutes,
id="archive-inactive-bindings",
replace_existing=True,
max_instances=1,
coalesce=True,
)
self.scheduler.start()
async def stop(self) -> None:
if self.scheduler.running:
self.scheduler.shutdown(wait=False)
async def archive_inactive_bindings(self) -> int:
runtime_settings = self.runtime_settings_getter()
cutoff = datetime.now(UTC) - timedelta(days=runtime_settings.archive_days)
total_archived = 0
while True:
async with self.session_factory() as session:
try:
result = await session.execute(
select(TokenBinding.token_hash)
.where(TokenBinding.last_used_at < cutoff)
.order_by(TokenBinding.last_used_at.asc())
.limit(self.settings.archive_batch_size)
)
token_hashes = list(result.scalars())
if not token_hashes:
break
await session.execute(delete(TokenBinding).where(TokenBinding.token_hash.in_(token_hashes)))
await session.commit()
except SQLAlchemyError:
await session.rollback()
logger.exception("Failed to archive inactive bindings.")
break
await self.binding_service.invalidate_many(token_hashes)
total_archived += len(token_hashes)
if len(token_hashes) < self.settings.archive_batch_size:
break
if total_archived:
logger.info("Archived inactive bindings.", extra={"count": total_archived})
return total_archived

View File

@@ -0,0 +1,464 @@
from __future__ import annotations
import asyncio
import json
import logging
import time
from dataclasses import dataclass
from datetime import UTC, date, timedelta
from typing import Callable
from redis.asyncio import Redis
from sqlalchemy import func, select, text, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.config import RuntimeSettings, Settings
from app.core.ip_utils import is_ip_in_network
from app.core.security import hash_token, mask_token
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class BindingRecord:
id: int
token_hash: str
token_display: str
bound_ip: str
status: int
ip_matched: bool
@dataclass(slots=True)
class BindingCheckResult:
allowed: bool
status_code: int
detail: str
token_hash: str | None = None
token_display: str | None = None
bound_ip: str | None = None
should_alert: bool = False
newly_bound: bool = False
class InMemoryRateLimiter:
def __init__(self, max_per_second: int) -> None:
self.max_per_second = max(1, max_per_second)
self._window = int(time.monotonic())
self._count = 0
self._lock = asyncio.Lock()
async def allow(self) -> bool:
async with self._lock:
current_window = int(time.monotonic())
if current_window != self._window:
self._window = current_window
self._count = 0
if self._count >= self.max_per_second:
return False
self._count += 1
return True
class BindingService:
def __init__(
self,
settings: Settings,
session_factory: async_sessionmaker[AsyncSession],
redis: Redis | None,
runtime_settings_getter: Callable[[], RuntimeSettings],
) -> None:
self.settings = settings
self.session_factory = session_factory
self.redis = redis
self.runtime_settings_getter = runtime_settings_getter
self.last_used_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=settings.last_used_queue_size)
self._flush_task: asyncio.Task[None] | None = None
self._stop_event = asyncio.Event()
self._redis_degraded_limiter = InMemoryRateLimiter(settings.downstream_max_connections // 2)
async def start(self) -> None:
if self._flush_task is None:
self._stop_event.clear()
self._flush_task = asyncio.create_task(self._flush_loop(), name="binding-last-used-flush")
async def stop(self) -> None:
self._stop_event.set()
if self._flush_task is not None:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
self._flush_task = None
await self.flush_last_used_updates()
def status_label(self, status_code: int) -> str:
return "Active" if status_code == STATUS_ACTIVE else "Banned"
def cache_key(self, token_hash: str) -> str:
return f"sentinel:binding:{token_hash}"
def metrics_key(self, target_date: date) -> str:
return f"sentinel:metrics:{target_date.isoformat()}"
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
token_display = mask_token(token)
cache_hit, cache_available = await self._load_binding_from_cache(token_hash)
if cache_hit is not None:
if cache_hit.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
should_alert=True,
)
if is_ip_in_network(client_ip, cache_hit.bound_ip):
await self._touch_cache(token_hash)
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed from cache.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
should_alert=True,
)
if not cache_available:
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
if not await self._redis_degraded_limiter.allow():
logger.warning("Redis degraded limiter rejected a request during PostgreSQL fallback.")
return BindingCheckResult(
allowed=False,
status_code=429,
detail="Redis degraded mode rate limit reached.",
token_hash=token_hash,
token_display=token_display,
)
try:
record = await self._load_binding_from_db(token_hash, client_ip)
except SQLAlchemyError:
return self._handle_backend_failure(token_hash, token_display)
if record is not None:
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.status)
if record.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
should_alert=True,
)
if record.ip_matched:
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed from PostgreSQL.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
should_alert=True,
)
try:
created = await self._create_binding(token_hash, token_display, client_ip)
except SQLAlchemyError:
return self._handle_backend_failure(token_hash, token_display)
if created is None:
try:
existing = await self._load_binding_from_db(token_hash, client_ip)
except SQLAlchemyError:
return self._handle_backend_failure(token_hash, token_display)
if existing is None:
return self._handle_backend_failure(token_hash, token_display)
await self.sync_binding_cache(existing.token_hash, existing.bound_ip, existing.status)
if existing.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
should_alert=True,
)
if existing.ip_matched:
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed after concurrent bind resolution.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
should_alert=True,
)
await self.sync_binding_cache(created.token_hash, created.bound_ip, created.status)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="First-use bind created.",
token_hash=token_hash,
token_display=token_display,
bound_ip=created.bound_ip,
newly_bound=True,
)
async def sync_binding_cache(self, token_hash: str, bound_ip: str, status_code: int) -> None:
if self.redis is None:
return
payload = json.dumps({"bound_ip": bound_ip, "status": status_code})
try:
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
except Exception:
logger.warning("Failed to write binding cache.", extra={"token_hash": token_hash})
async def invalidate_binding_cache(self, token_hash: str) -> None:
if self.redis is None:
return
try:
await self.redis.delete(self.cache_key(token_hash))
except Exception:
logger.warning("Failed to delete binding cache.", extra={"token_hash": token_hash})
async def invalidate_many(self, token_hashes: list[str]) -> None:
if self.redis is None or not token_hashes:
return
keys = [self.cache_key(item) for item in token_hashes]
try:
await self.redis.delete(*keys)
except Exception:
logger.warning("Failed to delete multiple binding cache keys.", extra={"count": len(keys)})
def record_last_used(self, token_hash: str) -> None:
try:
self.last_used_queue.put_nowait(token_hash)
except asyncio.QueueFull:
logger.warning("last_used queue is full; dropping update.", extra={"token_hash": token_hash})
async def flush_last_used_updates(self) -> None:
token_hashes: set[str] = set()
while True:
try:
token_hashes.add(self.last_used_queue.get_nowait())
except asyncio.QueueEmpty:
break
if not token_hashes:
return
async with self.session_factory() as session:
try:
stmt = (
update(TokenBinding)
.where(TokenBinding.token_hash.in_(token_hashes))
.values(last_used_at=func.now())
)
await session.execute(stmt)
await session.commit()
except SQLAlchemyError:
await session.rollback()
logger.exception("Failed to flush last_used_at updates.", extra={"count": len(token_hashes)})
async def increment_request_metric(self, outcome: str | None) -> None:
if self.redis is None:
return
key = self.metrics_key(date.today())
ttl = self.settings.metrics_ttl_days * 86400
try:
async with self.redis.pipeline(transaction=True) as pipeline:
pipeline.hincrby(key, "total", 1)
if outcome in {"allowed", "intercepted"}:
pipeline.hincrby(key, outcome, 1)
pipeline.expire(key, ttl)
await pipeline.execute()
except Exception:
logger.warning("Failed to increment request metrics.", extra={"outcome": outcome})
async def get_metrics_window(self, days: int = 7) -> list[dict[str, int | str]]:
if self.redis is None:
return [
{"date": (date.today() - timedelta(days=offset)).isoformat(), "allowed": 0, "intercepted": 0, "total": 0}
for offset in range(days - 1, -1, -1)
]
series: list[dict[str, int | str]] = []
for offset in range(days - 1, -1, -1):
target = date.today() - timedelta(days=offset)
raw = await self.redis.hgetall(self.metrics_key(target))
series.append(
{
"date": target.isoformat(),
"allowed": int(raw.get("allowed", 0)),
"intercepted": int(raw.get("intercepted", 0)),
"total": int(raw.get("total", 0)),
}
)
return series
async def _load_binding_from_cache(self, token_hash: str) -> tuple[BindingRecord | None, bool]:
if self.redis is None:
return None, False
try:
raw = await self.redis.get(self.cache_key(token_hash))
except Exception:
logger.warning("Failed to read binding cache.", extra={"token_hash": token_hash})
return None, False
if raw is None:
return None, True
data = json.loads(raw)
return (
BindingRecord(
id=0,
token_hash=token_hash,
token_display="",
bound_ip=data["bound_ip"],
status=int(data["status"]),
ip_matched=False,
),
True,
)
async def _touch_cache(self, token_hash: str) -> None:
if self.redis is None:
return
try:
await self.redis.expire(self.cache_key(token_hash), self.settings.redis_binding_ttl_seconds)
except Exception:
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
query = text(
"""
SELECT
id,
token_hash,
token_display,
bound_ip::text AS bound_ip,
status,
CAST(:client_ip AS inet) << bound_ip AS ip_matched
FROM token_bindings
WHERE token_hash = :token_hash
LIMIT 1
"""
)
async with self.session_factory() as session:
result = await session.execute(query, {"token_hash": token_hash, "client_ip": client_ip})
row = result.mappings().first()
if row is None:
return None
return BindingRecord(
id=int(row["id"]),
token_hash=str(row["token_hash"]),
token_display=str(row["token_display"]),
bound_ip=str(row["bound_ip"]),
status=int(row["status"]),
ip_matched=bool(row["ip_matched"]),
)
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
statement = text(
"""
INSERT INTO token_bindings (token_hash, token_display, bound_ip, status)
VALUES (:token_hash, :token_display, CAST(:bound_ip AS cidr), :status)
ON CONFLICT (token_hash) DO NOTHING
RETURNING id, token_hash, token_display, bound_ip::text AS bound_ip, status
"""
)
async with self.session_factory() as session:
try:
result = await session.execute(
statement,
{
"token_hash": token_hash,
"token_display": token_display,
"bound_ip": client_ip,
"status": STATUS_ACTIVE,
},
)
row = result.mappings().first()
await session.commit()
except SQLAlchemyError:
await session.rollback()
raise
if row is None:
return None
return BindingRecord(
id=int(row["id"]),
token_hash=str(row["token_hash"]),
token_display=str(row["token_display"]),
bound_ip=str(row["bound_ip"]),
status=int(row["status"]),
ip_matched=True,
)
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
runtime_settings = self.runtime_settings_getter()
logger.exception(
"Binding storage backend failed.",
extra={"failsafe_mode": runtime_settings.failsafe_mode, "token_hash": token_hash},
)
if runtime_settings.failsafe_mode == "open":
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed by failsafe mode.",
token_hash=token_hash,
token_display=token_display,
)
return BindingCheckResult(
allowed=False,
status_code=503,
detail="Binding backend unavailable and failsafe mode is closed.",
token_hash=token_hash,
token_display=token_display,
)
async def _flush_loop(self) -> None:
try:
while not self._stop_event.is_set():
await asyncio.sleep(self.settings.last_used_flush_interval_seconds)
await self.flush_last_used_updates()
except asyncio.CancelledError:
raise