first commit

This commit is contained in:
2026-02-08 14:42:58 +08:00
commit 20e1deae21
8197 changed files with 2264639 additions and 0 deletions

View File

@@ -0,0 +1,238 @@
from datetime import datetime, timedelta
from typing import Optional, List
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
import uuid
from app.config.settings import settings
from app.models.models import User, APIKey
from app.schemas.user import UserCreate, UserUpdate, TokenData, APIKeyCreate
from app.utils.cache import cache
# 密码加密上下文使用pbkdf2_sha256方案避免bcrypt的密码长度限制
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
class UserService:
"""用户服务类"""
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
@staticmethod
def get_password_hash(password: str) -> str:
"""获取密码哈希值"""
# 确保密码长度不超过72个字节
password = password[:72]
return pwd_context.hash(password)
@staticmethod
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
@staticmethod
def logout_user(token: str) -> None:
"""用户登出,将令牌加入黑名单"""
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
exp = payload.get("exp")
if exp:
# 计算令牌剩余有效期
remaining_time = exp - int(datetime.utcnow().timestamp())
if remaining_time > 0:
# 将令牌加入黑名单,设置与令牌剩余有效期相同的过期时间
cache.set(f"blacklist:{token}", "1", expire=remaining_time)
except JWTError:
pass
@staticmethod
def is_token_blacklisted(token: str) -> bool:
"""检查令牌是否在黑名单中"""
return cache.exists(f"blacklist:{token}")
@staticmethod
def get_current_user(db: Session, token: str) -> Optional[User]:
"""获取当前用户"""
try:
# 检查令牌是否在黑名单中
if UserService.is_token_blacklisted(token):
return None
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
username: str = payload.get("sub")
if username is None:
return None
token_data = TokenData(username=username, user_id=payload.get("user_id"))
except JWTError:
return None
user = UserService.get_user_by_username(db, username=token_data.username)
if user is None:
return None
return user
@staticmethod
def get_user_by_username(db: Session, username: str) -> Optional[User]:
"""通过用户名获取用户"""
return db.query(User).filter(User.username == username).first()
@staticmethod
def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
"""通过ID获取用户"""
return db.query(User).filter(User.id == user_id).first()
@staticmethod
def get_user_by_email(db: Session, email: str) -> Optional[User]:
"""通过邮箱获取用户"""
return db.query(User).filter(User.email == email).first()
@staticmethod
def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]:
"""获取用户列表"""
return db.query(User).offset(skip).limit(limit).all()
@staticmethod
def create_user(db: Session, user: UserCreate) -> User:
"""创建用户"""
# 生成唯一ID
user_id = f"user-{uuid.uuid4().hex[:8]}"
# 创建用户实例
db_user = User(
id=user_id,
username=user.username,
email=user.email,
password_hash=UserService.get_password_hash(user.password),
role=user.role
)
# 保存到数据库
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
@staticmethod
def update_user(db: Session, user_id: str, user_update: UserUpdate) -> Optional[User]:
"""更新用户"""
# 获取用户
db_user = UserService.get_user_by_id(db, user_id)
if not db_user:
return None
# 更新用户信息
update_data = user_update.dict(exclude_unset=True)
# 如果更新密码,需要重新哈希
if "password" in update_data:
update_data["password_hash"] = UserService.get_password_hash(update_data.pop("password"))
# 应用更新
for field, value in update_data.items():
setattr(db_user, field, value)
# 保存到数据库
db.commit()
db.refresh(db_user)
return db_user
@staticmethod
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
"""认证用户"""
user = UserService.get_user_by_username(db, username)
if not user:
return None
if not UserService.verify_password(password, user.password_hash):
return None
return user
class APIKeyService:
"""API密钥服务类"""
@staticmethod
def create_api_key(db: Session, api_key_create: APIKeyCreate) -> APIKey:
"""创建API密钥"""
# 生成唯一ID和密钥
api_key_id = f"key-{uuid.uuid4().hex[:8]}"
api_key_value = f"sk_{uuid.uuid4().hex}"
# 创建API密钥实例
db_api_key = APIKey(
id=api_key_id,
user_id=api_key_create.user_id,
name=api_key_create.name,
key=api_key_value,
expires_at=api_key_create.expires_at
)
# 保存到数据库
db.add(db_api_key)
db.commit()
db.refresh(db_api_key)
return db_api_key
@staticmethod
def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[APIKey]:
"""通过ID获取API密钥"""
return db.query(APIKey).filter(APIKey.id == api_key_id).first()
@staticmethod
def get_api_key_by_value(db: Session, api_key_value: str) -> Optional[APIKey]:
"""通过值获取API密钥"""
return db.query(APIKey).filter(APIKey.key == api_key_value).first()
@staticmethod
def get_api_keys_by_user_id(db: Session, user_id: str) -> List[APIKey]:
"""通过用户ID获取API密钥列表"""
return db.query(APIKey).filter(APIKey.user_id == user_id).all()
@staticmethod
def revoke_api_key(db: Session, api_key_id: str) -> Optional[APIKey]:
"""撤销API密钥"""
# 获取API密钥
db_api_key = APIKeyService.get_api_key_by_id(db, api_key_id)
if not db_api_key:
return None
# 更新状态为已撤销
db_api_key.status = "revoked"
# 保存到数据库
db.commit()
db.refresh(db_api_key)
return db_api_key
@staticmethod
def validate_api_key(db: Session, api_key_value: str) -> Optional[APIKey]:
"""验证API密钥"""
# 获取API密钥
api_key = APIKeyService.get_api_key_by_value(db, api_key_value)
if not api_key:
return None
# 检查状态
if api_key.status != "active":
return None
# 检查是否过期
if api_key.expires_at < datetime.utcnow():
return None
return api_key