130 lines
3.8 KiB
Python
130 lines
3.8 KiB
Python
from typing import List, Optional, Dict, Any
|
|
from sqlalchemy import select, and_, or_, join
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
from app.crud.base import CRUDBase
|
|
from app.models.models import Event, Image, Temperature
|
|
from app.schemas.event import EventUpdate, EventQuery, TestEvent
|
|
|
|
class CRUDEvent(CRUDBase[Event, EventUpdate, EventUpdate]):
|
|
async def get_by_id(self, db: AsyncSession, *, event_id: str) -> Optional[Event]:
|
|
"""根据ID获取事件"""
|
|
query = (
|
|
select(Event)
|
|
.options(
|
|
selectinload(Event.images),
|
|
selectinload(Event.temperatures)
|
|
)
|
|
.where(Event.eventId == event_id)
|
|
)
|
|
result = await db.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_multi_with_query(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
query: EventQuery
|
|
) -> List[Event]:
|
|
"""根据查询条件获取事件列表"""
|
|
conditions = []
|
|
|
|
if query.start_time:
|
|
conditions.append(Event.insDate >= query.start_time)
|
|
if query.end_time:
|
|
conditions.append(Event.insDate <= query.end_time)
|
|
if query.etypeName:
|
|
conditions.append(Event.etypeName == query.etypeName)
|
|
if query.area:
|
|
conditions.append(Event.area == query.area)
|
|
|
|
query_stmt = (
|
|
select(Event)
|
|
.options(
|
|
selectinload(Event.images),
|
|
selectinload(Event.temperatures)
|
|
)
|
|
)
|
|
|
|
if conditions:
|
|
query_stmt = query_stmt.where(and_(*conditions))
|
|
|
|
query_stmt = query_stmt.offset(query.skip).limit(query.limit)
|
|
|
|
result = await db.execute(query_stmt)
|
|
return result.scalars().all()
|
|
|
|
async def update_event(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
event_id: str,
|
|
obj_in: EventUpdate
|
|
) -> Optional[Event]:
|
|
"""更新事件信息"""
|
|
event = await self.get_by_id(db, event_id=event_id)
|
|
if not event:
|
|
return None
|
|
|
|
update_data = obj_in.model_dump()
|
|
for field, value in update_data.items():
|
|
setattr(event, field, value)
|
|
|
|
# db.add(event)
|
|
await db.commit()
|
|
await db.refresh(event)
|
|
return event
|
|
|
|
async def delete_event(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
event_id: str
|
|
) -> Optional[Event]:
|
|
"""删除事件"""
|
|
event = await self.get_by_id(db, event_id=event_id)
|
|
if not event:
|
|
return None
|
|
|
|
await db.delete(event)
|
|
await db.commit()
|
|
return event
|
|
|
|
|
|
async def get_test(
|
|
self,
|
|
db: AsyncSession,
|
|
|
|
|
|
) -> List[TestEvent]: #响应类型要写对啊
|
|
|
|
'''
|
|
eventId
|
|
number
|
|
name
|
|
imageUrl
|
|
localPath
|
|
temperature
|
|
confidence
|
|
createTime
|
|
'''
|
|
|
|
query_stmt = (
|
|
select(Event.eventId, Event.number, Event.name, Image.imageUrl, Image.localPath, Temperature.temperature, Temperature.confidence, Temperature.createTime)
|
|
.select_from(Event)
|
|
.outerjoin(Image, Event.eventId == Image.eventId)
|
|
.outerjoin(Temperature, Image.imageId == Temperature.imageId)
|
|
# 多个查询条件
|
|
.where(and_(Event.etypeName=="日常巡检", True))
|
|
)
|
|
|
|
result = await db.execute(query_stmt)
|
|
|
|
# 获取字典
|
|
result = result.mappings().all()
|
|
|
|
# print(result)
|
|
return [TestEvent(**row) for row in result]
|
|
|
|
|
|
event = CRUDEvent(Event) |