Files
SupervisorAI/repositories/video_check_repository.py

105 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
视频检查任务数据访问层
"""
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_
from models.video_check_task import SurVideoCheckTask, SurVideo, SurConfigBase
from config import settings
class VideoCheckTaskRepository:
"""视频检查任务数据访问类"""
def __init__(self, session: Session):
self.session = session
def get_pending_tasks(self) -> List[SurVideoCheckTask]:
"""获取待处理的任务status=0"""
return self.session.query(SurVideoCheckTask).filter(
SurVideoCheckTask.status == 0
).all()
def get_by_id(self, task_id: int) -> Optional[SurVideoCheckTask]:
"""根据ID获取任务"""
return self.session.query(SurVideoCheckTask).filter(
SurVideoCheckTask.id == task_id
).first()
def update_task_status(self, task_id: int, status: int,
start_time=None, finish_time=None,
progress=None, result=None, result_data=None,
feature_data=None) -> bool:
"""更新任务状态"""
try:
task = self.get_by_id(task_id)
if not task:
return False
task.status = status
if start_time:
task.start_time = start_time
if finish_time:
task.finish_time = finish_time
if progress is not None:
task.progress = progress
if result is not None:
task.result = result
if result_data is not None:
task.result_data = result_data
if feature_data is not None:
task.feature_data = feature_data
self.session.commit()
return True
except Exception:
self.session.rollback()
return False
def get_config_by_group_id(self, group_id: int) -> Dict[str, str]:
"""根据组ID获取配置"""
configs = self.session.query(SurConfigBase).filter(
and_(
SurConfigBase.config_type == settings.SUR_CONFIG_TYPE_FACE,
SurConfigBase.group_id == group_id
)
).all()
config_dict = {}
for config in configs:
config_dict[config.config_key] = config.config_value
return config_dict
def get_video_by_id(self, video_id: int) -> Optional[SurVideo]:
"""根据ID获取视频信息"""
return self.session.query(SurVideo).filter(
SurVideo.id == video_id
).first()
def get_videos_by_ids(self, video_ids: List[int]) -> List[SurVideo]:
"""根据ID列表获取视频信息"""
return self.session.query(SurVideo).filter(
SurVideo.id.in_(video_ids)
).all()
def get_processing_tasks(self) -> List[SurVideoCheckTask]:
"""获取正在处理的任务status=1"""
return self.session.query(SurVideoCheckTask).filter(
SurVideoCheckTask.status == 1
).all()
def get_timeout_tasks(self, timeout_hours: int) -> List[SurVideoCheckTask]:
"""获取超时的任务"""
from datetime import datetime, timedelta
timeout_time = datetime.now() - timedelta(hours=timeout_hours)
return self.session.query(SurVideoCheckTask).filter(
and_(
SurVideoCheckTask.status == 1,
SurVideoCheckTask.start_time < timeout_time
)
).all()