Files
sentinel/app/services/binding_service.py

465 lines
17 KiB
Python
Raw Normal View History

from __future__ import annotations
import asyncio
import json
import logging
import time
from dataclasses import dataclass
from datetime import UTC, date, timedelta
from typing import Callable
from redis.asyncio import Redis
from sqlalchemy import func, select, text, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.config import RuntimeSettings, Settings
from app.core.ip_utils import is_ip_in_network
from app.core.security import hash_token, mask_token
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class BindingRecord:
id: int
token_hash: str
token_display: str
bound_ip: str
status: int
ip_matched: bool
@dataclass(slots=True)
class BindingCheckResult:
allowed: bool
status_code: int
detail: str
token_hash: str | None = None
token_display: str | None = None
bound_ip: str | None = None
should_alert: bool = False
newly_bound: bool = False
class InMemoryRateLimiter:
def __init__(self, max_per_second: int) -> None:
self.max_per_second = max(1, max_per_second)
self._window = int(time.monotonic())
self._count = 0
self._lock = asyncio.Lock()
async def allow(self) -> bool:
async with self._lock:
current_window = int(time.monotonic())
if current_window != self._window:
self._window = current_window
self._count = 0
if self._count >= self.max_per_second:
return False
self._count += 1
return True
class BindingService:
def __init__(
self,
settings: Settings,
session_factory: async_sessionmaker[AsyncSession],
redis: Redis | None,
runtime_settings_getter: Callable[[], RuntimeSettings],
) -> None:
self.settings = settings
self.session_factory = session_factory
self.redis = redis
self.runtime_settings_getter = runtime_settings_getter
self.last_used_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=settings.last_used_queue_size)
self._flush_task: asyncio.Task[None] | None = None
self._stop_event = asyncio.Event()
self._redis_degraded_limiter = InMemoryRateLimiter(settings.downstream_max_connections // 2)
async def start(self) -> None:
if self._flush_task is None:
self._stop_event.clear()
self._flush_task = asyncio.create_task(self._flush_loop(), name="binding-last-used-flush")
async def stop(self) -> None:
self._stop_event.set()
if self._flush_task is not None:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
self._flush_task = None
await self.flush_last_used_updates()
def status_label(self, status_code: int) -> str:
return "Active" if status_code == STATUS_ACTIVE else "Banned"
def cache_key(self, token_hash: str) -> str:
return f"sentinel:binding:{token_hash}"
def metrics_key(self, target_date: date) -> str:
return f"sentinel:metrics:{target_date.isoformat()}"
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
token_display = mask_token(token)
cache_hit, cache_available = await self._load_binding_from_cache(token_hash)
if cache_hit is not None:
if cache_hit.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
should_alert=True,
)
if is_ip_in_network(client_ip, cache_hit.bound_ip):
await self._touch_cache(token_hash)
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed from cache.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
should_alert=True,
)
if not cache_available:
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
if not await self._redis_degraded_limiter.allow():
logger.warning("Redis degraded limiter rejected a request during PostgreSQL fallback.")
return BindingCheckResult(
allowed=False,
status_code=429,
detail="Redis degraded mode rate limit reached.",
token_hash=token_hash,
token_display=token_display,
)
try:
record = await self._load_binding_from_db(token_hash, client_ip)
except SQLAlchemyError:
return self._handle_backend_failure(token_hash, token_display)
if record is not None:
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.status)
if record.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
should_alert=True,
)
if record.ip_matched:
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed from PostgreSQL.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
should_alert=True,
)
try:
created = await self._create_binding(token_hash, token_display, client_ip)
except SQLAlchemyError:
return self._handle_backend_failure(token_hash, token_display)
if created is None:
try:
existing = await self._load_binding_from_db(token_hash, client_ip)
except SQLAlchemyError:
return self._handle_backend_failure(token_hash, token_display)
if existing is None:
return self._handle_backend_failure(token_hash, token_display)
await self.sync_binding_cache(existing.token_hash, existing.bound_ip, existing.status)
if existing.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
should_alert=True,
)
if existing.ip_matched:
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed after concurrent bind resolution.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
should_alert=True,
)
await self.sync_binding_cache(created.token_hash, created.bound_ip, created.status)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="First-use bind created.",
token_hash=token_hash,
token_display=token_display,
bound_ip=created.bound_ip,
newly_bound=True,
)
async def sync_binding_cache(self, token_hash: str, bound_ip: str, status_code: int) -> None:
if self.redis is None:
return
payload = json.dumps({"bound_ip": bound_ip, "status": status_code})
try:
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
except Exception:
logger.warning("Failed to write binding cache.", extra={"token_hash": token_hash})
async def invalidate_binding_cache(self, token_hash: str) -> None:
if self.redis is None:
return
try:
await self.redis.delete(self.cache_key(token_hash))
except Exception:
logger.warning("Failed to delete binding cache.", extra={"token_hash": token_hash})
async def invalidate_many(self, token_hashes: list[str]) -> None:
if self.redis is None or not token_hashes:
return
keys = [self.cache_key(item) for item in token_hashes]
try:
await self.redis.delete(*keys)
except Exception:
logger.warning("Failed to delete multiple binding cache keys.", extra={"count": len(keys)})
def record_last_used(self, token_hash: str) -> None:
try:
self.last_used_queue.put_nowait(token_hash)
except asyncio.QueueFull:
logger.warning("last_used queue is full; dropping update.", extra={"token_hash": token_hash})
async def flush_last_used_updates(self) -> None:
token_hashes: set[str] = set()
while True:
try:
token_hashes.add(self.last_used_queue.get_nowait())
except asyncio.QueueEmpty:
break
if not token_hashes:
return
async with self.session_factory() as session:
try:
stmt = (
update(TokenBinding)
.where(TokenBinding.token_hash.in_(token_hashes))
.values(last_used_at=func.now())
)
await session.execute(stmt)
await session.commit()
except SQLAlchemyError:
await session.rollback()
logger.exception("Failed to flush last_used_at updates.", extra={"count": len(token_hashes)})
async def increment_request_metric(self, outcome: str | None) -> None:
if self.redis is None:
return
key = self.metrics_key(date.today())
ttl = self.settings.metrics_ttl_days * 86400
try:
async with self.redis.pipeline(transaction=True) as pipeline:
pipeline.hincrby(key, "total", 1)
if outcome in {"allowed", "intercepted"}:
pipeline.hincrby(key, outcome, 1)
pipeline.expire(key, ttl)
await pipeline.execute()
except Exception:
logger.warning("Failed to increment request metrics.", extra={"outcome": outcome})
async def get_metrics_window(self, days: int = 7) -> list[dict[str, int | str]]:
if self.redis is None:
return [
{"date": (date.today() - timedelta(days=offset)).isoformat(), "allowed": 0, "intercepted": 0, "total": 0}
for offset in range(days - 1, -1, -1)
]
series: list[dict[str, int | str]] = []
for offset in range(days - 1, -1, -1):
target = date.today() - timedelta(days=offset)
raw = await self.redis.hgetall(self.metrics_key(target))
series.append(
{
"date": target.isoformat(),
"allowed": int(raw.get("allowed", 0)),
"intercepted": int(raw.get("intercepted", 0)),
"total": int(raw.get("total", 0)),
}
)
return series
async def _load_binding_from_cache(self, token_hash: str) -> tuple[BindingRecord | None, bool]:
if self.redis is None:
return None, False
try:
raw = await self.redis.get(self.cache_key(token_hash))
except Exception:
logger.warning("Failed to read binding cache.", extra={"token_hash": token_hash})
return None, False
if raw is None:
return None, True
data = json.loads(raw)
return (
BindingRecord(
id=0,
token_hash=token_hash,
token_display="",
bound_ip=data["bound_ip"],
status=int(data["status"]),
ip_matched=False,
),
True,
)
async def _touch_cache(self, token_hash: str) -> None:
if self.redis is None:
return
try:
await self.redis.expire(self.cache_key(token_hash), self.settings.redis_binding_ttl_seconds)
except Exception:
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
query = text(
"""
SELECT
id,
token_hash,
token_display,
bound_ip::text AS bound_ip,
status,
CAST(:client_ip AS inet) << bound_ip AS ip_matched
FROM token_bindings
WHERE token_hash = :token_hash
LIMIT 1
"""
)
async with self.session_factory() as session:
result = await session.execute(query, {"token_hash": token_hash, "client_ip": client_ip})
row = result.mappings().first()
if row is None:
return None
return BindingRecord(
id=int(row["id"]),
token_hash=str(row["token_hash"]),
token_display=str(row["token_display"]),
bound_ip=str(row["bound_ip"]),
status=int(row["status"]),
ip_matched=bool(row["ip_matched"]),
)
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
statement = text(
"""
INSERT INTO token_bindings (token_hash, token_display, bound_ip, status)
VALUES (:token_hash, :token_display, CAST(:bound_ip AS cidr), :status)
ON CONFLICT (token_hash) DO NOTHING
RETURNING id, token_hash, token_display, bound_ip::text AS bound_ip, status
"""
)
async with self.session_factory() as session:
try:
result = await session.execute(
statement,
{
"token_hash": token_hash,
"token_display": token_display,
"bound_ip": client_ip,
"status": STATUS_ACTIVE,
},
)
row = result.mappings().first()
await session.commit()
except SQLAlchemyError:
await session.rollback()
raise
if row is None:
return None
return BindingRecord(
id=int(row["id"]),
token_hash=str(row["token_hash"]),
token_display=str(row["token_display"]),
bound_ip=str(row["bound_ip"]),
status=int(row["status"]),
ip_matched=True,
)
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
runtime_settings = self.runtime_settings_getter()
logger.exception(
"Binding storage backend failed.",
extra={"failsafe_mode": runtime_settings.failsafe_mode, "token_hash": token_hash},
)
if runtime_settings.failsafe_mode == "open":
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed by failsafe mode.",
token_hash=token_hash,
token_display=token_display,
)
return BindingCheckResult(
allowed=False,
status_code=503,
detail="Binding backend unavailable and failsafe mode is closed.",
token_hash=token_hash,
token_display=token_display,
)
async def _flush_loop(self) -> None:
try:
while not self._stop_event.is_set():
await asyncio.sleep(self.settings.last_used_flush_interval_seconds)
await self.flush_last_used_updates()
except asyncio.CancelledError:
raise