Files
sentinel/app/services/binding_service.py

463 lines
17 KiB
Python

from __future__ import annotations
import asyncio
import json
import logging
import time
from dataclasses import dataclass
from datetime import date, timedelta
from typing import Callable
from redis.asyncio import Redis
from sqlalchemy import func, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.config import RuntimeSettings, Settings
from app.core.ip_utils import is_ip_in_network
from app.core.security import hash_token, mask_token
from app.models.token_binding import (
BINDING_MODE_ALL,
BINDING_MODE_MULTIPLE,
BINDING_MODE_SINGLE,
STATUS_ACTIVE,
STATUS_BANNED,
TokenBinding,
)
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class BindingRecord:
id: int
token_hash: str
token_display: str
bound_ip: str
binding_mode: str
allowed_ips: list[str]
status: int
ip_matched: bool
@dataclass(slots=True)
class BindingCheckResult:
allowed: bool
status_code: int
detail: str
token_hash: str | None = None
token_display: str | None = None
bound_ip: str | None = None
should_alert: bool = False
newly_bound: bool = False
class InMemoryRateLimiter:
def __init__(self, max_per_second: int) -> None:
self.max_per_second = max(1, max_per_second)
self._window = int(time.monotonic())
self._count = 0
self._lock = asyncio.Lock()
async def allow(self) -> bool:
async with self._lock:
current_window = int(time.monotonic())
if current_window != self._window:
self._window = current_window
self._count = 0
if self._count >= self.max_per_second:
return False
self._count += 1
return True
class BindingService:
def __init__(
self,
settings: Settings,
session_factory: async_sessionmaker[AsyncSession],
redis: Redis | None,
runtime_settings_getter: Callable[[], RuntimeSettings],
) -> None:
self.settings = settings
self.session_factory = session_factory
self.redis = redis
self.runtime_settings_getter = runtime_settings_getter
self.last_used_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=settings.last_used_queue_size)
self._flush_task: asyncio.Task[None] | None = None
self._stop_event = asyncio.Event()
self._redis_degraded_limiter = InMemoryRateLimiter(settings.downstream_max_connections // 2)
async def start(self) -> None:
if self._flush_task is None:
self._stop_event.clear()
self._flush_task = asyncio.create_task(self._flush_loop(), name="binding-last-used-flush")
async def stop(self) -> None:
self._stop_event.set()
if self._flush_task is not None:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
self._flush_task = None
await self.flush_last_used_updates()
def status_label(self, status_code: int) -> str:
return "Active" if status_code == STATUS_ACTIVE else "Banned"
def cache_key(self, token_hash: str) -> str:
return f"sentinel:binding:{token_hash}"
def metrics_key(self, target_date: date) -> str:
return f"sentinel:metrics:{target_date.isoformat()}"
def build_bound_ip_display(self, binding_mode: str, allowed_ips: list[str]) -> str:
if binding_mode == BINDING_MODE_ALL:
return "ALL"
if not allowed_ips:
return "-"
if binding_mode == BINDING_MODE_MULTIPLE:
return ", ".join(allowed_ips)
return allowed_ips[0]
def is_client_allowed(self, client_ip: str, binding_mode: str, allowed_ips: list[str]) -> bool:
if binding_mode == BINDING_MODE_ALL:
return True
return any(is_ip_in_network(client_ip, item) for item in allowed_ips)
def to_binding_record(self, binding: TokenBinding, client_ip: str) -> BindingRecord:
allowed_ips = [str(item) for item in binding.allowed_ips]
binding_mode = binding.binding_mode or BINDING_MODE_SINGLE
return BindingRecord(
id=binding.id,
token_hash=binding.token_hash,
token_display=binding.token_display,
bound_ip=binding.bound_ip,
binding_mode=binding_mode,
allowed_ips=allowed_ips,
status=binding.status,
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
)
def denied_result(
self,
token_hash: str,
token_display: str,
bound_ip: str,
detail: str,
*,
should_alert: bool = True,
status_code: int = 403,
) -> BindingCheckResult:
return BindingCheckResult(
allowed=False,
status_code=status_code,
detail=detail,
token_hash=token_hash,
token_display=token_display,
bound_ip=bound_ip,
should_alert=should_alert,
)
def allowed_result(
self,
token_hash: str,
token_display: str,
bound_ip: str,
detail: str,
*,
newly_bound: bool = False,
) -> BindingCheckResult:
return BindingCheckResult(
allowed=True,
status_code=200,
detail=detail,
token_hash=token_hash,
token_display=token_display,
bound_ip=bound_ip,
newly_bound=newly_bound,
)
def evaluate_existing_record(
self,
record: BindingRecord,
token_hash: str,
token_display: str,
detail: str,
) -> BindingCheckResult:
if record.status == STATUS_BANNED:
return self.denied_result(token_hash, token_display, record.bound_ip, "Token is banned.")
if record.ip_matched:
self.record_last_used(token_hash)
return self.allowed_result(token_hash, token_display, record.bound_ip, detail)
return self.denied_result(
token_hash,
token_display,
record.bound_ip,
"Client IP does not match the allowed binding rule.",
)
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
token_display = mask_token(token)
cache_hit, cache_available = await self._load_binding_from_cache(token_hash, client_ip)
if cache_hit is not None:
if cache_hit.ip_matched:
await self._touch_cache(token_hash)
return self.evaluate_existing_record(cache_hit, token_hash, token_display, "Allowed from cache.")
if not cache_available:
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
if not await self._redis_degraded_limiter.allow():
logger.warning("Redis degraded limiter rejected a request during PostgreSQL fallback.")
return BindingCheckResult(
allowed=False,
status_code=429,
detail="Redis degraded mode rate limit reached.",
token_hash=token_hash,
token_display=token_display,
)
try:
record = await self._load_binding_from_db(token_hash, client_ip)
except SQLAlchemyError:
return self._handle_backend_failure(token_hash, token_display)
if record is not None:
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.binding_mode, record.allowed_ips, record.status)
return self.evaluate_existing_record(record, token_hash, token_display, "Allowed from PostgreSQL.")
try:
created = await self._create_binding(token_hash, token_display, client_ip)
except SQLAlchemyError:
return self._handle_backend_failure(token_hash, token_display)
if created is None:
try:
existing = await self._load_binding_from_db(token_hash, client_ip)
except SQLAlchemyError:
return self._handle_backend_failure(token_hash, token_display)
if existing is None:
return self._handle_backend_failure(token_hash, token_display)
await self.sync_binding_cache(
existing.token_hash,
existing.bound_ip,
existing.binding_mode,
existing.allowed_ips,
existing.status,
)
return self.evaluate_existing_record(existing, token_hash, token_display, "Allowed after concurrent bind resolution.")
await self.sync_binding_cache(
created.token_hash,
created.bound_ip,
created.binding_mode,
created.allowed_ips,
created.status,
)
return self.allowed_result(token_hash, token_display, created.bound_ip, "First-use bind created.", newly_bound=True)
async def sync_binding_cache(
self,
token_hash: str,
bound_ip: str,
binding_mode: str,
allowed_ips: list[str],
status_code: int,
) -> None:
if self.redis is None:
return
payload = json.dumps(
{
"bound_ip": bound_ip,
"binding_mode": binding_mode,
"allowed_ips": allowed_ips,
"status": status_code,
}
)
try:
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
except Exception:
logger.warning("Failed to write binding cache.", extra={"token_hash": token_hash})
async def invalidate_binding_cache(self, token_hash: str) -> None:
if self.redis is None:
return
try:
await self.redis.delete(self.cache_key(token_hash))
except Exception:
logger.warning("Failed to delete binding cache.", extra={"token_hash": token_hash})
async def invalidate_many(self, token_hashes: list[str]) -> None:
if self.redis is None or not token_hashes:
return
keys = [self.cache_key(item) for item in token_hashes]
try:
await self.redis.delete(*keys)
except Exception:
logger.warning("Failed to delete multiple binding cache keys.", extra={"count": len(keys)})
def record_last_used(self, token_hash: str) -> None:
try:
self.last_used_queue.put_nowait(token_hash)
except asyncio.QueueFull:
logger.warning("last_used queue is full; dropping update.", extra={"token_hash": token_hash})
async def flush_last_used_updates(self) -> None:
token_hashes: set[str] = set()
while True:
try:
token_hashes.add(self.last_used_queue.get_nowait())
except asyncio.QueueEmpty:
break
if not token_hashes:
return
async with self.session_factory() as session:
try:
stmt = (
update(TokenBinding)
.where(TokenBinding.token_hash.in_(token_hashes))
.values(last_used_at=func.now())
)
await session.execute(stmt)
await session.commit()
except SQLAlchemyError:
await session.rollback()
logger.exception("Failed to flush last_used_at updates.", extra={"count": len(token_hashes)})
async def increment_request_metric(self, outcome: str | None) -> None:
if self.redis is None:
return
key = self.metrics_key(date.today())
ttl = self.settings.metrics_ttl_days * 86400
try:
async with self.redis.pipeline(transaction=True) as pipeline:
pipeline.hincrby(key, "total", 1)
if outcome in {"allowed", "intercepted"}:
pipeline.hincrby(key, outcome, 1)
pipeline.expire(key, ttl)
await pipeline.execute()
except Exception:
logger.warning("Failed to increment request metrics.", extra={"outcome": outcome})
async def get_metrics_window(self, days: int = 7) -> list[dict[str, int | str]]:
if self.redis is None:
return [
{"date": (date.today() - timedelta(days=offset)).isoformat(), "allowed": 0, "intercepted": 0, "total": 0}
for offset in range(days - 1, -1, -1)
]
series: list[dict[str, int | str]] = []
for offset in range(days - 1, -1, -1):
target = date.today() - timedelta(days=offset)
raw = await self.redis.hgetall(self.metrics_key(target))
series.append(
{
"date": target.isoformat(),
"allowed": int(raw.get("allowed", 0)),
"intercepted": int(raw.get("intercepted", 0)),
"total": int(raw.get("total", 0)),
}
)
return series
async def _load_binding_from_cache(self, token_hash: str, client_ip: str) -> tuple[BindingRecord | None, bool]:
if self.redis is None:
return None, False
try:
raw = await self.redis.get(self.cache_key(token_hash))
except Exception:
logger.warning("Failed to read binding cache.", extra={"token_hash": token_hash})
return None, False
if raw is None:
return None, True
data = json.loads(raw)
allowed_ips = [str(item) for item in data.get("allowed_ips", [])]
binding_mode = str(data.get("binding_mode", BINDING_MODE_SINGLE))
return (
BindingRecord(
id=0,
token_hash=token_hash,
token_display="",
bound_ip=str(data.get("bound_ip", self.build_bound_ip_display(binding_mode, allowed_ips))),
binding_mode=binding_mode,
allowed_ips=allowed_ips,
status=int(data["status"]),
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
),
True,
)
async def _touch_cache(self, token_hash: str) -> None:
if self.redis is None:
return
try:
await self.redis.expire(self.cache_key(token_hash), self.settings.redis_binding_ttl_seconds)
except Exception:
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
async with self.session_factory() as session:
binding = await session.scalar(select(TokenBinding).where(TokenBinding.token_hash == token_hash).limit(1))
if binding is None:
return None
return self.to_binding_record(binding, client_ip)
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
async with self.session_factory() as session:
try:
binding = TokenBinding(
token_hash=token_hash,
token_display=token_display,
bound_ip=client_ip,
binding_mode=BINDING_MODE_SINGLE,
allowed_ips=[client_ip],
status=STATUS_ACTIVE,
)
session.add(binding)
await session.flush()
await session.commit()
await session.refresh(binding)
except SQLAlchemyError as exc:
await session.rollback()
if "duplicate key" in str(exc).lower() or "unique" in str(exc).lower():
return None
raise
return self.to_binding_record(binding, client_ip)
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
runtime_settings = self.runtime_settings_getter()
logger.exception(
"Binding storage backend failed.",
extra={"failsafe_mode": runtime_settings.failsafe_mode, "token_hash": token_hash},
)
if runtime_settings.failsafe_mode == "open":
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed by failsafe mode.",
token_hash=token_hash,
token_display=token_display,
)
return BindingCheckResult(
allowed=False,
status_code=503,
detail="Binding backend unavailable and failsafe mode is closed.",
token_hash=token_hash,
token_display=token_display,
)
async def _flush_loop(self) -> None:
try:
while not self._stop_event.is_set():
await asyncio.sleep(self.settings.last_used_flush_interval_seconds)
await self.flush_last_used_updates()
except asyncio.CancelledError:
raise