Program/mooc/crud/crud_base.py
?..濡.. 8451ad034c 1.统一CRUD操作
2.完成登录部分接口
3.暂时挂载本地图片链接作为头像存储
2025-03-04 20:36:52 +08:00

132 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy import Column
from mooc.db.database import Base
ModelType = TypeVar("ModelType", bound="Base")
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD base class
Args:
model: SQLAlchemy模型类
"""
self.model = model
def get(self, db: Session, id: Any) -> Optional[ModelType]:
"""通过主键获取记录"""
# 检查模型是否有id字段如果没有则使用acid
if hasattr(self.model, 'id'):
return db.query(self.model).filter(self.model.id == id).first()
return db.query(self.model).filter(self.model.acid == id).first()
def get_by_field(
self,
db: Session,
field: str,
value: Any
) -> Optional[ModelType]:
"""
通过任意字段获取单条记录
Args:
db: 数据库会话
field: 字段名
value: 字段值
"""
return db.query(self.model).filter(getattr(self.model, field) == value).first()
def get_by_fields(
self,
db: Session,
filters: Dict[str, Any]
) -> Optional[ModelType]:
"""
通过多个字段AND条件查询获取单条记录
Args:
db: 数据库会话
filters: 字段名和字段值的字典,如 {"field1": value1, "field2": value2}
"""
query = db.query(self.model)
for field, value in filters.items():
query = query.filter(getattr(self.model, field) == value)
return query.first()
def get_multi_by_fields(
self,
db: Session,
filters: Dict[str, Any],
*,
skip: int = 0,
limit: int = 100
) -> List[ModelType]:
"""
通过多个字段AND条件查询获取多条记录
Args:
db: 数据库会话
filters: 字段名和字段值的字典,如 {"field1": value1, "field2": value2}
skip: 跳过记录数
limit: 返回记录数限制
"""
query = db.query(self.model)
for field, value in filters.items():
query = query.filter(getattr(self.model, field) == value)
return query.offset(skip).limit(limit).all()
def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
return db.query(self.model).offset(skip).limit(limit).all()
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
try:
db.commit()
db.refresh(db_obj)
except Exception as e:
db.rollback()
raise e
return db_obj
def update(
self,
db: Session,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
try:
db.commit()
db.refresh(db_obj)
except Exception as e:
db.rollback()
raise e
return db_obj
def delete(self, db: Session, *, id: int) -> ModelType:
obj = db.query(self.model).get(id)
if obj:
try:
db.delete(obj)
db.commit()
except Exception as e:
db.rollback()
raise e
return obj