from __future__ import annotations from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase from app.config import Settings class Base(DeclarativeBase): pass _engine: AsyncEngine | None = None _session_factory: async_sessionmaker[AsyncSession] | None = None def init_db(settings: Settings) -> None: global _engine, _session_factory if _engine is not None and _session_factory is not None: return _engine = create_async_engine( settings.pg_dsn, pool_pre_ping=True, pool_size=20, max_overflow=40, ) _session_factory = async_sessionmaker(_engine, expire_on_commit=False) def get_engine() -> AsyncEngine: if _engine is None: raise RuntimeError("Database engine has not been initialized.") return _engine def get_session_factory() -> async_sessionmaker[AsyncSession]: if _session_factory is None: raise RuntimeError("Database session factory has not been initialized.") return _session_factory async def ensure_schema_compatibility() -> None: engine = get_engine() statements = [ "DROP INDEX IF EXISTS idx_token_bindings_ip", "ALTER TABLE token_bindings ALTER COLUMN bound_ip TYPE TEXT USING bound_ip::text", "ALTER TABLE intercept_logs ALTER COLUMN bound_ip TYPE TEXT USING bound_ip::text", "ALTER TABLE token_bindings ADD COLUMN IF NOT EXISTS binding_mode VARCHAR(16) DEFAULT 'single'", "ALTER TABLE token_bindings ADD COLUMN IF NOT EXISTS allowed_ips JSONB DEFAULT '[]'::jsonb", "UPDATE token_bindings SET binding_mode = 'single' WHERE binding_mode IS NULL OR binding_mode = ''", """ UPDATE token_bindings SET allowed_ips = jsonb_build_array(bound_ip) WHERE allowed_ips IS NULL OR allowed_ips = '[]'::jsonb """, "ALTER TABLE token_bindings ALTER COLUMN binding_mode SET NOT NULL", "ALTER TABLE token_bindings ALTER COLUMN allowed_ips SET NOT NULL", "ALTER TABLE token_bindings ALTER COLUMN binding_mode SET DEFAULT 'single'", "ALTER TABLE token_bindings ALTER COLUMN allowed_ips SET DEFAULT '[]'::jsonb", "CREATE INDEX IF NOT EXISTS idx_token_bindings_ip ON token_bindings(bound_ip)", ] async with engine.begin() as connection: for statement in statements: await connection.execute(text(statement)) async def close_db() -> None: global _engine, _session_factory if _engine is not None: await _engine.dispose() _engine = None _session_factory = None