from __future__ import annotations import asyncio import json import logging import time from dataclasses import dataclass from datetime import date, timedelta from typing import Callable from redis.asyncio import Redis from sqlalchemy import func, select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.config import RuntimeSettings, Settings from app.core.ip_utils import is_ip_in_network from app.core.security import hash_token, mask_token from app.models.token_binding import ( BINDING_MODE_ALL, BINDING_MODE_MULTIPLE, BINDING_MODE_SINGLE, STATUS_ACTIVE, STATUS_BANNED, TokenBinding, ) logger = logging.getLogger(__name__) @dataclass(slots=True) class BindingRecord: id: int token_hash: str token_display: str bound_ip: str binding_mode: str allowed_ips: list[str] status: int ip_matched: bool @dataclass(slots=True) class BindingCheckResult: allowed: bool status_code: int detail: str token_hash: str | None = None token_display: str | None = None bound_ip: str | None = None should_alert: bool = False newly_bound: bool = False class InMemoryRateLimiter: def __init__(self, max_per_second: int) -> None: self.max_per_second = max(1, max_per_second) self._window = int(time.monotonic()) self._count = 0 self._lock = asyncio.Lock() async def allow(self) -> bool: async with self._lock: current_window = int(time.monotonic()) if current_window != self._window: self._window = current_window self._count = 0 if self._count >= self.max_per_second: return False self._count += 1 return True class BindingService: def __init__( self, settings: Settings, session_factory: async_sessionmaker[AsyncSession], redis: Redis | None, runtime_settings_getter: Callable[[], RuntimeSettings], ) -> None: self.settings = settings self.session_factory = session_factory self.redis = redis self.runtime_settings_getter = runtime_settings_getter self.last_used_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=settings.last_used_queue_size) self._flush_task: asyncio.Task[None] | None = None self._stop_event = asyncio.Event() self._redis_degraded_limiter = InMemoryRateLimiter(settings.downstream_max_connections // 2) async def start(self) -> None: if self._flush_task is None: self._stop_event.clear() self._flush_task = asyncio.create_task(self._flush_loop(), name="binding-last-used-flush") async def stop(self) -> None: self._stop_event.set() if self._flush_task is not None: self._flush_task.cancel() try: await self._flush_task except asyncio.CancelledError: pass self._flush_task = None await self.flush_last_used_updates() def status_label(self, status_code: int) -> str: return "Active" if status_code == STATUS_ACTIVE else "Banned" def cache_key(self, token_hash: str) -> str: return f"sentinel:binding:{token_hash}" def metrics_key(self, target_date: date) -> str: return f"sentinel:metrics:{target_date.isoformat()}" def build_bound_ip_display(self, binding_mode: str, allowed_ips: list[str]) -> str: if binding_mode == BINDING_MODE_ALL: return "ALL" if not allowed_ips: return "-" if binding_mode == BINDING_MODE_MULTIPLE: return ", ".join(allowed_ips) return allowed_ips[0] def is_client_allowed(self, client_ip: str, binding_mode: str, allowed_ips: list[str]) -> bool: if binding_mode == BINDING_MODE_ALL: return True return any(is_ip_in_network(client_ip, item) for item in allowed_ips) def to_binding_record(self, binding: TokenBinding, client_ip: str) -> BindingRecord: allowed_ips = [str(item) for item in binding.allowed_ips] binding_mode = binding.binding_mode or BINDING_MODE_SINGLE return BindingRecord( id=binding.id, token_hash=binding.token_hash, token_display=binding.token_display, bound_ip=binding.bound_ip, binding_mode=binding_mode, allowed_ips=allowed_ips, status=binding.status, ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips), ) def denied_result( self, token_hash: str, token_display: str, bound_ip: str, detail: str, *, should_alert: bool = True, status_code: int = 403, ) -> BindingCheckResult: return BindingCheckResult( allowed=False, status_code=status_code, detail=detail, token_hash=token_hash, token_display=token_display, bound_ip=bound_ip, should_alert=should_alert, ) def allowed_result( self, token_hash: str, token_display: str, bound_ip: str, detail: str, *, newly_bound: bool = False, ) -> BindingCheckResult: return BindingCheckResult( allowed=True, status_code=200, detail=detail, token_hash=token_hash, token_display=token_display, bound_ip=bound_ip, newly_bound=newly_bound, ) def evaluate_existing_record( self, record: BindingRecord, token_hash: str, token_display: str, detail: str, ) -> BindingCheckResult: if record.status == STATUS_BANNED: return self.denied_result(token_hash, token_display, record.bound_ip, "Token is banned.") if record.ip_matched: self.record_last_used(token_hash) return self.allowed_result(token_hash, token_display, record.bound_ip, detail) return self.denied_result( token_hash, token_display, record.bound_ip, "Client IP does not match the allowed binding rule.", ) async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult: token_hash = hash_token(token, self.settings.sentinel_hmac_secret) token_display = mask_token(token) cache_hit, cache_available = await self._load_binding_from_cache(token_hash, client_ip) if cache_hit is not None: if cache_hit.ip_matched: await self._touch_cache(token_hash) return self.evaluate_existing_record(cache_hit, token_hash, token_display, "Allowed from cache.") if not cache_available: logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.") if not await self._redis_degraded_limiter.allow(): logger.warning("Redis degraded limiter rejected a request during PostgreSQL fallback.") return BindingCheckResult( allowed=False, status_code=429, detail="Redis degraded mode rate limit reached.", token_hash=token_hash, token_display=token_display, ) try: record = await self._load_binding_from_db(token_hash, client_ip) except SQLAlchemyError: return self._handle_backend_failure(token_hash, token_display) if record is not None: await self.sync_binding_cache(record.token_hash, record.bound_ip, record.binding_mode, record.allowed_ips, record.status) return self.evaluate_existing_record(record, token_hash, token_display, "Allowed from PostgreSQL.") try: created = await self._create_binding(token_hash, token_display, client_ip) except SQLAlchemyError: return self._handle_backend_failure(token_hash, token_display) if created is None: try: existing = await self._load_binding_from_db(token_hash, client_ip) except SQLAlchemyError: return self._handle_backend_failure(token_hash, token_display) if existing is None: return self._handle_backend_failure(token_hash, token_display) await self.sync_binding_cache( existing.token_hash, existing.bound_ip, existing.binding_mode, existing.allowed_ips, existing.status, ) return self.evaluate_existing_record(existing, token_hash, token_display, "Allowed after concurrent bind resolution.") await self.sync_binding_cache( created.token_hash, created.bound_ip, created.binding_mode, created.allowed_ips, created.status, ) return self.allowed_result(token_hash, token_display, created.bound_ip, "First-use bind created.", newly_bound=True) async def sync_binding_cache( self, token_hash: str, bound_ip: str, binding_mode: str, allowed_ips: list[str], status_code: int, ) -> None: if self.redis is None: return payload = json.dumps( { "bound_ip": bound_ip, "binding_mode": binding_mode, "allowed_ips": allowed_ips, "status": status_code, } ) try: await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds) except Exception: logger.warning("Failed to write binding cache.", extra={"token_hash": token_hash}) async def invalidate_binding_cache(self, token_hash: str) -> None: if self.redis is None: return try: await self.redis.delete(self.cache_key(token_hash)) except Exception: logger.warning("Failed to delete binding cache.", extra={"token_hash": token_hash}) async def invalidate_many(self, token_hashes: list[str]) -> None: if self.redis is None or not token_hashes: return keys = [self.cache_key(item) for item in token_hashes] try: await self.redis.delete(*keys) except Exception: logger.warning("Failed to delete multiple binding cache keys.", extra={"count": len(keys)}) def record_last_used(self, token_hash: str) -> None: try: self.last_used_queue.put_nowait(token_hash) except asyncio.QueueFull: logger.warning("last_used queue is full; dropping update.", extra={"token_hash": token_hash}) async def flush_last_used_updates(self) -> None: token_hashes: set[str] = set() while True: try: token_hashes.add(self.last_used_queue.get_nowait()) except asyncio.QueueEmpty: break if not token_hashes: return async with self.session_factory() as session: try: stmt = ( update(TokenBinding) .where(TokenBinding.token_hash.in_(token_hashes)) .values(last_used_at=func.now()) ) await session.execute(stmt) await session.commit() except SQLAlchemyError: await session.rollback() logger.exception("Failed to flush last_used_at updates.", extra={"count": len(token_hashes)}) async def increment_request_metric(self, outcome: str | None) -> None: if self.redis is None: return key = self.metrics_key(date.today()) ttl = self.settings.metrics_ttl_days * 86400 try: async with self.redis.pipeline(transaction=True) as pipeline: pipeline.hincrby(key, "total", 1) if outcome in {"allowed", "intercepted"}: pipeline.hincrby(key, outcome, 1) pipeline.expire(key, ttl) await pipeline.execute() except Exception: logger.warning("Failed to increment request metrics.", extra={"outcome": outcome}) async def get_metrics_window(self, days: int = 7) -> list[dict[str, int | str]]: if self.redis is None: return [ {"date": (date.today() - timedelta(days=offset)).isoformat(), "allowed": 0, "intercepted": 0, "total": 0} for offset in range(days - 1, -1, -1) ] series: list[dict[str, int | str]] = [] for offset in range(days - 1, -1, -1): target = date.today() - timedelta(days=offset) raw = await self.redis.hgetall(self.metrics_key(target)) series.append( { "date": target.isoformat(), "allowed": int(raw.get("allowed", 0)), "intercepted": int(raw.get("intercepted", 0)), "total": int(raw.get("total", 0)), } ) return series async def _load_binding_from_cache(self, token_hash: str, client_ip: str) -> tuple[BindingRecord | None, bool]: if self.redis is None: return None, False try: raw = await self.redis.get(self.cache_key(token_hash)) except Exception: logger.warning("Failed to read binding cache.", extra={"token_hash": token_hash}) return None, False if raw is None: return None, True data = json.loads(raw) allowed_ips = [str(item) for item in data.get("allowed_ips", [])] binding_mode = str(data.get("binding_mode", BINDING_MODE_SINGLE)) return ( BindingRecord( id=0, token_hash=token_hash, token_display="", bound_ip=str(data.get("bound_ip", self.build_bound_ip_display(binding_mode, allowed_ips))), binding_mode=binding_mode, allowed_ips=allowed_ips, status=int(data["status"]), ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips), ), True, ) async def _touch_cache(self, token_hash: str) -> None: if self.redis is None: return try: await self.redis.expire(self.cache_key(token_hash), self.settings.redis_binding_ttl_seconds) except Exception: logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash}) async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None: async with self.session_factory() as session: binding = await session.scalar(select(TokenBinding).where(TokenBinding.token_hash == token_hash).limit(1)) if binding is None: return None return self.to_binding_record(binding, client_ip) async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None: async with self.session_factory() as session: try: binding = TokenBinding( token_hash=token_hash, token_display=token_display, bound_ip=client_ip, binding_mode=BINDING_MODE_SINGLE, allowed_ips=[client_ip], status=STATUS_ACTIVE, ) session.add(binding) await session.flush() await session.commit() await session.refresh(binding) except SQLAlchemyError as exc: await session.rollback() if "duplicate key" in str(exc).lower() or "unique" in str(exc).lower(): return None raise return self.to_binding_record(binding, client_ip) def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult: runtime_settings = self.runtime_settings_getter() logger.exception( "Binding storage backend failed.", extra={"failsafe_mode": runtime_settings.failsafe_mode, "token_hash": token_hash}, ) if runtime_settings.failsafe_mode == "open": return BindingCheckResult( allowed=True, status_code=200, detail="Allowed by failsafe mode.", token_hash=token_hash, token_display=token_display, ) return BindingCheckResult( allowed=False, status_code=503, detail="Binding backend unavailable and failsafe mode is closed.", token_hash=token_hash, token_display=token_display, ) async def _flush_loop(self) -> None: try: while not self._stop_event.is_set(): await asyncio.sleep(self.settings.last_used_flush_interval_seconds) await self.flush_last_used_updates() except asyncio.CancelledError: raise