first commit

This commit is contained in:
2026-02-08 14:42:58 +08:00
commit 20e1deae21
8197 changed files with 2264639 additions and 0 deletions

View File

@@ -0,0 +1,88 @@
from redis import asyncio # noqa
from redis.backoff import default_backoff
from redis.client import Redis, StrictRedis
from redis.cluster import RedisCluster
from redis.connection import (
BlockingConnectionPool,
Connection,
ConnectionPool,
SSLConnection,
UnixDomainSocketConnection,
)
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
BusyLoadingError,
ChildDeadlockedError,
ConnectionError,
CrossSlotTransactionError,
DataError,
InvalidPipelineStack,
InvalidResponse,
MaxConnectionsError,
OutOfMemoryError,
PubSubError,
ReadOnlyError,
RedisClusterException,
RedisError,
ResponseError,
TimeoutError,
WatchError,
)
from redis.sentinel import (
Sentinel,
SentinelConnectionPool,
SentinelManagedConnection,
SentinelManagedSSLConnection,
)
from redis.utils import from_url
def int_or_str(value):
try:
return int(value)
except ValueError:
return value
__version__ = "7.0.1"
VERSION = tuple(map(int_or_str, __version__.split(".")))
__all__ = [
"AuthenticationError",
"AuthenticationWrongNumberOfArgsError",
"BlockingConnectionPool",
"BusyLoadingError",
"ChildDeadlockedError",
"Connection",
"ConnectionError",
"ConnectionPool",
"CredentialProvider",
"CrossSlotTransactionError",
"DataError",
"from_url",
"default_backoff",
"InvalidPipelineStack",
"InvalidResponse",
"MaxConnectionsError",
"OutOfMemoryError",
"PubSubError",
"ReadOnlyError",
"Redis",
"RedisCluster",
"RedisClusterException",
"RedisError",
"ResponseError",
"Sentinel",
"SentinelConnectionPool",
"SentinelManagedConnection",
"SentinelManagedSSLConnection",
"SSLConnection",
"UsernamePasswordCredentialProvider",
"StrictRedis",
"TimeoutError",
"UnixDomainSocketConnection",
"WatchError",
]

View File

@@ -0,0 +1,27 @@
from .base import (
AsyncPushNotificationsParser,
BaseParser,
PushNotificationsParser,
_AsyncRESPBase,
)
from .commands import AsyncCommandsParser, CommandsParser
from .encoders import Encoder
from .hiredis import _AsyncHiredisParser, _HiredisParser
from .resp2 import _AsyncRESP2Parser, _RESP2Parser
from .resp3 import _AsyncRESP3Parser, _RESP3Parser
__all__ = [
"AsyncCommandsParser",
"_AsyncHiredisParser",
"_AsyncRESPBase",
"_AsyncRESP2Parser",
"_AsyncRESP3Parser",
"AsyncPushNotificationsParser",
"CommandsParser",
"Encoder",
"BaseParser",
"_HiredisParser",
"_RESP2Parser",
"_RESP3Parser",
"PushNotificationsParser",
]

View File

@@ -0,0 +1,474 @@
import logging
import sys
from abc import ABC
from asyncio import IncompleteReadError, StreamReader, TimeoutError
from typing import Awaitable, Callable, List, Optional, Protocol, Union
from redis.maint_notifications import (
MaintenanceNotification,
NodeFailedOverNotification,
NodeFailingOverNotification,
NodeMigratedNotification,
NodeMigratingNotification,
NodeMovingNotification,
)
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout
from ..exceptions import (
AskError,
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
BusyLoadingError,
ClusterCrossSlotError,
ClusterDownError,
ConnectionError,
ExecAbortError,
ExternalAuthProviderError,
MasterDownError,
ModuleError,
MovedError,
NoPermissionError,
NoScriptError,
OutOfMemoryError,
ReadOnlyError,
RedisError,
ResponseError,
TryAgainError,
)
from ..typing import EncodableT
from .encoders import Encoder
from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
MODULE_EXPORTS_DATA_TYPES_ERROR = (
"Error unloading module: the module "
"exports one or more module-side data "
"types, can't unload"
)
# user send an AUTH cmd to a server without authorization configured
NO_AUTH_SET_ERROR = {
# Redis >= 6.0
"AUTH <password> called without any password "
"configured for the default user. Are you sure "
"your configuration is correct?": AuthenticationError,
# Redis < 6.0
"Client sent AUTH, but no password is set": AuthenticationError,
}
EXTERNAL_AUTH_PROVIDER_ERROR = {
"problem with LDAP service": ExternalAuthProviderError,
}
logger = logging.getLogger(__name__)
class BaseParser(ABC):
EXCEPTION_CLASSES = {
"ERR": {
"max number of clients reached": ConnectionError,
"invalid password": AuthenticationError,
# some Redis server versions report invalid command syntax
# in lowercase
"wrong number of arguments "
"for 'auth' command": AuthenticationWrongNumberOfArgsError,
# some Redis server versions report invalid command syntax
# in uppercase
"wrong number of arguments "
"for 'AUTH' command": AuthenticationWrongNumberOfArgsError,
MODULE_LOAD_ERROR: ModuleError,
MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
NO_SUCH_MODULE_ERROR: ModuleError,
MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
**NO_AUTH_SET_ERROR,
**EXTERNAL_AUTH_PROVIDER_ERROR,
},
"OOM": OutOfMemoryError,
"WRONGPASS": AuthenticationError,
"EXECABORT": ExecAbortError,
"LOADING": BusyLoadingError,
"NOSCRIPT": NoScriptError,
"READONLY": ReadOnlyError,
"NOAUTH": AuthenticationError,
"NOPERM": NoPermissionError,
"ASK": AskError,
"TRYAGAIN": TryAgainError,
"MOVED": MovedError,
"CLUSTERDOWN": ClusterDownError,
"CROSSSLOT": ClusterCrossSlotError,
"MASTERDOWN": MasterDownError,
}
@classmethod
def parse_error(cls, response):
"Parse an error response"
error_code = response.split(" ")[0]
if error_code in cls.EXCEPTION_CLASSES:
response = response[len(error_code) + 1 :]
exception_class = cls.EXCEPTION_CLASSES[error_code]
if isinstance(exception_class, dict):
exception_class = exception_class.get(response, ResponseError)
return exception_class(response)
return ResponseError(response)
def on_disconnect(self):
raise NotImplementedError()
def on_connect(self, connection):
raise NotImplementedError()
class _RESPBase(BaseParser):
"""Base class for sync-based resp parsing"""
def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
self.encoder = None
self._sock = None
self._buffer = None
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
"Called when the socket connects"
self._sock = connection._sock
self._buffer = SocketBuffer(
self._sock, self.socket_read_size, connection.socket_timeout
)
self.encoder = connection.encoder
def on_disconnect(self):
"Called when the socket disconnects"
self._sock = None
if self._buffer is not None:
self._buffer.close()
self._buffer = None
self.encoder = None
def can_read(self, timeout):
return self._buffer and self._buffer.can_read(timeout)
class AsyncBaseParser(BaseParser):
"""Base parsing class for the python-backed async parser"""
__slots__ = "_stream", "_read_size"
def __init__(self, socket_read_size: int):
self._stream: Optional[StreamReader] = None
self._read_size = socket_read_size
async def can_read_destructive(self) -> bool:
raise NotImplementedError()
async def read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
raise NotImplementedError()
class MaintenanceNotificationsParser:
"""Protocol defining maintenance push notification parsing functionality"""
@staticmethod
def parse_maintenance_start_msg(response, notification_type):
# Expected message format is: <notification_type> <seq_number> <time>
id = response[1]
ttl = response[2]
return notification_type(id, ttl)
@staticmethod
def parse_maintenance_completed_msg(response, notification_type):
# Expected message format is: <notification_type> <seq_number>
id = response[1]
return notification_type(id)
@staticmethod
def parse_moving_msg(response):
# Expected message format is: MOVING <seq_number> <time> <endpoint>
id = response[1]
ttl = response[2]
if response[3] is None:
host, port = None, None
else:
value = response[3]
if isinstance(value, bytes):
value = value.decode()
host, port = value.split(":")
port = int(port) if port is not None else None
return NodeMovingNotification(id, host, port, ttl)
_INVALIDATION_MESSAGE = "invalidate"
_MOVING_MESSAGE = "MOVING"
_MIGRATING_MESSAGE = "MIGRATING"
_MIGRATED_MESSAGE = "MIGRATED"
_FAILING_OVER_MESSAGE = "FAILING_OVER"
_FAILED_OVER_MESSAGE = "FAILED_OVER"
_MAINTENANCE_MESSAGES = (
_MIGRATING_MESSAGE,
_MIGRATED_MESSAGE,
_FAILING_OVER_MESSAGE,
_FAILED_OVER_MESSAGE,
)
MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING: dict[
str, tuple[type[MaintenanceNotification], Callable]
] = {
_MIGRATING_MESSAGE: (
NodeMigratingNotification,
MaintenanceNotificationsParser.parse_maintenance_start_msg,
),
_MIGRATED_MESSAGE: (
NodeMigratedNotification,
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
),
_FAILING_OVER_MESSAGE: (
NodeFailingOverNotification,
MaintenanceNotificationsParser.parse_maintenance_start_msg,
),
_FAILED_OVER_MESSAGE: (
NodeFailedOverNotification,
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
),
_MOVING_MESSAGE: (
NodeMovingNotification,
MaintenanceNotificationsParser.parse_moving_msg,
),
}
class PushNotificationsParser(Protocol):
"""Protocol defining RESP3-specific parsing functionality"""
pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
node_moving_push_handler_func: Optional[Callable] = None
maintenance_push_handler_func: Optional[Callable] = None
def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses"""
raise NotImplementedError()
def handle_push_response(self, response, **kwargs):
msg_type = response[0]
if isinstance(msg_type, bytes):
msg_type = msg_type.decode()
if msg_type not in (
_INVALIDATION_MESSAGE,
*_MAINTENANCE_MESSAGES,
_MOVING_MESSAGE,
):
return self.pubsub_push_handler_func(response)
try:
if (
msg_type == _INVALIDATION_MESSAGE
and self.invalidation_push_handler_func
):
return self.invalidation_push_handler_func(response)
if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][1]
notification = parser_function(response)
return self.node_moving_push_handler_func(notification)
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][1]
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][0]
notification = parser_function(response, notification_type)
if notification is not None:
return self.maintenance_push_handler_func(notification)
except Exception as e:
logger.error(
"Error handling {} message ({}): {}".format(msg_type, response, e)
)
return None
def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func
def set_invalidation_push_handler(self, invalidation_push_handler_func):
self.invalidation_push_handler_func = invalidation_push_handler_func
def set_node_moving_push_handler(self, node_moving_push_handler_func):
self.node_moving_push_handler_func = node_moving_push_handler_func
def set_maintenance_push_handler(self, maintenance_push_handler_func):
self.maintenance_push_handler_func = maintenance_push_handler_func
class AsyncPushNotificationsParser(Protocol):
"""Protocol defining async RESP3-specific parsing functionality"""
pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
async def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses asynchronously"""
raise NotImplementedError()
async def handle_push_response(self, response, **kwargs):
"""Handle push responses asynchronously"""
msg_type = response[0]
if isinstance(msg_type, bytes):
msg_type = msg_type.decode()
if msg_type not in (
_INVALIDATION_MESSAGE,
*_MAINTENANCE_MESSAGES,
_MOVING_MESSAGE,
):
return await self.pubsub_push_handler_func(response)
try:
if (
msg_type == _INVALIDATION_MESSAGE
and self.invalidation_push_handler_func
):
return await self.invalidation_push_handler_func(response)
if isinstance(msg_type, bytes):
msg_type = msg_type.decode()
if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][1]
notification = parser_function(response)
return await self.node_moving_push_handler_func(notification)
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][1]
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][0]
notification = parser_function(response, notification_type)
if notification is not None:
return await self.maintenance_push_handler_func(notification)
except Exception as e:
logger.error(
"Error handling {} message ({}): {}".format(msg_type, response, e)
)
return None
def set_pubsub_push_handler(self, pubsub_push_handler_func):
"""Set the pubsub push handler function"""
self.pubsub_push_handler_func = pubsub_push_handler_func
def set_invalidation_push_handler(self, invalidation_push_handler_func):
"""Set the invalidation push handler function"""
self.invalidation_push_handler_func = invalidation_push_handler_func
def set_node_moving_push_handler(self, node_moving_push_handler_func):
self.node_moving_push_handler_func = node_moving_push_handler_func
def set_maintenance_push_handler(self, maintenance_push_handler_func):
self.maintenance_push_handler_func = maintenance_push_handler_func
class _AsyncRESPBase(AsyncBaseParser):
"""Base class for async resp parsing"""
__slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
self.encoder: Optional[Encoder] = None
self._buffer = b""
self._chunks = []
self._pos = 0
def _clear(self):
self._buffer = b""
self._chunks.clear()
def on_connect(self, connection):
"""Called when the stream connects"""
self._stream = connection._reader
if self._stream is None:
raise RedisError("Buffer is closed.")
self.encoder = connection.encoder
self._clear()
self._connected = True
def on_disconnect(self):
"""Called when the stream disconnects"""
self._connected = False
async def can_read_destructive(self) -> bool:
if not self._connected:
raise RedisError("Buffer is closed.")
if self._buffer:
return True
try:
async with async_timeout(0):
return self._stream.at_eof()
except TimeoutError:
return False
async def _read(self, length: int) -> bytes:
"""
Read `length` bytes of data. These are assumed to be followed
by a '\r\n' terminator which is subsequently discarded.
"""
want = length + 2
end = self._pos + want
if len(self._buffer) >= end:
result = self._buffer[self._pos : end - 2]
else:
tail = self._buffer[self._pos :]
try:
data = await self._stream.readexactly(want - len(tail))
except IncompleteReadError as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
result = (tail + data)[:-2]
self._chunks.append(data)
self._pos += want
return result
async def _readline(self) -> bytes:
"""
read an unknown number of bytes up to the next '\r\n'
line separator, which is discarded.
"""
found = self._buffer.find(b"\r\n", self._pos)
if found >= 0:
result = self._buffer[self._pos : found]
else:
tail = self._buffer[self._pos :]
data = await self._stream.readline()
if not data.endswith(b"\r\n"):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
result = (tail + data)[:-2]
self._chunks.append(data)
self._pos += len(result) + 2
return result

View File

@@ -0,0 +1,281 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from redis.exceptions import RedisError, ResponseError
from redis.utils import str_if_bytes
if TYPE_CHECKING:
from redis.asyncio.cluster import ClusterNode
class AbstractCommandsParser:
def _get_pubsub_keys(self, *args):
"""
Get the keys from pubsub command.
Although PubSub commands have predetermined key locations, they are not
supported in the 'COMMAND's output, so the key positions are hardcoded
in this method
"""
if len(args) < 2:
# The command has no keys in it
return None
args = [str_if_bytes(arg) for arg in args]
command = args[0].upper()
keys = None
if command == "PUBSUB":
# the second argument is a part of the command name, e.g.
# ['PUBSUB', 'NUMSUB', 'foo'].
pubsub_type = args[1].upper()
if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
keys = args[2:]
elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
# format example:
# SUBSCRIBE channel [channel ...]
keys = list(args[1:])
elif command in ["PUBLISH", "SPUBLISH"]:
# format example:
# PUBLISH channel message
keys = [args[1]]
return keys
def parse_subcommand(self, command, **options):
cmd_dict = {}
cmd_name = str_if_bytes(command[0])
cmd_dict["name"] = cmd_name
cmd_dict["arity"] = int(command[1])
cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]]
cmd_dict["first_key_pos"] = command[3]
cmd_dict["last_key_pos"] = command[4]
cmd_dict["step_count"] = command[5]
if len(command) > 7:
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]
return cmd_dict
class CommandsParser(AbstractCommandsParser):
"""
Parses Redis commands to get command keys.
COMMAND output is used to determine key locations.
Commands that do not have a predefined key location are flagged with
'movablekeys', and these commands' keys are determined by the command
'COMMAND GETKEYS'.
"""
def __init__(self, redis_connection):
self.commands = {}
self.initialize(redis_connection)
def initialize(self, r):
commands = r.command()
uppercase_commands = []
for cmd in commands:
if any(x.isupper() for x in cmd):
uppercase_commands.append(cmd)
for cmd in uppercase_commands:
commands[cmd.lower()] = commands.pop(cmd)
self.commands = commands
# As soon as this PR is merged into Redis, we should reimplement
# our logic to use COMMAND INFO changes to determine the key positions
# https://github.com/redis/redis/pull/8324
def get_keys(self, redis_conn, *args):
"""
Get the keys from the passed command.
NOTE: Due to a bug in redis<7.0, this function does not work properly
for EVAL or EVALSHA when the `numkeys` arg is 0.
- issue: https://github.com/redis/redis/issues/9493
- fix: https://github.com/redis/redis/pull/9733
So, don't use this function with EVAL or EVALSHA.
"""
if len(args) < 2:
# The command has no keys in it
return None
cmd_name = args[0].lower()
if cmd_name not in self.commands:
# try to split the command name and to take only the main command,
# e.g. 'memory' for 'memory usage'
cmd_name_split = cmd_name.split()
cmd_name = cmd_name_split[0]
if cmd_name in self.commands:
# save the splitted command to args
args = cmd_name_split + list(args[1:])
else:
# We'll try to reinitialize the commands cache, if the engine
# version has changed, the commands may not be current
self.initialize(redis_conn)
if cmd_name not in self.commands:
raise RedisError(
f"{cmd_name.upper()} command doesn't exist in Redis commands"
)
command = self.commands.get(cmd_name)
if "movablekeys" in command["flags"]:
keys = self._get_moveable_keys(redis_conn, *args)
elif "pubsub" in command["flags"] or command["name"] == "pubsub":
keys = self._get_pubsub_keys(*args)
else:
if (
command["step_count"] == 0
and command["first_key_pos"] == 0
and command["last_key_pos"] == 0
):
is_subcmd = False
if "subcommands" in command:
subcmd_name = f"{cmd_name}|{args[1].lower()}"
for subcmd in command["subcommands"]:
if str_if_bytes(subcmd[0]) == subcmd_name:
command = self.parse_subcommand(subcmd)
is_subcmd = True
# The command doesn't have keys in it
if not is_subcmd:
return None
last_key_pos = command["last_key_pos"]
if last_key_pos < 0:
last_key_pos = len(args) - abs(last_key_pos)
keys_pos = list(
range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
)
keys = [args[pos] for pos in keys_pos]
return keys
def _get_moveable_keys(self, redis_conn, *args):
"""
NOTE: Due to a bug in redis<7.0, this function does not work properly
for EVAL or EVALSHA when the `numkeys` arg is 0.
- issue: https://github.com/redis/redis/issues/9493
- fix: https://github.com/redis/redis/pull/9733
So, don't use this function with EVAL or EVALSHA.
"""
# The command name should be splitted into separate arguments,
# e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE']
pieces = args[0].split() + list(args[1:])
try:
keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces)
except ResponseError as e:
message = e.__str__()
if (
"Invalid arguments" in message
or "The command has no key arguments" in message
):
return None
else:
raise e
return keys
class AsyncCommandsParser(AbstractCommandsParser):
"""
Parses Redis commands to get command keys.
COMMAND output is used to determine key locations.
Commands that do not have a predefined key location are flagged with 'movablekeys',
and these commands' keys are determined by the command 'COMMAND GETKEYS'.
NOTE: Due to a bug in redis<7.0, this does not work properly
for EVAL or EVALSHA when the `numkeys` arg is 0.
- issue: https://github.com/redis/redis/issues/9493
- fix: https://github.com/redis/redis/pull/9733
So, don't use this with EVAL or EVALSHA.
"""
__slots__ = ("commands", "node")
def __init__(self) -> None:
self.commands: Dict[str, Union[int, Dict[str, Any]]] = {}
async def initialize(self, node: Optional["ClusterNode"] = None) -> None:
if node:
self.node = node
commands = await self.node.execute_command("COMMAND")
self.commands = {cmd.lower(): command for cmd, command in commands.items()}
# As soon as this PR is merged into Redis, we should reimplement
# our logic to use COMMAND INFO changes to determine the key positions
# https://github.com/redis/redis/pull/8324
async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
"""
Get the keys from the passed command.
NOTE: Due to a bug in redis<7.0, this function does not work properly
for EVAL or EVALSHA when the `numkeys` arg is 0.
- issue: https://github.com/redis/redis/issues/9493
- fix: https://github.com/redis/redis/pull/9733
So, don't use this function with EVAL or EVALSHA.
"""
if len(args) < 2:
# The command has no keys in it
return None
cmd_name = args[0].lower()
if cmd_name not in self.commands:
# try to split the command name and to take only the main command,
# e.g. 'memory' for 'memory usage'
cmd_name_split = cmd_name.split()
cmd_name = cmd_name_split[0]
if cmd_name in self.commands:
# save the splitted command to args
args = cmd_name_split + list(args[1:])
else:
# We'll try to reinitialize the commands cache, if the engine
# version has changed, the commands may not be current
await self.initialize()
if cmd_name not in self.commands:
raise RedisError(
f"{cmd_name.upper()} command doesn't exist in Redis commands"
)
command = self.commands.get(cmd_name)
if "movablekeys" in command["flags"]:
keys = await self._get_moveable_keys(*args)
elif "pubsub" in command["flags"] or command["name"] == "pubsub":
keys = self._get_pubsub_keys(*args)
else:
if (
command["step_count"] == 0
and command["first_key_pos"] == 0
and command["last_key_pos"] == 0
):
is_subcmd = False
if "subcommands" in command:
subcmd_name = f"{cmd_name}|{args[1].lower()}"
for subcmd in command["subcommands"]:
if str_if_bytes(subcmd[0]) == subcmd_name:
command = self.parse_subcommand(subcmd)
is_subcmd = True
# The command doesn't have keys in it
if not is_subcmd:
return None
last_key_pos = command["last_key_pos"]
if last_key_pos < 0:
last_key_pos = len(args) - abs(last_key_pos)
keys_pos = list(
range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
)
keys = [args[pos] for pos in keys_pos]
return keys
async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
try:
keys = await self.node.execute_command("COMMAND GETKEYS", *args)
except ResponseError as e:
message = e.__str__()
if (
"Invalid arguments" in message
or "The command has no key arguments" in message
):
return None
else:
raise e
return keys

View File

@@ -0,0 +1,44 @@
from ..exceptions import DataError
class Encoder:
"Encode strings to bytes-like and decode bytes-like to strings"
__slots__ = "encoding", "encoding_errors", "decode_responses"
def __init__(self, encoding, encoding_errors, decode_responses):
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses
def encode(self, value):
"Return a bytestring or bytes-like representation of the value"
if isinstance(value, (bytes, memoryview)):
return value
elif isinstance(value, bool):
# special case bool since it is a subclass of int
raise DataError(
"Invalid input of type: 'bool'. Convert to a "
"bytes, string, int or float first."
)
elif isinstance(value, (int, float)):
value = repr(value).encode()
elif not isinstance(value, str):
# a value we don't know how to deal with. throw an error
typename = type(value).__name__
raise DataError(
f"Invalid input of type: '{typename}'. "
f"Convert to a bytes, string, int or float first."
)
if isinstance(value, str):
value = value.encode(self.encoding, self.encoding_errors)
return value
def decode(self, value, force=False):
"Return a unicode string from the bytes-like representation"
if self.decode_responses or force:
if isinstance(value, memoryview):
value = value.tobytes()
if isinstance(value, bytes):
value = value.decode(self.encoding, self.encoding_errors)
return value

View File

@@ -0,0 +1,941 @@
import datetime
from redis.utils import str_if_bytes
def timestamp_to_datetime(response):
"Converts a unix timestamp to a Python datetime object"
if not response:
return None
try:
response = int(response)
except ValueError:
return None
return datetime.datetime.fromtimestamp(response)
def parse_debug_object(response):
"Parse the results of Redis's DEBUG OBJECT command into a Python dict"
# The 'type' of the object is the first item in the response, but isn't
# prefixed with a name
response = str_if_bytes(response)
response = "type:" + response
response = dict(kv.split(":") for kv in response.split())
# parse some expected int values from the string response
# note: this cmd isn't spec'd so these may not appear in all redis versions
int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle")
for field in int_fields:
if field in response:
response[field] = int(response[field])
return response
def parse_info(response):
"""Parse the result of Redis's INFO command into a Python dict"""
info = {}
response = str_if_bytes(response)
def get_value(value):
if "," not in value and "=" not in value:
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
return value
elif "=" not in value:
return [get_value(v) for v in value.split(",") if v]
else:
sub_dict = {}
for item in value.split(","):
if not item:
continue
if "=" in item:
k, v = item.rsplit("=", 1)
sub_dict[k] = get_value(v)
else:
sub_dict[item] = True
return sub_dict
for line in response.splitlines():
if line and not line.startswith("#"):
if line.find(":") != -1:
# Split, the info fields keys and values.
# Note that the value may contain ':'. but the 'host:'
# pseudo-command is the only case where the key contains ':'
key, value = line.split(":", 1)
if key == "cmdstat_host":
key, value = line.rsplit(":", 1)
if key == "module":
# Hardcode a list for key 'modules' since there could be
# multiple lines that started with 'module'
info.setdefault("modules", []).append(get_value(value))
else:
info[key] = get_value(value)
else:
# if the line isn't splittable, append it to the "__raw__" key
info.setdefault("__raw__", []).append(line)
return info
def parse_memory_stats(response, **kwargs):
"""Parse the results of MEMORY STATS"""
stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True)
for key, value in stats.items():
if key.startswith("db.") and isinstance(value, list):
stats[key] = pairs_to_dict(
value, decode_keys=True, decode_string_values=True
)
return stats
SENTINEL_STATE_TYPES = {
"can-failover-its-master": int,
"config-epoch": int,
"down-after-milliseconds": int,
"failover-timeout": int,
"info-refresh": int,
"last-hello-message": int,
"last-ok-ping-reply": int,
"last-ping-reply": int,
"last-ping-sent": int,
"master-link-down-time": int,
"master-port": int,
"num-other-sentinels": int,
"num-slaves": int,
"o-down-time": int,
"pending-commands": int,
"parallel-syncs": int,
"port": int,
"quorum": int,
"role-reported-time": int,
"s-down-time": int,
"slave-priority": int,
"slave-repl-offset": int,
"voted-leader-epoch": int,
}
def parse_sentinel_state(item):
result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES)
flags = set(result["flags"].split(","))
for name, flag in (
("is_master", "master"),
("is_slave", "slave"),
("is_sdown", "s_down"),
("is_odown", "o_down"),
("is_sentinel", "sentinel"),
("is_disconnected", "disconnected"),
("is_master_down", "master_down"),
):
result[name] = flag in flags
return result
def parse_sentinel_master(response):
return parse_sentinel_state(map(str_if_bytes, response))
def parse_sentinel_state_resp3(response):
result = {}
for key in response:
try:
value = SENTINEL_STATE_TYPES[key](str_if_bytes(response[key]))
result[str_if_bytes(key)] = value
except Exception:
result[str_if_bytes(key)] = response[str_if_bytes(key)]
flags = set(result["flags"].split(","))
result["flags"] = flags
return result
def parse_sentinel_masters(response):
result = {}
for item in response:
state = parse_sentinel_state(map(str_if_bytes, item))
result[state["name"]] = state
return result
def parse_sentinel_masters_resp3(response):
return [parse_sentinel_state(master) for master in response]
def parse_sentinel_slaves_and_sentinels(response):
return [parse_sentinel_state(map(str_if_bytes, item)) for item in response]
def parse_sentinel_slaves_and_sentinels_resp3(response):
return [parse_sentinel_state_resp3(item) for item in response]
def parse_sentinel_get_master(response):
return response and (response[0], int(response[1])) or None
def pairs_to_dict(response, decode_keys=False, decode_string_values=False):
"""Create a dict given a list of key/value pairs"""
if response is None:
return {}
if decode_keys or decode_string_values:
# the iter form is faster, but I don't know how to make that work
# with a str_if_bytes() map
keys = response[::2]
if decode_keys:
keys = map(str_if_bytes, keys)
values = response[1::2]
if decode_string_values:
values = map(str_if_bytes, values)
return dict(zip(keys, values))
else:
it = iter(response)
return dict(zip(it, it))
def pairs_to_dict_typed(response, type_info):
it = iter(response)
result = {}
for key, value in zip(it, it):
if key in type_info:
try:
value = type_info[key](value)
except Exception:
# if for some reason the value can't be coerced, just use
# the string value
pass
result[key] = value
return result
def zset_score_pairs(response, **options):
"""
If ``withscores`` is specified in the options, return the response as
a list of (value, score) pairs
"""
if not response or not options.get("withscores"):
return response
score_cast_func = options.get("score_cast_func", float)
it = iter(response)
return list(zip(it, map(score_cast_func, it)))
def zset_score_for_rank(response, **options):
"""
If ``withscores`` is specified in the options, return the response as
a [value, score] pair
"""
if not response or not options.get("withscore"):
return response
score_cast_func = options.get("score_cast_func", float)
return [response[0], score_cast_func(response[1])]
def zset_score_pairs_resp3(response, **options):
"""
If ``withscores`` is specified in the options, return the response as
a list of [value, score] pairs
"""
if not response or not options.get("withscores"):
return response
score_cast_func = options.get("score_cast_func", float)
return [[name, score_cast_func(val)] for name, val in response]
def zset_score_for_rank_resp3(response, **options):
"""
If ``withscores`` is specified in the options, return the response as
a [value, score] pair
"""
if not response or not options.get("withscore"):
return response
score_cast_func = options.get("score_cast_func", float)
return [response[0], score_cast_func(response[1])]
def sort_return_tuples(response, **options):
"""
If ``groups`` is specified, return the response as a list of
n-element tuples with n being the value found in options['groups']
"""
if not response or not options.get("groups"):
return response
n = options["groups"]
return list(zip(*[response[i::n] for i in range(n)]))
def parse_stream_list(response):
if response is None:
return None
data = []
for r in response:
if r is not None:
data.append((r[0], pairs_to_dict(r[1])))
else:
data.append((None, None))
return data
def pairs_to_dict_with_str_keys(response):
return pairs_to_dict(response, decode_keys=True)
def parse_list_of_dicts(response):
return list(map(pairs_to_dict_with_str_keys, response))
def parse_xclaim(response, **options):
if options.get("parse_justid", False):
return response
return parse_stream_list(response)
def parse_xautoclaim(response, **options):
if options.get("parse_justid", False):
return response[1]
response[1] = parse_stream_list(response[1])
return response
def parse_xinfo_stream(response, **options):
if isinstance(response, list):
data = pairs_to_dict(response, decode_keys=True)
else:
data = {str_if_bytes(k): v for k, v in response.items()}
if not options.get("full", False):
first = data.get("first-entry")
if first is not None and first[0] is not None:
data["first-entry"] = (first[0], pairs_to_dict(first[1]))
last = data["last-entry"]
if last is not None and last[0] is not None:
data["last-entry"] = (last[0], pairs_to_dict(last[1]))
else:
data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]}
if len(data["groups"]) > 0 and isinstance(data["groups"][0], list):
data["groups"] = [
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
]
for g in data["groups"]:
if g["consumers"] and g["consumers"][0] is not None:
g["consumers"] = [
pairs_to_dict(c, decode_keys=True) for c in g["consumers"]
]
else:
data["groups"] = [
{str_if_bytes(k): v for k, v in group.items()}
for group in data["groups"]
]
return data
def parse_xread(response):
if response is None:
return []
return [[r[0], parse_stream_list(r[1])] for r in response]
def parse_xread_resp3(response):
if response is None:
return {}
return {key: [parse_stream_list(value)] for key, value in response.items()}
def parse_xpending(response, **options):
if options.get("parse_detail", False):
return parse_xpending_range(response)
consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []]
return {
"pending": response[0],
"min": response[1],
"max": response[2],
"consumers": consumers,
}
def parse_xpending_range(response):
k = ("message_id", "consumer", "time_since_delivered", "times_delivered")
return [dict(zip(k, r)) for r in response]
def float_or_none(response):
if response is None:
return None
return float(response)
def bool_ok(response, **options):
return str_if_bytes(response) == "OK"
def parse_zadd(response, **options):
if response is None:
return None
if options.get("as_score"):
return float(response)
return int(response)
def parse_client_list(response, **options):
clients = []
for c in str_if_bytes(response).splitlines():
client_dict = {}
tokens = c.split(" ")
last_key = None
for token in tokens:
if "=" in token:
# Values might contain '='
key, value = token.split("=", 1)
client_dict[key] = value
last_key = key
else:
# Values may include spaces. For instance, when running Redis via a Unix socket — such as
# "/tmp/redis sock/redis.sock" — the addr or laddr field will include a space.
client_dict[last_key] += " " + token
if client_dict:
clients.append(client_dict)
return clients
def parse_config_get(response, **options):
response = [str_if_bytes(i) if i is not None else None for i in response]
return response and pairs_to_dict(response) or {}
def parse_scan(response, **options):
cursor, r = response
return int(cursor), r
def parse_hscan(response, **options):
cursor, r = response
no_values = options.get("no_values", False)
if no_values:
payload = r or []
else:
payload = r and pairs_to_dict(r) or {}
return int(cursor), payload
def parse_zscan(response, **options):
score_cast_func = options.get("score_cast_func", float)
cursor, r = response
it = iter(r)
return int(cursor), list(zip(it, map(score_cast_func, it)))
def parse_zmscore(response, **options):
# zmscore: list of scores (double precision floating point number) or nil
return [float(score) if score is not None else None for score in response]
def parse_slowlog_get(response, **options):
space = " " if options.get("decode_responses", False) else b" "
def parse_item(item):
result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])}
# Redis Enterprise injects another entry at index [3], which has
# the complexity info (i.e. the value N in case the command has
# an O(N) complexity) instead of the command.
if isinstance(item[3], list):
result["command"] = space.join(item[3])
# These fields are optional, depends on environment.
if len(item) >= 6:
result["client_address"] = item[4]
result["client_name"] = item[5]
else:
result["complexity"] = item[3]
result["command"] = space.join(item[4])
# These fields are optional, depends on environment.
if len(item) >= 7:
result["client_address"] = item[5]
result["client_name"] = item[6]
return result
return [parse_item(item) for item in response]
def parse_stralgo(response, **options):
"""
Parse the response from `STRALGO` command.
Without modifiers the returned value is string.
When LEN is given the command returns the length of the result
(i.e integer).
When IDX is given the command returns a dictionary with the LCS
length and all the ranges in both the strings, start and end
offset for each string, where there are matches.
When WITHMATCHLEN is given, each array representing a match will
also have the length of the match at the beginning of the array.
"""
if options.get("len", False):
return int(response)
if options.get("idx", False):
if options.get("withmatchlen", False):
matches = [
[(int(match[-1]))] + list(map(tuple, match[:-1]))
for match in response[1]
]
else:
matches = [list(map(tuple, match)) for match in response[1]]
return {
str_if_bytes(response[0]): matches,
str_if_bytes(response[2]): int(response[3]),
}
return str_if_bytes(response)
def parse_cluster_info(response, **options):
response = str_if_bytes(response)
return dict(line.split(":") for line in response.splitlines() if line)
def _parse_node_line(line):
line_items = line.split(" ")
node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8]
ip = addr.split("@")[0]
hostname = addr.split("@")[1].split(",")[1] if "@" in addr and "," in addr else ""
node_dict = {
"node_id": node_id,
"hostname": hostname,
"flags": flags,
"master_id": master_id,
"last_ping_sent": ping,
"last_pong_rcvd": pong,
"epoch": epoch,
"slots": [],
"migrations": [],
"connected": True if connected == "connected" else False,
}
if len(line_items) >= 9:
slots, migrations = _parse_slots(line_items[8:])
node_dict["slots"], node_dict["migrations"] = slots, migrations
return ip, node_dict
def _parse_slots(slot_ranges):
slots, migrations = [], []
for s_range in slot_ranges:
if "->-" in s_range:
slot_id, dst_node_id = s_range[1:-1].split("->-", 1)
migrations.append(
{"slot": slot_id, "node_id": dst_node_id, "state": "migrating"}
)
elif "-<-" in s_range:
slot_id, src_node_id = s_range[1:-1].split("-<-", 1)
migrations.append(
{"slot": slot_id, "node_id": src_node_id, "state": "importing"}
)
else:
s_range = [sl for sl in s_range.split("-")]
slots.append(s_range)
return slots, migrations
def parse_cluster_nodes(response, **options):
"""
@see: https://redis.io/commands/cluster-nodes # string / bytes
@see: https://redis.io/commands/cluster-replicas # list of string / bytes
"""
if isinstance(response, (str, bytes)):
response = response.splitlines()
return dict(_parse_node_line(str_if_bytes(node)) for node in response)
def parse_geosearch_generic(response, **options):
"""
Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER'
commands according to 'withdist', 'withhash' and 'withcoord' labels.
"""
try:
if options["store"] or options["store_dist"]:
# `store` and `store_dist` cant be combined
# with other command arguments.
# relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER'
return response
except KeyError: # it means the command was sent via execute_command
return response
if not isinstance(response, list):
response_list = [response]
else:
response_list = response
if not options["withdist"] and not options["withcoord"] and not options["withhash"]:
# just a bunch of places
return response_list
cast = {
"withdist": float,
"withcoord": lambda ll: (float(ll[0]), float(ll[1])),
"withhash": int,
}
# zip all output results with each casting function to get
# the properly native Python value.
f = [lambda x: x]
f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]]
return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list]
def parse_command(response, **options):
commands = {}
for command in response:
cmd_dict = {}
cmd_name = str_if_bytes(command[0])
cmd_dict["name"] = cmd_name
cmd_dict["arity"] = int(command[1])
cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]]
cmd_dict["first_key_pos"] = command[3]
cmd_dict["last_key_pos"] = command[4]
cmd_dict["step_count"] = command[5]
if len(command) > 7:
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]
commands[cmd_name] = cmd_dict
return commands
def parse_command_resp3(response, **options):
commands = {}
for command in response:
cmd_dict = {}
cmd_name = str_if_bytes(command[0])
cmd_dict["name"] = cmd_name
cmd_dict["arity"] = command[1]
cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]}
cmd_dict["first_key_pos"] = command[3]
cmd_dict["last_key_pos"] = command[4]
cmd_dict["step_count"] = command[5]
cmd_dict["acl_categories"] = command[6]
if len(command) > 7:
cmd_dict["tips"] = command[7]
cmd_dict["key_specifications"] = command[8]
cmd_dict["subcommands"] = command[9]
commands[cmd_name] = cmd_dict
return commands
def parse_pubsub_numsub(response, **options):
return list(zip(response[0::2], response[1::2]))
def parse_client_kill(response, **options):
if isinstance(response, int):
return response
return str_if_bytes(response) == "OK"
def parse_acl_getuser(response, **options):
if response is None:
return None
if isinstance(response, list):
data = pairs_to_dict(response, decode_keys=True)
else:
data = {str_if_bytes(key): value for key, value in response.items()}
# convert everything but user-defined data in 'keys' to native strings
data["flags"] = list(map(str_if_bytes, data["flags"]))
data["passwords"] = list(map(str_if_bytes, data["passwords"]))
data["commands"] = str_if_bytes(data["commands"])
if isinstance(data["keys"], str) or isinstance(data["keys"], bytes):
data["keys"] = list(str_if_bytes(data["keys"]).split(" "))
if data["keys"] == [""]:
data["keys"] = []
if "channels" in data:
if isinstance(data["channels"], str) or isinstance(data["channels"], bytes):
data["channels"] = list(str_if_bytes(data["channels"]).split(" "))
if data["channels"] == [""]:
data["channels"] = []
if "selectors" in data:
if data["selectors"] != [] and isinstance(data["selectors"][0], list):
data["selectors"] = [
list(map(str_if_bytes, selector)) for selector in data["selectors"]
]
elif data["selectors"] != []:
data["selectors"] = [
{str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()}
for selector in data["selectors"]
]
# split 'commands' into separate 'categories' and 'commands' lists
commands, categories = [], []
for command in data["commands"].split(" "):
categories.append(command) if "@" in command else commands.append(command)
data["commands"] = commands
data["categories"] = categories
data["enabled"] = "on" in data["flags"]
return data
def parse_acl_log(response, **options):
if response is None:
return None
if isinstance(response, list):
data = []
for log in response:
log_data = pairs_to_dict(log, True, True)
client_info = log_data.get("client-info", "")
log_data["client-info"] = parse_client_info(client_info)
# float() is lossy comparing to the "double" in C
log_data["age-seconds"] = float(log_data["age-seconds"])
data.append(log_data)
else:
data = bool_ok(response)
return data
def parse_client_info(value):
"""
Parsing client-info in ACL Log in following format.
"key1=value1 key2=value2 key3=value3"
"""
client_info = {}
for info in str_if_bytes(value).strip().split():
key, value = info.split("=")
client_info[key] = value
# Those fields are defined as int in networking.c
for int_key in {
"id",
"age",
"idle",
"db",
"sub",
"psub",
"multi",
"qbuf",
"qbuf-free",
"obl",
"argv-mem",
"oll",
"omem",
"tot-mem",
}:
if int_key in client_info:
client_info[int_key] = int(client_info[int_key])
return client_info
def parse_set_result(response, **options):
"""
Handle SET result since GET argument is available since Redis 6.2.
Parsing SET result into:
- BOOL
- String when GET argument is used
"""
if options.get("get"):
# Redis will return a getCommand result.
# See `setGenericCommand` in t_string.c
return response
return response and str_if_bytes(response) == "OK"
def string_keys_to_dict(key_string, callback):
return dict.fromkeys(key_string.split(), callback)
_RedisCallbacks = {
**string_keys_to_dict(
"AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST PSETEX "
"PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE",
bool,
),
**string_keys_to_dict("HINCRBYFLOAT INCRBYFLOAT", float),
**string_keys_to_dict(
"ASKING FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE "
"RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH",
bool_ok,
),
**string_keys_to_dict("XREAD XREADGROUP", parse_xread),
**string_keys_to_dict(
"GEORADIUS GEORADIUSBYMEMBER GEOSEARCH",
parse_geosearch_generic,
),
**string_keys_to_dict("XRANGE XREVRANGE", parse_stream_list),
"ACL GETUSER": parse_acl_getuser,
"ACL LOAD": bool_ok,
"ACL LOG": parse_acl_log,
"ACL SETUSER": bool_ok,
"ACL SAVE": bool_ok,
"CLIENT INFO": parse_client_info,
"CLIENT KILL": parse_client_kill,
"CLIENT LIST": parse_client_list,
"CLIENT PAUSE": bool_ok,
"CLIENT SETINFO": bool_ok,
"CLIENT SETNAME": bool_ok,
"CLIENT UNBLOCK": bool,
"CLUSTER ADDSLOTS": bool_ok,
"CLUSTER ADDSLOTSRANGE": bool_ok,
"CLUSTER DELSLOTS": bool_ok,
"CLUSTER DELSLOTSRANGE": bool_ok,
"CLUSTER FAILOVER": bool_ok,
"CLUSTER FORGET": bool_ok,
"CLUSTER INFO": parse_cluster_info,
"CLUSTER MEET": bool_ok,
"CLUSTER NODES": parse_cluster_nodes,
"CLUSTER REPLICAS": parse_cluster_nodes,
"CLUSTER REPLICATE": bool_ok,
"CLUSTER RESET": bool_ok,
"CLUSTER SAVECONFIG": bool_ok,
"CLUSTER SET-CONFIG-EPOCH": bool_ok,
"CLUSTER SETSLOT": bool_ok,
"CLUSTER SLAVES": parse_cluster_nodes,
"COMMAND": parse_command,
"CONFIG RESETSTAT": bool_ok,
"CONFIG SET": bool_ok,
"FUNCTION DELETE": bool_ok,
"FUNCTION FLUSH": bool_ok,
"FUNCTION RESTORE": bool_ok,
"GEODIST": float_or_none,
"HSCAN": parse_hscan,
"INFO": parse_info,
"LASTSAVE": timestamp_to_datetime,
"MEMORY PURGE": bool_ok,
"MODULE LOAD": bool,
"MODULE UNLOAD": bool,
"PING": lambda r: str_if_bytes(r) == "PONG",
"PUBSUB NUMSUB": parse_pubsub_numsub,
"PUBSUB SHARDNUMSUB": parse_pubsub_numsub,
"QUIT": bool_ok,
"SET": parse_set_result,
"SCAN": parse_scan,
"SCRIPT EXISTS": lambda r: list(map(bool, r)),
"SCRIPT FLUSH": bool_ok,
"SCRIPT KILL": bool_ok,
"SCRIPT LOAD": str_if_bytes,
"SENTINEL CKQUORUM": bool_ok,
"SENTINEL FAILOVER": bool_ok,
"SENTINEL FLUSHCONFIG": bool_ok,
"SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master,
"SENTINEL MONITOR": bool_ok,
"SENTINEL RESET": bool_ok,
"SENTINEL REMOVE": bool_ok,
"SENTINEL SET": bool_ok,
"SLOWLOG GET": parse_slowlog_get,
"SLOWLOG RESET": bool_ok,
"SORT": sort_return_tuples,
"SSCAN": parse_scan,
"TIME": lambda x: (int(x[0]), int(x[1])),
"XAUTOCLAIM": parse_xautoclaim,
"XCLAIM": parse_xclaim,
"XGROUP CREATE": bool_ok,
"XGROUP DESTROY": bool,
"XGROUP SETID": bool_ok,
"XINFO STREAM": parse_xinfo_stream,
"XPENDING": parse_xpending,
"ZSCAN": parse_zscan,
}
_RedisCallbacksRESP2 = {
**string_keys_to_dict(
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
),
**string_keys_to_dict(
"ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZREVRANGE "
"ZREVRANGEBYSCORE ZUNION",
zset_score_pairs,
),
**string_keys_to_dict(
"ZREVRANK ZRANK",
zset_score_for_rank,
),
**string_keys_to_dict("ZINCRBY ZSCORE", float_or_none),
**string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True),
**string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None),
**string_keys_to_dict(
"BZPOPMAX BZPOPMIN", lambda r: r and (r[0], r[1], float(r[2])) or None
),
"ACL CAT": lambda r: list(map(str_if_bytes, r)),
"ACL GENPASS": str_if_bytes,
"ACL HELP": lambda r: list(map(str_if_bytes, r)),
"ACL LIST": lambda r: list(map(str_if_bytes, r)),
"ACL USERS": lambda r: list(map(str_if_bytes, r)),
"ACL WHOAMI": str_if_bytes,
"CLIENT GETNAME": str_if_bytes,
"CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)),
"CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)),
"COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)),
"CONFIG GET": parse_config_get,
"DEBUG OBJECT": parse_debug_object,
"GEOHASH": lambda r: list(map(str_if_bytes, r)),
"GEOPOS": lambda r: list(
map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r)
),
"HGETALL": lambda r: r and pairs_to_dict(r) or {},
"MEMORY STATS": parse_memory_stats,
"MODULE LIST": lambda r: [pairs_to_dict(m) for m in r],
"RESET": str_if_bytes,
"SENTINEL MASTER": parse_sentinel_master,
"SENTINEL MASTERS": parse_sentinel_masters,
"SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels,
"SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels,
"STRALGO": parse_stralgo,
"XINFO CONSUMERS": parse_list_of_dicts,
"XINFO GROUPS": parse_list_of_dicts,
"ZADD": parse_zadd,
"ZMSCORE": parse_zmscore,
}
_RedisCallbacksRESP3 = {
**string_keys_to_dict(
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
),
**string_keys_to_dict(
"ZRANGE ZINTER ZPOPMAX ZPOPMIN HGETALL XREADGROUP",
lambda r, **kwargs: r,
),
**string_keys_to_dict(
"ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE ZUNION",
zset_score_pairs_resp3,
),
**string_keys_to_dict(
"ZREVRANK ZRANK",
zset_score_for_rank_resp3,
),
**string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3),
"ACL LOG": lambda r: (
[
{str_if_bytes(key): str_if_bytes(value) for key, value in x.items()}
for x in r
]
if isinstance(r, list)
else bool_ok(r)
),
"COMMAND": parse_command_resp3,
"CONFIG GET": lambda r: {
str_if_bytes(key) if key is not None else None: (
str_if_bytes(value) if value is not None else None
)
for key, value in r.items()
},
"MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()},
"SENTINEL MASTER": parse_sentinel_state_resp3,
"SENTINEL MASTERS": parse_sentinel_masters_resp3,
"SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3,
"SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3,
"STRALGO": lambda r, **options: (
{str_if_bytes(key): str_if_bytes(value) for key, value in r.items()}
if isinstance(r, dict)
else str_if_bytes(r)
),
"XINFO CONSUMERS": lambda r: [
{str_if_bytes(key): value for key, value in x.items()} for x in r
],
"XINFO GROUPS": lambda r: [
{str_if_bytes(key): value for key, value in d.items()} for d in r
],
}

View File

@@ -0,0 +1,301 @@
import asyncio
import socket
import sys
from logging import getLogger
from typing import Callable, List, Optional, TypedDict, Union
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout
from ..exceptions import ConnectionError, InvalidResponse, RedisError
from ..typing import EncodableT
from ..utils import HIREDIS_AVAILABLE
from .base import (
AsyncBaseParser,
AsyncPushNotificationsParser,
BaseParser,
PushNotificationsParser,
)
from .socket import (
NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
NONBLOCKING_EXCEPTIONS,
SENTINEL,
SERVER_CLOSED_CONNECTION_ERROR,
)
# Used to signal that hiredis-py does not have enough data to parse.
# Using `False` or `None` is not reliable, given that the parser can
# return `False` or `None` for legitimate reasons from RESP payloads.
NOT_ENOUGH_DATA = object()
class _HiredisReaderArgs(TypedDict, total=False):
protocolError: Callable[[str], Exception]
replyError: Callable[[str], Exception]
encoding: Optional[str]
errors: Optional[str]
class _HiredisParser(BaseParser, PushNotificationsParser):
"Parser class for connections using Hiredis"
def __init__(self, socket_read_size):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not installed")
self.socket_read_size = socket_read_size
self._buffer = bytearray(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.node_moving_push_handler_func = None
self.maintenance_push_handler_func = None
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
def on_connect(self, connection, **kwargs):
import hiredis
self._sock = connection._sock
self._socket_timeout = connection.socket_timeout
kwargs = {
"protocolError": InvalidResponse,
"replyError": self.parse_error,
"errors": connection.encoder.encoding_errors,
"notEnoughData": NOT_ENOUGH_DATA,
}
if connection.encoder.decode_responses:
kwargs["encoding"] = connection.encoder.encoding
self._reader = hiredis.Reader(**kwargs)
self._next_response = NOT_ENOUGH_DATA
try:
self._hiredis_PushNotificationType = hiredis.PushNotification
except AttributeError:
# hiredis < 3.2
self._hiredis_PushNotificationType = None
def on_disconnect(self):
self._sock = None
self._reader = None
self._next_response = NOT_ENOUGH_DATA
def can_read(self, timeout):
if not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._next_response is NOT_ENOUGH_DATA:
self._next_response = self._reader.gets()
if self._next_response is NOT_ENOUGH_DATA:
return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
return True
def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True):
sock = self._sock
custom_timeout = timeout is not SENTINEL
try:
if custom_timeout:
sock.settimeout(timeout)
bufflen = self._sock.recv_into(self._buffer)
if bufflen == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
self._reader.feed(self._buffer, 0, bufflen)
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True
except socket.timeout:
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket")
return False
except NONBLOCKING_EXCEPTIONS as ex:
# if we're in nonblocking mode and the recv raises a
# blocking error, simply return False indicating that
# there's no data to be read. otherwise raise the
# original exception.
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
if not raise_on_timeout and ex.errno == allowed:
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")
finally:
if custom_timeout:
sock.settimeout(self._socket_timeout)
def read_response(self, disable_decoding=False, push_request=False):
if not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
# _next_response might be cached from a can_read() call
if self._next_response is not NOT_ENOUGH_DATA:
response = self._next_response
self._next_response = NOT_ENOUGH_DATA
if self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
# if this is a push request return the push response
if push_request:
return response
return self.read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
return response
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
while response is NOT_ENOUGH_DATA:
self.read_from_socket()
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
# if the response is a ConnectionError or the response is a list and
# the first item is a ConnectionError, raise it as something bad
# happened
if isinstance(response, ConnectionError):
raise response
elif self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if push_request:
return response
return self.read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
elif (
isinstance(response, list)
and response
and isinstance(response[0], ConnectionError)
):
raise response[0]
return response
class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
"""Async implementation of parser class for connections using Hiredis"""
__slots__ = ("_reader",)
def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader = None
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None
async def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
def on_connect(self, connection):
import hiredis
self._stream = connection._reader
kwargs: _HiredisReaderArgs = {
"protocolError": InvalidResponse,
"replyError": self.parse_error,
"notEnoughData": NOT_ENOUGH_DATA,
}
if connection.encoder.decode_responses:
kwargs["encoding"] = connection.encoder.encoding
kwargs["errors"] = connection.encoder.encoding_errors
self._reader = hiredis.Reader(**kwargs)
self._connected = True
try:
self._hiredis_PushNotificationType = getattr(
hiredis, "PushNotification", None
)
except AttributeError:
# hiredis < 3.2
self._hiredis_PushNotificationType = None
def on_disconnect(self):
self._connected = False
async def can_read_destructive(self):
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._reader.gets() is not NOT_ENOUGH_DATA:
return True
try:
async with async_timeout(0):
return await self.read_from_socket()
except asyncio.TimeoutError:
return False
async def read_from_socket(self):
buffer = await self._stream.read(self._read_size)
if not buffer or not isinstance(buffer, bytes):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
self._reader.feed(buffer)
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True
async def read_response(
self, disable_decoding: bool = False, push_request: bool = False
) -> Union[EncodableT, List[EncodableT]]:
# If `on_disconnect()` has been called, prohibit any more reads
# even if they could happen because data might be present.
# We still allow reads in progress to finish
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
while response is NOT_ENOUGH_DATA:
await self.read_from_socket()
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
# if the response is a ConnectionError or the response is a list and
# the first item is a ConnectionError, raise it as something bad
# happened
if isinstance(response, ConnectionError):
raise response
elif self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = await self.handle_push_response(response)
if not push_request:
return await self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
elif (
isinstance(response, list)
and response
and isinstance(response[0], ConnectionError)
):
raise response[0]
return response

View File

@@ -0,0 +1,132 @@
from typing import Any, Union
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
from ..typing import EncodableT
from .base import _AsyncRESPBase, _RESPBase
from .socket import SERVER_CLOSED_CONNECTION_ERROR
class _RESP2Parser(_RESPBase):
"""RESP2 protocol implementation"""
def read_response(self, disable_decoding=False):
pos = self._buffer.get_pos() if self._buffer else None
try:
result = self._read_response(disable_decoding=disable_decoding)
except BaseException:
if self._buffer:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result
def _read_response(self, disable_decoding=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
byte, response = raw[:1], raw[1:]
# server returned an error
if byte == b"-":
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# int value
elif byte == b":":
return int(response)
# bulk response
elif byte == b"$" and response == b"-1":
return None
elif byte == b"$":
response = self._buffer.read(int(response))
# multi-bulk response
elif byte == b"*" and response == b"-1":
return None
elif byte == b"*":
response = [
self._read_response(disable_decoding=disable_decoding)
for i in range(int(response))
]
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if disable_decoding is False:
response = self.encoder.decode(response)
return response
class _AsyncRESP2Parser(_AsyncRESPBase):
"""Async class for the RESP2 protocol"""
async def read_response(self, disable_decoding: bool = False):
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
response = await self._read_response(disable_decoding=disable_decoding)
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response
async def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
raw = await self._readline()
response: Any
byte, response = raw[:1], raw[1:]
# server returned an error
if byte == b"-":
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
self._clear() # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# int value
elif byte == b":":
return int(response)
# bulk response
elif byte == b"$" and response == b"-1":
return None
elif byte == b"$":
response = await self._read(int(response))
# multi-bulk response
elif byte == b"*" and response == b"-1":
return None
elif byte == b"*":
response = [
(await self._read_response(disable_decoding))
for _ in range(int(response)) # noqa
]
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if disable_decoding is False:
response = self.encoder.decode(response)
return response

View File

@@ -0,0 +1,263 @@
from logging import getLogger
from typing import Any, Union
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
from ..typing import EncodableT
from .base import (
AsyncPushNotificationsParser,
PushNotificationsParser,
_AsyncRESPBase,
_RESPBase,
)
from .socket import SERVER_CLOSED_CONNECTION_ERROR
class _RESP3Parser(_RESPBase, PushNotificationsParser):
"""RESP3 protocol implementation"""
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.node_moving_push_handler_func = None
self.maintenance_push_handler_func = None
self.invalidation_push_handler_func = None
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
def read_response(self, disable_decoding=False, push_request=False):
pos = self._buffer.get_pos() if self._buffer else None
try:
result = self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
except BaseException:
if self._buffer:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result
def _read_response(self, disable_decoding=False, push_request=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
byte, response = raw[:1], raw[1:]
# server returned an error
if byte in (b"-", b"!"):
if byte == b"!":
response = self._buffer.read(int(response))
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# null value
elif byte == b"_":
return None
# int and big int values
elif byte in (b":", b"("):
return int(response)
# double value
elif byte == b",":
return float(response)
# bool value
elif byte == b"#":
return response == b"t"
# bulk response
elif byte == b"$":
response = self._buffer.read(int(response))
# verbatim string response
elif byte == b"=":
response = self._buffer.read(int(response))[4:]
# array response
elif byte == b"*":
response = [
self._read_response(disable_decoding=disable_decoding)
for _ in range(int(response))
]
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we return sets as list, all the time, for predictability
response = [
self._read_response(disable_decoding=disable_decoding)
for _ in range(int(response))
]
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
# Evaluation order of key:val expression in dict comprehension only
# became defined to be left-right in version 3.8
resp_dict = {}
for _ in range(int(response)):
key = self._read_response(disable_decoding=disable_decoding)
resp_dict[key] = self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
response = resp_dict
# push response
elif byte == b">":
response = [
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
for _ in range(int(response))
]
response = self.handle_push_response(response)
# if this is a push request return the push response
if push_request:
return response
return self._read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response
class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
async def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
async def read_response(
self, disable_decoding: bool = False, push_request: bool = False
):
if self._chunks:
# augment parsing buffer with previously read data
self._buffer += b"".join(self._chunks)
self._chunks.clear()
self._pos = 0
response = await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
# Successfully parsing a response allows us to clear our parsing buffer
self._clear()
return response
async def _read_response(
self, disable_decoding: bool = False, push_request: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
response: Any
byte, response = raw[:1], raw[1:]
# if byte not in (b"-", b"+", b":", b"$", b"*"):
# raise InvalidResponse(f"Protocol Error: {raw!r}")
# server returned an error
if byte in (b"-", b"!"):
if byte == b"!":
response = await self._read(int(response))
response = response.decode("utf-8", errors="replace")
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
self._clear() # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == b"+":
pass
# null value
elif byte == b"_":
return None
# int and big int values
elif byte in (b":", b"("):
return int(response)
# double value
elif byte == b",":
return float(response)
# bool value
elif byte == b"#":
return response == b"t"
# bulk response
elif byte == b"$":
response = await self._read(int(response))
# verbatim string response
elif byte == b"=":
response = (await self._read(int(response)))[4:]
# array response
elif byte == b"*":
response = [
(await self._read_response(disable_decoding=disable_decoding))
for _ in range(int(response))
]
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we always convert to a list, to have predictable return types
response = [
(await self._read_response(disable_decoding=disable_decoding))
for _ in range(int(response))
]
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
# Evaluation order of key:val expression in dict comprehension only
# became defined to be left-right in version 3.8
resp_dict = {}
for _ in range(int(response)):
key = await self._read_response(disable_decoding=disable_decoding)
resp_dict[key] = await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
response = resp_dict
# push response
elif byte == b">":
response = [
(
await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
)
for _ in range(int(response))
]
response = await self.handle_push_response(response)
if not push_request:
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

View File

@@ -0,0 +1,162 @@
import errno
import io
import socket
from io import SEEK_END
from typing import Optional, Union
from ..exceptions import ConnectionError, TimeoutError
from ..utils import SSL_AVAILABLE
NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK}
if SSL_AVAILABLE:
import ssl
if hasattr(ssl, "SSLWantReadError"):
NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2
NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2
else:
NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2
NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
SENTINEL = object()
SYM_CRLF = b"\r\n"
class SocketBuffer:
def __init__(
self, socket: socket.socket, socket_read_size: int, socket_timeout: float
):
self._sock = socket
self.socket_read_size = socket_read_size
self.socket_timeout = socket_timeout
self._buffer = io.BytesIO()
def unread_bytes(self) -> int:
"""
Remaining unread length of buffer
"""
pos = self._buffer.tell()
end = self._buffer.seek(0, SEEK_END)
self._buffer.seek(pos)
return end - pos
def _read_from_socket(
self,
length: Optional[int] = None,
timeout: Union[float, object] = SENTINEL,
raise_on_timeout: Optional[bool] = True,
) -> bool:
sock = self._sock
socket_read_size = self.socket_read_size
marker = 0
custom_timeout = timeout is not SENTINEL
buf = self._buffer
current_pos = buf.tell()
buf.seek(0, SEEK_END)
if custom_timeout:
sock.settimeout(timeout)
try:
while True:
data = self._sock.recv(socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
marker += data_length
if length is not None and length > marker:
continue
return True
except socket.timeout:
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket")
return False
except NONBLOCKING_EXCEPTIONS as ex:
# if we're in nonblocking mode and the recv raises a
# blocking error, simply return False indicating that
# there's no data to be read. otherwise raise the
# original exception.
allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
if not raise_on_timeout and ex.errno == allowed:
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")
finally:
buf.seek(current_pos)
if custom_timeout:
sock.settimeout(self.socket_timeout)
def can_read(self, timeout: float) -> bool:
return bool(self.unread_bytes()) or self._read_from_socket(
timeout=timeout, raise_on_timeout=False
)
def read(self, length: int) -> bytes:
length = length + 2 # make sure to read the \r\n terminator
# BufferIO will return less than requested if buffer is short
data = self._buffer.read(length)
missing = length - len(data)
if missing:
# fill up the buffer and read the remainder
self._read_from_socket(missing)
data += self._buffer.read(missing)
return data[:-2]
def readline(self) -> bytes:
buf = self._buffer
data = buf.readline()
while not data.endswith(SYM_CRLF):
# there's more data in the socket that we need
self._read_from_socket()
data += buf.readline()
return data[:-2]
def get_pos(self) -> int:
"""
Get current read position
"""
return self._buffer.tell()
def rewind(self, pos: int) -> None:
"""
Rewind the buffer to a specific position, to re-start reading
"""
self._buffer.seek(pos)
def purge(self) -> None:
"""
After a successful read, purge the read part of buffer
"""
unread = self.unread_bytes()
# Only if we have read all of the buffer do we truncate, to
# reduce the amount of memory thrashing. This heuristic
# can be changed or removed later.
if unread > 0:
return
if unread > 0:
# move unread data to the front
view = self._buffer.getbuffer()
view[:unread] = view[-unread:]
self._buffer.truncate(unread)
self._buffer.seek(0)
def close(self) -> None:
try:
self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
# BadFileDescriptor error. Perhaps the client ran out of
# memory or something else? It's probably OK to ignore
# any error being raised from purge/close since we're
# removing the reference to the instance below.
pass
self._buffer = None
self._sock = None

View File

@@ -0,0 +1,64 @@
from redis.asyncio.client import Redis, StrictRedis
from redis.asyncio.cluster import RedisCluster
from redis.asyncio.connection import (
BlockingConnectionPool,
Connection,
ConnectionPool,
SSLConnection,
UnixDomainSocketConnection,
)
from redis.asyncio.sentinel import (
Sentinel,
SentinelConnectionPool,
SentinelManagedConnection,
SentinelManagedSSLConnection,
)
from redis.asyncio.utils import from_url
from redis.backoff import default_backoff
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
BusyLoadingError,
ChildDeadlockedError,
ConnectionError,
DataError,
InvalidResponse,
OutOfMemoryError,
PubSubError,
ReadOnlyError,
RedisError,
ResponseError,
TimeoutError,
WatchError,
)
__all__ = [
"AuthenticationError",
"AuthenticationWrongNumberOfArgsError",
"BlockingConnectionPool",
"BusyLoadingError",
"ChildDeadlockedError",
"Connection",
"ConnectionError",
"ConnectionPool",
"DataError",
"from_url",
"default_backoff",
"InvalidResponse",
"PubSubError",
"OutOfMemoryError",
"ReadOnlyError",
"Redis",
"RedisCluster",
"RedisError",
"ResponseError",
"Sentinel",
"SentinelConnectionPool",
"SentinelManagedConnection",
"SentinelManagedSSLConnection",
"SSLConnection",
"StrictRedis",
"TimeoutError",
"UnixDomainSocketConnection",
"WatchError",
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,265 @@
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Mapping, Optional, Union
from redis.http.http_client import HttpClient, HttpResponse
DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)"
DEFAULT_TIMEOUT = 30.0
RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
class AsyncHTTPClient(ABC):
@abstractmethod
async def get(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP GET request."""
pass
@abstractmethod
async def delete(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP DELETE request."""
pass
@abstractmethod
async def post(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP POST request."""
pass
@abstractmethod
async def put(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP PUT request."""
pass
@abstractmethod
async def patch(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
"""
Invoke HTTP PATCH request."""
pass
@abstractmethod
async def request(
self,
method: str,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Union[bytes, str]] = None,
timeout: Optional[float] = None,
) -> HttpResponse:
"""
Invoke HTTP request with given method."""
pass
class AsyncHTTPClientWrapper(AsyncHTTPClient):
"""
An async wrapper around sync HTTP client with thread pool execution.
"""
def __init__(self, client: HttpClient, max_workers: int = 10) -> None:
"""
Initialize a new HTTP client instance.
Args:
client: Sync HTTP client instance.
max_workers: Maximum number of concurrent requests.
The client supports both regular HTTPS with server verification and mutual TLS
authentication. For server verification, provide CA certificate information via
ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client
certificate and key via client_cert_file and client_key_file.
"""
self.client = client
self._executor = ThreadPoolExecutor(max_workers=max_workers)
async def get(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor, self.client.get, path, params, headers, timeout, expect_json
)
async def delete(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.delete,
path,
params,
headers,
timeout,
expect_json,
)
async def post(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.post,
path,
json_body,
data,
params,
headers,
timeout,
expect_json,
)
async def put(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.put,
path,
json_body,
data,
params,
headers,
timeout,
expect_json,
)
async def patch(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.patch,
path,
json_body,
data,
params,
headers,
timeout,
expect_json,
)
async def request(
self,
method: str,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Union[bytes, str]] = None,
timeout: Optional[float] = None,
) -> HttpResponse:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self._executor,
self.client.request,
method,
path,
params,
headers,
body,
timeout,
)

View File

@@ -0,0 +1,334 @@
import asyncio
import logging
import threading
import uuid
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, Optional, Union
from redis.exceptions import LockError, LockNotOwnedError
from redis.typing import Number
if TYPE_CHECKING:
from redis.asyncio import Redis, RedisCluster
logger = logging.getLogger(__name__)
class Lock:
"""
A shared, distributed Lock. Using Redis for locking allows the Lock
to be shared across processes and/or machines.
It's left to the user to resolve deadlock issues and make sure
multiple clients play nicely together.
"""
lua_release = None
lua_extend = None
lua_reacquire = None
# KEYS[1] - lock name
# ARGV[1] - token
# return 1 if the lock was released, otherwise 0
LUA_RELEASE_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
redis.call('del', KEYS[1])
return 1
"""
# KEYS[1] - lock name
# ARGV[1] - token
# ARGV[2] - additional milliseconds
# ARGV[3] - "0" if the additional time should be added to the lock's
# existing ttl or "1" if the existing ttl should be replaced
# return 1 if the locks time was extended, otherwise 0
LUA_EXTEND_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
local expiration = redis.call('pttl', KEYS[1])
if not expiration then
expiration = 0
end
if expiration < 0 then
return 0
end
local newttl = ARGV[2]
if ARGV[3] == "0" then
newttl = ARGV[2] + expiration
end
redis.call('pexpire', KEYS[1], newttl)
return 1
"""
# KEYS[1] - lock name
# ARGV[1] - token
# ARGV[2] - milliseconds
# return 1 if the locks time was reacquired, otherwise 0
LUA_REACQUIRE_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
redis.call('pexpire', KEYS[1], ARGV[2])
return 1
"""
def __init__(
self,
redis: Union["Redis", "RedisCluster"],
name: Union[str, bytes, memoryview],
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[Number] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
supplied by ``redis``.
``timeout`` indicates a maximum life for the lock in seconds.
By default, it will remain locked until release() is called.
``timeout`` can be specified as a float or integer, both representing
the number of seconds to wait.
``sleep`` indicates the amount of time to sleep in seconds per loop
iteration when the lock is in blocking mode and another client is
currently holding the lock.
``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
float or integer, both representing the number of seconds to wait.
``thread_local`` indicates whether the lock token is placed in
thread-local storage. By default, the token is placed in thread local
storage so that a thread only sees its token, not a token set by
another thread. Consider the following timeline:
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
thread-1 sets the token to "abc"
time: 1, thread-2 blocks trying to acquire `my-lock` using the
Lock instance.
time: 5, thread-1 has not yet completed. redis expires the lock
key.
time: 5, thread-2 acquired `my-lock` now that it's available.
thread-2 sets the token to "xyz"
time: 6, thread-1 finishes its work and calls release(). if the
token is *not* stored in thread local storage, then
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
local storage isn't disabled in this case, the worker thread won't see
the token set by the thread that acquired the lock. Our assumption
is that these cases aren't common and as such default to using
thread local storage.
"""
self.redis = redis
self.name = name
self.timeout = timeout
self.sleep = sleep
self.blocking = blocking
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.local = threading.local() if self.thread_local else SimpleNamespace()
self.raise_on_release_error = raise_on_release_error
self.local.token = None
self.register_scripts()
def register_scripts(self):
cls = self.__class__
client = self.redis
if cls.lua_release is None:
cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
if cls.lua_extend is None:
cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
if cls.lua_reacquire is None:
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
async def __aenter__(self):
if await self.acquire():
return self
raise LockError("Unable to acquire lock within the time specified")
async def __aexit__(self, exc_type, exc_value, traceback):
try:
await self.release()
except LockError:
if self.raise_on_release_error:
raise
logger.warning(
"Lock was unlocked or no longer owned when exiting context manager."
)
async def acquire(
self,
blocking: Optional[bool] = None,
blocking_timeout: Optional[Number] = None,
token: Optional[Union[str, bytes]] = None,
):
"""
Use Redis to hold a shared, distributed lock named ``name``.
Returns True once the lock is acquired.
If ``blocking`` is False, always return immediately. If the lock
was acquired, return True, otherwise return False.
``blocking_timeout`` specifies the maximum number of seconds to
wait trying to acquire the lock.
``token`` specifies the token value to be used. If provided, token
must be a bytes object or a string that can be encoded to a bytes
object with the default encoding. If a token isn't specified, a UUID
will be generated.
"""
sleep = self.sleep
if token is None:
token = uuid.uuid1().hex.encode()
else:
try:
encoder = self.redis.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = self.redis.get_encoder()
token = encoder.encode(token)
if blocking is None:
blocking = self.blocking
if blocking_timeout is None:
blocking_timeout = self.blocking_timeout
stop_trying_at = None
if blocking_timeout is not None:
stop_trying_at = asyncio.get_running_loop().time() + blocking_timeout
while True:
if await self.do_acquire(token):
self.local.token = token
return True
if not blocking:
return False
next_try_at = asyncio.get_running_loop().time() + sleep
if stop_trying_at is not None and next_try_at > stop_trying_at:
return False
await asyncio.sleep(sleep)
async def do_acquire(self, token: Union[str, bytes]) -> bool:
if self.timeout:
# convert to milliseconds
timeout = int(self.timeout * 1000)
else:
timeout = None
if await self.redis.set(self.name, token, nx=True, px=timeout):
return True
return False
async def locked(self) -> bool:
"""
Returns True if this key is locked by any process, otherwise False.
"""
return await self.redis.get(self.name) is not None
async def owned(self) -> bool:
"""
Returns True if this key is locked by this lock, otherwise False.
"""
stored_token = await self.redis.get(self.name)
# need to always compare bytes to bytes
# TODO: this can be simplified when the context manager is finished
if stored_token and not isinstance(stored_token, bytes):
try:
encoder = self.redis.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = self.redis.get_encoder()
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token
def release(self) -> Awaitable[None]:
"""Releases the already acquired lock"""
expected_token = self.local.token
if expected_token is None:
raise LockError(
"Cannot release a lock that's not owned or is already unlocked.",
lock_name=self.name,
)
self.local.token = None
return self.do_release(expected_token)
async def do_release(self, expected_token: bytes) -> None:
if not bool(
await self.lua_release(
keys=[self.name], args=[expected_token], client=self.redis
)
):
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
def extend(
self, additional_time: Number, replace_ttl: bool = False
) -> Awaitable[bool]:
"""
Adds more time to an already acquired lock.
``additional_time`` can be specified as an integer or a float, both
representing the number of seconds to add.
``replace_ttl`` if False (the default), add `additional_time` to
the lock's existing ttl. If True, replace the lock's ttl with
`additional_time`.
"""
if self.local.token is None:
raise LockError("Cannot extend an unlocked lock")
if self.timeout is None:
raise LockError("Cannot extend a lock with no timeout")
return self.do_extend(additional_time, replace_ttl)
async def do_extend(self, additional_time, replace_ttl) -> bool:
additional_time = int(additional_time * 1000)
if not bool(
await self.lua_extend(
keys=[self.name],
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
client=self.redis,
)
):
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
return True
def reacquire(self) -> Awaitable[bool]:
"""
Resets a TTL of an already acquired lock back to a timeout value.
"""
if self.local.token is None:
raise LockError("Cannot reacquire an unlocked lock")
if self.timeout is None:
raise LockError("Cannot reacquire a lock with no timeout")
return self.do_reacquire()
async def do_reacquire(self) -> bool:
timeout = int(self.timeout * 1000)
if not bool(
await self.lua_reacquire(
keys=[self.name], args=[self.local.token, timeout], client=self.redis
)
):
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
return True

View File

@@ -0,0 +1,530 @@
import asyncio
import logging
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union
from redis.asyncio.client import PubSubHandler
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy
from redis.background import BackgroundScheduler
from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.circuit import State as CBState
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import experimental
logger = logging.getLogger(__name__)
@experimental
class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
"""
Client that operates on multiple logical Redis databases.
Should be used in Active-Active database setups.
"""
def __init__(self, config: MultiDbConfig):
self._databases = config.databases()
self._health_checks = config.default_health_checks()
if config.health_checks is not None:
self._health_checks.extend(config.health_checks)
self._health_check_interval = config.health_check_interval
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
config.health_check_probes, config.health_check_delay
)
self._failure_detectors = config.default_failure_detectors()
if config.failure_detectors is not None:
self._failure_detectors.extend(config.failure_detectors)
self._failover_strategy = (
config.default_failover_strategy()
if config.failover_strategy is None
else config.failover_strategy
)
self._failover_strategy.set_databases(self._databases)
self._auto_fallback_interval = config.auto_fallback_interval
self._event_dispatcher = config.event_dispatcher
self._command_retry = config.command_retry
self._command_retry.update_supported_errors([ConnectionRefusedError])
self.command_executor = DefaultCommandExecutor(
failure_detectors=self._failure_detectors,
databases=self._databases,
command_retry=self._command_retry,
failover_strategy=self._failover_strategy,
failover_attempts=config.failover_attempts,
failover_delay=config.failover_delay,
event_dispatcher=self._event_dispatcher,
auto_fallback_interval=self._auto_fallback_interval,
)
self.initialized = False
self._hc_lock = asyncio.Lock()
self._bg_scheduler = BackgroundScheduler()
self._config = config
self._recurring_hc_task = None
self._hc_tasks = []
self._half_open_state_task = None
async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
if not self.initialized:
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if self._recurring_hc_task:
self._recurring_hc_task.cancel()
if self._half_open_state_task:
self._half_open_state_task.cancel()
for hc_task in self._hc_tasks:
hc_task.cancel()
async def initialize(self):
"""
Perform initialization of databases to define their initial state.
"""
async def raise_exception_on_failed_hc(error):
raise error
# Initial databases check to define initial state
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
# Starts recurring health checks on the background.
self._recurring_hc_task = asyncio.create_task(
self._bg_scheduler.run_recurring_async(
self._health_check_interval,
self._check_databases_health,
)
)
is_active_db_found = False
for database, weight in self._databases:
# Set on state changed callback for each circuit.
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
# Set states according to a weights and circuit state
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
await self.command_executor.set_active_database(database)
is_active_db_found = True
if not is_active_db_found:
raise NoValidDatabaseException(
"Initial connection failed - no active database found"
)
self.initialized = True
def get_databases(self) -> Databases:
"""
Returns a sorted (by weight) list of all databases.
"""
return self._databases
async def set_active_database(self, database: AsyncDatabase) -> None:
"""
Promote one of the existing databases to become an active.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
await self._check_db_health(database)
if database.circuit.state == CBState.CLOSED:
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
await self.command_executor.set_active_database(database)
return
raise NoValidDatabaseException(
"Cannot set active database, database is unhealthy"
)
async def add_database(self, database: AsyncDatabase):
"""
Adds a new database to the database list.
"""
for existing_db, _ in self._databases:
if existing_db == database:
raise ValueError("Given database already exists")
await self._check_db_health(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.add(database, database.weight)
await self._change_active_database(database, highest_weighted_db)
async def _change_active_database(
self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase
):
if (
new_database.weight > highest_weight_database.weight
and new_database.circuit.state == CBState.CLOSED
):
await self.command_executor.set_active_database(new_database)
async def remove_database(self, database: AsyncDatabase):
"""
Removes a database from the database list.
"""
weight = self._databases.remove(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
if (
highest_weight <= weight
and highest_weighted_db.circuit.state == CBState.CLOSED
):
await self.command_executor.set_active_database(highest_weighted_db)
async def update_database_weight(self, database: AsyncDatabase, weight: float):
"""
Updates a database from the database list.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.update_weight(database, weight)
database.weight = weight
await self._change_active_database(database, highest_weighted_db)
def add_failure_detector(self, failure_detector: AsyncFailureDetector):
"""
Adds a new failure detector to the database.
"""
self._failure_detectors.append(failure_detector)
async def add_health_check(self, healthcheck: HealthCheck):
"""
Adds a new health check to the database.
"""
async with self._hc_lock:
self._health_checks.append(healthcheck)
async def execute_command(self, *args, **options):
"""
Executes a single command and return its result.
"""
if not self.initialized:
await self.initialize()
return await self.command_executor.execute_command(*args, **options)
def pipeline(self):
"""
Enters into pipeline mode of the client.
"""
return Pipeline(self)
async def transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
"""
Executes callable as transaction.
"""
if not self.initialized:
await self.initialize()
return await self.command_executor.execute_transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay,
)
async def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
if not self.initialized:
await self.initialize()
return PubSub(self, **kwargs)
async def _check_databases_health(
self,
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
):
"""
Runs health checks as a recurring task.
Runs health checks against all databases.
"""
try:
self._hc_tasks = [
asyncio.create_task(self._check_db_health(database))
for database, _ in self._databases
]
results = await asyncio.wait_for(
asyncio.gather(
*self._hc_tasks,
return_exceptions=True,
),
timeout=self._health_check_interval,
)
except asyncio.TimeoutError:
raise asyncio.TimeoutError(
"Health check execution exceeds health_check_interval"
)
for result in results:
if isinstance(result, UnhealthyDatabaseException):
unhealthy_db = result.database
unhealthy_db.circuit.state = CBState.OPEN
logger.exception(
"Health check failed, due to exception",
exc_info=result.original_exception,
)
if on_error:
on_error(result.original_exception)
async def _check_db_health(self, database: AsyncDatabase) -> bool:
"""
Runs health checks on the given database until first failure.
"""
# Health check will setup circuit state
is_healthy = await self._health_check_policy.execute(
self._health_checks, database
)
if not is_healthy:
if database.circuit.state != CBState.OPEN:
database.circuit.state = CBState.OPEN
return is_healthy
elif is_healthy and database.circuit.state != CBState.CLOSED:
database.circuit.state = CBState.CLOSED
return is_healthy
def _on_circuit_state_change_callback(
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
):
loop = asyncio.get_running_loop()
if new_state == CBState.HALF_OPEN:
self._half_open_state_task = asyncio.create_task(
self._check_db_health(circuit.database)
)
return
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)
async def aclose(self):
if self.command_executor.active_database:
await self.command_executor.active_database.client.aclose()
def _half_open_circuit(circuit: CircuitBreaker):
circuit.state = CBState.HALF_OPEN
class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands):
"""
Pipeline implementation for multiple logical Redis databases.
"""
def __init__(self, client: MultiDBClient):
self._command_stack = []
self._client = client
async def __aenter__(self: "Pipeline") -> "Pipeline":
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.reset()
await self._client.__aexit__(exc_type, exc_value, traceback)
def __await__(self):
return self._async_self().__await__()
async def _async_self(self):
return self
def __len__(self) -> int:
return len(self._command_stack)
def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True
async def reset(self) -> None:
self._command_stack = []
async def aclose(self) -> None:
"""Close the pipeline"""
await self.reset()
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called
Returns the current Pipeline object back so commands can be
chained together, such as:
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
"""
self._command_stack.append((args, options))
return self
def execute_command(self, *args, **kwargs):
"""Adds a command to the stack"""
return self.pipeline_execute_command(*args, **kwargs)
async def execute(self) -> List[Any]:
"""Execute all the commands in the current pipeline"""
if not self._client.initialized:
await self._client.initialize()
try:
return await self._client.command_executor.execute_pipeline(
tuple(self._command_stack)
)
finally:
await self.reset()
class PubSub:
"""
PubSub object for multi database client.
"""
def __init__(self, client: MultiDBClient, **kwargs):
"""Initialize the PubSub object for a multi-database client.
Args:
client: MultiDBClient instance to use for pub/sub operations
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
"""
self._client = client
self._client.command_executor.pubsub(**kwargs)
async def __aenter__(self) -> "PubSub":
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
await self.aclose()
async def aclose(self):
return await self._client.command_executor.execute_pubsub_method("aclose")
@property
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed
async def execute_command(self, *args: EncodableT):
return await self._client.command_executor.execute_pubsub_method(
"execute_command", *args
)
async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
expect a pattern name as the key and a callable as the value. A
pattern's callable will be invoked automatically when a message is
received on that pattern rather than producing a message via
``listen()``.
"""
return await self._client.command_executor.execute_pubsub_method(
"psubscribe", *args, **kwargs
)
async def punsubscribe(self, *args: ChannelT):
"""
Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
return await self._client.command_executor.execute_pubsub_method(
"punsubscribe", *args
)
async def subscribe(self, *args: ChannelT, **kwargs: Callable):
"""
Subscribe to channels. Channels supplied as keyword arguments expect
a channel name as the key and a callable as the value. A channel's
callable will be invoked automatically when a message is received on
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
return await self._client.command_executor.execute_pubsub_method(
"subscribe", *args, **kwargs
)
async def unsubscribe(self, *args):
"""
Unsubscribe from the supplied channels. If empty, unsubscribe from
all channels
"""
return await self._client.command_executor.execute_pubsub_method(
"unsubscribe", *args
)
async def get_message(
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number or None to wait indefinitely.
"""
return await self._client.command_executor.execute_pubsub_method(
"get_message",
ignore_subscribe_messages=ignore_subscribe_messages,
timeout=timeout,
)
async def run(
self,
*,
exception_handler=None,
poll_timeout: float = 1.0,
) -> None:
"""Process pub/sub messages using registered callbacks.
This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
redis-py, but it is a coroutine. To launch it as a separate task, use
``asyncio.create_task``:
>>> task = asyncio.create_task(pubsub.run())
To shut it down, use asyncio cancellation:
>>> task.cancel()
>>> await task
"""
return await self._client.command_executor.execute_pubsub_run(
sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self
)

View File

@@ -0,0 +1,339 @@
from abc import abstractmethod
from asyncio import iscoroutinefunction
from datetime import datetime
from typing import Any, Awaitable, Callable, List, Optional, Union
from redis.asyncio import RedisCluster
from redis.asyncio.client import Pipeline, PubSub
from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases
from redis.asyncio.multidb.event import (
AsyncActiveDatabaseChanged,
CloseConnectionOnActiveDatabaseChanged,
RegisterCommandFailure,
ResubscribeOnActiveDatabaseChanged,
)
from redis.asyncio.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
AsyncFailoverStrategy,
DefaultFailoverStrategyExecutor,
FailoverStrategyExecutor,
)
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.asyncio.retry import Retry
from redis.event import AsyncOnCommandsFailEvent, EventDispatcherInterface
from redis.multidb.circuit import State as CBState
from redis.multidb.command_executor import BaseCommandExecutor, CommandExecutor
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.typing import KeyT
class AsyncCommandExecutor(CommandExecutor):
@property
@abstractmethod
def databases(self) -> Databases:
"""Returns a list of databases."""
pass
@property
@abstractmethod
def failure_detectors(self) -> List[AsyncFailureDetector]:
"""Returns a list of failure detectors."""
pass
@abstractmethod
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
"""Adds a new failure detector to the list of failure detectors."""
pass
@property
@abstractmethod
def active_database(self) -> Optional[AsyncDatabase]:
"""Returns currently active database."""
pass
@abstractmethod
async def set_active_database(self, database: AsyncDatabase) -> None:
"""Sets the currently active database."""
pass
@property
@abstractmethod
def active_pubsub(self) -> Optional[PubSub]:
"""Returns currently active pubsub."""
pass
@active_pubsub.setter
@abstractmethod
def active_pubsub(self, pubsub: PubSub) -> None:
"""Sets currently active pubsub."""
pass
@property
@abstractmethod
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
"""Returns failover strategy executor."""
pass
@property
@abstractmethod
def command_retry(self) -> Retry:
"""Returns command retry object."""
pass
@abstractmethod
async def pubsub(self, **kwargs):
"""Initializes a PubSub object on a currently active database"""
pass
@abstractmethod
async def execute_command(self, *args, **options):
"""Executes a command and returns the result."""
pass
@abstractmethod
async def execute_pipeline(self, command_stack: tuple):
"""Executes a stack of commands in pipeline."""
pass
@abstractmethod
async def execute_transaction(
self, transaction: Callable[[Pipeline], None], *watches, **options
):
"""Executes a transaction block wrapped in callback."""
pass
@abstractmethod
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
"""Executes a given method on active pub/sub."""
pass
@abstractmethod
async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
"""Executes pub/sub run in a thread."""
pass
class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor):
def __init__(
self,
failure_detectors: List[AsyncFailureDetector],
databases: Databases,
command_retry: Retry,
failover_strategy: AsyncFailoverStrategy,
event_dispatcher: EventDispatcherInterface,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
):
"""
Initialize the DefaultCommandExecutor instance.
Args:
failure_detectors: List of failure detector instances to monitor database health
databases: Collection of available databases to execute commands on
command_retry: Retry policy for failed command execution
failover_strategy: Strategy for handling database failover
event_dispatcher: Interface for dispatching events
failover_attempts: Number of failover attempts
failover_delay: Delay between failover attempts
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
"""
super().__init__(auto_fallback_interval)
for fd in failure_detectors:
fd.set_command_executor(command_executor=self)
self._databases = databases
self._failure_detectors = failure_detectors
self._command_retry = command_retry
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
failover_strategy, failover_attempts, failover_delay
)
self._event_dispatcher = event_dispatcher
self._active_database: Optional[Database] = None
self._active_pubsub: Optional[PubSub] = None
self._active_pubsub_kwargs = {}
self._setup_event_dispatcher()
self._schedule_next_fallback()
@property
def databases(self) -> Databases:
return self._databases
@property
def failure_detectors(self) -> List[AsyncFailureDetector]:
return self._failure_detectors
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
self._failure_detectors.append(failure_detector)
@property
def active_database(self) -> Optional[AsyncDatabase]:
return self._active_database
async def set_active_database(self, database: AsyncDatabase) -> None:
old_active = self._active_database
self._active_database = database
if old_active is not None and old_active is not database:
await self._event_dispatcher.dispatch_async(
AsyncActiveDatabaseChanged(
old_active,
self._active_database,
self,
**self._active_pubsub_kwargs,
)
)
@property
def active_pubsub(self) -> Optional[PubSub]:
return self._active_pubsub
@active_pubsub.setter
def active_pubsub(self, pubsub: PubSub) -> None:
self._active_pubsub = pubsub
@property
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
return self._failover_strategy_executor
@property
def command_retry(self) -> Retry:
return self._command_retry
def pubsub(self, **kwargs):
if self._active_pubsub is None:
if isinstance(self._active_database.client, RedisCluster):
raise ValueError("PubSub is not supported for RedisCluster")
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
async def execute_command(self, *args, **options):
async def callback():
response = await self._active_database.client.execute_command(
*args, **options
)
await self._register_command_execution(args)
return response
return await self._execute_with_failure_detection(callback, args)
async def execute_pipeline(self, command_stack: tuple):
async def callback():
async with self._active_database.client.pipeline() as pipe:
for command, options in command_stack:
pipe.execute_command(*command, **options)
response = await pipe.execute()
await self._register_command_execution(command_stack)
return response
return await self._execute_with_failure_detection(callback, command_stack)
async def execute_transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
async def callback():
response = await self._active_database.client.transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay,
)
await self._register_command_execution(())
return response
return await self._execute_with_failure_detection(callback)
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
async def callback():
method = getattr(self.active_pubsub, method_name)
if iscoroutinefunction(method):
response = await method(*args, **kwargs)
else:
response = method(*args, **kwargs)
await self._register_command_execution(args)
return response
return await self._execute_with_failure_detection(callback, *args)
async def execute_pubsub_run(
self, sleep_time: float, exception_handler=None, pubsub=None
) -> Any:
async def callback():
return await self._active_pubsub.run(
poll_timeout=sleep_time,
exception_handler=exception_handler,
pubsub=pubsub,
)
return await self._execute_with_failure_detection(callback)
async def _execute_with_failure_detection(
self, callback: Callable, cmds: tuple = ()
):
"""
Execute a commands execution callback with failure detection.
"""
async def wrapper():
# On each retry we need to check active database as it might change.
await self._check_active_database()
return await callback()
return await self._command_retry.call_with_retry(
lambda: wrapper(),
lambda error: self._on_command_fail(error, *cmds),
)
async def _check_active_database(self):
"""
Checks if active a database needs to be updated.
"""
if (
self._active_database is None
or self._active_database.circuit.state != CBState.CLOSED
or (
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
and self._next_fallback_attempt <= datetime.now()
)
):
await self.set_active_database(
await self._failover_strategy_executor.execute()
)
self._schedule_next_fallback()
async def _on_command_fail(self, error, *args):
await self._event_dispatcher.dispatch_async(
AsyncOnCommandsFailEvent(args, error)
)
async def _register_command_execution(self, cmd: tuple):
for detector in self._failure_detectors:
await detector.register_command_execution(cmd)
def _setup_event_dispatcher(self):
"""
Registers necessary listeners.
"""
failure_listener = RegisterCommandFailure(self._failure_detectors)
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
self._event_dispatcher.register_listeners(
{
AsyncOnCommandsFailEvent: [failure_listener],
AsyncActiveDatabaseChanged: [
close_connection_listener,
resubscribe_listener,
],
}
)

View File

@@ -0,0 +1,210 @@
from dataclasses import dataclass, field
from typing import List, Optional, Type, Union
import pybreaker
from redis.asyncio import ConnectionPool, Redis, RedisCluster
from redis.asyncio.multidb.database import Database, Databases
from redis.asyncio.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
AsyncFailoverStrategy,
WeightBasedFailoverStrategy,
)
from redis.asyncio.multidb.failure_detector import (
AsyncFailureDetector,
FailureDetectorAsyncWrapper,
)
from redis.asyncio.multidb.healthcheck import (
DEFAULT_HEALTH_CHECK_DELAY,
DEFAULT_HEALTH_CHECK_INTERVAL,
DEFAULT_HEALTH_CHECK_POLICY,
DEFAULT_HEALTH_CHECK_PROBES,
HealthCheck,
HealthCheckPolicies,
PingHealthCheck,
)
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
from redis.data_structure import WeightedList
from redis.event import EventDispatcher, EventDispatcherInterface
from redis.multidb.circuit import (
DEFAULT_GRACE_PERIOD,
CircuitBreaker,
PBCircuitBreakerAdapter,
)
from redis.multidb.failure_detector import (
DEFAULT_FAILURE_RATE_THRESHOLD,
DEFAULT_FAILURES_DETECTION_WINDOW,
DEFAULT_MIN_NUM_FAILURES,
CommandFailureDetector,
)
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
def default_event_dispatcher() -> EventDispatcherInterface:
return EventDispatcher()
@dataclass
class DatabaseConfig:
"""
Dataclass representing the configuration for a database connection.
This class is used to store configuration settings for a database connection,
including client options, connection sourcing details, circuit breaker settings,
and cluster-specific properties. It provides a structure for defining these
attributes and allows for the creation of customized configurations for various
database setups.
Attributes:
weight (float): Weight of the database to define the active one.
client_kwargs (dict): Additional parameters for the database client connection.
from_url (Optional[str]): Redis URL way of connecting to the database.
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
on public Redis Enterprise endpoints.
Methods:
default_circuit_breaker:
Generates and returns a default CircuitBreaker instance adapted for use.
"""
weight: float = 1.0
client_kwargs: dict = field(default_factory=dict)
from_url: Optional[str] = None
from_pool: Optional[ConnectionPool] = None
circuit: Optional[CircuitBreaker] = None
grace_period: float = DEFAULT_GRACE_PERIOD
health_check_url: Optional[str] = None
def default_circuit_breaker(self) -> CircuitBreaker:
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
return PBCircuitBreakerAdapter(circuit_breaker)
@dataclass
class MultiDbConfig:
"""
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
Attributes:
databases_config: A list of database configurations.
client_class: The client class used to manage database connections.
command_retry: Retry strategy for executing database commands.
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
min_num_failures: Minimal count of failures required for failover
failure_rate_threshold: Percentage of failures required for failover
failures_detection_window: Time interval for tracking database failures.
health_checks: Optional list of additional health checks performed on databases.
health_check_interval: Time interval for executing health checks.
health_check_probes: Number of attempts to evaluate the health of a database.
health_check_delay: Delay between health check attempts.
failover_strategy: Optional strategy for handling database failover scenarios.
failover_attempts: Number of retries allowed for failover operations.
failover_delay: Delay between failover attempts.
auto_fallback_interval: Time interval to trigger automatic fallback.
event_dispatcher: Interface for dispatching events related to database operations.
Methods:
databases:
Retrieves a collection of database clients managed by weighted configurations.
Initializes database clients based on the provided configuration and removes
redundant retry objects for lower-level clients to rely on global retry logic.
default_failure_detectors:
Returns the default list of failure detectors used to monitor database failures.
default_health_checks:
Returns the default list of health checks used to monitor database health
with specific retry and backoff strategies.
default_failover_strategy:
Provides the default failover strategy used for handling failover scenarios
with defined retry and backoff configurations.
"""
databases_config: List[DatabaseConfig]
client_class: Type[Union[Redis, RedisCluster]] = Redis
command_retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
)
failure_detectors: Optional[List[AsyncFailureDetector]] = None
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
health_checks: Optional[List[HealthCheck]] = None
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
failover_strategy: Optional[AsyncFailoverStrategy] = None
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
failover_delay: float = DEFAULT_FAILOVER_DELAY
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
event_dispatcher: EventDispatcherInterface = field(
default_factory=default_event_dispatcher
)
def databases(self) -> Databases:
databases = WeightedList()
for database_config in self.databases_config:
# The retry object is not used in the lower level clients, so we can safely remove it.
# We rely on command_retry in terms of global retries.
database_config.client_kwargs.update(
{"retry": Retry(retries=0, backoff=NoBackoff())}
)
if database_config.from_url:
client = self.client_class.from_url(
database_config.from_url, **database_config.client_kwargs
)
elif database_config.from_pool:
database_config.from_pool.set_retry(
Retry(retries=0, backoff=NoBackoff())
)
client = self.client_class.from_pool(
connection_pool=database_config.from_pool
)
else:
client = self.client_class(**database_config.client_kwargs)
circuit = (
database_config.default_circuit_breaker()
if database_config.circuit is None
else database_config.circuit
)
databases.add(
Database(
client=client,
circuit=circuit,
weight=database_config.weight,
health_check_url=database_config.health_check_url,
),
database_config.weight,
)
return databases
def default_failure_detectors(self) -> List[AsyncFailureDetector]:
return [
FailureDetectorAsyncWrapper(
CommandFailureDetector(
min_num_failures=self.min_num_failures,
failure_rate_threshold=self.failure_rate_threshold,
failure_detection_window=self.failures_detection_window,
)
),
]
def default_health_checks(self) -> List[HealthCheck]:
return [
PingHealthCheck(),
]
def default_failover_strategy(self) -> AsyncFailoverStrategy:
return WeightBasedFailoverStrategy()

View File

@@ -0,0 +1,69 @@
from abc import abstractmethod
from typing import Optional, Union
from redis.asyncio import Redis, RedisCluster
from redis.data_structure import WeightedList
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.database import AbstractDatabase, BaseDatabase
from redis.typing import Number
class AsyncDatabase(AbstractDatabase):
"""Database with an underlying asynchronous redis client."""
@property
@abstractmethod
def client(self) -> Union[Redis, RedisCluster]:
"""The underlying redis client."""
pass
@client.setter
@abstractmethod
def client(self, client: Union[Redis, RedisCluster]):
"""Set the underlying redis client."""
pass
@property
@abstractmethod
def circuit(self) -> CircuitBreaker:
"""Circuit breaker for the current database."""
pass
@circuit.setter
@abstractmethod
def circuit(self, circuit: CircuitBreaker):
"""Set the circuit breaker for the current database."""
pass
Databases = WeightedList[tuple[AsyncDatabase, Number]]
class Database(BaseDatabase, AsyncDatabase):
def __init__(
self,
client: Union[Redis, RedisCluster],
circuit: CircuitBreaker,
weight: float,
health_check_url: Optional[str] = None,
):
self._client = client
self._cb = circuit
self._cb.database = self
super().__init__(weight, health_check_url)
@property
def client(self) -> Union[Redis, RedisCluster]:
return self._client
@client.setter
def client(self, client: Union[Redis, RedisCluster]):
self._client = client
@property
def circuit(self) -> CircuitBreaker:
return self._cb
@circuit.setter
def circuit(self, circuit: CircuitBreaker):
self._cb = circuit

View File

@@ -0,0 +1,84 @@
from typing import List
from redis.asyncio import Redis
from redis.asyncio.multidb.database import AsyncDatabase
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent
class AsyncActiveDatabaseChanged:
"""
Event fired when an async active database has been changed.
"""
def __init__(
self,
old_database: AsyncDatabase,
new_database: AsyncDatabase,
command_executor,
**kwargs,
):
self._old_database = old_database
self._new_database = new_database
self._command_executor = command_executor
self._kwargs = kwargs
@property
def old_database(self) -> AsyncDatabase:
return self._old_database
@property
def new_database(self) -> AsyncDatabase:
return self._new_database
@property
def command_executor(self):
return self._command_executor
@property
def kwargs(self):
return self._kwargs
class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface):
"""
Re-subscribe the currently active pub / sub to a new active database.
"""
async def listen(self, event: AsyncActiveDatabaseChanged):
old_pubsub = event.command_executor.active_pubsub
if old_pubsub is not None:
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
new_pubsub.channels = old_pubsub.channels
new_pubsub.patterns = old_pubsub.patterns
await new_pubsub.on_connect(None)
event.command_executor.active_pubsub = new_pubsub
await old_pubsub.aclose()
class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface):
"""
Close connection to the old active database.
"""
async def listen(self, event: AsyncActiveDatabaseChanged):
await event.old_database.client.aclose()
if isinstance(event.old_database.client, Redis):
await event.old_database.client.connection_pool.update_active_connections_for_reconnect()
await event.old_database.client.connection_pool.disconnect()
class RegisterCommandFailure(AsyncEventListenerInterface):
"""
Event listener that registers command failures and passing it to the failure detectors.
"""
def __init__(self, failure_detectors: List[AsyncFailureDetector]):
self._failure_detectors = failure_detectors
async def listen(self, event: AsyncOnCommandsFailEvent) -> None:
for failure_detector in self._failure_detectors:
await failure_detector.register_failure(event.exception, event.commands)

View File

@@ -0,0 +1,125 @@
import time
from abc import ABC, abstractmethod
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.data_structure import WeightedList
from redis.multidb.circuit import State as CBState
from redis.multidb.exception import (
NoValidDatabaseException,
TemporaryUnavailableException,
)
DEFAULT_FAILOVER_ATTEMPTS = 10
DEFAULT_FAILOVER_DELAY = 12
class AsyncFailoverStrategy(ABC):
@abstractmethod
async def database(self) -> AsyncDatabase:
"""Select the database according to the strategy."""
pass
@abstractmethod
def set_databases(self, databases: Databases) -> None:
"""Set the database strategy operates on."""
pass
class FailoverStrategyExecutor(ABC):
@property
@abstractmethod
def failover_attempts(self) -> int:
"""The number of failover attempts."""
pass
@property
@abstractmethod
def failover_delay(self) -> float:
"""The delay between failover attempts."""
pass
@property
@abstractmethod
def strategy(self) -> AsyncFailoverStrategy:
"""The strategy to execute."""
pass
@abstractmethod
async def execute(self) -> AsyncDatabase:
"""Execute the failover strategy."""
pass
class WeightBasedFailoverStrategy(AsyncFailoverStrategy):
"""
Failover strategy based on database weights.
"""
def __init__(self):
self._databases = WeightedList()
async def database(self) -> AsyncDatabase:
for database, _ in self._databases:
if database.circuit.state == CBState.CLOSED:
return database
raise NoValidDatabaseException("No valid database available for communication")
def set_databases(self, databases: Databases) -> None:
self._databases = databases
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
"""
Executes given failover strategy.
"""
def __init__(
self,
strategy: AsyncFailoverStrategy,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
):
self._strategy = strategy
self._failover_attempts = failover_attempts
self._failover_delay = failover_delay
self._next_attempt_ts: int = 0
self._failover_counter: int = 0
@property
def failover_attempts(self) -> int:
return self._failover_attempts
@property
def failover_delay(self) -> float:
return self._failover_delay
@property
def strategy(self) -> AsyncFailoverStrategy:
return self._strategy
async def execute(self) -> AsyncDatabase:
try:
database = await self._strategy.database()
self._reset()
return database
except NoValidDatabaseException as e:
if self._next_attempt_ts == 0:
self._next_attempt_ts = time.time() + self._failover_delay
self._failover_counter += 1
elif time.time() >= self._next_attempt_ts:
self._next_attempt_ts += self._failover_delay
self._failover_counter += 1
if self._failover_counter > self._failover_attempts:
self._reset()
raise e
else:
raise TemporaryUnavailableException(
"No database connections currently available. "
"This is a temporary condition - please retry the operation."
)
def _reset(self) -> None:
self._next_attempt_ts = 0
self._failover_counter = 0

View File

@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from redis.multidb.failure_detector import FailureDetector
class AsyncFailureDetector(ABC):
@abstractmethod
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
"""Register a failure that occurred during command execution."""
pass
@abstractmethod
async def register_command_execution(self, cmd: tuple) -> None:
"""Register a command execution."""
pass
@abstractmethod
def set_command_executor(self, command_executor) -> None:
"""Set the command executor for this failure."""
pass
class FailureDetectorAsyncWrapper(AsyncFailureDetector):
"""
Async wrapper for the failure detector.
"""
def __init__(self, failure_detector: FailureDetector) -> None:
self._failure_detector = failure_detector
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
self._failure_detector.register_failure(exception, cmd)
async def register_command_execution(self, cmd: tuple) -> None:
self._failure_detector.register_command_execution(cmd)
def set_command_executor(self, command_executor) -> None:
self._failure_detector.set_command_executor(command_executor)

View File

@@ -0,0 +1,285 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Optional, Tuple, Union
from redis.asyncio import Redis
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
from redis.backoff import NoBackoff
from redis.http.http_client import HttpClient
from redis.multidb.exception import UnhealthyDatabaseException
from redis.retry import Retry
DEFAULT_HEALTH_CHECK_PROBES = 3
DEFAULT_HEALTH_CHECK_INTERVAL = 5
DEFAULT_HEALTH_CHECK_DELAY = 0.5
DEFAULT_LAG_AWARE_TOLERANCE = 5000
logger = logging.getLogger(__name__)
class HealthCheck(ABC):
@abstractmethod
async def check_health(self, database) -> bool:
"""Function to determine the health status."""
pass
class HealthCheckPolicy(ABC):
"""
Health checks execution policy.
"""
@property
@abstractmethod
def health_check_probes(self) -> int:
"""Number of probes to execute health checks."""
pass
@property
@abstractmethod
def health_check_delay(self) -> float:
"""Delay between health check probes."""
pass
@abstractmethod
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
"""Execute health checks and return database health status."""
pass
class AbstractHealthCheckPolicy(HealthCheckPolicy):
def __init__(self, health_check_probes: int, health_check_delay: float):
if health_check_probes < 1:
raise ValueError("health_check_probes must be greater than 0")
self._health_check_probes = health_check_probes
self._health_check_delay = health_check_delay
@property
def health_check_probes(self) -> int:
return self._health_check_probes
@property
def health_check_delay(self) -> float:
return self._health_check_delay
@abstractmethod
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
pass
class HealthyAllPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if all health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
for attempt in range(self.health_check_probes):
try:
if not await health_check.check_health(database):
return False
except Exception as e:
raise UnhealthyDatabaseException("Unhealthy database", database, e)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
return True
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if a majority of health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
if self.health_check_probes % 2 == 0:
allowed_unsuccessful_probes = self.health_check_probes / 2
else:
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
for attempt in range(self.health_check_probes):
try:
if not await health_check.check_health(database):
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
return False
except Exception as e:
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
raise UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
return True
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if at least one health check probe is successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
is_healthy = False
for health_check in health_checks:
exception = None
for attempt in range(self.health_check_probes):
try:
if await health_check.check_health(database):
is_healthy = True
break
else:
is_healthy = False
except Exception as e:
exception = UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
await asyncio.sleep(self._health_check_delay)
if not is_healthy and not exception:
return is_healthy
elif not is_healthy and exception:
raise exception
return is_healthy
class HealthCheckPolicies(Enum):
HEALTHY_ALL = HealthyAllPolicy
HEALTHY_MAJORITY = HealthyMajorityPolicy
HEALTHY_ANY = HealthyAnyPolicy
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
class PingHealthCheck(HealthCheck):
"""
Health check based on PING command.
"""
async def check_health(self, database) -> bool:
if isinstance(database.client, Redis):
return await database.client.execute_command("PING")
else:
# For a cluster checks if all nodes are healthy.
all_nodes = database.client.get_nodes()
for node in all_nodes:
if not await node.redis_connection.execute_command("PING"):
return False
return True
class LagAwareHealthCheck(HealthCheck):
"""
Health check available for Redis Enterprise deployments.
Verify via REST API that the database is healthy based on different lags.
"""
def __init__(
self,
rest_api_port: int = 9443,
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
timeout: float = DEFAULT_TIMEOUT,
auth_basic: Optional[Tuple[str, str]] = None,
verify_tls: bool = True,
# TLS verification (server) options
ca_file: Optional[str] = None,
ca_path: Optional[str] = None,
ca_data: Optional[Union[str, bytes]] = None,
# Mutual TLS (client cert) options
client_cert_file: Optional[str] = None,
client_key_file: Optional[str] = None,
client_key_password: Optional[str] = None,
):
"""
Initialize LagAwareHealthCheck with the specified parameters.
Args:
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
auth_basic: Tuple of (username, password) for basic authentication
verify_tls: Whether to verify TLS certificates (default: True)
ca_file: Path to CA certificate file for TLS verification
ca_path: Path to CA certificates directory for TLS verification
ca_data: CA certificate data as string or bytes
client_cert_file: Path to client certificate file for mutual TLS
client_key_file: Path to client private key file for mutual TLS
client_key_password: Password for encrypted client private key
"""
self._http_client = AsyncHTTPClientWrapper(
HttpClient(
timeout=timeout,
auth_basic=auth_basic,
retry=Retry(NoBackoff(), retries=0),
verify_tls=verify_tls,
ca_file=ca_file,
ca_path=ca_path,
ca_data=ca_data,
client_cert_file=client_cert_file,
client_key_file=client_key_file,
client_key_password=client_key_password,
)
)
self._rest_api_port = rest_api_port
self._lag_aware_tolerance = lag_aware_tolerance
async def check_health(self, database) -> bool:
if database.health_check_url is None:
raise ValueError(
"Database health check url is not set. Please check DatabaseConfig for the current database."
)
if isinstance(database.client, Redis):
db_host = database.client.get_connection_kwargs()["host"]
else:
db_host = database.client.startup_nodes[0].host
base_url = f"{database.health_check_url}:{self._rest_api_port}"
self._http_client.client.base_url = base_url
# Find bdb matching to the current database host
matching_bdb = None
for bdb in await self._http_client.get("/v1/bdbs"):
for endpoint in bdb["endpoints"]:
if endpoint["dns_name"] == db_host:
matching_bdb = bdb
break
# In case if the host was set as public IP
for addr in endpoint["addr"]:
if addr == db_host:
matching_bdb = bdb
break
if matching_bdb is None:
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
raise ValueError("Could not find a matching bdb")
url = (
f"/v1/bdbs/{matching_bdb['uid']}/availability"
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
)
await self._http_client.get(url, expect_json=False)
# Status checked in an http client, otherwise HttpError will be raised
return True

View File

@@ -0,0 +1,58 @@
from asyncio import sleep
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
from redis.exceptions import ConnectionError, RedisError, TimeoutError
from redis.retry import AbstractRetry
T = TypeVar("T")
if TYPE_CHECKING:
from redis.backoff import AbstractBackoff
class Retry(AbstractRetry[RedisError]):
__hash__ = AbstractRetry.__hash__
def __init__(
self,
backoff: "AbstractBackoff",
retries: int,
supported_errors: Tuple[Type[RedisError], ...] = (
ConnectionError,
TimeoutError,
),
):
super().__init__(backoff, retries, supported_errors)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Retry):
return NotImplemented
return (
self._backoff == other._backoff
and self._retries == other._retries
and set(self._supported_errors) == set(other._supported_errors)
)
async def call_with_retry(
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
) -> T:
"""
Execute an operation that might fail and returns its result, or
raise the exception that was thrown depending on the `Backoff` object.
`do`: the operation to call. Expects no argument.
`fail`: the failure handler, expects the last error that was thrown
"""
self._backoff.reset()
failures = 0
while True:
try:
return await do()
except self._supported_errors as error:
failures += 1
await fail(error)
if self._retries >= 0 and failures > self._retries:
raise error
backoff = self._backoff.compute(failures)
if backoff > 0:
await sleep(backoff)

View File

@@ -0,0 +1,404 @@
import asyncio
import random
import weakref
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
from redis.asyncio.client import Redis
from redis.asyncio.connection import (
Connection,
ConnectionPool,
EncodableT,
SSLConnection,
)
from redis.commands import AsyncSentinelCommands
from redis.exceptions import (
ConnectionError,
ReadOnlyError,
ResponseError,
TimeoutError,
)
class MasterNotFoundError(ConnectionError):
pass
class SlaveNotFoundError(ConnectionError):
pass
class SentinelManagedConnection(Connection):
def __init__(self, **kwargs):
self.connection_pool = kwargs.pop("connection_pool")
super().__init__(**kwargs)
def __repr__(self):
s = f"<{self.__class__.__module__}.{self.__class__.__name__}"
if self.host:
host_info = f",host={self.host},port={self.port}"
s += host_info
return s + ")>"
async def connect_to(self, address):
self.host, self.port = address
await self.connect_check_health(
check_health=self.connection_pool.check_connection,
retry_socket_connect=False,
)
async def _connect_retry(self):
if self._reader:
return # already connected
if self.connection_pool.is_master:
await self.connect_to(await self.connection_pool.get_master_address())
else:
async for slave in self.connection_pool.rotate_slaves():
try:
return await self.connect_to(slave)
except ConnectionError:
continue
raise SlaveNotFoundError # Never be here
async def connect(self):
return await self.retry.call_with_retry(
self._connect_retry,
lambda error: asyncio.sleep(0),
)
async def read_response(
self,
disable_decoding: bool = False,
timeout: Optional[float] = None,
*,
disconnect_on_error: Optional[float] = True,
push_request: Optional[bool] = False,
):
try:
return await super().read_response(
disable_decoding=disable_decoding,
timeout=timeout,
disconnect_on_error=disconnect_on_error,
push_request=push_request,
)
except ReadOnlyError:
if self.connection_pool.is_master:
# When talking to a master, a ReadOnlyError when likely
# indicates that the previous master that we're still connected
# to has been demoted to a slave and there's a new master.
# calling disconnect will force the connection to re-query
# sentinel during the next connect() attempt.
await self.disconnect()
raise ConnectionError("The previous master is now a slave")
raise
class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
pass
class SentinelConnectionPool(ConnectionPool):
"""
Sentinel backed connection pool.
If ``check_connection`` flag is set to True, SentinelManagedConnection
sends a PING command right after establishing the connection.
"""
def __init__(self, service_name, sentinel_manager, **kwargs):
kwargs["connection_class"] = kwargs.get(
"connection_class",
(
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection
),
)
self.is_master = kwargs.pop("is_master", True)
self.check_connection = kwargs.pop("check_connection", False)
super().__init__(**kwargs)
self.connection_kwargs["connection_pool"] = weakref.proxy(self)
self.service_name = service_name
self.sentinel_manager = sentinel_manager
self.master_address = None
self.slave_rr_counter = None
def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>"
)
def reset(self):
super().reset()
self.master_address = None
self.slave_rr_counter = None
def owns_connection(self, connection: Connection):
check = not self.is_master or (
self.is_master and self.master_address == (connection.host, connection.port)
)
return check and super().owns_connection(connection)
async def get_master_address(self):
master_address = await self.sentinel_manager.discover_master(self.service_name)
if self.is_master:
if self.master_address != master_address:
self.master_address = master_address
# disconnect any idle connections so that they reconnect
# to the new master the next time that they are used.
await self.disconnect(inuse_connections=False)
return master_address
async def rotate_slaves(self) -> AsyncIterator:
"""Round-robin slave balancer"""
slaves = await self.sentinel_manager.discover_slaves(self.service_name)
if slaves:
if self.slave_rr_counter is None:
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
for _ in range(len(slaves)):
self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves)
slave = slaves[self.slave_rr_counter]
yield slave
# Fallback to the master connection
try:
yield await self.get_master_address()
except MasterNotFoundError:
pass
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
class Sentinel(AsyncSentinelCommands):
"""
Redis Sentinel cluster client
>>> from redis.sentinel import Sentinel
>>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1)
>>> master = sentinel.master_for('mymaster', socket_timeout=0.1)
>>> await master.set('foo', 'bar')
>>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1)
>>> await slave.get('foo')
b'bar'
``sentinels`` is a list of sentinel nodes. Each node is represented by
a pair (hostname, port).
``min_other_sentinels`` defined a minimum number of peers for a sentinel.
When querying a sentinel, if it doesn't meet this threshold, responses
from that sentinel won't be considered valid.
``sentinel_kwargs`` is a dictionary of connection arguments used when
connecting to sentinel instances. Any argument that can be passed to
a normal Redis connection can be specified here. If ``sentinel_kwargs`` is
not specified, any socket_timeout and socket_keepalive options specified
in ``connection_kwargs`` will be used.
``connection_kwargs`` are keyword arguments that will be used when
establishing a connection to a Redis server.
"""
def __init__(
self,
sentinels,
min_other_sentinels=0,
sentinel_kwargs=None,
force_master_ip=None,
**connection_kwargs,
):
# if sentinel_kwargs isn't defined, use the socket_* options from
# connection_kwargs
if sentinel_kwargs is None:
sentinel_kwargs = {
k: v for k, v in connection_kwargs.items() if k.startswith("socket_")
}
self.sentinel_kwargs = sentinel_kwargs
self.sentinels = [
Redis(host=hostname, port=port, **self.sentinel_kwargs)
for hostname, port in sentinels
]
self.min_other_sentinels = min_other_sentinels
self.connection_kwargs = connection_kwargs
self._force_master_ip = force_master_ip
async def execute_command(self, *args, **kwargs):
"""
Execute Sentinel command in sentinel nodes.
once - If set to True, then execute the resulting command on a single
node at random, rather than across the entire sentinel cluster.
"""
once = bool(kwargs.pop("once", False))
# Check if command is supposed to return the original
# responses instead of boolean value.
return_responses = bool(kwargs.pop("return_responses", False))
if once:
response = await random.choice(self.sentinels).execute_command(
*args, **kwargs
)
if return_responses:
return [response]
else:
return True if response else False
tasks = [
asyncio.Task(sentinel.execute_command(*args, **kwargs))
for sentinel in self.sentinels
]
responses = await asyncio.gather(*tasks)
if return_responses:
return responses
return all(responses)
def __repr__(self):
sentinel_addresses = []
for sentinel in self.sentinels:
sentinel_addresses.append(
f"{sentinel.connection_pool.connection_kwargs['host']}:"
f"{sentinel.connection_pool.connection_kwargs['port']}"
)
return (
f"<{self.__class__}.{self.__class__.__name__}"
f"(sentinels=[{','.join(sentinel_addresses)}])>"
)
def check_master_state(self, state: dict, service_name: str) -> bool:
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
return False
# Check if our sentinel doesn't see other nodes
if state["num-other-sentinels"] < self.min_other_sentinels:
return False
return True
async def discover_master(self, service_name: str):
"""
Asks sentinel servers for the Redis master's address corresponding
to the service labeled ``service_name``.
Returns a pair (address, port) or raises MasterNotFoundError if no
master is found.
"""
collected_errors = list()
for sentinel_no, sentinel in enumerate(self.sentinels):
try:
masters = await sentinel.sentinel_masters()
except (ConnectionError, TimeoutError) as e:
collected_errors.append(f"{sentinel} - {e!r}")
continue
state = masters.get(service_name)
if state and self.check_master_state(state, service_name):
# Put this sentinel at the top of the list
self.sentinels[0], self.sentinels[sentinel_no] = (
sentinel,
self.sentinels[0],
)
ip = (
self._force_master_ip
if self._force_master_ip is not None
else state["ip"]
)
return ip, state["port"]
error_info = ""
if len(collected_errors) > 0:
error_info = f" : {', '.join(collected_errors)}"
raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}")
def filter_slaves(
self, slaves: Iterable[Mapping]
) -> Sequence[Tuple[EncodableT, EncodableT]]:
"""Remove slaves that are in an ODOWN or SDOWN state"""
slaves_alive = []
for slave in slaves:
if slave["is_odown"] or slave["is_sdown"]:
continue
slaves_alive.append((slave["ip"], slave["port"]))
return slaves_alive
async def discover_slaves(
self, service_name: str
) -> Sequence[Tuple[EncodableT, EncodableT]]:
"""Returns a list of alive slaves for service ``service_name``"""
for sentinel in self.sentinels:
try:
slaves = await sentinel.sentinel_slaves(service_name)
except (ConnectionError, ResponseError, TimeoutError):
continue
slaves = self.filter_slaves(slaves)
if slaves:
return slaves
return []
def master_for(
self,
service_name: str,
redis_class: Type[Redis] = Redis,
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
**kwargs,
):
"""
Returns a redis client instance for the ``service_name`` master.
Sentinel client will detect failover and reconnect Redis clients
automatically.
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
used to retrieve the master's address before establishing a new
connection.
NOTE: If the master's address has changed, any cached connections to
the old master are closed.
By default clients will be a :py:class:`~redis.Redis` instance.
Specify a different class to the ``redis_class`` argument if you
desire something different.
The ``connection_pool_class`` specifies the connection pool to
use. The :py:class:`~redis.sentinel.SentinelConnectionPool`
will be used by default.
All other keyword arguments are merged with any connection_kwargs
passed to this class and passed to the connection pool as keyword
arguments to be used to initialize Redis connections.
"""
kwargs["is_master"] = True
connection_kwargs = dict(self.connection_kwargs)
connection_kwargs.update(kwargs)
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
# The Redis object "owns" the pool
return redis_class.from_pool(connection_pool)
def slave_for(
self,
service_name: str,
redis_class: Type[Redis] = Redis,
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
**kwargs,
):
"""
Returns redis client instance for the ``service_name`` slave(s).
A SentinelConnectionPool class is used to retrieve the slave's
address before establishing a new connection.
By default clients will be a :py:class:`~redis.Redis` instance.
Specify a different class to the ``redis_class`` argument if you
desire something different.
The ``connection_pool_class`` specifies the connection pool to use.
The SentinelConnectionPool will be used by default.
All other keyword arguments are merged with any connection_kwargs
passed to this class and passed to the connection pool as keyword
arguments to be used to initialize Redis connections.
"""
kwargs["is_master"] = False
connection_kwargs = dict(self.connection_kwargs)
connection_kwargs.update(kwargs)
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
# The Redis object "owns" the pool
return redis_class.from_pool(connection_pool)

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from redis.asyncio.client import Pipeline, Redis
def from_url(url, **kwargs):
"""
Returns an active Redis client generated from the given database URL.
Will attempt to extract the database id from the path url fragment, if
none is provided.
"""
from redis.asyncio.client import Redis
return Redis.from_url(url, **kwargs)
class pipeline: # noqa: N801
def __init__(self, redis_obj: "Redis"):
self.p: "Pipeline" = redis_obj.pipeline()
async def __aenter__(self) -> "Pipeline":
return self.p
async def __aexit__(self, exc_type, exc_value, traceback):
await self.p.execute()
del self.p

View File

@@ -0,0 +1,31 @@
from typing import Iterable
class RequestTokenErr(Exception):
"""
Represents an exception during token request.
"""
def __init__(self, *args):
super().__init__(*args)
class InvalidTokenSchemaErr(Exception):
"""
Represents an exception related to invalid token schema.
"""
def __init__(self, missing_fields: Iterable[str] = []):
super().__init__(
"Unexpected token schema. Following fields are missing: "
+ ", ".join(missing_fields)
)
class TokenRenewalErr(Exception):
"""
Represents an exception during token renewal process.
"""
def __init__(self, *args):
super().__init__(*args)

View File

@@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
from redis.auth.token import TokenInterface
"""
This interface is the facade of an identity provider
"""
class IdentityProviderInterface(ABC):
"""
Receive a token from the identity provider.
Receiving a token only works when being authenticated.
"""
@abstractmethod
def request_token(self, force_refresh=False) -> TokenInterface:
pass
class IdentityProviderConfigInterface(ABC):
"""
Configuration class that provides a configured identity provider.
"""
@abstractmethod
def get_provider(self) -> IdentityProviderInterface:
pass

View File

@@ -0,0 +1,130 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from redis.auth.err import InvalidTokenSchemaErr
class TokenInterface(ABC):
@abstractmethod
def is_expired(self) -> bool:
pass
@abstractmethod
def ttl(self) -> float:
pass
@abstractmethod
def try_get(self, key: str) -> str:
pass
@abstractmethod
def get_value(self) -> str:
pass
@abstractmethod
def get_expires_at_ms(self) -> float:
pass
@abstractmethod
def get_received_at_ms(self) -> float:
pass
class TokenResponse:
def __init__(self, token: TokenInterface):
self._token = token
def get_token(self) -> TokenInterface:
return self._token
def get_ttl_ms(self) -> float:
return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
class SimpleToken(TokenInterface):
def __init__(
self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
) -> None:
self.value = value
self.expires_at = expires_at_ms
self.received_at = received_at_ms
self.claims = claims
def ttl(self) -> float:
if self.expires_at == -1:
return -1
return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
def is_expired(self) -> bool:
if self.expires_at == -1:
return False
return self.ttl() <= 0
def try_get(self, key: str) -> str:
return self.claims.get(key)
def get_value(self) -> str:
return self.value
def get_expires_at_ms(self) -> float:
return self.expires_at
def get_received_at_ms(self) -> float:
return self.received_at
class JWToken(TokenInterface):
REQUIRED_FIELDS = {"exp"}
def __init__(self, token: str):
try:
import jwt
except ImportError as ie:
raise ImportError(
f"The PyJWT library is required for {self.__class__.__name__}.",
) from ie
self._value = token
self._decoded = jwt.decode(
self._value,
options={"verify_signature": False},
algorithms=[jwt.get_unverified_header(self._value).get("alg")],
)
self._validate_token()
def is_expired(self) -> bool:
exp = self._decoded["exp"]
if exp == -1:
return False
return (
self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
)
def ttl(self) -> float:
exp = self._decoded["exp"]
if exp == -1:
return -1
return (
self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
)
def try_get(self, key: str) -> str:
return self._decoded.get(key)
def get_value(self) -> str:
return self._value
def get_expires_at_ms(self) -> float:
return float(self._decoded["exp"] * 1000)
def get_received_at_ms(self) -> float:
return datetime.now(timezone.utc).timestamp() * 1000
def _validate_token(self):
actual_fields = {x for x in self._decoded.keys()}
if len(self.REQUIRED_FIELDS - actual_fields) != 0:
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)

View File

@@ -0,0 +1,370 @@
import asyncio
import logging
import threading
from datetime import datetime, timezone
from time import sleep
from typing import Any, Awaitable, Callable, Union
from redis.auth.err import RequestTokenErr, TokenRenewalErr
from redis.auth.idp import IdentityProviderInterface
from redis.auth.token import TokenResponse
logger = logging.getLogger(__name__)
class CredentialsListener:
"""
Listeners that will be notified on events related to credentials.
Accepts callbacks and awaitable callbacks.
"""
def __init__(self):
self._on_next = None
self._on_error = None
@property
def on_next(self) -> Union[Callable[[Any], None], Awaitable]:
return self._on_next
@on_next.setter
def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None:
self._on_next = callback
@property
def on_error(self) -> Union[Callable[[Exception], None], Awaitable]:
return self._on_error
@on_error.setter
def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None:
self._on_error = callback
class RetryPolicy:
def __init__(self, max_attempts: int, delay_in_ms: float):
self.max_attempts = max_attempts
self.delay_in_ms = delay_in_ms
def get_max_attempts(self) -> int:
"""
Retry attempts before exception will be thrown.
:return: int
"""
return self.max_attempts
def get_delay_in_ms(self) -> float:
"""
Delay between retries in seconds.
:return: int
"""
return self.delay_in_ms
class TokenManagerConfig:
def __init__(
self,
expiration_refresh_ratio: float,
lower_refresh_bound_millis: int,
token_request_execution_timeout_in_ms: int,
retry_policy: RetryPolicy,
):
self._expiration_refresh_ratio = expiration_refresh_ratio
self._lower_refresh_bound_millis = lower_refresh_bound_millis
self._token_request_execution_timeout_in_ms = (
token_request_execution_timeout_in_ms
)
self._retry_policy = retry_policy
def get_expiration_refresh_ratio(self) -> float:
"""
Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501
For example, a value of 0.75 means the token should be refreshed
when 75% of its lifetime has elapsed (or when 25% of its lifetime remains).
:return: float
"""
return self._expiration_refresh_ratio
def get_lower_refresh_bound_millis(self) -> int:
"""
Represents the minimum time in milliseconds before token expiration
to trigger a refresh, in milliseconds.
This value sets a fixed lower bound for when a token refresh should occur,
regardless of the token's total lifetime.
If set to 0 there will be no lower bound and the refresh will be triggered
based on the expirationRefreshRatio only.
:return: int
"""
return self._lower_refresh_bound_millis
def get_token_request_execution_timeout_in_ms(self) -> int:
"""
Represents the maximum time in milliseconds to wait
for a token request to complete.
:return: int
"""
return self._token_request_execution_timeout_in_ms
def get_retry_policy(self) -> RetryPolicy:
"""
Represents the retry policy for token requests.
:return: RetryPolicy
"""
return self._retry_policy
class TokenManager:
def __init__(
self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig
):
self._idp = identity_provider
self._config = config
self._next_timer = None
self._listener = None
self._init_timer = None
self._retries = 0
def __del__(self):
logger.info("Token manager are disposed")
self.stop()
def start(
self,
listener: CredentialsListener,
skip_initial: bool = False,
) -> Callable[[], None]:
self._listener = listener
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=_start_event_loop_in_thread, args=(loop,), daemon=True
)
thread.start()
# Event to block for initial execution.
init_event = asyncio.Event()
self._init_timer = loop.call_later(
0, self._renew_token, skip_initial, init_event
)
logger.info("Token manager started")
# Blocks in thread-safe manner.
asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
return self.stop
async def start_async(
self,
listener: CredentialsListener,
block_for_initial: bool = False,
initial_delay_in_ms: float = 0,
skip_initial: bool = False,
) -> Callable[[], None]:
self._listener = listener
loop = asyncio.get_running_loop()
init_event = asyncio.Event()
# Wraps the async callback with async wrapper to schedule with loop.call_later()
wrapped = _async_to_sync_wrapper(
loop, self._renew_token_async, skip_initial, init_event
)
self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped)
logger.info("Token manager started")
if block_for_initial:
await init_event.wait()
return self.stop
def stop(self):
if self._init_timer is not None:
self._init_timer.cancel()
if self._next_timer is not None:
self._next_timer.cancel()
def acquire_token(self, force_refresh=False) -> TokenResponse:
try:
token = self._idp.request_token(force_refresh)
except RequestTokenErr as e:
if self._retries < self._config.get_retry_policy().get_max_attempts():
self._retries += 1
sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000)
return self.acquire_token(force_refresh)
else:
raise e
self._retries = 0
return TokenResponse(token)
async def acquire_token_async(self, force_refresh=False) -> TokenResponse:
try:
token = self._idp.request_token(force_refresh)
except RequestTokenErr as e:
if self._retries < self._config.get_retry_policy().get_max_attempts():
self._retries += 1
await asyncio.sleep(
self._config.get_retry_policy().get_delay_in_ms() / 1000
)
return await self.acquire_token_async(force_refresh)
else:
raise e
self._retries = 0
return TokenResponse(token)
def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float:
delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date)
delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date)
delay = min(delay_for_ratio_refresh, delay_for_lower_refresh)
return 0 if delay < 0 else delay / 1000
def _delay_for_lower_refresh(self, expire_date: float):
return (
expire_date
- self._config.get_lower_refresh_bound_millis()
- (datetime.now(timezone.utc).timestamp() * 1000)
)
def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
token_ttl = expire_date - issue_date
refresh_before = token_ttl - (
token_ttl * self._config.get_expiration_refresh_ratio()
)
return (
expire_date
- refresh_before
- (datetime.now(timezone.utc).timestamp() * 1000)
)
def _renew_token(
self, skip_initial: bool = False, init_event: asyncio.Event = None
):
"""
Task to renew token from identity provider.
Schedules renewal tasks based on token TTL.
"""
try:
token_res = self.acquire_token(force_refresh=True)
delay = self._calculate_renewal_delay(
token_res.get_token().get_expires_at_ms(),
token_res.get_token().get_received_at_ms(),
)
if token_res.get_token().is_expired():
raise TokenRenewalErr("Requested token is expired")
if self._listener.on_next is None:
logger.warning(
"No registered callback for token renewal task. Renewal cancelled"
)
return
if not skip_initial:
try:
self._listener.on_next(token_res.get_token())
except Exception as e:
raise TokenRenewalErr(e)
if delay <= 0:
return
loop = asyncio.get_running_loop()
self._next_timer = loop.call_later(delay, self._renew_token)
logger.info(f"Next token renewal scheduled in {delay} seconds")
return token_res
except Exception as e:
if self._listener.on_error is None:
raise e
self._listener.on_error(e)
finally:
if init_event:
init_event.set()
async def _renew_token_async(
self, skip_initial: bool = False, init_event: asyncio.Event = None
):
"""
Async task to renew tokens from identity provider.
Schedules renewal tasks based on token TTL.
"""
try:
token_res = await self.acquire_token_async(force_refresh=True)
delay = self._calculate_renewal_delay(
token_res.get_token().get_expires_at_ms(),
token_res.get_token().get_received_at_ms(),
)
if token_res.get_token().is_expired():
raise TokenRenewalErr("Requested token is expired")
if self._listener.on_next is None:
logger.warning(
"No registered callback for token renewal task. Renewal cancelled"
)
return
if not skip_initial:
try:
await self._listener.on_next(token_res.get_token())
except Exception as e:
raise TokenRenewalErr(e)
if delay <= 0:
return
loop = asyncio.get_running_loop()
wrapped = _async_to_sync_wrapper(loop, self._renew_token_async)
logger.info(f"Next token renewal scheduled in {delay} seconds")
loop.call_later(delay, wrapped)
except Exception as e:
if self._listener.on_error is None:
raise e
await self._listener.on_error(e)
finally:
if init_event:
init_event.set()
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
"""
Wraps an asynchronous function so it can be used with loop.call_later.
:param loop: The event loop in which the coroutine will be executed.
:param coro_func: The coroutine function to wrap.
:param args: Positional arguments to pass to the coroutine function.
:param kwargs: Keyword arguments to pass to the coroutine function.
:return: A regular function suitable for loop.call_later.
"""
def wrapped():
# Schedule the coroutine in the event loop
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
return wrapped
def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
"""
Starts event loop in a thread.
Used to be able to schedule tasks using loop.call_later.
:param event_loop:
:return:
"""
asyncio.set_event_loop(event_loop)
event_loop.run_forever()

View File

@@ -0,0 +1,204 @@
import asyncio
import threading
from typing import Any, Callable, Coroutine
class BackgroundScheduler:
"""
Schedules background tasks execution either in separate thread or in the running event loop.
"""
def __init__(self):
self._next_timer = None
self._event_loops = []
self._lock = threading.Lock()
self._stopped = False
def __del__(self):
self.stop()
def stop(self):
"""
Stop all scheduled tasks and clean up resources.
"""
with self._lock:
if self._stopped:
return
self._stopped = True
if self._next_timer:
self._next_timer.cancel()
self._next_timer = None
# Stop all event loops
for loop in self._event_loops:
if loop.is_running():
loop.call_soon_threadsafe(loop.stop)
self._event_loops.clear()
def run_once(self, delay: float, callback: Callable, *args):
"""
Runs callable task once after certain delay in seconds.
"""
with self._lock:
if self._stopped:
return
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()
with self._lock:
self._event_loops.append(loop)
thread = threading.Thread(
target=_start_event_loop_in_thread,
args=(loop, self._call_later, delay, callback, *args),
daemon=True,
)
thread.start()
def run_recurring(self, interval: float, callback: Callable, *args):
"""
Runs recurring callable task with given interval in seconds.
"""
with self._lock:
if self._stopped:
return
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()
with self._lock:
self._event_loops.append(loop)
thread = threading.Thread(
target=_start_event_loop_in_thread,
args=(loop, self._call_later_recurring, interval, callback, *args),
daemon=True,
)
thread.start()
async def run_recurring_async(
self, interval: float, coro: Callable[..., Coroutine[Any, Any, Any]], *args
):
"""
Runs recurring coroutine with given interval in seconds in the current event loop.
To be used only from an async context. No additional threads are created.
"""
with self._lock:
if self._stopped:
return
loop = asyncio.get_running_loop()
wrapped = _async_to_sync_wrapper(loop, coro, *args)
def tick():
with self._lock:
if self._stopped:
return
# Schedule the coroutine
wrapped()
# Schedule next tick
self._next_timer = loop.call_later(interval, tick)
# Schedule first tick
self._next_timer = loop.call_later(interval, tick)
def _call_later(
self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args
):
with self._lock:
if self._stopped:
return
self._next_timer = loop.call_later(delay, callback, *args)
def _call_later_recurring(
self,
loop: asyncio.AbstractEventLoop,
interval: float,
callback: Callable,
*args,
):
with self._lock:
if self._stopped:
return
self._call_later(
loop, interval, self._execute_recurring, loop, interval, callback, *args
)
def _execute_recurring(
self,
loop: asyncio.AbstractEventLoop,
interval: float,
callback: Callable,
*args,
):
"""
Executes recurring callable task with given interval in seconds.
"""
with self._lock:
if self._stopped:
return
try:
callback(*args)
except Exception:
# Silently ignore exceptions during shutdown
pass
with self._lock:
if self._stopped:
return
self._call_later(
loop, interval, self._execute_recurring, loop, interval, callback, *args
)
def _start_event_loop_in_thread(
event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args
):
"""
Starts event loop in a thread and schedule callback as soon as event loop is ready.
Used to be able to schedule tasks using loop.call_later.
:param event_loop:
:return:
"""
asyncio.set_event_loop(event_loop)
event_loop.call_soon(call_soon_cb, event_loop, *args)
try:
event_loop.run_forever()
finally:
try:
# Clean up pending tasks
pending = asyncio.all_tasks(event_loop)
for task in pending:
task.cancel()
# Run loop once more to process cancellations
event_loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
except Exception:
pass
finally:
event_loop.close()
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
"""
Wraps an asynchronous function so it can be used with loop.call_later.
:param loop: The event loop in which the coroutine will be executed.
:param coro_func: The coroutine function to wrap.
:param args: Positional arguments to pass to the coroutine function.
:param kwargs: Keyword arguments to pass to the coroutine function.
:return: A regular function suitable for loop.call_later.
"""
def wrapped():
# Schedule the coroutine in the event loop
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
return wrapped

View File

@@ -0,0 +1,183 @@
import random
from abc import ABC, abstractmethod
# Maximum backoff between each retry in seconds
DEFAULT_CAP = 0.512
# Minimum backoff between each retry in seconds
DEFAULT_BASE = 0.008
class AbstractBackoff(ABC):
"""Backoff interface"""
def reset(self):
"""
Reset internal state before an operation.
`reset` is called once at the beginning of
every call to `Retry.call_with_retry`
"""
pass
@abstractmethod
def compute(self, failures: int) -> float:
"""Compute backoff in seconds upon failure"""
pass
class ConstantBackoff(AbstractBackoff):
"""Constant backoff upon failure"""
def __init__(self, backoff: float) -> None:
"""`backoff`: backoff time in seconds"""
self._backoff = backoff
def __hash__(self) -> int:
return hash((self._backoff,))
def __eq__(self, other) -> bool:
if not isinstance(other, ConstantBackoff):
return NotImplemented
return self._backoff == other._backoff
def compute(self, failures: int) -> float:
return self._backoff
class NoBackoff(ConstantBackoff):
"""No backoff upon failure"""
def __init__(self) -> None:
super().__init__(0)
class ExponentialBackoff(AbstractBackoff):
"""Exponential backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE):
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, ExponentialBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
return min(self._cap, self._base * 2**failures)
class FullJitterBackoff(AbstractBackoff):
"""Full jitter backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, FullJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
return random.uniform(0, min(self._cap, self._base * 2**failures))
class EqualJitterBackoff(AbstractBackoff):
"""Equal jitter backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, EqualJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
temp = min(self._cap, self._base * 2**failures) / 2
return temp + random.uniform(0, temp)
class DecorrelatedJitterBackoff(AbstractBackoff):
"""Decorrelated jitter backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
self._previous_backoff = 0
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, DecorrelatedJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def reset(self) -> None:
self._previous_backoff = 0
def compute(self, failures: int) -> float:
max_backoff = max(self._base, self._previous_backoff * 3)
temp = random.uniform(self._base, max_backoff)
self._previous_backoff = min(self._cap, temp)
return self._previous_backoff
class ExponentialWithJitterBackoff(AbstractBackoff):
"""Exponential backoff upon failure, with jitter"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, ExponentialWithJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
return min(self._cap, random.random() * self._base * 2**failures)
def default_backoff():
return EqualJitterBackoff()

View File

@@ -0,0 +1,402 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Union
class CacheEntryStatus(Enum):
VALID = "VALID"
IN_PROGRESS = "IN_PROGRESS"
class EvictionPolicyType(Enum):
time_based = "time_based"
frequency_based = "frequency_based"
@dataclass(frozen=True)
class CacheKey:
command: str
redis_keys: tuple
class CacheEntry:
def __init__(
self,
cache_key: CacheKey,
cache_value: bytes,
status: CacheEntryStatus,
connection_ref,
):
self.cache_key = cache_key
self.cache_value = cache_value
self.status = status
self.connection_ref = connection_ref
def __hash__(self):
return hash(
(self.cache_key, self.cache_value, self.status, self.connection_ref)
)
def __eq__(self, other):
return hash(self) == hash(other)
class EvictionPolicyInterface(ABC):
@property
@abstractmethod
def cache(self):
pass
@cache.setter
@abstractmethod
def cache(self, value):
pass
@property
@abstractmethod
def type(self) -> EvictionPolicyType:
pass
@abstractmethod
def evict_next(self) -> CacheKey:
pass
@abstractmethod
def evict_many(self, count: int) -> List[CacheKey]:
pass
@abstractmethod
def touch(self, cache_key: CacheKey) -> None:
pass
class CacheConfigurationInterface(ABC):
@abstractmethod
def get_cache_class(self):
pass
@abstractmethod
def get_max_size(self) -> int:
pass
@abstractmethod
def get_eviction_policy(self):
pass
@abstractmethod
def is_exceeds_max_size(self, count: int) -> bool:
pass
@abstractmethod
def is_allowed_to_cache(self, command: str) -> bool:
pass
class CacheInterface(ABC):
@property
@abstractmethod
def collection(self) -> OrderedDict:
pass
@property
@abstractmethod
def config(self) -> CacheConfigurationInterface:
pass
@property
@abstractmethod
def eviction_policy(self) -> EvictionPolicyInterface:
pass
@property
@abstractmethod
def size(self) -> int:
pass
@abstractmethod
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
pass
@abstractmethod
def set(self, entry: CacheEntry) -> bool:
pass
@abstractmethod
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
pass
@abstractmethod
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
pass
@abstractmethod
def flush(self) -> int:
pass
@abstractmethod
def is_cachable(self, key: CacheKey) -> bool:
pass
class DefaultCache(CacheInterface):
def __init__(
self,
cache_config: CacheConfigurationInterface,
) -> None:
self._cache = OrderedDict()
self._cache_config = cache_config
self._eviction_policy = self._cache_config.get_eviction_policy().value()
self._eviction_policy.cache = self
@property
def collection(self) -> OrderedDict:
return self._cache
@property
def config(self) -> CacheConfigurationInterface:
return self._cache_config
@property
def eviction_policy(self) -> EvictionPolicyInterface:
return self._eviction_policy
@property
def size(self) -> int:
return len(self._cache)
def set(self, entry: CacheEntry) -> bool:
if not self.is_cachable(entry.cache_key):
return False
self._cache[entry.cache_key] = entry
self._eviction_policy.touch(entry.cache_key)
if self._cache_config.is_exceeds_max_size(len(self._cache)):
self._eviction_policy.evict_next()
return True
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
entry = self._cache.get(key, None)
if entry is None:
return None
self._eviction_policy.touch(key)
return entry
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
response = []
for key in cache_keys:
if self.get(key) is not None:
self._cache.pop(key)
response.append(True)
else:
response.append(False)
return response
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
response = []
keys_to_delete = []
for redis_key in redis_keys:
if isinstance(redis_key, bytes):
redis_key = redis_key.decode()
for cache_key in self._cache:
if redis_key in cache_key.redis_keys:
keys_to_delete.append(cache_key)
response.append(True)
for key in keys_to_delete:
self._cache.pop(key)
return response
def flush(self) -> int:
elem_count = len(self._cache)
self._cache.clear()
return elem_count
def is_cachable(self, key: CacheKey) -> bool:
return self._cache_config.is_allowed_to_cache(key.command)
class LRUPolicy(EvictionPolicyInterface):
def __init__(self):
self.cache = None
@property
def cache(self):
return self._cache
@cache.setter
def cache(self, cache: CacheInterface):
self._cache = cache
@property
def type(self) -> EvictionPolicyType:
return EvictionPolicyType.time_based
def evict_next(self) -> CacheKey:
self._assert_cache()
popped_entry = self._cache.collection.popitem(last=False)
return popped_entry[0]
def evict_many(self, count: int) -> List[CacheKey]:
self._assert_cache()
if count > len(self._cache.collection):
raise ValueError("Evictions count is above cache size")
popped_keys = []
for _ in range(count):
popped_entry = self._cache.collection.popitem(last=False)
popped_keys.append(popped_entry[0])
return popped_keys
def touch(self, cache_key: CacheKey) -> None:
self._assert_cache()
if self._cache.collection.get(cache_key) is None:
raise ValueError("Given entry does not belong to the cache")
self._cache.collection.move_to_end(cache_key)
def _assert_cache(self):
if self.cache is None or not isinstance(self.cache, CacheInterface):
raise ValueError("Eviction policy should be associated with valid cache.")
class EvictionPolicy(Enum):
LRU = LRUPolicy
class CacheConfig(CacheConfigurationInterface):
DEFAULT_CACHE_CLASS = DefaultCache
DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU
DEFAULT_MAX_SIZE = 10000
DEFAULT_ALLOW_LIST = [
"BITCOUNT",
"BITFIELD_RO",
"BITPOS",
"EXISTS",
"GEODIST",
"GEOHASH",
"GEOPOS",
"GEORADIUSBYMEMBER_RO",
"GEORADIUS_RO",
"GEOSEARCH",
"GET",
"GETBIT",
"GETRANGE",
"HEXISTS",
"HGET",
"HGETALL",
"HKEYS",
"HLEN",
"HMGET",
"HSTRLEN",
"HVALS",
"JSON.ARRINDEX",
"JSON.ARRLEN",
"JSON.GET",
"JSON.MGET",
"JSON.OBJKEYS",
"JSON.OBJLEN",
"JSON.RESP",
"JSON.STRLEN",
"JSON.TYPE",
"LCS",
"LINDEX",
"LLEN",
"LPOS",
"LRANGE",
"MGET",
"SCARD",
"SDIFF",
"SINTER",
"SINTERCARD",
"SISMEMBER",
"SMEMBERS",
"SMISMEMBER",
"SORT_RO",
"STRLEN",
"SUBSTR",
"SUNION",
"TS.GET",
"TS.INFO",
"TS.RANGE",
"TS.REVRANGE",
"TYPE",
"XLEN",
"XPENDING",
"XRANGE",
"XREAD",
"XREVRANGE",
"ZCARD",
"ZCOUNT",
"ZDIFF",
"ZINTER",
"ZINTERCARD",
"ZLEXCOUNT",
"ZMSCORE",
"ZRANGE",
"ZRANGEBYLEX",
"ZRANGEBYSCORE",
"ZRANK",
"ZREVRANGE",
"ZREVRANGEBYLEX",
"ZREVRANGEBYSCORE",
"ZREVRANK",
"ZSCORE",
"ZUNION",
]
def __init__(
self,
max_size: int = DEFAULT_MAX_SIZE,
cache_class: Any = DEFAULT_CACHE_CLASS,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
):
self._cache_class = cache_class
self._max_size = max_size
self._eviction_policy = eviction_policy
def get_cache_class(self):
return self._cache_class
def get_max_size(self) -> int:
return self._max_size
def get_eviction_policy(self) -> EvictionPolicy:
return self._eviction_policy
def is_exceeds_max_size(self, count: int) -> bool:
return count > self._max_size
def is_allowed_to_cache(self, command: str) -> bool:
return command in self.DEFAULT_ALLOW_LIST
class CacheFactoryInterface(ABC):
@abstractmethod
def get_cache(self) -> CacheInterface:
pass
class CacheFactory(CacheFactoryInterface):
def __init__(self, cache_config: Optional[CacheConfig] = None):
self._config = cache_config
if self._config is None:
self._config = CacheConfig()
def get_cache(self) -> CacheInterface:
cache_class = self._config.get_cache_class()
return cache_class(cache_config=self._config)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,18 @@
from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands
from .core import AsyncCoreCommands, CoreCommands
from .helpers import list_or_args
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
from .sentinel import AsyncSentinelCommands, SentinelCommands
__all__ = [
"AsyncCoreCommands",
"AsyncRedisClusterCommands",
"AsyncRedisModuleCommands",
"AsyncSentinelCommands",
"CoreCommands",
"READ_COMMANDS",
"RedisClusterCommands",
"RedisModuleCommands",
"SentinelCommands",
"list_or_args",
]

View File

@@ -0,0 +1,253 @@
from redis._parsers.helpers import bool_ok
from ..helpers import get_protocol_version, parse_to_list
from .commands import * # noqa
from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo
class AbstractBloom:
"""
The client allows to interact with RedisBloom and use all of
it's functionality.
- BF for Bloom Filter
- CF for Cuckoo Filter
- CMS for Count-Min Sketch
- TOPK for TopK Data Structure
- TDIGEST for estimate rank statistics
"""
@staticmethod
def append_items(params, items):
"""Append ITEMS to params."""
params.extend(["ITEMS"])
params += items
@staticmethod
def append_error(params, error):
"""Append ERROR to params."""
if error is not None:
params.extend(["ERROR", error])
@staticmethod
def append_capacity(params, capacity):
"""Append CAPACITY to params."""
if capacity is not None:
params.extend(["CAPACITY", capacity])
@staticmethod
def append_expansion(params, expansion):
"""Append EXPANSION to params."""
if expansion is not None:
params.extend(["EXPANSION", expansion])
@staticmethod
def append_no_scale(params, noScale):
"""Append NONSCALING tag to params."""
if noScale is not None:
params.extend(["NONSCALING"])
@staticmethod
def append_weights(params, weights):
"""Append WEIGHTS to params."""
if len(weights) > 0:
params.append("WEIGHTS")
params += weights
@staticmethod
def append_no_create(params, noCreate):
"""Append NOCREATE tag to params."""
if noCreate is not None:
params.extend(["NOCREATE"])
@staticmethod
def append_items_and_increments(params, items, increments):
"""Append pairs of items and increments to params."""
for i in range(len(items)):
params.append(items[i])
params.append(increments[i])
@staticmethod
def append_values_and_weights(params, items, weights):
"""Append pairs of items and weights to params."""
for i in range(len(items)):
params.append(items[i])
params.append(weights[i])
@staticmethod
def append_max_iterations(params, max_iterations):
"""Append MAXITERATIONS to params."""
if max_iterations is not None:
params.extend(["MAXITERATIONS", max_iterations])
@staticmethod
def append_bucket_size(params, bucket_size):
"""Append BUCKETSIZE to params."""
if bucket_size is not None:
params.extend(["BUCKETSIZE", bucket_size])
class CMSBloom(CMSCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
CMS_INITBYDIM: bool_ok,
CMS_INITBYPROB: bool_ok,
# CMS_INCRBY: spaceHolder,
# CMS_QUERY: spaceHolder,
CMS_MERGE: bool_ok,
}
_RESP2_MODULE_CALLBACKS = {
CMS_INFO: CMSInfo,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = CMSCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class TOPKBloom(TOPKCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
TOPK_RESERVE: bool_ok,
# TOPK_QUERY: spaceHolder,
# TOPK_COUNT: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
TOPK_ADD: parse_to_list,
TOPK_INCRBY: parse_to_list,
TOPK_INFO: TopKInfo,
TOPK_LIST: parse_to_list,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = TOPKCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class CFBloom(CFCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
CF_RESERVE: bool_ok,
# CF_ADD: spaceHolder,
# CF_ADDNX: spaceHolder,
# CF_INSERT: spaceHolder,
# CF_INSERTNX: spaceHolder,
# CF_EXISTS: spaceHolder,
# CF_DEL: spaceHolder,
# CF_COUNT: spaceHolder,
# CF_SCANDUMP: spaceHolder,
# CF_LOADCHUNK: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
CF_INFO: CFInfo,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = CFCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class TDigestBloom(TDigestCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
TDIGEST_CREATE: bool_ok,
# TDIGEST_RESET: bool_ok,
# TDIGEST_ADD: spaceHolder,
# TDIGEST_MERGE: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
TDIGEST_BYRANK: parse_to_list,
TDIGEST_BYREVRANK: parse_to_list,
TDIGEST_CDF: parse_to_list,
TDIGEST_INFO: TDigestInfo,
TDIGEST_MIN: float,
TDIGEST_MAX: float,
TDIGEST_TRIMMED_MEAN: float,
TDIGEST_QUANTILE: parse_to_list,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = TDigestCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
class BFBloom(BFCommands, AbstractBloom):
def __init__(self, client, **kwargs):
"""Create a new RedisBloom client."""
# Set the module commands' callbacks
_MODULE_CALLBACKS = {
BF_RESERVE: bool_ok,
# BF_ADD: spaceHolder,
# BF_MADD: spaceHolder,
# BF_INSERT: spaceHolder,
# BF_EXISTS: spaceHolder,
# BF_MEXISTS: spaceHolder,
# BF_SCANDUMP: spaceHolder,
# BF_LOADCHUNK: spaceHolder,
# BF_CARD: spaceHolder,
}
_RESP2_MODULE_CALLBACKS = {
BF_INFO: BFInfo,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.commandmixin = BFCommands
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
_MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
_MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in _MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)

View File

@@ -0,0 +1,538 @@
from redis.client import NEVER_DECODE
from redis.utils import deprecated_function
BF_RESERVE = "BF.RESERVE"
BF_ADD = "BF.ADD"
BF_MADD = "BF.MADD"
BF_INSERT = "BF.INSERT"
BF_EXISTS = "BF.EXISTS"
BF_MEXISTS = "BF.MEXISTS"
BF_SCANDUMP = "BF.SCANDUMP"
BF_LOADCHUNK = "BF.LOADCHUNK"
BF_INFO = "BF.INFO"
BF_CARD = "BF.CARD"
CF_RESERVE = "CF.RESERVE"
CF_ADD = "CF.ADD"
CF_ADDNX = "CF.ADDNX"
CF_INSERT = "CF.INSERT"
CF_INSERTNX = "CF.INSERTNX"
CF_EXISTS = "CF.EXISTS"
CF_MEXISTS = "CF.MEXISTS"
CF_DEL = "CF.DEL"
CF_COUNT = "CF.COUNT"
CF_SCANDUMP = "CF.SCANDUMP"
CF_LOADCHUNK = "CF.LOADCHUNK"
CF_INFO = "CF.INFO"
CMS_INITBYDIM = "CMS.INITBYDIM"
CMS_INITBYPROB = "CMS.INITBYPROB"
CMS_INCRBY = "CMS.INCRBY"
CMS_QUERY = "CMS.QUERY"
CMS_MERGE = "CMS.MERGE"
CMS_INFO = "CMS.INFO"
TOPK_RESERVE = "TOPK.RESERVE"
TOPK_ADD = "TOPK.ADD"
TOPK_INCRBY = "TOPK.INCRBY"
TOPK_QUERY = "TOPK.QUERY"
TOPK_COUNT = "TOPK.COUNT"
TOPK_LIST = "TOPK.LIST"
TOPK_INFO = "TOPK.INFO"
TDIGEST_CREATE = "TDIGEST.CREATE"
TDIGEST_RESET = "TDIGEST.RESET"
TDIGEST_ADD = "TDIGEST.ADD"
TDIGEST_MERGE = "TDIGEST.MERGE"
TDIGEST_CDF = "TDIGEST.CDF"
TDIGEST_QUANTILE = "TDIGEST.QUANTILE"
TDIGEST_MIN = "TDIGEST.MIN"
TDIGEST_MAX = "TDIGEST.MAX"
TDIGEST_INFO = "TDIGEST.INFO"
TDIGEST_TRIMMED_MEAN = "TDIGEST.TRIMMED_MEAN"
TDIGEST_RANK = "TDIGEST.RANK"
TDIGEST_REVRANK = "TDIGEST.REVRANK"
TDIGEST_BYRANK = "TDIGEST.BYRANK"
TDIGEST_BYREVRANK = "TDIGEST.BYREVRANK"
class BFCommands:
"""Bloom Filter commands."""
def create(self, key, errorRate, capacity, expansion=None, noScale=None):
"""
Create a new Bloom Filter `key` with desired probability of false positives
`errorRate` expected entries to be inserted as `capacity`.
Default expansion value is 2. By default, filter is auto-scaling.
For more information see `BF.RESERVE <https://redis.io/commands/bf.reserve>`_.
""" # noqa
params = [key, errorRate, capacity]
self.append_expansion(params, expansion)
self.append_no_scale(params, noScale)
return self.execute_command(BF_RESERVE, *params)
reserve = create
def add(self, key, item):
"""
Add to a Bloom Filter `key` an `item`.
For more information see `BF.ADD <https://redis.io/commands/bf.add>`_.
""" # noqa
return self.execute_command(BF_ADD, key, item)
def madd(self, key, *items):
"""
Add to a Bloom Filter `key` multiple `items`.
For more information see `BF.MADD <https://redis.io/commands/bf.madd>`_.
""" # noqa
return self.execute_command(BF_MADD, key, *items)
def insert(
self,
key,
items,
capacity=None,
error=None,
noCreate=None,
expansion=None,
noScale=None,
):
"""
Add to a Bloom Filter `key` multiple `items`.
If `nocreate` remain `None` and `key` does not exist, a new Bloom Filter
`key` will be created with desired probability of false positives `errorRate`
and expected entries to be inserted as `size`.
For more information see `BF.INSERT <https://redis.io/commands/bf.insert>`_.
""" # noqa
params = [key]
self.append_capacity(params, capacity)
self.append_error(params, error)
self.append_expansion(params, expansion)
self.append_no_create(params, noCreate)
self.append_no_scale(params, noScale)
self.append_items(params, items)
return self.execute_command(BF_INSERT, *params)
def exists(self, key, item):
"""
Check whether an `item` exists in Bloom Filter `key`.
For more information see `BF.EXISTS <https://redis.io/commands/bf.exists>`_.
""" # noqa
return self.execute_command(BF_EXISTS, key, item)
def mexists(self, key, *items):
"""
Check whether `items` exist in Bloom Filter `key`.
For more information see `BF.MEXISTS <https://redis.io/commands/bf.mexists>`_.
""" # noqa
return self.execute_command(BF_MEXISTS, key, *items)
def scandump(self, key, iter):
"""
Begin an incremental save of the bloom filter `key`.
This is useful for large bloom filters which cannot fit into the normal SAVE and RESTORE model.
The first time this command is called, the value of `iter` should be 0.
This command will return successive (iter, data) pairs until (0, NULL) to indicate completion.
For more information see `BF.SCANDUMP <https://redis.io/commands/bf.scandump>`_.
""" # noqa
params = [key, iter]
options = {}
options[NEVER_DECODE] = []
return self.execute_command(BF_SCANDUMP, *params, **options)
def loadchunk(self, key, iter, data):
"""
Restore a filter previously saved using SCANDUMP.
See the SCANDUMP command for example usage.
This command will overwrite any bloom filter stored under key.
Ensure that the bloom filter will not be modified between invocations.
For more information see `BF.LOADCHUNK <https://redis.io/commands/bf.loadchunk>`_.
""" # noqa
return self.execute_command(BF_LOADCHUNK, key, iter, data)
def info(self, key):
"""
Return capacity, size, number of filters, number of items inserted, and expansion rate.
For more information see `BF.INFO <https://redis.io/commands/bf.info>`_.
""" # noqa
return self.execute_command(BF_INFO, key)
def card(self, key):
"""
Returns the cardinality of a Bloom filter - number of items that were added to a Bloom filter and detected as unique
(items that caused at least one bit to be set in at least one sub-filter).
For more information see `BF.CARD <https://redis.io/commands/bf.card>`_.
""" # noqa
return self.execute_command(BF_CARD, key)
class CFCommands:
"""Cuckoo Filter commands."""
def create(
self, key, capacity, expansion=None, bucket_size=None, max_iterations=None
):
"""
Create a new Cuckoo Filter `key` an initial `capacity` items.
For more information see `CF.RESERVE <https://redis.io/commands/cf.reserve>`_.
""" # noqa
params = [key, capacity]
self.append_expansion(params, expansion)
self.append_bucket_size(params, bucket_size)
self.append_max_iterations(params, max_iterations)
return self.execute_command(CF_RESERVE, *params)
reserve = create
def add(self, key, item):
"""
Add an `item` to a Cuckoo Filter `key`.
For more information see `CF.ADD <https://redis.io/commands/cf.add>`_.
""" # noqa
return self.execute_command(CF_ADD, key, item)
def addnx(self, key, item):
"""
Add an `item` to a Cuckoo Filter `key` only if item does not yet exist.
Command might be slower that `add`.
For more information see `CF.ADDNX <https://redis.io/commands/cf.addnx>`_.
""" # noqa
return self.execute_command(CF_ADDNX, key, item)
def insert(self, key, items, capacity=None, nocreate=None):
"""
Add multiple `items` to a Cuckoo Filter `key`, allowing the filter
to be created with a custom `capacity` if it does not yet exist.
`items` must be provided as a list.
For more information see `CF.INSERT <https://redis.io/commands/cf.insert>`_.
""" # noqa
params = [key]
self.append_capacity(params, capacity)
self.append_no_create(params, nocreate)
self.append_items(params, items)
return self.execute_command(CF_INSERT, *params)
def insertnx(self, key, items, capacity=None, nocreate=None):
"""
Add multiple `items` to a Cuckoo Filter `key` only if they do not exist yet,
allowing the filter to be created with a custom `capacity` if it does not yet exist.
`items` must be provided as a list.
For more information see `CF.INSERTNX <https://redis.io/commands/cf.insertnx>`_.
""" # noqa
params = [key]
self.append_capacity(params, capacity)
self.append_no_create(params, nocreate)
self.append_items(params, items)
return self.execute_command(CF_INSERTNX, *params)
def exists(self, key, item):
"""
Check whether an `item` exists in Cuckoo Filter `key`.
For more information see `CF.EXISTS <https://redis.io/commands/cf.exists>`_.
""" # noqa
return self.execute_command(CF_EXISTS, key, item)
def mexists(self, key, *items):
"""
Check whether an `items` exist in Cuckoo Filter `key`.
For more information see `CF.MEXISTS <https://redis.io/commands/cf.mexists>`_.
""" # noqa
return self.execute_command(CF_MEXISTS, key, *items)
def delete(self, key, item):
"""
Delete `item` from `key`.
For more information see `CF.DEL <https://redis.io/commands/cf.del>`_.
""" # noqa
return self.execute_command(CF_DEL, key, item)
def count(self, key, item):
"""
Return the number of times an `item` may be in the `key`.
For more information see `CF.COUNT <https://redis.io/commands/cf.count>`_.
""" # noqa
return self.execute_command(CF_COUNT, key, item)
def scandump(self, key, iter):
"""
Begin an incremental save of the Cuckoo filter `key`.
This is useful for large Cuckoo filters which cannot fit into the normal
SAVE and RESTORE model.
The first time this command is called, the value of `iter` should be 0.
This command will return successive (iter, data) pairs until
(0, NULL) to indicate completion.
For more information see `CF.SCANDUMP <https://redis.io/commands/cf.scandump>`_.
""" # noqa
return self.execute_command(CF_SCANDUMP, key, iter)
def loadchunk(self, key, iter, data):
"""
Restore a filter previously saved using SCANDUMP. See the SCANDUMP command for example usage.
This command will overwrite any Cuckoo filter stored under key.
Ensure that the Cuckoo filter will not be modified between invocations.
For more information see `CF.LOADCHUNK <https://redis.io/commands/cf.loadchunk>`_.
""" # noqa
return self.execute_command(CF_LOADCHUNK, key, iter, data)
def info(self, key):
"""
Return size, number of buckets, number of filter, number of items inserted,
number of items deleted, bucket size, expansion rate, and max iteration.
For more information see `CF.INFO <https://redis.io/commands/cf.info>`_.
""" # noqa
return self.execute_command(CF_INFO, key)
class TOPKCommands:
"""TOP-k Filter commands."""
def reserve(self, key, k, width, depth, decay):
"""
Create a new Top-K Filter `key` with desired probability of false
positives `errorRate` expected entries to be inserted as `size`.
For more information see `TOPK.RESERVE <https://redis.io/commands/topk.reserve>`_.
""" # noqa
return self.execute_command(TOPK_RESERVE, key, k, width, depth, decay)
def add(self, key, *items):
"""
Add one `item` or more to a Top-K Filter `key`.
For more information see `TOPK.ADD <https://redis.io/commands/topk.add>`_.
""" # noqa
return self.execute_command(TOPK_ADD, key, *items)
def incrby(self, key, items, increments):
"""
Add/increase `items` to a Top-K Sketch `key` by ''increments''.
Both `items` and `increments` are lists.
For more information see `TOPK.INCRBY <https://redis.io/commands/topk.incrby>`_.
Example:
>>> topkincrby('A', ['foo'], [1])
""" # noqa
params = [key]
self.append_items_and_increments(params, items, increments)
return self.execute_command(TOPK_INCRBY, *params)
def query(self, key, *items):
"""
Check whether one `item` or more is a Top-K item at `key`.
For more information see `TOPK.QUERY <https://redis.io/commands/topk.query>`_.
""" # noqa
return self.execute_command(TOPK_QUERY, key, *items)
@deprecated_function(version="4.4.0", reason="deprecated since redisbloom 2.4.0")
def count(self, key, *items):
"""
Return count for one `item` or more from `key`.
For more information see `TOPK.COUNT <https://redis.io/commands/topk.count>`_.
""" # noqa
return self.execute_command(TOPK_COUNT, key, *items)
def list(self, key, withcount=False):
"""
Return full list of items in Top-K list of `key`.
If `withcount` set to True, return full list of items
with probabilistic count in Top-K list of `key`.
For more information see `TOPK.LIST <https://redis.io/commands/topk.list>`_.
""" # noqa
params = [key]
if withcount:
params.append("WITHCOUNT")
return self.execute_command(TOPK_LIST, *params)
def info(self, key):
"""
Return k, width, depth and decay values of `key`.
For more information see `TOPK.INFO <https://redis.io/commands/topk.info>`_.
""" # noqa
return self.execute_command(TOPK_INFO, key)
class TDigestCommands:
def create(self, key, compression=100):
"""
Allocate the memory and initialize the t-digest.
For more information see `TDIGEST.CREATE <https://redis.io/commands/tdigest.create>`_.
""" # noqa
return self.execute_command(TDIGEST_CREATE, key, "COMPRESSION", compression)
def reset(self, key):
"""
Reset the sketch `key` to zero - empty out the sketch and re-initialize it.
For more information see `TDIGEST.RESET <https://redis.io/commands/tdigest.reset>`_.
""" # noqa
return self.execute_command(TDIGEST_RESET, key)
def add(self, key, values):
"""
Adds one or more observations to a t-digest sketch `key`.
For more information see `TDIGEST.ADD <https://redis.io/commands/tdigest.add>`_.
""" # noqa
return self.execute_command(TDIGEST_ADD, key, *values)
def merge(self, destination_key, num_keys, *keys, compression=None, override=False):
"""
Merges all of the values from `keys` to 'destination-key' sketch.
It is mandatory to provide the `num_keys` before passing the input keys and
the other (optional) arguments.
If `destination_key` already exists its values are merged with the input keys.
If you wish to override the destination key contents use the `OVERRIDE` parameter.
For more information see `TDIGEST.MERGE <https://redis.io/commands/tdigest.merge>`_.
""" # noqa
params = [destination_key, num_keys, *keys]
if compression is not None:
params.extend(["COMPRESSION", compression])
if override:
params.append("OVERRIDE")
return self.execute_command(TDIGEST_MERGE, *params)
def min(self, key):
"""
Return minimum value from the sketch `key`. Will return DBL_MAX if the sketch is empty.
For more information see `TDIGEST.MIN <https://redis.io/commands/tdigest.min>`_.
""" # noqa
return self.execute_command(TDIGEST_MIN, key)
def max(self, key):
"""
Return maximum value from the sketch `key`. Will return DBL_MIN if the sketch is empty.
For more information see `TDIGEST.MAX <https://redis.io/commands/tdigest.max>`_.
""" # noqa
return self.execute_command(TDIGEST_MAX, key)
def quantile(self, key, quantile, *quantiles):
"""
Returns estimates of one or more cutoffs such that a specified fraction of the
observations added to this t-digest would be less than or equal to each of the
specified cutoffs. (Multiple quantiles can be returned with one call)
For more information see `TDIGEST.QUANTILE <https://redis.io/commands/tdigest.quantile>`_.
""" # noqa
return self.execute_command(TDIGEST_QUANTILE, key, quantile, *quantiles)
def cdf(self, key, value, *values):
"""
Return double fraction of all points added which are <= value.
For more information see `TDIGEST.CDF <https://redis.io/commands/tdigest.cdf>`_.
""" # noqa
return self.execute_command(TDIGEST_CDF, key, value, *values)
def info(self, key):
"""
Return Compression, Capacity, Merged Nodes, Unmerged Nodes, Merged Weight, Unmerged Weight
and Total Compressions.
For more information see `TDIGEST.INFO <https://redis.io/commands/tdigest.info>`_.
""" # noqa
return self.execute_command(TDIGEST_INFO, key)
def trimmed_mean(self, key, low_cut_quantile, high_cut_quantile):
"""
Return mean value from the sketch, excluding observation values outside
the low and high cutoff quantiles.
For more information see `TDIGEST.TRIMMED_MEAN <https://redis.io/commands/tdigest.trimmed_mean>`_.
""" # noqa
return self.execute_command(
TDIGEST_TRIMMED_MEAN, key, low_cut_quantile, high_cut_quantile
)
def rank(self, key, value, *values):
"""
Retrieve the estimated rank of value (the number of observations in the sketch
that are smaller than value + half the number of observations that are equal to value).
For more information see `TDIGEST.RANK <https://redis.io/commands/tdigest.rank>`_.
""" # noqa
return self.execute_command(TDIGEST_RANK, key, value, *values)
def revrank(self, key, value, *values):
"""
Retrieve the estimated rank of value (the number of observations in the sketch
that are larger than value + half the number of observations that are equal to value).
For more information see `TDIGEST.REVRANK <https://redis.io/commands/tdigest.revrank>`_.
""" # noqa
return self.execute_command(TDIGEST_REVRANK, key, value, *values)
def byrank(self, key, rank, *ranks):
"""
Retrieve an estimation of the value with the given rank.
For more information see `TDIGEST.BY_RANK <https://redis.io/commands/tdigest.by_rank>`_.
""" # noqa
return self.execute_command(TDIGEST_BYRANK, key, rank, *ranks)
def byrevrank(self, key, rank, *ranks):
"""
Retrieve an estimation of the value with the given reverse rank.
For more information see `TDIGEST.BY_REVRANK <https://redis.io/commands/tdigest.by_revrank>`_.
""" # noqa
return self.execute_command(TDIGEST_BYREVRANK, key, rank, *ranks)
class CMSCommands:
"""Count-Min Sketch Commands"""
def initbydim(self, key, width, depth):
"""
Initialize a Count-Min Sketch `key` to dimensions (`width`, `depth`) specified by user.
For more information see `CMS.INITBYDIM <https://redis.io/commands/cms.initbydim>`_.
""" # noqa
return self.execute_command(CMS_INITBYDIM, key, width, depth)
def initbyprob(self, key, error, probability):
"""
Initialize a Count-Min Sketch `key` to characteristics (`error`, `probability`) specified by user.
For more information see `CMS.INITBYPROB <https://redis.io/commands/cms.initbyprob>`_.
""" # noqa
return self.execute_command(CMS_INITBYPROB, key, error, probability)
def incrby(self, key, items, increments):
"""
Add/increase `items` to a Count-Min Sketch `key` by ''increments''.
Both `items` and `increments` are lists.
For more information see `CMS.INCRBY <https://redis.io/commands/cms.incrby>`_.
Example:
>>> cmsincrby('A', ['foo'], [1])
""" # noqa
params = [key]
self.append_items_and_increments(params, items, increments)
return self.execute_command(CMS_INCRBY, *params)
def query(self, key, *items):
"""
Return count for an `item` from `key`. Multiple items can be queried with one call.
For more information see `CMS.QUERY <https://redis.io/commands/cms.query>`_.
""" # noqa
return self.execute_command(CMS_QUERY, key, *items)
def merge(self, destKey, numKeys, srcKeys, weights=[]):
"""
Merge `numKeys` of sketches into `destKey`. Sketches specified in `srcKeys`.
All sketches must have identical width and depth.
`Weights` can be used to multiply certain sketches. Default weight is 1.
Both `srcKeys` and `weights` are lists.
For more information see `CMS.MERGE <https://redis.io/commands/cms.merge>`_.
""" # noqa
params = [destKey, numKeys]
params += srcKeys
self.append_weights(params, weights)
return self.execute_command(CMS_MERGE, *params)
def info(self, key):
"""
Return width, depth and total count of the sketch.
For more information see `CMS.INFO <https://redis.io/commands/cms.info>`_.
""" # noqa
return self.execute_command(CMS_INFO, key)

View File

@@ -0,0 +1,120 @@
from ..helpers import nativestr
class BFInfo:
capacity = None
size = None
filterNum = None
insertedNum = None
expansionRate = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.capacity = response["Capacity"]
self.size = response["Size"]
self.filterNum = response["Number of filters"]
self.insertedNum = response["Number of items inserted"]
self.expansionRate = response["Expansion rate"]
def get(self, item):
try:
return self.__getitem__(item)
except AttributeError:
return None
def __getitem__(self, item):
return getattr(self, item)
class CFInfo:
size = None
bucketNum = None
filterNum = None
insertedNum = None
deletedNum = None
bucketSize = None
expansionRate = None
maxIteration = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.size = response["Size"]
self.bucketNum = response["Number of buckets"]
self.filterNum = response["Number of filters"]
self.insertedNum = response["Number of items inserted"]
self.deletedNum = response["Number of items deleted"]
self.bucketSize = response["Bucket size"]
self.expansionRate = response["Expansion rate"]
self.maxIteration = response["Max iterations"]
def get(self, item):
try:
return self.__getitem__(item)
except AttributeError:
return None
def __getitem__(self, item):
return getattr(self, item)
class CMSInfo:
width = None
depth = None
count = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.width = response["width"]
self.depth = response["depth"]
self.count = response["count"]
def __getitem__(self, item):
return getattr(self, item)
class TopKInfo:
k = None
width = None
depth = None
decay = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.k = response["k"]
self.width = response["width"]
self.depth = response["depth"]
self.decay = response["decay"]
def __getitem__(self, item):
return getattr(self, item)
class TDigestInfo:
compression = None
capacity = None
merged_nodes = None
unmerged_nodes = None
merged_weight = None
unmerged_weight = None
total_compressions = None
memory_usage = None
def __init__(self, args):
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.compression = response["Compression"]
self.capacity = response["Capacity"]
self.merged_nodes = response["Merged nodes"]
self.unmerged_nodes = response["Unmerged nodes"]
self.merged_weight = response["Merged weight"]
self.unmerged_weight = response["Unmerged weight"]
self.total_compressions = response["Total compressions"]
self.memory_usage = response["Memory usage"]
def get(self, item):
try:
return self.__getitem__(item)
except AttributeError:
return None
def __getitem__(self, item):
return getattr(self, item)

View File

@@ -0,0 +1,919 @@
import asyncio
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
NoReturn,
Optional,
Union,
)
from redis.crc import key_slot
from redis.exceptions import RedisClusterException, RedisError
from redis.typing import (
AnyKeyT,
ClusterCommandsProtocol,
EncodableT,
KeysT,
KeyT,
PatternT,
ResponseT,
)
from .core import (
ACLCommands,
AsyncACLCommands,
AsyncDataAccessCommands,
AsyncFunctionCommands,
AsyncManagementCommands,
AsyncModuleCommands,
AsyncScriptCommands,
DataAccessCommands,
FunctionCommands,
ManagementCommands,
ModuleCommands,
PubSubCommands,
ScriptCommands,
)
from .helpers import list_or_args
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
if TYPE_CHECKING:
from redis.asyncio.cluster import TargetNodesT
# Not complete, but covers the major ones
# https://redis.io/commands
READ_COMMANDS = frozenset(
[
"BITCOUNT",
"BITPOS",
"EVAL_RO",
"EVALSHA_RO",
"EXISTS",
"GEODIST",
"GEOHASH",
"GEOPOS",
"GEORADIUS",
"GEORADIUSBYMEMBER",
"GET",
"GETBIT",
"GETRANGE",
"HEXISTS",
"HGET",
"HGETALL",
"HKEYS",
"HLEN",
"HMGET",
"HSTRLEN",
"HVALS",
"KEYS",
"LINDEX",
"LLEN",
"LRANGE",
"MGET",
"PTTL",
"RANDOMKEY",
"SCARD",
"SDIFF",
"SINTER",
"SISMEMBER",
"SMEMBERS",
"SRANDMEMBER",
"STRLEN",
"SUNION",
"TTL",
"ZCARD",
"ZCOUNT",
"ZRANGE",
"ZSCORE",
]
)
class ClusterMultiKeyCommands(ClusterCommandsProtocol):
"""
A class containing commands that handle more than one key
"""
def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]:
"""Split keys into a dictionary that maps a slot to a list of keys."""
slots_to_keys = {}
for key in keys:
slot = key_slot(self.encoder.encode(key))
slots_to_keys.setdefault(slot, []).append(key)
return slots_to_keys
def _partition_pairs_by_slot(
self, mapping: Mapping[AnyKeyT, EncodableT]
) -> Dict[int, List[EncodableT]]:
"""Split pairs into a dictionary that maps a slot to a list of pairs."""
slots_to_pairs = {}
for pair in mapping.items():
slot = key_slot(self.encoder.encode(pair[0]))
slots_to_pairs.setdefault(slot, []).extend(pair)
return slots_to_pairs
def _execute_pipeline_by_slot(
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
) -> List[Any]:
read_from_replicas = self.read_from_replicas and command in READ_COMMANDS
pipe = self.pipeline()
[
pipe.execute_command(
command,
*slot_args,
target_nodes=[
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
],
)
for slot, slot_args in slots_to_args.items()
]
return pipe.execute()
def _reorder_keys_by_command(
self,
keys: Iterable[KeyT],
slots_to_args: Mapping[int, Iterable[EncodableT]],
responses: Iterable[Any],
) -> List[Any]:
results = {
k: v
for slot_values, response in zip(slots_to_args.values(), responses)
for k, v in zip(slot_values, response)
}
return [results[key] for key in keys]
def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
"""
Splits the keys into different slots and then calls MGET
for the keys of every slot. This operation will not be atomic
if keys belong to more than one slot.
Returns a list of values ordered identically to ``keys``
For more information see https://redis.io/commands/mget
"""
# Concatenate all keys into a list
keys = list_or_args(keys, args)
# Split keys into slots
slots_to_keys = self._partition_keys_by_slot(keys)
# Execute commands using a pipeline
res = self._execute_pipeline_by_slot("MGET", slots_to_keys)
# Reorder keys in the order the user provided & return
return self._reorder_keys_by_command(keys, slots_to_keys, res)
def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]:
"""
Sets key/values based on a mapping. Mapping is a dictionary of
key/value pairs. Both keys and values should be strings or types that
can be cast to a string via str().
Splits the keys into different slots and then calls MSET
for the keys of every slot. This operation will not be atomic
if keys belong to more than one slot.
For more information see https://redis.io/commands/mset
"""
# Partition the keys by slot
slots_to_pairs = self._partition_pairs_by_slot(mapping)
# Execute commands using a pipeline & return list of replies
return self._execute_pipeline_by_slot("MSET", slots_to_pairs)
def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
"""
Runs the given command once for the keys
of each slot. Returns the sum of the return values.
"""
# Partition the keys by slot
slots_to_keys = self._partition_keys_by_slot(keys)
# Sum up the reply from each command
return sum(self._execute_pipeline_by_slot(command, slots_to_keys))
def exists(self, *keys: KeyT) -> ResponseT:
"""
Returns the number of ``names`` that exist in the
whole cluster. The keys are first split up into slots
and then an EXISTS command is sent for every slot
For more information see https://redis.io/commands/exists
"""
return self._split_command_across_slots("EXISTS", *keys)
def delete(self, *keys: KeyT) -> ResponseT:
"""
Deletes the given keys in the cluster.
The keys are first split up into slots
and then an DEL command is sent for every slot
Non-existent keys are ignored.
Returns the number of keys that were deleted.
For more information see https://redis.io/commands/del
"""
return self._split_command_across_slots("DEL", *keys)
def touch(self, *keys: KeyT) -> ResponseT:
"""
Updates the last access time of given keys across the
cluster.
The keys are first split up into slots
and then an TOUCH command is sent for every slot
Non-existent keys are ignored.
Returns the number of keys that were touched.
For more information see https://redis.io/commands/touch
"""
return self._split_command_across_slots("TOUCH", *keys)
def unlink(self, *keys: KeyT) -> ResponseT:
"""
Remove the specified keys in a different thread.
The keys are first split up into slots
and then an TOUCH command is sent for every slot
Non-existent keys are ignored.
Returns the number of keys that were unlinked.
For more information see https://redis.io/commands/unlink
"""
return self._split_command_across_slots("UNLINK", *keys)
class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands):
"""
A class containing commands that handle more than one key
"""
async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]:
"""
Splits the keys into different slots and then calls MGET
for the keys of every slot. This operation will not be atomic
if keys belong to more than one slot.
Returns a list of values ordered identically to ``keys``
For more information see https://redis.io/commands/mget
"""
# Concatenate all keys into a list
keys = list_or_args(keys, args)
# Split keys into slots
slots_to_keys = self._partition_keys_by_slot(keys)
# Execute commands using a pipeline
res = await self._execute_pipeline_by_slot("MGET", slots_to_keys)
# Reorder keys in the order the user provided & return
return self._reorder_keys_by_command(keys, slots_to_keys, res)
async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]:
"""
Sets key/values based on a mapping. Mapping is a dictionary of
key/value pairs. Both keys and values should be strings or types that
can be cast to a string via str().
Splits the keys into different slots and then calls MSET
for the keys of every slot. This operation will not be atomic
if keys belong to more than one slot.
For more information see https://redis.io/commands/mset
"""
# Partition the keys by slot
slots_to_pairs = self._partition_pairs_by_slot(mapping)
# Execute commands using a pipeline & return list of replies
return await self._execute_pipeline_by_slot("MSET", slots_to_pairs)
async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int:
"""
Runs the given command once for the keys
of each slot. Returns the sum of the return values.
"""
# Partition the keys by slot
slots_to_keys = self._partition_keys_by_slot(keys)
# Sum up the reply from each command
return sum(await self._execute_pipeline_by_slot(command, slots_to_keys))
async def _execute_pipeline_by_slot(
self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]]
) -> List[Any]:
if self._initialize:
await self.initialize()
read_from_replicas = self.read_from_replicas and command in READ_COMMANDS
pipe = self.pipeline()
[
pipe.execute_command(
command,
*slot_args,
target_nodes=[
self.nodes_manager.get_node_from_slot(slot, read_from_replicas)
],
)
for slot, slot_args in slots_to_args.items()
]
return await pipe.execute()
class ClusterManagementCommands(ManagementCommands):
"""
A class for Redis Cluster management commands
The class inherits from Redis's core ManagementCommands class and do the
required adjustments to work with cluster mode
"""
def slaveof(self, *args, **kwargs) -> NoReturn:
"""
Make the server a replica of another instance, or promote it as master.
For more information see https://redis.io/commands/slaveof
"""
raise RedisClusterException("SLAVEOF is not supported in cluster mode")
def replicaof(self, *args, **kwargs) -> NoReturn:
"""
Make the server a replica of another instance, or promote it as master.
For more information see https://redis.io/commands/replicaof
"""
raise RedisClusterException("REPLICAOF is not supported in cluster mode")
def swapdb(self, *args, **kwargs) -> NoReturn:
"""
Swaps two Redis databases.
For more information see https://redis.io/commands/swapdb
"""
raise RedisClusterException("SWAPDB is not supported in cluster mode")
def cluster_myid(self, target_node: "TargetNodesT") -> ResponseT:
"""
Returns the node's id.
:target_node: 'ClusterNode'
The node to execute the command on
For more information check https://redis.io/commands/cluster-myid/
"""
return self.execute_command("CLUSTER MYID", target_nodes=target_node)
def cluster_addslots(
self, target_node: "TargetNodesT", *slots: EncodableT
) -> ResponseT:
"""
Assign new hash slots to receiving node. Sends to specified node.
:target_node: 'ClusterNode'
The node to execute the command on
For more information see https://redis.io/commands/cluster-addslots
"""
return self.execute_command(
"CLUSTER ADDSLOTS", *slots, target_nodes=target_node
)
def cluster_addslotsrange(
self, target_node: "TargetNodesT", *slots: EncodableT
) -> ResponseT:
"""
Similar to the CLUSTER ADDSLOTS command.
The difference between the two commands is that ADDSLOTS takes a list of slots
to assign to the node, while ADDSLOTSRANGE takes a list of slot ranges
(specified by start and end slots) to assign to the node.
:target_node: 'ClusterNode'
The node to execute the command on
For more information see https://redis.io/commands/cluster-addslotsrange
"""
return self.execute_command(
"CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node
)
def cluster_countkeysinslot(self, slot_id: int) -> ResponseT:
"""
Return the number of local keys in the specified hash slot
Send to node based on specified slot_id
For more information see https://redis.io/commands/cluster-countkeysinslot
"""
return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id)
def cluster_count_failure_report(self, node_id: str) -> ResponseT:
"""
Return the number of failure reports active for a given node
Sends to a random node
For more information see https://redis.io/commands/cluster-count-failure-reports
"""
return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id)
def cluster_delslots(self, *slots: EncodableT) -> List[bool]:
"""
Set hash slots as unbound in the cluster.
It determines by it self what node the slot is in and sends it there
Returns a list of the results for each processed slot.
For more information see https://redis.io/commands/cluster-delslots
"""
return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots]
def cluster_delslotsrange(self, *slots: EncodableT) -> ResponseT:
"""
Similar to the CLUSTER DELSLOTS command.
The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove
from the node, while CLUSTER DELSLOTSRANGE takes a list of slot ranges to remove
from the node.
For more information see https://redis.io/commands/cluster-delslotsrange
"""
return self.execute_command("CLUSTER DELSLOTSRANGE", *slots)
def cluster_failover(
self, target_node: "TargetNodesT", option: Optional[str] = None
) -> ResponseT:
"""
Forces a slave to perform a manual failover of its master
Sends to specified node
:target_node: 'ClusterNode'
The node to execute the command on
For more information see https://redis.io/commands/cluster-failover
"""
if option:
if option.upper() not in ["FORCE", "TAKEOVER"]:
raise RedisError(
f"Invalid option for CLUSTER FAILOVER command: {option}"
)
else:
return self.execute_command(
"CLUSTER FAILOVER", option, target_nodes=target_node
)
else:
return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node)
def cluster_info(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
"""
Provides info about Redis Cluster node state.
The command will be sent to a random node in the cluster if no target
node is specified.
For more information see https://redis.io/commands/cluster-info
"""
return self.execute_command("CLUSTER INFO", target_nodes=target_nodes)
def cluster_keyslot(self, key: str) -> ResponseT:
"""
Returns the hash slot of the specified key
Sends to random node in the cluster
For more information see https://redis.io/commands/cluster-keyslot
"""
return self.execute_command("CLUSTER KEYSLOT", key)
def cluster_meet(
self, host: str, port: int, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Force a node cluster to handshake with another node.
Sends to specified node.
For more information see https://redis.io/commands/cluster-meet
"""
return self.execute_command(
"CLUSTER MEET", host, port, target_nodes=target_nodes
)
def cluster_nodes(self) -> ResponseT:
"""
Get Cluster config for the node.
Sends to random node in the cluster
For more information see https://redis.io/commands/cluster-nodes
"""
return self.execute_command("CLUSTER NODES")
def cluster_replicate(
self, target_nodes: "TargetNodesT", node_id: str
) -> ResponseT:
"""
Reconfigure a node as a slave of the specified master node
For more information see https://redis.io/commands/cluster-replicate
"""
return self.execute_command(
"CLUSTER REPLICATE", node_id, target_nodes=target_nodes
)
def cluster_reset(
self, soft: bool = True, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Reset a Redis Cluster node
If 'soft' is True then it will send 'SOFT' argument
If 'soft' is False then it will send 'HARD' argument
For more information see https://redis.io/commands/cluster-reset
"""
return self.execute_command(
"CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes
)
def cluster_save_config(
self, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Forces the node to save cluster state on disk
For more information see https://redis.io/commands/cluster-saveconfig
"""
return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes)
def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> ResponseT:
"""
Returns the number of keys in the specified cluster slot
For more information see https://redis.io/commands/cluster-getkeysinslot
"""
return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys)
def cluster_set_config_epoch(
self, epoch: int, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Set the configuration epoch in a new node
For more information see https://redis.io/commands/cluster-set-config-epoch
"""
return self.execute_command(
"CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes
)
def cluster_setslot(
self, target_node: "TargetNodesT", node_id: str, slot_id: int, state: str
) -> ResponseT:
"""
Bind an hash slot to a specific node
:target_node: 'ClusterNode'
The node to execute the command on
For more information see https://redis.io/commands/cluster-setslot
"""
if state.upper() in ("IMPORTING", "NODE", "MIGRATING"):
return self.execute_command(
"CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node
)
elif state.upper() == "STABLE":
raise RedisError('For "stable" state please use cluster_setslot_stable')
else:
raise RedisError(f"Invalid slot state: {state}")
def cluster_setslot_stable(self, slot_id: int) -> ResponseT:
"""
Clears migrating / importing state from the slot.
It determines by it self what node the slot is in and sends it there.
For more information see https://redis.io/commands/cluster-setslot
"""
return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE")
def cluster_replicas(
self, node_id: str, target_nodes: Optional["TargetNodesT"] = None
) -> ResponseT:
"""
Provides a list of replica nodes replicating from the specified primary
target node.
For more information see https://redis.io/commands/cluster-replicas
"""
return self.execute_command(
"CLUSTER REPLICAS", node_id, target_nodes=target_nodes
)
def cluster_slots(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
"""
Get array of Cluster slot to node mappings
For more information see https://redis.io/commands/cluster-slots
"""
return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes)
def cluster_shards(self, target_nodes=None):
"""
Returns details about the shards of the cluster.
For more information see https://redis.io/commands/cluster-shards
"""
return self.execute_command("CLUSTER SHARDS", target_nodes=target_nodes)
def cluster_myshardid(self, target_nodes=None):
"""
Returns the shard ID of the node.
For more information see https://redis.io/commands/cluster-myshardid/
"""
return self.execute_command("CLUSTER MYSHARDID", target_nodes=target_nodes)
def cluster_links(self, target_node: "TargetNodesT") -> ResponseT:
"""
Each node in a Redis Cluster maintains a pair of long-lived TCP link with each
peer in the cluster: One for sending outbound messages towards the peer and one
for receiving inbound messages from the peer.
This command outputs information of all such peer links as an array.
For more information see https://redis.io/commands/cluster-links
"""
return self.execute_command("CLUSTER LINKS", target_nodes=target_node)
def cluster_flushslots(self, target_nodes: Optional["TargetNodesT"] = None) -> None:
raise NotImplementedError(
"CLUSTER FLUSHSLOTS is intentionally not implemented in the client."
)
def cluster_bumpepoch(self, target_nodes: Optional["TargetNodesT"] = None) -> None:
raise NotImplementedError(
"CLUSTER BUMPEPOCH is intentionally not implemented in the client."
)
def readonly(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
"""
Enables read queries.
The command will be sent to the default cluster node if target_nodes is
not specified.
For more information see https://redis.io/commands/readonly
"""
if target_nodes == "replicas" or target_nodes == "all":
# read_from_replicas will only be enabled if the READONLY command
# is sent to all replicas
self.read_from_replicas = True
return self.execute_command("READONLY", target_nodes=target_nodes)
def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT:
"""
Disables read queries.
The command will be sent to the default cluster node if target_nodes is
not specified.
For more information see https://redis.io/commands/readwrite
"""
# Reset read from replicas flag
self.read_from_replicas = False
return self.execute_command("READWRITE", target_nodes=target_nodes)
class AsyncClusterManagementCommands(
ClusterManagementCommands, AsyncManagementCommands
):
"""
A class for Redis Cluster management commands
The class inherits from Redis's core ManagementCommands class and do the
required adjustments to work with cluster mode
"""
async def cluster_delslots(self, *slots: EncodableT) -> List[bool]:
"""
Set hash slots as unbound in the cluster.
It determines by it self what node the slot is in and sends it there
Returns a list of the results for each processed slot.
For more information see https://redis.io/commands/cluster-delslots
"""
return await asyncio.gather(
*(
asyncio.create_task(self.execute_command("CLUSTER DELSLOTS", slot))
for slot in slots
)
)
class ClusterDataAccessCommands(DataAccessCommands):
"""
A class for Redis Cluster Data Access Commands
The class inherits from Redis's core DataAccessCommand class and do the
required adjustments to work with cluster mode
"""
def stralgo(
self,
algo: Literal["LCS"],
value1: KeyT,
value2: KeyT,
specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings",
len: bool = False,
idx: bool = False,
minmatchlen: Optional[int] = None,
withmatchlen: bool = False,
**kwargs,
) -> ResponseT:
"""
Implements complex algorithms that operate on strings.
Right now the only algorithm implemented is the LCS algorithm
(longest common substring). However new algorithms could be
implemented in the future.
``algo`` Right now must be LCS
``value1`` and ``value2`` Can be two strings or two keys
``specific_argument`` Specifying if the arguments to the algorithm
will be keys or strings. strings is the default.
``len`` Returns just the len of the match.
``idx`` Returns the match positions in each string.
``minmatchlen`` Restrict the list of matches to the ones of a given
minimal length. Can be provided only when ``idx`` set to True.
``withmatchlen`` Returns the matches with the len of the match.
Can be provided only when ``idx`` set to True.
For more information see https://redis.io/commands/stralgo
"""
target_nodes = kwargs.pop("target_nodes", None)
if specific_argument == "strings" and target_nodes is None:
target_nodes = "default-node"
kwargs.update({"target_nodes": target_nodes})
return super().stralgo(
algo,
value1,
value2,
specific_argument,
len,
idx,
minmatchlen,
withmatchlen,
**kwargs,
)
def scan_iter(
self,
match: Optional[PatternT] = None,
count: Optional[int] = None,
_type: Optional[str] = None,
**kwargs,
) -> Iterator:
# Do the first query with cursor=0 for all nodes
cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs)
yield from data
cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0}
if cursors:
# Get nodes by name
nodes = {name: self.get_node(node_name=name) for name in cursors.keys()}
# Iterate over each node till its cursor is 0
kwargs.pop("target_nodes", None)
while cursors:
for name, cursor in cursors.items():
cur, data = self.scan(
cursor=cursor,
match=match,
count=count,
_type=_type,
target_nodes=nodes[name],
**kwargs,
)
yield from data
cursors[name] = cur[name]
cursors = {
name: cursor for name, cursor in cursors.items() if cursor != 0
}
class AsyncClusterDataAccessCommands(
ClusterDataAccessCommands, AsyncDataAccessCommands
):
"""
A class for Redis Cluster Data Access Commands
The class inherits from Redis's core DataAccessCommand class and do the
required adjustments to work with cluster mode
"""
async def scan_iter(
self,
match: Optional[PatternT] = None,
count: Optional[int] = None,
_type: Optional[str] = None,
**kwargs,
) -> AsyncIterator:
# Do the first query with cursor=0 for all nodes
cursors, data = await self.scan(match=match, count=count, _type=_type, **kwargs)
for value in data:
yield value
cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0}
if cursors:
# Get nodes by name
nodes = {name: self.get_node(node_name=name) for name in cursors.keys()}
# Iterate over each node till its cursor is 0
kwargs.pop("target_nodes", None)
while cursors:
for name, cursor in cursors.items():
cur, data = await self.scan(
cursor=cursor,
match=match,
count=count,
_type=_type,
target_nodes=nodes[name],
**kwargs,
)
for value in data:
yield value
cursors[name] = cur[name]
cursors = {
name: cursor for name, cursor in cursors.items() if cursor != 0
}
class RedisClusterCommands(
ClusterMultiKeyCommands,
ClusterManagementCommands,
ACLCommands,
PubSubCommands,
ClusterDataAccessCommands,
ScriptCommands,
FunctionCommands,
ModuleCommands,
RedisModuleCommands,
):
"""
A class for all Redis Cluster commands
For key-based commands, the target node(s) will be internally determined
by the keys' hash slot.
Non-key-based commands can be executed with the 'target_nodes' argument to
target specific nodes. By default, if target_nodes is not specified, the
command will be executed on the default cluster node.
:param :target_nodes: type can be one of the followings:
- nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM
- 'ClusterNode'
- 'list(ClusterNodes)'
- 'dict(any:clusterNodes)'
for example:
r.cluster_info(target_nodes=RedisCluster.ALL_NODES)
"""
class AsyncRedisClusterCommands(
AsyncClusterMultiKeyCommands,
AsyncClusterManagementCommands,
AsyncACLCommands,
AsyncClusterDataAccessCommands,
AsyncScriptCommands,
AsyncFunctionCommands,
AsyncModuleCommands,
AsyncRedisModuleCommands,
):
"""
A class for all Redis Cluster commands
For key-based commands, the target node(s) will be internally determined
by the keys' hash slot.
Non-key-based commands can be executed with the 'target_nodes' argument to
target specific nodes. By default, if target_nodes is not specified, the
command will be executed on the default cluster node.
:param :target_nodes: type can be one of the followings:
- nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM
- 'ClusterNode'
- 'list(ClusterNodes)'
- 'dict(any:clusterNodes)'
for example:
r.cluster_info(target_nodes=RedisCluster.ALL_NODES)
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,98 @@
import copy
import random
import string
from typing import List, Tuple
import redis
from redis.typing import KeysT, KeyT
def list_or_args(keys: KeysT, args: Tuple[KeyT, ...]) -> List[KeyT]:
# returns a single new list combining keys and args
try:
iter(keys)
# a string or bytes instance can be iterated, but indicates
# keys wasn't passed as a list
if isinstance(keys, (bytes, str)):
keys = [keys]
else:
keys = list(keys)
except TypeError:
keys = [keys]
if args:
keys.extend(args)
return keys
def nativestr(x):
"""Return the decoded binary string, or a string, depending on type."""
r = x.decode("utf-8", "replace") if isinstance(x, bytes) else x
if r == "null":
return
return r
def delist(x):
"""Given a list of binaries, return the stringified version."""
if x is None:
return x
return [nativestr(obj) for obj in x]
def parse_to_list(response):
"""Optimistically parse the response to a list."""
res = []
special_values = {"infinity", "nan", "-infinity"}
if response is None:
return res
for item in response:
if item is None:
res.append(None)
continue
try:
item_str = nativestr(item)
except TypeError:
res.append(None)
continue
if isinstance(item_str, str) and item_str.lower() in special_values:
res.append(item_str) # Keep as string
else:
try:
res.append(int(item))
except ValueError:
try:
res.append(float(item))
except ValueError:
res.append(item_str)
return res
def random_string(length=10):
"""
Returns a random N character long string.
"""
return "".join( # nosec
random.choice(string.ascii_lowercase) for x in range(length)
)
def decode_dict_keys(obj):
"""Decode the keys of the given dictionary with utf-8."""
newobj = copy.copy(obj)
for k in obj.keys():
if isinstance(k, bytes):
newobj[k.decode("utf-8")] = newobj[k]
newobj.pop(k)
return newobj
def get_protocol_version(client):
if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis):
return client.connection_pool.connection_kwargs.get("protocol")
elif isinstance(client, redis.cluster.AbstractRedisCluster):
return client.nodes_manager.connection_kwargs.get("protocol")

View File

@@ -0,0 +1,147 @@
from json import JSONDecodeError, JSONDecoder, JSONEncoder
import redis
from ..helpers import get_protocol_version, nativestr
from .commands import JSONCommands
from .decoders import bulk_of_jsons, decode_list
class JSON(JSONCommands):
"""
Create a client for talking to json.
:param decoder:
:type json.JSONDecoder: An instance of json.JSONDecoder
:param encoder:
:type json.JSONEncoder: An instance of json.JSONEncoder
"""
def __init__(
self, client, version=None, decoder=JSONDecoder(), encoder=JSONEncoder()
):
"""
Create a client for talking to json.
:param decoder:
:type json.JSONDecoder: An instance of json.JSONDecoder
:param encoder:
:type json.JSONEncoder: An instance of json.JSONEncoder
"""
# Set the module commands' callbacks
self._MODULE_CALLBACKS = {
"JSON.ARRPOP": self._decode,
"JSON.DEBUG": self._decode,
"JSON.GET": self._decode,
"JSON.MERGE": lambda r: r and nativestr(r) == "OK",
"JSON.MGET": bulk_of_jsons(self._decode),
"JSON.MSET": lambda r: r and nativestr(r) == "OK",
"JSON.RESP": self._decode,
"JSON.SET": lambda r: r and nativestr(r) == "OK",
"JSON.TOGGLE": self._decode,
}
_RESP2_MODULE_CALLBACKS = {
"JSON.ARRAPPEND": self._decode,
"JSON.ARRINDEX": self._decode,
"JSON.ARRINSERT": self._decode,
"JSON.ARRLEN": self._decode,
"JSON.ARRTRIM": self._decode,
"JSON.CLEAR": int,
"JSON.DEL": int,
"JSON.FORGET": int,
"JSON.GET": self._decode,
"JSON.NUMINCRBY": self._decode,
"JSON.NUMMULTBY": self._decode,
"JSON.OBJKEYS": self._decode,
"JSON.STRAPPEND": self._decode,
"JSON.OBJLEN": self._decode,
"JSON.STRLEN": self._decode,
"JSON.TOGGLE": self._decode,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.execute_command = client.execute_command
self.MODULE_VERSION = version
if get_protocol_version(self.client) in ["3", 3]:
self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for key, value in self._MODULE_CALLBACKS.items():
self.client.set_response_callback(key, value)
self.__encoder__ = encoder
self.__decoder__ = decoder
def _decode(self, obj):
"""Get the decoder."""
if obj is None:
return obj
try:
x = self.__decoder__.decode(obj)
if x is None:
raise TypeError
return x
except TypeError:
try:
return self.__decoder__.decode(obj.decode())
except AttributeError:
return decode_list(obj)
except (AttributeError, JSONDecodeError):
return decode_list(obj)
def _encode(self, obj):
"""Get the encoder."""
return self.__encoder__.encode(obj)
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the JSON module, that can be used for executing
JSON commands, as well as classic core commands.
Usage example:
r = redis.Redis()
pipe = r.json().pipeline()
pipe.jsonset('foo', '.', {'hello!': 'world'})
pipe.jsonget('foo')
pipe.jsonget('notakey')
"""
if isinstance(self.client, redis.RedisCluster):
p = ClusterPipeline(
nodes_manager=self.client.nodes_manager,
commands_parser=self.client.commands_parser,
startup_nodes=self.client.nodes_manager.startup_nodes,
result_callbacks=self.client.result_callbacks,
cluster_response_callbacks=self.client.cluster_response_callbacks,
cluster_error_retry_attempts=self.client.retry.get_retries(),
read_from_replicas=self.client.read_from_replicas,
reinitialize_steps=self.client.reinitialize_steps,
lock=self.client._lock,
)
else:
p = Pipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p._encode = self._encode
p._decode = self._decode
return p
class ClusterPipeline(JSONCommands, redis.cluster.ClusterPipeline):
"""Cluster pipeline for the module."""
class Pipeline(JSONCommands, redis.client.Pipeline):
"""Pipeline for the module."""

View File

@@ -0,0 +1,5 @@
from typing import List, Mapping, Union
JsonType = Union[
str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"]
]

View File

@@ -0,0 +1,431 @@
import os
from json import JSONDecodeError, loads
from typing import Dict, List, Optional, Tuple, Union
from redis.exceptions import DataError
from redis.utils import deprecated_function
from ._util import JsonType
from .decoders import decode_dict_keys
from .path import Path
class JSONCommands:
"""json commands."""
def arrappend(
self, name: str, path: Optional[str] = Path.root_path(), *args: JsonType
) -> List[Optional[int]]:
"""Append the objects ``args`` to the array under the
``path` in key ``name``.
For more information see `JSON.ARRAPPEND <https://redis.io/commands/json.arrappend>`_..
""" # noqa
pieces = [name, str(path)]
for o in args:
pieces.append(self._encode(o))
return self.execute_command("JSON.ARRAPPEND", *pieces)
def arrindex(
self,
name: str,
path: str,
scalar: int,
start: Optional[int] = None,
stop: Optional[int] = None,
) -> List[Optional[int]]:
"""
Return the index of ``scalar`` in the JSON array under ``path`` at key
``name``.
The search can be limited using the optional inclusive ``start``
and exclusive ``stop`` indices.
For more information see `JSON.ARRINDEX <https://redis.io/commands/json.arrindex>`_.
""" # noqa
pieces = [name, str(path), self._encode(scalar)]
if start is not None:
pieces.append(start)
if stop is not None:
pieces.append(stop)
return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name])
def arrinsert(
self, name: str, path: str, index: int, *args: JsonType
) -> List[Optional[int]]:
"""Insert the objects ``args`` to the array at index ``index``
under the ``path` in key ``name``.
For more information see `JSON.ARRINSERT <https://redis.io/commands/json.arrinsert>`_.
""" # noqa
pieces = [name, str(path), index]
for o in args:
pieces.append(self._encode(o))
return self.execute_command("JSON.ARRINSERT", *pieces)
def arrlen(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[int]]:
"""Return the length of the array JSON value under ``path``
at key``name``.
For more information see `JSON.ARRLEN <https://redis.io/commands/json.arrlen>`_.
""" # noqa
return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name])
def arrpop(
self,
name: str,
path: Optional[str] = Path.root_path(),
index: Optional[int] = -1,
) -> List[Optional[str]]:
"""Pop the element at ``index`` in the array JSON value under
``path`` at key ``name``.
For more information see `JSON.ARRPOP <https://redis.io/commands/json.arrpop>`_.
""" # noqa
return self.execute_command("JSON.ARRPOP", name, str(path), index)
def arrtrim(
self, name: str, path: str, start: int, stop: int
) -> List[Optional[int]]:
"""Trim the array JSON value under ``path`` at key ``name`` to the
inclusive range given by ``start`` and ``stop``.
For more information see `JSON.ARRTRIM <https://redis.io/commands/json.arrtrim>`_.
""" # noqa
return self.execute_command("JSON.ARRTRIM", name, str(path), start, stop)
def type(self, name: str, path: Optional[str] = Path.root_path()) -> List[str]:
"""Get the type of the JSON value under ``path`` from key ``name``.
For more information see `JSON.TYPE <https://redis.io/commands/json.type>`_.
""" # noqa
return self.execute_command("JSON.TYPE", name, str(path), keys=[name])
def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List:
"""Return the JSON value under ``path`` at key ``name``.
For more information see `JSON.RESP <https://redis.io/commands/json.resp>`_.
""" # noqa
return self.execute_command("JSON.RESP", name, str(path), keys=[name])
def objkeys(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[List[str]]]:
"""Return the key names in the dictionary JSON value under ``path`` at
key ``name``.
For more information see `JSON.OBJKEYS <https://redis.io/commands/json.objkeys>`_.
""" # noqa
return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name])
def objlen(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[int]]:
"""Return the length of the dictionary JSON value under ``path`` at key
``name``.
For more information see `JSON.OBJLEN <https://redis.io/commands/json.objlen>`_.
""" # noqa
return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name])
def numincrby(self, name: str, path: str, number: int) -> str:
"""Increment the numeric (integer or floating point) JSON value under
``path`` at key ``name`` by the provided ``number``.
For more information see `JSON.NUMINCRBY <https://redis.io/commands/json.numincrby>`_.
""" # noqa
return self.execute_command(
"JSON.NUMINCRBY", name, str(path), self._encode(number)
)
@deprecated_function(version="4.0.0", reason="deprecated since redisjson 1.0.0")
def nummultby(self, name: str, path: str, number: int) -> str:
"""Multiply the numeric (integer or floating point) JSON value under
``path`` at key ``name`` with the provided ``number``.
For more information see `JSON.NUMMULTBY <https://redis.io/commands/json.nummultby>`_.
""" # noqa
return self.execute_command(
"JSON.NUMMULTBY", name, str(path), self._encode(number)
)
def clear(self, name: str, path: Optional[str] = Path.root_path()) -> int:
"""Empty arrays and objects (to have zero slots/keys without deleting the
array/object).
Return the count of cleared paths (ignoring non-array and non-objects
paths).
For more information see `JSON.CLEAR <https://redis.io/commands/json.clear>`_.
""" # noqa
return self.execute_command("JSON.CLEAR", name, str(path))
def delete(self, key: str, path: Optional[str] = Path.root_path()) -> int:
"""Delete the JSON value stored at key ``key`` under ``path``.
For more information see `JSON.DEL <https://redis.io/commands/json.del>`_.
"""
return self.execute_command("JSON.DEL", key, str(path))
# forget is an alias for delete
forget = delete
def get(
self, name: str, *args, no_escape: Optional[bool] = False
) -> Optional[List[JsonType]]:
"""
Get the object stored as a JSON value at key ``name``.
``args`` is zero or more paths, and defaults to root path
```no_escape`` is a boolean flag to add no_escape option to get
non-ascii characters
For more information see `JSON.GET <https://redis.io/commands/json.get>`_.
""" # noqa
pieces = [name]
if no_escape:
pieces.append("noescape")
if len(args) == 0:
pieces.append(Path.root_path())
else:
for p in args:
pieces.append(str(p))
# Handle case where key doesn't exist. The JSONDecoder would raise a
# TypeError exception since it can't decode None
try:
return self.execute_command("JSON.GET", *pieces, keys=[name])
except TypeError:
return None
def mget(self, keys: List[str], path: str) -> List[JsonType]:
"""
Get the objects stored as a JSON values under ``path``. ``keys``
is a list of one or more keys.
For more information see `JSON.MGET <https://redis.io/commands/json.mget>`_.
""" # noqa
pieces = []
pieces += keys
pieces.append(str(path))
return self.execute_command("JSON.MGET", *pieces, keys=keys)
def set(
self,
name: str,
path: str,
obj: JsonType,
nx: Optional[bool] = False,
xx: Optional[bool] = False,
decode_keys: Optional[bool] = False,
) -> Optional[str]:
"""
Set the JSON value at key ``name`` under the ``path`` to ``obj``.
``nx`` if set to True, set ``value`` only if it does not exist.
``xx`` if set to True, set ``value`` only if it exists.
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
with utf-8.
For the purpose of using this within a pipeline, this command is also
aliased to JSON.SET.
For more information see `JSON.SET <https://redis.io/commands/json.set>`_.
"""
if decode_keys:
obj = decode_dict_keys(obj)
pieces = [name, str(path), self._encode(obj)]
# Handle existential modifiers
if nx and xx:
raise Exception(
"nx and xx are mutually exclusive: use one, the "
"other or neither - but not both"
)
elif nx:
pieces.append("NX")
elif xx:
pieces.append("XX")
return self.execute_command("JSON.SET", *pieces)
def mset(self, triplets: List[Tuple[str, str, JsonType]]) -> Optional[str]:
"""
Set the JSON value at key ``name`` under the ``path`` to ``obj``
for one or more keys.
``triplets`` is a list of one or more triplets of key, path, value.
For the purpose of using this within a pipeline, this command is also
aliased to JSON.MSET.
For more information see `JSON.MSET <https://redis.io/commands/json.mset>`_.
"""
pieces = []
for triplet in triplets:
pieces.extend([triplet[0], str(triplet[1]), self._encode(triplet[2])])
return self.execute_command("JSON.MSET", *pieces)
def merge(
self,
name: str,
path: str,
obj: JsonType,
decode_keys: Optional[bool] = False,
) -> Optional[str]:
"""
Merges a given JSON value into matching paths. Consequently, JSON values
at matching paths are updated, deleted, or expanded with new children
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
with utf-8.
For more information see `JSON.MERGE <https://redis.io/commands/json.merge>`_.
"""
if decode_keys:
obj = decode_dict_keys(obj)
pieces = [name, str(path), self._encode(obj)]
return self.execute_command("JSON.MERGE", *pieces)
def set_file(
self,
name: str,
path: str,
file_name: str,
nx: Optional[bool] = False,
xx: Optional[bool] = False,
decode_keys: Optional[bool] = False,
) -> Optional[str]:
"""
Set the JSON value at key ``name`` under the ``path`` to the content
of the json file ``file_name``.
``nx`` if set to True, set ``value`` only if it does not exist.
``xx`` if set to True, set ``value`` only if it exists.
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
with utf-8.
"""
with open(file_name) as fp:
file_content = loads(fp.read())
return self.set(name, path, file_content, nx=nx, xx=xx, decode_keys=decode_keys)
def set_path(
self,
json_path: str,
root_folder: str,
nx: Optional[bool] = False,
xx: Optional[bool] = False,
decode_keys: Optional[bool] = False,
) -> Dict[str, bool]:
"""
Iterate over ``root_folder`` and set each JSON file to a value
under ``json_path`` with the file name as the key.
``nx`` if set to True, set ``value`` only if it does not exist.
``xx`` if set to True, set ``value`` only if it exists.
``decode_keys`` If set to True, the keys of ``obj`` will be decoded
with utf-8.
"""
set_files_result = {}
for root, dirs, files in os.walk(root_folder):
for file in files:
file_path = os.path.join(root, file)
try:
file_name = file_path.rsplit(".")[0]
self.set_file(
file_name,
json_path,
file_path,
nx=nx,
xx=xx,
decode_keys=decode_keys,
)
set_files_result[file_path] = True
except JSONDecodeError:
set_files_result[file_path] = False
return set_files_result
def strlen(self, name: str, path: Optional[str] = None) -> List[Optional[int]]:
"""Return the length of the string JSON value under ``path`` at key
``name``.
For more information see `JSON.STRLEN <https://redis.io/commands/json.strlen>`_.
""" # noqa
pieces = [name]
if path is not None:
pieces.append(str(path))
return self.execute_command("JSON.STRLEN", *pieces, keys=[name])
def toggle(
self, name: str, path: Optional[str] = Path.root_path()
) -> Union[bool, List[Optional[int]]]:
"""Toggle boolean value under ``path`` at key ``name``.
returning the new value.
For more information see `JSON.TOGGLE <https://redis.io/commands/json.toggle>`_.
""" # noqa
return self.execute_command("JSON.TOGGLE", name, str(path))
def strappend(
self, name: str, value: str, path: Optional[str] = Path.root_path()
) -> Union[int, List[Optional[int]]]:
"""Append to the string JSON value. If two options are specified after
the key name, the path is determined to be the first. If a single
option is passed, then the root_path (i.e Path.root_path()) is used.
For more information see `JSON.STRAPPEND <https://redis.io/commands/json.strappend>`_.
""" # noqa
pieces = [name, str(path), self._encode(value)]
return self.execute_command("JSON.STRAPPEND", *pieces)
def debug(
self,
subcommand: str,
key: Optional[str] = None,
path: Optional[str] = Path.root_path(),
) -> Union[int, List[str]]:
"""Return the memory usage in bytes of a value under ``path`` from
key ``name``.
For more information see `JSON.DEBUG <https://redis.io/commands/json.debug>`_.
""" # noqa
valid_subcommands = ["MEMORY", "HELP"]
if subcommand not in valid_subcommands:
raise DataError("The only valid subcommands are ", str(valid_subcommands))
pieces = [subcommand]
if subcommand == "MEMORY":
if key is None:
raise DataError("No key specified")
pieces.append(key)
pieces.append(str(path))
return self.execute_command("JSON.DEBUG", *pieces)
@deprecated_function(
version="4.0.0", reason="redisjson-py supported this, call get directly."
)
def jsonget(self, *args, **kwargs):
return self.get(*args, **kwargs)
@deprecated_function(
version="4.0.0", reason="redisjson-py supported this, call get directly."
)
def jsonmget(self, *args, **kwargs):
return self.mget(*args, **kwargs)
@deprecated_function(
version="4.0.0", reason="redisjson-py supported this, call get directly."
)
def jsonset(self, *args, **kwargs):
return self.set(*args, **kwargs)

View File

@@ -0,0 +1,60 @@
import copy
import re
from ..helpers import nativestr
def bulk_of_jsons(d):
"""Replace serialized JSON values with objects in a
bulk array response (list).
"""
def _f(b):
for index, item in enumerate(b):
if item is not None:
b[index] = d(item)
return b
return _f
def decode_dict_keys(obj):
"""Decode the keys of the given dictionary with utf-8."""
newobj = copy.copy(obj)
for k in obj.keys():
if isinstance(k, bytes):
newobj[k.decode("utf-8")] = newobj[k]
newobj.pop(k)
return newobj
def unstring(obj):
"""
Attempt to parse string to native integer formats.
One can't simply call int/float in a try/catch because there is a
semantic difference between (for example) 15.0 and 15.
"""
floatreg = "^\\d+.\\d+$"
match = re.findall(floatreg, obj)
if match != []:
return float(match[0])
intreg = "^\\d+$"
match = re.findall(intreg, obj)
if match != []:
return int(match[0])
return obj
def decode_list(b):
"""
Given a non-deserializable object, make a best effort to
return a useful set of results.
"""
if isinstance(b, list):
return [nativestr(obj) for obj in b]
elif isinstance(b, bytes):
return unstring(nativestr(b))
elif isinstance(b, str):
return unstring(b)
return b

View File

@@ -0,0 +1,16 @@
class Path:
"""This class represents a path in a JSON value."""
strPath = ""
@staticmethod
def root_path():
"""Return the root path's string representation."""
return "."
def __init__(self, path):
"""Make a new path based on the string representation in `path`."""
self.strPath = path
def __repr__(self):
return self.strPath

View File

@@ -0,0 +1,101 @@
from __future__ import annotations
from json import JSONDecoder, JSONEncoder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .bf import BFBloom, CFBloom, CMSBloom, TDigestBloom, TOPKBloom
from .json import JSON
from .search import AsyncSearch, Search
from .timeseries import TimeSeries
from .vectorset import VectorSet
class RedisModuleCommands:
"""This class contains the wrapper functions to bring supported redis
modules into the command namespace.
"""
def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()) -> JSON:
"""Access the json namespace, providing support for redis json."""
from .json import JSON
jj = JSON(client=self, encoder=encoder, decoder=decoder)
return jj
def ft(self, index_name="idx") -> Search:
"""Access the search namespace, providing support for redis search."""
from .search import Search
s = Search(client=self, index_name=index_name)
return s
def ts(self) -> TimeSeries:
"""Access the timeseries namespace, providing support for
redis timeseries data.
"""
from .timeseries import TimeSeries
s = TimeSeries(client=self)
return s
def bf(self) -> BFBloom:
"""Access the bloom namespace."""
from .bf import BFBloom
bf = BFBloom(client=self)
return bf
def cf(self) -> CFBloom:
"""Access the bloom namespace."""
from .bf import CFBloom
cf = CFBloom(client=self)
return cf
def cms(self) -> CMSBloom:
"""Access the bloom namespace."""
from .bf import CMSBloom
cms = CMSBloom(client=self)
return cms
def topk(self) -> TOPKBloom:
"""Access the bloom namespace."""
from .bf import TOPKBloom
topk = TOPKBloom(client=self)
return topk
def tdigest(self) -> TDigestBloom:
"""Access the bloom namespace."""
from .bf import TDigestBloom
tdigest = TDigestBloom(client=self)
return tdigest
def vset(self) -> VectorSet:
"""Access the VectorSet commands namespace."""
from .vectorset import VectorSet
vset = VectorSet(client=self)
return vset
class AsyncRedisModuleCommands(RedisModuleCommands):
def ft(self, index_name="idx") -> AsyncSearch:
"""Access the search namespace, providing support for redis search."""
from .search import AsyncSearch
s = AsyncSearch(client=self, index_name=index_name)
return s

View File

@@ -0,0 +1,189 @@
from redis.client import Pipeline as RedisPipeline
from ...asyncio.client import Pipeline as AsyncioPipeline
from .commands import (
AGGREGATE_CMD,
CONFIG_CMD,
INFO_CMD,
PROFILE_CMD,
SEARCH_CMD,
SPELLCHECK_CMD,
SYNDUMP_CMD,
AsyncSearchCommands,
SearchCommands,
)
class Search(SearchCommands):
"""
Create a client for talking to search.
It abstracts the API of the module and lets you just use the engine.
"""
class BatchIndexer:
"""
A batch indexer allows you to automatically batch
document indexing in pipelines, flushing it every N documents.
"""
def __init__(self, client, chunk_size=1000):
self.client = client
self.execute_command = client.execute_command
self._pipeline = client.pipeline(transaction=False, shard_hint=None)
self.total = 0
self.chunk_size = chunk_size
self.current_chunk = 0
def __del__(self):
if self.current_chunk:
self.commit()
def add_document(
self,
doc_id,
nosave=False,
score=1.0,
payload=None,
replace=False,
partial=False,
no_create=False,
**fields,
):
"""
Add a document to the batch query
"""
self.client._add_document(
doc_id,
conn=self._pipeline,
nosave=nosave,
score=score,
payload=payload,
replace=replace,
partial=partial,
no_create=no_create,
**fields,
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
self.commit()
def add_document_hash(self, doc_id, score=1.0, replace=False):
"""
Add a hash to the batch query
"""
self.client._add_document_hash(
doc_id, conn=self._pipeline, score=score, replace=replace
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
self.commit()
def commit(self):
"""
Manually commit and flush the batch indexing query
"""
self._pipeline.execute()
self.current_chunk = 0
def __init__(self, client, index_name="idx"):
"""
Create a new Client for the given index_name.
The default name is `idx`
If conn is not None, we employ an already existing redis connection
"""
self._MODULE_CALLBACKS = {}
self.client = client
self.index_name = index_name
self.execute_command = client.execute_command
self._pipeline = client.pipeline
self._RESP2_MODULE_CALLBACKS = {
INFO_CMD: self._parse_info,
SEARCH_CMD: self._parse_search,
AGGREGATE_CMD: self._parse_aggregate,
PROFILE_CMD: self._parse_profile,
SPELLCHECK_CMD: self._parse_spellcheck,
CONFIG_CMD: self._parse_config_get,
SYNDUMP_CMD: self._parse_syndump,
}
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the SEARCH module, that can be used for executing
SEARCH commands, as well as classic core commands.
"""
p = Pipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p.index_name = self.index_name
return p
class AsyncSearch(Search, AsyncSearchCommands):
class BatchIndexer(Search.BatchIndexer):
"""
A batch indexer allows you to automatically batch
document indexing in pipelines, flushing it every N documents.
"""
async def add_document(
self,
doc_id,
nosave=False,
score=1.0,
payload=None,
replace=False,
partial=False,
no_create=False,
**fields,
):
"""
Add a document to the batch query
"""
self.client._add_document(
doc_id,
conn=self._pipeline,
nosave=nosave,
score=score,
payload=payload,
replace=replace,
partial=partial,
no_create=no_create,
**fields,
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
await self.commit()
async def commit(self):
"""
Manually commit and flush the batch indexing query
"""
await self._pipeline.execute()
self.current_chunk = 0
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the SEARCH module, that can be used for executing
SEARCH commands, as well as classic core commands.
"""
p = AsyncPipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p.index_name = self.index_name
return p
class Pipeline(SearchCommands, RedisPipeline):
"""Pipeline for the module."""
class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline):
"""AsyncPipeline for the module."""

View File

@@ -0,0 +1,7 @@
def to_string(s, encoding: str = "utf-8"):
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode(encoding, "ignore")
else:
return s # Not a string we care about

View File

@@ -0,0 +1,399 @@
from typing import List, Optional, Tuple, Union
from redis.commands.search.dialect import DEFAULT_DIALECT
FIELDNAME = object()
class Limit:
def __init__(self, offset: int = 0, count: int = 0) -> None:
self.offset = offset
self.count = count
def build_args(self):
if self.count:
return ["LIMIT", str(self.offset), str(self.count)]
else:
return []
class Reducer:
"""
Base reducer object for all reducers.
See the `redisearch.reducers` module for the actual reducers.
"""
NAME = None
def __init__(self, *args: str) -> None:
self._args: Tuple[str, ...] = args
self._field: Optional[str] = None
self._alias: Optional[str] = None
def alias(self, alias: str) -> "Reducer":
"""
Set the alias for this reducer.
### Parameters
- **alias**: The value of the alias for this reducer. If this is the
special value `aggregation.FIELDNAME` then this reducer will be
aliased using the same name as the field upon which it operates.
Note that using `FIELDNAME` is only possible on reducers which
operate on a single field value.
This method returns the `Reducer` object making it suitable for
chaining.
"""
if alias is FIELDNAME:
if not self._field:
raise ValueError("Cannot use FIELDNAME alias with no field")
else:
# Chop off initial '@'
alias = self._field[1:]
self._alias = alias
return self
@property
def args(self) -> Tuple[str, ...]:
return self._args
class SortDirection:
"""
This special class is used to indicate sort direction.
"""
DIRSTRING: Optional[str] = None
def __init__(self, field: str) -> None:
self.field = field
class Asc(SortDirection):
"""
Indicate that the given field should be sorted in ascending order
"""
DIRSTRING = "ASC"
class Desc(SortDirection):
"""
Indicate that the given field should be sorted in descending order
"""
DIRSTRING = "DESC"
class AggregateRequest:
"""
Aggregation request which can be passed to `Client.aggregate`.
"""
def __init__(self, query: str = "*") -> None:
"""
Create an aggregation request. This request may then be passed to
`client.aggregate()`.
In order for the request to be usable, it must contain at least one
group.
- **query** Query string for filtering records.
All member methods (except `build_args()`)
return the object itself, making them useful for chaining.
"""
self._query: str = query
self._aggregateplan: List[str] = []
self._loadfields: List[str] = []
self._loadall: bool = False
self._max: int = 0
self._with_schema: bool = False
self._verbatim: bool = False
self._cursor: List[str] = []
self._dialect: int = DEFAULT_DIALECT
self._add_scores: bool = False
self._scorer: str = "TFIDF"
def load(self, *fields: str) -> "AggregateRequest":
"""
Indicate the fields to be returned in the response. These fields are
returned in addition to any others implicitly specified.
### Parameters
- **fields**: If fields not specified, all the fields will be loaded.
Otherwise, fields should be given in the format of `@field`.
"""
if fields:
self._loadfields.extend(fields)
else:
self._loadall = True
return self
def group_by(
self, fields: Union[str, List[str]], *reducers: Reducer
) -> "AggregateRequest":
"""
Specify by which fields to group the aggregation.
### Parameters
- **fields**: Fields to group by. This can either be a single string,
or a list of strings. both cases, the field should be specified as
`@field`.
- **reducers**: One or more reducers. Reducers may be found in the
`aggregation` module.
"""
fields = [fields] if isinstance(fields, str) else fields
ret = ["GROUPBY", str(len(fields)), *fields]
for reducer in reducers:
ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
ret.extend(reducer.args)
if reducer._alias is not None:
ret += ["AS", reducer._alias]
self._aggregateplan.extend(ret)
return self
def apply(self, **kwexpr) -> "AggregateRequest":
"""
Specify one or more projection expressions to add to each result
### Parameters
- **kwexpr**: One or more key-value pairs for a projection. The key is
the alias for the projection, and the value is the projection
expression itself, for example `apply(square_root="sqrt(@foo)")`
"""
for alias, expr in kwexpr.items():
ret = ["APPLY", expr]
if alias is not None:
ret += ["AS", alias]
self._aggregateplan.extend(ret)
return self
def limit(self, offset: int, num: int) -> "AggregateRequest":
"""
Sets the limit for the most recent group or query.
If no group has been defined yet (via `group_by()`) then this sets
the limit for the initial pool of results from the query. Otherwise,
this limits the number of items operated on from the previous group.
Setting a limit on the initial search results may be useful when
attempting to execute an aggregation on a sample of a large data set.
### Parameters
- **offset**: Result offset from which to begin paging
- **num**: Number of results to return
Example of sorting the initial results:
```
AggregateRequest("@sale_amount:[10000, inf]")\
.limit(0, 10)\
.group_by("@state", r.count())
```
Will only group by the states found in the first 10 results of the
query `@sale_amount:[10000, inf]`. On the other hand,
```
AggregateRequest("@sale_amount:[10000, inf]")\
.limit(0, 1000)\
.group_by("@state", r.count()\
.limit(0, 10)
```
Will group all the results matching the query, but only return the
first 10 groups.
If you only wish to return a *top-N* style query, consider using
`sort_by()` instead.
"""
_limit = Limit(offset, num)
self._aggregateplan.extend(_limit.build_args())
return self
def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
"""
Indicate how the results should be sorted. This can also be used for
*top-N* style queries
### Parameters
- **fields**: The fields by which to sort. This can be either a single
field or a list of fields. If you wish to specify order, you can
use the `Asc` or `Desc` wrapper classes.
- **max**: Maximum number of results to return. This can be
used instead of `LIMIT` and is also faster.
Example of sorting by `foo` ascending and `bar` descending:
```
sort_by(Asc("@foo"), Desc("@bar"))
```
Return the top 10 customers:
```
AggregateRequest()\
.group_by("@customer", r.sum("@paid").alias(FIELDNAME))\
.sort_by(Desc("@paid"), max=10)
```
"""
fields_args = []
for f in fields:
if isinstance(f, (Asc, Desc)):
fields_args += [f.field, f.DIRSTRING]
else:
fields_args += [f]
ret = ["SORTBY", str(len(fields_args))]
ret.extend(fields_args)
max = kwargs.get("max", 0)
if max > 0:
ret += ["MAX", str(max)]
self._aggregateplan.extend(ret)
return self
def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest":
"""
Specify filter for post-query results using predicates relating to
values in the result set.
### Parameters
- **fields**: Fields to group by. This can either be a single string,
or a list of strings.
"""
if isinstance(expressions, str):
expressions = [expressions]
for expression in expressions:
self._aggregateplan.extend(["FILTER", expression])
return self
def with_schema(self) -> "AggregateRequest":
"""
If set, the `schema` property will contain a list of `[field, type]`
entries in the result object.
"""
self._with_schema = True
return self
def add_scores(self) -> "AggregateRequest":
"""
If set, includes the score as an ordinary field of the row.
"""
self._add_scores = True
return self
def scorer(self, scorer: str) -> "AggregateRequest":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self
def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest":
args = ["WITHCURSOR"]
if count:
args += ["COUNT", str(count)]
if max_idle:
args += ["MAXIDLE", str(max_idle * 1000)]
self._cursor = args
return self
def build_args(self) -> List[str]:
# @foo:bar ...
ret = [self._query]
if self._with_schema:
ret.append("WITHSCHEMA")
if self._verbatim:
ret.append("VERBATIM")
if self._scorer:
ret.extend(["SCORER", self._scorer])
if self._add_scores:
ret.append("ADDSCORES")
if self._cursor:
ret += self._cursor
if self._loadall:
ret.append("LOAD")
ret.append("*")
elif self._loadfields:
ret.append("LOAD")
ret.append(str(len(self._loadfields)))
ret.extend(self._loadfields)
if self._dialect:
ret.extend(["DIALECT", str(self._dialect)])
ret.extend(self._aggregateplan)
return ret
def dialect(self, dialect: int) -> "AggregateRequest":
"""
Add a dialect field to the aggregate command.
- **dialect** - dialect version to execute the query under
"""
self._dialect = dialect
return self
class Cursor:
def __init__(self, cid: int) -> None:
self.cid = cid
self.max_idle = 0
self.count = 0
def build_args(self):
args = [str(self.cid)]
if self.max_idle:
args += ["MAXIDLE", str(self.max_idle)]
if self.count:
args += ["COUNT", str(self.count)]
return args
class AggregateResult:
def __init__(self, rows, cursor: Cursor, schema) -> None:
self.rows = rows
self.cursor = cursor
self.schema = schema
def __repr__(self) -> str:
cid = self.cursor.cid if self.cursor else -1
return (
f"<{self.__class__.__name__} at 0x{id(self):x} "
f"Rows={len(self.rows)}, Cursor={cid}>"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
# Value for the default dialect to be used as a part of
# Search or Aggregate query.
DEFAULT_DIALECT = 2

View File

@@ -0,0 +1,17 @@
class Document:
"""
Represents a single document in a result set
"""
def __init__(self, id, payload=None, **fields):
self.id = id
self.payload = payload
for k, v in fields.items():
setattr(self, k, v)
def __repr__(self):
return f"Document {self.__dict__}"
def __getitem__(self, item):
value = getattr(self, item)
return value

View File

@@ -0,0 +1,210 @@
from typing import List
from redis import DataError
class Field:
"""
A class representing a field in a document.
"""
NUMERIC = "NUMERIC"
TEXT = "TEXT"
WEIGHT = "WEIGHT"
GEO = "GEO"
TAG = "TAG"
VECTOR = "VECTOR"
SORTABLE = "SORTABLE"
NOINDEX = "NOINDEX"
AS = "AS"
GEOSHAPE = "GEOSHAPE"
INDEX_MISSING = "INDEXMISSING"
INDEX_EMPTY = "INDEXEMPTY"
def __init__(
self,
name: str,
args: List[str] = None,
sortable: bool = False,
no_index: bool = False,
index_missing: bool = False,
index_empty: bool = False,
as_name: str = None,
):
"""
Create a new field object.
Args:
name: The name of the field.
args:
sortable: If `True`, the field will be sortable.
no_index: If `True`, the field will not be indexed.
index_missing: If `True`, it will be possible to search for documents that
have this field missing.
index_empty: If `True`, it will be possible to search for documents that
have this field empty.
as_name: If provided, this alias will be used for the field.
"""
if args is None:
args = []
self.name = name
self.args = args
self.args_suffix = list()
self.as_name = as_name
if no_index:
self.args_suffix.append(Field.NOINDEX)
if index_missing:
self.args_suffix.append(Field.INDEX_MISSING)
if index_empty:
self.args_suffix.append(Field.INDEX_EMPTY)
if sortable:
self.args_suffix.append(Field.SORTABLE)
if no_index and not sortable:
raise ValueError("Non-Sortable non-Indexable fields are ignored")
def append_arg(self, value):
self.args.append(value)
def redis_args(self):
args = [self.name]
if self.as_name:
args += [self.AS, self.as_name]
args += self.args
args += self.args_suffix
return args
class TextField(Field):
"""
TextField is used to define a text field in a schema definition
"""
NOSTEM = "NOSTEM"
PHONETIC = "PHONETIC"
def __init__(
self,
name: str,
weight: float = 1.0,
no_stem: bool = False,
phonetic_matcher: str = None,
withsuffixtrie: bool = False,
**kwargs,
):
Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs)
if no_stem:
Field.append_arg(self, self.NOSTEM)
if phonetic_matcher and phonetic_matcher in [
"dm:en",
"dm:fr",
"dm:pt",
"dm:es",
]:
Field.append_arg(self, self.PHONETIC)
Field.append_arg(self, phonetic_matcher)
if withsuffixtrie:
Field.append_arg(self, "WITHSUFFIXTRIE")
class NumericField(Field):
"""
NumericField is used to define a numeric field in a schema definition
"""
def __init__(self, name: str, **kwargs):
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
class GeoShapeField(Field):
"""
GeoShapeField is used to enable within/contain indexing/searching
"""
SPHERICAL = "SPHERICAL"
FLAT = "FLAT"
def __init__(self, name: str, coord_system=None, **kwargs):
args = [Field.GEOSHAPE]
if coord_system:
args.append(coord_system)
Field.__init__(self, name, args=args, **kwargs)
class GeoField(Field):
"""
GeoField is used to define a geo-indexing field in a schema definition
"""
def __init__(self, name: str, **kwargs):
Field.__init__(self, name, args=[Field.GEO], **kwargs)
class TagField(Field):
"""
TagField is a tag-indexing field with simpler compression and tokenization.
See http://redisearch.io/Tags/
"""
SEPARATOR = "SEPARATOR"
CASESENSITIVE = "CASESENSITIVE"
def __init__(
self,
name: str,
separator: str = ",",
case_sensitive: bool = False,
withsuffixtrie: bool = False,
**kwargs,
):
args = [Field.TAG, self.SEPARATOR, separator]
if case_sensitive:
args.append(self.CASESENSITIVE)
if withsuffixtrie:
args.append("WITHSUFFIXTRIE")
Field.__init__(self, name, args=args, **kwargs)
class VectorField(Field):
"""
Allows vector similarity queries against the value in this attribute.
See https://oss.redis.com/redisearch/Vectors/#vector_fields.
"""
def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs):
"""
Create Vector Field. Notice that Vector cannot have sortable or no_index tag,
although it's also a Field.
``name`` is the name of the field.
``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA".
``attributes`` each algorithm can have specific attributes. Some of them
are mandatory and some of them are optional. See
https://oss.redis.com/redisearch/master/Vectors/#specific_creation_attributes_per_algorithm
for more information.
"""
sort = kwargs.get("sortable", False)
noindex = kwargs.get("no_index", False)
if sort or noindex:
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]:
raise DataError(
"Realtime vector indexing supporting 3 Indexing Methods:"
"'FLAT', 'HNSW', and 'SVS-VAMANA'."
)
attr_li = []
for key, value in attributes.items():
attr_li.extend([key, value])
Field.__init__(
self, name, args=[Field.VECTOR, algorithm, len(attr_li), *attr_li], **kwargs
)

View File

@@ -0,0 +1,79 @@
from enum import Enum
class IndexType(Enum):
"""Enum of the currently supported index types."""
HASH = 1
JSON = 2
class IndexDefinition:
"""IndexDefinition is used to define a index definition for automatic
indexing on Hash or Json update."""
def __init__(
self,
prefix=[],
filter=None,
language_field=None,
language=None,
score_field=None,
score=1.0,
payload_field=None,
index_type=None,
):
self.args = []
self._append_index_type(index_type)
self._append_prefix(prefix)
self._append_filter(filter)
self._append_language(language_field, language)
self._append_score(score_field, score)
self._append_payload(payload_field)
def _append_index_type(self, index_type):
"""Append `ON HASH` or `ON JSON` according to the enum."""
if index_type is IndexType.HASH:
self.args.extend(["ON", "HASH"])
elif index_type is IndexType.JSON:
self.args.extend(["ON", "JSON"])
elif index_type is not None:
raise RuntimeError(f"index_type must be one of {list(IndexType)}")
def _append_prefix(self, prefix):
"""Append PREFIX."""
if len(prefix) > 0:
self.args.append("PREFIX")
self.args.append(len(prefix))
for p in prefix:
self.args.append(p)
def _append_filter(self, filter):
"""Append FILTER."""
if filter is not None:
self.args.append("FILTER")
self.args.append(filter)
def _append_language(self, language_field, language):
"""Append LANGUAGE_FIELD and LANGUAGE."""
if language_field is not None:
self.args.append("LANGUAGE_FIELD")
self.args.append(language_field)
if language is not None:
self.args.append("LANGUAGE")
self.args.append(language)
def _append_score(self, score_field, score):
"""Append SCORE_FIELD and SCORE."""
if score_field is not None:
self.args.append("SCORE_FIELD")
self.args.append(score_field)
if score is not None:
self.args.append("SCORE")
self.args.append(score)
def _append_payload(self, payload_field):
"""Append PAYLOAD_FIELD."""
if payload_field is not None:
self.args.append("PAYLOAD_FIELD")
self.args.append(payload_field)

View File

@@ -0,0 +1,14 @@
from typing import Any
class ProfileInformation:
"""
Wrapper around FT.PROFILE response
"""
def __init__(self, info: Any) -> None:
self._info: Any = info
@property
def info(self) -> Any:
return self._info

View File

@@ -0,0 +1,381 @@
from typing import List, Optional, Tuple, Union
from redis.commands.search.dialect import DEFAULT_DIALECT
class Query:
"""
Query is used to build complex queries that have more parameters than just
the query string. The query string is set in the constructor, and other
options have setter functions.
The setter functions return the query object so they can be chained.
i.e. `Query("foo").verbatim().filter(...)` etc.
"""
def __init__(self, query_string: str) -> None:
"""
Create a new query object.
The query string is set in the constructor, and other options have
setter functions.
"""
self._query_string: str = query_string
self._offset: int = 0
self._num: int = 10
self._no_content: bool = False
self._no_stopwords: bool = False
self._fields: Optional[List[str]] = None
self._verbatim: bool = False
self._with_payloads: bool = False
self._with_scores: bool = False
self._scorer: Optional[str] = None
self._filters: List = list()
self._ids: Optional[Tuple[str, ...]] = None
self._slop: int = -1
self._timeout: Optional[float] = None
self._in_order: bool = False
self._sortby: Optional[SortbyField] = None
self._return_fields: List = []
self._return_fields_decode_as: dict = {}
self._summarize_fields: List = []
self._highlight_fields: List = []
self._language: Optional[str] = None
self._expander: Optional[str] = None
self._dialect: int = DEFAULT_DIALECT
def query_string(self) -> str:
"""Return the query string of this query only."""
return self._query_string
def limit_ids(self, *ids) -> "Query":
"""Limit the results to a specific set of pre-known document
ids of any length."""
self._ids = ids
return self
def return_fields(self, *fields) -> "Query":
"""Add fields to return fields."""
for field in fields:
self.return_field(field)
return self
def return_field(
self,
field: str,
as_field: Optional[str] = None,
decode_field: Optional[bool] = True,
encoding: Optional[str] = "utf8",
) -> "Query":
"""
Add a field to the list of fields to return.
- **field**: The field to include in query results
- **as_field**: The alias for the field
- **decode_field**: Whether to decode the field from bytes to string
- **encoding**: The encoding to use when decoding the field
"""
self._return_fields.append(field)
self._return_fields_decode_as[field] = encoding if decode_field else None
if as_field is not None:
self._return_fields += ("AS", as_field)
return self
def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List:
if not fields:
return []
return [fields] if isinstance(fields, str) else list(fields)
def summarize(
self,
fields: Optional[List] = None,
context_len: Optional[int] = None,
num_frags: Optional[int] = None,
sep: Optional[str] = None,
) -> "Query":
"""
Return an abridged format of the field, containing only the segments of
the field that contain the matching term(s).
If `fields` is specified, then only the mentioned fields are
summarized; otherwise, all results are summarized.
Server-side defaults are used for each option (except `fields`)
if not specified
- **fields** List of fields to summarize. All fields are summarized
if not specified
- **context_len** Amount of context to include with each fragment
- **num_frags** Number of fragments per document
- **sep** Separator string to separate fragments
"""
args = ["SUMMARIZE"]
fields = self._mk_field_list(fields)
if fields:
args += ["FIELDS", str(len(fields))] + fields
if context_len is not None:
args += ["LEN", str(context_len)]
if num_frags is not None:
args += ["FRAGS", str(num_frags)]
if sep is not None:
args += ["SEPARATOR", sep]
self._summarize_fields = args
return self
def highlight(
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
) -> "Query":
"""
Apply specified markup to matched term(s) within the returned field(s).
- **fields** If specified, then only those mentioned fields are
highlighted, otherwise all fields are highlighted
- **tags** A list of two strings to surround the match.
"""
args = ["HIGHLIGHT"]
fields = self._mk_field_list(fields)
if fields:
args += ["FIELDS", str(len(fields))] + fields
if tags:
args += ["TAGS"] + list(tags)
self._highlight_fields = args
return self
def language(self, language: str) -> "Query":
"""
Analyze the query as being in the specified language.
:param language: The language (e.g. `chinese` or `english`)
"""
self._language = language
return self
def slop(self, slop: int) -> "Query":
"""Allow a maximum of N intervening non-matched terms between
phrase terms (0 means exact phrase).
"""
self._slop = slop
return self
def timeout(self, timeout: float) -> "Query":
"""overrides the timeout parameter of the module"""
self._timeout = timeout
return self
def in_order(self) -> "Query":
"""
Match only documents where the query terms appear in
the same order in the document.
i.e., for the query "hello world", we do not match "world hello"
"""
self._in_order = True
return self
def scorer(self, scorer: str) -> "Query":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
Since Redis 8.0 default was changed to BM25STD.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self
def get_args(self) -> List[Union[str, int, float]]:
"""Format the redis arguments for this query and return them."""
args: List[Union[str, int, float]] = [self._query_string]
args += self._get_args_tags()
args += self._summarize_fields + self._highlight_fields
args += ["LIMIT", self._offset, self._num]
return args
def _get_args_tags(self) -> List[Union[str, int, float]]:
args: List[Union[str, int, float]] = []
if self._no_content:
args.append("NOCONTENT")
if self._fields:
args.append("INFIELDS")
args.append(len(self._fields))
args += self._fields
if self._verbatim:
args.append("VERBATIM")
if self._no_stopwords:
args.append("NOSTOPWORDS")
if self._filters:
for flt in self._filters:
if not isinstance(flt, Filter):
raise AttributeError("Did not receive a Filter object.")
args += flt.args
if self._with_payloads:
args.append("WITHPAYLOADS")
if self._scorer:
args += ["SCORER", self._scorer]
if self._with_scores:
args.append("WITHSCORES")
if self._ids:
args.append("INKEYS")
args.append(len(self._ids))
args += self._ids
if self._slop >= 0:
args += ["SLOP", self._slop]
if self._timeout is not None:
args += ["TIMEOUT", self._timeout]
if self._in_order:
args.append("INORDER")
if self._return_fields:
args.append("RETURN")
args.append(len(self._return_fields))
args += self._return_fields
if self._sortby:
if not isinstance(self._sortby, SortbyField):
raise AttributeError("Did not receive a SortByField.")
args.append("SORTBY")
args += self._sortby.args
if self._language:
args += ["LANGUAGE", self._language]
if self._expander:
args += ["EXPANDER", self._expander]
if self._dialect:
args += ["DIALECT", self._dialect]
return args
def paging(self, offset: int, num: int) -> "Query":
"""
Set the paging for the query (defaults to 0..10).
- **offset**: Paging offset for the results. Defaults to 0
- **num**: How many results do we want
"""
self._offset = offset
self._num = num
return self
def verbatim(self) -> "Query":
"""Set the query to be verbatim, i.e., use no query expansion
or stemming.
"""
self._verbatim = True
return self
def no_content(self) -> "Query":
"""Set the query to only return ids and not the document content."""
self._no_content = True
return self
def no_stopwords(self) -> "Query":
"""
Prevent the query from being filtered for stopwords.
Only useful in very big queries that you are certain contain
no stopwords.
"""
self._no_stopwords = True
return self
def with_payloads(self) -> "Query":
"""Ask the engine to return document payloads."""
self._with_payloads = True
return self
def with_scores(self) -> "Query":
"""Ask the engine to return document search scores."""
self._with_scores = True
return self
def limit_fields(self, *fields: str) -> "Query":
"""
Limit the search to specific TEXT fields only.
- **fields**: Each element should be a string, case sensitive field name
from the defined schema.
"""
self._fields = list(fields)
return self
def add_filter(self, flt: "Filter") -> "Query":
"""
Add a numeric or geo filter to the query.
**Currently, only one of each filter is supported by the engine**
- **flt**: A NumericFilter or GeoFilter object, used on a
corresponding field
"""
self._filters.append(flt)
return self
def sort_by(self, field: str, asc: bool = True) -> "Query":
"""
Add a sortby field to the query.
- **field** - the name of the field to sort by
- **asc** - when `True`, sorting will be done in ascending order
"""
self._sortby = SortbyField(field, asc)
return self
def expander(self, expander: str) -> "Query":
"""
Add an expander field to the query.
- **expander** - the name of the expander
"""
self._expander = expander
return self
def dialect(self, dialect: int) -> "Query":
"""
Add a dialect field to the query.
- **dialect** - dialect version to execute the query under
"""
self._dialect = dialect
return self
class Filter:
def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None:
self.args = [keyword, field] + list(args)
class NumericFilter(Filter):
INF = "+inf"
NEG_INF = "-inf"
def __init__(
self,
field: str,
minval: Union[int, str],
maxval: Union[int, str],
minExclusive: bool = False,
maxExclusive: bool = False,
) -> None:
args = [
minval if not minExclusive else f"({minval}",
maxval if not maxExclusive else f"({maxval}",
]
Filter.__init__(self, "FILTER", field, *args)
class GeoFilter(Filter):
METERS = "m"
KILOMETERS = "km"
FEET = "ft"
MILES = "mi"
def __init__(
self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS
) -> None:
Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit)
class SortbyField:
def __init__(self, field: str, asc=True) -> None:
self.args = [field, "ASC" if asc else "DESC"]

View File

@@ -0,0 +1,317 @@
def tags(*t):
"""
Indicate that the values should be matched to a tag field
### Parameters
- **t**: Tags to search for
"""
if not t:
raise ValueError("At least one tag must be specified")
return TagValue(*t)
def between(a, b, inclusive_min=True, inclusive_max=True):
"""
Indicate that value is a numeric range
"""
return RangeValue(a, b, inclusive_min=inclusive_min, inclusive_max=inclusive_max)
def equal(n):
"""
Match a numeric value
"""
return between(n, n)
def lt(n):
"""
Match any value less than n
"""
return between(None, n, inclusive_max=False)
def le(n):
"""
Match any value less or equal to n
"""
return between(None, n, inclusive_max=True)
def gt(n):
"""
Match any value greater than n
"""
return between(n, None, inclusive_min=False)
def ge(n):
"""
Match any value greater or equal to n
"""
return between(n, None, inclusive_min=True)
def geo(lat, lon, radius, unit="km"):
"""
Indicate that value is a geo region
"""
return GeoValue(lat, lon, radius, unit)
class Value:
@property
def combinable(self):
"""
Whether this type of value may be combined with other values
for the same field. This makes the filter potentially more efficient
"""
return False
@staticmethod
def make_value(v):
"""
Convert an object to a value, if it is not a value already
"""
if isinstance(v, Value):
return v
return ScalarValue(v)
def to_string(self):
raise NotImplementedError()
def __str__(self):
return self.to_string()
class RangeValue(Value):
combinable = False
def __init__(self, a, b, inclusive_min=False, inclusive_max=False):
if a is None:
a = "-inf"
if b is None:
b = "inf"
self.range = [str(a), str(b)]
self.inclusive_min = inclusive_min
self.inclusive_max = inclusive_max
def to_string(self):
return "[{1}{0[0]} {2}{0[1]}]".format(
self.range,
"(" if not self.inclusive_min else "",
"(" if not self.inclusive_max else "",
)
class ScalarValue(Value):
combinable = True
def __init__(self, v):
self.v = str(v)
def to_string(self):
return self.v
class TagValue(Value):
combinable = False
def __init__(self, *tags):
self.tags = tags
def to_string(self):
return "{" + " | ".join(str(t) for t in self.tags) + "}"
class GeoValue(Value):
def __init__(self, lon, lat, radius, unit="km"):
self.lon = lon
self.lat = lat
self.radius = radius
self.unit = unit
def to_string(self):
return f"[{self.lon} {self.lat} {self.radius} {self.unit}]"
class Node:
def __init__(self, *children, **kwparams):
"""
Create a node
### Parameters
- **children**: One or more sub-conditions. These can be additional
`intersect`, `disjunct`, `union`, `optional`, or any other `Node`
type.
The semantics of multiple conditions are dependent on the type of
query. For an `intersection` node, this amounts to a logical AND,
for a `union` node, this amounts to a logical `OR`.
- **kwparams**: key-value parameters. Each key is the name of a field,
and the value should be a field value. This can be one of the
following:
- Simple string (for text field matches)
- value returned by one of the helper functions
- list of either a string or a value
### Examples
Field `num` should be between 1 and 10
```
intersect(num=between(1, 10)
```
Name can either be `bob` or `john`
```
union(name=("bob", "john"))
```
Don't select countries in Israel, Japan, or US
```
disjunct_union(country=("il", "jp", "us"))
```
"""
self.params = []
kvparams = {}
for k, v in kwparams.items():
curvals = kvparams.setdefault(k, [])
if isinstance(v, (str, int, float)):
curvals.append(Value.make_value(v))
elif isinstance(v, Value):
curvals.append(v)
else:
curvals.extend(Value.make_value(subv) for subv in v)
self.params += [Node.to_node(p) for p in children]
for k, v in kvparams.items():
self.params.extend(self.join_fields(k, v))
def join_fields(self, key, vals):
if len(vals) == 1:
return [BaseNode(f"@{key}:{vals[0].to_string()}")]
if not vals[0].combinable:
return [BaseNode(f"@{key}:{v.to_string()}") for v in vals]
s = BaseNode(f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})")
return [s]
@classmethod
def to_node(cls, obj): # noqa
if isinstance(obj, Node):
return obj
return BaseNode(obj)
@property
def JOINSTR(self):
raise NotImplementedError()
def to_string(self, with_parens=None):
with_parens = self._should_use_paren(with_parens)
pre, post = ("(", ")") if with_parens else ("", "")
return f"{pre}{self.JOINSTR.join(n.to_string() for n in self.params)}{post}"
def _should_use_paren(self, optval):
if optval is not None:
return optval
return len(self.params) > 1
def __str__(self):
return self.to_string()
class BaseNode(Node):
def __init__(self, s):
super().__init__()
self.s = str(s)
def to_string(self, with_parens=None):
return self.s
class IntersectNode(Node):
"""
Create an intersection node. All children need to be satisfied in order for
this node to evaluate as true
"""
JOINSTR = " "
class UnionNode(Node):
"""
Create a union node. Any of the children need to be satisfied in order for
this node to evaluate as true
"""
JOINSTR = "|"
class DisjunctNode(IntersectNode):
"""
Create a disjunct node. In order for this node to be true, all of its
children must evaluate to false
"""
def to_string(self, with_parens=None):
with_parens = self._should_use_paren(with_parens)
ret = super().to_string(with_parens=False)
if with_parens:
return "(-" + ret + ")"
else:
return "-" + ret
class DistjunctUnion(DisjunctNode):
"""
This node is true if *all* of its children are false. This is equivalent to
```
disjunct(union(...))
```
"""
JOINSTR = "|"
class OptionalNode(IntersectNode):
"""
Create an optional node. If this nodes evaluates to true, then the document
will be rated higher in score/rank.
"""
def to_string(self, with_parens=None):
with_parens = self._should_use_paren(with_parens)
ret = super().to_string(with_parens=False)
if with_parens:
return "(~" + ret + ")"
else:
return "~" + ret
def intersect(*args, **kwargs):
return IntersectNode(*args, **kwargs)
def union(*args, **kwargs):
return UnionNode(*args, **kwargs)
def disjunct(*args, **kwargs):
return DisjunctNode(*args, **kwargs)
def disjunct_union(*args, **kwargs):
return DistjunctUnion(*args, **kwargs)
def querystring(*args, **kwargs):
return intersect(*args, **kwargs).to_string()

View File

@@ -0,0 +1,182 @@
from typing import Union
from .aggregation import Asc, Desc, Reducer, SortDirection
class FieldOnlyReducer(Reducer):
"""See https://redis.io/docs/interact/search-and-query/search/aggregations/"""
def __init__(self, field: str) -> None:
super().__init__(field)
self._field = field
class count(Reducer):
"""
Counts the number of results in the group
"""
NAME = "COUNT"
def __init__(self) -> None:
super().__init__()
class sum(FieldOnlyReducer):
"""
Calculates the sum of all the values in the given fields within the group
"""
NAME = "SUM"
def __init__(self, field: str) -> None:
super().__init__(field)
class min(FieldOnlyReducer):
"""
Calculates the smallest value in the given field within the group
"""
NAME = "MIN"
def __init__(self, field: str) -> None:
super().__init__(field)
class max(FieldOnlyReducer):
"""
Calculates the largest value in the given field within the group
"""
NAME = "MAX"
def __init__(self, field: str) -> None:
super().__init__(field)
class avg(FieldOnlyReducer):
"""
Calculates the mean value in the given field within the group
"""
NAME = "AVG"
def __init__(self, field: str) -> None:
super().__init__(field)
class tolist(FieldOnlyReducer):
"""
Returns all the matched properties in a list
"""
NAME = "TOLIST"
def __init__(self, field: str) -> None:
super().__init__(field)
class count_distinct(FieldOnlyReducer):
"""
Calculate the number of distinct values contained in all the results in
the group for the given field
"""
NAME = "COUNT_DISTINCT"
def __init__(self, field: str) -> None:
super().__init__(field)
class count_distinctish(FieldOnlyReducer):
"""
Calculate the number of distinct values contained in all the results in the
group for the given field. This uses a faster algorithm than
`count_distinct` but is less accurate
"""
NAME = "COUNT_DISTINCTISH"
class quantile(Reducer):
"""
Return the value for the nth percentile within the range of values for the
field within the group.
"""
NAME = "QUANTILE"
def __init__(self, field: str, pct: float) -> None:
super().__init__(field, str(pct))
self._field = field
class stddev(FieldOnlyReducer):
"""
Return the standard deviation for the values within the group
"""
NAME = "STDDEV"
def __init__(self, field: str) -> None:
super().__init__(field)
class first_value(Reducer):
"""
Selects the first value within the group according to sorting parameters
"""
NAME = "FIRST_VALUE"
def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None:
"""
Selects the first value of the given field within the group.
### Parameter
- **field**: Source field used for the value
- **byfields**: How to sort the results. This can be either the
*class* of `aggregation.Asc` or `aggregation.Desc` in which
case the field `field` is also used as the sort input.
`byfields` can also be one or more *instances* of `Asc` or `Desc`
indicating the sort order for these fields
"""
fieldstrs = []
if (
len(byfields) == 1
and isinstance(byfields[0], type)
and issubclass(byfields[0], SortDirection)
):
byfields = [byfields[0](field)]
for f in byfields:
fieldstrs += [f.field, f.DIRSTRING]
args = [field]
if fieldstrs:
args += ["BY"] + fieldstrs
super().__init__(*args)
self._field = field
class random_sample(Reducer):
"""
Returns a random sample of items from the dataset, from the given property
"""
NAME = "RANDOM_SAMPLE"
def __init__(self, field: str, size: int) -> None:
"""
### Parameter
**field**: Field to sample from
**size**: Return this many items (can be less)
"""
args = [field, str(size)]
super().__init__(*args)
self._field = field

View File

@@ -0,0 +1,87 @@
from typing import Optional
from ._util import to_string
from .document import Document
class Result:
"""
Represents the result of a search query, and has an array of Document
objects
"""
def __init__(
self,
res,
hascontent,
duration=0,
has_payload=False,
with_scores=False,
field_encodings: Optional[dict] = None,
):
"""
- duration: the execution time of the query
- has_payload: whether the query has payloads
- with_scores: whether the query has scores
- field_encodings: a dictionary of field encodings if any is provided
"""
self.total = res[0]
self.duration = duration
self.docs = []
step = 1
if hascontent:
step = step + 1
if has_payload:
step = step + 1
if with_scores:
step = step + 1
offset = 2 if with_scores else 1
for i in range(1, len(res), step):
id = to_string(res[i])
payload = to_string(res[i + offset]) if has_payload else None
# fields_offset = 2 if has_payload else 1
fields_offset = offset + 1 if has_payload else offset
score = float(res[i + 1]) if with_scores else None
fields = {}
if hascontent and res[i + fields_offset] is not None:
keys = map(to_string, res[i + fields_offset][::2])
values = res[i + fields_offset][1::2]
for key, value in zip(keys, values):
if field_encodings is None or key not in field_encodings:
fields[key] = to_string(value)
continue
encoding = field_encodings[key]
# If the encoding is None, we don't need to decode the value
if encoding is None:
fields[key] = value
else:
fields[key] = to_string(value, encoding=encoding)
try:
del fields["id"]
except KeyError:
pass
try:
fields["json"] = fields["$"]
del fields["$"]
except KeyError:
pass
doc = (
Document(id, score=score, payload=payload, **fields)
if with_scores
else Document(id, payload=payload, **fields)
)
self.docs.append(doc)
def __repr__(self) -> str:
return f"Result{{{self.total} total, docs: {self.docs}}}"

View File

@@ -0,0 +1,55 @@
from typing import Optional
from ._util import to_string
class Suggestion:
"""
Represents a single suggestion being sent or returned from the
autocomplete server
"""
def __init__(
self, string: str, score: float = 1.0, payload: Optional[str] = None
) -> None:
self.string = to_string(string)
self.payload = to_string(payload)
self.score = score
def __repr__(self) -> str:
return self.string
class SuggestionParser:
"""
Internal class used to parse results from the `SUGGET` command.
This needs to consume either 1, 2, or 3 values at a time from
the return value depending on what objects were requested
"""
def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None:
self.with_scores = with_scores
self.with_payloads = with_payloads
if with_scores and with_payloads:
self.sugsize = 3
self._scoreidx = 1
self._payloadidx = 2
elif with_scores:
self.sugsize = 2
self._scoreidx = 1
elif with_payloads:
self.sugsize = 2
self._payloadidx = 1
else:
self.sugsize = 1
self._scoreidx = -1
self._sugs = ret
def __iter__(self):
for i in range(0, len(self._sugs), self.sugsize):
ss = self._sugs[i]
score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0
payload = self._sugs[i + self._payloadidx] if self.with_payloads else None
yield Suggestion(ss, score, payload)

View File

@@ -0,0 +1,129 @@
import warnings
class SentinelCommands:
"""
A class containing the commands specific to redis sentinel. This class is
to be used as a mixin.
"""
def sentinel(self, *args):
"""Redis Sentinel's SENTINEL command."""
warnings.warn(DeprecationWarning("Use the individual sentinel_* methods"))
def sentinel_get_master_addr_by_name(self, service_name, return_responses=False):
"""
Returns a (host, port) pair for the given ``service_name`` when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL GET-MASTER-ADDR-BY-NAME",
service_name,
once=True,
return_responses=return_responses,
)
def sentinel_master(self, service_name, return_responses=False):
"""
Returns a dictionary containing the specified masters state, when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL MASTER", service_name, return_responses=return_responses
)
def sentinel_masters(self):
"""
Returns a list of dictionaries containing each master's state.
Important: This function is called by the Sentinel implementation and is
called directly on the Redis standalone client for sentinels,
so it doesn't support the "once" and "return_responses" options.
"""
return self.execute_command("SENTINEL MASTERS")
def sentinel_monitor(self, name, ip, port, quorum):
"""Add a new master to Sentinel to be monitored"""
return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum)
def sentinel_remove(self, name):
"""Remove a master from Sentinel's monitoring"""
return self.execute_command("SENTINEL REMOVE", name)
def sentinel_sentinels(self, service_name, return_responses=False):
"""
Returns a list of sentinels for ``service_name``, when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL SENTINELS", service_name, return_responses=return_responses
)
def sentinel_set(self, name, option, value):
"""Set Sentinel monitoring parameters for a given master"""
return self.execute_command("SENTINEL SET", name, option, value)
def sentinel_slaves(self, service_name):
"""
Returns a list of slaves for ``service_name``
Important: This function is called by the Sentinel implementation and is
called directly on the Redis standalone client for sentinels,
so it doesn't support the "once" and "return_responses" options.
"""
return self.execute_command("SENTINEL SLAVES", service_name)
def sentinel_reset(self, pattern):
"""
This command will reset all the masters with matching name.
The pattern argument is a glob-style pattern.
The reset process clears any previous state in a master (including a
failover in progress), and removes every slave and sentinel already
discovered and associated with the master.
"""
return self.execute_command("SENTINEL RESET", pattern, once=True)
def sentinel_failover(self, new_master_name):
"""
Force a failover as if the master was not reachable, and without
asking for agreement to other Sentinels (however a new version of the
configuration will be published so that the other Sentinels will
update their configurations).
"""
return self.execute_command("SENTINEL FAILOVER", new_master_name)
def sentinel_ckquorum(self, new_master_name):
"""
Check if the current Sentinel configuration is able to reach the
quorum needed to failover a master, and the majority needed to
authorize the failover.
This command should be used in monitoring systems to check if a
Sentinel deployment is ok.
"""
return self.execute_command("SENTINEL CKQUORUM", new_master_name, once=True)
def sentinel_flushconfig(self):
"""
Force Sentinel to rewrite its configuration on disk, including the
current Sentinel state.
Normally Sentinel rewrites the configuration every time something
changes in its state (in the context of the subset of the state which
is persisted on disk across restart).
However sometimes it is possible that the configuration file is lost
because of operation errors, disk failures, package upgrade scripts or
configuration managers. In those cases a way to to force Sentinel to
rewrite the configuration file is handy.
This command works even if the previous configuration file is
completely missing.
"""
return self.execute_command("SENTINEL FLUSHCONFIG")
class AsyncSentinelCommands(SentinelCommands):
async def sentinel(self, *args) -> None:
"""Redis Sentinel's SENTINEL command."""
super().sentinel(*args)

View File

@@ -0,0 +1,108 @@
import redis
from redis._parsers.helpers import bool_ok
from ..helpers import get_protocol_version, parse_to_list
from .commands import (
ALTER_CMD,
CREATE_CMD,
CREATERULE_CMD,
DEL_CMD,
DELETERULE_CMD,
GET_CMD,
INFO_CMD,
MGET_CMD,
MRANGE_CMD,
MREVRANGE_CMD,
QUERYINDEX_CMD,
RANGE_CMD,
REVRANGE_CMD,
TimeSeriesCommands,
)
from .info import TSInfo
from .utils import parse_get, parse_m_get, parse_m_range, parse_range
class TimeSeries(TimeSeriesCommands):
"""
This class subclasses redis-py's `Redis` and implements RedisTimeSeries's
commands (prefixed with "ts").
The client allows to interact with RedisTimeSeries and use all of it's
functionality.
"""
def __init__(self, client=None, **kwargs):
"""Create a new RedisTimeSeries client."""
# Set the module commands' callbacks
self._MODULE_CALLBACKS = {
ALTER_CMD: bool_ok,
CREATE_CMD: bool_ok,
CREATERULE_CMD: bool_ok,
DELETERULE_CMD: bool_ok,
}
_RESP2_MODULE_CALLBACKS = {
DEL_CMD: int,
GET_CMD: parse_get,
INFO_CMD: TSInfo,
MGET_CMD: parse_m_get,
MRANGE_CMD: parse_m_range,
MREVRANGE_CMD: parse_m_range,
RANGE_CMD: parse_range,
REVRANGE_CMD: parse_range,
QUERYINDEX_CMD: parse_to_list,
}
_RESP3_MODULE_CALLBACKS = {}
self.client = client
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS)
else:
self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS)
for k, v in self._MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the TimeSeries module, that can be used
for executing only TimeSeries commands and core commands.
Usage example:
r = redis.Redis()
pipe = r.ts().pipeline()
for i in range(100):
pipeline.add("with_pipeline", i, 1.1 * i)
pipeline.execute()
"""
if isinstance(self.client, redis.RedisCluster):
p = ClusterPipeline(
nodes_manager=self.client.nodes_manager,
commands_parser=self.client.commands_parser,
startup_nodes=self.client.nodes_manager.startup_nodes,
result_callbacks=self.client.result_callbacks,
cluster_response_callbacks=self.client.cluster_response_callbacks,
cluster_error_retry_attempts=self.client.retry.get_retries(),
read_from_replicas=self.client.read_from_replicas,
reinitialize_steps=self.client.reinitialize_steps,
lock=self.client._lock,
)
else:
p = Pipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
return p
class ClusterPipeline(TimeSeriesCommands, redis.cluster.ClusterPipeline):
"""Cluster pipeline for the module."""
class Pipeline(TimeSeriesCommands, redis.client.Pipeline):
"""Pipeline for the module."""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,91 @@
from ..helpers import nativestr
from .utils import list_to_dict
class TSInfo:
"""
Hold information and statistics on the time-series.
Can be created using ``tsinfo`` command
https://redis.io/docs/latest/commands/ts.info/
"""
rules = []
labels = []
sourceKey = None
chunk_count = None
memory_usage = None
total_samples = None
retention_msecs = None
last_time_stamp = None
first_time_stamp = None
max_samples_per_chunk = None
chunk_size = None
duplicate_policy = None
def __init__(self, args):
"""
Hold information and statistics on the time-series.
The supported params that can be passed as args:
rules:
A list of compaction rules of the time series.
sourceKey:
Key name for source time series in case the current series
is a target of a rule.
chunkCount:
Number of Memory Chunks used for the time series.
memoryUsage:
Total number of bytes allocated for the time series.
totalSamples:
Total number of samples in the time series.
labels:
A list of label-value pairs that represent the metadata
labels of the time series.
retentionTime:
Retention time, in milliseconds, for the time series.
lastTimestamp:
Last timestamp present in the time series.
firstTimestamp:
First timestamp present in the time series.
maxSamplesPerChunk:
Deprecated.
chunkSize:
Amount of memory, in bytes, allocated for data.
duplicatePolicy:
Policy that will define handling of duplicate samples.
Can read more about on
https://redis.io/docs/latest/develop/data-types/timeseries/configuration/#duplicate_policy
"""
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.rules = response.get("rules")
self.source_key = response.get("sourceKey")
self.chunk_count = response.get("chunkCount")
self.memory_usage = response.get("memoryUsage")
self.total_samples = response.get("totalSamples")
self.labels = list_to_dict(response.get("labels"))
self.retention_msecs = response.get("retentionTime")
self.last_timestamp = response.get("lastTimestamp")
self.first_timestamp = response.get("firstTimestamp")
if "maxSamplesPerChunk" in response:
self.max_samples_per_chunk = response["maxSamplesPerChunk"]
self.chunk_size = (
self.max_samples_per_chunk * 16
) # backward compatible changes
if "chunkSize" in response:
self.chunk_size = response["chunkSize"]
if "duplicatePolicy" in response:
self.duplicate_policy = response["duplicatePolicy"]
if isinstance(self.duplicate_policy, bytes):
self.duplicate_policy = self.duplicate_policy.decode()
def get(self, item):
try:
return self.__getitem__(item)
except AttributeError:
return None
def __getitem__(self, item):
return getattr(self, item)

View File

@@ -0,0 +1,44 @@
from ..helpers import nativestr
def list_to_dict(aList):
return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))}
def parse_range(response, **kwargs):
"""Parse range response. Used by TS.RANGE and TS.REVRANGE."""
return [tuple((r[0], float(r[1]))) for r in response]
def parse_m_range(response):
"""Parse multi range response. Used by TS.MRANGE and TS.MREVRANGE."""
res = []
for item in response:
res.append({nativestr(item[0]): [list_to_dict(item[1]), parse_range(item[2])]})
return sorted(res, key=lambda d: list(d.keys()))
def parse_get(response):
"""Parse get response. Used by TS.GET."""
if not response:
return None
return int(response[0]), float(response[1])
def parse_m_get(response):
"""Parse multi get response. Used by TS.MGET."""
res = []
for item in response:
if not item[2]:
res.append({nativestr(item[0]): [list_to_dict(item[1]), None, None]})
else:
res.append(
{
nativestr(item[0]): [
list_to_dict(item[1]),
int(item[2][0]),
float(item[2][1]),
]
}
)
return sorted(res, key=lambda d: list(d.keys()))

View File

@@ -0,0 +1,46 @@
import json
from redis._parsers.helpers import pairs_to_dict
from redis.commands.vectorset.utils import (
parse_vemb_result,
parse_vlinks_result,
parse_vsim_result,
)
from ..helpers import get_protocol_version
from .commands import (
VEMB_CMD,
VGETATTR_CMD,
VINFO_CMD,
VLINKS_CMD,
VSIM_CMD,
VectorSetCommands,
)
class VectorSet(VectorSetCommands):
def __init__(self, client, **kwargs):
"""Create a new VectorSet client."""
# Set the module commands' callbacks
self._MODULE_CALLBACKS = {
VEMB_CMD: parse_vemb_result,
VSIM_CMD: parse_vsim_result,
VGETATTR_CMD: lambda r: r and json.loads(r) or None,
}
self._RESP2_MODULE_CALLBACKS = {
VINFO_CMD: lambda r: r and pairs_to_dict(r) or None,
VLINKS_CMD: parse_vlinks_result,
}
self._RESP3_MODULE_CALLBACKS = {}
self.client = client
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
self._MODULE_CALLBACKS.update(self._RESP3_MODULE_CALLBACKS)
else:
self._MODULE_CALLBACKS.update(self._RESP2_MODULE_CALLBACKS)
for k, v in self._MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)

View File

@@ -0,0 +1,392 @@
import json
from enum import Enum
from typing import Any, Awaitable, Dict, List, Optional, Union
from redis.client import NEVER_DECODE
from redis.commands.helpers import get_protocol_version
from redis.exceptions import DataError
from redis.typing import CommandsProtocol, EncodableT, KeyT, Number
VADD_CMD = "VADD"
VSIM_CMD = "VSIM"
VREM_CMD = "VREM"
VDIM_CMD = "VDIM"
VCARD_CMD = "VCARD"
VEMB_CMD = "VEMB"
VLINKS_CMD = "VLINKS"
VINFO_CMD = "VINFO"
VSETATTR_CMD = "VSETATTR"
VGETATTR_CMD = "VGETATTR"
VRANDMEMBER_CMD = "VRANDMEMBER"
# Return type for vsim command
VSimResult = Optional[
List[
Union[
List[EncodableT], Dict[EncodableT, Number], Dict[EncodableT, Dict[str, Any]]
]
]
]
class QuantizationOptions(Enum):
"""Quantization options for the VADD command."""
NOQUANT = "NOQUANT"
BIN = "BIN"
Q8 = "Q8"
class CallbacksOptions(Enum):
"""Options that can be set for the commands callbacks"""
RAW = "RAW"
WITHSCORES = "WITHSCORES"
WITHATTRIBS = "WITHATTRIBS"
ALLOW_DECODING = "ALLOW_DECODING"
RESP3 = "RESP3"
class VectorSetCommands(CommandsProtocol):
"""Redis VectorSet commands"""
def vadd(
self,
key: KeyT,
vector: Union[List[float], bytes],
element: str,
reduce_dim: Optional[int] = None,
cas: Optional[bool] = False,
quantization: Optional[QuantizationOptions] = None,
ef: Optional[Number] = None,
attributes: Optional[Union[dict, str]] = None,
numlinks: Optional[int] = None,
) -> Union[Awaitable[int], int]:
"""
Add vector ``vector`` for element ``element`` to a vector set ``key``.
``reduce_dim`` sets the dimensions to reduce the vector to.
If not provided, the vector is not reduced.
``cas`` is a boolean flag that indicates whether to use CAS (check-and-set style)
when adding the vector. If not provided, CAS is not used.
``quantization`` sets the quantization type to use.
If not provided, int8 quantization is used.
The options are:
- NOQUANT: No quantization
- BIN: Binary quantization
- Q8: Signed 8-bit quantization
``ef`` sets the exploration factor to use.
If not provided, the default exploration factor is used.
``attributes`` is a dictionary or json string that contains the attributes to set for the vector.
If not provided, no attributes are set.
``numlinks`` sets the number of links to create for the vector.
If not provided, the default number of links is used.
For more information, see https://redis.io/commands/vadd.
"""
if not vector or not element:
raise DataError("Both vector and element must be provided")
pieces = []
if reduce_dim:
pieces.extend(["REDUCE", reduce_dim])
values_pieces = []
if isinstance(vector, bytes):
values_pieces.extend(["FP32", vector])
else:
values_pieces.extend(["VALUES", len(vector)])
values_pieces.extend(vector)
pieces.extend(values_pieces)
pieces.append(element)
if cas:
pieces.append("CAS")
if quantization:
pieces.append(quantization.value)
if ef:
pieces.extend(["EF", ef])
if attributes:
if isinstance(attributes, dict):
# transform attributes to json string
attributes_json = json.dumps(attributes)
else:
attributes_json = attributes
pieces.extend(["SETATTR", attributes_json])
if numlinks:
pieces.extend(["M", numlinks])
return self.execute_command(VADD_CMD, key, *pieces)
def vsim(
self,
key: KeyT,
input: Union[List[float], bytes, str],
with_scores: Optional[bool] = False,
with_attribs: Optional[bool] = False,
count: Optional[int] = None,
ef: Optional[Number] = None,
filter: Optional[str] = None,
filter_ef: Optional[str] = None,
truth: Optional[bool] = False,
no_thread: Optional[bool] = False,
epsilon: Optional[Number] = None,
) -> Union[Awaitable[VSimResult], VSimResult]:
"""
Compare a vector or element ``input`` with the other vectors in a vector set ``key``.
``with_scores`` sets if similarity scores should be returned for each element in the result.
``with_attribs`` ``with_attribs`` sets if the results should be returned with the
attributes of the elements in the result, or None when no attributes are present.
``count`` sets the number of results to return.
``ef`` sets the exploration factor.
``filter`` sets the filter that should be applied for the search.
``filter_ef`` sets the max filtering effort.
``truth`` when enabled, forces the command to perform a linear scan.
``no_thread`` when enabled forces the command to execute the search
on the data structure in the main thread.
``epsilon`` floating point between 0 and 1, if specified will return
only elements with distance no further than the specified one.
For more information, see https://redis.io/commands/vsim.
"""
if not input:
raise DataError("'input' should be provided")
pieces = []
options = {}
if isinstance(input, bytes):
pieces.extend(["FP32", input])
elif isinstance(input, list):
pieces.extend(["VALUES", len(input)])
pieces.extend(input)
else:
pieces.extend(["ELE", input])
if with_scores or with_attribs:
if get_protocol_version(self.client) in ["3", 3]:
options[CallbacksOptions.RESP3.value] = True
if with_scores:
pieces.append("WITHSCORES")
options[CallbacksOptions.WITHSCORES.value] = True
if with_attribs:
pieces.append("WITHATTRIBS")
options[CallbacksOptions.WITHATTRIBS.value] = True
if count:
pieces.extend(["COUNT", count])
if epsilon:
pieces.extend(["EPSILON", epsilon])
if ef:
pieces.extend(["EF", ef])
if filter:
pieces.extend(["FILTER", filter])
if filter_ef:
pieces.extend(["FILTER-EF", filter_ef])
if truth:
pieces.append("TRUTH")
if no_thread:
pieces.append("NOTHREAD")
return self.execute_command(VSIM_CMD, key, *pieces, **options)
def vdim(self, key: KeyT) -> Union[Awaitable[int], int]:
"""
Get the dimension of a vector set.
In the case of vectors that were populated using the `REDUCE`
option, for random projection, the vector set will report the size of
the projected (reduced) dimension.
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
For more information, see https://redis.io/commands/vdim.
"""
return self.execute_command(VDIM_CMD, key)
def vcard(self, key: KeyT) -> Union[Awaitable[int], int]:
"""
Get the cardinality(the number of elements) of a vector set with key ``key``.
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
For more information, see https://redis.io/commands/vcard.
"""
return self.execute_command(VCARD_CMD, key)
def vrem(self, key: KeyT, element: str) -> Union[Awaitable[int], int]:
"""
Remove an element from a vector set.
For more information, see https://redis.io/commands/vrem.
"""
return self.execute_command(VREM_CMD, key, element)
def vemb(
self, key: KeyT, element: str, raw: Optional[bool] = False
) -> Union[
Awaitable[Optional[Union[List[EncodableT], Dict[str, EncodableT]]]],
Optional[Union[List[EncodableT], Dict[str, EncodableT]]],
]:
"""
Get the approximated vector of an element ``element`` from vector set ``key``.
``raw`` is a boolean flag that indicates whether to return the
internal representation used by the vector.
For more information, see https://redis.io/commands/vemb.
"""
options = {}
pieces = []
pieces.extend([key, element])
if get_protocol_version(self.client) in ["3", 3]:
options[CallbacksOptions.RESP3.value] = True
if raw:
pieces.append("RAW")
options[NEVER_DECODE] = True
if (
hasattr(self.client, "connection_pool")
and self.client.connection_pool.connection_kwargs["decode_responses"]
) or (
hasattr(self.client, "nodes_manager")
and self.client.nodes_manager.connection_kwargs["decode_responses"]
):
# allow decoding in the postprocessing callback
# if the user set decode_responses=True
# in the connection pool
options[CallbacksOptions.ALLOW_DECODING.value] = True
options[CallbacksOptions.RAW.value] = True
return self.execute_command(VEMB_CMD, *pieces, **options)
def vlinks(
self, key: KeyT, element: str, with_scores: Optional[bool] = False
) -> Union[
Awaitable[
Optional[
List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]
]
],
Optional[List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]],
]:
"""
Returns the neighbors for each level the element ``element`` exists in the vector set ``key``.
The result is a list of lists, where each list contains the neighbors for one level.
If the element does not exist, or if the vector set does not exist, None is returned.
If the ``WITHSCORES`` option is provided, the result is a list of dicts,
where each dict contains the neighbors for one level, with the scores as values.
For more information, see https://redis.io/commands/vlinks
"""
options = {}
pieces = []
pieces.extend([key, element])
if with_scores:
pieces.append("WITHSCORES")
options[CallbacksOptions.WITHSCORES.value] = True
return self.execute_command(VLINKS_CMD, *pieces, **options)
def vinfo(self, key: KeyT) -> Union[Awaitable[dict], dict]:
"""
Get information about a vector set.
For more information, see https://redis.io/commands/vinfo.
"""
return self.execute_command(VINFO_CMD, key)
def vsetattr(
self, key: KeyT, element: str, attributes: Optional[Union[dict, str]] = None
) -> Union[Awaitable[int], int]:
"""
Associate or remove JSON attributes ``attributes`` of element ``element``
for vector set ``key``.
For more information, see https://redis.io/commands/vsetattr
"""
if attributes is None:
attributes_json = "{}"
elif isinstance(attributes, dict):
# transform attributes to json string
attributes_json = json.dumps(attributes)
else:
attributes_json = attributes
return self.execute_command(VSETATTR_CMD, key, element, attributes_json)
def vgetattr(
self, key: KeyT, element: str
) -> Union[Optional[Awaitable[dict]], Optional[dict]]:
"""
Retrieve the JSON attributes of an element ``element `` for vector set ``key``.
If the element does not exist, or if the vector set does not exist, None is
returned.
For more information, see https://redis.io/commands/vgetattr.
"""
return self.execute_command(VGETATTR_CMD, key, element)
def vrandmember(
self, key: KeyT, count: Optional[int] = None
) -> Union[
Awaitable[Optional[Union[List[str], str]]], Optional[Union[List[str], str]]
]:
"""
Returns random elements from a vector set ``key``.
``count`` is the number of elements to return.
If ``count`` is not provided, a single element is returned as a single string.
If ``count`` is positive(smaller than the number of elements
in the vector set), the command returns a list with up to ``count``
distinct elements from the vector set
If ``count`` is negative, the command returns a list with ``count`` random elements,
potentially with duplicates.
If ``count`` is greater than the number of elements in the vector set,
only the entire set is returned as a list.
If the vector set does not exist, ``None`` is returned.
For more information, see https://redis.io/commands/vrandmember.
"""
pieces = []
pieces.append(key)
if count is not None:
pieces.append(count)
return self.execute_command(VRANDMEMBER_CMD, *pieces)

View File

@@ -0,0 +1,130 @@
import json
from redis._parsers.helpers import pairs_to_dict
from redis.commands.vectorset.commands import CallbacksOptions
def parse_vemb_result(response, **options):
"""
Handle VEMB result since the command can returning different result
structures depending on input options and on quantization type of the vector set.
Parsing VEMB result into:
- List[Union[bytes, Union[int, float]]]
- Dict[str, Union[bytes, str, float]]
"""
if response is None:
return response
if options.get(CallbacksOptions.RAW.value):
result = {}
result["quantization"] = (
response[0].decode("utf-8")
if options.get(CallbacksOptions.ALLOW_DECODING.value)
else response[0]
)
result["raw"] = response[1]
result["l2"] = float(response[2])
if len(response) > 3:
result["range"] = float(response[3])
return result
else:
if options.get(CallbacksOptions.RESP3.value):
return response
result = []
for i in range(len(response)):
try:
result.append(int(response[i]))
except ValueError:
# if the value is not an integer, it should be a float
result.append(float(response[i]))
return result
def parse_vlinks_result(response, **options):
"""
Handle VLINKS result since the command can be returning different result
structures depending on input options.
Parsing VLINKS result into:
- List[List[str]]
- List[Dict[str, Number]]
"""
if response is None:
return response
if options.get(CallbacksOptions.WITHSCORES.value):
result = []
# Redis will return a list of list of strings.
# This list have to be transformed to list of dicts
for level_item in response:
level_data_dict = {}
for key, value in pairs_to_dict(level_item).items():
value = float(value)
level_data_dict[key] = value
result.append(level_data_dict)
return result
else:
# return the list of elements for each level
# list of lists
return response
def parse_vsim_result(response, **options):
"""
Handle VSIM result since the command can be returning different result
structures depending on input options.
Parsing VSIM result into:
- List[List[str]]
- List[Dict[str, Number]] - when with_scores is used (without attributes)
- List[Dict[str, Mapping[str, Any]]] - when with_attribs is used (without scores)
- List[Dict[str, Union[Number, Mapping[str, Any]]]] - when with_scores and with_attribs are used
"""
if response is None:
return response
withscores = bool(options.get(CallbacksOptions.WITHSCORES.value))
withattribs = bool(options.get(CallbacksOptions.WITHATTRIBS.value))
# Exactly one of withscores or withattribs is True
if (withscores and not withattribs) or (not withscores and withattribs):
# Redis will return a list of list of pairs.
# This list have to be transformed to dict
result_dict = {}
if options.get(CallbacksOptions.RESP3.value):
resp_dict = response
else:
resp_dict = pairs_to_dict(response)
for key, value in resp_dict.items():
if withscores:
value = float(value)
else:
value = json.loads(value) if value else None
result_dict[key] = value
return result_dict
elif withscores and withattribs:
it = iter(response)
result_dict = {}
if options.get(CallbacksOptions.RESP3.value):
for elem, data in response.items():
if data[1] is not None:
attribs_dict = json.loads(data[1])
else:
attribs_dict = None
result_dict[elem] = {"score": data[0], "attributes": attribs_dict}
else:
for elem, score, attribs in zip(it, it, it):
if attribs is not None:
attribs_dict = json.loads(attribs)
else:
attribs_dict = None
result_dict[elem] = {"score": float(score), "attributes": attribs_dict}
return result_dict
else:
# return the list of elements for each level
# list of lists
return response

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
from binascii import crc_hqx
from redis.typing import EncodedT
# Redis Cluster's key space is divided into 16384 slots.
# For more information see: https://github.com/redis/redis/issues/2576
REDIS_CLUSTER_HASH_SLOTS = 16384
__all__ = ["key_slot", "REDIS_CLUSTER_HASH_SLOTS"]
def key_slot(key: EncodedT, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int:
"""Calculate key slot for a given key.
See Keys distribution model in https://redis.io/topics/cluster-spec
:param key - bytes
:param bucket - int
"""
start = key.find(b"{")
if start > -1:
end = key.find(b"}", start + 1)
if end > -1 and end != start + 1:
key = key[start + 1 : end]
return crc_hqx(key, 0) % bucket

View File

@@ -0,0 +1,65 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Tuple, Union
logger = logging.getLogger(__name__)
class CredentialProvider:
"""
Credentials Provider.
"""
def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
raise NotImplementedError("get_credentials must be implemented")
async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
logger.warning(
"This method is added for backward compatability. "
"Please override it in your implementation."
)
return self.get_credentials()
class StreamingCredentialProvider(CredentialProvider, ABC):
"""
Credential provider that streams credentials in the background.
"""
@abstractmethod
def on_next(self, callback: Callable[[Any], None]):
"""
Specifies the callback that should be invoked
when the next credentials will be retrieved.
:param callback: Callback with
:return:
"""
pass
@abstractmethod
def on_error(self, callback: Callable[[Exception], None]):
pass
@abstractmethod
def is_streaming(self) -> bool:
pass
class UsernamePasswordCredentialProvider(CredentialProvider):
"""
Simple implementation of CredentialProvider that just wraps static
username and password.
"""
def __init__(self, username: Optional[str] = None, password: Optional[str] = None):
self.username = username or ""
self.password = password or ""
def get_credentials(self):
if self.username:
return self.username, self.password
return (self.password,)
async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
return self.get_credentials()

View File

@@ -0,0 +1,81 @@
import threading
from typing import Any, Generic, List, TypeVar
from redis.typing import Number
T = TypeVar("T")
class WeightedList(Generic[T]):
"""
Thread-safe weighted list.
"""
def __init__(self):
self._items: List[tuple[Any, Number]] = []
self._lock = threading.RLock()
def add(self, item: Any, weight: float) -> None:
"""Add item with weight, maintaining sorted order"""
with self._lock:
# Find insertion point using binary search
left, right = 0, len(self._items)
while left < right:
mid = (left + right) // 2
if self._items[mid][1] < weight:
right = mid
else:
left = mid + 1
self._items.insert(left, (item, weight))
def remove(self, item):
"""Remove first occurrence of item"""
with self._lock:
for i, (stored_item, weight) in enumerate(self._items):
if stored_item == item:
self._items.pop(i)
return weight
raise ValueError("Item not found")
def get_by_weight_range(
self, min_weight: float, max_weight: float
) -> List[tuple[Any, Number]]:
"""Get all items within weight range"""
with self._lock:
result = []
for item, weight in self._items:
if min_weight <= weight <= max_weight:
result.append((item, weight))
return result
def get_top_n(self, n: int) -> List[tuple[Any, Number]]:
"""Get top N the highest weighted items"""
with self._lock:
return [(item, weight) for item, weight in self._items[:n]]
def update_weight(self, item, new_weight: float):
with self._lock:
"""Update weight of an item"""
old_weight = self.remove(item)
self.add(item, new_weight)
return old_weight
def __iter__(self):
"""Iterate in descending weight order"""
with self._lock:
items_copy = (
self._items.copy()
) # Create snapshot as lock released after each 'yield'
for item, weight in items_copy:
yield item, weight
def __len__(self):
with self._lock:
return len(self._items)
def __getitem__(self, index) -> tuple[Any, Number]:
with self._lock:
item, weight = self._items[index]
return item, weight

View File

@@ -0,0 +1,468 @@
import asyncio
import threading
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, List, Optional, Type, Union
from redis.auth.token import TokenInterface
from redis.credentials import CredentialProvider, StreamingCredentialProvider
class EventListenerInterface(ABC):
"""
Represents a listener for given event object.
"""
@abstractmethod
def listen(self, event: object):
pass
class AsyncEventListenerInterface(ABC):
"""
Represents an async listener for given event object.
"""
@abstractmethod
async def listen(self, event: object):
pass
class EventDispatcherInterface(ABC):
"""
Represents a dispatcher that dispatches events to listeners
associated with given event.
"""
@abstractmethod
def dispatch(self, event: object):
pass
@abstractmethod
async def dispatch_async(self, event: object):
pass
@abstractmethod
def register_listeners(
self,
mappings: Dict[
Type[object],
List[Union[EventListenerInterface, AsyncEventListenerInterface]],
],
):
"""Register additional listeners."""
pass
class EventException(Exception):
"""
Exception wrapper that adds an event object into exception context.
"""
def __init__(self, exception: Exception, event: object):
self.exception = exception
self.event = event
super().__init__(exception)
class EventDispatcher(EventDispatcherInterface):
# TODO: Make dispatcher to accept external mappings.
def __init__(
self,
event_listeners: Optional[
Dict[Type[object], List[EventListenerInterface]]
] = None,
):
"""
Dispatcher that dispatches events to listeners associated with given event.
"""
self._event_listeners_mapping: Dict[
Type[object], List[EventListenerInterface]
] = {
AfterConnectionReleasedEvent: [
ReAuthConnectionListener(),
],
AfterPooledConnectionsInstantiationEvent: [
RegisterReAuthForPooledConnections()
],
AfterSingleConnectionInstantiationEvent: [
RegisterReAuthForSingleConnection()
],
AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()],
AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()],
AsyncAfterConnectionReleasedEvent: [
AsyncReAuthConnectionListener(),
],
}
self._lock = threading.Lock()
self._async_lock = None
if event_listeners:
self.register_listeners(event_listeners)
def dispatch(self, event: object):
with self._lock:
listeners = self._event_listeners_mapping.get(type(event), [])
for listener in listeners:
listener.listen(event)
async def dispatch_async(self, event: object):
if self._async_lock is None:
self._async_lock = asyncio.Lock()
async with self._async_lock:
listeners = self._event_listeners_mapping.get(type(event), [])
for listener in listeners:
await listener.listen(event)
def register_listeners(
self,
mappings: Dict[
Type[object],
List[Union[EventListenerInterface, AsyncEventListenerInterface]],
],
):
with self._lock:
for event_type in mappings:
if event_type in self._event_listeners_mapping:
self._event_listeners_mapping[event_type] = list(
set(
self._event_listeners_mapping[event_type]
+ mappings[event_type]
)
)
else:
self._event_listeners_mapping[event_type] = mappings[event_type]
class AfterConnectionReleasedEvent:
"""
Event that will be fired before each command execution.
"""
def __init__(self, connection):
self._connection = connection
@property
def connection(self):
return self._connection
class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent):
pass
class ClientType(Enum):
SYNC = ("sync",)
ASYNC = ("async",)
class AfterPooledConnectionsInstantiationEvent:
"""
Event that will be fired after pooled connection instances was created.
"""
def __init__(
self,
connection_pools: List,
client_type: ClientType,
credential_provider: Optional[CredentialProvider] = None,
):
self._connection_pools = connection_pools
self._client_type = client_type
self._credential_provider = credential_provider
@property
def connection_pools(self):
return self._connection_pools
@property
def client_type(self) -> ClientType:
return self._client_type
@property
def credential_provider(self) -> Union[CredentialProvider, None]:
return self._credential_provider
class AfterSingleConnectionInstantiationEvent:
"""
Event that will be fired after single connection instances was created.
:param connection_lock: For sync client thread-lock should be provided,
for async asyncio.Lock
"""
def __init__(
self,
connection,
client_type: ClientType,
connection_lock: Union[threading.RLock, asyncio.Lock],
):
self._connection = connection
self._client_type = client_type
self._connection_lock = connection_lock
@property
def connection(self):
return self._connection
@property
def client_type(self) -> ClientType:
return self._client_type
@property
def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
return self._connection_lock
class AfterPubSubConnectionInstantiationEvent:
def __init__(
self,
pubsub_connection,
connection_pool,
client_type: ClientType,
connection_lock: Union[threading.RLock, asyncio.Lock],
):
self._pubsub_connection = pubsub_connection
self._connection_pool = connection_pool
self._client_type = client_type
self._connection_lock = connection_lock
@property
def pubsub_connection(self):
return self._pubsub_connection
@property
def connection_pool(self):
return self._connection_pool
@property
def client_type(self) -> ClientType:
return self._client_type
@property
def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
return self._connection_lock
class AfterAsyncClusterInstantiationEvent:
"""
Event that will be fired after async cluster instance was created.
Async cluster doesn't use connection pools,
instead ClusterNode object manages connections.
"""
def __init__(
self,
nodes: dict,
credential_provider: Optional[CredentialProvider] = None,
):
self._nodes = nodes
self._credential_provider = credential_provider
@property
def nodes(self) -> dict:
return self._nodes
@property
def credential_provider(self) -> Union[CredentialProvider, None]:
return self._credential_provider
class OnCommandsFailEvent:
"""
Event fired whenever a command fails during the execution.
"""
def __init__(
self,
commands: tuple,
exception: Exception,
):
self._commands = commands
self._exception = exception
@property
def commands(self) -> tuple:
return self._commands
@property
def exception(self) -> Exception:
return self._exception
class AsyncOnCommandsFailEvent(OnCommandsFailEvent):
pass
class ReAuthConnectionListener(EventListenerInterface):
"""
Listener that performs re-authentication of given connection.
"""
def listen(self, event: AfterConnectionReleasedEvent):
event.connection.re_auth()
class AsyncReAuthConnectionListener(AsyncEventListenerInterface):
"""
Async listener that performs re-authentication of given connection.
"""
async def listen(self, event: AsyncAfterConnectionReleasedEvent):
await event.connection.re_auth()
class RegisterReAuthForPooledConnections(EventListenerInterface):
"""
Listener that registers a re-authentication callback for pooled connections.
Required by :class:`StreamingCredentialProvider`.
"""
def __init__(self):
self._event = None
def listen(self, event: AfterPooledConnectionsInstantiationEvent):
if isinstance(event.credential_provider, StreamingCredentialProvider):
self._event = event
if event.client_type == ClientType.SYNC:
event.credential_provider.on_next(self._re_auth)
event.credential_provider.on_error(self._raise_on_error)
else:
event.credential_provider.on_next(self._re_auth_async)
event.credential_provider.on_error(self._raise_on_error_async)
def _re_auth(self, token):
for pool in self._event.connection_pools:
pool.re_auth_callback(token)
async def _re_auth_async(self, token):
for pool in self._event.connection_pools:
await pool.re_auth_callback(token)
def _raise_on_error(self, error: Exception):
raise EventException(error, self._event)
async def _raise_on_error_async(self, error: Exception):
raise EventException(error, self._event)
class RegisterReAuthForSingleConnection(EventListenerInterface):
"""
Listener that registers a re-authentication callback for single connection.
Required by :class:`StreamingCredentialProvider`.
"""
def __init__(self):
self._event = None
def listen(self, event: AfterSingleConnectionInstantiationEvent):
if isinstance(
event.connection.credential_provider, StreamingCredentialProvider
):
self._event = event
if event.client_type == ClientType.SYNC:
event.connection.credential_provider.on_next(self._re_auth)
event.connection.credential_provider.on_error(self._raise_on_error)
else:
event.connection.credential_provider.on_next(self._re_auth_async)
event.connection.credential_provider.on_error(
self._raise_on_error_async
)
def _re_auth(self, token):
with self._event.connection_lock:
self._event.connection.send_command(
"AUTH", token.try_get("oid"), token.get_value()
)
self._event.connection.read_response()
async def _re_auth_async(self, token):
async with self._event.connection_lock:
await self._event.connection.send_command(
"AUTH", token.try_get("oid"), token.get_value()
)
await self._event.connection.read_response()
def _raise_on_error(self, error: Exception):
raise EventException(error, self._event)
async def _raise_on_error_async(self, error: Exception):
raise EventException(error, self._event)
class RegisterReAuthForAsyncClusterNodes(EventListenerInterface):
def __init__(self):
self._event = None
def listen(self, event: AfterAsyncClusterInstantiationEvent):
if isinstance(event.credential_provider, StreamingCredentialProvider):
self._event = event
event.credential_provider.on_next(self._re_auth)
event.credential_provider.on_error(self._raise_on_error)
async def _re_auth(self, token: TokenInterface):
for key in self._event.nodes:
await self._event.nodes[key].re_auth_callback(token)
async def _raise_on_error(self, error: Exception):
raise EventException(error, self._event)
class RegisterReAuthForPubSub(EventListenerInterface):
def __init__(self):
self._connection = None
self._connection_pool = None
self._client_type = None
self._connection_lock = None
self._event = None
def listen(self, event: AfterPubSubConnectionInstantiationEvent):
if isinstance(
event.pubsub_connection.credential_provider, StreamingCredentialProvider
) and event.pubsub_connection.get_protocol() in [3, "3"]:
self._event = event
self._connection = event.pubsub_connection
self._connection_pool = event.connection_pool
self._client_type = event.client_type
self._connection_lock = event.connection_lock
if self._client_type == ClientType.SYNC:
self._connection.credential_provider.on_next(self._re_auth)
self._connection.credential_provider.on_error(self._raise_on_error)
else:
self._connection.credential_provider.on_next(self._re_auth_async)
self._connection.credential_provider.on_error(
self._raise_on_error_async
)
def _re_auth(self, token: TokenInterface):
with self._connection_lock:
self._connection.send_command(
"AUTH", token.try_get("oid"), token.get_value()
)
self._connection.read_response()
self._connection_pool.re_auth_callback(token)
async def _re_auth_async(self, token: TokenInterface):
async with self._connection_lock:
await self._connection.send_command(
"AUTH", token.try_get("oid"), token.get_value()
)
await self._connection.read_response()
await self._connection_pool.re_auth_callback(token)
def _raise_on_error(self, error: Exception):
raise EventException(error, self._event)
async def _raise_on_error_async(self, error: Exception):
raise EventException(error, self._event)

View File

@@ -0,0 +1,255 @@
"Core exceptions raised by the Redis client"
class RedisError(Exception):
pass
class ConnectionError(RedisError):
pass
class TimeoutError(RedisError):
pass
class AuthenticationError(ConnectionError):
pass
class AuthorizationError(ConnectionError):
pass
class BusyLoadingError(ConnectionError):
pass
class InvalidResponse(RedisError):
pass
class ResponseError(RedisError):
pass
class DataError(RedisError):
pass
class PubSubError(RedisError):
pass
class WatchError(RedisError):
pass
class NoScriptError(ResponseError):
pass
class OutOfMemoryError(ResponseError):
"""
Indicates the database is full. Can only occur when either:
* Redis maxmemory-policy=noeviction
* Redis maxmemory-policy=volatile* and there are no evictable keys
For more information see `Memory optimization in Redis <https://redis.io/docs/management/optimization/memory-optimization/#memory-allocation>`_. # noqa
"""
pass
class ExecAbortError(ResponseError):
pass
class ReadOnlyError(ResponseError):
pass
class NoPermissionError(ResponseError):
pass
class ModuleError(ResponseError):
pass
class LockError(RedisError, ValueError):
"Errors acquiring or releasing a lock"
# NOTE: For backwards compatibility, this class derives from ValueError.
# This was originally chosen to behave like threading.Lock.
def __init__(self, message=None, lock_name=None):
self.message = message
self.lock_name = lock_name
class LockNotOwnedError(LockError):
"Error trying to extend or release a lock that is not owned (anymore)"
pass
class ChildDeadlockedError(Exception):
"Error indicating that a child process is deadlocked after a fork()"
pass
class AuthenticationWrongNumberOfArgsError(ResponseError):
"""
An error to indicate that the wrong number of args
were sent to the AUTH command
"""
pass
class RedisClusterException(Exception):
"""
Base exception for the RedisCluster client
"""
pass
class ClusterError(RedisError):
"""
Cluster errors occurred multiple times, resulting in an exhaustion of the
command execution TTL
"""
pass
class ClusterDownError(ClusterError, ResponseError):
"""
Error indicated CLUSTERDOWN error received from cluster.
By default Redis Cluster nodes stop accepting queries if they detect there
is at least a hash slot uncovered (no available node is serving it).
This way if the cluster is partially down (for example a range of hash
slots are no longer covered) the entire cluster eventually becomes
unavailable. It automatically returns available as soon as all the slots
are covered again.
"""
def __init__(self, resp):
self.args = (resp,)
self.message = resp
class AskError(ResponseError):
"""
Error indicated ASK error received from cluster.
When a slot is set as MIGRATING, the node will accept all queries that
pertain to this hash slot, but only if the key in question exists,
otherwise the query is forwarded using a -ASK redirection to the node that
is target of the migration.
src node: MIGRATING to dst node
get > ASK error
ask dst node > ASKING command
dst node: IMPORTING from src node
asking command only affects next command
any op will be allowed after asking command
"""
def __init__(self, resp):
"""should only redirect to master node"""
self.args = (resp,)
self.message = resp
slot_id, new_node = resp.split(" ")
host, port = new_node.rsplit(":", 1)
self.slot_id = int(slot_id)
self.node_addr = self.host, self.port = host, int(port)
class TryAgainError(ResponseError):
"""
Error indicated TRYAGAIN error received from cluster.
Operations on keys that don't exist or are - during resharding - split
between the source and destination nodes, will generate a -TRYAGAIN error.
"""
def __init__(self, *args, **kwargs):
pass
class ClusterCrossSlotError(ResponseError):
"""
Error indicated CROSSSLOT error received from cluster.
A CROSSSLOT error is generated when keys in a request don't hash to the
same slot.
"""
message = "Keys in request don't hash to the same slot"
class MovedError(AskError):
"""
Error indicated MOVED error received from cluster.
A request sent to a node that doesn't serve this key will be replayed with
a MOVED error that points to the correct node.
"""
pass
class MasterDownError(ClusterDownError):
"""
Error indicated MASTERDOWN error received from cluster.
Link with MASTER is down and replica-serve-stale-data is set to 'no'.
"""
pass
class SlotNotCoveredError(RedisClusterException):
"""
This error only happens in the case where the connection pool will try to
fetch what node that is covered by a given slot.
If this error is raised the client should drop the current node layout and
attempt to reconnect and refresh the node layout again
"""
pass
class MaxConnectionsError(ConnectionError):
"""
Raised when a connection pool has reached its max_connections limit.
This indicates pool exhaustion rather than an actual connection failure.
"""
pass
class CrossSlotTransactionError(RedisClusterException):
"""
Raised when a transaction or watch is triggered in a pipeline
and not all keys or all commands belong to the same slot.
"""
pass
class InvalidPipelineStack(RedisClusterException):
"""
Raised on unexpected response length on pipelines. This is
most likely a handling error on the stack.
"""
pass
class ExternalAuthProviderError(ConnectionError):
"""
Raised when an external authentication provider returns an error.
"""
pass

View File

@@ -0,0 +1,425 @@
from __future__ import annotations
import base64
import gzip
import json
import ssl
import zlib
from dataclasses import dataclass
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode, urljoin
from urllib.request import Request, urlopen
__all__ = ["HttpClient", "HttpResponse", "HttpError", "DEFAULT_TIMEOUT"]
from redis.backoff import ExponentialWithJitterBackoff
from redis.retry import Retry
from redis.utils import dummy_fail
DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)"
DEFAULT_TIMEOUT = 30.0
RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
@dataclass
class HttpResponse:
status: int
headers: Dict[str, str]
url: str
content: bytes
def text(self, encoding: Optional[str] = None) -> str:
enc = encoding or self._get_encoding()
return self.content.decode(enc, errors="replace")
def json(self) -> Any:
return json.loads(self.text(encoding=self._get_encoding()))
def _get_encoding(self) -> str:
# Try to infer encoding from headers; default to utf-8
ctype = self.headers.get("content-type", "")
# Example: application/json; charset=utf-8
for part in ctype.split(";"):
p = part.strip()
if p.lower().startswith("charset="):
return p.split("=", 1)[1].strip() or "utf-8"
return "utf-8"
class HttpError(Exception):
def __init__(self, status: int, url: str, message: Optional[str] = None):
self.status = status
self.url = url
self.message = message or f"HTTP {status} for {url}"
super().__init__(self.message)
class HttpClient:
"""
A lightweight HTTP client for REST API calls.
"""
def __init__(
self,
base_url: str = "",
headers: Optional[Mapping[str, str]] = None,
timeout: float = DEFAULT_TIMEOUT,
retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
),
verify_tls: bool = True,
# TLS verification (server) options
ca_file: Optional[str] = None,
ca_path: Optional[str] = None,
ca_data: Optional[Union[str, bytes]] = None,
# Mutual TLS (client cert) options
client_cert_file: Optional[str] = None,
client_key_file: Optional[str] = None,
client_key_password: Optional[str] = None,
auth_basic: Optional[Tuple[str, str]] = None, # (username, password)
user_agent: str = DEFAULT_USER_AGENT,
) -> None:
"""
Initialize a new HTTP client instance.
Args:
base_url: Base URL for all requests. Will be prefixed to all paths.
headers: Default headers to include in all requests.
timeout: Default timeout in seconds for requests.
retry: Retry configuration for failed requests.
verify_tls: Whether to verify TLS certificates.
ca_file: Path to CA certificate file for TLS verification.
ca_path: Path to a directory containing CA certificates.
ca_data: CA certificate data as string or bytes.
client_cert_file: Path to client certificate for mutual TLS.
client_key_file: Path to a client private key for mutual TLS.
client_key_password: Password for an encrypted client private key.
auth_basic: Tuple of (username, password) for HTTP basic auth.
user_agent: User-Agent header value for requests.
The client supports both regular HTTPS with server verification and mutual TLS
authentication. For server verification, provide CA certificate information via
ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client
certificate and key via client_cert_file and client_key_file.
"""
self.base_url = (
base_url.rstrip() + "/"
if base_url and not base_url.endswith("/")
else base_url
)
self._default_headers = {k.lower(): v for k, v in (headers or {}).items()}
self.timeout = timeout
self.retry = retry
self.retry.update_supported_errors((HTTPError, URLError, ssl.SSLError))
self.verify_tls = verify_tls
# TLS settings
self.ca_file = ca_file
self.ca_path = ca_path
self.ca_data = ca_data
self.client_cert_file = client_cert_file
self.client_key_file = client_key_file
self.client_key_password = client_key_password
self.auth_basic = auth_basic
self.user_agent = user_agent
# Public JSON-centric helpers
def get(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
return self._json_call(
"GET",
path,
params=params,
headers=headers,
timeout=timeout,
body=None,
expect_json=expect_json,
)
def delete(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
return self._json_call(
"DELETE",
path,
params=params,
headers=headers,
timeout=timeout,
body=None,
expect_json=expect_json,
)
def post(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
return self._json_call(
"POST",
path,
params=params,
headers=headers,
timeout=timeout,
body=self._prepare_body(json_body=json_body, data=data),
expect_json=expect_json,
)
def put(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
return self._json_call(
"PUT",
path,
params=params,
headers=headers,
timeout=timeout,
body=self._prepare_body(json_body=json_body, data=data),
expect_json=expect_json,
)
def patch(
self,
path: str,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
return self._json_call(
"PATCH",
path,
params=params,
headers=headers,
timeout=timeout,
body=self._prepare_body(json_body=json_body, data=data),
expect_json=expect_json,
)
# Low-level request
def request(
self,
method: str,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[Union[bytes, str]] = None,
timeout: Optional[float] = None,
) -> HttpResponse:
url = self._build_url(path, params)
all_headers = self._prepare_headers(headers, body)
data = body.encode("utf-8") if isinstance(body, str) else body
req = Request(url=url, method=method.upper(), data=data, headers=all_headers)
context: Optional[ssl.SSLContext] = None
if url.lower().startswith("https"):
if self.verify_tls:
# Use provided CA material if any; fall back to system defaults
context = ssl.create_default_context(
cafile=self.ca_file,
capath=self.ca_path,
cadata=self.ca_data,
)
# Load client certificate for mTLS if configured
if self.client_cert_file:
context.load_cert_chain(
certfile=self.client_cert_file,
keyfile=self.client_key_file,
password=self.client_key_password,
)
else:
# Verification disabled
context = ssl.create_default_context()
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
try:
return self.retry.call_with_retry(
lambda: self._make_request(req, context=context, timeout=timeout),
lambda _: dummy_fail(),
lambda error: self._is_retryable_http_error(error),
)
except HTTPError as e:
# Read error body, build response, and decide on retry
err_body = b""
try:
err_body = e.read()
except Exception:
pass
headers_map = {k.lower(): v for k, v in (e.headers or {}).items()}
err_body = self._maybe_decompress(err_body, headers_map)
status = getattr(e, "code", 0) or 0
response = HttpResponse(
status=status,
headers=headers_map,
url=url,
content=err_body,
)
return response
def _make_request(
self,
request: Request,
context: Optional[ssl.SSLContext] = None,
timeout: Optional[float] = None,
):
with urlopen(request, timeout=timeout or self.timeout, context=context) as resp:
raw = resp.read()
headers_map = {k.lower(): v for k, v in resp.headers.items()}
raw = self._maybe_decompress(raw, headers_map)
return HttpResponse(
status=resp.status,
headers=headers_map,
url=resp.geturl(),
content=raw,
)
def _is_retryable_http_error(self, error: Exception) -> bool:
if isinstance(error, HTTPError):
return self._should_retry_status(error.code)
return False
# Internal utilities
def _json_call(
self,
method: str,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: Optional[float] = None,
body: Optional[Union[bytes, str]] = None,
expect_json: bool = True,
) -> Union[HttpResponse, Any]:
resp = self.request(
method=method,
path=path,
params=params,
headers=headers,
body=body,
timeout=timeout,
)
if not (200 <= resp.status < 400):
raise HttpError(resp.status, resp.url, resp.text())
if expect_json:
return resp.json()
return resp
def _prepare_body(
self, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None
) -> Optional[Union[bytes, str]]:
if json_body is not None and data is not None:
raise ValueError("Provide either json_body or data, not both.")
if json_body is not None:
return json.dumps(json_body, ensure_ascii=False, separators=(",", ":"))
return data
def _build_url(
self,
path: str,
params: Optional[
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
] = None,
) -> str:
url = urljoin(self.base_url or "", path)
if params:
# urlencode with doseq=True supports list/tuple values
query = urlencode(
{k: v for k, v in params.items() if v is not None}, doseq=True
)
separator = "&" if ("?" in url) else "?"
url = f"{url}{separator}{query}" if query else url
return url
def _prepare_headers(
self, headers: Optional[Mapping[str, str]], body: Optional[Union[bytes, str]]
) -> Dict[str, str]:
# Start with defaults
prepared: Dict[str, str] = {}
prepared.update(self._default_headers)
# Standard defaults for JSON REST usage
prepared.setdefault("accept", "application/json")
prepared.setdefault("user-agent", self.user_agent)
# We will send gzip accept-encoding; handle decompression manually
prepared.setdefault("accept-encoding", "gzip, deflate")
# If we have a string body and content-type not specified, assume JSON
if body is not None and isinstance(body, str):
prepared.setdefault("content-type", "application/json; charset=utf-8")
# Basic authentication if provided and not overridden
if self.auth_basic and "authorization" not in prepared:
user, pwd = self.auth_basic
token = base64.b64encode(f"{user}:{pwd}".encode("utf-8")).decode("ascii")
prepared["authorization"] = f"Basic {token}"
# Merge per-call headers (case-insensitive)
if headers:
for k, v in headers.items():
prepared[k.lower()] = v
# urllib expects header keys in canonical capitalization sometimes; but its tolerant.
# We'll return as provided; urllib will handle it.
return prepared
def _should_retry_status(self, status: int) -> bool:
return status in RETRY_STATUS_CODES
def _maybe_decompress(self, content: bytes, headers: Mapping[str, str]) -> bytes:
if not content:
return content
encoding = (headers.get("content-encoding") or "").lower()
try:
if "gzip" in encoding:
return gzip.decompress(content)
if "deflate" in encoding:
# Try raw deflate, then zlib-wrapped
try:
return zlib.decompress(content, -zlib.MAX_WBITS)
except zlib.error:
return zlib.decompress(content)
except Exception:
# If decompression fails, return original bytes
return content
return content

View File

@@ -0,0 +1,343 @@
import logging
import threading
import time as mod_time
import uuid
from types import SimpleNamespace, TracebackType
from typing import Optional, Type
from redis.exceptions import LockError, LockNotOwnedError
from redis.typing import Number
logger = logging.getLogger(__name__)
class Lock:
"""
A shared, distributed Lock. Using Redis for locking allows the Lock
to be shared across processes and/or machines.
It's left to the user to resolve deadlock issues and make sure
multiple clients play nicely together.
"""
lua_release = None
lua_extend = None
lua_reacquire = None
# KEYS[1] - lock name
# ARGV[1] - token
# return 1 if the lock was released, otherwise 0
LUA_RELEASE_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
redis.call('del', KEYS[1])
return 1
"""
# KEYS[1] - lock name
# ARGV[1] - token
# ARGV[2] - additional milliseconds
# ARGV[3] - "0" if the additional time should be added to the lock's
# existing ttl or "1" if the existing ttl should be replaced
# return 1 if the locks time was extended, otherwise 0
LUA_EXTEND_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
local expiration = redis.call('pttl', KEYS[1])
if not expiration then
expiration = 0
end
if expiration < 0 then
return 0
end
local newttl = ARGV[2]
if ARGV[3] == "0" then
newttl = ARGV[2] + expiration
end
redis.call('pexpire', KEYS[1], newttl)
return 1
"""
# KEYS[1] - lock name
# ARGV[1] - token
# ARGV[2] - milliseconds
# return 1 if the locks time was reacquired, otherwise 0
LUA_REACQUIRE_SCRIPT = """
local token = redis.call('get', KEYS[1])
if not token or token ~= ARGV[1] then
return 0
end
redis.call('pexpire', KEYS[1], ARGV[2])
return 1
"""
def __init__(
self,
redis,
name: str,
timeout: Optional[Number] = None,
sleep: Number = 0.1,
blocking: bool = True,
blocking_timeout: Optional[Number] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
supplied by ``redis``.
``timeout`` indicates a maximum life for the lock in seconds.
By default, it will remain locked until release() is called.
``timeout`` can be specified as a float or integer, both representing
the number of seconds to wait.
``sleep`` indicates the amount of time to sleep in seconds per loop
iteration when the lock is in blocking mode and another client is
currently holding the lock.
``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
float or integer, both representing the number of seconds to wait.
``thread_local`` indicates whether the lock token is placed in
thread-local storage. By default, the token is placed in thread local
storage so that a thread only sees its token, not a token set by
another thread. Consider the following timeline:
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
thread-1 sets the token to "abc"
time: 1, thread-2 blocks trying to acquire `my-lock` using the
Lock instance.
time: 5, thread-1 has not yet completed. redis expires the lock
key.
time: 5, thread-2 acquired `my-lock` now that it's available.
thread-2 sets the token to "xyz"
time: 6, thread-1 finishes its work and calls release(). if the
token is *not* stored in thread local storage, then
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
local storage isn't disabled in this case, the worker thread won't see
the token set by the thread that acquired the lock. Our assumption
is that these cases aren't common and as such default to using
thread local storage.
"""
self.redis = redis
self.name = name
self.timeout = timeout
self.sleep = sleep
self.blocking = blocking
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.raise_on_release_error = raise_on_release_error
self.local = threading.local() if self.thread_local else SimpleNamespace()
self.local.token = None
self.register_scripts()
def register_scripts(self) -> None:
cls = self.__class__
client = self.redis
if cls.lua_release is None:
cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
if cls.lua_extend is None:
cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
if cls.lua_reacquire is None:
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
def __enter__(self) -> "Lock":
if self.acquire():
return self
raise LockError(
"Unable to acquire lock within the time specified",
lock_name=self.name,
)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
try:
self.release()
except LockError:
if self.raise_on_release_error:
raise
logger.warning(
"Lock was unlocked or no longer owned when exiting context manager."
)
def acquire(
self,
sleep: Optional[Number] = None,
blocking: Optional[bool] = None,
blocking_timeout: Optional[Number] = None,
token: Optional[str] = None,
):
"""
Use Redis to hold a shared, distributed lock named ``name``.
Returns True once the lock is acquired.
If ``blocking`` is False, always return immediately. If the lock
was acquired, return True, otherwise return False.
``blocking_timeout`` specifies the maximum number of seconds to
wait trying to acquire the lock.
``token`` specifies the token value to be used. If provided, token
must be a bytes object or a string that can be encoded to a bytes
object with the default encoding. If a token isn't specified, a UUID
will be generated.
"""
if sleep is None:
sleep = self.sleep
if token is None:
token = uuid.uuid1().hex.encode()
else:
encoder = self.redis.get_encoder()
token = encoder.encode(token)
if blocking is None:
blocking = self.blocking
if blocking_timeout is None:
blocking_timeout = self.blocking_timeout
stop_trying_at = None
if blocking_timeout is not None:
stop_trying_at = mod_time.monotonic() + blocking_timeout
while True:
if self.do_acquire(token):
self.local.token = token
return True
if not blocking:
return False
next_try_at = mod_time.monotonic() + sleep
if stop_trying_at is not None and next_try_at > stop_trying_at:
return False
mod_time.sleep(sleep)
def do_acquire(self, token: str) -> bool:
if self.timeout:
# convert to milliseconds
timeout = int(self.timeout * 1000)
else:
timeout = None
if self.redis.set(self.name, token, nx=True, px=timeout):
return True
return False
def locked(self) -> bool:
"""
Returns True if this key is locked by any process, otherwise False.
"""
return self.redis.get(self.name) is not None
def owned(self) -> bool:
"""
Returns True if this key is locked by this lock, otherwise False.
"""
stored_token = self.redis.get(self.name)
# need to always compare bytes to bytes
# TODO: this can be simplified when the context manager is finished
if stored_token and not isinstance(stored_token, bytes):
encoder = self.redis.get_encoder()
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token
def release(self) -> None:
"""
Releases the already acquired lock
"""
expected_token = self.local.token
if expected_token is None:
raise LockError(
"Cannot release a lock that's not owned or is already unlocked.",
lock_name=self.name,
)
self.local.token = None
self.do_release(expected_token)
def do_release(self, expected_token: str) -> None:
if not bool(
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
):
raise LockNotOwnedError(
"Cannot release a lock that's no longer owned",
lock_name=self.name,
)
def extend(self, additional_time: Number, replace_ttl: bool = False) -> bool:
"""
Adds more time to an already acquired lock.
``additional_time`` can be specified as an integer or a float, both
representing the number of seconds to add.
``replace_ttl`` if False (the default), add `additional_time` to
the lock's existing ttl. If True, replace the lock's ttl with
`additional_time`.
"""
if self.local.token is None:
raise LockError("Cannot extend an unlocked lock", lock_name=self.name)
if self.timeout is None:
raise LockError("Cannot extend a lock with no timeout", lock_name=self.name)
return self.do_extend(additional_time, replace_ttl)
def do_extend(self, additional_time: Number, replace_ttl: bool) -> bool:
additional_time = int(additional_time * 1000)
if not bool(
self.lua_extend(
keys=[self.name],
args=[self.local.token, additional_time, "1" if replace_ttl else "0"],
client=self.redis,
)
):
raise LockNotOwnedError(
"Cannot extend a lock that's no longer owned",
lock_name=self.name,
)
return True
def reacquire(self) -> bool:
"""
Resets a TTL of an already acquired lock back to a timeout value.
"""
if self.local.token is None:
raise LockError("Cannot reacquire an unlocked lock", lock_name=self.name)
if self.timeout is None:
raise LockError(
"Cannot reacquire a lock with no timeout",
lock_name=self.name,
)
return self.do_reacquire()
def do_reacquire(self) -> bool:
timeout = int(self.timeout * 1000)
if not bool(
self.lua_reacquire(
keys=[self.name], args=[self.local.token, timeout], client=self.redis
)
):
raise LockNotOwnedError(
"Cannot reacquire a lock that's no longer owned",
lock_name=self.name,
)
return True

View File

@@ -0,0 +1,810 @@
import enum
import ipaddress
import logging
import re
import threading
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Literal, Optional, Union
from redis.typing import Number
class MaintenanceState(enum.Enum):
NONE = "none"
MOVING = "moving"
MAINTENANCE = "maintenance"
class EndpointType(enum.Enum):
"""Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
INTERNAL_IP = "internal-ip"
INTERNAL_FQDN = "internal-fqdn"
EXTERNAL_IP = "external-ip"
EXTERNAL_FQDN = "external-fqdn"
NONE = "none"
def __str__(self):
"""Return the string value of the enum."""
return self.value
if TYPE_CHECKING:
from redis.connection import (
MaintNotificationsAbstractConnection,
MaintNotificationsAbstractConnectionPool,
)
class MaintenanceNotification(ABC):
"""
Base class for maintenance notifications sent through push messages by Redis server.
This class provides common functionality for all maintenance notifications including
unique identification and TTL (Time-To-Live) functionality.
Attributes:
id (int): Unique identifier for this notification
ttl (int): Time-to-live in seconds for this notification
creation_time (float): Timestamp when the notification was created/read
"""
def __init__(self, id: int, ttl: int):
"""
Initialize a new MaintenanceNotification with unique ID and TTL functionality.
Args:
id (int): Unique identifier for this notification
ttl (int): Time-to-live in seconds for this notification
"""
self.id = id
self.ttl = ttl
self.creation_time = time.monotonic()
self.expire_at = self.creation_time + self.ttl
def is_expired(self) -> bool:
"""
Check if this notification has expired based on its TTL
and creation time.
Returns:
bool: True if the notification has expired, False otherwise
"""
return time.monotonic() > (self.creation_time + self.ttl)
@abstractmethod
def __repr__(self) -> str:
"""
Return a string representation of the maintenance notification.
This method must be implemented by all concrete subclasses.
Returns:
str: String representation of the notification
"""
pass
@abstractmethod
def __eq__(self, other) -> bool:
"""
Compare two maintenance notifications for equality.
This method must be implemented by all concrete subclasses.
Notifications are typically considered equal if they have the same id
and are of the same type.
Args:
other: The other object to compare with
Returns:
bool: True if the notifications are equal, False otherwise
"""
pass
@abstractmethod
def __hash__(self) -> int:
"""
Return a hash value for the maintenance notification.
This method must be implemented by all concrete subclasses to allow
instances to be used in sets and as dictionary keys.
Returns:
int: Hash value for the notification
"""
pass
class NodeMovingNotification(MaintenanceNotification):
"""
This notification is received when a node is replaced with a new node
during cluster rebalancing or maintenance operations.
"""
def __init__(
self,
id: int,
new_node_host: Optional[str],
new_node_port: Optional[int],
ttl: int,
):
"""
Initialize a new NodeMovingNotification.
Args:
id (int): Unique identifier for this notification
new_node_host (str): Hostname or IP address of the new replacement node
new_node_port (int): Port number of the new replacement node
ttl (int): Time-to-live in seconds for this notification
"""
super().__init__(id, ttl)
self.new_node_host = new_node_host
self.new_node_port = new_node_port
def __repr__(self) -> str:
expiry_time = self.expire_at
remaining = max(0, expiry_time - time.monotonic())
return (
f"{self.__class__.__name__}("
f"id={self.id}, "
f"new_node_host='{self.new_node_host}', "
f"new_node_port={self.new_node_port}, "
f"ttl={self.ttl}, "
f"creation_time={self.creation_time}, "
f"expires_at={expiry_time}, "
f"remaining={remaining:.1f}s, "
f"expired={self.is_expired()}"
f")"
)
def __eq__(self, other) -> bool:
"""
Two NodeMovingNotification notifications are considered equal if they have the same
id, new_node_host, and new_node_port.
"""
if not isinstance(other, NodeMovingNotification):
return False
return (
self.id == other.id
and self.new_node_host == other.new_node_host
and self.new_node_port == other.new_node_port
)
def __hash__(self) -> int:
"""
Return a hash value for the notification to allow
instances to be used in sets and as dictionary keys.
Returns:
int: Hash value based on notification type class name, id,
new_node_host and new_node_port
"""
try:
node_port = int(self.new_node_port) if self.new_node_port else None
except ValueError:
node_port = 0
return hash(
(
self.__class__.__name__,
int(self.id),
str(self.new_node_host),
node_port,
)
)
class NodeMigratingNotification(MaintenanceNotification):
"""
Notification for when a Redis cluster node is in the process of migrating slots.
This notification is received when a node starts migrating its slots to another node
during cluster rebalancing or maintenance operations.
Args:
id (int): Unique identifier for this notification
ttl (int): Time-to-live in seconds for this notification
"""
def __init__(self, id: int, ttl: int):
super().__init__(id, ttl)
def __repr__(self) -> str:
expiry_time = self.creation_time + self.ttl
remaining = max(0, expiry_time - time.monotonic())
return (
f"{self.__class__.__name__}("
f"id={self.id}, "
f"ttl={self.ttl}, "
f"creation_time={self.creation_time}, "
f"expires_at={expiry_time}, "
f"remaining={remaining:.1f}s, "
f"expired={self.is_expired()}"
f")"
)
def __eq__(self, other) -> bool:
"""
Two NodeMigratingNotification notifications are considered equal if they have the same
id and are of the same type.
"""
if not isinstance(other, NodeMigratingNotification):
return False
return self.id == other.id and type(self) is type(other)
def __hash__(self) -> int:
"""
Return a hash value for the notification to allow
instances to be used in sets and as dictionary keys.
Returns:
int: Hash value based on notification type and id
"""
return hash((self.__class__.__name__, int(self.id)))
class NodeMigratedNotification(MaintenanceNotification):
"""
Notification for when a Redis cluster node has completed migrating slots.
This notification is received when a node has finished migrating all its slots
to other nodes during cluster rebalancing or maintenance operations.
Args:
id (int): Unique identifier for this notification
"""
DEFAULT_TTL = 5
def __init__(self, id: int):
super().__init__(id, NodeMigratedNotification.DEFAULT_TTL)
def __repr__(self) -> str:
expiry_time = self.creation_time + self.ttl
remaining = max(0, expiry_time - time.monotonic())
return (
f"{self.__class__.__name__}("
f"id={self.id}, "
f"ttl={self.ttl}, "
f"creation_time={self.creation_time}, "
f"expires_at={expiry_time}, "
f"remaining={remaining:.1f}s, "
f"expired={self.is_expired()}"
f")"
)
def __eq__(self, other) -> bool:
"""
Two NodeMigratedNotification notifications are considered equal if they have the same
id and are of the same type.
"""
if not isinstance(other, NodeMigratedNotification):
return False
return self.id == other.id and type(self) is type(other)
def __hash__(self) -> int:
"""
Return a hash value for the notification to allow
instances to be used in sets and as dictionary keys.
Returns:
int: Hash value based on notification type and id
"""
return hash((self.__class__.__name__, int(self.id)))
class NodeFailingOverNotification(MaintenanceNotification):
"""
Notification for when a Redis cluster node is in the process of failing over.
This notification is received when a node starts a failover process during
cluster maintenance operations or when handling node failures.
Args:
id (int): Unique identifier for this notification
ttl (int): Time-to-live in seconds for this notification
"""
def __init__(self, id: int, ttl: int):
super().__init__(id, ttl)
def __repr__(self) -> str:
expiry_time = self.creation_time + self.ttl
remaining = max(0, expiry_time - time.monotonic())
return (
f"{self.__class__.__name__}("
f"id={self.id}, "
f"ttl={self.ttl}, "
f"creation_time={self.creation_time}, "
f"expires_at={expiry_time}, "
f"remaining={remaining:.1f}s, "
f"expired={self.is_expired()}"
f")"
)
def __eq__(self, other) -> bool:
"""
Two NodeFailingOverNotification notifications are considered equal if they have the same
id and are of the same type.
"""
if not isinstance(other, NodeFailingOverNotification):
return False
return self.id == other.id and type(self) is type(other)
def __hash__(self) -> int:
"""
Return a hash value for the notification to allow
instances to be used in sets and as dictionary keys.
Returns:
int: Hash value based on notification type and id
"""
return hash((self.__class__.__name__, int(self.id)))
class NodeFailedOverNotification(MaintenanceNotification):
"""
Notification for when a Redis cluster node has completed a failover.
This notification is received when a node has finished the failover process
during cluster maintenance operations or after handling node failures.
Args:
id (int): Unique identifier for this notification
"""
DEFAULT_TTL = 5
def __init__(self, id: int):
super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL)
def __repr__(self) -> str:
expiry_time = self.creation_time + self.ttl
remaining = max(0, expiry_time - time.monotonic())
return (
f"{self.__class__.__name__}("
f"id={self.id}, "
f"ttl={self.ttl}, "
f"creation_time={self.creation_time}, "
f"expires_at={expiry_time}, "
f"remaining={remaining:.1f}s, "
f"expired={self.is_expired()}"
f")"
)
def __eq__(self, other) -> bool:
"""
Two NodeFailedOverNotification notifications are considered equal if they have the same
id and are of the same type.
"""
if not isinstance(other, NodeFailedOverNotification):
return False
return self.id == other.id and type(self) is type(other)
def __hash__(self) -> int:
"""
Return a hash value for the notification to allow
instances to be used in sets and as dictionary keys.
Returns:
int: Hash value based on notification type and id
"""
return hash((self.__class__.__name__, int(self.id)))
def _is_private_fqdn(host: str) -> bool:
"""
Determine if an FQDN is likely to be internal/private.
This uses heuristics based on RFC 952 and RFC 1123 standards:
- .local domains (RFC 6762 - Multicast DNS)
- .internal domains (common internal convention)
- Single-label hostnames (no dots)
- Common internal TLDs
Args:
host (str): The FQDN to check
Returns:
bool: True if the FQDN appears to be internal/private
"""
host_lower = host.lower().rstrip(".")
# Single-label hostnames (no dots) are typically internal
if "." not in host_lower:
return True
# Common internal/private domain patterns
internal_patterns = [
r"\.local$", # mDNS/Bonjour domains
r"\.internal$", # Common internal convention
r"\.corp$", # Corporate domains
r"\.lan$", # Local area network
r"\.intranet$", # Intranet domains
r"\.private$", # Private domains
]
for pattern in internal_patterns:
if re.search(pattern, host_lower):
return True
# If none of the internal patterns match, assume it's external
return False
class MaintNotificationsConfig:
"""
Configuration class for maintenance notifications handling behaviour. Notifications are received through
push notifications.
This class defines how the Redis client should react to different push notifications
such as node moving, migrations, etc. in a Redis cluster.
"""
def __init__(
self,
enabled: Union[bool, Literal["auto"]] = "auto",
proactive_reconnect: bool = True,
relaxed_timeout: Optional[Number] = 10,
endpoint_type: Optional[EndpointType] = None,
):
"""
Initialize a new MaintNotificationsConfig.
Args:
enabled (bool | "auto"): Controls maintenance notifications handling behavior.
- True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup,
otherwise a ResponseError is raised.
- "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are
gracefully handled - a warning is logged and normal operation continues.
- False: Maintenance notifications are completely disabled.
Defaults to "auto".
proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
Defaults to True.
relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance.
If -1 is provided - the relaxed timeout is disabled. Defaults to 20.
endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
If None, the endpoint type will be automatically determined based on the host and TLS configuration.
Defaults to None.
Raises:
ValueError: If endpoint_type is provided but is not a valid endpoint type.
"""
self.enabled = enabled
self.relaxed_timeout = relaxed_timeout
self.proactive_reconnect = proactive_reconnect
self.endpoint_type = endpoint_type
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"enabled={self.enabled}, "
f"proactive_reconnect={self.proactive_reconnect}, "
f"relaxed_timeout={self.relaxed_timeout}, "
f"endpoint_type={self.endpoint_type!r}"
f")"
)
def is_relaxed_timeouts_enabled(self) -> bool:
"""
Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout.
If relaxed_timeout is set to None, it will make the operation blocking
and waiting until any response is received.
Returns:
True if the relaxed_timeout is enabled, False otherwise.
"""
return self.relaxed_timeout != -1
def get_endpoint_type(
self, host: str, connection: "MaintNotificationsAbstractConnection"
) -> EndpointType:
"""
Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command.
Logic:
1. If endpoint_type is explicitly set, use it
2. Otherwise, check the original host from connection.host:
- If host is an IP address, use it directly to determine internal-ip vs external-ip
- If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn
Args:
host: User provided hostname to analyze
connection: The connection object to analyze for endpoint type determination
Returns:
"""
# If endpoint_type is explicitly set, use it
if self.endpoint_type is not None:
return self.endpoint_type
# Check if the host is an IP address
try:
ip_addr = ipaddress.ip_address(host)
# Host is an IP address - use it directly
is_private = ip_addr.is_private
return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP
except ValueError:
# Host is an FQDN - need to check resolved IP to determine internal vs external
pass
# Host is an FQDN, get the resolved IP to determine if it's internal or external
resolved_ip = connection.get_resolved_ip()
if resolved_ip:
try:
ip_addr = ipaddress.ip_address(resolved_ip)
is_private = ip_addr.is_private
# Use FQDN types since the original host was an FQDN
return (
EndpointType.INTERNAL_FQDN
if is_private
else EndpointType.EXTERNAL_FQDN
)
except ValueError:
# This shouldn't happen since we got the IP from the socket, but fallback
pass
# Final fallback: use heuristics on the FQDN itself
is_private = _is_private_fqdn(host)
return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN
class MaintNotificationsPoolHandler:
def __init__(
self,
pool: "MaintNotificationsAbstractConnectionPool",
config: MaintNotificationsConfig,
) -> None:
self.pool = pool
self.config = config
self._processed_notifications = set()
self._lock = threading.RLock()
self.connection = None
def set_connection(self, connection: "MaintNotificationsAbstractConnection"):
self.connection = connection
def get_handler_for_connection(self):
# Copy all data that should be shared between connections
# but each connection should have its own pool handler
# since each connection can be in a different state
copy = MaintNotificationsPoolHandler(self.pool, self.config)
copy._processed_notifications = self._processed_notifications
copy._lock = self._lock
copy.connection = None
return copy
def remove_expired_notifications(self):
with self._lock:
for notification in tuple(self._processed_notifications):
if notification.is_expired():
self._processed_notifications.remove(notification)
def handle_notification(self, notification: MaintenanceNotification):
self.remove_expired_notifications()
if isinstance(notification, NodeMovingNotification):
return self.handle_node_moving_notification(notification)
else:
logging.error(f"Unhandled notification type: {notification}")
def handle_node_moving_notification(self, notification: NodeMovingNotification):
if (
not self.config.proactive_reconnect
and not self.config.is_relaxed_timeouts_enabled()
):
return
with self._lock:
if notification in self._processed_notifications:
# nothing to do in the connection pool handling
# the notification has already been handled or is expired
# just return
return
with self.pool._lock:
if (
self.config.proactive_reconnect
or self.config.is_relaxed_timeouts_enabled()
):
# Get the current connected address - if any
# This is the address that is being moved
# and we need to handle only connections
# connected to the same address
moving_address_src = (
self.connection.getpeername() if self.connection else None
)
if getattr(self.pool, "set_in_maintenance", False):
# Set pool in maintenance mode - executed only if
# BlockingConnectionPool is used
self.pool.set_in_maintenance(True)
# Update maintenance state, timeout and optionally host address
# connection settings for matching connections
self.pool.update_connections_settings(
state=MaintenanceState.MOVING,
maintenance_notification_hash=hash(notification),
relaxed_timeout=self.config.relaxed_timeout,
host_address=notification.new_node_host,
matching_address=moving_address_src,
matching_pattern="connected_address",
update_notification_hash=True,
include_free_connections=True,
)
if self.config.proactive_reconnect:
if notification.new_node_host is not None:
self.run_proactive_reconnect(moving_address_src)
else:
threading.Timer(
notification.ttl / 2,
self.run_proactive_reconnect,
args=(moving_address_src,),
).start()
# Update config for new connections:
# Set state to MOVING
# update host
# if relax timeouts are enabled - update timeouts
kwargs: dict = {
"maintenance_state": MaintenanceState.MOVING,
"maintenance_notification_hash": hash(notification),
}
if notification.new_node_host is not None:
# the host is not updated if the new node host is None
# this happens when the MOVING push notification does not contain
# the new node host - in this case we only update the timeouts
kwargs.update(
{
"host": notification.new_node_host,
}
)
if self.config.is_relaxed_timeouts_enabled():
kwargs.update(
{
"socket_timeout": self.config.relaxed_timeout,
"socket_connect_timeout": self.config.relaxed_timeout,
}
)
self.pool.update_connection_kwargs(**kwargs)
if getattr(self.pool, "set_in_maintenance", False):
self.pool.set_in_maintenance(False)
threading.Timer(
notification.ttl,
self.handle_node_moved_notification,
args=(notification,),
).start()
self._processed_notifications.add(notification)
def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
"""
Run proactive reconnect for the pool.
Active connections are marked for reconnect after they complete the current command.
Inactive connections are disconnected and will be connected on next use.
"""
with self._lock:
with self.pool._lock:
# take care for the active connections in the pool
# mark them for reconnect after they complete the current command
self.pool.update_active_connections_for_reconnect(
moving_address_src=moving_address_src,
)
# take care for the inactive connections in the pool
# delete them and create new ones
self.pool.disconnect_free_connections(
moving_address_src=moving_address_src,
)
def handle_node_moved_notification(self, notification: NodeMovingNotification):
"""
Handle the cleanup after a node moving notification expires.
"""
notification_hash = hash(notification)
with self._lock:
# if the current maintenance_notification_hash in kwargs is not matching the notification
# it means there has been a new moving notification after this one
# and we don't need to revert the kwargs yet
if (
self.pool.connection_kwargs.get("maintenance_notification_hash")
== notification_hash
):
orig_host = self.pool.connection_kwargs.get("orig_host_address")
orig_socket_timeout = self.pool.connection_kwargs.get(
"orig_socket_timeout"
)
orig_connect_timeout = self.pool.connection_kwargs.get(
"orig_socket_connect_timeout"
)
kwargs: dict = {
"maintenance_state": MaintenanceState.NONE,
"maintenance_notification_hash": None,
"host": orig_host,
"socket_timeout": orig_socket_timeout,
"socket_connect_timeout": orig_connect_timeout,
}
self.pool.update_connection_kwargs(**kwargs)
with self.pool._lock:
reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled()
reset_host_address = self.config.proactive_reconnect
self.pool.update_connections_settings(
relaxed_timeout=-1,
state=MaintenanceState.NONE,
maintenance_notification_hash=None,
matching_notification_hash=notification_hash,
matching_pattern="notification_hash",
update_notification_hash=True,
reset_relaxed_timeout=reset_relaxed_timeout,
reset_host_address=reset_host_address,
include_free_connections=True,
)
class MaintNotificationsConnectionHandler:
# 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications
_NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = {
NodeMigratingNotification: 1,
NodeFailingOverNotification: 1,
NodeMigratedNotification: 0,
NodeFailedOverNotification: 0,
}
def __init__(
self,
connection: "MaintNotificationsAbstractConnection",
config: MaintNotificationsConfig,
) -> None:
self.connection = connection
self.config = config
def handle_notification(self, notification: MaintenanceNotification):
# get the notification type by checking its class in the _NOTIFICATION_TYPES dict
notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None)
if notification_type is None:
logging.error(f"Unhandled notification type: {notification}")
return
if notification_type:
self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE)
else:
self.handle_maintenance_completed_notification()
def handle_maintenance_start_notification(
self, maintenance_state: MaintenanceState
):
if (
self.connection.maintenance_state == MaintenanceState.MOVING
or not self.config.is_relaxed_timeouts_enabled()
):
return
self.connection.maintenance_state = maintenance_state
self.connection.set_tmp_settings(
tmp_relaxed_timeout=self.config.relaxed_timeout
)
# extend the timeout for all created connections
self.connection.update_current_socket_timeout(self.config.relaxed_timeout)
def handle_maintenance_completed_notification(self):
# Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
if (
self.connection.maintenance_state == MaintenanceState.MOVING
or not self.config.is_relaxed_timeouts_enabled()
):
return
self.connection.reset_tmp_settings(reset_relaxed_timeout=True)
# Maintenance completed - reset the connection
# timeouts by providing -1 as the relaxed timeout
self.connection.update_current_socket_timeout(-1)
self.connection.maintenance_state = MaintenanceState.NONE

View File

@@ -0,0 +1,144 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Callable
import pybreaker
DEFAULT_GRACE_PERIOD = 60
class State(Enum):
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half-open"
class CircuitBreaker(ABC):
@property
@abstractmethod
def grace_period(self) -> float:
"""The grace period in seconds when the circle should be kept open."""
pass
@grace_period.setter
@abstractmethod
def grace_period(self, grace_period: float):
"""Set the grace period in seconds."""
@property
@abstractmethod
def state(self) -> State:
"""The current state of the circuit."""
pass
@state.setter
@abstractmethod
def state(self, state: State):
"""Set current state of the circuit."""
pass
@property
@abstractmethod
def database(self):
"""Database associated with this circuit."""
pass
@database.setter
@abstractmethod
def database(self, database):
"""Set database associated with this circuit."""
pass
@abstractmethod
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
"""Callback called when the state of the circuit changes."""
pass
class BaseCircuitBreaker(CircuitBreaker):
"""
Base implementation of Circuit Breaker interface.
"""
def __init__(self, cb: pybreaker.CircuitBreaker):
self._cb = cb
self._state_pb_mapper = {
State.CLOSED: self._cb.close,
State.OPEN: self._cb.open,
State.HALF_OPEN: self._cb.half_open,
}
self._database = None
@property
def grace_period(self) -> float:
return self._cb.reset_timeout
@grace_period.setter
def grace_period(self, grace_period: float):
self._cb.reset_timeout = grace_period
@property
def state(self) -> State:
return State(value=self._cb.state.name)
@state.setter
def state(self, state: State):
self._state_pb_mapper[state]()
@property
def database(self):
return self._database
@database.setter
def database(self, database):
self._database = database
@abstractmethod
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
"""Callback called when the state of the circuit changes."""
pass
class PBListener(pybreaker.CircuitBreakerListener):
"""Wrapper for callback to be compatible with pybreaker implementation."""
def __init__(
self,
cb: Callable[[CircuitBreaker, State, State], None],
database,
):
"""
Initialize a PBListener instance.
Args:
cb: Callback function that will be called when the circuit breaker state changes.
database: Database instance associated with this circuit breaker.
"""
self._cb = cb
self._database = database
def state_change(self, cb, old_state, new_state):
cb = PBCircuitBreakerAdapter(cb)
cb.database = self._database
old_state = State(value=old_state.name)
new_state = State(value=new_state.name)
self._cb(cb, old_state, new_state)
class PBCircuitBreakerAdapter(BaseCircuitBreaker):
def __init__(self, cb: pybreaker.CircuitBreaker):
"""
Initialize a PBCircuitBreakerAdapter instance.
This adapter wraps pybreaker's CircuitBreaker implementation to make it compatible
with our CircuitBreaker interface.
Args:
cb: A pybreaker CircuitBreaker instance to be adapted.
"""
super().__init__(cb)
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
listener = PBListener(cb, self.database)
self._cb.add_listener(listener)

View File

@@ -0,0 +1,526 @@
import logging
import threading
from concurrent.futures import as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Any, Callable, List, Optional
from redis.background import BackgroundScheduler
from redis.client import PubSubWorkerThread
from redis.commands import CoreCommands, RedisModuleCommands
from redis.multidb.circuit import CircuitBreaker
from redis.multidb.circuit import State as CBState
from redis.multidb.command_executor import DefaultCommandExecutor
from redis.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
from redis.multidb.database import Database, Databases, SyncDatabase
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
from redis.multidb.failure_detector import FailureDetector
from redis.multidb.healthcheck import HealthCheck, HealthCheckPolicy
from redis.utils import experimental
logger = logging.getLogger(__name__)
@experimental
class MultiDBClient(RedisModuleCommands, CoreCommands):
"""
Client that operates on multiple logical Redis databases.
Should be used in Active-Active database setups.
"""
def __init__(self, config: MultiDbConfig):
self._databases = config.databases()
self._health_checks = config.default_health_checks()
if config.health_checks is not None:
self._health_checks.extend(config.health_checks)
self._health_check_interval = config.health_check_interval
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
config.health_check_probes, config.health_check_probes_delay
)
self._failure_detectors = config.default_failure_detectors()
if config.failure_detectors is not None:
self._failure_detectors.extend(config.failure_detectors)
self._failover_strategy = (
config.default_failover_strategy()
if config.failover_strategy is None
else config.failover_strategy
)
self._failover_strategy.set_databases(self._databases)
self._auto_fallback_interval = config.auto_fallback_interval
self._event_dispatcher = config.event_dispatcher
self._command_retry = config.command_retry
self._command_retry.update_supported_errors((ConnectionRefusedError,))
self.command_executor = DefaultCommandExecutor(
failure_detectors=self._failure_detectors,
databases=self._databases,
command_retry=self._command_retry,
failover_strategy=self._failover_strategy,
failover_attempts=config.failover_attempts,
failover_delay=config.failover_delay,
event_dispatcher=self._event_dispatcher,
auto_fallback_interval=self._auto_fallback_interval,
)
self.initialized = False
self._hc_lock = threading.RLock()
self._bg_scheduler = BackgroundScheduler()
self._config = config
def initialize(self):
"""
Perform initialization of databases to define their initial state.
"""
def raise_exception_on_failed_hc(error):
raise error
# Initial databases check to define initial state
self._check_databases_health(on_error=raise_exception_on_failed_hc)
# Starts recurring health checks on the background.
self._bg_scheduler.run_recurring(
self._health_check_interval,
self._check_databases_health,
)
is_active_db_found = False
for database, weight in self._databases:
# Set on state changed callback for each circuit.
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
# Set states according to a weights and circuit state
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
self.command_executor.active_database = database
is_active_db_found = True
if not is_active_db_found:
raise NoValidDatabaseException(
"Initial connection failed - no active database found"
)
self.initialized = True
def get_databases(self) -> Databases:
"""
Returns a sorted (by weight) list of all databases.
"""
return self._databases
def set_active_database(self, database: SyncDatabase) -> None:
"""
Promote one of the existing databases to become an active.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
self._check_db_health(database)
if database.circuit.state == CBState.CLOSED:
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
self.command_executor.active_database = database
return
raise NoValidDatabaseException(
"Cannot set active database, database is unhealthy"
)
def add_database(self, database: SyncDatabase):
"""
Adds a new database to the database list.
"""
for existing_db, _ in self._databases:
if existing_db == database:
raise ValueError("Given database already exists")
self._check_db_health(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.add(database, database.weight)
self._change_active_database(database, highest_weighted_db)
def _change_active_database(
self, new_database: SyncDatabase, highest_weight_database: SyncDatabase
):
if (
new_database.weight > highest_weight_database.weight
and new_database.circuit.state == CBState.CLOSED
):
self.command_executor.active_database = new_database
def remove_database(self, database: Database):
"""
Removes a database from the database list.
"""
weight = self._databases.remove(database)
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
if (
highest_weight <= weight
and highest_weighted_db.circuit.state == CBState.CLOSED
):
self.command_executor.active_database = highest_weighted_db
def update_database_weight(self, database: SyncDatabase, weight: float):
"""
Updates a database from the database list.
"""
exists = None
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break
if not exists:
raise ValueError("Given database is not a member of database list")
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.update_weight(database, weight)
database.weight = weight
self._change_active_database(database, highest_weighted_db)
def add_failure_detector(self, failure_detector: FailureDetector):
"""
Adds a new failure detector to the database.
"""
self._failure_detectors.append(failure_detector)
def add_health_check(self, healthcheck: HealthCheck):
"""
Adds a new health check to the database.
"""
with self._hc_lock:
self._health_checks.append(healthcheck)
def execute_command(self, *args, **options):
"""
Executes a single command and return its result.
"""
if not self.initialized:
self.initialize()
return self.command_executor.execute_command(*args, **options)
def pipeline(self):
"""
Enters into pipeline mode of the client.
"""
return Pipeline(self)
def transaction(self, func: Callable[["Pipeline"], None], *watches, **options):
"""
Executes callable as transaction.
"""
if not self.initialized:
self.initialize()
return self.command_executor.execute_transaction(func, *watches, *options)
def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
if not self.initialized:
self.initialize()
return PubSub(self, **kwargs)
def _check_db_health(self, database: SyncDatabase) -> bool:
"""
Runs health checks on the given database until first failure.
"""
# Health check will setup circuit state
is_healthy = self._health_check_policy.execute(self._health_checks, database)
if not is_healthy:
if database.circuit.state != CBState.OPEN:
database.circuit.state = CBState.OPEN
return is_healthy
elif is_healthy and database.circuit.state != CBState.CLOSED:
database.circuit.state = CBState.CLOSED
return is_healthy
def _check_databases_health(self, on_error: Callable[[Exception], None] = None):
"""
Runs health checks as a recurring task.
Runs health checks against all databases.
"""
with ThreadPoolExecutor(max_workers=len(self._databases)) as executor:
# Submit all health checks
futures = {
executor.submit(self._check_db_health, database)
for database, _ in self._databases
}
try:
for future in as_completed(
futures, timeout=self._health_check_interval
):
try:
future.result()
except UnhealthyDatabaseException as e:
unhealthy_db = e.database
unhealthy_db.circuit.state = CBState.OPEN
logger.exception(
"Health check failed, due to exception",
exc_info=e.original_exception,
)
if on_error:
on_error(e.original_exception)
except TimeoutError:
raise TimeoutError(
"Health check execution exceeds health_check_interval"
)
def _on_circuit_state_change_callback(
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
):
if new_state == CBState.HALF_OPEN:
self._check_db_health(circuit.database)
return
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
self._bg_scheduler.run_once(
DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit
)
def close(self):
self.command_executor.active_database.client.close()
def _half_open_circuit(circuit: CircuitBreaker):
circuit.state = CBState.HALF_OPEN
class Pipeline(RedisModuleCommands, CoreCommands):
"""
Pipeline implementation for multiple logical Redis databases.
"""
def __init__(self, client: MultiDBClient):
self._command_stack = []
self._client = client
def __enter__(self) -> "Pipeline":
return self
def __exit__(self, exc_type, exc_value, traceback):
self.reset()
def __del__(self):
try:
self.reset()
except Exception:
pass
def __len__(self) -> int:
return len(self._command_stack)
def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True
def reset(self) -> None:
self._command_stack = []
def close(self) -> None:
"""Close the pipeline"""
self.reset()
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called
Returns the current Pipeline object back so commands can be
chained together, such as:
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
"""
self._command_stack.append((args, options))
return self
def execute_command(self, *args, **kwargs):
"""Adds a command to the stack"""
return self.pipeline_execute_command(*args, **kwargs)
def execute(self) -> List[Any]:
"""Execute all the commands in the current pipeline"""
if not self._client.initialized:
self._client.initialize()
try:
return self._client.command_executor.execute_pipeline(
tuple(self._command_stack)
)
finally:
self.reset()
class PubSub:
"""
PubSub object for multi database client.
"""
def __init__(self, client: MultiDBClient, **kwargs):
"""Initialize the PubSub object for a multi-database client.
Args:
client: MultiDBClient instance to use for pub/sub operations
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
"""
self._client = client
self._client.command_executor.pubsub(**kwargs)
def __enter__(self) -> "PubSub":
return self
def __del__(self) -> None:
try:
# if this object went out of scope prior to shutting down
# subscriptions, close the connection manually before
# returning it to the connection pool
self.reset()
except Exception:
pass
def reset(self) -> None:
return self._client.command_executor.execute_pubsub_method("reset")
def close(self) -> None:
self.reset()
@property
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed
def execute_command(self, *args):
return self._client.command_executor.execute_pubsub_method(
"execute_command", *args
)
def psubscribe(self, *args, **kwargs):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
expect a pattern name as the key and a callable as the value. A
pattern's callable will be invoked automatically when a message is
received on that pattern rather than producing a message via
``listen()``.
"""
return self._client.command_executor.execute_pubsub_method(
"psubscribe", *args, **kwargs
)
def punsubscribe(self, *args):
"""
Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
return self._client.command_executor.execute_pubsub_method(
"punsubscribe", *args
)
def subscribe(self, *args, **kwargs):
"""
Subscribe to channels. Channels supplied as keyword arguments expect
a channel name as the key and a callable as the value. A channel's
callable will be invoked automatically when a message is received on
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
return self._client.command_executor.execute_pubsub_method(
"subscribe", *args, **kwargs
)
def unsubscribe(self, *args):
"""
Unsubscribe from the supplied channels. If empty, unsubscribe from
all channels
"""
return self._client.command_executor.execute_pubsub_method("unsubscribe", *args)
def ssubscribe(self, *args, **kwargs):
"""
Subscribes the client to the specified shard channels.
Channels supplied as keyword arguments expect a channel name as the key
and a callable as the value. A channel's callable will be invoked automatically
when a message is received on that channel rather than producing a message via
``listen()`` or ``get_sharded_message()``.
"""
return self._client.command_executor.execute_pubsub_method(
"ssubscribe", *args, **kwargs
)
def sunsubscribe(self, *args):
"""
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
all shard_channels
"""
return self._client.command_executor.execute_pubsub_method(
"sunsubscribe", *args
)
def get_message(
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
):
"""
Get the next message if one is available, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number, or None, to wait indefinitely.
"""
return self._client.command_executor.execute_pubsub_method(
"get_message",
ignore_subscribe_messages=ignore_subscribe_messages,
timeout=timeout,
)
def get_sharded_message(
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
):
"""
Get the next message if one is available in a sharded channel, otherwise None.
If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number, or None, to wait indefinitely.
"""
return self._client.command_executor.execute_pubsub_method(
"get_sharded_message",
ignore_subscribe_messages=ignore_subscribe_messages,
timeout=timeout,
)
def run_in_thread(
self,
sleep_time: float = 0.0,
daemon: bool = False,
exception_handler: Optional[Callable] = None,
sharded_pubsub: bool = False,
) -> "PubSubWorkerThread":
return self._client.command_executor.execute_pubsub_run(
sleep_time,
daemon=daemon,
exception_handler=exception_handler,
pubsub=self,
sharded_pubsub=sharded_pubsub,
)

View File

@@ -0,0 +1,350 @@
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Any, Callable, List, Optional
from redis.client import Pipeline, PubSub, PubSubWorkerThread
from redis.event import EventDispatcherInterface, OnCommandsFailEvent
from redis.multidb.circuit import State as CBState
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.multidb.database import Database, Databases, SyncDatabase
from redis.multidb.event import (
ActiveDatabaseChanged,
CloseConnectionOnActiveDatabaseChanged,
RegisterCommandFailure,
ResubscribeOnActiveDatabaseChanged,
)
from redis.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
DefaultFailoverStrategyExecutor,
FailoverStrategy,
FailoverStrategyExecutor,
)
from redis.multidb.failure_detector import FailureDetector
from redis.retry import Retry
class CommandExecutor(ABC):
@property
@abstractmethod
def auto_fallback_interval(self) -> float:
"""Returns auto-fallback interval."""
pass
@auto_fallback_interval.setter
@abstractmethod
def auto_fallback_interval(self, auto_fallback_interval: float) -> None:
"""Sets auto-fallback interval."""
pass
class BaseCommandExecutor(CommandExecutor):
def __init__(
self,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
):
self._auto_fallback_interval = auto_fallback_interval
self._next_fallback_attempt: datetime
@property
def auto_fallback_interval(self) -> float:
return self._auto_fallback_interval
@auto_fallback_interval.setter
def auto_fallback_interval(self, auto_fallback_interval: int) -> None:
self._auto_fallback_interval = auto_fallback_interval
def _schedule_next_fallback(self) -> None:
if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL:
return
self._next_fallback_attempt = datetime.now() + timedelta(
seconds=self._auto_fallback_interval
)
class SyncCommandExecutor(CommandExecutor):
@property
@abstractmethod
def databases(self) -> Databases:
"""Returns a list of databases."""
pass
@property
@abstractmethod
def failure_detectors(self) -> List[FailureDetector]:
"""Returns a list of failure detectors."""
pass
@abstractmethod
def add_failure_detector(self, failure_detector: FailureDetector) -> None:
"""Adds a new failure detector to the list of failure detectors."""
pass
@property
@abstractmethod
def active_database(self) -> Optional[Database]:
"""Returns currently active database."""
pass
@active_database.setter
@abstractmethod
def active_database(self, database: SyncDatabase) -> None:
"""Sets the currently active database."""
pass
@property
@abstractmethod
def active_pubsub(self) -> Optional[PubSub]:
"""Returns currently active pubsub."""
pass
@active_pubsub.setter
@abstractmethod
def active_pubsub(self, pubsub: PubSub) -> None:
"""Sets currently active pubsub."""
pass
@property
@abstractmethod
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
"""Returns failover strategy executor."""
pass
@property
@abstractmethod
def command_retry(self) -> Retry:
"""Returns command retry object."""
pass
@abstractmethod
def pubsub(self, **kwargs):
"""Initializes a PubSub object on a currently active database"""
pass
@abstractmethod
def execute_command(self, *args, **options):
"""Executes a command and returns the result."""
pass
@abstractmethod
def execute_pipeline(self, command_stack: tuple):
"""Executes a stack of commands in pipeline."""
pass
@abstractmethod
def execute_transaction(
self, transaction: Callable[[Pipeline], None], *watches, **options
):
"""Executes a transaction block wrapped in callback."""
pass
@abstractmethod
def execute_pubsub_method(self, method_name: str, *args, **kwargs):
"""Executes a given method on active pub/sub."""
pass
@abstractmethod
def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
"""Executes pub/sub run in a thread."""
pass
class DefaultCommandExecutor(SyncCommandExecutor, BaseCommandExecutor):
def __init__(
self,
failure_detectors: List[FailureDetector],
databases: Databases,
command_retry: Retry,
failover_strategy: FailoverStrategy,
event_dispatcher: EventDispatcherInterface,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
):
"""
Initialize the DefaultCommandExecutor instance.
Args:
failure_detectors: List of failure detector instances to monitor database health
databases: Collection of available databases to execute commands on
command_retry: Retry policy for failed command execution
failover_strategy: Strategy for handling database failover
event_dispatcher: Interface for dispatching events
failover_attempts: Number of failover attempts
failover_delay: Delay between failover attempts
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
"""
super().__init__(auto_fallback_interval)
for fd in failure_detectors:
fd.set_command_executor(command_executor=self)
self._databases = databases
self._failure_detectors = failure_detectors
self._command_retry = command_retry
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
failover_strategy, failover_attempts, failover_delay
)
self._event_dispatcher = event_dispatcher
self._active_database: Optional[Database] = None
self._active_pubsub: Optional[PubSub] = None
self._active_pubsub_kwargs = {}
self._setup_event_dispatcher()
self._schedule_next_fallback()
@property
def databases(self) -> Databases:
return self._databases
@property
def failure_detectors(self) -> List[FailureDetector]:
return self._failure_detectors
def add_failure_detector(self, failure_detector: FailureDetector) -> None:
self._failure_detectors.append(failure_detector)
@property
def command_retry(self) -> Retry:
return self._command_retry
@property
def active_database(self) -> Optional[SyncDatabase]:
return self._active_database
@active_database.setter
def active_database(self, database: SyncDatabase) -> None:
old_active = self._active_database
self._active_database = database
if old_active is not None and old_active is not database:
self._event_dispatcher.dispatch(
ActiveDatabaseChanged(
old_active,
self._active_database,
self,
**self._active_pubsub_kwargs,
)
)
@property
def active_pubsub(self) -> Optional[PubSub]:
return self._active_pubsub
@active_pubsub.setter
def active_pubsub(self, pubsub: PubSub) -> None:
self._active_pubsub = pubsub
@property
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
return self._failover_strategy_executor
def execute_command(self, *args, **options):
def callback():
response = self._active_database.client.execute_command(*args, **options)
self._register_command_execution(args)
return response
return self._execute_with_failure_detection(callback, args)
def execute_pipeline(self, command_stack: tuple):
def callback():
with self._active_database.client.pipeline() as pipe:
for command, options in command_stack:
pipe.execute_command(*command, **options)
response = pipe.execute()
self._register_command_execution(command_stack)
return response
return self._execute_with_failure_detection(callback, command_stack)
def execute_transaction(
self, transaction: Callable[[Pipeline], None], *watches, **options
):
def callback():
response = self._active_database.client.transaction(
transaction, *watches, **options
)
self._register_command_execution(())
return response
return self._execute_with_failure_detection(callback)
def pubsub(self, **kwargs):
def callback():
if self._active_pubsub is None:
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
return None
return self._execute_with_failure_detection(callback)
def execute_pubsub_method(self, method_name: str, *args, **kwargs):
def callback():
method = getattr(self.active_pubsub, method_name)
response = method(*args, **kwargs)
self._register_command_execution(args)
return response
return self._execute_with_failure_detection(callback, *args)
def execute_pubsub_run(self, sleep_time, **kwargs) -> "PubSubWorkerThread":
def callback():
return self._active_pubsub.run_in_thread(sleep_time, **kwargs)
return self._execute_with_failure_detection(callback)
def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()):
"""
Execute a commands execution callback with failure detection.
"""
def wrapper():
# On each retry we need to check active database as it might change.
self._check_active_database()
return callback()
return self._command_retry.call_with_retry(
lambda: wrapper(),
lambda error: self._on_command_fail(error, *cmds),
)
def _on_command_fail(self, error, *args):
self._event_dispatcher.dispatch(OnCommandsFailEvent(args, error))
def _check_active_database(self):
"""
Checks if active a database needs to be updated.
"""
if (
self._active_database is None
or self._active_database.circuit.state != CBState.CLOSED
or (
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
and self._next_fallback_attempt <= datetime.now()
)
):
self.active_database = self._failover_strategy_executor.execute()
self._schedule_next_fallback()
def _register_command_execution(self, cmd: tuple):
for detector in self._failure_detectors:
detector.register_command_execution(cmd)
def _setup_event_dispatcher(self):
"""
Registers necessary listeners.
"""
failure_listener = RegisterCommandFailure(self._failure_detectors)
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
self._event_dispatcher.register_listeners(
{
OnCommandsFailEvent: [failure_listener],
ActiveDatabaseChanged: [
close_connection_listener,
resubscribe_listener,
],
}
)

View File

@@ -0,0 +1,207 @@
from dataclasses import dataclass, field
from typing import List, Type, Union
import pybreaker
from typing_extensions import Optional
from redis import ConnectionPool, Redis, RedisCluster
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
from redis.data_structure import WeightedList
from redis.event import EventDispatcher, EventDispatcherInterface
from redis.multidb.circuit import (
DEFAULT_GRACE_PERIOD,
CircuitBreaker,
PBCircuitBreakerAdapter,
)
from redis.multidb.database import Database, Databases
from redis.multidb.failover import (
DEFAULT_FAILOVER_ATTEMPTS,
DEFAULT_FAILOVER_DELAY,
FailoverStrategy,
WeightBasedFailoverStrategy,
)
from redis.multidb.failure_detector import (
DEFAULT_FAILURE_RATE_THRESHOLD,
DEFAULT_FAILURES_DETECTION_WINDOW,
DEFAULT_MIN_NUM_FAILURES,
CommandFailureDetector,
FailureDetector,
)
from redis.multidb.healthcheck import (
DEFAULT_HEALTH_CHECK_DELAY,
DEFAULT_HEALTH_CHECK_INTERVAL,
DEFAULT_HEALTH_CHECK_POLICY,
DEFAULT_HEALTH_CHECK_PROBES,
HealthCheck,
HealthCheckPolicies,
PingHealthCheck,
)
from redis.retry import Retry
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
def default_event_dispatcher() -> EventDispatcherInterface:
return EventDispatcher()
@dataclass
class DatabaseConfig:
"""
Dataclass representing the configuration for a database connection.
This class is used to store configuration settings for a database connection,
including client options, connection sourcing details, circuit breaker settings,
and cluster-specific properties. It provides a structure for defining these
attributes and allows for the creation of customized configurations for various
database setups.
Attributes:
weight (float): Weight of the database to define the active one.
client_kwargs (dict): Additional parameters for the database client connection.
from_url (Optional[str]): Redis URL way of connecting to the database.
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
on public Redis Enterprise endpoints.
Methods:
default_circuit_breaker:
Generates and returns a default CircuitBreaker instance adapted for use.
"""
weight: float = 1.0
client_kwargs: dict = field(default_factory=dict)
from_url: Optional[str] = None
from_pool: Optional[ConnectionPool] = None
circuit: Optional[CircuitBreaker] = None
grace_period: float = DEFAULT_GRACE_PERIOD
health_check_url: Optional[str] = None
def default_circuit_breaker(self) -> CircuitBreaker:
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
return PBCircuitBreakerAdapter(circuit_breaker)
@dataclass
class MultiDbConfig:
"""
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
Attributes:
databases_config: A list of database configurations.
client_class: The client class used to manage database connections.
command_retry: Retry strategy for executing database commands.
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
min_num_failures: Minimal count of failures required for failover
failure_rate_threshold: Percentage of failures required for failover
failures_detection_window: Time interval for tracking database failures.
health_checks: Optional list of additional health checks performed on databases.
health_check_interval: Time interval for executing health checks.
health_check_probes: Number of attempts to evaluate the health of a database.
health_check_probes_delay: Delay between health check attempts.
health_check_policy: Policy for determining database health based on health checks.
failover_strategy: Optional strategy for handling database failover scenarios.
failover_attempts: Number of retries allowed for failover operations.
failover_delay: Delay between failover attempts.
auto_fallback_interval: Time interval to trigger automatic fallback.
event_dispatcher: Interface for dispatching events related to database operations.
Methods:
databases:
Retrieves a collection of database clients managed by weighted configurations.
Initializes database clients based on the provided configuration and removes
redundant retry objects for lower-level clients to rely on global retry logic.
default_failure_detectors:
Returns the default list of failure detectors used to monitor database failures.
default_health_checks:
Returns the default list of health checks used to monitor database health
with specific retry and backoff strategies.
default_failover_strategy:
Provides the default failover strategy used for handling failover scenarios
with defined retry and backoff configurations.
"""
databases_config: List[DatabaseConfig]
client_class: Type[Union[Redis, RedisCluster]] = Redis
command_retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
)
failure_detectors: Optional[List[FailureDetector]] = None
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
health_checks: Optional[List[HealthCheck]] = None
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
health_check_probes_delay: float = DEFAULT_HEALTH_CHECK_DELAY
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
failover_strategy: Optional[FailoverStrategy] = None
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
failover_delay: float = DEFAULT_FAILOVER_DELAY
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
event_dispatcher: EventDispatcherInterface = field(
default_factory=default_event_dispatcher
)
def databases(self) -> Databases:
databases = WeightedList()
for database_config in self.databases_config:
# The retry object is not used in the lower level clients, so we can safely remove it.
# We rely on command_retry in terms of global retries.
database_config.client_kwargs.update(
{"retry": Retry(retries=0, backoff=NoBackoff())}
)
if database_config.from_url:
client = self.client_class.from_url(
database_config.from_url, **database_config.client_kwargs
)
elif database_config.from_pool:
database_config.from_pool.set_retry(
Retry(retries=0, backoff=NoBackoff())
)
client = self.client_class.from_pool(
connection_pool=database_config.from_pool
)
else:
client = self.client_class(**database_config.client_kwargs)
circuit = (
database_config.default_circuit_breaker()
if database_config.circuit is None
else database_config.circuit
)
databases.add(
Database(
client=client,
circuit=circuit,
weight=database_config.weight,
health_check_url=database_config.health_check_url,
),
database_config.weight,
)
return databases
def default_failure_detectors(self) -> List[FailureDetector]:
return [
CommandFailureDetector(
min_num_failures=self.min_num_failures,
failure_rate_threshold=self.failure_rate_threshold,
failure_detection_window=self.failures_detection_window,
),
]
def default_health_checks(self) -> List[HealthCheck]:
return [
PingHealthCheck(),
]
def default_failover_strategy(self) -> FailoverStrategy:
return WeightBasedFailoverStrategy()

View File

@@ -0,0 +1,130 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
import redis
from redis import RedisCluster
from redis.data_structure import WeightedList
from redis.multidb.circuit import CircuitBreaker
from redis.typing import Number
class AbstractDatabase(ABC):
@property
@abstractmethod
def weight(self) -> float:
"""The weight of this database in compare to others. Used to determine the database failover to."""
pass
@weight.setter
@abstractmethod
def weight(self, weight: float):
"""Set the weight of this database in compare to others."""
pass
@property
@abstractmethod
def health_check_url(self) -> Optional[str]:
"""Health check URL associated with the current database."""
pass
@health_check_url.setter
@abstractmethod
def health_check_url(self, health_check_url: Optional[str]):
"""Set the health check URL associated with the current database."""
pass
class BaseDatabase(AbstractDatabase):
def __init__(
self,
weight: float,
health_check_url: Optional[str] = None,
):
self._weight = weight
self._health_check_url = health_check_url
@property
def weight(self) -> float:
return self._weight
@weight.setter
def weight(self, weight: float):
self._weight = weight
@property
def health_check_url(self) -> Optional[str]:
return self._health_check_url
@health_check_url.setter
def health_check_url(self, health_check_url: Optional[str]):
self._health_check_url = health_check_url
class SyncDatabase(AbstractDatabase):
"""Database with an underlying synchronous redis client."""
@property
@abstractmethod
def client(self) -> Union[redis.Redis, RedisCluster]:
"""The underlying redis client."""
pass
@client.setter
@abstractmethod
def client(self, client: Union[redis.Redis, RedisCluster]):
"""Set the underlying redis client."""
pass
@property
@abstractmethod
def circuit(self) -> CircuitBreaker:
"""Circuit breaker for the current database."""
pass
@circuit.setter
@abstractmethod
def circuit(self, circuit: CircuitBreaker):
"""Set the circuit breaker for the current database."""
pass
Databases = WeightedList[tuple[SyncDatabase, Number]]
class Database(BaseDatabase, SyncDatabase):
def __init__(
self,
client: Union[redis.Redis, RedisCluster],
circuit: CircuitBreaker,
weight: float,
health_check_url: Optional[str] = None,
):
"""
Initialize a new Database instance.
Args:
client: Underlying Redis client instance for database operations
circuit: Circuit breaker for handling database failures
weight: Weight value used for database failover prioritization
health_check_url: Health check URL associated with the current database
"""
self._client = client
self._cb = circuit
self._cb.database = self
super().__init__(weight, health_check_url)
@property
def client(self) -> Union[redis.Redis, RedisCluster]:
return self._client
@client.setter
def client(self, client: Union[redis.Redis, RedisCluster]):
self._client = client
@property
def circuit(self) -> CircuitBreaker:
return self._cb
@circuit.setter
def circuit(self, circuit: CircuitBreaker):
self._cb = circuit

View File

@@ -0,0 +1,89 @@
from typing import List
from redis.client import Redis
from redis.event import EventListenerInterface, OnCommandsFailEvent
from redis.multidb.database import SyncDatabase
from redis.multidb.failure_detector import FailureDetector
class ActiveDatabaseChanged:
"""
Event fired when an active database has been changed.
"""
def __init__(
self,
old_database: SyncDatabase,
new_database: SyncDatabase,
command_executor,
**kwargs,
):
self._old_database = old_database
self._new_database = new_database
self._command_executor = command_executor
self._kwargs = kwargs
@property
def old_database(self) -> SyncDatabase:
return self._old_database
@property
def new_database(self) -> SyncDatabase:
return self._new_database
@property
def command_executor(self):
return self._command_executor
@property
def kwargs(self):
return self._kwargs
class ResubscribeOnActiveDatabaseChanged(EventListenerInterface):
"""
Re-subscribe the currently active pub / sub to a new active database.
"""
def listen(self, event: ActiveDatabaseChanged):
old_pubsub = event.command_executor.active_pubsub
if old_pubsub is not None:
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
new_pubsub.channels = old_pubsub.channels
new_pubsub.patterns = old_pubsub.patterns
new_pubsub.shard_channels = old_pubsub.shard_channels
new_pubsub.on_connect(None)
event.command_executor.active_pubsub = new_pubsub
old_pubsub.close()
class CloseConnectionOnActiveDatabaseChanged(EventListenerInterface):
"""
Close connection to the old active database.
"""
def listen(self, event: ActiveDatabaseChanged):
event.old_database.client.close()
if isinstance(event.old_database.client, Redis):
event.old_database.client.connection_pool.update_active_connections_for_reconnect()
event.old_database.client.connection_pool.disconnect()
else:
for node in event.old_database.client.nodes_manager.nodes_cache.values():
node.redis_connection.connection_pool.update_active_connections_for_reconnect()
node.redis_connection.connection_pool.disconnect()
class RegisterCommandFailure(EventListenerInterface):
"""
Event listener that registers command failures and passing it to the failure detectors.
"""
def __init__(self, failure_detectors: List[FailureDetector]):
self._failure_detectors = failure_detectors
def listen(self, event: OnCommandsFailEvent) -> None:
for failure_detector in self._failure_detectors:
failure_detector.register_failure(event.exception, event.commands)

View File

@@ -0,0 +1,17 @@
class NoValidDatabaseException(Exception):
pass
class UnhealthyDatabaseException(Exception):
"""Exception raised when a database is unhealthy due to an underlying exception."""
def __init__(self, message, database, original_exception):
super().__init__(message)
self.database = database
self.original_exception = original_exception
class TemporaryUnavailableException(Exception):
"""Exception raised when all databases in setup are temporary unavailable."""
pass

View File

@@ -0,0 +1,125 @@
import time
from abc import ABC, abstractmethod
from redis.data_structure import WeightedList
from redis.multidb.circuit import State as CBState
from redis.multidb.database import Databases, SyncDatabase
from redis.multidb.exception import (
NoValidDatabaseException,
TemporaryUnavailableException,
)
DEFAULT_FAILOVER_ATTEMPTS = 10
DEFAULT_FAILOVER_DELAY = 12
class FailoverStrategy(ABC):
@abstractmethod
def database(self) -> SyncDatabase:
"""Select the database according to the strategy."""
pass
@abstractmethod
def set_databases(self, databases: Databases) -> None:
"""Set the database strategy operates on."""
pass
class FailoverStrategyExecutor(ABC):
@property
@abstractmethod
def failover_attempts(self) -> int:
"""The number of failover attempts."""
pass
@property
@abstractmethod
def failover_delay(self) -> float:
"""The delay between failover attempts."""
pass
@property
@abstractmethod
def strategy(self) -> FailoverStrategy:
"""The strategy to execute."""
pass
@abstractmethod
def execute(self) -> SyncDatabase:
"""Execute the failover strategy."""
pass
class WeightBasedFailoverStrategy(FailoverStrategy):
"""
Failover strategy based on database weights.
"""
def __init__(self) -> None:
self._databases = WeightedList()
def database(self) -> SyncDatabase:
for database, _ in self._databases:
if database.circuit.state == CBState.CLOSED:
return database
raise NoValidDatabaseException("No valid database available for communication")
def set_databases(self, databases: Databases) -> None:
self._databases = databases
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
"""
Executes given failover strategy.
"""
def __init__(
self,
strategy: FailoverStrategy,
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
failover_delay: float = DEFAULT_FAILOVER_DELAY,
):
self._strategy = strategy
self._failover_attempts = failover_attempts
self._failover_delay = failover_delay
self._next_attempt_ts: int = 0
self._failover_counter: int = 0
@property
def failover_attempts(self) -> int:
return self._failover_attempts
@property
def failover_delay(self) -> float:
return self._failover_delay
@property
def strategy(self) -> FailoverStrategy:
return self._strategy
def execute(self) -> SyncDatabase:
try:
database = self._strategy.database()
self._reset()
return database
except NoValidDatabaseException as e:
if self._next_attempt_ts == 0:
self._next_attempt_ts = time.time() + self._failover_delay
self._failover_counter += 1
elif time.time() >= self._next_attempt_ts:
self._next_attempt_ts += self._failover_delay
self._failover_counter += 1
if self._failover_counter > self._failover_attempts:
self._reset()
raise e
else:
raise TemporaryUnavailableException(
"No database connections currently available. "
"This is a temporary condition - please retry the operation."
)
def _reset(self) -> None:
self._next_attempt_ts = 0
self._failover_counter = 0

View File

@@ -0,0 +1,104 @@
import math
import threading
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import List, Type
from typing_extensions import Optional
from redis.multidb.circuit import State as CBState
DEFAULT_MIN_NUM_FAILURES = 1000
DEFAULT_FAILURE_RATE_THRESHOLD = 0.1
DEFAULT_FAILURES_DETECTION_WINDOW = 2
class FailureDetector(ABC):
@abstractmethod
def register_failure(self, exception: Exception, cmd: tuple) -> None:
"""Register a failure that occurred during command execution."""
pass
@abstractmethod
def register_command_execution(self, cmd: tuple) -> None:
"""Register a command execution."""
pass
@abstractmethod
def set_command_executor(self, command_executor) -> None:
"""Set the command executor for this failure."""
pass
class CommandFailureDetector(FailureDetector):
"""
Detects a failure based on a threshold of failed commands during a specific period of time.
"""
def __init__(
self,
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES,
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD,
failure_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW,
error_types: Optional[List[Type[Exception]]] = None,
) -> None:
"""
Initialize a new CommandFailureDetector instance.
Args:
min_num_failures: Minimal count of failures required for failover
failure_rate_threshold: Percentage of failures required for failover
failure_detection_window: Time interval for executing health checks.
error_types: Optional list of exception types to trigger failover. If None, all exceptions are counted.
The detector tracks command failures within a sliding time window. When the number of failures
exceeds the threshold within the specified duration, it triggers failure detection.
"""
self._command_executor = None
self._min_num_failures = min_num_failures
self._failure_rate_threshold = failure_rate_threshold
self._failure_detection_window = failure_detection_window
self._error_types = error_types
self._commands_executed: int = 0
self._start_time: datetime = datetime.now()
self._end_time: datetime = self._start_time + timedelta(
seconds=self._failure_detection_window
)
self._failures_count: int = 0
self._lock = threading.RLock()
def register_failure(self, exception: Exception, cmd: tuple) -> None:
with self._lock:
if self._error_types:
if type(exception) in self._error_types:
self._failures_count += 1
else:
self._failures_count += 1
self._check_threshold()
def set_command_executor(self, command_executor) -> None:
self._command_executor = command_executor
def register_command_execution(self, cmd: tuple) -> None:
with self._lock:
if not self._start_time < datetime.now() < self._end_time:
self._reset()
self._commands_executed += 1
def _check_threshold(self):
if self._failures_count >= self._min_num_failures and self._failures_count >= (
math.ceil(self._commands_executed * self._failure_rate_threshold)
):
self._command_executor.active_database.circuit.state = CBState.OPEN
self._reset()
def _reset(self) -> None:
with self._lock:
self._start_time = datetime.now()
self._end_time = self._start_time + timedelta(
seconds=self._failure_detection_window
)
self._failures_count = 0
self._commands_executed = 0

View File

@@ -0,0 +1,282 @@
import logging
from abc import ABC, abstractmethod
from enum import Enum
from time import sleep
from typing import List, Optional, Tuple, Union
from redis import Redis
from redis.backoff import NoBackoff
from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient
from redis.multidb.exception import UnhealthyDatabaseException
from redis.retry import Retry
DEFAULT_HEALTH_CHECK_PROBES = 3
DEFAULT_HEALTH_CHECK_INTERVAL = 5
DEFAULT_HEALTH_CHECK_DELAY = 0.5
DEFAULT_LAG_AWARE_TOLERANCE = 5000
logger = logging.getLogger(__name__)
class HealthCheck(ABC):
@abstractmethod
def check_health(self, database) -> bool:
"""Function to determine the health status."""
pass
class HealthCheckPolicy(ABC):
"""
Health checks execution policy.
"""
@property
@abstractmethod
def health_check_probes(self) -> int:
"""Number of probes to execute health checks."""
pass
@property
@abstractmethod
def health_check_delay(self) -> float:
"""Delay between health check probes."""
pass
@abstractmethod
def execute(self, health_checks: List[HealthCheck], database) -> bool:
"""Execute health checks and return database health status."""
pass
class AbstractHealthCheckPolicy(HealthCheckPolicy):
def __init__(self, health_check_probes: int, health_check_delay: float):
if health_check_probes < 1:
raise ValueError("health_check_probes must be greater than 0")
self._health_check_probes = health_check_probes
self._health_check_delay = health_check_delay
@property
def health_check_probes(self) -> int:
return self._health_check_probes
@property
def health_check_delay(self) -> float:
return self._health_check_delay
@abstractmethod
def execute(self, health_checks: List[HealthCheck], database) -> bool:
pass
class HealthyAllPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if all health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
for attempt in range(self.health_check_probes):
try:
if not health_check.check_health(database):
return False
except Exception as e:
raise UnhealthyDatabaseException("Unhealthy database", database, e)
if attempt < self.health_check_probes - 1:
sleep(self._health_check_delay)
return True
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if a majority of health check probes are successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
def execute(self, health_checks: List[HealthCheck], database) -> bool:
for health_check in health_checks:
if self.health_check_probes % 2 == 0:
allowed_unsuccessful_probes = self.health_check_probes / 2
else:
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
for attempt in range(self.health_check_probes):
try:
if not health_check.check_health(database):
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
return False
except Exception as e:
allowed_unsuccessful_probes -= 1
if allowed_unsuccessful_probes <= 0:
raise UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
sleep(self._health_check_delay)
return True
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
"""
Policy that returns True if at least one health check probe is successful.
"""
def __init__(self, health_check_probes: int, health_check_delay: float):
super().__init__(health_check_probes, health_check_delay)
def execute(self, health_checks: List[HealthCheck], database) -> bool:
is_healthy = False
for health_check in health_checks:
exception = None
for attempt in range(self.health_check_probes):
try:
if health_check.check_health(database):
is_healthy = True
break
else:
is_healthy = False
except Exception as e:
exception = UnhealthyDatabaseException(
"Unhealthy database", database, e
)
if attempt < self.health_check_probes - 1:
sleep(self._health_check_delay)
if not is_healthy and not exception:
return is_healthy
elif not is_healthy and exception:
raise exception
return is_healthy
class HealthCheckPolicies(Enum):
HEALTHY_ALL = HealthyAllPolicy
HEALTHY_MAJORITY = HealthyMajorityPolicy
HEALTHY_ANY = HealthyAnyPolicy
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
class PingHealthCheck(HealthCheck):
"""
Health check based on PING command.
"""
def check_health(self, database) -> bool:
if isinstance(database.client, Redis):
return database.client.execute_command("PING")
else:
# For a cluster checks if all nodes are healthy.
all_nodes = database.client.get_nodes()
for node in all_nodes:
if not node.redis_connection.execute_command("PING"):
return False
return True
class LagAwareHealthCheck(HealthCheck):
"""
Health check available for Redis Enterprise deployments.
Verify via REST API that the database is healthy based on different lags.
"""
def __init__(
self,
rest_api_port: int = 9443,
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
timeout: float = DEFAULT_TIMEOUT,
auth_basic: Optional[Tuple[str, str]] = None,
verify_tls: bool = True,
# TLS verification (server) options
ca_file: Optional[str] = None,
ca_path: Optional[str] = None,
ca_data: Optional[Union[str, bytes]] = None,
# Mutual TLS (client cert) options
client_cert_file: Optional[str] = None,
client_key_file: Optional[str] = None,
client_key_password: Optional[str] = None,
):
"""
Initialize LagAwareHealthCheck with the specified parameters.
Args:
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
auth_basic: Tuple of (username, password) for basic authentication
verify_tls: Whether to verify TLS certificates (default: True)
ca_file: Path to CA certificate file for TLS verification
ca_path: Path to CA certificates directory for TLS verification
ca_data: CA certificate data as string or bytes
client_cert_file: Path to client certificate file for mutual TLS
client_key_file: Path to client private key file for mutual TLS
client_key_password: Password for encrypted client private key
"""
self._http_client = HttpClient(
timeout=timeout,
auth_basic=auth_basic,
retry=Retry(NoBackoff(), retries=0),
verify_tls=verify_tls,
ca_file=ca_file,
ca_path=ca_path,
ca_data=ca_data,
client_cert_file=client_cert_file,
client_key_file=client_key_file,
client_key_password=client_key_password,
)
self._rest_api_port = rest_api_port
self._lag_aware_tolerance = lag_aware_tolerance
def check_health(self, database) -> bool:
if database.health_check_url is None:
raise ValueError(
"Database health check url is not set. Please check DatabaseConfig for the current database."
)
if isinstance(database.client, Redis):
db_host = database.client.get_connection_kwargs()["host"]
else:
db_host = database.client.startup_nodes[0].host
base_url = f"{database.health_check_url}:{self._rest_api_port}"
self._http_client.base_url = base_url
# Find bdb matching to the current database host
matching_bdb = None
for bdb in self._http_client.get("/v1/bdbs"):
for endpoint in bdb["endpoints"]:
if endpoint["dns_name"] == db_host:
matching_bdb = bdb
break
# In case if the host was set as public IP
for addr in endpoint["addr"]:
if addr == db_host:
matching_bdb = bdb
break
if matching_bdb is None:
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
raise ValueError("Could not find a matching bdb")
url = (
f"/v1/bdbs/{matching_bdb['uid']}/availability"
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
)
self._http_client.get(url, expect_json=False)
# Status checked in an http client, otherwise HttpError will be raised
return True

View File

@@ -0,0 +1,308 @@
import base64
import datetime
import ssl
from urllib.parse import urljoin, urlparse
import cryptography.hazmat.primitives.hashes
import requests
from cryptography import hazmat, x509
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat import backends
from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA, EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from cryptography.hazmat.primitives.hashes import SHA1, Hash
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from cryptography.x509 import ocsp
from redis.exceptions import AuthorizationError, ConnectionError
def _verify_response(issuer_cert, ocsp_response):
pubkey = issuer_cert.public_key()
try:
if isinstance(pubkey, RSAPublicKey):
pubkey.verify(
ocsp_response.signature,
ocsp_response.tbs_response_bytes,
PKCS1v15(),
ocsp_response.signature_hash_algorithm,
)
elif isinstance(pubkey, DSAPublicKey):
pubkey.verify(
ocsp_response.signature,
ocsp_response.tbs_response_bytes,
ocsp_response.signature_hash_algorithm,
)
elif isinstance(pubkey, EllipticCurvePublicKey):
pubkey.verify(
ocsp_response.signature,
ocsp_response.tbs_response_bytes,
ECDSA(ocsp_response.signature_hash_algorithm),
)
else:
pubkey.verify(ocsp_response.signature, ocsp_response.tbs_response_bytes)
except InvalidSignature:
raise ConnectionError("failed to valid ocsp response")
def _check_certificate(issuer_cert, ocsp_bytes, validate=True):
"""A wrapper the return the validity of a known ocsp certificate"""
ocsp_response = ocsp.load_der_ocsp_response(ocsp_bytes)
if ocsp_response.response_status == ocsp.OCSPResponseStatus.UNAUTHORIZED:
raise AuthorizationError("you are not authorized to view this ocsp certificate")
if ocsp_response.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL:
if ocsp_response.certificate_status != ocsp.OCSPCertStatus.GOOD:
raise ConnectionError(
f"Received an {str(ocsp_response.certificate_status).split('.')[1]} "
"ocsp certificate status"
)
else:
raise ConnectionError(
"failed to retrieve a successful response from the ocsp responder"
)
if ocsp_response.this_update >= datetime.datetime.now():
raise ConnectionError("ocsp certificate was issued in the future")
if (
ocsp_response.next_update
and ocsp_response.next_update < datetime.datetime.now()
):
raise ConnectionError("ocsp certificate has invalid update - in the past")
responder_name = ocsp_response.responder_name
issuer_hash = ocsp_response.issuer_key_hash
responder_hash = ocsp_response.responder_key_hash
cert_to_validate = issuer_cert
if (
responder_name is not None
and responder_name == issuer_cert.subject
or responder_hash == issuer_hash
):
cert_to_validate = issuer_cert
else:
certs = ocsp_response.certificates
responder_certs = _get_certificates(
certs, issuer_cert, responder_name, responder_hash
)
try:
responder_cert = responder_certs[0]
except IndexError:
raise ConnectionError("no certificates found for the responder")
ext = responder_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage)
if ext is None or x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING not in ext.value:
raise ConnectionError("delegate not autorized for ocsp signing")
cert_to_validate = responder_cert
if validate:
_verify_response(cert_to_validate, ocsp_response)
return True
def _get_certificates(certs, issuer_cert, responder_name, responder_hash):
if responder_name is None:
certificates = [
c
for c in certs
if _get_pubkey_hash(c) == responder_hash and c.issuer == issuer_cert.subject
]
else:
certificates = [
c
for c in certs
if c.subject == responder_name and c.issuer == issuer_cert.subject
]
return certificates
def _get_pubkey_hash(certificate):
pubkey = certificate.public_key()
# https://stackoverflow.com/a/46309453/600498
if isinstance(pubkey, RSAPublicKey):
h = pubkey.public_bytes(Encoding.DER, PublicFormat.PKCS1)
elif isinstance(pubkey, EllipticCurvePublicKey):
h = pubkey.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint)
else:
h = pubkey.public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
sha1 = Hash(SHA1(), backend=backends.default_backend())
sha1.update(h)
return sha1.finalize()
def ocsp_staple_verifier(con, ocsp_bytes, expected=None):
"""An implementation of a function for set_ocsp_client_callback in PyOpenSSL.
This function validates that the provide ocsp_bytes response is valid,
and matches the expected, stapled responses.
"""
if ocsp_bytes in [b"", None]:
raise ConnectionError("no ocsp response present")
issuer_cert = None
peer_cert = con.get_peer_certificate().to_cryptography()
for c in con.get_peer_cert_chain():
cert = c.to_cryptography()
if cert.subject == peer_cert.issuer:
issuer_cert = cert
break
if issuer_cert is None:
raise ConnectionError("no matching issuer cert found in certificate chain")
if expected is not None:
e = x509.load_pem_x509_certificate(expected)
if peer_cert != e:
raise ConnectionError("received and expected certificates do not match")
return _check_certificate(issuer_cert, ocsp_bytes)
class OCSPVerifier:
"""A class to verify ssl sockets for RFC6960/RFC6961. This can be used
when using direct validation of OCSP responses and certificate revocations.
@see https://datatracker.ietf.org/doc/html/rfc6960
@see https://datatracker.ietf.org/doc/html/rfc6961
"""
def __init__(self, sock, host, port, ca_certs=None):
self.SOCK = sock
self.HOST = host
self.PORT = port
self.CA_CERTS = ca_certs
def _bin2ascii(self, der):
"""Convert SSL certificates in a binary (DER) format to ASCII PEM."""
pem = ssl.DER_cert_to_PEM_cert(der)
cert = x509.load_pem_x509_certificate(pem.encode(), backends.default_backend())
return cert
def components_from_socket(self):
"""This function returns the certificate, primary issuer, and primary ocsp
server in the chain for a socket already wrapped with ssl.
"""
# convert the binary certifcate to text
der = self.SOCK.getpeercert(True)
if der is False:
raise ConnectionError("no certificate found for ssl peer")
cert = self._bin2ascii(der)
return self._certificate_components(cert)
def _certificate_components(self, cert):
"""Given an SSL certificate, retract the useful components for
validating the certificate status with an OCSP server.
Args:
cert ([bytes]): A PEM encoded ssl certificate
"""
try:
aia = cert.extensions.get_extension_for_oid(
x509.oid.ExtensionOID.AUTHORITY_INFORMATION_ACCESS
).value
except cryptography.x509.extensions.ExtensionNotFound:
raise ConnectionError("No AIA information present in ssl certificate")
# fetch certificate issuers
issuers = [
i
for i in aia
if i.access_method == x509.oid.AuthorityInformationAccessOID.CA_ISSUERS
]
try:
issuer = issuers[0].access_location.value
except IndexError:
issuer = None
# now, the series of ocsp server entries
ocsps = [
i
for i in aia
if i.access_method == x509.oid.AuthorityInformationAccessOID.OCSP
]
try:
ocsp = ocsps[0].access_location.value
except IndexError:
raise ConnectionError("no ocsp servers in certificate")
return cert, issuer, ocsp
def components_from_direct_connection(self):
"""Return the certificate, primary issuer, and primary ocsp server
from the host defined by the socket. This is useful in cases where
different certificates are occasionally presented.
"""
pem = ssl.get_server_certificate((self.HOST, self.PORT), ca_certs=self.CA_CERTS)
cert = x509.load_pem_x509_certificate(pem.encode(), backends.default_backend())
return self._certificate_components(cert)
def build_certificate_url(self, server, cert, issuer_cert):
"""Return the complete url to the ocsp"""
orb = ocsp.OCSPRequestBuilder()
# add_certificate returns an initialized OCSPRequestBuilder
orb = orb.add_certificate(
cert, issuer_cert, cryptography.hazmat.primitives.hashes.SHA256()
)
request = orb.build()
path = base64.b64encode(
request.public_bytes(hazmat.primitives.serialization.Encoding.DER)
)
url = urljoin(server, path.decode("ascii"))
return url
def check_certificate(self, server, cert, issuer_url):
"""Checks the validity of an ocsp server for an issuer"""
r = requests.get(issuer_url)
if not r.ok:
raise ConnectionError("failed to fetch issuer certificate")
der = r.content
issuer_cert = self._bin2ascii(der)
ocsp_url = self.build_certificate_url(server, cert, issuer_cert)
# HTTP 1.1 mandates the addition of the Host header in ocsp responses
header = {
"Host": urlparse(ocsp_url).netloc,
"Content-Type": "application/ocsp-request",
}
r = requests.get(ocsp_url, headers=header)
if not r.ok:
raise ConnectionError("failed to fetch ocsp certificate")
return _check_certificate(issuer_cert, r.content, True)
def is_valid(self):
"""Returns the validity of the certificate wrapping our socket.
This first retrieves for validate the certificate, issuer_url,
and ocsp_server for certificate validate. Then retrieves the
issuer certificate from the issuer_url, and finally checks
the validity of OCSP revocation status.
"""
# validate the certificate
try:
cert, issuer_url, ocsp_server = self.components_from_socket()
if issuer_url is None:
raise ConnectionError("no issuers found in certificate chain")
return self.check_certificate(ocsp_server, cert, issuer_url)
except AuthorizationError:
cert, issuer_url, ocsp_server = self.components_from_direct_connection()
if issuer_url is None:
raise ConnectionError("no issuers found in certificate chain")
return self.check_certificate(ocsp_server, cert, issuer_url)

View File

@@ -0,0 +1,126 @@
import abc
import socket
from time import sleep
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterable,
Optional,
Tuple,
Type,
TypeVar,
)
from redis.exceptions import ConnectionError, TimeoutError
T = TypeVar("T")
E = TypeVar("E", bound=Exception, covariant=True)
if TYPE_CHECKING:
from redis.backoff import AbstractBackoff
class AbstractRetry(Generic[E], abc.ABC):
"""Retry a specific number of times after a failure"""
_supported_errors: Tuple[Type[E], ...]
def __init__(
self,
backoff: "AbstractBackoff",
retries: int,
supported_errors: Tuple[Type[E], ...],
):
"""
Initialize a `Retry` object with a `Backoff` object
that retries a maximum of `retries` times.
`retries` can be negative to retry forever.
You can specify the types of supported errors which trigger
a retry with the `supported_errors` parameter.
"""
self._backoff = backoff
self._retries = retries
self._supported_errors = supported_errors
@abc.abstractmethod
def __eq__(self, other: Any) -> bool:
return NotImplemented
def __hash__(self) -> int:
return hash((self._backoff, self._retries, frozenset(self._supported_errors)))
def update_supported_errors(self, specified_errors: Iterable[Type[E]]) -> None:
"""
Updates the supported errors with the specified error types
"""
self._supported_errors = tuple(
set(self._supported_errors + tuple(specified_errors))
)
def get_retries(self) -> int:
"""
Get the number of retries.
"""
return self._retries
def update_retries(self, value: int) -> None:
"""
Set the number of retries.
"""
self._retries = value
class Retry(AbstractRetry[Exception]):
__hash__ = AbstractRetry.__hash__
def __init__(
self,
backoff: "AbstractBackoff",
retries: int,
supported_errors: Tuple[Type[Exception], ...] = (
ConnectionError,
TimeoutError,
socket.timeout,
),
):
super().__init__(backoff, retries, supported_errors)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Retry):
return NotImplemented
return (
self._backoff == other._backoff
and self._retries == other._retries
and set(self._supported_errors) == set(other._supported_errors)
)
def call_with_retry(
self,
do: Callable[[], T],
fail: Callable[[Exception], Any],
is_retryable: Optional[Callable[[Exception], bool]] = None,
) -> T:
"""
Execute an operation that might fail and returns its result, or
raise the exception that was thrown depending on the `Backoff` object.
`do`: the operation to call. Expects no argument.
`fail`: the failure handler, expects the last error that was thrown
"""
self._backoff.reset()
failures = 0
while True:
try:
return do()
except self._supported_errors as error:
if is_retryable and not is_retryable(error):
raise
failures += 1
fail(error)
if self._retries >= 0 and failures > self._retries:
raise error
backoff = self._backoff.compute(failures)
if backoff > 0:
sleep(backoff)

View File

@@ -0,0 +1,425 @@
import random
import weakref
from typing import Optional
from redis.client import Redis
from redis.commands import SentinelCommands
from redis.connection import Connection, ConnectionPool, SSLConnection
from redis.exceptions import (
ConnectionError,
ReadOnlyError,
ResponseError,
TimeoutError,
)
class MasterNotFoundError(ConnectionError):
pass
class SlaveNotFoundError(ConnectionError):
pass
class SentinelManagedConnection(Connection):
def __init__(self, **kwargs):
self.connection_pool = kwargs.pop("connection_pool")
super().__init__(**kwargs)
def __repr__(self):
pool = self.connection_pool
s = (
f"<{type(self).__module__}.{type(self).__name__}"
f"(service={pool.service_name}%s)>"
)
if self.host:
host_info = f",host={self.host},port={self.port}"
s = s % host_info
return s
def connect_to(self, address):
self.host, self.port = address
self.connect_check_health(
check_health=self.connection_pool.check_connection,
retry_socket_connect=False,
)
def _connect_retry(self):
if self._sock:
return # already connected
if self.connection_pool.is_master:
self.connect_to(self.connection_pool.get_master_address())
else:
for slave in self.connection_pool.rotate_slaves():
try:
return self.connect_to(slave)
except ConnectionError:
continue
raise SlaveNotFoundError # Never be here
def connect(self):
return self.retry.call_with_retry(self._connect_retry, lambda error: None)
def read_response(
self,
disable_decoding=False,
*,
disconnect_on_error: Optional[bool] = False,
push_request: Optional[bool] = False,
):
try:
return super().read_response(
disable_decoding=disable_decoding,
disconnect_on_error=disconnect_on_error,
push_request=push_request,
)
except ReadOnlyError:
if self.connection_pool.is_master:
# When talking to a master, a ReadOnlyError when likely
# indicates that the previous master that we're still connected
# to has been demoted to a slave and there's a new master.
# calling disconnect will force the connection to re-query
# sentinel during the next connect() attempt.
self.disconnect()
raise ConnectionError("The previous master is now a slave")
raise
class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
pass
class SentinelConnectionPoolProxy:
def __init__(
self,
connection_pool,
is_master,
check_connection,
service_name,
sentinel_manager,
):
self.connection_pool_ref = weakref.ref(connection_pool)
self.is_master = is_master
self.check_connection = check_connection
self.service_name = service_name
self.sentinel_manager = sentinel_manager
self.reset()
def reset(self):
self.master_address = None
self.slave_rr_counter = None
def get_master_address(self):
master_address = self.sentinel_manager.discover_master(self.service_name)
if self.is_master and self.master_address != master_address:
self.master_address = master_address
# disconnect any idle connections so that they reconnect
# to the new master the next time that they are used.
connection_pool = self.connection_pool_ref()
if connection_pool is not None:
connection_pool.disconnect(inuse_connections=False)
return master_address
def rotate_slaves(self):
slaves = self.sentinel_manager.discover_slaves(self.service_name)
if slaves:
if self.slave_rr_counter is None:
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
for _ in range(len(slaves)):
self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves)
slave = slaves[self.slave_rr_counter]
yield slave
# Fallback to the master connection
try:
yield self.get_master_address()
except MasterNotFoundError:
pass
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
class SentinelConnectionPool(ConnectionPool):
"""
Sentinel backed connection pool.
If ``check_connection`` flag is set to True, SentinelManagedConnection
sends a PING command right after establishing the connection.
"""
def __init__(self, service_name, sentinel_manager, **kwargs):
kwargs["connection_class"] = kwargs.get(
"connection_class",
(
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection
),
)
self.is_master = kwargs.pop("is_master", True)
self.check_connection = kwargs.pop("check_connection", False)
self.proxy = SentinelConnectionPoolProxy(
connection_pool=self,
is_master=self.is_master,
check_connection=self.check_connection,
service_name=service_name,
sentinel_manager=sentinel_manager,
)
super().__init__(**kwargs)
self.connection_kwargs["connection_pool"] = self.proxy
self.service_name = service_name
self.sentinel_manager = sentinel_manager
def __repr__(self):
role = "master" if self.is_master else "slave"
return (
f"<{type(self).__module__}.{type(self).__name__}"
f"(service={self.service_name}({role}))>"
)
def reset(self):
super().reset()
self.proxy.reset()
@property
def master_address(self):
return self.proxy.master_address
def owns_connection(self, connection):
check = not self.is_master or (
self.is_master and self.master_address == (connection.host, connection.port)
)
parent = super()
return check and parent.owns_connection(connection)
def get_master_address(self):
return self.proxy.get_master_address()
def rotate_slaves(self):
"Round-robin slave balancer"
return self.proxy.rotate_slaves()
class Sentinel(SentinelCommands):
"""
Redis Sentinel cluster client
>>> from redis.sentinel import Sentinel
>>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1)
>>> master = sentinel.master_for('mymaster', socket_timeout=0.1)
>>> master.set('foo', 'bar')
>>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1)
>>> slave.get('foo')
b'bar'
``sentinels`` is a list of sentinel nodes. Each node is represented by
a pair (hostname, port).
``min_other_sentinels`` defined a minimum number of peers for a sentinel.
When querying a sentinel, if it doesn't meet this threshold, responses
from that sentinel won't be considered valid.
``sentinel_kwargs`` is a dictionary of connection arguments used when
connecting to sentinel instances. Any argument that can be passed to
a normal Redis connection can be specified here. If ``sentinel_kwargs`` is
not specified, any socket_timeout and socket_keepalive options specified
in ``connection_kwargs`` will be used.
``connection_kwargs`` are keyword arguments that will be used when
establishing a connection to a Redis server.
"""
def __init__(
self,
sentinels,
min_other_sentinels=0,
sentinel_kwargs=None,
force_master_ip=None,
**connection_kwargs,
):
# if sentinel_kwargs isn't defined, use the socket_* options from
# connection_kwargs
if sentinel_kwargs is None:
sentinel_kwargs = {
k: v for k, v in connection_kwargs.items() if k.startswith("socket_")
}
self.sentinel_kwargs = sentinel_kwargs
self.sentinels = [
Redis(hostname, port, **self.sentinel_kwargs)
for hostname, port in sentinels
]
self.min_other_sentinels = min_other_sentinels
self.connection_kwargs = connection_kwargs
self._force_master_ip = force_master_ip
def execute_command(self, *args, **kwargs):
"""
Execute Sentinel command in sentinel nodes.
once - If set to True, then execute the resulting command on a single
node at random, rather than across the entire sentinel cluster.
"""
once = bool(kwargs.pop("once", False))
# Check if command is supposed to return the original
# responses instead of boolean value.
return_responses = bool(kwargs.pop("return_responses", False))
if once:
response = random.choice(self.sentinels).execute_command(*args, **kwargs)
if return_responses:
return [response]
else:
return True if response else False
responses = []
for sentinel in self.sentinels:
responses.append(sentinel.execute_command(*args, **kwargs))
if return_responses:
return responses
return all(responses)
def __repr__(self):
sentinel_addresses = []
for sentinel in self.sentinels:
sentinel_addresses.append(
"{host}:{port}".format_map(sentinel.connection_pool.connection_kwargs)
)
return (
f"<{type(self).__module__}.{type(self).__name__}"
f"(sentinels=[{','.join(sentinel_addresses)}])>"
)
def check_master_state(self, state, service_name):
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
return False
# Check if our sentinel doesn't see other nodes
if state["num-other-sentinels"] < self.min_other_sentinels:
return False
return True
def discover_master(self, service_name):
"""
Asks sentinel servers for the Redis master's address corresponding
to the service labeled ``service_name``.
Returns a pair (address, port) or raises MasterNotFoundError if no
master is found.
"""
collected_errors = list()
for sentinel_no, sentinel in enumerate(self.sentinels):
try:
masters = sentinel.sentinel_masters()
except (ConnectionError, TimeoutError) as e:
collected_errors.append(f"{sentinel} - {e!r}")
continue
state = masters.get(service_name)
if state and self.check_master_state(state, service_name):
# Put this sentinel at the top of the list
self.sentinels[0], self.sentinels[sentinel_no] = (
sentinel,
self.sentinels[0],
)
ip = (
self._force_master_ip
if self._force_master_ip is not None
else state["ip"]
)
return ip, state["port"]
error_info = ""
if len(collected_errors) > 0:
error_info = f" : {', '.join(collected_errors)}"
raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}")
def filter_slaves(self, slaves):
"Remove slaves that are in an ODOWN or SDOWN state"
slaves_alive = []
for slave in slaves:
if slave["is_odown"] or slave["is_sdown"]:
continue
slaves_alive.append((slave["ip"], slave["port"]))
return slaves_alive
def discover_slaves(self, service_name):
"Returns a list of alive slaves for service ``service_name``"
for sentinel in self.sentinels:
try:
slaves = sentinel.sentinel_slaves(service_name)
except (ConnectionError, ResponseError, TimeoutError):
continue
slaves = self.filter_slaves(slaves)
if slaves:
return slaves
return []
def master_for(
self,
service_name,
redis_class=Redis,
connection_pool_class=SentinelConnectionPool,
**kwargs,
):
"""
Returns a redis client instance for the ``service_name`` master.
Sentinel client will detect failover and reconnect Redis clients
automatically.
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
used to retrieve the master's address before establishing a new
connection.
NOTE: If the master's address has changed, any cached connections to
the old master are closed.
By default clients will be a :py:class:`~redis.Redis` instance.
Specify a different class to the ``redis_class`` argument if you
desire something different.
The ``connection_pool_class`` specifies the connection pool to
use. The :py:class:`~redis.sentinel.SentinelConnectionPool`
will be used by default.
All other keyword arguments are merged with any connection_kwargs
passed to this class and passed to the connection pool as keyword
arguments to be used to initialize Redis connections.
"""
kwargs["is_master"] = True
connection_kwargs = dict(self.connection_kwargs)
connection_kwargs.update(kwargs)
return redis_class.from_pool(
connection_pool_class(service_name, self, **connection_kwargs)
)
def slave_for(
self,
service_name,
redis_class=Redis,
connection_pool_class=SentinelConnectionPool,
**kwargs,
):
"""
Returns redis client instance for the ``service_name`` slave(s).
A SentinelConnectionPool class is used to retrieve the slave's
address before establishing a new connection.
By default clients will be a :py:class:`~redis.Redis` instance.
Specify a different class to the ``redis_class`` argument if you
desire something different.
The ``connection_pool_class`` specifies the connection pool to use.
The SentinelConnectionPool will be used by default.
All other keyword arguments are merged with any connection_kwargs
passed to this class and passed to the connection pool as keyword
arguments to be used to initialize Redis connections.
"""
kwargs["is_master"] = False
connection_kwargs = dict(self.connection_kwargs)
connection_kwargs.update(kwargs)
return redis_class.from_pool(
connection_pool_class(service_name, self, **connection_kwargs)
)

View File

@@ -0,0 +1,57 @@
# from __future__ import annotations
from datetime import datetime, timedelta
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Iterable,
Mapping,
Protocol,
Type,
TypeVar,
Union,
)
if TYPE_CHECKING:
from redis._parsers import Encoder
Number = Union[int, float]
EncodedT = Union[bytes, bytearray, memoryview]
DecodedT = Union[str, int, float]
EncodableT = Union[EncodedT, DecodedT]
AbsExpiryT = Union[int, datetime]
ExpiryT = Union[int, timedelta]
ZScoreBoundT = Union[float, str] # str allows for the [ or ( prefix
BitfieldOffsetT = Union[int, str] # str allows for #x syntax
_StringLikeT = Union[bytes, str, memoryview]
KeyT = _StringLikeT # Main redis key space
PatternT = _StringLikeT # Patterns matched against keys, fields etc
FieldT = EncodableT # Fields within hash tables, streams and geo commands
KeysT = Union[KeyT, Iterable[KeyT]]
ResponseT = Union[Awaitable[Any], Any]
ChannelT = _StringLikeT
GroupT = _StringLikeT # Consumer group
ConsumerT = _StringLikeT # Consumer name
StreamIdT = Union[int, _StringLikeT]
ScriptTextT = _StringLikeT
TimeoutSecT = Union[int, float, _StringLikeT]
# Mapping is not covariant in the key type, which prevents
# Mapping[_StringLikeT, X] from accepting arguments of type Dict[str, X]. Using
# a TypeVar instead of a Union allows mappings with any of the permitted types
# to be passed. Care is needed if there is more than one such mapping in a
# type signature because they will all be required to be the same key type.
AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview)
AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview)
AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview)
ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]]
class CommandsProtocol(Protocol):
def execute_command(self, *args, **options) -> ResponseT: ...
class ClusterCommandsProtocol(CommandsProtocol):
encoder: "Encoder"

Some files were not shown because too many files have changed in this diff Show More