From 24ab4593f302d7ecb7811572d2c7419a27799161 Mon Sep 17 00:00:00 2001 From: Tudor Date: Wed, 7 Jan 2026 16:20:49 +0000 Subject: [PATCH] security improvements --- .env | 45 ++++++++++ .env.example | 45 ++++++++++ backend/app.py | 203 +++++++++++++++++++++++++++++++++++++++----- backend/config.py | 26 ++++-- frontend/index.html | 4 +- frontend/styles.css | 2 + requirements.txt | 2 + 7 files changed, 295 insertions(+), 32 deletions(-) create mode 100644 .env create mode 100644 .env.example diff --git a/.env b/.env new file mode 100644 index 0000000..7a8c4a2 --- /dev/null +++ b/.env @@ -0,0 +1,45 @@ +# SchoolCompare Environment Configuration +# Copy this file to .env and update the values + +# ============================================================================= +# DATABASE +# ============================================================================= +# PostgreSQL connection string +DATABASE_URL=postgresql://schoolcompare:CHANGE_THIS_PASSWORD@localhost:5432/schoolcompare + +# ============================================================================= +# SERVER +# ============================================================================= +# Set to False in production +DEBUG=False + +# Server host and port +HOST=0.0.0.0 +PORT=80 + +# ============================================================================= +# CORS +# ============================================================================= +# Comma-separated list of allowed origins +# In production, only include your actual domain +ALLOWED_ORIGINS=["https://schoolcompare.co.uk"] + +# ============================================================================= +# SECURITY +# ============================================================================= +# Admin API key for protected endpoints (e.g., /api/admin/reload) +# Generate a secure random key: python -c "import secrets; print(secrets.token_urlsafe(32))" +ADMIN_API_KEY=CHANGE_THIS_TO_A_SECURE_RANDOM_KEY + +# Rate limiting (requests per minute per IP) +RATE_LIMIT_PER_MINUTE=60 +RATE_LIMIT_BURST=10 + +# Maximum request body size in bytes (default 1MB) +MAX_REQUEST_SIZE=1048576 + +# ============================================================================= +# API +# ============================================================================= +DEFAULT_PAGE_SIZE=50 +MAX_PAGE_SIZE=100 diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..7a8c4a2 --- /dev/null +++ b/.env.example @@ -0,0 +1,45 @@ +# SchoolCompare Environment Configuration +# Copy this file to .env and update the values + +# ============================================================================= +# DATABASE +# ============================================================================= +# PostgreSQL connection string +DATABASE_URL=postgresql://schoolcompare:CHANGE_THIS_PASSWORD@localhost:5432/schoolcompare + +# ============================================================================= +# SERVER +# ============================================================================= +# Set to False in production +DEBUG=False + +# Server host and port +HOST=0.0.0.0 +PORT=80 + +# ============================================================================= +# CORS +# ============================================================================= +# Comma-separated list of allowed origins +# In production, only include your actual domain +ALLOWED_ORIGINS=["https://schoolcompare.co.uk"] + +# ============================================================================= +# SECURITY +# ============================================================================= +# Admin API key for protected endpoints (e.g., /api/admin/reload) +# Generate a secure random key: python -c "import secrets; print(secrets.token_urlsafe(32))" +ADMIN_API_KEY=CHANGE_THIS_TO_A_SECURE_RANDOM_KEY + +# Rate limiting (requests per minute per IP) +RATE_LIMIT_PER_MINUTE=60 +RATE_LIMIT_BURST=10 + +# Maximum request body size in bytes (default 1MB) +MAX_REQUEST_SIZE=1048576 + +# ============================================================================= +# API +# ============================================================================= +DEFAULT_PAGE_SIZE=50 +MAX_PAGE_SIZE=100 diff --git a/backend/app.py b/backend/app.py index cddd625..e3b8c3e 100644 --- a/backend/app.py +++ b/backend/app.py @@ -4,14 +4,19 @@ Serves primary school (KS2) performance data for comparing schools. Uses real data from UK Government Compare School Performance downloads. """ +import re from contextlib import asynccontextmanager from typing import Optional import pandas as pd -from fastapi import FastAPI, HTTPException, Query +from fastapi import FastAPI, HTTPException, Query, Request, Depends, Header from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, Response from fastapi.staticfiles import StaticFiles +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded +from starlette.middleware.base import BaseHTTPMiddleware from .config import settings from .data_loader import ( @@ -27,6 +32,107 @@ from .schemas import METRIC_DEFINITIONS, RANKING_COLUMNS, SCHOOL_COLUMNS from .utils import clean_for_json +# ============================================================================= +# SECURITY MIDDLEWARE & HELPERS +# ============================================================================= + +# Rate limiter +limiter = Limiter(key_func=get_remote_address) + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """Add security headers to all responses.""" + + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + + # Prevent clickjacking + response.headers["X-Frame-Options"] = "DENY" + + # Prevent MIME type sniffing + response.headers["X-Content-Type-Options"] = "nosniff" + + # XSS Protection (legacy browsers) + response.headers["X-XSS-Protection"] = "1; mode=block" + + # Referrer policy + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + + # Permissions policy (restrict browser features) + response.headers["Permissions-Policy"] = ( + "geolocation=(), microphone=(), camera=(), payment=()" + ) + + # Content Security Policy + response.headers["Content-Security-Policy"] = ( + "default-src 'self'; " + "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; " + "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; " + "font-src 'self' https://fonts.gstatic.com; " + "img-src 'self' data:; " + "connect-src 'self'; " + "frame-ancestors 'none'; " + "base-uri 'self'; " + "form-action 'self';" + ) + + # HSTS (only enable if using HTTPS in production) + response.headers["Strict-Transport-Security"] = ( + "max-age=31536000; includeSubDomains" + ) + + return response + + +class RequestSizeLimitMiddleware(BaseHTTPMiddleware): + """Limit request body size to prevent DoS attacks.""" + + async def dispatch(self, request: Request, call_next): + content_length = request.headers.get("content-length") + if content_length: + if int(content_length) > settings.max_request_size: + return Response( + content="Request too large", + status_code=413, + ) + return await call_next(request) + + +def verify_admin_api_key(x_api_key: str = Header(None)) -> bool: + """Verify admin API key for protected endpoints.""" + if not x_api_key or x_api_key != settings.admin_api_key: + raise HTTPException( + status_code=401, + detail="Invalid or missing API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + return True + + +# Input validation helpers +def sanitize_search_input(value: Optional[str], max_length: int = 100) -> Optional[str]: + """Sanitize search input to prevent injection attacks.""" + if value is None: + return None + # Strip whitespace and limit length + value = value.strip()[:max_length] + # Remove potentially dangerous characters (allow alphanumeric, spaces, common punctuation) + value = re.sub(r"[^\w\s\-\',\.]", "", value) + return value if value else None + + +def validate_postcode(postcode: Optional[str]) -> Optional[str]: + """Validate and normalize UK postcode format.""" + if not postcode: + return None + postcode = postcode.strip().upper() + # UK postcode pattern + pattern = r"^[A-Z]{1,2}[0-9][A-Z0-9]?\s*[0-9][A-Z]{2}$" + if not re.match(pattern, postcode): + return None + return postcode + + @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan - startup and shutdown events.""" @@ -52,15 +158,27 @@ app = FastAPI( description="API for comparing primary school (KS2) performance data - schoolcompare.co.uk", version="2.0.0", lifespan=lifespan, + # Disable docs in production for security + docs_url="/docs" if settings.debug else None, + redoc_url="/redoc" if settings.debug else None, + openapi_url="/openapi.json" if settings.debug else None, ) -# CORS middleware with configurable origins +# Add rate limiter +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +# Security middleware (order matters - these run in reverse order) +app.add_middleware(SecurityHeadersMiddleware) +app.add_middleware(RequestSizeLimitMiddleware) + +# CORS middleware - restricted for production app.add_middleware( CORSMiddleware, allow_origins=settings.allowed_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_credentials=False, # Don't allow credentials unless needed + allow_methods=["GET", "POST"], # Only allow needed methods + allow_headers=["Content-Type", "X-API-Key"], # Only allow needed headers ) @@ -83,15 +201,17 @@ async def serve_rankings(): @app.get("/api/schools") +@limiter.limit(f"{settings.rate_limit_per_minute}/minute") async def get_schools( - search: Optional[str] = Query(None, description="Search by school name"), + request: Request, + search: Optional[str] = Query(None, description="Search by school name", max_length=100), local_authority: Optional[str] = Query( - None, description="Filter by local authority" + None, description="Filter by local authority", max_length=100 ), - school_type: Optional[str] = Query(None, description="Filter by school type"), - postcode: Optional[str] = Query(None, description="Search near postcode"), + school_type: Optional[str] = Query(None, description="Filter by school type", max_length=100), + postcode: Optional[str] = Query(None, description="Search near postcode", max_length=10), radius: float = Query(5.0, ge=0.1, le=50, description="Search radius in miles"), - page: int = Query(1, ge=1, description="Page number"), + page: int = Query(1, ge=1, le=1000, description="Page number"), page_size: int = Query(None, ge=1, le=100, description="Results per page"), ): """ @@ -100,6 +220,12 @@ async def get_schools( Returns paginated results with total count for efficient loading. Supports location-based search using postcode. """ + # Sanitize inputs + search = sanitize_search_input(search) + local_authority = sanitize_search_input(local_authority) + school_type = sanitize_search_input(school_type) + postcode = validate_postcode(postcode) + df = load_school_data() if df.empty: @@ -216,8 +342,13 @@ async def get_schools( @app.get("/api/schools/{urn}") -async def get_school_details(urn: int): +@limiter.limit(f"{settings.rate_limit_per_minute}/minute") +async def get_school_details(request: Request, urn: int): """Get detailed KS2 data for a specific primary school across all years.""" + # Validate URN range (UK school URNs are 6 digits) + if not (100000 <= urn <= 999999): + raise HTTPException(status_code=400, detail="Invalid URN format") + df = load_school_data() if df.empty: @@ -248,7 +379,11 @@ async def get_school_details(urn: int): @app.get("/api/compare") -async def compare_schools(urns: str = Query(..., description="Comma-separated URNs")): +@limiter.limit(f"{settings.rate_limit_per_minute}/minute") +async def compare_schools( + request: Request, + urns: str = Query(..., description="Comma-separated URNs", max_length=100) +): """Compare multiple primary schools side by side.""" df = load_school_data() @@ -257,6 +392,13 @@ async def compare_schools(urns: str = Query(..., description="Comma-separated UR try: urn_list = [int(u.strip()) for u in urns.split(",")] + # Limit number of schools to compare + if len(urn_list) > 10: + raise HTTPException(status_code=400, detail="Maximum 10 schools can be compared") + # Validate URN format + for urn in urn_list: + if not (100000 <= urn <= 999999): + raise HTTPException(status_code=400, detail="Invalid URN format") except ValueError: raise HTTPException(status_code=400, detail="Invalid URN format") @@ -284,7 +426,8 @@ async def compare_schools(urns: str = Query(..., description="Comma-separated UR @app.get("/api/filters") -async def get_filter_options(): +@limiter.limit(f"{settings.rate_limit_per_minute}/minute") +async def get_filter_options(request: Request): """Get available filter options (local authorities, school types, years).""" df = load_school_data() @@ -303,7 +446,8 @@ async def get_filter_options(): @app.get("/api/metrics") -async def get_available_metrics(): +@limiter.limit(f"{settings.rate_limit_per_minute}/minute") +async def get_available_metrics(request: Request): """ Get list of available KS2 performance metrics for primary schools. @@ -321,17 +465,26 @@ async def get_available_metrics(): @app.get("/api/rankings") +@limiter.limit(f"{settings.rate_limit_per_minute}/minute") async def get_rankings( - metric: str = Query("rwm_expected_pct", description="KS2 metric to rank by"), + request: Request, + metric: str = Query("rwm_expected_pct", description="KS2 metric to rank by", max_length=50), year: Optional[int] = Query( - None, description="Specific year (defaults to most recent)" + None, description="Specific year (defaults to most recent)", ge=2000, le=2100 ), limit: int = Query(20, ge=1, le=100, description="Number of schools to return"), local_authority: Optional[str] = Query( - None, description="Filter by local authority" + None, description="Filter by local authority", max_length=100 ), ): """Get primary school rankings by a specific KS2 metric.""" + # Sanitize local authority input + local_authority = sanitize_search_input(local_authority) + + # Validate metric name (only allow alphanumeric and underscore) + if not re.match(r"^[a-z0-9_]+$", metric): + raise HTTPException(status_code=400, detail="Invalid metric name") + df = load_school_data() if df.empty: @@ -372,7 +525,8 @@ async def get_rankings( @app.get("/api/data-info") -async def get_data_info(): +@limiter.limit(f"{settings.rate_limit_per_minute}/minute") +async def get_data_info(request: Request): """Get information about loaded data.""" # Get info directly from database db_info = get_db_info() @@ -416,8 +570,15 @@ async def get_data_info(): @app.post("/api/admin/reload") -async def reload_data(): - """Admin endpoint to force data reload (useful after data updates).""" +@limiter.limit("5/minute") +async def reload_data( + request: Request, + _: bool = Depends(verify_admin_api_key) +): + """ + Admin endpoint to force data reload (useful after data updates). + Requires X-API-Key header with valid admin API key. + """ clear_cache() load_school_data() return {"status": "reloaded"} diff --git a/backend/config.py b/backend/config.py index 255c62d..04cefa4 100644 --- a/backend/config.py +++ b/backend/config.py @@ -3,33 +3,41 @@ Application configuration using pydantic-settings. Loads from environment variables and .env file. """ +import secrets from pathlib import Path from typing import List, Optional from pydantic_settings import BaseSettings -import os +from pydantic import Field class Settings(BaseSettings): """Application settings loaded from environment.""" - + # Paths data_dir: Path = Path(__file__).parent.parent / "data" frontend_dir: Path = Path(__file__).parent.parent / "frontend" - + # Server host: str = "0.0.0.0" port: int = 80 - + debug: bool = False # Set to False in production + # Database database_url: str = "postgresql://schoolcompare:schoolcompare@localhost:5432/schoolcompare" - - # CORS - allowed_origins: List[str] = ["https://schoolcompare.co.uk", "http://localhost:8000", "http://localhost:3000"] - + + # CORS - Production should only allow the actual domain + allowed_origins: List[str] = ["https://schoolcompare.co.uk"] + # API default_page_size: int = 50 max_page_size: int = 100 - + + # Security + admin_api_key: str = Field(default_factory=lambda: secrets.token_urlsafe(32)) + rate_limit_per_minute: int = 60 # Requests per minute per IP + rate_limit_burst: int = 10 # Allow burst of requests + max_request_size: int = 1024 * 1024 # 1MB max request size + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/frontend/index.html b/frontend/index.html index 15eb338..a8f7e2c 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -15,7 +15,7 @@
-