Introducing Postgresql for persistance
Some checks failed
Build and Push Docker Image / build-and-push (push) Failing after 32s

This commit is contained in:
Tudor Sitaru
2026-01-06 17:15:43 +00:00
parent bd3640d50f
commit 52fbade30c
6 changed files with 492 additions and 213 deletions

View File

@@ -8,8 +8,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
# Install curl for healthcheck
RUN apt-get update && apt-get install -y --no-install-recommends curl \
# Install curl for healthcheck and libpq for PostgreSQL
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
libpq5 \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
@@ -21,6 +23,7 @@ RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY backend/ ./backend/
COPY frontend/ ./frontend/
COPY scripts/ ./scripts/
COPY data/ ./data/
# Expose the application port

View File

@@ -14,17 +14,27 @@ from typing import Optional
from .config import settings
from .schemas import METRIC_DEFINITIONS, RANKING_COLUMNS, SCHOOL_COLUMNS
from .data_loader import load_school_data, clear_cache, geocode_single_postcode, geocode_postcodes_bulk, haversine_distance
from .data_loader import (
load_school_data, clear_cache, geocode_single_postcode,
geocode_postcodes_bulk, haversine_distance, get_data_info as get_db_info
)
from .database import init_db
from .utils import clean_for_json
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan - startup and shutdown events."""
# Startup: pre-load data
print("Starting up: Loading school data...")
load_school_data()
print("Data loaded successfully.")
# Startup: initialize database and pre-load data
print("Starting up: Initializing database...")
init_db() # Ensure tables exist
print("Loading school data from database...")
df = load_school_data()
if df.empty:
print("Warning: No data in database. Run the migration script to import data.")
else:
print("Data loaded successfully.")
yield # Application runs here
@@ -325,13 +335,24 @@ async def get_rankings(
@app.get("/api/data-info")
async def get_data_info():
"""Get information about loaded data."""
# Get info directly from database
db_info = get_db_info()
if db_info["total_schools"] == 0:
return {
"status": "no_data",
"message": "No data in database. Run the migration script: python scripts/migrate_csv_to_db.py",
"data_source": "PostgreSQL",
}
# Also get DataFrame-based stats for backwards compatibility
df = load_school_data()
if df.empty:
return {
"status": "no_data",
"message": "No data files found in data folder. Please download KS2 data from the government website.",
"data_folder": str(settings.data_dir),
"message": "No data available",
"data_source": "PostgreSQL",
}
years = [int(y) for y in sorted(df["year"].unique())]
@@ -340,6 +361,7 @@ async def get_data_info():
return {
"status": "loaded",
"data_source": "PostgreSQL",
"total_records": int(len(df)),
"unique_schools": int(df["urn"].nunique()),
"years_available": years,

View File

@@ -4,7 +4,7 @@ Loads from environment variables and .env file.
"""
from pathlib import Path
from typing import List
from typing import List, Optional
from pydantic_settings import BaseSettings
import os
@@ -20,6 +20,9 @@ class Settings(BaseSettings):
host: str = "0.0.0.0"
port: int = 80
# Database
database_url: str = "postgresql://schoolcompare:schoolcompare@localhost:5432/schoolcompare"
# CORS
allowed_origins: List[str] = ["https://schoolcompare.co.uk", "http://localhost:8000", "http://localhost:3000"]

View File

