Program/mooc/db/database.py

91 lines
2.6 KiB
Python
Raw Normal View History

2025-01-04 01:13:47 +08:00
from sqlalchemy import create_engine, inspect
2024-12-31 22:27:04 +08:00
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from mooc.core.config import settings
2025-01-04 01:13:47 +08:00
from typing import Generator, Set
import importlib
import pkgutil
from pathlib import Path
import logging
# 配置日志
logger = logging.getLogger(__name__)
2024-12-31 22:27:04 +08:00
# 创建数据库引擎
engine = create_engine(
settings.SQLALCHEMY_DATABASE_URI,
pool_pre_ping=True,
echo=settings.SQLALCHEMY_ECHO
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建基类
Base = declarative_base()
2025-01-04 01:13:47 +08:00
def get_existing_tables() -> Set[str]:
"""获取数据库中已存在的表"""
inspector = inspect(engine)
return set(inspector.get_table_names())
def import_models() -> None:
"""
自动导入所有模型
这确保所有模型类都被正确地注册到Base.metadata
"""
models_path = Path(__file__).parent.parent / "models"
for module_info in pkgutil.iter_modules([str(models_path)]):
importlib.import_module(f"mooc.models.{module_info.name}")
# 导入后立即验证
from mooc.models import verify_all_models
verify_all_models()
def create_missing_tables() -> None:
"""创建缺失的表"""
existing_tables = get_existing_tables()
metadata_tables = set(Base.metadata.tables.keys())
missing_tables = metadata_tables - existing_tables
if missing_tables:
logger.info(f"Creating missing tables: {missing_tables}")
# 只创建缺失的表
for table_name in missing_tables:
if table_name in Base.metadata.tables:
Base.metadata.tables[table_name].create(engine)
else:
logger.info("All tables already exist")
def init_db() -> None:
"""
初始化数据库
1. 导入所有模型并验证
2. 检查并创建缺失的表
"""
try:
# 确保所有模型都被导入并验证
import_models()
logger.info("All models imported successfully")
# 创建缺失的表
create_missing_tables()
logger.info("Database initialization completed successfully")
# 打印所有已注册的表名(用于调试)
from mooc.models import get_all_table_names
logger.debug(f"Registered tables: {get_all_table_names()}")
except Exception as e:
logger.error(f"Database initialization failed: {str(e)}")
raise
2024-12-31 22:27:04 +08:00
2025-01-03 14:37:35 +08:00
def get_db() -> Generator:
"""
获取数据库会话的依赖项
"""
2024-12-31 22:27:04 +08:00
try:
2025-01-03 14:37:35 +08:00
db = SessionLocal()
2024-12-31 22:27:04 +08:00
yield db
finally:
db.close()