first commit
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
from redis.asyncio.client import Redis, StrictRedis
|
||||
from redis.asyncio.cluster import RedisCluster
|
||||
from redis.asyncio.connection import (
|
||||
BlockingConnectionPool,
|
||||
Connection,
|
||||
ConnectionPool,
|
||||
SSLConnection,
|
||||
UnixDomainSocketConnection,
|
||||
)
|
||||
from redis.asyncio.sentinel import (
|
||||
Sentinel,
|
||||
SentinelConnectionPool,
|
||||
SentinelManagedConnection,
|
||||
SentinelManagedSSLConnection,
|
||||
)
|
||||
from redis.asyncio.utils import from_url
|
||||
from redis.backoff import default_backoff
|
||||
from redis.exceptions import (
|
||||
AuthenticationError,
|
||||
AuthenticationWrongNumberOfArgsError,
|
||||
BusyLoadingError,
|
||||
ChildDeadlockedError,
|
||||
ConnectionError,
|
||||
DataError,
|
||||
InvalidResponse,
|
||||
OutOfMemoryError,
|
||||
PubSubError,
|
||||
ReadOnlyError,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
WatchError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AuthenticationError",
|
||||
"AuthenticationWrongNumberOfArgsError",
|
||||
"BlockingConnectionPool",
|
||||
"BusyLoadingError",
|
||||
"ChildDeadlockedError",
|
||||
"Connection",
|
||||
"ConnectionError",
|
||||
"ConnectionPool",
|
||||
"DataError",
|
||||
"from_url",
|
||||
"default_backoff",
|
||||
"InvalidResponse",
|
||||
"PubSubError",
|
||||
"OutOfMemoryError",
|
||||
"ReadOnlyError",
|
||||
"Redis",
|
||||
"RedisCluster",
|
||||
"RedisError",
|
||||
"ResponseError",
|
||||
"Sentinel",
|
||||
"SentinelConnectionPool",
|
||||
"SentinelManagedConnection",
|
||||
"SentinelManagedSSLConnection",
|
||||
"SSLConnection",
|
||||
"StrictRedis",
|
||||
"TimeoutError",
|
||||
"UnixDomainSocketConnection",
|
||||
"WatchError",
|
||||
]
|
||||
1675
backend/venv/lib/python3.9/site-packages/redis/asyncio/client.py
Normal file
1675
backend/venv/lib/python3.9/site-packages/redis/asyncio/client.py
Normal file
File diff suppressed because it is too large
Load Diff
2448
backend/venv/lib/python3.9/site-packages/redis/asyncio/cluster.py
Normal file
2448
backend/venv/lib/python3.9/site-packages/redis/asyncio/cluster.py
Normal file
File diff suppressed because it is too large
Load Diff
1399
backend/venv/lib/python3.9/site-packages/redis/asyncio/connection.py
Normal file
1399
backend/venv/lib/python3.9/site-packages/redis/asyncio/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,265 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
from redis.http.http_client import HttpClient, HttpResponse
|
||||
|
||||
DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)"
|
||||
DEFAULT_TIMEOUT = 30.0
|
||||
RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
|
||||
|
||||
class AsyncHTTPClient(ABC):
|
||||
@abstractmethod
|
||||
async def get(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP GET request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP DELETE request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def post(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP POST request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def put(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP PUT request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def patch(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
"""
|
||||
Invoke HTTP PATCH request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
body: Optional[Union[bytes, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> HttpResponse:
|
||||
"""
|
||||
Invoke HTTP request with given method."""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncHTTPClientWrapper(AsyncHTTPClient):
|
||||
"""
|
||||
An async wrapper around sync HTTP client with thread pool execution.
|
||||
"""
|
||||
|
||||
def __init__(self, client: HttpClient, max_workers: int = 10) -> None:
|
||||
"""
|
||||
Initialize a new HTTP client instance.
|
||||
|
||||
Args:
|
||||
client: Sync HTTP client instance.
|
||||
max_workers: Maximum number of concurrent requests.
|
||||
|
||||
The client supports both regular HTTPS with server verification and mutual TLS
|
||||
authentication. For server verification, provide CA certificate information via
|
||||
ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client
|
||||
certificate and key via client_cert_file and client_key_file.
|
||||
"""
|
||||
self.client = client
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor, self.client.get, path, params, headers, timeout, expect_json
|
||||
)
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.delete,
|
||||
path,
|
||||
params,
|
||||
headers,
|
||||
timeout,
|
||||
expect_json,
|
||||
)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.post,
|
||||
path,
|
||||
json_body,
|
||||
data,
|
||||
params,
|
||||
headers,
|
||||
timeout,
|
||||
expect_json,
|
||||
)
|
||||
|
||||
async def put(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.put,
|
||||
path,
|
||||
json_body,
|
||||
data,
|
||||
params,
|
||||
headers,
|
||||
timeout,
|
||||
expect_json,
|
||||
)
|
||||
|
||||
async def patch(
|
||||
self,
|
||||
path: str,
|
||||
json_body: Optional[Any] = None,
|
||||
data: Optional[Union[bytes, str]] = None,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
expect_json: bool = True,
|
||||
) -> Union[HttpResponse, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.patch,
|
||||
path,
|
||||
json_body,
|
||||
data,
|
||||
params,
|
||||
headers,
|
||||
timeout,
|
||||
expect_json,
|
||||
)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[
|
||||
Mapping[str, Union[None, str, int, float, bool, list, tuple]]
|
||||
] = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
body: Optional[Union[bytes, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> HttpResponse:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self._executor,
|
||||
self.client.request,
|
||||
method,
|
||||
path,
|
||||
params,
|
||||
headers,
|
||||
body,
|
||||
timeout,
|
||||
)
|
||||
334
backend/venv/lib/python3.9/site-packages/redis/asyncio/lock.py
Normal file
334
backend/venv/lib/python3.9/site-packages/redis/asyncio/lock.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Awaitable, Optional, Union
|
||||
|
||||
from redis.exceptions import LockError, LockNotOwnedError
|
||||
from redis.typing import Number
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Lock:
|
||||
"""
|
||||
A shared, distributed Lock. Using Redis for locking allows the Lock
|
||||
to be shared across processes and/or machines.
|
||||
|
||||
It's left to the user to resolve deadlock issues and make sure
|
||||
multiple clients play nicely together.
|
||||
"""
|
||||
|
||||
lua_release = None
|
||||
lua_extend = None
|
||||
lua_reacquire = None
|
||||
|
||||
# KEYS[1] - lock name
|
||||
# ARGV[1] - token
|
||||
# return 1 if the lock was released, otherwise 0
|
||||
LUA_RELEASE_SCRIPT = """
|
||||
local token = redis.call('get', KEYS[1])
|
||||
if not token or token ~= ARGV[1] then
|
||||
return 0
|
||||
end
|
||||
redis.call('del', KEYS[1])
|
||||
return 1
|
||||
"""
|
||||
|
||||
# KEYS[1] - lock name
|
||||
# ARGV[1] - token
|
||||
# ARGV[2] - additional milliseconds
|
||||
# ARGV[3] - "0" if the additional time should be added to the lock's
|
||||
# existing ttl or "1" if the existing ttl should be replaced
|
||||
# return 1 if the locks time was extended, otherwise 0
|
||||
LUA_EXTEND_SCRIPT = """
|
||||
local token = redis.call('get', KEYS[1])
|
||||
if not token or token ~= ARGV[1] then
|
||||
return 0
|
||||
end
|
||||
local expiration = redis.call('pttl', KEYS[1])
|
||||
if not expiration then
|
||||
expiration = 0
|
||||
end
|
||||
if expiration < 0 then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newttl = ARGV[2]
|
||||
if ARGV[3] == "0" then
|
||||
newttl = ARGV[2] + expiration
|
||||
end
|
||||
redis.call('pexpire', KEYS[1], newttl)
|
||||
return 1
|
||||
"""
|
||||
|
||||
# KEYS[1] - lock name
|
||||
# ARGV[1] - token
|
||||
# ARGV[2] - milliseconds
|
||||
# return 1 if the locks time was reacquired, otherwise 0
|
||||
LUA_REACQUIRE_SCRIPT = """
|
||||
local token = redis.call('get', KEYS[1])
|
||||
if not token or token ~= ARGV[1] then
|
||||
return 0
|
||||
end
|
||||
redis.call('pexpire', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis: Union["Redis", "RedisCluster"],
|
||||
name: Union[str, bytes, memoryview],
|
||||
timeout: Optional[float] = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
thread_local: bool = True,
|
||||
raise_on_release_error: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a new Lock instance named ``name`` using the Redis client
|
||||
supplied by ``redis``.
|
||||
|
||||
``timeout`` indicates a maximum life for the lock in seconds.
|
||||
By default, it will remain locked until release() is called.
|
||||
``timeout`` can be specified as a float or integer, both representing
|
||||
the number of seconds to wait.
|
||||
|
||||
``sleep`` indicates the amount of time to sleep in seconds per loop
|
||||
iteration when the lock is in blocking mode and another client is
|
||||
currently holding the lock.
|
||||
|
||||
``blocking`` indicates whether calling ``acquire`` should block until
|
||||
the lock has been acquired or to fail immediately, causing ``acquire``
|
||||
to return False and the lock not being acquired. Defaults to True.
|
||||
Note this value can be overridden by passing a ``blocking``
|
||||
argument to ``acquire``.
|
||||
|
||||
``blocking_timeout`` indicates the maximum amount of time in seconds to
|
||||
spend trying to acquire the lock. A value of ``None`` indicates
|
||||
continue trying forever. ``blocking_timeout`` can be specified as a
|
||||
float or integer, both representing the number of seconds to wait.
|
||||
|
||||
``thread_local`` indicates whether the lock token is placed in
|
||||
thread-local storage. By default, the token is placed in thread local
|
||||
storage so that a thread only sees its token, not a token set by
|
||||
another thread. Consider the following timeline:
|
||||
|
||||
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
|
||||
thread-1 sets the token to "abc"
|
||||
time: 1, thread-2 blocks trying to acquire `my-lock` using the
|
||||
Lock instance.
|
||||
time: 5, thread-1 has not yet completed. redis expires the lock
|
||||
key.
|
||||
time: 5, thread-2 acquired `my-lock` now that it's available.
|
||||
thread-2 sets the token to "xyz"
|
||||
time: 6, thread-1 finishes its work and calls release(). if the
|
||||
token is *not* stored in thread local storage, then
|
||||
thread-1 would see the token value as "xyz" and would be
|
||||
able to successfully release the thread-2's lock.
|
||||
|
||||
``raise_on_release_error`` indicates whether to raise an exception when
|
||||
the lock is no longer owned when exiting the context manager. By default,
|
||||
this is True, meaning an exception will be raised. If False, the warning
|
||||
will be logged and the exception will be suppressed.
|
||||
|
||||
In some use cases it's necessary to disable thread local storage. For
|
||||
example, if you have code where one thread acquires a lock and passes
|
||||
that lock instance to a worker thread to release later. If thread
|
||||
local storage isn't disabled in this case, the worker thread won't see
|
||||
the token set by the thread that acquired the lock. Our assumption
|
||||
is that these cases aren't common and as such default to using
|
||||
thread local storage.
|
||||
"""
|
||||
self.redis = redis
|
||||
self.name = name
|
||||
self.timeout = timeout
|
||||
self.sleep = sleep
|
||||
self.blocking = blocking
|
||||
self.blocking_timeout = blocking_timeout
|
||||
self.thread_local = bool(thread_local)
|
||||
self.local = threading.local() if self.thread_local else SimpleNamespace()
|
||||
self.raise_on_release_error = raise_on_release_error
|
||||
self.local.token = None
|
||||
self.register_scripts()
|
||||
|
||||
def register_scripts(self):
|
||||
cls = self.__class__
|
||||
client = self.redis
|
||||
if cls.lua_release is None:
|
||||
cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
|
||||
if cls.lua_extend is None:
|
||||
cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
|
||||
if cls.lua_reacquire is None:
|
||||
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
|
||||
|
||||
async def __aenter__(self):
|
||||
if await self.acquire():
|
||||
return self
|
||||
raise LockError("Unable to acquire lock within the time specified")
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
try:
|
||||
await self.release()
|
||||
except LockError:
|
||||
if self.raise_on_release_error:
|
||||
raise
|
||||
logger.warning(
|
||||
"Lock was unlocked or no longer owned when exiting context manager."
|
||||
)
|
||||
|
||||
async def acquire(
|
||||
self,
|
||||
blocking: Optional[bool] = None,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
token: Optional[Union[str, bytes]] = None,
|
||||
):
|
||||
"""
|
||||
Use Redis to hold a shared, distributed lock named ``name``.
|
||||
Returns True once the lock is acquired.
|
||||
|
||||
If ``blocking`` is False, always return immediately. If the lock
|
||||
was acquired, return True, otherwise return False.
|
||||
|
||||
``blocking_timeout`` specifies the maximum number of seconds to
|
||||
wait trying to acquire the lock.
|
||||
|
||||
``token`` specifies the token value to be used. If provided, token
|
||||
must be a bytes object or a string that can be encoded to a bytes
|
||||
object with the default encoding. If a token isn't specified, a UUID
|
||||
will be generated.
|
||||
"""
|
||||
sleep = self.sleep
|
||||
if token is None:
|
||||
token = uuid.uuid1().hex.encode()
|
||||
else:
|
||||
try:
|
||||
encoder = self.redis.connection_pool.get_encoder()
|
||||
except AttributeError:
|
||||
# Cluster
|
||||
encoder = self.redis.get_encoder()
|
||||
token = encoder.encode(token)
|
||||
if blocking is None:
|
||||
blocking = self.blocking
|
||||
if blocking_timeout is None:
|
||||
blocking_timeout = self.blocking_timeout
|
||||
stop_trying_at = None
|
||||
if blocking_timeout is not None:
|
||||
stop_trying_at = asyncio.get_running_loop().time() + blocking_timeout
|
||||
while True:
|
||||
if await self.do_acquire(token):
|
||||
self.local.token = token
|
||||
return True
|
||||
if not blocking:
|
||||
return False
|
||||
next_try_at = asyncio.get_running_loop().time() + sleep
|
||||
if stop_trying_at is not None and next_try_at > stop_trying_at:
|
||||
return False
|
||||
await asyncio.sleep(sleep)
|
||||
|
||||
async def do_acquire(self, token: Union[str, bytes]) -> bool:
|
||||
if self.timeout:
|
||||
# convert to milliseconds
|
||||
timeout = int(self.timeout * 1000)
|
||||
else:
|
||||
timeout = None
|
||||
if await self.redis.set(self.name, token, nx=True, px=timeout):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def locked(self) -> bool:
|
||||
"""
|
||||
Returns True if this key is locked by any process, otherwise False.
|
||||
"""
|
||||
return await self.redis.get(self.name) is not None
|
||||
|
||||
async def owned(self) -> bool:
|
||||
"""
|
||||
Returns True if this key is locked by this lock, otherwise False.
|
||||
"""
|
||||
stored_token = await self.redis.get(self.name)
|
||||
# need to always compare bytes to bytes
|
||||
# TODO: this can be simplified when the context manager is finished
|
||||
if stored_token and not isinstance(stored_token, bytes):
|
||||
try:
|
||||
encoder = self.redis.connection_pool.get_encoder()
|
||||
except AttributeError:
|
||||
# Cluster
|
||||
encoder = self.redis.get_encoder()
|
||||
stored_token = encoder.encode(stored_token)
|
||||
return self.local.token is not None and stored_token == self.local.token
|
||||
|
||||
def release(self) -> Awaitable[None]:
|
||||
"""Releases the already acquired lock"""
|
||||
expected_token = self.local.token
|
||||
if expected_token is None:
|
||||
raise LockError(
|
||||
"Cannot release a lock that's not owned or is already unlocked.",
|
||||
lock_name=self.name,
|
||||
)
|
||||
self.local.token = None
|
||||
return self.do_release(expected_token)
|
||||
|
||||
async def do_release(self, expected_token: bytes) -> None:
|
||||
if not bool(
|
||||
await self.lua_release(
|
||||
keys=[self.name], args=[expected_token], client=self.redis
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
|
||||
|
||||
def extend(
|
||||
self, additional_time: Number, replace_ttl: bool = False
|
||||
) -> Awaitable[bool]:
|
||||
"""
|
||||
Adds more time to an already acquired lock.
|
||||
|
||||
``additional_time`` can be specified as an integer or a float, both
|
||||
representing the number of seconds to add.
|
||||
|
||||
``replace_ttl`` if False (the default), add `additional_time` to
|
||||
the lock's existing ttl. If True, replace the lock's ttl with
|
||||
`additional_time`.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot extend an unlocked lock")
|
||||
if self.timeout is None:
|
||||
raise LockError("Cannot extend a lock with no timeout")
|
||||
return self.do_extend(additional_time, replace_ttl)
|
||||
|
||||
async def do_extend(self, additional_time, replace_ttl) -> bool:
|
||||
additional_time = int(additional_time * 1000)
|
||||
if not bool(
|
||||
await self.lua_extend(
|
||||
keys=[self.name],
|
||||
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
|
||||
client=self.redis,
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
|
||||
return True
|
||||
|
||||
def reacquire(self) -> Awaitable[bool]:
|
||||
"""
|
||||
Resets a TTL of an already acquired lock back to a timeout value.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot reacquire an unlocked lock")
|
||||
if self.timeout is None:
|
||||
raise LockError("Cannot reacquire a lock with no timeout")
|
||||
return self.do_reacquire()
|
||||
|
||||
async def do_reacquire(self) -> bool:
|
||||
timeout = int(self.timeout * 1000)
|
||||
if not bool(
|
||||
await self.lua_reacquire(
|
||||
keys=[self.name], args=[self.local.token, timeout], client=self.redis
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
|
||||
return True
|
||||
@@ -0,0 +1,530 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union
|
||||
|
||||
from redis.asyncio.client import PubSubHandler
|
||||
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
|
||||
from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
|
||||
from redis.asyncio.multidb.database import AsyncDatabase, Databases
|
||||
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
|
||||
from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy
|
||||
from redis.background import BackgroundScheduler
|
||||
from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands
|
||||
from redis.multidb.circuit import CircuitBreaker
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
|
||||
from redis.typing import ChannelT, EncodableT, KeyT
|
||||
from redis.utils import experimental
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@experimental
|
||||
class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
|
||||
"""
|
||||
Client that operates on multiple logical Redis databases.
|
||||
Should be used in Active-Active database setups.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiDbConfig):
|
||||
self._databases = config.databases()
|
||||
self._health_checks = config.default_health_checks()
|
||||
|
||||
if config.health_checks is not None:
|
||||
self._health_checks.extend(config.health_checks)
|
||||
|
||||
self._health_check_interval = config.health_check_interval
|
||||
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
|
||||
config.health_check_probes, config.health_check_delay
|
||||
)
|
||||
self._failure_detectors = config.default_failure_detectors()
|
||||
|
||||
if config.failure_detectors is not None:
|
||||
self._failure_detectors.extend(config.failure_detectors)
|
||||
|
||||
self._failover_strategy = (
|
||||
config.default_failover_strategy()
|
||||
if config.failover_strategy is None
|
||||
else config.failover_strategy
|
||||
)
|
||||
self._failover_strategy.set_databases(self._databases)
|
||||
self._auto_fallback_interval = config.auto_fallback_interval
|
||||
self._event_dispatcher = config.event_dispatcher
|
||||
self._command_retry = config.command_retry
|
||||
self._command_retry.update_supported_errors([ConnectionRefusedError])
|
||||
self.command_executor = DefaultCommandExecutor(
|
||||
failure_detectors=self._failure_detectors,
|
||||
databases=self._databases,
|
||||
command_retry=self._command_retry,
|
||||
failover_strategy=self._failover_strategy,
|
||||
failover_attempts=config.failover_attempts,
|
||||
failover_delay=config.failover_delay,
|
||||
event_dispatcher=self._event_dispatcher,
|
||||
auto_fallback_interval=self._auto_fallback_interval,
|
||||
)
|
||||
self.initialized = False
|
||||
self._hc_lock = asyncio.Lock()
|
||||
self._bg_scheduler = BackgroundScheduler()
|
||||
self._config = config
|
||||
self._recurring_hc_task = None
|
||||
self._hc_tasks = []
|
||||
self._half_open_state_task = None
|
||||
|
||||
async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
if self._recurring_hc_task:
|
||||
self._recurring_hc_task.cancel()
|
||||
if self._half_open_state_task:
|
||||
self._half_open_state_task.cancel()
|
||||
for hc_task in self._hc_tasks:
|
||||
hc_task.cancel()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
Perform initialization of databases to define their initial state.
|
||||
"""
|
||||
|
||||
async def raise_exception_on_failed_hc(error):
|
||||
raise error
|
||||
|
||||
# Initial databases check to define initial state
|
||||
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
|
||||
|
||||
# Starts recurring health checks on the background.
|
||||
self._recurring_hc_task = asyncio.create_task(
|
||||
self._bg_scheduler.run_recurring_async(
|
||||
self._health_check_interval,
|
||||
self._check_databases_health,
|
||||
)
|
||||
)
|
||||
|
||||
is_active_db_found = False
|
||||
|
||||
for database, weight in self._databases:
|
||||
# Set on state changed callback for each circuit.
|
||||
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
|
||||
|
||||
# Set states according to a weights and circuit state
|
||||
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
|
||||
await self.command_executor.set_active_database(database)
|
||||
is_active_db_found = True
|
||||
|
||||
if not is_active_db_found:
|
||||
raise NoValidDatabaseException(
|
||||
"Initial connection failed - no active database found"
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def get_databases(self) -> Databases:
|
||||
"""
|
||||
Returns a sorted (by weight) list of all databases.
|
||||
"""
|
||||
return self._databases
|
||||
|
||||
async def set_active_database(self, database: AsyncDatabase) -> None:
|
||||
"""
|
||||
Promote one of the existing databases to become an active.
|
||||
"""
|
||||
exists = None
|
||||
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
exists = True
|
||||
break
|
||||
|
||||
if not exists:
|
||||
raise ValueError("Given database is not a member of database list")
|
||||
|
||||
await self._check_db_health(database)
|
||||
|
||||
if database.circuit.state == CBState.CLOSED:
|
||||
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
|
||||
await self.command_executor.set_active_database(database)
|
||||
return
|
||||
|
||||
raise NoValidDatabaseException(
|
||||
"Cannot set active database, database is unhealthy"
|
||||
)
|
||||
|
||||
async def add_database(self, database: AsyncDatabase):
|
||||
"""
|
||||
Adds a new database to the database list.
|
||||
"""
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
raise ValueError("Given database already exists")
|
||||
|
||||
await self._check_db_health(database)
|
||||
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
self._databases.add(database, database.weight)
|
||||
await self._change_active_database(database, highest_weighted_db)
|
||||
|
||||
async def _change_active_database(
|
||||
self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase
|
||||
):
|
||||
if (
|
||||
new_database.weight > highest_weight_database.weight
|
||||
and new_database.circuit.state == CBState.CLOSED
|
||||
):
|
||||
await self.command_executor.set_active_database(new_database)
|
||||
|
||||
async def remove_database(self, database: AsyncDatabase):
|
||||
"""
|
||||
Removes a database from the database list.
|
||||
"""
|
||||
weight = self._databases.remove(database)
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
|
||||
if (
|
||||
highest_weight <= weight
|
||||
and highest_weighted_db.circuit.state == CBState.CLOSED
|
||||
):
|
||||
await self.command_executor.set_active_database(highest_weighted_db)
|
||||
|
||||
async def update_database_weight(self, database: AsyncDatabase, weight: float):
|
||||
"""
|
||||
Updates a database from the database list.
|
||||
"""
|
||||
exists = None
|
||||
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
exists = True
|
||||
break
|
||||
|
||||
if not exists:
|
||||
raise ValueError("Given database is not a member of database list")
|
||||
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
self._databases.update_weight(database, weight)
|
||||
database.weight = weight
|
||||
await self._change_active_database(database, highest_weighted_db)
|
||||
|
||||
def add_failure_detector(self, failure_detector: AsyncFailureDetector):
|
||||
"""
|
||||
Adds a new failure detector to the database.
|
||||
"""
|
||||
self._failure_detectors.append(failure_detector)
|
||||
|
||||
async def add_health_check(self, healthcheck: HealthCheck):
|
||||
"""
|
||||
Adds a new health check to the database.
|
||||
"""
|
||||
async with self._hc_lock:
|
||||
self._health_checks.append(healthcheck)
|
||||
|
||||
async def execute_command(self, *args, **options):
|
||||
"""
|
||||
Executes a single command and return its result.
|
||||
"""
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
|
||||
return await self.command_executor.execute_command(*args, **options)
|
||||
|
||||
def pipeline(self):
|
||||
"""
|
||||
Enters into pipeline mode of the client.
|
||||
"""
|
||||
return Pipeline(self)
|
||||
|
||||
async def transaction(
|
||||
self,
|
||||
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
|
||||
*watches: KeyT,
|
||||
shard_hint: Optional[str] = None,
|
||||
value_from_callable: bool = False,
|
||||
watch_delay: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Executes callable as transaction.
|
||||
"""
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
|
||||
return await self.command_executor.execute_transaction(
|
||||
func,
|
||||
*watches,
|
||||
shard_hint=shard_hint,
|
||||
value_from_callable=value_from_callable,
|
||||
watch_delay=watch_delay,
|
||||
)
|
||||
|
||||
async def pubsub(self, **kwargs):
|
||||
"""
|
||||
Return a Publish/Subscribe object. With this object, you can
|
||||
subscribe to channels and listen for messages that get published to
|
||||
them.
|
||||
"""
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
|
||||
return PubSub(self, **kwargs)
|
||||
|
||||
async def _check_databases_health(
|
||||
self,
|
||||
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
|
||||
):
|
||||
"""
|
||||
Runs health checks as a recurring task.
|
||||
Runs health checks against all databases.
|
||||
"""
|
||||
try:
|
||||
self._hc_tasks = [
|
||||
asyncio.create_task(self._check_db_health(database))
|
||||
for database, _ in self._databases
|
||||
]
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(
|
||||
*self._hc_tasks,
|
||||
return_exceptions=True,
|
||||
),
|
||||
timeout=self._health_check_interval,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise asyncio.TimeoutError(
|
||||
"Health check execution exceeds health_check_interval"
|
||||
)
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, UnhealthyDatabaseException):
|
||||
unhealthy_db = result.database
|
||||
unhealthy_db.circuit.state = CBState.OPEN
|
||||
|
||||
logger.exception(
|
||||
"Health check failed, due to exception",
|
||||
exc_info=result.original_exception,
|
||||
)
|
||||
|
||||
if on_error:
|
||||
on_error(result.original_exception)
|
||||
|
||||
async def _check_db_health(self, database: AsyncDatabase) -> bool:
|
||||
"""
|
||||
Runs health checks on the given database until first failure.
|
||||
"""
|
||||
# Health check will setup circuit state
|
||||
is_healthy = await self._health_check_policy.execute(
|
||||
self._health_checks, database
|
||||
)
|
||||
|
||||
if not is_healthy:
|
||||
if database.circuit.state != CBState.OPEN:
|
||||
database.circuit.state = CBState.OPEN
|
||||
return is_healthy
|
||||
elif is_healthy and database.circuit.state != CBState.CLOSED:
|
||||
database.circuit.state = CBState.CLOSED
|
||||
|
||||
return is_healthy
|
||||
|
||||
def _on_circuit_state_change_callback(
|
||||
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
|
||||
):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if new_state == CBState.HALF_OPEN:
|
||||
self._half_open_state_task = asyncio.create_task(
|
||||
self._check_db_health(circuit.database)
|
||||
)
|
||||
return
|
||||
|
||||
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
|
||||
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)
|
||||
|
||||
async def aclose(self):
|
||||
if self.command_executor.active_database:
|
||||
await self.command_executor.active_database.client.aclose()
|
||||
|
||||
|
||||
def _half_open_circuit(circuit: CircuitBreaker):
|
||||
circuit.state = CBState.HALF_OPEN
|
||||
|
||||
|
||||
class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands):
|
||||
"""
|
||||
Pipeline implementation for multiple logical Redis databases.
|
||||
"""
|
||||
|
||||
def __init__(self, client: MultiDBClient):
|
||||
self._command_stack = []
|
||||
self._client = client
|
||||
|
||||
async def __aenter__(self: "Pipeline") -> "Pipeline":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.reset()
|
||||
await self._client.__aexit__(exc_type, exc_value, traceback)
|
||||
|
||||
def __await__(self):
|
||||
return self._async_self().__await__()
|
||||
|
||||
async def _async_self(self):
|
||||
return self
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._command_stack)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Pipeline instances should always evaluate to True"""
|
||||
return True
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._command_stack = []
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the pipeline"""
|
||||
await self.reset()
|
||||
|
||||
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
|
||||
"""
|
||||
Stage a command to be executed when execute() is next called
|
||||
|
||||
Returns the current Pipeline object back so commands can be
|
||||
chained together, such as:
|
||||
|
||||
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
|
||||
|
||||
At some other point, you can then run: pipe.execute(),
|
||||
which will execute all commands queued in the pipe.
|
||||
"""
|
||||
self._command_stack.append((args, options))
|
||||
return self
|
||||
|
||||
def execute_command(self, *args, **kwargs):
|
||||
"""Adds a command to the stack"""
|
||||
return self.pipeline_execute_command(*args, **kwargs)
|
||||
|
||||
async def execute(self) -> List[Any]:
|
||||
"""Execute all the commands in the current pipeline"""
|
||||
if not self._client.initialized:
|
||||
await self._client.initialize()
|
||||
|
||||
try:
|
||||
return await self._client.command_executor.execute_pipeline(
|
||||
tuple(self._command_stack)
|
||||
)
|
||||
finally:
|
||||
await self.reset()
|
||||
|
||||
|
||||
class PubSub:
|
||||
"""
|
||||
PubSub object for multi database client.
|
||||
"""
|
||||
|
||||
def __init__(self, client: MultiDBClient, **kwargs):
|
||||
"""Initialize the PubSub object for a multi-database client.
|
||||
|
||||
Args:
|
||||
client: MultiDBClient instance to use for pub/sub operations
|
||||
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
|
||||
"""
|
||||
|
||||
self._client = client
|
||||
self._client.command_executor.pubsub(**kwargs)
|
||||
|
||||
async def __aenter__(self) -> "PubSub":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self):
|
||||
return await self._client.command_executor.execute_pubsub_method("aclose")
|
||||
|
||||
@property
|
||||
def subscribed(self) -> bool:
|
||||
return self._client.command_executor.active_pubsub.subscribed
|
||||
|
||||
async def execute_command(self, *args: EncodableT):
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"execute_command", *args
|
||||
)
|
||||
|
||||
async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
|
||||
"""
|
||||
Subscribe to channel patterns. Patterns supplied as keyword arguments
|
||||
expect a pattern name as the key and a callable as the value. A
|
||||
pattern's callable will be invoked automatically when a message is
|
||||
received on that pattern rather than producing a message via
|
||||
``listen()``.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"psubscribe", *args, **kwargs
|
||||
)
|
||||
|
||||
async def punsubscribe(self, *args: ChannelT):
|
||||
"""
|
||||
Unsubscribe from the supplied patterns. If empty, unsubscribe from
|
||||
all patterns.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"punsubscribe", *args
|
||||
)
|
||||
|
||||
async def subscribe(self, *args: ChannelT, **kwargs: Callable):
|
||||
"""
|
||||
Subscribe to channels. Channels supplied as keyword arguments expect
|
||||
a channel name as the key and a callable as the value. A channel's
|
||||
callable will be invoked automatically when a message is received on
|
||||
that channel rather than producing a message via ``listen()`` or
|
||||
``get_message()``.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"subscribe", *args, **kwargs
|
||||
)
|
||||
|
||||
async def unsubscribe(self, *args):
|
||||
"""
|
||||
Unsubscribe from the supplied channels. If empty, unsubscribe from
|
||||
all channels
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"unsubscribe", *args
|
||||
)
|
||||
|
||||
async def get_message(
|
||||
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
|
||||
):
|
||||
"""
|
||||
Get the next message if one is available, otherwise None.
|
||||
|
||||
If timeout is specified, the system will wait for `timeout` seconds
|
||||
before returning. Timeout should be specified as a floating point
|
||||
number or None to wait indefinitely.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"get_message",
|
||||
ignore_subscribe_messages=ignore_subscribe_messages,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
exception_handler=None,
|
||||
poll_timeout: float = 1.0,
|
||||
) -> None:
|
||||
"""Process pub/sub messages using registered callbacks.
|
||||
|
||||
This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
|
||||
redis-py, but it is a coroutine. To launch it as a separate task, use
|
||||
``asyncio.create_task``:
|
||||
|
||||
>>> task = asyncio.create_task(pubsub.run())
|
||||
|
||||
To shut it down, use asyncio cancellation:
|
||||
|
||||
>>> task.cancel()
|
||||
>>> await task
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_run(
|
||||
sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self
|
||||
)
|
||||
@@ -0,0 +1,339 @@
|
||||
from abc import abstractmethod
|
||||
from asyncio import iscoroutinefunction
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Union
|
||||
|
||||
from redis.asyncio import RedisCluster
|
||||
from redis.asyncio.client import Pipeline, PubSub
|
||||
from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases
|
||||
from redis.asyncio.multidb.event import (
|
||||
AsyncActiveDatabaseChanged,
|
||||
CloseConnectionOnActiveDatabaseChanged,
|
||||
RegisterCommandFailure,
|
||||
ResubscribeOnActiveDatabaseChanged,
|
||||
)
|
||||
from redis.asyncio.multidb.failover import (
|
||||
DEFAULT_FAILOVER_ATTEMPTS,
|
||||
DEFAULT_FAILOVER_DELAY,
|
||||
AsyncFailoverStrategy,
|
||||
DefaultFailoverStrategyExecutor,
|
||||
FailoverStrategyExecutor,
|
||||
)
|
||||
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
|
||||
from redis.asyncio.retry import Retry
|
||||
from redis.event import AsyncOnCommandsFailEvent, EventDispatcherInterface
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.command_executor import BaseCommandExecutor, CommandExecutor
|
||||
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
from redis.typing import KeyT
|
||||
|
||||
|
||||
class AsyncCommandExecutor(CommandExecutor):
|
||||
@property
|
||||
@abstractmethod
|
||||
def databases(self) -> Databases:
|
||||
"""Returns a list of databases."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failure_detectors(self) -> List[AsyncFailureDetector]:
|
||||
"""Returns a list of failure detectors."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
|
||||
"""Adds a new failure detector to the list of failure detectors."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def active_database(self) -> Optional[AsyncDatabase]:
|
||||
"""Returns currently active database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_active_database(self, database: AsyncDatabase) -> None:
|
||||
"""Sets the currently active database."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def active_pubsub(self) -> Optional[PubSub]:
|
||||
"""Returns currently active pubsub."""
|
||||
pass
|
||||
|
||||
@active_pubsub.setter
|
||||
@abstractmethod
|
||||
def active_pubsub(self, pubsub: PubSub) -> None:
|
||||
"""Sets currently active pubsub."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
|
||||
"""Returns failover strategy executor."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def command_retry(self) -> Retry:
|
||||
"""Returns command retry object."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def pubsub(self, **kwargs):
|
||||
"""Initializes a PubSub object on a currently active database"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_command(self, *args, **options):
|
||||
"""Executes a command and returns the result."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_pipeline(self, command_stack: tuple):
|
||||
"""Executes a stack of commands in pipeline."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_transaction(
|
||||
self, transaction: Callable[[Pipeline], None], *watches, **options
|
||||
):
|
||||
"""Executes a transaction block wrapped in callback."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
|
||||
"""Executes a given method on active pub/sub."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
|
||||
"""Executes pub/sub run in a thread."""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
failure_detectors: List[AsyncFailureDetector],
|
||||
databases: Databases,
|
||||
command_retry: Retry,
|
||||
failover_strategy: AsyncFailoverStrategy,
|
||||
event_dispatcher: EventDispatcherInterface,
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY,
|
||||
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
|
||||
):
|
||||
"""
|
||||
Initialize the DefaultCommandExecutor instance.
|
||||
|
||||
Args:
|
||||
failure_detectors: List of failure detector instances to monitor database health
|
||||
databases: Collection of available databases to execute commands on
|
||||
command_retry: Retry policy for failed command execution
|
||||
failover_strategy: Strategy for handling database failover
|
||||
event_dispatcher: Interface for dispatching events
|
||||
failover_attempts: Number of failover attempts
|
||||
failover_delay: Delay between failover attempts
|
||||
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
|
||||
"""
|
||||
super().__init__(auto_fallback_interval)
|
||||
|
||||
for fd in failure_detectors:
|
||||
fd.set_command_executor(command_executor=self)
|
||||
|
||||
self._databases = databases
|
||||
self._failure_detectors = failure_detectors
|
||||
self._command_retry = command_retry
|
||||
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
|
||||
failover_strategy, failover_attempts, failover_delay
|
||||
)
|
||||
self._event_dispatcher = event_dispatcher
|
||||
self._active_database: Optional[Database] = None
|
||||
self._active_pubsub: Optional[PubSub] = None
|
||||
self._active_pubsub_kwargs = {}
|
||||
self._setup_event_dispatcher()
|
||||
self._schedule_next_fallback()
|
||||
|
||||
@property
|
||||
def databases(self) -> Databases:
|
||||
return self._databases
|
||||
|
||||
@property
|
||||
def failure_detectors(self) -> List[AsyncFailureDetector]:
|
||||
return self._failure_detectors
|
||||
|
||||
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
|
||||
self._failure_detectors.append(failure_detector)
|
||||
|
||||
@property
|
||||
def active_database(self) -> Optional[AsyncDatabase]:
|
||||
return self._active_database
|
||||
|
||||
async def set_active_database(self, database: AsyncDatabase) -> None:
|
||||
old_active = self._active_database
|
||||
self._active_database = database
|
||||
|
||||
if old_active is not None and old_active is not database:
|
||||
await self._event_dispatcher.dispatch_async(
|
||||
AsyncActiveDatabaseChanged(
|
||||
old_active,
|
||||
self._active_database,
|
||||
self,
|
||||
**self._active_pubsub_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def active_pubsub(self) -> Optional[PubSub]:
|
||||
return self._active_pubsub
|
||||
|
||||
@active_pubsub.setter
|
||||
def active_pubsub(self, pubsub: PubSub) -> None:
|
||||
self._active_pubsub = pubsub
|
||||
|
||||
@property
|
||||
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
|
||||
return self._failover_strategy_executor
|
||||
|
||||
@property
|
||||
def command_retry(self) -> Retry:
|
||||
return self._command_retry
|
||||
|
||||
def pubsub(self, **kwargs):
|
||||
if self._active_pubsub is None:
|
||||
if isinstance(self._active_database.client, RedisCluster):
|
||||
raise ValueError("PubSub is not supported for RedisCluster")
|
||||
|
||||
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
|
||||
self._active_pubsub_kwargs = kwargs
|
||||
|
||||
async def execute_command(self, *args, **options):
|
||||
async def callback():
|
||||
response = await self._active_database.client.execute_command(
|
||||
*args, **options
|
||||
)
|
||||
await self._register_command_execution(args)
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback, args)
|
||||
|
||||
async def execute_pipeline(self, command_stack: tuple):
|
||||
async def callback():
|
||||
async with self._active_database.client.pipeline() as pipe:
|
||||
for command, options in command_stack:
|
||||
pipe.execute_command(*command, **options)
|
||||
|
||||
response = await pipe.execute()
|
||||
await self._register_command_execution(command_stack)
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback, command_stack)
|
||||
|
||||
async def execute_transaction(
|
||||
self,
|
||||
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
|
||||
*watches: KeyT,
|
||||
shard_hint: Optional[str] = None,
|
||||
value_from_callable: bool = False,
|
||||
watch_delay: Optional[float] = None,
|
||||
):
|
||||
async def callback():
|
||||
response = await self._active_database.client.transaction(
|
||||
func,
|
||||
*watches,
|
||||
shard_hint=shard_hint,
|
||||
value_from_callable=value_from_callable,
|
||||
watch_delay=watch_delay,
|
||||
)
|
||||
await self._register_command_execution(())
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback)
|
||||
|
||||
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
|
||||
async def callback():
|
||||
method = getattr(self.active_pubsub, method_name)
|
||||
if iscoroutinefunction(method):
|
||||
response = await method(*args, **kwargs)
|
||||
else:
|
||||
response = method(*args, **kwargs)
|
||||
|
||||
await self._register_command_execution(args)
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback, *args)
|
||||
|
||||
async def execute_pubsub_run(
|
||||
self, sleep_time: float, exception_handler=None, pubsub=None
|
||||
) -> Any:
|
||||
async def callback():
|
||||
return await self._active_pubsub.run(
|
||||
poll_timeout=sleep_time,
|
||||
exception_handler=exception_handler,
|
||||
pubsub=pubsub,
|
||||
)
|
||||
|
||||
return await self._execute_with_failure_detection(callback)
|
||||
|
||||
async def _execute_with_failure_detection(
|
||||
self, callback: Callable, cmds: tuple = ()
|
||||
):
|
||||
"""
|
||||
Execute a commands execution callback with failure detection.
|
||||
"""
|
||||
|
||||
async def wrapper():
|
||||
# On each retry we need to check active database as it might change.
|
||||
await self._check_active_database()
|
||||
return await callback()
|
||||
|
||||
return await self._command_retry.call_with_retry(
|
||||
lambda: wrapper(),
|
||||
lambda error: self._on_command_fail(error, *cmds),
|
||||
)
|
||||
|
||||
async def _check_active_database(self):
|
||||
"""
|
||||
Checks if active a database needs to be updated.
|
||||
"""
|
||||
if (
|
||||
self._active_database is None
|
||||
or self._active_database.circuit.state != CBState.CLOSED
|
||||
or (
|
||||
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
and self._next_fallback_attempt <= datetime.now()
|
||||
)
|
||||
):
|
||||
await self.set_active_database(
|
||||
await self._failover_strategy_executor.execute()
|
||||
)
|
||||
self._schedule_next_fallback()
|
||||
|
||||
async def _on_command_fail(self, error, *args):
|
||||
await self._event_dispatcher.dispatch_async(
|
||||
AsyncOnCommandsFailEvent(args, error)
|
||||
)
|
||||
|
||||
async def _register_command_execution(self, cmd: tuple):
|
||||
for detector in self._failure_detectors:
|
||||
await detector.register_command_execution(cmd)
|
||||
|
||||
def _setup_event_dispatcher(self):
|
||||
"""
|
||||
Registers necessary listeners.
|
||||
"""
|
||||
failure_listener = RegisterCommandFailure(self._failure_detectors)
|
||||
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
|
||||
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
|
||||
self._event_dispatcher.register_listeners(
|
||||
{
|
||||
AsyncOnCommandsFailEvent: [failure_listener],
|
||||
AsyncActiveDatabaseChanged: [
|
||||
close_connection_listener,
|
||||
resubscribe_listener,
|
||||
],
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,210 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
import pybreaker
|
||||
|
||||
from redis.asyncio import ConnectionPool, Redis, RedisCluster
|
||||
from redis.asyncio.multidb.database import Database, Databases
|
||||
from redis.asyncio.multidb.failover import (
|
||||
DEFAULT_FAILOVER_ATTEMPTS,
|
||||
DEFAULT_FAILOVER_DELAY,
|
||||
AsyncFailoverStrategy,
|
||||
WeightBasedFailoverStrategy,
|
||||
)
|
||||
from redis.asyncio.multidb.failure_detector import (
|
||||
AsyncFailureDetector,
|
||||
FailureDetectorAsyncWrapper,
|
||||
)
|
||||
from redis.asyncio.multidb.healthcheck import (
|
||||
DEFAULT_HEALTH_CHECK_DELAY,
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL,
|
||||
DEFAULT_HEALTH_CHECK_POLICY,
|
||||
DEFAULT_HEALTH_CHECK_PROBES,
|
||||
HealthCheck,
|
||||
HealthCheckPolicies,
|
||||
PingHealthCheck,
|
||||
)
|
||||
from redis.asyncio.retry import Retry
|
||||
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.event import EventDispatcher, EventDispatcherInterface
|
||||
from redis.multidb.circuit import (
|
||||
DEFAULT_GRACE_PERIOD,
|
||||
CircuitBreaker,
|
||||
PBCircuitBreakerAdapter,
|
||||
)
|
||||
from redis.multidb.failure_detector import (
|
||||
DEFAULT_FAILURE_RATE_THRESHOLD,
|
||||
DEFAULT_FAILURES_DETECTION_WINDOW,
|
||||
DEFAULT_MIN_NUM_FAILURES,
|
||||
CommandFailureDetector,
|
||||
)
|
||||
|
||||
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
|
||||
|
||||
|
||||
def default_event_dispatcher() -> EventDispatcherInterface:
|
||||
return EventDispatcher()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""
|
||||
Dataclass representing the configuration for a database connection.
|
||||
|
||||
This class is used to store configuration settings for a database connection,
|
||||
including client options, connection sourcing details, circuit breaker settings,
|
||||
and cluster-specific properties. It provides a structure for defining these
|
||||
attributes and allows for the creation of customized configurations for various
|
||||
database setups.
|
||||
|
||||
Attributes:
|
||||
weight (float): Weight of the database to define the active one.
|
||||
client_kwargs (dict): Additional parameters for the database client connection.
|
||||
from_url (Optional[str]): Redis URL way of connecting to the database.
|
||||
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
|
||||
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
|
||||
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
|
||||
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
|
||||
on public Redis Enterprise endpoints.
|
||||
|
||||
Methods:
|
||||
default_circuit_breaker:
|
||||
Generates and returns a default CircuitBreaker instance adapted for use.
|
||||
"""
|
||||
|
||||
weight: float = 1.0
|
||||
client_kwargs: dict = field(default_factory=dict)
|
||||
from_url: Optional[str] = None
|
||||
from_pool: Optional[ConnectionPool] = None
|
||||
circuit: Optional[CircuitBreaker] = None
|
||||
grace_period: float = DEFAULT_GRACE_PERIOD
|
||||
health_check_url: Optional[str] = None
|
||||
|
||||
def default_circuit_breaker(self) -> CircuitBreaker:
|
||||
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
|
||||
return PBCircuitBreakerAdapter(circuit_breaker)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDbConfig:
|
||||
"""
|
||||
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
|
||||
|
||||
Attributes:
|
||||
databases_config: A list of database configurations.
|
||||
client_class: The client class used to manage database connections.
|
||||
command_retry: Retry strategy for executing database commands.
|
||||
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
|
||||
min_num_failures: Minimal count of failures required for failover
|
||||
failure_rate_threshold: Percentage of failures required for failover
|
||||
failures_detection_window: Time interval for tracking database failures.
|
||||
health_checks: Optional list of additional health checks performed on databases.
|
||||
health_check_interval: Time interval for executing health checks.
|
||||
health_check_probes: Number of attempts to evaluate the health of a database.
|
||||
health_check_delay: Delay between health check attempts.
|
||||
failover_strategy: Optional strategy for handling database failover scenarios.
|
||||
failover_attempts: Number of retries allowed for failover operations.
|
||||
failover_delay: Delay between failover attempts.
|
||||
auto_fallback_interval: Time interval to trigger automatic fallback.
|
||||
event_dispatcher: Interface for dispatching events related to database operations.
|
||||
|
||||
Methods:
|
||||
databases:
|
||||
Retrieves a collection of database clients managed by weighted configurations.
|
||||
Initializes database clients based on the provided configuration and removes
|
||||
redundant retry objects for lower-level clients to rely on global retry logic.
|
||||
|
||||
default_failure_detectors:
|
||||
Returns the default list of failure detectors used to monitor database failures.
|
||||
|
||||
default_health_checks:
|
||||
Returns the default list of health checks used to monitor database health
|
||||
with specific retry and backoff strategies.
|
||||
|
||||
default_failover_strategy:
|
||||
Provides the default failover strategy used for handling failover scenarios
|
||||
with defined retry and backoff configurations.
|
||||
"""
|
||||
|
||||
databases_config: List[DatabaseConfig]
|
||||
client_class: Type[Union[Redis, RedisCluster]] = Redis
|
||||
command_retry: Retry = Retry(
|
||||
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
|
||||
)
|
||||
failure_detectors: Optional[List[AsyncFailureDetector]] = None
|
||||
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
|
||||
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
|
||||
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
|
||||
health_checks: Optional[List[HealthCheck]] = None
|
||||
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
|
||||
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
|
||||
health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY
|
||||
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
|
||||
failover_strategy: Optional[AsyncFailoverStrategy] = None
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY
|
||||
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
event_dispatcher: EventDispatcherInterface = field(
|
||||
default_factory=default_event_dispatcher
|
||||
)
|
||||
|
||||
def databases(self) -> Databases:
|
||||
databases = WeightedList()
|
||||
|
||||
for database_config in self.databases_config:
|
||||
# The retry object is not used in the lower level clients, so we can safely remove it.
|
||||
# We rely on command_retry in terms of global retries.
|
||||
database_config.client_kwargs.update(
|
||||
{"retry": Retry(retries=0, backoff=NoBackoff())}
|
||||
)
|
||||
|
||||
if database_config.from_url:
|
||||
client = self.client_class.from_url(
|
||||
database_config.from_url, **database_config.client_kwargs
|
||||
)
|
||||
elif database_config.from_pool:
|
||||
database_config.from_pool.set_retry(
|
||||
Retry(retries=0, backoff=NoBackoff())
|
||||
)
|
||||
client = self.client_class.from_pool(
|
||||
connection_pool=database_config.from_pool
|
||||
)
|
||||
else:
|
||||
client = self.client_class(**database_config.client_kwargs)
|
||||
|
||||
circuit = (
|
||||
database_config.default_circuit_breaker()
|
||||
if database_config.circuit is None
|
||||
else database_config.circuit
|
||||
)
|
||||
databases.add(
|
||||
Database(
|
||||
client=client,
|
||||
circuit=circuit,
|
||||
weight=database_config.weight,
|
||||
health_check_url=database_config.health_check_url,
|
||||
),
|
||||
database_config.weight,
|
||||
)
|
||||
|
||||
return databases
|
||||
|
||||
def default_failure_detectors(self) -> List[AsyncFailureDetector]:
|
||||
return [
|
||||
FailureDetectorAsyncWrapper(
|
||||
CommandFailureDetector(
|
||||
min_num_failures=self.min_num_failures,
|
||||
failure_rate_threshold=self.failure_rate_threshold,
|
||||
failure_detection_window=self.failures_detection_window,
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def default_health_checks(self) -> List[HealthCheck]:
|
||||
return [
|
||||
PingHealthCheck(),
|
||||
]
|
||||
|
||||
def default_failover_strategy(self) -> AsyncFailoverStrategy:
|
||||
return WeightBasedFailoverStrategy()
|
||||
@@ -0,0 +1,69 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.multidb.circuit import CircuitBreaker
|
||||
from redis.multidb.database import AbstractDatabase, BaseDatabase
|
||||
from redis.typing import Number
|
||||
|
||||
|
||||
class AsyncDatabase(AbstractDatabase):
|
||||
"""Database with an underlying asynchronous redis client."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def client(self) -> Union[Redis, RedisCluster]:
|
||||
"""The underlying redis client."""
|
||||
pass
|
||||
|
||||
@client.setter
|
||||
@abstractmethod
|
||||
def client(self, client: Union[Redis, RedisCluster]):
|
||||
"""Set the underlying redis client."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def circuit(self) -> CircuitBreaker:
|
||||
"""Circuit breaker for the current database."""
|
||||
pass
|
||||
|
||||
@circuit.setter
|
||||
@abstractmethod
|
||||
def circuit(self, circuit: CircuitBreaker):
|
||||
"""Set the circuit breaker for the current database."""
|
||||
pass
|
||||
|
||||
|
||||
Databases = WeightedList[tuple[AsyncDatabase, Number]]
|
||||
|
||||
|
||||
class Database(BaseDatabase, AsyncDatabase):
|
||||
def __init__(
|
||||
self,
|
||||
client: Union[Redis, RedisCluster],
|
||||
circuit: CircuitBreaker,
|
||||
weight: float,
|
||||
health_check_url: Optional[str] = None,
|
||||
):
|
||||
self._client = client
|
||||
self._cb = circuit
|
||||
self._cb.database = self
|
||||
super().__init__(weight, health_check_url)
|
||||
|
||||
@property
|
||||
def client(self) -> Union[Redis, RedisCluster]:
|
||||
return self._client
|
||||
|
||||
@client.setter
|
||||
def client(self, client: Union[Redis, RedisCluster]):
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def circuit(self) -> CircuitBreaker:
|
||||
return self._cb
|
||||
|
||||
@circuit.setter
|
||||
def circuit(self, circuit: CircuitBreaker):
|
||||
self._cb = circuit
|
||||
@@ -0,0 +1,84 @@
|
||||
from typing import List
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio.multidb.database import AsyncDatabase
|
||||
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
|
||||
from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent
|
||||
|
||||
|
||||
class AsyncActiveDatabaseChanged:
|
||||
"""
|
||||
Event fired when an async active database has been changed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
old_database: AsyncDatabase,
|
||||
new_database: AsyncDatabase,
|
||||
command_executor,
|
||||
**kwargs,
|
||||
):
|
||||
self._old_database = old_database
|
||||
self._new_database = new_database
|
||||
self._command_executor = command_executor
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def old_database(self) -> AsyncDatabase:
|
||||
return self._old_database
|
||||
|
||||
@property
|
||||
def new_database(self) -> AsyncDatabase:
|
||||
return self._new_database
|
||||
|
||||
@property
|
||||
def command_executor(self):
|
||||
return self._command_executor
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
return self._kwargs
|
||||
|
||||
|
||||
class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface):
|
||||
"""
|
||||
Re-subscribe the currently active pub / sub to a new active database.
|
||||
"""
|
||||
|
||||
async def listen(self, event: AsyncActiveDatabaseChanged):
|
||||
old_pubsub = event.command_executor.active_pubsub
|
||||
|
||||
if old_pubsub is not None:
|
||||
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
|
||||
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
|
||||
new_pubsub.channels = old_pubsub.channels
|
||||
new_pubsub.patterns = old_pubsub.patterns
|
||||
await new_pubsub.on_connect(None)
|
||||
event.command_executor.active_pubsub = new_pubsub
|
||||
await old_pubsub.aclose()
|
||||
|
||||
|
||||
class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface):
|
||||
"""
|
||||
Close connection to the old active database.
|
||||
"""
|
||||
|
||||
async def listen(self, event: AsyncActiveDatabaseChanged):
|
||||
await event.old_database.client.aclose()
|
||||
|
||||
if isinstance(event.old_database.client, Redis):
|
||||
await event.old_database.client.connection_pool.update_active_connections_for_reconnect()
|
||||
await event.old_database.client.connection_pool.disconnect()
|
||||
|
||||
|
||||
class RegisterCommandFailure(AsyncEventListenerInterface):
|
||||
"""
|
||||
Event listener that registers command failures and passing it to the failure detectors.
|
||||
"""
|
||||
|
||||
def __init__(self, failure_detectors: List[AsyncFailureDetector]):
|
||||
self._failure_detectors = failure_detectors
|
||||
|
||||
async def listen(self, event: AsyncOnCommandsFailEvent) -> None:
|
||||
for failure_detector in self._failure_detectors:
|
||||
await failure_detector.register_failure(event.exception, event.commands)
|
||||
@@ -0,0 +1,125 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from redis.asyncio.multidb.database import AsyncDatabase, Databases
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.exception import (
|
||||
NoValidDatabaseException,
|
||||
TemporaryUnavailableException,
|
||||
)
|
||||
|
||||
DEFAULT_FAILOVER_ATTEMPTS = 10
|
||||
DEFAULT_FAILOVER_DELAY = 12
|
||||
|
||||
|
||||
class AsyncFailoverStrategy(ABC):
|
||||
@abstractmethod
|
||||
async def database(self) -> AsyncDatabase:
|
||||
"""Select the database according to the strategy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_databases(self, databases: Databases) -> None:
|
||||
"""Set the database strategy operates on."""
|
||||
pass
|
||||
|
||||
|
||||
class FailoverStrategyExecutor(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_attempts(self) -> int:
|
||||
"""The number of failover attempts."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_delay(self) -> float:
|
||||
"""The delay between failover attempts."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def strategy(self) -> AsyncFailoverStrategy:
|
||||
"""The strategy to execute."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> AsyncDatabase:
|
||||
"""Execute the failover strategy."""
|
||||
pass
|
||||
|
||||
|
||||
class WeightBasedFailoverStrategy(AsyncFailoverStrategy):
|
||||
"""
|
||||
Failover strategy based on database weights.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._databases = WeightedList()
|
||||
|
||||
async def database(self) -> AsyncDatabase:
|
||||
for database, _ in self._databases:
|
||||
if database.circuit.state == CBState.CLOSED:
|
||||
return database
|
||||
|
||||
raise NoValidDatabaseException("No valid database available for communication")
|
||||
|
||||
def set_databases(self, databases: Databases) -> None:
|
||||
self._databases = databases
|
||||
|
||||
|
||||
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
|
||||
"""
|
||||
Executes given failover strategy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: AsyncFailoverStrategy,
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY,
|
||||
):
|
||||
self._strategy = strategy
|
||||
self._failover_attempts = failover_attempts
|
||||
self._failover_delay = failover_delay
|
||||
self._next_attempt_ts: int = 0
|
||||
self._failover_counter: int = 0
|
||||
|
||||
@property
|
||||
def failover_attempts(self) -> int:
|
||||
return self._failover_attempts
|
||||
|
||||
@property
|
||||
def failover_delay(self) -> float:
|
||||
return self._failover_delay
|
||||
|
||||
@property
|
||||
def strategy(self) -> AsyncFailoverStrategy:
|
||||
return self._strategy
|
||||
|
||||
async def execute(self) -> AsyncDatabase:
|
||||
try:
|
||||
database = await self._strategy.database()
|
||||
self._reset()
|
||||
return database
|
||||
except NoValidDatabaseException as e:
|
||||
if self._next_attempt_ts == 0:
|
||||
self._next_attempt_ts = time.time() + self._failover_delay
|
||||
self._failover_counter += 1
|
||||
elif time.time() >= self._next_attempt_ts:
|
||||
self._next_attempt_ts += self._failover_delay
|
||||
self._failover_counter += 1
|
||||
|
||||
if self._failover_counter > self._failover_attempts:
|
||||
self._reset()
|
||||
raise e
|
||||
else:
|
||||
raise TemporaryUnavailableException(
|
||||
"No database connections currently available. "
|
||||
"This is a temporary condition - please retry the operation."
|
||||
)
|
||||
|
||||
def _reset(self) -> None:
|
||||
self._next_attempt_ts = 0
|
||||
self._failover_counter = 0
|
||||
@@ -0,0 +1,38 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from redis.multidb.failure_detector import FailureDetector
|
||||
|
||||
|
||||
class AsyncFailureDetector(ABC):
|
||||
@abstractmethod
|
||||
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
|
||||
"""Register a failure that occurred during command execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_command_execution(self, cmd: tuple) -> None:
|
||||
"""Register a command execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_command_executor(self, command_executor) -> None:
|
||||
"""Set the command executor for this failure."""
|
||||
pass
|
||||
|
||||
|
||||
class FailureDetectorAsyncWrapper(AsyncFailureDetector):
|
||||
"""
|
||||
Async wrapper for the failure detector.
|
||||
"""
|
||||
|
||||
def __init__(self, failure_detector: FailureDetector) -> None:
|
||||
self._failure_detector = failure_detector
|
||||
|
||||
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
|
||||
self._failure_detector.register_failure(exception, cmd)
|
||||
|
||||
async def register_command_execution(self, cmd: tuple) -> None:
|
||||
self._failure_detector.register_command_execution(cmd)
|
||||
|
||||
def set_command_executor(self, command_executor) -> None:
|
||||
self._failure_detector.set_command_executor(command_executor)
|
||||
@@ -0,0 +1,285 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
|
||||
from redis.backoff import NoBackoff
|
||||
from redis.http.http_client import HttpClient
|
||||
from redis.multidb.exception import UnhealthyDatabaseException
|
||||
from redis.retry import Retry
|
||||
|
||||
DEFAULT_HEALTH_CHECK_PROBES = 3
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL = 5
|
||||
DEFAULT_HEALTH_CHECK_DELAY = 0.5
|
||||
DEFAULT_LAG_AWARE_TOLERANCE = 5000
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthCheck(ABC):
|
||||
@abstractmethod
|
||||
async def check_health(self, database) -> bool:
|
||||
"""Function to determine the health status."""
|
||||
pass
|
||||
|
||||
|
||||
class HealthCheckPolicy(ABC):
|
||||
"""
|
||||
Health checks execution policy.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def health_check_probes(self) -> int:
|
||||
"""Number of probes to execute health checks."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def health_check_delay(self) -> float:
|
||||
"""Delay between health check probes."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
"""Execute health checks and return database health status."""
|
||||
pass
|
||||
|
||||
|
||||
class AbstractHealthCheckPolicy(HealthCheckPolicy):
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
if health_check_probes < 1:
|
||||
raise ValueError("health_check_probes must be greater than 0")
|
||||
self._health_check_probes = health_check_probes
|
||||
self._health_check_delay = health_check_delay
|
||||
|
||||
@property
|
||||
def health_check_probes(self) -> int:
|
||||
return self._health_check_probes
|
||||
|
||||
@property
|
||||
def health_check_delay(self) -> float:
|
||||
return self._health_check_delay
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class HealthyAllPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if all health check probes are successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
for health_check in health_checks:
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if not await health_check.check_health(database):
|
||||
return False
|
||||
except Exception as e:
|
||||
raise UnhealthyDatabaseException("Unhealthy database", database, e)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
await asyncio.sleep(self._health_check_delay)
|
||||
return True
|
||||
|
||||
|
||||
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if a majority of health check probes are successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
for health_check in health_checks:
|
||||
if self.health_check_probes % 2 == 0:
|
||||
allowed_unsuccessful_probes = self.health_check_probes / 2
|
||||
else:
|
||||
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
|
||||
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if not await health_check.check_health(database):
|
||||
allowed_unsuccessful_probes -= 1
|
||||
if allowed_unsuccessful_probes <= 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
allowed_unsuccessful_probes -= 1
|
||||
if allowed_unsuccessful_probes <= 0:
|
||||
raise UnhealthyDatabaseException(
|
||||
"Unhealthy database", database, e
|
||||
)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
await asyncio.sleep(self._health_check_delay)
|
||||
return True
|
||||
|
||||
|
||||
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if at least one health check probe is successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
is_healthy = False
|
||||
|
||||
for health_check in health_checks:
|
||||
exception = None
|
||||
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if await health_check.check_health(database):
|
||||
is_healthy = True
|
||||
break
|
||||
else:
|
||||
is_healthy = False
|
||||
except Exception as e:
|
||||
exception = UnhealthyDatabaseException(
|
||||
"Unhealthy database", database, e
|
||||
)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
await asyncio.sleep(self._health_check_delay)
|
||||
|
||||
if not is_healthy and not exception:
|
||||
return is_healthy
|
||||
elif not is_healthy and exception:
|
||||
raise exception
|
||||
|
||||
return is_healthy
|
||||
|
||||
|
||||
class HealthCheckPolicies(Enum):
|
||||
HEALTHY_ALL = HealthyAllPolicy
|
||||
HEALTHY_MAJORITY = HealthyMajorityPolicy
|
||||
HEALTHY_ANY = HealthyAnyPolicy
|
||||
|
||||
|
||||
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
|
||||
|
||||
|
||||
class PingHealthCheck(HealthCheck):
|
||||
"""
|
||||
Health check based on PING command.
|
||||
"""
|
||||
|
||||
async def check_health(self, database) -> bool:
|
||||
if isinstance(database.client, Redis):
|
||||
return await database.client.execute_command("PING")
|
||||
else:
|
||||
# For a cluster checks if all nodes are healthy.
|
||||
all_nodes = database.client.get_nodes()
|
||||
for node in all_nodes:
|
||||
if not await node.redis_connection.execute_command("PING"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class LagAwareHealthCheck(HealthCheck):
|
||||
"""
|
||||
Health check available for Redis Enterprise deployments.
|
||||
Verify via REST API that the database is healthy based on different lags.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rest_api_port: int = 9443,
|
||||
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
auth_basic: Optional[Tuple[str, str]] = None,
|
||||
verify_tls: bool = True,
|
||||
# TLS verification (server) options
|
||||
ca_file: Optional[str] = None,
|
||||
ca_path: Optional[str] = None,
|
||||
ca_data: Optional[Union[str, bytes]] = None,
|
||||
# Mutual TLS (client cert) options
|
||||
client_cert_file: Optional[str] = None,
|
||||
client_key_file: Optional[str] = None,
|
||||
client_key_password: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize LagAwareHealthCheck with the specified parameters.
|
||||
|
||||
Args:
|
||||
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
|
||||
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
|
||||
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
|
||||
auth_basic: Tuple of (username, password) for basic authentication
|
||||
verify_tls: Whether to verify TLS certificates (default: True)
|
||||
ca_file: Path to CA certificate file for TLS verification
|
||||
ca_path: Path to CA certificates directory for TLS verification
|
||||
ca_data: CA certificate data as string or bytes
|
||||
client_cert_file: Path to client certificate file for mutual TLS
|
||||
client_key_file: Path to client private key file for mutual TLS
|
||||
client_key_password: Password for encrypted client private key
|
||||
"""
|
||||
self._http_client = AsyncHTTPClientWrapper(
|
||||
HttpClient(
|
||||
timeout=timeout,
|
||||
auth_basic=auth_basic,
|
||||
retry=Retry(NoBackoff(), retries=0),
|
||||
verify_tls=verify_tls,
|
||||
ca_file=ca_file,
|
||||
ca_path=ca_path,
|
||||
ca_data=ca_data,
|
||||
client_cert_file=client_cert_file,
|
||||
client_key_file=client_key_file,
|
||||
client_key_password=client_key_password,
|
||||
)
|
||||
)
|
||||
self._rest_api_port = rest_api_port
|
||||
self._lag_aware_tolerance = lag_aware_tolerance
|
||||
|
||||
async def check_health(self, database) -> bool:
|
||||
if database.health_check_url is None:
|
||||
raise ValueError(
|
||||
"Database health check url is not set. Please check DatabaseConfig for the current database."
|
||||
)
|
||||
|
||||
if isinstance(database.client, Redis):
|
||||
db_host = database.client.get_connection_kwargs()["host"]
|
||||
else:
|
||||
db_host = database.client.startup_nodes[0].host
|
||||
|
||||
base_url = f"{database.health_check_url}:{self._rest_api_port}"
|
||||
self._http_client.client.base_url = base_url
|
||||
|
||||
# Find bdb matching to the current database host
|
||||
matching_bdb = None
|
||||
for bdb in await self._http_client.get("/v1/bdbs"):
|
||||
for endpoint in bdb["endpoints"]:
|
||||
if endpoint["dns_name"] == db_host:
|
||||
matching_bdb = bdb
|
||||
break
|
||||
|
||||
# In case if the host was set as public IP
|
||||
for addr in endpoint["addr"]:
|
||||
if addr == db_host:
|
||||
matching_bdb = bdb
|
||||
break
|
||||
|
||||
if matching_bdb is None:
|
||||
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
|
||||
raise ValueError("Could not find a matching bdb")
|
||||
|
||||
url = (
|
||||
f"/v1/bdbs/{matching_bdb['uid']}/availability"
|
||||
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
|
||||
)
|
||||
await self._http_client.get(url, expect_json=False)
|
||||
|
||||
# Status checked in an http client, otherwise HttpError will be raised
|
||||
return True
|
||||
@@ -0,0 +1,58 @@
|
||||
from asyncio import sleep
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
|
||||
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
from redis.retry import AbstractRetry
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.backoff import AbstractBackoff
|
||||
|
||||
|
||||
class Retry(AbstractRetry[RedisError]):
|
||||
__hash__ = AbstractRetry.__hash__
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backoff: "AbstractBackoff",
|
||||
retries: int,
|
||||
supported_errors: Tuple[Type[RedisError], ...] = (
|
||||
ConnectionError,
|
||||
TimeoutError,
|
||||
),
|
||||
):
|
||||
super().__init__(backoff, retries, supported_errors)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Retry):
|
||||
return NotImplemented
|
||||
|
||||
return (
|
||||
self._backoff == other._backoff
|
||||
and self._retries == other._retries
|
||||
and set(self._supported_errors) == set(other._supported_errors)
|
||||
)
|
||||
|
||||
async def call_with_retry(
|
||||
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
|
||||
) -> T:
|
||||
"""
|
||||
Execute an operation that might fail and returns its result, or
|
||||
raise the exception that was thrown depending on the `Backoff` object.
|
||||
`do`: the operation to call. Expects no argument.
|
||||
`fail`: the failure handler, expects the last error that was thrown
|
||||
"""
|
||||
self._backoff.reset()
|
||||
failures = 0
|
||||
while True:
|
||||
try:
|
||||
return await do()
|
||||
except self._supported_errors as error:
|
||||
failures += 1
|
||||
await fail(error)
|
||||
if self._retries >= 0 and failures > self._retries:
|
||||
raise error
|
||||
backoff = self._backoff.compute(failures)
|
||||
if backoff > 0:
|
||||
await sleep(backoff)
|
||||
@@ -0,0 +1,404 @@
|
||||
import asyncio
|
||||
import random
|
||||
import weakref
|
||||
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
|
||||
|
||||
from redis.asyncio.client import Redis
|
||||
from redis.asyncio.connection import (
|
||||
Connection,
|
||||
ConnectionPool,
|
||||
EncodableT,
|
||||
SSLConnection,
|
||||
)
|
||||
from redis.commands import AsyncSentinelCommands
|
||||
from redis.exceptions import (
|
||||
ConnectionError,
|
||||
ReadOnlyError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
|
||||
class MasterNotFoundError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class SlaveNotFoundError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class SentinelManagedConnection(Connection):
|
||||
def __init__(self, **kwargs):
|
||||
self.connection_pool = kwargs.pop("connection_pool")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
s = f"<{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
if self.host:
|
||||
host_info = f",host={self.host},port={self.port}"
|
||||
s += host_info
|
||||
return s + ")>"
|
||||
|
||||
async def connect_to(self, address):
|
||||
self.host, self.port = address
|
||||
await self.connect_check_health(
|
||||
check_health=self.connection_pool.check_connection,
|
||||
retry_socket_connect=False,
|
||||
)
|
||||
|
||||
async def _connect_retry(self):
|
||||
if self._reader:
|
||||
return # already connected
|
||||
if self.connection_pool.is_master:
|
||||
await self.connect_to(await self.connection_pool.get_master_address())
|
||||
else:
|
||||
async for slave in self.connection_pool.rotate_slaves():
|
||||
try:
|
||||
return await self.connect_to(slave)
|
||||
except ConnectionError:
|
||||
continue
|
||||
raise SlaveNotFoundError # Never be here
|
||||
|
||||
async def connect(self):
|
||||
return await self.retry.call_with_retry(
|
||||
self._connect_retry,
|
||||
lambda error: asyncio.sleep(0),
|
||||
)
|
||||
|
||||
async def read_response(
|
||||
self,
|
||||
disable_decoding: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
*,
|
||||
disconnect_on_error: Optional[float] = True,
|
||||
push_request: Optional[bool] = False,
|
||||
):
|
||||
try:
|
||||
return await super().read_response(
|
||||
disable_decoding=disable_decoding,
|
||||
timeout=timeout,
|
||||
disconnect_on_error=disconnect_on_error,
|
||||
push_request=push_request,
|
||||
)
|
||||
except ReadOnlyError:
|
||||
if self.connection_pool.is_master:
|
||||
# When talking to a master, a ReadOnlyError when likely
|
||||
# indicates that the previous master that we're still connected
|
||||
# to has been demoted to a slave and there's a new master.
|
||||
# calling disconnect will force the connection to re-query
|
||||
# sentinel during the next connect() attempt.
|
||||
await self.disconnect()
|
||||
raise ConnectionError("The previous master is now a slave")
|
||||
raise
|
||||
|
||||
|
||||
class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
|
||||
pass
|
||||
|
||||
|
||||
class SentinelConnectionPool(ConnectionPool):
|
||||
"""
|
||||
Sentinel backed connection pool.
|
||||
|
||||
If ``check_connection`` flag is set to True, SentinelManagedConnection
|
||||
sends a PING command right after establishing the connection.
|
||||
"""
|
||||
|
||||
def __init__(self, service_name, sentinel_manager, **kwargs):
|
||||
kwargs["connection_class"] = kwargs.get(
|
||||
"connection_class",
|
||||
(
|
||||
SentinelManagedSSLConnection
|
||||
if kwargs.pop("ssl", False)
|
||||
else SentinelManagedConnection
|
||||
),
|
||||
)
|
||||
self.is_master = kwargs.pop("is_master", True)
|
||||
self.check_connection = kwargs.pop("check_connection", False)
|
||||
super().__init__(**kwargs)
|
||||
self.connection_kwargs["connection_pool"] = weakref.proxy(self)
|
||||
self.service_name = service_name
|
||||
self.sentinel_manager = sentinel_manager
|
||||
self.master_address = None
|
||||
self.slave_rr_counter = None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>"
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.master_address = None
|
||||
self.slave_rr_counter = None
|
||||
|
||||
def owns_connection(self, connection: Connection):
|
||||
check = not self.is_master or (
|
||||
self.is_master and self.master_address == (connection.host, connection.port)
|
||||
)
|
||||
return check and super().owns_connection(connection)
|
||||
|
||||
async def get_master_address(self):
|
||||
master_address = await self.sentinel_manager.discover_master(self.service_name)
|
||||
if self.is_master:
|
||||
if self.master_address != master_address:
|
||||
self.master_address = master_address
|
||||
# disconnect any idle connections so that they reconnect
|
||||
# to the new master the next time that they are used.
|
||||
await self.disconnect(inuse_connections=False)
|
||||
return master_address
|
||||
|
||||
async def rotate_slaves(self) -> AsyncIterator:
|
||||
"""Round-robin slave balancer"""
|
||||
slaves = await self.sentinel_manager.discover_slaves(self.service_name)
|
||||
if slaves:
|
||||
if self.slave_rr_counter is None:
|
||||
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
|
||||
for _ in range(len(slaves)):
|
||||
self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves)
|
||||
slave = slaves[self.slave_rr_counter]
|
||||
yield slave
|
||||
# Fallback to the master connection
|
||||
try:
|
||||
yield await self.get_master_address()
|
||||
except MasterNotFoundError:
|
||||
pass
|
||||
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
|
||||
|
||||
|
||||
class Sentinel(AsyncSentinelCommands):
|
||||
"""
|
||||
Redis Sentinel cluster client
|
||||
|
||||
>>> from redis.sentinel import Sentinel
|
||||
>>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1)
|
||||
>>> master = sentinel.master_for('mymaster', socket_timeout=0.1)
|
||||
>>> await master.set('foo', 'bar')
|
||||
>>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1)
|
||||
>>> await slave.get('foo')
|
||||
b'bar'
|
||||
|
||||
``sentinels`` is a list of sentinel nodes. Each node is represented by
|
||||
a pair (hostname, port).
|
||||
|
||||
``min_other_sentinels`` defined a minimum number of peers for a sentinel.
|
||||
When querying a sentinel, if it doesn't meet this threshold, responses
|
||||
from that sentinel won't be considered valid.
|
||||
|
||||
``sentinel_kwargs`` is a dictionary of connection arguments used when
|
||||
connecting to sentinel instances. Any argument that can be passed to
|
||||
a normal Redis connection can be specified here. If ``sentinel_kwargs`` is
|
||||
not specified, any socket_timeout and socket_keepalive options specified
|
||||
in ``connection_kwargs`` will be used.
|
||||
|
||||
``connection_kwargs`` are keyword arguments that will be used when
|
||||
establishing a connection to a Redis server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sentinels,
|
||||
min_other_sentinels=0,
|
||||
sentinel_kwargs=None,
|
||||
force_master_ip=None,
|
||||
**connection_kwargs,
|
||||
):
|
||||
# if sentinel_kwargs isn't defined, use the socket_* options from
|
||||
# connection_kwargs
|
||||
if sentinel_kwargs is None:
|
||||
sentinel_kwargs = {
|
||||
k: v for k, v in connection_kwargs.items() if k.startswith("socket_")
|
||||
}
|
||||
self.sentinel_kwargs = sentinel_kwargs
|
||||
|
||||
self.sentinels = [
|
||||
Redis(host=hostname, port=port, **self.sentinel_kwargs)
|
||||
for hostname, port in sentinels
|
||||
]
|
||||
self.min_other_sentinels = min_other_sentinels
|
||||
self.connection_kwargs = connection_kwargs
|
||||
self._force_master_ip = force_master_ip
|
||||
|
||||
async def execute_command(self, *args, **kwargs):
|
||||
"""
|
||||
Execute Sentinel command in sentinel nodes.
|
||||
once - If set to True, then execute the resulting command on a single
|
||||
node at random, rather than across the entire sentinel cluster.
|
||||
"""
|
||||
once = bool(kwargs.pop("once", False))
|
||||
|
||||
# Check if command is supposed to return the original
|
||||
# responses instead of boolean value.
|
||||
return_responses = bool(kwargs.pop("return_responses", False))
|
||||
|
||||
if once:
|
||||
response = await random.choice(self.sentinels).execute_command(
|
||||
*args, **kwargs
|
||||
)
|
||||
if return_responses:
|
||||
return [response]
|
||||
else:
|
||||
return True if response else False
|
||||
|
||||
tasks = [
|
||||
asyncio.Task(sentinel.execute_command(*args, **kwargs))
|
||||
for sentinel in self.sentinels
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
if return_responses:
|
||||
return responses
|
||||
|
||||
return all(responses)
|
||||
|
||||
def __repr__(self):
|
||||
sentinel_addresses = []
|
||||
for sentinel in self.sentinels:
|
||||
sentinel_addresses.append(
|
||||
f"{sentinel.connection_pool.connection_kwargs['host']}:"
|
||||
f"{sentinel.connection_pool.connection_kwargs['port']}"
|
||||
)
|
||||
return (
|
||||
f"<{self.__class__}.{self.__class__.__name__}"
|
||||
f"(sentinels=[{','.join(sentinel_addresses)}])>"
|
||||
)
|
||||
|
||||
def check_master_state(self, state: dict, service_name: str) -> bool:
|
||||
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
|
||||
return False
|
||||
# Check if our sentinel doesn't see other nodes
|
||||
if state["num-other-sentinels"] < self.min_other_sentinels:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def discover_master(self, service_name: str):
|
||||
"""
|
||||
Asks sentinel servers for the Redis master's address corresponding
|
||||
to the service labeled ``service_name``.
|
||||
|
||||
Returns a pair (address, port) or raises MasterNotFoundError if no
|
||||
master is found.
|
||||
"""
|
||||
collected_errors = list()
|
||||
for sentinel_no, sentinel in enumerate(self.sentinels):
|
||||
try:
|
||||
masters = await sentinel.sentinel_masters()
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
collected_errors.append(f"{sentinel} - {e!r}")
|
||||
continue
|
||||
state = masters.get(service_name)
|
||||
if state and self.check_master_state(state, service_name):
|
||||
# Put this sentinel at the top of the list
|
||||
self.sentinels[0], self.sentinels[sentinel_no] = (
|
||||
sentinel,
|
||||
self.sentinels[0],
|
||||
)
|
||||
|
||||
ip = (
|
||||
self._force_master_ip
|
||||
if self._force_master_ip is not None
|
||||
else state["ip"]
|
||||
)
|
||||
return ip, state["port"]
|
||||
|
||||
error_info = ""
|
||||
if len(collected_errors) > 0:
|
||||
error_info = f" : {', '.join(collected_errors)}"
|
||||
raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}")
|
||||
|
||||
def filter_slaves(
|
||||
self, slaves: Iterable[Mapping]
|
||||
) -> Sequence[Tuple[EncodableT, EncodableT]]:
|
||||
"""Remove slaves that are in an ODOWN or SDOWN state"""
|
||||
slaves_alive = []
|
||||
for slave in slaves:
|
||||
if slave["is_odown"] or slave["is_sdown"]:
|
||||
continue
|
||||
slaves_alive.append((slave["ip"], slave["port"]))
|
||||
return slaves_alive
|
||||
|
||||
async def discover_slaves(
|
||||
self, service_name: str
|
||||
) -> Sequence[Tuple[EncodableT, EncodableT]]:
|
||||
"""Returns a list of alive slaves for service ``service_name``"""
|
||||
for sentinel in self.sentinels:
|
||||
try:
|
||||
slaves = await sentinel.sentinel_slaves(service_name)
|
||||
except (ConnectionError, ResponseError, TimeoutError):
|
||||
continue
|
||||
slaves = self.filter_slaves(slaves)
|
||||
if slaves:
|
||||
return slaves
|
||||
return []
|
||||
|
||||
def master_for(
|
||||
self,
|
||||
service_name: str,
|
||||
redis_class: Type[Redis] = Redis,
|
||||
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a redis client instance for the ``service_name`` master.
|
||||
Sentinel client will detect failover and reconnect Redis clients
|
||||
automatically.
|
||||
|
||||
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
|
||||
used to retrieve the master's address before establishing a new
|
||||
connection.
|
||||
|
||||
NOTE: If the master's address has changed, any cached connections to
|
||||
the old master are closed.
|
||||
|
||||
By default clients will be a :py:class:`~redis.Redis` instance.
|
||||
Specify a different class to the ``redis_class`` argument if you
|
||||
desire something different.
|
||||
|
||||
The ``connection_pool_class`` specifies the connection pool to
|
||||
use. The :py:class:`~redis.sentinel.SentinelConnectionPool`
|
||||
will be used by default.
|
||||
|
||||
All other keyword arguments are merged with any connection_kwargs
|
||||
passed to this class and passed to the connection pool as keyword
|
||||
arguments to be used to initialize Redis connections.
|
||||
"""
|
||||
kwargs["is_master"] = True
|
||||
connection_kwargs = dict(self.connection_kwargs)
|
||||
connection_kwargs.update(kwargs)
|
||||
|
||||
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
|
||||
# The Redis object "owns" the pool
|
||||
return redis_class.from_pool(connection_pool)
|
||||
|
||||
def slave_for(
|
||||
self,
|
||||
service_name: str,
|
||||
redis_class: Type[Redis] = Redis,
|
||||
connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns redis client instance for the ``service_name`` slave(s).
|
||||
|
||||
A SentinelConnectionPool class is used to retrieve the slave's
|
||||
address before establishing a new connection.
|
||||
|
||||
By default clients will be a :py:class:`~redis.Redis` instance.
|
||||
Specify a different class to the ``redis_class`` argument if you
|
||||
desire something different.
|
||||
|
||||
The ``connection_pool_class`` specifies the connection pool to use.
|
||||
The SentinelConnectionPool will be used by default.
|
||||
|
||||
All other keyword arguments are merged with any connection_kwargs
|
||||
passed to this class and passed to the connection pool as keyword
|
||||
arguments to be used to initialize Redis connections.
|
||||
"""
|
||||
kwargs["is_master"] = False
|
||||
connection_kwargs = dict(self.connection_kwargs)
|
||||
connection_kwargs.update(kwargs)
|
||||
|
||||
connection_pool = connection_pool_class(service_name, self, **connection_kwargs)
|
||||
# The Redis object "owns" the pool
|
||||
return redis_class.from_pool(connection_pool)
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio.client import Pipeline, Redis
|
||||
|
||||
|
||||
def from_url(url, **kwargs):
|
||||
"""
|
||||
Returns an active Redis client generated from the given database URL.
|
||||
|
||||
Will attempt to extract the database id from the path url fragment, if
|
||||
none is provided.
|
||||
"""
|
||||
from redis.asyncio.client import Redis
|
||||
|
||||
return Redis.from_url(url, **kwargs)
|
||||
|
||||
|
||||
class pipeline: # noqa: N801
|
||||
def __init__(self, redis_obj: "Redis"):
|
||||
self.p: "Pipeline" = redis_obj.pipeline()
|
||||
|
||||
async def __aenter__(self) -> "Pipeline":
|
||||
return self.p
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.p.execute()
|
||||
del self.p
|
||||
Reference in New Issue
Block a user