@@ -1,24 +1,19 @@
"""
Data loading module with optimized pandas operations.
Uses vectorized operations instead of .apply() for performance.
Data loading module that queries from PostgreSQL database.
Provides efficient queries with caching and lazy loading.
"""
import pandas as pd
import numpy as np
from pathlib import Path
from functools import lru_cache
import re
from typing import Optional, Dict, Tuple, List
import requests
from typing import Optional, Dict, Tuple
from sqlalchemy import select, func, and_, or_
from sqlalchemy.orm import joinedload, Session
from .config import settings
from .schemas import (
COLUMN_MAPPINGS,
NUMERIC_COLUMNS,
SCHOOL_TYPE_MAP,
NULL_VALUES,
LA_CODE_TO_NAME,
)
from .database import SessionLocal, get_db_session
from .models import School, SchoolResult
# Cache for postcode geocoding
_postcode_cache: Dict[str, Tuple[float, float]] = {}
@@ -31,17 +26,25 @@ def geocode_postcodes_bulk(postcodes: list) -> Dict[str, Tuple[float, float]]:
"""
results = {}
# Remove invalid postcodes and deduplicate
valid_postcodes = [p.strip().upper() for p in postcodes if p and isinstance(p, str) and len(p.strip()) >= 5]
valid_postcodes = list(set(valid_postcodes))
# Check cache first
uncached = []
for pc in postcodes:
if pc and isinstance(pc, str):
pc_upper = pc.strip().upper()
if pc_upper in _postcode_cache:
results[pc_upper] = _postcode_cache[pc_upper]
elif len(pc_upper) >= 5:
uncached.append(pc_upper)
if not valid_postcodes:
if not uncached:
return results
uncached = list(set(uncached))
# postcodes.io allows max 100 postcodes per request
batch_size = 100
for i in range(0, len(valid_postcodes), batch_size):
batch = valid_postcodes[i:i + batch_size]
for i in range(0, len(uncached), batch_size):
batch = uncached[i:i + batch_size]
try:
response = requests.post(
'https://api.postcodes.io/postcodes',
@@ -57,6 +60,7 @@ def geocode_postcodes_bulk(postcodes: list) -> Dict[str, Tuple[float, float]]:
lon = item['result'].get('longitude')
if lat and lon:
results[pc] = (lat, lon)
_postcode_cache[pc] = (lat, lon)
except Exception as e:
print(f" Warning: Geocoding batch failed: {e}")
@@ -93,189 +97,6 @@ def geocode_single_postcode(postcode: str) -> Optional[Tuple[float, float]]:
return None
def extract_year_from_folder(folder_name: str) -> Optional[int]:
"""Extract the end year from folder name like '2023-2024' -> 2024."""
match = re.search(r'(\d{4})-(\d{4})', folder_name)
if match:
return int(match.group(2))
return None
def parse_numeric_vectorized(series: pd.Series) -> pd.Series:
"""
Vectorized numeric parsing - much faster than .apply().
Handles SUPP, NE, NA, NP, %, etc.
"""
# Convert to string first
str_series = series.astype(str)
# Replace null values with NaN
for null_val in NULL_VALUES:
str_series = str_series.replace(null_val, np.nan)
# Remove % signs
str_series = str_series.str.rstrip('%')
# Convert to numeric
return pd.to_numeric(str_series, errors='coerce')
def create_address_vectorized(df: pd.DataFrame) -> pd.Series:
"""
Vectorized address creation - much faster than .apply().
"""
parts = []
if 'address1' in df.columns:
parts.append(df['address1'].fillna('').astype(str))
if 'town' in df.columns:
parts.append(df['town'].fillna('').astype(str))
if 'postcode' in df.columns:
parts.append(df['postcode'].fillna('').astype(str))
if not parts:
return pd.Series([''] * len(df), index=df.index)
# Combine parts with comma separator, filtering empty strings
result = pd.Series([''] * len(df), index=df.index)
for i, row_idx in enumerate(df.index):
row_parts = [p.iloc[i] if hasattr(p, 'iloc') else p[i] for p in parts]
row_parts = [p for p in row_parts if p and p.strip()]
result.iloc[i] = ', '.join(row_parts)
return result
def create_address_fast(df: pd.DataFrame) -> pd.Series:
"""
Fast vectorized address creation using string concatenation.
"""
addr1 = df.get('address1', pd.Series([''] * len(df))).fillna('').astype(str)
town = df.get('town', pd.Series([''] * len(df))).fillna('').astype(str)
postcode = df.get('postcode', pd.Series([''] * len(df))).fillna('').astype(str)
# Build address with proper separators
result = addr1.str.strip()
# Add town if not empty
town_mask = town.str.strip() != ''
result = result.where(~town_mask, result + ', ' + town.str.strip())
# Add postcode if not empty
postcode_mask = postcode.str.strip() != ''
result = result.where(~postcode_mask, result + ', ' + postcode.str.strip())
# Clean up leading commas
result = result.str.lstrip(', ')
return result
def load_year_data(year_folder: Path, year: int) -> Optional[pd.DataFrame]:
"""Load and process data for a single year."""
ks2_file = year_folder / "england_ks2final.csv"
if not ks2_file.exists():
return None
try:
print(f"Loading data from {ks2_file}")
df = pd.read_csv(ks2_file, low_memory=False)
# Handle column types
if 'LEA' in df.columns and df['LEA'].dtype == 'object':
df['LEA'] = pd.to_numeric(df['LEA'], errors='coerce')
if 'URN' in df.columns and df['URN'].dtype == 'object':
df['URN'] = pd.to_numeric(df['URN'], errors='coerce')
# Filter to schools only (RECTYPE == 1 means school level data)
if 'RECTYPE' in df.columns:
df = df[df['RECTYPE'] == 1].copy()
# Add year and local authority name
df['year'] = year
# Try different column names for LA name
la_name_cols = ['LANAME', 'LA (name)', 'LA_NAME', 'LA NAME']
la_col_found = None
for col in la_name_cols:
if col in df.columns:
la_col_found = col
break
if la_col_found:
df['local_authority'] = df[la_col_found]
elif 'LEA' in df.columns:
# Map LEA codes to names using our mapping
df['local_authority'] = df['LEA'].map(LA_CODE_TO_NAME).fillna(df['LEA'].astype(str))
# Rename columns using mapping
rename_dict = {k: v for k, v in COLUMN_MAPPINGS.items() if k in df.columns}
df = df.rename(columns=rename_dict)
# Create address field (vectorized)
df['address'] = create_address_fast(df)
# Map school type codes to names (vectorized)
if 'school_type_code' in df.columns:
df['school_type'] = df['school_type_code'].map(SCHOOL_TYPE_MAP).fillna('Other')
# Parse numeric columns (vectorized - much faster than .apply())
for col in NUMERIC_COLUMNS:
if col in df.columns:
df[col] = parse_numeric_vectorized(df[col])
# Initialize lat/long columns
df['latitude'] = None
df['longitude'] = None
print(f" Loaded {len(df)} schools for year {year}")
return df
except Exception as e:
print(f"Error loading {ks2_file}: {e}")
return None
@lru_cache(maxsize=1)
def load_school_data() -> pd.DataFrame:
"""
Load and combine all school data from CSV files in year folders.
Uses lru_cache for singleton-like behavior.
"""
all_data = []
data_dir = settings.data_dir
if data_dir.exists():
for year_folder in data_dir.iterdir():
if year_folder.is_dir() and re.match(r'\d{4}-\d{4}', year_folder.name):
year = extract_year_from_folder(year_folder.name)
if year is None:
continue
df = load_year_data(year_folder, year)
if df is not None:
all_data.append(df)
if all_data:
result = pd.concat(all_data, ignore_index=True)
print(f"\nTotal records loaded: {len(result)}")
print(f"Unique schools: {result['urn'].nunique()}")
print(f"Years: {sorted(result['year'].unique())}")
# Note: Geocoding is done lazily when location search is used
# This keeps startup fast
return result
else:
print("No data files found. Creating empty DataFrame.")
return pd.DataFrame()
def clear_cache():
"""Clear the data cache to force reload."""
load_school_data.cache_clear()
def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
"""
Calculate the great circle distance between two points on Earth (in miles).
@@ -296,3 +117,402 @@ def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> fl
return c * r
# =============================================================================
# DATABASE QUERY FUNCTIONS
# =============================================================================
def get_db():
"""Get a database session."""
return SessionLocal()
def get_available_years(db: Session = None) -> List[int]:
"""Get list of available years in the database."""
close_db = db is None
if db is None:
db = get_db()
try:
result = db.query(SchoolResult.year).distinct().order_by(SchoolResult.year).all()
return [r[0] for r in result]
finally:
if close_db:
db.close()
def get_available_local_authorities(db: Session = None) -> List[str]:
"""Get list of available local authorities."""
close_db = db is None
if db is None:
db = get_db()
try:
result = db.query(School.local_authority)\
.filter(School.local_authority.isnot(None))\
.distinct()\
.order_by(School.local_authority)\
.all()
return [r[0] for r in result if r[0]]
finally:
if close_db:
db.close()
def get_available_school_types(db: Session = None) -> List[str]:
"""Get list of available school types."""
close_db = db is None
if db is None:
db = get_db()
try:
result = db.query(School.school_type)\
.filter(School.school_type.isnot(None))\
.distinct()\
.order_by(School.school_type)\
.all()
return [r[0] for r in result if r[0]]
finally:
if close_db:
db.close()
def get_schools_count(db: Session = None) -> int:
"""Get total number of schools."""
close_db = db is None
if db is None:
db = get_db()
try:
return db.query(School).count()
finally:
if close_db:
db.close()
def get_schools(
db: Session,
search: Optional[str] = None,
local_authority: Optional[str] = None,
school_type: Optional[str] = None,
page: int = 1,
page_size: int = 50,
) -> Tuple[List[School], int]:
"""
Get paginated list of schools with optional filters.
Returns (schools, total_count).
"""
query = db.query(School)
# Apply filters
if search:
search_lower = f"%{search.lower()}%"
query = query.filter(
or_(
func.lower(School.school_name).like(search_lower),
func.lower(School.postcode).like(search_lower),
func.lower(School.town).like(search_lower),
)
)
if local_authority:
query = query.filter(func.lower(School.local_authority) == local_authority.lower())
if school_type:
query = query.filter(func.lower(School.school_type) == school_type.lower())
# Get total count
total = query.count()
# Apply pagination
offset = (page - 1) * page_size
schools = query.order_by(School.school_name).offset(offset).limit(page_size).all()
return schools, total
def get_schools_near_location(
db: Session,
latitude: float,
longitude: float,
radius_miles: float = 5.0,
search: Optional[str] = None,
local_authority: Optional[str] = None,
school_type: Optional[str] = None,
page: int = 1,
page_size: int = 50,
) -> Tuple[List[Tuple[School, float]], int]:
"""
Get schools near a location, sorted by distance.
Returns list of (school, distance) tuples and total count.
"""
# Get all schools with coordinates
query = db.query(School).filter(
School.latitude.isnot(None),
School.longitude.isnot(None)
)
# Apply text filters
if search:
search_lower = f"%{search.lower()}%"
query = query.filter(
or_(
func.lower(School.school_name).like(search_lower),
func.lower(School.postcode).like(search_lower),
func.lower(School.town).like(search_lower),
)
)
if local_authority:
query = query.filter(func.lower(School.local_authority) == local_authority.lower())
if school_type:
query = query.filter(func.lower(School.school_type) == school_type.lower())
# Get all matching schools and calculate distances
all_schools = query.all()
schools_with_distance = []
for school in all_schools:
if school.latitude and school.longitude:
dist = haversine_distance(latitude, longitude, school.latitude, school.longitude)
if dist <= radius_miles:
schools_with_distance.append((school, dist))
# Sort by distance
schools_with_distance.sort(key=lambda x: x[1])
total = len(schools_with_distance)
# Paginate
offset = (page - 1) * page_size
paginated = schools_with_distance[offset:offset + page_size]
return paginated, total
def get_school_by_urn(db: Session, urn: int) -> Optional[School]:
"""Get a single school by URN."""
return db.query(School).filter(School.urn == urn).first()
def get_school_results(
db: Session,
urn: int,
years: Optional[List[int]] = None
) -> List[SchoolResult]:
"""Get all results for a school, optionally filtered by years."""
query = db.query(SchoolResult)\
.join(School)\
.filter(School.urn == urn)\
.order_by(SchoolResult.year)
if years:
query = query.filter(SchoolResult.year.in_(years))
return query.all()
def get_rankings(
db: Session,
metric: str,
year: int,
local_authority: Optional[str] = None,
limit: int = 20,
ascending: bool = False,
) -> List[Tuple[School, SchoolResult]]:
"""
Get school rankings for a specific metric and year.
Returns list of (school, result) tuples.
"""
# Build the query
query = db.query(School, SchoolResult)\
.join(SchoolResult)\
.filter(SchoolResult.year == year)
# Filter by local authority
if local_authority:
query = query.filter(func.lower(School.local_authority) == local_authority.lower())
# Get the metric column
metric_column = getattr(SchoolResult, metric, None)
if metric_column is None:
return []
# Filter out nulls and order
query = query.filter(metric_column.isnot(None))
if ascending:
query = query.order_by(metric_column.asc())
else:
query = query.order_by(metric_column.desc())
return query.limit(limit).all()
def get_data_info(db: Session = None) -> dict:
"""Get information about the data in the database."""
close_db = db is None
if db is None:
db = get_db()
try:
school_count = db.query(School).count()
result_count = db.query(SchoolResult).count()
years = get_available_years(db)
local_authorities = get_available_local_authorities(db)
return {
"total_schools": school_count,
"total_results": result_count,
"years_available": years,
"local_authorities_count": len(local_authorities),
"data_source": "PostgreSQL",
}
finally:
if close_db:
db.close()
def school_to_dict(school: School, include_results: bool = False) -> dict:
"""Convert a School model to dictionary."""
data = {
"urn": school.urn,
"school_name": school.school_name,
"local_authority": school.local_authority,
"school_type": school.school_type,
"address": school.address,
"town": school.town,
"postcode": school.postcode,
"latitude": school.latitude,
"longitude": school.longitude,
}
if include_results and school.results:
data["results"] = [result_to_dict(r) for r in school.results]
return data
def result_to_dict(result: SchoolResult) -> dict:
"""Convert a SchoolResult model to dictionary."""
return {
"year": result.year,
"total_pupils": result.total_pupils,
"eligible_pupils": result.eligible_pupils,
# Expected Standard
"rwm_expected_pct": result.rwm_expected_pct,
"reading_expected_pct": result.reading_expected_pct,
"writing_expected_pct": result.writing_expected_pct,
"maths_expected_pct": result.maths_expected_pct,
"gps_expected_pct": result.gps_expected_pct,
"science_expected_pct": result.science_expected_pct,
# Higher Standard
"rwm_high_pct": result.rwm_high_pct,
"reading_high_pct": result.reading_high_pct,
"writing_high_pct": result.writing_high_pct,
"maths_high_pct": result.maths_high_pct,
"gps_high_pct": result.gps_high_pct,
# Progress
"reading_progress": result.reading_progress,
"writing_progress": result.writing_progress,
"maths_progress": result.maths_progress,
# Averages
"reading_avg_score": result.reading_avg_score,
"maths_avg_score": result.maths_avg_score,
"gps_avg_score": result.gps_avg_score,
# Context
"disadvantaged_pct": result.disadvantaged_pct,
"eal_pct": result.eal_pct,
"sen_support_pct": result.sen_support_pct,
"sen_ehcp_pct": result.sen_ehcp_pct,
"stability_pct": result.stability_pct,
# Gender
"rwm_expected_boys_pct": result.rwm_expected_boys_pct,
"rwm_expected_girls_pct": result.rwm_expected_girls_pct,
"rwm_high_boys_pct": result.rwm_high_boys_pct,
"rwm_high_girls_pct": result.rwm_high_girls_pct,
# Disadvantaged
"rwm_expected_disadvantaged_pct": result.rwm_expected_disadvantaged_pct,
"rwm_expected_non_disadvantaged_pct": result.rwm_expected_non_disadvantaged_pct,
"disadvantaged_gap": result.disadvantaged_gap,
# 3-Year
"rwm_expected_3yr_pct": result.rwm_expected_3yr_pct,
"reading_avg_3yr": result.reading_avg_3yr,
"maths_avg_3yr": result.maths_avg_3yr,
}
# =============================================================================
# LEGACY COMPATIBILITY - DataFrame-based functions
# =============================================================================
def load_school_data_as_dataframe(db: Session = None) -> pd.DataFrame:
"""
Load all school data as a pandas DataFrame.
For compatibility with existing code that expects DataFrames.
"""
close_db = db is None
if db is None:
db = get_db()
try:
# Query all schools with their results
schools = db.query(School).options(joinedload(School.results)).all()
rows = []
for school in schools:
for result in school.results:
row = {
"urn": school.urn,
"school_name": school.school_name,
"local_authority": school.local_authority,
"school_type": school.school_type,
"address": school.address,
"town": school.town,
"postcode": school.postcode,
"latitude": school.latitude,
"longitude": school.longitude,
**result_to_dict(result)
}
rows.append(row)
if rows:
return pd.DataFrame(rows)
return pd.DataFrame()
finally:
if close_db:
db.close()
# Cache for DataFrame (legacy compatibility)
_df_cache: Optional[pd.DataFrame] = None
def load_school_data() -> pd.DataFrame:
"""
Legacy function to load school data as DataFrame.
Uses caching for performance.
"""
global _df_cache
if _df_cache is not None:
return _df_cache
print("Loading school data from database...")
_df_cache = load_school_data_as_dataframe()
if not _df_cache.empty:
print(f"Total records loaded: {len(_df_cache)}")
print(f"Unique schools: {_df_cache['urn'].nunique()}")
print(f"Years: {sorted(_df_cache['year'].unique())}")
else:
print("No data found in database")
return _df_cache
def clear_cache():
"""Clear all caches."""
global _df_cache
_df_cache = None

View File

@@ -1,16 +1,44 @@
services:
db:
image: postgres:16-alpine
container_name: schoolcompare_db
environment:
POSTGRES_USER: schoolcompare
POSTGRES_PASSWORD: schoolcompare
POSTGRES_DB: schoolcompare
volumes:
- postgres_data:/var/lib/postgresql/data
ports:
- "5432:5432"
restart: unless-stopped
healthcheck:
test: ["CMD-SHELL", "pg_isready -U schoolcompare"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
app:
build: .
container_name: schoolcompare_app
ports:
- "80:80"
environment:
DATABASE_URL: postgresql://schoolcompare:schoolcompare@db:5432/schoolcompare
volumes:
# Mount data directory for easy updates without rebuilding
# Mount data directory for migrations
- ./data:/app/data:ro
depends_on:
db:
condition: service_healthy
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:80/api/data-info"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
start_period: 30s
volumes:
postgres_data:

View File

@@ -5,4 +5,7 @@ python-multipart==0.0.6
aiofiles==23.2.1
pydantic-settings==2.1.0
requests==2.31.0
sqlalchemy==2.0.25
psycopg2-binary==2.9.9
alembic==1.13.1