仓库模块完了
This commit is contained in:
Binary file not shown.
@@ -8,7 +8,7 @@ import logging
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from app.models.models import User, Algorithm, APIKey
|
||||
from app.models.models import User, Algorithm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -209,26 +209,7 @@ class PermissionManager:
|
||||
logger.error(f"Error getting algorithm permissions: {str(e)}")
|
||||
return []
|
||||
|
||||
def check_api_key_access(self, db: Session, api_key_value: str, algorithm_id: str) -> bool:
|
||||
"""检查API密钥对算法的访问权限"""
|
||||
try:
|
||||
# 通过API密钥查找用户
|
||||
api_key = db.query(APIKey).filter(
|
||||
APIKey.key == api_key_value,
|
||||
APIKey.status == "active"
|
||||
).first()
|
||||
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
# 检查用户对算法的访问权限
|
||||
return self.check_algorithm_access(
|
||||
db, api_key.user_id, algorithm_id, PermissionType.EXECUTE
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking API key access: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def validate_user_algorithm_operation(self, db: Session, user_id: str, algorithm_id: str,
|
||||
operation: str) -> bool:
|
||||
|
||||
@@ -2,12 +2,12 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
import uuid
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.models.models import User, APIKey
|
||||
from app.schemas.user import UserCreate, UserUpdate, TokenData, APIKeyCreate
|
||||
from app.models.models import User, Role
|
||||
from app.schemas.user import UserCreate, UserUpdate, TokenData, RoleCreate
|
||||
from app.utils.cache import cache
|
||||
|
||||
# 密码加密上下文,使用pbkdf2_sha256方案,避免bcrypt的密码长度限制
|
||||
@@ -86,12 +86,12 @@ class UserService:
|
||||
@staticmethod
|
||||
def get_user_by_username(db: Session, username: str) -> Optional[User]:
|
||||
"""通过用户名获取用户"""
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
return db.query(User).options(joinedload(User.role)).filter(User.username == username).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
|
||||
"""通过ID获取用户"""
|
||||
return db.query(User).filter(User.id == user_id).first()
|
||||
return db.query(User).options(joinedload(User.role)).filter(User.id == user_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
@@ -101,7 +101,7 @@ class UserService:
|
||||
@staticmethod
|
||||
def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[User]:
|
||||
"""获取用户列表"""
|
||||
return db.query(User).offset(skip).limit(limit).all()
|
||||
return db.query(User).options(joinedload(User.role)).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def create_user(db: Session, user: UserCreate) -> User:
|
||||
@@ -115,7 +115,7 @@ class UserService:
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
password_hash=UserService.get_password_hash(user.password),
|
||||
role=user.role
|
||||
role_id=user.role_id
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
@@ -159,80 +159,64 @@ class UserService:
|
||||
if not UserService.verify_password(password, user.password_hash):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
class APIKeyService:
|
||||
"""API密钥服务类"""
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(db: Session, api_key_create: APIKeyCreate) -> APIKey:
|
||||
"""创建API密钥"""
|
||||
# 生成唯一ID和密钥
|
||||
api_key_id = f"key-{uuid.uuid4().hex[:8]}"
|
||||
api_key_value = f"sk_{uuid.uuid4().hex}"
|
||||
def create_role(db: Session, role: RoleCreate) -> Role:
|
||||
"""创建角色"""
|
||||
# 生成唯一ID
|
||||
role_id = f"role-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 创建API密钥实例
|
||||
db_api_key = APIKey(
|
||||
id=api_key_id,
|
||||
user_id=api_key_create.user_id,
|
||||
name=api_key_create.name,
|
||||
key=api_key_value,
|
||||
expires_at=api_key_create.expires_at
|
||||
# 创建角色实例
|
||||
db_role = Role(
|
||||
id=role_id,
|
||||
name=role.name,
|
||||
description=role.description
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.add(db_api_key)
|
||||
db.add(db_role)
|
||||
db.commit()
|
||||
db.refresh(db_api_key)
|
||||
db.refresh(db_role)
|
||||
|
||||
return db_api_key
|
||||
return db_role
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[APIKey]:
|
||||
"""通过ID获取API密钥"""
|
||||
return db.query(APIKey).filter(APIKey.id == api_key_id).first()
|
||||
def get_role_by_id(db: Session, role_id: str) -> Optional[Role]:
|
||||
"""通过ID获取角色"""
|
||||
return db.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_by_value(db: Session, api_key_value: str) -> Optional[APIKey]:
|
||||
"""通过值获取API密钥"""
|
||||
return db.query(APIKey).filter(APIKey.key == api_key_value).first()
|
||||
def get_role_by_name(db: Session, role_name: str) -> Optional[Role]:
|
||||
"""通过名称获取角色"""
|
||||
return db.query(Role).filter(Role.name == role_name).first()
|
||||
|
||||
@staticmethod
|
||||
def get_api_keys_by_user_id(db: Session, user_id: str) -> List[APIKey]:
|
||||
"""通过用户ID获取API密钥列表"""
|
||||
return db.query(APIKey).filter(APIKey.user_id == user_id).all()
|
||||
def get_roles(db: Session, skip: int = 0, limit: int = 100) -> List[Role]:
|
||||
"""获取角色列表"""
|
||||
return db.query(Role).offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def revoke_api_key(db: Session, api_key_id: str) -> Optional[APIKey]:
|
||||
"""撤销API密钥"""
|
||||
# 获取API密钥
|
||||
db_api_key = APIKeyService.get_api_key_by_id(db, api_key_id)
|
||||
if not db_api_key:
|
||||
return None
|
||||
def init_default_roles(db: Session) -> None:
|
||||
"""初始化默认角色"""
|
||||
# 检查是否已存在默认角色
|
||||
admin_role = UserService.get_role_by_name(db, "admin")
|
||||
user_role = UserService.get_role_by_name(db, "user")
|
||||
|
||||
# 更新状态为已撤销
|
||||
db_api_key.status = "revoked"
|
||||
# 创建管理员角色
|
||||
if not admin_role:
|
||||
admin_role = RoleCreate(
|
||||
name="admin",
|
||||
description="系统管理员,拥有所有权限"
|
||||
)
|
||||
UserService.create_role(db, admin_role)
|
||||
|
||||
# 保存到数据库
|
||||
db.commit()
|
||||
db.refresh(db_api_key)
|
||||
|
||||
return db_api_key
|
||||
|
||||
@staticmethod
|
||||
def validate_api_key(db: Session, api_key_value: str) -> Optional[APIKey]:
|
||||
"""验证API密钥"""
|
||||
# 获取API密钥
|
||||
api_key = APIKeyService.get_api_key_by_value(db, api_key_value)
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
# 检查状态
|
||||
if api_key.status != "active":
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
if api_key.expires_at < datetime.utcnow():
|
||||
return None
|
||||
|
||||
return api_key
|
||||
# 创建普通用户角色
|
||||
if not user_role:
|
||||
user_role = RoleCreate(
|
||||
name="user",
|
||||
description="普通用户,拥有基本权限"
|
||||
)
|
||||
UserService.create_role(db, user_role)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user