Files
school_compare/backend/migration.py

455 lines
17 KiB
Python
Raw Normal View History

"""
Database migration logic for importing CSV data.
Used by both CLI script and automatic startup migration.
"""
import re
from pathlib import Path
from typing import Dict, Optional
import numpy as np
import pandas as pd
import requests
from .config import settings
from .database import Base, engine, get_db_session
from .models import School, SchoolResult
from .schemas import (
COLUMN_MAPPINGS,
LA_CODE_TO_NAME,
NULL_VALUES,
SCHOOL_TYPE_MAP,
)
def parse_numeric(value) -> Optional[float]:
"""Parse a numeric value, handling special cases."""
if pd.isna(value):
return None
if isinstance(value, (int, float)):
return float(value) if not np.isnan(value) else None
str_val = str(value).strip().upper()
if str_val in NULL_VALUES or str_val == "":
return None
# Remove percentage signs if present
str_val = str_val.replace("%", "")
try:
return float(str_val)
except ValueError:
return None
def extract_year_from_folder(folder_name: str) -> Optional[int]:
"""Extract year from folder name like '2023-2024'."""
match = re.search(r"(\d{4})-(\d{4})", folder_name)
if match:
return int(match.group(2))
match = re.search(r"(\d{4})", folder_name)
if match:
return int(match.group(1))
return None
def geocode_postcodes_bulk(postcodes: list) -> Dict[str, tuple]:
"""
Geocode postcodes in bulk using postcodes.io API.
Returns dict of postcode -> (latitude, longitude).
"""
results = {}
valid_postcodes = [
p.strip().upper()
for p in postcodes
if p and isinstance(p, str) and len(p.strip()) >= 5
]
valid_postcodes = list(set(valid_postcodes))
if not valid_postcodes:
return results
batch_size = 100
total_batches = (len(valid_postcodes) + batch_size - 1) // batch_size
for i, batch_start in enumerate(range(0, len(valid_postcodes), batch_size)):
batch = valid_postcodes[batch_start : batch_start + batch_size]
print(
f" Geocoding batch {i + 1}/{total_batches} ({len(batch)} postcodes)..."
)
try:
response = requests.post(
"https://api.postcodes.io/postcodes",
json={"postcodes": batch},
timeout=30,
)
if response.status_code == 200:
data = response.json()
for item in data.get("result", []):
if item and item.get("result"):
pc = item["query"].upper()
lat = item["result"].get("latitude")
lon = item["result"].get("longitude")
if lat and lon:
results[pc] = (lat, lon)
except Exception as e:
print(f" Warning: Geocoding batch failed: {e}")
return results
def load_csv_data(data_dir: Path) -> pd.DataFrame:
"""Load all CSV data from data directory."""
all_data = []
for folder in sorted(data_dir.iterdir()):
if not folder.is_dir():
continue
year = extract_year_from_folder(folder.name)
if not year:
continue
# Specifically look for the KS2 results file
ks2_file = folder / "england_ks2final.csv"
if not ks2_file.exists():
continue
csv_file = ks2_file
print(f" Loading {csv_file.name} (year {year})...")
try:
df = pd.read_csv(csv_file, encoding="latin-1", low_memory=False)
except Exception as e:
print(f" Error loading {csv_file}: {e}")
continue
# Rename columns
df.rename(columns=COLUMN_MAPPINGS, inplace=True)
df["year"] = year
# Handle local authority name
la_name_cols = ["LANAME", "LA (name)", "LA_NAME", "LA NAME"]
la_name_col = next((c for c in la_name_cols if c in df.columns), None)
if la_name_col and la_name_col != "local_authority":
df["local_authority"] = df[la_name_col]
elif "LEA" in df.columns:
df["local_authority_code"] = pd.to_numeric(df["LEA"], errors="coerce")
df["local_authority"] = (
df["local_authority_code"]
.map(LA_CODE_TO_NAME)
.fillna(df["LEA"].astype(str))
)
# Store LEA code
if "LEA" in df.columns:
df["local_authority_code"] = pd.to_numeric(df["LEA"], errors="coerce")
# Map school type
if "school_type_code" in df.columns:
df["school_type"] = (
df["school_type_code"]
.map(SCHOOL_TYPE_MAP)
.fillna(df["school_type_code"])
)
# Create combined address
addr_parts = ["address1", "address2", "town", "postcode"]
for col in addr_parts:
if col not in df.columns:
df[col] = None
df["address"] = df.apply(
lambda r: ", ".join(
str(v)
for v in [
r.get("address1"),
r.get("address2"),
r.get("town"),
r.get("postcode"),
]
if pd.notna(v) and str(v).strip()
),
axis=1,
)
all_data.append(df)
print(f" Loaded {len(df)} records")
if all_data:
result = pd.concat(all_data, ignore_index=True)
print(f"\nTotal records loaded: {len(result)}")
print(f"Unique schools: {result['urn'].nunique()}")
print(f"Years: {sorted(result['year'].unique())}")
return result
return pd.DataFrame()
def migrate_data(df: pd.DataFrame, geocode: bool = False, geocode_cache: dict = None):
"""Migrate DataFrame data to database."""
if geocode_cache is None:
geocode_cache = {}
# Clean URN column - convert to integer, drop invalid values
df = df.copy()
df["urn"] = pd.to_numeric(df["urn"], errors="coerce")
df = df.dropna(subset=["urn"])
df["urn"] = df["urn"].astype(int)
# Group by URN to get unique schools (use latest year's data)
school_data = (
df.sort_values("year", ascending=False).groupby("urn").first().reset_index()
)
print(f"\nMigrating {len(school_data)} unique schools...")
# Geocode postcodes that aren't already in the cache
geocoded = dict(geocode_cache) # start with preserved coordinates
if geocode and "postcode" in df.columns:
cached_postcodes = {
str(row.get("postcode", "")).strip().upper()
for _, row in school_data.iterrows()
if int(float(str(row.get("urn", 0) or 0))) in geocode_cache
}
postcodes_needed = [
p for p in df["postcode"].dropna().unique()
if str(p).strip().upper() not in cached_postcodes
]
if postcodes_needed:
print(f"\nGeocoding {len(postcodes_needed)} postcodes ({len(geocode_cache)} restored from cache)...")
fresh = geocode_postcodes_bulk(postcodes_needed)
geocoded.update(fresh)
print(f" Successfully geocoded {len(fresh)} new postcodes")
else:
print(f"\nAll {len(geocode_cache)} postcodes restored from cache, skipping geocoding.")
with get_db_session() as db:
# Create schools
urn_to_school_id = {}
schools_created = 0
for _, row in school_data.iterrows():
# Safely parse URN - handle None, NaN, whitespace, and invalid values
urn_val = row.get("urn")
urn = None
if pd.notna(urn_val):
try:
urn_str = str(urn_val).strip()
if urn_str:
urn = int(float(urn_str)) # Handle "12345.0" format
except (ValueError, TypeError):
pass
if not urn:
continue
# Skip if we've already added this URN (handles duplicates in source data)
if urn in urn_to_school_id:
continue
# Get geocoding data
postcode = row.get("postcode")
lat, lon = None, None
if postcode and pd.notna(postcode):
coords = geocoded.get(str(postcode).strip().upper())
if coords:
lat, lon = coords
# Safely parse local_authority_code
la_code = None
la_code_val = row.get("local_authority_code")
if pd.notna(la_code_val):
try:
la_code_str = str(la_code_val).strip()
if la_code_str:
la_code = int(float(la_code_str))
except (ValueError, TypeError):
pass
school = School(
urn=urn,
school_name=row.get("school_name")
if pd.notna(row.get("school_name"))
else "Unknown",
local_authority=row.get("local_authority")
if pd.notna(row.get("local_authority"))
else None,
local_authority_code=la_code,
school_type=row.get("school_type")
if pd.notna(row.get("school_type"))
else None,
school_type_code=row.get("school_type_code")
if pd.notna(row.get("school_type_code"))
else None,
religious_denomination=row.get("religious_denomination")
if pd.notna(row.get("religious_denomination"))
else None,
age_range=row.get("age_range")
if pd.notna(row.get("age_range"))
else None,
address1=row.get("address1") if pd.notna(row.get("address1")) else None,
address2=row.get("address2") if pd.notna(row.get("address2")) else None,
town=row.get("town") if pd.notna(row.get("town")) else None,
postcode=row.get("postcode") if pd.notna(row.get("postcode")) else None,
latitude=lat,
longitude=lon,
)
db.add(school)
db.flush() # Get the ID
urn_to_school_id[urn] = school.id
schools_created += 1
if schools_created % 1000 == 0:
print(f" Created {schools_created} schools...")
print(f" Created {schools_created} schools")
# Create results
print(f"\nMigrating {len(df)} yearly results...")
results_created = 0
for _, row in df.iterrows():
# Safely parse URN
urn_val = row.get("urn")
urn = None
if pd.notna(urn_val):
try:
urn_str = str(urn_val).strip()
if urn_str:
urn = int(float(urn_str))
except (ValueError, TypeError):
pass
if not urn or urn not in urn_to_school_id:
continue
school_id = urn_to_school_id[urn]
# Safely parse year
year_val = row.get("year")
year = None
if pd.notna(year_val):
try:
year = int(float(str(year_val).strip()))
except (ValueError, TypeError):
pass
if not year:
continue
result = SchoolResult(
school_id=school_id,
year=year,
total_pupils=parse_numeric(row.get("total_pupils")),
eligible_pupils=parse_numeric(row.get("eligible_pupils")),
# Expected Standard
rwm_expected_pct=parse_numeric(row.get("rwm_expected_pct")),
reading_expected_pct=parse_numeric(row.get("reading_expected_pct")),
writing_expected_pct=parse_numeric(row.get("writing_expected_pct")),
maths_expected_pct=parse_numeric(row.get("maths_expected_pct")),
gps_expected_pct=parse_numeric(row.get("gps_expected_pct")),
science_expected_pct=parse_numeric(row.get("science_expected_pct")),
# Higher Standard
rwm_high_pct=parse_numeric(row.get("rwm_high_pct")),
reading_high_pct=parse_numeric(row.get("reading_high_pct")),
writing_high_pct=parse_numeric(row.get("writing_high_pct")),
maths_high_pct=parse_numeric(row.get("maths_high_pct")),
gps_high_pct=parse_numeric(row.get("gps_high_pct")),
# Progress
reading_progress=parse_numeric(row.get("reading_progress")),
writing_progress=parse_numeric(row.get("writing_progress")),
maths_progress=parse_numeric(row.get("maths_progress")),
# Averages
reading_avg_score=parse_numeric(row.get("reading_avg_score")),
maths_avg_score=parse_numeric(row.get("maths_avg_score")),
gps_avg_score=parse_numeric(row.get("gps_avg_score")),
# Context
disadvantaged_pct=parse_numeric(row.get("disadvantaged_pct")),
eal_pct=parse_numeric(row.get("eal_pct")),
sen_support_pct=parse_numeric(row.get("sen_support_pct")),
sen_ehcp_pct=parse_numeric(row.get("sen_ehcp_pct")),
stability_pct=parse_numeric(row.get("stability_pct")),
# Absence
reading_absence_pct=parse_numeric(row.get("reading_absence_pct")),
gps_absence_pct=parse_numeric(row.get("gps_absence_pct")),
maths_absence_pct=parse_numeric(row.get("maths_absence_pct")),
writing_absence_pct=parse_numeric(row.get("writing_absence_pct")),
science_absence_pct=parse_numeric(row.get("science_absence_pct")),
# Gender
rwm_expected_boys_pct=parse_numeric(row.get("rwm_expected_boys_pct")),
rwm_expected_girls_pct=parse_numeric(row.get("rwm_expected_girls_pct")),
rwm_high_boys_pct=parse_numeric(row.get("rwm_high_boys_pct")),
rwm_high_girls_pct=parse_numeric(row.get("rwm_high_girls_pct")),
# Disadvantaged
rwm_expected_disadvantaged_pct=parse_numeric(
row.get("rwm_expected_disadvantaged_pct")
),
rwm_expected_non_disadvantaged_pct=parse_numeric(
row.get("rwm_expected_non_disadvantaged_pct")
),
disadvantaged_gap=parse_numeric(row.get("disadvantaged_gap")),
# 3-Year
rwm_expected_3yr_pct=parse_numeric(row.get("rwm_expected_3yr_pct")),
reading_avg_3yr=parse_numeric(row.get("reading_avg_3yr")),
maths_avg_3yr=parse_numeric(row.get("maths_avg_3yr")),
)
db.add(result)
results_created += 1
if results_created % 10000 == 0:
print(f" Created {results_created} results...")
db.flush()
print(f" Created {results_created} results")
# Commit all changes
db.commit()
print("\nMigration complete!")
def run_full_migration(geocode: bool = False) -> bool:
"""
Run a complete migration: drop all tables and reimport from CSV.
Returns True if successful, False if no data found.
Raises exception on error.
"""
# Preserve existing geocoding so a reimport doesn't throw away coordinates
# that took a long time to compute.
geocode_cache: dict[int, tuple[float, float]] = {}
inspector = __import__("sqlalchemy").inspect(engine)
if "schools" in inspector.get_table_names():
try:
with get_db_session() as db:
rows = db.execute(
__import__("sqlalchemy").text(
"SELECT urn, latitude, longitude FROM schools "
"WHERE latitude IS NOT NULL AND longitude IS NOT NULL"
)
).fetchall()
geocode_cache = {r.urn: (r.latitude, r.longitude) for r in rows}
print(f" Saved {len(geocode_cache)} existing geocoded coordinates.")
except Exception as e:
print(f" Warning: could not save geocode cache: {e}")
# Only drop the core KS2 tables — leave supplementary tables (ofsted, census,
# finance, etc.) intact so a reimport doesn't wipe integrator-populated data.
# schema_version is NOT dropped: it persists so restarts don't re-trigger migration.
ks2_tables = ["school_results", "schools"]
print(f"Dropping core tables: {ks2_tables} ...")
inspector = __import__("sqlalchemy").inspect(engine)
existing = set(inspector.get_table_names())
for tname in ks2_tables:
if tname in existing:
Base.metadata.tables[tname].drop(bind=engine)
print("Creating all tables...")
Base.metadata.create_all(bind=engine)
print("\nLoading CSV data...")
df = load_csv_data(settings.data_dir)
if df.empty:
print("Warning: No CSV data found to migrate!")
return False
migrate_data(df, geocode=geocode, geocode_cache=geocode_cache)
return True