105 lines
3.6 KiB
Python
105 lines
3.6 KiB
Python
"""
|
||
视频检查任务数据访问层
|
||
"""
|
||
|
||
from typing import List, Optional, Dict, Any
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import and_
|
||
|
||
from models.video_check_task import SurVideoCheckTask, SurVideo, SurConfigBase
|
||
from config import settings
|
||
|
||
|
||
class VideoCheckTaskRepository:
|
||
"""视频检查任务数据访问类"""
|
||
|
||
def __init__(self, session: Session):
|
||
self.session = session
|
||
|
||
def get_pending_tasks(self) -> List[SurVideoCheckTask]:
|
||
"""获取待处理的任务(status=0)"""
|
||
return self.session.query(SurVideoCheckTask).filter(
|
||
SurVideoCheckTask.status == 0
|
||
).all()
|
||
|
||
def get_by_id(self, task_id: int) -> Optional[SurVideoCheckTask]:
|
||
"""根据ID获取任务"""
|
||
return self.session.query(SurVideoCheckTask).filter(
|
||
SurVideoCheckTask.id == task_id
|
||
).first()
|
||
|
||
def update_task_status(self, task_id: int, status: int,
|
||
start_time=None, finish_time=None,
|
||
progress=None, result=None, result_data=None,
|
||
feature_data=None) -> bool:
|
||
"""更新任务状态"""
|
||
try:
|
||
task = self.get_by_id(task_id)
|
||
if not task:
|
||
return False
|
||
|
||
task.status = status
|
||
if start_time:
|
||
task.start_time = start_time
|
||
if finish_time:
|
||
task.finish_time = finish_time
|
||
if progress is not None:
|
||
task.progress = progress
|
||
if result is not None:
|
||
task.result = result
|
||
if result_data is not None:
|
||
task.result_data = result_data
|
||
if feature_data is not None:
|
||
task.feature_data = feature_data
|
||
|
||
self.session.commit()
|
||
return True
|
||
except Exception:
|
||
self.session.rollback()
|
||
return False
|
||
|
||
def get_config_by_group_id(self, group_id: int) -> Dict[str, str]:
|
||
"""根据组ID获取配置"""
|
||
configs = self.session.query(SurConfigBase).filter(
|
||
and_(
|
||
SurConfigBase.config_type == settings.SUR_CONFIG_TYPE_FACE,
|
||
SurConfigBase.group_id == group_id
|
||
)
|
||
).all()
|
||
|
||
config_dict = {}
|
||
for config in configs:
|
||
config_dict[config.config_key] = config.config_value
|
||
|
||
return config_dict
|
||
|
||
def get_video_by_id(self, video_id: int) -> Optional[SurVideo]:
|
||
"""根据ID获取视频信息"""
|
||
return self.session.query(SurVideo).filter(
|
||
SurVideo.id == video_id
|
||
).first()
|
||
|
||
def get_videos_by_ids(self, video_ids: List[int]) -> List[SurVideo]:
|
||
"""根据ID列表获取视频信息"""
|
||
return self.session.query(SurVideo).filter(
|
||
SurVideo.id.in_(video_ids)
|
||
).all()
|
||
|
||
def get_processing_tasks(self) -> List[SurVideoCheckTask]:
|
||
"""获取正在处理的任务(status=1)"""
|
||
return self.session.query(SurVideoCheckTask).filter(
|
||
SurVideoCheckTask.status == 1
|
||
).all()
|
||
|
||
def get_timeout_tasks(self, timeout_hours: int) -> List[SurVideoCheckTask]:
|
||
"""获取超时的任务"""
|
||
from datetime import datetime, timedelta
|
||
|
||
timeout_time = datetime.now() - timedelta(hours=timeout_hours)
|
||
|
||
return self.session.query(SurVideoCheckTask).filter(
|
||
and_(
|
||
SurVideoCheckTask.status == 1,
|
||
SurVideoCheckTask.start_time < timeout_time
|
||
)
|
||
).all() |