kangda_robotic_dog/机器狗后台服务/app/crud/base.py

88 lines
3.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from app.core.database import Base
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
'''
Generic[ModelType, CreateSchemaType, UpdateSchemaType] 是 Python 类型提示中泛型编程的关键语法,用于声明一个​​泛型类​​。
它的作用是让 CRUDBase 类具备类型参数化的能力,允许在继承或实例化时动态绑定具体类型,从而实现代码复用和类型安全
声明 CRUDBase 类需要三个类型参数ModelType、CreateSchemaType、UpdateSchemaType。
这些类型参数会在类的内部方法中使用(如 get、create、update确保类型一致性。
'''
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD对象与SQLAlchemy模型类一起使用
:param model: SQLAlchemy模型类
"""
self.model = model
# 根据id获取对象
async def get(self, db: AsyncSession, id: Any) -> Optional[ModelType]:
"""
通过ID获取对象
"""
query = select(self.model).where(self.model.id == id)
result = await db.execute(query)
return result.scalar_one_or_none()
# 分页查询
async def get_multi(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
"""
获取多个对象
"""
query = select(self.model).offset(skip).limit(limit)
result = await db.execute(query)
return result.scalars().all()
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
"""
创建对象
"""
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
"""
更新对象
"""
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def remove(self, db: AsyncSession, *, id: Any) -> ModelType:
"""
删除对象
"""
obj = await self.get(db=db, id=id)
await db.delete(obj)
await db.commit()
return obj