Files
school_compare/backend/app.py
Tudor ca351e9d73 feat: migrate backend to marts schema, update EES tap for verified datasets
Pipeline:
- EES tap: split KS4 into performance + info streams, fix admissions filename
  (SchoolLevel keyword match), fix census filename (yearly suffix), remove
  phonics (no school-level data on EES), change endswith → in for matching
- stg_ees_ks4: rewrite to filter long-format data and extract Attainment 8,
  Progress 8, EBacc, English/Maths metrics; join KS4 info for context
- stg_ees_admissions: map real CSV columns (total_number_places_offered, etc.)
- stg_ees_census: update source reference, stub with TODO for data columns
- Remove stg_ees_phonics, fact_phonics (no school-level EES data)
- Add ees_ks4_performance + ees_ks4_info sources, remove ees_ks4 + ees_phonics
- Update int_ks4_with_lineage + fact_ks4_performance with new KS4 columns
- Annual EES DAG: remove stg_ees_phonics+ from selector

Backend:
- models.py: replace all models to point at marts.* tables with schema='marts'
  (DimSchool, DimLocation, KS2Performance, FactOfstedInspection, etc.)
- data_loader.py: rewrite load_school_data_as_dataframe() using raw SQL joining
  dim_school + dim_location + fact_ks2_performance; update get_supplementary_data()
- database.py: remove migration machinery, keep only connection setup
- app.py: remove check_and_migrate_if_needed, remove /api/admin/reimport-ks2
  endpoints (pipeline handles all imports)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-27 09:29:27 +00:00

665 lines
22 KiB
Python

