from datetime import datetime, timedelta from typing import Optional, List from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.orm import Session import uuid from app.config.settings import settings from app.models.models import User, APIKey from app.schemas.user import UserCreate, UserUpdate, TokenData, APIKeyCreate from app.utils.cache import cache # 密码加密上下文,使用pbkdf2_sha256方案,避免bcrypt的密码长度限制 pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") class UserService: """用户服务类""" @staticmethod def verify_password(plain_password: str, hashed_password: str) -> bool: """验证密码""" return pwd_context.verify(plain_password, hashed_password) @staticmethod def get_password_hash(password: str) -> str: """获取密码哈希值""" # 确保密码长度不超过72个字节 password = password[:72] return pwd_context.hash(password) @staticmethod def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """创建访问令牌""" to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) return encoded_jwt @staticmethod def logout_user(token: str) -> None: """用户登出,将令牌加入黑名单""" try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) exp = payload.get("exp") if exp: # 计算令牌剩余有效期 remaining_time = exp - int(datetime.utcnow().timestamp()) if remaining_time > 0: # 将令牌加入黑名单,设置与令牌剩余有效期相同的过期时间 cache.set(f"blacklist:{token}", "1", expire=remaining_time) except JWTError: pass @staticmethod def is_token_blacklisted(token: str) -> bool: """检查令牌是否在黑名单中""" return cache.exists(f"blacklist:{token}") @staticmethod def get_current_user(db: Session, token: str) -> Optional[User]: """获取当前用户""" try: # 检查令牌是否在黑名单中 if UserService.is_token_blacklisted(token): return None payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) username: str = payload.get("sub") if username is None: return None token_data = TokenData(username=username, user_id=payload.get("user_id")) except JWTError: return None user = UserService.get_user_by_username(db, username=token_data.username) if user is None: return None return user @staticmethod def get_user_by_username(db: Session, username: str) -> Optional[User]: """通过用户名获取用户""" return db.query(User).filter(User.username == username).first() @staticmethod def get_user_by_id(db: Session, user_id: str) -> Optional[User]: """通过ID获取用户""" return db.query(User).filter(User.id == user_id).first() @staticmethod def get_user_by_email(db: Session, email: str) -> Optional[User]: """通过邮箱获取用户""" return db.query(User).filter(User.email == email).first() @staticmethod def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]: """获取用户列表""" return db.query(User).offset(skip).limit(limit).all() @staticmethod def create_user(db: Session, user: UserCreate) -> User: """创建用户""" # 生成唯一ID user_id = f"user-{uuid.uuid4().hex[:8]}" # 创建用户实例 db_user = User( id=user_id, username=user.username, email=user.email, password_hash=UserService.get_password_hash(user.password), role=user.role ) # 保存到数据库 db.add(db_user) db.commit() db.refresh(db_user) return db_user @staticmethod def update_user(db: Session, user_id: str, user_update: UserUpdate) -> Optional[User]: """更新用户""" # 获取用户 db_user = UserService.get_user_by_id(db, user_id) if not db_user: return None # 更新用户信息 update_data = user_update.dict(exclude_unset=True) # 如果更新密码,需要重新哈希 if "password" in update_data: update_data["password_hash"] = UserService.get_password_hash(update_data.pop("password")) # 应用更新 for field, value in update_data.items(): setattr(db_user, field, value) # 保存到数据库 db.commit() db.refresh(db_user) return db_user @staticmethod def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: """认证用户""" user = UserService.get_user_by_username(db, username) if not user: return None if not UserService.verify_password(password, user.password_hash): return None return user class APIKeyService: """API密钥服务类""" @staticmethod def create_api_key(db: Session, api_key_create: APIKeyCreate) -> APIKey: """创建API密钥""" # 生成唯一ID和密钥 api_key_id = f"key-{uuid.uuid4().hex[:8]}" api_key_value = f"sk_{uuid.uuid4().hex}" # 创建API密钥实例 db_api_key = APIKey( id=api_key_id, user_id=api_key_create.user_id, name=api_key_create.name, key=api_key_value, expires_at=api_key_create.expires_at ) # 保存到数据库 db.add(db_api_key) db.commit() db.refresh(db_api_key) return db_api_key @staticmethod def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[APIKey]: """通过ID获取API密钥""" return db.query(APIKey).filter(APIKey.id == api_key_id).first() @staticmethod def get_api_key_by_value(db: Session, api_key_value: str) -> Optional[APIKey]: """通过值获取API密钥""" return db.query(APIKey).filter(APIKey.key == api_key_value).first() @staticmethod def get_api_keys_by_user_id(db: Session, user_id: str) -> List[APIKey]: """通过用户ID获取API密钥列表""" return db.query(APIKey).filter(APIKey.user_id == user_id).all() @staticmethod def revoke_api_key(db: Session, api_key_id: str) -> Optional[APIKey]: """撤销API密钥""" # 获取API密钥 db_api_key = APIKeyService.get_api_key_by_id(db, api_key_id) if not db_api_key: return None # 更新状态为已撤销 db_api_key.status = "revoked" # 保存到数据库 db.commit() db.refresh(db_api_key) return db_api_key @staticmethod def validate_api_key(db: Session, api_key_value: str) -> Optional[APIKey]: """验证API密钥""" # 获取API密钥 api_key = APIKeyService.get_api_key_by_value(db, api_key_value) if not api_key: return None # 检查状态 if api_key.status != "active": return None # 检查是否过期 if api_key.expires_at < datetime.utcnow(): return None return api_key