""" Database connection setup using SQLAlchemy. """ from datetime import datetime from typing import Optional from sqlalchemy import create_engine, inspect from sqlalchemy.orm import sessionmaker, declarative_base from contextlib import contextmanager from .config import settings # Create engine engine = create_engine( settings.database_url, pool_size=10, max_overflow=20, pool_pre_ping=True, # Verify connections before use echo=False, # Set to True for SQL debugging ) # Session factory SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Base class for models Base = declarative_base() def get_db(): """ Dependency for FastAPI routes to get a database session. """ db = SessionLocal() try: yield db finally: db.close() @contextmanager def get_db_session(): """ Context manager for database sessions. Use in non-FastAPI contexts (scripts, etc). """ db = SessionLocal() try: yield db db.commit() except Exception: db.rollback() raise finally: db.close() def init_db(): """ Initialize database - create all tables. """ Base.metadata.create_all(bind=engine) def drop_db(): """ Drop all tables - use with caution! """ Base.metadata.drop_all(bind=engine) def get_db_schema_version() -> Optional[int]: """ Get the current schema version from the database. Returns None if table doesn't exist or no version is set. """ from .models import SchemaVersion # Import here to avoid circular imports # Check if schema_version table exists inspector = inspect(engine) if "schema_version" not in inspector.get_table_names(): return None try: with get_db_session() as db: row = db.query(SchemaVersion).first() return row.version if row else None except Exception: return None def set_db_schema_version(version: int): """ Set/update the schema version in the database. Creates the row if it doesn't exist. """ from .models import SchemaVersion with get_db_session() as db: row = db.query(SchemaVersion).first() if row: row.version = version row.migrated_at = datetime.utcnow() else: db.add(SchemaVersion(id=1, version=version, migrated_at=datetime.utcnow())) def check_and_migrate_if_needed(): """ Check schema version and run migration if needed. Called during application startup. """ from .version import SCHEMA_VERSION from .migration import run_full_migration db_version = get_db_schema_version() if db_version == SCHEMA_VERSION: print(f"Schema version {SCHEMA_VERSION} matches. Fast startup.") # Still ensure tables exist (they should if version matches) init_db() return if db_version is None: print(f"No schema version found. Running initial migration (v{SCHEMA_VERSION})...") else: print(f"Schema mismatch: DB has v{db_version}, code expects v{SCHEMA_VERSION}") print("Running full migration...") try: success = run_full_migration(geocode=False) if success: # Ensure schema_version table exists before writing init_db() set_db_schema_version(SCHEMA_VERSION) print(f"Migration complete. Schema version set to {SCHEMA_VERSION}") else: print("Warning: Migration completed but no data was imported.") init_db() set_db_schema_version(SCHEMA_VERSION) except Exception as e: print(f"FATAL: Migration failed: {e}") print("Application cannot start. Please check database and CSV files.") raise