Files
school_compare/backend/app.py
Tudor 6597ee40fb
All checks were successful
Build and Push Docker Image / build-and-push (push) Successful in 58s
bug fixing
2026-01-10 11:40:02 +00:00

630 lines
21 KiB
Python

"""
SchoolCompare.co.uk API
Serves primary school (KS2) performance data for comparing schools.
Uses real data from UK Government Compare School Performance downloads.
"""
import re
from contextlib import asynccontextmanager
from typing import Optional
import numpy as np
import pandas as pd
from fastapi import FastAPI, HTTPException, Query, Request, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from starlette.middleware.base import BaseHTTPMiddleware
from .config import settings
from .data_loader import (
clear_cache,
load_school_data,
geocode_single_postcode,
)
from .data_loader import get_data_info as get_db_info
from .database import init_db
from .schemas import METRIC_DEFINITIONS, RANKING_COLUMNS, SCHOOL_COLUMNS
from .utils import clean_for_json
# =============================================================================
# SECURITY MIDDLEWARE & HELPERS
# =============================================================================
# Rate limiter
limiter = Limiter(key_func=get_remote_address)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses."""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Prevent clickjacking
response.headers["X-Frame-Options"] = "DENY"
# Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# XSS Protection (legacy browsers)
response.headers["X-XSS-Protection"] = "1; mode=block"
# Referrer policy
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions policy (restrict browser features)
response.headers["Permissions-Policy"] = (
"geolocation=(), microphone=(), camera=(), payment=()"
)
# Content Security Policy
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://unpkg.com; "
"style-src 'self' 'unsafe-inline' https://fonts.googleapis.com https://cdn.jsdelivr.net https://unpkg.com; "
"font-src 'self' https://fonts.gstatic.com; "
"img-src 'self' data: https://*.tile.openstreetmap.org https://unpkg.com; "
"connect-src 'self' https://cdn.jsdelivr.net https://*.tile.openstreetmap.org https://unpkg.com; "
"frame-ancestors 'none'; "
"base-uri 'self'; "
"form-action 'self';"
)
# HSTS (only enable if using HTTPS in production)
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
return response
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
"""Limit request body size to prevent DoS attacks."""
async def dispatch(self, request: Request, call_next):
content_length = request.headers.get("content-length")
if content_length:
if int(content_length) > settings.max_request_size:
return Response(
content="Request too large",
status_code=413,
)
return await call_next(request)
def verify_admin_api_key(x_api_key: str = Header(None)) -> bool:
"""Verify admin API key for protected endpoints."""
if not x_api_key or x_api_key != settings.admin_api_key:
raise HTTPException(
status_code=401,
detail="Invalid or missing API key",
headers={"WWW-Authenticate": "ApiKey"},
)
return True
# Input validation helpers
def sanitize_search_input(value: Optional[str], max_length: int = 100) -> Optional[str]:
"""Sanitize search input to prevent injection attacks."""
if value is None:
return None
# Strip whitespace and limit length
value = value.strip()[:max_length]
# Remove potentially dangerous characters (allow alphanumeric, spaces, common punctuation)
value = re.sub(r"[^\w\s\-\',\.]", "", value)
return value if value else None
def validate_postcode(postcode: Optional[str]) -> Optional[str]:
"""Validate and normalize UK postcode format."""
if not postcode:
return None
postcode = postcode.strip().upper()
# UK postcode pattern
pattern = r"^[A-Z]{1,2}[0-9][A-Z0-9]?\s*[0-9][A-Z]{2}$"
if not re.match(pattern, postcode):
return None
return postcode
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan - startup and shutdown events."""
# Startup: initialize database and pre-load data
print("Starting up: Initializing database...")
init_db() # Ensure tables exist
print("Loading school data from database...")
df = load_school_data()
if df.empty:
print("Warning: No data in database. Run the migration script to import data.")
else:
print("Data loaded successfully.")
yield # Application runs here
# Shutdown: cleanup if needed
print("Shutting down...")
app = FastAPI(
title="SchoolCompare API",
description="API for comparing primary school (KS2) performance data - schoolcompare.co.uk",
version="2.0.0",
lifespan=lifespan,
# Disable docs in production for security
docs_url="/docs" if settings.debug else None,
redoc_url="/redoc" if settings.debug else None,
openapi_url="/openapi.json" if settings.debug else None,
)
# Add rate limiter
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Security middleware (order matters - these run in reverse order)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(RequestSizeLimitMiddleware)
# CORS middleware - restricted for production
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins,
allow_credentials=False, # Don't allow credentials unless needed
allow_methods=["GET", "POST"], # Only allow needed methods
allow_headers=["Content-Type", "X-API-Key"], # Only allow needed headers
)
@app.get("/")
async def root():
"""Serve the frontend."""
return FileResponse(settings.frontend_dir / "index.html")
@app.get("/compare")
async def serve_compare():
"""Serve the frontend for /compare route (SPA routing)."""
return FileResponse(settings.frontend_dir / "index.html")
@app.get("/rankings")
async def serve_rankings():
"""Serve the frontend for /rankings route (SPA routing)."""
return FileResponse(settings.frontend_dir / "index.html")
@app.get("/api/schools")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_schools(
request: Request,
search: Optional[str] = Query(None, description="Search by school name", max_length=100),
local_authority: Optional[str] = Query(
None, description="Filter by local authority", max_length=100
),
school_type: Optional[str] = Query(None, description="Filter by school type", max_length=100),
postcode: Optional[str] = Query(None, description="Search near postcode", max_length=10),
radius: float = Query(5.0, ge=0.1, le=50, description="Search radius in miles"),
page: int = Query(1, ge=1, le=1000, description="Page number"),
page_size: int = Query(None, ge=1, le=100, description="Results per page"),
):
"""
Get list of unique primary schools with pagination.
Returns paginated results with total count for efficient loading.
Supports location-based search using postcode.
"""
# Sanitize inputs
search = sanitize_search_input(search)
local_authority = sanitize_search_input(local_authority)
school_type = sanitize_search_input(school_type)
postcode = validate_postcode(postcode)
df = load_school_data()
if df.empty:
return {"schools": [], "total": 0, "page": page, "page_size": 0}
# Use configured default if not specified
if page_size is None:
page_size = settings.default_page_size
# Get unique schools (latest year data for each)
latest_year = df.groupby("urn")["year"].max().reset_index()
df_latest = df.merge(latest_year, on=["urn", "year"])
# Calculate trend by comparing to previous year
# Get second-latest year for each school
df_sorted = df.sort_values(["urn", "year"], ascending=[True, False])
df_prev = df_sorted.groupby("urn").nth(1).reset_index()
if not df_prev.empty and "rwm_expected_pct" in df_prev.columns:
prev_rwm = df_prev[["urn", "rwm_expected_pct"]].rename(
columns={"rwm_expected_pct": "prev_rwm_expected_pct"}
)
df_latest = df_latest.merge(prev_rwm, on="urn", how="left")
# Include key result metrics for display on cards
location_cols = ["latitude", "longitude"]
result_cols = [
"year",
"rwm_expected_pct",
"rwm_high_pct",
"prev_rwm_expected_pct",
"reading_expected_pct",
"writing_expected_pct",
"maths_expected_pct",
"total_pupils",
]
available_cols = [
c
for c in SCHOOL_COLUMNS + location_cols + result_cols
if c in df_latest.columns
]
schools_df = df_latest[available_cols].drop_duplicates(subset=["urn"])
# Location-based search (uses pre-geocoded data from database)
search_coords = None
if postcode:
coords = geocode_single_postcode(postcode)
if coords:
search_coords = coords
schools_df = schools_df.copy()
# Filter by distance using pre-geocoded lat/long from database
# Use vectorized haversine calculation for better performance
lat1, lon1 = search_coords
# Handle potential duplicate columns by taking first occurrence
lat_col = schools_df.loc[:, "latitude"]
lon_col = schools_df.loc[:, "longitude"]
if isinstance(lat_col, pd.DataFrame):
lat_col = lat_col.iloc[:, 0]
if isinstance(lon_col, pd.DataFrame):
lon_col = lon_col.iloc[:, 0]
lat2 = lat_col.values
lon2 = lon_col.values
# Vectorized haversine formula
R = 3959 # Earth's radius in miles
lat1_rad = np.radians(lat1)
lat2_rad = np.radians(lat2)
dlat = np.radians(lat2 - lat1)
dlon = np.radians(lon2 - lon1)
a = np.sin(dlat / 2) ** 2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon / 2) ** 2
c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
distances = R * c
# Handle missing coordinates
has_coords = ~(pd.isna(lat_col) | pd.isna(lon_col))
distances = np.where(has_coords.values, distances, float("inf"))
schools_df["distance"] = distances
schools_df = schools_df[schools_df["distance"] <= radius]
schools_df = schools_df.sort_values("distance")
# Apply filters
if search:
search_lower = search.lower()
mask = (
schools_df["school_name"].str.lower().str.contains(search_lower, na=False)
)
if "address" in schools_df.columns:
mask = mask | schools_df["address"].str.lower().str.contains(
search_lower, na=False
)
schools_df = schools_df[mask]
if local_authority:
schools_df = schools_df[
schools_df["local_authority"].str.lower() == local_authority.lower()
]
if school_type:
schools_df = schools_df[
schools_df["school_type"].str.lower() == school_type.lower()
]
# Pagination
total = len(schools_df)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
schools_df = schools_df.iloc[start_idx:end_idx]
return {
"schools": clean_for_json(schools_df),
"total": total,
"page": page,
"page_size": page_size,
"total_pages": (total + page_size - 1) // page_size if page_size > 0 else 0,
"search_location": {"postcode": postcode, "radius": radius}
if search_coords
else None,
}
@app.get("/api/schools/{urn}")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_school_details(request: Request, urn: int):
"""Get detailed KS2 data for a specific primary school across all years."""
# Validate URN range (UK school URNs are 6 digits)
if not (100000 <= urn <= 999999):
raise HTTPException(status_code=400, detail="Invalid URN format")
df = load_school_data()
if df.empty:
raise HTTPException(status_code=404, detail="No data available")
school_data = df[df["urn"] == urn]
if school_data.empty:
raise HTTPException(status_code=404, detail="School not found")
# Sort by year
school_data = school_data.sort_values("year")
# Get latest info for the school
latest = school_data.iloc[-1]
return {
"school_info": {
"urn": urn,
"school_name": latest.get("school_name", ""),
"local_authority": latest.get("local_authority", ""),
"school_type": latest.get("school_type", ""),
"address": latest.get("address", ""),
"religious_denomination": latest.get("religious_denomination", ""),
"age_range": latest.get("age_range", ""),
"latitude": latest.get("latitude"),
"longitude": latest.get("longitude"),
"phase": "Primary",
},
"yearly_data": clean_for_json(school_data),
}
@app.get("/api/compare")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def compare_schools(
request: Request,
urns: str = Query(..., description="Comma-separated URNs", max_length=100)
):
"""Compare multiple primary schools side by side."""
df = load_school_data()
if df.empty:
raise HTTPException(status_code=404, detail="No data available")
try:
urn_list = [int(u.strip()) for u in urns.split(",")]
# Limit number of schools to compare
if len(urn_list) > 10:
raise HTTPException(status_code=400, detail="Maximum 10 schools can be compared")
# Validate URN format
for urn in urn_list:
if not (100000 <= urn <= 999999):
raise HTTPException(status_code=400, detail="Invalid URN format")
except ValueError:
raise HTTPException(status_code=400, detail="Invalid URN format")
comparison_data = df[df["urn"].isin(urn_list)]
if comparison_data.empty:
raise HTTPException(status_code=404, detail="No schools found")
result = {}
for urn in urn_list:
school_data = comparison_data[comparison_data["urn"] == urn].sort_values("year")
if not school_data.empty:
latest = school_data.iloc[-1]
result[str(urn)] = {
"school_info": {
"urn": urn,
"school_name": latest.get("school_name", ""),
"local_authority": latest.get("local_authority", ""),
"address": latest.get("address", ""),
},
"yearly_data": clean_for_json(school_data),
}
return {"comparison": result}
@app.get("/api/filters")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_filter_options(request: Request):
"""Get available filter options (local authorities, school types, years)."""
df = load_school_data()
if df.empty:
return {
"local_authorities": [],
"school_types": [],
"years": [],
}
return {
"local_authorities": sorted(df["local_authority"].dropna().unique().tolist()),
"school_types": sorted(df["school_type"].dropna().unique().tolist()),
"years": sorted(df["year"].dropna().unique().tolist()),
}
@app.get("/api/metrics")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_available_metrics(request: Request):
"""
Get list of available KS2 performance metrics for primary schools.
This is the single source of truth for metric definitions.
Frontend should consume this to avoid duplication.
"""
df = load_school_data()
available = []
for key, info in METRIC_DEFINITIONS.items():
if df.empty or key in df.columns:
available.append({"key": key, **info})
return {"metrics": available}
@app.get("/api/rankings")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_rankings(
request: Request,
metric: str = Query("rwm_expected_pct", description="KS2 metric to rank by", max_length=50),
year: Optional[int] = Query(
None, description="Specific year (defaults to most recent)", ge=2000, le=2100
),
limit: int = Query(20, ge=1, le=100, description="Number of schools to return"),
local_authority: Optional[str] = Query(
None, description="Filter by local authority", max_length=100
),
):
"""Get primary school rankings by a specific KS2 metric."""
# Sanitize local authority input
local_authority = sanitize_search_input(local_authority)
# Validate metric name (only allow alphanumeric and underscore)
if not re.match(r"^[a-z0-9_]+$", metric):
raise HTTPException(status_code=400, detail="Invalid metric name")
df = load_school_data()
if df.empty:
return {"metric": metric, "year": None, "rankings": [], "total": 0}
if metric not in df.columns:
raise HTTPException(status_code=400, detail=f"Metric '{metric}' not available")
# Filter by year
if year:
df = df[df["year"] == year]
else:
# Use most recent year
max_year = df["year"].max()
df = df[df["year"] == max_year]
# Filter by local authority if specified
if local_authority:
df = df[df["local_authority"].str.lower() == local_authority.lower()]
# Sort and rank (exclude rows with no data for this metric)
df = df.dropna(subset=[metric])
total = len(df)
# For progress scores, higher is better. For percentages, higher is also better.
df = df.sort_values(metric, ascending=False).head(limit)
# Return only relevant fields for rankings
available_cols = [c for c in RANKING_COLUMNS if c in df.columns]
df = df[available_cols]
return {
"metric": metric,
"year": int(df["year"].iloc[0]) if not df.empty else None,
"rankings": clean_for_json(df),
"total": total,
}
@app.get("/api/data-info")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_data_info(request: Request):
"""Get information about loaded data."""
# Get info directly from database
db_info = get_db_info()
if db_info["total_schools"] == 0:
return {
"status": "no_data",
"message": "No data in database. Run the migration script: python scripts/migrate_csv_to_db.py",
"data_source": "PostgreSQL",
}
# Also get DataFrame-based stats for backwards compatibility
df = load_school_data()
if df.empty:
return {
"status": "no_data",
"message": "No data available",
"data_source": "PostgreSQL",
}
years = [int(y) for y in sorted(df["year"].unique())]
schools_per_year = {
str(int(k)): int(v)
for k, v in df.groupby("year")["urn"].nunique().to_dict().items()
}
la_counts = {
str(k): int(v)
for k, v in df["local_authority"].value_counts().to_dict().items()
}
return {
"status": "loaded",
"data_source": "PostgreSQL",
"total_records": int(len(df)),
"unique_schools": int(df["urn"].nunique()),
"years_available": years,
"schools_per_year": schools_per_year,
"local_authorities": la_counts,
}
@app.post("/api/admin/reload")
@limiter.limit("5/minute")
async def reload_data(
request: Request,
_: bool = Depends(verify_admin_api_key)
):
"""
Admin endpoint to force data reload (useful after data updates).
Requires X-API-Key header with valid admin API key.
"""
clear_cache()
load_school_data()
return {"status": "reloaded"}
# =============================================================================
# SEO FILES
# =============================================================================
@app.get("/favicon.svg")
async def favicon():
"""Serve favicon."""
return FileResponse(settings.frontend_dir / "favicon.svg", media_type="image/svg+xml")
@app.get("/robots.txt")
async def robots_txt():
"""Serve robots.txt for search engine crawlers."""
return FileResponse(settings.frontend_dir / "robots.txt", media_type="text/plain")
@app.get("/sitemap.xml")
async def sitemap_xml():
"""Serve sitemap.xml for search engine indexing."""
return FileResponse(settings.frontend_dir / "sitemap.xml", media_type="application/xml")
# Mount static files directly (must be after all routes to avoid catching API calls)
if settings.frontend_dir.exists():
app.mount("/static", StaticFiles(directory=settings.frontend_dir), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=settings.host, port=settings.port)