Files
algorithm/backend/app/services/user.py
2026-02-08 20:06:35 +08:00

223 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, timedelta
from typing import Optional, List
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session, joinedload
import uuid
from app.config.settings import settings
from app.models.models import User, Role
from app.schemas.user import UserCreate, UserUpdate, TokenData, RoleCreate
from app.utils.cache import cache
# 密码加密上下文使用pbkdf2_sha256方案避免bcrypt的密码长度限制
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
class UserService:
"""用户服务类"""
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
@staticmethod
def get_password_hash(password: str) -> str:
"""获取密码哈希值"""
# 确保密码长度不超过72个字节
password = password[:72]
return pwd_context.hash(password)
@staticmethod
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
@staticmethod
def logout_user(token: str) -> None:
"""用户登出,将令牌加入黑名单"""
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
exp = payload.get("exp")
if exp:
# 计算令牌剩余有效期
remaining_time = exp - int(datetime.utcnow().timestamp())
if remaining_time > 0:
# 将令牌加入黑名单,设置与令牌剩余有效期相同的过期时间
cache.set(f"blacklist:{token}", "1", expire=remaining_time)
except JWTError:
pass
@staticmethod
def is_token_blacklisted(token: str) -> bool:
"""检查令牌是否在黑名单中"""
return cache.exists(f"blacklist:{token}")
@staticmethod
def get_current_user(db: Session, token: str) -> Optional[User]:
"""获取当前用户"""
try:
# 检查令牌是否在黑名单中
if UserService.is_token_blacklisted(token):
return None
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
username: str = payload.get("sub")
if username is None:
return None
token_data = TokenData(username=username, user_id=payload.get("user_id"))
except JWTError:
return None
user = UserService.get_user_by_username(db, username=token_data.username)
if user is None:
return None
return user
@staticmethod
def get_user_by_username(db: Session, username: str) -> Optional[User]:
"""通过用户名获取用户"""
return db.query(User).options(joinedload(User.role)).filter(User.username == username).first()
@staticmethod
def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
"""通过ID获取用户"""
return db.query(User).options(joinedload(User.role)).filter(User.id == user_id).first()
@staticmethod
def get_user_by_email(db: Session, email: str) -> Optional[User]:
"""通过邮箱获取用户"""
return db.query(User).filter(User.email == email).first()
@staticmethod
def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]:
"""获取用户列表"""
return db.query(User).options(joinedload(User.role)).offset(skip).limit(limit).all()
@staticmethod
def create_user(db: Session, user: UserCreate) -> User:
"""创建用户"""
# 生成唯一ID
user_id = f"user-{uuid.uuid4().hex[:8]}"
# 创建用户实例
db_user = User(
id=user_id,
username=user.username,
email=user.email,
password_hash=UserService.get_password_hash(user.password),
role_id=user.role_id
)
# 保存到数据库
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
@staticmethod
def update_user(db: Session, user_id: str, user_update: UserUpdate) -> Optional[User]:
"""更新用户"""
# 获取用户
db_user = UserService.get_user_by_id(db, user_id)
if not db_user:
return None
# 更新用户信息
update_data = user_update.dict(exclude_unset=True)
# 如果更新密码,需要重新哈希
if "password" in update_data:
update_data["password_hash"] = UserService.get_password_hash(update_data.pop("password"))
# 应用更新
for field, value in update_data.items():
setattr(db_user, field, value)
# 保存到数据库
db.commit()
db.refresh(db_user)
return db_user
@staticmethod
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
"""认证用户"""
user = UserService.get_user_by_username(db, username)
if not user:
return None
if not UserService.verify_password(password, user.password_hash):
return None
return user
@staticmethod
def create_role(db: Session, role: RoleCreate) -> Role:
"""创建角色"""
# 生成唯一ID
role_id = f"role-{uuid.uuid4().hex[:8]}"
# 创建角色实例
db_role = Role(
id=role_id,
name=role.name,
description=role.description
)
# 保存到数据库
db.add(db_role)
db.commit()
db.refresh(db_role)
return db_role
@staticmethod
def get_role_by_id(db: Session, role_id: str) -> Optional[Role]:
"""通过ID获取角色"""
return db.query(Role).filter(Role.id == role_id).first()
@staticmethod
def get_role_by_name(db: Session, role_name: str) -> Optional[Role]:
"""通过名称获取角色"""
return db.query(Role).filter(Role.name == role_name).first()
@staticmethod
def get_roles(db: Session, skip: int = 0, limit: int = 100) -> List[Role]:
"""获取角色列表"""
return db.query(Role).offset(skip).limit(limit).all()
@staticmethod
def init_default_roles(db: Session) -> None:
"""初始化默认角色"""
# 检查是否已存在默认角色
admin_role = UserService.get_role_by_name(db, "admin")
user_role = UserService.get_role_by_name(db, "user")
# 创建管理员角色
if not admin_role:
admin_role = RoleCreate(
name="admin",
description="系统管理员,拥有所有权限"
)
UserService.create_role(db, admin_role)
# 创建普通用户角色
if not user_role:
user_role = RoleCreate(
name="user",
description="普通用户,拥有基本权限"
)
UserService.create_role(db, user_role)