Add multi-IP binding modes and deployment guide
This commit is contained in:
@@ -5,18 +5,25 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, date, timedelta
|
||||
from datetime import date, timedelta
|
||||
from typing import Callable
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import func, select, text, update
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.config import RuntimeSettings, Settings
|
||||
from app.core.ip_utils import is_ip_in_network
|
||||
from app.core.security import hash_token, mask_token
|
||||
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
|
||||
from app.models.token_binding import (
|
||||
BINDING_MODE_ALL,
|
||||
BINDING_MODE_MULTIPLE,
|
||||
BINDING_MODE_SINGLE,
|
||||
STATUS_ACTIVE,
|
||||
STATUS_BANNED,
|
||||
TokenBinding,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,6 +34,8 @@ class BindingRecord:
|
||||
token_hash: str
|
||||
token_display: str
|
||||
bound_ip: str
|
||||
binding_mode: str
|
||||
allowed_ips: list[str]
|
||||
status: int
|
||||
ip_matched: bool
|
||||
|
||||
@@ -104,42 +113,101 @@ class BindingService:
|
||||
def metrics_key(self, target_date: date) -> str:
|
||||
return f"sentinel:metrics:{target_date.isoformat()}"
|
||||
|
||||
def build_bound_ip_display(self, binding_mode: str, allowed_ips: list[str]) -> str:
|
||||
if binding_mode == BINDING_MODE_ALL:
|
||||
return "ALL"
|
||||
if not allowed_ips:
|
||||
return "-"
|
||||
if binding_mode == BINDING_MODE_MULTIPLE:
|
||||
return ", ".join(allowed_ips)
|
||||
return allowed_ips[0]
|
||||
|
||||
def is_client_allowed(self, client_ip: str, binding_mode: str, allowed_ips: list[str]) -> bool:
|
||||
if binding_mode == BINDING_MODE_ALL:
|
||||
return True
|
||||
return any(is_ip_in_network(client_ip, item) for item in allowed_ips)
|
||||
|
||||
def to_binding_record(self, binding: TokenBinding, client_ip: str) -> BindingRecord:
|
||||
allowed_ips = [str(item) for item in binding.allowed_ips]
|
||||
binding_mode = binding.binding_mode or BINDING_MODE_SINGLE
|
||||
return BindingRecord(
|
||||
id=binding.id,
|
||||
token_hash=binding.token_hash,
|
||||
token_display=binding.token_display,
|
||||
bound_ip=binding.bound_ip,
|
||||
binding_mode=binding_mode,
|
||||
allowed_ips=allowed_ips,
|
||||
status=binding.status,
|
||||
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
|
||||
)
|
||||
|
||||
def denied_result(
|
||||
self,
|
||||
token_hash: str,
|
||||
token_display: str,
|
||||
bound_ip: str,
|
||||
detail: str,
|
||||
*,
|
||||
should_alert: bool = True,
|
||||
status_code: int = 403,
|
||||
) -> BindingCheckResult:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=status_code,
|
||||
detail=detail,
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=bound_ip,
|
||||
should_alert=should_alert,
|
||||
)
|
||||
|
||||
def allowed_result(
|
||||
self,
|
||||
token_hash: str,
|
||||
token_display: str,
|
||||
bound_ip: str,
|
||||
detail: str,
|
||||
*,
|
||||
newly_bound: bool = False,
|
||||
) -> BindingCheckResult:
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail=detail,
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=bound_ip,
|
||||
newly_bound=newly_bound,
|
||||
)
|
||||
|
||||
def evaluate_existing_record(
|
||||
self,
|
||||
record: BindingRecord,
|
||||
token_hash: str,
|
||||
token_display: str,
|
||||
detail: str,
|
||||
) -> BindingCheckResult:
|
||||
if record.status == STATUS_BANNED:
|
||||
return self.denied_result(token_hash, token_display, record.bound_ip, "Token is banned.")
|
||||
if record.ip_matched:
|
||||
self.record_last_used(token_hash)
|
||||
return self.allowed_result(token_hash, token_display, record.bound_ip, detail)
|
||||
return self.denied_result(
|
||||
token_hash,
|
||||
token_display,
|
||||
record.bound_ip,
|
||||
"Client IP does not match the allowed binding rule.",
|
||||
)
|
||||
|
||||
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
|
||||
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
|
||||
token_display = mask_token(token)
|
||||
|
||||
cache_hit, cache_available = await self._load_binding_from_cache(token_hash)
|
||||
cache_hit, cache_available = await self._load_binding_from_cache(token_hash, client_ip)
|
||||
if cache_hit is not None:
|
||||
if cache_hit.status == STATUS_BANNED:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Token is banned.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=cache_hit.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
if is_ip_in_network(client_ip, cache_hit.bound_ip):
|
||||
if cache_hit.ip_matched:
|
||||
await self._touch_cache(token_hash)
|
||||
self.record_last_used(token_hash)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed from cache.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=cache_hit.bound_ip,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Client IP does not match the bound CIDR.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=cache_hit.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
return self.evaluate_existing_record(cache_hit, token_hash, token_display, "Allowed from cache.")
|
||||
|
||||
if not cache_available:
|
||||
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
|
||||
@@ -159,36 +227,8 @@ class BindingService:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
|
||||
if record is not None:
|
||||
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.status)
|
||||
if record.status == STATUS_BANNED:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Token is banned.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=record.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
if record.ip_matched:
|
||||
self.record_last_used(token_hash)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed from PostgreSQL.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=record.bound_ip,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Client IP does not match the bound CIDR.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=record.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.binding_mode, record.allowed_ips, record.status)
|
||||
return self.evaluate_existing_record(record, token_hash, token_display, "Allowed from PostgreSQL.")
|
||||
|
||||
try:
|
||||
created = await self._create_binding(token_hash, token_display, client_ip)
|
||||
@@ -202,52 +242,42 @@ class BindingService:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
if existing is None:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
await self.sync_binding_cache(existing.token_hash, existing.bound_ip, existing.status)
|
||||
if existing.status == STATUS_BANNED:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Token is banned.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=existing.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
if existing.ip_matched:
|
||||
self.record_last_used(token_hash)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed after concurrent bind resolution.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=existing.bound_ip,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Client IP does not match the bound CIDR.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=existing.bound_ip,
|
||||
should_alert=True,
|
||||
await self.sync_binding_cache(
|
||||
existing.token_hash,
|
||||
existing.bound_ip,
|
||||
existing.binding_mode,
|
||||
existing.allowed_ips,
|
||||
existing.status,
|
||||
)
|
||||
return self.evaluate_existing_record(existing, token_hash, token_display, "Allowed after concurrent bind resolution.")
|
||||
|
||||
await self.sync_binding_cache(created.token_hash, created.bound_ip, created.status)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="First-use bind created.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=created.bound_ip,
|
||||
newly_bound=True,
|
||||
await self.sync_binding_cache(
|
||||
created.token_hash,
|
||||
created.bound_ip,
|
||||
created.binding_mode,
|
||||
created.allowed_ips,
|
||||
created.status,
|
||||
)
|
||||
return self.allowed_result(token_hash, token_display, created.bound_ip, "First-use bind created.", newly_bound=True)
|
||||
|
||||
async def sync_binding_cache(self, token_hash: str, bound_ip: str, status_code: int) -> None:
|
||||
async def sync_binding_cache(
|
||||
self,
|
||||
token_hash: str,
|
||||
bound_ip: str,
|
||||
binding_mode: str,
|
||||
allowed_ips: list[str],
|
||||
status_code: int,
|
||||
) -> None:
|
||||
if self.redis is None:
|
||||
return
|
||||
payload = json.dumps({"bound_ip": bound_ip, "status": status_code})
|
||||
payload = json.dumps(
|
||||
{
|
||||
"bound_ip": bound_ip,
|
||||
"binding_mode": binding_mode,
|
||||
"allowed_ips": allowed_ips,
|
||||
"status": status_code,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
|
||||
except Exception:
|
||||
@@ -336,7 +366,7 @@ class BindingService:
|
||||
)
|
||||
return series
|
||||
|
||||
async def _load_binding_from_cache(self, token_hash: str) -> tuple[BindingRecord | None, bool]:
|
||||
async def _load_binding_from_cache(self, token_hash: str, client_ip: str) -> tuple[BindingRecord | None, bool]:
|
||||
if self.redis is None:
|
||||
return None, False
|
||||
try:
|
||||
@@ -348,14 +378,18 @@ class BindingService:
|
||||
return None, True
|
||||
|
||||
data = json.loads(raw)
|
||||
allowed_ips = [str(item) for item in data.get("allowed_ips", [])]
|
||||
binding_mode = str(data.get("binding_mode", BINDING_MODE_SINGLE))
|
||||
return (
|
||||
BindingRecord(
|
||||
id=0,
|
||||
token_hash=token_hash,
|
||||
token_display="",
|
||||
bound_ip=data["bound_ip"],
|
||||
bound_ip=str(data.get("bound_ip", self.build_bound_ip_display(binding_mode, allowed_ips))),
|
||||
binding_mode=binding_mode,
|
||||
allowed_ips=allowed_ips,
|
||||
status=int(data["status"]),
|
||||
ip_matched=False,
|
||||
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
|
||||
),
|
||||
True,
|
||||
)
|
||||
@@ -369,69 +403,33 @@ class BindingService:
|
||||
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
|
||||
|
||||
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
|
||||
query = text(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
token_hash,
|
||||
token_display,
|
||||
bound_ip::text AS bound_ip,
|
||||
status,
|
||||
CAST(:client_ip AS inet) << bound_ip AS ip_matched
|
||||
FROM token_bindings
|
||||
WHERE token_hash = :token_hash
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(query, {"token_hash": token_hash, "client_ip": client_ip})
|
||||
row = result.mappings().first()
|
||||
if row is None:
|
||||
binding = await session.scalar(select(TokenBinding).where(TokenBinding.token_hash == token_hash).limit(1))
|
||||
if binding is None:
|
||||
return None
|
||||
return BindingRecord(
|
||||
id=int(row["id"]),
|
||||
token_hash=str(row["token_hash"]),
|
||||
token_display=str(row["token_display"]),
|
||||
bound_ip=str(row["bound_ip"]),
|
||||
status=int(row["status"]),
|
||||
ip_matched=bool(row["ip_matched"]),
|
||||
)
|
||||
return self.to_binding_record(binding, client_ip)
|
||||
|
||||
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
|
||||
statement = text(
|
||||
"""
|
||||
INSERT INTO token_bindings (token_hash, token_display, bound_ip, status)
|
||||
VALUES (:token_hash, :token_display, CAST(:bound_ip AS cidr), :status)
|
||||
ON CONFLICT (token_hash) DO NOTHING
|
||||
RETURNING id, token_hash, token_display, bound_ip::text AS bound_ip, status
|
||||
"""
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
try:
|
||||
result = await session.execute(
|
||||
statement,
|
||||
{
|
||||
"token_hash": token_hash,
|
||||
"token_display": token_display,
|
||||
"bound_ip": client_ip,
|
||||
"status": STATUS_ACTIVE,
|
||||
},
|
||||
binding = TokenBinding(
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=client_ip,
|
||||
binding_mode=BINDING_MODE_SINGLE,
|
||||
allowed_ips=[client_ip],
|
||||
status=STATUS_ACTIVE,
|
||||
)
|
||||
row = result.mappings().first()
|
||||
session.add(binding)
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
except SQLAlchemyError:
|
||||
await session.refresh(binding)
|
||||
except SQLAlchemyError as exc:
|
||||
await session.rollback()
|
||||
if "duplicate key" in str(exc).lower() or "unique" in str(exc).lower():
|
||||
return None
|
||||
raise
|
||||
if row is None:
|
||||
return None
|
||||
return BindingRecord(
|
||||
id=int(row["id"]),
|
||||
token_hash=str(row["token_hash"]),
|
||||
token_display=str(row["token_display"]),
|
||||
bound_ip=str(row["bound_ip"]),
|
||||
status=int(row["status"]),
|
||||
ip_matched=True,
|
||||
)
|
||||
return self.to_binding_record(binding, client_ip)
|
||||
|
||||
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
|
||||
runtime_settings = self.runtime_settings_getter()
|
||||
|
||||
Reference in New Issue
Block a user