from __future__ import annotations import logging from urllib.parse import urlsplit import httpx from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, StreamingResponse from app.config import Settings from app.core.ip_utils import extract_client_ip from app.core.security import extract_request_token from app.dependencies import get_alert_service, get_binding_service, get_settings from app.services.alert_service import AlertService from app.services.binding_service import BindingService logger = logging.getLogger(__name__) router = APIRouter() CONTENT_LENGTH_HEADER = "content-length" def build_upstream_headers(request: Request, downstream_url: str) -> list[tuple[str, str]]: downstream_host = urlsplit(downstream_url).netloc headers: list[tuple[str, str]] = [] for header_name, header_value in request.headers.items(): if header_name.lower() == "host": continue headers.append((header_name, header_value)) headers.append(("host", downstream_host)) return headers def build_upstream_url(settings: Settings, request: Request) -> str: return f"{settings.downstream_url}{request.url.path}" def apply_downstream_headers(response: StreamingResponse, upstream_response: httpx.Response) -> None: for header_name, header_value in upstream_response.headers.multi_items(): if header_name.lower() == CONTENT_LENGTH_HEADER: continue response.headers.append(header_name, header_value) @router.api_route("/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"], include_in_schema=False) @router.api_route( "/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"], include_in_schema=False, ) async def reverse_proxy( request: Request, path: str = "", settings: Settings = Depends(get_settings), binding_service: BindingService = Depends(get_binding_service), alert_service: AlertService = Depends(get_alert_service), ): client_ip = extract_client_ip(request, settings) token, token_source = extract_request_token(request.headers) if token: binding_result = await binding_service.evaluate_token_binding(token, client_ip) if binding_result.allowed: await binding_service.increment_request_metric("allowed") else: await binding_service.increment_request_metric("intercepted" if binding_result.should_alert else None) if binding_result.should_alert and binding_result.token_hash and binding_result.token_display and binding_result.bound_ip: await alert_service.handle_intercept( token_hash=binding_result.token_hash, token_display=binding_result.token_display, bound_ip=binding_result.bound_ip, attempt_ip=client_ip, ) return JSONResponse( status_code=binding_result.status_code, content={"detail": binding_result.detail}, ) logger.debug("Token binding check passed.", extra={"client_ip": client_ip, "token_source": token_source}) else: await binding_service.increment_request_metric("allowed") downstream_client: httpx.AsyncClient = request.app.state.downstream_client upstream_url = build_upstream_url(settings, request) upstream_headers = build_upstream_headers(request, settings.downstream_url) try: upstream_request = downstream_client.build_request( request.method, upstream_url, params=request.query_params.multi_items(), headers=upstream_headers, content=request.stream(), ) upstream_response = await downstream_client.send(upstream_request, stream=True) except httpx.HTTPError as exc: logger.exception("Failed to reach downstream service.") return JSONResponse(status_code=502, content={"detail": f"Downstream request failed: {exc!s}"}) async def stream_response(): try: async for chunk in upstream_response.aiter_raw(): yield chunk finally: await upstream_response.aclose() response = StreamingResponse( stream_response(), status_code=upstream_response.status_code, media_type=None, ) apply_downstream_headers(response, upstream_response) return response