91 lines
2.6 KiB
Python
91 lines
2.6 KiB
Python
from sqlalchemy import create_engine, inspect
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker
|
|
from mooc.core.config import settings
|
|
from typing import Generator, Set
|
|
import importlib
|
|
import pkgutil
|
|
from pathlib import Path
|
|
import logging
|
|
|
|
# 配置日志
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 创建数据库引擎
|
|
engine = create_engine(
|
|
settings.SQLALCHEMY_DATABASE_URI,
|
|
pool_pre_ping=True,
|
|
echo=settings.SQLALCHEMY_ECHO
|
|
)
|
|
|
|
# 创建会话工厂
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
# 创建基类
|
|
Base = declarative_base()
|
|
|
|
def get_existing_tables() -> Set[str]:
|
|
"""获取数据库中已存在的表"""
|
|
inspector = inspect(engine)
|
|
return set(inspector.get_table_names())
|
|
|
|
def import_models() -> None:
|
|
"""
|
|
自动导入所有模型
|
|
这确保所有模型类都被正确地注册到Base.metadata
|
|
"""
|
|
models_path = Path(__file__).parent.parent / "models"
|
|
for module_info in pkgutil.iter_modules([str(models_path)]):
|
|
importlib.import_module(f"mooc.models.{module_info.name}")
|
|
|
|
# 导入后立即验证
|
|
from mooc.models import verify_all_models
|
|
verify_all_models()
|
|
|
|
def create_missing_tables() -> None:
|
|
"""创建缺失的表"""
|
|
existing_tables = get_existing_tables()
|
|
metadata_tables = set(Base.metadata.tables.keys())
|
|
missing_tables = metadata_tables - existing_tables
|
|
|
|
if missing_tables:
|
|
logger.info(f"Creating missing tables: {missing_tables}")
|
|
# 只创建缺失的表
|
|
for table_name in missing_tables:
|
|
if table_name in Base.metadata.tables:
|
|
Base.metadata.tables[table_name].create(engine)
|
|
else:
|
|
logger.info("All tables already exist")
|
|
|
|
def init_db() -> None:
|
|
"""
|
|
初始化数据库
|
|
1. 导入所有模型并验证
|
|
2. 检查并创建缺失的表
|
|
"""
|
|
try:
|
|
# 确保所有模型都被导入并验证
|
|
import_models()
|
|
logger.info("All models imported successfully")
|
|
|
|
# 创建缺失的表
|
|
create_missing_tables()
|
|
logger.info("Database initialization completed successfully")
|
|
|
|
# 打印所有已注册的表名(用于调试)
|
|
from mooc.models import get_all_table_names
|
|
logger.debug(f"Registered tables: {get_all_table_names()}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Database initialization failed: {str(e)}")
|
|
raise
|
|
|
|
def get_db() -> Generator:
|
|
"""
|
|
获取数据库会话的依赖项
|
|
"""
|
|
try:
|
|
db = SessionLocal()
|
|
yield db
|
|
finally:
|
|
db.close() |