223 lines
7.6 KiB
Python
223 lines
7.6 KiB
Python
from datetime import datetime, timedelta
|
||
from typing import Optional, List
|
||
from jose import JWTError, jwt
|
||
from passlib.context import CryptContext
|
||
from sqlalchemy.orm import Session, joinedload
|
||
import uuid
|
||
|
||
from app.config.settings import settings
|
||
from app.models.models import User, Role
|
||
from app.schemas.user import UserCreate, UserUpdate, TokenData, RoleCreate
|
||
from app.utils.cache import cache
|
||
|
||
# 密码加密上下文,使用pbkdf2_sha256方案,避免bcrypt的密码长度限制
|
||
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||
|
||
|
||
class UserService:
|
||
"""用户服务类"""
|
||
|
||
@staticmethod
|
||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
"""验证密码"""
|
||
return pwd_context.verify(plain_password, hashed_password)
|
||
|
||
@staticmethod
|
||
def get_password_hash(password: str) -> str:
|
||
"""获取密码哈希值"""
|
||
# 确保密码长度不超过72个字节
|
||
password = password[:72]
|
||
return pwd_context.hash(password)
|
||
|
||
@staticmethod
|
||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||
"""创建访问令牌"""
|
||
to_encode = data.copy()
|
||
if expires_delta:
|
||
expire = datetime.utcnow() + expires_delta
|
||
else:
|
||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
|
||
to_encode.update({"exp": expire})
|
||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||
return encoded_jwt
|
||
|
||
@staticmethod
|
||
def logout_user(token: str) -> None:
|
||
"""用户登出,将令牌加入黑名单"""
|
||
try:
|
||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||
exp = payload.get("exp")
|
||
if exp:
|
||
# 计算令牌剩余有效期
|
||
remaining_time = exp - int(datetime.utcnow().timestamp())
|
||
if remaining_time > 0:
|
||
# 将令牌加入黑名单,设置与令牌剩余有效期相同的过期时间
|
||
cache.set(f"blacklist:{token}", "1", expire=remaining_time)
|
||
except JWTError:
|
||
pass
|
||
|
||
@staticmethod
|
||
def is_token_blacklisted(token: str) -> bool:
|
||
"""检查令牌是否在黑名单中"""
|
||
return cache.exists(f"blacklist:{token}")
|
||
|
||
@staticmethod
|
||
def get_current_user(db: Session, token: str) -> Optional[User]:
|
||
"""获取当前用户"""
|
||
try:
|
||
# 检查令牌是否在黑名单中
|
||
if UserService.is_token_blacklisted(token):
|
||
return None
|
||
|
||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||
username: str = payload.get("sub")
|
||
if username is None:
|
||
return None
|
||
token_data = TokenData(username=username, user_id=payload.get("user_id"))
|
||
except JWTError:
|
||
return None
|
||
|
||
user = UserService.get_user_by_username(db, username=token_data.username)
|
||
if user is None:
|
||
return None
|
||
return user
|
||
|
||
@staticmethod
|
||
def get_user_by_username(db: Session, username: str) -> Optional[User]:
|
||
"""通过用户名获取用户"""
|
||
return db.query(User).options(joinedload(User.role)).filter(User.username == username).first()
|
||
|
||
@staticmethod
|
||
def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
|
||
"""通过ID获取用户"""
|
||
return db.query(User).options(joinedload(User.role)).filter(User.id == user_id).first()
|
||
|
||
@staticmethod
|
||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||
"""通过邮箱获取用户"""
|
||
return db.query(User).filter(User.email == email).first()
|
||
|
||
@staticmethod
|
||
def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]:
|
||
"""获取用户列表"""
|
||
return db.query(User).options(joinedload(User.role)).offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def create_user(db: Session, user: UserCreate) -> User:
|
||
"""创建用户"""
|
||
# 生成唯一ID
|
||
user_id = f"user-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 创建用户实例
|
||
db_user = User(
|
||
id=user_id,
|
||
username=user.username,
|
||
email=user.email,
|
||
password_hash=UserService.get_password_hash(user.password),
|
||
role_id=user.role_id
|
||
)
|
||
|
||
# 保存到数据库
|
||
db.add(db_user)
|
||
db.commit()
|
||
db.refresh(db_user)
|
||
|
||
return db_user
|
||
|
||
@staticmethod
|
||
def update_user(db: Session, user_id: str, user_update: UserUpdate) -> Optional[User]:
|
||
"""更新用户"""
|
||
# 获取用户
|
||
db_user = UserService.get_user_by_id(db, user_id)
|
||
if not db_user:
|
||
return None
|
||
|
||
# 更新用户信息
|
||
update_data = user_update.dict(exclude_unset=True)
|
||
|
||
# 如果更新密码,需要重新哈希
|
||
if "password" in update_data:
|
||
update_data["password_hash"] = UserService.get_password_hash(update_data.pop("password"))
|
||
|
||
# 应用更新
|
||
for field, value in update_data.items():
|
||
setattr(db_user, field, value)
|
||
|
||
# 保存到数据库
|
||
db.commit()
|
||
db.refresh(db_user)
|
||
|
||
return db_user
|
||
|
||
@staticmethod
|
||
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
|
||
"""认证用户"""
|
||
user = UserService.get_user_by_username(db, username)
|
||
if not user:
|
||
return None
|
||
if not UserService.verify_password(password, user.password_hash):
|
||
return None
|
||
return user
|
||
|
||
@staticmethod
|
||
def create_role(db: Session, role: RoleCreate) -> Role:
|
||
"""创建角色"""
|
||
# 生成唯一ID
|
||
role_id = f"role-{uuid.uuid4().hex[:8]}"
|
||
|
||
# 创建角色实例
|
||
db_role = Role(
|
||
id=role_id,
|
||
name=role.name,
|
||
description=role.description
|
||
)
|
||
|
||
# 保存到数据库
|
||
db.add(db_role)
|
||
db.commit()
|
||
db.refresh(db_role)
|
||
|
||
return db_role
|
||
|
||
@staticmethod
|
||
def get_role_by_id(db: Session, role_id: str) -> Optional[Role]:
|
||
"""通过ID获取角色"""
|
||
return db.query(Role).filter(Role.id == role_id).first()
|
||
|
||
@staticmethod
|
||
def get_role_by_name(db: Session, role_name: str) -> Optional[Role]:
|
||
"""通过名称获取角色"""
|
||
return db.query(Role).filter(Role.name == role_name).first()
|
||
|
||
@staticmethod
|
||
def get_roles(db: Session, skip: int = 0, limit: int = 100) -> List[Role]:
|
||
"""获取角色列表"""
|
||
return db.query(Role).offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def init_default_roles(db: Session) -> None:
|
||
"""初始化默认角色"""
|
||
# 检查是否已存在默认角色
|
||
admin_role = UserService.get_role_by_name(db, "admin")
|
||
user_role = UserService.get_role_by_name(db, "user")
|
||
|
||
# 创建管理员角色
|
||
if not admin_role:
|
||
admin_role = RoleCreate(
|
||
name="admin",
|
||
description="系统管理员,拥有所有权限"
|
||
)
|
||
UserService.create_role(db, admin_role)
|
||
|
||
# 创建普通用户角色
|
||
if not user_role:
|
||
user_role = RoleCreate(
|
||
name="user",
|
||
description="普通用户,拥有基本权限"
|
||
)
|
||
UserService.create_role(db, user_role)
|
||
|
||
|
||
|