Add multi-IP binding modes and deployment guide

This commit is contained in:
2026-03-04 15:30:13 +08:00
parent 4348ee799b
commit eed1acd454
12 changed files with 509 additions and 217 deletions

View File

@@ -5,18 +5,25 @@ import json
import logging
import time
from dataclasses import dataclass
from datetime import UTC, date, timedelta
from datetime import date, timedelta
from typing import Callable
from redis.asyncio import Redis
from sqlalchemy import func, select, text, update
from sqlalchemy import func, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.config import RuntimeSettings, Settings
from app.core.ip_utils import is_ip_in_network
from app.core.security import hash_token, mask_token
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
from app.models.token_binding import (
BINDING_MODE_ALL,
BINDING_MODE_MULTIPLE,
BINDING_MODE_SINGLE,
STATUS_ACTIVE,
STATUS_BANNED,
TokenBinding,
)
logger = logging.getLogger(__name__)
@@ -27,6 +34,8 @@ class BindingRecord:
token_hash: str
token_display: str
bound_ip: str
binding_mode: str
allowed_ips: list[str]
status: int
ip_matched: bool
@@ -104,42 +113,101 @@ class BindingService:
def metrics_key(self, target_date: date) -> str:
return f"sentinel:metrics:{target_date.isoformat()}"
def build_bound_ip_display(self, binding_mode: str, allowed_ips: list[str]) -> str:
if binding_mode == BINDING_MODE_ALL:
return "ALL"
if not allowed_ips:
return "-"
if binding_mode == BINDING_MODE_MULTIPLE:
return ", ".join(allowed_ips)
return allowed_ips[0]
def is_client_allowed(self, client_ip: str, binding_mode: str, allowed_ips: list[str]) -> bool:
if binding_mode == BINDING_MODE_ALL:
return True
return any(is_ip_in_network(client_ip, item) for item in allowed_ips)
def to_binding_record(self, binding: TokenBinding, client_ip: str) -> BindingRecord:
allowed_ips = [str(item) for item in binding.allowed_ips]
binding_mode = binding.binding_mode or BINDING_MODE_SINGLE
return BindingRecord(
id=binding.id,
token_hash=binding.token_hash,
token_display=binding.token_display,
bound_ip=binding.bound_ip,
binding_mode=binding_mode,
allowed_ips=allowed_ips,
status=binding.status,
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
)
def denied_result(
self,
token_hash: str,
token_display: str,
bound_ip: str,
detail: str,
*,
should_alert: bool = True,
status_code: int = 403,
) -> BindingCheckResult:
return BindingCheckResult(
allowed=False,
status_code=status_code,
detail=detail,
token_hash=token_hash,
token_display=token_display,
bound_ip=bound_ip,
should_alert=should_alert,
)
def allowed_result(
self,
token_hash: str,
token_display: str,
bound_ip: str,
detail: str,
*,
newly_bound: bool = False,
) -> BindingCheckResult:
return BindingCheckResult(
allowed=True,
status_code=200,
detail=detail,
token_hash=token_hash,
token_display=token_display,
bound_ip=bound_ip,
newly_bound=newly_bound,
)
def evaluate_existing_record(
self,
record: BindingRecord,
token_hash: str,
token_display: str,
detail: str,
) -> BindingCheckResult:
if record.status == STATUS_BANNED:
return self.denied_result(token_hash, token_display, record.bound_ip, "Token is banned.")
if record.ip_matched:
self.record_last_used(token_hash)
return self.allowed_result(token_hash, token_display, record.bound_ip, detail)
return self.denied_result(
token_hash,
token_display,
record.bound_ip,
"Client IP does not match the allowed binding rule.",
)
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
token_display = mask_token(token)
cache_hit, cache_available = await self._load_binding_from_cache(token_hash)
cache_hit, cache_available = await self._load_binding_from_cache(token_hash, client_ip)
if cache_hit is not None:
if cache_hit.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
should_alert=True,
)
if is_ip_in_network(client_ip, cache_hit.bound_ip):
if cache_hit.ip_matched:
await self._touch_cache(token_hash)
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed from cache.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
should_alert=True,
)
return self.evaluate_existing_record(cache_hit, token_hash, token_display, "Allowed from cache.")
if not cache_available:
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
@@ -159,36 +227,8 @@ class BindingService:
return self._handle_backend_failure(token_hash, token_display)
if record is not None:
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.status)
if record.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
should_alert=True,
)
if record.ip_matched:
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed from PostgreSQL.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
should_alert=True,
)
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.binding_mode, record.allowed_ips, record.status)
return self.evaluate_existing_record(record, token_hash, token_display, "Allowed from PostgreSQL.")
try:
created = await self._create_binding(token_hash, token_display, client_ip)
@@ -202,52 +242,42 @@ class BindingService:
return self._handle_backend_failure(token_hash, token_display)
if existing is None:
return self._handle_backend_failure(token_hash, token_display)
await self.sync_binding_cache(existing.token_hash, existing.bound_ip, existing.status)
if existing.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
should_alert=True,
)
if existing.ip_matched:
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed after concurrent bind resolution.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
should_alert=True,
await self.sync_binding_cache(
existing.token_hash,
existing.bound_ip,
existing.binding_mode,
existing.allowed_ips,
existing.status,
)
return self.evaluate_existing_record(existing, token_hash, token_display, "Allowed after concurrent bind resolution.")
await self.sync_binding_cache(created.token_hash, created.bound_ip, created.status)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="First-use bind created.",
token_hash=token_hash,
token_display=token_display,
bound_ip=created.bound_ip,
newly_bound=True,
await self.sync_binding_cache(
created.token_hash,
created.bound_ip,
created.binding_mode,
created.allowed_ips,
created.status,
)
return self.allowed_result(token_hash, token_display, created.bound_ip, "First-use bind created.", newly_bound=True)
async def sync_binding_cache(self, token_hash: str, bound_ip: str, status_code: int) -> None:
async def sync_binding_cache(
self,
token_hash: str,
bound_ip: str,
binding_mode: str,
allowed_ips: list[str],
status_code: int,
) -> None:
if self.redis is None:
return
payload = json.dumps({"bound_ip": bound_ip, "status": status_code})
payload = json.dumps(
{
"bound_ip": bound_ip,
"binding_mode": binding_mode,
"allowed_ips": allowed_ips,
"status": status_code,
}
)
try:
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
except Exception:
@@ -336,7 +366,7 @@ class BindingService:
)
return series
async def _load_binding_from_cache(self, token_hash: str) -> tuple[BindingRecord | None, bool]:
async def _load_binding_from_cache(self, token_hash: str, client_ip: str) -> tuple[BindingRecord | None, bool]:
if self.redis is None:
return None, False
try:
@@ -348,14 +378,18 @@ class BindingService:
return None, True
data = json.loads(raw)
allowed_ips = [str(item) for item in data.get("allowed_ips", [])]
binding_mode = str(data.get("binding_mode", BINDING_MODE_SINGLE))
return (
BindingRecord(
id=0,
token_hash=token_hash,
token_display="",
bound_ip=data["bound_ip"],
bound_ip=str(data.get("bound_ip", self.build_bound_ip_display(binding_mode, allowed_ips))),
binding_mode=binding_mode,
allowed_ips=allowed_ips,
status=int(data["status"]),
ip_matched=False,
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
),
True,
)
@@ -369,69 +403,33 @@ class BindingService:
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
query = text(
"""
SELECT
id,
token_hash,
token_display,
bound_ip::text AS bound_ip,
status,
CAST(:client_ip AS inet) << bound_ip AS ip_matched
FROM token_bindings
WHERE token_hash = :token_hash
LIMIT 1
"""
)
async with self.session_factory() as session:
result = await session.execute(query, {"token_hash": token_hash, "client_ip": client_ip})
row = result.mappings().first()
if row is None:
binding = await session.scalar(select(TokenBinding).where(TokenBinding.token_hash == token_hash).limit(1))
if binding is None:
return None
return BindingRecord(
id=int(row["id"]),
token_hash=str(row["token_hash"]),
token_display=str(row["token_display"]),
bound_ip=str(row["bound_ip"]),
status=int(row["status"]),
ip_matched=bool(row["ip_matched"]),
)
return self.to_binding_record(binding, client_ip)
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
statement = text(
"""
INSERT INTO token_bindings (token_hash, token_display, bound_ip, status)
VALUES (:token_hash, :token_display, CAST(:bound_ip AS cidr), :status)
ON CONFLICT (token_hash) DO NOTHING
RETURNING id, token_hash, token_display, bound_ip::text AS bound_ip, status
"""
)
async with self.session_factory() as session:
try:
result = await session.execute(
statement,
{
"token_hash": token_hash,
"token_display": token_display,
"bound_ip": client_ip,
"status": STATUS_ACTIVE,
},
binding = TokenBinding(
token_hash=token_hash,
token_display=token_display,
bound_ip=client_ip,
binding_mode=BINDING_MODE_SINGLE,
allowed_ips=[client_ip],
status=STATUS_ACTIVE,
)
row = result.mappings().first()
session.add(binding)
await session.flush()
await session.commit()
except SQLAlchemyError:
await session.refresh(binding)
except SQLAlchemyError as exc:
await session.rollback()
if "duplicate key" in str(exc).lower() or "unique" in str(exc).lower():
return None
raise
if row is None:
return None
return BindingRecord(
id=int(row["id"]),
token_hash=str(row["token_hash"]),
token_display=str(row["token_display"]),
bound_ip=str(row["bound_ip"]),
status=int(row["status"]),
ip_matched=True,
)
return self.to_binding_record(binding, client_ip)
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
runtime_settings = self.runtime_settings_getter()