first commit
This commit is contained in:
238
backend/app/services/user.py
Normal file
238
backend/app/services/user.py
Normal file
@@ -0,0 +1,238 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.models.models import User, APIKey
|
||||
from app.schemas.user import UserCreate, UserUpdate, TokenData, APIKeyCreate
|
||||
from app.utils.cache import cache
|
||||
|
||||
# 密码加密上下文,使用pbkdf2_sha256方案,避免bcrypt的密码长度限制
|
||||
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
|
||||
|
||||
class UserService:
|
||||
"""用户服务类"""
|
||||
|
||||
@staticmethod
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
@staticmethod
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""获取密码哈希值"""
|
||||
# 确保密码长度不超过72个字节
|
||||
password = password[:72]
|
||||
return pwd_context.hash(password)
|
||||
|
||||
@staticmethod
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建访问令牌"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
@staticmethod
|
||||
def logout_user(token: str) -> None:
|
||||
"""用户登出,将令牌加入黑名单"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
exp = payload.get("exp")
|
||||
if exp:
|
||||
# 计算令牌剩余有效期
|
||||
remaining_time = exp - int(datetime.utcnow().timestamp())
|
||||
if remaining_time > 0:
|
||||
# 将令牌加入黑名单,设置与令牌剩余有效期相同的过期时间
|
||||
cache.set(f"blacklist:{token}", "1", expire=remaining_time)
|
||||
except JWTError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_token_blacklisted(token: str) -> bool:
|
||||
"""检查令牌是否在黑名单中"""
|
||||
return cache.exists(f"blacklist:{token}")
|
||||
|
||||
@staticmethod
|
||||
def get_current_user(db: Session, token: str) -> Optional[User]:
|
||||
"""获取当前用户"""
|
||||
try:
|
||||
# 检查令牌是否在黑名单中
|
||||
if UserService.is_token_blacklisted(token):
|
||||
return None
|
||||
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
return None
|
||||
token_data = TokenData(username=username, user_id=payload.get("user_id"))
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
user = UserService.get_user_by_username(db, username=token_data.username)
|
||||
if user is None:
|
||||
return None
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_username(db: Session, username: str) -> Optional[User]:
|
||||
"""通过用户名获取用户"""
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
|
||||
"""通过ID获取用户"""
|
||||
return db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
"""通过邮箱获取用户"""
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
|
||||
@staticmethod
|
||||
def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]:
|
||||
"""获取用户列表"""
|
||||
return db.query(User).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def create_user(db: Session, user: UserCreate) -> User:
|
||||
"""创建用户"""
|
||||
# 生成唯一ID
|
||||
user_id = f"user-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建用户实例
|
||||
db_user = User(
|
||||
id=user_id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
password_hash=UserService.get_password_hash(user.password),
|
||||
role=user.role
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
return db_user
|
||||
|
||||
@staticmethod
|
||||
def update_user(db: Session, user_id: str, user_update: UserUpdate) -> Optional[User]:
|
||||
"""更新用户"""
|
||||
# 获取用户
|
||||
db_user = UserService.get_user_by_id(db, user_id)
|
||||
if not db_user:
|
||||
return None
|
||||
|
||||
# 更新用户信息
|
||||
update_data = user_update.dict(exclude_unset=True)
|
||||
|
||||
# 如果更新密码,需要重新哈希
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = UserService.get_password_hash(update_data.pop("password"))
|
||||
|
||||
# 应用更新
|
||||
for field, value in update_data.items():
|
||||
setattr(db_user, field, value)
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
return db_user
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
|
||||
"""认证用户"""
|
||||
user = UserService.get_user_by_username(db, username)
|
||||
if not user:
|
||||
return None
|
||||
if not UserService.verify_password(password, user.password_hash):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
class APIKeyService:
|
||||
"""API密钥服务类"""
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(db: Session, api_key_create: APIKeyCreate) -> APIKey:
|
||||
"""创建API密钥"""
|
||||
# 生成唯一ID和密钥
|
||||
api_key_id = f"key-{uuid.uuid4().hex[:8]}"
|
||||
api_key_value = f"sk_{uuid.uuid4().hex}"
|
||||
|
||||
# 创建API密钥实例
|
||||
db_api_key = APIKey(
|
||||
id=api_key_id,
|
||||
user_id=api_key_create.user_id,
|
||||
name=api_key_create.name,
|
||||
key=api_key_value,
|
||||
expires_at=api_key_create.expires_at
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_api_key)
|
||||
db.commit()
|
||||
db.refresh(db_api_key)
|
||||
|
||||
return db_api_key
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[APIKey]:
|
||||
"""通过ID获取API密钥"""
|
||||
return db.query(APIKey).filter(APIKey.id == api_key_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_by_value(db: Session, api_key_value: str) -> Optional[APIKey]:
|
||||
"""通过值获取API密钥"""
|
||||
return db.query(APIKey).filter(APIKey.key == api_key_value).first()
|
||||
|
||||
@staticmethod
|
||||
def get_api_keys_by_user_id(db: Session, user_id: str) -> List[APIKey]:
|
||||
"""通过用户ID获取API密钥列表"""
|
||||
return db.query(APIKey).filter(APIKey.user_id == user_id).all()
|
||||
|
||||
@staticmethod
|
||||
def revoke_api_key(db: Session, api_key_id: str) -> Optional[APIKey]:
|
||||
"""撤销API密钥"""
|
||||
# 获取API密钥
|
||||
db_api_key = APIKeyService.get_api_key_by_id(db, api_key_id)
|
||||
if not db_api_key:
|
||||
return None
|
||||
|
||||
# 更新状态为已撤销
|
||||
db_api_key.status = "revoked"
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_api_key)
|
||||
|
||||
return db_api_key
|
||||
|
||||
@staticmethod
|
||||
def validate_api_key(db: Session, api_key_value: str) -> Optional[APIKey]:
|
||||
"""验证API密钥"""
|
||||
# 获取API密钥
|
||||
api_key = APIKeyService.get_api_key_by_value(db, api_key_value)
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
# 检查状态
|
||||
if api_key.status != "active":
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
if api_key.expires_at < datetime.utcnow():
|
||||
return None
|
||||
|
||||
return api_key
|
||||
Reference in New Issue
Block a user