security improvements
All checks were successful
Build and Push Docker Image / build-and-push (push) Successful in 1m10s

This commit is contained in:
Tudor
2026-01-07 16:20:49 +00:00
parent 9af8d471a6
commit 24ab4593f3
7 changed files with 295 additions and 32 deletions

45
.env Normal file
View File

@@ -0,0 +1,45 @@
# SchoolCompare Environment Configuration
# Copy this file to .env and update the values
# =============================================================================
# DATABASE
# =============================================================================
# PostgreSQL connection string
DATABASE_URL=postgresql://schoolcompare:CHANGE_THIS_PASSWORD@localhost:5432/schoolcompare
# =============================================================================
# SERVER
# =============================================================================
# Set to False in production
DEBUG=False
# Server host and port
HOST=0.0.0.0
PORT=80
# =============================================================================
# CORS
# =============================================================================
# Comma-separated list of allowed origins
# In production, only include your actual domain
ALLOWED_ORIGINS=["https://schoolcompare.co.uk"]
# =============================================================================
# SECURITY
# =============================================================================
# Admin API key for protected endpoints (e.g., /api/admin/reload)
# Generate a secure random key: python -c "import secrets; print(secrets.token_urlsafe(32))"
ADMIN_API_KEY=CHANGE_THIS_TO_A_SECURE_RANDOM_KEY
# Rate limiting (requests per minute per IP)
RATE_LIMIT_PER_MINUTE=60
RATE_LIMIT_BURST=10
# Maximum request body size in bytes (default 1MB)
MAX_REQUEST_SIZE=1048576
# =============================================================================
# API
# =============================================================================
DEFAULT_PAGE_SIZE=50
MAX_PAGE_SIZE=100

45
.env.example Normal file
View File

@@ -0,0 +1,45 @@
# SchoolCompare Environment Configuration
# Copy this file to .env and update the values
# =============================================================================
# DATABASE
# =============================================================================
# PostgreSQL connection string
DATABASE_URL=postgresql://schoolcompare:CHANGE_THIS_PASSWORD@localhost:5432/schoolcompare
# =============================================================================
# SERVER
# =============================================================================
# Set to False in production
DEBUG=False
# Server host and port
HOST=0.0.0.0
PORT=80
# =============================================================================
# CORS
# =============================================================================
# Comma-separated list of allowed origins
# In production, only include your actual domain
ALLOWED_ORIGINS=["https://schoolcompare.co.uk"]
# =============================================================================
# SECURITY
# =============================================================================
# Admin API key for protected endpoints (e.g., /api/admin/reload)
# Generate a secure random key: python -c "import secrets; print(secrets.token_urlsafe(32))"
ADMIN_API_KEY=CHANGE_THIS_TO_A_SECURE_RANDOM_KEY
# Rate limiting (requests per minute per IP)
RATE_LIMIT_PER_MINUTE=60
RATE_LIMIT_BURST=10
# Maximum request body size in bytes (default 1MB)
MAX_REQUEST_SIZE=1048576
# =============================================================================
# API
# =============================================================================
DEFAULT_PAGE_SIZE=50
MAX_PAGE_SIZE=100

View File

