""" SchoolCompare.co.uk API Serves primary school (KS2) performance data for comparing schools. Uses real data from UK Government Compare School Performance downloads. """ import re from contextlib import asynccontextmanager from typing import Optional import numpy as np import pandas as pd from fastapi import FastAPI, HTTPException, Query, Request, Depends, Header from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, Response from fastapi.staticfiles import StaticFiles from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from starlette.middleware.base import BaseHTTPMiddleware import asyncio from .config import settings from .data_loader import ( clear_cache, load_school_data, geocode_single_postcode, get_supplementary_data, ) from .data_loader import get_data_info as get_db_info from .database import check_and_migrate_if_needed from .migration import run_full_migration from .schemas import METRIC_DEFINITIONS, RANKING_COLUMNS, SCHOOL_COLUMNS from .utils import clean_for_json # ============================================================================= # SECURITY MIDDLEWARE & HELPERS # ============================================================================= # Rate limiter limiter = Limiter(key_func=get_remote_address) class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add security headers to all responses.""" async def dispatch(self, request: Request, call_next): response = await call_next(request) # Prevent clickjacking response.headers["X-Frame-Options"] = "DENY" # Prevent MIME type sniffing response.headers["X-Content-Type-Options"] = "nosniff" # XSS Protection (legacy browsers) response.headers["X-XSS-Protection"] = "1; mode=block" # Referrer policy response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" # Permissions policy (restrict browser features) response.headers["Permissions-Policy"] = ( "geolocation=(), microphone=(), camera=(), payment=()" ) # Content Security Policy response.headers["Content-Security-Policy"] = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://unpkg.com https://analytics.schoolcompare.co.uk; " "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com https://cdn.jsdelivr.net https://unpkg.com; " "font-src 'self' https://fonts.gstatic.com; " "img-src 'self' data: https://*.tile.openstreetmap.org https://unpkg.com; " "connect-src 'self' https://cdn.jsdelivr.net https://*.tile.openstreetmap.org https://unpkg.com https://analytics.schoolcompare.co.uk; " "frame-ancestors 'none'; " "base-uri 'self'; " "form-action 'self' https://formsubmit.co;" ) # HSTS (only enable if using HTTPS in production) response.headers["Strict-Transport-Security"] = ( "max-age=31536000; includeSubDomains" ) return response class RequestSizeLimitMiddleware(BaseHTTPMiddleware): """Limit request body size to prevent DoS attacks.""" async def dispatch(self, request: Request, call_next): content_length = request.headers.get("content-length") if content_length: if int(content_length) > settings.max_request_size: return Response( content="Request too large", status_code=413, ) return await call_next(request) def verify_admin_api_key(x_api_key: str = Header(None)) -> bool: """Verify admin API key for protected endpoints.""" if not x_api_key or x_api_key != settings.admin_api_key: raise HTTPException( status_code=401, detail="Invalid or missing API key", headers={"WWW-Authenticate": "ApiKey"}, ) return True # Input validation helpers def sanitize_search_input(value: Optional[str], max_length: int = 100) -> Optional[str]: """Sanitize search input to prevent injection attacks.""" if value is None: return None # Strip whitespace and limit length value = value.strip()[:max_length] # Remove potentially dangerous characters (allow alphanumeric, spaces, common punctuation) value = re.sub(r"[^\w\s\-\',\.]", "", value) return value if value else None def validate_postcode(postcode: Optional[str]) -> Optional[str]: """Validate and normalize UK postcode format.""" if not postcode: return None postcode = postcode.strip().upper() # UK postcode pattern pattern = r"^[A-Z]{1,2}[0-9][A-Z0-9]?\s*[0-9][A-Z]{2}$" if not re.match(pattern, postcode): return None return postcode @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan - startup and shutdown events.""" # Startup: check schema version and migrate if needed print("Starting up: Checking database schema...") check_and_migrate_if_needed() print("Loading school data from database...") df = load_school_data() if df.empty: print("Warning: No data in database. Check CSV files in data/ folder.") else: print(f"Data loaded successfully: {len(df)} records.") yield # Application runs here # Shutdown: cleanup if needed print("Shutting down...") app = FastAPI( title="SchoolCompare API", description="API for comparing primary school (KS2) performance data - schoolcompare.co.uk", version="2.0.0", lifespan=lifespan, # Disable docs in production for security docs_url="/docs" if settings.debug else None, redoc_url="/redoc" if settings.debug else None, openapi_url="/openapi.json" if settings.debug else None, ) # Add rate limiter app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # Security middleware (order matters - these run in reverse order) app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(RequestSizeLimitMiddleware) # CORS middleware - restricted for production app.add_middleware( CORSMiddleware, allow_origins=settings.allowed_origins, allow_credentials=False, # Don't allow credentials unless needed allow_methods=["GET", "POST"], # Only allow needed methods allow_headers=["Content-Type", "X-API-Key"], # Only allow needed headers ) @app.get("/") async def root(): """Serve the frontend.""" return FileResponse(settings.frontend_dir / "index.html") @app.get("/compare") async def serve_compare(): """Serve the frontend for /compare route (SPA routing).""" return FileResponse(settings.frontend_dir / "index.html") @app.get("/rankings") async def serve_rankings(): """Serve the frontend for /rankings route (SPA routing).""" return FileResponse(settings.frontend_dir / "index.html") @app.get("/api/config") async def get_config(): """Return public configuration for the frontend.""" return { "ga_measurement_id": settings.ga_measurement_id } @app.get("/api/schools") @limiter.limit(f"{settings.rate_limit_per_minute}/minute") async def get_schools( request: Request, search: Optional[str] = Query(None, description="Search by school name", max_length=100), local_authority: Optional[str] = Query( None, description="Filter by local authority", max_length=100 ), school_type: Optional[str] = Query(None, description="Filter by school type", max_length=100), postcode: Optional[str] = Query(None, description="Search near postcode", max_length=10), radius: float = Query(5.0, ge=0.1, le=50, description="Search radius in miles"), page: int = Query(1, ge=1, le=1000, description="Page number"), page_size: int = Query(None, ge=1, le=100, description="Results per page"), ): """ Get list of unique primary schools with pagination. Returns paginated results with total count for efficient loading. Supports location-based search using postcode. """ # Sanitize inputs search = sanitize_search_input(search) local_authority = sanitize_search_input(local_authority) school_type = sanitize_search_input(school_type) postcode = validate_postcode(postcode) df = load_school_data() if df.empty: return {"schools": [], "total": 0, "page": page, "page_size": 0} # Use configured default if not specified if page_size is None: page_size = settings.default_page_size # Get unique schools (latest year data for each) latest_year = df.groupby("urn")["year"].max().reset_index() df_latest = df.merge(latest_year, on=["urn", "year"]) # Calculate trend by comparing to previous year # Get second-latest year for each school df_sorted = df.sort_values(["urn", "year"], ascending=[True, False]) df_prev = df_sorted.groupby("urn").nth(1).reset_index() if not df_prev.empty and "rwm_expected_pct" in df_prev.columns: prev_rwm = df_prev[["urn", "rwm_expected_pct"]].rename( columns={"rwm_expected_pct": "prev_rwm_expected_pct"} ) df_latest = df_latest.merge(prev_rwm, on="urn", how="left") # Include key result metrics for display on cards location_cols = ["latitude", "longitude"] result_cols = [ "year", "rwm_expected_pct", "rwm_high_pct", "prev_rwm_expected_pct", "reading_expected_pct", "writing_expected_pct", "maths_expected_pct", "total_pupils", ] available_cols = [ c for c in SCHOOL_COLUMNS + location_cols + result_cols if c in df_latest.columns ] schools_df = df_latest[available_cols].drop_duplicates(subset=["urn"]) # Location-based search (uses pre-geocoded data from database) search_coords = None if postcode: coords = geocode_single_postcode(postcode) if coords: search_coords = coords schools_df = schools_df.copy() # Filter by distance using pre-geocoded lat/long from database # Use vectorized haversine calculation for better performance lat1, lon1 = search_coords # Handle potential duplicate columns by taking first occurrence lat_col = schools_df.loc[:, "latitude"] lon_col = schools_df.loc[:, "longitude"] if isinstance(lat_col, pd.DataFrame): lat_col = lat_col.iloc[:, 0] if isinstance(lon_col, pd.DataFrame): lon_col = lon_col.iloc[:, 0] lat2 = lat_col.values lon2 = lon_col.values # Vectorized haversine formula R = 3959 # Earth's radius in miles lat1_rad = np.radians(lat1) lat2_rad = np.radians(lat2) dlat = np.radians(lat2 - lat1) dlon = np.radians(lon2 - lon1) a = np.sin(dlat / 2) ** 2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon / 2) ** 2 c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) distances = R * c # Handle missing coordinates has_coords = ~(pd.isna(lat_col) | pd.isna(lon_col)) distances = np.where(has_coords.values, distances, float("inf")) schools_df["distance"] = distances schools_df = schools_df[schools_df["distance"] <= radius] schools_df = schools_df.sort_values("distance") # Apply filters if search: search_lower = search.lower() mask = ( schools_df["school_name"].str.lower().str.contains(search_lower, na=False) ) if "address" in schools_df.columns: mask = mask | schools_df["address"].str.lower().str.contains( search_lower, na=False ) schools_df = schools_df[mask] if local_authority: schools_df = schools_df[ schools_df["local_authority"].str.lower() == local_authority.lower() ] if school_type: schools_df = schools_df[ schools_df["school_type"].str.lower() == school_type.lower() ] # Pagination total = len(schools_df) start_idx = (page - 1) * page_size end_idx = start_idx + page_size schools_df = schools_df.iloc[start_idx:end_idx] return { "schools": clean_for_json(schools_df), "total": total, "page": page, "page_size": page_size, "total_pages": (total + page_size - 1) // page_size if page_size > 0 else 0, "location_info": { "postcode": postcode, "radius": radius * 1.60934, # Convert miles to km for frontend display "coordinates": [search_coords[0], search_coords[1]] } if search_coords else None, } @app.get("/api/schools/{urn}") @limiter.limit(f"{settings.rate_limit_per_minute}/minute") async def get_school_details(request: Request, urn: int): """Get detailed KS2 data for a specific primary school across all years.""" # Validate URN range (UK school URNs are 6 digits) if not (100000 <= urn <= 999999): raise HTTPException(status_code=400, detail="Invalid URN format") df = load_school_data() if df.empty: raise HTTPException(status_code=404, detail="No data available") school_data = df[df["urn"] == urn] if school_data.empty: raise HTTPException(status_code=404, detail="School not found") # Sort by year school_data = school_data.sort_values("year") # Get latest info for the school latest = school_data.iloc[-1] # Fetch supplementary data (Ofsted, Parent View, admissions, etc.) from .database import SessionLocal supplementary = {} try: db = SessionLocal() supplementary = get_supplementary_data(db, urn) db.close() except Exception: pass return { "school_info": { "urn": urn, "school_name": latest.get("school_name", ""), "local_authority": latest.get("local_authority", ""), "school_type": latest.get("school_type", ""), "address": latest.get("address", ""), "religious_denomination": latest.get("religious_denomination", ""), "age_range": latest.get("age_range", ""), "latitude": latest.get("latitude"), "longitude": latest.get("longitude"), "phase": "Primary", # GIAS fields "website": latest.get("website"), "headteacher_name": latest.get("headteacher_name"), "capacity": latest.get("capacity"), "trust_name": latest.get("trust_name"), "gender": latest.get("gender"), }, "yearly_data": clean_for_json(school_data), # Supplementary data (null if not yet populated by Kestra) "ofsted": supplementary.get("ofsted"), "parent_view": supplementary.get("parent_view"), "census": supplementary.get("census"), "admissions": supplementary.get("admissions"), "sen_detail": supplementary.get("sen_detail"), "phonics": supplementary.get("phonics"), "deprivation": supplementary.get("deprivation"), "finance": supplementary.get("finance"), } @app.get("/api/compare") @limiter.limit(f"{settings.rate_limit_per_minute}/minute") async def compare_schools( request: Request, urns: str = Query(..., description="Comma-separated URNs", max_length=100) ): """Compare multiple primary schools side by side.""" df = load_school_data() if df.empty: raise HTTPException(status_code=404, detail="No data available") try: urn_list = [int(u.strip()) for u in urns.split(",")] # Limit number of schools to compare if len(urn_list) > 10: raise HTTPException(status_code=400, detail="Maximum 10 schools can be compared") # Validate URN format for urn in urn_list: if not (100000 <= urn <= 999999): raise HTTPException(status_code=400, detail="Invalid URN format") except ValueError: raise HTTPException(status_code=400, detail="Invalid URN format") comparison_data = df[df["urn"].isin(urn_list)] if comparison_data.empty: raise HTTPException(status_code=404, detail="No schools found") result = {} for urn in urn_list: school_data = comparison_data[comparison_data["urn"] == urn].sort_values("year") if not school_data.empty: latest = school_data.iloc[-1] result[str(urn)] = { "school_info": { "urn": urn, "school_name": latest.get("school_name", ""), "local_authority": latest.get("local_authority", ""), "address": latest.get("address", ""), }, "yearly_data": clean_for_json(school_data), } return {"comparison": result} @app.get("/api/filters") @limiter.limit(f"{settings.rate_limit_per_minute}/minute") async def get_filter_options(request: Request): """Get available filter options (local authorities, school types, years).""" df = load_school_data() if df.empty: return { "local_authorities": [], "school_types": [], "years": [], } return { "local_authorities": sorted(df["local_authority"].dropna().unique().tolist()), "school_types": sorted(df["school_type"].dropna().unique().tolist()), "years": sorted(df["year"].dropna().unique().tolist()), } @app.get("/api/metrics") @limiter.limit(f"{settings.rate_limit_per_minute}/minute") async def get_available_metrics(request: Request): """ Get list of available KS2 performance metrics for primary schools. This is the single source of truth for metric definitions. Frontend should consume this to avoid duplication. """ df = load_school_data() available = [] for key, info in METRIC_DEFINITIONS.items(): if df.empty or key in df.columns: available.append({"key": key, **info}) return {"metrics": available} @app.get("/api/rankings") @limiter.limit(f"{settings.rate_limit_per_minute}/minute") async def get_rankings( request: Request, metric: str = Query("rwm_expected_pct", description="KS2 metric to rank by", max_length=50), year: Optional[int] = Query( None, description="Specific year (defaults to most recent)", ge=2000, le=2100 ), limit: int = Query(20, ge=1, le=100, description="Number of schools to return"), local_authority: Optional[str] = Query( None, description="Filter by local authority", max_length=100 ), ): """Get primary school rankings by a specific KS2 metric.""" # Sanitize local authority input local_authority = sanitize_search_input(local_authority) # Validate metric name (only allow alphanumeric and underscore) if not re.match(r"^[a-z0-9_]+$", metric): raise HTTPException(status_code=400, detail="Invalid metric name") df = load_school_data() if df.empty: return {"metric": metric, "year": None, "rankings": [], "total": 0} if metric not in df.columns: raise HTTPException(status_code=400, detail=f"Metric '{metric}' not available") # Filter by year if year: df = df[df["year"] == year] else: # Use most recent year max_year = df["year"].max() df = df[df["year"] == max_year] # Filter by local authority if specified if local_authority: df = df[df["local_authority"].str.lower() == local_authority.lower()] # Sort and rank (exclude rows with no data for this metric) df = df.dropna(subset=[metric]) total = len(df) # For progress scores, higher is better. For percentages, higher is also better. df = df.sort_values(metric, ascending=False).head(limit) # Return only relevant fields for rankings available_cols = [c for c in RANKING_COLUMNS if c in df.columns] df = df[available_cols] return { "metric": metric, "year": int(df["year"].iloc[0]) if not df.empty else None, "rankings": clean_for_json(df), "total": total, } @app.get("/api/data-info") @limiter.limit(f"{settings.rate_limit_per_minute}/minute") async def get_data_info(request: Request): """Get information about loaded data.""" # Get info directly from database db_info = get_db_info() if db_info["total_schools"] == 0: return { "status": "no_data", "message": "No data in database. Run the migration script: python scripts/migrate_csv_to_db.py", "data_source": "PostgreSQL", } # Also get DataFrame-based stats for backwards compatibility df = load_school_data() if df.empty: return { "status": "no_data", "message": "No data available", "data_source": "PostgreSQL", } years = [int(y) for y in sorted(df["year"].unique())] schools_per_year = { str(int(k)): int(v) for k, v in df.groupby("year")["urn"].nunique().to_dict().items() } la_counts = { str(k): int(v) for k, v in df["local_authority"].value_counts().to_dict().items() } return { "status": "loaded", "data_source": "PostgreSQL", "total_records": int(len(df)), "unique_schools": int(df["urn"].nunique()), "years_available": years, "schools_per_year": schools_per_year, "local_authorities": la_counts, } @app.post("/api/admin/reload") @limiter.limit("5/minute") async def reload_data( request: Request, _: bool = Depends(verify_admin_api_key) ): """ Admin endpoint to force data reload (useful after data updates). Requires X-API-Key header with valid admin API key. """ clear_cache() load_school_data() return {"status": "reloaded"} _reimport_status: dict = {"running": False, "done": False, "error": None} @app.post("/api/admin/reimport-ks2") @limiter.limit("2/minute") async def reimport_ks2( request: Request, geocode: bool = True, _: bool = Depends(verify_admin_api_key) ): """ Start a full KS2 CSV migration in the background and return immediately. Poll GET /api/admin/reimport-ks2/status to check progress. Pass ?geocode=false to skip postcode → lat/lng resolution. Requires X-API-Key header with valid admin API key. """ global _reimport_status if _reimport_status["running"]: return {"status": "already_running"} _reimport_status = {"running": True, "done": False, "error": None} def _run(): global _reimport_status try: success = run_full_migration(geocode=geocode) if not success: _reimport_status = {"running": False, "done": False, "error": "No CSV data found"} return clear_cache() load_school_data() _reimport_status = {"running": False, "done": True, "error": None} except Exception as exc: _reimport_status = {"running": False, "done": False, "error": str(exc)} import threading threading.Thread(target=_run, daemon=True).start() return {"status": "started"} @app.get("/api/admin/reimport-ks2/status") async def reimport_ks2_status( request: Request, _: bool = Depends(verify_admin_api_key) ): """Poll this endpoint to check reimport progress.""" s = _reimport_status if s["error"]: raise HTTPException(status_code=500, detail=s["error"]) return {"running": s["running"], "done": s["done"]} # ============================================================================= # SEO FILES # ============================================================================= @app.get("/favicon.svg") async def favicon(): """Serve favicon.""" return FileResponse(settings.frontend_dir / "favicon.svg", media_type="image/svg+xml") @app.get("/robots.txt") async def robots_txt(): """Serve robots.txt for search engine crawlers.""" return FileResponse(settings.frontend_dir / "robots.txt", media_type="text/plain") @app.get("/sitemap.xml") async def sitemap_xml(): """Serve sitemap.xml for search engine indexing.""" return FileResponse(settings.frontend_dir / "sitemap.xml", media_type="application/xml") # Mount static files directly (must be after all routes to avoid catching API calls) if settings.frontend_dir.exists(): app.mount("/static", StaticFiles(directory=settings.frontend_dir), name="static") if __name__ == "__main__": import uvicorn uvicorn.run(app, host=settings.host, port=settings.port)