"""
SchoolCompare.co.uk API
Serves primary school (KS2) performance data for comparing schools.
Uses real data from UK Government Compare School Performance downloads.
"""
import re
from contextlib import asynccontextmanager
from typing import Optional
import numpy as np
import pandas as pd
from fastapi import FastAPI, HTTPException, Query, Request, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from starlette.middleware.base import BaseHTTPMiddleware
import asyncio
from .config import settings
from .data_loader import (
clear_cache,
load_school_data,
geocode_single_postcode,
get_supplementary_data,
)
from .data_loader import get_data_info as get_db_info
from .schemas import METRIC_DEFINITIONS, RANKING_COLUMNS, SCHOOL_COLUMNS
from .utils import clean_for_json
# =============================================================================
# SECURITY MIDDLEWARE & HELPERS
# =============================================================================
# Rate limiter
limiter = Limiter(key_func=get_remote_address)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses."""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Prevent clickjacking
response.headers["X-Frame-Options"] = "DENY"
# Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# XSS Protection (legacy browsers)
response.headers["X-XSS-Protection"] = "1; mode=block"
# Referrer policy
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions policy (restrict browser features)
response.headers["Permissions-Policy"] = (
"geolocation=(), microphone=(), camera=(), payment=()"
)
# Content Security Policy
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://unpkg.com https://analytics.schoolcompare.co.uk; "
"style-src 'self' 'unsafe-inline' https://fonts.googleapis.com https://cdn.jsdelivr.net https://unpkg.com; "
"font-src 'self' https://fonts.gstatic.com; "
"img-src 'self' data: https://*.tile.openstreetmap.org https://unpkg.com; "
"connect-src 'self' https://cdn.jsdelivr.net https://*.tile.openstreetmap.org https://unpkg.com https://analytics.schoolcompare.co.uk; "
"frame-ancestors 'none'; "
"base-uri 'self'; "
"form-action 'self' https://formsubmit.co;"
)
# HSTS (only enable if using HTTPS in production)
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
return response
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
"""Limit request body size to prevent DoS attacks."""
async def dispatch(self, request: Request, call_next):
content_length = request.headers.get("content-length")
if content_length:
if int(content_length) > settings.max_request_size:
return Response(
content="Request too large",
status_code=413,
)
return await call_next(request)
def verify_admin_api_key(x_api_key: str = Header(None)) -> bool:
"""Verify admin API key for protected endpoints."""
if not x_api_key or x_api_key != settings.admin_api_key:
raise HTTPException(
status_code=401,
detail="Invalid or missing API key",
headers={"WWW-Authenticate": "ApiKey"},
)
return True
# Input validation helpers
def sanitize_search_input(value: Optional[str], max_length: int = 100) -> Optional[str]:
"""Sanitize search input to prevent injection attacks."""
if value is None:
return None
# Strip whitespace and limit length
value = value.strip()[:max_length]
# Remove potentially dangerous characters (allow alphanumeric, spaces, common punctuation)
value = re.sub(r"[^\w\s\-\',\.]", "", value)
return value if value else None
def validate_postcode(postcode: Optional[str]) -> Optional[str]:
"""Validate and normalize UK postcode format."""
if not postcode:
return None
postcode = postcode.strip().upper()
# UK postcode pattern
pattern = r"^[A-Z]{1,2}[0-9][A-Z0-9]?\s*[0-9][A-Z]{2}$"
if not re.match(pattern, postcode):
return None
return postcode
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan - startup and shutdown events."""
print("Loading school data from marts...")
df = load_school_data()
if df.empty:
print("Warning: No data in marts. Run the annual EES pipeline to populate KS2 data.")
else:
print(f"Data loaded successfully: {len(df)} records.")
yield
print("Shutting down...")
app = FastAPI(
title="SchoolCompare API",
description="API for comparing primary school (KS2) performance data - schoolcompare.co.uk",
version="2.0.0",
lifespan=lifespan,
# Disable docs in production for security
docs_url="/docs" if settings.debug else None,
redoc_url="/redoc" if settings.debug else None,
openapi_url="/openapi.json" if settings.debug else None,
)
# Add rate limiter
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Security middleware (order matters - these run in reverse order)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(RequestSizeLimitMiddleware)
# CORS middleware - restricted for production
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins,
allow_credentials=False, # Don't allow credentials unless needed
allow_methods=["GET", "POST"], # Only allow needed methods
allow_headers=["Content-Type", "X-API-Key"], # Only allow needed headers
)
@app.get("/")
async def root():
"""Serve the frontend."""
return FileResponse(settings.frontend_dir / "index.html")
@app.get("/compare")
async def serve_compare():
"""Serve the frontend for /compare route (SPA routing)."""
return FileResponse(settings.frontend_dir / "index.html")
@app.get("/rankings")
async def serve_rankings():
"""Serve the frontend for /rankings route (SPA routing)."""
return FileResponse(settings.frontend_dir / "index.html")
@app.get("/api/config")
async def get_config():
"""Return public configuration for the frontend."""
return {
"ga_measurement_id": settings.ga_measurement_id
}
@app.get("/api/schools")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_schools(
request: Request,
search: Optional[str] = Query(None, description="Search by school name", max_length=100),
local_authority: Optional[str] = Query(
None, description="Filter by local authority", max_length=100
),
school_type: Optional[str] = Query(None, description="Filter by school type", max_length=100),
postcode: Optional[str] = Query(None, description="Search near postcode", max_length=10),
radius: float = Query(5.0, ge=0.1, le=50, description="Search radius in miles"),
page: int = Query(1, ge=1, le=1000, description="Page number"),
page_size: int = Query(None, ge=1, le=100, description="Results per page"),
):
"""
Get list of unique primary schools with pagination.
Returns paginated results with total count for efficient loading.
Supports location-based search using postcode.
"""
# Sanitize inputs
search = sanitize_search_input(search)
local_authority = sanitize_search_input(local_authority)
school_type = sanitize_search_input(school_type)
postcode = validate_postcode(postcode)
df = load_school_data()
if df.empty:
return {"schools": [], "total": 0, "page": page, "page_size": 0}
# Use configured default if not specified
if page_size is None:
page_size = settings.default_page_size
# Get unique schools (latest year data for each)
latest_year = df.groupby("urn")["year"].max().reset_index()
df_latest = df.merge(latest_year, on=["urn", "year"])
# Calculate trend by comparing to previous year
# Get second-latest year for each school
df_sorted = df.sort_values(["urn", "year"], ascending=[True, False])
df_prev = df_sorted.groupby("urn").nth(1).reset_index()
if not df_prev.empty and "rwm_expected_pct" in df_prev.columns:
prev_rwm = df_prev[["urn", "rwm_expected_pct"]].rename(
columns={"rwm_expected_pct": "prev_rwm_expected_pct"}
)
df_latest = df_latest.merge(prev_rwm, on="urn", how="left")
# Include key result metrics for display on cards
location_cols = ["latitude", "longitude"]
result_cols = [
"year",
"rwm_expected_pct",
"rwm_high_pct",
"prev_rwm_expected_pct",
"reading_expected_pct",
"writing_expected_pct",
"maths_expected_pct",
"total_pupils",
]
available_cols = [
c
for c in SCHOOL_COLUMNS + location_cols + result_cols
if c in df_latest.columns
]
schools_df = df_latest[available_cols].drop_duplicates(subset=["urn"])
# Location-based search (uses pre-geocoded data from database)
search_coords = None
if postcode:
coords = geocode_single_postcode(postcode)
if coords:
search_coords = coords
schools_df = schools_df.copy()
# Filter by distance using pre-geocoded lat/long from database
# Use vectorized haversine calculation for better performance
lat1, lon1 = search_coords
# Handle potential duplicate columns by taking first occurrence
lat_col = schools_df.loc[:, "latitude"]
lon_col = schools_df.loc[:, "longitude"]
if isinstance(lat_col, pd.DataFrame):
lat_col = lat_col.iloc[:, 0]
if isinstance(lon_col, pd.DataFrame):
lon_col = lon_col.iloc[:, 0]
lat2 = lat_col.values
lon2 = lon_col.values
# Vectorized haversine formula
R = 3959 # Earth's radius in miles
lat1_rad = np.radians(lat1)
lat2_rad = np.radians(lat2)
dlat = np.radians(lat2 - lat1)
dlon = np.radians(lon2 - lon1)
a = np.sin(dlat / 2) ** 2 + np.cos(lat1_rad) * np.cos(lat2_rad) * np.sin(dlon / 2) ** 2
c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
distances = R * c
# Handle missing coordinates
has_coords = ~(pd.isna(lat_col) | pd.isna(lon_col))
distances = np.where(has_coords.values, distances, float("inf"))
schools_df["distance"] = distances
schools_df = schools_df[schools_df["distance"] <= radius]
schools_df = schools_df.sort_values("distance")
# Apply filters
if search:
search_lower = search.lower()
mask = (
schools_df["school_name"].str.lower().str.contains(search_lower, na=False)
)
if "address" in schools_df.columns:
mask = mask | schools_df["address"].str.lower().str.contains(
search_lower, na=False
)
schools_df = schools_df[mask]
if local_authority:
schools_df = schools_df[
schools_df["local_authority"].str.lower() == local_authority.lower()
]
if school_type:
schools_df = schools_df[
schools_df["school_type"].str.lower() == school_type.lower()
]
# Pagination
total = len(schools_df)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
schools_df = schools_df.iloc[start_idx:end_idx]
return {
"schools": clean_for_json(schools_df),
"total": total,
"page": page,
"page_size": page_size,
"total_pages": (total + page_size - 1) // page_size if page_size > 0 else 0,
"location_info": {
"postcode": postcode,
"radius": radius * 1.60934, # Convert miles to km for frontend display
"coordinates": [search_coords[0], search_coords[1]]
}
if search_coords
else None,
}
@app.get("/api/schools/{urn}")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_school_details(request: Request, urn: int):
"""Get detailed KS2 data for a specific primary school across all years."""
# Validate URN range (UK school URNs are 6 digits)
if not (100000 <= urn <= 999999):
raise HTTPException(status_code=400, detail="Invalid URN format")
df = load_school_data()
if df.empty:
raise HTTPException(status_code=404, detail="No data available")
school_data = df[df["urn"] == urn]
if school_data.empty:
raise HTTPException(status_code=404, detail="School not found")
# Sort by year
school_data = school_data.sort_values("year")
# Get latest info for the school
latest = school_data.iloc[-1]
# Fetch supplementary data (Ofsted, Parent View, admissions, etc.)
from .database import SessionLocal
supplementary = {}
try:
db = SessionLocal()
supplementary = get_supplementary_data(db, urn)
db.close()
except Exception:
pass
return {
"school_info": {
"urn": urn,
"school_name": latest.get("school_name", ""),
"local_authority": latest.get("local_authority", ""),
"school_type": latest.get("school_type", ""),
"address": latest.get("address", ""),
"religious_denomination": latest.get("religious_denomination", ""),
"age_range": latest.get("age_range", ""),
"latitude": latest.get("latitude"),
"longitude": latest.get("longitude"),
"phase": "Primary",
# GIAS fields
"website": latest.get("website"),
"headteacher_name": latest.get("headteacher_name"),
"capacity": latest.get("capacity"),
"trust_name": latest.get("trust_name"),
"gender": latest.get("gender"),
},
"yearly_data": clean_for_json(school_data),
# Supplementary data (null if not yet populated by Kestra)
"ofsted": supplementary.get("ofsted"),
"parent_view": supplementary.get("parent_view"),
"census": supplementary.get("census"),
"admissions": supplementary.get("admissions"),
"sen_detail": supplementary.get("sen_detail"),
"phonics": supplementary.get("phonics"),
"deprivation": supplementary.get("deprivation"),
"finance": supplementary.get("finance"),
}
@app.get("/api/compare")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def compare_schools(
request: Request,
urns: str = Query(..., description="Comma-separated URNs", max_length=100)
):
"""Compare multiple primary schools side by side."""
df = load_school_data()
if df.empty:
raise HTTPException(status_code=404, detail="No data available")
try:
urn_list = [int(u.strip()) for u in urns.split(",")]
# Limit number of schools to compare
if len(urn_list) > 10:
raise HTTPException(status_code=400, detail="Maximum 10 schools can be compared")
# Validate URN format
for urn in urn_list:
if not (100000 <= urn <= 999999):
raise HTTPException(status_code=400, detail="Invalid URN format")
except ValueError:
raise HTTPException(status_code=400, detail="Invalid URN format")
comparison_data = df[df["urn"].isin(urn_list)]
if comparison_data.empty:
raise HTTPException(status_code=404, detail="No schools found")
result = {}
for urn in urn_list:
school_data = comparison_data[comparison_data["urn"] == urn].sort_values("year")
if not school_data.empty:
latest = school_data.iloc[-1]
result[str(urn)] = {
"school_info": {
"urn": urn,
"school_name": latest.get("school_name", ""),
"local_authority": latest.get("local_authority", ""),
"address": latest.get("address", ""),
},
"yearly_data": clean_for_json(school_data),
}
return {"comparison": result}
@app.get("/api/filters")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_filter_options(request: Request):
"""Get available filter options (local authorities, school types, years)."""
df = load_school_data()
if df.empty:
return {
"local_authorities": [],
"school_types": [],
"years": [],
}
return {
"local_authorities": sorted(df["local_authority"].dropna().unique().tolist()),
"school_types": sorted(df["school_type"].dropna().unique().tolist()),
"years": sorted(df["year"].dropna().unique().tolist()),
}
@app.get("/api/metrics")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_available_metrics(request: Request):
"""
Get list of available KS2 performance metrics for primary schools.
This is the single source of truth for metric definitions.
Frontend should consume this to avoid duplication.
"""
df = load_school_data()
available = []
for key, info in METRIC_DEFINITIONS.items():
if df.empty or key in df.columns:
available.append({"key": key, **info})
return {"metrics": available}
@app.get("/api/rankings")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_rankings(
request: Request,
metric: str = Query("rwm_expected_pct", description="KS2 metric to rank by", max_length=50),
year: Optional[int] = Query(
None, description="Specific year (defaults to most recent)", ge=2000, le=2100
),
limit: int = Query(20, ge=1, le=100, description="Number of schools to return"),
local_authority: Optional[str] = Query(
None, description="Filter by local authority", max_length=100
),
):
"""Get primary school rankings by a specific KS2 metric."""
# Sanitize local authority input
local_authority = sanitize_search_input(local_authority)
# Validate metric name (only allow alphanumeric and underscore)
if not re.match(r"^[a-z0-9_]+$", metric):
raise HTTPException(status_code=400, detail="Invalid metric name")
df = load_school_data()
if df.empty:
return {"metric": metric, "year": None, "rankings": [], "total": 0}
if metric not in df.columns:
raise HTTPException(status_code=400, detail=f"Metric '{metric}' not available")
# Filter by year
if year:
df = df[df["year"] == year]
else:
# Use most recent year
max_year = df["year"].max()
df = df[df["year"] == max_year]
# Filter by local authority if specified
if local_authority:
df = df[df["local_authority"].str.lower() == local_authority.lower()]
# Sort and rank (exclude rows with no data for this metric)
df = df.dropna(subset=[metric])
total = len(df)
# For progress scores, higher is better. For percentages, higher is also better.
df = df.sort_values(metric, ascending=False).head(limit)
# Return only relevant fields for rankings
available_cols = [c for c in RANKING_COLUMNS if c in df.columns]
df = df[available_cols]
return {
"metric": metric,
"year": int(df["year"].iloc[0]) if not df.empty else None,
"rankings": clean_for_json(df),
"total": total,
}
@app.get("/api/data-info")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_data_info(request: Request):
"""Get information about loaded data."""
# Get info directly from database
db_info = get_db_info()
if db_info["total_schools"] == 0:
return {
"status": "no_data",
"message": "No data in marts. Run the annual EES pipeline to load KS2 data.",
"data_source": "PostgreSQL",
}
# Also get DataFrame-based stats for backwards compatibility
df = load_school_data()
if df.empty:
return {
"status": "no_data",
"message": "No data available",
"data_source": "PostgreSQL",
}
years = [int(y) for y in sorted(df["year"].unique())]
schools_per_year = {
str(int(k)): int(v)
for k, v in df.groupby("year")["urn"].nunique().to_dict().items()
}
la_counts = {
str(k): int(v)
for k, v in df["local_authority"].value_counts().to_dict().items()
}
return {
"status": "loaded",
"data_source": "PostgreSQL",
"total_records": int(len(df)),
"unique_schools": int(df["urn"].nunique()),
"years_available": years,
"schools_per_year": schools_per_year,
"local_authorities": la_counts,
}
@app.post("/api/admin/reload")
@limiter.limit("5/minute")
async def reload_data(
request: Request,
_: bool = Depends(verify_admin_api_key)
):
"""
Admin endpoint to force data reload (useful after data updates).
Requires X-API-Key header with valid admin API key.
"""
clear_cache()
load_school_data()
return {"status": "reloaded"}
# =============================================================================
# SEO FILES
# =============================================================================
@app.get("/favicon.svg")
async def favicon():
"""Serve favicon."""
return FileResponse(settings.frontend_dir / "favicon.svg", media_type="image/svg+xml")
@app.get("/robots.txt")
async def robots_txt():
"""Serve robots.txt for search engine crawlers."""
return FileResponse(settings.frontend_dir / "robots.txt", media_type="text/plain")
@app.get("/sitemap.xml")
async def sitemap_xml():
"""Serve sitemap.xml for search engine indexing."""
return FileResponse(settings.frontend_dir / "sitemap.xml", media_type="application/xml")
# Mount static files directly (must be after all routes to avoid catching API calls)
if settings.frontend_dir.exists():
app.mount("/static", StaticFiles(directory=settings.frontend_dir), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=settings.host, port=settings.port)