@@ -4,14 +4,19 @@ Serves primary school (KS2) performance data for comparing schools.
Uses real data from UK Government Compare School Performance downloads.
"""
import re
from contextlib import asynccontextmanager
from typing import Optional
import pandas as pd
from fastapi import FastAPI, HTTPException, Query
from fastapi import FastAPI, HTTPException, Query, Request, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from starlette.middleware.base import BaseHTTPMiddleware
from .config import settings
from .data_loader import (
@@ -27,6 +32,107 @@ from .schemas import METRIC_DEFINITIONS, RANKING_COLUMNS, SCHOOL_COLUMNS
from .utils import clean_for_json
# =============================================================================
# SECURITY MIDDLEWARE & HELPERS
# =============================================================================
# Rate limiter
limiter = Limiter(key_func=get_remote_address)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses."""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Prevent clickjacking
response.headers["X-Frame-Options"] = "DENY"
# Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# XSS Protection (legacy browsers)
response.headers["X-XSS-Protection"] = "1; mode=block"
# Referrer policy
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions policy (restrict browser features)
response.headers["Permissions-Policy"] = (
"geolocation=(), microphone=(), camera=(), payment=()"
)
# Content Security Policy
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; "
"style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; "
"font-src 'self' https://fonts.gstatic.com; "
"img-src 'self' data:; "
"connect-src 'self'; "
"frame-ancestors 'none'; "
"base-uri 'self'; "
"form-action 'self';"
)
# HSTS (only enable if using HTTPS in production)
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
return response
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
"""Limit request body size to prevent DoS attacks."""
async def dispatch(self, request: Request, call_next):
content_length = request.headers.get("content-length")
if content_length:
if int(content_length) > settings.max_request_size:
return Response(
content="Request too large",
status_code=413,
)
return await call_next(request)
def verify_admin_api_key(x_api_key: str = Header(None)) -> bool:
"""Verify admin API key for protected endpoints."""
if not x_api_key or x_api_key != settings.admin_api_key:
raise HTTPException(
status_code=401,
detail="Invalid or missing API key",
headers={"WWW-Authenticate": "ApiKey"},
)
return True
# Input validation helpers
def sanitize_search_input(value: Optional[str], max_length: int = 100) -> Optional[str]:
"""Sanitize search input to prevent injection attacks."""
if value is None:
return None
# Strip whitespace and limit length
value = value.strip()[:max_length]
# Remove potentially dangerous characters (allow alphanumeric, spaces, common punctuation)
value = re.sub(r"[^\w\s\-\',\.]", "", value)
return value if value else None
def validate_postcode(postcode: Optional[str]) -> Optional[str]:
"""Validate and normalize UK postcode format."""
if not postcode:
return None
postcode = postcode.strip().upper()
# UK postcode pattern
pattern = r"^[A-Z]{1,2}[0-9][A-Z0-9]?\s*[0-9][A-Z]{2}$"
if not re.match(pattern, postcode):
return None
return postcode
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan - startup and shutdown events."""
@@ -52,15 +158,27 @@ app = FastAPI(
description="API for comparing primary school (KS2) performance data - schoolcompare.co.uk",
version="2.0.0",
lifespan=lifespan,
# Disable docs in production for security
docs_url="/docs" if settings.debug else None,
redoc_url="/redoc" if settings.debug else None,
openapi_url="/openapi.json" if settings.debug else None,
)
# CORS middleware with configurable origins
# Add rate limiter
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Security middleware (order matters - these run in reverse order)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(RequestSizeLimitMiddleware)
# CORS middleware - restricted for production
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=False, # Don't allow credentials unless needed
allow_methods=["GET", "POST"], # Only allow needed methods
allow_headers=["Content-Type", "X-API-Key"], # Only allow needed headers
)
@@ -83,15 +201,17 @@ async def serve_rankings():
@app.get("/api/schools")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_schools(
search: Optional[str] = Query(None, description="Search by school name"),
request: Request,
search: Optional[str] = Query(None, description="Search by school name", max_length=100),
local_authority: Optional[str] = Query(
None, description="Filter by local authority"
None, description="Filter by local authority", max_length=100
),
school_type: Optional[str] = Query(None, description="Filter by school type"),
postcode: Optional[str] = Query(None, description="Search near postcode"),
school_type: Optional[str] = Query(None, description="Filter by school type", max_length=100),
postcode: Optional[str] = Query(None, description="Search near postcode", max_length=10),
radius: float = Query(5.0, ge=0.1, le=50, description="Search radius in miles"),
page: int = Query(1, ge=1, description="Page number"),
page: int = Query(1, ge=1, le=1000, description="Page number"),
page_size: int = Query(None, ge=1, le=100, description="Results per page"),
):
"""
@@ -100,6 +220,12 @@ async def get_schools(
Returns paginated results with total count for efficient loading.
Supports location-based search using postcode.
"""
# Sanitize inputs
search = sanitize_search_input(search)
local_authority = sanitize_search_input(local_authority)
school_type = sanitize_search_input(school_type)
postcode = validate_postcode(postcode)
df = load_school_data()
if df.empty:
@@ -216,8 +342,13 @@ async def get_schools(
@app.get("/api/schools/{urn}")
async def get_school_details(urn: int):
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_school_details(request: Request, urn: int):
"""Get detailed KS2 data for a specific primary school across all years."""
# Validate URN range (UK school URNs are 6 digits)
if not (100000 <= urn <= 999999):
raise HTTPException(status_code=400, detail="Invalid URN format")
df = load_school_data()
if df.empty:
@@ -248,7 +379,11 @@ async def get_school_details(urn: int):
@app.get("/api/compare")
async def compare_schools(urns: str = Query(..., description="Comma-separated URNs")):
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def compare_schools(
request: Request,
urns: str = Query(..., description="Comma-separated URNs", max_length=100)
):
"""Compare multiple primary schools side by side."""
df = load_school_data()
@@ -257,6 +392,13 @@ async def compare_schools(urns: str = Query(..., description="Comma-separated UR
try:
urn_list = [int(u.strip()) for u in urns.split(",")]
# Limit number of schools to compare
if len(urn_list) > 10:
raise HTTPException(status_code=400, detail="Maximum 10 schools can be compared")
# Validate URN format
for urn in urn_list:
if not (100000 <= urn <= 999999):
raise HTTPException(status_code=400, detail="Invalid URN format")
except ValueError:
raise HTTPException(status_code=400, detail="Invalid URN format")
@@ -284,7 +426,8 @@ async def compare_schools(urns: str = Query(..., description="Comma-separated UR
@app.get("/api/filters")
async def get_filter_options():
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_filter_options(request: Request):
"""Get available filter options (local authorities, school types, years)."""
df = load_school_data()
@@ -303,7 +446,8 @@ async def get_filter_options():
@app.get("/api/metrics")
async def get_available_metrics():
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_available_metrics(request: Request):
"""
Get list of available KS2 performance metrics for primary schools.
@@ -321,17 +465,26 @@ async def get_available_metrics():
@app.get("/api/rankings")
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_rankings(
metric: str = Query("rwm_expected_pct", description="KS2 metric to rank by"),
request: Request,
metric: str = Query("rwm_expected_pct", description="KS2 metric to rank by", max_length=50),
year: Optional[int] = Query(
None, description="Specific year (defaults to most recent)"
None, description="Specific year (defaults to most recent)", ge=2000, le=2100
),
limit: int = Query(20, ge=1, le=100, description="Number of schools to return"),
local_authority: Optional[str] = Query(
None, description="Filter by local authority"
None, description="Filter by local authority", max_length=100
),
):
"""Get primary school rankings by a specific KS2 metric."""
# Sanitize local authority input
local_authority = sanitize_search_input(local_authority)
# Validate metric name (only allow alphanumeric and underscore)
if not re.match(r"^[a-z0-9_]+$", metric):
raise HTTPException(status_code=400, detail="Invalid metric name")
df = load_school_data()
if df.empty:
@@ -372,7 +525,8 @@ async def get_rankings(
@app.get("/api/data-info")
async def get_data_info():
@limiter.limit(f"{settings.rate_limit_per_minute}/minute")
async def get_data_info(request: Request):
"""Get information about loaded data."""
# Get info directly from database
db_info = get_db_info()
@@ -416,8 +570,15 @@ async def get_data_info():
@app.post("/api/admin/reload")
async def reload_data():
"""Admin endpoint to force data reload (useful after data updates)."""
@limiter.limit("5/minute")
async def reload_data(
request: Request,
_: bool = Depends(verify_admin_api_key)
):
"""
Admin endpoint to force data reload (useful after data updates).
Requires X-API-Key header with valid admin API key.
"""
clear_cache()
load_school_data()
return {"status": "reloaded"}

View File

@@ -3,33 +3,41 @@ Application configuration using pydantic-settings.
Loads from environment variables and .env file.
"""
import secrets
from pathlib import Path
from typing import List, Optional
from pydantic_settings import BaseSettings
import os
from pydantic import Field
class Settings(BaseSettings):
"""Application settings loaded from environment."""
# Paths
data_dir: Path = Path(__file__).parent.parent / "data"
frontend_dir: Path = Path(__file__).parent.parent / "frontend"
# Server
host: str = "0.0.0.0"
port: int = 80
debug: bool = False # Set to False in production
# Database
database_url: str = "postgresql://schoolcompare:schoolcompare@localhost:5432/schoolcompare"
# CORS
allowed_origins: List[str] = ["https://schoolcompare.co.uk", "http://localhost:8000", "http://localhost:3000"]
# CORS - Production should only allow the actual domain
allowed_origins: List[str] = ["https://schoolcompare.co.uk"]
# API
default_page_size: int = 50
max_page_size: int = 100
# Security
admin_api_key: str = Field(default_factory=lambda: secrets.token_urlsafe(32))
rate_limit_per_minute: int = 60 # Requests per minute per IP
rate_limit_burst: int = 10 # Allow burst of requests
max_request_size: int = 1024 * 1024 # 1MB max request size
class Config:
env_file = ".env"
env_file_encoding = "utf-8"

View File

@@ -15,7 +15,7 @@
<header class="header">
<div class="header-content">
<div class="logo">
<a href="/" class="logo">
<div class="logo-icon">
<svg viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="20" cy="20" r="18" stroke="currentColor" stroke-width="2"/>
@@ -27,7 +27,7 @@
<span class="logo-title">SchoolCompare</span>
<span class="logo-subtitle">schoolcompare.co.uk</span>
</div>
</div>
</a>
<nav class="nav">
<a href="/" class="nav-link active" data-view="dashboard">Dashboard</a>
<a href="/compare" class="nav-link" data-view="compare">Compare</a>

View File

@@ -96,6 +96,8 @@ body {
display: flex;
align-items: center;
gap: 0.75rem;
text-decoration: none;
color: inherit;
}
.logo-icon {

View File

@@ -8,4 +8,6 @@ requests==2.31.0
sqlalchemy==2.0.25
psycopg2-binary==2.9.9
alembic==1.13.1
slowapi==0.1.9
secure==0.3.0