from datetime import datetime, timedelta from typing import Optional, List from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.orm import Session, joinedload import uuid from app.config.settings import settings from app.models.models import User, Role from app.schemas.user import UserCreate, UserUpdate, TokenData, RoleCreate from app.utils.cache import cache # 密码加密上下文,使用pbkdf2_sha256方案,避免bcrypt的密码长度限制 pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") class UserService: """用户服务类""" @staticmethod def verify_password(plain_password: str, hashed_password: str) -> bool: """验证密码""" return pwd_context.verify(plain_password, hashed_password) @staticmethod def get_password_hash(password: str) -> str: """获取密码哈希值""" # 确保密码长度不超过72个字节 password = password[:72] return pwd_context.hash(password) @staticmethod def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """创建访问令牌""" to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) return encoded_jwt @staticmethod def logout_user(token: str) -> None: """用户登出,将令牌加入黑名单""" try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) exp = payload.get("exp") if exp: # 计算令牌剩余有效期 remaining_time = exp - int(datetime.utcnow().timestamp()) if remaining_time > 0: # 将令牌加入黑名单,设置与令牌剩余有效期相同的过期时间 cache.set(f"blacklist:{token}", "1", expire=remaining_time) except JWTError: pass @staticmethod def is_token_blacklisted(token: str) -> bool: """检查令牌是否在黑名单中""" return cache.exists(f"blacklist:{token}") @staticmethod def get_current_user(db: Session, token: str) -> Optional[User]: """获取当前用户""" try: # 检查令牌是否在黑名单中 if UserService.is_token_blacklisted(token): return None payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) username: str = payload.get("sub") if username is None: return None token_data = TokenData(username=username, user_id=payload.get("user_id")) except JWTError: return None user = UserService.get_user_by_username(db, username=token_data.username) if user is None: return None return user @staticmethod def get_user_by_username(db: Session, username: str) -> Optional[User]: """通过用户名获取用户""" return db.query(User).options(joinedload(User.role)).filter(User.username == username).first() @staticmethod def get_user_by_id(db: Session, user_id: str) -> Optional[User]: """通过ID获取用户""" return db.query(User).options(joinedload(User.role)).filter(User.id == user_id).first() @staticmethod def get_user_by_email(db: Session, email: str) -> Optional[User]: """通过邮箱获取用户""" return db.query(User).filter(User.email == email).first() @staticmethod def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]: """获取用户列表""" return db.query(User).options(joinedload(User.role)).offset(skip).limit(limit).all() @staticmethod def create_user(db: Session, user: UserCreate) -> User: """创建用户""" # 生成唯一ID user_id = f"user-{uuid.uuid4().hex[:8]}" # 创建用户实例 db_user = User( id=user_id, username=user.username, email=user.email, password_hash=UserService.get_password_hash(user.password), role_id=user.role_id ) # 保存到数据库 db.add(db_user) db.commit() db.refresh(db_user) return db_user @staticmethod def update_user(db: Session, user_id: str, user_update: UserUpdate) -> Optional[User]: """更新用户""" # 获取用户 db_user = UserService.get_user_by_id(db, user_id) if not db_user: return None # 更新用户信息 update_data = user_update.dict(exclude_unset=True) # 如果更新密码,需要重新哈希 if "password" in update_data: update_data["password_hash"] = UserService.get_password_hash(update_data.pop("password")) # 应用更新 for field, value in update_data.items(): setattr(db_user, field, value) # 保存到数据库 db.commit() db.refresh(db_user) return db_user @staticmethod def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: """认证用户""" user = UserService.get_user_by_username(db, username) if not user: return None if not UserService.verify_password(password, user.password_hash): return None return user @staticmethod def create_role(db: Session, role: RoleCreate) -> Role: """创建角色""" # 生成唯一ID role_id = f"role-{uuid.uuid4().hex[:8]}" # 创建角色实例 db_role = Role( id=role_id, name=role.name, description=role.description ) # 保存到数据库 db.add(db_role) db.commit() db.refresh(db_role) return db_role @staticmethod def get_role_by_id(db: Session, role_id: str) -> Optional[Role]: """通过ID获取角色""" return db.query(Role).filter(Role.id == role_id).first() @staticmethod def get_role_by_name(db: Session, role_name: str) -> Optional[Role]: """通过名称获取角色""" return db.query(Role).filter(Role.name == role_name).first() @staticmethod def get_roles(db: Session, skip: int = 0, limit: int = 100) -> List[Role]: """获取角色列表""" return db.query(Role).offset(skip).limit(limit).all() @staticmethod def init_default_roles(db: Session) -> None: """初始化默认角色""" # 检查是否已存在默认角色 admin_role = UserService.get_role_by_name(db, "admin") user_role = UserService.get_role_by_name(db, "user") # 创建管理员角色 if not admin_role: admin_role = RoleCreate( name="admin", description="系统管理员,拥有所有权限" ) UserService.create_role(db, admin_role) # 创建普通用户角色 if not user_role: user_role = RoleCreate( name="user", description="普通用户,拥有基本权限" ) UserService.create_role(db, user_role)