first commit
This commit is contained in:
88
backend/venv/lib/python3.9/site-packages/redis/__init__.py
Normal file
88
backend/venv/lib/python3.9/site-packages/redis/__init__.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from redis import asyncio # noqa
|
||||
from redis.backoff import default_backoff
|
||||
from redis.client import Redis, StrictRedis
|
||||
from redis.cluster import RedisCluster
|
||||
from redis.connection import (
|
||||
BlockingConnectionPool,
|
||||
Connection,
|
||||
ConnectionPool,
|
||||
SSLConnection,
|
||||
UnixDomainSocketConnection,
|
||||
)
|
||||
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
|
||||
from redis.exceptions import (
|
||||
AuthenticationError,
|
||||
AuthenticationWrongNumberOfArgsError,
|
||||
BusyLoadingError,
|
||||
ChildDeadlockedError,
|
||||
ConnectionError,
|
||||
CrossSlotTransactionError,
|
||||
DataError,
|
||||
InvalidPipelineStack,
|
||||
InvalidResponse,
|
||||
MaxConnectionsError,
|
||||
OutOfMemoryError,
|
||||
PubSubError,
|
||||
ReadOnlyError,
|
||||
RedisClusterException,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
WatchError,
|
||||
)
|
||||
from redis.sentinel import (
|
||||
Sentinel,
|
||||
SentinelConnectionPool,
|
||||
SentinelManagedConnection,
|
||||
SentinelManagedSSLConnection,
|
||||
)
|
||||
from redis.utils import from_url
|
||||
|
||||
|
||||
def int_or_str(value):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
|
||||
__version__ = "7.0.1"
|
||||
VERSION = tuple(map(int_or_str, __version__.split(".")))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AuthenticationError",
|
||||
"AuthenticationWrongNumberOfArgsError",
|
||||
"BlockingConnectionPool",
|
||||
"BusyLoadingError",
|
||||
"ChildDeadlockedError",
|
||||
"Connection",
|
||||
"ConnectionError",
|
||||
"ConnectionPool",
|
||||
"CredentialProvider",
|
||||
"CrossSlotTransactionError",
|
||||
"DataError",
|
||||
"from_url",
|
||||
"default_backoff",
|
||||
"InvalidPipelineStack",
|
||||
"InvalidResponse",
|
||||
"MaxConnectionsError",
|
||||
"OutOfMemoryError",
|
||||
"PubSubError",
|
||||
"ReadOnlyError",
|
||||
"Redis",
|
||||
"RedisCluster",
|
||||
"RedisClusterException",
|
||||
"RedisError",
|
||||
"ResponseError",
|
||||
"Sentinel",
|
||||
"SentinelConnectionPool",
|
||||
"SentinelManagedConnection",
|
||||
"SentinelManagedSSLConnection",
|
||||
"SSLConnection",
|
||||
"UsernamePasswordCredentialProvider",
|
||||
"StrictRedis",
|
||||
"TimeoutError",
|
||||
"UnixDomainSocketConnection",
|
||||
"WatchError",
|
||||
]
|
||||
@@ -0,0 +1,27 @@
|
||||
from .base import (
|
||||
AsyncPushNotificationsParser,
|
||||
BaseParser,
|
||||
PushNotificationsParser,
|
||||
_AsyncRESPBase,
|
||||
)
|
||||
from .commands import AsyncCommandsParser, CommandsParser
|
||||
from .encoders import Encoder
|
||||
from .hiredis import _AsyncHiredisParser, _HiredisParser
|
||||
from .resp2 import _AsyncRESP2Parser, _RESP2Parser
|
||||
from .resp3 import _AsyncRESP3Parser, _RESP3Parser
|
||||
|
||||
__all__ = [
|
||||
"AsyncCommandsParser",
|
||||
"_AsyncHiredisParser",
|
||||
"_AsyncRESPBase",
|
||||
"_AsyncRESP2Parser",
|
||||
"_AsyncRESP3Parser",
|
||||
"AsyncPushNotificationsParser",
|
||||
"CommandsParser",
|
||||
"Encoder",
|
||||
"BaseParser",
|
||||
"_HiredisParser",
|
||||
"_RESP2Parser",
|
||||
"_RESP3Parser",
|
||||
"PushNotificationsParser",
|
||||
]
|
||||
474
backend/venv/lib/python3.9/site-packages/redis/_parsers/base.py
Normal file
474
backend/venv/lib/python3.9/site-packages/redis/_parsers/base.py
Normal file
@@ -0,0 +1,474 @@
|
||||
import logging
|
||||
import sys
|
||||
from abc import ABC
|
||||
from asyncio import IncompleteReadError, StreamReader, TimeoutError
|
||||
from typing import Awaitable, Callable, List, Optional, Protocol, Union
|
||||
|
||||
from redis.maint_notifications import (
|
||||
MaintenanceNotification,
|
||||
NodeFailedOverNotification,
|
||||
NodeFailingOverNotification,
|
||||
NodeMigratedNotification,
|
||||
NodeMigratingNotification,
|
||||
NodeMovingNotification,
|
||||
)
|
||||
|
||||
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
|
||||
from asyncio import timeout as async_timeout
|
||||
else:
|
||||
from async_timeout import timeout as async_timeout
|
||||
|
||||
from ..exceptions import (
|
||||
AskError,
|
||||
AuthenticationError,
|
||||
AuthenticationWrongNumberOfArgsError,
|
||||
BusyLoadingError,
|
||||
ClusterCrossSlotError,
|
||||
ClusterDownError,
|
||||
ConnectionError,
|
||||
ExecAbortError,
|
||||
ExternalAuthProviderError,
|
||||
MasterDownError,
|
||||
ModuleError,
|
||||
MovedError,
|
||||
NoPermissionError,
|
||||
NoScriptError,
|
||||
OutOfMemoryError,
|
||||
ReadOnlyError,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
TryAgainError,
|
||||
)
|
||||
from ..typing import EncodableT
|
||||
from .encoders import Encoder
|
||||
from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
|
||||
|
||||
MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
|
||||
NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
|
||||
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
|
||||
MODULE_EXPORTS_DATA_TYPES_ERROR = (
|
||||
"Error unloading module: the module "
|
||||
"exports one or more module-side data "
|
||||
"types, can't unload"
|
||||
)
|
||||
# user send an AUTH cmd to a server without authorization configured
|
||||
NO_AUTH_SET_ERROR = {
|
||||
# Redis >= 6.0
|
||||
"AUTH <password> called without any password "
|
||||
"configured for the default user. Are you sure "
|
||||
"your configuration is correct?": AuthenticationError,
|
||||
# Redis < 6.0
|
||||
"Client sent AUTH, but no password is set": AuthenticationError,
|
||||
}
|
||||
|
||||
EXTERNAL_AUTH_PROVIDER_ERROR = {
|
||||
"problem with LDAP service": ExternalAuthProviderError,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseParser(ABC):
|
||||
EXCEPTION_CLASSES = {
|
||||
"ERR": {
|
||||
"max number of clients reached": ConnectionError,
|
||||
"invalid password": AuthenticationError,
|
||||
# some Redis server versions report invalid command syntax
|
||||
# in lowercase
|
||||
"wrong number of arguments "
|
||||
"for 'auth' command": AuthenticationWrongNumberOfArgsError,
|
||||
# some Redis server versions report invalid command syntax
|
||||
# in uppercase
|
||||
"wrong number of arguments "
|
||||
"for 'AUTH' command": AuthenticationWrongNumberOfArgsError,
|
||||
MODULE_LOAD_ERROR: ModuleError,
|
||||
MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
|
||||
NO_SUCH_MODULE_ERROR: ModuleError,
|
||||
MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
|
||||
**NO_AUTH_SET_ERROR,
|
||||
**EXTERNAL_AUTH_PROVIDER_ERROR,
|
||||
},
|
||||
"OOM": OutOfMemoryError,
|
||||
"WRONGPASS": AuthenticationError,
|
||||
"EXECABORT": ExecAbortError,
|
||||
"LOADING": BusyLoadingError,
|
||||
"NOSCRIPT": NoScriptError,
|
||||
"READONLY": ReadOnlyError,
|
||||
"NOAUTH": AuthenticationError,
|
||||
"NOPERM": NoPermissionError,
|
||||
"ASK": AskError,
|
||||
"TRYAGAIN": TryAgainError,
|
||||
"MOVED": MovedError,
|
||||
"CLUSTERDOWN": ClusterDownError,
|
||||
"CROSSSLOT": ClusterCrossSlotError,
|
||||
"MASTERDOWN": MasterDownError,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def parse_error(cls, response):
|
||||
"Parse an error response"
|
||||
error_code = response.split(" ")[0]
|
||||
if error_code in cls.EXCEPTION_CLASSES:
|
||||
response = response[len(error_code) + 1 :]
|
||||
exception_class = cls.EXCEPTION_CLASSES[error_code]
|
||||
if isinstance(exception_class, dict):
|
||||
exception_class = exception_class.get(response, ResponseError)
|
||||
return exception_class(response)
|
||||
return ResponseError(response)
|
||||
|
||||
def on_disconnect(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_connect(self, connection):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class _RESPBase(BaseParser):
|
||||
"""Base class for sync-based resp parsing"""
|
||||
|
||||
def __init__(self, socket_read_size):
|
||||
self.socket_read_size = socket_read_size
|
||||
self.encoder = None
|
||||
self._sock = None
|
||||
self._buffer = None
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.on_disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def on_connect(self, connection):
|
||||
"Called when the socket connects"
|
||||
self._sock = connection._sock
|
||||
self._buffer = SocketBuffer(
|
||||
self._sock, self.socket_read_size, connection.socket_timeout
|
||||
)
|
||||
self.encoder = connection.encoder
|
||||
|
||||
def on_disconnect(self):
|
||||
"Called when the socket disconnects"
|
||||
self._sock = None
|
||||
if self._buffer is not None:
|
||||
self._buffer.close()
|
||||
self._buffer = None
|
||||
self.encoder = None
|
||||
|
||||
def can_read(self, timeout):
|
||||
return self._buffer and self._buffer.can_read(timeout)
|
||||
|
||||
|
||||
class AsyncBaseParser(BaseParser):
|
||||
"""Base parsing class for the python-backed async parser"""
|
||||
|
||||
__slots__ = "_stream", "_read_size"
|
||||
|
||||
def __init__(self, socket_read_size: int):
|
||||
self._stream: Optional[StreamReader] = None
|
||||
self._read_size = socket_read_size
|
||||
|
||||
async def can_read_destructive(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def read_response(
|
||||
self, disable_decoding: bool = False
|
||||
) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MaintenanceNotificationsParser:
|
||||
"""Protocol defining maintenance push notification parsing functionality"""
|
||||
|
||||
@staticmethod
|
||||
def parse_maintenance_start_msg(response, notification_type):
|
||||
# Expected message format is: <notification_type> <seq_number> <time>
|
||||
id = response[1]
|
||||
ttl = response[2]
|
||||
return notification_type(id, ttl)
|
||||
|
||||
@staticmethod
|
||||
def parse_maintenance_completed_msg(response, notification_type):
|
||||
# Expected message format is: <notification_type> <seq_number>
|
||||
id = response[1]
|
||||
return notification_type(id)
|
||||
|
||||
@staticmethod
|
||||
def parse_moving_msg(response):
|
||||
# Expected message format is: MOVING <seq_number> <time> <endpoint>
|
||||
id = response[1]
|
||||
ttl = response[2]
|
||||
if response[3] is None:
|
||||
host, port = None, None
|
||||
else:
|
||||
value = response[3]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
host, port = value.split(":")
|
||||
port = int(port) if port is not None else None
|
||||
|
||||
return NodeMovingNotification(id, host, port, ttl)
|
||||
|
||||
|
||||
_INVALIDATION_MESSAGE = "invalidate"
|
||||
_MOVING_MESSAGE = "MOVING"
|
||||
_MIGRATING_MESSAGE = "MIGRATING"
|
||||
_MIGRATED_MESSAGE = "MIGRATED"
|
||||
_FAILING_OVER_MESSAGE = "FAILING_OVER"
|
||||
_FAILED_OVER_MESSAGE = "FAILED_OVER"
|
||||
|
||||
_MAINTENANCE_MESSAGES = (
|
||||
_MIGRATING_MESSAGE,
|
||||
_MIGRATED_MESSAGE,
|
||||
_FAILING_OVER_MESSAGE,
|
||||
_FAILED_OVER_MESSAGE,
|
||||
)
|
||||
|
||||
MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING: dict[
|
||||
str, tuple[type[MaintenanceNotification], Callable]
|
||||
] = {
|
||||
_MIGRATING_MESSAGE: (
|
||||
NodeMigratingNotification,
|
||||
MaintenanceNotificationsParser.parse_maintenance_start_msg,
|
||||
),
|
||||
_MIGRATED_MESSAGE: (
|
||||
NodeMigratedNotification,
|
||||
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
|
||||
),
|
||||
_FAILING_OVER_MESSAGE: (
|
||||
NodeFailingOverNotification,
|
||||
MaintenanceNotificationsParser.parse_maintenance_start_msg,
|
||||
),
|
||||
_FAILED_OVER_MESSAGE: (
|
||||
NodeFailedOverNotification,
|
||||
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
|
||||
),
|
||||
_MOVING_MESSAGE: (
|
||||
NodeMovingNotification,
|
||||
MaintenanceNotificationsParser.parse_moving_msg,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class PushNotificationsParser(Protocol):
|
||||
"""Protocol defining RESP3-specific parsing functionality"""
|
||||
|
||||
pubsub_push_handler_func: Callable
|
||||
invalidation_push_handler_func: Optional[Callable] = None
|
||||
node_moving_push_handler_func: Optional[Callable] = None
|
||||
maintenance_push_handler_func: Optional[Callable] = None
|
||||
|
||||
def handle_pubsub_push_response(self, response):
|
||||
"""Handle pubsub push responses"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def handle_push_response(self, response, **kwargs):
|
||||
msg_type = response[0]
|
||||
if isinstance(msg_type, bytes):
|
||||
msg_type = msg_type.decode()
|
||||
|
||||
if msg_type not in (
|
||||
_INVALIDATION_MESSAGE,
|
||||
*_MAINTENANCE_MESSAGES,
|
||||
_MOVING_MESSAGE,
|
||||
):
|
||||
return self.pubsub_push_handler_func(response)
|
||||
|
||||
try:
|
||||
if (
|
||||
msg_type == _INVALIDATION_MESSAGE
|
||||
and self.invalidation_push_handler_func
|
||||
):
|
||||
return self.invalidation_push_handler_func(response)
|
||||
|
||||
if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
|
||||
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
||||
msg_type
|
||||
][1]
|
||||
|
||||
notification = parser_function(response)
|
||||
return self.node_moving_push_handler_func(notification)
|
||||
|
||||
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
|
||||
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
||||
msg_type
|
||||
][1]
|
||||
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
||||
msg_type
|
||||
][0]
|
||||
notification = parser_function(response, notification_type)
|
||||
|
||||
if notification is not None:
|
||||
return self.maintenance_push_handler_func(notification)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error handling {} message ({}): {}".format(msg_type, response, e)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def set_pubsub_push_handler(self, pubsub_push_handler_func):
|
||||
self.pubsub_push_handler_func = pubsub_push_handler_func
|
||||
|
||||
def set_invalidation_push_handler(self, invalidation_push_handler_func):
|
||||
self.invalidation_push_handler_func = invalidation_push_handler_func
|
||||
|
||||
def set_node_moving_push_handler(self, node_moving_push_handler_func):
|
||||
self.node_moving_push_handler_func = node_moving_push_handler_func
|
||||
|
||||
def set_maintenance_push_handler(self, maintenance_push_handler_func):
|
||||
self.maintenance_push_handler_func = maintenance_push_handler_func
|
||||
|
||||
|
||||
class AsyncPushNotificationsParser(Protocol):
|
||||
"""Protocol defining async RESP3-specific parsing functionality"""
|
||||
|
||||
pubsub_push_handler_func: Callable
|
||||
invalidation_push_handler_func: Optional[Callable] = None
|
||||
node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
|
||||
maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
|
||||
|
||||
async def handle_pubsub_push_response(self, response):
|
||||
"""Handle pubsub push responses asynchronously"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def handle_push_response(self, response, **kwargs):
|
||||
"""Handle push responses asynchronously"""
|
||||
|
||||
msg_type = response[0]
|
||||
if isinstance(msg_type, bytes):
|
||||
msg_type = msg_type.decode()
|
||||
|
||||
if msg_type not in (
|
||||
_INVALIDATION_MESSAGE,
|
||||
*_MAINTENANCE_MESSAGES,
|
||||
_MOVING_MESSAGE,
|
||||
):
|
||||
return await self.pubsub_push_handler_func(response)
|
||||
|
||||
try:
|
||||
if (
|
||||
msg_type == _INVALIDATION_MESSAGE
|
||||
and self.invalidation_push_handler_func
|
||||
):
|
||||
return await self.invalidation_push_handler_func(response)
|
||||
|
||||
if isinstance(msg_type, bytes):
|
||||
msg_type = msg_type.decode()
|
||||
|
||||
if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
|
||||
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
||||
msg_type
|
||||
][1]
|
||||
notification = parser_function(response)
|
||||
return await self.node_moving_push_handler_func(notification)
|
||||
|
||||
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
|
||||
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
||||
msg_type
|
||||
][1]
|
||||
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
|
||||
msg_type
|
||||
][0]
|
||||
notification = parser_function(response, notification_type)
|
||||
|
||||
if notification is not None:
|
||||
return await self.maintenance_push_handler_func(notification)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error handling {} message ({}): {}".format(msg_type, response, e)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def set_pubsub_push_handler(self, pubsub_push_handler_func):
|
||||
"""Set the pubsub push handler function"""
|
||||
self.pubsub_push_handler_func = pubsub_push_handler_func
|
||||
|
||||
def set_invalidation_push_handler(self, invalidation_push_handler_func):
|
||||
"""Set the invalidation push handler function"""
|
||||
self.invalidation_push_handler_func = invalidation_push_handler_func
|
||||
|
||||
def set_node_moving_push_handler(self, node_moving_push_handler_func):
|
||||
self.node_moving_push_handler_func = node_moving_push_handler_func
|
||||
|
||||
def set_maintenance_push_handler(self, maintenance_push_handler_func):
|
||||
self.maintenance_push_handler_func = maintenance_push_handler_func
|
||||
|
||||
|
||||
class _AsyncRESPBase(AsyncBaseParser):
|
||||
"""Base class for async resp parsing"""
|
||||
|
||||
__slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
|
||||
|
||||
def __init__(self, socket_read_size: int):
|
||||
super().__init__(socket_read_size)
|
||||
self.encoder: Optional[Encoder] = None
|
||||
self._buffer = b""
|
||||
self._chunks = []
|
||||
self._pos = 0
|
||||
|
||||
def _clear(self):
|
||||
self._buffer = b""
|
||||
self._chunks.clear()
|
||||
|
||||
def on_connect(self, connection):
|
||||
"""Called when the stream connects"""
|
||||
self._stream = connection._reader
|
||||
if self._stream is None:
|
||||
raise RedisError("Buffer is closed.")
|
||||
self.encoder = connection.encoder
|
||||
self._clear()
|
||||
self._connected = True
|
||||
|
||||
def on_disconnect(self):
|
||||
"""Called when the stream disconnects"""
|
||||
self._connected = False
|
||||
|
||||
async def can_read_destructive(self) -> bool:
|
||||
if not self._connected:
|
||||
raise RedisError("Buffer is closed.")
|
||||
if self._buffer:
|
||||
return True
|
||||
try:
|
||||
async with async_timeout(0):
|
||||
return self._stream.at_eof()
|
||||
except TimeoutError:
|
||||
return False
|
||||
|
||||
async def _read(self, length: int) -> bytes:
|
||||
"""
|
||||
Read `length` bytes of data. These are assumed to be followed
|
||||
by a '\r\n' terminator which is subsequently discarded.
|
||||
"""
|
||||
want = length + 2
|
||||
end = self._pos + want
|
||||
if len(self._buffer) >= end:
|
||||
result = self._buffer[self._pos : end - 2]
|
||||
else:
|
||||
tail = self._buffer[self._pos :]
|
||||
try:
|
||||
data = await self._stream.readexactly(want - len(tail))
|
||||
except IncompleteReadError as error:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
|
||||
result = (tail + data)[:-2]
|
||||
self._chunks.append(data)
|
||||
self._pos += want
|
||||
return result
|
||||
|
||||
async def _readline(self) -> bytes:
|
||||
"""
|
||||
read an unknown number of bytes up to the next '\r\n'
|
||||
line separator, which is discarded.
|
||||
"""
|
||||
found = self._buffer.find(b"\r\n", self._pos)
|
||||
if found >= 0:
|
||||
result = self._buffer[self._pos : found]
|
||||
else:
|
||||
tail = self._buffer[self._pos :]
|
||||
data = await self._stream.readline()
|
||||
if not data.endswith(b"\r\n"):
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
result = (tail + data)[:-2]
|
||||
self._chunks.append(data)
|
||||
self._pos += len(result) + 2
|
||||
return result
|
||||
@@ -0,0 +1,281 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from redis.exceptions import RedisError, ResponseError
|
||||
from redis.utils import str_if_bytes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio.cluster import ClusterNode
|
||||
|
||||
|
||||
class AbstractCommandsParser:
|
||||
def _get_pubsub_keys(self, *args):
|
||||
"""
|
||||
Get the keys from pubsub command.
|
||||
Although PubSub commands have predetermined key locations, they are not
|
||||
supported in the 'COMMAND's output, so the key positions are hardcoded
|
||||
in this method
|
||||
"""
|
||||
if len(args) < 2:
|
||||
# The command has no keys in it
|
||||
return None
|
||||
args = [str_if_bytes(arg) for arg in args]
|
||||
command = args[0].upper()
|
||||
keys = None
|
||||
if command == "PUBSUB":
|
||||
# the second argument is a part of the command name, e.g.
|
||||
# ['PUBSUB', 'NUMSUB', 'foo'].
|
||||
pubsub_type = args[1].upper()
|
||||
if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
|
||||
keys = args[2:]
|
||||
elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
|
||||
# format example:
|
||||
# SUBSCRIBE channel [channel ...]
|
||||
keys = list(args[1:])
|
||||
elif command in ["PUBLISH", "SPUBLISH"]:
|
||||
# format example:
|
||||
# PUBLISH channel message
|
||||
keys = [args[1]]
|
||||
return keys
|
||||
|
||||
def parse_subcommand(self, command, **options):
|
||||
cmd_dict = {}
|
||||
cmd_name = str_if_bytes(command[0])
|
||||
cmd_dict["name"] = cmd_name
|
||||
cmd_dict["arity"] = int(command[1])
|
||||
cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]]
|
||||
cmd_dict["first_key_pos"] = command[3]
|
||||
cmd_dict["last_key_pos"] = command[4]
|
||||
cmd_dict["step_count"] = command[5]
|
||||
if len(command) > 7:
|
||||
cmd_dict["tips"] = command[7]
|
||||
cmd_dict["key_specifications"] = command[8]
|
||||
cmd_dict["subcommands"] = command[9]
|
||||
return cmd_dict
|
||||
|
||||
|
||||
class CommandsParser(AbstractCommandsParser):
|
||||
"""
|
||||
Parses Redis commands to get command keys.
|
||||
COMMAND output is used to determine key locations.
|
||||
Commands that do not have a predefined key location are flagged with
|
||||
'movablekeys', and these commands' keys are determined by the command
|
||||
'COMMAND GETKEYS'.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_connection):
|
||||
self.commands = {}
|
||||
self.initialize(redis_connection)
|
||||
|
||||
def initialize(self, r):
|
||||
commands = r.command()
|
||||
uppercase_commands = []
|
||||
for cmd in commands:
|
||||
if any(x.isupper() for x in cmd):
|
||||
uppercase_commands.append(cmd)
|
||||
for cmd in uppercase_commands:
|
||||
commands[cmd.lower()] = commands.pop(cmd)
|
||||
self.commands = commands
|
||||
|
||||
# As soon as this PR is merged into Redis, we should reimplement
|
||||
# our logic to use COMMAND INFO changes to determine the key positions
|
||||
# https://github.com/redis/redis/pull/8324
|
||||
def get_keys(self, redis_conn, *args):
|
||||
"""
|
||||
Get the keys from the passed command.
|
||||
|
||||
NOTE: Due to a bug in redis<7.0, this function does not work properly
|
||||
for EVAL or EVALSHA when the `numkeys` arg is 0.
|
||||
- issue: https://github.com/redis/redis/issues/9493
|
||||
- fix: https://github.com/redis/redis/pull/9733
|
||||
|
||||
So, don't use this function with EVAL or EVALSHA.
|
||||
"""
|
||||
if len(args) < 2:
|
||||
# The command has no keys in it
|
||||
return None
|
||||
|
||||
cmd_name = args[0].lower()
|
||||
if cmd_name not in self.commands:
|
||||
# try to split the command name and to take only the main command,
|
||||
# e.g. 'memory' for 'memory usage'
|
||||
cmd_name_split = cmd_name.split()
|
||||
cmd_name = cmd_name_split[0]
|
||||
if cmd_name in self.commands:
|
||||
# save the splitted command to args
|
||||
args = cmd_name_split + list(args[1:])
|
||||
else:
|
||||
# We'll try to reinitialize the commands cache, if the engine
|
||||
# version has changed, the commands may not be current
|
||||
self.initialize(redis_conn)
|
||||
if cmd_name not in self.commands:
|
||||
raise RedisError(
|
||||
f"{cmd_name.upper()} command doesn't exist in Redis commands"
|
||||
)
|
||||
|
||||
command = self.commands.get(cmd_name)
|
||||
if "movablekeys" in command["flags"]:
|
||||
keys = self._get_moveable_keys(redis_conn, *args)
|
||||
elif "pubsub" in command["flags"] or command["name"] == "pubsub":
|
||||
keys = self._get_pubsub_keys(*args)
|
||||
else:
|
||||
if (
|
||||
command["step_count"] == 0
|
||||
and command["first_key_pos"] == 0
|
||||
and command["last_key_pos"] == 0
|
||||
):
|
||||
is_subcmd = False
|
||||
if "subcommands" in command:
|
||||
subcmd_name = f"{cmd_name}|{args[1].lower()}"
|
||||
for subcmd in command["subcommands"]:
|
||||
if str_if_bytes(subcmd[0]) == subcmd_name:
|
||||
command = self.parse_subcommand(subcmd)
|
||||
is_subcmd = True
|
||||
|
||||
# The command doesn't have keys in it
|
||||
if not is_subcmd:
|
||||
return None
|
||||
last_key_pos = command["last_key_pos"]
|
||||
if last_key_pos < 0:
|
||||
last_key_pos = len(args) - abs(last_key_pos)
|
||||
keys_pos = list(
|
||||
range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
|
||||
)
|
||||
keys = [args[pos] for pos in keys_pos]
|
||||
|
||||
return keys
|
||||
|
||||
def _get_moveable_keys(self, redis_conn, *args):
|
||||
"""
|
||||
NOTE: Due to a bug in redis<7.0, this function does not work properly
|
||||
for EVAL or EVALSHA when the `numkeys` arg is 0.
|
||||
- issue: https://github.com/redis/redis/issues/9493
|
||||
- fix: https://github.com/redis/redis/pull/9733
|
||||
|
||||
So, don't use this function with EVAL or EVALSHA.
|
||||
"""
|
||||
# The command name should be splitted into separate arguments,
|
||||
# e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE']
|
||||
pieces = args[0].split() + list(args[1:])
|
||||
try:
|
||||
keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces)
|
||||
except ResponseError as e:
|
||||
message = e.__str__()
|
||||
if (
|
||||
"Invalid arguments" in message
|
||||
or "The command has no key arguments" in message
|
||||
):
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
return keys
|
||||
|
||||
|
||||
class AsyncCommandsParser(AbstractCommandsParser):
|
||||
"""
|
||||
Parses Redis commands to get command keys.
|
||||
|
||||
COMMAND output is used to determine key locations.
|
||||
Commands that do not have a predefined key location are flagged with 'movablekeys',
|
||||
and these commands' keys are determined by the command 'COMMAND GETKEYS'.
|
||||
|
||||
NOTE: Due to a bug in redis<7.0, this does not work properly
|
||||
for EVAL or EVALSHA when the `numkeys` arg is 0.
|
||||
- issue: https://github.com/redis/redis/issues/9493
|
||||
- fix: https://github.com/redis/redis/pull/9733
|
||||
|
||||
So, don't use this with EVAL or EVALSHA.
|
||||
"""
|
||||
|
||||
__slots__ = ("commands", "node")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.commands: Dict[str, Union[int, Dict[str, Any]]] = {}
|
||||
|
||||
async def initialize(self, node: Optional["ClusterNode"] = None) -> None:
|
||||
if node:
|
||||
self.node = node
|
||||
|
||||
commands = await self.node.execute_command("COMMAND")
|
||||
self.commands = {cmd.lower(): command for cmd, command in commands.items()}
|
||||
|
||||
# As soon as this PR is merged into Redis, we should reimplement
|
||||
# our logic to use COMMAND INFO changes to determine the key positions
|
||||
# https://github.com/redis/redis/pull/8324
|
||||
async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
|
||||
"""
|
||||
Get the keys from the passed command.
|
||||
|
||||
NOTE: Due to a bug in redis<7.0, this function does not work properly
|
||||
for EVAL or EVALSHA when the `numkeys` arg is 0.
|
||||
- issue: https://github.com/redis/redis/issues/9493
|
||||
- fix: https://github.com/redis/redis/pull/9733
|
||||
|
||||
So, don't use this function with EVAL or EVALSHA.
|
||||
"""
|
||||
if len(args) < 2:
|
||||
# The command has no keys in it
|
||||
return None
|
||||
|
||||
cmd_name = args[0].lower()
|
||||
if cmd_name not in self.commands:
|
||||
# try to split the command name and to take only the main command,
|
||||
# e.g. 'memory' for 'memory usage'
|
||||
cmd_name_split = cmd_name.split()
|
||||
cmd_name = cmd_name_split[0]
|
||||
if cmd_name in self.commands:
|
||||
# save the splitted command to args
|
||||
args = cmd_name_split + list(args[1:])
|
||||
else:
|
||||
# We'll try to reinitialize the commands cache, if the engine
|
||||
# version has changed, the commands may not be current
|
||||
await self.initialize()
|
||||
if cmd_name not in self.commands:
|
||||
raise RedisError(
|
||||
f"{cmd_name.upper()} command doesn't exist in Redis commands"
|
||||
)
|
||||
|
||||
command = self.commands.get(cmd_name)
|
||||
if "movablekeys" in command["flags"]:
|
||||
keys = await self._get_moveable_keys(*args)
|
||||
elif "pubsub" in command["flags"] or command["name"] == "pubsub":
|
||||
keys = self._get_pubsub_keys(*args)
|
||||
else:
|
||||
if (
|
||||
command["step_count"] == 0
|
||||
and command["first_key_pos"] == 0
|
||||
and command["last_key_pos"] == 0
|
||||
):
|
||||
is_subcmd = False
|
||||
if "subcommands" in command:
|
||||
subcmd_name = f"{cmd_name}|{args[1].lower()}"
|
||||
for subcmd in command["subcommands"]:
|
||||
if str_if_bytes(subcmd[0]) == subcmd_name:
|
||||
command = self.parse_subcommand(subcmd)
|
||||
is_subcmd = True
|
||||
|
||||
# The command doesn't have keys in it
|
||||
if not is_subcmd:
|
||||
return None
|
||||
last_key_pos = command["last_key_pos"]
|
||||
if last_key_pos < 0:
|
||||
last_key_pos = len(args) - abs(last_key_pos)
|
||||
keys_pos = list(
|
||||
range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
|
||||
)
|
||||
keys = [args[pos] for pos in keys_pos]
|
||||
|
||||
return keys
|
||||
|
||||
async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
|
||||
try:
|
||||
keys = await self.node.execute_command("COMMAND GETKEYS", *args)
|
||||
except ResponseError as e:
|
||||
message = e.__str__()
|
||||
if (
|
||||
"Invalid arguments" in message
|
||||
or "The command has no key arguments" in message
|
||||
):
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
return keys
|
||||
@@ -0,0 +1,44 @@
|
||||
from ..exceptions import DataError
|
||||
|
||||
|
||||
class Encoder:
|
||||
"Encode strings to bytes-like and decode bytes-like to strings"
|
||||
|
||||
__slots__ = "encoding", "encoding_errors", "decode_responses"
|
||||
|
||||
def __init__(self, encoding, encoding_errors, decode_responses):
|
||||
self.encoding = encoding
|
||||
self.encoding_errors = encoding_errors
|
||||
self.decode_responses = decode_responses
|
||||
|
||||
def encode(self, value):
|
||||
"Return a bytestring or bytes-like representation of the value"
|
||||
if isinstance(value, (bytes, memoryview)):
|
||||
return value
|
||||
elif isinstance(value, bool):
|
||||
# special case bool since it is a subclass of int
|
||||
raise DataError(
|
||||
"Invalid input of type: 'bool'. Convert to a "
|
||||
"bytes, string, int or float first."
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
value = repr(value).encode()
|
||||
elif not isinstance(value, str):
|
||||
# a value we don't know how to deal with. throw an error
|
||||
typename = type(value).__name__
|
||||
raise DataError(
|
||||
f"Invalid input of type: '{typename}'. "
|
||||
f"Convert to a bytes, string, int or float first."
|
||||
)
|
||||
if isinstance(value, str):
|
||||
value = value.encode(self.encoding, self.encoding_errors)
|
||||
return value
|
||||
|
||||
def decode(self, value, force=False):
|
||||
"Return a unicode string from the bytes-like representation"
|
||||
if self.decode_responses or force:
|
||||
if isinstance(value, memoryview):
|
||||
value = value.tobytes()
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode(self.encoding, self.encoding_errors)
|
||||
return value
|
||||
@@ -0,0 +1,941 @@
|
||||
import datetime
|
||||
|
||||
from redis.utils import str_if_bytes
|
||||
|
||||
|
||||
def timestamp_to_datetime(response):
|
||||
"Converts a unix timestamp to a Python datetime object"
|
||||
if not response:
|
||||
return None
|
||||
try:
|
||||
response = int(response)
|
||||
except ValueError:
|
||||
return None
|
||||
return datetime.datetime.fromtimestamp(response)
|
||||
|
||||
|
||||
def parse_debug_object(response):
|
||||
"Parse the results of Redis's DEBUG OBJECT command into a Python dict"
|
||||
# The 'type' of the object is the first item in the response, but isn't
|
||||
# prefixed with a name
|
||||
response = str_if_bytes(response)
|
||||
response = "type:" + response
|
||||
response = dict(kv.split(":") for kv in response.split())
|
||||
|
||||
# parse some expected int values from the string response
|
||||
# note: this cmd isn't spec'd so these may not appear in all redis versions
|
||||
int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle")
|
||||
for field in int_fields:
|
||||
if field in response:
|
||||
response[field] = int(response[field])
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def parse_info(response):
|
||||
"""Parse the result of Redis's INFO command into a Python dict"""
|
||||
info = {}
|
||||
response = str_if_bytes(response)
|
||||
|
||||
def get_value(value):
|
||||
if "," not in value and "=" not in value:
|
||||
try:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
elif "=" not in value:
|
||||
return [get_value(v) for v in value.split(",") if v]
|
||||
else:
|
||||
sub_dict = {}
|
||||
for item in value.split(","):
|
||||
if not item:
|
||||
continue
|
||||
if "=" in item:
|
||||
k, v = item.rsplit("=", 1)
|
||||
sub_dict[k] = get_value(v)
|
||||
else:
|
||||
sub_dict[item] = True
|
||||
return sub_dict
|
||||
|
||||
for line in response.splitlines():
|
||||
if line and not line.startswith("#"):
|
||||
if line.find(":") != -1:
|
||||
# Split, the info fields keys and values.
|
||||
# Note that the value may contain ':'. but the 'host:'
|
||||
# pseudo-command is the only case where the key contains ':'
|
||||
key, value = line.split(":", 1)
|
||||
if key == "cmdstat_host":
|
||||
key, value = line.rsplit(":", 1)
|
||||
|
||||
if key == "module":
|
||||
# Hardcode a list for key 'modules' since there could be
|
||||
# multiple lines that started with 'module'
|
||||
info.setdefault("modules", []).append(get_value(value))
|
||||
else:
|
||||
info[key] = get_value(value)
|
||||
else:
|
||||
# if the line isn't splittable, append it to the "__raw__" key
|
||||
info.setdefault("__raw__", []).append(line)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def parse_memory_stats(response, **kwargs):
|
||||
"""Parse the results of MEMORY STATS"""
|
||||
stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True)
|
||||
for key, value in stats.items():
|
||||
if key.startswith("db.") and isinstance(value, list):
|
||||
stats[key] = pairs_to_dict(
|
||||
value, decode_keys=True, decode_string_values=True
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
SENTINEL_STATE_TYPES = {
|
||||
"can-failover-its-master": int,
|
||||
"config-epoch": int,
|
||||
"down-after-milliseconds": int,
|
||||
"failover-timeout": int,
|
||||
"info-refresh": int,
|
||||
"last-hello-message": int,
|
||||
"last-ok-ping-reply": int,
|
||||
"last-ping-reply": int,
|
||||
"last-ping-sent": int,
|
||||
"master-link-down-time": int,
|
||||
"master-port": int,
|
||||
"num-other-sentinels": int,
|
||||
"num-slaves": int,
|
||||
"o-down-time": int,
|
||||
"pending-commands": int,
|
||||
"parallel-syncs": int,
|
||||
"port": int,
|
||||
"quorum": int,
|
||||
"role-reported-time": int,
|
||||
"s-down-time": int,
|
||||
"slave-priority": int,
|
||||
"slave-repl-offset": int,
|
||||
"voted-leader-epoch": int,
|
||||
}
|
||||
|
||||
|
||||
def parse_sentinel_state(item):
|
||||
result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES)
|
||||
flags = set(result["flags"].split(","))
|
||||
for name, flag in (
|
||||
("is_master", "master"),
|
||||
("is_slave", "slave"),
|
||||
("is_sdown", "s_down"),
|
||||
("is_odown", "o_down"),
|
||||
("is_sentinel", "sentinel"),
|
||||
("is_disconnected", "disconnected"),
|
||||
("is_master_down", "master_down"),
|
||||
):
|
||||
result[name] = flag in flags
|
||||
return result
|
||||
|
||||
|
||||
def parse_sentinel_master(response):
|
||||
return parse_sentinel_state(map(str_if_bytes, response))
|
||||
|
||||
|
||||
def parse_sentinel_state_resp3(response):
|
||||
result = {}
|
||||
for key in response:
|
||||
try:
|
||||
value = SENTINEL_STATE_TYPES[key](str_if_bytes(response[key]))
|
||||
result[str_if_bytes(key)] = value
|
||||
except Exception:
|
||||
result[str_if_bytes(key)] = response[str_if_bytes(key)]
|
||||
flags = set(result["flags"].split(","))
|
||||
result["flags"] = flags
|
||||
return result
|
||||
|
||||
|
||||
def parse_sentinel_masters(response):
|
||||
result = {}
|
||||
for item in response:
|
||||
state = parse_sentinel_state(map(str_if_bytes, item))
|
||||
result[state["name"]] = state
|
||||
return result
|
||||
|
||||
|
||||
def parse_sentinel_masters_resp3(response):
|
||||
return [parse_sentinel_state(master) for master in response]
|
||||
|
||||
|
||||
def parse_sentinel_slaves_and_sentinels(response):
|
||||
return [parse_sentinel_state(map(str_if_bytes, item)) for item in response]
|
||||
|
||||
|
||||
def parse_sentinel_slaves_and_sentinels_resp3(response):
|
||||
return [parse_sentinel_state_resp3(item) for item in response]
|
||||
|
||||
|
||||
def parse_sentinel_get_master(response):
|
||||
return response and (response[0], int(response[1])) or None
|
||||
|
||||
|
||||
def pairs_to_dict(response, decode_keys=False, decode_string_values=False):
|
||||
"""Create a dict given a list of key/value pairs"""
|
||||
if response is None:
|
||||
return {}
|
||||
if decode_keys or decode_string_values:
|
||||
# the iter form is faster, but I don't know how to make that work
|
||||
# with a str_if_bytes() map
|
||||
keys = response[::2]
|
||||
if decode_keys:
|
||||
keys = map(str_if_bytes, keys)
|
||||
values = response[1::2]
|
||||
if decode_string_values:
|
||||
values = map(str_if_bytes, values)
|
||||
return dict(zip(keys, values))
|
||||
else:
|
||||
it = iter(response)
|
||||
return dict(zip(it, it))
|
||||
|
||||
|
||||
def pairs_to_dict_typed(response, type_info):
|
||||
it = iter(response)
|
||||
result = {}
|
||||
for key, value in zip(it, it):
|
||||
if key in type_info:
|
||||
try:
|
||||
value = type_info[key](value)
|
||||
except Exception:
|
||||
# if for some reason the value can't be coerced, just use
|
||||
# the string value
|
||||
pass
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def zset_score_pairs(response, **options):
|
||||
"""
|
||||
If ``withscores`` is specified in the options, return the response as
|
||||
a list of (value, score) pairs
|
||||
"""
|
||||
if not response or not options.get("withscores"):
|
||||
return response
|
||||
score_cast_func = options.get("score_cast_func", float)
|
||||
it = iter(response)
|
||||
return list(zip(it, map(score_cast_func, it)))
|
||||
|
||||
|
||||
def zset_score_for_rank(response, **options):
|
||||
"""
|
||||
If ``withscores`` is specified in the options, return the response as
|
||||
a [value, score] pair
|
||||
"""
|
||||
if not response or not options.get("withscore"):
|
||||
return response
|
||||
score_cast_func = options.get("score_cast_func", float)
|
||||
return [response[0], score_cast_func(response[1])]
|
||||
|
||||
|
||||
def zset_score_pairs_resp3(response, **options):
|
||||
"""
|
||||
If ``withscores`` is specified in the options, return the response as
|
||||
a list of [value, score] pairs
|
||||
"""
|
||||
if not response or not options.get("withscores"):
|
||||
return response
|
||||
score_cast_func = options.get("score_cast_func", float)
|
||||
return [[name, score_cast_func(val)] for name, val in response]
|
||||
|
||||
|
||||
def zset_score_for_rank_resp3(response, **options):
|
||||
"""
|
||||
If ``withscores`` is specified in the options, return the response as
|
||||
a [value, score] pair
|
||||
"""
|
||||
if not response or not options.get("withscore"):
|
||||
return response
|
||||
score_cast_func = options.get("score_cast_func", float)
|
||||
return [response[0], score_cast_func(response[1])]
|
||||
|
||||
|
||||
def sort_return_tuples(response, **options):
|
||||
"""
|
||||
If ``groups`` is specified, return the response as a list of
|
||||
n-element tuples with n being the value found in options['groups']
|
||||
"""
|
||||
if not response or not options.get("groups"):
|
||||
return response
|
||||
n = options["groups"]
|
||||
return list(zip(*[response[i::n] for i in range(n)]))
|
||||
|
||||
|
||||
def parse_stream_list(response):
|
||||
if response is None:
|
||||
return None
|
||||
data = []
|
||||
for r in response:
|
||||
if r is not None:
|
||||
data.append((r[0], pairs_to_dict(r[1])))
|
||||
else:
|
||||
data.append((None, None))
|
||||
return data
|
||||
|
||||
|
||||
def pairs_to_dict_with_str_keys(response):
|
||||
return pairs_to_dict(response, decode_keys=True)
|
||||
|
||||
|
||||
def parse_list_of_dicts(response):
|
||||
return list(map(pairs_to_dict_with_str_keys, response))
|
||||
|
||||
|
||||
def parse_xclaim(response, **options):
|
||||
if options.get("parse_justid", False):
|
||||
return response
|
||||
return parse_stream_list(response)
|
||||
|
||||
|
||||
def parse_xautoclaim(response, **options):
|
||||
if options.get("parse_justid", False):
|
||||
return response[1]
|
||||
response[1] = parse_stream_list(response[1])
|
||||
return response
|
||||
|
||||
|
||||
def parse_xinfo_stream(response, **options):
|
||||
if isinstance(response, list):
|
||||
data = pairs_to_dict(response, decode_keys=True)
|
||||
else:
|
||||
data = {str_if_bytes(k): v for k, v in response.items()}
|
||||
if not options.get("full", False):
|
||||
first = data.get("first-entry")
|
||||
if first is not None and first[0] is not None:
|
||||
data["first-entry"] = (first[0], pairs_to_dict(first[1]))
|
||||
last = data["last-entry"]
|
||||
if last is not None and last[0] is not None:
|
||||
data["last-entry"] = (last[0], pairs_to_dict(last[1]))
|
||||
else:
|
||||
data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]}
|
||||
if len(data["groups"]) > 0 and isinstance(data["groups"][0], list):
|
||||
data["groups"] = [
|
||||
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
|
||||
]
|
||||
for g in data["groups"]:
|
||||
if g["consumers"] and g["consumers"][0] is not None:
|
||||
g["consumers"] = [
|
||||
pairs_to_dict(c, decode_keys=True) for c in g["consumers"]
|
||||
]
|
||||
else:
|
||||
data["groups"] = [
|
||||
{str_if_bytes(k): v for k, v in group.items()}
|
||||
for group in data["groups"]
|
||||
]
|
||||
return data
|
||||
|
||||
|
||||
def parse_xread(response):
|
||||
if response is None:
|
||||
return []
|
||||
return [[r[0], parse_stream_list(r[1])] for r in response]
|
||||
|
||||
|
||||
def parse_xread_resp3(response):
|
||||
if response is None:
|
||||
return {}
|
||||
return {key: [parse_stream_list(value)] for key, value in response.items()}
|
||||
|
||||
|
||||
def parse_xpending(response, **options):
|
||||
if options.get("parse_detail", False):
|
||||
return parse_xpending_range(response)
|
||||
consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []]
|
||||
return {
|
||||
"pending": response[0],
|
||||
"min": response[1],
|
||||
"max": response[2],
|
||||
"consumers": consumers,
|
||||
}
|
||||
|
||||
|
||||
def parse_xpending_range(response):
|
||||
k = ("message_id", "consumer", "time_since_delivered", "times_delivered")
|
||||
return [dict(zip(k, r)) for r in response]
|
||||
|
||||
|
||||
def float_or_none(response):
|
||||
if response is None:
|
||||
return None
|
||||
return float(response)
|
||||
|
||||
|
||||
def bool_ok(response, **options):
|
||||
return str_if_bytes(response) == "OK"
|
||||
|
||||
|
||||
def parse_zadd(response, **options):
|
||||
if response is None:
|
||||
return None
|
||||
if options.get("as_score"):
|
||||
return float(response)
|
||||
return int(response)
|
||||
|
||||
|
||||
def parse_client_list(response, **options):
|
||||
clients = []
|
||||
for c in str_if_bytes(response).splitlines():
|
||||
client_dict = {}
|
||||
tokens = c.split(" ")
|
||||
last_key = None
|
||||
for token in tokens:
|
||||
if "=" in token:
|
||||
# Values might contain '='
|
||||
key, value = token.split("=", 1)
|
||||
client_dict[key] = value
|
||||
last_key = key
|
||||
else:
|
||||
# Values may include spaces. For instance, when running Redis via a Unix socket — such as
|
||||
# "/tmp/redis sock/redis.sock" — the addr or laddr field will include a space.
|
||||
client_dict[last_key] += " " + token
|
||||
|
||||
if client_dict:
|
||||
clients.append(client_dict)
|
||||
return clients
|
||||
|
||||
|
||||
def parse_config_get(response, **options):
|
||||
response = [str_if_bytes(i) if i is not None else None for i in response]
|
||||
return response and pairs_to_dict(response) or {}
|
||||
|
||||
|
||||
def parse_scan(response, **options):
|
||||
cursor, r = response
|
||||
return int(cursor), r
|
||||
|
||||
|
||||
def parse_hscan(response, **options):
|
||||
cursor, r = response
|
||||
no_values = options.get("no_values", False)
|
||||
if no_values:
|
||||
payload = r or []
|
||||
else:
|
||||
payload = r and pairs_to_dict(r) or {}
|
||||
return int(cursor), payload
|
||||
|
||||
|
||||
def parse_zscan(response, **options):
|
||||
score_cast_func = options.get("score_cast_func", float)
|
||||
cursor, r = response
|
||||
it = iter(r)
|
||||
return int(cursor), list(zip(it, map(score_cast_func, it)))
|
||||
|
||||
|
||||
def parse_zmscore(response, **options):
|
||||
# zmscore: list of scores (double precision floating point number) or nil
|
||||
return [float(score) if score is not None else None for score in response]
|
||||
|
||||
|
||||
def parse_slowlog_get(response, **options):
|
||||
space = " " if options.get("decode_responses", False) else b" "
|
||||
|
||||
def parse_item(item):
|
||||
result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])}
|
||||
# Redis Enterprise injects another entry at index [3], which has
|
||||
# the complexity info (i.e. the value N in case the command has
|
||||
# an O(N) complexity) instead of the command.
|
||||
if isinstance(item[3], list):
|
||||
result["command"] = space.join(item[3])
|
||||
|
||||
# These fields are optional, depends on environment.
|
||||
if len(item) >= 6:
|
||||
result["client_address"] = item[4]
|
||||
result["client_name"] = item[5]
|
||||
else:
|
||||
result["complexity"] = item[3]
|
||||
result["command"] = space.join(item[4])
|
||||
|
||||
# These fields are optional, depends on environment.
|
||||
if len(item) >= 7:
|
||||
result["client_address"] = item[5]
|
||||
result["client_name"] = item[6]
|
||||
|
||||
return result
|
||||
|
||||
return [parse_item(item) for item in response]
|
||||
|
||||
|
||||
def parse_stralgo(response, **options):
|
||||
"""
|
||||
Parse the response from `STRALGO` command.
|
||||
Without modifiers the returned value is string.
|
||||
When LEN is given the command returns the length of the result
|
||||
(i.e integer).
|
||||
When IDX is given the command returns a dictionary with the LCS
|
||||
length and all the ranges in both the strings, start and end
|
||||
offset for each string, where there are matches.
|
||||
When WITHMATCHLEN is given, each array representing a match will
|
||||
also have the length of the match at the beginning of the array.
|
||||
"""
|
||||
if options.get("len", False):
|
||||
return int(response)
|
||||
if options.get("idx", False):
|
||||
if options.get("withmatchlen", False):
|
||||
matches = [
|
||||
[(int(match[-1]))] + list(map(tuple, match[:-1]))
|
||||
for match in response[1]
|
||||
]
|
||||
else:
|
||||
matches = [list(map(tuple, match)) for match in response[1]]
|
||||
return {
|
||||
str_if_bytes(response[0]): matches,
|
||||
str_if_bytes(response[2]): int(response[3]),
|
||||
}
|
||||
return str_if_bytes(response)
|
||||
|
||||
|
||||
def parse_cluster_info(response, **options):
|
||||
response = str_if_bytes(response)
|
||||
return dict(line.split(":") for line in response.splitlines() if line)
|
||||
|
||||
|
||||
def _parse_node_line(line):
|
||||
line_items = line.split(" ")
|
||||
node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8]
|
||||
ip = addr.split("@")[0]
|
||||
hostname = addr.split("@")[1].split(",")[1] if "@" in addr and "," in addr else ""
|
||||
node_dict = {
|
||||
"node_id": node_id,
|
||||
"hostname": hostname,
|
||||
"flags": flags,
|
||||
"master_id": master_id,
|
||||
"last_ping_sent": ping,
|
||||
"last_pong_rcvd": pong,
|
||||
"epoch": epoch,
|
||||
"slots": [],
|
||||
"migrations": [],
|
||||
"connected": True if connected == "connected" else False,
|
||||
}
|
||||
if len(line_items) >= 9:
|
||||
slots, migrations = _parse_slots(line_items[8:])
|
||||
node_dict["slots"], node_dict["migrations"] = slots, migrations
|
||||
return ip, node_dict
|
||||
|
||||
|
||||
def _parse_slots(slot_ranges):
|
||||
slots, migrations = [], []
|
||||
for s_range in slot_ranges:
|
||||
if "->-" in s_range:
|
||||
slot_id, dst_node_id = s_range[1:-1].split("->-", 1)
|
||||
migrations.append(
|
||||
{"slot": slot_id, "node_id": dst_node_id, "state": "migrating"}
|
||||
)
|
||||
elif "-<-" in s_range:
|
||||
slot_id, src_node_id = s_range[1:-1].split("-<-", 1)
|
||||
migrations.append(
|
||||
{"slot": slot_id, "node_id": src_node_id, "state": "importing"}
|
||||
)
|
||||
else:
|
||||
s_range = [sl for sl in s_range.split("-")]
|
||||
slots.append(s_range)
|
||||
|
||||
return slots, migrations
|
||||
|
||||
|
||||
def parse_cluster_nodes(response, **options):
|
||||
"""
|
||||
@see: https://redis.io/commands/cluster-nodes # string / bytes
|
||||
@see: https://redis.io/commands/cluster-replicas # list of string / bytes
|
||||
"""
|
||||
if isinstance(response, (str, bytes)):
|
||||
response = response.splitlines()
|
||||
return dict(_parse_node_line(str_if_bytes(node)) for node in response)
|
||||
|
||||
|
||||
def parse_geosearch_generic(response, **options):
|
||||
"""
|
||||
Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER'
|
||||
commands according to 'withdist', 'withhash' and 'withcoord' labels.
|
||||
"""
|
||||
try:
|
||||
if options["store"] or options["store_dist"]:
|
||||
# `store` and `store_dist` cant be combined
|
||||
# with other command arguments.
|
||||
# relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER'
|
||||
return response
|
||||
except KeyError: # it means the command was sent via execute_command
|
||||
return response
|
||||
|
||||
if not isinstance(response, list):
|
||||
response_list = [response]
|
||||
else:
|
||||
response_list = response
|
||||
|
||||
if not options["withdist"] and not options["withcoord"] and not options["withhash"]:
|
||||
# just a bunch of places
|
||||
return response_list
|
||||
|
||||
cast = {
|
||||
"withdist": float,
|
||||
"withcoord": lambda ll: (float(ll[0]), float(ll[1])),
|
||||
"withhash": int,
|
||||
}
|
||||
|
||||
# zip all output results with each casting function to get
|
||||
# the properly native Python value.
|
||||
f = [lambda x: x]
|
||||
f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]]
|
||||
return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list]
|
||||
|
||||
|
||||
def parse_command(response, **options):
|
||||
commands = {}
|
||||
for command in response:
|
||||
cmd_dict = {}
|
||||
cmd_name = str_if_bytes(command[0])
|
||||
cmd_dict["name"] = cmd_name
|
||||
cmd_dict["arity"] = int(command[1])
|
||||
cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]]
|
||||
cmd_dict["first_key_pos"] = command[3]
|
||||
cmd_dict["last_key_pos"] = command[4]
|
||||
cmd_dict["step_count"] = command[5]
|
||||
if len(command) > 7:
|
||||
cmd_dict["tips"] = command[7]
|
||||
cmd_dict["key_specifications"] = command[8]
|
||||
cmd_dict["subcommands"] = command[9]
|
||||
commands[cmd_name] = cmd_dict
|
||||
return commands
|
||||
|
||||
|
||||
def parse_command_resp3(response, **options):
|
||||
commands = {}
|
||||
for command in response:
|
||||
cmd_dict = {}
|
||||
cmd_name = str_if_bytes(command[0])
|
||||
cmd_dict["name"] = cmd_name
|
||||
cmd_dict["arity"] = command[1]
|
||||
cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]}
|
||||
cmd_dict["first_key_pos"] = command[3]
|
||||
cmd_dict["last_key_pos"] = command[4]
|
||||
cmd_dict["step_count"] = command[5]
|
||||
cmd_dict["acl_categories"] = command[6]
|
||||
if len(command) > 7:
|
||||
cmd_dict["tips"] = command[7]
|
||||
cmd_dict["key_specifications"] = command[8]
|
||||
cmd_dict["subcommands"] = command[9]
|
||||
|
||||
commands[cmd_name] = cmd_dict
|
||||
return commands
|
||||
|
||||
|
||||
def parse_pubsub_numsub(response, **options):
|
||||
return list(zip(response[0::2], response[1::2]))
|
||||
|
||||
|
||||
def parse_client_kill(response, **options):
|
||||
if isinstance(response, int):
|
||||
return response
|
||||
return str_if_bytes(response) == "OK"
|
||||
|
||||
|
||||
def parse_acl_getuser(response, **options):
|
||||
if response is None:
|
||||
return None
|
||||
if isinstance(response, list):
|
||||
data = pairs_to_dict(response, decode_keys=True)
|
||||
else:
|
||||
data = {str_if_bytes(key): value for key, value in response.items()}
|
||||
|
||||
# convert everything but user-defined data in 'keys' to native strings
|
||||
data["flags"] = list(map(str_if_bytes, data["flags"]))
|
||||
data["passwords"] = list(map(str_if_bytes, data["passwords"]))
|
||||
data["commands"] = str_if_bytes(data["commands"])
|
||||
if isinstance(data["keys"], str) or isinstance(data["keys"], bytes):
|
||||
data["keys"] = list(str_if_bytes(data["keys"]).split(" "))
|
||||
if data["keys"] == [""]:
|
||||
data["keys"] = []
|
||||
if "channels" in data:
|
||||
if isinstance(data["channels"], str) or isinstance(data["channels"], bytes):
|
||||
data["channels"] = list(str_if_bytes(data["channels"]).split(" "))
|
||||
if data["channels"] == [""]:
|
||||
data["channels"] = []
|
||||
if "selectors" in data:
|
||||
if data["selectors"] != [] and isinstance(data["selectors"][0], list):
|
||||
data["selectors"] = [
|
||||
list(map(str_if_bytes, selector)) for selector in data["selectors"]
|
||||
]
|
||||
elif data["selectors"] != []:
|
||||
data["selectors"] = [
|
||||
{str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()}
|
||||
for selector in data["selectors"]
|
||||
]
|
||||
|
||||
# split 'commands' into separate 'categories' and 'commands' lists
|
||||
commands, categories = [], []
|
||||
for command in data["commands"].split(" "):
|
||||
categories.append(command) if "@" in command else commands.append(command)
|
||||
|
||||
data["commands"] = commands
|
||||
data["categories"] = categories
|
||||
data["enabled"] = "on" in data["flags"]
|
||||
return data
|
||||
|
||||
|
||||
def parse_acl_log(response, **options):
|
||||
if response is None:
|
||||
return None
|
||||
if isinstance(response, list):
|
||||
data = []
|
||||
for log in response:
|
||||
log_data = pairs_to_dict(log, True, True)
|
||||
client_info = log_data.get("client-info", "")
|
||||
log_data["client-info"] = parse_client_info(client_info)
|
||||
|
||||
# float() is lossy comparing to the "double" in C
|
||||
log_data["age-seconds"] = float(log_data["age-seconds"])
|
||||
data.append(log_data)
|
||||
else:
|
||||
data = bool_ok(response)
|
||||
return data
|
||||
|
||||
|
||||
def parse_client_info(value):
|
||||
"""
|
||||
Parsing client-info in ACL Log in following format.
|
||||
"key1=value1 key2=value2 key3=value3"
|
||||
"""
|
||||
client_info = {}
|
||||
for info in str_if_bytes(value).strip().split():
|
||||
key, value = info.split("=")
|
||||
client_info[key] = value
|
||||
|
||||
# Those fields are defined as int in networking.c
|
||||
for int_key in {
|
||||
"id",
|
||||
"age",
|
||||
"idle",
|
||||
"db",
|
||||
"sub",
|
||||
"psub",
|
||||
"multi",
|
||||
"qbuf",
|
||||
"qbuf-free",
|
||||
"obl",
|
||||
"argv-mem",
|
||||
"oll",
|
||||
"omem",
|
||||
"tot-mem",
|
||||
}:
|
||||
if int_key in client_info:
|
||||
client_info[int_key] = int(client_info[int_key])
|
||||
return client_info
|
||||
|
||||
|
||||
def parse_set_result(response, **options):
|
||||
"""
|
||||
Handle SET result since GET argument is available since Redis 6.2.
|
||||
Parsing SET result into:
|
||||
- BOOL
|
||||
- String when GET argument is used
|
||||
"""
|
||||
if options.get("get"):
|
||||
# Redis will return a getCommand result.
|
||||
# See `setGenericCommand` in t_string.c
|
||||
return response
|
||||
return response and str_if_bytes(response) == "OK"
|
||||
|
||||
|
||||
def string_keys_to_dict(key_string, callback):
|
||||
return dict.fromkeys(key_string.split(), callback)
|
||||
|
||||
|
||||
_RedisCallbacks = {
|
||||
**string_keys_to_dict(
|
||||
"AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST PSETEX "
|
||||
"PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE",
|
||||
bool,
|
||||
),
|
||||
**string_keys_to_dict("HINCRBYFLOAT INCRBYFLOAT", float),
|
||||
**string_keys_to_dict(
|
||||
"ASKING FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE "
|
||||
"RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH",
|
||||
bool_ok,
|
||||
),
|
||||
**string_keys_to_dict("XREAD XREADGROUP", parse_xread),
|
||||
**string_keys_to_dict(
|
||||
"GEORADIUS GEORADIUSBYMEMBER GEOSEARCH",
|
||||
parse_geosearch_generic,
|
||||
),
|
||||
**string_keys_to_dict("XRANGE XREVRANGE", parse_stream_list),
|
||||
"ACL GETUSER": parse_acl_getuser,
|
||||
"ACL LOAD": bool_ok,
|
||||
"ACL LOG": parse_acl_log,
|
||||
"ACL SETUSER": bool_ok,
|
||||
"ACL SAVE": bool_ok,
|
||||
"CLIENT INFO": parse_client_info,
|
||||
"CLIENT KILL": parse_client_kill,
|
||||
"CLIENT LIST": parse_client_list,
|
||||
"CLIENT PAUSE": bool_ok,
|
||||
"CLIENT SETINFO": bool_ok,
|
||||
"CLIENT SETNAME": bool_ok,
|
||||
"CLIENT UNBLOCK": bool,
|
||||
"CLUSTER ADDSLOTS": bool_ok,
|
||||
"CLUSTER ADDSLOTSRANGE": bool_ok,
|
||||
"CLUSTER DELSLOTS": bool_ok,
|
||||
"CLUSTER DELSLOTSRANGE": bool_ok,
|
||||
"CLUSTER FAILOVER": bool_ok,
|
||||
"CLUSTER FORGET": bool_ok,
|
||||
"CLUSTER INFO": parse_cluster_info,
|
||||
"CLUSTER MEET": bool_ok,
|
||||
"CLUSTER NODES": parse_cluster_nodes,
|
||||
"CLUSTER REPLICAS": parse_cluster_nodes,
|
||||
"CLUSTER REPLICATE": bool_ok,
|
||||
"CLUSTER RESET": bool_ok,
|
||||
"CLUSTER SAVECONFIG": bool_ok,
|
||||
"CLUSTER SET-CONFIG-EPOCH": bool_ok,
|
||||
"CLUSTER SETSLOT": bool_ok,
|
||||
"CLUSTER SLAVES": parse_cluster_nodes,
|
||||
"COMMAND": parse_command,
|
||||
"CONFIG RESETSTAT": bool_ok,
|
||||
"CONFIG SET": bool_ok,
|
||||
"FUNCTION DELETE": bool_ok,
|
||||
"FUNCTION FLUSH": bool_ok,
|
||||
"FUNCTION RESTORE": bool_ok,
|
||||
"GEODIST": float_or_none,
|
||||
"HSCAN": parse_hscan,
|
||||
"INFO": parse_info,
|
||||
"LASTSAVE": timestamp_to_datetime,
|
||||
"MEMORY PURGE": bool_ok,
|
||||
"MODULE LOAD": bool,
|
||||
"MODULE UNLOAD": bool,
|
||||
"PING": lambda r: str_if_bytes(r) == "PONG",
|
||||
"PUBSUB NUMSUB": parse_pubsub_numsub,
|
||||
"PUBSUB SHARDNUMSUB": parse_pubsub_numsub,
|
||||
"QUIT": bool_ok,
|
||||
"SET": parse_set_result,
|
||||
"SCAN": parse_scan,
|
||||
"SCRIPT EXISTS": lambda r: list(map(bool, r)),
|
||||
"SCRIPT FLUSH": bool_ok,
|
||||
"SCRIPT KILL": bool_ok,
|
||||
"SCRIPT LOAD": str_if_bytes,
|
||||
"SENTINEL CKQUORUM": bool_ok,
|
||||
"SENTINEL FAILOVER": bool_ok,
|
||||
"SENTINEL FLUSHCONFIG": bool_ok,
|
||||
"SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master,
|
||||
"SENTINEL MONITOR": bool_ok,
|
||||
"SENTINEL RESET": bool_ok,
|
||||
"SENTINEL REMOVE": bool_ok,
|
||||
"SENTINEL SET": bool_ok,
|
||||
"SLOWLOG GET": parse_slowlog_get,
|
||||
"SLOWLOG RESET": bool_ok,
|
||||
"SORT": sort_return_tuples,
|
||||
"SSCAN": parse_scan,
|
||||
"TIME": lambda x: (int(x[0]), int(x[1])),
|
||||
"XAUTOCLAIM": parse_xautoclaim,
|
||||
"XCLAIM": parse_xclaim,
|
||||
"XGROUP CREATE": bool_ok,
|
||||
"XGROUP DESTROY": bool,
|
||||
"XGROUP SETID": bool_ok,
|
||||
"XINFO STREAM": parse_xinfo_stream,
|
||||
"XPENDING": parse_xpending,
|
||||
"ZSCAN": parse_zscan,
|
||||
}
|
||||
|
||||
|
||||
_RedisCallbacksRESP2 = {
|
||||
**string_keys_to_dict(
|
||||
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
|
||||
),
|
||||
**string_keys_to_dict(
|
||||
"ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZREVRANGE "
|
||||
"ZREVRANGEBYSCORE ZUNION",
|
||||
zset_score_pairs,
|
||||
),
|
||||
**string_keys_to_dict(
|
||||
"ZREVRANK ZRANK",
|
||||
zset_score_for_rank,
|
||||
),
|
||||
**string_keys_to_dict("ZINCRBY ZSCORE", float_or_none),
|
||||
**string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True),
|
||||
**string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None),
|
||||
**string_keys_to_dict(
|
||||
"BZPOPMAX BZPOPMIN", lambda r: r and (r[0], r[1], float(r[2])) or None
|
||||
),
|
||||
"ACL CAT": lambda r: list(map(str_if_bytes, r)),
|
||||
"ACL GENPASS": str_if_bytes,
|
||||
"ACL HELP": lambda r: list(map(str_if_bytes, r)),
|
||||
"ACL LIST": lambda r: list(map(str_if_bytes, r)),
|
||||
"ACL USERS": lambda r: list(map(str_if_bytes, r)),
|
||||
"ACL WHOAMI": str_if_bytes,
|
||||
"CLIENT GETNAME": str_if_bytes,
|
||||
"CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)),
|
||||
"CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)),
|
||||
"COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)),
|
||||
"CONFIG GET": parse_config_get,
|
||||
"DEBUG OBJECT": parse_debug_object,
|
||||
"GEOHASH": lambda r: list(map(str_if_bytes, r)),
|
||||
"GEOPOS": lambda r: list(
|
||||
map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r)
|
||||
),
|
||||
"HGETALL": lambda r: r and pairs_to_dict(r) or {},
|
||||
"MEMORY STATS": parse_memory_stats,
|
||||
"MODULE LIST": lambda r: [pairs_to_dict(m) for m in r],
|
||||
"RESET": str_if_bytes,
|
||||
"SENTINEL MASTER": parse_sentinel_master,
|
||||
"SENTINEL MASTERS": parse_sentinel_masters,
|
||||
"SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels,
|
||||
"SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels,
|
||||
"STRALGO": parse_stralgo,
|
||||
"XINFO CONSUMERS": parse_list_of_dicts,
|
||||
"XINFO GROUPS": parse_list_of_dicts,
|
||||
"ZADD": parse_zadd,
|
||||
"ZMSCORE": parse_zmscore,
|
||||
}
|
||||
|
||||
|
||||
_RedisCallbacksRESP3 = {
|
||||
**string_keys_to_dict(
|
||||
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
|
||||
),
|
||||
**string_keys_to_dict(
|
||||
"ZRANGE ZINTER ZPOPMAX ZPOPMIN HGETALL XREADGROUP",
|
||||
lambda r, **kwargs: r,
|
||||
),
|
||||
**string_keys_to_dict(
|
||||
"ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE ZUNION",
|
||||
zset_score_pairs_resp3,
|
||||
),
|
||||
**string_keys_to_dict(
|
||||
"ZREVRANK ZRANK",
|
||||
zset_score_for_rank_resp3,
|
||||
),
|
||||
**string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3),
|
||||
"ACL LOG": lambda r: (
|
||||
[
|
||||
{str_if_bytes(key): str_if_bytes(value) for key, value in x.items()}
|
||||
for x in r
|
||||
]
|
||||
if isinstance(r, list)
|
||||
else bool_ok(r)
|
||||
),
|
||||
"COMMAND": parse_command_resp3,
|
||||
"CONFIG GET": lambda r: {
|
||||
str_if_bytes(key) if key is not None else None: (
|
||||
str_if_bytes(value) if value is not None else None
|
||||
)
|
||||
for key, value in r.items()
|
||||
},
|
||||
"MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()},
|
||||
"SENTINEL MASTER": parse_sentinel_state_resp3,
|
||||
"SENTINEL MASTERS": parse_sentinel_masters_resp3,
|
||||
"SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3,
|
||||
"SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3,
|
||||
"STRALGO": lambda r, **options: (
|
||||
{str_if_bytes(key): str_if_bytes(value) for key, value in r.items()}
|
||||
if isinstance(r, dict)
|
||||
else str_if_bytes(r)
|
||||
),
|
||||
"XINFO CONSUMERS": lambda r: [
|
||||
{str_if_bytes(key): value for key, value in x.items()} for x in r
|
||||
],
|
||||
"XINFO GROUPS": lambda r: [
|
||||
{str_if_bytes(key): value for key, value in d.items()} for d in r
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
import asyncio
|
||||
import socket
|
||||
import sys
|
||||
from logging import getLogger
|
||||
from typing import Callable, List, Optional, TypedDict, Union
|
||||
|
||||
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
|
||||
from asyncio import timeout as async_timeout
|
||||
else:
|
||||
from async_timeout import timeout as async_timeout
|
||||
|
||||
from ..exceptions import ConnectionError, InvalidResponse, RedisError
|
||||
from ..typing import EncodableT
|
||||
from ..utils import HIREDIS_AVAILABLE
|
||||
from .base import (
|
||||
AsyncBaseParser,
|
||||
AsyncPushNotificationsParser,
|
||||
BaseParser,
|
||||
PushNotificationsParser,
|
||||
)
|
||||
from .socket import (
|
||||
NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
|
||||
NONBLOCKING_EXCEPTIONS,
|
||||
SENTINEL,
|
||||
SERVER_CLOSED_CONNECTION_ERROR,
|
||||
)
|
||||
|
||||
# Used to signal that hiredis-py does not have enough data to parse.
|
||||
# Using `False` or `None` is not reliable, given that the parser can
|
||||
# return `False` or `None` for legitimate reasons from RESP payloads.
|
||||
NOT_ENOUGH_DATA = object()
|
||||
|
||||
|
||||
class _HiredisReaderArgs(TypedDict, total=False):
|
||||
protocolError: Callable[[str], Exception]
|
||||
replyError: Callable[[str], Exception]
|
||||
encoding: Optional[str]
|
||||
errors: Optional[str]
|
||||
|
||||
|
||||
class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
"Parser class for connections using Hiredis"
|
||||
|
||||
def __init__(self, socket_read_size):
|
||||
if not HIREDIS_AVAILABLE:
|
||||
raise RedisError("Hiredis is not installed")
|
||||
self.socket_read_size = socket_read_size
|
||||
self._buffer = bytearray(socket_read_size)
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.node_moving_push_handler_func = None
|
||||
self.maintenance_push_handler_func = None
|
||||
self.invalidation_push_handler_func = None
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.on_disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def handle_pubsub_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
def on_connect(self, connection, **kwargs):
|
||||
import hiredis
|
||||
|
||||
self._sock = connection._sock
|
||||
self._socket_timeout = connection.socket_timeout
|
||||
kwargs = {
|
||||
"protocolError": InvalidResponse,
|
||||
"replyError": self.parse_error,
|
||||
"errors": connection.encoder.encoding_errors,
|
||||
"notEnoughData": NOT_ENOUGH_DATA,
|
||||
}
|
||||
|
||||
if connection.encoder.decode_responses:
|
||||
kwargs["encoding"] = connection.encoder.encoding
|
||||
self._reader = hiredis.Reader(**kwargs)
|
||||
self._next_response = NOT_ENOUGH_DATA
|
||||
|
||||
try:
|
||||
self._hiredis_PushNotificationType = hiredis.PushNotification
|
||||
except AttributeError:
|
||||
# hiredis < 3.2
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
def on_disconnect(self):
|
||||
self._sock = None
|
||||
self._reader = None
|
||||
self._next_response = NOT_ENOUGH_DATA
|
||||
|
||||
def can_read(self, timeout):
|
||||
if not self._reader:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
|
||||
if self._next_response is NOT_ENOUGH_DATA:
|
||||
self._next_response = self._reader.gets()
|
||||
if self._next_response is NOT_ENOUGH_DATA:
|
||||
return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
|
||||
return True
|
||||
|
||||
def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
|
||||
sock = self._sock
|
||||
custom_timeout = timeout is not SENTINEL
|
||||
try:
|
||||
if custom_timeout:
|
||||
sock.settimeout(timeout)
|
||||
bufflen = self._sock.recv_into(self._buffer)
|
||||
if bufflen == 0:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
self._reader.feed(self._buffer, 0, bufflen)
|
||||
# data was read from the socket and added to the buffer.
|
||||
# return True to indicate that data was read.
|
||||
return True
|
||||
except socket.timeout:
|
||||
if raise_on_timeout:
|
||||
raise TimeoutError("Timeout reading from socket")
|
||||
return False
|
||||
except NONBLOCKING_EXCEPTIONS as ex:
|
||||
# if we're in nonblocking mode and the recv raises a
|
||||
# blocking error, simply return False indicating that
|
||||
# there's no data to be read. otherwise raise the
|
||||
# original exception.
|
||||
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
|
||||
if not raise_on_timeout and ex.errno == allowed:
|
||||
return False
|
||||
raise ConnectionError(f"Error while reading from socket: {ex.args}")
|
||||
finally:
|
||||
if custom_timeout:
|
||||
sock.settimeout(self._socket_timeout)
|
||||
|
||||
def read_response(self, disable_decoding=False, push_request=False):
|
||||
if not self._reader:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
|
||||
# _next_response might be cached from a can_read() call
|
||||
if self._next_response is not NOT_ENOUGH_DATA:
|
||||
response = self._next_response
|
||||
self._next_response = NOT_ENOUGH_DATA
|
||||
if self._hiredis_PushNotificationType is not None and isinstance(
|
||||
response, self._hiredis_PushNotificationType
|
||||
):
|
||||
response = self.handle_push_response(response)
|
||||
|
||||
# if this is a push request return the push response
|
||||
if push_request:
|
||||
return response
|
||||
|
||||
return self.read_response(
|
||||
disable_decoding=disable_decoding,
|
||||
push_request=push_request,
|
||||
)
|
||||
return response
|
||||
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
|
||||
while response is NOT_ENOUGH_DATA:
|
||||
self.read_from_socket()
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
# if the response is a ConnectionError or the response is a list and
|
||||
# the first item is a ConnectionError, raise it as something bad
|
||||
# happened
|
||||
if isinstance(response, ConnectionError):
|
||||
raise response
|
||||
elif self._hiredis_PushNotificationType is not None and isinstance(
|
||||
response, self._hiredis_PushNotificationType
|
||||
):
|
||||
response = self.handle_push_response(response)
|
||||
if push_request:
|
||||
return response
|
||||
return self.read_response(
|
||||
disable_decoding=disable_decoding,
|
||||
push_request=push_request,
|
||||
)
|
||||
|
||||
elif (
|
||||
isinstance(response, list)
|
||||
and response
|
||||
and isinstance(response[0], ConnectionError)
|
||||
):
|
||||
raise response[0]
|
||||
return response
|
||||
|
||||
|
||||
class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
"""Async implementation of parser class for connections using Hiredis"""
|
||||
|
||||
__slots__ = ("_reader",)
|
||||
|
||||
def __init__(self, socket_read_size: int):
|
||||
if not HIREDIS_AVAILABLE:
|
||||
raise RedisError("Hiredis is not available.")
|
||||
super().__init__(socket_read_size=socket_read_size)
|
||||
self._reader = None
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
async def handle_pubsub_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
def on_connect(self, connection):
|
||||
import hiredis
|
||||
|
||||
self._stream = connection._reader
|
||||
kwargs: _HiredisReaderArgs = {
|
||||
"protocolError": InvalidResponse,
|
||||
"replyError": self.parse_error,
|
||||
"notEnoughData": NOT_ENOUGH_DATA,
|
||||
}
|
||||
if connection.encoder.decode_responses:
|
||||
kwargs["encoding"] = connection.encoder.encoding
|
||||
kwargs["errors"] = connection.encoder.encoding_errors
|
||||
|
||||
self._reader = hiredis.Reader(**kwargs)
|
||||
self._connected = True
|
||||
|
||||
try:
|
||||
self._hiredis_PushNotificationType = getattr(
|
||||
hiredis, "PushNotification", None
|
||||
)
|
||||
except AttributeError:
|
||||
# hiredis < 3.2
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
def on_disconnect(self):
|
||||
self._connected = False
|
||||
|
||||
async def can_read_destructive(self):
|
||||
if not self._connected:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
if self._reader.gets() is not NOT_ENOUGH_DATA:
|
||||
return True
|
||||
try:
|
||||
async with async_timeout(0):
|
||||
return await self.read_from_socket()
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
|
||||
async def read_from_socket(self):
|
||||
buffer = await self._stream.read(self._read_size)
|
||||
if not buffer or not isinstance(buffer, bytes):
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
|
||||
self._reader.feed(buffer)
|
||||
# data was read from the socket and added to the buffer.
|
||||
# return True to indicate that data was read.
|
||||
return True
|
||||
|
||||
async def read_response(
|
||||
self, disable_decoding: bool = False, push_request: bool = False
|
||||
) -> Union[EncodableT, List[EncodableT]]:
|
||||
# If `on_disconnect()` has been called, prohibit any more reads
|
||||
# even if they could happen because data might be present.
|
||||
# We still allow reads in progress to finish
|
||||
if not self._connected:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
|
||||
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
|
||||
while response is NOT_ENOUGH_DATA:
|
||||
await self.read_from_socket()
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
|
||||
# if the response is a ConnectionError or the response is a list and
|
||||
# the first item is a ConnectionError, raise it as something bad
|
||||
# happened
|
||||
if isinstance(response, ConnectionError):
|
||||
raise response
|
||||
elif self._hiredis_PushNotificationType is not None and isinstance(
|
||||
response, self._hiredis_PushNotificationType
|
||||
):
|
||||
response = await self.handle_push_response(response)
|
||||
if not push_request:
|
||||
return await self.read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
elif (
|
||||
isinstance(response, list)
|
||||
and response
|
||||
and isinstance(response[0], ConnectionError)
|
||||
):
|
||||
raise response[0]
|
||||
return response
|
||||
132
backend/venv/lib/python3.9/site-packages/redis/_parsers/resp2.py
Normal file
132
backend/venv/lib/python3.9/site-packages/redis/_parsers/resp2.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
|
||||
from ..typing import EncodableT
|
||||
from .base import _AsyncRESPBase, _RESPBase
|
||||
from .socket import SERVER_CLOSED_CONNECTION_ERROR
|
||||
|
||||
|
||||
class _RESP2Parser(_RESPBase):
|
||||
"""RESP2 protocol implementation"""
|
||||
|
||||
def read_response(self, disable_decoding=False):
|
||||
pos = self._buffer.get_pos() if self._buffer else None
|
||||
try:
|
||||
result = self._read_response(disable_decoding=disable_decoding)
|
||||
except BaseException:
|
||||
if self._buffer:
|
||||
self._buffer.rewind(pos)
|
||||
raise
|
||||
else:
|
||||
self._buffer.purge()
|
||||
return result
|
||||
|
||||
def _read_response(self, disable_decoding=False):
|
||||
raw = self._buffer.readline()
|
||||
if not raw:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
|
||||
byte, response = raw[:1], raw[1:]
|
||||
|
||||
# server returned an error
|
||||
if byte == b"-":
|
||||
response = response.decode("utf-8", errors="replace")
|
||||
error = self.parse_error(response)
|
||||
# if the error is a ConnectionError, raise immediately so the user
|
||||
# is notified
|
||||
if isinstance(error, ConnectionError):
|
||||
raise error
|
||||
# otherwise, we're dealing with a ResponseError that might belong
|
||||
# inside a pipeline response. the connection's read_response()
|
||||
# and/or the pipeline's execute() will raise this error if
|
||||
# necessary, so just return the exception instance here.
|
||||
return error
|
||||
# single value
|
||||
elif byte == b"+":
|
||||
pass
|
||||
# int value
|
||||
elif byte == b":":
|
||||
return int(response)
|
||||
# bulk response
|
||||
elif byte == b"$" and response == b"-1":
|
||||
return None
|
||||
elif byte == b"$":
|
||||
response = self._buffer.read(int(response))
|
||||
# multi-bulk response
|
||||
elif byte == b"*" and response == b"-1":
|
||||
return None
|
||||
elif byte == b"*":
|
||||
response = [
|
||||
self._read_response(disable_decoding=disable_decoding)
|
||||
for i in range(int(response))
|
||||
]
|
||||
else:
|
||||
raise InvalidResponse(f"Protocol Error: {raw!r}")
|
||||
|
||||
if disable_decoding is False:
|
||||
response = self.encoder.decode(response)
|
||||
return response
|
||||
|
||||
|
||||
class _AsyncRESP2Parser(_AsyncRESPBase):
|
||||
"""Async class for the RESP2 protocol"""
|
||||
|
||||
async def read_response(self, disable_decoding: bool = False):
|
||||
if not self._connected:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
if self._chunks:
|
||||
# augment parsing buffer with previously read data
|
||||
self._buffer += b"".join(self._chunks)
|
||||
self._chunks.clear()
|
||||
self._pos = 0
|
||||
response = await self._read_response(disable_decoding=disable_decoding)
|
||||
# Successfully parsing a response allows us to clear our parsing buffer
|
||||
self._clear()
|
||||
return response
|
||||
|
||||
async def _read_response(
|
||||
self, disable_decoding: bool = False
|
||||
) -> Union[EncodableT, ResponseError, None]:
|
||||
raw = await self._readline()
|
||||
response: Any
|
||||
byte, response = raw[:1], raw[1:]
|
||||
|
||||
# server returned an error
|
||||
if byte == b"-":
|
||||
response = response.decode("utf-8", errors="replace")
|
||||
error = self.parse_error(response)
|
||||
# if the error is a ConnectionError, raise immediately so the user
|
||||
# is notified
|
||||
if isinstance(error, ConnectionError):
|
||||
self._clear() # Successful parse
|
||||
raise error
|
||||
# otherwise, we're dealing with a ResponseError that might belong
|
||||
# inside a pipeline response. the connection's read_response()
|
||||
# and/or the pipeline's execute() will raise this error if
|
||||
# necessary, so just return the exception instance here.
|
||||
return error
|
||||
# single value
|
||||
elif byte == b"+":
|
||||
pass
|
||||
# int value
|
||||
elif byte == b":":
|
||||
return int(response)
|
||||
# bulk response
|
||||
elif byte == b"$" and response == b"-1":
|
||||
return None
|
||||
elif byte == b"$":
|
||||
response = await self._read(int(response))
|
||||
# multi-bulk response
|
||||
elif byte == b"*" and response == b"-1":
|
||||
return None
|
||||
elif byte == b"*":
|
||||
response = [
|
||||
(await self._read_response(disable_decoding))
|
||||
for _ in range(int(response)) # noqa
|
||||
]
|
||||
else:
|
||||
raise InvalidResponse(f"Protocol Error: {raw!r}")
|
||||
|
||||
if disable_decoding is False:
|
||||
response = self.encoder.decode(response)
|
||||
return response
|
||||
263
backend/venv/lib/python3.9/site-packages/redis/_parsers/resp3.py
Normal file
263
backend/venv/lib/python3.9/site-packages/redis/_parsers/resp3.py
Normal file
@@ -0,0 +1,263 @@
|
||||
from logging import getLogger
|
||||
from typing import Any, Union
|
||||
|
||||
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
|
||||
from ..typing import EncodableT
|
||||
from .base import (
|
||||
AsyncPushNotificationsParser,
|
||||
PushNotificationsParser,
|
||||
_AsyncRESPBase,
|
||||
_RESPBase,
|
||||
)
|
||||
from .socket import SERVER_CLOSED_CONNECTION_ERROR
|
||||
|
||||
|
||||
class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
||||
"""RESP3 protocol implementation"""
|
||||
|
||||
def __init__(self, socket_read_size):
|
||||
super().__init__(socket_read_size)
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.node_moving_push_handler_func = None
|
||||
self.maintenance_push_handler_func = None
|
||||
self.invalidation_push_handler_func = None
|
||||
|
||||
def handle_pubsub_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
def read_response(self, disable_decoding=False, push_request=False):
|
||||
pos = self._buffer.get_pos() if self._buffer else None
|
||||
try:
|
||||
result = self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
except BaseException:
|
||||
if self._buffer:
|
||||
self._buffer.rewind(pos)
|
||||
raise
|
||||
else:
|
||||
self._buffer.purge()
|
||||
return result
|
||||
|
||||
def _read_response(self, disable_decoding=False, push_request=False):
|
||||
raw = self._buffer.readline()
|
||||
if not raw:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
|
||||
byte, response = raw[:1], raw[1:]
|
||||
|
||||
# server returned an error
|
||||
if byte in (b"-", b"!"):
|
||||
if byte == b"!":
|
||||
response = self._buffer.read(int(response))
|
||||
response = response.decode("utf-8", errors="replace")
|
||||
error = self.parse_error(response)
|
||||
# if the error is a ConnectionError, raise immediately so the user
|
||||
# is notified
|
||||
if isinstance(error, ConnectionError):
|
||||
raise error
|
||||
# otherwise, we're dealing with a ResponseError that might belong
|
||||
# inside a pipeline response. the connection's read_response()
|
||||
# and/or the pipeline's execute() will raise this error if
|
||||
# necessary, so just return the exception instance here.
|
||||
return error
|
||||
# single value
|
||||
elif byte == b"+":
|
||||
pass
|
||||
# null value
|
||||
elif byte == b"_":
|
||||
return None
|
||||
# int and big int values
|
||||
elif byte in (b":", b"("):
|
||||
return int(response)
|
||||
# double value
|
||||
elif byte == b",":
|
||||
return float(response)
|
||||
# bool value
|
||||
elif byte == b"#":
|
||||
return response == b"t"
|
||||
# bulk response
|
||||
elif byte == b"$":
|
||||
response = self._buffer.read(int(response))
|
||||
# verbatim string response
|
||||
elif byte == b"=":
|
||||
response = self._buffer.read(int(response))[4:]
|
||||
# array response
|
||||
elif byte == b"*":
|
||||
response = [
|
||||
self._read_response(disable_decoding=disable_decoding)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
# set response
|
||||
elif byte == b"~":
|
||||
# redis can return unhashable types (like dict) in a set,
|
||||
# so we return sets as list, all the time, for predictability
|
||||
response = [
|
||||
self._read_response(disable_decoding=disable_decoding)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
# map response
|
||||
elif byte == b"%":
|
||||
# We cannot use a dict-comprehension to parse stream.
|
||||
# Evaluation order of key:val expression in dict comprehension only
|
||||
# became defined to be left-right in version 3.8
|
||||
resp_dict = {}
|
||||
for _ in range(int(response)):
|
||||
key = self._read_response(disable_decoding=disable_decoding)
|
||||
resp_dict[key] = self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
response = resp_dict
|
||||
# push response
|
||||
elif byte == b">":
|
||||
response = [
|
||||
self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
response = self.handle_push_response(response)
|
||||
|
||||
# if this is a push request return the push response
|
||||
if push_request:
|
||||
return response
|
||||
|
||||
return self._read_response(
|
||||
disable_decoding=disable_decoding,
|
||||
push_request=push_request,
|
||||
)
|
||||
else:
|
||||
raise InvalidResponse(f"Protocol Error: {raw!r}")
|
||||
|
||||
if isinstance(response, bytes) and disable_decoding is False:
|
||||
response = self.encoder.decode(response)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
|
||||
def __init__(self, socket_read_size):
|
||||
super().__init__(socket_read_size)
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
|
||||
async def handle_pubsub_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
async def read_response(
|
||||
self, disable_decoding: bool = False, push_request: bool = False
|
||||
):
|
||||
if self._chunks:
|
||||
# augment parsing buffer with previously read data
|
||||
self._buffer += b"".join(self._chunks)
|
||||
self._chunks.clear()
|
||||
self._pos = 0
|
||||
response = await self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
# Successfully parsing a response allows us to clear our parsing buffer
|
||||
self._clear()
|
||||
return response
|
||||
|
||||
async def _read_response(
|
||||
self, disable_decoding: bool = False, push_request: bool = False
|
||||
) -> Union[EncodableT, ResponseError, None]:
|
||||
if not self._stream or not self.encoder:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
raw = await self._readline()
|
||||
response: Any
|
||||
byte, response = raw[:1], raw[1:]
|
||||
|
||||
# if byte not in (b"-", b"+", b":", b"$", b"*"):
|
||||
# raise InvalidResponse(f"Protocol Error: {raw!r}")
|
||||
|
||||
# server returned an error
|
||||
if byte in (b"-", b"!"):
|
||||
if byte == b"!":
|
||||
response = await self._read(int(response))
|
||||
response = response.decode("utf-8", errors="replace")
|
||||
error = self.parse_error(response)
|
||||
# if the error is a ConnectionError, raise immediately so the user
|
||||
# is notified
|
||||
if isinstance(error, ConnectionError):
|
||||
self._clear() # Successful parse
|
||||
raise error
|
||||
# otherwise, we're dealing with a ResponseError that might belong
|
||||
# inside a pipeline response. the connection's read_response()
|
||||
# and/or the pipeline's execute() will raise this error if
|
||||
# necessary, so just return the exception instance here.
|
||||
return error
|
||||
# single value
|
||||
elif byte == b"+":
|
||||
pass
|
||||
# null value
|
||||
elif byte == b"_":
|
||||
return None
|
||||
# int and big int values
|
||||
elif byte in (b":", b"("):
|
||||
return int(response)
|
||||
# double value
|
||||
elif byte == b",":
|
||||
return float(response)
|
||||
# bool value
|
||||
elif byte == b"#":
|
||||
return response == b"t"
|
||||
# bulk response
|
||||
elif byte == b"$":
|
||||
response = await self._read(int(response))
|
||||
# verbatim string response
|
||||
elif byte == b"=":
|
||||
response = (await self._read(int(response)))[4:]
|
||||
# array response
|
||||
elif byte == b"*":
|
||||
response = [
|
||||
(await self._read_response(disable_decoding=disable_decoding))
|
||||
for _ in range(int(response))
|
||||
]
|
||||
# set response
|
||||
elif byte == b"~":
|
||||
# redis can return unhashable types (like dict) in a set,
|
||||
# so we always convert to a list, to have predictable return types
|
||||
response = [
|
||||
(await self._read_response(disable_decoding=disable_decoding))
|
||||
for _ in range(int(response))
|
||||
]
|
||||
# map response
|
||||
elif byte == b"%":
|
||||
# We cannot use a dict-comprehension to parse stream.
|
||||
# Evaluation order of key:val expression in dict comprehension only
|
||||
# became defined to be left-right in version 3.8
|
||||
resp_dict = {}
|
||||
for _ in range(int(response)):
|
||||
key = await self._read_response(disable_decoding=disable_decoding)
|
||||
resp_dict[key] = await self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
response = resp_dict
|
||||
# push response
|
||||
elif byte == b">":
|
||||
response = [
|
||||
(
|
||||
await self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
response = await self.handle_push_response(response)
|
||||
if not push_request:
|
||||
return await self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
else:
|
||||
raise InvalidResponse(f"Protocol Error: {raw!r}")
|
||||
|
||||
if isinstance(response, bytes) and disable_decoding is False:
|
||||
response = self.encoder.decode(response)
|
||||
return response
|
||||
@@ -0,0 +1,162 @@
|
||||
import errno
|
||||
import io
|
||||
import socket
|
||||
from io import SEEK_END
|
||||
from typing import Optional, Union
|
||||
|
||||
from ..exceptions import ConnectionError, TimeoutError
|
||||
from ..utils import SSL_AVAILABLE
|
||||
|
||||
NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK}
|
||||
|
||||
if SSL_AVAILABLE:
|
||||
import ssl
|
||||
|
||||
if hasattr(ssl, "SSLWantReadError"):
|
||||
NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2
|
||||
NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2
|
||||
else:
|
||||
NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2
|
||||
|
||||
NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
|
||||
|
||||
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
|
||||
SENTINEL = object()
|
||||
|
||||
SYM_CRLF = b"\r\n"
|
||||
|
||||
|
||||
class SocketBuffer:
|
||||
def __init__(
|
||||
self, socket: socket.socket, socket_read_size: int, socket_timeout: float
|
||||
):
|
||||
self._sock = socket
|
||||
self.socket_read_size = socket_read_size
|
||||
self.socket_timeout = socket_timeout
|
||||
self._buffer = io.BytesIO()
|
||||
|
||||
def unread_bytes(self) -> int:
|
||||
"""
|
||||
Remaining unread length of buffer
|
||||
"""
|
||||
pos = self._buffer.tell()
|
||||
end = self._buffer.seek(0, SEEK_END)
|
||||
self._buffer.seek(pos)
|
||||
return end - pos
|
||||
|
||||
def _read_from_socket(
|
||||
self,
|
||||
length: Optional[int] = None,
|
||||
timeout: Union[float, object] = SENTINEL,
|
||||
raise_on_timeout: Optional[bool] = True,
|
||||
) -> bool:
|
||||
sock = self._sock
|
||||
socket_read_size = self.socket_read_size
|
||||
marker = 0
|
||||
custom_timeout = timeout is not SENTINEL
|
||||
|
||||
buf = self._buffer
|
||||
current_pos = buf.tell()
|
||||
buf.seek(0, SEEK_END)
|
||||
if custom_timeout:
|
||||
sock.settimeout(timeout)
|
||||
try:
|
||||
while True:
|
||||
data = self._sock.recv(socket_read_size)
|
||||
# an empty string indicates the server shutdown the socket
|
||||
if isinstance(data, bytes) and len(data) == 0:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
buf.write(data)
|
||||
data_length = len(data)
|
||||
marker += data_length
|
||||
|
||||
if length is not None and length > marker:
|
||||
continue
|
||||
return True
|
||||
except socket.timeout:
|
||||
if raise_on_timeout:
|
||||
raise TimeoutError("Timeout reading from socket")
|
||||
return False
|
||||
except NONBLOCKING_EXCEPTIONS as ex:
|
||||
# if we're in nonblocking mode and the recv raises a
|
||||
# blocking error, simply return False indicating that
|
||||
# there's no data to be read. otherwise raise the
|
||||
# original exception.
|
||||
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
|
||||
if not raise_on_timeout and ex.errno == allowed:
|
||||
return False
|
||||
raise ConnectionError(f"Error while reading from socket: {ex.args}")
|
||||
finally:
|
||||
buf.seek(current_pos)
|
||||
if custom_timeout:
|
||||
sock.settimeout(self.socket_timeout)
|
||||
|
||||
def can_read(self, timeout: float) -> bool:
|
||||
return bool(self.unread_bytes()) or self._read_from_socket(
|
||||
timeout=timeout, raise_on_timeout=False
|
||||
)
|
||||
|
||||
def read(self, length: int) -> bytes:
|
||||
length = length + 2 # make sure to read the \r\n terminator
|
||||
# BufferIO will return less than requested if buffer is short
|
||||
data = self._buffer.read(length)
|
||||
missing = length - len(data)
|
||||
if missing:
|
||||
# fill up the buffer and read the remainder
|
||||
self._read_from_socket(missing)
|
||||
data += self._buffer.read(missing)
|
||||
return data[:-2]
|
||||
|
||||
def readline(self) -> bytes:
|
||||
buf = self._buffer
|
||||
data = buf.readline()
|
||||
while not data.endswith(SYM_CRLF):
|
||||
# there's more data in the socket that we need
|
||||
self._read_from_socket()
|
||||
data += buf.readline()
|
||||
|
||||
return data[:-2]
|
||||
|
||||
def get_pos(self) -> int:
|
||||
"""
|
||||
Get current read position
|
||||
"""
|
||||
return self._buffer.tell()
|
||||
|
||||
def rewind(self, pos: int) -> None:
|
||||
"""
|
||||
Rewind the buffer to a specific position, to re-start reading
|
||||
"""
|
||||
self._buffer.seek(pos)
|
||||
|
||||
def purge(self) -> None:
|
||||
"""
|
||||
After a successful read, purge the read part of buffer
|
||||
"""
|
||||
unread = self.unread_bytes()
|
||||
|
||||
# Only if we have read all of the buffer do we truncate, to
|
||||
# reduce the amount of memory thrashing. This heuristic
|
||||
# can be changed or removed later.
|
||||
if unread > 0:
|
||||
return
|
||||
|
||||
if unread > 0:
|
||||
# move unread data to the front
|
||||
view = self._buffer.getbuffer()
|
||||
view[:unread] = view[-unread:]
|
||||
self._buffer.truncate(unread)
|
||||
self._buffer.seek(0)
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._buffer.close()
|
||||
except Exception:
|
||||
# issue #633 suggests the purge/close somehow raised a
|
||||
# BadFileDescriptor error. Perhaps the client ran out of
|
||||
# memory or something else? It's probably OK to ignore
|
||||
# any error being raised from purge/close since we're
|
||||
# removing the reference to the instance below.
|
||||
pass
|
||||
self._buffer = None
|
||||
self._sock = None
|
||||
@@ -0,0 +1,64 @@
|
||||
from redis.asyncio.client import Redis, StrictRedis
|
||||
from redis.asyncio.cluster import RedisCluster
|
||||
from redis.asyncio.connection import (
|
||||
BlockingConnectionPool,
|
||||
Connection,
|
||||
ConnectionPool,
|
||||
SSLConnection,
|
||||
UnixDomainSocketConnection,
|
||||
)
|
||||
from redis.asyncio.sentinel import (
|
||||
Sentinel,
|
||||
SentinelConnectionPool,
|
||||
SentinelManagedConnection,
|
||||
SentinelManagedSSLConnection,
|
||||
)
|
||||
from redis.asyncio.utils import from_url
|
||||
from redis.backoff import default_backoff
|
||||
from redis.exceptions import (
|
||||
AuthenticationError,
|
||||
AuthenticationWrongNumberOfArgsError,
|
||||
BusyLoadingError,
|
||||
ChildDeadlockedError,
|
||||
ConnectionError,
|
||||
DataError,
|
||||
InvalidResponse,
|
||||
OutOfMemoryError,
|
||||
PubSubError,
|
||||
ReadOnlyError,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
WatchError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AuthenticationError",
|
||||
"AuthenticationWrongNumberOfArgsError",
|
||||
"BlockingConnectionPool",
|
||||
"BusyLoadingError",
|
||||
"ChildDeadlockedError",
|
||||
"Connection",
|
||||
"ConnectionError",
|
||||
"ConnectionPool",
|
||||
"DataError",
|
||||
"from_url",
|
||||
"default_backoff",
|
||||
"InvalidResponse",
|
||||
"PubSubError",
|
||||
"OutOfMemoryError",
|
||||
"ReadOnlyError",
|
||||
"Redis",
|
||||
"RedisCluster",
|
||||
"RedisError",
|
||||
"ResponseError",
|
||||
"Sentinel",
|
||||
"SentinelConnectionPool",
|
||||
"SentinelManagedConnection",
|
||||
"SentinelManagedSSLConnection",
|
||||
"SSLConnection",
|
||||
"StrictRedis",
|
||||
"TimeoutError",
|
||||
"UnixDomainSocketConnection",
|
||||
"WatchError",
|
||||
]
|
||||
1675
backend/venv/lib/python3.9/site-packages/redis/asyncio/client.py
Normal file
1675
backend/venv/lib/python3.9/site-packages/redis/asyncio/client.py
Normal file
File diff suppressed because it is too large
Load Diff
2448
backend/venv/lib/python3.9/site-packages/redis/asyncio/cluster.py
Normal file
2448
backend/venv/lib/python3.9/site-packages/redis/asyncio/cluster.py
Normal file
File diff suppressed because it is too large
Load Diff
1399
backend/venv/lib/python3.9/site-packages/redis/asyncio/connection.py
Normal file
1399
backend/venv/lib/python3.9/site-packages/redis/asyncio/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,265 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
from redis.http.http_client import HttpClient, HttpResponse
|
||||
|
||||
DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)"
|
||||
DEFAULT_TIMEOUT = 30.0
|
||||
RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
|
||||
|
||||
class AsyncHTTPClient(ABC):
|
||||
@abstractmethod
|
||||
async def get(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP GET request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP DELETE request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def post(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP POST request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def put(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP PUT request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def patch(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP PATCH request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
body: Optional[Union[bytes, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> HttpResponse:
|
||||
"""
|
||||
Invoke HTTP request with given method."""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncHTTPClientWrapper(AsyncHTTPClient):
|
||||
"""
|
||||
An async wrapper around sync HTTP client with thread pool execution.
|
||||
"""
|
||||
|
||||
def __init__(self, client: HttpClient, max_workers: int = 10) -> None:
|
||||
"""
|
||||
Initialize a new HTTP client instance.
|
||||
|
||||
Args:
|
||||
client: Sync HTTP client instance.
|
||||
max_workers: Maximum number of concurrent requests.
|
||||
|
||||
The client supports both regular HTTPS with server verification and mutual TLS
|
||||
authentication. For server verification, provide CA certificate information via
|
||||
ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client
|
||||
certificate and key via client_cert_file and client_key_file.
|
||||
"""
|
||||
self.client = client
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor, self.client.get, path, params, headers, timeout, expect_json
|
||||
)
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.delete,
|
||||
path,
|
||||
params,
|
||||
headers,
|
||||
timeout,
|
||||
expect_json,
|
||||
)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.post,
|
||||
path,
|
||||
json_body,
|
||||
data,
|
||||
params,
|
||||
headers,
|
||||
timeout,
|
||||
expect_json,
|
||||
)
|
||||
|
||||
async def put(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.put,
|
||||
path,
|
||||
json_body,
|
||||
data,
|
||||
params,
|
||||
headers,
|
||||
timeout,
|
||||
expect_json,
|
||||
)
|
||||
|
||||
async def patch(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.patch,
|
||||
path,
|
||||
json_body,
|
||||
data,
|
||||
params,
|
||||
headers,
|
||||
timeout,
|
||||
expect_json,
|
||||
)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
body: Optional[Union[bytes, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> HttpResponse:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.request,
|
||||
method,
|
||||
path,
|
||||
params,
|
||||
headers,
|
||||
body,
|
||||
timeout,
|
||||
)
|
||||
334
backend/venv/lib/python3.9/site-packages/redis/asyncio/lock.py
Normal file
334
backend/venv/lib/python3.9/site-packages/redis/asyncio/lock.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Awaitable, Optional, Union
|
||||
|
||||
from redis.exceptions import LockError, LockNotOwnedError
|
||||
from redis.typing import Number
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Lock:
|
||||
"""
|
||||
A shared, distributed Lock. Using Redis for locking allows the Lock
|
||||
to be shared across processes and/or machines.
|
||||
|
||||
It's left to the user to resolve deadlock issues and make sure
|
||||
multiple clients play nicely together.
|
||||
"""
|
||||
|
||||
lua_release = None
|
||||
lua_extend = None
|
||||
lua_reacquire = None
|
||||
|
||||
# KEYS[1] - lock name
|
||||
# ARGV[1] - token
|
||||
# return 1 if the lock was released, otherwise 0
|
||||
LUA_RELEASE_SCRIPT = """
|
||||
local token = redis.call('get', KEYS[1])
|
||||
if not token or token ~= ARGV[1] then
|
||||
return 0
|
||||
end
|
||||
redis.call('del', KEYS[1])
|
||||
return 1
|
||||
"""
|
||||
|
||||
# KEYS[1] - lock name
|
||||
# ARGV[1] - token
|
||||
# ARGV[2] - additional milliseconds
|
||||
# ARGV[3] - "0" if the additional time should be added to the lock's
|
||||
# existing ttl or "1" if the existing ttl should be replaced
|
||||
# return 1 if the locks time was extended, otherwise 0
|
||||
LUA_EXTEND_SCRIPT = """
|
||||
local token = redis.call('get', KEYS[1])
|
||||
if not token or token ~= ARGV[1] then
|
||||
return 0
|
||||
end
|
||||
local expiration = redis.call('pttl', KEYS[1])
|
||||
if not expiration then
|
||||
expiration = 0
|
||||
end
|
||||
if expiration < 0 then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newttl = ARGV[2]
|
||||
if ARGV[3] == "0" then
|
||||
newttl = ARGV[2] + expiration
|
||||
end
|
||||
redis.call('pexpire', KEYS[1], newttl)
|
||||
return 1
|
||||
"""
|
||||
|
||||
# KEYS[1] - lock name
|
||||
# ARGV[1] - token
|
||||
# ARGV[2] - milliseconds
|
||||
# return 1 if the locks time was reacquired, otherwise 0
|
||||
LUA_REACQUIRE_SCRIPT = """
|
||||
local token = redis.call('get', KEYS[1])
|
||||
if not token or token ~= ARGV[1] then
|
||||
return 0
|
||||
end
|
||||
redis.call('pexpire', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis: Union["Redis", "RedisCluster"],
|
||||
name: Union[str, bytes, memoryview],
|
||||
timeout: Optional[float] = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
thread_local: bool = True,
|
||||
raise_on_release_error: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a new Lock instance named ``name`` using the Redis client
|
||||
supplied by ``redis``.
|
||||
|
||||
``timeout`` indicates a maximum life for the lock in seconds.
|
||||
By default, it will remain locked until release() is called.
|
||||
``timeout`` can be specified as a float or integer, both representing
|
||||
the number of seconds to wait.
|
||||
|
||||
``sleep`` indicates the amount of time to sleep in seconds per loop
|
||||
iteration when the lock is in blocking mode and another client is
|
||||
currently holding the lock.
|
||||
|
||||
``blocking`` indicates whether calling ``acquire`` should block until
|
||||
the lock has been acquired or to fail immediately, causing ``acquire``
|
||||
to return False and the lock not being acquired. Defaults to True.
|
||||
Note this value can be overridden by passing a ``blocking``
|
||||
argument to ``acquire``.
|
||||
|
||||
``blocking_timeout`` indicates the maximum amount of time in seconds to
|
||||
spend trying to acquire the lock. A value of ``None`` indicates
|
||||
continue trying forever. ``blocking_timeout`` can be specified as a
|
||||
float or integer, both representing the number of seconds to wait.
|
||||
|
||||
``thread_local`` indicates whether the lock token is placed in
|
||||
thread-local storage. By default, the token is placed in thread local
|
||||
storage so that a thread only sees its token, not a token set by
|
||||
another thread. Consider the following timeline:
|
||||
|
||||
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
|
||||
thread-1 sets the token to "abc"
|
||||
time: 1, thread-2 blocks trying to acquire `my-lock` using the
|
||||
Lock instance.
|
||||
time: 5, thread-1 has not yet completed. redis expires the lock
|
||||
key.
|
||||
time: 5, thread-2 acquired `my-lock` now that it's available.
|
||||
thread-2 sets the token to "xyz"
|
||||
time: 6, thread-1 finishes its work and calls release(). if the
|
||||
token is *not* stored in thread local storage, then
|
||||
thread-1 would see the token value as "xyz" and would be
|
||||
able to successfully release the thread-2's lock.
|
||||
|
||||
``raise_on_release_error`` indicates whether to raise an exception when
|
||||
the lock is no longer owned when exiting the context manager. By default,
|
||||
this is True, meaning an exception will be raised. If False, the warning
|
||||
will be logged and the exception will be suppressed.
|
||||
|
||||
In some use cases it's necessary to disable thread local storage. For
|
||||
example, if you have code where one thread acquires a lock and passes
|
||||
that lock instance to a worker thread to release later. If thread
|
||||
local storage isn't disabled in this case, the worker thread won't see
|
||||
the token set by the thread that acquired the lock. Our assumption
|
||||
is that these cases aren't common and as such default to using
|
||||
thread local storage.
|
||||
"""
|
||||
self.redis = redis
|
||||
self.name = name
|
||||
self.timeout = timeout
|
||||
self.sleep = sleep
|
||||
self.blocking = blocking
|
||||
self.blocking_timeout = blocking_timeout
|
||||
self.thread_local = bool(thread_local)
|
||||
self.local = threading.local() if self.thread_local else SimpleNamespace()
|
||||
self.raise_on_release_error = raise_on_release_error
|
||||
self.local.token = None
|
||||
self.register_scripts()
|
||||
|
||||
def register_scripts(self):
|
||||
cls = self.__class__
|
||||
client = self.redis
|
||||
if cls.lua_release is None:
|
||||
cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
|
||||
if cls.lua_extend is None:
|
||||
cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
|
||||
if cls.lua_reacquire is None:
|
||||
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
|
||||
|
||||
async def __aenter__(self):
|
||||
if await self.acquire():
|
||||
return self
|
||||
raise LockError("Unable to acquire lock within the time specified")
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
try:
|
||||
await self.release()
|
||||
except LockError:
|
||||
if self.raise_on_release_error:
|
||||
raise
|
||||
logger.warning(
|
||||
"Lock was unlocked or no longer owned when exiting context manager."
|
||||
)
|
||||
|
||||
async def acquire(
|
||||
self,
|
||||
blocking: Optional[bool] = None,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
token: Optional[Union[str, bytes]] = None,
|
||||
):
|
||||
"""
|
||||
Use Redis to hold a shared, distributed lock named ``name``.
|
||||
Returns True once the lock is acquired.
|
||||
|
||||
If ``blocking`` is False, always return immediately. If the lock
|
||||
was acquired, return True, otherwise return False.
|
||||
|
||||
``blocking_timeout`` specifies the maximum number of seconds to
|
||||
wait trying to acquire the lock.
|
||||
|
||||
``token`` specifies the token value to be used. If provided, token
|
||||
must be a bytes object or a string that can be encoded to a bytes
|
||||
object with the default encoding. If a token isn't specified, a UUID
|
||||
will be generated.
|
||||
"""
|
||||
sleep = self.sleep
|
||||
if token is None:
|
||||
token = uuid.uuid1().hex.encode()
|
||||
else:
|
||||
try:
|
||||
encoder = self.redis.connection_pool.get_encoder()
|
||||
except AttributeError:
|
||||
# Cluster
|
||||
encoder = self.redis.get_encoder()
|
||||
token = encoder.encode(token)
|
||||
if blocking is None:
|
||||
blocking = self.blocking
|
||||
if blocking_timeout is None:
|
||||
blocking_timeout = self.blocking_timeout
|
||||
stop_trying_at = None
|
||||
if blocking_timeout is not None:
|
||||
stop_trying_at = asyncio.get_running_loop().time() + blocking_timeout
|
||||
while True:
|
||||
if await self.do_acquire(token):
|
||||
self.local.token = token
|
||||
return True
|
||||
if not blocking:
|
||||
return False
|
||||
next_try_at = asyncio.get_running_loop().time() + sleep
|
||||
if stop_trying_at is not None and next_try_at > stop_trying_at:
|
||||
return False
|
||||
await asyncio.sleep(sleep)
|
||||
|
||||
async def do_acquire(self, token: Union[str, bytes]) -> bool:
|
||||
if self.timeout:
|
||||
# convert to milliseconds
|
||||
timeout = int(self.timeout * 1000)
|
||||
else:
|
||||
timeout = None
|
||||
if await self.redis.set(self.name, token, nx=True, px=timeout):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def locked(self) -> bool:
|
||||
"""
|
||||
Returns True if this key is locked by any process, otherwise False.
|
||||
"""
|
||||
return await self.redis.get(self.name) is not None
|
||||
|
||||
async def owned(self) -> bool:
|
||||
"""
|
||||
Returns True if this key is locked by this lock, otherwise False.
|
||||
"""
|
||||
stored_token = await self.redis.get(self.name)
|
||||
# need to always compare bytes to bytes
|
||||
# TODO: this can be simplified when the context manager is finished
|
||||
if stored_token and not isinstance(stored_token, bytes):
|
||||
try:
|
||||
encoder = self.redis.connection_pool.get_encoder()
|
||||
except AttributeError:
|
||||
# Cluster
|
||||
encoder = self.redis.get_encoder()
|
||||
stored_token = encoder.encode(stored_token)
|
||||
return self.local.token is not None and stored_token == self.local.token
|
||||
|
||||
def release(self) -> Awaitable[None]:
|
||||
"""Releases the already acquired lock"""
|
||||
expected_token = self.local.token
|
||||
if expected_token is None:
|
||||
raise LockError(
|
||||
"Cannot release a lock that's not owned or is already unlocked.",
|
||||
lock_name=self.name,
|
||||
)
|
||||
self.local.token = None
|
||||
return self.do_release(expected_token)
|
||||
|
||||
async def do_release(self, expected_token: bytes) -> None:
|
||||
if not bool(
|
||||
await self.lua_release(
|
||||
keys=[self.name], args=[expected_token], client=self.redis
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
|
||||
|
||||
def extend(
|
||||
self, additional_time: Number, replace_ttl: bool = False
|
||||
) -> Awaitable[bool]:
|
||||
"""
|
||||
Adds more time to an already acquired lock.
|
||||
|
||||
``additional_time`` can be specified as an integer or a float, both
|
||||
representing the number of seconds to add.
|
||||
|
||||
``replace_ttl`` if False (the default), add `additional_time` to
|
||||
the lock's existing ttl. If True, replace the lock's ttl with
|
||||
`additional_time`.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot extend an unlocked lock")
|
||||
if self.timeout is None:
|
||||
raise LockError("Cannot extend a lock with no timeout")
|
||||
return self.do_extend(additional_time, replace_ttl)
|
||||
|
||||
async def do_extend(self, additional_time, replace_ttl) -> bool:
|
||||
additional_time = int(additional_time * 1000)
|
||||
if not bool(
|
||||
await self.lua_extend(
|
||||
keys=[self.name],
|
||||
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
|
||||
client=self.redis,
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
|
||||
return True
|
||||
|
||||
def reacquire(self) -> Awaitable[bool]:
|
||||
"""
|
||||
Resets a TTL of an already acquired lock back to a timeout value.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot reacquire an unlocked lock")
|
||||
if self.timeout is None:
|
||||
raise LockError("Cannot reacquire a lock with no timeout")
|
||||
return self.do_reacquire()
|
||||
|
||||
async def do_reacquire(self) -> bool:
|
||||
timeout = int(self.timeout * 1000)
|
||||
if not bool(
|
||||
await self.lua_reacquire(
|
||||
keys=[self.name], args=[self.local.token, timeout], client=self.redis
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
|
||||
return True
|
||||
@@ -0,0 +1,530 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union
|
||||
|
||||
from redis.asyncio.client import PubSubHandler
|
||||
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
|
||||
from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
|
||||
from redis.asyncio.multidb.database import AsyncDatabase, Databases
|
||||
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
|
||||
from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy
|
||||
from redis.background import BackgroundScheduler
|
||||
from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands
|
||||
from redis.multidb.circuit import CircuitBreaker
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
|
||||
from redis.typing import ChannelT, EncodableT, KeyT
|
||||
from redis.utils import experimental
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@experimental
|
||||
class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
|
||||
"""
|
||||
Client that operates on multiple logical Redis databases.
|
||||
Should be used in Active-Active database setups.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiDbConfig):
|
||||
self._databases = config.databases()
|
||||
self._health_checks = config.default_health_checks()
|
||||
|
||||
if config.health_checks is not None:
|
||||
self._health_checks.extend(config.health_checks)
|
||||
|
||||
self._health_check_interval = config.health_check_interval
|
||||
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
|
||||
config.health_check_probes, config.health_check_delay
|
||||
)
|
||||
self._failure_detectors = config.default_failure_detectors()
|
||||
|
||||
if config.failure_detectors is not None:
|
||||
self._failure_detectors.extend(config.failure_detectors)
|
||||
|
||||
self._failover_strategy = (
|
||||
config.default_failover_strategy()
|
||||
if config.failover_strategy is None
|
||||
else config.failover_strategy
|
||||
)
|
||||
self._failover_strategy.set_databases(self._databases)
|
||||
self._auto_fallback_interval = config.auto_fallback_interval
|
||||
self._event_dispatcher = config.event_dispatcher
|
||||
self._command_retry = config.command_retry
|
||||
self._command_retry.update_supported_errors([ConnectionRefusedError])
|
||||
self.command_executor = DefaultCommandExecutor(
|
||||
failure_detectors=self._failure_detectors,
|
||||
databases=self._databases,
|
||||
command_retry=self._command_retry,
|
||||
failover_strategy=self._failover_strategy,
|
||||
failover_attempts=config.failover_attempts,
|
||||
failover_delay=config.failover_delay,
|
||||
event_dispatcher=self._event_dispatcher,
|
||||
auto_fallback_interval=self._auto_fallback_interval,
|
||||
)
|
||||
self.initialized = False
|
||||
self._hc_lock = asyncio.Lock()
|
||||
self._bg_scheduler = BackgroundScheduler()
|
||||
self._config = config
|
||||
self._recurring_hc_task = None
|
||||
self._hc_tasks = []
|
||||
self._half_open_state_task = None
|
||||
|
||||
async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
if self._recurring_hc_task:
|
||||
self._recurring_hc_task.cancel()
|
||||
if self._half_open_state_task:
|
||||
self._half_open_state_task.cancel()
|
||||
for hc_task in self._hc_tasks:
|
||||
hc_task.cancel()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
Perform initialization of databases to define their initial state.
|
||||
"""
|
||||
|
||||
async def raise_exception_on_failed_hc(error):
|
||||
raise error
|
||||
|
||||
# Initial databases check to define initial state
|
||||
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
|
||||
|
||||
# Starts recurring health checks on the background.
|
||||
self._recurring_hc_task = asyncio.create_task(
|
||||
self._bg_scheduler.run_recurring_async(
|
||||
self._health_check_interval,
|
||||
self._check_databases_health,
|
||||
)
|
||||
)
|
||||
|
||||
is_active_db_found = False
|
||||
|
||||
for database, weight in self._databases:
|
||||
# Set on state changed callback for each circuit.
|
||||
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
|
||||
|
||||
# Set states according to a weights and circuit state
|
||||
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
|
||||
await self.command_executor.set_active_database(database)
|
||||
is_active_db_found = True
|
||||
|
||||
if not is_active_db_found:
|
||||
raise NoValidDatabaseException(
|
||||
"Initial connection failed - no active database found"
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def get_databases(self) -> Databases:
|
||||
"""
|
||||
Returns a sorted (by weight) list of all databases.
|
||||
"""
|
||||
return self._databases
|
||||
|
||||
async def set_active_database(self, database: AsyncDatabase) -> None:
|
||||
"""
|
||||
Promote one of the existing databases to become an active.
|
||||
"""
|
||||
exists = None
|
||||
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
exists = True
|
||||
break
|
||||
|
||||
if not exists:
|
||||
raise ValueError("Given database is not a member of database list")
|
||||
|
||||
await self._check_db_health(database)
|
||||
|
||||
if database.circuit.state == CBState.CLOSED:
|
||||
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
|
||||
await self.command_executor.set_active_database(database)
|
||||
return
|
||||
|
||||
raise NoValidDatabaseException(
|
||||
"Cannot set active database, database is unhealthy"
|
||||
)
|
||||
|
||||
async def add_database(self, database: AsyncDatabase):
|
||||
"""
|
||||
Adds a new database to the database list.
|
||||
"""
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
raise ValueError("Given database already exists")
|
||||
|
||||
await self._check_db_health(database)
|
||||
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
self._databases.add(database, database.weight)
|
||||
await self._change_active_database(database, highest_weighted_db)
|
||||
|
||||
async def _change_active_database(
|
||||
self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase
|
||||
):
|
||||
if (
|
||||
new_database.weight > highest_weight_database.weight
|
||||
and new_database.circuit.state == CBState.CLOSED
|
||||
):
|
||||
await self.command_executor.set_active_database(new_database)
|
||||
|
||||
async def remove_database(self, database: AsyncDatabase):
|
||||
"""
|
||||
Removes a database from the database list.
|
||||
"""
|
||||
weight = self._databases.remove(database)
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
|
||||
if (
|
||||
highest_weight <= weight
|
||||
and highest_weighted_db.circuit.state == CBState.CLOSED
|
||||
):
|
||||
await self.command_executor.set_active_database(highest_weighted_db)
|
||||
|
||||
async def update_database_weight(self, database: AsyncDatabase, weight: float):
|
||||
"""
|
||||
Updates a database from the database list.
|
||||
"""
|
||||
exists = None
|
||||
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
exists = True
|
||||
break
|
||||
|
||||
if not exists:
|
||||
raise ValueError("Given database is not a member of database list")
|
||||
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
self._databases.update_weight(database, weight)
|
||||
database.weight = weight
|
||||
await self._change_active_database(database, highest_weighted_db)
|
||||
|
||||
def add_failure_detector(self, failure_detector: AsyncFailureDetector):
|
||||
"""
|
||||
Adds a new failure detector to the database.
|
||||
"""
|
||||
self._failure_detectors.append(failure_detector)
|
||||
|
||||
async def add_health_check(self, healthcheck: HealthCheck):
|
||||
"""
|
||||
Adds a new health check to the database.
|
||||
"""
|
||||
async with self._hc_lock:
|
||||
self._health_checks.append(healthcheck)
|
||||
|
||||
async def execute_command(self, *args, **options):
|
||||
"""
|
||||
Executes a single command and return its result.
|
||||
"""
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
|
||||
return await self.command_executor.execute_command(*args, **options)
|
||||
|
||||
def pipeline(self):
|
||||
"""
|
||||
Enters into pipeline mode of the client.
|
||||
"""
|
||||
return Pipeline(self)
|
||||
|
||||
async def transaction(
|
||||
self,
|
||||
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
|
||||
*watches: KeyT,
|
||||
shard_hint: Optional[str] = None,
|
||||
value_from_callable: bool = False,
|
||||
watch_delay: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Executes callable as transaction.
|
||||
"""
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
|
||||
return await self.command_executor.execute_transaction(
|
||||
func,
|
||||
*watches,
|
||||
shard_hint=shard_hint,
|
||||
value_from_callable=value_from_callable,
|
||||
watch_delay=watch_delay,
|
||||
)
|
||||
|
||||
async def pubsub(self, **kwargs):
|
||||
"""
|
||||
Return a Publish/Subscribe object. With this object, you can
|
||||
subscribe to channels and listen for messages that get published to
|
||||
them.
|
||||
"""
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
|
||||
return PubSub(self, **kwargs)
|
||||
|
||||
async def _check_databases_health(
|
||||
self,
|
||||
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
|
||||
):
|
||||
"""
|
||||
Runs health checks as a recurring task.
|
||||
Runs health checks against all databases.
|
||||
"""
|
||||
try:
|
||||
self._hc_tasks = [
|
||||
asyncio.create_task(self._check_db_health(database))
|
||||
for database, _ in self._databases
|
||||
]
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(
|
||||
*self._hc_tasks,
|
||||
return_exceptions=True,
|
||||
),
|
||||
timeout=self._health_check_interval,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise asyncio.TimeoutError(
|
||||
"Health check execution exceeds health_check_interval"
|
||||
)
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, UnhealthyDatabaseException):
|
||||
unhealthy_db = result.database
|
||||
unhealthy_db.circuit.state = CBState.OPEN
|
||||
|
||||
logger.exception(
|
||||
"Health check failed, due to exception",
|
||||
exc_info=result.original_exception,
|
||||
)
|
||||
|
||||
if on_error:
|
||||
on_error(result.original_exception)
|
||||
|
||||
async def _check_db_health(self, database: AsyncDatabase) -> bool:
|
||||
"""
|
||||
Runs health checks on the given database until first failure.
|
||||
"""
|
||||
# Health check will setup circuit state
|
||||
is_healthy = await self._health_check_policy.execute(
|
||||
self._health_checks, database
|
||||
)
|
||||
|
||||
if not is_healthy:
|
||||
if database.circuit.state != CBState.OPEN:
|
||||
database.circuit.state = CBState.OPEN
|
||||
return is_healthy
|
||||
elif is_healthy and database.circuit.state != CBState.CLOSED:
|
||||
database.circuit.state = CBState.CLOSED
|
||||
|
||||
return is_healthy
|
||||
|
||||
def _on_circuit_state_change_callback(
|
||||
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
|
||||
):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if new_state == CBState.HALF_OPEN:
|
||||
self._half_open_state_task = asyncio.create_task(
|
||||
self._check_db_health(circuit.database)
|
||||
)
|
||||
return
|
||||
|
||||
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
|
||||
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)
|
||||
|
||||
async def aclose(self):
|
||||
if self.command_executor.active_database:
|
||||
await self.command_executor.active_database.client.aclose()
|
||||
|
||||
|
||||
def _half_open_circuit(circuit: CircuitBreaker):
|
||||
circuit.state = CBState.HALF_OPEN
|
||||
|
||||
|
||||
class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands):
|
||||
"""
|
||||
Pipeline implementation for multiple logical Redis databases.
|
||||
"""
|
||||
|
||||
def __init__(self, client: MultiDBClient):
|
||||
self._command_stack = []
|
||||
self._client = client
|
||||
|
||||
async def __aenter__(self: "Pipeline") -> "Pipeline":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.reset()
|
||||
await self._client.__aexit__(exc_type, exc_value, traceback)
|
||||
|
||||
def __await__(self):
|
||||
return self._async_self().__await__()
|
||||
|
||||
async def _async_self(self):
|
||||
return self
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._command_stack)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Pipeline instances should always evaluate to True"""
|
||||
return True
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._command_stack = []
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the pipeline"""
|
||||
await self.reset()
|
||||
|
||||
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
|
||||
"""
|
||||
Stage a command to be executed when execute() is next called
|
||||
|
||||
Returns the current Pipeline object back so commands can be
|
||||
chained together, such as:
|
||||
|
||||
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
|
||||
|
||||
At some other point, you can then run: pipe.execute(),
|
||||
which will execute all commands queued in the pipe.
|
||||
"""
|
||||
self._command_stack.append((args, options))
|
||||
return self
|
||||
|
||||
def execute_command(self, *args, **kwargs):
|
||||
"""Adds a command to the stack"""
|
||||
return self.pipeline_execute_command(*args, **kwargs)
|
||||
|
||||
async def execute(self) -> List[Any]:
|
||||
"""Execute all the commands in the current pipeline"""
|
||||
if not self._client.initialized:
|
||||
await self._client.initialize()
|
||||
|
||||
try:
|
||||
return await self._client.command_executor.execute_pipeline(
|
||||
tuple(self._command_stack)
|
||||
)
|
||||
finally:
|
||||
await self.reset()
|
||||
|
||||
|
||||
class PubSub:
|
||||
"""
|
||||
PubSub object for multi database client.
|
||||
"""
|
||||
|
||||
def __init__(self, client: MultiDBClient, **kwargs):
|
||||
"""Initialize the PubSub object for a multi-database client.
|
||||
|
||||
Args:
|
||||
client: MultiDBClient instance to use for pub/sub operations
|
||||
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
|
||||
"""
|
||||
|
||||
self._client = client
|
||||
self._client.command_executor.pubsub(**kwargs)
|
||||
|
||||
async def __aenter__(self) -> "PubSub":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self):
|
||||
return await self._client.command_executor.execute_pubsub_method("aclose")
|
||||
|
||||
@property
|
||||
def subscribed(self) -> bool:
|
||||
return self._client.command_executor.active_pubsub.subscribed
|
||||
|
||||
async def execute_command(self, *args: EncodableT):
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"execute_command", *args
|
||||
)
|
||||
|
||||
async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
|
||||
"""
|
||||
Subscribe to channel patterns. Patterns supplied as keyword arguments
|
||||
expect a pattern name as the key and a callable as the value. A
|
||||
pattern's callable will be invoked automatically when a message is
|
||||
received on that pattern rather than producing a message via
|
||||
``listen()``.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"psubscribe", *args, **kwargs
|
||||
)
|
||||
|
||||
async def punsubscribe(self, *args: ChannelT):
|
||||
"""
|
||||
Unsubscribe from the supplied patterns. If empty, unsubscribe from
|
||||
all patterns.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"punsubscribe", *args
|
||||
)
|
||||
|
||||
async def subscribe(self, *args: ChannelT, **kwargs: Callable):
|
||||
"""
|
||||
Subscribe to channels. Channels supplied as keyword arguments expect
|
||||
a channel name as the key and a callable as the value. A channel's
|
||||
callable will be invoked automatically when a message is received on
|
||||
that channel rather than producing a message via ``listen()`` or
|
||||
``get_message()``.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"subscribe", *args, **kwargs
|
||||
)
|
||||
|
||||
async def unsubscribe(self, *args):
|
||||
"""
|
||||
Unsubscribe from the supplied channels. If empty, unsubscribe from
|
||||
all channels
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"unsubscribe", *args
|
||||
)
|
||||
|
||||
async def get_message(
|
||||
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
|
||||
):
|
||||
"""
|
||||
Get the next message if one is available, otherwise None.
|
||||
|
||||
If timeout is specified, the system will wait for `timeout` seconds
|
||||
before returning. Timeout should be specified as a floating point
|
||||
number or None to wait indefinitely.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"get_message",
|
||||
ignore_subscribe_messages=ignore_subscribe_messages,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
exception_handler=None,
|
||||
poll_timeout: float = 1.0,
|
||||
) -> None:
|
||||
"""Process pub/sub messages using registered callbacks.
|
||||
|
||||
This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
|
||||
redis-py, but it is a coroutine. To launch it as a separate task, use
|
||||
``asyncio.create_task``:
|
||||
|
||||
>>> task = asyncio.create_task(pubsub.run())
|
||||
|
||||
To shut it down, use asyncio cancellation:
|
||||
|
||||
>>> task.cancel()
|
||||
>>> await task
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_run(
|
||||
sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self
|
||||
)
|
||||
@@ -0,0 +1,339 @@
|
||||
from abc import abstractmethod
|
||||
from asyncio import iscoroutinefunction
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Union
|
||||
|
||||
from redis.asyncio import RedisCluster
|
||||
from redis.asyncio.client import Pipeline, PubSub
|
||||
from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases
|
||||
from redis.asyncio.multidb.event import (
|
||||
AsyncActiveDatabaseChanged,
|
||||
CloseConnectionOnActiveDatabaseChanged,
|
||||
RegisterCommandFailure,
|
||||
ResubscribeOnActiveDatabaseChanged,
|
||||
)
|
||||
from redis.asyncio.multidb.failover import (
|
||||
DEFAULT_FAILOVER_ATTEMPTS,
|
||||
DEFAULT_FAILOVER_DELAY,
|
||||
AsyncFailoverStrategy,
|
||||
DefaultFailoverStrategyExecutor,
|
||||
FailoverStrategyExecutor,
|
||||
)
|
||||
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
|
||||
from redis.asyncio.retry import Retry
|
||||
from redis.event import AsyncOnCommandsFailEvent, EventDispatcherInterface
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.command_executor import BaseCommandExecutor, CommandExecutor
|
||||
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
from redis.typing import KeyT
|
||||
|
||||
|
||||
class AsyncCommandExecutor(CommandExecutor):
|
||||
@property
|
||||
@abstractmethod
|
||||
def databases(self) -> Databases:
|
||||
"""Returns a list of databases."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failure_detectors(self) -> List[AsyncFailureDetector]:
|
||||
"""Returns a list of failure detectors."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
|
||||
"""Adds a new failure detector to the list of failure detectors."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def active_database(self) -> Optional[AsyncDatabase]:
|
||||
"""Returns currently active database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_active_database(self, database: AsyncDatabase) -> None:
|
||||
"""Sets the currently active database."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def active_pubsub(self) -> Optional[PubSub]:
|
||||
"""Returns currently active pubsub."""
|
||||
pass
|
||||
|
||||
@active_pubsub.setter
|
||||
@abstractmethod
|
||||
def active_pubsub(self, pubsub: PubSub) -> None:
|
||||
"""Sets currently active pubsub."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
|
||||
"""Returns failover strategy executor."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def command_retry(self) -> Retry:
|
||||
"""Returns command retry object."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def pubsub(self, **kwargs):
|
||||
"""Initializes a PubSub object on a currently active database"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_command(self, *args, **options):
|
||||
"""Executes a command and returns the result."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_pipeline(self, command_stack: tuple):
|
||||
"""Executes a stack of commands in pipeline."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_transaction(
|
||||
self, transaction: Callable[[Pipeline], None], *watches, **options
|
||||
):
|
||||
"""Executes a transaction block wrapped in callback."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
|
||||
"""Executes a given method on active pub/sub."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
|
||||
"""Executes pub/sub run in a thread."""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
failure_detectors: List[AsyncFailureDetector],
|
||||
databases: Databases,
|
||||
command_retry: Retry,
|
||||
failover_strategy: AsyncFailoverStrategy,
|
||||
event_dispatcher: EventDispatcherInterface,
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY,
|
||||
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
|
||||
):
|
||||
"""
|
||||
Initialize the DefaultCommandExecutor instance.
|
||||
|
||||
Args:
|
||||
failure_detectors: List of failure detector instances to monitor database health
|
||||
databases: Collection of available databases to execute commands on
|
||||
command_retry: Retry policy for failed command execution
|
||||
failover_strategy: Strategy for handling database failover
|
||||
event_dispatcher: Interface for dispatching events
|
||||
failover_attempts: Number of failover attempts
|
||||
failover_delay: Delay between failover attempts
|
||||
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
|
||||
"""
|
||||
super().__init__(auto_fallback_interval)
|
||||
|
||||
for fd in failure_detectors:
|
||||
fd.set_command_executor(command_executor=self)
|
||||
|
||||
self._databases = databases
|
||||
self._failure_detectors = failure_detectors
|
||||
self._command_retry = command_retry
|
||||
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
|
||||
failover_strategy, failover_attempts, failover_delay
|
||||
)
|
||||
self._event_dispatcher = event_dispatcher
|
||||
self._active_database: Optional[Database] = None
|
||||
self._active_pubsub: Optional[PubSub] = None
|
||||
self._active_pubsub_kwargs = {}
|
||||
self._setup_event_dispatcher()
|
||||
self._schedule_next_fallback()
|
||||
|
||||
@property
|
||||
def databases(self) -> Databases:
|
||||
return self._databases
|
||||
|
||||
@property
|
||||
def failure_detectors(self) -> List[AsyncFailureDetector]:
|
||||
return self._failure_detectors
|
||||
|
||||
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
|
||||
self._failure_detectors.append(failure_detector)
|
||||
|
||||
@property
|
||||
def active_database(self) -> Optional[AsyncDatabase]:
|
||||
return self._active_database
|
||||
|
||||
async def set_active_database(self, database: AsyncDatabase) -> None:
|
||||
old_active = self._active_database
|
||||
self._active_database = database
|
||||
|
||||
if old_active is not None and old_active is not database:
|
||||
await self._event_dispatcher.dispatch_async(
|
||||
AsyncActiveDatabaseChanged(
|
||||
old_active,
|
||||
self._active_database,
|
||||
self,
|
||||
**self._active_pubsub_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def active_pubsub(self) -> Optional[PubSub]:
|
||||
return self._active_pubsub
|
||||
|
||||
@active_pubsub.setter
|
||||
def active_pubsub(self, pubsub: PubSub) -> None:
|
||||
self._active_pubsub = pubsub
|
||||
|
||||
@property
|
||||
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
|
||||
return self._failover_strategy_executor
|
||||
|
||||
@property
|
||||
def command_retry(self) -> Retry:
|
||||
return self._command_retry
|
||||
|
||||
def pubsub(self, **kwargs):
|
||||
if self._active_pubsub is None:
|
||||
if isinstance(self._active_database.client, RedisCluster):
|
||||
raise ValueError("PubSub is not supported for RedisCluster")
|
||||
|
||||
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
|
||||
self._active_pubsub_kwargs = kwargs
|
||||
|
||||
async def execute_command(self, *args, **options):
|
||||
async def callback():
|
||||
response = await self._active_database.client.execute_command(
|
||||
*args, **options
|
||||
)
|
||||
await self._register_command_execution(args)
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback, args)
|
||||
|
||||
async def execute_pipeline(self, command_stack: tuple):
|
||||
async def callback():
|
||||
async with self._active_database.client.pipeline() as pipe:
|
||||
for command, options in command_stack:
|
||||
pipe.execute_command(*command, **options)
|
||||
|
||||
response = await pipe.execute()
|
||||
await self._register_command_execution(command_stack)
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback, command_stack)
|
||||
|
||||
async def execute_transaction(
|
||||
self,
|
||||
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
|
||||
*watches: KeyT,
|
||||
shard_hint: Optional[str] = None,
|
||||
value_from_callable: bool = False,
|
||||
watch_delay: Optional[float] = None,
|
||||
):
|
||||
async def callback():
|
||||
response = await self._active_database.client.transaction(
|
||||
func,
|
||||
*watches,
|
||||
shard_hint=shard_hint,
|
||||
value_from_callable=value_from_callable,
|
||||
watch_delay=watch_delay,
|
||||
)
|
||||
await self._register_command_execution(())
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback)
|
||||
|
||||
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
|
||||
async def callback():
|
||||
method = getattr(self.active_pubsub, method_name)
|
||||
if iscoroutinefunction(method):
|
||||
response = await method(*args, **kwargs)
|
||||
else:
|
||||
response = method(*args, **kwargs)
|
||||
|
||||
await self._register_command_execution(args)
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback, *args)
|
||||
|
||||
async def execute_pubsub_run(
|
||||
self, sleep_time: float, exception_handler=None, pubsub=None
|
||||
) -> Any:
|
||||
async def callback():
|
||||
return await self._active_pubsub.run(
|
||||
poll_timeout=sleep_time,
|
||||
exception_handler=exception_handler,
|
||||
pubsub=pubsub,
|
||||
)
|
||||
|
||||
return await self._execute_with_failure_detection(callback)
|
||||
|
||||
async def _execute_with_failure_detection(
|
||||
self, callback: Callable, cmds: tuple = ()
|
||||
):
|
||||
"""
|
||||
Execute a commands execution callback with failure detection.
|
||||
"""
|
||||
|
||||
async def wrapper():
|
||||
# On each retry we need to check active database as it might change.
|
||||
await self._check_active_database()
|
||||
return await callback()
|
||||
|
||||
return await self._command_retry.call_with_retry(
|
||||
lambda: wrapper(),
|
||||
lambda error: self._on_command_fail(error, *cmds),
|
||||
)
|
||||
|
||||
async def _check_active_database(self):
|
||||
"""
|
||||
Checks if active a database needs to be updated.
|
||||
"""
|
||||
if (
|
||||
self._active_database is None
|
||||
or self._active_database.circuit.state != CBState.CLOSED
|
||||
or (
|
||||
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
and self._next_fallback_attempt <= datetime.now()
|
||||
)
|
||||
):
|
||||
await self.set_active_database(
|
||||
await self._failover_strategy_executor.execute()
|
||||
)
|
||||
self._schedule_next_fallback()
|
||||
|
||||
async def _on_command_fail(self, error, *args):
|
||||
await self._event_dispatcher.dispatch_async(
|
||||
AsyncOnCommandsFailEvent(args, error)
|
||||
)
|
||||
|
||||
async def _register_command_execution(self, cmd: tuple):
|
||||
for detector in self._failure_detectors:
|
||||
await detector.register_command_execution(cmd)
|
||||
|
||||
def _setup_event_dispatcher(self):
|
||||
"""
|
||||
Registers necessary listeners.
|
||||
"""
|
||||
failure_listener = RegisterCommandFailure(self._failure_detectors)
|
||||
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
|
||||
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
|
||||
self._event_dispatcher.register_listeners(
|
||||
{
|
||||
AsyncOnCommandsFailEvent: [failure_listener],
|
||||
AsyncActiveDatabaseChanged: [
|
||||
close_connection_listener,
|
||||
resubscribe_listener,
|
||||
],
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,210 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
import pybreaker
|
||||
|
||||
from redis.asyncio import ConnectionPool, Redis, RedisCluster
|
||||
from redis.asyncio.multidb.database import Database, Databases
|
||||
from redis.asyncio.multidb.failover import (
|
||||
DEFAULT_FAILOVER_ATTEMPTS,
|
||||
DEFAULT_FAILOVER_DELAY,
|
||||
AsyncFailoverStrategy,
|
||||
WeightBasedFailoverStrategy,
|
||||
)
|
||||
from redis.asyncio.multidb.failure_detector import (
|
||||
AsyncFailureDetector,
|
||||
FailureDetectorAsyncWrapper,
|
||||
)
|
||||
from redis.asyncio.multidb.healthcheck import (
|
||||
DEFAULT_HEALTH_CHECK_DELAY,
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL,
|
||||
DEFAULT_HEALTH_CHECK_POLICY,
|
||||
DEFAULT_HEALTH_CHECK_PROBES,
|
||||
HealthCheck,
|
||||
HealthCheckPolicies,
|
||||
PingHealthCheck,
|
||||
)
|
||||
from redis.asyncio.retry import Retry
|
||||
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.event import EventDispatcher, EventDispatcherInterface
|
||||
from redis.multidb.circuit import (
|
||||
DEFAULT_GRACE_PERIOD,
|
||||
CircuitBreaker,
|
||||
PBCircuitBreakerAdapter,
|
||||
)
|
||||
from redis.multidb.failure_detector import (
|
||||
DEFAULT_FAILURE_RATE_THRESHOLD,
|
||||
DEFAULT_FAILURES_DETECTION_WINDOW,
|
||||
DEFAULT_MIN_NUM_FAILURES,
|
||||
CommandFailureDetector,
|
||||
)
|
||||
|
||||
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
|
||||
|
||||
|
||||
def default_event_dispatcher() -> EventDispatcherInterface:
|
||||
return EventDispatcher()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""
|
||||
Dataclass representing the configuration for a database connection.
|
||||
|
||||
This class is used to store configuration settings for a database connection,
|
||||
including client options, connection sourcing details, circuit breaker settings,
|
||||
and cluster-specific properties. It provides a structure for defining these
|
||||
attributes and allows for the creation of customized configurations for various
|
||||
database setups.
|
||||
|
||||
Attributes:
|
||||
weight (float): Weight of the database to define the active one.
|
||||
client_kwargs (dict): Additional parameters for the database client connection.
|
||||
from_url (Optional[str]): Redis URL way of connecting to the database.
|
||||
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
|
||||
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
|
||||
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
|
||||
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
|
||||
on public Redis Enterprise endpoints.
|
||||
|
||||
Methods:
|
||||
default_circuit_breaker:
|
||||
Generates and returns a default CircuitBreaker instance adapted for use.
|
||||
"""
|
||||
|
||||
weight: float = 1.0
|
||||
client_kwargs: dict = field(default_factory=dict)
|
||||
from_url: Optional[str] = None
|
||||
from_pool: Optional[ConnectionPool] = None
|
||||
circuit: Optional[CircuitBreaker] = None
|
||||
grace_period: float = DEFAULT_GRACE_PERIOD
|
||||
health_check_url: Optional[str] = None
|
||||
|
||||
def default_circuit_breaker(self) -> CircuitBreaker:
|
||||
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
|
||||
return PBCircuitBreakerAdapter(circuit_breaker)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDbConfig:
|
||||
"""
|
||||
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
|
||||
|
||||
Attributes:
|
||||
databases_config: A list of database configurations.
|
||||
client_class: The client class used to manage database connections.
|
||||
command_retry: Retry strategy for executing database commands.
|
||||
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
|
||||
min_num_failures: Minimal count of failures required for failover
|
||||
failure_rate_threshold: Percentage of failures required for failover
|
||||
failures_detection_window: Time interval for tracking database failures.
|
||||
health_checks: Optional list of additional health checks performed on databases.
|
||||
health_check_interval: Time interval for executing health checks.
|
||||
health_check_probes: Number of attempts to evaluate the health of a database.
|
||||
health_check_delay: Delay between health check attempts.
|
||||
failover_strategy: Optional strategy for handling database failover scenarios.
|
||||
failover_attempts: Number of retries allowed for failover operations.
|
||||
failover_delay: Delay between failover attempts.
|
||||
auto_fallback_interval: Time interval to trigger automatic fallback.
|
||||
event_dispatcher: Interface for dispatching events related to database operations.
|
||||
|
||||
Methods:
|
||||
databases:
|
||||
Retrieves a collection of database clients managed by weighted configurations.
|
||||
Initializes database clients based on the provided configuration and removes
|
||||
redundant retry objects for lower-level clients to rely on global retry logic.
|
||||
|
||||
default_failure_detectors:
|
||||
Returns the default list of failure detectors used to monitor database failures.
|
||||
|
||||
default_health_checks:
|
||||
Returns the default list of health checks used to monitor database health
|
||||
with specific retry and backoff strategies.
|
||||
|
||||
default_failover_strategy:
|
||||
Provides the default failover strategy used for handling failover scenarios
|
||||
with defined retry and backoff configurations.
|
||||
"""
|
||||
|
||||
databases_config: List[DatabaseConfig]
|
||||
client_class: Type[Union[Redis, RedisCluster]] = Redis
|
||||
command_retry: Retry = Retry(
|
||||
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
|
||||
)
|
||||
failure_detectors: Optional[List[AsyncFailureDetector]] = None
|
||||
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
|
||||
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
|
||||
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
|
||||
health_checks: Optional[List[HealthCheck]] = None
|
||||
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
|
||||
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
|
||||
health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY
|
||||
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
|
||||
failover_strategy: Optional[AsyncFailoverStrategy] = None
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY
|
||||
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
event_dispatcher: EventDispatcherInterface = field(
|
||||
default_factory=default_event_dispatcher
|
||||
)
|
||||
|
||||
def databases(self) -> Databases:
|
||||
databases = WeightedList()
|
||||
|
||||
for database_config in self.databases_config:
|
||||
# The retry object is not used in the lower level clients, so we can safely remove it.
|
||||
# We rely on command_retry in terms of global retries.
|
||||
database_config.client_kwargs.update(
|
||||
{"retry": Retry(retries=0, backoff=NoBackoff())}
|
||||
)
|
||||
|
||||
if database_config.from_url:
|
||||
client = self.client_class.from_url(
|
||||
database_config.from_url, **database_config.client_kwargs
|
||||
)
|
||||
elif database_config.from_pool:
|
||||
database_config.from_pool.set_retry(
|
||||
Retry(retries=0, backoff=NoBackoff())
|
||||
)
|
||||
client = self.client_class.from_pool(
|
||||
connection_pool=database_config.from_pool
|
||||
)
|
||||
else:
|
||||
client = self.client_class(**database_config.client_kwargs)
|
||||
|
||||
circuit = (
|
||||
database_config.default_circuit_breaker()
|
||||
if database_config.circuit is None
|
||||
else database_config.circuit
|
||||
)
|
||||
databases.add(
|
||||
Database(
|
||||
client=client,
|
||||
circuit=circuit,
|
||||
weight=database_config.weight,
|
||||
health_check_url=database_config.health_check_url,
|
||||
),
|
||||
database_config.weight,
|
||||
)
|
||||
|
||||
return databases
|
||||
|
||||
def default_failure_detectors(self) -> List[AsyncFailureDetector]:
|
||||
return [
|
||||
FailureDetectorAsyncWrapper(
|
||||
CommandFailureDetector(
|
||||
min_num_failures=self.min_num_failures,
|
||||
failure_rate_threshold=self.failure_rate_threshold,
|
||||
failure_detection_window=self.failures_detection_window,
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def default_health_checks(self) -> List[HealthCheck]:
|
||||
return [
|
||||
PingHealthCheck(),
|
||||
]
|
||||
|
||||
def default_failover_strategy(self) -> AsyncFailoverStrategy:
|
||||
return WeightBasedFailoverStrategy()
|
||||
@@ -0,0 +1,69 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.multidb.circuit import CircuitBreaker
|
||||
from redis.multidb.database import AbstractDatabase, BaseDatabase
|
||||
from redis.typing import Number
|
||||
|
||||
|
||||
class AsyncDatabase(AbstractDatabase):
|
||||
"""Database with an underlying asynchronous redis client."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def client(self) -> Union[Redis, RedisCluster]:
|
||||
"""The underlying redis client."""
|
||||
pass
|
||||
|
||||
@client.setter
|
||||
@abstractmethod
|
||||
def client(self, client: Union[Redis, RedisCluster]):
|
||||
"""Set the underlying redis client."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def circuit(self) -> CircuitBreaker:
|
||||
"""Circuit breaker for the current database."""
|
||||
pass
|
||||
|
||||
@circuit.setter
|
||||
@abstractmethod
|
||||
def circuit(self, circuit: CircuitBreaker):
|
||||
"""Set the circuit breaker for the current database."""
|
||||
pass
|
||||
|
||||
|
||||
Databases = WeightedList[tuple[AsyncDatabase, Number]]
|
||||
|
||||
|
||||
class Database(BaseDatabase, AsyncDatabase):
|
||||
def __init__(
|
||||
self,
|
||||
client: Union[Redis, RedisCluster],
|
||||
circuit: CircuitBreaker,
|
||||
weight: float,
|
||||
health_check_url: Optional[str] = None,
|
||||
):
|
||||
self._client = client
|
||||
self._cb = circuit
|
||||
self._cb.database = self
|
||||
super().__init__(weight, health_check_url)
|
||||
|
||||
@property
|
||||
def client(self) -> Union[Redis, RedisCluster]:
|
||||
return self._client
|
||||
|
||||
@client.setter
|
||||
def client(self, client: Union[Redis, RedisCluster]):
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def circuit(self) -> CircuitBreaker:
|
||||
return self._cb
|
||||
|
||||
@circuit.setter
|
||||
def circuit(self, circuit: CircuitBreaker):
|
||||
self._cb = circuit
|
||||
@@ -0,0 +1,84 @@
|
||||
from typing import List
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio.multidb.database import AsyncDatabase
|
||||
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
|
||||
from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent
|
||||
|
||||
|
||||
class AsyncActiveDatabaseChanged:
|
||||
"""
|
||||
Event fired when an async active database has been changed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
old_database: AsyncDatabase,
|
||||
new_database: AsyncDatabase,
|
||||
command_executor,
|
||||
**kwargs,
|
||||
):
|
||||
self._old_database = old_database
|
||||
self._new_database = new_database
|
||||
self._command_executor = command_executor
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def old_database(self) -> AsyncDatabase:
|
||||
return self._old_database
|
||||
|
||||
@property
|
||||
def new_database(self) -> AsyncDatabase:
|
||||
return self._new_database
|
||||
|
||||
@property
|
||||
def command_executor(self):
|
||||
return self._command_executor
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
return self._kwargs
|
||||
|
||||
|
||||
class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface):
|
||||
"""
|
||||
Re-subscribe the currently active pub / sub to a new active database.
|
||||
"""
|
||||
|
||||
async def listen(self, event: AsyncActiveDatabaseChanged):
|
||||
old_pubsub = event.command_executor.active_pubsub
|
||||
|
||||
if old_pubsub is not None:
|
||||
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
|
||||
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
|
||||
new_pubsub.channels = old_pubsub.channels
|
||||
new_pubsub.patterns = old_pubsub.patterns
|
||||
await new_pubsub.on_connect(None)
|
||||
event.command_executor.active_pubsub = new_pubsub
|
||||
await old_pubsub.aclose()
|
||||
|
||||
|
||||
class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface):
|
||||
"""
|
||||
Close connection to the old active database.
|
||||
"""
|
||||
|
||||
async def listen(self, event: AsyncActiveDatabaseChanged):
|
||||
await event.old_database.client.aclose()
|
||||
|
||||
if isinstance(event.old_database.client, Redis):
|
||||
await event.old_database.client.connection_pool.update_active_connections_for_reconnect()
|
||||
await event.old_database.client.connection_pool.disconnect()
|
||||
|
||||
|
||||
class RegisterCommandFailure(AsyncEventListenerInterface):
|
||||
"""
|
||||
Event listener that registers command failures and passing it to the failure detectors.
|
||||
"""
|
||||
|
||||
def __init__(self, failure_detectors: List[AsyncFailureDetector]):
|
||||
self._failure_detectors = failure_detectors
|
||||
|
||||
async def listen(self, event: AsyncOnCommandsFailEvent) -> None:
|
||||
for failure_detector in self._failure_detectors:
|
||||
await failure_detector.register_failure(event.exception, event.commands)
|
||||
@@ -0,0 +1,125 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from redis.asyncio.multidb.database import AsyncDatabase, Databases
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.exception import (
|
||||
NoValidDatabaseException,
|
||||
TemporaryUnavailableException,
|
||||
)
|
||||
|
||||
DEFAULT_FAILOVER_ATTEMPTS = 10
|
||||
DEFAULT_FAILOVER_DELAY = 12
|
||||
|
||||
|
||||
class AsyncFailoverStrategy(ABC):
|
||||
@abstractmethod
|
||||
async def database(self) -> AsyncDatabase:
|
||||
"""Select the database according to the strategy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_databases(self, databases: Databases) -> None:
|
||||
"""Set the database strategy operates on."""
|
||||
pass
|
||||
|
||||
|
||||
class FailoverStrategyExecutor(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_attempts(self) -> int:
|
||||
"""The number of failover attempts."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_delay(self) -> float:
|
||||
"""The delay between failover attempts."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def strategy(self) -> AsyncFailoverStrategy:
|
||||
"""The strategy to execute."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> AsyncDatabase:
|
||||
"""Execute the failover strategy."""
|
||||
pass
|
||||
|
||||
|
||||
class WeightBasedFailoverStrategy(AsyncFailoverStrategy):
|
||||
"""
|
||||
Failover strategy based on database weights.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._databases = WeightedList()
|
||||
|
||||
async def database(self) -> AsyncDatabase:
|
||||
for database, _ in self._databases:
|
||||
if database.circuit.state == CBState.CLOSED:
|
||||
return database
|
||||
|
||||
raise NoValidDatabaseException("No valid database available for communication")
|
||||
|
||||
def set_databases(self, databases: Databases) -> None:
|
||||
self._databases = databases
|
||||
|
||||
|
||||
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
|
||||
"""
|
||||
Executes given failover strategy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: AsyncFailoverStrategy,
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY,
|
||||
):
|
||||
self._strategy = strategy
|
||||
self._failover_attempts = failover_attempts
|
||||
self._failover_delay = failover_delay
|
||||
self._next_attempt_ts: int = 0
|
||||
self._failover_counter: int = 0
|
||||
|
||||
@property
|
||||
def failover_attempts(self) -> int:
|
||||
return self._failover_attempts
|
||||
|
||||
@property
|
||||
def failover_delay(self) -> float:
|
||||
return self._failover_delay
|
||||
|
||||
@property
|
||||
def strategy(self) -> AsyncFailoverStrategy:
|
||||
return self._strategy
|
||||
|
||||
async def execute(self) -> AsyncDatabase:
|
||||
try:
|
||||
database = await self._strategy.database()
|
||||
self._reset()
|
||||
return database
|
||||
except NoValidDatabaseException as e:
|
||||
if self._next_attempt_ts == 0:
|
||||
self._next_attempt_ts = time.time() + self._failover_delay
|
||||
self._failover_counter += 1
|
||||
elif time.time() >= self._next_attempt_ts:
|
||||
self._next_attempt_ts += self._failover_delay
|
||||
self._failover_counter += 1
|
||||
|
||||
if self._failover_counter > self._failover_attempts:
|
||||
self._reset()
|
||||
raise e
|
||||
else:
|
||||
raise TemporaryUnavailableException(
|
||||
"No database connections currently available. "
|
||||
"This is a temporary condition - please retry the operation."
|
||||
)
|
||||
|
||||
def _reset(self) -> None:
|
||||
self._next_attempt_ts = 0
|
||||
self._failover_counter = 0
|
||||
@@ -0,0 +1,38 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from redis.multidb.failure_detector import FailureDetector
|
||||
|
||||
|
||||
class AsyncFailureDetector(ABC):
|
||||
@abstractmethod
|
||||
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
|
||||
"""Register a failure that occurred during command execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_command_execution(self, cmd: tuple) -> None:
|
||||
"""Register a command execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_command_executor(self, command_executor) -> None:
|
||||
"""Set the command executor for this failure."""
|
||||
pass
|
||||
|
||||
|
||||
class FailureDetectorAsyncWrapper(AsyncFailureDetector):
|
||||
"""
|
||||
Async wrapper for the failure detector.
|
||||
"""
|
||||
|
||||
def __init__(self, failure_detector: FailureDetector) -> None:
|
||||
self._failure_detector = failure_detector
|
||||
|
||||
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
|
||||
self._failure_detector.register_failure(exception, cmd)
|
||||
|
||||
async def register_command_execution(self, cmd: tuple) -> None:
|
||||
self._failure_detector.register_command_execution(cmd)
|
||||
|
||||
def set_command_executor(self, command_executor) -> None:
|
||||
self._failure_detector.set_command_executor(command_executor)
|
||||
@@ -0,0 +1,285 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
|
||||
from redis.backoff import NoBackoff
|
||||
from redis.http.http_client import HttpClient
|
||||
from redis.multidb.exception import UnhealthyDatabaseException
|
||||
from redis.retry import Retry
|
||||
|
||||
DEFAULT_HEALTH_CHECK_PROBES = 3
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL = 5
|
||||
DEFAULT_HEALTH_CHECK_DELAY = 0.5
|
||||
DEFAULT_LAG_AWARE_TOLERANCE = 5000
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthCheck(ABC):
|
||||
@abstractmethod
|
||||
async def check_health(self, database) -> bool:
|
||||
"""Function to determine the health status."""
|
||||
pass
|
||||
|
||||
|
||||
class HealthCheckPolicy(ABC):
|
||||
"""
|
||||
Health checks execution policy.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def health_check_probes(self) -> int:
|
||||
"""Number of probes to execute health checks."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def health_check_delay(self) -> float:
|
||||
"""Delay between health check probes."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
"""Execute health checks and return database health status."""
|
||||
pass
|
||||
|
||||
|
||||
class AbstractHealthCheckPolicy(HealthCheckPolicy):
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
if health_check_probes < 1:
|
||||
raise ValueError("health_check_probes must be greater than 0")
|
||||
self._health_check_probes = health_check_probes
|
||||
self._health_check_delay = health_check_delay
|
||||
|
||||
@property
|
||||
def health_check_probes(self) -> int:
|
||||
return self._health_check_probes
|
||||
|
||||
@property
|
||||
def health_check_delay(self) -> float:
|
||||
return self._health_check_delay
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class HealthyAllPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if all health check probes are successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
for health_check in health_checks:
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if not await health_check.check_health(database):
|
||||
return False
|
||||
except Exception as e:
|
||||
raise UnhealthyDatabaseException("Unhealthy database", database, e)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
await asyncio.sleep(self._health_check_delay)
|
||||
return True
|
||||
|
||||
|
||||
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if a majority of health check probes are successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
for health_check in health_checks:
|
||||
if self.health_check_probes % 2 == 0:
|
||||
allowed_unsuccessful_probes = self.health_check_probes / 2
|
||||
else:
|
||||
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
|
||||
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if not await health_check.check_health(database):
|
||||
allowed_unsuccessful_probes -= 1
|
||||
if allowed_unsuccessful_probes <= 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
allowed_unsuccessful_probes -= 1
|
||||
if allowed_unsuccessful_probes <= 0:
|
||||
raise UnhealthyDatabaseException(
|
||||
"Unhealthy database", database, e
|
||||
)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
await asyncio.sleep(self._health_check_delay)
|
||||
return True
|
||||
|
||||
|
||||
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if at least one health check probe is successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
is_healthy = False
|
||||
|
||||
for health_check in health_checks:
|
||||
exception = None
|
||||
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if await health_check.check_health(database):
|
||||
is_healthy = True
|
||||
break
|
||||
else:
|
||||
is_healthy = False
|
||||
except Exception as e:
|
||||
exception = UnhealthyDatabaseException(
|
||||
"Unhealthy database", database, e
|
||||
)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
await asyncio.sleep(self._health_check_delay)
|
||||
|
||||
if not is_healthy and not exception:
|
||||
return is_healthy
|
||||
elif not is_healthy and exception:
|
||||
raise exception
|
||||
|
||||
return is_healthy
|
||||
|
||||
|
||||
class HealthCheckPolicies(Enum):
|
||||
HEALTHY_ALL = HealthyAllPolicy
|
||||
HEALTHY_MAJORITY = HealthyMajorityPolicy
|
||||
HEALTHY_ANY = HealthyAnyPolicy
|
||||
|
||||
|
||||
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
|
||||
|
||||
|
||||
class PingHealthCheck(HealthCheck):
|
||||
"""
|
||||
Health check based on PING command.
|
||||
"""
|
||||
|
||||
async def check_health(self, database) -> bool:
|
||||
if isinstance(database.client, Redis):
|
||||
return await database.client.execute_command("PING")
|
||||
else:
|
||||
# For a cluster checks if all nodes are healthy.
|
||||
all_nodes = database.client.get_nodes()
|
||||
for node in all_nodes:
|
||||
if not await node.redis_connection.execute_command("PING"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class LagAwareHealthCheck(HealthCheck):
|
||||
"""
|
||||
Health check available for Redis Enterprise deployments.
|
||||
Verify via REST API that the database is healthy based on different lags.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rest_api_port: int = 9443,
|
||||
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
auth_basic: Optional[Tuple[str, str]] = None,
|
||||
verify_tls: bool = True,
|
||||
# TLS verification (server) options
|
||||
ca_file: Optional[str] = None,
|
||||
ca_path: Optional[str] = None,
|
||||
ca_data: Optional[Union[str, bytes]] = None,
|
||||
# Mutual TLS (client cert) options
|
||||
client_cert_file: Optional[str] = None,
|
||||
client_key_file: Optional[str] = None,
|
||||
client_key_password: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize LagAwareHealthCheck with the specified parameters.
|
||||
|
||||
Args:
|
||||
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
|
||||
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
|
||||
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
|
||||
auth_basic: Tuple of (username, password) for basic authentication
|
||||
verify_tls: Whether to verify TLS certificates (default: True)
|
||||
ca_file: Path to CA certificate file for TLS verification
|
||||
ca_path: Path to CA certificates directory for TLS verification
|
||||
ca_data: CA certificate data as string or bytes
|
||||
client_cert_file: Path to client certificate file for mutual TLS
|
||||
client_key_file: Path to client private key file for mutual TLS
|
||||
client_key_password: Password for encrypted client private key
|
||||
"""
|
||||
self._http_client = AsyncHTTPClientWrapper(
|
||||
HttpClient(
|
||||
timeout=timeout,
|
||||
auth_basic=auth_basic,
|
||||
retry=Retry(NoBackoff(), retries=0),
|
||||
verify_tls=verify_tls,
|
||||
ca_file=ca_file,
|
||||
ca_path=ca_path,
|
||||
ca_data=ca_data,
|
||||
client_cert_file=client_cert_file,
|
||||
client_key_file=client_key_file,
|
||||
client_key_password=client_key_password,
|
||||
)
|
||||
)
|
||||
self._rest_api_port = rest_api_port
|
||||
self._lag_aware_tolerance = lag_aware_tolerance
|
||||
|
||||
async def check_health(self, database) -> bool:
|
||||
if database.health_check_url is None:
|
||||
raise ValueError(
|
||||
"Database health check url is not set. Please check DatabaseConfig for the current database."
|
||||
)
|
||||
|
||||
if isinstance(database.client, Redis):
|
||||
db_host = database.client.get_connection_kwargs()["host"]
|
||||
else:
|
||||
db_host = database.client.startup_nodes[0].host
|
||||
|
||||
base_url = f"{database.health_check_url}:{self._rest_api_port}"
|
||||
self._http_client.client.base_url = base_url
|
||||
|
||||
# Find bdb matching to the current database host
|
||||
matching_bdb = None
|
||||
for bdb in await self._http_client.get("/v1/bdbs"):
|
||||
for endpoint in bdb["endpoints"]:
|
||||
if endpoint["dns_name"] == db_host:
|
||||
matching_bdb = bdb
|
||||
break
|
||||
|
||||
# In case if the host was set as public IP
|
||||
for addr in endpoint["addr"]:
|
||||
if addr == db_host:
|
||||
matching_bdb = bdb
|
||||
break
|
||||
|
||||
if matching_bdb is None:
|
||||
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
|
||||
raise ValueError("Could not find a matching bdb")
|
||||
|
||||
url = (
|
||||
f"/v1/bdbs/{matching_bdb['uid']}/availability"
|
||||
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
|
||||
)
|
||||
await self._http_client.get(url, expect_json=False)
|
||||
|
||||
# Status checked in an http client, otherwise HttpError will be raised
|
||||
return True
|
||||
@@ -0,0 +1,58 @@
|
||||
from asyncio import sleep
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
|
||||
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
from redis.retry import AbstractRetry
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.backoff import AbstractBackoff
|
||||
|
||||
|
||||
class Retry(AbstractRetry[RedisError]):
|
||||
__hash__ = AbstractRetry.__hash__
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backoff: "AbstractBackoff",
|
||||
retries: int,
|
||||
supported_errors: Tuple[Type[RedisError], ...] = (
|
||||
ConnectionError,
|
||||
TimeoutError,
|
||||
),
|
||||
):
|
||||
super().__init__(backoff, retries, supported_errors)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Retry):
|
||||
return NotImplemented
|
||||
|
||||
return (
|
||||
self._backoff == other._backoff
|
||||
and self._retries == other._retries
|
||||
and set(self._supported_errors) == set(other._supported_errors)
|
||||
)
|
||||
|
||||
async def call_with_retry(
|
||||
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
|
||||
) -> T:
|
||||
"""
|
||||
Execute an operation that might fail and returns its result, or
|
||||
raise the exception that was thrown depending on the `Backoff` object.
|
||||
`do`: the operation to call. Expects no argument.
|
||||
`fail`: the failure handler, expects the last error that was thrown
|
||||
"""
|
||||
self._backoff.reset()
|
||||
failures = 0
|
||||
while True:
|
||||
try:
|
||||
return await do()
|
||||
except self._supported_errors as error:
|
||||
failures += 1
|
||||
await fail(error)
|
||||
if self._retries >= 0 and failures > self._retries:
|
||||
raise error
|
||||
backoff = self._backoff.compute(failures)
|
||||
if backoff > 0:
|
||||
await sleep(backoff)
|
||||
@@ -0,0 +1,404 @@
|
||||
import asyncio
|
||||
import random
|
||||
import weakref
|
||||
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
|
||||
|
||||
from redis.asyncio.client import Redis
|
||||
from redis.asyncio.connection import (
|
||||
Connection,
|
||||
ConnectionPool,
|
||||
EncodableT,
|
||||
SSLConnection,
|
||||
)
|
||||
from redis.commands import AsyncSentinelCommands
|
||||
from redis.exceptions import (
|
||||
ConnectionError,
|
||||
ReadOnlyError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
|
||||
class MasterNotFoundError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class SlaveNotFoundError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class SentinelManagedConnection(Connection):
|
||||
def __init__(self, **kwargs):
|
||||
self.connection_pool = kwargs.pop("connection_pool")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
s = f"<{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
if self.host:
|
||||
host_info = f",host={self.host},port={self.port}"
|
||||
s += host_info
|
||||
return s + ")>"
|
||||
|
||||
async def connect_to(self, address):
|
||||
self.host, self.port = address
|
||||
await self.connect_check_health(
|
||||
check_health=self.connection_pool.check_connection,
|
||||
retry_socket_connect=False,
|
||||
)
|
||||
|
||||
async def _connect_retry(self):
|
||||
if self._reader:
|
||||
return # already connected
|
||||
if self.connection_pool.is_master:
|
||||
await self.connect_to(await self.connection_pool.get_master_address())
|
||||
else:
|
||||
async for slave in self.connection_pool.rotate_slaves():
|
||||
try:
|
||||
return await self.connect_to(slave)
|
||||
except ConnectionError:
|
||||
continue
|
||||
raise SlaveNotFoundError # Never be here
|
||||
|
||||
async def connect(self):
|
||||
return await self.retry.call_with_retry(
|
||||
self._connect_retry,
|
||||
lambda error: asyncio.sleep(0),
|
||||
)
|
||||
|
||||
async def read_response(
|
||||
self,
|
||||
disable_decoding: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
*,
|
||||
disconnect_on_error: Optional[float] = True,
|
||||
push_request: Optional[bool] = False,
|
||||
):
|
||||
try:
|
||||
return await super().read_response(
|
||||
disable_decoding=disable_decoding,
|
||||
timeout=timeout,
|
||||
disconnect_on_error=disconnect_on_error,
|
||||
push_request=push_request,
|
||||
)
|
||||
except ReadOnlyError:
|
||||
if self.connection_pool.is_master:
|
||||
# When talking to a master, a ReadOnlyError when likely
|
||||
# indicates that the previous master that we're still connected
|
||||
# to has been demoted to a slave and there's a new master.
|
||||
# calling disconnect will force the connection to re-query
|
||||
# sentinel during the next connect() attempt.
|
||||
await self.disconnect()
|
||||
raise ConnectionError("The previous master is now a slave")
|
||||
raise
|
||||
|
||||
|
||||
class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
|
||||
pass
|
||||
|
||||
|
||||
class SentinelConnectionPool(ConnectionPool):
|
||||
"""
|
||||
Sentinel backed connection pool.
|
||||
|
||||
If ``check_connection`` flag is set to True, SentinelManagedConnection
|
||||
sends a PING command right after establishing the connection.
|
||||
"""
|
||||
|
||||
def __init__(self, service_name, sentinel_manager, **kwargs):
|
||||
kwargs["connection_class"] = kwargs.get(
|
||||
"connection_class",
|
||||
(
|
||||
SentinelManagedSSLConnection
|
||||
if kwargs.pop("ssl", False)
|
||||
else SentinelManagedConnection
|
||||
),
|
||||
)
|
||||
self.is_master = kwargs.pop("is_master", True)
|
||||
self.check_connection = kwargs.pop("check_connection", False)
|
||||
super().__init__(**kwargs)
|
||||
self.connection_kwargs["connection_pool"] = weakref.proxy(self)
|
||||
self.service_name = service_name
|
||||
self.sentinel_manager = sentinel_manager
|
||||
self.master_address = None
|
||||
self.slave_rr_counter = None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>"
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.master_address = None
|
||||
self.slave_rr_counter = None
|
||||
|
||||
def owns_connection(self, connection: Connection):
|
||||
check = not self.is_master or (
|
||||
self.is_master and self.master_address == (connection.host, connection.port)
|
||||
)
|
||||
return check and super().owns_connection(connection)
|
||||
|
||||
async def get_master_address(self):
|
||||
master_address = await self.sentinel_manager.discover_master(self.service_name)
|
||||
if self.is_master:
|
||||
if self.master_address != master_address:
|
||||
self.master_address = master_address
|
||||
# disconnect any idle connections so that they reconnect
|
||||
# to the new master the next time that they are used.
|
||||
await self.disconnect(inuse_connections=False)
|
||||
return master_address
|
||||
|
||||
async def rotate_slaves(self) -> AsyncIterator:
|
||||
"""Round-robin slave balancer"""
|
||||
slaves = await self.sentinel_manager.discover_slaves(self.service_name)
|
||||
if slaves:
|
||||
if self.slave_rr_counter is None:
|
||||
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
|
||||
for _ in range(len(slaves)):
|
||||
self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves)
|
||||
slave = slaves[self.slave_rr_counter]
|
||||
yield slave
|
||||
# Fallback to the master connection
|
||||
try:
|
||||
yield await self.get_master_address()
|
||||
except MasterNotFoundError:
|
||||
pass
|
||||
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
|
||||
|
||||
|
||||
class Sentinel(AsyncSentinelCommands):
|
||||
"""
|
||||
Redis Sentinel cluster client
|
||||
|
||||
>>> from redis.sentinel import Sentinel
|
||||
>>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1)
|
||||
>>> master = sentinel.master_for('mymaster', socket_timeout=0.1)
|
||||
>>> await master.set('foo', 'bar')
|
||||
>>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1)
|
||||
>>> await slave.get('foo')
|
||||
b'bar'
|
||||
|
||||
``sentinels`` is a list of sentinel nodes. Each node is represented by
|
||||
a pair (hostname, port).
|
||||
|
||||
``min_other_sentinels`` defined a minimum number of peers for a sentinel.
|
||||
When querying a sentinel, if it doesn't meet this threshold, responses
|
||||
from that sentinel won't be considered valid.
|
||||
|
||||
``sentinel_kwargs`` is a dictionary of connection arguments used when
|
||||
connecting to sentinel instances. Any argument that can be passed to
|
||||
a normal Redis connection can be specified here. If ``sentinel_kwargs`` is
|
||||
not specified, any socket_timeout and socket_keepalive options specified
|
||||
in ``connection_kwargs`` will be used.
|
||||
|
||||
``connection_kwargs`` are keyword arguments that will be used when
|
||||
establishing a connection to a Redis server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sentinels,
|
||||
min_other_sentinels=0,
|
||||
sentinel_kwargs=None,
|
||||
force_master_ip=None,
|
||||
**connection_kwargs,
|
||||
):
|
||||
# if sentinel_kwargs isn't defined, use the socket_* options from
|
||||
# connection_kwargs
|
||||
if sentinel_kwargs is None:
|
||||
sentinel_kwargs = {
|
||||
k: v for k, v in connection_kwargs.items() if k.startswith("socket_")
|
||||
}
|
||||
self.sentinel_kwargs = sentinel_kwargs
|
||||
|
||||
self.sentinels = [
|
||||
Redis(host=hostname, port=port, **self.sentinel_kwargs)
|
||||
for hostname, port in sentinels
|
||||
]
|
||||
self.min_other_sentinels = min_other_sentinels
|
||||
self.connection_kwargs = connection_kwargs
|
||||
self._force_master_ip = force_master_ip
|
||||
|
||||
async def execute_command(self, *args, **kwargs):
|
||||
"""
|
||||
Execute Sentinel command in sentinel nodes.
|
||||
once - If set to True, then execute the resulting command on a single
|
||||
node at random, rather than across the entire sentinel cluster.
|
||||
"""
|
||||
once = bool(kwargs.pop("once", False))
|
||||
|
||||
# Check if command is supposed to return the original
|
||||
# responses instead of boolean value.
|
||||
return_responses = bool(kwargs.pop("return_responses", False))
|
||||
|
||||
if once:
|
||||
response = await random.choice(self.sentinels).execute_command(
|
||||
*args, **kwargs
|
||||
)
|
||||
if return_responses:
|
||||
return [response]
|
||||
else:
|
||||
return True if response else False
|
||||
|
||||
tasks = [
|
||||
asyncio.Task(sentinel.execute_command(*args, **kwargs))
|
||||
for sentinel in self.sentinels
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
if return_responses:
|
||||
return responses
|
||||
|
||||
return all(responses)
|
||||
|
||||
def __repr__(self):
|
||||
sentinel_addresses = []
|
||||
for sentinel in self.sentinels:
|
||||
sentinel_addresses.append(
|
||||
f"{sentinel.connection_pool.connection_kwargs['host']}:"
|
||||
f"{sentinel.connection_pool.connection_kwargs['port']}"
|
||||
)
|
||||
return (
|
||||
f"<{self.__class__}.{self.__class__.__name__}"
|
||||
f"(sentinels=[{','.join(sentinel_addresses)}])>"
|
||||
)
|
||||
|
||||
def check_master_state(self, state: dict, service_name: str) -> bool:
|
||||
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
|
||||
return False
|
||||
# Check if our sentinel doesn't see other nodes
|
||||
if state["num-other-sentinels"] < self.min_other_sentinels:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def discover_master(self, service_name: str):
|
||||
"""
|
||||
Asks sentinel servers for the Redis master's address corresponding
|
||||
to the service labeled ``service_name``.
|
||||
|
||||
Returns a pair (address, port) or raises MasterNotFoundError if no
|
||||
master is found.
|
||||
"""
|
||||
collected_errors = list()
|
||||
for sentinel_no, sentinel in enumerate(self.sentinels):
|
||||
try:
|
||||
masters = await sentinel.sentinel_masters()
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
collected_errors.append(f"{sentinel} - {e!r}")
|
||||
continue
|
||||
state = masters.get(service_name)
|
||||
if state and self.check_master_state(state, service_name):
|
||||
# Put this sentinel at the top of the list
|
||||
self.sentinels[0], self.sentinels[sentinel_no] = (
|
||||
sentinel,
|
||||
self.sentinels[0],
|
||||
)
|
||||
|
||||
ip = (
|
||||
self._force_master_ip
|
||||
if self._force_master_ip is not None
|
||||
else state["ip"]
|
||||
)
|
||||
return ip, state["port"]
|
||||
|
||||
error_info = ""
|
||||
if len(collected_errors) > 0:
|
||||
error_info = f" : {', '.join(collected_errors)}"
|
||||
raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}")
|
||||
|
||||
def filter_slaves(
|
||||
self, slaves: Iterable[Mapping]
|
||||
) -> Sequence[Tuple[EncodableT, EncodableT]]:
|
||||
"""Remove slaves that are in an ODOWN or SDOWN state"""
|
||||
slaves_alive = []
|
||||
for slave in slaves:
|
||||
if slave["is_odown"] or slave["is_sdown"]:
|
||||
continue
|
||||
slaves_alive.append((slave["ip"], slave["port"]))
|
||||
return slaves_alive
|
||||
|
||||
async def discover_slaves(
|
||||
self, service_name: str
|
||||
) -> Sequence[Tuple[EncodableT, EncodableT]]:
|
||||
"""Returns a list of alive slaves for service ``service_name``"""
|
||||
for sentinel in self.sentinels:
|
||||
try:
|
||||
slaves = await sentinel.sentinel_slaves(service_name)
|
||||
except (ConnectionError, ResponseError, TimeoutError):
|
||||
continue
|
||||
slaves = self.filter_slaves(slaves)
|
||||
if slaves:
|
||||
return slaves
|
||||
return []
|
||||
|
||||
def master_for(
|
||||
self,
|
||||
service_name: str,
|
||||
redis_class: Type[Redis] = Redis,
|
||||
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a redis client instance for the ``service_name`` master.
|
||||
Sentinel client will detect failover and reconnect Redis clients
|
||||
automatically.
|
||||
|
||||
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
|
||||
used to retrieve the master's address before establishing a new
|
||||
connection.
|
||||
|
||||
NOTE: If the master's address has changed, any cached connections to
|
||||
the old master are closed.
|
||||
|
||||
By default clients will be a :py:class:`~redis.Redis` instance.
|
||||
Specify a different class to the ``redis_class`` argument if you
|
||||
desire something different.
|
||||
|
||||
The ``connection_pool_class`` specifies the connection pool to
|
||||
use. The :py:class:`~redis.sentinel.SentinelConnectionPool`
|
||||
will be used by default.
|
||||
|
||||
All other keyword arguments are merged with any connection_kwargs
|
||||
passed to this class and passed to the connection pool as keyword
|
||||
arguments to be used to initialize Redis connections.
|
||||
"""
|
||||
kwargs["is_master"] = True
|
||||
connection_kwargs = dict(self.connection_kwargs)
|
||||
connection_kwargs.update(kwargs)
|
||||
|
||||
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
|
||||
# The Redis object "owns" the pool
|
||||
return redis_class.from_pool(connection_pool)
|
||||
|
||||
def slave_for(
|
||||
self,
|
||||
service_name: str,
|
||||
redis_class: Type[Redis] = Redis,
|
||||
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns redis client instance for the ``service_name`` slave(s).
|
||||
|
||||
A SentinelConnectionPool class is used to retrieve the slave's
|
||||
address before establishing a new connection.
|
||||
|
||||
By default clients will be a :py:class:`~redis.Redis` instance.
|
||||
Specify a different class to the ``redis_class`` argument if you
|
||||
desire something different.
|
||||
|
||||
The ``connection_pool_class`` specifies the connection pool to use.
|
||||
The SentinelConnectionPool will be used by default.
|
||||
|
||||
All other keyword arguments are merged with any connection_kwargs
|
||||
passed to this class and passed to the connection pool as keyword
|
||||
arguments to be used to initialize Redis connections.
|
||||
"""
|
||||
kwargs["is_master"] = False
|
||||
connection_kwargs = dict(self.connection_kwargs)
|
||||
connection_kwargs.update(kwargs)
|
||||
|
||||
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
|
||||
# The Redis object "owns" the pool
|
||||
return redis_class.from_pool(connection_pool)
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio.client import Pipeline, Redis
|
||||
|
||||
|
||||
def from_url(url, **kwargs):
|
||||
"""
|
||||
Returns an active Redis client generated from the given database URL.
|
||||
|
||||
Will attempt to extract the database id from the path url fragment, if
|
||||
none is provided.
|
||||
"""
|
||||
from redis.asyncio.client import Redis
|
||||
|
||||
return Redis.from_url(url, **kwargs)
|
||||
|
||||
|
||||
class pipeline: # noqa: N801
|
||||
def __init__(self, redis_obj: "Redis"):
|
||||
self.p: "Pipeline" = redis_obj.pipeline()
|
||||
|
||||
async def __aenter__(self) -> "Pipeline":
|
||||
return self.p
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.p.execute()
|
||||
del self.p
|
||||
31
backend/venv/lib/python3.9/site-packages/redis/auth/err.py
Normal file
31
backend/venv/lib/python3.9/site-packages/redis/auth/err.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class RequestTokenErr(Exception):
|
||||
"""
|
||||
Represents an exception during token request.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class InvalidTokenSchemaErr(Exception):
|
||||
"""
|
||||
Represents an exception related to invalid token schema.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_fields: Iterable[str] = []):
|
||||
super().__init__(
|
||||
"Unexpected token schema. Following fields are missing: "
|
||||
+ ", ".join(missing_fields)
|
||||
)
|
||||
|
||||
|
||||
class TokenRenewalErr(Exception):
|
||||
"""
|
||||
Represents an exception during token renewal process.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
28
backend/venv/lib/python3.9/site-packages/redis/auth/idp.py
Normal file
28
backend/venv/lib/python3.9/site-packages/redis/auth/idp.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from redis.auth.token import TokenInterface
|
||||
|
||||
"""
|
||||
This interface is the facade of an identity provider
|
||||
"""
|
||||
|
||||
|
||||
class IdentityProviderInterface(ABC):
|
||||
"""
|
||||
Receive a token from the identity provider.
|
||||
Receiving a token only works when being authenticated.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def request_token(self, force_refresh=False) -> TokenInterface:
|
||||
pass
|
||||
|
||||
|
||||
class IdentityProviderConfigInterface(ABC):
|
||||
"""
|
||||
Configuration class that provides a configured identity provider.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_provider(self) -> IdentityProviderInterface:
|
||||
pass
|
||||
130
backend/venv/lib/python3.9/site-packages/redis/auth/token.py
Normal file
130
backend/venv/lib/python3.9/site-packages/redis/auth/token.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from redis.auth.err import InvalidTokenSchemaErr
|
||||
|
||||
|
||||
class TokenInterface(ABC):
|
||||
@abstractmethod
|
||||
def is_expired(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def ttl(self) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def try_get(self, key: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_value(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_expires_at_ms(self) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_received_at_ms(self) -> float:
|
||||
pass
|
||||
|
||||
|
||||
class TokenResponse:
|
||||
def __init__(self, token: TokenInterface):
|
||||
self._token = token
|
||||
|
||||
def get_token(self) -> TokenInterface:
|
||||
return self._token
|
||||
|
||||
def get_ttl_ms(self) -> float:
|
||||
return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
|
||||
|
||||
|
||||
class SimpleToken(TokenInterface):
|
||||
def __init__(
|
||||
self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
|
||||
) -> None:
|
||||
self.value = value
|
||||
self.expires_at = expires_at_ms
|
||||
self.received_at = received_at_ms
|
||||
self.claims = claims
|
||||
|
||||
def ttl(self) -> float:
|
||||
if self.expires_at == -1:
|
||||
return -1
|
||||
|
||||
return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
if self.expires_at == -1:
|
||||
return False
|
||||
|
||||
return self.ttl() <= 0
|
||||
|
||||
def try_get(self, key: str) -> str:
|
||||
return self.claims.get(key)
|
||||
|
||||
def get_value(self) -> str:
|
||||
return self.value
|
||||
|
||||
def get_expires_at_ms(self) -> float:
|
||||
return self.expires_at
|
||||
|
||||
def get_received_at_ms(self) -> float:
|
||||
return self.received_at
|
||||
|
||||
|
||||
class JWToken(TokenInterface):
|
||||
REQUIRED_FIELDS = {"exp"}
|
||||
|
||||
def __init__(self, token: str):
|
||||
try:
|
||||
import jwt
|
||||
except ImportError as ie:
|
||||
raise ImportError(
|
||||
f"The PyJWT library is required for {self.__class__.__name__}.",
|
||||
) from ie
|
||||
self._value = token
|
||||
self._decoded = jwt.decode(
|
||||
self._value,
|
||||
options={"verify_signature": False},
|
||||
algorithms=[jwt.get_unverified_header(self._value).get("alg")],
|
||||
)
|
||||
self._validate_token()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
exp = self._decoded["exp"]
|
||||
if exp == -1:
|
||||
return False
|
||||
|
||||
return (
|
||||
self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
|
||||
)
|
||||
|
||||
def ttl(self) -> float:
|
||||
exp = self._decoded["exp"]
|
||||
if exp == -1:
|
||||
return -1
|
||||
|
||||
return (
|
||||
self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
|
||||
)
|
||||
|
||||
def try_get(self, key: str) -> str:
|
||||
return self._decoded.get(key)
|
||||
|
||||
def get_value(self) -> str:
|
||||
return self._value
|
||||
|
||||
def get_expires_at_ms(self) -> float:
|
||||
return float(self._decoded["exp"] * 1000)
|
||||
|
||||
def get_received_at_ms(self) -> float:
|
||||
return datetime.now(timezone.utc).timestamp() * 1000
|
||||
|
||||
def _validate_token(self):
|
||||
actual_fields = {x for x in self._decoded.keys()}
|
||||
|
||||
if len(self.REQUIRED_FIELDS - actual_fields) != 0:
|
||||
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)
|
||||
@@ -0,0 +1,370 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from time import sleep
|
||||
from typing import Any, Awaitable, Callable, Union
|
||||
|
||||
from redis.auth.err import RequestTokenErr, TokenRenewalErr
|
||||
from redis.auth.idp import IdentityProviderInterface
|
||||
from redis.auth.token import TokenResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialsListener:
|
||||
"""
|
||||
Listeners that will be notified on events related to credentials.
|
||||
Accepts callbacks and awaitable callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._on_next = None
|
||||
self._on_error = None
|
||||
|
||||
@property
|
||||
def on_next(self) -> Union[Callable[[Any], None], Awaitable]:
|
||||
return self._on_next
|
||||
|
||||
@on_next.setter
|
||||
def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None:
|
||||
self._on_next = callback
|
||||
|
||||
@property
|
||||
def on_error(self) -> Union[Callable[[Exception], None], Awaitable]:
|
||||
return self._on_error
|
||||
|
||||
@on_error.setter
|
||||
def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None:
|
||||
self._on_error = callback
|
||||
|
||||
|
||||
class RetryPolicy:
|
||||
def __init__(self, max_attempts: int, delay_in_ms: float):
|
||||
self.max_attempts = max_attempts
|
||||
self.delay_in_ms = delay_in_ms
|
||||
|
||||
def get_max_attempts(self) -> int:
|
||||
"""
|
||||
Retry attempts before exception will be thrown.
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return self.max_attempts
|
||||
|
||||
def get_delay_in_ms(self) -> float:
|
||||
"""
|
||||
Delay between retries in seconds.
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return self.delay_in_ms
|
||||
|
||||
|
||||
class TokenManagerConfig:
|
||||
def __init__(
|
||||
self,
|
||||
expiration_refresh_ratio: float,
|
||||
lower_refresh_bound_millis: int,
|
||||
token_request_execution_timeout_in_ms: int,
|
||||
retry_policy: RetryPolicy,
|
||||
):
|
||||
self._expiration_refresh_ratio = expiration_refresh_ratio
|
||||
self._lower_refresh_bound_millis = lower_refresh_bound_millis
|
||||
self._token_request_execution_timeout_in_ms = (
|
||||
token_request_execution_timeout_in_ms
|
||||
)
|
||||
self._retry_policy = retry_policy
|
||||
|
||||
def get_expiration_refresh_ratio(self) -> float:
|
||||
"""
|
||||
Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501
|
||||
For example, a value of 0.75 means the token should be refreshed
|
||||
when 75% of its lifetime has elapsed (or when 25% of its lifetime remains).
|
||||
|
||||
:return: float
|
||||
"""
|
||||
|
||||
return self._expiration_refresh_ratio
|
||||
|
||||
def get_lower_refresh_bound_millis(self) -> int:
|
||||
"""
|
||||
Represents the minimum time in milliseconds before token expiration
|
||||
to trigger a refresh, in milliseconds.
|
||||
This value sets a fixed lower bound for when a token refresh should occur,
|
||||
regardless of the token's total lifetime.
|
||||
If set to 0 there will be no lower bound and the refresh will be triggered
|
||||
based on the expirationRefreshRatio only.
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return self._lower_refresh_bound_millis
|
||||
|
||||
def get_token_request_execution_timeout_in_ms(self) -> int:
|
||||
"""
|
||||
Represents the maximum time in milliseconds to wait
|
||||
for a token request to complete.
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return self._token_request_execution_timeout_in_ms
|
||||
|
||||
def get_retry_policy(self) -> RetryPolicy:
|
||||
"""
|
||||
Represents the retry policy for token requests.
|
||||
|
||||
:return: RetryPolicy
|
||||
"""
|
||||
return self._retry_policy
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(
|
||||
self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig
|
||||
):
|
||||
self._idp = identity_provider
|
||||
self._config = config
|
||||
self._next_timer = None
|
||||
self._listener = None
|
||||
self._init_timer = None
|
||||
self._retries = 0
|
||||
|
||||
def __del__(self):
|
||||
logger.info("Token manager are disposed")
|
||||
self.stop()
|
||||
|
||||
def start(
|
||||
self,
|
||||
listener: CredentialsListener,
|
||||
skip_initial: bool = False,
|
||||
) -> Callable[[], None]:
|
||||
self._listener = listener
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# Run loop in a separate thread to unblock main thread.
|
||||
loop = asyncio.new_event_loop()
|
||||
thread = threading.Thread(
|
||||
target=_start_event_loop_in_thread, args=(loop,), daemon=True
|
||||
)
|
||||
thread.start()
|
||||
|
||||
# Event to block for initial execution.
|
||||
init_event = asyncio.Event()
|
||||
self._init_timer = loop.call_later(
|
||||
0, self._renew_token, skip_initial, init_event
|
||||
)
|
||||
logger.info("Token manager started")
|
||||
|
||||
# Blocks in thread-safe manner.
|
||||
asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
|
||||
return self.stop
|
||||
|
||||
async def start_async(
|
||||
self,
|
||||
listener: CredentialsListener,
|
||||
block_for_initial: bool = False,
|
||||
initial_delay_in_ms: float = 0,
|
||||
skip_initial: bool = False,
|
||||
) -> Callable[[], None]:
|
||||
self._listener = listener
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
init_event = asyncio.Event()
|
||||
|
||||
# Wraps the async callback with async wrapper to schedule with loop.call_later()
|
||||
wrapped = _async_to_sync_wrapper(
|
||||
loop, self._renew_token_async, skip_initial, init_event
|
||||
)
|
||||
self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped)
|
||||
logger.info("Token manager started")
|
||||
|
||||
if block_for_initial:
|
||||
await init_event.wait()
|
||||
|
||||
return self.stop
|
||||
|
||||
def stop(self):
|
||||
if self._init_timer is not None:
|
||||
self._init_timer.cancel()
|
||||
if self._next_timer is not None:
|
||||
self._next_timer.cancel()
|
||||
|
||||
def acquire_token(self, force_refresh=False) -> TokenResponse:
|
||||
try:
|
||||
token = self._idp.request_token(force_refresh)
|
||||
except RequestTokenErr as e:
|
||||
if self._retries < self._config.get_retry_policy().get_max_attempts():
|
||||
self._retries += 1
|
||||
sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000)
|
||||
return self.acquire_token(force_refresh)
|
||||
else:
|
||||
raise e
|
||||
|
||||
self._retries = 0
|
||||
return TokenResponse(token)
|
||||
|
||||
async def acquire_token_async(self, force_refresh=False) -> TokenResponse:
|
||||
try:
|
||||
token = self._idp.request_token(force_refresh)
|
||||
except RequestTokenErr as e:
|
||||
if self._retries < self._config.get_retry_policy().get_max_attempts():
|
||||
self._retries += 1
|
||||
await asyncio.sleep(
|
||||
self._config.get_retry_policy().get_delay_in_ms() / 1000
|
||||
)
|
||||
return await self.acquire_token_async(force_refresh)
|
||||
else:
|
||||
raise e
|
||||
|
||||
self._retries = 0
|
||||
return TokenResponse(token)
|
||||
|
||||
def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float:
|
||||
delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date)
|
||||
delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date)
|
||||
delay = min(delay_for_ratio_refresh, delay_for_lower_refresh)
|
||||
|
||||
return 0 if delay < 0 else delay / 1000
|
||||
|
||||
def _delay_for_lower_refresh(self, expire_date: float):
|
||||
return (
|
||||
expire_date
|
||||
- self._config.get_lower_refresh_bound_millis()
|
||||
- (datetime.now(timezone.utc).timestamp() * 1000)
|
||||
)
|
||||
|
||||
def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
|
||||
token_ttl = expire_date - issue_date
|
||||
refresh_before = token_ttl - (
|
||||
token_ttl * self._config.get_expiration_refresh_ratio()
|
||||
)
|
||||
|
||||
return (
|
||||
expire_date
|
||||
- refresh_before
|
||||
- (datetime.now(timezone.utc).timestamp() * 1000)
|
||||
)
|
||||
|
||||
def _renew_token(
|
||||
self, skip_initial: bool = False, init_event: asyncio.Event = None
|
||||
):
|
||||
"""
|
||||
Task to renew token from identity provider.
|
||||
Schedules renewal tasks based on token TTL.
|
||||
"""
|
||||
|
||||
try:
|
||||
token_res = self.acquire_token(force_refresh=True)
|
||||
delay = self._calculate_renewal_delay(
|
||||
token_res.get_token().get_expires_at_ms(),
|
||||
token_res.get_token().get_received_at_ms(),
|
||||
)
|
||||
|
||||
if token_res.get_token().is_expired():
|
||||
raise TokenRenewalErr("Requested token is expired")
|
||||
|
||||
if self._listener.on_next is None:
|
||||
logger.warning(
|
||||
"No registered callback for token renewal task. Renewal cancelled"
|
||||
)
|
||||
return
|
||||
|
||||
if not skip_initial:
|
||||
try:
|
||||
self._listener.on_next(token_res.get_token())
|
||||
except Exception as e:
|
||||
raise TokenRenewalErr(e)
|
||||
|
||||
if delay <= 0:
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
self._next_timer = loop.call_later(delay, self._renew_token)
|
||||
logger.info(f"Next token renewal scheduled in {delay} seconds")
|
||||
return token_res
|
||||
except Exception as e:
|
||||
if self._listener.on_error is None:
|
||||
raise e
|
||||
|
||||
self._listener.on_error(e)
|
||||
finally:
|
||||
if init_event:
|
||||
init_event.set()
|
||||
|
||||
async def _renew_token_async(
|
||||
self, skip_initial: bool = False, init_event: asyncio.Event = None
|
||||
):
|
||||
"""
|
||||
Async task to renew tokens from identity provider.
|
||||
Schedules renewal tasks based on token TTL.
|
||||
"""
|
||||
|
||||
try:
|
||||
token_res = await self.acquire_token_async(force_refresh=True)
|
||||
delay = self._calculate_renewal_delay(
|
||||
token_res.get_token().get_expires_at_ms(),
|
||||
token_res.get_token().get_received_at_ms(),
|
||||
)
|
||||
|
||||
if token_res.get_token().is_expired():
|
||||
raise TokenRenewalErr("Requested token is expired")
|
||||
|
||||
if self._listener.on_next is None:
|
||||
logger.warning(
|
||||
"No registered callback for token renewal task. Renewal cancelled"
|
||||
)
|
||||
return
|
||||
|
||||
if not skip_initial:
|
||||
try:
|
||||
await self._listener.on_next(token_res.get_token())
|
||||
except Exception as e:
|
||||
raise TokenRenewalErr(e)
|
||||
|
||||
if delay <= 0:
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
wrapped = _async_to_sync_wrapper(loop, self._renew_token_async)
|
||||
logger.info(f"Next token renewal scheduled in {delay} seconds")
|
||||
loop.call_later(delay, wrapped)
|
||||
except Exception as e:
|
||||
if self._listener.on_error is None:
|
||||
raise e
|
||||
|
||||
await self._listener.on_error(e)
|
||||
finally:
|
||||
if init_event:
|
||||
init_event.set()
|
||||
|
||||
|
||||
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
|
||||
"""
|
||||
Wraps an asynchronous function so it can be used with loop.call_later.
|
||||
|
||||
:param loop: The event loop in which the coroutine will be executed.
|
||||
:param coro_func: The coroutine function to wrap.
|
||||
:param args: Positional arguments to pass to the coroutine function.
|
||||
:param kwargs: Keyword arguments to pass to the coroutine function.
|
||||
:return: A regular function suitable for loop.call_later.
|
||||
"""
|
||||
|
||||
def wrapped():
|
||||
# Schedule the coroutine in the event loop
|
||||
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
|
||||
"""
|
||||
Starts event loop in a thread.
|
||||
Used to be able to schedule tasks using loop.call_later.
|
||||
|
||||
:param event_loop:
|
||||
:return:
|
||||
"""
|
||||
asyncio.set_event_loop(event_loop)
|
||||
event_loop.run_forever()
|
||||
204
backend/venv/lib/python3.9/site-packages/redis/background.py
Normal file
204
backend/venv/lib/python3.9/site-packages/redis/background.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
|
||||
class BackgroundScheduler:
|
||||
"""
|
||||
Schedules background tasks execution either in separate thread or in the running event loop.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._next_timer = None
|
||||
self._event_loops = []
|
||||
self._lock = threading.Lock()
|
||||
self._stopped = False
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop all scheduled tasks and clean up resources.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
self._stopped = True
|
||||
|
||||
if self._next_timer:
|
||||
self._next_timer.cancel()
|
||||
self._next_timer = None
|
||||
|
||||
# Stop all event loops
|
||||
for loop in self._event_loops:
|
||||
if loop.is_running():
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
|
||||
self._event_loops.clear()
|
||||
|
||||
def run_once(self, delay: float, callback: Callable, *args):
|
||||
"""
|
||||
Runs callable task once after certain delay in seconds.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
# Run loop in a separate thread to unblock main thread.
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
with self._lock:
|
||||
self._event_loops.append(loop)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=_start_event_loop_in_thread,
|
||||
args=(loop, self._call_later, delay, callback, *args),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
|
||||
def run_recurring(self, interval: float, callback: Callable, *args):
|
||||
"""
|
||||
Runs recurring callable task with given interval in seconds.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
# Run loop in a separate thread to unblock main thread.
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
with self._lock:
|
||||
self._event_loops.append(loop)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=_start_event_loop_in_thread,
|
||||
args=(loop, self._call_later_recurring, interval, callback, *args),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
|
||||
async def run_recurring_async(
|
||||
self, interval: float, coro: Callable[..., Coroutine[Any, Any, Any]], *args
|
||||
):
|
||||
"""
|
||||
Runs recurring coroutine with given interval in seconds in the current event loop.
|
||||
To be used only from an async context. No additional threads are created.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
wrapped = _async_to_sync_wrapper(loop, coro, *args)
|
||||
|
||||
def tick():
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
# Schedule the coroutine
|
||||
wrapped()
|
||||
# Schedule next tick
|
||||
self._next_timer = loop.call_later(interval, tick)
|
||||
|
||||
# Schedule first tick
|
||||
self._next_timer = loop.call_later(interval, tick)
|
||||
|
||||
def _call_later(
|
||||
self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args
|
||||
):
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
self._next_timer = loop.call_later(delay, callback, *args)
|
||||
|
||||
def _call_later_recurring(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
interval: float,
|
||||
callback: Callable,
|
||||
*args,
|
||||
):
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
self._call_later(
|
||||
loop, interval, self._execute_recurring, loop, interval, callback, *args
|
||||
)
|
||||
|
||||
def _execute_recurring(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
interval: float,
|
||||
callback: Callable,
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
Executes recurring callable task with given interval in seconds.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
try:
|
||||
callback(*args)
|
||||
except Exception:
|
||||
# Silently ignore exceptions during shutdown
|
||||
pass
|
||||
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
self._call_later(
|
||||
loop, interval, self._execute_recurring, loop, interval, callback, *args
|
||||
)
|
||||
|
||||
|
||||
def _start_event_loop_in_thread(
|
||||
event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args
|
||||
):
|
||||
"""
|
||||
Starts event loop in a thread and schedule callback as soon as event loop is ready.
|
||||
Used to be able to schedule tasks using loop.call_later.
|
||||
|
||||
:param event_loop:
|
||||
:return:
|
||||
"""
|
||||
asyncio.set_event_loop(event_loop)
|
||||
event_loop.call_soon(call_soon_cb, event_loop, *args)
|
||||
try:
|
||||
event_loop.run_forever()
|
||||
finally:
|
||||
try:
|
||||
# Clean up pending tasks
|
||||
pending = asyncio.all_tasks(event_loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
# Run loop once more to process cancellations
|
||||
event_loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
event_loop.close()
|
||||
|
||||
|
||||
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
|
||||
"""
|
||||
Wraps an asynchronous function so it can be used with loop.call_later.
|
||||
|
||||
:param loop: The event loop in which the coroutine will be executed.
|
||||
:param coro_func: The coroutine function to wrap.
|
||||
:param args: Positional arguments to pass to the coroutine function.
|
||||
:param kwargs: Keyword arguments to pass to the coroutine function.
|
||||
:return: A regular function suitable for loop.call_later.
|
||||
"""
|
||||
|
||||
def wrapped():
|
||||
# Schedule the coroutine in the event loop
|
||||
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
|
||||
|
||||
return wrapped
|
||||
183
backend/venv/lib/python3.9/site-packages/redis/backoff.py
Normal file
183
backend/venv/lib/python3.9/site-packages/redis/backoff.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# Maximum backoff between each retry in seconds
|
||||
DEFAULT_CAP = 0.512
|
||||
# Minimum backoff between each retry in seconds
|
||||
DEFAULT_BASE = 0.008
|
||||
|
||||
|
||||
class AbstractBackoff(ABC):
|
||||
"""Backoff interface"""
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset internal state before an operation.
|
||||
`reset` is called once at the beginning of
|
||||
every call to `Retry.call_with_retry`
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, failures: int) -> float:
|
||||
"""Compute backoff in seconds upon failure"""
|
||||
pass
|
||||
|
||||
|
||||
class ConstantBackoff(AbstractBackoff):
|
||||
"""Constant backoff upon failure"""
|
||||
|
||||
def __init__(self, backoff: float) -> None:
|
||||
"""`backoff`: backoff time in seconds"""
|
||||
self._backoff = backoff
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._backoff,))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ConstantBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._backoff == other._backoff
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
return self._backoff
|
||||
|
||||
|
||||
class NoBackoff(ConstantBackoff):
|
||||
"""No backoff upon failure"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(0)
|
||||
|
||||
|
||||
class ExponentialBackoff(AbstractBackoff):
|
||||
"""Exponential backoff upon failure"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE):
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
"""
|
||||
self._cap = cap
|
||||
self._base = base
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ExponentialBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
return min(self._cap, self._base * 2**failures)
|
||||
|
||||
|
||||
class FullJitterBackoff(AbstractBackoff):
|
||||
"""Full jitter backoff upon failure"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
"""
|
||||
self._cap = cap
|
||||
self._base = base
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, FullJitterBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
return random.uniform(0, min(self._cap, self._base * 2**failures))
|
||||
|
||||
|
||||
class EqualJitterBackoff(AbstractBackoff):
|
||||
"""Equal jitter backoff upon failure"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
"""
|
||||
self._cap = cap
|
||||
self._base = base
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, EqualJitterBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
temp = min(self._cap, self._base * 2**failures) / 2
|
||||
return temp + random.uniform(0, temp)
|
||||
|
||||
|
||||
class DecorrelatedJitterBackoff(AbstractBackoff):
|
||||
"""Decorrelated jitter backoff upon failure"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
"""
|
||||
self._cap = cap
|
||||
self._base = base
|
||||
self._previous_backoff = 0
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, DecorrelatedJitterBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def reset(self) -> None:
|
||||
self._previous_backoff = 0
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
max_backoff = max(self._base, self._previous_backoff * 3)
|
||||
temp = random.uniform(self._base, max_backoff)
|
||||
self._previous_backoff = min(self._cap, temp)
|
||||
return self._previous_backoff
|
||||
|
||||
|
||||
class ExponentialWithJitterBackoff(AbstractBackoff):
|
||||
"""Exponential backoff upon failure, with jitter"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
"""
|
||||
self._cap = cap
|
||||
self._base = base
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ExponentialWithJitterBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
return min(self._cap, random.random() * self._base * 2**failures)
|
||||
|
||||
|
||||
def default_backoff():
|
||||
return EqualJitterBackoff()
|
||||
402
backend/venv/lib/python3.9/site-packages/redis/cache.py
Normal file
402
backend/venv/lib/python3.9/site-packages/redis/cache.py
Normal file
@@ -0,0 +1,402 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
|
||||
class CacheEntryStatus(Enum):
|
||||
VALID = "VALID"
|
||||
IN_PROGRESS = "IN_PROGRESS"
|
||||
|
||||
|
||||
class EvictionPolicyType(Enum):
|
||||
time_based = "time_based"
|
||||
frequency_based = "frequency_based"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheKey:
|
||||
command: str
|
||||
redis_keys: tuple
|
||||
|
||||
|
||||
class CacheEntry:
|
||||
def __init__(
|
||||
self,
|
||||
cache_key: CacheKey,
|
||||
cache_value: bytes,
|
||||
status: CacheEntryStatus,
|
||||
connection_ref,
|
||||
):
|
||||
self.cache_key = cache_key
|
||||
self.cache_value = cache_value
|
||||
self.status = status
|
||||
self.connection_ref = connection_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(
|
||||
(self.cache_key, self.cache_value, self.status, self.connection_ref)
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return hash(self) == hash(other)
|
||||
|
||||
|
||||
class EvictionPolicyInterface(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def cache(self):
|
||||
pass
|
||||
|
||||
@cache.setter
|
||||
@abstractmethod
|
||||
def cache(self, value):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> EvictionPolicyType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict_next(self) -> CacheKey:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict_many(self, count: int) -> List[CacheKey]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def touch(self, cache_key: CacheKey) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class CacheConfigurationInterface(ABC):
|
||||
@abstractmethod
|
||||
def get_cache_class(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_max_size(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_eviction_policy(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_exceeds_max_size(self, count: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_allowed_to_cache(self, command: str) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class CacheInterface(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def collection(self) -> OrderedDict:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> CacheConfigurationInterface:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def eviction_policy(self) -> EvictionPolicyInterface:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def size(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, entry: CacheEntry) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def flush(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_cachable(self, key: CacheKey) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultCache(CacheInterface):
|
||||
def __init__(
|
||||
self,
|
||||
cache_config: CacheConfigurationInterface,
|
||||
) -> None:
|
||||
self._cache = OrderedDict()
|
||||
self._cache_config = cache_config
|
||||
self._eviction_policy = self._cache_config.get_eviction_policy().value()
|
||||
self._eviction_policy.cache = self
|
||||
|
||||
@property
|
||||
def collection(self) -> OrderedDict:
|
||||
return self._cache
|
||||
|
||||
@property
|
||||
def config(self) -> CacheConfigurationInterface:
|
||||
return self._cache_config
|
||||
|
||||
@property
|
||||
def eviction_policy(self) -> EvictionPolicyInterface:
|
||||
return self._eviction_policy
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self._cache)
|
||||
|
||||
def set(self, entry: CacheEntry) -> bool:
|
||||
if not self.is_cachable(entry.cache_key):
|
||||
return False
|
||||
|
||||
self._cache[entry.cache_key] = entry
|
||||
self._eviction_policy.touch(entry.cache_key)
|
||||
|
||||
if self._cache_config.is_exceeds_max_size(len(self._cache)):
|
||||
self._eviction_policy.evict_next()
|
||||
|
||||
return True
|
||||
|
||||
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
|
||||
entry = self._cache.get(key, None)
|
||||
|
||||
if entry is None:
|
||||
return None
|
||||
|
||||
self._eviction_policy.touch(key)
|
||||
return entry
|
||||
|
||||
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
|
||||
response = []
|
||||
|
||||
for key in cache_keys:
|
||||
if self.get(key) is not None:
|
||||
self._cache.pop(key)
|
||||
response.append(True)
|
||||
else:
|
||||
response.append(False)
|
||||
|
||||
return response
|
||||
|
||||
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
|
||||
response = []
|
||||
keys_to_delete = []
|
||||
|
||||
for redis_key in redis_keys:
|
||||
if isinstance(redis_key, bytes):
|
||||
redis_key = redis_key.decode()
|
||||
for cache_key in self._cache:
|
||||
if redis_key in cache_key.redis_keys:
|
||||
keys_to_delete.append(cache_key)
|
||||
response.append(True)
|
||||
|
||||
for key in keys_to_delete:
|
||||
self._cache.pop(key)
|
||||
|
||||
return response
|
||||
|
||||
def flush(self) -> int:
|
||||
elem_count = len(self._cache)
|
||||
self._cache.clear()
|
||||
return elem_count
|
||||
|
||||
def is_cachable(self, key: CacheKey) -> bool:
|
||||
return self._cache_config.is_allowed_to_cache(key.command)
|
||||
|
||||
|
||||
class LRUPolicy(EvictionPolicyInterface):
|
||||
def __init__(self):
|
||||
self.cache = None
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._cache
|
||||
|
||||
@cache.setter
|
||||
def cache(self, cache: CacheInterface):
|
||||
self._cache = cache
|
||||
|
||||
@property
|
||||
def type(self) -> EvictionPolicyType:
|
||||
return EvictionPolicyType.time_based
|
||||
|
||||
def evict_next(self) -> CacheKey:
|
||||
self._assert_cache()
|
||||
popped_entry = self._cache.collection.popitem(last=False)
|
||||
return popped_entry[0]
|
||||
|
||||
def evict_many(self, count: int) -> List[CacheKey]:
|
||||
self._assert_cache()
|
||||
if count > len(self._cache.collection):
|
||||
raise ValueError("Evictions count is above cache size")
|
||||
|
||||
popped_keys = []
|
||||
|
||||
for _ in range(count):
|
||||
popped_entry = self._cache.collection.popitem(last=False)
|
||||
popped_keys.append(popped_entry[0])
|
||||
|
||||
return popped_keys
|
||||
|
||||
def touch(self, cache_key: CacheKey) -> None:
|
||||
self._assert_cache()
|
||||
|
||||
if self._cache.collection.get(cache_key) is None:
|
||||
raise ValueError("Given entry does not belong to the cache")
|
||||
|
||||
self._cache.collection.move_to_end(cache_key)
|
||||
|
||||
def _assert_cache(self):
|
||||
if self.cache is None or not isinstance(self.cache, CacheInterface):
|
||||
raise ValueError("Eviction policy should be associated with valid cache.")
|
||||
|
||||
|
||||
class EvictionPolicy(Enum):
|
||||
LRU = LRUPolicy
|
||||
|
||||
|
||||
class CacheConfig(CacheConfigurationInterface):
|
||||
DEFAULT_CACHE_CLASS = DefaultCache
|
||||
DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU
|
||||
DEFAULT_MAX_SIZE = 10000
|
||||
|
||||
DEFAULT_ALLOW_LIST = [
|
||||
"BITCOUNT",
|
||||
"BITFIELD_RO",
|
||||
"BITPOS",
|
||||
"EXISTS",
|
||||
"GEODIST",
|
||||
"GEOHASH",
|
||||
"GEOPOS",
|
||||
"GEORADIUSBYMEMBER_RO",
|
||||
"GEORADIUS_RO",
|
||||
"GEOSEARCH",
|
||||
"GET",
|
||||
"GETBIT",
|
||||
"GETRANGE",
|
||||
"HEXISTS",
|
||||
"HGET",
|
||||
"HGETALL",
|
||||
"HKEYS",
|
||||
"HLEN",
|
||||
"HMGET",
|
||||
"HSTRLEN",
|
||||
"HVALS",
|
||||
"JSON.ARRINDEX",
|
||||
"JSON.ARRLEN",
|
||||
"JSON.GET",
|
||||
"JSON.MGET",
|
||||
"JSON.OBJKEYS",
|
||||
"JSON.OBJLEN",
|
||||
"JSON.RESP",
|
||||
"JSON.STRLEN",
|
||||
"JSON.TYPE",
|
||||
"LCS",
|
||||
"LINDEX",
|
||||
"LLEN",
|
||||
"LPOS",
|
||||
"LRANGE",
|
||||
"MGET",
|
||||
"SCARD",
|
||||
"SDIFF",
|
||||
"SINTER",
|
||||
"SINTERCARD",
|
||||
"SISMEMBER",
|
||||
"SMEMBERS",
|
||||
"SMISMEMBER",
|
||||
"SORT_RO",
|
||||
"STRLEN",
|
||||
"SUBSTR",
|
||||
"SUNION",
|
||||
"TS.GET",
|
||||
"TS.INFO",
|
||||
"TS.RANGE",
|
||||
"TS.REVRANGE",
|
||||
"TYPE",
|
||||
"XLEN",
|
||||
"XPENDING",
|
||||
"XRANGE",
|
||||
"XREAD",
|
||||
"XREVRANGE",
|
||||
"ZCARD",
|
||||
"ZCOUNT",
|
||||
"ZDIFF",
|
||||
"ZINTER",
|
||||
"ZINTERCARD",
|
||||
"ZLEXCOUNT",
|
||||
"ZMSCORE",
|
||||
"ZRANGE",
|
||||
"ZRANGEBYLEX",
|
||||
"ZRANGEBYSCORE",
|
||||
"ZRANK",
|
||||
"ZREVRANGE",
|
||||
"ZREVRANGEBYLEX",
|
||||
"ZREVRANGEBYSCORE",
|
||||
"ZREVRANK",
|
||||
"ZSCORE",
|
||||
"ZUNION",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = DEFAULT_MAX_SIZE,
|
||||
cache_class: Any = DEFAULT_CACHE_CLASS,
|
||||
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
|
||||
):
|
||||
self._cache_class = cache_class
|
||||
self._max_size = max_size
|
||||
self._eviction_policy = eviction_policy
|
||||
|
||||
def get_cache_class(self):
|
||||
return self._cache_class
|
||||
|
||||
def get_max_size(self) -> int:
|
||||
return self._max_size
|
||||
|
||||
def get_eviction_policy(self) -> EvictionPolicy:
|
||||
return self._eviction_policy
|
||||
|
||||
def is_exceeds_max_size(self, count: int) -> bool:
|
||||
return count > self._max_size
|
||||
|
||||
def is_allowed_to_cache(self, command: str) -> bool:
|
||||
return command in self.DEFAULT_ALLOW_LIST
|
||||
|
||||
|
||||
class CacheFactoryInterface(ABC):
|
||||
@abstractmethod
|
||||
def get_cache(self) -> CacheInterface:
|
||||
pass
|
||||
|
||||
|
||||
class CacheFactory(CacheFactoryInterface):
|
||||
def __init__(self, cache_config: Optional[CacheConfig] = None):
|
||||
self._config = cache_config
|
||||
|
||||
if self._config is None:
|
||||
self._config = CacheConfig()
|
||||
|
||||
def get_cache(self) -> CacheInterface:
|
||||
cache_class = self._config.get_cache_class()
|
||||
return cache_class(cache_config=self._config)
|
||||
1712
backend/venv/lib/python3.9/site-packages/redis/client.py
Executable file
1712
backend/venv/lib/python3.9/site-packages/redis/client.py
Executable file
File diff suppressed because it is too large
Load Diff
3370
backend/venv/lib/python3.9/site-packages/redis/cluster.py
Normal file
3370
backend/venv/lib/python3.9/site-packages/redis/cluster.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,18 @@
|
||||
from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands
|
||||
from .core import AsyncCoreCommands, CoreCommands
|
||||
from .helpers import list_or_args
|
||||
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
|
||||
from .sentinel import AsyncSentinelCommands, SentinelCommands
|
||||
|
||||
__all__ = [
|
||||
"AsyncCoreCommands",
|
||||
"AsyncRedisClusterCommands",
|
||||
"AsyncRedisModuleCommands",
|
||||
"AsyncSentinelCommands",
|
||||
"CoreCommands",
|
||||
"READ_COMMANDS",
|
||||
"RedisClusterCommands",
|
||||
"RedisModuleCommands",
|
||||
"SentinelCommands",
|
||||
"list_or_args",
|
||||
]
|
||||
@@ -0,0 +1,253 @@
|
||||
from redis._parsers.helpers import bool_ok
|
||||
|
||||
from ..helpers import get_protocol_version, parse_to_list
|
||||
from .commands import * # noqa
|
||||
from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo
|
||||
|
||||
|
||||
class AbstractBloom:
|
||||
"""
|
||||
The client allows to interact with RedisBloom and use all of
|
||||
it's functionality.
|
||||
|
||||
- BF for Bloom Filter
|
||||
- CF for Cuckoo Filter
|
||||
- CMS for Count-Min Sketch
|
||||
- TOPK for TopK Data Structure
|
||||
- TDIGEST for estimate rank statistics
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def append_items(params, items):
|
||||
"""Append ITEMS to params."""
|
||||
params.extend(["ITEMS"])
|
||||
params += items
|
||||
|
||||
@staticmethod
|
||||
def append_error(params, error):
|
||||
"""Append ERROR to params."""
|
||||
if error is not None:
|
||||
params.extend(["ERROR", error])
|
||||
|
||||
@staticmethod
|
||||
def append_capacity(params, capacity):
|
||||
"""Append CAPACITY to params."""
|
||||
if capacity is not None:
|
||||
params.extend(["CAPACITY", capacity])
|
||||
|
||||
@staticmethod
|
||||
def append_expansion(params, expansion):
|
||||
"""Append EXPANSION to params."""
|
||||
if expansion is not None:
|
||||
params.extend(["EXPANSION", expansion])
|
||||
|
||||
@staticmethod
|
||||
def append_no_scale(params, noScale):
|
||||
"""Append NONSCALING tag to params."""
|
||||
if noScale is not None:
|
||||
params.extend(["NONSCALING"])
|
||||
|
||||
@staticmethod
|
||||
def append_weights(params, weights):
|
||||
"""Append WEIGHTS to params."""
|
||||
if len(weights) > 0:
|
||||
params.append("WEIGHTS")
|
||||
params += weights
|
||||
|
||||
@staticmethod
|
||||
def append_no_create(params, noCreate):
|
||||
"""Append NOCREATE tag to params."""
|
||||
if noCreate is not None:
|
||||
params.extend(["NOCREATE"])
|
||||
|
||||
@staticmethod
|
||||
def append_items_and_increments(params, items, increments):
|
||||
"""Append pairs of items and increments to params."""
|
||||
for i in range(len(items)):
|
||||
params.append(items[i])
|
||||
params.append(increments[i])
|
||||
|
||||
@staticmethod
|
||||
def append_values_and_weights(params, items, weights):
|
||||
"""Append pairs of items and weights to params."""
|
||||
for i in range(len(items)):
|
||||
params.append(items[i])
|
||||
params.append(weights[i])
|
||||
|
||||
@staticmethod
|
||||
def append_max_iterations(params, max_iterations):
|
||||
"""Append MAXITERATIONS to params."""
|
||||
if max_iterations is not None:
|
||||
params.extend(["MAXITERATIONS", max_iterations])
|
||||
|
||||
@staticmethod
|
||||
def append_bucket_size(params, bucket_size):
|
||||
"""Append BUCKETSIZE to params."""
|
||||
if bucket_size is not None:
|
||||
params.extend(["BUCKETSIZE", bucket_size])
|
||||
|
||||
|
||||
class CMSBloom(CMSCommands, AbstractBloom):
|
||||
def __init__(self, client, **kwargs):
|
||||
"""Create a new RedisBloom client."""
|
||||
# Set the module commands' callbacks
|
||||
_MODULE_CALLBACKS = {
|
||||
CMS_INITBYDIM: bool_ok,
|
||||
CMS_INITBYPROB: bool_ok,
|
||||
# CMS_INCRBY: spaceHolder,
|
||||
# CMS_QUERY: spaceHolder,
|
||||
CMS_MERGE: bool_ok,
|
||||
}
|
||||
|
||||
_RESP2_MODULE_CALLBACKS = {
|
||||
CMS_INFO: CMSInfo,
|
||||
}
|
||||
_RESP3_MODULE_CALLBACKS = {}
|
||||
|
||||
self.client = client
|
||||
self.commandmixin = CMSCommands
|
||||
self.execute_command = client.execute_command
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
|
||||
else:
|
||||
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
|
||||
|
||||
for k, v in _MODULE_CALLBACKS.items():
|
||||
self.client.set_response_callback(k, v)
|
||||
|
||||
|
||||
class TOPKBloom(TOPKCommands, AbstractBloom):
|
||||
def __init__(self, client, **kwargs):
|
||||
"""Create a new RedisBloom client."""
|
||||
# Set the module commands' callbacks
|
||||
_MODULE_CALLBACKS = {
|
||||
TOPK_RESERVE: bool_ok,
|
||||
# TOPK_QUERY: spaceHolder,
|
||||
# TOPK_COUNT: spaceHolder,
|
||||
}
|
||||
|
||||
_RESP2_MODULE_CALLBACKS = {
|
||||
TOPK_ADD: parse_to_list,
|
||||
TOPK_INCRBY: parse_to_list,
|
||||
TOPK_INFO: TopKInfo,
|
||||
TOPK_LIST: parse_to_list,
|
||||
}
|
||||
_RESP3_MODULE_CALLBACKS = {}
|
||||
|
||||
self.client = client
|
||||
self.commandmixin = TOPKCommands
|
||||
self.execute_command = client.execute_command
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
|
||||
else:
|
||||
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
|
||||
|
||||
for k, v in _MODULE_CALLBACKS.items():
|
||||
self.client.set_response_callback(k, v)
|
||||
|
||||
|
||||
class CFBloom(CFCommands, AbstractBloom):
|
||||
def __init__(self, client, **kwargs):
|
||||
"""Create a new RedisBloom client."""
|
||||
# Set the module commands' callbacks
|
||||
_MODULE_CALLBACKS = {
|
||||
CF_RESERVE: bool_ok,
|
||||
# CF_ADD: spaceHolder,
|
||||
# CF_ADDNX: spaceHolder,
|
||||
# CF_INSERT: spaceHolder,
|
||||
# CF_INSERTNX: spaceHolder,
|
||||
# CF_EXISTS: spaceHolder,
|
||||
# CF_DEL: spaceHolder,
|
||||
# CF_COUNT: spaceHolder,
|
||||
# CF_SCANDUMP: spaceHolder,
|
||||
# CF_LOADCHUNK: spaceHolder,
|
||||
}
|
||||
|
||||
_RESP2_MODULE_CALLBACKS = {
|
||||
CF_INFO: CFInfo,
|
||||
}
|
||||
_RESP3_MODULE_CALLBACKS = {}
|
||||
|
||||
self.client = client
|
||||
self.commandmixin = CFCommands
|
||||
self.execute_command = client.execute_command
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
|
||||
else:
|
||||
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
|
||||
|
||||
for k, v in _MODULE_CALLBACKS.items():
|
||||
self.client.set_response_callback(k, v)
|
||||
|
||||
|
||||
class TDigestBloom(TDigestCommands, AbstractBloom):
|
||||
def __init__(self, client, **kwargs):
|
||||
"""Create a new RedisBloom client."""
|
||||
# Set the module commands' callbacks
|
||||
_MODULE_CALLBACKS = {
|
||||
TDIGEST_CREATE: bool_ok,
|
||||
# TDIGEST_RESET: bool_ok,
|
||||
# TDIGEST_ADD: spaceHolder,
|
||||
# TDIGEST_MERGE: spaceHolder,
|
||||
}
|
||||
|
||||
_RESP2_MODULE_CALLBACKS = {
|
||||
TDIGEST_BYRANK: parse_to_list,
|
||||
TDIGEST_BYREVRANK: parse_to_list,
|
||||
TDIGEST_CDF: parse_to_list,
|
||||
TDIGEST_INFO: TDigestInfo,
|
||||
TDIGEST_MIN: float,
|
||||
TDIGEST_MAX: float,
|
||||
TDIGEST_TRIMMED_MEAN: float,
|
||||
TDIGEST_QUANTILE: parse_to_list,
|
||||
}
|
||||
_RESP3_MODULE_CALLBACKS = {}
|
||||
|
||||
self.client = client
|
||||
self.commandmixin = TDigestCommands
|
||||
self.execute_command = client.execute_command
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
|
||||
else:
|
||||
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
|
||||
|
||||
for k, v in _MODULE_CALLBACKS.items():
|
||||
self.client.set_response_callback(k, v)
|
||||
|
||||
|
||||
class BFBloom(BFCommands, AbstractBloom):
|
||||
def __init__(self, client, **kwargs):
|
||||
"""Create a new RedisBloom client."""
|
||||
# Set the module commands' callbacks
|
||||
_MODULE_CALLBACKS = {
|
||||
BF_RESERVE: bool_ok,
|
||||
# BF_ADD: spaceHolder,
|
||||
# BF_MADD: spaceHolder,
|
||||
# BF_INSERT: spaceHolder,
|
||||
# BF_EXISTS: spaceHolder,
|
||||
# BF_MEXISTS: spaceHolder,
|
||||
# BF_SCANDUMP: spaceHolder,
|
||||
# BF_LOADCHUNK: spaceHolder,
|
||||
# BF_CARD: spaceHolder,
|
||||
}
|
||||
|
||||
_RESP2_MODULE_CALLBACKS = {
|
||||
BF_INFO: BFInfo,
|
||||
}
|
||||
_RESP3_MODULE_CALLBACKS = {}
|
||||
|
||||
self.client = client
|
||||
self.commandmixin = BFCommands
|
||||
self.execute_command = client.execute_command
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
|
||||
else:
|
||||
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
|
||||
|
||||
for k, v in _MODULE_CALLBACKS.items():
|
||||
self.client.set_response_callback(k, v)
|
||||
@@ -0,0 +1,538 @@
|
||||
from redis.client import NEVER_DECODE
|
||||
from redis.utils import deprecated_function
|
||||
|
||||
BF_RESERVE = "BF.RESERVE"
|
||||
BF_ADD = "BF.ADD"
|
||||
BF_MADD = "BF.MADD"
|
||||
BF_INSERT = "BF.INSERT"
|
||||
BF_EXISTS = "BF.EXISTS"
|
||||
BF_MEXISTS = "BF.MEXISTS"
|
||||
BF_SCANDUMP = "BF.SCANDUMP"
|
||||
BF_LOADCHUNK = "BF.LOADCHUNK"
|
||||
BF_INFO = "BF.INFO"
|
||||
BF_CARD = "BF.CARD"
|
||||
|
||||
CF_RESERVE = "CF.RESERVE"
|
||||
CF_ADD = "CF.ADD"
|
||||
CF_ADDNX = "CF.ADDNX"
|
||||
CF_INSERT = "CF.INSERT"
|
||||
CF_INSERTNX = "CF.INSERTNX"
|
||||
CF_EXISTS = "CF.EXISTS"
|
||||
CF_MEXISTS = "CF.MEXISTS"
|
||||
CF_DEL = "CF.DEL"
|
||||
CF_COUNT = "CF.COUNT"
|
||||
CF_SCANDUMP = "CF.SCANDUMP"
|
||||
CF_LOADCHUNK = "CF.LOADCHUNK"
|
||||
CF_INFO = "CF.INFO"
|
||||
|
||||
CMS_INITBYDIM = "CMS.INITBYDIM"
|
||||
CMS_INITBYPROB = "CMS.INITBYPROB"
|
||||
CMS_INCRBY = "CMS.INCRBY"
|
||||
CMS_QUERY = "CMS.QUERY"
|
||||
CMS_MERGE = "CMS.MERGE"
|
||||
CMS_INFO = "CMS.INFO"
|
||||
|
||||
TOPK_RESERVE = "TOPK.RESERVE"
|
||||
TOPK_ADD = "TOPK.ADD"
|
||||
TOPK_INCRBY = "TOPK.INCRBY"
|
||||
TOPK_QUERY = "TOPK.QUERY"
|
||||
TOPK_COUNT = "TOPK.COUNT"
|
||||
TOPK_LIST = "TOPK.LIST"
|
||||
TOPK_INFO = "TOPK.INFO"
|
||||
|
||||
TDIGEST_CREATE = "TDIGEST.CREATE"
|
||||
TDIGEST_RESET = "TDIGEST.RESET"
|
||||
TDIGEST_ADD = "TDIGEST.ADD"
|
||||
TDIGEST_MERGE = "TDIGEST.MERGE"
|
||||
TDIGEST_CDF = "TDIGEST.CDF"
|
||||
TDIGEST_QUANTILE = "TDIGEST.QUANTILE"
|
||||
TDIGEST_MIN = "TDIGEST.MIN"
|
||||
TDIGEST_MAX = "TDIGEST.MAX"
|
||||
TDIGEST_INFO = "TDIGEST.INFO"
|
||||
TDIGEST_TRIMMED_MEAN = "TDIGEST.TRIMMED_MEAN"
|
||||
TDIGEST_RANK = "TDIGEST.RANK"
|
||||
TDIGEST_REVRANK = "TDIGEST.REVRANK"
|
||||
TDIGEST_BYRANK = "TDIGEST.BYRANK"
|
||||
TDIGEST_BYREVRANK = "TDIGEST.BYREVRANK"
|
||||
|
||||
|
||||
class BFCommands:
|
||||
"""Bloom Filter commands."""
|
||||
|
||||
def create(self, key, errorRate, capacity, expansion=None, noScale=None):
|
||||
"""
|
||||
Create a new Bloom Filter `key` with desired probability of false positives
|
||||
`errorRate` expected entries to be inserted as `capacity`.
|
||||
Default expansion value is 2. By default, filter is auto-scaling.
|
||||
For more information see `BF.RESERVE <https://redis.io/commands/bf.reserve>`_.
|
||||
""" # noqa
|
||||
params = [key, errorRate, capacity]
|
||||
self.append_expansion(params, expansion)
|
||||
self.append_no_scale(params, noScale)
|
||||
return self.execute_command(BF_RESERVE, *params)
|
||||
|
||||
reserve = create
|
||||
|
||||
def add(self, key, item):
|
||||
"""
|
||||
Add to a Bloom Filter `key` an `item`.
|
||||
For more information see `BF.ADD <https://redis.io/commands/bf.add>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(BF_ADD, key, item)
|
||||
|
||||
def madd(self, key, *items):
|
||||
"""
|
||||
Add to a Bloom Filter `key` multiple `items`.
|
||||
For more information see `BF.MADD <https://redis.io/commands/bf.madd>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(BF_MADD, key, *items)
|
||||
|
||||
def insert(
|
||||
self,
|
||||
key,
|
||||
items,
|
||||
capacity=None,
|
||||
error=None,
|
||||
noCreate=None,
|
||||
expansion=None,
|
||||
noScale=None,
|
||||
):
|
||||
"""
|
||||
Add to a Bloom Filter `key` multiple `items`.
|
||||
|
||||
If `nocreate` remain `None` and `key` does not exist, a new Bloom Filter
|
||||
`key` will be created with desired probability of false positives `errorRate`
|
||||
and expected entries to be inserted as `size`.
|
||||
For more information see `BF.INSERT <https://redis.io/commands/bf.insert>`_.
|
||||
""" # noqa
|
||||
params = [key]
|
||||
self.append_capacity(params, capacity)
|
||||
self.append_error(params, error)
|
||||
self.append_expansion(params, expansion)
|
||||
self.append_no_create(params, noCreate)
|
||||
self.append_no_scale(params, noScale)
|
||||
self.append_items(params, items)
|
||||
|
||||
return self.execute_command(BF_INSERT, *params)
|
||||
|
||||
def exists(self, key, item):
|
||||
"""
|
||||
Check whether an `item` exists in Bloom Filter `key`.
|
||||
For more information see `BF.EXISTS <https://redis.io/commands/bf.exists>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(BF_EXISTS, key, item)
|
||||
|
||||
def mexists(self, key, *items):
|
||||
"""
|
||||
Check whether `items` exist in Bloom Filter `key`.
|
||||
For more information see `BF.MEXISTS <https://redis.io/commands/bf.mexists>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(BF_MEXISTS, key, *items)
|
||||
|
||||
def scandump(self, key, iter):
|
||||
"""
|
||||
Begin an incremental save of the bloom filter `key`.
|
||||
|
||||
This is useful for large bloom filters which cannot fit into the normal SAVE and RESTORE model.
|
||||
The first time this command is called, the value of `iter` should be 0.
|
||||
This command will return successive (iter, data) pairs until (0, NULL) to indicate completion.
|
||||
For more information see `BF.SCANDUMP <https://redis.io/commands/bf.scandump>`_.
|
||||
""" # noqa
|
||||
params = [key, iter]
|
||||
options = {}
|
||||
options[NEVER_DECODE] = []
|
||||
return self.execute_command(BF_SCANDUMP, *params, **options)
|
||||
|
||||
def loadchunk(self, key, iter, data):
|
||||
"""
|
||||
Restore a filter previously saved using SCANDUMP.
|
||||
|
||||
See the SCANDUMP command for example usage.
|
||||
This command will overwrite any bloom filter stored under key.
|
||||
Ensure that the bloom filter will not be modified between invocations.
|
||||
For more information see `BF.LOADCHUNK <https://redis.io/commands/bf.loadchunk>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(BF_LOADCHUNK, key, iter, data)
|
||||
|
||||
def info(self, key):
|
||||
"""
|
||||
Return capacity, size, number of filters, number of items inserted, and expansion rate.
|
||||
For more information see `BF.INFO <https://redis.io/commands/bf.info>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(BF_INFO, key)
|
||||
|
||||
def card(self, key):
|
||||
"""
|
||||
Returns the cardinality of a Bloom filter - number of items that were added to a Bloom filter and detected as unique
|
||||
(items that caused at least one bit to be set in at least one sub-filter).
|
||||
For more information see `BF.CARD <https://redis.io/commands/bf.card>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(BF_CARD, key)
|
||||
|
||||
|
||||
class CFCommands:
|
||||
"""Cuckoo Filter commands."""
|
||||
|
||||
def create(
|
||||
self, key, capacity, expansion=None, bucket_size=None, max_iterations=None
|
||||
):
|
||||
"""
|
||||
Create a new Cuckoo Filter `key` an initial `capacity` items.
|
||||
For more information see `CF.RESERVE <https://redis.io/commands/cf.reserve>`_.
|
||||
""" # noqa
|
||||
params = [key, capacity]
|
||||
self.append_expansion(params, expansion)
|
||||
self.append_bucket_size(params, bucket_size)
|
||||
self.append_max_iterations(params, max_iterations)
|
||||
return self.execute_command(CF_RESERVE, *params)
|
||||
|
||||
reserve = create
|
||||
|
||||
def add(self, key, item):
|
||||
"""
|
||||
Add an `item` to a Cuckoo Filter `key`.
|
||||
For more information see `CF.ADD <https://redis.io/commands/cf.add>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CF_ADD, key, item)
|
||||
|
||||
def addnx(self, key, item):
|
||||
"""
|
||||
Add an `item` to a Cuckoo Filter `key` only if item does not yet exist.
|
||||
Command might be slower that `add`.
|
||||
For more information see `CF.ADDNX <https://redis.io/commands/cf.addnx>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CF_ADDNX, key, item)
|
||||
|
||||
def insert(self, key, items, capacity=None, nocreate=None):
|
||||
"""
|
||||
Add multiple `items` to a Cuckoo Filter `key`, allowing the filter
|
||||
to be created with a custom `capacity` if it does not yet exist.
|
||||
`items` must be provided as a list.
|
||||
For more information see `CF.INSERT <https://redis.io/commands/cf.insert>`_.
|
||||
""" # noqa
|
||||
params = [key]
|
||||
self.append_capacity(params, capacity)
|
||||
self.append_no_create(params, nocreate)
|
||||
self.append_items(params, items)
|
||||
return self.execute_command(CF_INSERT, *params)
|
||||
|
||||
def insertnx(self, key, items, capacity=None, nocreate=None):
|
||||
"""
|
||||
Add multiple `items` to a Cuckoo Filter `key` only if they do not exist yet,
|
||||
allowing the filter to be created with a custom `capacity` if it does not yet exist.
|
||||
`items` must be provided as a list.
|
||||
For more information see `CF.INSERTNX <https://redis.io/commands/cf.insertnx>`_.
|
||||
""" # noqa
|
||||
params = [key]
|
||||
self.append_capacity(params, capacity)
|
||||
self.append_no_create(params, nocreate)
|
||||
self.append_items(params, items)
|
||||
return self.execute_command(CF_INSERTNX, *params)
|
||||
|
||||
def exists(self, key, item):
|
||||
"""
|
||||
Check whether an `item` exists in Cuckoo Filter `key`.
|
||||
For more information see `CF.EXISTS <https://redis.io/commands/cf.exists>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CF_EXISTS, key, item)
|
||||
|
||||
def mexists(self, key, *items):
|
||||
"""
|
||||
Check whether an `items` exist in Cuckoo Filter `key`.
|
||||
For more information see `CF.MEXISTS <https://redis.io/commands/cf.mexists>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CF_MEXISTS, key, *items)
|
||||
|
||||
def delete(self, key, item):
|
||||
"""
|
||||
Delete `item` from `key`.
|
||||
For more information see `CF.DEL <https://redis.io/commands/cf.del>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CF_DEL, key, item)
|
||||
|
||||
def count(self, key, item):
|
||||
"""
|
||||
Return the number of times an `item` may be in the `key`.
|
||||
For more information see `CF.COUNT <https://redis.io/commands/cf.count>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CF_COUNT, key, item)
|
||||
|
||||
def scandump(self, key, iter):
|
||||
"""
|
||||
Begin an incremental save of the Cuckoo filter `key`.
|
||||
|
||||
This is useful for large Cuckoo filters which cannot fit into the normal
|
||||
SAVE and RESTORE model.
|
||||
The first time this command is called, the value of `iter` should be 0.
|
||||
This command will return successive (iter, data) pairs until
|
||||
(0, NULL) to indicate completion.
|
||||
For more information see `CF.SCANDUMP <https://redis.io/commands/cf.scandump>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CF_SCANDUMP, key, iter)
|
||||
|
||||
def loadchunk(self, key, iter, data):
|
||||
"""
|
||||
Restore a filter previously saved using SCANDUMP. See the SCANDUMP command for example usage.
|
||||
|
||||
This command will overwrite any Cuckoo filter stored under key.
|
||||
Ensure that the Cuckoo filter will not be modified between invocations.
|
||||
For more information see `CF.LOADCHUNK <https://redis.io/commands/cf.loadchunk>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CF_LOADCHUNK, key, iter, data)
|
||||
|
||||
def info(self, key):
|
||||
"""
|
||||
Return size, number of buckets, number of filter, number of items inserted,
|
||||
number of items deleted, bucket size, expansion rate, and max iteration.
|
||||
For more information see `CF.INFO <https://redis.io/commands/cf.info>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CF_INFO, key)
|
||||
|
||||
|
||||
class TOPKCommands:
|
||||
"""TOP-k Filter commands."""
|
||||
|
||||
def reserve(self, key, k, width, depth, decay):
|
||||
"""
|
||||
Create a new Top-K Filter `key` with desired probability of false
|
||||
positives `errorRate` expected entries to be inserted as `size`.
|
||||
For more information see `TOPK.RESERVE <https://redis.io/commands/topk.reserve>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TOPK_RESERVE, key, k, width, depth, decay)
|
||||
|
||||
def add(self, key, *items):
|
||||
"""
|
||||
Add one `item` or more to a Top-K Filter `key`.
|
||||
For more information see `TOPK.ADD <https://redis.io/commands/topk.add>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TOPK_ADD, key, *items)
|
||||
|
||||
def incrby(self, key, items, increments):
|
||||
"""
|
||||
Add/increase `items` to a Top-K Sketch `key` by ''increments''.
|
||||
Both `items` and `increments` are lists.
|
||||
For more information see `TOPK.INCRBY <https://redis.io/commands/topk.incrby>`_.
|
||||
|
||||
Example:
|
||||
|
||||
>>> topkincrby('A', ['foo'], [1])
|
||||
""" # noqa
|
||||
params = [key]
|
||||
self.append_items_and_increments(params, items, increments)
|
||||
return self.execute_command(TOPK_INCRBY, *params)
|
||||
|
||||
def query(self, key, *items):
|
||||
"""
|
||||
Check whether one `item` or more is a Top-K item at `key`.
|
||||
For more information see `TOPK.QUERY <https://redis.io/commands/topk.query>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TOPK_QUERY, key, *items)
|
||||
|
||||
@deprecated_function(version="4.4.0", reason="deprecated since redisbloom 2.4.0")
|
||||
def count(self, key, *items):
|
||||
"""
|
||||
Return count for one `item` or more from `key`.
|
||||
For more information see `TOPK.COUNT <https://redis.io/commands/topk.count>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TOPK_COUNT, key, *items)
|
||||
|
||||
def list(self, key, withcount=False):
|
||||
"""
|
||||
Return full list of items in Top-K list of `key`.
|
||||
If `withcount` set to True, return full list of items
|
||||
with probabilistic count in Top-K list of `key`.
|
||||
For more information see `TOPK.LIST <https://redis.io/commands/topk.list>`_.
|
||||
""" # noqa
|
||||
params = [key]
|
||||
if withcount:
|
||||
params.append("WITHCOUNT")
|
||||
return self.execute_command(TOPK_LIST, *params)
|
||||
|
||||
def info(self, key):
|
||||
"""
|
||||
Return k, width, depth and decay values of `key`.
|
||||
For more information see `TOPK.INFO <https://redis.io/commands/topk.info>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TOPK_INFO, key)
|
||||
|
||||
|
||||
class TDigestCommands:
|
||||
def create(self, key, compression=100):
|
||||
"""
|
||||
Allocate the memory and initialize the t-digest.
|
||||
For more information see `TDIGEST.CREATE <https://redis.io/commands/tdigest.create>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_CREATE, key, "COMPRESSION", compression)
|
||||
|
||||
def reset(self, key):
|
||||
"""
|
||||
Reset the sketch `key` to zero - empty out the sketch and re-initialize it.
|
||||
For more information see `TDIGEST.RESET <https://redis.io/commands/tdigest.reset>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_RESET, key)
|
||||
|
||||
def add(self, key, values):
|
||||
"""
|
||||
Adds one or more observations to a t-digest sketch `key`.
|
||||
|
||||
For more information see `TDIGEST.ADD <https://redis.io/commands/tdigest.add>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_ADD, key, *values)
|
||||
|
||||
def merge(self, destination_key, num_keys, *keys, compression=None, override=False):
|
||||
"""
|
||||
Merges all of the values from `keys` to 'destination-key' sketch.
|
||||
It is mandatory to provide the `num_keys` before passing the input keys and
|
||||
the other (optional) arguments.
|
||||
If `destination_key` already exists its values are merged with the input keys.
|
||||
If you wish to override the destination key contents use the `OVERRIDE` parameter.
|
||||
|
||||
For more information see `TDIGEST.MERGE <https://redis.io/commands/tdigest.merge>`_.
|
||||
""" # noqa
|
||||
params = [destination_key, num_keys, *keys]
|
||||
if compression is not None:
|
||||
params.extend(["COMPRESSION", compression])
|
||||
if override:
|
||||
params.append("OVERRIDE")
|
||||
return self.execute_command(TDIGEST_MERGE, *params)
|
||||
|
||||
def min(self, key):
|
||||
"""
|
||||
Return minimum value from the sketch `key`. Will return DBL_MAX if the sketch is empty.
|
||||
For more information see `TDIGEST.MIN <https://redis.io/commands/tdigest.min>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_MIN, key)
|
||||
|
||||
def max(self, key):
|
||||
"""
|
||||
Return maximum value from the sketch `key`. Will return DBL_MIN if the sketch is empty.
|
||||
For more information see `TDIGEST.MAX <https://redis.io/commands/tdigest.max>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_MAX, key)
|
||||
|
||||
def quantile(self, key, quantile, *quantiles):
|
||||
"""
|
||||
Returns estimates of one or more cutoffs such that a specified fraction of the
|
||||
observations added to this t-digest would be less than or equal to each of the
|
||||
specified cutoffs. (Multiple quantiles can be returned with one call)
|
||||
For more information see `TDIGEST.QUANTILE <https://redis.io/commands/tdigest.quantile>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_QUANTILE, key, quantile, *quantiles)
|
||||
|
||||
def cdf(self, key, value, *values):
|
||||
"""
|
||||
Return double fraction of all points added which are <= value.
|
||||
For more information see `TDIGEST.CDF <https://redis.io/commands/tdigest.cdf>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_CDF, key, value, *values)
|
||||
|
||||
def info(self, key):
|
||||
"""
|
||||
Return Compression, Capacity, Merged Nodes, Unmerged Nodes, Merged Weight, Unmerged Weight
|
||||
and Total Compressions.
|
||||
For more information see `TDIGEST.INFO <https://redis.io/commands/tdigest.info>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_INFO, key)
|
||||
|
||||
def trimmed_mean(self, key, low_cut_quantile, high_cut_quantile):
|
||||
"""
|
||||
Return mean value from the sketch, excluding observation values outside
|
||||
the low and high cutoff quantiles.
|
||||
For more information see `TDIGEST.TRIMMED_MEAN <https://redis.io/commands/tdigest.trimmed_mean>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(
|
||||
TDIGEST_TRIMMED_MEAN, key, low_cut_quantile, high_cut_quantile
|
||||
)
|
||||
|
||||
def rank(self, key, value, *values):
|
||||
"""
|
||||
Retrieve the estimated rank of value (the number of observations in the sketch
|
||||
that are smaller than value + half the number of observations that are equal to value).
|
||||
|
||||
For more information see `TDIGEST.RANK <https://redis.io/commands/tdigest.rank>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_RANK, key, value, *values)
|
||||
|
||||
def revrank(self, key, value, *values):
|
||||
"""
|
||||
Retrieve the estimated rank of value (the number of observations in the sketch
|
||||
that are larger than value + half the number of observations that are equal to value).
|
||||
|
||||
For more information see `TDIGEST.REVRANK <https://redis.io/commands/tdigest.revrank>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_REVRANK, key, value, *values)
|
||||
|
||||
def byrank(self, key, rank, *ranks):
|
||||
"""
|
||||
Retrieve an estimation of the value with the given rank.
|
||||
|
||||
For more information see `TDIGEST.BY_RANK <https://redis.io/commands/tdigest.by_rank>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_BYRANK, key, rank, *ranks)
|
||||
|
||||
def byrevrank(self, key, rank, *ranks):
|
||||
"""
|
||||
Retrieve an estimation of the value with the given reverse rank.
|
||||
|
||||
For more information see `TDIGEST.BY_REVRANK <https://redis.io/commands/tdigest.by_revrank>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(TDIGEST_BYREVRANK, key, rank, *ranks)
|
||||
|
||||
|
||||
class CMSCommands:
|
||||
"""Count-Min Sketch Commands"""
|
||||
|
||||
def initbydim(self, key, width, depth):
|
||||
"""
|
||||
Initialize a Count-Min Sketch `key` to dimensions (`width`, `depth`) specified by user.
|
||||
For more information see `CMS.INITBYDIM <https://redis.io/commands/cms.initbydim>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CMS_INITBYDIM, key, width, depth)
|
||||
|
||||
def initbyprob(self, key, error, probability):
|
||||
"""
|
||||
Initialize a Count-Min Sketch `key` to characteristics (`error`, `probability`) specified by user.
|
||||
For more information see `CMS.INITBYPROB <https://redis.io/commands/cms.initbyprob>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CMS_INITBYPROB, key, error, probability)
|
||||
|
||||
def incrby(self, key, items, increments):
|
||||
"""
|
||||
Add/increase `items` to a Count-Min Sketch `key` by ''increments''.
|
||||
Both `items` and `increments` are lists.
|
||||
For more information see `CMS.INCRBY <https://redis.io/commands/cms.incrby>`_.
|
||||
|
||||
Example:
|
||||
|
||||
>>> cmsincrby('A', ['foo'], [1])
|
||||
""" # noqa
|
||||
params = [key]
|
||||
self.append_items_and_increments(params, items, increments)
|
||||
return self.execute_command(CMS_INCRBY, *params)
|
||||
|
||||
def query(self, key, *items):
|
||||
"""
|
||||
Return count for an `item` from `key`. Multiple items can be queried with one call.
|
||||
For more information see `CMS.QUERY <https://redis.io/commands/cms.query>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CMS_QUERY, key, *items)
|
||||
|
||||
def merge(self, destKey, numKeys, srcKeys, weights=[]):
|
||||
"""
|
||||
Merge `numKeys` of sketches into `destKey`. Sketches specified in `srcKeys`.
|
||||
All sketches must have identical width and depth.
|
||||
`Weights` can be used to multiply certain sketches. Default weight is 1.
|
||||
Both `srcKeys` and `weights` are lists.
|
||||
For more information see `CMS.MERGE <https://redis.io/commands/cms.merge>`_.
|
||||
""" # noqa
|
||||
params = [destKey, numKeys]
|
||||
params += srcKeys
|
||||
self.append_weights(params, weights)
|
||||
return self.execute_command(CMS_MERGE, *params)
|
||||
|
||||
def info(self, key):
|
||||
"""
|
||||
Return width, depth and total count of the sketch.
|
||||
For more information see `CMS.INFO <https://redis.io/commands/cms.info>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(CMS_INFO, key)
|
||||
@@ -0,0 +1,120 @@
|
||||
from ..helpers import nativestr
|
||||
|
||||
|
||||
class BFInfo:
|
||||
capacity = None
|
||||
size = None
|
||||
filterNum = None
|
||||
insertedNum = None
|
||||
expansionRate = None
|
||||
|
||||
def __init__(self, args):
|
||||
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
|
||||
self.capacity = response["Capacity"]
|
||||
self.size = response["Size"]
|
||||
self.filterNum = response["Number of filters"]
|
||||
self.insertedNum = response["Number of items inserted"]
|
||||
self.expansionRate = response["Expansion rate"]
|
||||
|
||||
def get(self, item):
|
||||
try:
|
||||
return self.__getitem__(item)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class CFInfo:
|
||||
size = None
|
||||
bucketNum = None
|
||||
filterNum = None
|
||||
insertedNum = None
|
||||
deletedNum = None
|
||||
bucketSize = None
|
||||
expansionRate = None
|
||||
maxIteration = None
|
||||
|
||||
def __init__(self, args):
|
||||
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
|
||||
self.size = response["Size"]
|
||||
self.bucketNum = response["Number of buckets"]
|
||||
self.filterNum = response["Number of filters"]
|
||||
self.insertedNum = response["Number of items inserted"]
|
||||
self.deletedNum = response["Number of items deleted"]
|
||||
self.bucketSize = response["Bucket size"]
|
||||
self.expansionRate = response["Expansion rate"]
|
||||
self.maxIteration = response["Max iterations"]
|
||||
|
||||
def get(self, item):
|
||||
try:
|
||||
return self.__getitem__(item)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class CMSInfo:
|
||||
width = None
|
||||
depth = None
|
||||
count = None
|
||||
|
||||
def __init__(self, args):
|
||||
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
|
||||
self.width = response["width"]
|
||||
self.depth = response["depth"]
|
||||
self.count = response["count"]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class TopKInfo:
|
||||
k = None
|
||||
width = None
|
||||
depth = None
|
||||
decay = None
|
||||
|
||||
def __init__(self, args):
|
||||
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
|
||||
self.k = response["k"]
|
||||
self.width = response["width"]
|
||||
self.depth = response["depth"]
|
||||
self.decay = response["decay"]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class TDigestInfo:
|
||||
compression = None
|
||||
capacity = None
|
||||
merged_nodes = None
|
||||
unmerged_nodes = None
|
||||
merged_weight = None
|
||||
unmerged_weight = None
|
||||
total_compressions = None
|
||||
memory_usage = None
|
||||
|
||||
def __init__(self, args):
|
||||
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
|
||||
self.compression = response["Compression"]
|
||||
self.capacity = response["Capacity"]
|
||||
self.merged_nodes = response["Merged nodes"]
|
||||
self.unmerged_nodes = response["Unmerged nodes"]
|
||||
self.merged_weight = response["Merged weight"]
|
||||
self.unmerged_weight = response["Unmerged weight"]
|
||||
self.total_compressions = response["Total compressions"]
|
||||
self.memory_usage = response["Memory usage"]
|
||||
|
||||
def get(self, item):
|
||||
try:
|
||||
return self.__getitem__(item)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
@@ -0,0 +1,919 @@
|
||||
import asyncio
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from redis.crc import key_slot
|
||||
from redis.exceptions import RedisClusterException, RedisError
|
||||
from redis.typing import (
|
||||
AnyKeyT,
|
||||
ClusterCommandsProtocol,
|
||||
EncodableT,
|
||||
KeysT,
|
||||
KeyT,
|
||||
PatternT,
|
||||
ResponseT,
|
||||
)
|
||||
|
||||
from .core import (
|
||||
ACLCommands,
|
||||
AsyncACLCommands,
|
||||
AsyncDataAccessCommands,
|
||||
AsyncFunctionCommands,
|
||||
AsyncManagementCommands,
|
||||
AsyncModuleCommands,
|
||||
AsyncScriptCommands,
|
||||
DataAccessCommands,
|
||||
FunctionCommands,
|
||||
ManagementCommands,
|
||||
ModuleCommands,
|
||||
PubSubCommands,
|
||||
ScriptCommands,
|
||||
)
|
||||
from .helpers import list_or_args
|
||||
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio.cluster import TargetNodesT
|
||||
|
||||
# Not complete, but covers the major ones
|
||||
# https://redis.io/commands
|
||||
READ_COMMANDS = frozenset(
|
||||
[
|
||||
"BITCOUNT",
|
||||
"BITPOS",
|
||||
"EVAL_RO",
|
||||
"EVALSHA_RO",
|
||||
"EXISTS",
|
||||
"GEODIST",
|
||||
"GEOHASH",
|
||||
"GEOPOS",
|
||||
"GEORADIUS",
|
||||
"GEORADIUSBYMEMBER",
|
||||
"GET",
|
||||
"GETBIT",
|
||||
"GETRANGE",
|
||||
"HEXISTS",
|
||||
"HGET",
|
||||
"HGETALL",
|
||||
"HKEYS",
|
||||
"HLEN",
|
||||
"HMGET",
|
||||
"HSTRLEN",
|
||||
"HVALS",
|
||||
"KEYS",
|
||||
"LINDEX",
|
||||
"LLEN",
|
||||
"LRANGE",
|
||||
"MGET",
|
||||
"PTTL",
|
||||
"RANDOMKEY",
|
||||
"SCARD",
|
||||
"SDIFF",
|
||||
"SINTER",
|
||||
"SISMEMBER",
|
||||
"SMEMBERS",
|
||||
"SRANDMEMBER",
|
||||
"STRLEN",
|
||||
"SUNION",
|
||||
"TTL",
|
||||
"ZCARD",
|
||||
"ZCOUNT",
|
||||
"ZRANGE",
|
||||
"ZSCORE",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ClusterMultiKeyCommands(ClusterCommandsProtocol):
|
||||
"""
|
||||
A class containing commands that handle more than one key
|
||||
"""
|
||||
|
||||
def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]:
|
||||
"""Split keys into a dictionary that maps a slot to a list of keys."""
|
||||
|
||||
slots_to_keys = {}
|
||||
for key in keys:
|
||||
slot = key_slot(self.encoder.encode(key))
|
||||
slots_to_keys.setdefault(slot, []).append(key)
|
||||
|
||||
return slots_to_keys
|
||||
|
||||
def _partition_pairs_by_slot(
|
||||
self, mapping: Mapping[AnyKeyT, EncodableT]
|
||||
) -> Dict[int, List[EncodableT]]:
|
||||
"""Split pairs into a dictionary that maps a slot to a list of pairs."""
|
||||
|
||||
slots_to_pairs = {}
|
||||
for pair in mapping.items():
|
||||
slot = key_slot(self.encoder.encode(pair[0]))
|
||||
slots_to_pairs.setdefault(slot, []).extend(pair)
|
||||
|
||||
return slots_to_pairs
|
||||
|
||||
def _execute_pipeline_by_slot(
|
||||
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
|
||||
) -> List[Any]:
|
||||
read_from_replicas = self.read_from_replicas and command in READ_COMMANDS
|
||||
pipe = self.pipeline()
|
||||
[
|
||||
pipe.execute_command(
|
||||
command,
|
||||
*slot_args,
|
||||
target_nodes=[
|
||||
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
|
||||
],
|
||||
)
|
||||
for slot, slot_args in slots_to_args.items()
|
||||
]
|
||||
return pipe.execute()
|
||||
|
||||
def _reorder_keys_by_command(
|
||||
self,
|
||||
keys: Iterable[KeyT],
|
||||
slots_to_args: Mapping[int, Iterable[EncodableT]],
|
||||
responses: Iterable[Any],
|
||||
) -> List[Any]:
|
||||
results = {
|
||||
k: v
|
||||
for slot_values, response in zip(slots_to_args.values(), responses)
|
||||
for k, v in zip(slot_values, response)
|
||||
}
|
||||
return [results[key] for key in keys]
|
||||
|
||||
def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
|
||||
"""
|
||||
Splits the keys into different slots and then calls MGET
|
||||
for the keys of every slot. This operation will not be atomic
|
||||
if keys belong to more than one slot.
|
||||
|
||||
Returns a list of values ordered identically to ``keys``
|
||||
|
||||
For more information see https://redis.io/commands/mget
|
||||
"""
|
||||
|
||||
# Concatenate all keys into a list
|
||||
keys = list_or_args(keys, args)
|
||||
|
||||
# Split keys into slots
|
||||
slots_to_keys = self._partition_keys_by_slot(keys)
|
||||
|
||||
# Execute commands using a pipeline
|
||||
res = self._execute_pipeline_by_slot("MGET", slots_to_keys)
|
||||
|
||||
# Reorder keys in the order the user provided & return
|
||||
return self._reorder_keys_by_command(keys, slots_to_keys, res)
|
||||
|
||||
def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]:
|
||||
"""
|
||||
Sets key/values based on a mapping. Mapping is a dictionary of
|
||||
key/value pairs. Both keys and values should be strings or types that
|
||||
can be cast to a string via str().
|
||||
|
||||
Splits the keys into different slots and then calls MSET
|
||||
for the keys of every slot. This operation will not be atomic
|
||||
if keys belong to more than one slot.
|
||||
|
||||
For more information see https://redis.io/commands/mset
|
||||
"""
|
||||
|
||||
# Partition the keys by slot
|
||||
slots_to_pairs = self._partition_pairs_by_slot(mapping)
|
||||
|
||||
# Execute commands using a pipeline & return list of replies
|
||||
return self._execute_pipeline_by_slot("MSET", slots_to_pairs)
|
||||
|
||||
def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
|
||||
"""
|
||||
Runs the given command once for the keys
|
||||
of each slot. Returns the sum of the return values.
|
||||
"""
|
||||
|
||||
# Partition the keys by slot
|
||||
slots_to_keys = self._partition_keys_by_slot(keys)
|
||||
|
||||
# Sum up the reply from each command
|
||||
return sum(self._execute_pipeline_by_slot(command, slots_to_keys))
|
||||
|
||||
def exists(self, *keys: KeyT) -> ResponseT:
|
||||
"""
|
||||
Returns the number of ``names`` that exist in the
|
||||
whole cluster. The keys are first split up into slots
|
||||
and then an EXISTS command is sent for every slot
|
||||
|
||||
For more information see https://redis.io/commands/exists
|
||||
"""
|
||||
return self._split_command_across_slots("EXISTS", *keys)
|
||||
|
||||
def delete(self, *keys: KeyT) -> ResponseT:
|
||||
"""
|
||||
Deletes the given keys in the cluster.
|
||||
The keys are first split up into slots
|
||||
and then an DEL command is sent for every slot
|
||||
|
||||
Non-existent keys are ignored.
|
||||
Returns the number of keys that were deleted.
|
||||
|
||||
For more information see https://redis.io/commands/del
|
||||
"""
|
||||
return self._split_command_across_slots("DEL", *keys)
|
||||
|
||||
def touch(self, *keys: KeyT) -> ResponseT:
|
||||
"""
|
||||
Updates the last access time of given keys across the
|
||||
cluster.
|
||||
|
||||
The keys are first split up into slots
|
||||
and then an TOUCH command is sent for every slot
|
||||
|
||||
Non-existent keys are ignored.
|
||||
Returns the number of keys that were touched.
|
||||
|
||||
For more information see https://redis.io/commands/touch
|
||||
"""
|
||||
return self._split_command_across_slots("TOUCH", *keys)
|
||||
|
||||
def unlink(self, *keys: KeyT) -> ResponseT:
|
||||
"""
|
||||
Remove the specified keys in a different thread.
|
||||
|
||||
The keys are first split up into slots
|
||||
and then an TOUCH command is sent for every slot
|
||||
|
||||
Non-existent keys are ignored.
|
||||
Returns the number of keys that were unlinked.
|
||||
|
||||
For more information see https://redis.io/commands/unlink
|
||||
"""
|
||||
return self._split_command_across_slots("UNLINK", *keys)
|
||||
|
||||
|
||||
class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands):
|
||||
"""
|
||||
A class containing commands that handle more than one key
|
||||
"""
|
||||
|
||||
async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
|
||||
"""
|
||||
Splits the keys into different slots and then calls MGET
|
||||
for the keys of every slot. This operation will not be atomic
|
||||
if keys belong to more than one slot.
|
||||
|
||||
Returns a list of values ordered identically to ``keys``
|
||||
|
||||
For more information see https://redis.io/commands/mget
|
||||
"""
|
||||
|
||||
# Concatenate all keys into a list
|
||||
keys = list_or_args(keys, args)
|
||||
|
||||
# Split keys into slots
|
||||
slots_to_keys = self._partition_keys_by_slot(keys)
|
||||
|
||||
# Execute commands using a pipeline
|
||||
res = await self._execute_pipeline_by_slot("MGET", slots_to_keys)
|
||||
|
||||
# Reorder keys in the order the user provided & return
|
||||
return self._reorder_keys_by_command(keys, slots_to_keys, res)
|
||||
|
||||
async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]:
|
||||
"""
|
||||
Sets key/values based on a mapping. Mapping is a dictionary of
|
||||
key/value pairs. Both keys and values should be strings or types that
|
||||
can be cast to a string via str().
|
||||
|
||||
Splits the keys into different slots and then calls MSET
|
||||
for the keys of every slot. This operation will not be atomic
|
||||
if keys belong to more than one slot.
|
||||
|
||||
For more information see https://redis.io/commands/mset
|
||||
"""
|
||||
|
||||
# Partition the keys by slot
|
||||
slots_to_pairs = self._partition_pairs_by_slot(mapping)
|
||||
|
||||
# Execute commands using a pipeline & return list of replies
|
||||
return await self._execute_pipeline_by_slot("MSET", slots_to_pairs)
|
||||
|
||||
async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
|
||||
"""
|
||||
Runs the given command once for the keys
|
||||
of each slot. Returns the sum of the return values.
|
||||
"""
|
||||
|
||||
# Partition the keys by slot
|
||||
slots_to_keys = self._partition_keys_by_slot(keys)
|
||||
|
||||
# Sum up the reply from each command
|
||||
return sum(await self._execute_pipeline_by_slot(command, slots_to_keys))
|
||||
|
||||
async def _execute_pipeline_by_slot(
|
||||
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
|
||||
) -> List[Any]:
|
||||
if self._initialize:
|
||||
await self.initialize()
|
||||
read_from_replicas = self.read_from_replicas and command in READ_COMMANDS
|
||||
pipe = self.pipeline()
|
||||
[
|
||||
pipe.execute_command(
|
||||
command,
|
||||
*slot_args,
|
||||
target_nodes=[
|
||||
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
|
||||
],
|
||||
)
|
||||
for slot, slot_args in slots_to_args.items()
|
||||
]
|
||||
return await pipe.execute()
|
||||
|
||||
|
||||
class ClusterManagementCommands(ManagementCommands):
|
||||
"""
|
||||
A class for Redis Cluster management commands
|
||||
|
||||
The class inherits from Redis's core ManagementCommands class and do the
|
||||
required adjustments to work with cluster mode
|
||||
"""
|
||||
|
||||
def slaveof(self, *args, **kwargs) -> NoReturn:
|
||||
"""
|
||||
Make the server a replica of another instance, or promote it as master.
|
||||
|
||||
For more information see https://redis.io/commands/slaveof
|
||||
"""
|
||||
raise RedisClusterException("SLAVEOF is not supported in cluster mode")
|
||||
|
||||
def replicaof(self, *args, **kwargs) -> NoReturn:
|
||||
"""
|
||||
Make the server a replica of another instance, or promote it as master.
|
||||
|
||||
For more information see https://redis.io/commands/replicaof
|
||||
"""
|
||||
raise RedisClusterException("REPLICAOF is not supported in cluster mode")
|
||||
|
||||
def swapdb(self, *args, **kwargs) -> NoReturn:
|
||||
"""
|
||||
Swaps two Redis databases.
|
||||
|
||||
For more information see https://redis.io/commands/swapdb
|
||||
"""
|
||||
raise RedisClusterException("SWAPDB is not supported in cluster mode")
|
||||
|
||||
def cluster_myid(self, target_node: "TargetNodesT") -> ResponseT:
|
||||
"""
|
||||
Returns the node's id.
|
||||
|
||||
:target_node: 'ClusterNode'
|
||||
The node to execute the command on
|
||||
|
||||
For more information check https://redis.io/commands/cluster-myid/
|
||||
"""
|
||||
return self.execute_command("CLUSTER MYID", target_nodes=target_node)
|
||||
|
||||
def cluster_addslots(
|
||||
self, target_node: "TargetNodesT", *slots: EncodableT
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Assign new hash slots to receiving node. Sends to specified node.
|
||||
|
||||
:target_node: 'ClusterNode'
|
||||
The node to execute the command on
|
||||
|
||||
For more information see https://redis.io/commands/cluster-addslots
|
||||
"""
|
||||
return self.execute_command(
|
||||
"CLUSTER ADDSLOTS", *slots, target_nodes=target_node
|
||||
)
|
||||
|
||||
def cluster_addslotsrange(
|
||||
self, target_node: "TargetNodesT", *slots: EncodableT
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Similar to the CLUSTER ADDSLOTS command.
|
||||
The difference between the two commands is that ADDSLOTS takes a list of slots
|
||||
to assign to the node, while ADDSLOTSRANGE takes a list of slot ranges
|
||||
(specified by start and end slots) to assign to the node.
|
||||
|
||||
:target_node: 'ClusterNode'
|
||||
The node to execute the command on
|
||||
|
||||
For more information see https://redis.io/commands/cluster-addslotsrange
|
||||
"""
|
||||
return self.execute_command(
|
||||
"CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node
|
||||
)
|
||||
|
||||
def cluster_countkeysinslot(self, slot_id: int) -> ResponseT:
|
||||
"""
|
||||
Return the number of local keys in the specified hash slot
|
||||
Send to node based on specified slot_id
|
||||
|
||||
For more information see https://redis.io/commands/cluster-countkeysinslot
|
||||
"""
|
||||
return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id)
|
||||
|
||||
def cluster_count_failure_report(self, node_id: str) -> ResponseT:
|
||||
"""
|
||||
Return the number of failure reports active for a given node
|
||||
Sends to a random node
|
||||
|
||||
For more information see https://redis.io/commands/cluster-count-failure-reports
|
||||
"""
|
||||
return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id)
|
||||
|
||||
def cluster_delslots(self, *slots: EncodableT) -> List[bool]:
|
||||
"""
|
||||
Set hash slots as unbound in the cluster.
|
||||
It determines by it self what node the slot is in and sends it there
|
||||
|
||||
Returns a list of the results for each processed slot.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-delslots
|
||||
"""
|
||||
return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots]
|
||||
|
||||
def cluster_delslotsrange(self, *slots: EncodableT) -> ResponseT:
|
||||
"""
|
||||
Similar to the CLUSTER DELSLOTS command.
|
||||
The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove
|
||||
from the node, while CLUSTER DELSLOTSRANGE takes a list of slot ranges to remove
|
||||
from the node.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-delslotsrange
|
||||
"""
|
||||
return self.execute_command("CLUSTER DELSLOTSRANGE", *slots)
|
||||
|
||||
def cluster_failover(
|
||||
self, target_node: "TargetNodesT", option: Optional[str] = None
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Forces a slave to perform a manual failover of its master
|
||||
Sends to specified node
|
||||
|
||||
:target_node: 'ClusterNode'
|
||||
The node to execute the command on
|
||||
|
||||
For more information see https://redis.io/commands/cluster-failover
|
||||
"""
|
||||
if option:
|
||||
if option.upper() not in ["FORCE", "TAKEOVER"]:
|
||||
raise RedisError(
|
||||
f"Invalid option for CLUSTER FAILOVER command: {option}"
|
||||
)
|
||||
else:
|
||||
return self.execute_command(
|
||||
"CLUSTER FAILOVER", option, target_nodes=target_node
|
||||
)
|
||||
else:
|
||||
return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node)
|
||||
|
||||
def cluster_info(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
|
||||
"""
|
||||
Provides info about Redis Cluster node state.
|
||||
The command will be sent to a random node in the cluster if no target
|
||||
node is specified.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-info
|
||||
"""
|
||||
return self.execute_command("CLUSTER INFO", target_nodes=target_nodes)
|
||||
|
||||
def cluster_keyslot(self, key: str) -> ResponseT:
|
||||
"""
|
||||
Returns the hash slot of the specified key
|
||||
Sends to random node in the cluster
|
||||
|
||||
For more information see https://redis.io/commands/cluster-keyslot
|
||||
"""
|
||||
return self.execute_command("CLUSTER KEYSLOT", key)
|
||||
|
||||
def cluster_meet(
|
||||
self, host: str, port: int, target_nodes: Optional["TargetNodesT"] = None
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Force a node cluster to handshake with another node.
|
||||
Sends to specified node.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-meet
|
||||
"""
|
||||
return self.execute_command(
|
||||
"CLUSTER MEET", host, port, target_nodes=target_nodes
|
||||
)
|
||||
|
||||
def cluster_nodes(self) -> ResponseT:
|
||||
"""
|
||||
Get Cluster config for the node.
|
||||
Sends to random node in the cluster
|
||||
|
||||
For more information see https://redis.io/commands/cluster-nodes
|
||||
"""
|
||||
return self.execute_command("CLUSTER NODES")
|
||||
|
||||
def cluster_replicate(
|
||||
self, target_nodes: "TargetNodesT", node_id: str
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Reconfigure a node as a slave of the specified master node
|
||||
|
||||
For more information see https://redis.io/commands/cluster-replicate
|
||||
"""
|
||||
return self.execute_command(
|
||||
"CLUSTER REPLICATE", node_id, target_nodes=target_nodes
|
||||
)
|
||||
|
||||
def cluster_reset(
|
||||
self, soft: bool = True, target_nodes: Optional["TargetNodesT"] = None
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Reset a Redis Cluster node
|
||||
|
||||
If 'soft' is True then it will send 'SOFT' argument
|
||||
If 'soft' is False then it will send 'HARD' argument
|
||||
|
||||
For more information see https://redis.io/commands/cluster-reset
|
||||
"""
|
||||
return self.execute_command(
|
||||
"CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes
|
||||
)
|
||||
|
||||
def cluster_save_config(
|
||||
self, target_nodes: Optional["TargetNodesT"] = None
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Forces the node to save cluster state on disk
|
||||
|
||||
For more information see https://redis.io/commands/cluster-saveconfig
|
||||
"""
|
||||
return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes)
|
||||
|
||||
def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> ResponseT:
|
||||
"""
|
||||
Returns the number of keys in the specified cluster slot
|
||||
|
||||
For more information see https://redis.io/commands/cluster-getkeysinslot
|
||||
"""
|
||||
return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys)
|
||||
|
||||
def cluster_set_config_epoch(
|
||||
self, epoch: int, target_nodes: Optional["TargetNodesT"] = None
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Set the configuration epoch in a new node
|
||||
|
||||
For more information see https://redis.io/commands/cluster-set-config-epoch
|
||||
"""
|
||||
return self.execute_command(
|
||||
"CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes
|
||||
)
|
||||
|
||||
def cluster_setslot(
|
||||
self, target_node: "TargetNodesT", node_id: str, slot_id: int, state: str
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Bind an hash slot to a specific node
|
||||
|
||||
:target_node: 'ClusterNode'
|
||||
The node to execute the command on
|
||||
|
||||
For more information see https://redis.io/commands/cluster-setslot
|
||||
"""
|
||||
if state.upper() in ("IMPORTING", "NODE", "MIGRATING"):
|
||||
return self.execute_command(
|
||||
"CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node
|
||||
)
|
||||
elif state.upper() == "STABLE":
|
||||
raise RedisError('For "stable" state please use cluster_setslot_stable')
|
||||
else:
|
||||
raise RedisError(f"Invalid slot state: {state}")
|
||||
|
||||
def cluster_setslot_stable(self, slot_id: int) -> ResponseT:
|
||||
"""
|
||||
Clears migrating / importing state from the slot.
|
||||
It determines by it self what node the slot is in and sends it there.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-setslot
|
||||
"""
|
||||
return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE")
|
||||
|
||||
def cluster_replicas(
|
||||
self, node_id: str, target_nodes: Optional["TargetNodesT"] = None
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Provides a list of replica nodes replicating from the specified primary
|
||||
target node.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-replicas
|
||||
"""
|
||||
return self.execute_command(
|
||||
"CLUSTER REPLICAS", node_id, target_nodes=target_nodes
|
||||
)
|
||||
|
||||
def cluster_slots(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
|
||||
"""
|
||||
Get array of Cluster slot to node mappings
|
||||
|
||||
For more information see https://redis.io/commands/cluster-slots
|
||||
"""
|
||||
return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes)
|
||||
|
||||
def cluster_shards(self, target_nodes=None):
|
||||
"""
|
||||
Returns details about the shards of the cluster.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-shards
|
||||
"""
|
||||
return self.execute_command("CLUSTER SHARDS", target_nodes=target_nodes)
|
||||
|
||||
def cluster_myshardid(self, target_nodes=None):
|
||||
"""
|
||||
Returns the shard ID of the node.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-myshardid/
|
||||
"""
|
||||
return self.execute_command("CLUSTER MYSHARDID", target_nodes=target_nodes)
|
||||
|
||||
def cluster_links(self, target_node: "TargetNodesT") -> ResponseT:
|
||||
"""
|
||||
Each node in a Redis Cluster maintains a pair of long-lived TCP link with each
|
||||
peer in the cluster: One for sending outbound messages towards the peer and one
|
||||
for receiving inbound messages from the peer.
|
||||
|
||||
This command outputs information of all such peer links as an array.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-links
|
||||
"""
|
||||
return self.execute_command("CLUSTER LINKS", target_nodes=target_node)
|
||||
|
||||
def cluster_flushslots(self, target_nodes: Optional["TargetNodesT"] = None) -> None:
|
||||
raise NotImplementedError(
|
||||
"CLUSTER FLUSHSLOTS is intentionally not implemented in the client."
|
||||
)
|
||||
|
||||
def cluster_bumpepoch(self, target_nodes: Optional["TargetNodesT"] = None) -> None:
|
||||
raise NotImplementedError(
|
||||
"CLUSTER BUMPEPOCH is intentionally not implemented in the client."
|
||||
)
|
||||
|
||||
def readonly(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
|
||||
"""
|
||||
Enables read queries.
|
||||
The command will be sent to the default cluster node if target_nodes is
|
||||
not specified.
|
||||
|
||||
For more information see https://redis.io/commands/readonly
|
||||
"""
|
||||
if target_nodes == "replicas" or target_nodes == "all":
|
||||
# read_from_replicas will only be enabled if the READONLY command
|
||||
# is sent to all replicas
|
||||
self.read_from_replicas = True
|
||||
return self.execute_command("READONLY", target_nodes=target_nodes)
|
||||
|
||||
def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
|
||||
"""
|
||||
Disables read queries.
|
||||
The command will be sent to the default cluster node if target_nodes is
|
||||
not specified.
|
||||
|
||||
For more information see https://redis.io/commands/readwrite
|
||||
"""
|
||||
# Reset read from replicas flag
|
||||
self.read_from_replicas = False
|
||||
return self.execute_command("READWRITE", target_nodes=target_nodes)
|
||||
|
||||
|
||||
class AsyncClusterManagementCommands(
|
||||
ClusterManagementCommands, AsyncManagementCommands
|
||||
):
|
||||
"""
|
||||
A class for Redis Cluster management commands
|
||||
|
||||
The class inherits from Redis's core ManagementCommands class and do the
|
||||
required adjustments to work with cluster mode
|
||||
"""
|
||||
|
||||
async def cluster_delslots(self, *slots: EncodableT) -> List[bool]:
|
||||
"""
|
||||
Set hash slots as unbound in the cluster.
|
||||
It determines by it self what node the slot is in and sends it there
|
||||
|
||||
Returns a list of the results for each processed slot.
|
||||
|
||||
For more information see https://redis.io/commands/cluster-delslots
|
||||
"""
|
||||
return await asyncio.gather(
|
||||
*(
|
||||
asyncio.create_task(self.execute_command("CLUSTER DELSLOTS", slot))
|
||||
for slot in slots
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ClusterDataAccessCommands(DataAccessCommands):
|
||||
"""
|
||||
A class for Redis Cluster Data Access Commands
|
||||
|
||||
The class inherits from Redis's core DataAccessCommand class and do the
|
||||
required adjustments to work with cluster mode
|
||||
"""
|
||||
|
||||
def stralgo(
|
||||
self,
|
||||
algo: Literal["LCS"],
|
||||
value1: KeyT,
|
||||
value2: KeyT,
|
||||
specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings",
|
||||
len: bool = False,
|
||||
idx: bool = False,
|
||||
minmatchlen: Optional[int] = None,
|
||||
withmatchlen: bool = False,
|
||||
**kwargs,
|
||||
) -> ResponseT:
|
||||
"""
|
||||
Implements complex algorithms that operate on strings.
|
||||
Right now the only algorithm implemented is the LCS algorithm
|
||||
(longest common substring). However new algorithms could be
|
||||
implemented in the future.
|
||||
|
||||
``algo`` Right now must be LCS
|
||||
``value1`` and ``value2`` Can be two strings or two keys
|
||||
``specific_argument`` Specifying if the arguments to the algorithm
|
||||
will be keys or strings. strings is the default.
|
||||
``len`` Returns just the len of the match.
|
||||
``idx`` Returns the match positions in each string.
|
||||
``minmatchlen`` Restrict the list of matches to the ones of a given
|
||||
minimal length. Can be provided only when ``idx`` set to True.
|
||||
``withmatchlen`` Returns the matches with the len of the match.
|
||||
Can be provided only when ``idx`` set to True.
|
||||
|
||||
For more information see https://redis.io/commands/stralgo
|
||||
"""
|
||||
target_nodes = kwargs.pop("target_nodes", None)
|
||||
if specific_argument == "strings" and target_nodes is None:
|
||||
target_nodes = "default-node"
|
||||
kwargs.update({"target_nodes": target_nodes})
|
||||
return super().stralgo(
|
||||
algo,
|
||||
value1,
|
||||
value2,
|
||||
specific_argument,
|
||||
len,
|
||||
idx,
|
||||
minmatchlen,
|
||||
withmatchlen,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def scan_iter(
|
||||
self,
|
||||
match: Optional[PatternT] = None,
|
||||
count: Optional[int] = None,
|
||||
_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Iterator:
|
||||
# Do the first query with cursor=0 for all nodes
|
||||
cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs)
|
||||
yield from data
|
||||
|
||||
cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0}
|
||||
if cursors:
|
||||
# Get nodes by name
|
||||
nodes = {name: self.get_node(node_name=name) for name in cursors.keys()}
|
||||
|
||||
# Iterate over each node till its cursor is 0
|
||||
kwargs.pop("target_nodes", None)
|
||||
while cursors:
|
||||
for name, cursor in cursors.items():
|
||||
cur, data = self.scan(
|
||||
cursor=cursor,
|
||||
match=match,
|
||||
count=count,
|
||||
_type=_type,
|
||||
target_nodes=nodes[name],
|
||||
**kwargs,
|
||||
)
|
||||
yield from data
|
||||
cursors[name] = cur[name]
|
||||
|
||||
cursors = {
|
||||
name: cursor for name, cursor in cursors.items() if cursor != 0
|
||||
}
|
||||
|
||||
|
||||
class AsyncClusterDataAccessCommands(
|
||||
ClusterDataAccessCommands, AsyncDataAccessCommands
|
||||
):
|
||||
"""
|
||||
A class for Redis Cluster Data Access Commands
|
||||
|
||||
The class inherits from Redis's core DataAccessCommand class and do the
|
||||
required adjustments to work with cluster mode
|
||||
"""
|
||||
|
||||
async def scan_iter(
|
||||
self,
|
||||
match: Optional[PatternT] = None,
|
||||
count: Optional[int] = None,
|
||||
_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator:
|
||||
# Do the first query with cursor=0 for all nodes
|
||||
cursors, data = await self.scan(match=match, count=count, _type=_type, **kwargs)
|
||||
for value in data:
|
||||
yield value
|
||||
|
||||
cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0}
|
||||
if cursors:
|
||||
# Get nodes by name
|
||||
nodes = {name: self.get_node(node_name=name) for name in cursors.keys()}
|
||||
|
||||
# Iterate over each node till its cursor is 0
|
||||
kwargs.pop("target_nodes", None)
|
||||
while cursors:
|
||||
for name, cursor in cursors.items():
|
||||
cur, data = await self.scan(
|
||||
cursor=cursor,
|
||||
match=match,
|
||||
count=count,
|
||||
_type=_type,
|
||||
target_nodes=nodes[name],
|
||||
**kwargs,
|
||||
)
|
||||
for value in data:
|
||||
yield value
|
||||
cursors[name] = cur[name]
|
||||
|
||||
cursors = {
|
||||
name: cursor for name, cursor in cursors.items() if cursor != 0
|
||||
}
|
||||
|
||||
|
||||
class RedisClusterCommands(
|
||||
ClusterMultiKeyCommands,
|
||||
ClusterManagementCommands,
|
||||
ACLCommands,
|
||||
PubSubCommands,
|
||||
ClusterDataAccessCommands,
|
||||
ScriptCommands,
|
||||
FunctionCommands,
|
||||
ModuleCommands,
|
||||
RedisModuleCommands,
|
||||
):
|
||||
"""
|
||||
A class for all Redis Cluster commands
|
||||
|
||||
For key-based commands, the target node(s) will be internally determined
|
||||
by the keys' hash slot.
|
||||
Non-key-based commands can be executed with the 'target_nodes' argument to
|
||||
target specific nodes. By default, if target_nodes is not specified, the
|
||||
command will be executed on the default cluster node.
|
||||
|
||||
:param :target_nodes: type can be one of the followings:
|
||||
- nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM
|
||||
- 'ClusterNode'
|
||||
- 'list(ClusterNodes)'
|
||||
- 'dict(any:clusterNodes)'
|
||||
|
||||
for example:
|
||||
r.cluster_info(target_nodes=RedisCluster.ALL_NODES)
|
||||
"""
|
||||
|
||||
|
||||
class AsyncRedisClusterCommands(
|
||||
AsyncClusterMultiKeyCommands,
|
||||
AsyncClusterManagementCommands,
|
||||
AsyncACLCommands,
|
||||
AsyncClusterDataAccessCommands,
|
||||
AsyncScriptCommands,
|
||||
AsyncFunctionCommands,
|
||||
AsyncModuleCommands,
|
||||
AsyncRedisModuleCommands,
|
||||
):
|
||||
"""
|
||||
A class for all Redis Cluster commands
|
||||
|
||||
For key-based commands, the target node(s) will be internally determined
|
||||
by the keys' hash slot.
|
||||
Non-key-based commands can be executed with the 'target_nodes' argument to
|
||||
target specific nodes. By default, if target_nodes is not specified, the
|
||||
command will be executed on the default cluster node.
|
||||
|
||||
:param :target_nodes: type can be one of the followings:
|
||||
- nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM
|
||||
- 'ClusterNode'
|
||||
- 'list(ClusterNodes)'
|
||||
- 'dict(any:clusterNodes)'
|
||||
|
||||
for example:
|
||||
r.cluster_info(target_nodes=RedisCluster.ALL_NODES)
|
||||
"""
|
||||
6769
backend/venv/lib/python3.9/site-packages/redis/commands/core.py
Normal file
6769
backend/venv/lib/python3.9/site-packages/redis/commands/core.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,98 @@
|
||||
import copy
|
||||
import random
|
||||
import string
|
||||
from typing import List, Tuple
|
||||
|
||||
import redis
|
||||
from redis.typing import KeysT, KeyT
|
||||
|
||||
|
||||
def list_or_args(keys: KeysT, args: Tuple[KeyT, ...]) -> List[KeyT]:
|
||||
# returns a single new list combining keys and args
|
||||
try:
|
||||
iter(keys)
|
||||
# a string or bytes instance can be iterated, but indicates
|
||||
# keys wasn't passed as a list
|
||||
if isinstance(keys, (bytes, str)):
|
||||
keys = [keys]
|
||||
else:
|
||||
keys = list(keys)
|
||||
except TypeError:
|
||||
keys = [keys]
|
||||
if args:
|
||||
keys.extend(args)
|
||||
return keys
|
||||
|
||||
|
||||
def nativestr(x):
|
||||
"""Return the decoded binary string, or a string, depending on type."""
|
||||
r = x.decode("utf-8", "replace") if isinstance(x, bytes) else x
|
||||
if r == "null":
|
||||
return
|
||||
return r
|
||||
|
||||
|
||||
def delist(x):
|
||||
"""Given a list of binaries, return the stringified version."""
|
||||
if x is None:
|
||||
return x
|
||||
return [nativestr(obj) for obj in x]
|
||||
|
||||
|
||||
def parse_to_list(response):
|
||||
"""Optimistically parse the response to a list."""
|
||||
res = []
|
||||
|
||||
special_values = {"infinity", "nan", "-infinity"}
|
||||
|
||||
if response is None:
|
||||
return res
|
||||
|
||||
for item in response:
|
||||
if item is None:
|
||||
res.append(None)
|
||||
continue
|
||||
try:
|
||||
item_str = nativestr(item)
|
||||
except TypeError:
|
||||
res.append(None)
|
||||
continue
|
||||
|
||||
if isinstance(item_str, str) and item_str.lower() in special_values:
|
||||
res.append(item_str) # Keep as string
|
||||
else:
|
||||
try:
|
||||
res.append(int(item))
|
||||
except ValueError:
|
||||
try:
|
||||
res.append(float(item))
|
||||
except ValueError:
|
||||
res.append(item_str)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def random_string(length=10):
|
||||
"""
|
||||
Returns a random N character long string.
|
||||
"""
|
||||
return "".join( # nosec
|
||||
random.choice(string.ascii_lowercase) for x in range(length)
|
||||
)
|
||||
|
||||
|
||||
def decode_dict_keys(obj):
|
||||
"""Decode the keys of the given dictionary with utf-8."""
|
||||
newobj = copy.copy(obj)
|
||||
for k in obj.keys():
|
||||
if isinstance(k, bytes):
|
||||
newobj[k.decode("utf-8")] = newobj[k]
|
||||
newobj.pop(k)
|
||||
return newobj
|
||||
|
||||
|
||||
def get_protocol_version(client):
|
||||
if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis):
|
||||
return client.connection_pool.connection_kwargs.get("protocol")
|
||||
elif isinstance(client, redis.cluster.AbstractRedisCluster):
|
||||
return client.nodes_manager.connection_kwargs.get("protocol")
|
||||
@@ -0,0 +1,147 @@
|
||||
from json import JSONDecodeError, JSONDecoder, JSONEncoder
|
||||
|
||||
import redis
|
||||
|
||||
from ..helpers import get_protocol_version, nativestr
|
||||
from .commands import JSONCommands
|
||||
from .decoders import bulk_of_jsons, decode_list
|
||||
|
||||
|
||||
class JSON(JSONCommands):
|
||||
"""
|
||||
Create a client for talking to json.
|
||||
|
||||
:param decoder:
|
||||
:type json.JSONDecoder: An instance of json.JSONDecoder
|
||||
|
||||
:param encoder:
|
||||
:type json.JSONEncoder: An instance of json.JSONEncoder
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, client, version=None, decoder=JSONDecoder(), encoder=JSONEncoder()
|
||||
):
|
||||
"""
|
||||
Create a client for talking to json.
|
||||
|
||||
:param decoder:
|
||||
:type json.JSONDecoder: An instance of json.JSONDecoder
|
||||
|
||||
:param encoder:
|
||||
:type json.JSONEncoder: An instance of json.JSONEncoder
|
||||
"""
|
||||
# Set the module commands' callbacks
|
||||
self._MODULE_CALLBACKS = {
|
||||
"JSON.ARRPOP": self._decode,
|
||||
"JSON.DEBUG": self._decode,
|
||||
"JSON.GET": self._decode,
|
||||
"JSON.MERGE": lambda r: r and nativestr(r) == "OK",
|
||||
"JSON.MGET": bulk_of_jsons(self._decode),
|
||||
"JSON.MSET": lambda r: r and nativestr(r) == "OK",
|
||||
"JSON.RESP": self._decode,
|
||||
"JSON.SET": lambda r: r and nativestr(r) == "OK",
|
||||
"JSON.TOGGLE": self._decode,
|
||||
}
|
||||
|
||||
_RESP2_MODULE_CALLBACKS = {
|
||||
"JSON.ARRAPPEND": self._decode,
|
||||
"JSON.ARRINDEX": self._decode,
|
||||
"JSON.ARRINSERT": self._decode,
|
||||
"JSON.ARRLEN": self._decode,
|
||||
"JSON.ARRTRIM": self._decode,
|
||||
"JSON.CLEAR": int,
|
||||
"JSON.DEL": int,
|
||||
"JSON.FORGET": int,
|
||||
"JSON.GET": self._decode,
|
||||
"JSON.NUMINCRBY": self._decode,
|
||||
"JSON.NUMMULTBY": self._decode,
|
||||
"JSON.OBJKEYS": self._decode,
|
||||
"JSON.STRAPPEND": self._decode,
|
||||
"JSON.OBJLEN": self._decode,
|
||||
"JSON.STRLEN": self._decode,
|
||||
"JSON.TOGGLE": self._decode,
|
||||
}
|
||||
|
||||
_RESP3_MODULE_CALLBACKS = {}
|
||||
|
||||
self.client = client
|
||||
self.execute_command = client.execute_command
|
||||
self.MODULE_VERSION = version
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
|
||||
else:
|
||||
self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
|
||||
|
||||
for key, value in self._MODULE_CALLBACKS.items():
|
||||
self.client.set_response_callback(key, value)
|
||||
|
||||
self.__encoder__ = encoder
|
||||
self.__decoder__ = decoder
|
||||
|
||||
def _decode(self, obj):
|
||||
"""Get the decoder."""
|
||||
if obj is None:
|
||||
return obj
|
||||
|
||||
try:
|
||||
x = self.__decoder__.decode(obj)
|
||||
if x is None:
|
||||
raise TypeError
|
||||
return x
|
||||
except TypeError:
|
||||
try:
|
||||
return self.__decoder__.decode(obj.decode())
|
||||
except AttributeError:
|
||||
return decode_list(obj)
|
||||
except (AttributeError, JSONDecodeError):
|
||||
return decode_list(obj)
|
||||
|
||||
def _encode(self, obj):
|
||||
"""Get the encoder."""
|
||||
return self.__encoder__.encode(obj)
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
"""Creates a pipeline for the JSON module, that can be used for executing
|
||||
JSON commands, as well as classic core commands.
|
||||
|
||||
Usage example:
|
||||
|
||||
r = redis.Redis()
|
||||
pipe = r.json().pipeline()
|
||||
pipe.jsonset('foo', '.', {'hello!': 'world'})
|
||||
pipe.jsonget('foo')
|
||||
pipe.jsonget('notakey')
|
||||
"""
|
||||
if isinstance(self.client, redis.RedisCluster):
|
||||
p = ClusterPipeline(
|
||||
nodes_manager=self.client.nodes_manager,
|
||||
commands_parser=self.client.commands_parser,
|
||||
startup_nodes=self.client.nodes_manager.startup_nodes,
|
||||
result_callbacks=self.client.result_callbacks,
|
||||
cluster_response_callbacks=self.client.cluster_response_callbacks,
|
||||
cluster_error_retry_attempts=self.client.retry.get_retries(),
|
||||
read_from_replicas=self.client.read_from_replicas,
|
||||
reinitialize_steps=self.client.reinitialize_steps,
|
||||
lock=self.client._lock,
|
||||
)
|
||||
|
||||
else:
|
||||
p = Pipeline(
|
||||
connection_pool=self.client.connection_pool,
|
||||
response_callbacks=self._MODULE_CALLBACKS,
|
||||
transaction=transaction,
|
||||
shard_hint=shard_hint,
|
||||
)
|
||||
|
||||
p._encode = self._encode
|
||||
p._decode = self._decode
|
||||
return p
|
||||
|
||||
|
||||
class ClusterPipeline(JSONCommands, redis.cluster.ClusterPipeline):
|
||||
"""Cluster pipeline for the module."""
|
||||
|
||||
|
||||
class Pipeline(JSONCommands, redis.client.Pipeline):
|
||||
"""Pipeline for the module."""
|
||||
@@ -0,0 +1,5 @@
|
||||
from typing import List, Mapping, Union
|
||||
|
||||
JsonType = Union[
|
||||
str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"]
|
||||
]
|
||||
@@ -0,0 +1,431 @@
|
||||
import os
|
||||
from json import JSONDecodeError, loads
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from redis.exceptions import DataError
|
||||
from redis.utils import deprecated_function
|
||||
|
||||
from ._util import JsonType
|
||||
from .decoders import decode_dict_keys
|
||||
from .path import Path
|
||||
|
||||
|
||||
class JSONCommands:
|
||||
"""json commands."""
|
||||
|
||||
def arrappend(
|
||||
self, name: str, path: Optional[str] = Path.root_path(), *args: JsonType
|
||||
) -> List[Optional[int]]:
|
||||
"""Append the objects ``args`` to the array under the
|
||||
``path` in key ``name``.
|
||||
|
||||
For more information see `JSON.ARRAPPEND <https://redis.io/commands/json.arrappend>`_..
|
||||
""" # noqa
|
||||
pieces = [name, str(path)]
|
||||
for o in args:
|
||||
pieces.append(self._encode(o))
|
||||
return self.execute_command("JSON.ARRAPPEND", *pieces)
|
||||
|
||||
def arrindex(
|
||||
self,
|
||||
name: str,
|
||||
path: str,
|
||||
scalar: int,
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> List[Optional[int]]:
|
||||
"""
|
||||
Return the index of ``scalar`` in the JSON array under ``path`` at key
|
||||
``name``.
|
||||
|
||||
The search can be limited using the optional inclusive ``start``
|
||||
and exclusive ``stop`` indices.
|
||||
|
||||
For more information see `JSON.ARRINDEX <https://redis.io/commands/json.arrindex>`_.
|
||||
""" # noqa
|
||||
pieces = [name, str(path), self._encode(scalar)]
|
||||
if start is not None:
|
||||
pieces.append(start)
|
||||
if stop is not None:
|
||||
pieces.append(stop)
|
||||
|
||||
return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name])
|
||||
|
||||
def arrinsert(
|
||||
self, name: str, path: str, index: int, *args: JsonType
|
||||
) -> List[Optional[int]]:
|
||||
"""Insert the objects ``args`` to the array at index ``index``
|
||||
under the ``path` in key ``name``.
|
||||
|
||||
For more information see `JSON.ARRINSERT <https://redis.io/commands/json.arrinsert>`_.
|
||||
""" # noqa
|
||||
pieces = [name, str(path), index]
|
||||
for o in args:
|
||||
pieces.append(self._encode(o))
|
||||
return self.execute_command("JSON.ARRINSERT", *pieces)
|
||||
|
||||
def arrlen(
|
||||
self, name: str, path: Optional[str] = Path.root_path()
|
||||
) -> List[Optional[int]]:
|
||||
"""Return the length of the array JSON value under ``path``
|
||||
at key``name``.
|
||||
|
||||
For more information see `JSON.ARRLEN <https://redis.io/commands/json.arrlen>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name])
|
||||
|
||||
def arrpop(
|
||||
self,
|
||||
name: str,
|
||||
path: Optional[str] = Path.root_path(),
|
||||
index: Optional[int] = -1,
|
||||
) -> List[Optional[str]]:
|
||||
"""Pop the element at ``index`` in the array JSON value under
|
||||
``path`` at key ``name``.
|
||||
|
||||
For more information see `JSON.ARRPOP <https://redis.io/commands/json.arrpop>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.ARRPOP", name, str(path), index)
|
||||
|
||||
def arrtrim(
|
||||
self, name: str, path: str, start: int, stop: int
|
||||
) -> List[Optional[int]]:
|
||||
"""Trim the array JSON value under ``path`` at key ``name`` to the
|
||||
inclusive range given by ``start`` and ``stop``.
|
||||
|
||||
For more information see `JSON.ARRTRIM <https://redis.io/commands/json.arrtrim>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.ARRTRIM", name, str(path), start, stop)
|
||||
|
||||
def type(self, name: str, path: Optional[str] = Path.root_path()) -> List[str]:
|
||||
"""Get the type of the JSON value under ``path`` from key ``name``.
|
||||
|
||||
For more information see `JSON.TYPE <https://redis.io/commands/json.type>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.TYPE", name, str(path), keys=[name])
|
||||
|
||||
def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List:
|
||||
"""Return the JSON value under ``path`` at key ``name``.
|
||||
|
||||
For more information see `JSON.RESP <https://redis.io/commands/json.resp>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.RESP", name, str(path), keys=[name])
|
||||
|
||||
def objkeys(
|
||||
self, name: str, path: Optional[str] = Path.root_path()
|
||||
) -> List[Optional[List[str]]]:
|
||||
"""Return the key names in the dictionary JSON value under ``path`` at
|
||||
key ``name``.
|
||||
|
||||
For more information see `JSON.OBJKEYS <https://redis.io/commands/json.objkeys>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name])
|
||||
|
||||
def objlen(
|
||||
self, name: str, path: Optional[str] = Path.root_path()
|
||||
) -> List[Optional[int]]:
|
||||
"""Return the length of the dictionary JSON value under ``path`` at key
|
||||
``name``.
|
||||
|
||||
For more information see `JSON.OBJLEN <https://redis.io/commands/json.objlen>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name])
|
||||
|
||||
def numincrby(self, name: str, path: str, number: int) -> str:
|
||||
"""Increment the numeric (integer or floating point) JSON value under
|
||||
``path`` at key ``name`` by the provided ``number``.
|
||||
|
||||
For more information see `JSON.NUMINCRBY <https://redis.io/commands/json.numincrby>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(
|
||||
"JSON.NUMINCRBY", name, str(path), self._encode(number)
|
||||
)
|
||||
|
||||
@deprecated_function(version="4.0.0", reason="deprecated since redisjson 1.0.0")
|
||||
def nummultby(self, name: str, path: str, number: int) -> str:
|
||||
"""Multiply the numeric (integer or floating point) JSON value under
|
||||
``path`` at key ``name`` with the provided ``number``.
|
||||
|
||||
For more information see `JSON.NUMMULTBY <https://redis.io/commands/json.nummultby>`_.
|
||||
""" # noqa
|
||||
return self.execute_command(
|
||||
"JSON.NUMMULTBY", name, str(path), self._encode(number)
|
||||
)
|
||||
|
||||
def clear(self, name: str, path: Optional[str] = Path.root_path()) -> int:
|
||||
"""Empty arrays and objects (to have zero slots/keys without deleting the
|
||||
array/object).
|
||||
|
||||
Return the count of cleared paths (ignoring non-array and non-objects
|
||||
paths).
|
||||
|
||||
For more information see `JSON.CLEAR <https://redis.io/commands/json.clear>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.CLEAR", name, str(path))
|
||||
|
||||
def delete(self, key: str, path: Optional[str] = Path.root_path()) -> int:
|
||||
"""Delete the JSON value stored at key ``key`` under ``path``.
|
||||
|
||||
For more information see `JSON.DEL <https://redis.io/commands/json.del>`_.
|
||||
"""
|
||||
return self.execute_command("JSON.DEL", key, str(path))
|
||||
|
||||
# forget is an alias for delete
|
||||
forget = delete
|
||||
|
||||
def get(
|
||||
self, name: str, *args, no_escape: Optional[bool] = False
|
||||
) -> Optional[List[JsonType]]:
|
||||
"""
|
||||
Get the object stored as a JSON value at key ``name``.
|
||||
|
||||
``args`` is zero or more paths, and defaults to root path
|
||||
```no_escape`` is a boolean flag to add no_escape option to get
|
||||
non-ascii characters
|
||||
|
||||
For more information see `JSON.GET <https://redis.io/commands/json.get>`_.
|
||||
""" # noqa
|
||||
pieces = [name]
|
||||
if no_escape:
|
||||
pieces.append("noescape")
|
||||
|
||||
if len(args) == 0:
|
||||
pieces.append(Path.root_path())
|
||||
|
||||
else:
|
||||
for p in args:
|
||||
pieces.append(str(p))
|
||||
|
||||
# Handle case where key doesn't exist. The JSONDecoder would raise a
|
||||
# TypeError exception since it can't decode None
|
||||
try:
|
||||
return self.execute_command("JSON.GET", *pieces, keys=[name])
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
def mget(self, keys: List[str], path: str) -> List[JsonType]:
|
||||
"""
|
||||
Get the objects stored as a JSON values under ``path``. ``keys``
|
||||
is a list of one or more keys.
|
||||
|
||||
For more information see `JSON.MGET <https://redis.io/commands/json.mget>`_.
|
||||
""" # noqa
|
||||
pieces = []
|
||||
pieces += keys
|
||||
pieces.append(str(path))
|
||||
return self.execute_command("JSON.MGET", *pieces, keys=keys)
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str,
|
||||
path: str,
|
||||
obj: JsonType,
|
||||
nx: Optional[bool] = False,
|
||||
xx: Optional[bool] = False,
|
||||
decode_keys: Optional[bool] = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Set the JSON value at key ``name`` under the ``path`` to ``obj``.
|
||||
|
||||
``nx`` if set to True, set ``value`` only if it does not exist.
|
||||
``xx`` if set to True, set ``value`` only if it exists.
|
||||
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
|
||||
with utf-8.
|
||||
|
||||
For the purpose of using this within a pipeline, this command is also
|
||||
aliased to JSON.SET.
|
||||
|
||||
For more information see `JSON.SET <https://redis.io/commands/json.set>`_.
|
||||
"""
|
||||
if decode_keys:
|
||||
obj = decode_dict_keys(obj)
|
||||
|
||||
pieces = [name, str(path), self._encode(obj)]
|
||||
|
||||
# Handle existential modifiers
|
||||
if nx and xx:
|
||||
raise Exception(
|
||||
"nx and xx are mutually exclusive: use one, the "
|
||||
"other or neither - but not both"
|
||||
)
|
||||
elif nx:
|
||||
pieces.append("NX")
|
||||
elif xx:
|
||||
pieces.append("XX")
|
||||
return self.execute_command("JSON.SET", *pieces)
|
||||
|
||||
def mset(self, triplets: List[Tuple[str, str, JsonType]]) -> Optional[str]:
|
||||
"""
|
||||
Set the JSON value at key ``name`` under the ``path`` to ``obj``
|
||||
for one or more keys.
|
||||
|
||||
``triplets`` is a list of one or more triplets of key, path, value.
|
||||
|
||||
For the purpose of using this within a pipeline, this command is also
|
||||
aliased to JSON.MSET.
|
||||
|
||||
For more information see `JSON.MSET <https://redis.io/commands/json.mset>`_.
|
||||
"""
|
||||
pieces = []
|
||||
for triplet in triplets:
|
||||
pieces.extend([triplet[0], str(triplet[1]), self._encode(triplet[2])])
|
||||
return self.execute_command("JSON.MSET", *pieces)
|
||||
|
||||
def merge(
|
||||
self,
|
||||
name: str,
|
||||
path: str,
|
||||
obj: JsonType,
|
||||
decode_keys: Optional[bool] = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Merges a given JSON value into matching paths. Consequently, JSON values
|
||||
at matching paths are updated, deleted, or expanded with new children
|
||||
|
||||
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
|
||||
with utf-8.
|
||||
|
||||
For more information see `JSON.MERGE <https://redis.io/commands/json.merge>`_.
|
||||
"""
|
||||
if decode_keys:
|
||||
obj = decode_dict_keys(obj)
|
||||
|
||||
pieces = [name, str(path), self._encode(obj)]
|
||||
|
||||
return self.execute_command("JSON.MERGE", *pieces)
|
||||
|
||||
def set_file(
|
||||
self,
|
||||
name: str,
|
||||
path: str,
|
||||
file_name: str,
|
||||
nx: Optional[bool] = False,
|
||||
xx: Optional[bool] = False,
|
||||
decode_keys: Optional[bool] = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Set the JSON value at key ``name`` under the ``path`` to the content
|
||||
of the json file ``file_name``.
|
||||
|
||||
``nx`` if set to True, set ``value`` only if it does not exist.
|
||||
``xx`` if set to True, set ``value`` only if it exists.
|
||||
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
|
||||
with utf-8.
|
||||
|
||||
"""
|
||||
|
||||
with open(file_name) as fp:
|
||||
file_content = loads(fp.read())
|
||||
|
||||
return self.set(name, path, file_content, nx=nx, xx=xx, decode_keys=decode_keys)
|
||||
|
||||
def set_path(
|
||||
self,
|
||||
json_path: str,
|
||||
root_folder: str,
|
||||
nx: Optional[bool] = False,
|
||||
xx: Optional[bool] = False,
|
||||
decode_keys: Optional[bool] = False,
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
Iterate over ``root_folder`` and set each JSON file to a value
|
||||
under ``json_path`` with the file name as the key.
|
||||
|
||||
``nx`` if set to True, set ``value`` only if it does not exist.
|
||||
``xx`` if set to True, set ``value`` only if it exists.
|
||||
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
|
||||
with utf-8.
|
||||
|
||||
"""
|
||||
set_files_result = {}
|
||||
for root, dirs, files in os.walk(root_folder):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
file_name = file_path.rsplit(".")[0]
|
||||
self.set_file(
|
||||
file_name,
|
||||
json_path,
|
||||
file_path,
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
decode_keys=decode_keys,
|
||||
)
|
||||
set_files_result[file_path] = True
|
||||
except JSONDecodeError:
|
||||
set_files_result[file_path] = False
|
||||
|
||||
return set_files_result
|
||||
|
||||
def strlen(self, name: str, path: Optional[str] = None) -> List[Optional[int]]:
|
||||
"""Return the length of the string JSON value under ``path`` at key
|
||||
``name``.
|
||||
|
||||
For more information see `JSON.STRLEN <https://redis.io/commands/json.strlen>`_.
|
||||
""" # noqa
|
||||
pieces = [name]
|
||||
if path is not None:
|
||||
pieces.append(str(path))
|
||||
return self.execute_command("JSON.STRLEN", *pieces, keys=[name])
|
||||
|
||||
def toggle(
|
||||
self, name: str, path: Optional[str] = Path.root_path()
|
||||
) -> Union[bool, List[Optional[int]]]:
|
||||
"""Toggle boolean value under ``path`` at key ``name``.
|
||||
returning the new value.
|
||||
|
||||
For more information see `JSON.TOGGLE <https://redis.io/commands/json.toggle>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.TOGGLE", name, str(path))
|
||||
|
||||
def strappend(
|
||||
self, name: str, value: str, path: Optional[str] = Path.root_path()
|
||||
) -> Union[int, List[Optional[int]]]:
|
||||
"""Append to the string JSON value. If two options are specified after
|
||||
the key name, the path is determined to be the first. If a single
|
||||
option is passed, then the root_path (i.e Path.root_path()) is used.
|
||||
|
||||
For more information see `JSON.STRAPPEND <https://redis.io/commands/json.strappend>`_.
|
||||
""" # noqa
|
||||
pieces = [name, str(path), self._encode(value)]
|
||||
return self.execute_command("JSON.STRAPPEND", *pieces)
|
||||
|
||||
def debug(
|
||||
self,
|
||||
subcommand: str,
|
||||
key: Optional[str] = None,
|
||||
path: Optional[str] = Path.root_path(),
|
||||
) -> Union[int, List[str]]:
|
||||
"""Return the memory usage in bytes of a value under ``path`` from
|
||||
key ``name``.
|
||||
|
||||
For more information see `JSON.DEBUG <https://redis.io/commands/json.debug>`_.
|
||||
""" # noqa
|
||||
valid_subcommands = ["MEMORY", "HELP"]
|
||||
if subcommand not in valid_subcommands:
|
||||
raise DataError("The only valid subcommands are ", str(valid_subcommands))
|
||||
pieces = [subcommand]
|
||||
if subcommand == "MEMORY":
|
||||
if key is None:
|
||||
raise DataError("No key specified")
|
||||
pieces.append(key)
|
||||
pieces.append(str(path))
|
||||
return self.execute_command("JSON.DEBUG", *pieces)
|
||||
|
||||
@deprecated_function(
|
||||
version="4.0.0", reason="redisjson-py supported this, call get directly."
|
||||
)
|
||||
def jsonget(self, *args, **kwargs):
|
||||
return self.get(*args, **kwargs)
|
||||
|
||||
@deprecated_function(
|
||||
version="4.0.0", reason="redisjson-py supported this, call get directly."
|
||||
)
|
||||
def jsonmget(self, *args, **kwargs):
|
||||
return self.mget(*args, **kwargs)
|
||||
|
||||
@deprecated_function(
|
||||
version="4.0.0", reason="redisjson-py supported this, call get directly."
|
||||
)
|
||||
def jsonset(self, *args, **kwargs):
|
||||
return self.set(*args, **kwargs)
|
||||
@@ -0,0 +1,60 @@
|
||||
import copy
|
||||
import re
|
||||
|
||||
from ..helpers import nativestr
|
||||
|
||||
|
||||
def bulk_of_jsons(d):
|
||||
"""Replace serialized JSON values with objects in a
|
||||
bulk array response (list).
|
||||
"""
|
||||
|
||||
def _f(b):
|
||||
for index, item in enumerate(b):
|
||||
if item is not None:
|
||||
b[index] = d(item)
|
||||
return b
|
||||
|
||||
return _f
|
||||
|
||||
|
||||
def decode_dict_keys(obj):
|
||||
"""Decode the keys of the given dictionary with utf-8."""
|
||||
newobj = copy.copy(obj)
|
||||
for k in obj.keys():
|
||||
if isinstance(k, bytes):
|
||||
newobj[k.decode("utf-8")] = newobj[k]
|
||||
newobj.pop(k)
|
||||
return newobj
|
||||
|
||||
|
||||
def unstring(obj):
|
||||
"""
|
||||
Attempt to parse string to native integer formats.
|
||||
One can't simply call int/float in a try/catch because there is a
|
||||
semantic difference between (for example) 15.0 and 15.
|
||||
"""
|
||||
floatreg = "^\\d+.\\d+$"
|
||||
match = re.findall(floatreg, obj)
|
||||
if match != []:
|
||||
return float(match[0])
|
||||
|
||||
intreg = "^\\d+$"
|
||||
match = re.findall(intreg, obj)
|
||||
if match != []:
|
||||
return int(match[0])
|
||||
return obj
|
||||
|
||||
|
||||
def decode_list(b):
|
||||
"""
|
||||
Given a non-deserializable object, make a best effort to
|
||||
return a useful set of results.
|
||||
"""
|
||||
if isinstance(b, list):
|
||||
return [nativestr(obj) for obj in b]
|
||||
elif isinstance(b, bytes):
|
||||
return unstring(nativestr(b))
|
||||
elif isinstance(b, str):
|
||||
return unstring(b)
|
||||
return b
|
||||
@@ -0,0 +1,16 @@
|
||||
class Path:
|
||||
"""This class represents a path in a JSON value."""
|
||||
|
||||
strPath = ""
|
||||
|
||||
@staticmethod
|
||||
def root_path():
|
||||
"""Return the root path's string representation."""
|
||||
return "."
|
||||
|
||||
def __init__(self, path):
|
||||
"""Make a new path based on the string representation in `path`."""
|
||||
self.strPath = path
|
||||
|
||||
def __repr__(self):
|
||||
return self.strPath
|
||||
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from json import JSONDecoder, JSONEncoder
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bf import BFBloom, CFBloom, CMSBloom, TDigestBloom, TOPKBloom
|
||||
from .json import JSON
|
||||
from .search import AsyncSearch, Search
|
||||
from .timeseries import TimeSeries
|
||||
from .vectorset import VectorSet
|
||||
|
||||
|
||||
class RedisModuleCommands:
|
||||
"""This class contains the wrapper functions to bring supported redis
|
||||
modules into the command namespace.
|
||||
"""
|
||||
|
||||
def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()) -> JSON:
|
||||
"""Access the json namespace, providing support for redis json."""
|
||||
|
||||
from .json import JSON
|
||||
|
||||
jj = JSON(client=self, encoder=encoder, decoder=decoder)
|
||||
return jj
|
||||
|
||||
def ft(self, index_name="idx") -> Search:
|
||||
"""Access the search namespace, providing support for redis search."""
|
||||
|
||||
from .search import Search
|
||||
|
||||
s = Search(client=self, index_name=index_name)
|
||||
return s
|
||||
|
||||
def ts(self) -> TimeSeries:
|
||||
"""Access the timeseries namespace, providing support for
|
||||
redis timeseries data.
|
||||
"""
|
||||
|
||||
from .timeseries import TimeSeries
|
||||
|
||||
s = TimeSeries(client=self)
|
||||
return s
|
||||
|
||||
def bf(self) -> BFBloom:
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import BFBloom
|
||||
|
||||
bf = BFBloom(client=self)
|
||||
return bf
|
||||
|
||||
def cf(self) -> CFBloom:
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import CFBloom
|
||||
|
||||
cf = CFBloom(client=self)
|
||||
return cf
|
||||
|
||||
def cms(self) -> CMSBloom:
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import CMSBloom
|
||||
|
||||
cms = CMSBloom(client=self)
|
||||
return cms
|
||||
|
||||
def topk(self) -> TOPKBloom:
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import TOPKBloom
|
||||
|
||||
topk = TOPKBloom(client=self)
|
||||
return topk
|
||||
|
||||
def tdigest(self) -> TDigestBloom:
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import TDigestBloom
|
||||
|
||||
tdigest = TDigestBloom(client=self)
|
||||
return tdigest
|
||||
|
||||
def vset(self) -> VectorSet:
|
||||
"""Access the VectorSet commands namespace."""
|
||||
|
||||
from .vectorset import VectorSet
|
||||
|
||||
vset = VectorSet(client=self)
|
||||
return vset
|
||||
|
||||
|
||||
class AsyncRedisModuleCommands(RedisModuleCommands):
|
||||
def ft(self, index_name="idx") -> AsyncSearch:
|
||||
"""Access the search namespace, providing support for redis search."""
|
||||
|
||||
from .search import AsyncSearch
|
||||
|
||||
s = AsyncSearch(client=self, index_name=index_name)
|
||||
return s
|
||||
@@ -0,0 +1,189 @@
|
||||
from redis.client import Pipeline as RedisPipeline
|
||||
|
||||
from ...asyncio.client import Pipeline as AsyncioPipeline
|
||||
from .commands import (
|
||||
AGGREGATE_CMD,
|
||||
CONFIG_CMD,
|
||||
INFO_CMD,
|
||||
PROFILE_CMD,
|
||||
SEARCH_CMD,
|
||||
SPELLCHECK_CMD,
|
||||
SYNDUMP_CMD,
|
||||
AsyncSearchCommands,
|
||||
SearchCommands,
|
||||
)
|
||||
|
||||
|
||||
class Search(SearchCommands):
|
||||
"""
|
||||
Create a client for talking to search.
|
||||
It abstracts the API of the module and lets you just use the engine.
|
||||
"""
|
||||
|
||||
class BatchIndexer:
|
||||
"""
|
||||
A batch indexer allows you to automatically batch
|
||||
document indexing in pipelines, flushing it every N documents.
|
||||
"""
|
||||
|
||||
def __init__(self, client, chunk_size=1000):
|
||||
self.client = client
|
||||
self.execute_command = client.execute_command
|
||||
self._pipeline = client.pipeline(transaction=False, shard_hint=None)
|
||||
self.total = 0
|
||||
self.chunk_size = chunk_size
|
||||
self.current_chunk = 0
|
||||
|
||||
def __del__(self):
|
||||
if self.current_chunk:
|
||||
self.commit()
|
||||
|
||||
def add_document(
|
||||
self,
|
||||
doc_id,
|
||||
nosave=False,
|
||||
score=1.0,
|
||||
payload=None,
|
||||
replace=False,
|
||||
partial=False,
|
||||
no_create=False,
|
||||
**fields,
|
||||
):
|
||||
"""
|
||||
Add a document to the batch query
|
||||
"""
|
||||
self.client._add_document(
|
||||
doc_id,
|
||||
conn=self._pipeline,
|
||||
nosave=nosave,
|
||||
score=score,
|
||||
payload=payload,
|
||||
replace=replace,
|
||||
partial=partial,
|
||||
no_create=no_create,
|
||||
**fields,
|
||||
)
|
||||
self.current_chunk += 1
|
||||
self.total += 1
|
||||
if self.current_chunk >= self.chunk_size:
|
||||
self.commit()
|
||||
|
||||
def add_document_hash(self, doc_id, score=1.0, replace=False):
|
||||
"""
|
||||
Add a hash to the batch query
|
||||
"""
|
||||
self.client._add_document_hash(
|
||||
doc_id, conn=self._pipeline, score=score, replace=replace
|
||||
)
|
||||
self.current_chunk += 1
|
||||
self.total += 1
|
||||
if self.current_chunk >= self.chunk_size:
|
||||
self.commit()
|
||||
|
||||
def commit(self):
|
||||
"""
|
||||
Manually commit and flush the batch indexing query
|
||||
"""
|
||||
self._pipeline.execute()
|
||||
self.current_chunk = 0
|
||||
|
||||
def __init__(self, client, index_name="idx"):
|
||||
"""
|
||||
Create a new Client for the given index_name.
|
||||
The default name is `idx`
|
||||
|
||||
If conn is not None, we employ an already existing redis connection
|
||||
"""
|
||||
self._MODULE_CALLBACKS = {}
|
||||
self.client = client
|
||||
self.index_name = index_name
|
||||
self.execute_command = client.execute_command
|
||||
self._pipeline = client.pipeline
|
||||
self._RESP2_MODULE_CALLBACKS = {
|
||||
INFO_CMD: self._parse_info,
|
||||
SEARCH_CMD: self._parse_search,
|
||||
AGGREGATE_CMD: self._parse_aggregate,
|
||||
PROFILE_CMD: self._parse_profile,
|
||||
SPELLCHECK_CMD: self._parse_spellcheck,
|
||||
CONFIG_CMD: self._parse_config_get,
|
||||
SYNDUMP_CMD: self._parse_syndump,
|
||||
}
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
"""Creates a pipeline for the SEARCH module, that can be used for executing
|
||||
SEARCH commands, as well as classic core commands.
|
||||
"""
|
||||
p = Pipeline(
|
||||
connection_pool=self.client.connection_pool,
|
||||
response_callbacks=self._MODULE_CALLBACKS,
|
||||
transaction=transaction,
|
||||
shard_hint=shard_hint,
|
||||
)
|
||||
p.index_name = self.index_name
|
||||
return p
|
||||
|
||||
|
||||
class AsyncSearch(Search, AsyncSearchCommands):
|
||||
class BatchIndexer(Search.BatchIndexer):
|
||||
"""
|
||||
A batch indexer allows you to automatically batch
|
||||
document indexing in pipelines, flushing it every N documents.
|
||||
"""
|
||||
|
||||
async def add_document(
|
||||
self,
|
||||
doc_id,
|
||||
nosave=False,
|
||||
score=1.0,
|
||||
payload=None,
|
||||
replace=False,
|
||||
partial=False,
|
||||
no_create=False,
|
||||
**fields,
|
||||
):
|
||||
"""
|
||||
Add a document to the batch query
|
||||
"""
|
||||
self.client._add_document(
|
||||
doc_id,
|
||||
conn=self._pipeline,
|
||||
nosave=nosave,
|
||||
score=score,
|
||||
payload=payload,
|
||||
replace=replace,
|
||||
partial=partial,
|
||||
no_create=no_create,
|
||||
**fields,
|
||||
)
|
||||
self.current_chunk += 1
|
||||
self.total += 1
|
||||
if self.current_chunk >= self.chunk_size:
|
||||
await self.commit()
|
||||
|
||||
async def commit(self):
|
||||
"""
|
||||
Manually commit and flush the batch indexing query
|
||||
"""
|
||||
await self._pipeline.execute()
|
||||
self.current_chunk = 0
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
"""Creates a pipeline for the SEARCH module, that can be used for executing
|
||||
SEARCH commands, as well as classic core commands.
|
||||
"""
|
||||
p = AsyncPipeline(
|
||||
connection_pool=self.client.connection_pool,
|
||||
response_callbacks=self._MODULE_CALLBACKS,
|
||||
transaction=transaction,
|
||||
shard_hint=shard_hint,
|
||||
)
|
||||
p.index_name = self.index_name
|
||||
return p
|
||||
|
||||
|
||||
class Pipeline(SearchCommands, RedisPipeline):
|
||||
"""Pipeline for the module."""
|
||||
|
||||
|
||||
class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline):
|
||||
"""AsyncPipeline for the module."""
|
||||
@@ -0,0 +1,7 @@
|
||||
def to_string(s, encoding: str = "utf-8"):
|
||||
if isinstance(s, str):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
return s.decode(encoding, "ignore")
|
||||
else:
|
||||
return s # Not a string we care about
|
||||
@@ -0,0 +1,399 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from redis.commands.search.dialect import DEFAULT_DIALECT
|
||||
|
||||
FIELDNAME = object()
|
||||
|
||||
|
||||
class Limit:
|
||||
def __init__(self, offset: int = 0, count: int = 0) -> None:
|
||||
self.offset = offset
|
||||
self.count = count
|
||||
|
||||
def build_args(self):
|
||||
if self.count:
|
||||
return ["LIMIT", str(self.offset), str(self.count)]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class Reducer:
|
||||
"""
|
||||
Base reducer object for all reducers.
|
||||
|
||||
See the `redisearch.reducers` module for the actual reducers.
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
|
||||
def __init__(self, *args: str) -> None:
|
||||
self._args: Tuple[str, ...] = args
|
||||
self._field: Optional[str] = None
|
||||
self._alias: Optional[str] = None
|
||||
|
||||
def alias(self, alias: str) -> "Reducer":
|
||||
"""
|
||||
Set the alias for this reducer.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **alias**: The value of the alias for this reducer. If this is the
|
||||
special value `aggregation.FIELDNAME` then this reducer will be
|
||||
aliased using the same name as the field upon which it operates.
|
||||
Note that using `FIELDNAME` is only possible on reducers which
|
||||
operate on a single field value.
|
||||
|
||||
This method returns the `Reducer` object making it suitable for
|
||||
chaining.
|
||||
"""
|
||||
if alias is FIELDNAME:
|
||||
if not self._field:
|
||||
raise ValueError("Cannot use FIELDNAME alias with no field")
|
||||
else:
|
||||
# Chop off initial '@'
|
||||
alias = self._field[1:]
|
||||
self._alias = alias
|
||||
return self
|
||||
|
||||
@property
|
||||
def args(self) -> Tuple[str, ...]:
|
||||
return self._args
|
||||
|
||||
|
||||
class SortDirection:
|
||||
"""
|
||||
This special class is used to indicate sort direction.
|
||||
"""
|
||||
|
||||
DIRSTRING: Optional[str] = None
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
self.field = field
|
||||
|
||||
|
||||
class Asc(SortDirection):
|
||||
"""
|
||||
Indicate that the given field should be sorted in ascending order
|
||||
"""
|
||||
|
||||
DIRSTRING = "ASC"
|
||||
|
||||
|
||||
class Desc(SortDirection):
|
||||
"""
|
||||
Indicate that the given field should be sorted in descending order
|
||||
"""
|
||||
|
||||
DIRSTRING = "DESC"
|
||||
|
||||
|
||||
class AggregateRequest:
|
||||
"""
|
||||
Aggregation request which can be passed to `Client.aggregate`.
|
||||
"""
|
||||
|
||||
def __init__(self, query: str = "*") -> None:
|
||||
"""
|
||||
Create an aggregation request. This request may then be passed to
|
||||
`client.aggregate()`.
|
||||
|
||||
In order for the request to be usable, it must contain at least one
|
||||
group.
|
||||
|
||||
- **query** Query string for filtering records.
|
||||
|
||||
All member methods (except `build_args()`)
|
||||
return the object itself, making them useful for chaining.
|
||||
"""
|
||||
self._query: str = query
|
||||
self._aggregateplan: List[str] = []
|
||||
self._loadfields: List[str] = []
|
||||
self._loadall: bool = False
|
||||
self._max: int = 0
|
||||
self._with_schema: bool = False
|
||||
self._verbatim: bool = False
|
||||
self._cursor: List[str] = []
|
||||
self._dialect: int = DEFAULT_DIALECT
|
||||
self._add_scores: bool = False
|
||||
self._scorer: str = "TFIDF"
|
||||
|
||||
def load(self, *fields: str) -> "AggregateRequest":
|
||||
"""
|
||||
Indicate the fields to be returned in the response. These fields are
|
||||
returned in addition to any others implicitly specified.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: If fields not specified, all the fields will be loaded.
|
||||
Otherwise, fields should be given in the format of `@field`.
|
||||
"""
|
||||
if fields:
|
||||
self._loadfields.extend(fields)
|
||||
else:
|
||||
self._loadall = True
|
||||
return self
|
||||
|
||||
def group_by(
|
||||
self, fields: Union[str, List[str]], *reducers: Reducer
|
||||
) -> "AggregateRequest":
|
||||
"""
|
||||
Specify by which fields to group the aggregation.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: Fields to group by. This can either be a single string,
|
||||
or a list of strings. both cases, the field should be specified as
|
||||
`@field`.
|
||||
- **reducers**: One or more reducers. Reducers may be found in the
|
||||
`aggregation` module.
|
||||
"""
|
||||
fields = [fields] if isinstance(fields, str) else fields
|
||||
|
||||
ret = ["GROUPBY", str(len(fields)), *fields]
|
||||
for reducer in reducers:
|
||||
ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
|
||||
ret.extend(reducer.args)
|
||||
if reducer._alias is not None:
|
||||
ret += ["AS", reducer._alias]
|
||||
|
||||
self._aggregateplan.extend(ret)
|
||||
return self
|
||||
|
||||
def apply(self, **kwexpr) -> "AggregateRequest":
|
||||
"""
|
||||
Specify one or more projection expressions to add to each result
|
||||
|
||||
### Parameters
|
||||
|
||||
- **kwexpr**: One or more key-value pairs for a projection. The key is
|
||||
the alias for the projection, and the value is the projection
|
||||
expression itself, for example `apply(square_root="sqrt(@foo)")`
|
||||
"""
|
||||
for alias, expr in kwexpr.items():
|
||||
ret = ["APPLY", expr]
|
||||
if alias is not None:
|
||||
ret += ["AS", alias]
|
||||
self._aggregateplan.extend(ret)
|
||||
|
||||
return self
|
||||
|
||||
def limit(self, offset: int, num: int) -> "AggregateRequest":
|
||||
"""
|
||||
Sets the limit for the most recent group or query.
|
||||
|
||||
If no group has been defined yet (via `group_by()`) then this sets
|
||||
the limit for the initial pool of results from the query. Otherwise,
|
||||
this limits the number of items operated on from the previous group.
|
||||
|
||||
Setting a limit on the initial search results may be useful when
|
||||
attempting to execute an aggregation on a sample of a large data set.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **offset**: Result offset from which to begin paging
|
||||
- **num**: Number of results to return
|
||||
|
||||
|
||||
Example of sorting the initial results:
|
||||
|
||||
```
|
||||
AggregateRequest("@sale_amount:[10000, inf]")\
|
||||
.limit(0, 10)\
|
||||
.group_by("@state", r.count())
|
||||
```
|
||||
|
||||
Will only group by the states found in the first 10 results of the
|
||||
query `@sale_amount:[10000, inf]`. On the other hand,
|
||||
|
||||
```
|
||||
AggregateRequest("@sale_amount:[10000, inf]")\
|
||||
.limit(0, 1000)\
|
||||
.group_by("@state", r.count()\
|
||||
.limit(0, 10)
|
||||
```
|
||||
|
||||
Will group all the results matching the query, but only return the
|
||||
first 10 groups.
|
||||
|
||||
If you only wish to return a *top-N* style query, consider using
|
||||
`sort_by()` instead.
|
||||
|
||||
"""
|
||||
_limit = Limit(offset, num)
|
||||
self._aggregateplan.extend(_limit.build_args())
|
||||
return self
|
||||
|
||||
def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
|
||||
"""
|
||||
Indicate how the results should be sorted. This can also be used for
|
||||
*top-N* style queries
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: The fields by which to sort. This can be either a single
|
||||
field or a list of fields. If you wish to specify order, you can
|
||||
use the `Asc` or `Desc` wrapper classes.
|
||||
- **max**: Maximum number of results to return. This can be
|
||||
used instead of `LIMIT` and is also faster.
|
||||
|
||||
|
||||
Example of sorting by `foo` ascending and `bar` descending:
|
||||
|
||||
```
|
||||
sort_by(Asc("@foo"), Desc("@bar"))
|
||||
```
|
||||
|
||||
Return the top 10 customers:
|
||||
|
||||
```
|
||||
AggregateRequest()\
|
||||
.group_by("@customer", r.sum("@paid").alias(FIELDNAME))\
|
||||
.sort_by(Desc("@paid"), max=10)
|
||||
```
|
||||
"""
|
||||
|
||||
fields_args = []
|
||||
for f in fields:
|
||||
if isinstance(f, (Asc, Desc)):
|
||||
fields_args += [f.field, f.DIRSTRING]
|
||||
else:
|
||||
fields_args += [f]
|
||||
|
||||
ret = ["SORTBY", str(len(fields_args))]
|
||||
ret.extend(fields_args)
|
||||
max = kwargs.get("max", 0)
|
||||
if max > 0:
|
||||
ret += ["MAX", str(max)]
|
||||
|
||||
self._aggregateplan.extend(ret)
|
||||
return self
|
||||
|
||||
def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest":
|
||||
"""
|
||||
Specify filter for post-query results using predicates relating to
|
||||
values in the result set.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: Fields to group by. This can either be a single string,
|
||||
or a list of strings.
|
||||
"""
|
||||
if isinstance(expressions, str):
|
||||
expressions = [expressions]
|
||||
|
||||
for expression in expressions:
|
||||
self._aggregateplan.extend(["FILTER", expression])
|
||||
|
||||
return self
|
||||
|
||||
def with_schema(self) -> "AggregateRequest":
|
||||
"""
|
||||
If set, the `schema` property will contain a list of `[field, type]`
|
||||
entries in the result object.
|
||||
"""
|
||||
self._with_schema = True
|
||||
return self
|
||||
|
||||
def add_scores(self) -> "AggregateRequest":
|
||||
"""
|
||||
If set, includes the score as an ordinary field of the row.
|
||||
"""
|
||||
self._add_scores = True
|
||||
return self
|
||||
|
||||
def scorer(self, scorer: str) -> "AggregateRequest":
|
||||
"""
|
||||
Use a different scoring function to evaluate document relevance.
|
||||
Default is `TFIDF`.
|
||||
|
||||
:param scorer: The scoring function to use
|
||||
(e.g. `TFIDF.DOCNORM` or `BM25`)
|
||||
"""
|
||||
self._scorer = scorer
|
||||
return self
|
||||
|
||||
def verbatim(self) -> "AggregateRequest":
|
||||
self._verbatim = True
|
||||
return self
|
||||
|
||||
def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest":
|
||||
args = ["WITHCURSOR"]
|
||||
if count:
|
||||
args += ["COUNT", str(count)]
|
||||
if max_idle:
|
||||
args += ["MAXIDLE", str(max_idle * 1000)]
|
||||
self._cursor = args
|
||||
return self
|
||||
|
||||
def build_args(self) -> List[str]:
|
||||
# @foo:bar ...
|
||||
ret = [self._query]
|
||||
|
||||
if self._with_schema:
|
||||
ret.append("WITHSCHEMA")
|
||||
|
||||
if self._verbatim:
|
||||
ret.append("VERBATIM")
|
||||
|
||||
if self._scorer:
|
||||
ret.extend(["SCORER", self._scorer])
|
||||
|
||||
if self._add_scores:
|
||||
ret.append("ADDSCORES")
|
||||
|
||||
if self._cursor:
|
||||
ret += self._cursor
|
||||
|
||||
if self._loadall:
|
||||
ret.append("LOAD")
|
||||
ret.append("*")
|
||||
|
||||
elif self._loadfields:
|
||||
ret.append("LOAD")
|
||||
ret.append(str(len(self._loadfields)))
|
||||
ret.extend(self._loadfields)
|
||||
|
||||
if self._dialect:
|
||||
ret.extend(["DIALECT", str(self._dialect)])
|
||||
|
||||
ret.extend(self._aggregateplan)
|
||||
|
||||
return ret
|
||||
|
||||
def dialect(self, dialect: int) -> "AggregateRequest":
|
||||
"""
|
||||
Add a dialect field to the aggregate command.
|
||||
|
||||
- **dialect** - dialect version to execute the query under
|
||||
"""
|
||||
self._dialect = dialect
|
||||
return self
|
||||
|
||||
|
||||
class Cursor:
|
||||
def __init__(self, cid: int) -> None:
|
||||
self.cid = cid
|
||||
self.max_idle = 0
|
||||
self.count = 0
|
||||
|
||||
def build_args(self):
|
||||
args = [str(self.cid)]
|
||||
if self.max_idle:
|
||||
args += ["MAXIDLE", str(self.max_idle)]
|
||||
if self.count:
|
||||
args += ["COUNT", str(self.count)]
|
||||
return args
|
||||
|
||||
|
||||
class AggregateResult:
|
||||
def __init__(self, rows, cursor: Cursor, schema) -> None:
|
||||
self.rows = rows
|
||||
self.cursor = cursor
|
||||
self.schema = schema
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cid = self.cursor.cid if self.cursor else -1
|
||||
return (
|
||||
f"<{self.__class__.__name__} at 0x{id(self):x} "
|
||||
f"Rows={len(self.rows)}, Cursor={cid}>"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
||||
# Value for the default dialect to be used as a part of
|
||||
# Search or Aggregate query.
|
||||
DEFAULT_DIALECT = 2
|
||||
@@ -0,0 +1,17 @@
|
||||
class Document:
|
||||
"""
|
||||
Represents a single document in a result set
|
||||
"""
|
||||
|
||||
def __init__(self, id, payload=None, **fields):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
for k, v in fields.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Document {self.__dict__}"
|
||||
|
||||
def __getitem__(self, item):
|
||||
value = getattr(self, item)
|
||||
return value
|
||||
@@ -0,0 +1,210 @@
|
||||
from typing import List
|
||||
|
||||
from redis import DataError
|
||||
|
||||
|
||||
class Field:
|
||||
"""
|
||||
A class representing a field in a document.
|
||||
"""
|
||||
|
||||
NUMERIC = "NUMERIC"
|
||||
TEXT = "TEXT"
|
||||
WEIGHT = "WEIGHT"
|
||||
GEO = "GEO"
|
||||
TAG = "TAG"
|
||||
VECTOR = "VECTOR"
|
||||
SORTABLE = "SORTABLE"
|
||||
NOINDEX = "NOINDEX"
|
||||
AS = "AS"
|
||||
GEOSHAPE = "GEOSHAPE"
|
||||
INDEX_MISSING = "INDEXMISSING"
|
||||
INDEX_EMPTY = "INDEXEMPTY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
args: List[str] = None,
|
||||
sortable: bool = False,
|
||||
no_index: bool = False,
|
||||
index_missing: bool = False,
|
||||
index_empty: bool = False,
|
||||
as_name: str = None,
|
||||
):
|
||||
"""
|
||||
Create a new field object.
|
||||
|
||||
Args:
|
||||
name: The name of the field.
|
||||
args:
|
||||
sortable: If `True`, the field will be sortable.
|
||||
no_index: If `True`, the field will not be indexed.
|
||||
index_missing: If `True`, it will be possible to search for documents that
|
||||
have this field missing.
|
||||
index_empty: If `True`, it will be possible to search for documents that
|
||||
have this field empty.
|
||||
as_name: If provided, this alias will be used for the field.
|
||||
"""
|
||||
if args is None:
|
||||
args = []
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.args_suffix = list()
|
||||
self.as_name = as_name
|
||||
|
||||
if no_index:
|
||||
self.args_suffix.append(Field.NOINDEX)
|
||||
if index_missing:
|
||||
self.args_suffix.append(Field.INDEX_MISSING)
|
||||
if index_empty:
|
||||
self.args_suffix.append(Field.INDEX_EMPTY)
|
||||
if sortable:
|
||||
self.args_suffix.append(Field.SORTABLE)
|
||||
|
||||
if no_index and not sortable:
|
||||
raise ValueError("Non-Sortable non-Indexable fields are ignored")
|
||||
|
||||
def append_arg(self, value):
|
||||
self.args.append(value)
|
||||
|
||||
def redis_args(self):
|
||||
args = [self.name]
|
||||
if self.as_name:
|
||||
args += [self.AS, self.as_name]
|
||||
args += self.args
|
||||
args += self.args_suffix
|
||||
return args
|
||||
|
||||
|
||||
class TextField(Field):
|
||||
"""
|
||||
TextField is used to define a text field in a schema definition
|
||||
"""
|
||||
|
||||
NOSTEM = "NOSTEM"
|
||||
PHONETIC = "PHONETIC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
weight: float = 1.0,
|
||||
no_stem: bool = False,
|
||||
phonetic_matcher: str = None,
|
||||
withsuffixtrie: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs)
|
||||
|
||||
if no_stem:
|
||||
Field.append_arg(self, self.NOSTEM)
|
||||
if phonetic_matcher and phonetic_matcher in [
|
||||
"dm:en",
|
||||
"dm:fr",
|
||||
"dm:pt",
|
||||
"dm:es",
|
||||
]:
|
||||
Field.append_arg(self, self.PHONETIC)
|
||||
Field.append_arg(self, phonetic_matcher)
|
||||
if withsuffixtrie:
|
||||
Field.append_arg(self, "WITHSUFFIXTRIE")
|
||||
|
||||
|
||||
class NumericField(Field):
|
||||
"""
|
||||
NumericField is used to define a numeric field in a schema definition
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
|
||||
|
||||
|
||||
class GeoShapeField(Field):
|
||||
"""
|
||||
GeoShapeField is used to enable within/contain indexing/searching
|
||||
"""
|
||||
|
||||
SPHERICAL = "SPHERICAL"
|
||||
FLAT = "FLAT"
|
||||
|
||||
def __init__(self, name: str, coord_system=None, **kwargs):
|
||||
args = [Field.GEOSHAPE]
|
||||
if coord_system:
|
||||
args.append(coord_system)
|
||||
Field.__init__(self, name, args=args, **kwargs)
|
||||
|
||||
|
||||
class GeoField(Field):
|
||||
"""
|
||||
GeoField is used to define a geo-indexing field in a schema definition
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
Field.__init__(self, name, args=[Field.GEO], **kwargs)
|
||||
|
||||
|
||||
class TagField(Field):
|
||||
"""
|
||||
TagField is a tag-indexing field with simpler compression and tokenization.
|
||||
See http://redisearch.io/Tags/
|
||||
"""
|
||||
|
||||
SEPARATOR = "SEPARATOR"
|
||||
CASESENSITIVE = "CASESENSITIVE"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
separator: str = ",",
|
||||
case_sensitive: bool = False,
|
||||
withsuffixtrie: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
args = [Field.TAG, self.SEPARATOR, separator]
|
||||
if case_sensitive:
|
||||
args.append(self.CASESENSITIVE)
|
||||
if withsuffixtrie:
|
||||
args.append("WITHSUFFIXTRIE")
|
||||
|
||||
Field.__init__(self, name, args=args, **kwargs)
|
||||
|
||||
|
||||
class VectorField(Field):
|
||||
"""
|
||||
Allows vector similarity queries against the value in this attribute.
|
||||
See https://oss.redis.com/redisearch/Vectors/#vector_fields.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs):
|
||||
"""
|
||||
Create Vector Field. Notice that Vector cannot have sortable or no_index tag,
|
||||
although it's also a Field.
|
||||
|
||||
``name`` is the name of the field.
|
||||
|
||||
``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA".
|
||||
|
||||
``attributes`` each algorithm can have specific attributes. Some of them
|
||||
are mandatory and some of them are optional. See
|
||||
https://oss.redis.com/redisearch/master/Vectors/#specific_creation_attributes_per_algorithm
|
||||
for more information.
|
||||
"""
|
||||
sort = kwargs.get("sortable", False)
|
||||
noindex = kwargs.get("no_index", False)
|
||||
|
||||
if sort or noindex:
|
||||
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
|
||||
|
||||
if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]:
|
||||
raise DataError(
|
||||
"Realtime vector indexing supporting 3 Indexing Methods:"
|
||||
"'FLAT', 'HNSW', and 'SVS-VAMANA'."
|
||||
)
|
||||
|
||||
attr_li = []
|
||||
|
||||
for key, value in attributes.items():
|
||||
attr_li.extend([key, value])
|
||||
|
||||
Field.__init__(
|
||||
self, name, args=[Field.VECTOR, algorithm, len(attr_li), *attr_li], **kwargs
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class IndexType(Enum):
|
||||
"""Enum of the currently supported index types."""
|
||||
|
||||
HASH = 1
|
||||
JSON = 2
|
||||
|
||||
|
||||
class IndexDefinition:
|
||||
"""IndexDefinition is used to define a index definition for automatic
|
||||
indexing on Hash or Json update."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix=[],
|
||||
filter=None,
|
||||
language_field=None,
|
||||
language=None,
|
||||
score_field=None,
|
||||
score=1.0,
|
||||
payload_field=None,
|
||||
index_type=None,
|
||||
):
|
||||
self.args = []
|
||||
self._append_index_type(index_type)
|
||||
self._append_prefix(prefix)
|
||||
self._append_filter(filter)
|
||||
self._append_language(language_field, language)
|
||||
self._append_score(score_field, score)
|
||||
self._append_payload(payload_field)
|
||||
|
||||
def _append_index_type(self, index_type):
|
||||
"""Append `ON HASH` or `ON JSON` according to the enum."""
|
||||
if index_type is IndexType.HASH:
|
||||
self.args.extend(["ON", "HASH"])
|
||||
elif index_type is IndexType.JSON:
|
||||
self.args.extend(["ON", "JSON"])
|
||||
elif index_type is not None:
|
||||
raise RuntimeError(f"index_type must be one of {list(IndexType)}")
|
||||
|
||||
def _append_prefix(self, prefix):
|
||||
"""Append PREFIX."""
|
||||
if len(prefix) > 0:
|
||||
self.args.append("PREFIX")
|
||||
self.args.append(len(prefix))
|
||||
for p in prefix:
|
||||
self.args.append(p)
|
||||
|
||||
def _append_filter(self, filter):
|
||||
"""Append FILTER."""
|
||||
if filter is not None:
|
||||
self.args.append("FILTER")
|
||||
self.args.append(filter)
|
||||
|
||||
def _append_language(self, language_field, language):
|
||||
"""Append LANGUAGE_FIELD and LANGUAGE."""
|
||||
if language_field is not None:
|
||||
self.args.append("LANGUAGE_FIELD")
|
||||
self.args.append(language_field)
|
||||
if language is not None:
|
||||
self.args.append("LANGUAGE")
|
||||
self.args.append(language)
|
||||
|
||||
def _append_score(self, score_field, score):
|
||||
"""Append SCORE_FIELD and SCORE."""
|
||||
if score_field is not None:
|
||||
self.args.append("SCORE_FIELD")
|
||||
self.args.append(score_field)
|
||||
if score is not None:
|
||||
self.args.append("SCORE")
|
||||
self.args.append(score)
|
||||
|
||||
def _append_payload(self, payload_field):
|
||||
"""Append PAYLOAD_FIELD."""
|
||||
if payload_field is not None:
|
||||
self.args.append("PAYLOAD_FIELD")
|
||||
self.args.append(payload_field)
|
||||
@@ -0,0 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ProfileInformation:
|
||||
"""
|
||||
Wrapper around FT.PROFILE response
|
||||
"""
|
||||
|
||||
def __init__(self, info: Any) -> None:
|
||||
self._info: Any = info
|
||||
|
||||
@property
|
||||
def info(self) -> Any:
|
||||
return self._info
|
||||
@@ -0,0 +1,381 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from redis.commands.search.dialect import DEFAULT_DIALECT
|
||||
|
||||
|
||||
class Query:
|
||||
"""
|
||||
Query is used to build complex queries that have more parameters than just
|
||||
the query string. The query string is set in the constructor, and other
|
||||
options have setter functions.
|
||||
|
||||
The setter functions return the query object so they can be chained.
|
||||
i.e. `Query("foo").verbatim().filter(...)` etc.
|
||||
"""
|
||||
|
||||
def __init__(self, query_string: str) -> None:
|
||||
"""
|
||||
Create a new query object.
|
||||
The query string is set in the constructor, and other options have
|
||||
setter functions.
|
||||
"""
|
||||
|
||||
self._query_string: str = query_string
|
||||
self._offset: int = 0
|
||||
self._num: int = 10
|
||||
self._no_content: bool = False
|
||||
self._no_stopwords: bool = False
|
||||
self._fields: Optional[List[str]] = None
|
||||
self._verbatim: bool = False
|
||||
self._with_payloads: bool = False
|
||||
self._with_scores: bool = False
|
||||
self._scorer: Optional[str] = None
|
||||
self._filters: List = list()
|
||||
self._ids: Optional[Tuple[str, ...]] = None
|
||||
self._slop: int = -1
|
||||
self._timeout: Optional[float] = None
|
||||
self._in_order: bool = False
|
||||
self._sortby: Optional[SortbyField] = None
|
||||
self._return_fields: List = []
|
||||
self._return_fields_decode_as: dict = {}
|
||||
self._summarize_fields: List = []
|
||||
self._highlight_fields: List = []
|
||||
self._language: Optional[str] = None
|
||||
self._expander: Optional[str] = None
|
||||
self._dialect: int = DEFAULT_DIALECT
|
||||
|
||||
def query_string(self) -> str:
|
||||
"""Return the query string of this query only."""
|
||||
return self._query_string
|
||||
|
||||
def limit_ids(self, *ids) -> "Query":
|
||||
"""Limit the results to a specific set of pre-known document
|
||||
ids of any length."""
|
||||
self._ids = ids
|
||||
return self
|
||||
|
||||
def return_fields(self, *fields) -> "Query":
|
||||
"""Add fields to return fields."""
|
||||
for field in fields:
|
||||
self.return_field(field)
|
||||
return self
|
||||
|
||||
def return_field(
|
||||
self,
|
||||
field: str,
|
||||
as_field: Optional[str] = None,
|
||||
decode_field: Optional[bool] = True,
|
||||
encoding: Optional[str] = "utf8",
|
||||
) -> "Query":
|
||||
"""
|
||||
Add a field to the list of fields to return.
|
||||
|
||||
- **field**: The field to include in query results
|
||||
- **as_field**: The alias for the field
|
||||
- **decode_field**: Whether to decode the field from bytes to string
|
||||
- **encoding**: The encoding to use when decoding the field
|
||||
"""
|
||||
self._return_fields.append(field)
|
||||
self._return_fields_decode_as[field] = encoding if decode_field else None
|
||||
if as_field is not None:
|
||||
self._return_fields += ("AS", as_field)
|
||||
return self
|
||||
|
||||
def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List:
|
||||
if not fields:
|
||||
return []
|
||||
return [fields] if isinstance(fields, str) else list(fields)
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
fields: Optional[List] = None,
|
||||
context_len: Optional[int] = None,
|
||||
num_frags: Optional[int] = None,
|
||||
sep: Optional[str] = None,
|
||||
) -> "Query":
|
||||
"""
|
||||
Return an abridged format of the field, containing only the segments of
|
||||
the field that contain the matching term(s).
|
||||
|
||||
If `fields` is specified, then only the mentioned fields are
|
||||
summarized; otherwise, all results are summarized.
|
||||
|
||||
Server-side defaults are used for each option (except `fields`)
|
||||
if not specified
|
||||
|
||||
- **fields** List of fields to summarize. All fields are summarized
|
||||
if not specified
|
||||
- **context_len** Amount of context to include with each fragment
|
||||
- **num_frags** Number of fragments per document
|
||||
- **sep** Separator string to separate fragments
|
||||
"""
|
||||
args = ["SUMMARIZE"]
|
||||
fields = self._mk_field_list(fields)
|
||||
if fields:
|
||||
args += ["FIELDS", str(len(fields))] + fields
|
||||
|
||||
if context_len is not None:
|
||||
args += ["LEN", str(context_len)]
|
||||
if num_frags is not None:
|
||||
args += ["FRAGS", str(num_frags)]
|
||||
if sep is not None:
|
||||
args += ["SEPARATOR", sep]
|
||||
|
||||
self._summarize_fields = args
|
||||
return self
|
||||
|
||||
def highlight(
|
||||
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
|
||||
) -> "Query":
|
||||
"""
|
||||
Apply specified markup to matched term(s) within the returned field(s).
|
||||
|
||||
- **fields** If specified, then only those mentioned fields are
|
||||
highlighted, otherwise all fields are highlighted
|
||||
- **tags** A list of two strings to surround the match.
|
||||
"""
|
||||
args = ["HIGHLIGHT"]
|
||||
fields = self._mk_field_list(fields)
|
||||
if fields:
|
||||
args += ["FIELDS", str(len(fields))] + fields
|
||||
if tags:
|
||||
args += ["TAGS"] + list(tags)
|
||||
|
||||
self._highlight_fields = args
|
||||
return self
|
||||
|
||||
def language(self, language: str) -> "Query":
|
||||
"""
|
||||
Analyze the query as being in the specified language.
|
||||
|
||||
:param language: The language (e.g. `chinese` or `english`)
|
||||
"""
|
||||
self._language = language
|
||||
return self
|
||||
|
||||
def slop(self, slop: int) -> "Query":
|
||||
"""Allow a maximum of N intervening non-matched terms between
|
||||
phrase terms (0 means exact phrase).
|
||||
"""
|
||||
self._slop = slop
|
||||
return self
|
||||
|
||||
def timeout(self, timeout: float) -> "Query":
|
||||
"""overrides the timeout parameter of the module"""
|
||||
self._timeout = timeout
|
||||
return self
|
||||
|
||||
def in_order(self) -> "Query":
|
||||
"""
|
||||
Match only documents where the query terms appear in
|
||||
the same order in the document.
|
||||
i.e., for the query "hello world", we do not match "world hello"
|
||||
"""
|
||||
self._in_order = True
|
||||
return self
|
||||
|
||||
def scorer(self, scorer: str) -> "Query":
|
||||
"""
|
||||
Use a different scoring function to evaluate document relevance.
|
||||
Default is `TFIDF`.
|
||||
|
||||
Since Redis 8.0 default was changed to BM25STD.
|
||||
|
||||
:param scorer: The scoring function to use
|
||||
(e.g. `TFIDF.DOCNORM` or `BM25`)
|
||||
"""
|
||||
self._scorer = scorer
|
||||
return self
|
||||
|
||||
def get_args(self) -> List[Union[str, int, float]]:
|
||||
"""Format the redis arguments for this query and return them."""
|
||||
args: List[Union[str, int, float]] = [self._query_string]
|
||||
args += self._get_args_tags()
|
||||
args += self._summarize_fields + self._highlight_fields
|
||||
args += ["LIMIT", self._offset, self._num]
|
||||
return args
|
||||
|
||||
def _get_args_tags(self) -> List[Union[str, int, float]]:
|
||||
args: List[Union[str, int, float]] = []
|
||||
if self._no_content:
|
||||
args.append("NOCONTENT")
|
||||
if self._fields:
|
||||
args.append("INFIELDS")
|
||||
args.append(len(self._fields))
|
||||
args += self._fields
|
||||
if self._verbatim:
|
||||
args.append("VERBATIM")
|
||||
if self._no_stopwords:
|
||||
args.append("NOSTOPWORDS")
|
||||
if self._filters:
|
||||
for flt in self._filters:
|
||||
if not isinstance(flt, Filter):
|
||||
raise AttributeError("Did not receive a Filter object.")
|
||||
args += flt.args
|
||||
if self._with_payloads:
|
||||
args.append("WITHPAYLOADS")
|
||||
if self._scorer:
|
||||
args += ["SCORER", self._scorer]
|
||||
if self._with_scores:
|
||||
args.append("WITHSCORES")
|
||||
if self._ids:
|
||||
args.append("INKEYS")
|
||||
args.append(len(self._ids))
|
||||
args += self._ids
|
||||
if self._slop >= 0:
|
||||
args += ["SLOP", self._slop]
|
||||
if self._timeout is not None:
|
||||
args += ["TIMEOUT", self._timeout]
|
||||
if self._in_order:
|
||||
args.append("INORDER")
|
||||
if self._return_fields:
|
||||
args.append("RETURN")
|
||||
args.append(len(self._return_fields))
|
||||
args += self._return_fields
|
||||
if self._sortby:
|
||||
if not isinstance(self._sortby, SortbyField):
|
||||
raise AttributeError("Did not receive a SortByField.")
|
||||
args.append("SORTBY")
|
||||
args += self._sortby.args
|
||||
if self._language:
|
||||
args += ["LANGUAGE", self._language]
|
||||
if self._expander:
|
||||
args += ["EXPANDER", self._expander]
|
||||
if self._dialect:
|
||||
args += ["DIALECT", self._dialect]
|
||||
|
||||
return args
|
||||
|
||||
def paging(self, offset: int, num: int) -> "Query":
|
||||
"""
|
||||
Set the paging for the query (defaults to 0..10).
|
||||
|
||||
- **offset**: Paging offset for the results. Defaults to 0
|
||||
- **num**: How many results do we want
|
||||
"""
|
||||
self._offset = offset
|
||||
self._num = num
|
||||
return self
|
||||
|
||||
def verbatim(self) -> "Query":
|
||||
"""Set the query to be verbatim, i.e., use no query expansion
|
||||
or stemming.
|
||||
"""
|
||||
self._verbatim = True
|
||||
return self
|
||||
|
||||
def no_content(self) -> "Query":
|
||||
"""Set the query to only return ids and not the document content."""
|
||||
self._no_content = True
|
||||
return self
|
||||
|
||||
def no_stopwords(self) -> "Query":
|
||||
"""
|
||||
Prevent the query from being filtered for stopwords.
|
||||
Only useful in very big queries that you are certain contain
|
||||
no stopwords.
|
||||
"""
|
||||
self._no_stopwords = True
|
||||
return self
|
||||
|
||||
def with_payloads(self) -> "Query":
|
||||
"""Ask the engine to return document payloads."""
|
||||
self._with_payloads = True
|
||||
return self
|
||||
|
||||
def with_scores(self) -> "Query":
|
||||
"""Ask the engine to return document search scores."""
|
||||
self._with_scores = True
|
||||
return self
|
||||
|
||||
def limit_fields(self, *fields: str) -> "Query":
|
||||
"""
|
||||
Limit the search to specific TEXT fields only.
|
||||
|
||||
- **fields**: Each element should be a string, case sensitive field name
|
||||
from the defined schema.
|
||||
"""
|
||||
self._fields = list(fields)
|
||||
return self
|
||||
|
||||
def add_filter(self, flt: "Filter") -> "Query":
|
||||
"""
|
||||
Add a numeric or geo filter to the query.
|
||||
**Currently, only one of each filter is supported by the engine**
|
||||
|
||||
- **flt**: A NumericFilter or GeoFilter object, used on a
|
||||
corresponding field
|
||||
"""
|
||||
|
||||
self._filters.append(flt)
|
||||
return self
|
||||
|
||||
def sort_by(self, field: str, asc: bool = True) -> "Query":
|
||||
"""
|
||||
Add a sortby field to the query.
|
||||
|
||||
- **field** - the name of the field to sort by
|
||||
- **asc** - when `True`, sorting will be done in ascending order
|
||||
"""
|
||||
self._sortby = SortbyField(field, asc)
|
||||
return self
|
||||
|
||||
def expander(self, expander: str) -> "Query":
|
||||
"""
|
||||
Add an expander field to the query.
|
||||
|
||||
- **expander** - the name of the expander
|
||||
"""
|
||||
self._expander = expander
|
||||
return self
|
||||
|
||||
def dialect(self, dialect: int) -> "Query":
|
||||
"""
|
||||
Add a dialect field to the query.
|
||||
|
||||
- **dialect** - dialect version to execute the query under
|
||||
"""
|
||||
self._dialect = dialect
|
||||
return self
|
||||
|
||||
|
||||
class Filter:
|
||||
def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None:
|
||||
self.args = [keyword, field] + list(args)
|
||||
|
||||
|
||||
class NumericFilter(Filter):
|
||||
INF = "+inf"
|
||||
NEG_INF = "-inf"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field: str,
|
||||
minval: Union[int, str],
|
||||
maxval: Union[int, str],
|
||||
minExclusive: bool = False,
|
||||
maxExclusive: bool = False,
|
||||
) -> None:
|
||||
args = [
|
||||
minval if not minExclusive else f"({minval}",
|
||||
maxval if not maxExclusive else f"({maxval}",
|
||||
]
|
||||
|
||||
Filter.__init__(self, "FILTER", field, *args)
|
||||
|
||||
|
||||
class GeoFilter(Filter):
|
||||
METERS = "m"
|
||||
KILOMETERS = "km"
|
||||
FEET = "ft"
|
||||
MILES = "mi"
|
||||
|
||||
def __init__(
|
||||
self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS
|
||||
) -> None:
|
||||
Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit)
|
||||
|
||||
|
||||
class SortbyField:
|
||||
def __init__(self, field: str, asc=True) -> None:
|
||||
self.args = [field, "ASC" if asc else "DESC"]
|
||||
@@ -0,0 +1,317 @@
|
||||
def tags(*t):
|
||||
"""
|
||||
Indicate that the values should be matched to a tag field
|
||||
|
||||
### Parameters
|
||||
|
||||
- **t**: Tags to search for
|
||||
"""
|
||||
if not t:
|
||||
raise ValueError("At least one tag must be specified")
|
||||
return TagValue(*t)
|
||||
|
||||
|
||||
def between(a, b, inclusive_min=True, inclusive_max=True):
|
||||
"""
|
||||
Indicate that value is a numeric range
|
||||
"""
|
||||
return RangeValue(a, b, inclusive_min=inclusive_min, inclusive_max=inclusive_max)
|
||||
|
||||
|
||||
def equal(n):
|
||||
"""
|
||||
Match a numeric value
|
||||
"""
|
||||
return between(n, n)
|
||||
|
||||
|
||||
def lt(n):
|
||||
"""
|
||||
Match any value less than n
|
||||
"""
|
||||
return between(None, n, inclusive_max=False)
|
||||
|
||||
|
||||
def le(n):
|
||||
"""
|
||||
Match any value less or equal to n
|
||||
"""
|
||||
return between(None, n, inclusive_max=True)
|
||||
|
||||
|
||||
def gt(n):
|
||||
"""
|
||||
Match any value greater than n
|
||||
"""
|
||||
return between(n, None, inclusive_min=False)
|
||||
|
||||
|
||||
def ge(n):
|
||||
"""
|
||||
Match any value greater or equal to n
|
||||
"""
|
||||
return between(n, None, inclusive_min=True)
|
||||
|
||||
|
||||
def geo(lat, lon, radius, unit="km"):
|
||||
"""
|
||||
Indicate that value is a geo region
|
||||
"""
|
||||
return GeoValue(lat, lon, radius, unit)
|
||||
|
||||
|
||||
class Value:
|
||||
@property
|
||||
def combinable(self):
|
||||
"""
|
||||
Whether this type of value may be combined with other values
|
||||
for the same field. This makes the filter potentially more efficient
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def make_value(v):
|
||||
"""
|
||||
Convert an object to a value, if it is not a value already
|
||||
"""
|
||||
if isinstance(v, Value):
|
||||
return v
|
||||
return ScalarValue(v)
|
||||
|
||||
def to_string(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string()
|
||||
|
||||
|
||||
class RangeValue(Value):
|
||||
combinable = False
|
||||
|
||||
def __init__(self, a, b, inclusive_min=False, inclusive_max=False):
|
||||
if a is None:
|
||||
a = "-inf"
|
||||
if b is None:
|
||||
b = "inf"
|
||||
self.range = [str(a), str(b)]
|
||||
self.inclusive_min = inclusive_min
|
||||
self.inclusive_max = inclusive_max
|
||||
|
||||
def to_string(self):
|
||||
return "[{1}{0[0]} {2}{0[1]}]".format(
|
||||
self.range,
|
||||
"(" if not self.inclusive_min else "",
|
||||
"(" if not self.inclusive_max else "",
|
||||
)
|
||||
|
||||
|
||||
class ScalarValue(Value):
|
||||
combinable = True
|
||||
|
||||
def __init__(self, v):
|
||||
self.v = str(v)
|
||||
|
||||
def to_string(self):
|
||||
return self.v
|
||||
|
||||
|
||||
class TagValue(Value):
|
||||
combinable = False
|
||||
|
||||
def __init__(self, *tags):
|
||||
self.tags = tags
|
||||
|
||||
def to_string(self):
|
||||
return "{" + " | ".join(str(t) for t in self.tags) + "}"
|
||||
|
||||
|
||||
class GeoValue(Value):
|
||||
def __init__(self, lon, lat, radius, unit="km"):
|
||||
self.lon = lon
|
||||
self.lat = lat
|
||||
self.radius = radius
|
||||
self.unit = unit
|
||||
|
||||
def to_string(self):
|
||||
return f"[{self.lon} {self.lat} {self.radius} {self.unit}]"
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, *children, **kwparams):
|
||||
"""
|
||||
Create a node
|
||||
|
||||
### Parameters
|
||||
|
||||
- **children**: One or more sub-conditions. These can be additional
|
||||
`intersect`, `disjunct`, `union`, `optional`, or any other `Node`
|
||||
type.
|
||||
|
||||
The semantics of multiple conditions are dependent on the type of
|
||||
query. For an `intersection` node, this amounts to a logical AND,
|
||||
for a `union` node, this amounts to a logical `OR`.
|
||||
|
||||
- **kwparams**: key-value parameters. Each key is the name of a field,
|
||||
and the value should be a field value. This can be one of the
|
||||
following:
|
||||
|
||||
- Simple string (for text field matches)
|
||||
- value returned by one of the helper functions
|
||||
- list of either a string or a value
|
||||
|
||||
|
||||
### Examples
|
||||
|
||||
Field `num` should be between 1 and 10
|
||||
```
|
||||
intersect(num=between(1, 10)
|
||||
```
|
||||
|
||||
Name can either be `bob` or `john`
|
||||
|
||||
```
|
||||
union(name=("bob", "john"))
|
||||
```
|
||||
|
||||
Don't select countries in Israel, Japan, or US
|
||||
|
||||
```
|
||||
disjunct_union(country=("il", "jp", "us"))
|
||||
```
|
||||
"""
|
||||
|
||||
self.params = []
|
||||
|
||||
kvparams = {}
|
||||
for k, v in kwparams.items():
|
||||
curvals = kvparams.setdefault(k, [])
|
||||
if isinstance(v, (str, int, float)):
|
||||
curvals.append(Value.make_value(v))
|
||||
elif isinstance(v, Value):
|
||||
curvals.append(v)
|
||||
else:
|
||||
curvals.extend(Value.make_value(subv) for subv in v)
|
||||
|
||||
self.params += [Node.to_node(p) for p in children]
|
||||
|
||||
for k, v in kvparams.items():
|
||||
self.params.extend(self.join_fields(k, v))
|
||||
|
||||
def join_fields(self, key, vals):
|
||||
if len(vals) == 1:
|
||||
return [BaseNode(f"@{key}:{vals[0].to_string()}")]
|
||||
if not vals[0].combinable:
|
||||
return [BaseNode(f"@{key}:{v.to_string()}") for v in vals]
|
||||
s = BaseNode(f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})")
|
||||
return [s]
|
||||
|
||||
@classmethod
|
||||
def to_node(cls, obj): # noqa
|
||||
if isinstance(obj, Node):
|
||||
return obj
|
||||
return BaseNode(obj)
|
||||
|
||||
@property
|
||||
def JOINSTR(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
with_parens = self._should_use_paren(with_parens)
|
||||
pre, post = ("(", ")") if with_parens else ("", "")
|
||||
return f"{pre}{self.JOINSTR.join(n.to_string() for n in self.params)}{post}"
|
||||
|
||||
def _should_use_paren(self, optval):
|
||||
if optval is not None:
|
||||
return optval
|
||||
return len(self.params) > 1
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string()
|
||||
|
||||
|
||||
class BaseNode(Node):
|
||||
def __init__(self, s):
|
||||
super().__init__()
|
||||
self.s = str(s)
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
return self.s
|
||||
|
||||
|
||||
class IntersectNode(Node):
|
||||
"""
|
||||
Create an intersection node. All children need to be satisfied in order for
|
||||
this node to evaluate as true
|
||||
"""
|
||||
|
||||
JOINSTR = " "
|
||||
|
||||
|
||||
class UnionNode(Node):
|
||||
"""
|
||||
Create a union node. Any of the children need to be satisfied in order for
|
||||
this node to evaluate as true
|
||||
"""
|
||||
|
||||
JOINSTR = "|"
|
||||
|
||||
|
||||
class DisjunctNode(IntersectNode):
|
||||
"""
|
||||
Create a disjunct node. In order for this node to be true, all of its
|
||||
children must evaluate to false
|
||||
"""
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
with_parens = self._should_use_paren(with_parens)
|
||||
ret = super().to_string(with_parens=False)
|
||||
if with_parens:
|
||||
return "(-" + ret + ")"
|
||||
else:
|
||||
return "-" + ret
|
||||
|
||||
|
||||
class DistjunctUnion(DisjunctNode):
|
||||
"""
|
||||
This node is true if *all* of its children are false. This is equivalent to
|
||||
```
|
||||
disjunct(union(...))
|
||||
```
|
||||
"""
|
||||
|
||||
JOINSTR = "|"
|
||||
|
||||
|
||||
class OptionalNode(IntersectNode):
|
||||
"""
|
||||
Create an optional node. If this nodes evaluates to true, then the document
|
||||
will be rated higher in score/rank.
|
||||
"""
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
with_parens = self._should_use_paren(with_parens)
|
||||
ret = super().to_string(with_parens=False)
|
||||
if with_parens:
|
||||
return "(~" + ret + ")"
|
||||
else:
|
||||
return "~" + ret
|
||||
|
||||
|
||||
def intersect(*args, **kwargs):
|
||||
return IntersectNode(*args, **kwargs)
|
||||
|
||||
|
||||
def union(*args, **kwargs):
|
||||
return UnionNode(*args, **kwargs)
|
||||
|
||||
|
||||
def disjunct(*args, **kwargs):
|
||||
return DisjunctNode(*args, **kwargs)
|
||||
|
||||
|
||||
def disjunct_union(*args, **kwargs):
|
||||
return DistjunctUnion(*args, **kwargs)
|
||||
|
||||
|
||||
def querystring(*args, **kwargs):
|
||||
return intersect(*args, **kwargs).to_string()
|
||||
@@ -0,0 +1,182 @@
|
||||
from typing import Union
|
||||
|
||||
from .aggregation import Asc, Desc, Reducer, SortDirection
|
||||
|
||||
|
||||
class FieldOnlyReducer(Reducer):
|
||||
"""See https://redis.io/docs/interact/search-and-query/search/aggregations/"""
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
self._field = field
|
||||
|
||||
|
||||
class count(Reducer):
|
||||
"""
|
||||
Counts the number of results in the group
|
||||
"""
|
||||
|
||||
NAME = "COUNT"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
class sum(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the sum of all the values in the given fields within the group
|
||||
"""
|
||||
|
||||
NAME = "SUM"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class min(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the smallest value in the given field within the group
|
||||
"""
|
||||
|
||||
NAME = "MIN"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class max(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the largest value in the given field within the group
|
||||
"""
|
||||
|
||||
NAME = "MAX"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class avg(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the mean value in the given field within the group
|
||||
"""
|
||||
|
||||
NAME = "AVG"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class tolist(FieldOnlyReducer):
|
||||
"""
|
||||
Returns all the matched properties in a list
|
||||
"""
|
||||
|
||||
NAME = "TOLIST"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class count_distinct(FieldOnlyReducer):
|
||||
"""
|
||||
Calculate the number of distinct values contained in all the results in
|
||||
the group for the given field
|
||||
"""
|
||||
|
||||
NAME = "COUNT_DISTINCT"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class count_distinctish(FieldOnlyReducer):
|
||||
"""
|
||||
Calculate the number of distinct values contained in all the results in the
|
||||
group for the given field. This uses a faster algorithm than
|
||||
`count_distinct` but is less accurate
|
||||
"""
|
||||
|
||||
NAME = "COUNT_DISTINCTISH"
|
||||
|
||||
|
||||
class quantile(Reducer):
|
||||
"""
|
||||
Return the value for the nth percentile within the range of values for the
|
||||
field within the group.
|
||||
"""
|
||||
|
||||
NAME = "QUANTILE"
|
||||
|
||||
def __init__(self, field: str, pct: float) -> None:
|
||||
super().__init__(field, str(pct))
|
||||
self._field = field
|
||||
|
||||
|
||||
class stddev(FieldOnlyReducer):
|
||||
"""
|
||||
Return the standard deviation for the values within the group
|
||||
"""
|
||||
|
||||
NAME = "STDDEV"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class first_value(Reducer):
|
||||
"""
|
||||
Selects the first value within the group according to sorting parameters
|
||||
"""
|
||||
|
||||
NAME = "FIRST_VALUE"
|
||||
|
||||
def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None:
|
||||
"""
|
||||
Selects the first value of the given field within the group.
|
||||
|
||||
### Parameter
|
||||
|
||||
- **field**: Source field used for the value
|
||||
- **byfields**: How to sort the results. This can be either the
|
||||
*class* of `aggregation.Asc` or `aggregation.Desc` in which
|
||||
case the field `field` is also used as the sort input.
|
||||
|
||||
`byfields` can also be one or more *instances* of `Asc` or `Desc`
|
||||
indicating the sort order for these fields
|
||||
"""
|
||||
|
||||
fieldstrs = []
|
||||
if (
|
||||
len(byfields) == 1
|
||||
and isinstance(byfields[0], type)
|
||||
and issubclass(byfields[0], SortDirection)
|
||||
):
|
||||
byfields = [byfields[0](field)]
|
||||
|
||||
for f in byfields:
|
||||
fieldstrs += [f.field, f.DIRSTRING]
|
||||
|
||||
args = [field]
|
||||
if fieldstrs:
|
||||
args += ["BY"] + fieldstrs
|
||||
super().__init__(*args)
|
||||
self._field = field
|
||||
|
||||
|
||||
class random_sample(Reducer):
|
||||
"""
|
||||
Returns a random sample of items from the dataset, from the given property
|
||||
"""
|
||||
|
||||
NAME = "RANDOM_SAMPLE"
|
||||
|
||||
def __init__(self, field: str, size: int) -> None:
|
||||
"""
|
||||
### Parameter
|
||||
|
||||
**field**: Field to sample from
|
||||
**size**: Return this many items (can be less)
|
||||
"""
|
||||
args = [field, str(size)]
|
||||
super().__init__(*args)
|
||||
self._field = field
|
||||
@@ -0,0 +1,87 @@
|
||||
from typing import Optional
|
||||
|
||||
from ._util import to_string
|
||||
from .document import Document
|
||||
|
||||
|
||||
class Result:
|
||||
"""
|
||||
Represents the result of a search query, and has an array of Document
|
||||
objects
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
res,
|
||||
hascontent,
|
||||
duration=0,
|
||||
has_payload=False,
|
||||
with_scores=False,
|
||||
field_encodings: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
- duration: the execution time of the query
|
||||
- has_payload: whether the query has payloads
|
||||
- with_scores: whether the query has scores
|
||||
- field_encodings: a dictionary of field encodings if any is provided
|
||||
"""
|
||||
|
||||
self.total = res[0]
|
||||
self.duration = duration
|
||||
self.docs = []
|
||||
|
||||
step = 1
|
||||
if hascontent:
|
||||
step = step + 1
|
||||
if has_payload:
|
||||
step = step + 1
|
||||
if with_scores:
|
||||
step = step + 1
|
||||
|
||||
offset = 2 if with_scores else 1
|
||||
|
||||
for i in range(1, len(res), step):
|
||||
id = to_string(res[i])
|
||||
payload = to_string(res[i + offset]) if has_payload else None
|
||||
# fields_offset = 2 if has_payload else 1
|
||||
fields_offset = offset + 1 if has_payload else offset
|
||||
score = float(res[i + 1]) if with_scores else None
|
||||
|
||||
fields = {}
|
||||
if hascontent and res[i + fields_offset] is not None:
|
||||
keys = map(to_string, res[i + fields_offset][::2])
|
||||
values = res[i + fields_offset][1::2]
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
if field_encodings is None or key not in field_encodings:
|
||||
fields[key] = to_string(value)
|
||||
continue
|
||||
|
||||
encoding = field_encodings[key]
|
||||
|
||||
# If the encoding is None, we don't need to decode the value
|
||||
if encoding is None:
|
||||
fields[key] = value
|
||||
else:
|
||||
fields[key] = to_string(value, encoding=encoding)
|
||||
|
||||
try:
|
||||
del fields["id"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
fields["json"] = fields["$"]
|
||||
del fields["$"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
doc = (
|
||||
Document(id, score=score, payload=payload, **fields)
|
||||
if with_scores
|
||||
else Document(id, payload=payload, **fields)
|
||||
)
|
||||
self.docs.append(doc)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Result{{{self.total} total, docs: {self.docs}}}"
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Optional
|
||||
|
||||
from ._util import to_string
|
||||
|
||||
|
||||
class Suggestion:
|
||||
"""
|
||||
Represents a single suggestion being sent or returned from the
|
||||
autocomplete server
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, string: str, score: float = 1.0, payload: Optional[str] = None
|
||||
) -> None:
|
||||
self.string = to_string(string)
|
||||
self.payload = to_string(payload)
|
||||
self.score = score
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.string
|
||||
|
||||
|
||||
class SuggestionParser:
|
||||
"""
|
||||
Internal class used to parse results from the `SUGGET` command.
|
||||
This needs to consume either 1, 2, or 3 values at a time from
|
||||
the return value depending on what objects were requested
|
||||
"""
|
||||
|
||||
def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None:
|
||||
self.with_scores = with_scores
|
||||
self.with_payloads = with_payloads
|
||||
|
||||
if with_scores and with_payloads:
|
||||
self.sugsize = 3
|
||||
self._scoreidx = 1
|
||||
self._payloadidx = 2
|
||||
elif with_scores:
|
||||
self.sugsize = 2
|
||||
self._scoreidx = 1
|
||||
elif with_payloads:
|
||||
self.sugsize = 2
|
||||
self._payloadidx = 1
|
||||
else:
|
||||
self.sugsize = 1
|
||||
self._scoreidx = -1
|
||||
|
||||
self._sugs = ret
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(0, len(self._sugs), self.sugsize):
|
||||
ss = self._sugs[i]
|
||||
score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0
|
||||
payload = self._sugs[i + self._payloadidx] if self.with_payloads else None
|
||||
yield Suggestion(ss, score, payload)
|
||||
@@ -0,0 +1,129 @@
|
||||
import warnings
|
||||
|
||||
|
||||
class SentinelCommands:
|
||||
"""
|
||||
A class containing the commands specific to redis sentinel. This class is
|
||||
to be used as a mixin.
|
||||
"""
|
||||
|
||||
def sentinel(self, *args):
|
||||
"""Redis Sentinel's SENTINEL command."""
|
||||
warnings.warn(DeprecationWarning("Use the individual sentinel_* methods"))
|
||||
|
||||
def sentinel_get_master_addr_by_name(self, service_name, return_responses=False):
|
||||
"""
|
||||
Returns a (host, port) pair for the given ``service_name`` when return_responses is True,
|
||||
otherwise returns a boolean value that indicates if the command was successful.
|
||||
"""
|
||||
return self.execute_command(
|
||||
"SENTINEL GET-MASTER-ADDR-BY-NAME",
|
||||
service_name,
|
||||
once=True,
|
||||
return_responses=return_responses,
|
||||
)
|
||||
|
||||
def sentinel_master(self, service_name, return_responses=False):
|
||||
"""
|
||||
Returns a dictionary containing the specified masters state, when return_responses is True,
|
||||
otherwise returns a boolean value that indicates if the command was successful.
|
||||
"""
|
||||
return self.execute_command(
|
||||
"SENTINEL MASTER", service_name, return_responses=return_responses
|
||||
)
|
||||
|
||||
def sentinel_masters(self):
|
||||
"""
|
||||
Returns a list of dictionaries containing each master's state.
|
||||
|
||||
Important: This function is called by the Sentinel implementation and is
|
||||
called directly on the Redis standalone client for sentinels,
|
||||
so it doesn't support the "once" and "return_responses" options.
|
||||
"""
|
||||
return self.execute_command("SENTINEL MASTERS")
|
||||
|
||||
def sentinel_monitor(self, name, ip, port, quorum):
|
||||
"""Add a new master to Sentinel to be monitored"""
|
||||
return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum)
|
||||
|
||||
def sentinel_remove(self, name):
|
||||
"""Remove a master from Sentinel's monitoring"""
|
||||
return self.execute_command("SENTINEL REMOVE", name)
|
||||
|
||||
def sentinel_sentinels(self, service_name, return_responses=False):
|
||||
"""
|
||||
Returns a list of sentinels for ``service_name``, when return_responses is True,
|
||||
otherwise returns a boolean value that indicates if the command was successful.
|
||||
"""
|
||||
return self.execute_command(
|
||||
"SENTINEL SENTINELS", service_name, return_responses=return_responses
|
||||
)
|
||||
|
||||
def sentinel_set(self, name, option, value):
|
||||
"""Set Sentinel monitoring parameters for a given master"""
|
||||
return self.execute_command("SENTINEL SET", name, option, value)
|
||||
|
||||
def sentinel_slaves(self, service_name):
|
||||
"""
|
||||
Returns a list of slaves for ``service_name``
|
||||
|
||||
Important: This function is called by the Sentinel implementation and is
|
||||
called directly on the Redis standalone client for sentinels,
|
||||
so it doesn't support the "once" and "return_responses" options.
|
||||
"""
|
||||
return self.execute_command("SENTINEL SLAVES", service_name)
|
||||
|
||||
def sentinel_reset(self, pattern):
|
||||
"""
|
||||
This command will reset all the masters with matching name.
|
||||
The pattern argument is a glob-style pattern.
|
||||
|
||||
The reset process clears any previous state in a master (including a
|
||||
failover in progress), and removes every slave and sentinel already
|
||||
discovered and associated with the master.
|
||||
"""
|
||||
return self.execute_command("SENTINEL RESET", pattern, once=True)
|
||||
|
||||
def sentinel_failover(self, new_master_name):
|
||||
"""
|
||||
Force a failover as if the master was not reachable, and without
|
||||
asking for agreement to other Sentinels (however a new version of the
|
||||
configuration will be published so that the other Sentinels will
|
||||
update their configurations).
|
||||
"""
|
||||
return self.execute_command("SENTINEL FAILOVER", new_master_name)
|
||||
|
||||
def sentinel_ckquorum(self, new_master_name):
|
||||
"""
|
||||
Check if the current Sentinel configuration is able to reach the
|
||||
quorum needed to failover a master, and the majority needed to
|
||||
authorize the failover.
|
||||
|
||||
This command should be used in monitoring systems to check if a
|
||||
Sentinel deployment is ok.
|
||||
"""
|
||||
return self.execute_command("SENTINEL CKQUORUM", new_master_name, once=True)
|
||||
|
||||
def sentinel_flushconfig(self):
|
||||
"""
|
||||
Force Sentinel to rewrite its configuration on disk, including the
|
||||
current Sentinel state.
|
||||
|
||||
Normally Sentinel rewrites the configuration every time something
|
||||
changes in its state (in the context of the subset of the state which
|
||||
is persisted on disk across restart).
|
||||
However sometimes it is possible that the configuration file is lost
|
||||
because of operation errors, disk failures, package upgrade scripts or
|
||||
configuration managers. In those cases a way to to force Sentinel to
|
||||
rewrite the configuration file is handy.
|
||||
|
||||
This command works even if the previous configuration file is
|
||||
completely missing.
|
||||
"""
|
||||
return self.execute_command("SENTINEL FLUSHCONFIG")
|
||||
|
||||
|
||||
class AsyncSentinelCommands(SentinelCommands):
|
||||
async def sentinel(self, *args) -> None:
|
||||
"""Redis Sentinel's SENTINEL command."""
|
||||
super().sentinel(*args)
|
||||
@@ -0,0 +1,108 @@
|
||||
import redis
|
||||
from redis._parsers.helpers import bool_ok
|
||||
|
||||
from ..helpers import get_protocol_version, parse_to_list
|
||||
from .commands import (
|
||||
ALTER_CMD,
|
||||
CREATE_CMD,
|
||||
CREATERULE_CMD,
|
||||
DEL_CMD,
|
||||
DELETERULE_CMD,
|
||||
GET_CMD,
|
||||
INFO_CMD,
|
||||
MGET_CMD,
|
||||
MRANGE_CMD,
|
||||
MREVRANGE_CMD,
|
||||
QUERYINDEX_CMD,
|
||||
RANGE_CMD,
|
||||
REVRANGE_CMD,
|
||||
TimeSeriesCommands,
|
||||
)
|
||||
from .info import TSInfo
|
||||
from .utils import parse_get, parse_m_get, parse_m_range, parse_range
|
||||
|
||||
|
||||
class TimeSeries(TimeSeriesCommands):
|
||||
"""
|
||||
This class subclasses redis-py's `Redis` and implements RedisTimeSeries's
|
||||
commands (prefixed with "ts").
|
||||
The client allows to interact with RedisTimeSeries and use all of it's
|
||||
functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, client=None, **kwargs):
|
||||
"""Create a new RedisTimeSeries client."""
|
||||
# Set the module commands' callbacks
|
||||
self._MODULE_CALLBACKS = {
|
||||
ALTER_CMD: bool_ok,
|
||||
CREATE_CMD: bool_ok,
|
||||
CREATERULE_CMD: bool_ok,
|
||||
DELETERULE_CMD: bool_ok,
|
||||
}
|
||||
|
||||
_RESP2_MODULE_CALLBACKS = {
|
||||
DEL_CMD: int,
|
||||
GET_CMD: parse_get,
|
||||
INFO_CMD: TSInfo,
|
||||
MGET_CMD: parse_m_get,
|
||||
MRANGE_CMD: parse_m_range,
|
||||
MREVRANGE_CMD: parse_m_range,
|
||||
RANGE_CMD: parse_range,
|
||||
REVRANGE_CMD: parse_range,
|
||||
QUERYINDEX_CMD: parse_to_list,
|
||||
}
|
||||
_RESP3_MODULE_CALLBACKS = {}
|
||||
|
||||
self.client = client
|
||||
self.execute_command = client.execute_command
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
|
||||
else:
|
||||
self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
|
||||
|
||||
for k, v in self._MODULE_CALLBACKS.items():
|
||||
self.client.set_response_callback(k, v)
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
"""Creates a pipeline for the TimeSeries module, that can be used
|
||||
for executing only TimeSeries commands and core commands.
|
||||
|
||||
Usage example:
|
||||
|
||||
r = redis.Redis()
|
||||
pipe = r.ts().pipeline()
|
||||
for i in range(100):
|
||||
pipeline.add("with_pipeline", i, 1.1 * i)
|
||||
pipeline.execute()
|
||||
|
||||
"""
|
||||
if isinstance(self.client, redis.RedisCluster):
|
||||
p = ClusterPipeline(
|
||||
nodes_manager=self.client.nodes_manager,
|
||||
commands_parser=self.client.commands_parser,
|
||||
startup_nodes=self.client.nodes_manager.startup_nodes,
|
||||
result_callbacks=self.client.result_callbacks,
|
||||
cluster_response_callbacks=self.client.cluster_response_callbacks,
|
||||
cluster_error_retry_attempts=self.client.retry.get_retries(),
|
||||
read_from_replicas=self.client.read_from_replicas,
|
||||
reinitialize_steps=self.client.reinitialize_steps,
|
||||
lock=self.client._lock,
|
||||
)
|
||||
|
||||
else:
|
||||
p = Pipeline(
|
||||
connection_pool=self.client.connection_pool,
|
||||
response_callbacks=self._MODULE_CALLBACKS,
|
||||
transaction=transaction,
|
||||
shard_hint=shard_hint,
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
class ClusterPipeline(TimeSeriesCommands, redis.cluster.ClusterPipeline):
|
||||
"""Cluster pipeline for the module."""
|
||||
|
||||
|
||||
class Pipeline(TimeSeriesCommands, redis.client.Pipeline):
|
||||
"""Pipeline for the module."""
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,91 @@
|
||||
from ..helpers import nativestr
|
||||
from .utils import list_to_dict
|
||||
|
||||
|
||||
class TSInfo:
|
||||
"""
|
||||
Hold information and statistics on the time-series.
|
||||
Can be created using ``tsinfo`` command
|
||||
https://redis.io/docs/latest/commands/ts.info/
|
||||
"""
|
||||
|
||||
rules = []
|
||||
labels = []
|
||||
sourceKey = None
|
||||
chunk_count = None
|
||||
memory_usage = None
|
||||
total_samples = None
|
||||
retention_msecs = None
|
||||
last_time_stamp = None
|
||||
first_time_stamp = None
|
||||
|
||||
max_samples_per_chunk = None
|
||||
chunk_size = None
|
||||
duplicate_policy = None
|
||||
|
||||
def __init__(self, args):
|
||||
"""
|
||||
Hold information and statistics on the time-series.
|
||||
|
||||
The supported params that can be passed as args:
|
||||
|
||||
rules:
|
||||
A list of compaction rules of the time series.
|
||||
sourceKey:
|
||||
Key name for source time series in case the current series
|
||||
is a target of a rule.
|
||||
chunkCount:
|
||||
Number of Memory Chunks used for the time series.
|
||||
memoryUsage:
|
||||
Total number of bytes allocated for the time series.
|
||||
totalSamples:
|
||||
Total number of samples in the time series.
|
||||
labels:
|
||||
A list of label-value pairs that represent the metadata
|
||||
labels of the time series.
|
||||
retentionTime:
|
||||
Retention time, in milliseconds, for the time series.
|
||||
lastTimestamp:
|
||||
Last timestamp present in the time series.
|
||||
firstTimestamp:
|
||||
First timestamp present in the time series.
|
||||
maxSamplesPerChunk:
|
||||
Deprecated.
|
||||
chunkSize:
|
||||
Amount of memory, in bytes, allocated for data.
|
||||
duplicatePolicy:
|
||||
Policy that will define handling of duplicate samples.
|
||||
|
||||
Can read more about on
|
||||
https://redis.io/docs/latest/develop/data-types/timeseries/configuration/#duplicate_policy
|
||||
"""
|
||||
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
|
||||
self.rules = response.get("rules")
|
||||
self.source_key = response.get("sourceKey")
|
||||
self.chunk_count = response.get("chunkCount")
|
||||
self.memory_usage = response.get("memoryUsage")
|
||||
self.total_samples = response.get("totalSamples")
|
||||
self.labels = list_to_dict(response.get("labels"))
|
||||
self.retention_msecs = response.get("retentionTime")
|
||||
self.last_timestamp = response.get("lastTimestamp")
|
||||
self.first_timestamp = response.get("firstTimestamp")
|
||||
if "maxSamplesPerChunk" in response:
|
||||
self.max_samples_per_chunk = response["maxSamplesPerChunk"]
|
||||
self.chunk_size = (
|
||||
self.max_samples_per_chunk * 16
|
||||
) # backward compatible changes
|
||||
if "chunkSize" in response:
|
||||
self.chunk_size = response["chunkSize"]
|
||||
if "duplicatePolicy" in response:
|
||||
self.duplicate_policy = response["duplicatePolicy"]
|
||||
if isinstance(self.duplicate_policy, bytes):
|
||||
self.duplicate_policy = self.duplicate_policy.decode()
|
||||
|
||||
def get(self, item):
|
||||
try:
|
||||
return self.__getitem__(item)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
@@ -0,0 +1,44 @@
|
||||
from ..helpers import nativestr
|
||||
|
||||
|
||||
def list_to_dict(aList):
|
||||
return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))}
|
||||
|
||||
|
||||
def parse_range(response, **kwargs):
|
||||
"""Parse range response. Used by TS.RANGE and TS.REVRANGE."""
|
||||
return [tuple((r[0], float(r[1]))) for r in response]
|
||||
|
||||
|
||||
def parse_m_range(response):
|
||||
"""Parse multi range response. Used by TS.MRANGE and TS.MREVRANGE."""
|
||||
res = []
|
||||
for item in response:
|
||||
res.append({nativestr(item[0]): [list_to_dict(item[1]), parse_range(item[2])]})
|
||||
return sorted(res, key=lambda d: list(d.keys()))
|
||||
|
||||
|
||||
def parse_get(response):
|
||||
"""Parse get response. Used by TS.GET."""
|
||||
if not response:
|
||||
return None
|
||||
return int(response[0]), float(response[1])
|
||||
|
||||
|
||||
def parse_m_get(response):
|
||||
"""Parse multi get response. Used by TS.MGET."""
|
||||
res = []
|
||||
for item in response:
|
||||
if not item[2]:
|
||||
res.append({nativestr(item[0]): [list_to_dict(item[1]), None, None]})
|
||||
else:
|
||||
res.append(
|
||||
{
|
||||
nativestr(item[0]): [
|
||||
list_to_dict(item[1]),
|
||||
int(item[2][0]),
|
||||
float(item[2][1]),
|
||||
]
|
||||
}
|
||||
)
|
||||
return sorted(res, key=lambda d: list(d.keys()))
|
||||
@@ -0,0 +1,46 @@
|
||||
import json
|
||||
|
||||
from redis._parsers.helpers import pairs_to_dict
|
||||
from redis.commands.vectorset.utils import (
|
||||
parse_vemb_result,
|
||||
parse_vlinks_result,
|
||||
parse_vsim_result,
|
||||
)
|
||||
|
||||
from ..helpers import get_protocol_version
|
||||
from .commands import (
|
||||
VEMB_CMD,
|
||||
VGETATTR_CMD,
|
||||
VINFO_CMD,
|
||||
VLINKS_CMD,
|
||||
VSIM_CMD,
|
||||
VectorSetCommands,
|
||||
)
|
||||
|
||||
|
||||
class VectorSet(VectorSetCommands):
|
||||
def __init__(self, client, **kwargs):
|
||||
"""Create a new VectorSet client."""
|
||||
# Set the module commands' callbacks
|
||||
self._MODULE_CALLBACKS = {
|
||||
VEMB_CMD: parse_vemb_result,
|
||||
VSIM_CMD: parse_vsim_result,
|
||||
VGETATTR_CMD: lambda r: r and json.loads(r) or None,
|
||||
}
|
||||
|
||||
self._RESP2_MODULE_CALLBACKS = {
|
||||
VINFO_CMD: lambda r: r and pairs_to_dict(r) or None,
|
||||
VLINKS_CMD: parse_vlinks_result,
|
||||
}
|
||||
self._RESP3_MODULE_CALLBACKS = {}
|
||||
|
||||
self.client = client
|
||||
self.execute_command = client.execute_command
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
self._MODULE_CALLBACKS.update(self._RESP3_MODULE_CALLBACKS)
|
||||
else:
|
||||
self._MODULE_CALLBACKS.update(self._RESP2_MODULE_CALLBACKS)
|
||||
|
||||
for k, v in self._MODULE_CALLBACKS.items():
|
||||
self.client.set_response_callback(k, v)
|
||||
@@ -0,0 +1,392 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Dict, List, Optional, Union
|
||||
|
||||
from redis.client import NEVER_DECODE
|
||||
from redis.commands.helpers import get_protocol_version
|
||||
from redis.exceptions import DataError
|
||||
from redis.typing import CommandsProtocol, EncodableT, KeyT, Number
|
||||
|
||||
VADD_CMD = "VADD"
|
||||
VSIM_CMD = "VSIM"
|
||||
VREM_CMD = "VREM"
|
||||
VDIM_CMD = "VDIM"
|
||||
VCARD_CMD = "VCARD"
|
||||
VEMB_CMD = "VEMB"
|
||||
VLINKS_CMD = "VLINKS"
|
||||
VINFO_CMD = "VINFO"
|
||||
VSETATTR_CMD = "VSETATTR"
|
||||
VGETATTR_CMD = "VGETATTR"
|
||||
VRANDMEMBER_CMD = "VRANDMEMBER"
|
||||
|
||||
# Return type for vsim command
|
||||
VSimResult = Optional[
|
||||
List[
|
||||
Union[
|
||||
List[EncodableT], Dict[EncodableT, Number], Dict[EncodableT, Dict[str, Any]]
|
||||
]
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class QuantizationOptions(Enum):
|
||||
"""Quantization options for the VADD command."""
|
||||
|
||||
NOQUANT = "NOQUANT"
|
||||
BIN = "BIN"
|
||||
Q8 = "Q8"
|
||||
|
||||
|
||||
class CallbacksOptions(Enum):
|
||||
"""Options that can be set for the commands callbacks"""
|
||||
|
||||
RAW = "RAW"
|
||||
WITHSCORES = "WITHSCORES"
|
||||
WITHATTRIBS = "WITHATTRIBS"
|
||||
ALLOW_DECODING = "ALLOW_DECODING"
|
||||
RESP3 = "RESP3"
|
||||
|
||||
|
||||
class VectorSetCommands(CommandsProtocol):
|
||||
"""Redis VectorSet commands"""
|
||||
|
||||
def vadd(
|
||||
self,
|
||||
key: KeyT,
|
||||
vector: Union[List[float], bytes],
|
||||
element: str,
|
||||
reduce_dim: Optional[int] = None,
|
||||
cas: Optional[bool] = False,
|
||||
quantization: Optional[QuantizationOptions] = None,
|
||||
ef: Optional[Number] = None,
|
||||
attributes: Optional[Union[dict, str]] = None,
|
||||
numlinks: Optional[int] = None,
|
||||
) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Add vector ``vector`` for element ``element`` to a vector set ``key``.
|
||||
|
||||
``reduce_dim`` sets the dimensions to reduce the vector to.
|
||||
If not provided, the vector is not reduced.
|
||||
|
||||
``cas`` is a boolean flag that indicates whether to use CAS (check-and-set style)
|
||||
when adding the vector. If not provided, CAS is not used.
|
||||
|
||||
``quantization`` sets the quantization type to use.
|
||||
If not provided, int8 quantization is used.
|
||||
The options are:
|
||||
- NOQUANT: No quantization
|
||||
- BIN: Binary quantization
|
||||
- Q8: Signed 8-bit quantization
|
||||
|
||||
``ef`` sets the exploration factor to use.
|
||||
If not provided, the default exploration factor is used.
|
||||
|
||||
``attributes`` is a dictionary or json string that contains the attributes to set for the vector.
|
||||
If not provided, no attributes are set.
|
||||
|
||||
``numlinks`` sets the number of links to create for the vector.
|
||||
If not provided, the default number of links is used.
|
||||
|
||||
For more information, see https://redis.io/commands/vadd.
|
||||
"""
|
||||
if not vector or not element:
|
||||
raise DataError("Both vector and element must be provided")
|
||||
|
||||
pieces = []
|
||||
if reduce_dim:
|
||||
pieces.extend(["REDUCE", reduce_dim])
|
||||
|
||||
values_pieces = []
|
||||
if isinstance(vector, bytes):
|
||||
values_pieces.extend(["FP32", vector])
|
||||
else:
|
||||
values_pieces.extend(["VALUES", len(vector)])
|
||||
values_pieces.extend(vector)
|
||||
pieces.extend(values_pieces)
|
||||
|
||||
pieces.append(element)
|
||||
|
||||
if cas:
|
||||
pieces.append("CAS")
|
||||
|
||||
if quantization:
|
||||
pieces.append(quantization.value)
|
||||
|
||||
if ef:
|
||||
pieces.extend(["EF", ef])
|
||||
|
||||
if attributes:
|
||||
if isinstance(attributes, dict):
|
||||
# transform attributes to json string
|
||||
attributes_json = json.dumps(attributes)
|
||||
else:
|
||||
attributes_json = attributes
|
||||
pieces.extend(["SETATTR", attributes_json])
|
||||
|
||||
if numlinks:
|
||||
pieces.extend(["M", numlinks])
|
||||
|
||||
return self.execute_command(VADD_CMD, key, *pieces)
|
||||
|
||||
def vsim(
|
||||
self,
|
||||
key: KeyT,
|
||||
input: Union[List[float], bytes, str],
|
||||
with_scores: Optional[bool] = False,
|
||||
with_attribs: Optional[bool] = False,
|
||||
count: Optional[int] = None,
|
||||
ef: Optional[Number] = None,
|
||||
filter: Optional[str] = None,
|
||||
filter_ef: Optional[str] = None,
|
||||
truth: Optional[bool] = False,
|
||||
no_thread: Optional[bool] = False,
|
||||
epsilon: Optional[Number] = None,
|
||||
) -> Union[Awaitable[VSimResult], VSimResult]:
|
||||
"""
|
||||
Compare a vector or element ``input`` with the other vectors in a vector set ``key``.
|
||||
|
||||
``with_scores`` sets if similarity scores should be returned for each element in the result.
|
||||
|
||||
``with_attribs`` ``with_attribs`` sets if the results should be returned with the
|
||||
attributes of the elements in the result, or None when no attributes are present.
|
||||
|
||||
``count`` sets the number of results to return.
|
||||
|
||||
``ef`` sets the exploration factor.
|
||||
|
||||
``filter`` sets the filter that should be applied for the search.
|
||||
|
||||
``filter_ef`` sets the max filtering effort.
|
||||
|
||||
``truth`` when enabled, forces the command to perform a linear scan.
|
||||
|
||||
``no_thread`` when enabled forces the command to execute the search
|
||||
on the data structure in the main thread.
|
||||
|
||||
``epsilon`` floating point between 0 and 1, if specified will return
|
||||
only elements with distance no further than the specified one.
|
||||
|
||||
For more information, see https://redis.io/commands/vsim.
|
||||
"""
|
||||
|
||||
if not input:
|
||||
raise DataError("'input' should be provided")
|
||||
|
||||
pieces = []
|
||||
options = {}
|
||||
|
||||
if isinstance(input, bytes):
|
||||
pieces.extend(["FP32", input])
|
||||
elif isinstance(input, list):
|
||||
pieces.extend(["VALUES", len(input)])
|
||||
pieces.extend(input)
|
||||
else:
|
||||
pieces.extend(["ELE", input])
|
||||
|
||||
if with_scores or with_attribs:
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
options[CallbacksOptions.RESP3.value] = True
|
||||
|
||||
if with_scores:
|
||||
pieces.append("WITHSCORES")
|
||||
options[CallbacksOptions.WITHSCORES.value] = True
|
||||
|
||||
if with_attribs:
|
||||
pieces.append("WITHATTRIBS")
|
||||
options[CallbacksOptions.WITHATTRIBS.value] = True
|
||||
|
||||
if count:
|
||||
pieces.extend(["COUNT", count])
|
||||
|
||||
if epsilon:
|
||||
pieces.extend(["EPSILON", epsilon])
|
||||
|
||||
if ef:
|
||||
pieces.extend(["EF", ef])
|
||||
|
||||
if filter:
|
||||
pieces.extend(["FILTER", filter])
|
||||
|
||||
if filter_ef:
|
||||
pieces.extend(["FILTER-EF", filter_ef])
|
||||
|
||||
if truth:
|
||||
pieces.append("TRUTH")
|
||||
|
||||
if no_thread:
|
||||
pieces.append("NOTHREAD")
|
||||
|
||||
return self.execute_command(VSIM_CMD, key, *pieces, **options)
|
||||
|
||||
def vdim(self, key: KeyT) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Get the dimension of a vector set.
|
||||
|
||||
In the case of vectors that were populated using the `REDUCE`
|
||||
option, for random projection, the vector set will report the size of
|
||||
the projected (reduced) dimension.
|
||||
|
||||
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
|
||||
|
||||
For more information, see https://redis.io/commands/vdim.
|
||||
"""
|
||||
return self.execute_command(VDIM_CMD, key)
|
||||
|
||||
def vcard(self, key: KeyT) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Get the cardinality(the number of elements) of a vector set with key ``key``.
|
||||
|
||||
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
|
||||
|
||||
For more information, see https://redis.io/commands/vcard.
|
||||
"""
|
||||
return self.execute_command(VCARD_CMD, key)
|
||||
|
||||
def vrem(self, key: KeyT, element: str) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Remove an element from a vector set.
|
||||
|
||||
For more information, see https://redis.io/commands/vrem.
|
||||
"""
|
||||
return self.execute_command(VREM_CMD, key, element)
|
||||
|
||||
def vemb(
|
||||
self, key: KeyT, element: str, raw: Optional[bool] = False
|
||||
) -> Union[
|
||||
Awaitable[Optional[Union[List[EncodableT], Dict[str, EncodableT]]]],
|
||||
Optional[Union[List[EncodableT], Dict[str, EncodableT]]],
|
||||
]:
|
||||
"""
|
||||
Get the approximated vector of an element ``element`` from vector set ``key``.
|
||||
|
||||
``raw`` is a boolean flag that indicates whether to return the
|
||||
internal representation used by the vector.
|
||||
|
||||
|
||||
For more information, see https://redis.io/commands/vemb.
|
||||
"""
|
||||
options = {}
|
||||
pieces = []
|
||||
pieces.extend([key, element])
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
options[CallbacksOptions.RESP3.value] = True
|
||||
|
||||
if raw:
|
||||
pieces.append("RAW")
|
||||
|
||||
options[NEVER_DECODE] = True
|
||||
if (
|
||||
hasattr(self.client, "connection_pool")
|
||||
and self.client.connection_pool.connection_kwargs["decode_responses"]
|
||||
) or (
|
||||
hasattr(self.client, "nodes_manager")
|
||||
and self.client.nodes_manager.connection_kwargs["decode_responses"]
|
||||
):
|
||||
# allow decoding in the postprocessing callback
|
||||
# if the user set decode_responses=True
|
||||
# in the connection pool
|
||||
options[CallbacksOptions.ALLOW_DECODING.value] = True
|
||||
|
||||
options[CallbacksOptions.RAW.value] = True
|
||||
|
||||
return self.execute_command(VEMB_CMD, *pieces, **options)
|
||||
|
||||
def vlinks(
|
||||
self, key: KeyT, element: str, with_scores: Optional[bool] = False
|
||||
) -> Union[
|
||||
Awaitable[
|
||||
Optional[
|
||||
List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]
|
||||
]
|
||||
],
|
||||
Optional[List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]],
|
||||
]:
|
||||
"""
|
||||
Returns the neighbors for each level the element ``element`` exists in the vector set ``key``.
|
||||
|
||||
The result is a list of lists, where each list contains the neighbors for one level.
|
||||
If the element does not exist, or if the vector set does not exist, None is returned.
|
||||
|
||||
If the ``WITHSCORES`` option is provided, the result is a list of dicts,
|
||||
where each dict contains the neighbors for one level, with the scores as values.
|
||||
|
||||
For more information, see https://redis.io/commands/vlinks
|
||||
"""
|
||||
options = {}
|
||||
pieces = []
|
||||
pieces.extend([key, element])
|
||||
|
||||
if with_scores:
|
||||
pieces.append("WITHSCORES")
|
||||
options[CallbacksOptions.WITHSCORES.value] = True
|
||||
|
||||
return self.execute_command(VLINKS_CMD, *pieces, **options)
|
||||
|
||||
def vinfo(self, key: KeyT) -> Union[Awaitable[dict], dict]:
|
||||
"""
|
||||
Get information about a vector set.
|
||||
|
||||
For more information, see https://redis.io/commands/vinfo.
|
||||
"""
|
||||
return self.execute_command(VINFO_CMD, key)
|
||||
|
||||
def vsetattr(
|
||||
self, key: KeyT, element: str, attributes: Optional[Union[dict, str]] = None
|
||||
) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Associate or remove JSON attributes ``attributes`` of element ``element``
|
||||
for vector set ``key``.
|
||||
|
||||
For more information, see https://redis.io/commands/vsetattr
|
||||
"""
|
||||
if attributes is None:
|
||||
attributes_json = "{}"
|
||||
elif isinstance(attributes, dict):
|
||||
# transform attributes to json string
|
||||
attributes_json = json.dumps(attributes)
|
||||
else:
|
||||
attributes_json = attributes
|
||||
|
||||
return self.execute_command(VSETATTR_CMD, key, element, attributes_json)
|
||||
|
||||
def vgetattr(
|
||||
self, key: KeyT, element: str
|
||||
) -> Union[Optional[Awaitable[dict]], Optional[dict]]:
|
||||
"""
|
||||
Retrieve the JSON attributes of an element ``element `` for vector set ``key``.
|
||||
|
||||
If the element does not exist, or if the vector set does not exist, None is
|
||||
returned.
|
||||
|
||||
For more information, see https://redis.io/commands/vgetattr.
|
||||
"""
|
||||
return self.execute_command(VGETATTR_CMD, key, element)
|
||||
|
||||
def vrandmember(
|
||||
self, key: KeyT, count: Optional[int] = None
|
||||
) -> Union[
|
||||
Awaitable[Optional[Union[List[str], str]]], Optional[Union[List[str], str]]
|
||||
]:
|
||||
"""
|
||||
Returns random elements from a vector set ``key``.
|
||||
|
||||
``count`` is the number of elements to return.
|
||||
If ``count`` is not provided, a single element is returned as a single string.
|
||||
If ``count`` is positive(smaller than the number of elements
|
||||
in the vector set), the command returns a list with up to ``count``
|
||||
distinct elements from the vector set
|
||||
If ``count`` is negative, the command returns a list with ``count`` random elements,
|
||||
potentially with duplicates.
|
||||
If ``count`` is greater than the number of elements in the vector set,
|
||||
only the entire set is returned as a list.
|
||||
|
||||
If the vector set does not exist, ``None`` is returned.
|
||||
|
||||
For more information, see https://redis.io/commands/vrandmember.
|
||||
"""
|
||||
pieces = []
|
||||
pieces.append(key)
|
||||
if count is not None:
|
||||
pieces.append(count)
|
||||
return self.execute_command(VRANDMEMBER_CMD, *pieces)
|
||||
@@ -0,0 +1,130 @@
|
||||
import json
|
||||
|
||||
from redis._parsers.helpers import pairs_to_dict
|
||||
from redis.commands.vectorset.commands import CallbacksOptions
|
||||
|
||||
|
||||
def parse_vemb_result(response, **options):
|
||||
"""
|
||||
Handle VEMB result since the command can returning different result
|
||||
structures depending on input options and on quantization type of the vector set.
|
||||
|
||||
Parsing VEMB result into:
|
||||
- List[Union[bytes, Union[int, float]]]
|
||||
- Dict[str, Union[bytes, str, float]]
|
||||
"""
|
||||
if response is None:
|
||||
return response
|
||||
|
||||
if options.get(CallbacksOptions.RAW.value):
|
||||
result = {}
|
||||
result["quantization"] = (
|
||||
response[0].decode("utf-8")
|
||||
if options.get(CallbacksOptions.ALLOW_DECODING.value)
|
||||
else response[0]
|
||||
)
|
||||
result["raw"] = response[1]
|
||||
result["l2"] = float(response[2])
|
||||
if len(response) > 3:
|
||||
result["range"] = float(response[3])
|
||||
return result
|
||||
else:
|
||||
if options.get(CallbacksOptions.RESP3.value):
|
||||
return response
|
||||
|
||||
result = []
|
||||
for i in range(len(response)):
|
||||
try:
|
||||
result.append(int(response[i]))
|
||||
except ValueError:
|
||||
# if the value is not an integer, it should be a float
|
||||
result.append(float(response[i]))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_vlinks_result(response, **options):
|
||||
"""
|
||||
Handle VLINKS result since the command can be returning different result
|
||||
structures depending on input options.
|
||||
Parsing VLINKS result into:
|
||||
- List[List[str]]
|
||||
- List[Dict[str, Number]]
|
||||
"""
|
||||
if response is None:
|
||||
return response
|
||||
|
||||
if options.get(CallbacksOptions.WITHSCORES.value):
|
||||
result = []
|
||||
# Redis will return a list of list of strings.
|
||||
# This list have to be transformed to list of dicts
|
||||
for level_item in response:
|
||||
level_data_dict = {}
|
||||
for key, value in pairs_to_dict(level_item).items():
|
||||
value = float(value)
|
||||
level_data_dict[key] = value
|
||||
result.append(level_data_dict)
|
||||
return result
|
||||
else:
|
||||
# return the list of elements for each level
|
||||
# list of lists
|
||||
return response
|
||||
|
||||
|
||||
def parse_vsim_result(response, **options):
|
||||
"""
|
||||
Handle VSIM result since the command can be returning different result
|
||||
structures depending on input options.
|
||||
Parsing VSIM result into:
|
||||
- List[List[str]]
|
||||
- List[Dict[str, Number]] - when with_scores is used (without attributes)
|
||||
- List[Dict[str, Mapping[str, Any]]] - when with_attribs is used (without scores)
|
||||
- List[Dict[str, Union[Number, Mapping[str, Any]]]] - when with_scores and with_attribs are used
|
||||
|
||||
"""
|
||||
if response is None:
|
||||
return response
|
||||
|
||||
withscores = bool(options.get(CallbacksOptions.WITHSCORES.value))
|
||||
withattribs = bool(options.get(CallbacksOptions.WITHATTRIBS.value))
|
||||
|
||||
# Exactly one of withscores or withattribs is True
|
||||
if (withscores and not withattribs) or (not withscores and withattribs):
|
||||
# Redis will return a list of list of pairs.
|
||||
# This list have to be transformed to dict
|
||||
result_dict = {}
|
||||
if options.get(CallbacksOptions.RESP3.value):
|
||||
resp_dict = response
|
||||
else:
|
||||
resp_dict = pairs_to_dict(response)
|
||||
for key, value in resp_dict.items():
|
||||
if withscores:
|
||||
value = float(value)
|
||||
else:
|
||||
value = json.loads(value) if value else None
|
||||
|
||||
result_dict[key] = value
|
||||
return result_dict
|
||||
elif withscores and withattribs:
|
||||
it = iter(response)
|
||||
result_dict = {}
|
||||
if options.get(CallbacksOptions.RESP3.value):
|
||||
for elem, data in response.items():
|
||||
if data[1] is not None:
|
||||
attribs_dict = json.loads(data[1])
|
||||
else:
|
||||
attribs_dict = None
|
||||
result_dict[elem] = {"score": data[0], "attributes": attribs_dict}
|
||||
else:
|
||||
for elem, score, attribs in zip(it, it, it):
|
||||
if attribs is not None:
|
||||
attribs_dict = json.loads(attribs)
|
||||
else:
|
||||
attribs_dict = None
|
||||
|
||||
result_dict[elem] = {"score": float(score), "attributes": attribs_dict}
|
||||
return result_dict
|
||||
else:
|
||||
# return the list of elements for each level
|
||||
# list of lists
|
||||
return response
|
||||
2999
backend/venv/lib/python3.9/site-packages/redis/connection.py
Normal file
2999
backend/venv/lib/python3.9/site-packages/redis/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
23
backend/venv/lib/python3.9/site-packages/redis/crc.py
Normal file
23
backend/venv/lib/python3.9/site-packages/redis/crc.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from binascii import crc_hqx
|
||||
|
||||
from redis.typing import EncodedT
|
||||
|
||||
# Redis Cluster's key space is divided into 16384 slots.
|
||||
# For more information see: https://github.com/redis/redis/issues/2576
|
||||
REDIS_CLUSTER_HASH_SLOTS = 16384
|
||||
|
||||
__all__ = ["key_slot", "REDIS_CLUSTER_HASH_SLOTS"]
|
||||
|
||||
|
||||
def key_slot(key: EncodedT, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int:
|
||||
"""Calculate key slot for a given key.
|
||||
See Keys distribution model in https://redis.io/topics/cluster-spec
|
||||
:param key - bytes
|
||||
:param bucket - int
|
||||
"""
|
||||
start = key.find(b"{")
|
||||
if start > -1:
|
||||
end = key.find(b"}", start + 1)
|
||||
if end > -1 and end != start + 1:
|
||||
key = key[start + 1 : end]
|
||||
return crc_hqx(key, 0) % bucket
|
||||
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialProvider:
|
||||
"""
|
||||
Credentials Provider.
|
||||
"""
|
||||
|
||||
def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
|
||||
raise NotImplementedError("get_credentials must be implemented")
|
||||
|
||||
async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
|
||||
logger.warning(
|
||||
"This method is added for backward compatability. "
|
||||
"Please override it in your implementation."
|
||||
)
|
||||
return self.get_credentials()
|
||||
|
||||
|
||||
class StreamingCredentialProvider(CredentialProvider, ABC):
|
||||
"""
|
||||
Credential provider that streams credentials in the background.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def on_next(self, callback: Callable[[Any], None]):
|
||||
"""
|
||||
Specifies the callback that should be invoked
|
||||
when the next credentials will be retrieved.
|
||||
|
||||
:param callback: Callback with
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_error(self, callback: Callable[[Exception], None]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_streaming(self) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class UsernamePasswordCredentialProvider(CredentialProvider):
|
||||
"""
|
||||
Simple implementation of CredentialProvider that just wraps static
|
||||
username and password.
|
||||
"""
|
||||
|
||||
def __init__(self, username: Optional[str] = None, password: Optional[str] = None):
|
||||
self.username = username or ""
|
||||
self.password = password or ""
|
||||
|
||||
def get_credentials(self):
|
||||
if self.username:
|
||||
return self.username, self.password
|
||||
return (self.password,)
|
||||
|
||||
async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
|
||||
return self.get_credentials()
|
||||
@@ -0,0 +1,81 @@
|
||||
import threading
|
||||
from typing import Any, Generic, List, TypeVar
|
||||
|
||||
from redis.typing import Number
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class WeightedList(Generic[T]):
|
||||
"""
|
||||
Thread-safe weighted list.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._items: List[tuple[Any, Number]] = []
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def add(self, item: Any, weight: float) -> None:
|
||||
"""Add item with weight, maintaining sorted order"""
|
||||
with self._lock:
|
||||
# Find insertion point using binary search
|
||||
left, right = 0, len(self._items)
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if self._items[mid][1] < weight:
|
||||
right = mid
|
||||
else:
|
||||
left = mid + 1
|
||||
|
||||
self._items.insert(left, (item, weight))
|
||||
|
||||
def remove(self, item):
|
||||
"""Remove first occurrence of item"""
|
||||
with self._lock:
|
||||
for i, (stored_item, weight) in enumerate(self._items):
|
||||
if stored_item == item:
|
||||
self._items.pop(i)
|
||||
return weight
|
||||
raise ValueError("Item not found")
|
||||
|
||||
def get_by_weight_range(
|
||||
self, min_weight: float, max_weight: float
|
||||
) -> List[tuple[Any, Number]]:
|
||||
"""Get all items within weight range"""
|
||||
with self._lock:
|
||||
result = []
|
||||
for item, weight in self._items:
|
||||
if min_weight <= weight <= max_weight:
|
||||
result.append((item, weight))
|
||||
return result
|
||||
|
||||
def get_top_n(self, n: int) -> List[tuple[Any, Number]]:
|
||||
"""Get top N the highest weighted items"""
|
||||
with self._lock:
|
||||
return [(item, weight) for item, weight in self._items[:n]]
|
||||
|
||||
def update_weight(self, item, new_weight: float):
|
||||
with self._lock:
|
||||
"""Update weight of an item"""
|
||||
old_weight = self.remove(item)
|
||||
self.add(item, new_weight)
|
||||
return old_weight
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate in descending weight order"""
|
||||
with self._lock:
|
||||
items_copy = (
|
||||
self._items.copy()
|
||||
) # Create snapshot as lock released after each 'yield'
|
||||
|
||||
for item, weight in items_copy:
|
||||
yield item, weight
|
||||
|
||||
def __len__(self):
|
||||
with self._lock:
|
||||
return len(self._items)
|
||||
|
||||
def __getitem__(self, index) -> tuple[Any, Number]:
|
||||
with self._lock:
|
||||
item, weight = self._items[index]
|
||||
return item, weight
|
||||
468
backend/venv/lib/python3.9/site-packages/redis/event.py
Normal file
468
backend/venv/lib/python3.9/site-packages/redis/event.py
Normal file
@@ -0,0 +1,468 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from redis.auth.token import TokenInterface
|
||||
from redis.credentials import CredentialProvider, StreamingCredentialProvider
|
||||
|
||||
|
||||
class EventListenerInterface(ABC):
|
||||
"""
|
||||
Represents a listener for given event object.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def listen(self, event: object):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncEventListenerInterface(ABC):
|
||||
"""
|
||||
Represents an async listener for given event object.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def listen(self, event: object):
|
||||
pass
|
||||
|
||||
|
||||
class EventDispatcherInterface(ABC):
|
||||
"""
|
||||
Represents a dispatcher that dispatches events to listeners
|
||||
associated with given event.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dispatch(self, event: object):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def dispatch_async(self, event: object):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_listeners(
|
||||
self,
|
||||
mappings: Dict[
|
||||
Type[object],
|
||||
List[Union[EventListenerInterface, AsyncEventListenerInterface]],
|
||||
],
|
||||
):
|
||||
"""Register additional listeners."""
|
||||
pass
|
||||
|
||||
|
||||
class EventException(Exception):
|
||||
"""
|
||||
Exception wrapper that adds an event object into exception context.
|
||||
"""
|
||||
|
||||
def __init__(self, exception: Exception, event: object):
|
||||
self.exception = exception
|
||||
self.event = event
|
||||
super().__init__(exception)
|
||||
|
||||
|
||||
class EventDispatcher(EventDispatcherInterface):
|
||||
# TODO: Make dispatcher to accept external mappings.
|
||||
def __init__(
|
||||
self,
|
||||
event_listeners: Optional[
|
||||
Dict[Type[object], List[EventListenerInterface]]
|
||||
] = None,
|
||||
):
|
||||
"""
|
||||
Dispatcher that dispatches events to listeners associated with given event.
|
||||
"""
|
||||
self._event_listeners_mapping: Dict[
|
||||
Type[object], List[EventListenerInterface]
|
||||
] = {
|
||||
AfterConnectionReleasedEvent: [
|
||||
ReAuthConnectionListener(),
|
||||
],
|
||||
AfterPooledConnectionsInstantiationEvent: [
|
||||
RegisterReAuthForPooledConnections()
|
||||
],
|
||||
AfterSingleConnectionInstantiationEvent: [
|
||||
RegisterReAuthForSingleConnection()
|
||||
],
|
||||
AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()],
|
||||
AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()],
|
||||
AsyncAfterConnectionReleasedEvent: [
|
||||
AsyncReAuthConnectionListener(),
|
||||
],
|
||||
}
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._async_lock = None
|
||||
|
||||
if event_listeners:
|
||||
self.register_listeners(event_listeners)
|
||||
|
||||
def dispatch(self, event: object):
|
||||
with self._lock:
|
||||
listeners = self._event_listeners_mapping.get(type(event), [])
|
||||
|
||||
for listener in listeners:
|
||||
listener.listen(event)
|
||||
|
||||
async def dispatch_async(self, event: object):
|
||||
if self._async_lock is None:
|
||||
self._async_lock = asyncio.Lock()
|
||||
|
||||
async with self._async_lock:
|
||||
listeners = self._event_listeners_mapping.get(type(event), [])
|
||||
|
||||
for listener in listeners:
|
||||
await listener.listen(event)
|
||||
|
||||
def register_listeners(
|
||||
self,
|
||||
mappings: Dict[
|
||||
Type[object],
|
||||
List[Union[EventListenerInterface, AsyncEventListenerInterface]],
|
||||
],
|
||||
):
|
||||
with self._lock:
|
||||
for event_type in mappings:
|
||||
if event_type in self._event_listeners_mapping:
|
||||
self._event_listeners_mapping[event_type] = list(
|
||||
set(
|
||||
self._event_listeners_mapping[event_type]
|
||||
+ mappings[event_type]
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._event_listeners_mapping[event_type] = mappings[event_type]
|
||||
|
||||
|
||||
class AfterConnectionReleasedEvent:
|
||||
"""
|
||||
Event that will be fired before each command execution.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
self._connection = connection
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return self._connection
|
||||
|
||||
|
||||
class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ClientType(Enum):
|
||||
SYNC = ("sync",)
|
||||
ASYNC = ("async",)
|
||||
|
||||
|
||||
class AfterPooledConnectionsInstantiationEvent:
|
||||
"""
|
||||
Event that will be fired after pooled connection instances was created.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_pools: List,
|
||||
client_type: ClientType,
|
||||
credential_provider: Optional[CredentialProvider] = None,
|
||||
):
|
||||
self._connection_pools = connection_pools
|
||||
self._client_type = client_type
|
||||
self._credential_provider = credential_provider
|
||||
|
||||
@property
|
||||
def connection_pools(self):
|
||||
return self._connection_pools
|
||||
|
||||
@property
|
||||
def client_type(self) -> ClientType:
|
||||
return self._client_type
|
||||
|
||||
@property
|
||||
def credential_provider(self) -> Union[CredentialProvider, None]:
|
||||
return self._credential_provider
|
||||
|
||||
|
||||
class AfterSingleConnectionInstantiationEvent:
|
||||
"""
|
||||
Event that will be fired after single connection instances was created.
|
||||
|
||||
:param connection_lock: For sync client thread-lock should be provided,
|
||||
for async asyncio.Lock
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection,
|
||||
client_type: ClientType,
|
||||
connection_lock: Union[threading.RLock, asyncio.Lock],
|
||||
):
|
||||
self._connection = connection
|
||||
self._client_type = client_type
|
||||
self._connection_lock = connection_lock
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def client_type(self) -> ClientType:
|
||||
return self._client_type
|
||||
|
||||
@property
|
||||
def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
|
||||
return self._connection_lock
|
||||
|
||||
|
||||
class AfterPubSubConnectionInstantiationEvent:
|
||||
def __init__(
|
||||
self,
|
||||
pubsub_connection,
|
||||
connection_pool,
|
||||
client_type: ClientType,
|
||||
connection_lock: Union[threading.RLock, asyncio.Lock],
|
||||
):
|
||||
self._pubsub_connection = pubsub_connection
|
||||
self._connection_pool = connection_pool
|
||||
self._client_type = client_type
|
||||
self._connection_lock = connection_lock
|
||||
|
||||
@property
|
||||
def pubsub_connection(self):
|
||||
return self._pubsub_connection
|
||||
|
||||
@property
|
||||
def connection_pool(self):
|
||||
return self._connection_pool
|
||||
|
||||
@property
|
||||
def client_type(self) -> ClientType:
|
||||
return self._client_type
|
||||
|
||||
@property
|
||||
def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
|
||||
return self._connection_lock
|
||||
|
||||
|
||||
class AfterAsyncClusterInstantiationEvent:
|
||||
"""
|
||||
Event that will be fired after async cluster instance was created.
|
||||
|
||||
Async cluster doesn't use connection pools,
|
||||
instead ClusterNode object manages connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: dict,
|
||||
credential_provider: Optional[CredentialProvider] = None,
|
||||
):
|
||||
self._nodes = nodes
|
||||
self._credential_provider = credential_provider
|
||||
|
||||
@property
|
||||
def nodes(self) -> dict:
|
||||
return self._nodes
|
||||
|
||||
@property
|
||||
def credential_provider(self) -> Union[CredentialProvider, None]:
|
||||
return self._credential_provider
|
||||
|
||||
|
||||
class OnCommandsFailEvent:
|
||||
"""
|
||||
Event fired whenever a command fails during the execution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
commands: tuple,
|
||||
exception: Exception,
|
||||
):
|
||||
self._commands = commands
|
||||
self._exception = exception
|
||||
|
||||
@property
|
||||
def commands(self) -> tuple:
|
||||
return self._commands
|
||||
|
||||
@property
|
||||
def exception(self) -> Exception:
|
||||
return self._exception
|
||||
|
||||
|
||||
class AsyncOnCommandsFailEvent(OnCommandsFailEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ReAuthConnectionListener(EventListenerInterface):
|
||||
"""
|
||||
Listener that performs re-authentication of given connection.
|
||||
"""
|
||||
|
||||
def listen(self, event: AfterConnectionReleasedEvent):
|
||||
event.connection.re_auth()
|
||||
|
||||
|
||||
class AsyncReAuthConnectionListener(AsyncEventListenerInterface):
|
||||
"""
|
||||
Async listener that performs re-authentication of given connection.
|
||||
"""
|
||||
|
||||
async def listen(self, event: AsyncAfterConnectionReleasedEvent):
|
||||
await event.connection.re_auth()
|
||||
|
||||
|
||||
class RegisterReAuthForPooledConnections(EventListenerInterface):
|
||||
"""
|
||||
Listener that registers a re-authentication callback for pooled connections.
|
||||
Required by :class:`StreamingCredentialProvider`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._event = None
|
||||
|
||||
def listen(self, event: AfterPooledConnectionsInstantiationEvent):
|
||||
if isinstance(event.credential_provider, StreamingCredentialProvider):
|
||||
self._event = event
|
||||
|
||||
if event.client_type == ClientType.SYNC:
|
||||
event.credential_provider.on_next(self._re_auth)
|
||||
event.credential_provider.on_error(self._raise_on_error)
|
||||
else:
|
||||
event.credential_provider.on_next(self._re_auth_async)
|
||||
event.credential_provider.on_error(self._raise_on_error_async)
|
||||
|
||||
def _re_auth(self, token):
|
||||
for pool in self._event.connection_pools:
|
||||
pool.re_auth_callback(token)
|
||||
|
||||
async def _re_auth_async(self, token):
|
||||
for pool in self._event.connection_pools:
|
||||
await pool.re_auth_callback(token)
|
||||
|
||||
def _raise_on_error(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
async def _raise_on_error_async(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
|
||||
class RegisterReAuthForSingleConnection(EventListenerInterface):
|
||||
"""
|
||||
Listener that registers a re-authentication callback for single connection.
|
||||
Required by :class:`StreamingCredentialProvider`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._event = None
|
||||
|
||||
def listen(self, event: AfterSingleConnectionInstantiationEvent):
|
||||
if isinstance(
|
||||
event.connection.credential_provider, StreamingCredentialProvider
|
||||
):
|
||||
self._event = event
|
||||
|
||||
if event.client_type == ClientType.SYNC:
|
||||
event.connection.credential_provider.on_next(self._re_auth)
|
||||
event.connection.credential_provider.on_error(self._raise_on_error)
|
||||
else:
|
||||
event.connection.credential_provider.on_next(self._re_auth_async)
|
||||
event.connection.credential_provider.on_error(
|
||||
self._raise_on_error_async
|
||||
)
|
||||
|
||||
def _re_auth(self, token):
|
||||
with self._event.connection_lock:
|
||||
self._event.connection.send_command(
|
||||
"AUTH", token.try_get("oid"), token.get_value()
|
||||
)
|
||||
self._event.connection.read_response()
|
||||
|
||||
async def _re_auth_async(self, token):
|
||||
async with self._event.connection_lock:
|
||||
await self._event.connection.send_command(
|
||||
"AUTH", token.try_get("oid"), token.get_value()
|
||||
)
|
||||
await self._event.connection.read_response()
|
||||
|
||||
def _raise_on_error(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
async def _raise_on_error_async(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
|
||||
class RegisterReAuthForAsyncClusterNodes(EventListenerInterface):
|
||||
def __init__(self):
|
||||
self._event = None
|
||||
|
||||
def listen(self, event: AfterAsyncClusterInstantiationEvent):
|
||||
if isinstance(event.credential_provider, StreamingCredentialProvider):
|
||||
self._event = event
|
||||
event.credential_provider.on_next(self._re_auth)
|
||||
event.credential_provider.on_error(self._raise_on_error)
|
||||
|
||||
async def _re_auth(self, token: TokenInterface):
|
||||
for key in self._event.nodes:
|
||||
await self._event.nodes[key].re_auth_callback(token)
|
||||
|
||||
async def _raise_on_error(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
|
||||
class RegisterReAuthForPubSub(EventListenerInterface):
|
||||
def __init__(self):
|
||||
self._connection = None
|
||||
self._connection_pool = None
|
||||
self._client_type = None
|
||||
self._connection_lock = None
|
||||
self._event = None
|
||||
|
||||
def listen(self, event: AfterPubSubConnectionInstantiationEvent):
|
||||
if isinstance(
|
||||
event.pubsub_connection.credential_provider, StreamingCredentialProvider
|
||||
) and event.pubsub_connection.get_protocol() in [3, "3"]:
|
||||
self._event = event
|
||||
self._connection = event.pubsub_connection
|
||||
self._connection_pool = event.connection_pool
|
||||
self._client_type = event.client_type
|
||||
self._connection_lock = event.connection_lock
|
||||
|
||||
if self._client_type == ClientType.SYNC:
|
||||
self._connection.credential_provider.on_next(self._re_auth)
|
||||
self._connection.credential_provider.on_error(self._raise_on_error)
|
||||
else:
|
||||
self._connection.credential_provider.on_next(self._re_auth_async)
|
||||
self._connection.credential_provider.on_error(
|
||||
self._raise_on_error_async
|
||||
)
|
||||
|
||||
def _re_auth(self, token: TokenInterface):
|
||||
with self._connection_lock:
|
||||
self._connection.send_command(
|
||||
"AUTH", token.try_get("oid"), token.get_value()
|
||||
)
|
||||
self._connection.read_response()
|
||||
|
||||
self._connection_pool.re_auth_callback(token)
|
||||
|
||||
async def _re_auth_async(self, token: TokenInterface):
|
||||
async with self._connection_lock:
|
||||
await self._connection.send_command(
|
||||
"AUTH", token.try_get("oid"), token.get_value()
|
||||
)
|
||||
await self._connection.read_response()
|
||||
|
||||
await self._connection_pool.re_auth_callback(token)
|
||||
|
||||
def _raise_on_error(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
async def _raise_on_error_async(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
255
backend/venv/lib/python3.9/site-packages/redis/exceptions.py
Normal file
255
backend/venv/lib/python3.9/site-packages/redis/exceptions.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"Core exceptions raised by the Redis client"
|
||||
|
||||
|
||||
class RedisError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class AuthorizationError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class BusyLoadingError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidResponse(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class ResponseError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class DataError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class PubSubError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class WatchError(RedisError):
|
||||
pass
|
||||
|
||||
|
||||
class NoScriptError(ResponseError):
|
||||
pass
|
||||
|
||||
|
||||
class OutOfMemoryError(ResponseError):
|
||||
"""
|
||||
Indicates the database is full. Can only occur when either:
|
||||
* Redis maxmemory-policy=noeviction
|
||||
* Redis maxmemory-policy=volatile* and there are no evictable keys
|
||||
|
||||
For more information see `Memory optimization in Redis <https://redis.io/docs/management/optimization/memory-optimization/#memory-allocation>`_. # noqa
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecAbortError(ResponseError):
|
||||
pass
|
||||
|
||||
|
||||
class ReadOnlyError(ResponseError):
|
||||
pass
|
||||
|
||||
|
||||
class NoPermissionError(ResponseError):
|
||||
pass
|
||||
|
||||
|
||||
class ModuleError(ResponseError):
|
||||
pass
|
||||
|
||||
|
||||
class LockError(RedisError, ValueError):
|
||||
"Errors acquiring or releasing a lock"
|
||||
|
||||
# NOTE: For backwards compatibility, this class derives from ValueError.
|
||||
# This was originally chosen to behave like threading.Lock.
|
||||
|
||||
def __init__(self, message=None, lock_name=None):
|
||||
self.message = message
|
||||
self.lock_name = lock_name
|
||||
|
||||
|
||||
class LockNotOwnedError(LockError):
|
||||
"Error trying to extend or release a lock that is not owned (anymore)"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChildDeadlockedError(Exception):
|
||||
"Error indicating that a child process is deadlocked after a fork()"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationWrongNumberOfArgsError(ResponseError):
|
||||
"""
|
||||
An error to indicate that the wrong number of args
|
||||
were sent to the AUTH command
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RedisClusterException(Exception):
|
||||
"""
|
||||
Base exception for the RedisCluster client
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ClusterError(RedisError):
|
||||
"""
|
||||
Cluster errors occurred multiple times, resulting in an exhaustion of the
|
||||
command execution TTL
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ClusterDownError(ClusterError, ResponseError):
|
||||
"""
|
||||
Error indicated CLUSTERDOWN error received from cluster.
|
||||
By default Redis Cluster nodes stop accepting queries if they detect there
|
||||
is at least a hash slot uncovered (no available node is serving it).
|
||||
This way if the cluster is partially down (for example a range of hash
|
||||
slots are no longer covered) the entire cluster eventually becomes
|
||||
unavailable. It automatically returns available as soon as all the slots
|
||||
are covered again.
|
||||
"""
|
||||
|
||||
def __init__(self, resp):
|
||||
self.args = (resp,)
|
||||
self.message = resp
|
||||
|
||||
|
||||
class AskError(ResponseError):
|
||||
"""
|
||||
Error indicated ASK error received from cluster.
|
||||
When a slot is set as MIGRATING, the node will accept all queries that
|
||||
pertain to this hash slot, but only if the key in question exists,
|
||||
otherwise the query is forwarded using a -ASK redirection to the node that
|
||||
is target of the migration.
|
||||
|
||||
src node: MIGRATING to dst node
|
||||
get > ASK error
|
||||
ask dst node > ASKING command
|
||||
dst node: IMPORTING from src node
|
||||
asking command only affects next command
|
||||
any op will be allowed after asking command
|
||||
"""
|
||||
|
||||
def __init__(self, resp):
|
||||
"""should only redirect to master node"""
|
||||
self.args = (resp,)
|
||||
self.message = resp
|
||||
slot_id, new_node = resp.split(" ")
|
||||
host, port = new_node.rsplit(":", 1)
|
||||
self.slot_id = int(slot_id)
|
||||
self.node_addr = self.host, self.port = host, int(port)
|
||||
|
||||
|
||||
class TryAgainError(ResponseError):
|
||||
"""
|
||||
Error indicated TRYAGAIN error received from cluster.
|
||||
Operations on keys that don't exist or are - during resharding - split
|
||||
between the source and destination nodes, will generate a -TRYAGAIN error.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ClusterCrossSlotError(ResponseError):
|
||||
"""
|
||||
Error indicated CROSSSLOT error received from cluster.
|
||||
A CROSSSLOT error is generated when keys in a request don't hash to the
|
||||
same slot.
|
||||
"""
|
||||
|
||||
message = "Keys in request don't hash to the same slot"
|
||||
|
||||
|
||||
class MovedError(AskError):
|
||||
"""
|
||||
Error indicated MOVED error received from cluster.
|
||||
A request sent to a node that doesn't serve this key will be replayed with
|
||||
a MOVED error that points to the correct node.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MasterDownError(ClusterDownError):
|
||||
"""
|
||||
Error indicated MASTERDOWN error received from cluster.
|
||||
Link with MASTER is down and replica-serve-stale-data is set to 'no'.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SlotNotCoveredError(RedisClusterException):
|
||||
"""
|
||||
This error only happens in the case where the connection pool will try to
|
||||
fetch what node that is covered by a given slot.
|
||||
|
||||
If this error is raised the client should drop the current node layout and
|
||||
attempt to reconnect and refresh the node layout again
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MaxConnectionsError(ConnectionError):
|
||||
"""
|
||||
Raised when a connection pool has reached its max_connections limit.
|
||||
This indicates pool exhaustion rather than an actual connection failure.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CrossSlotTransactionError(RedisClusterException):
|
||||
"""
|
||||
Raised when a transaction or watch is triggered in a pipeline
|
||||
and not all keys or all commands belong to the same slot.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidPipelineStack(RedisClusterException):
|
||||
"""
|
||||
Raised on unexpected response length on pipelines. This is
|
||||
most likely a handling error on the stack.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExternalAuthProviderError(ConnectionError):
|
||||
"""
|
||||
Raised when an external authentication provider returns an error.
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,425 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import gzip
|
||||
import json
|
||||
import ssl
|
||||
import zlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import urlencode, urljoin
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
__all__ = ["HttpClient", "HttpResponse", "HttpError", "DEFAULT_TIMEOUT"]
|
||||
|
||||
from redis.backoff import ExponentialWithJitterBackoff
|
||||
from redis.retry import Retry
|
||||
from redis.utils import dummy_fail
|
||||
|
||||
DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)"
|
||||
DEFAULT_TIMEOUT = 30.0
|
||||
RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpResponse:
|
||||
status: int
|
||||
headers: Dict[str, str]
|
||||
url: str
|
||||
content: bytes
|
||||
|
||||
def text(self, encoding: Optional[str] = None) -> str:
|
||||
enc = encoding or self._get_encoding()
|
||||
return self.content.decode(enc, errors="replace")
|
||||
|
||||
def json(self) -> Any:
|
||||
return json.loads(self.text(encoding=self._get_encoding()))
|
||||
|
||||
def _get_encoding(self) -> str:
|
||||
# Try to infer encoding from headers; default to utf-8
|
||||
ctype = self.headers.get("content-type", "")
|
||||
# Example: application/json; charset=utf-8
|
||||
for part in ctype.split(";"):
|
||||
p = part.strip()
|
||||
if p.lower().startswith("charset="):
|
||||
return p.split("=", 1)[1].strip() or "utf-8"
|
||||
return "utf-8"
|
||||
|
||||
|
||||
class HttpError(Exception):
|
||||
def __init__(self, status: int, url: str, message: Optional[str] = None):
|
||||
self.status = status
|
||||
self.url = url
|
||||
self.message = message or f"HTTP {status} for {url}"
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class HttpClient:
|
||||
"""
|
||||
A lightweight HTTP client for REST API calls.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "",
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
retry: Retry = Retry(
|
||||
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
|
||||
),
|
||||
verify_tls: bool = True,
|
||||
# TLS verification (server) options
|
||||
ca_file: Optional[str] = None,
|
||||
ca_path: Optional[str] = None,
|
||||
ca_data: Optional[Union[str, bytes]] = None,
|
||||
# Mutual TLS (client cert) options
|
||||
client_cert_file: Optional[str] = None,
|
||||
client_key_file: Optional[str] = None,
|
||||
client_key_password: Optional[str] = None,
|
||||
auth_basic: Optional[Tuple[str, str]] = None, # (username, password)
|
||||
user_agent: str = DEFAULT_USER_AGENT,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a new HTTP client instance.
|
||||
|
||||
Args:
|
||||
base_url: Base URL for all requests. Will be prefixed to all paths.
|
||||
headers: Default headers to include in all requests.
|
||||
timeout: Default timeout in seconds for requests.
|
||||
retry: Retry configuration for failed requests.
|
||||
verify_tls: Whether to verify TLS certificates.
|
||||
ca_file: Path to CA certificate file for TLS verification.
|
||||
ca_path: Path to a directory containing CA certificates.
|
||||
ca_data: CA certificate data as string or bytes.
|
||||
client_cert_file: Path to client certificate for mutual TLS.
|
||||
client_key_file: Path to a client private key for mutual TLS.
|
||||
client_key_password: Password for an encrypted client private key.
|
||||
auth_basic: Tuple of (username, password) for HTTP basic auth.
|
||||
user_agent: User-Agent header value for requests.
|
||||
|
||||
The client supports both regular HTTPS with server verification and mutual TLS
|
||||
authentication. For server verification, provide CA certificate information via
|
||||
ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client
|
||||
certificate and key via client_cert_file and client_key_file.
|
||||
"""
|
||||
self.base_url = (
|
||||
base_url.rstrip() + "/"
|
||||
if base_url and not base_url.endswith("/")
|
||||
else base_url
|
||||
)
|
||||
self._default_headers = {k.lower(): v for k, v in (headers or {}).items()}
|
||||
self.timeout = timeout
|
||||
self.retry = retry
|
||||
self.retry.update_supported_errors((HTTPError, URLError, ssl.SSLError))
|
||||
self.verify_tls = verify_tls
|
||||
|
||||
# TLS settings
|
||||
self.ca_file = ca_file
|
||||
self.ca_path = ca_path
|
||||
self.ca_data = ca_data
|
||||
self.client_cert_file = client_cert_file
|
||||
self.client_key_file = client_key_file
|
||||
self.client_key_password = client_key_password
|
||||
|
||||
self.auth_basic = auth_basic
|
||||
self.user_agent = user_agent
|
||||
|
||||
# Public JSON-centric helpers
|
||||
def get(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
return self._json_call(
|
||||
"GET",
|
||||
path,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
body=None,
|
||||
expect_json=expect_json,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
return self._json_call(
|
||||
"DELETE",
|
||||
path,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
body=None,
|
||||
expect_json=expect_json,
|
||||
)
|
||||
|
||||
def post(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
return self._json_call(
|
||||
"POST",
|
||||
path,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
body=self._prepare_body(json_body=json_body, data=data),
|
||||
expect_json=expect_json,
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
return self._json_call(
|
||||
"PUT",
|
||||
path,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
body=self._prepare_body(json_body=json_body, data=data),
|
||||
expect_json=expect_json,
|
||||
)
|
||||
|
||||
def patch(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
return self._json_call(
|
||||
"PATCH",
|
||||
path,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
body=self._prepare_body(json_body=json_body, data=data),
|
||||
expect_json=expect_json,
|
||||
)
|
||||
|
||||
# Low-level request
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
body: Optional[Union[bytes, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> HttpResponse:
|
||||
url = self._build_url(path, params)
|
||||
all_headers = self._prepare_headers(headers, body)
|
||||
data = body.encode("utf-8") if isinstance(body, str) else body
|
||||
|
||||
req = Request(url=url, method=method.upper(), data=data, headers=all_headers)
|
||||
|
||||
context: Optional[ssl.SSLContext] = None
|
||||
if url.lower().startswith("https"):
|
||||
if self.verify_tls:
|
||||
# Use provided CA material if any; fall back to system defaults
|
||||
context = ssl.create_default_context(
|
||||
cafile=self.ca_file,
|
||||
capath=self.ca_path,
|
||||
cadata=self.ca_data,
|
||||
)
|
||||
# Load client certificate for mTLS if configured
|
||||
if self.client_cert_file:
|
||||
context.load_cert_chain(
|
||||
certfile=self.client_cert_file,
|
||||
keyfile=self.client_key_file,
|
||||
password=self.client_key_password,
|
||||
)
|
||||
else:
|
||||
# Verification disabled
|
||||
context = ssl.create_default_context()
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
try:
|
||||
return self.retry.call_with_retry(
|
||||
lambda: self._make_request(req, context=context, timeout=timeout),
|
||||
lambda _: dummy_fail(),
|
||||
lambda error: self._is_retryable_http_error(error),
|
||||
)
|
||||
except HTTPError as e:
|
||||
# Read error body, build response, and decide on retry
|
||||
err_body = b""
|
||||
try:
|
||||
err_body = e.read()
|
||||
except Exception:
|
||||
pass
|
||||
headers_map = {k.lower(): v for k, v in (e.headers or {}).items()}
|
||||
err_body = self._maybe_decompress(err_body, headers_map)
|
||||
status = getattr(e, "code", 0) or 0
|
||||
response = HttpResponse(
|
||||
status=status,
|
||||
headers=headers_map,
|
||||
url=url,
|
||||
content=err_body,
|
||||
)
|
||||
return response
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
request: Request,
|
||||
context: Optional[ssl.SSLContext] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
with urlopen(request, timeout=timeout or self.timeout, context=context) as resp:
|
||||
raw = resp.read()
|
||||
headers_map = {k.lower(): v for k, v in resp.headers.items()}
|
||||
raw = self._maybe_decompress(raw, headers_map)
|
||||
return HttpResponse(
|
||||
status=resp.status,
|
||||
headers=headers_map,
|
||||
url=resp.geturl(),
|
||||
content=raw,
|
||||
)
|
||||
|
||||
def _is_retryable_http_error(self, error: Exception) -> bool:
|
||||
if isinstance(error, HTTPError):
|
||||
return self._should_retry_status(error.code)
|
||||
return False
|
||||
|
||||
# Internal utilities
|
||||
def _json_call(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
body: Optional[Union[bytes, str]] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
resp = self.request(
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
headers=headers,
|
||||
body=body,
|
||||
timeout=timeout,
|
||||
)
|
||||
if not (200 <= resp.status < 400):
|
||||
raise HttpError(resp.status, resp.url, resp.text())
|
||||
if expect_json:
|
||||
return resp.json()
|
||||
return resp
|
||||
|
||||
def _prepare_body(
|
||||
self, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None
|
||||
) -> Optional[Union[bytes, str]]:
|
||||
if json_body is not None and data is not None:
|
||||
raise ValueError("Provide either json_body or data, not both.")
|
||||
if json_body is not None:
|
||||
return json.dumps(json_body, ensure_ascii=False, separators=(",", ":"))
|
||||
return data
|
||||
|
||||
def _build_url(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
) -> str:
|
||||
url = urljoin(self.base_url or "", path)
|
||||
if params:
|
||||
# urlencode with doseq=True supports list/tuple values
|
||||
query = urlencode(
|
||||
{k: v for k, v in params.items() if v is not None}, doseq=True
|
||||
)
|
||||
separator = "&" if ("?" in url) else "?"
|
||||
url = f"{url}{separator}{query}" if query else url
|
||||
return url
|
||||
|
||||
def _prepare_headers(
|
||||
self, headers: Optional[Mapping[str, str]], body: Optional[Union[bytes, str]]
|
||||
) -> Dict[str, str]:
|
||||
# Start with defaults
|
||||
prepared: Dict[str, str] = {}
|
||||
prepared.update(self._default_headers)
|
||||
|
||||
# Standard defaults for JSON REST usage
|
||||
prepared.setdefault("accept", "application/json")
|
||||
prepared.setdefault("user-agent", self.user_agent)
|
||||
# We will send gzip accept-encoding; handle decompression manually
|
||||
prepared.setdefault("accept-encoding", "gzip, deflate")
|
||||
|
||||
# If we have a string body and content-type not specified, assume JSON
|
||||
if body is not None and isinstance(body, str):
|
||||
prepared.setdefault("content-type", "application/json; charset=utf-8")
|
||||
|
||||
# Basic authentication if provided and not overridden
|
||||
if self.auth_basic and "authorization" not in prepared:
|
||||
user, pwd = self.auth_basic
|
||||
token = base64.b64encode(f"{user}:{pwd}".encode("utf-8")).decode("ascii")
|
||||
prepared["authorization"] = f"Basic {token}"
|
||||
|
||||
# Merge per-call headers (case-insensitive)
|
||||
if headers:
|
||||
for k, v in headers.items():
|
||||
prepared[k.lower()] = v
|
||||
|
||||
# urllib expects header keys in canonical capitalization sometimes; but it’s tolerant.
|
||||
# We'll return as provided; urllib will handle it.
|
||||
return prepared
|
||||
|
||||
def _should_retry_status(self, status: int) -> bool:
|
||||
return status in RETRY_STATUS_CODES
|
||||
|
||||
def _maybe_decompress(self, content: bytes, headers: Mapping[str, str]) -> bytes:
|
||||
if not content:
|
||||
return content
|
||||
encoding = (headers.get("content-encoding") or "").lower()
|
||||
try:
|
||||
if "gzip" in encoding:
|
||||
return gzip.decompress(content)
|
||||
if "deflate" in encoding:
|
||||
# Try raw deflate, then zlib-wrapped
|
||||
try:
|
||||
return zlib.decompress(content, -zlib.MAX_WBITS)
|
||||
except zlib.error:
|
||||
return zlib.decompress(content)
|
||||
except Exception:
|
||||
# If decompression fails, return original bytes
|
||||
return content
|
||||
return content
|
||||
343
backend/venv/lib/python3.9/site-packages/redis/lock.py
Normal file
343
backend/venv/lib/python3.9/site-packages/redis/lock.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import logging
|
||||
import threading
|
||||
import time as mod_time
|
||||
import uuid
|
||||
from types import SimpleNamespace, TracebackType
|
||||
from typing import Optional, Type
|
||||
|
||||
from redis.exceptions import LockError, LockNotOwnedError
|
||||
from redis.typing import Number
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Lock:
|
||||
"""
|
||||
A shared, distributed Lock. Using Redis for locking allows the Lock
|
||||
to be shared across processes and/or machines.
|
||||
|
||||
It's left to the user to resolve deadlock issues and make sure
|
||||
multiple clients play nicely together.
|
||||
"""
|
||||
|
||||
lua_release = None
|
||||
lua_extend = None
|
||||
lua_reacquire = None
|
||||
|
||||
# KEYS[1] - lock name
|
||||
# ARGV[1] - token
|
||||
# return 1 if the lock was released, otherwise 0
|
||||
LUA_RELEASE_SCRIPT = """
|
||||
local token = redis.call('get', KEYS[1])
|
||||
if not token or token ~= ARGV[1] then
|
||||
return 0
|
||||
end
|
||||
redis.call('del', KEYS[1])
|
||||
return 1
|
||||
"""
|
||||
|
||||
# KEYS[1] - lock name
|
||||
# ARGV[1] - token
|
||||
# ARGV[2] - additional milliseconds
|
||||
# ARGV[3] - "0" if the additional time should be added to the lock's
|
||||
# existing ttl or "1" if the existing ttl should be replaced
|
||||
# return 1 if the locks time was extended, otherwise 0
|
||||
LUA_EXTEND_SCRIPT = """
|
||||
local token = redis.call('get', KEYS[1])
|
||||
if not token or token ~= ARGV[1] then
|
||||
return 0
|
||||
end
|
||||
local expiration = redis.call('pttl', KEYS[1])
|
||||
if not expiration then
|
||||
expiration = 0
|
||||
end
|
||||
if expiration < 0 then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newttl = ARGV[2]
|
||||
if ARGV[3] == "0" then
|
||||
newttl = ARGV[2] + expiration
|
||||
end
|
||||
redis.call('pexpire', KEYS[1], newttl)
|
||||
return 1
|
||||
"""
|
||||
|
||||
# KEYS[1] - lock name
|
||||
# ARGV[1] - token
|
||||
# ARGV[2] - milliseconds
|
||||
# return 1 if the locks time was reacquired, otherwise 0
|
||||
LUA_REACQUIRE_SCRIPT = """
|
||||
local token = redis.call('get', KEYS[1])
|
||||
if not token or token ~= ARGV[1] then
|
||||
return 0
|
||||
end
|
||||
redis.call('pexpire', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis,
|
||||
name: str,
|
||||
timeout: Optional[Number] = None,
|
||||
sleep: Number = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
thread_local: bool = True,
|
||||
raise_on_release_error: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a new Lock instance named ``name`` using the Redis client
|
||||
supplied by ``redis``.
|
||||
|
||||
``timeout`` indicates a maximum life for the lock in seconds.
|
||||
By default, it will remain locked until release() is called.
|
||||
``timeout`` can be specified as a float or integer, both representing
|
||||
the number of seconds to wait.
|
||||
|
||||
``sleep`` indicates the amount of time to sleep in seconds per loop
|
||||
iteration when the lock is in blocking mode and another client is
|
||||
currently holding the lock.
|
||||
|
||||
``blocking`` indicates whether calling ``acquire`` should block until
|
||||
the lock has been acquired or to fail immediately, causing ``acquire``
|
||||
to return False and the lock not being acquired. Defaults to True.
|
||||
Note this value can be overridden by passing a ``blocking``
|
||||
argument to ``acquire``.
|
||||
|
||||
``blocking_timeout`` indicates the maximum amount of time in seconds to
|
||||
spend trying to acquire the lock. A value of ``None`` indicates
|
||||
continue trying forever. ``blocking_timeout`` can be specified as a
|
||||
float or integer, both representing the number of seconds to wait.
|
||||
|
||||
``thread_local`` indicates whether the lock token is placed in
|
||||
thread-local storage. By default, the token is placed in thread local
|
||||
storage so that a thread only sees its token, not a token set by
|
||||
another thread. Consider the following timeline:
|
||||
|
||||
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
|
||||
thread-1 sets the token to "abc"
|
||||
time: 1, thread-2 blocks trying to acquire `my-lock` using the
|
||||
Lock instance.
|
||||
time: 5, thread-1 has not yet completed. redis expires the lock
|
||||
key.
|
||||
time: 5, thread-2 acquired `my-lock` now that it's available.
|
||||
thread-2 sets the token to "xyz"
|
||||
time: 6, thread-1 finishes its work and calls release(). if the
|
||||
token is *not* stored in thread local storage, then
|
||||
thread-1 would see the token value as "xyz" and would be
|
||||
able to successfully release the thread-2's lock.
|
||||
|
||||
``raise_on_release_error`` indicates whether to raise an exception when
|
||||
the lock is no longer owned when exiting the context manager. By default,
|
||||
this is True, meaning an exception will be raised. If False, the warning
|
||||
will be logged and the exception will be suppressed.
|
||||
|
||||
In some use cases it's necessary to disable thread local storage. For
|
||||
example, if you have code where one thread acquires a lock and passes
|
||||
that lock instance to a worker thread to release later. If thread
|
||||
local storage isn't disabled in this case, the worker thread won't see
|
||||
the token set by the thread that acquired the lock. Our assumption
|
||||
is that these cases aren't common and as such default to using
|
||||
thread local storage.
|
||||
"""
|
||||
self.redis = redis
|
||||
self.name = name
|
||||
self.timeout = timeout
|
||||
self.sleep = sleep
|
||||
self.blocking = blocking
|
||||
self.blocking_timeout = blocking_timeout
|
||||
self.thread_local = bool(thread_local)
|
||||
self.raise_on_release_error = raise_on_release_error
|
||||
self.local = threading.local() if self.thread_local else SimpleNamespace()
|
||||
self.local.token = None
|
||||
self.register_scripts()
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
cls = self.__class__
|
||||
client = self.redis
|
||||
if cls.lua_release is None:
|
||||
cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
|
||||
if cls.lua_extend is None:
|
||||
cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
|
||||
if cls.lua_reacquire is None:
|
||||
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
|
||||
|
||||
def __enter__(self) -> "Lock":
|
||||
if self.acquire():
|
||||
return self
|
||||
raise LockError(
|
||||
"Unable to acquire lock within the time specified",
|
||||
lock_name=self.name,
|
||||
)
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
try:
|
||||
self.release()
|
||||
except LockError:
|
||||
if self.raise_on_release_error:
|
||||
raise
|
||||
logger.warning(
|
||||
"Lock was unlocked or no longer owned when exiting context manager."
|
||||
)
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
sleep: Optional[Number] = None,
|
||||
blocking: Optional[bool] = None,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
token: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Use Redis to hold a shared, distributed lock named ``name``.
|
||||
Returns True once the lock is acquired.
|
||||
|
||||
If ``blocking`` is False, always return immediately. If the lock
|
||||
was acquired, return True, otherwise return False.
|
||||
|
||||
``blocking_timeout`` specifies the maximum number of seconds to
|
||||
wait trying to acquire the lock.
|
||||
|
||||
``token`` specifies the token value to be used. If provided, token
|
||||
must be a bytes object or a string that can be encoded to a bytes
|
||||
object with the default encoding. If a token isn't specified, a UUID
|
||||
will be generated.
|
||||
"""
|
||||
if sleep is None:
|
||||
sleep = self.sleep
|
||||
if token is None:
|
||||
token = uuid.uuid1().hex.encode()
|
||||
else:
|
||||
encoder = self.redis.get_encoder()
|
||||
token = encoder.encode(token)
|
||||
if blocking is None:
|
||||
blocking = self.blocking
|
||||
if blocking_timeout is None:
|
||||
blocking_timeout = self.blocking_timeout
|
||||
stop_trying_at = None
|
||||
if blocking_timeout is not None:
|
||||
stop_trying_at = mod_time.monotonic() + blocking_timeout
|
||||
while True:
|
||||
if self.do_acquire(token):
|
||||
self.local.token = token
|
||||
return True
|
||||
if not blocking:
|
||||
return False
|
||||
next_try_at = mod_time.monotonic() + sleep
|
||||
if stop_trying_at is not None and next_try_at > stop_trying_at:
|
||||
return False
|
||||
mod_time.sleep(sleep)
|
||||
|
||||
def do_acquire(self, token: str) -> bool:
|
||||
if self.timeout:
|
||||
# convert to milliseconds
|
||||
timeout = int(self.timeout * 1000)
|
||||
else:
|
||||
timeout = None
|
||||
if self.redis.set(self.name, token, nx=True, px=timeout):
|
||||
return True
|
||||
return False
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""
|
||||
Returns True if this key is locked by any process, otherwise False.
|
||||
"""
|
||||
return self.redis.get(self.name) is not None
|
||||
|
||||
def owned(self) -> bool:
|
||||
"""
|
||||
Returns True if this key is locked by this lock, otherwise False.
|
||||
"""
|
||||
stored_token = self.redis.get(self.name)
|
||||
# need to always compare bytes to bytes
|
||||
# TODO: this can be simplified when the context manager is finished
|
||||
if stored_token and not isinstance(stored_token, bytes):
|
||||
encoder = self.redis.get_encoder()
|
||||
stored_token = encoder.encode(stored_token)
|
||||
return self.local.token is not None and stored_token == self.local.token
|
||||
|
||||
def release(self) -> None:
|
||||
"""
|
||||
Releases the already acquired lock
|
||||
"""
|
||||
expected_token = self.local.token
|
||||
if expected_token is None:
|
||||
raise LockError(
|
||||
"Cannot release a lock that's not owned or is already unlocked.",
|
||||
lock_name=self.name,
|
||||
)
|
||||
self.local.token = None
|
||||
self.do_release(expected_token)
|
||||
|
||||
def do_release(self, expected_token: str) -> None:
|
||||
if not bool(
|
||||
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
|
||||
):
|
||||
raise LockNotOwnedError(
|
||||
"Cannot release a lock that's no longer owned",
|
||||
lock_name=self.name,
|
||||
)
|
||||
|
||||
def extend(self, additional_time: Number, replace_ttl: bool = False) -> bool:
|
||||
"""
|
||||
Adds more time to an already acquired lock.
|
||||
|
||||
``additional_time`` can be specified as an integer or a float, both
|
||||
representing the number of seconds to add.
|
||||
|
||||
``replace_ttl`` if False (the default), add `additional_time` to
|
||||
the lock's existing ttl. If True, replace the lock's ttl with
|
||||
`additional_time`.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot extend an unlocked lock", lock_name=self.name)
|
||||
if self.timeout is None:
|
||||
raise LockError("Cannot extend a lock with no timeout", lock_name=self.name)
|
||||
return self.do_extend(additional_time, replace_ttl)
|
||||
|
||||
def do_extend(self, additional_time: Number, replace_ttl: bool) -> bool:
|
||||
additional_time = int(additional_time * 1000)
|
||||
if not bool(
|
||||
self.lua_extend(
|
||||
keys=[self.name],
|
||||
args=[self.local.token, additional_time, "1" if replace_ttl else "0"],
|
||||
client=self.redis,
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError(
|
||||
"Cannot extend a lock that's no longer owned",
|
||||
lock_name=self.name,
|
||||
)
|
||||
return True
|
||||
|
||||
def reacquire(self) -> bool:
|
||||
"""
|
||||
Resets a TTL of an already acquired lock back to a timeout value.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot reacquire an unlocked lock", lock_name=self.name)
|
||||
if self.timeout is None:
|
||||
raise LockError(
|
||||
"Cannot reacquire a lock with no timeout",
|
||||
lock_name=self.name,
|
||||
)
|
||||
return self.do_reacquire()
|
||||
|
||||
def do_reacquire(self) -> bool:
|
||||
timeout = int(self.timeout * 1000)
|
||||
if not bool(
|
||||
self.lua_reacquire(
|
||||
keys=[self.name], args=[self.local.token, timeout], client=self.redis
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError(
|
||||
"Cannot reacquire a lock that's no longer owned",
|
||||
lock_name=self.name,
|
||||
)
|
||||
return True
|
||||
@@ -0,0 +1,810 @@
|
||||
import enum
|
||||
import ipaddress
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
from redis.typing import Number
|
||||
|
||||
|
||||
class MaintenanceState(enum.Enum):
|
||||
NONE = "none"
|
||||
MOVING = "moving"
|
||||
MAINTENANCE = "maintenance"
|
||||
|
||||
|
||||
class EndpointType(enum.Enum):
|
||||
"""Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
|
||||
|
||||
INTERNAL_IP = "internal-ip"
|
||||
INTERNAL_FQDN = "internal-fqdn"
|
||||
EXTERNAL_IP = "external-ip"
|
||||
EXTERNAL_FQDN = "external-fqdn"
|
||||
NONE = "none"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string value of the enum."""
|
||||
return self.value
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.connection import (
|
||||
MaintNotificationsAbstractConnection,
|
||||
MaintNotificationsAbstractConnectionPool,
|
||||
)
|
||||
|
||||
|
||||
class MaintenanceNotification(ABC):
|
||||
"""
|
||||
Base class for maintenance notifications sent through push messages by Redis server.
|
||||
|
||||
This class provides common functionality for all maintenance notifications including
|
||||
unique identification and TTL (Time-To-Live) functionality.
|
||||
|
||||
Attributes:
|
||||
id (int): Unique identifier for this notification
|
||||
ttl (int): Time-to-live in seconds for this notification
|
||||
creation_time (float): Timestamp when the notification was created/read
|
||||
"""
|
||||
|
||||
def __init__(self, id: int, ttl: int):
|
||||
"""
|
||||
Initialize a new MaintenanceNotification with unique ID and TTL functionality.
|
||||
|
||||
Args:
|
||||
id (int): Unique identifier for this notification
|
||||
ttl (int): Time-to-live in seconds for this notification
|
||||
"""
|
||||
self.id = id
|
||||
self.ttl = ttl
|
||||
self.creation_time = time.monotonic()
|
||||
self.expire_at = self.creation_time + self.ttl
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""
|
||||
Check if this notification has expired based on its TTL
|
||||
and creation time.
|
||||
|
||||
Returns:
|
||||
bool: True if the notification has expired, False otherwise
|
||||
"""
|
||||
return time.monotonic() > (self.creation_time + self.ttl)
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Return a string representation of the maintenance notification.
|
||||
|
||||
This method must be implemented by all concrete subclasses.
|
||||
|
||||
Returns:
|
||||
str: String representation of the notification
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, other) -> bool:
|
||||
"""
|
||||
Compare two maintenance notifications for equality.
|
||||
|
||||
This method must be implemented by all concrete subclasses.
|
||||
Notifications are typically considered equal if they have the same id
|
||||
and are of the same type.
|
||||
|
||||
Args:
|
||||
other: The other object to compare with
|
||||
|
||||
Returns:
|
||||
bool: True if the notifications are equal, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Return a hash value for the maintenance notification.
|
||||
|
||||
This method must be implemented by all concrete subclasses to allow
|
||||
instances to be used in sets and as dictionary keys.
|
||||
|
||||
Returns:
|
||||
int: Hash value for the notification
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NodeMovingNotification(MaintenanceNotification):
|
||||
"""
|
||||
This notification is received when a node is replaced with a new node
|
||||
during cluster rebalancing or maintenance operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
new_node_host: Optional[str],
|
||||
new_node_port: Optional[int],
|
||||
ttl: int,
|
||||
):
|
||||
"""
|
||||
Initialize a new NodeMovingNotification.
|
||||
|
||||
Args:
|
||||
id (int): Unique identifier for this notification
|
||||
new_node_host (str): Hostname or IP address of the new replacement node
|
||||
new_node_port (int): Port number of the new replacement node
|
||||
ttl (int): Time-to-live in seconds for this notification
|
||||
"""
|
||||
super().__init__(id, ttl)
|
||||
self.new_node_host = new_node_host
|
||||
self.new_node_port = new_node_port
|
||||
|
||||
def __repr__(self) -> str:
|
||||
expiry_time = self.expire_at
|
||||
remaining = max(0, expiry_time - time.monotonic())
|
||||
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"id={self.id}, "
|
||||
f"new_node_host='{self.new_node_host}', "
|
||||
f"new_node_port={self.new_node_port}, "
|
||||
f"ttl={self.ttl}, "
|
||||
f"creation_time={self.creation_time}, "
|
||||
f"expires_at={expiry_time}, "
|
||||
f"remaining={remaining:.1f}s, "
|
||||
f"expired={self.is_expired()}"
|
||||
f")"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""
|
||||
Two NodeMovingNotification notifications are considered equal if they have the same
|
||||
id, new_node_host, and new_node_port.
|
||||
"""
|
||||
if not isinstance(other, NodeMovingNotification):
|
||||
return False
|
||||
return (
|
||||
self.id == other.id
|
||||
and self.new_node_host == other.new_node_host
|
||||
and self.new_node_port == other.new_node_port
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Return a hash value for the notification to allow
|
||||
instances to be used in sets and as dictionary keys.
|
||||
|
||||
Returns:
|
||||
int: Hash value based on notification type class name, id,
|
||||
new_node_host and new_node_port
|
||||
"""
|
||||
try:
|
||||
node_port = int(self.new_node_port) if self.new_node_port else None
|
||||
except ValueError:
|
||||
node_port = 0
|
||||
|
||||
return hash(
|
||||
(
|
||||
self.__class__.__name__,
|
||||
int(self.id),
|
||||
str(self.new_node_host),
|
||||
node_port,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class NodeMigratingNotification(MaintenanceNotification):
|
||||
"""
|
||||
Notification for when a Redis cluster node is in the process of migrating slots.
|
||||
|
||||
This notification is received when a node starts migrating its slots to another node
|
||||
during cluster rebalancing or maintenance operations.
|
||||
|
||||
Args:
|
||||
id (int): Unique identifier for this notification
|
||||
ttl (int): Time-to-live in seconds for this notification
|
||||
"""
|
||||
|
||||
def __init__(self, id: int, ttl: int):
|
||||
super().__init__(id, ttl)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
expiry_time = self.creation_time + self.ttl
|
||||
remaining = max(0, expiry_time - time.monotonic())
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"id={self.id}, "
|
||||
f"ttl={self.ttl}, "
|
||||
f"creation_time={self.creation_time}, "
|
||||
f"expires_at={expiry_time}, "
|
||||
f"remaining={remaining:.1f}s, "
|
||||
f"expired={self.is_expired()}"
|
||||
f")"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""
|
||||
Two NodeMigratingNotification notifications are considered equal if they have the same
|
||||
id and are of the same type.
|
||||
"""
|
||||
if not isinstance(other, NodeMigratingNotification):
|
||||
return False
|
||||
return self.id == other.id and type(self) is type(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Return a hash value for the notification to allow
|
||||
instances to be used in sets and as dictionary keys.
|
||||
|
||||
Returns:
|
||||
int: Hash value based on notification type and id
|
||||
"""
|
||||
return hash((self.__class__.__name__, int(self.id)))
|
||||
|
||||
|
||||
class NodeMigratedNotification(MaintenanceNotification):
|
||||
"""
|
||||
Notification for when a Redis cluster node has completed migrating slots.
|
||||
|
||||
This notification is received when a node has finished migrating all its slots
|
||||
to other nodes during cluster rebalancing or maintenance operations.
|
||||
|
||||
Args:
|
||||
id (int): Unique identifier for this notification
|
||||
"""
|
||||
|
||||
DEFAULT_TTL = 5
|
||||
|
||||
def __init__(self, id: int):
|
||||
super().__init__(id, NodeMigratedNotification.DEFAULT_TTL)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
expiry_time = self.creation_time + self.ttl
|
||||
remaining = max(0, expiry_time - time.monotonic())
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"id={self.id}, "
|
||||
f"ttl={self.ttl}, "
|
||||
f"creation_time={self.creation_time}, "
|
||||
f"expires_at={expiry_time}, "
|
||||
f"remaining={remaining:.1f}s, "
|
||||
f"expired={self.is_expired()}"
|
||||
f")"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""
|
||||
Two NodeMigratedNotification notifications are considered equal if they have the same
|
||||
id and are of the same type.
|
||||
"""
|
||||
if not isinstance(other, NodeMigratedNotification):
|
||||
return False
|
||||
return self.id == other.id and type(self) is type(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Return a hash value for the notification to allow
|
||||
instances to be used in sets and as dictionary keys.
|
||||
|
||||
Returns:
|
||||
int: Hash value based on notification type and id
|
||||
"""
|
||||
return hash((self.__class__.__name__, int(self.id)))
|
||||
|
||||
|
||||
class NodeFailingOverNotification(MaintenanceNotification):
|
||||
"""
|
||||
Notification for when a Redis cluster node is in the process of failing over.
|
||||
|
||||
This notification is received when a node starts a failover process during
|
||||
cluster maintenance operations or when handling node failures.
|
||||
|
||||
Args:
|
||||
id (int): Unique identifier for this notification
|
||||
ttl (int): Time-to-live in seconds for this notification
|
||||
"""
|
||||
|
||||
def __init__(self, id: int, ttl: int):
|
||||
super().__init__(id, ttl)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
expiry_time = self.creation_time + self.ttl
|
||||
remaining = max(0, expiry_time - time.monotonic())
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"id={self.id}, "
|
||||
f"ttl={self.ttl}, "
|
||||
f"creation_time={self.creation_time}, "
|
||||
f"expires_at={expiry_time}, "
|
||||
f"remaining={remaining:.1f}s, "
|
||||
f"expired={self.is_expired()}"
|
||||
f")"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""
|
||||
Two NodeFailingOverNotification notifications are considered equal if they have the same
|
||||
id and are of the same type.
|
||||
"""
|
||||
if not isinstance(other, NodeFailingOverNotification):
|
||||
return False
|
||||
return self.id == other.id and type(self) is type(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Return a hash value for the notification to allow
|
||||
instances to be used in sets and as dictionary keys.
|
||||
|
||||
Returns:
|
||||
int: Hash value based on notification type and id
|
||||
"""
|
||||
return hash((self.__class__.__name__, int(self.id)))
|
||||
|
||||
|
||||
class NodeFailedOverNotification(MaintenanceNotification):
|
||||
"""
|
||||
Notification for when a Redis cluster node has completed a failover.
|
||||
|
||||
This notification is received when a node has finished the failover process
|
||||
during cluster maintenance operations or after handling node failures.
|
||||
|
||||
Args:
|
||||
id (int): Unique identifier for this notification
|
||||
"""
|
||||
|
||||
DEFAULT_TTL = 5
|
||||
|
||||
def __init__(self, id: int):
|
||||
super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
expiry_time = self.creation_time + self.ttl
|
||||
remaining = max(0, expiry_time - time.monotonic())
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"id={self.id}, "
|
||||
f"ttl={self.ttl}, "
|
||||
f"creation_time={self.creation_time}, "
|
||||
f"expires_at={expiry_time}, "
|
||||
f"remaining={remaining:.1f}s, "
|
||||
f"expired={self.is_expired()}"
|
||||
f")"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""
|
||||
Two NodeFailedOverNotification notifications are considered equal if they have the same
|
||||
id and are of the same type.
|
||||
"""
|
||||
if not isinstance(other, NodeFailedOverNotification):
|
||||
return False
|
||||
return self.id == other.id and type(self) is type(other)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Return a hash value for the notification to allow
|
||||
instances to be used in sets and as dictionary keys.
|
||||
|
||||
Returns:
|
||||
int: Hash value based on notification type and id
|
||||
"""
|
||||
return hash((self.__class__.__name__, int(self.id)))
|
||||
|
||||
|
||||
def _is_private_fqdn(host: str) -> bool:
|
||||
"""
|
||||
Determine if an FQDN is likely to be internal/private.
|
||||
|
||||
This uses heuristics based on RFC 952 and RFC 1123 standards:
|
||||
- .local domains (RFC 6762 - Multicast DNS)
|
||||
- .internal domains (common internal convention)
|
||||
- Single-label hostnames (no dots)
|
||||
- Common internal TLDs
|
||||
|
||||
Args:
|
||||
host (str): The FQDN to check
|
||||
|
||||
Returns:
|
||||
bool: True if the FQDN appears to be internal/private
|
||||
"""
|
||||
host_lower = host.lower().rstrip(".")
|
||||
|
||||
# Single-label hostnames (no dots) are typically internal
|
||||
if "." not in host_lower:
|
||||
return True
|
||||
|
||||
# Common internal/private domain patterns
|
||||
internal_patterns = [
|
||||
r"\.local$", # mDNS/Bonjour domains
|
||||
r"\.internal$", # Common internal convention
|
||||
r"\.corp$", # Corporate domains
|
||||
r"\.lan$", # Local area network
|
||||
r"\.intranet$", # Intranet domains
|
||||
r"\.private$", # Private domains
|
||||
]
|
||||
|
||||
for pattern in internal_patterns:
|
||||
if re.search(pattern, host_lower):
|
||||
return True
|
||||
|
||||
# If none of the internal patterns match, assume it's external
|
||||
return False
|
||||
|
||||
|
||||
class MaintNotificationsConfig:
|
||||
"""
|
||||
Configuration class for maintenance notifications handling behaviour. Notifications are received through
|
||||
push notifications.
|
||||
|
||||
This class defines how the Redis client should react to different push notifications
|
||||
such as node moving, migrations, etc. in a Redis cluster.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: Union[bool, Literal["auto"]] = "auto",
|
||||
proactive_reconnect: bool = True,
|
||||
relaxed_timeout: Optional[Number] = 10,
|
||||
endpoint_type: Optional[EndpointType] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new MaintNotificationsConfig.
|
||||
|
||||
Args:
|
||||
enabled (bool | "auto"): Controls maintenance notifications handling behavior.
|
||||
- True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup,
|
||||
otherwise a ResponseError is raised.
|
||||
- "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are
|
||||
gracefully handled - a warning is logged and normal operation continues.
|
||||
- False: Maintenance notifications are completely disabled.
|
||||
Defaults to "auto".
|
||||
proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
|
||||
Defaults to True.
|
||||
relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance.
|
||||
If -1 is provided - the relaxed timeout is disabled. Defaults to 20.
|
||||
endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
|
||||
If None, the endpoint type will be automatically determined based on the host and TLS configuration.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint_type is provided but is not a valid endpoint type.
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.relaxed_timeout = relaxed_timeout
|
||||
self.proactive_reconnect = proactive_reconnect
|
||||
self.endpoint_type = endpoint_type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"enabled={self.enabled}, "
|
||||
f"proactive_reconnect={self.proactive_reconnect}, "
|
||||
f"relaxed_timeout={self.relaxed_timeout}, "
|
||||
f"endpoint_type={self.endpoint_type!r}"
|
||||
f")"
|
||||
)
|
||||
|
||||
def is_relaxed_timeouts_enabled(self) -> bool:
|
||||
"""
|
||||
Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout.
|
||||
If relaxed_timeout is set to None, it will make the operation blocking
|
||||
and waiting until any response is received.
|
||||
|
||||
Returns:
|
||||
True if the relaxed_timeout is enabled, False otherwise.
|
||||
"""
|
||||
return self.relaxed_timeout != -1
|
||||
|
||||
def get_endpoint_type(
|
||||
self, host: str, connection: "MaintNotificationsAbstractConnection"
|
||||
) -> EndpointType:
|
||||
"""
|
||||
Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command.
|
||||
|
||||
Logic:
|
||||
1. If endpoint_type is explicitly set, use it
|
||||
2. Otherwise, check the original host from connection.host:
|
||||
- If host is an IP address, use it directly to determine internal-ip vs external-ip
|
||||
- If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn
|
||||
|
||||
Args:
|
||||
host: User provided hostname to analyze
|
||||
connection: The connection object to analyze for endpoint type determination
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
||||
# If endpoint_type is explicitly set, use it
|
||||
if self.endpoint_type is not None:
|
||||
return self.endpoint_type
|
||||
|
||||
# Check if the host is an IP address
|
||||
try:
|
||||
ip_addr = ipaddress.ip_address(host)
|
||||
# Host is an IP address - use it directly
|
||||
is_private = ip_addr.is_private
|
||||
return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP
|
||||
except ValueError:
|
||||
# Host is an FQDN - need to check resolved IP to determine internal vs external
|
||||
pass
|
||||
|
||||
# Host is an FQDN, get the resolved IP to determine if it's internal or external
|
||||
resolved_ip = connection.get_resolved_ip()
|
||||
|
||||
if resolved_ip:
|
||||
try:
|
||||
ip_addr = ipaddress.ip_address(resolved_ip)
|
||||
is_private = ip_addr.is_private
|
||||
# Use FQDN types since the original host was an FQDN
|
||||
return (
|
||||
EndpointType.INTERNAL_FQDN
|
||||
if is_private
|
||||
else EndpointType.EXTERNAL_FQDN
|
||||
)
|
||||
except ValueError:
|
||||
# This shouldn't happen since we got the IP from the socket, but fallback
|
||||
pass
|
||||
|
||||
# Final fallback: use heuristics on the FQDN itself
|
||||
is_private = _is_private_fqdn(host)
|
||||
return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN
|
||||
|
||||
|
||||
class MaintNotificationsPoolHandler:
|
||||
def __init__(
|
||||
self,
|
||||
pool: "MaintNotificationsAbstractConnectionPool",
|
||||
config: MaintNotificationsConfig,
|
||||
) -> None:
|
||||
self.pool = pool
|
||||
self.config = config
|
||||
self._processed_notifications = set()
|
||||
self._lock = threading.RLock()
|
||||
self.connection = None
|
||||
|
||||
def set_connection(self, connection: "MaintNotificationsAbstractConnection"):
|
||||
self.connection = connection
|
||||
|
||||
def get_handler_for_connection(self):
|
||||
# Copy all data that should be shared between connections
|
||||
# but each connection should have its own pool handler
|
||||
# since each connection can be in a different state
|
||||
copy = MaintNotificationsPoolHandler(self.pool, self.config)
|
||||
copy._processed_notifications = self._processed_notifications
|
||||
copy._lock = self._lock
|
||||
copy.connection = None
|
||||
return copy
|
||||
|
||||
def remove_expired_notifications(self):
|
||||
with self._lock:
|
||||
for notification in tuple(self._processed_notifications):
|
||||
if notification.is_expired():
|
||||
self._processed_notifications.remove(notification)
|
||||
|
||||
def handle_notification(self, notification: MaintenanceNotification):
|
||||
self.remove_expired_notifications()
|
||||
|
||||
if isinstance(notification, NodeMovingNotification):
|
||||
return self.handle_node_moving_notification(notification)
|
||||
else:
|
||||
logging.error(f"Unhandled notification type: {notification}")
|
||||
|
||||
def handle_node_moving_notification(self, notification: NodeMovingNotification):
|
||||
if (
|
||||
not self.config.proactive_reconnect
|
||||
and not self.config.is_relaxed_timeouts_enabled()
|
||||
):
|
||||
return
|
||||
with self._lock:
|
||||
if notification in self._processed_notifications:
|
||||
# nothing to do in the connection pool handling
|
||||
# the notification has already been handled or is expired
|
||||
# just return
|
||||
return
|
||||
|
||||
with self.pool._lock:
|
||||
if (
|
||||
self.config.proactive_reconnect
|
||||
or self.config.is_relaxed_timeouts_enabled()
|
||||
):
|
||||
# Get the current connected address - if any
|
||||
# This is the address that is being moved
|
||||
# and we need to handle only connections
|
||||
# connected to the same address
|
||||
moving_address_src = (
|
||||
self.connection.getpeername() if self.connection else None
|
||||
)
|
||||
|
||||
if getattr(self.pool, "set_in_maintenance", False):
|
||||
# Set pool in maintenance mode - executed only if
|
||||
# BlockingConnectionPool is used
|
||||
self.pool.set_in_maintenance(True)
|
||||
|
||||
# Update maintenance state, timeout and optionally host address
|
||||
# connection settings for matching connections
|
||||
self.pool.update_connections_settings(
|
||||
state=MaintenanceState.MOVING,
|
||||
maintenance_notification_hash=hash(notification),
|
||||
relaxed_timeout=self.config.relaxed_timeout,
|
||||
host_address=notification.new_node_host,
|
||||
matching_address=moving_address_src,
|
||||
matching_pattern="connected_address",
|
||||
update_notification_hash=True,
|
||||
include_free_connections=True,
|
||||
)
|
||||
|
||||
if self.config.proactive_reconnect:
|
||||
if notification.new_node_host is not None:
|
||||
self.run_proactive_reconnect(moving_address_src)
|
||||
else:
|
||||
threading.Timer(
|
||||
notification.ttl / 2,
|
||||
self.run_proactive_reconnect,
|
||||
args=(moving_address_src,),
|
||||
).start()
|
||||
|
||||
# Update config for new connections:
|
||||
# Set state to MOVING
|
||||
# update host
|
||||
# if relax timeouts are enabled - update timeouts
|
||||
kwargs: dict = {
|
||||
"maintenance_state": MaintenanceState.MOVING,
|
||||
"maintenance_notification_hash": hash(notification),
|
||||
}
|
||||
if notification.new_node_host is not None:
|
||||
# the host is not updated if the new node host is None
|
||||
# this happens when the MOVING push notification does not contain
|
||||
# the new node host - in this case we only update the timeouts
|
||||
kwargs.update(
|
||||
{
|
||||
"host": notification.new_node_host,
|
||||
}
|
||||
)
|
||||
if self.config.is_relaxed_timeouts_enabled():
|
||||
kwargs.update(
|
||||
{
|
||||
"socket_timeout": self.config.relaxed_timeout,
|
||||
"socket_connect_timeout": self.config.relaxed_timeout,
|
||||
}
|
||||
)
|
||||
self.pool.update_connection_kwargs(**kwargs)
|
||||
|
||||
if getattr(self.pool, "set_in_maintenance", False):
|
||||
self.pool.set_in_maintenance(False)
|
||||
|
||||
threading.Timer(
|
||||
notification.ttl,
|
||||
self.handle_node_moved_notification,
|
||||
args=(notification,),
|
||||
).start()
|
||||
|
||||
self._processed_notifications.add(notification)
|
||||
|
||||
def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
|
||||
"""
|
||||
Run proactive reconnect for the pool.
|
||||
Active connections are marked for reconnect after they complete the current command.
|
||||
Inactive connections are disconnected and will be connected on next use.
|
||||
"""
|
||||
with self._lock:
|
||||
with self.pool._lock:
|
||||
# take care for the active connections in the pool
|
||||
# mark them for reconnect after they complete the current command
|
||||
self.pool.update_active_connections_for_reconnect(
|
||||
moving_address_src=moving_address_src,
|
||||
)
|
||||
# take care for the inactive connections in the pool
|
||||
# delete them and create new ones
|
||||
self.pool.disconnect_free_connections(
|
||||
moving_address_src=moving_address_src,
|
||||
)
|
||||
|
||||
def handle_node_moved_notification(self, notification: NodeMovingNotification):
|
||||
"""
|
||||
Handle the cleanup after a node moving notification expires.
|
||||
"""
|
||||
notification_hash = hash(notification)
|
||||
|
||||
with self._lock:
|
||||
# if the current maintenance_notification_hash in kwargs is not matching the notification
|
||||
# it means there has been a new moving notification after this one
|
||||
# and we don't need to revert the kwargs yet
|
||||
if (
|
||||
self.pool.connection_kwargs.get("maintenance_notification_hash")
|
||||
== notification_hash
|
||||
):
|
||||
orig_host = self.pool.connection_kwargs.get("orig_host_address")
|
||||
orig_socket_timeout = self.pool.connection_kwargs.get(
|
||||
"orig_socket_timeout"
|
||||
)
|
||||
orig_connect_timeout = self.pool.connection_kwargs.get(
|
||||
"orig_socket_connect_timeout"
|
||||
)
|
||||
kwargs: dict = {
|
||||
"maintenance_state": MaintenanceState.NONE,
|
||||
"maintenance_notification_hash": None,
|
||||
"host": orig_host,
|
||||
"socket_timeout": orig_socket_timeout,
|
||||
"socket_connect_timeout": orig_connect_timeout,
|
||||
}
|
||||
self.pool.update_connection_kwargs(**kwargs)
|
||||
|
||||
with self.pool._lock:
|
||||
reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled()
|
||||
reset_host_address = self.config.proactive_reconnect
|
||||
|
||||
self.pool.update_connections_settings(
|
||||
relaxed_timeout=-1,
|
||||
state=MaintenanceState.NONE,
|
||||
maintenance_notification_hash=None,
|
||||
matching_notification_hash=notification_hash,
|
||||
matching_pattern="notification_hash",
|
||||
update_notification_hash=True,
|
||||
reset_relaxed_timeout=reset_relaxed_timeout,
|
||||
reset_host_address=reset_host_address,
|
||||
include_free_connections=True,
|
||||
)
|
||||
|
||||
|
||||
class MaintNotificationsConnectionHandler:
|
||||
# 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications
|
||||
_NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = {
|
||||
NodeMigratingNotification: 1,
|
||||
NodeFailingOverNotification: 1,
|
||||
NodeMigratedNotification: 0,
|
||||
NodeFailedOverNotification: 0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "MaintNotificationsAbstractConnection",
|
||||
config: MaintNotificationsConfig,
|
||||
) -> None:
|
||||
self.connection = connection
|
||||
self.config = config
|
||||
|
||||
def handle_notification(self, notification: MaintenanceNotification):
|
||||
# get the notification type by checking its class in the _NOTIFICATION_TYPES dict
|
||||
notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None)
|
||||
|
||||
if notification_type is None:
|
||||
logging.error(f"Unhandled notification type: {notification}")
|
||||
return
|
||||
|
||||
if notification_type:
|
||||
self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE)
|
||||
else:
|
||||
self.handle_maintenance_completed_notification()
|
||||
|
||||
def handle_maintenance_start_notification(
|
||||
self, maintenance_state: MaintenanceState
|
||||
):
|
||||
if (
|
||||
self.connection.maintenance_state == MaintenanceState.MOVING
|
||||
or not self.config.is_relaxed_timeouts_enabled()
|
||||
):
|
||||
return
|
||||
|
||||
self.connection.maintenance_state = maintenance_state
|
||||
self.connection.set_tmp_settings(
|
||||
tmp_relaxed_timeout=self.config.relaxed_timeout
|
||||
)
|
||||
# extend the timeout for all created connections
|
||||
self.connection.update_current_socket_timeout(self.config.relaxed_timeout)
|
||||
|
||||
def handle_maintenance_completed_notification(self):
|
||||
# Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
|
||||
if (
|
||||
self.connection.maintenance_state == MaintenanceState.MOVING
|
||||
or not self.config.is_relaxed_timeouts_enabled()
|
||||
):
|
||||
return
|
||||
self.connection.reset_tmp_settings(reset_relaxed_timeout=True)
|
||||
# Maintenance completed - reset the connection
|
||||
# timeouts by providing -1 as the relaxed timeout
|
||||
self.connection.update_current_socket_timeout(-1)
|
||||
self.connection.maintenance_state = MaintenanceState.NONE
|
||||
@@ -0,0 +1,144 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import pybreaker
|
||||
|
||||
DEFAULT_GRACE_PERIOD = 60
|
||||
|
||||
|
||||
class State(Enum):
|
||||
CLOSED = "closed"
|
||||
OPEN = "open"
|
||||
HALF_OPEN = "half-open"
|
||||
|
||||
|
||||
class CircuitBreaker(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def grace_period(self) -> float:
|
||||
"""The grace period in seconds when the circle should be kept open."""
|
||||
pass
|
||||
|
||||
@grace_period.setter
|
||||
@abstractmethod
|
||||
def grace_period(self, grace_period: float):
|
||||
"""Set the grace period in seconds."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def state(self) -> State:
|
||||
"""The current state of the circuit."""
|
||||
pass
|
||||
|
||||
@state.setter
|
||||
@abstractmethod
|
||||
def state(self, state: State):
|
||||
"""Set current state of the circuit."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def database(self):
|
||||
"""Database associated with this circuit."""
|
||||
pass
|
||||
|
||||
@database.setter
|
||||
@abstractmethod
|
||||
def database(self, database):
|
||||
"""Set database associated with this circuit."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
|
||||
"""Callback called when the state of the circuit changes."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseCircuitBreaker(CircuitBreaker):
|
||||
"""
|
||||
Base implementation of Circuit Breaker interface.
|
||||
"""
|
||||
|
||||
def __init__(self, cb: pybreaker.CircuitBreaker):
|
||||
self._cb = cb
|
||||
self._state_pb_mapper = {
|
||||
State.CLOSED: self._cb.close,
|
||||
State.OPEN: self._cb.open,
|
||||
State.HALF_OPEN: self._cb.half_open,
|
||||
}
|
||||
self._database = None
|
||||
|
||||
@property
|
||||
def grace_period(self) -> float:
|
||||
return self._cb.reset_timeout
|
||||
|
||||
@grace_period.setter
|
||||
def grace_period(self, grace_period: float):
|
||||
self._cb.reset_timeout = grace_period
|
||||
|
||||
@property
|
||||
def state(self) -> State:
|
||||
return State(value=self._cb.state.name)
|
||||
|
||||
@state.setter
|
||||
def state(self, state: State):
|
||||
self._state_pb_mapper[state]()
|
||||
|
||||
@property
|
||||
def database(self):
|
||||
return self._database
|
||||
|
||||
@database.setter
|
||||
def database(self, database):
|
||||
self._database = database
|
||||
|
||||
@abstractmethod
|
||||
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
|
||||
"""Callback called when the state of the circuit changes."""
|
||||
pass
|
||||
|
||||
|
||||
class PBListener(pybreaker.CircuitBreakerListener):
|
||||
"""Wrapper for callback to be compatible with pybreaker implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cb: Callable[[CircuitBreaker, State, State], None],
|
||||
database,
|
||||
):
|
||||
"""
|
||||
Initialize a PBListener instance.
|
||||
|
||||
Args:
|
||||
cb: Callback function that will be called when the circuit breaker state changes.
|
||||
database: Database instance associated with this circuit breaker.
|
||||
"""
|
||||
|
||||
self._cb = cb
|
||||
self._database = database
|
||||
|
||||
def state_change(self, cb, old_state, new_state):
|
||||
cb = PBCircuitBreakerAdapter(cb)
|
||||
cb.database = self._database
|
||||
old_state = State(value=old_state.name)
|
||||
new_state = State(value=new_state.name)
|
||||
self._cb(cb, old_state, new_state)
|
||||
|
||||
|
||||
class PBCircuitBreakerAdapter(BaseCircuitBreaker):
|
||||
def __init__(self, cb: pybreaker.CircuitBreaker):
|
||||
"""
|
||||
Initialize a PBCircuitBreakerAdapter instance.
|
||||
|
||||
This adapter wraps pybreaker's CircuitBreaker implementation to make it compatible
|
||||
with our CircuitBreaker interface.
|
||||
|
||||
Args:
|
||||
cb: A pybreaker CircuitBreaker instance to be adapted.
|
||||
"""
|
||||
super().__init__(cb)
|
||||
|
||||
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
|
||||
listener = PBListener(cb, self.database)
|
||||
self._cb.add_listener(listener)
|
||||
526
backend/venv/lib/python3.9/site-packages/redis/multidb/client.py
Normal file
526
backend/venv/lib/python3.9/site-packages/redis/multidb/client.py
Normal file
@@ -0,0 +1,526 @@
|
||||
import logging
|
||||
import threading
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from redis.background import BackgroundScheduler
|
||||
from redis.client import PubSubWorkerThread
|
||||
from redis.commands import CoreCommands, RedisModuleCommands
|
||||
from redis.multidb.circuit import CircuitBreaker
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.command_executor import DefaultCommandExecutor
|
||||
from redis.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
|
||||
from redis.multidb.database import Database, Databases, SyncDatabase
|
||||
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
|
||||
from redis.multidb.failure_detector import FailureDetector
|
||||
from redis.multidb.healthcheck import HealthCheck, HealthCheckPolicy
|
||||
from redis.utils import experimental
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@experimental
|
||||
class MultiDBClient(RedisModuleCommands, CoreCommands):
|
||||
"""
|
||||
Client that operates on multiple logical Redis databases.
|
||||
Should be used in Active-Active database setups.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiDbConfig):
|
||||
self._databases = config.databases()
|
||||
self._health_checks = config.default_health_checks()
|
||||
|
||||
if config.health_checks is not None:
|
||||
self._health_checks.extend(config.health_checks)
|
||||
|
||||
self._health_check_interval = config.health_check_interval
|
||||
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
|
||||
config.health_check_probes, config.health_check_probes_delay
|
||||
)
|
||||
self._failure_detectors = config.default_failure_detectors()
|
||||
|
||||
if config.failure_detectors is not None:
|
||||
self._failure_detectors.extend(config.failure_detectors)
|
||||
|
||||
self._failover_strategy = (
|
||||
config.default_failover_strategy()
|
||||
if config.failover_strategy is None
|
||||
else config.failover_strategy
|
||||
)
|
||||
self._failover_strategy.set_databases(self._databases)
|
||||
self._auto_fallback_interval = config.auto_fallback_interval
|
||||
self._event_dispatcher = config.event_dispatcher
|
||||
self._command_retry = config.command_retry
|
||||
self._command_retry.update_supported_errors((ConnectionRefusedError,))
|
||||
self.command_executor = DefaultCommandExecutor(
|
||||
failure_detectors=self._failure_detectors,
|
||||
databases=self._databases,
|
||||
command_retry=self._command_retry,
|
||||
failover_strategy=self._failover_strategy,
|
||||
failover_attempts=config.failover_attempts,
|
||||
failover_delay=config.failover_delay,
|
||||
event_dispatcher=self._event_dispatcher,
|
||||
auto_fallback_interval=self._auto_fallback_interval,
|
||||
)
|
||||
self.initialized = False
|
||||
self._hc_lock = threading.RLock()
|
||||
self._bg_scheduler = BackgroundScheduler()
|
||||
self._config = config
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Perform initialization of databases to define their initial state.
|
||||
"""
|
||||
|
||||
def raise_exception_on_failed_hc(error):
|
||||
raise error
|
||||
|
||||
# Initial databases check to define initial state
|
||||
self._check_databases_health(on_error=raise_exception_on_failed_hc)
|
||||
|
||||
# Starts recurring health checks on the background.
|
||||
self._bg_scheduler.run_recurring(
|
||||
self._health_check_interval,
|
||||
self._check_databases_health,
|
||||
)
|
||||
|
||||
is_active_db_found = False
|
||||
|
||||
for database, weight in self._databases:
|
||||
# Set on state changed callback for each circuit.
|
||||
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
|
||||
|
||||
# Set states according to a weights and circuit state
|
||||
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
|
||||
self.command_executor.active_database = database
|
||||
is_active_db_found = True
|
||||
|
||||
if not is_active_db_found:
|
||||
raise NoValidDatabaseException(
|
||||
"Initial connection failed - no active database found"
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def get_databases(self) -> Databases:
|
||||
"""
|
||||
Returns a sorted (by weight) list of all databases.
|
||||
"""
|
||||
return self._databases
|
||||
|
||||
def set_active_database(self, database: SyncDatabase) -> None:
|
||||
"""
|
||||
Promote one of the existing databases to become an active.
|
||||
"""
|
||||
exists = None
|
||||
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
exists = True
|
||||
break
|
||||
|
||||
if not exists:
|
||||
raise ValueError("Given database is not a member of database list")
|
||||
|
||||
self._check_db_health(database)
|
||||
|
||||
if database.circuit.state == CBState.CLOSED:
|
||||
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
|
||||
self.command_executor.active_database = database
|
||||
return
|
||||
|
||||
raise NoValidDatabaseException(
|
||||
"Cannot set active database, database is unhealthy"
|
||||
)
|
||||
|
||||
def add_database(self, database: SyncDatabase):
|
||||
"""
|
||||
Adds a new database to the database list.
|
||||
"""
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
raise ValueError("Given database already exists")
|
||||
|
||||
self._check_db_health(database)
|
||||
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
self._databases.add(database, database.weight)
|
||||
self._change_active_database(database, highest_weighted_db)
|
||||
|
||||
def _change_active_database(
|
||||
self, new_database: SyncDatabase, highest_weight_database: SyncDatabase
|
||||
):
|
||||
if (
|
||||
new_database.weight > highest_weight_database.weight
|
||||
and new_database.circuit.state == CBState.CLOSED
|
||||
):
|
||||
self.command_executor.active_database = new_database
|
||||
|
||||
def remove_database(self, database: Database):
|
||||
"""
|
||||
Removes a database from the database list.
|
||||
"""
|
||||
weight = self._databases.remove(database)
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
|
||||
if (
|
||||
highest_weight <= weight
|
||||
and highest_weighted_db.circuit.state == CBState.CLOSED
|
||||
):
|
||||
self.command_executor.active_database = highest_weighted_db
|
||||
|
||||
def update_database_weight(self, database: SyncDatabase, weight: float):
|
||||
"""
|
||||
Updates a database from the database list.
|
||||
"""
|
||||
exists = None
|
||||
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
exists = True
|
||||
break
|
||||
|
||||
if not exists:
|
||||
raise ValueError("Given database is not a member of database list")
|
||||
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
self._databases.update_weight(database, weight)
|
||||
database.weight = weight
|
||||
self._change_active_database(database, highest_weighted_db)
|
||||
|
||||
def add_failure_detector(self, failure_detector: FailureDetector):
|
||||
"""
|
||||
Adds a new failure detector to the database.
|
||||
"""
|
||||
self._failure_detectors.append(failure_detector)
|
||||
|
||||
def add_health_check(self, healthcheck: HealthCheck):
|
||||
"""
|
||||
Adds a new health check to the database.
|
||||
"""
|
||||
with self._hc_lock:
|
||||
self._health_checks.append(healthcheck)
|
||||
|
||||
def execute_command(self, *args, **options):
|
||||
"""
|
||||
Executes a single command and return its result.
|
||||
"""
|
||||
if not self.initialized:
|
||||
self.initialize()
|
||||
|
||||
return self.command_executor.execute_command(*args, **options)
|
||||
|
||||
def pipeline(self):
|
||||
"""
|
||||
Enters into pipeline mode of the client.
|
||||
"""
|
||||
return Pipeline(self)
|
||||
|
||||
def transaction(self, func: Callable[["Pipeline"], None], *watches, **options):
|
||||
"""
|
||||
Executes callable as transaction.
|
||||
"""
|
||||
if not self.initialized:
|
||||
self.initialize()
|
||||
|
||||
return self.command_executor.execute_transaction(func, *watches, *options)
|
||||
|
||||
def pubsub(self, **kwargs):
|
||||
"""
|
||||
Return a Publish/Subscribe object. With this object, you can
|
||||
subscribe to channels and listen for messages that get published to
|
||||
them.
|
||||
"""
|
||||
if not self.initialized:
|
||||
self.initialize()
|
||||
|
||||
return PubSub(self, **kwargs)
|
||||
|
||||
def _check_db_health(self, database: SyncDatabase) -> bool:
|
||||
"""
|
||||
Runs health checks on the given database until first failure.
|
||||
"""
|
||||
# Health check will setup circuit state
|
||||
is_healthy = self._health_check_policy.execute(self._health_checks, database)
|
||||
|
||||
if not is_healthy:
|
||||
if database.circuit.state != CBState.OPEN:
|
||||
database.circuit.state = CBState.OPEN
|
||||
return is_healthy
|
||||
elif is_healthy and database.circuit.state != CBState.CLOSED:
|
||||
database.circuit.state = CBState.CLOSED
|
||||
|
||||
return is_healthy
|
||||
|
||||
def _check_databases_health(self, on_error: Callable[[Exception], None] = None):
|
||||
"""
|
||||
Runs health checks as a recurring task.
|
||||
Runs health checks against all databases.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=len(self._databases)) as executor:
|
||||
# Submit all health checks
|
||||
futures = {
|
||||
executor.submit(self._check_db_health, database)
|
||||
for database, _ in self._databases
|
||||
}
|
||||
|
||||
try:
|
||||
for future in as_completed(
|
||||
futures, timeout=self._health_check_interval
|
||||
):
|
||||
try:
|
||||
future.result()
|
||||
except UnhealthyDatabaseException as e:
|
||||
unhealthy_db = e.database
|
||||
unhealthy_db.circuit.state = CBState.OPEN
|
||||
|
||||
logger.exception(
|
||||
"Health check failed, due to exception",
|
||||
exc_info=e.original_exception,
|
||||
)
|
||||
|
||||
if on_error:
|
||||
on_error(e.original_exception)
|
||||
except TimeoutError:
|
||||
raise TimeoutError(
|
||||
"Health check execution exceeds health_check_interval"
|
||||
)
|
||||
|
||||
def _on_circuit_state_change_callback(
|
||||
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
|
||||
):
|
||||
if new_state == CBState.HALF_OPEN:
|
||||
self._check_db_health(circuit.database)
|
||||
return
|
||||
|
||||
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
|
||||
self._bg_scheduler.run_once(
|
||||
DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.command_executor.active_database.client.close()
|
||||
|
||||
|
||||
def _half_open_circuit(circuit: CircuitBreaker):
|
||||
circuit.state = CBState.HALF_OPEN
|
||||
|
||||
|
||||
class Pipeline(RedisModuleCommands, CoreCommands):
|
||||
"""
|
||||
Pipeline implementation for multiple logical Redis databases.
|
||||
"""
|
||||
|
||||
def __init__(self, client: MultiDBClient):
|
||||
self._command_stack = []
|
||||
self._client = client
|
||||
|
||||
def __enter__(self) -> "Pipeline":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.reset()
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.reset()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._command_stack)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Pipeline instances should always evaluate to True"""
|
||||
return True
|
||||
|
||||
def reset(self) -> None:
|
||||
self._command_stack = []
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the pipeline"""
|
||||
self.reset()
|
||||
|
||||
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
|
||||
"""
|
||||
Stage a command to be executed when execute() is next called
|
||||
|
||||
Returns the current Pipeline object back so commands can be
|
||||
chained together, such as:
|
||||
|
||||
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
|
||||
|
||||
At some other point, you can then run: pipe.execute(),
|
||||
which will execute all commands queued in the pipe.
|
||||
"""
|
||||
self._command_stack.append((args, options))
|
||||
return self
|
||||
|
||||
def execute_command(self, *args, **kwargs):
|
||||
"""Adds a command to the stack"""
|
||||
return self.pipeline_execute_command(*args, **kwargs)
|
||||
|
||||
def execute(self) -> List[Any]:
|
||||
"""Execute all the commands in the current pipeline"""
|
||||
if not self._client.initialized:
|
||||
self._client.initialize()
|
||||
|
||||
try:
|
||||
return self._client.command_executor.execute_pipeline(
|
||||
tuple(self._command_stack)
|
||||
)
|
||||
finally:
|
||||
self.reset()
|
||||
|
||||
|
||||
class PubSub:
|
||||
"""
|
||||
PubSub object for multi database client.
|
||||
"""
|
||||
|
||||
def __init__(self, client: MultiDBClient, **kwargs):
|
||||
"""Initialize the PubSub object for a multi-database client.
|
||||
|
||||
Args:
|
||||
client: MultiDBClient instance to use for pub/sub operations
|
||||
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
|
||||
"""
|
||||
|
||||
self._client = client
|
||||
self._client.command_executor.pubsub(**kwargs)
|
||||
|
||||
def __enter__(self) -> "PubSub":
|
||||
return self
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
# if this object went out of scope prior to shutting down
|
||||
# subscriptions, close the connection manually before
|
||||
# returning it to the connection pool
|
||||
self.reset()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
return self._client.command_executor.execute_pubsub_method("reset")
|
||||
|
||||
def close(self) -> None:
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def subscribed(self) -> bool:
|
||||
return self._client.command_executor.active_pubsub.subscribed
|
||||
|
||||
def execute_command(self, *args):
|
||||
return self._client.command_executor.execute_pubsub_method(
|
||||
"execute_command", *args
|
||||
)
|
||||
|
||||
def psubscribe(self, *args, **kwargs):
|
||||
"""
|
||||
Subscribe to channel patterns. Patterns supplied as keyword arguments
|
||||
expect a pattern name as the key and a callable as the value. A
|
||||
pattern's callable will be invoked automatically when a message is
|
||||
received on that pattern rather than producing a message via
|
||||
``listen()``.
|
||||
"""
|
||||
return self._client.command_executor.execute_pubsub_method(
|
||||
"psubscribe", *args, **kwargs
|
||||
)
|
||||
|
||||
def punsubscribe(self, *args):
|
||||
"""
|
||||
Unsubscribe from the supplied patterns. If empty, unsubscribe from
|
||||
all patterns.
|
||||
"""
|
||||
return self._client.command_executor.execute_pubsub_method(
|
||||
"punsubscribe", *args
|
||||
)
|
||||
|
||||
def subscribe(self, *args, **kwargs):
|
||||
"""
|
||||
Subscribe to channels. Channels supplied as keyword arguments expect
|
||||
a channel name as the key and a callable as the value. A channel's
|
||||
callable will be invoked automatically when a message is received on
|
||||
that channel rather than producing a message via ``listen()`` or
|
||||
``get_message()``.
|
||||
"""
|
||||
return self._client.command_executor.execute_pubsub_method(
|
||||
"subscribe", *args, **kwargs
|
||||
)
|
||||
|
||||
def unsubscribe(self, *args):
|
||||
"""
|
||||
Unsubscribe from the supplied channels. If empty, unsubscribe from
|
||||
all channels
|
||||
"""
|
||||
return self._client.command_executor.execute_pubsub_method("unsubscribe", *args)
|
||||
|
||||
def ssubscribe(self, *args, **kwargs):
|
||||
"""
|
||||
Subscribes the client to the specified shard channels.
|
||||
Channels supplied as keyword arguments expect a channel name as the key
|
||||
and a callable as the value. A channel's callable will be invoked automatically
|
||||
when a message is received on that channel rather than producing a message via
|
||||
``listen()`` or ``get_sharded_message()``.
|
||||
"""
|
||||
return self._client.command_executor.execute_pubsub_method(
|
||||
"ssubscribe", *args, **kwargs
|
||||
)
|
||||
|
||||
def sunsubscribe(self, *args):
|
||||
"""
|
||||
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
|
||||
all shard_channels
|
||||
"""
|
||||
return self._client.command_executor.execute_pubsub_method(
|
||||
"sunsubscribe", *args
|
||||
)
|
||||
|
||||
def get_message(
|
||||
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
|
||||
):
|
||||
"""
|
||||
Get the next message if one is available, otherwise None.
|
||||
|
||||
If timeout is specified, the system will wait for `timeout` seconds
|
||||
before returning. Timeout should be specified as a floating point
|
||||
number, or None, to wait indefinitely.
|
||||
"""
|
||||
return self._client.command_executor.execute_pubsub_method(
|
||||
"get_message",
|
||||
ignore_subscribe_messages=ignore_subscribe_messages,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def get_sharded_message(
|
||||
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
|
||||
):
|
||||
"""
|
||||
Get the next message if one is available in a sharded channel, otherwise None.
|
||||
|
||||
If timeout is specified, the system will wait for `timeout` seconds
|
||||
before returning. Timeout should be specified as a floating point
|
||||
number, or None, to wait indefinitely.
|
||||
"""
|
||||
return self._client.command_executor.execute_pubsub_method(
|
||||
"get_sharded_message",
|
||||
ignore_subscribe_messages=ignore_subscribe_messages,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def run_in_thread(
|
||||
self,
|
||||
sleep_time: float = 0.0,
|
||||
daemon: bool = False,
|
||||
exception_handler: Optional[Callable] = None,
|
||||
sharded_pubsub: bool = False,
|
||||
) -> "PubSubWorkerThread":
|
||||
return self._client.command_executor.execute_pubsub_run(
|
||||
sleep_time,
|
||||
daemon=daemon,
|
||||
exception_handler=exception_handler,
|
||||
pubsub=self,
|
||||
sharded_pubsub=sharded_pubsub,
|
||||
)
|
||||
@@ -0,0 +1,350 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from redis.client import Pipeline, PubSub, PubSubWorkerThread
|
||||
from redis.event import EventDispatcherInterface, OnCommandsFailEvent
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
from redis.multidb.database import Database, Databases, SyncDatabase
|
||||
from redis.multidb.event import (
|
||||
ActiveDatabaseChanged,
|
||||
CloseConnectionOnActiveDatabaseChanged,
|
||||
RegisterCommandFailure,
|
||||
ResubscribeOnActiveDatabaseChanged,
|
||||
)
|
||||
from redis.multidb.failover import (
|
||||
DEFAULT_FAILOVER_ATTEMPTS,
|
||||
DEFAULT_FAILOVER_DELAY,
|
||||
DefaultFailoverStrategyExecutor,
|
||||
FailoverStrategy,
|
||||
FailoverStrategyExecutor,
|
||||
)
|
||||
from redis.multidb.failure_detector import FailureDetector
|
||||
from redis.retry import Retry
|
||||
|
||||
|
||||
class CommandExecutor(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def auto_fallback_interval(self) -> float:
|
||||
"""Returns auto-fallback interval."""
|
||||
pass
|
||||
|
||||
@auto_fallback_interval.setter
|
||||
@abstractmethod
|
||||
def auto_fallback_interval(self, auto_fallback_interval: float) -> None:
|
||||
"""Sets auto-fallback interval."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseCommandExecutor(CommandExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
|
||||
):
|
||||
self._auto_fallback_interval = auto_fallback_interval
|
||||
self._next_fallback_attempt: datetime
|
||||
|
||||
@property
|
||||
def auto_fallback_interval(self) -> float:
|
||||
return self._auto_fallback_interval
|
||||
|
||||
@auto_fallback_interval.setter
|
||||
def auto_fallback_interval(self, auto_fallback_interval: int) -> None:
|
||||
self._auto_fallback_interval = auto_fallback_interval
|
||||
|
||||
def _schedule_next_fallback(self) -> None:
|
||||
if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL:
|
||||
return
|
||||
|
||||
self._next_fallback_attempt = datetime.now() + timedelta(
|
||||
seconds=self._auto_fallback_interval
|
||||
)
|
||||
|
||||
|
||||
class SyncCommandExecutor(CommandExecutor):
|
||||
@property
|
||||
@abstractmethod
|
||||
def databases(self) -> Databases:
|
||||
"""Returns a list of databases."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failure_detectors(self) -> List[FailureDetector]:
|
||||
"""Returns a list of failure detectors."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_failure_detector(self, failure_detector: FailureDetector) -> None:
|
||||
"""Adds a new failure detector to the list of failure detectors."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def active_database(self) -> Optional[Database]:
|
||||
"""Returns currently active database."""
|
||||
pass
|
||||
|
||||
@active_database.setter
|
||||
@abstractmethod
|
||||
def active_database(self, database: SyncDatabase) -> None:
|
||||
"""Sets the currently active database."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def active_pubsub(self) -> Optional[PubSub]:
|
||||
"""Returns currently active pubsub."""
|
||||
pass
|
||||
|
||||
@active_pubsub.setter
|
||||
@abstractmethod
|
||||
def active_pubsub(self, pubsub: PubSub) -> None:
|
||||
"""Sets currently active pubsub."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
|
||||
"""Returns failover strategy executor."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def command_retry(self) -> Retry:
|
||||
"""Returns command retry object."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pubsub(self, **kwargs):
|
||||
"""Initializes a PubSub object on a currently active database"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute_command(self, *args, **options):
|
||||
"""Executes a command and returns the result."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute_pipeline(self, command_stack: tuple):
|
||||
"""Executes a stack of commands in pipeline."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute_transaction(
|
||||
self, transaction: Callable[[Pipeline], None], *watches, **options
|
||||
):
|
||||
"""Executes a transaction block wrapped in callback."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute_pubsub_method(self, method_name: str, *args, **kwargs):
|
||||
"""Executes a given method on active pub/sub."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
|
||||
"""Executes pub/sub run in a thread."""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultCommandExecutor(SyncCommandExecutor, BaseCommandExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
failure_detectors: List[FailureDetector],
|
||||
databases: Databases,
|
||||
command_retry: Retry,
|
||||
failover_strategy: FailoverStrategy,
|
||||
event_dispatcher: EventDispatcherInterface,
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY,
|
||||
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
|
||||
):
|
||||
"""
|
||||
Initialize the DefaultCommandExecutor instance.
|
||||
|
||||
Args:
|
||||
failure_detectors: List of failure detector instances to monitor database health
|
||||
databases: Collection of available databases to execute commands on
|
||||
command_retry: Retry policy for failed command execution
|
||||
failover_strategy: Strategy for handling database failover
|
||||
event_dispatcher: Interface for dispatching events
|
||||
failover_attempts: Number of failover attempts
|
||||
failover_delay: Delay between failover attempts
|
||||
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
|
||||
"""
|
||||
super().__init__(auto_fallback_interval)
|
||||
|
||||
for fd in failure_detectors:
|
||||
fd.set_command_executor(command_executor=self)
|
||||
|
||||
self._databases = databases
|
||||
self._failure_detectors = failure_detectors
|
||||
self._command_retry = command_retry
|
||||
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
|
||||
failover_strategy, failover_attempts, failover_delay
|
||||
)
|
||||
self._event_dispatcher = event_dispatcher
|
||||
self._active_database: Optional[Database] = None
|
||||
self._active_pubsub: Optional[PubSub] = None
|
||||
self._active_pubsub_kwargs = {}
|
||||
self._setup_event_dispatcher()
|
||||
self._schedule_next_fallback()
|
||||
|
||||
@property
|
||||
def databases(self) -> Databases:
|
||||
return self._databases
|
||||
|
||||
@property
|
||||
def failure_detectors(self) -> List[FailureDetector]:
|
||||
return self._failure_detectors
|
||||
|
||||
def add_failure_detector(self, failure_detector: FailureDetector) -> None:
|
||||
self._failure_detectors.append(failure_detector)
|
||||
|
||||
@property
|
||||
def command_retry(self) -> Retry:
|
||||
return self._command_retry
|
||||
|
||||
@property
|
||||
def active_database(self) -> Optional[SyncDatabase]:
|
||||
return self._active_database
|
||||
|
||||
@active_database.setter
|
||||
def active_database(self, database: SyncDatabase) -> None:
|
||||
old_active = self._active_database
|
||||
self._active_database = database
|
||||
|
||||
if old_active is not None and old_active is not database:
|
||||
self._event_dispatcher.dispatch(
|
||||
ActiveDatabaseChanged(
|
||||
old_active,
|
||||
self._active_database,
|
||||
self,
|
||||
**self._active_pubsub_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def active_pubsub(self) -> Optional[PubSub]:
|
||||
return self._active_pubsub
|
||||
|
||||
@active_pubsub.setter
|
||||
def active_pubsub(self, pubsub: PubSub) -> None:
|
||||
self._active_pubsub = pubsub
|
||||
|
||||
@property
|
||||
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
|
||||
return self._failover_strategy_executor
|
||||
|
||||
def execute_command(self, *args, **options):
|
||||
def callback():
|
||||
response = self._active_database.client.execute_command(*args, **options)
|
||||
self._register_command_execution(args)
|
||||
return response
|
||||
|
||||
return self._execute_with_failure_detection(callback, args)
|
||||
|
||||
def execute_pipeline(self, command_stack: tuple):
|
||||
def callback():
|
||||
with self._active_database.client.pipeline() as pipe:
|
||||
for command, options in command_stack:
|
||||
pipe.execute_command(*command, **options)
|
||||
|
||||
response = pipe.execute()
|
||||
self._register_command_execution(command_stack)
|
||||
return response
|
||||
|
||||
return self._execute_with_failure_detection(callback, command_stack)
|
||||
|
||||
def execute_transaction(
|
||||
self, transaction: Callable[[Pipeline], None], *watches, **options
|
||||
):
|
||||
def callback():
|
||||
response = self._active_database.client.transaction(
|
||||
transaction, *watches, **options
|
||||
)
|
||||
self._register_command_execution(())
|
||||
return response
|
||||
|
||||
return self._execute_with_failure_detection(callback)
|
||||
|
||||
def pubsub(self, **kwargs):
|
||||
def callback():
|
||||
if self._active_pubsub is None:
|
||||
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
|
||||
self._active_pubsub_kwargs = kwargs
|
||||
return None
|
||||
|
||||
return self._execute_with_failure_detection(callback)
|
||||
|
||||
def execute_pubsub_method(self, method_name: str, *args, **kwargs):
|
||||
def callback():
|
||||
method = getattr(self.active_pubsub, method_name)
|
||||
response = method(*args, **kwargs)
|
||||
self._register_command_execution(args)
|
||||
return response
|
||||
|
||||
return self._execute_with_failure_detection(callback, *args)
|
||||
|
||||
def execute_pubsub_run(self, sleep_time, **kwargs) -> "PubSubWorkerThread":
|
||||
def callback():
|
||||
return self._active_pubsub.run_in_thread(sleep_time, **kwargs)
|
||||
|
||||
return self._execute_with_failure_detection(callback)
|
||||
|
||||
def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()):
|
||||
"""
|
||||
Execute a commands execution callback with failure detection.
|
||||
"""
|
||||
|
||||
def wrapper():
|
||||
# On each retry we need to check active database as it might change.
|
||||
self._check_active_database()
|
||||
return callback()
|
||||
|
||||
return self._command_retry.call_with_retry(
|
||||
lambda: wrapper(),
|
||||
lambda error: self._on_command_fail(error, *cmds),
|
||||
)
|
||||
|
||||
def _on_command_fail(self, error, *args):
|
||||
self._event_dispatcher.dispatch(OnCommandsFailEvent(args, error))
|
||||
|
||||
def _check_active_database(self):
|
||||
"""
|
||||
Checks if active a database needs to be updated.
|
||||
"""
|
||||
if (
|
||||
self._active_database is None
|
||||
or self._active_database.circuit.state != CBState.CLOSED
|
||||
or (
|
||||
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
and self._next_fallback_attempt <= datetime.now()
|
||||
)
|
||||
):
|
||||
self.active_database = self._failover_strategy_executor.execute()
|
||||
self._schedule_next_fallback()
|
||||
|
||||
def _register_command_execution(self, cmd: tuple):
|
||||
for detector in self._failure_detectors:
|
||||
detector.register_command_execution(cmd)
|
||||
|
||||
def _setup_event_dispatcher(self):
|
||||
"""
|
||||
Registers necessary listeners.
|
||||
"""
|
||||
failure_listener = RegisterCommandFailure(self._failure_detectors)
|
||||
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
|
||||
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
|
||||
self._event_dispatcher.register_listeners(
|
||||
{
|
||||
OnCommandsFailEvent: [failure_listener],
|
||||
ActiveDatabaseChanged: [
|
||||
close_connection_listener,
|
||||
resubscribe_listener,
|
||||
],
|
||||
}
|
||||
)
|
||||
207
backend/venv/lib/python3.9/site-packages/redis/multidb/config.py
Normal file
207
backend/venv/lib/python3.9/site-packages/redis/multidb/config.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Type, Union
|
||||
|
||||
import pybreaker
|
||||
from typing_extensions import Optional
|
||||
|
||||
from redis import ConnectionPool, Redis, RedisCluster
|
||||
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.event import EventDispatcher, EventDispatcherInterface
|
||||
from redis.multidb.circuit import (
|
||||
DEFAULT_GRACE_PERIOD,
|
||||
CircuitBreaker,
|
||||
PBCircuitBreakerAdapter,
|
||||
)
|
||||
from redis.multidb.database import Database, Databases
|
||||
from redis.multidb.failover import (
|
||||
DEFAULT_FAILOVER_ATTEMPTS,
|
||||
DEFAULT_FAILOVER_DELAY,
|
||||
FailoverStrategy,
|
||||
WeightBasedFailoverStrategy,
|
||||
)
|
||||
from redis.multidb.failure_detector import (
|
||||
DEFAULT_FAILURE_RATE_THRESHOLD,
|
||||
DEFAULT_FAILURES_DETECTION_WINDOW,
|
||||
DEFAULT_MIN_NUM_FAILURES,
|
||||
CommandFailureDetector,
|
||||
FailureDetector,
|
||||
)
|
||||
from redis.multidb.healthcheck import (
|
||||
DEFAULT_HEALTH_CHECK_DELAY,
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL,
|
||||
DEFAULT_HEALTH_CHECK_POLICY,
|
||||
DEFAULT_HEALTH_CHECK_PROBES,
|
||||
HealthCheck,
|
||||
HealthCheckPolicies,
|
||||
PingHealthCheck,
|
||||
)
|
||||
from redis.retry import Retry
|
||||
|
||||
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
|
||||
|
||||
|
||||
def default_event_dispatcher() -> EventDispatcherInterface:
|
||||
return EventDispatcher()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""
|
||||
Dataclass representing the configuration for a database connection.
|
||||
|
||||
This class is used to store configuration settings for a database connection,
|
||||
including client options, connection sourcing details, circuit breaker settings,
|
||||
and cluster-specific properties. It provides a structure for defining these
|
||||
attributes and allows for the creation of customized configurations for various
|
||||
database setups.
|
||||
|
||||
Attributes:
|
||||
weight (float): Weight of the database to define the active one.
|
||||
client_kwargs (dict): Additional parameters for the database client connection.
|
||||
from_url (Optional[str]): Redis URL way of connecting to the database.
|
||||
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
|
||||
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
|
||||
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
|
||||
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
|
||||
on public Redis Enterprise endpoints.
|
||||
|
||||
Methods:
|
||||
default_circuit_breaker:
|
||||
Generates and returns a default CircuitBreaker instance adapted for use.
|
||||
"""
|
||||
|
||||
weight: float = 1.0
|
||||
client_kwargs: dict = field(default_factory=dict)
|
||||
from_url: Optional[str] = None
|
||||
from_pool: Optional[ConnectionPool] = None
|
||||
circuit: Optional[CircuitBreaker] = None
|
||||
grace_period: float = DEFAULT_GRACE_PERIOD
|
||||
health_check_url: Optional[str] = None
|
||||
|
||||
def default_circuit_breaker(self) -> CircuitBreaker:
|
||||
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
|
||||
return PBCircuitBreakerAdapter(circuit_breaker)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDbConfig:
|
||||
"""
|
||||
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
|
||||
|
||||
Attributes:
|
||||
databases_config: A list of database configurations.
|
||||
client_class: The client class used to manage database connections.
|
||||
command_retry: Retry strategy for executing database commands.
|
||||
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
|
||||
min_num_failures: Minimal count of failures required for failover
|
||||
failure_rate_threshold: Percentage of failures required for failover
|
||||
failures_detection_window: Time interval for tracking database failures.
|
||||
health_checks: Optional list of additional health checks performed on databases.
|
||||
health_check_interval: Time interval for executing health checks.
|
||||
health_check_probes: Number of attempts to evaluate the health of a database.
|
||||
health_check_probes_delay: Delay between health check attempts.
|
||||
health_check_policy: Policy for determining database health based on health checks.
|
||||
failover_strategy: Optional strategy for handling database failover scenarios.
|
||||
failover_attempts: Number of retries allowed for failover operations.
|
||||
failover_delay: Delay between failover attempts.
|
||||
auto_fallback_interval: Time interval to trigger automatic fallback.
|
||||
event_dispatcher: Interface for dispatching events related to database operations.
|
||||
|
||||
Methods:
|
||||
databases:
|
||||
Retrieves a collection of database clients managed by weighted configurations.
|
||||
Initializes database clients based on the provided configuration and removes
|
||||
redundant retry objects for lower-level clients to rely on global retry logic.
|
||||
|
||||
default_failure_detectors:
|
||||
Returns the default list of failure detectors used to monitor database failures.
|
||||
|
||||
default_health_checks:
|
||||
Returns the default list of health checks used to monitor database health
|
||||
with specific retry and backoff strategies.
|
||||
|
||||
default_failover_strategy:
|
||||
Provides the default failover strategy used for handling failover scenarios
|
||||
with defined retry and backoff configurations.
|
||||
"""
|
||||
|
||||
databases_config: List[DatabaseConfig]
|
||||
client_class: Type[Union[Redis, RedisCluster]] = Redis
|
||||
command_retry: Retry = Retry(
|
||||
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
|
||||
)
|
||||
failure_detectors: Optional[List[FailureDetector]] = None
|
||||
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
|
||||
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
|
||||
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
|
||||
health_checks: Optional[List[HealthCheck]] = None
|
||||
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
|
||||
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
|
||||
health_check_probes_delay: float = DEFAULT_HEALTH_CHECK_DELAY
|
||||
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
|
||||
failover_strategy: Optional[FailoverStrategy] = None
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY
|
||||
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
event_dispatcher: EventDispatcherInterface = field(
|
||||
default_factory=default_event_dispatcher
|
||||
)
|
||||
|
||||
def databases(self) -> Databases:
|
||||
databases = WeightedList()
|
||||
|
||||
for database_config in self.databases_config:
|
||||
# The retry object is not used in the lower level clients, so we can safely remove it.
|
||||
# We rely on command_retry in terms of global retries.
|
||||
database_config.client_kwargs.update(
|
||||
{"retry": Retry(retries=0, backoff=NoBackoff())}
|
||||
)
|
||||
|
||||
if database_config.from_url:
|
||||
client = self.client_class.from_url(
|
||||
database_config.from_url, **database_config.client_kwargs
|
||||
)
|
||||
elif database_config.from_pool:
|
||||
database_config.from_pool.set_retry(
|
||||
Retry(retries=0, backoff=NoBackoff())
|
||||
)
|
||||
client = self.client_class.from_pool(
|
||||
connection_pool=database_config.from_pool
|
||||
)
|
||||
else:
|
||||
client = self.client_class(**database_config.client_kwargs)
|
||||
|
||||
circuit = (
|
||||
database_config.default_circuit_breaker()
|
||||
if database_config.circuit is None
|
||||
else database_config.circuit
|
||||
)
|
||||
databases.add(
|
||||
Database(
|
||||
client=client,
|
||||
circuit=circuit,
|
||||
weight=database_config.weight,
|
||||
health_check_url=database_config.health_check_url,
|
||||
),
|
||||
database_config.weight,
|
||||
)
|
||||
|
||||
return databases
|
||||
|
||||
def default_failure_detectors(self) -> List[FailureDetector]:
|
||||
return [
|
||||
CommandFailureDetector(
|
||||
min_num_failures=self.min_num_failures,
|
||||
failure_rate_threshold=self.failure_rate_threshold,
|
||||
failure_detection_window=self.failures_detection_window,
|
||||
),
|
||||
]
|
||||
|
||||
def default_health_checks(self) -> List[HealthCheck]:
|
||||
return [
|
||||
PingHealthCheck(),
|
||||
]
|
||||
|
||||
def default_failover_strategy(self) -> FailoverStrategy:
|
||||
return WeightBasedFailoverStrategy()
|
||||
@@ -0,0 +1,130 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
import redis
|
||||
from redis import RedisCluster
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.multidb.circuit import CircuitBreaker
|
||||
from redis.typing import Number
|
||||
|
||||
|
||||
class AbstractDatabase(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def weight(self) -> float:
|
||||
"""The weight of this database in compare to others. Used to determine the database failover to."""
|
||||
pass
|
||||
|
||||
@weight.setter
|
||||
@abstractmethod
|
||||
def weight(self, weight: float):
|
||||
"""Set the weight of this database in compare to others."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def health_check_url(self) -> Optional[str]:
|
||||
"""Health check URL associated with the current database."""
|
||||
pass
|
||||
|
||||
@health_check_url.setter
|
||||
@abstractmethod
|
||||
def health_check_url(self, health_check_url: Optional[str]):
|
||||
"""Set the health check URL associated with the current database."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseDatabase(AbstractDatabase):
|
||||
def __init__(
|
||||
self,
|
||||
weight: float,
|
||||
health_check_url: Optional[str] = None,
|
||||
):
|
||||
self._weight = weight
|
||||
self._health_check_url = health_check_url
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
return self._weight
|
||||
|
||||
@weight.setter
|
||||
def weight(self, weight: float):
|
||||
self._weight = weight
|
||||
|
||||
@property
|
||||
def health_check_url(self) -> Optional[str]:
|
||||
return self._health_check_url
|
||||
|
||||
@health_check_url.setter
|
||||
def health_check_url(self, health_check_url: Optional[str]):
|
||||
self._health_check_url = health_check_url
|
||||
|
||||
|
||||
class SyncDatabase(AbstractDatabase):
|
||||
"""Database with an underlying synchronous redis client."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def client(self) -> Union[redis.Redis, RedisCluster]:
|
||||
"""The underlying redis client."""
|
||||
pass
|
||||
|
||||
@client.setter
|
||||
@abstractmethod
|
||||
def client(self, client: Union[redis.Redis, RedisCluster]):
|
||||
"""Set the underlying redis client."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def circuit(self) -> CircuitBreaker:
|
||||
"""Circuit breaker for the current database."""
|
||||
pass
|
||||
|
||||
@circuit.setter
|
||||
@abstractmethod
|
||||
def circuit(self, circuit: CircuitBreaker):
|
||||
"""Set the circuit breaker for the current database."""
|
||||
pass
|
||||
|
||||
|
||||
Databases = WeightedList[tuple[SyncDatabase, Number]]
|
||||
|
||||
|
||||
class Database(BaseDatabase, SyncDatabase):
|
||||
def __init__(
|
||||
self,
|
||||
client: Union[redis.Redis, RedisCluster],
|
||||
circuit: CircuitBreaker,
|
||||
weight: float,
|
||||
health_check_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new Database instance.
|
||||
|
||||
Args:
|
||||
client: Underlying Redis client instance for database operations
|
||||
circuit: Circuit breaker for handling database failures
|
||||
weight: Weight value used for database failover prioritization
|
||||
health_check_url: Health check URL associated with the current database
|
||||
"""
|
||||
self._client = client
|
||||
self._cb = circuit
|
||||
self._cb.database = self
|
||||
super().__init__(weight, health_check_url)
|
||||
|
||||
@property
|
||||
def client(self) -> Union[redis.Redis, RedisCluster]:
|
||||
return self._client
|
||||
|
||||
@client.setter
|
||||
def client(self, client: Union[redis.Redis, RedisCluster]):
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def circuit(self) -> CircuitBreaker:
|
||||
return self._cb
|
||||
|
||||
@circuit.setter
|
||||
def circuit(self, circuit: CircuitBreaker):
|
||||
self._cb = circuit
|
||||
@@ -0,0 +1,89 @@
|
||||
from typing import List
|
||||
|
||||
from redis.client import Redis
|
||||
from redis.event import EventListenerInterface, OnCommandsFailEvent
|
||||
from redis.multidb.database import SyncDatabase
|
||||
from redis.multidb.failure_detector import FailureDetector
|
||||
|
||||
|
||||
class ActiveDatabaseChanged:
|
||||
"""
|
||||
Event fired when an active database has been changed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
old_database: SyncDatabase,
|
||||
new_database: SyncDatabase,
|
||||
command_executor,
|
||||
**kwargs,
|
||||
):
|
||||
self._old_database = old_database
|
||||
self._new_database = new_database
|
||||
self._command_executor = command_executor
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def old_database(self) -> SyncDatabase:
|
||||
return self._old_database
|
||||
|
||||
@property
|
||||
def new_database(self) -> SyncDatabase:
|
||||
return self._new_database
|
||||
|
||||
@property
|
||||
def command_executor(self):
|
||||
return self._command_executor
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
return self._kwargs
|
||||
|
||||
|
||||
class ResubscribeOnActiveDatabaseChanged(EventListenerInterface):
|
||||
"""
|
||||
Re-subscribe the currently active pub / sub to a new active database.
|
||||
"""
|
||||
|
||||
def listen(self, event: ActiveDatabaseChanged):
|
||||
old_pubsub = event.command_executor.active_pubsub
|
||||
|
||||
if old_pubsub is not None:
|
||||
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
|
||||
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
|
||||
new_pubsub.channels = old_pubsub.channels
|
||||
new_pubsub.patterns = old_pubsub.patterns
|
||||
new_pubsub.shard_channels = old_pubsub.shard_channels
|
||||
new_pubsub.on_connect(None)
|
||||
event.command_executor.active_pubsub = new_pubsub
|
||||
old_pubsub.close()
|
||||
|
||||
|
||||
class CloseConnectionOnActiveDatabaseChanged(EventListenerInterface):
|
||||
"""
|
||||
Close connection to the old active database.
|
||||
"""
|
||||
|
||||
def listen(self, event: ActiveDatabaseChanged):
|
||||
event.old_database.client.close()
|
||||
|
||||
if isinstance(event.old_database.client, Redis):
|
||||
event.old_database.client.connection_pool.update_active_connections_for_reconnect()
|
||||
event.old_database.client.connection_pool.disconnect()
|
||||
else:
|
||||
for node in event.old_database.client.nodes_manager.nodes_cache.values():
|
||||
node.redis_connection.connection_pool.update_active_connections_for_reconnect()
|
||||
node.redis_connection.connection_pool.disconnect()
|
||||
|
||||
|
||||
class RegisterCommandFailure(EventListenerInterface):
|
||||
"""
|
||||
Event listener that registers command failures and passing it to the failure detectors.
|
||||
"""
|
||||
|
||||
def __init__(self, failure_detectors: List[FailureDetector]):
|
||||
self._failure_detectors = failure_detectors
|
||||
|
||||
def listen(self, event: OnCommandsFailEvent) -> None:
|
||||
for failure_detector in self._failure_detectors:
|
||||
failure_detector.register_failure(event.exception, event.commands)
|
||||
@@ -0,0 +1,17 @@
|
||||
class NoValidDatabaseException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnhealthyDatabaseException(Exception):
|
||||
"""Exception raised when a database is unhealthy due to an underlying exception."""
|
||||
|
||||
def __init__(self, message, database, original_exception):
|
||||
super().__init__(message)
|
||||
self.database = database
|
||||
self.original_exception = original_exception
|
||||
|
||||
|
||||
class TemporaryUnavailableException(Exception):
|
||||
"""Exception raised when all databases in setup are temporary unavailable."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,125 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.database import Databases, SyncDatabase
|
||||
from redis.multidb.exception import (
|
||||
NoValidDatabaseException,
|
||||
TemporaryUnavailableException,
|
||||
)
|
||||
|
||||
DEFAULT_FAILOVER_ATTEMPTS = 10
|
||||
DEFAULT_FAILOVER_DELAY = 12
|
||||
|
||||
|
||||
class FailoverStrategy(ABC):
|
||||
@abstractmethod
|
||||
def database(self) -> SyncDatabase:
|
||||
"""Select the database according to the strategy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_databases(self, databases: Databases) -> None:
|
||||
"""Set the database strategy operates on."""
|
||||
pass
|
||||
|
||||
|
||||
class FailoverStrategyExecutor(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_attempts(self) -> int:
|
||||
"""The number of failover attempts."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_delay(self) -> float:
|
||||
"""The delay between failover attempts."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def strategy(self) -> FailoverStrategy:
|
||||
"""The strategy to execute."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute(self) -> SyncDatabase:
|
||||
"""Execute the failover strategy."""
|
||||
pass
|
||||
|
||||
|
||||
class WeightBasedFailoverStrategy(FailoverStrategy):
|
||||
"""
|
||||
Failover strategy based on database weights.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._databases = WeightedList()
|
||||
|
||||
def database(self) -> SyncDatabase:
|
||||
for database, _ in self._databases:
|
||||
if database.circuit.state == CBState.CLOSED:
|
||||
return database
|
||||
|
||||
raise NoValidDatabaseException("No valid database available for communication")
|
||||
|
||||
def set_databases(self, databases: Databases) -> None:
|
||||
self._databases = databases
|
||||
|
||||
|
||||
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
|
||||
"""
|
||||
Executes given failover strategy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: FailoverStrategy,
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY,
|
||||
):
|
||||
self._strategy = strategy
|
||||
self._failover_attempts = failover_attempts
|
||||
self._failover_delay = failover_delay
|
||||
self._next_attempt_ts: int = 0
|
||||
self._failover_counter: int = 0
|
||||
|
||||
@property
|
||||
def failover_attempts(self) -> int:
|
||||
return self._failover_attempts
|
||||
|
||||
@property
|
||||
def failover_delay(self) -> float:
|
||||
return self._failover_delay
|
||||
|
||||
@property
|
||||
def strategy(self) -> FailoverStrategy:
|
||||
return self._strategy
|
||||
|
||||
def execute(self) -> SyncDatabase:
|
||||
try:
|
||||
database = self._strategy.database()
|
||||
self._reset()
|
||||
return database
|
||||
except NoValidDatabaseException as e:
|
||||
if self._next_attempt_ts == 0:
|
||||
self._next_attempt_ts = time.time() + self._failover_delay
|
||||
self._failover_counter += 1
|
||||
elif time.time() >= self._next_attempt_ts:
|
||||
self._next_attempt_ts += self._failover_delay
|
||||
self._failover_counter += 1
|
||||
|
||||
if self._failover_counter > self._failover_attempts:
|
||||
self._reset()
|
||||
raise e
|
||||
else:
|
||||
raise TemporaryUnavailableException(
|
||||
"No database connections currently available. "
|
||||
"This is a temporary condition - please retry the operation."
|
||||
)
|
||||
|
||||
def _reset(self) -> None:
|
||||
self._next_attempt_ts = 0
|
||||
self._failover_counter = 0
|
||||
@@ -0,0 +1,104 @@
|
||||
import math
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Type
|
||||
|
||||
from typing_extensions import Optional
|
||||
|
||||
from redis.multidb.circuit import State as CBState
|
||||
|
||||
DEFAULT_MIN_NUM_FAILURES = 1000
|
||||
DEFAULT_FAILURE_RATE_THRESHOLD = 0.1
|
||||
DEFAULT_FAILURES_DETECTION_WINDOW = 2
|
||||
|
||||
|
||||
class FailureDetector(ABC):
|
||||
@abstractmethod
|
||||
def register_failure(self, exception: Exception, cmd: tuple) -> None:
|
||||
"""Register a failure that occurred during command execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_command_execution(self, cmd: tuple) -> None:
|
||||
"""Register a command execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_command_executor(self, command_executor) -> None:
|
||||
"""Set the command executor for this failure."""
|
||||
pass
|
||||
|
||||
|
||||
class CommandFailureDetector(FailureDetector):
|
||||
"""
|
||||
Detects a failure based on a threshold of failed commands during a specific period of time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES,
|
||||
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD,
|
||||
failure_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW,
|
||||
error_types: Optional[List[Type[Exception]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a new CommandFailureDetector instance.
|
||||
|
||||
Args:
|
||||
min_num_failures: Minimal count of failures required for failover
|
||||
failure_rate_threshold: Percentage of failures required for failover
|
||||
failure_detection_window: Time interval for executing health checks.
|
||||
error_types: Optional list of exception types to trigger failover. If None, all exceptions are counted.
|
||||
|
||||
The detector tracks command failures within a sliding time window. When the number of failures
|
||||
exceeds the threshold within the specified duration, it triggers failure detection.
|
||||
"""
|
||||
self._command_executor = None
|
||||
self._min_num_failures = min_num_failures
|
||||
self._failure_rate_threshold = failure_rate_threshold
|
||||
self._failure_detection_window = failure_detection_window
|
||||
self._error_types = error_types
|
||||
self._commands_executed: int = 0
|
||||
self._start_time: datetime = datetime.now()
|
||||
self._end_time: datetime = self._start_time + timedelta(
|
||||
seconds=self._failure_detection_window
|
||||
)
|
||||
self._failures_count: int = 0
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def register_failure(self, exception: Exception, cmd: tuple) -> None:
|
||||
with self._lock:
|
||||
if self._error_types:
|
||||
if type(exception) in self._error_types:
|
||||
self._failures_count += 1
|
||||
else:
|
||||
self._failures_count += 1
|
||||
|
||||
self._check_threshold()
|
||||
|
||||
def set_command_executor(self, command_executor) -> None:
|
||||
self._command_executor = command_executor
|
||||
|
||||
def register_command_execution(self, cmd: tuple) -> None:
|
||||
with self._lock:
|
||||
if not self._start_time < datetime.now() < self._end_time:
|
||||
self._reset()
|
||||
|
||||
self._commands_executed += 1
|
||||
|
||||
def _check_threshold(self):
|
||||
if self._failures_count >= self._min_num_failures and self._failures_count >= (
|
||||
math.ceil(self._commands_executed * self._failure_rate_threshold)
|
||||
):
|
||||
self._command_executor.active_database.circuit.state = CBState.OPEN
|
||||
self._reset()
|
||||
|
||||
def _reset(self) -> None:
|
||||
with self._lock:
|
||||
self._start_time = datetime.now()
|
||||
self._end_time = self._start_time + timedelta(
|
||||
seconds=self._failure_detection_window
|
||||
)
|
||||
self._failures_count = 0
|
||||
self._commands_executed = 0
|
||||
@@ -0,0 +1,282 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from time import sleep
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from redis import Redis
|
||||
from redis.backoff import NoBackoff
|
||||
from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient
|
||||
from redis.multidb.exception import UnhealthyDatabaseException
|
||||
from redis.retry import Retry
|
||||
|
||||
DEFAULT_HEALTH_CHECK_PROBES = 3
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL = 5
|
||||
DEFAULT_HEALTH_CHECK_DELAY = 0.5
|
||||
DEFAULT_LAG_AWARE_TOLERANCE = 5000
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthCheck(ABC):
|
||||
@abstractmethod
|
||||
def check_health(self, database) -> bool:
|
||||
"""Function to determine the health status."""
|
||||
pass
|
||||
|
||||
|
||||
class HealthCheckPolicy(ABC):
|
||||
"""
|
||||
Health checks execution policy.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def health_check_probes(self) -> int:
|
||||
"""Number of probes to execute health checks."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def health_check_delay(self) -> float:
|
||||
"""Delay between health check probes."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
"""Execute health checks and return database health status."""
|
||||
pass
|
||||
|
||||
|
||||
class AbstractHealthCheckPolicy(HealthCheckPolicy):
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
if health_check_probes < 1:
|
||||
raise ValueError("health_check_probes must be greater than 0")
|
||||
self._health_check_probes = health_check_probes
|
||||
self._health_check_delay = health_check_delay
|
||||
|
||||
@property
|
||||
def health_check_probes(self) -> int:
|
||||
return self._health_check_probes
|
||||
|
||||
@property
|
||||
def health_check_delay(self) -> float:
|
||||
return self._health_check_delay
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class HealthyAllPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if all health check probes are successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
for health_check in health_checks:
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if not health_check.check_health(database):
|
||||
return False
|
||||
except Exception as e:
|
||||
raise UnhealthyDatabaseException("Unhealthy database", database, e)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
sleep(self._health_check_delay)
|
||||
return True
|
||||
|
||||
|
||||
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if a majority of health check probes are successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
for health_check in health_checks:
|
||||
if self.health_check_probes % 2 == 0:
|
||||
allowed_unsuccessful_probes = self.health_check_probes / 2
|
||||
else:
|
||||
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
|
||||
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if not health_check.check_health(database):
|
||||
allowed_unsuccessful_probes -= 1
|
||||
if allowed_unsuccessful_probes <= 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
allowed_unsuccessful_probes -= 1
|
||||
if allowed_unsuccessful_probes <= 0:
|
||||
raise UnhealthyDatabaseException(
|
||||
"Unhealthy database", database, e
|
||||
)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
sleep(self._health_check_delay)
|
||||
return True
|
||||
|
||||
|
||||
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if at least one health check probe is successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
is_healthy = False
|
||||
|
||||
for health_check in health_checks:
|
||||
exception = None
|
||||
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if health_check.check_health(database):
|
||||
is_healthy = True
|
||||
break
|
||||
else:
|
||||
is_healthy = False
|
||||
except Exception as e:
|
||||
exception = UnhealthyDatabaseException(
|
||||
"Unhealthy database", database, e
|
||||
)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
sleep(self._health_check_delay)
|
||||
|
||||
if not is_healthy and not exception:
|
||||
return is_healthy
|
||||
elif not is_healthy and exception:
|
||||
raise exception
|
||||
|
||||
return is_healthy
|
||||
|
||||
|
||||
class HealthCheckPolicies(Enum):
|
||||
HEALTHY_ALL = HealthyAllPolicy
|
||||
HEALTHY_MAJORITY = HealthyMajorityPolicy
|
||||
HEALTHY_ANY = HealthyAnyPolicy
|
||||
|
||||
|
||||
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
|
||||
|
||||
|
||||
class PingHealthCheck(HealthCheck):
|
||||
"""
|
||||
Health check based on PING command.
|
||||
"""
|
||||
|
||||
def check_health(self, database) -> bool:
|
||||
if isinstance(database.client, Redis):
|
||||
return database.client.execute_command("PING")
|
||||
else:
|
||||
# For a cluster checks if all nodes are healthy.
|
||||
all_nodes = database.client.get_nodes()
|
||||
for node in all_nodes:
|
||||
if not node.redis_connection.execute_command("PING"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class LagAwareHealthCheck(HealthCheck):
|
||||
"""
|
||||
Health check available for Redis Enterprise deployments.
|
||||
Verify via REST API that the database is healthy based on different lags.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rest_api_port: int = 9443,
|
||||
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
auth_basic: Optional[Tuple[str, str]] = None,
|
||||
verify_tls: bool = True,
|
||||
# TLS verification (server) options
|
||||
ca_file: Optional[str] = None,
|
||||
ca_path: Optional[str] = None,
|
||||
ca_data: Optional[Union[str, bytes]] = None,
|
||||
# Mutual TLS (client cert) options
|
||||
client_cert_file: Optional[str] = None,
|
||||
client_key_file: Optional[str] = None,
|
||||
client_key_password: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize LagAwareHealthCheck with the specified parameters.
|
||||
|
||||
Args:
|
||||
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
|
||||
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
|
||||
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
|
||||
auth_basic: Tuple of (username, password) for basic authentication
|
||||
verify_tls: Whether to verify TLS certificates (default: True)
|
||||
ca_file: Path to CA certificate file for TLS verification
|
||||
ca_path: Path to CA certificates directory for TLS verification
|
||||
ca_data: CA certificate data as string or bytes
|
||||
client_cert_file: Path to client certificate file for mutual TLS
|
||||
client_key_file: Path to client private key file for mutual TLS
|
||||
client_key_password: Password for encrypted client private key
|
||||
"""
|
||||
self._http_client = HttpClient(
|
||||
timeout=timeout,
|
||||
auth_basic=auth_basic,
|
||||
retry=Retry(NoBackoff(), retries=0),
|
||||
verify_tls=verify_tls,
|
||||
ca_file=ca_file,
|
||||
ca_path=ca_path,
|
||||
ca_data=ca_data,
|
||||
client_cert_file=client_cert_file,
|
||||
client_key_file=client_key_file,
|
||||
client_key_password=client_key_password,
|
||||
)
|
||||
self._rest_api_port = rest_api_port
|
||||
self._lag_aware_tolerance = lag_aware_tolerance
|
||||
|
||||
def check_health(self, database) -> bool:
|
||||
if database.health_check_url is None:
|
||||
raise ValueError(
|
||||
"Database health check url is not set. Please check DatabaseConfig for the current database."
|
||||
)
|
||||
|
||||
if isinstance(database.client, Redis):
|
||||
db_host = database.client.get_connection_kwargs()["host"]
|
||||
else:
|
||||
db_host = database.client.startup_nodes[0].host
|
||||
|
||||
base_url = f"{database.health_check_url}:{self._rest_api_port}"
|
||||
self._http_client.base_url = base_url
|
||||
|
||||
# Find bdb matching to the current database host
|
||||
matching_bdb = None
|
||||
for bdb in self._http_client.get("/v1/bdbs"):
|
||||
for endpoint in bdb["endpoints"]:
|
||||
if endpoint["dns_name"] == db_host:
|
||||
matching_bdb = bdb
|
||||
break
|
||||
|
||||
# In case if the host was set as public IP
|
||||
for addr in endpoint["addr"]:
|
||||
if addr == db_host:
|
||||
matching_bdb = bdb
|
||||
break
|
||||
|
||||
if matching_bdb is None:
|
||||
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
|
||||
raise ValueError("Could not find a matching bdb")
|
||||
|
||||
url = (
|
||||
f"/v1/bdbs/{matching_bdb['uid']}/availability"
|
||||
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
|
||||
)
|
||||
self._http_client.get(url, expect_json=False)
|
||||
|
||||
# Status checked in an http client, otherwise HttpError will be raised
|
||||
return True
|
||||
308
backend/venv/lib/python3.9/site-packages/redis/ocsp.py
Normal file
308
backend/venv/lib/python3.9/site-packages/redis/ocsp.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import base64
|
||||
import datetime
|
||||
import ssl
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import cryptography.hazmat.primitives.hashes
|
||||
import requests
|
||||
from cryptography import hazmat, x509
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat import backends
|
||||
from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA, EllipticCurvePublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
from cryptography.hazmat.primitives.hashes import SHA1, Hash
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
||||
from cryptography.x509 import ocsp
|
||||
|
||||
from redis.exceptions import AuthorizationError, ConnectionError
|
||||
|
||||
|
||||
def _verify_response(issuer_cert, ocsp_response):
|
||||
pubkey = issuer_cert.public_key()
|
||||
try:
|
||||
if isinstance(pubkey, RSAPublicKey):
|
||||
pubkey.verify(
|
||||
ocsp_response.signature,
|
||||
ocsp_response.tbs_response_bytes,
|
||||
PKCS1v15(),
|
||||
ocsp_response.signature_hash_algorithm,
|
||||
)
|
||||
elif isinstance(pubkey, DSAPublicKey):
|
||||
pubkey.verify(
|
||||
ocsp_response.signature,
|
||||
ocsp_response.tbs_response_bytes,
|
||||
ocsp_response.signature_hash_algorithm,
|
||||
)
|
||||
elif isinstance(pubkey, EllipticCurvePublicKey):
|
||||
pubkey.verify(
|
||||
ocsp_response.signature,
|
||||
ocsp_response.tbs_response_bytes,
|
||||
ECDSA(ocsp_response.signature_hash_algorithm),
|
||||
)
|
||||
else:
|
||||
pubkey.verify(ocsp_response.signature, ocsp_response.tbs_response_bytes)
|
||||
except InvalidSignature:
|
||||
raise ConnectionError("failed to valid ocsp response")
|
||||
|
||||
|
||||
def _check_certificate(issuer_cert, ocsp_bytes, validate=True):
|
||||
"""A wrapper the return the validity of a known ocsp certificate"""
|
||||
|
||||
ocsp_response = ocsp.load_der_ocsp_response(ocsp_bytes)
|
||||
|
||||
if ocsp_response.response_status == ocsp.OCSPResponseStatus.UNAUTHORIZED:
|
||||
raise AuthorizationError("you are not authorized to view this ocsp certificate")
|
||||
if ocsp_response.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL:
|
||||
if ocsp_response.certificate_status != ocsp.OCSPCertStatus.GOOD:
|
||||
raise ConnectionError(
|
||||
f"Received an {str(ocsp_response.certificate_status).split('.')[1]} "
|
||||
"ocsp certificate status"
|
||||
)
|
||||
else:
|
||||
raise ConnectionError(
|
||||
"failed to retrieve a successful response from the ocsp responder"
|
||||
)
|
||||
|
||||
if ocsp_response.this_update >= datetime.datetime.now():
|
||||
raise ConnectionError("ocsp certificate was issued in the future")
|
||||
|
||||
if (
|
||||
ocsp_response.next_update
|
||||
and ocsp_response.next_update < datetime.datetime.now()
|
||||
):
|
||||
raise ConnectionError("ocsp certificate has invalid update - in the past")
|
||||
|
||||
responder_name = ocsp_response.responder_name
|
||||
issuer_hash = ocsp_response.issuer_key_hash
|
||||
responder_hash = ocsp_response.responder_key_hash
|
||||
|
||||
cert_to_validate = issuer_cert
|
||||
if (
|
||||
responder_name is not None
|
||||
and responder_name == issuer_cert.subject
|
||||
or responder_hash == issuer_hash
|
||||
):
|
||||
cert_to_validate = issuer_cert
|
||||
else:
|
||||
certs = ocsp_response.certificates
|
||||
responder_certs = _get_certificates(
|
||||
certs, issuer_cert, responder_name, responder_hash
|
||||
)
|
||||
|
||||
try:
|
||||
responder_cert = responder_certs[0]
|
||||
except IndexError:
|
||||
raise ConnectionError("no certificates found for the responder")
|
||||
|
||||
ext = responder_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage)
|
||||
if ext is None or x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING not in ext.value:
|
||||
raise ConnectionError("delegate not autorized for ocsp signing")
|
||||
cert_to_validate = responder_cert
|
||||
|
||||
if validate:
|
||||
_verify_response(cert_to_validate, ocsp_response)
|
||||
return True
|
||||
|
||||
|
||||
def _get_certificates(certs, issuer_cert, responder_name, responder_hash):
|
||||
if responder_name is None:
|
||||
certificates = [
|
||||
c
|
||||
for c in certs
|
||||
if _get_pubkey_hash(c) == responder_hash and c.issuer == issuer_cert.subject
|
||||
]
|
||||
else:
|
||||
certificates = [
|
||||
c
|
||||
for c in certs
|
||||
if c.subject == responder_name and c.issuer == issuer_cert.subject
|
||||
]
|
||||
|
||||
return certificates
|
||||
|
||||
|
||||
def _get_pubkey_hash(certificate):
|
||||
pubkey = certificate.public_key()
|
||||
|
||||
# https://stackoverflow.com/a/46309453/600498
|
||||
if isinstance(pubkey, RSAPublicKey):
|
||||
h = pubkey.public_bytes(Encoding.DER, PublicFormat.PKCS1)
|
||||
elif isinstance(pubkey, EllipticCurvePublicKey):
|
||||
h = pubkey.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint)
|
||||
else:
|
||||
h = pubkey.public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
|
||||
|
||||
sha1 = Hash(SHA1(), backend=backends.default_backend())
|
||||
sha1.update(h)
|
||||
return sha1.finalize()
|
||||
|
||||
|
||||
def ocsp_staple_verifier(con, ocsp_bytes, expected=None):
|
||||
"""An implementation of a function for set_ocsp_client_callback in PyOpenSSL.
|
||||
|
||||
This function validates that the provide ocsp_bytes response is valid,
|
||||
and matches the expected, stapled responses.
|
||||
"""
|
||||
if ocsp_bytes in [b"", None]:
|
||||
raise ConnectionError("no ocsp response present")
|
||||
|
||||
issuer_cert = None
|
||||
peer_cert = con.get_peer_certificate().to_cryptography()
|
||||
for c in con.get_peer_cert_chain():
|
||||
cert = c.to_cryptography()
|
||||
if cert.subject == peer_cert.issuer:
|
||||
issuer_cert = cert
|
||||
break
|
||||
|
||||
if issuer_cert is None:
|
||||
raise ConnectionError("no matching issuer cert found in certificate chain")
|
||||
|
||||
if expected is not None:
|
||||
e = x509.load_pem_x509_certificate(expected)
|
||||
if peer_cert != e:
|
||||
raise ConnectionError("received and expected certificates do not match")
|
||||
|
||||
return _check_certificate(issuer_cert, ocsp_bytes)
|
||||
|
||||
|
||||
class OCSPVerifier:
|
||||
"""A class to verify ssl sockets for RFC6960/RFC6961. This can be used
|
||||
when using direct validation of OCSP responses and certificate revocations.
|
||||
|
||||
@see https://datatracker.ietf.org/doc/html/rfc6960
|
||||
@see https://datatracker.ietf.org/doc/html/rfc6961
|
||||
"""
|
||||
|
||||
def __init__(self, sock, host, port, ca_certs=None):
|
||||
self.SOCK = sock
|
||||
self.HOST = host
|
||||
self.PORT = port
|
||||
self.CA_CERTS = ca_certs
|
||||
|
||||
def _bin2ascii(self, der):
|
||||
"""Convert SSL certificates in a binary (DER) format to ASCII PEM."""
|
||||
|
||||
pem = ssl.DER_cert_to_PEM_cert(der)
|
||||
cert = x509.load_pem_x509_certificate(pem.encode(), backends.default_backend())
|
||||
return cert
|
||||
|
||||
def components_from_socket(self):
|
||||
"""This function returns the certificate, primary issuer, and primary ocsp
|
||||
server in the chain for a socket already wrapped with ssl.
|
||||
"""
|
||||
|
||||
# convert the binary certifcate to text
|
||||
der = self.SOCK.getpeercert(True)
|
||||
if der is False:
|
||||
raise ConnectionError("no certificate found for ssl peer")
|
||||
cert = self._bin2ascii(der)
|
||||
return self._certificate_components(cert)
|
||||
|
||||
def _certificate_components(self, cert):
|
||||
"""Given an SSL certificate, retract the useful components for
|
||||
validating the certificate status with an OCSP server.
|
||||
|
||||
Args:
|
||||
cert ([bytes]): A PEM encoded ssl certificate
|
||||
"""
|
||||
|
||||
try:
|
||||
aia = cert.extensions.get_extension_for_oid(
|
||||
x509.oid.ExtensionOID.AUTHORITY_INFORMATION_ACCESS
|
||||
).value
|
||||
except cryptography.x509.extensions.ExtensionNotFound:
|
||||
raise ConnectionError("No AIA information present in ssl certificate")
|
||||
|
||||
# fetch certificate issuers
|
||||
issuers = [
|
||||
i
|
||||
for i in aia
|
||||
if i.access_method == x509.oid.AuthorityInformationAccessOID.CA_ISSUERS
|
||||
]
|
||||
try:
|
||||
issuer = issuers[0].access_location.value
|
||||
except IndexError:
|
||||
issuer = None
|
||||
|
||||
# now, the series of ocsp server entries
|
||||
ocsps = [
|
||||
i
|
||||
for i in aia
|
||||
if i.access_method == x509.oid.AuthorityInformationAccessOID.OCSP
|
||||
]
|
||||
|
||||
try:
|
||||
ocsp = ocsps[0].access_location.value
|
||||
except IndexError:
|
||||
raise ConnectionError("no ocsp servers in certificate")
|
||||
|
||||
return cert, issuer, ocsp
|
||||
|
||||
def components_from_direct_connection(self):
|
||||
"""Return the certificate, primary issuer, and primary ocsp server
|
||||
from the host defined by the socket. This is useful in cases where
|
||||
different certificates are occasionally presented.
|
||||
"""
|
||||
|
||||
pem = ssl.get_server_certificate((self.HOST, self.PORT), ca_certs=self.CA_CERTS)
|
||||
cert = x509.load_pem_x509_certificate(pem.encode(), backends.default_backend())
|
||||
return self._certificate_components(cert)
|
||||
|
||||
def build_certificate_url(self, server, cert, issuer_cert):
|
||||
"""Return the complete url to the ocsp"""
|
||||
orb = ocsp.OCSPRequestBuilder()
|
||||
|
||||
# add_certificate returns an initialized OCSPRequestBuilder
|
||||
orb = orb.add_certificate(
|
||||
cert, issuer_cert, cryptography.hazmat.primitives.hashes.SHA256()
|
||||
)
|
||||
request = orb.build()
|
||||
|
||||
path = base64.b64encode(
|
||||
request.public_bytes(hazmat.primitives.serialization.Encoding.DER)
|
||||
)
|
||||
url = urljoin(server, path.decode("ascii"))
|
||||
return url
|
||||
|
||||
def check_certificate(self, server, cert, issuer_url):
|
||||
"""Checks the validity of an ocsp server for an issuer"""
|
||||
|
||||
r = requests.get(issuer_url)
|
||||
if not r.ok:
|
||||
raise ConnectionError("failed to fetch issuer certificate")
|
||||
der = r.content
|
||||
issuer_cert = self._bin2ascii(der)
|
||||
|
||||
ocsp_url = self.build_certificate_url(server, cert, issuer_cert)
|
||||
|
||||
# HTTP 1.1 mandates the addition of the Host header in ocsp responses
|
||||
header = {
|
||||
"Host": urlparse(ocsp_url).netloc,
|
||||
"Content-Type": "application/ocsp-request",
|
||||
}
|
||||
r = requests.get(ocsp_url, headers=header)
|
||||
if not r.ok:
|
||||
raise ConnectionError("failed to fetch ocsp certificate")
|
||||
return _check_certificate(issuer_cert, r.content, True)
|
||||
|
||||
def is_valid(self):
|
||||
"""Returns the validity of the certificate wrapping our socket.
|
||||
This first retrieves for validate the certificate, issuer_url,
|
||||
and ocsp_server for certificate validate. Then retrieves the
|
||||
issuer certificate from the issuer_url, and finally checks
|
||||
the validity of OCSP revocation status.
|
||||
"""
|
||||
|
||||
# validate the certificate
|
||||
try:
|
||||
cert, issuer_url, ocsp_server = self.components_from_socket()
|
||||
if issuer_url is None:
|
||||
raise ConnectionError("no issuers found in certificate chain")
|
||||
return self.check_certificate(ocsp_server, cert, issuer_url)
|
||||
except AuthorizationError:
|
||||
cert, issuer_url, ocsp_server = self.components_from_direct_connection()
|
||||
if issuer_url is None:
|
||||
raise ConnectionError("no issuers found in certificate chain")
|
||||
return self.check_certificate(ocsp_server, cert, issuer_url)
|
||||
126
backend/venv/lib/python3.9/site-packages/redis/retry.py
Normal file
126
backend/venv/lib/python3.9/site-packages/redis/retry.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import abc
|
||||
import socket
|
||||
from time import sleep
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from redis.exceptions import ConnectionError, TimeoutError
|
||||
|
||||
T = TypeVar("T")
|
||||
E = TypeVar("E", bound=Exception, covariant=True)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.backoff import AbstractBackoff
|
||||
|
||||
|
||||
class AbstractRetry(Generic[E], abc.ABC):
|
||||
"""Retry a specific number of times after a failure"""
|
||||
|
||||
_supported_errors: Tuple[Type[E], ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backoff: "AbstractBackoff",
|
||||
retries: int,
|
||||
supported_errors: Tuple[Type[E], ...],
|
||||
):
|
||||
"""
|
||||
Initialize a `Retry` object with a `Backoff` object
|
||||
that retries a maximum of `retries` times.
|
||||
`retries` can be negative to retry forever.
|
||||
You can specify the types of supported errors which trigger
|
||||
a retry with the `supported_errors` parameter.
|
||||
"""
|
||||
self._backoff = backoff
|
||||
self._retries = retries
|
||||
self._supported_errors = supported_errors
|
||||
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._backoff, self._retries, frozenset(self._supported_errors)))
|
||||
|
||||
def update_supported_errors(self, specified_errors: Iterable[Type[E]]) -> None:
|
||||
"""
|
||||
Updates the supported errors with the specified error types
|
||||
"""
|
||||
self._supported_errors = tuple(
|
||||
set(self._supported_errors + tuple(specified_errors))
|
||||
)
|
||||
|
||||
def get_retries(self) -> int:
|
||||
"""
|
||||
Get the number of retries.
|
||||
"""
|
||||
return self._retries
|
||||
|
||||
def update_retries(self, value: int) -> None:
|
||||
"""
|
||||
Set the number of retries.
|
||||
"""
|
||||
self._retries = value
|
||||
|
||||
|
||||
class Retry(AbstractRetry[Exception]):
|
||||
__hash__ = AbstractRetry.__hash__
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backoff: "AbstractBackoff",
|
||||
retries: int,
|
||||
supported_errors: Tuple[Type[Exception], ...] = (
|
||||
ConnectionError,
|
||||
TimeoutError,
|
||||
socket.timeout,
|
||||
),
|
||||
):
|
||||
super().__init__(backoff, retries, supported_errors)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Retry):
|
||||
return NotImplemented
|
||||
|
||||
return (
|
||||
self._backoff == other._backoff
|
||||
and self._retries == other._retries
|
||||
and set(self._supported_errors) == set(other._supported_errors)
|
||||
)
|
||||
|
||||
def call_with_retry(
|
||||
self,
|
||||
do: Callable[[], T],
|
||||
fail: Callable[[Exception], Any],
|
||||
is_retryable: Optional[Callable[[Exception], bool]] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Execute an operation that might fail and returns its result, or
|
||||
raise the exception that was thrown depending on the `Backoff` object.
|
||||
`do`: the operation to call. Expects no argument.
|
||||
`fail`: the failure handler, expects the last error that was thrown
|
||||
"""
|
||||
self._backoff.reset()
|
||||
failures = 0
|
||||
while True:
|
||||
try:
|
||||
return do()
|
||||
except self._supported_errors as error:
|
||||
if is_retryable and not is_retryable(error):
|
||||
raise
|
||||
failures += 1
|
||||
fail(error)
|
||||
if self._retries >= 0 and failures > self._retries:
|
||||
raise error
|
||||
backoff = self._backoff.compute(failures)
|
||||
if backoff > 0:
|
||||
sleep(backoff)
|
||||
425
backend/venv/lib/python3.9/site-packages/redis/sentinel.py
Normal file
425
backend/venv/lib/python3.9/site-packages/redis/sentinel.py
Normal file
@@ -0,0 +1,425 @@
|
||||
import random
|
||||
import weakref
|
||||
from typing import Optional
|
||||
|
||||
from redis.client import Redis
|
||||
from redis.commands import SentinelCommands
|
||||
from redis.connection import Connection, ConnectionPool, SSLConnection
|
||||
from redis.exceptions import (
|
||||
ConnectionError,
|
||||
ReadOnlyError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
|
||||
class MasterNotFoundError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class SlaveNotFoundError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class SentinelManagedConnection(Connection):
|
||||
def __init__(self, **kwargs):
|
||||
self.connection_pool = kwargs.pop("connection_pool")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
pool = self.connection_pool
|
||||
s = (
|
||||
f"<{type(self).__module__}.{type(self).__name__}"
|
||||
f"(service={pool.service_name}%s)>"
|
||||
)
|
||||
if self.host:
|
||||
host_info = f",host={self.host},port={self.port}"
|
||||
s = s % host_info
|
||||
return s
|
||||
|
||||
def connect_to(self, address):
|
||||
self.host, self.port = address
|
||||
|
||||
self.connect_check_health(
|
||||
check_health=self.connection_pool.check_connection,
|
||||
retry_socket_connect=False,
|
||||
)
|
||||
|
||||
def _connect_retry(self):
|
||||
if self._sock:
|
||||
return # already connected
|
||||
if self.connection_pool.is_master:
|
||||
self.connect_to(self.connection_pool.get_master_address())
|
||||
else:
|
||||
for slave in self.connection_pool.rotate_slaves():
|
||||
try:
|
||||
return self.connect_to(slave)
|
||||
except ConnectionError:
|
||||
continue
|
||||
raise SlaveNotFoundError # Never be here
|
||||
|
||||
def connect(self):
|
||||
return self.retry.call_with_retry(self._connect_retry, lambda error: None)
|
||||
|
||||
def read_response(
|
||||
self,
|
||||
disable_decoding=False,
|
||||
*,
|
||||
disconnect_on_error: Optional[bool] = False,
|
||||
push_request: Optional[bool] = False,
|
||||
):
|
||||
try:
|
||||
return super().read_response(
|
||||
disable_decoding=disable_decoding,
|
||||
disconnect_on_error=disconnect_on_error,
|
||||
push_request=push_request,
|
||||
)
|
||||
except ReadOnlyError:
|
||||
if self.connection_pool.is_master:
|
||||
# When talking to a master, a ReadOnlyError when likely
|
||||
# indicates that the previous master that we're still connected
|
||||
# to has been demoted to a slave and there's a new master.
|
||||
# calling disconnect will force the connection to re-query
|
||||
# sentinel during the next connect() attempt.
|
||||
self.disconnect()
|
||||
raise ConnectionError("The previous master is now a slave")
|
||||
raise
|
||||
|
||||
|
||||
class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
|
||||
pass
|
||||
|
||||
|
||||
class SentinelConnectionPoolProxy:
|
||||
def __init__(
|
||||
self,
|
||||
connection_pool,
|
||||
is_master,
|
||||
check_connection,
|
||||
service_name,
|
||||
sentinel_manager,
|
||||
):
|
||||
self.connection_pool_ref = weakref.ref(connection_pool)
|
||||
self.is_master = is_master
|
||||
self.check_connection = check_connection
|
||||
self.service_name = service_name
|
||||
self.sentinel_manager = sentinel_manager
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.master_address = None
|
||||
self.slave_rr_counter = None
|
||||
|
||||
def get_master_address(self):
|
||||
master_address = self.sentinel_manager.discover_master(self.service_name)
|
||||
if self.is_master and self.master_address != master_address:
|
||||
self.master_address = master_address
|
||||
# disconnect any idle connections so that they reconnect
|
||||
# to the new master the next time that they are used.
|
||||
connection_pool = self.connection_pool_ref()
|
||||
if connection_pool is not None:
|
||||
connection_pool.disconnect(inuse_connections=False)
|
||||
return master_address
|
||||
|
||||
def rotate_slaves(self):
|
||||
slaves = self.sentinel_manager.discover_slaves(self.service_name)
|
||||
if slaves:
|
||||
if self.slave_rr_counter is None:
|
||||
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
|
||||
for _ in range(len(slaves)):
|
||||
self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves)
|
||||
slave = slaves[self.slave_rr_counter]
|
||||
yield slave
|
||||
# Fallback to the master connection
|
||||
try:
|
||||
yield self.get_master_address()
|
||||
except MasterNotFoundError:
|
||||
pass
|
||||
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
|
||||
|
||||
|
||||
class SentinelConnectionPool(ConnectionPool):
|
||||
"""
|
||||
Sentinel backed connection pool.
|
||||
|
||||
If ``check_connection`` flag is set to True, SentinelManagedConnection
|
||||
sends a PING command right after establishing the connection.
|
||||
"""
|
||||
|
||||
def __init__(self, service_name, sentinel_manager, **kwargs):
|
||||
kwargs["connection_class"] = kwargs.get(
|
||||
"connection_class",
|
||||
(
|
||||
SentinelManagedSSLConnection
|
||||
if kwargs.pop("ssl", False)
|
||||
else SentinelManagedConnection
|
||||
),
|
||||
)
|
||||
self.is_master = kwargs.pop("is_master", True)
|
||||
self.check_connection = kwargs.pop("check_connection", False)
|
||||
self.proxy = SentinelConnectionPoolProxy(
|
||||
connection_pool=self,
|
||||
is_master=self.is_master,
|
||||
check_connection=self.check_connection,
|
||||
service_name=service_name,
|
||||
sentinel_manager=sentinel_manager,
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
self.connection_kwargs["connection_pool"] = self.proxy
|
||||
self.service_name = service_name
|
||||
self.sentinel_manager = sentinel_manager
|
||||
|
||||
def __repr__(self):
|
||||
role = "master" if self.is_master else "slave"
|
||||
return (
|
||||
f"<{type(self).__module__}.{type(self).__name__}"
|
||||
f"(service={self.service_name}({role}))>"
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.proxy.reset()
|
||||
|
||||
@property
|
||||
def master_address(self):
|
||||
return self.proxy.master_address
|
||||
|
||||
def owns_connection(self, connection):
|
||||
check = not self.is_master or (
|
||||
self.is_master and self.master_address == (connection.host, connection.port)
|
||||
)
|
||||
parent = super()
|
||||
return check and parent.owns_connection(connection)
|
||||
|
||||
def get_master_address(self):
|
||||
return self.proxy.get_master_address()
|
||||
|
||||
def rotate_slaves(self):
|
||||
"Round-robin slave balancer"
|
||||
return self.proxy.rotate_slaves()
|
||||
|
||||
|
||||
class Sentinel(SentinelCommands):
|
||||
"""
|
||||
Redis Sentinel cluster client
|
||||
|
||||
>>> from redis.sentinel import Sentinel
|
||||
>>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1)
|
||||
>>> master = sentinel.master_for('mymaster', socket_timeout=0.1)
|
||||
>>> master.set('foo', 'bar')
|
||||
>>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1)
|
||||
>>> slave.get('foo')
|
||||
b'bar'
|
||||
|
||||
``sentinels`` is a list of sentinel nodes. Each node is represented by
|
||||
a pair (hostname, port).
|
||||
|
||||
``min_other_sentinels`` defined a minimum number of peers for a sentinel.
|
||||
When querying a sentinel, if it doesn't meet this threshold, responses
|
||||
from that sentinel won't be considered valid.
|
||||
|
||||
``sentinel_kwargs`` is a dictionary of connection arguments used when
|
||||
connecting to sentinel instances. Any argument that can be passed to
|
||||
a normal Redis connection can be specified here. If ``sentinel_kwargs`` is
|
||||
not specified, any socket_timeout and socket_keepalive options specified
|
||||
in ``connection_kwargs`` will be used.
|
||||
|
||||
``connection_kwargs`` are keyword arguments that will be used when
|
||||
establishing a connection to a Redis server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sentinels,
|
||||
min_other_sentinels=0,
|
||||
sentinel_kwargs=None,
|
||||
force_master_ip=None,
|
||||
**connection_kwargs,
|
||||
):
|
||||
# if sentinel_kwargs isn't defined, use the socket_* options from
|
||||
# connection_kwargs
|
||||
if sentinel_kwargs is None:
|
||||
sentinel_kwargs = {
|
||||
k: v for k, v in connection_kwargs.items() if k.startswith("socket_")
|
||||
}
|
||||
self.sentinel_kwargs = sentinel_kwargs
|
||||
|
||||
self.sentinels = [
|
||||
Redis(hostname, port, **self.sentinel_kwargs)
|
||||
for hostname, port in sentinels
|
||||
]
|
||||
self.min_other_sentinels = min_other_sentinels
|
||||
self.connection_kwargs = connection_kwargs
|
||||
self._force_master_ip = force_master_ip
|
||||
|
||||
def execute_command(self, *args, **kwargs):
|
||||
"""
|
||||
Execute Sentinel command in sentinel nodes.
|
||||
once - If set to True, then execute the resulting command on a single
|
||||
node at random, rather than across the entire sentinel cluster.
|
||||
"""
|
||||
once = bool(kwargs.pop("once", False))
|
||||
|
||||
# Check if command is supposed to return the original
|
||||
# responses instead of boolean value.
|
||||
return_responses = bool(kwargs.pop("return_responses", False))
|
||||
|
||||
if once:
|
||||
response = random.choice(self.sentinels).execute_command(*args, **kwargs)
|
||||
if return_responses:
|
||||
return [response]
|
||||
else:
|
||||
return True if response else False
|
||||
|
||||
responses = []
|
||||
for sentinel in self.sentinels:
|
||||
responses.append(sentinel.execute_command(*args, **kwargs))
|
||||
|
||||
if return_responses:
|
||||
return responses
|
||||
|
||||
return all(responses)
|
||||
|
||||
def __repr__(self):
|
||||
sentinel_addresses = []
|
||||
for sentinel in self.sentinels:
|
||||
sentinel_addresses.append(
|
||||
"{host}:{port}".format_map(sentinel.connection_pool.connection_kwargs)
|
||||
)
|
||||
return (
|
||||
f"<{type(self).__module__}.{type(self).__name__}"
|
||||
f"(sentinels=[{','.join(sentinel_addresses)}])>"
|
||||
)
|
||||
|
||||
def check_master_state(self, state, service_name):
|
||||
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
|
||||
return False
|
||||
# Check if our sentinel doesn't see other nodes
|
||||
if state["num-other-sentinels"] < self.min_other_sentinels:
|
||||
return False
|
||||
return True
|
||||
|
||||
def discover_master(self, service_name):
|
||||
"""
|
||||
Asks sentinel servers for the Redis master's address corresponding
|
||||
to the service labeled ``service_name``.
|
||||
|
||||
Returns a pair (address, port) or raises MasterNotFoundError if no
|
||||
master is found.
|
||||
"""
|
||||
collected_errors = list()
|
||||
for sentinel_no, sentinel in enumerate(self.sentinels):
|
||||
try:
|
||||
masters = sentinel.sentinel_masters()
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
collected_errors.append(f"{sentinel} - {e!r}")
|
||||
continue
|
||||
state = masters.get(service_name)
|
||||
if state and self.check_master_state(state, service_name):
|
||||
# Put this sentinel at the top of the list
|
||||
self.sentinels[0], self.sentinels[sentinel_no] = (
|
||||
sentinel,
|
||||
self.sentinels[0],
|
||||
)
|
||||
|
||||
ip = (
|
||||
self._force_master_ip
|
||||
if self._force_master_ip is not None
|
||||
else state["ip"]
|
||||
)
|
||||
return ip, state["port"]
|
||||
|
||||
error_info = ""
|
||||
if len(collected_errors) > 0:
|
||||
error_info = f" : {', '.join(collected_errors)}"
|
||||
raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}")
|
||||
|
||||
def filter_slaves(self, slaves):
|
||||
"Remove slaves that are in an ODOWN or SDOWN state"
|
||||
slaves_alive = []
|
||||
for slave in slaves:
|
||||
if slave["is_odown"] or slave["is_sdown"]:
|
||||
continue
|
||||
slaves_alive.append((slave["ip"], slave["port"]))
|
||||
return slaves_alive
|
||||
|
||||
def discover_slaves(self, service_name):
|
||||
"Returns a list of alive slaves for service ``service_name``"
|
||||
for sentinel in self.sentinels:
|
||||
try:
|
||||
slaves = sentinel.sentinel_slaves(service_name)
|
||||
except (ConnectionError, ResponseError, TimeoutError):
|
||||
continue
|
||||
slaves = self.filter_slaves(slaves)
|
||||
if slaves:
|
||||
return slaves
|
||||
return []
|
||||
|
||||
def master_for(
|
||||
self,
|
||||
service_name,
|
||||
redis_class=Redis,
|
||||
connection_pool_class=SentinelConnectionPool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a redis client instance for the ``service_name`` master.
|
||||
Sentinel client will detect failover and reconnect Redis clients
|
||||
automatically.
|
||||
|
||||
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
|
||||
used to retrieve the master's address before establishing a new
|
||||
connection.
|
||||
|
||||
NOTE: If the master's address has changed, any cached connections to
|
||||
the old master are closed.
|
||||
|
||||
By default clients will be a :py:class:`~redis.Redis` instance.
|
||||
Specify a different class to the ``redis_class`` argument if you
|
||||
desire something different.
|
||||
|
||||
The ``connection_pool_class`` specifies the connection pool to
|
||||
use. The :py:class:`~redis.sentinel.SentinelConnectionPool`
|
||||
will be used by default.
|
||||
|
||||
All other keyword arguments are merged with any connection_kwargs
|
||||
passed to this class and passed to the connection pool as keyword
|
||||
arguments to be used to initialize Redis connections.
|
||||
"""
|
||||
kwargs["is_master"] = True
|
||||
connection_kwargs = dict(self.connection_kwargs)
|
||||
connection_kwargs.update(kwargs)
|
||||
return redis_class.from_pool(
|
||||
connection_pool_class(service_name, self, **connection_kwargs)
|
||||
)
|
||||
|
||||
def slave_for(
|
||||
self,
|
||||
service_name,
|
||||
redis_class=Redis,
|
||||
connection_pool_class=SentinelConnectionPool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns redis client instance for the ``service_name`` slave(s).
|
||||
|
||||
A SentinelConnectionPool class is used to retrieve the slave's
|
||||
address before establishing a new connection.
|
||||
|
||||
By default clients will be a :py:class:`~redis.Redis` instance.
|
||||
Specify a different class to the ``redis_class`` argument if you
|
||||
desire something different.
|
||||
|
||||
The ``connection_pool_class`` specifies the connection pool to use.
|
||||
The SentinelConnectionPool will be used by default.
|
||||
|
||||
All other keyword arguments are merged with any connection_kwargs
|
||||
passed to this class and passed to the connection pool as keyword
|
||||
arguments to be used to initialize Redis connections.
|
||||
"""
|
||||
kwargs["is_master"] = False
|
||||
connection_kwargs = dict(self.connection_kwargs)
|
||||
connection_kwargs.update(kwargs)
|
||||
return redis_class.from_pool(
|
||||
connection_pool_class(service_name, self, **connection_kwargs)
|
||||
)
|
||||
57
backend/venv/lib/python3.9/site-packages/redis/typing.py
Normal file
57
backend/venv/lib/python3.9/site-packages/redis/typing.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis._parsers import Encoder
|
||||
|
||||
|
||||
Number = Union[int, float]
|
||||
EncodedT = Union[bytes, bytearray, memoryview]
|
||||
DecodedT = Union[str, int, float]
|
||||
EncodableT = Union[EncodedT, DecodedT]
|
||||
AbsExpiryT = Union[int, datetime]
|
||||
ExpiryT = Union[int, timedelta]
|
||||
ZScoreBoundT = Union[float, str] # str allows for the [ or ( prefix
|
||||
BitfieldOffsetT = Union[int, str] # str allows for #x syntax
|
||||
_StringLikeT = Union[bytes, str, memoryview]
|
||||
KeyT = _StringLikeT # Main redis key space
|
||||
PatternT = _StringLikeT # Patterns matched against keys, fields etc
|
||||
FieldT = EncodableT # Fields within hash tables, streams and geo commands
|
||||
KeysT = Union[KeyT, Iterable[KeyT]]
|
||||
ResponseT = Union[Awaitable[Any], Any]
|
||||
ChannelT = _StringLikeT
|
||||
GroupT = _StringLikeT # Consumer group
|
||||
ConsumerT = _StringLikeT # Consumer name
|
||||
StreamIdT = Union[int, _StringLikeT]
|
||||
ScriptTextT = _StringLikeT
|
||||
TimeoutSecT = Union[int, float, _StringLikeT]
|
||||
# Mapping is not covariant in the key type, which prevents
|
||||
# Mapping[_StringLikeT, X] from accepting arguments of type Dict[str, X]. Using
|
||||
# a TypeVar instead of a Union allows mappings with any of the permitted types
|
||||
# to be passed. Care is needed if there is more than one such mapping in a
|
||||
# type signature because they will all be required to be the same key type.
|
||||
AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview)
|
||||
AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview)
|
||||
AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview)
|
||||
|
||||
ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]]
|
||||
|
||||
|
||||
class CommandsProtocol(Protocol):
|
||||
def execute_command(self, *args, **options) -> ResponseT: ...
|
||||
|
||||
|
||||
class ClusterCommandsProtocol(CommandsProtocol):
|
||||
encoder: "Encoder"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user