first commit
This commit is contained in:
@@ -0,0 +1,530 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union
|
||||
|
||||
from redis.asyncio.client import PubSubHandler
|
||||
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
|
||||
from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
|
||||
from redis.asyncio.multidb.database import AsyncDatabase, Databases
|
||||
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
|
||||
from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy
|
||||
from redis.background import BackgroundScheduler
|
||||
from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands
|
||||
from redis.multidb.circuit import CircuitBreaker
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException
|
||||
from redis.typing import ChannelT, EncodableT, KeyT
|
||||
from redis.utils import experimental
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@experimental
|
||||
class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
|
||||
"""
|
||||
Client that operates on multiple logical Redis databases.
|
||||
Should be used in Active-Active database setups.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MultiDbConfig):
|
||||
self._databases = config.databases()
|
||||
self._health_checks = config.default_health_checks()
|
||||
|
||||
if config.health_checks is not None:
|
||||
self._health_checks.extend(config.health_checks)
|
||||
|
||||
self._health_check_interval = config.health_check_interval
|
||||
self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value(
|
||||
config.health_check_probes, config.health_check_delay
|
||||
)
|
||||
self._failure_detectors = config.default_failure_detectors()
|
||||
|
||||
if config.failure_detectors is not None:
|
||||
self._failure_detectors.extend(config.failure_detectors)
|
||||
|
||||
self._failover_strategy = (
|
||||
config.default_failover_strategy()
|
||||
if config.failover_strategy is None
|
||||
else config.failover_strategy
|
||||
)
|
||||
self._failover_strategy.set_databases(self._databases)
|
||||
self._auto_fallback_interval = config.auto_fallback_interval
|
||||
self._event_dispatcher = config.event_dispatcher
|
||||
self._command_retry = config.command_retry
|
||||
self._command_retry.update_supported_errors([ConnectionRefusedError])
|
||||
self.command_executor = DefaultCommandExecutor(
|
||||
failure_detectors=self._failure_detectors,
|
||||
databases=self._databases,
|
||||
command_retry=self._command_retry,
|
||||
failover_strategy=self._failover_strategy,
|
||||
failover_attempts=config.failover_attempts,
|
||||
failover_delay=config.failover_delay,
|
||||
event_dispatcher=self._event_dispatcher,
|
||||
auto_fallback_interval=self._auto_fallback_interval,
|
||||
)
|
||||
self.initialized = False
|
||||
self._hc_lock = asyncio.Lock()
|
||||
self._bg_scheduler = BackgroundScheduler()
|
||||
self._config = config
|
||||
self._recurring_hc_task = None
|
||||
self._hc_tasks = []
|
||||
self._half_open_state_task = None
|
||||
|
||||
async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
if self._recurring_hc_task:
|
||||
self._recurring_hc_task.cancel()
|
||||
if self._half_open_state_task:
|
||||
self._half_open_state_task.cancel()
|
||||
for hc_task in self._hc_tasks:
|
||||
hc_task.cancel()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
Perform initialization of databases to define their initial state.
|
||||
"""
|
||||
|
||||
async def raise_exception_on_failed_hc(error):
|
||||
raise error
|
||||
|
||||
# Initial databases check to define initial state
|
||||
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
|
||||
|
||||
# Starts recurring health checks on the background.
|
||||
self._recurring_hc_task = asyncio.create_task(
|
||||
self._bg_scheduler.run_recurring_async(
|
||||
self._health_check_interval,
|
||||
self._check_databases_health,
|
||||
)
|
||||
)
|
||||
|
||||
is_active_db_found = False
|
||||
|
||||
for database, weight in self._databases:
|
||||
# Set on state changed callback for each circuit.
|
||||
database.circuit.on_state_changed(self._on_circuit_state_change_callback)
|
||||
|
||||
# Set states according to a weights and circuit state
|
||||
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
|
||||
await self.command_executor.set_active_database(database)
|
||||
is_active_db_found = True
|
||||
|
||||
if not is_active_db_found:
|
||||
raise NoValidDatabaseException(
|
||||
"Initial connection failed - no active database found"
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def get_databases(self) -> Databases:
|
||||
"""
|
||||
Returns a sorted (by weight) list of all databases.
|
||||
"""
|
||||
return self._databases
|
||||
|
||||
async def set_active_database(self, database: AsyncDatabase) -> None:
|
||||
"""
|
||||
Promote one of the existing databases to become an active.
|
||||
"""
|
||||
exists = None
|
||||
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
exists = True
|
||||
break
|
||||
|
||||
if not exists:
|
||||
raise ValueError("Given database is not a member of database list")
|
||||
|
||||
await self._check_db_health(database)
|
||||
|
||||
if database.circuit.state == CBState.CLOSED:
|
||||
highest_weighted_db, _ = self._databases.get_top_n(1)[0]
|
||||
await self.command_executor.set_active_database(database)
|
||||
return
|
||||
|
||||
raise NoValidDatabaseException(
|
||||
"Cannot set active database, database is unhealthy"
|
||||
)
|
||||
|
||||
async def add_database(self, database: AsyncDatabase):
|
||||
"""
|
||||
Adds a new database to the database list.
|
||||
"""
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
raise ValueError("Given database already exists")
|
||||
|
||||
await self._check_db_health(database)
|
||||
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
self._databases.add(database, database.weight)
|
||||
await self._change_active_database(database, highest_weighted_db)
|
||||
|
||||
async def _change_active_database(
|
||||
self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase
|
||||
):
|
||||
if (
|
||||
new_database.weight > highest_weight_database.weight
|
||||
and new_database.circuit.state == CBState.CLOSED
|
||||
):
|
||||
await self.command_executor.set_active_database(new_database)
|
||||
|
||||
async def remove_database(self, database: AsyncDatabase):
|
||||
"""
|
||||
Removes a database from the database list.
|
||||
"""
|
||||
weight = self._databases.remove(database)
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
|
||||
if (
|
||||
highest_weight <= weight
|
||||
and highest_weighted_db.circuit.state == CBState.CLOSED
|
||||
):
|
||||
await self.command_executor.set_active_database(highest_weighted_db)
|
||||
|
||||
async def update_database_weight(self, database: AsyncDatabase, weight: float):
|
||||
"""
|
||||
Updates a database from the database list.
|
||||
"""
|
||||
exists = None
|
||||
|
||||
for existing_db, _ in self._databases:
|
||||
if existing_db == database:
|
||||
exists = True
|
||||
break
|
||||
|
||||
if not exists:
|
||||
raise ValueError("Given database is not a member of database list")
|
||||
|
||||
highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
|
||||
self._databases.update_weight(database, weight)
|
||||
database.weight = weight
|
||||
await self._change_active_database(database, highest_weighted_db)
|
||||
|
||||
def add_failure_detector(self, failure_detector: AsyncFailureDetector):
|
||||
"""
|
||||
Adds a new failure detector to the database.
|
||||
"""
|
||||
self._failure_detectors.append(failure_detector)
|
||||
|
||||
async def add_health_check(self, healthcheck: HealthCheck):
|
||||
"""
|
||||
Adds a new health check to the database.
|
||||
"""
|
||||
async with self._hc_lock:
|
||||
self._health_checks.append(healthcheck)
|
||||
|
||||
async def execute_command(self, *args, **options):
|
||||
"""
|
||||
Executes a single command and return its result.
|
||||
"""
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
|
||||
return await self.command_executor.execute_command(*args, **options)
|
||||
|
||||
def pipeline(self):
|
||||
"""
|
||||
Enters into pipeline mode of the client.
|
||||
"""
|
||||
return Pipeline(self)
|
||||
|
||||
async def transaction(
|
||||
self,
|
||||
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
|
||||
*watches: KeyT,
|
||||
shard_hint: Optional[str] = None,
|
||||
value_from_callable: bool = False,
|
||||
watch_delay: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Executes callable as transaction.
|
||||
"""
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
|
||||
return await self.command_executor.execute_transaction(
|
||||
func,
|
||||
*watches,
|
||||
shard_hint=shard_hint,
|
||||
value_from_callable=value_from_callable,
|
||||
watch_delay=watch_delay,
|
||||
)
|
||||
|
||||
async def pubsub(self, **kwargs):
|
||||
"""
|
||||
Return a Publish/Subscribe object. With this object, you can
|
||||
subscribe to channels and listen for messages that get published to
|
||||
them.
|
||||
"""
|
||||
if not self.initialized:
|
||||
await self.initialize()
|
||||
|
||||
return PubSub(self, **kwargs)
|
||||
|
||||
async def _check_databases_health(
|
||||
self,
|
||||
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
|
||||
):
|
||||
"""
|
||||
Runs health checks as a recurring task.
|
||||
Runs health checks against all databases.
|
||||
"""
|
||||
try:
|
||||
self._hc_tasks = [
|
||||
asyncio.create_task(self._check_db_health(database))
|
||||
for database, _ in self._databases
|
||||
]
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(
|
||||
*self._hc_tasks,
|
||||
return_exceptions=True,
|
||||
),
|
||||
timeout=self._health_check_interval,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise asyncio.TimeoutError(
|
||||
"Health check execution exceeds health_check_interval"
|
||||
)
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, UnhealthyDatabaseException):
|
||||
unhealthy_db = result.database
|
||||
unhealthy_db.circuit.state = CBState.OPEN
|
||||
|
||||
logger.exception(
|
||||
"Health check failed, due to exception",
|
||||
exc_info=result.original_exception,
|
||||
)
|
||||
|
||||
if on_error:
|
||||
on_error(result.original_exception)
|
||||
|
||||
async def _check_db_health(self, database: AsyncDatabase) -> bool:
|
||||
"""
|
||||
Runs health checks on the given database until first failure.
|
||||
"""
|
||||
# Health check will setup circuit state
|
||||
is_healthy = await self._health_check_policy.execute(
|
||||
self._health_checks, database
|
||||
)
|
||||
|
||||
if not is_healthy:
|
||||
if database.circuit.state != CBState.OPEN:
|
||||
database.circuit.state = CBState.OPEN
|
||||
return is_healthy
|
||||
elif is_healthy and database.circuit.state != CBState.CLOSED:
|
||||
database.circuit.state = CBState.CLOSED
|
||||
|
||||
return is_healthy
|
||||
|
||||
def _on_circuit_state_change_callback(
|
||||
self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
|
||||
):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if new_state == CBState.HALF_OPEN:
|
||||
self._half_open_state_task = asyncio.create_task(
|
||||
self._check_db_health(circuit.database)
|
||||
)
|
||||
return
|
||||
|
||||
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
|
||||
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)
|
||||
|
||||
async def aclose(self):
|
||||
if self.command_executor.active_database:
|
||||
await self.command_executor.active_database.client.aclose()
|
||||
|
||||
|
||||
def _half_open_circuit(circuit: CircuitBreaker):
|
||||
circuit.state = CBState.HALF_OPEN
|
||||
|
||||
|
||||
class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands):
|
||||
"""
|
||||
Pipeline implementation for multiple logical Redis databases.
|
||||
"""
|
||||
|
||||
def __init__(self, client: MultiDBClient):
|
||||
self._command_stack = []
|
||||
self._client = client
|
||||
|
||||
async def __aenter__(self: "Pipeline") -> "Pipeline":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.reset()
|
||||
await self._client.__aexit__(exc_type, exc_value, traceback)
|
||||
|
||||
def __await__(self):
|
||||
return self._async_self().__await__()
|
||||
|
||||
async def _async_self(self):
|
||||
return self
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._command_stack)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Pipeline instances should always evaluate to True"""
|
||||
return True
|
||||
|
||||
async def reset(self) -> None:
|
||||
self._command_stack = []
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the pipeline"""
|
||||
await self.reset()
|
||||
|
||||
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
|
||||
"""
|
||||
Stage a command to be executed when execute() is next called
|
||||
|
||||
Returns the current Pipeline object back so commands can be
|
||||
chained together, such as:
|
||||
|
||||
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
|
||||
|
||||
At some other point, you can then run: pipe.execute(),
|
||||
which will execute all commands queued in the pipe.
|
||||
"""
|
||||
self._command_stack.append((args, options))
|
||||
return self
|
||||
|
||||
def execute_command(self, *args, **kwargs):
|
||||
"""Adds a command to the stack"""
|
||||
return self.pipeline_execute_command(*args, **kwargs)
|
||||
|
||||
async def execute(self) -> List[Any]:
|
||||
"""Execute all the commands in the current pipeline"""
|
||||
if not self._client.initialized:
|
||||
await self._client.initialize()
|
||||
|
||||
try:
|
||||
return await self._client.command_executor.execute_pipeline(
|
||||
tuple(self._command_stack)
|
||||
)
|
||||
finally:
|
||||
await self.reset()
|
||||
|
||||
|
||||
class PubSub:
|
||||
"""
|
||||
PubSub object for multi database client.
|
||||
"""
|
||||
|
||||
def __init__(self, client: MultiDBClient, **kwargs):
|
||||
"""Initialize the PubSub object for a multi-database client.
|
||||
|
||||
Args:
|
||||
client: MultiDBClient instance to use for pub/sub operations
|
||||
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
|
||||
"""
|
||||
|
||||
self._client = client
|
||||
self._client.command_executor.pubsub(**kwargs)
|
||||
|
||||
async def __aenter__(self) -> "PubSub":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self):
|
||||
return await self._client.command_executor.execute_pubsub_method("aclose")
|
||||
|
||||
@property
|
||||
def subscribed(self) -> bool:
|
||||
return self._client.command_executor.active_pubsub.subscribed
|
||||
|
||||
async def execute_command(self, *args: EncodableT):
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"execute_command", *args
|
||||
)
|
||||
|
||||
async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
|
||||
"""
|
||||
Subscribe to channel patterns. Patterns supplied as keyword arguments
|
||||
expect a pattern name as the key and a callable as the value. A
|
||||
pattern's callable will be invoked automatically when a message is
|
||||
received on that pattern rather than producing a message via
|
||||
``listen()``.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"psubscribe", *args, **kwargs
|
||||
)
|
||||
|
||||
async def punsubscribe(self, *args: ChannelT):
|
||||
"""
|
||||
Unsubscribe from the supplied patterns. If empty, unsubscribe from
|
||||
all patterns.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"punsubscribe", *args
|
||||
)
|
||||
|
||||
async def subscribe(self, *args: ChannelT, **kwargs: Callable):
|
||||
"""
|
||||
Subscribe to channels. Channels supplied as keyword arguments expect
|
||||
a channel name as the key and a callable as the value. A channel's
|
||||
callable will be invoked automatically when a message is received on
|
||||
that channel rather than producing a message via ``listen()`` or
|
||||
``get_message()``.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"subscribe", *args, **kwargs
|
||||
)
|
||||
|
||||
async def unsubscribe(self, *args):
|
||||
"""
|
||||
Unsubscribe from the supplied channels. If empty, unsubscribe from
|
||||
all channels
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"unsubscribe", *args
|
||||
)
|
||||
|
||||
async def get_message(
|
||||
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
|
||||
):
|
||||
"""
|
||||
Get the next message if one is available, otherwise None.
|
||||
|
||||
If timeout is specified, the system will wait for `timeout` seconds
|
||||
before returning. Timeout should be specified as a floating point
|
||||
number or None to wait indefinitely.
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_method(
|
||||
"get_message",
|
||||
ignore_subscribe_messages=ignore_subscribe_messages,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
exception_handler=None,
|
||||
poll_timeout: float = 1.0,
|
||||
) -> None:
|
||||
"""Process pub/sub messages using registered callbacks.
|
||||
|
||||
This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
|
||||
redis-py, but it is a coroutine. To launch it as a separate task, use
|
||||
``asyncio.create_task``:
|
||||
|
||||
>>> task = asyncio.create_task(pubsub.run())
|
||||
|
||||
To shut it down, use asyncio cancellation:
|
||||
|
||||
>>> task.cancel()
|
||||
>>> await task
|
||||
"""
|
||||
return await self._client.command_executor.execute_pubsub_run(
|
||||
sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self
|
||||
)
|
||||
@@ -0,0 +1,339 @@
|
||||
from abc import abstractmethod
|
||||
from asyncio import iscoroutinefunction
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Union
|
||||
|
||||
from redis.asyncio import RedisCluster
|
||||
from redis.asyncio.client import Pipeline, PubSub
|
||||
from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases
|
||||
from redis.asyncio.multidb.event import (
|
||||
AsyncActiveDatabaseChanged,
|
||||
CloseConnectionOnActiveDatabaseChanged,
|
||||
RegisterCommandFailure,
|
||||
ResubscribeOnActiveDatabaseChanged,
|
||||
)
|
||||
from redis.asyncio.multidb.failover import (
|
||||
DEFAULT_FAILOVER_ATTEMPTS,
|
||||
DEFAULT_FAILOVER_DELAY,
|
||||
AsyncFailoverStrategy,
|
||||
DefaultFailoverStrategyExecutor,
|
||||
FailoverStrategyExecutor,
|
||||
)
|
||||
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
|
||||
from redis.asyncio.retry import Retry
|
||||
from redis.event import AsyncOnCommandsFailEvent, EventDispatcherInterface
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.command_executor import BaseCommandExecutor, CommandExecutor
|
||||
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
from redis.typing import KeyT
|
||||
|
||||
|
||||
class AsyncCommandExecutor(CommandExecutor):
|
||||
@property
|
||||
@abstractmethod
|
||||
def databases(self) -> Databases:
|
||||
"""Returns a list of databases."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failure_detectors(self) -> List[AsyncFailureDetector]:
|
||||
"""Returns a list of failure detectors."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
|
||||
"""Adds a new failure detector to the list of failure detectors."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def active_database(self) -> Optional[AsyncDatabase]:
|
||||
"""Returns currently active database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_active_database(self, database: AsyncDatabase) -> None:
|
||||
"""Sets the currently active database."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def active_pubsub(self) -> Optional[PubSub]:
|
||||
"""Returns currently active pubsub."""
|
||||
pass
|
||||
|
||||
@active_pubsub.setter
|
||||
@abstractmethod
|
||||
def active_pubsub(self, pubsub: PubSub) -> None:
|
||||
"""Sets currently active pubsub."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
|
||||
"""Returns failover strategy executor."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def command_retry(self) -> Retry:
|
||||
"""Returns command retry object."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def pubsub(self, **kwargs):
|
||||
"""Initializes a PubSub object on a currently active database"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_command(self, *args, **options):
|
||||
"""Executes a command and returns the result."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_pipeline(self, command_stack: tuple):
|
||||
"""Executes a stack of commands in pipeline."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_transaction(
|
||||
self, transaction: Callable[[Pipeline], None], *watches, **options
|
||||
):
|
||||
"""Executes a transaction block wrapped in callback."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
|
||||
"""Executes a given method on active pub/sub."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
|
||||
"""Executes pub/sub run in a thread."""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
failure_detectors: List[AsyncFailureDetector],
|
||||
databases: Databases,
|
||||
command_retry: Retry,
|
||||
failover_strategy: AsyncFailoverStrategy,
|
||||
event_dispatcher: EventDispatcherInterface,
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY,
|
||||
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
|
||||
):
|
||||
"""
|
||||
Initialize the DefaultCommandExecutor instance.
|
||||
|
||||
Args:
|
||||
failure_detectors: List of failure detector instances to monitor database health
|
||||
databases: Collection of available databases to execute commands on
|
||||
command_retry: Retry policy for failed command execution
|
||||
failover_strategy: Strategy for handling database failover
|
||||
event_dispatcher: Interface for dispatching events
|
||||
failover_attempts: Number of failover attempts
|
||||
failover_delay: Delay between failover attempts
|
||||
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
|
||||
"""
|
||||
super().__init__(auto_fallback_interval)
|
||||
|
||||
for fd in failure_detectors:
|
||||
fd.set_command_executor(command_executor=self)
|
||||
|
||||
self._databases = databases
|
||||
self._failure_detectors = failure_detectors
|
||||
self._command_retry = command_retry
|
||||
self._failover_strategy_executor = DefaultFailoverStrategyExecutor(
|
||||
failover_strategy, failover_attempts, failover_delay
|
||||
)
|
||||
self._event_dispatcher = event_dispatcher
|
||||
self._active_database: Optional[Database] = None
|
||||
self._active_pubsub: Optional[PubSub] = None
|
||||
self._active_pubsub_kwargs = {}
|
||||
self._setup_event_dispatcher()
|
||||
self._schedule_next_fallback()
|
||||
|
||||
@property
|
||||
def databases(self) -> Databases:
|
||||
return self._databases
|
||||
|
||||
@property
|
||||
def failure_detectors(self) -> List[AsyncFailureDetector]:
|
||||
return self._failure_detectors
|
||||
|
||||
def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
|
||||
self._failure_detectors.append(failure_detector)
|
||||
|
||||
@property
|
||||
def active_database(self) -> Optional[AsyncDatabase]:
|
||||
return self._active_database
|
||||
|
||||
async def set_active_database(self, database: AsyncDatabase) -> None:
|
||||
old_active = self._active_database
|
||||
self._active_database = database
|
||||
|
||||
if old_active is not None and old_active is not database:
|
||||
await self._event_dispatcher.dispatch_async(
|
||||
AsyncActiveDatabaseChanged(
|
||||
old_active,
|
||||
self._active_database,
|
||||
self,
|
||||
**self._active_pubsub_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def active_pubsub(self) -> Optional[PubSub]:
|
||||
return self._active_pubsub
|
||||
|
||||
@active_pubsub.setter
|
||||
def active_pubsub(self, pubsub: PubSub) -> None:
|
||||
self._active_pubsub = pubsub
|
||||
|
||||
@property
|
||||
def failover_strategy_executor(self) -> FailoverStrategyExecutor:
|
||||
return self._failover_strategy_executor
|
||||
|
||||
@property
|
||||
def command_retry(self) -> Retry:
|
||||
return self._command_retry
|
||||
|
||||
def pubsub(self, **kwargs):
|
||||
if self._active_pubsub is None:
|
||||
if isinstance(self._active_database.client, RedisCluster):
|
||||
raise ValueError("PubSub is not supported for RedisCluster")
|
||||
|
||||
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
|
||||
self._active_pubsub_kwargs = kwargs
|
||||
|
||||
async def execute_command(self, *args, **options):
|
||||
async def callback():
|
||||
response = await self._active_database.client.execute_command(
|
||||
*args, **options
|
||||
)
|
||||
await self._register_command_execution(args)
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback, args)
|
||||
|
||||
async def execute_pipeline(self, command_stack: tuple):
|
||||
async def callback():
|
||||
async with self._active_database.client.pipeline() as pipe:
|
||||
for command, options in command_stack:
|
||||
pipe.execute_command(*command, **options)
|
||||
|
||||
response = await pipe.execute()
|
||||
await self._register_command_execution(command_stack)
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback, command_stack)
|
||||
|
||||
async def execute_transaction(
|
||||
self,
|
||||
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
|
||||
*watches: KeyT,
|
||||
shard_hint: Optional[str] = None,
|
||||
value_from_callable: bool = False,
|
||||
watch_delay: Optional[float] = None,
|
||||
):
|
||||
async def callback():
|
||||
response = await self._active_database.client.transaction(
|
||||
func,
|
||||
*watches,
|
||||
shard_hint=shard_hint,
|
||||
value_from_callable=value_from_callable,
|
||||
watch_delay=watch_delay,
|
||||
)
|
||||
await self._register_command_execution(())
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback)
|
||||
|
||||
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
|
||||
async def callback():
|
||||
method = getattr(self.active_pubsub, method_name)
|
||||
if iscoroutinefunction(method):
|
||||
response = await method(*args, **kwargs)
|
||||
else:
|
||||
response = method(*args, **kwargs)
|
||||
|
||||
await self._register_command_execution(args)
|
||||
return response
|
||||
|
||||
return await self._execute_with_failure_detection(callback, *args)
|
||||
|
||||
async def execute_pubsub_run(
|
||||
self, sleep_time: float, exception_handler=None, pubsub=None
|
||||
) -> Any:
|
||||
async def callback():
|
||||
return await self._active_pubsub.run(
|
||||
poll_timeout=sleep_time,
|
||||
exception_handler=exception_handler,
|
||||
pubsub=pubsub,
|
||||
)
|
||||
|
||||
return await self._execute_with_failure_detection(callback)
|
||||
|
||||
async def _execute_with_failure_detection(
|
||||
self, callback: Callable, cmds: tuple = ()
|
||||
):
|
||||
"""
|
||||
Execute a commands execution callback with failure detection.
|
||||
"""
|
||||
|
||||
async def wrapper():
|
||||
# On each retry we need to check active database as it might change.
|
||||
await self._check_active_database()
|
||||
return await callback()
|
||||
|
||||
return await self._command_retry.call_with_retry(
|
||||
lambda: wrapper(),
|
||||
lambda error: self._on_command_fail(error, *cmds),
|
||||
)
|
||||
|
||||
async def _check_active_database(self):
|
||||
"""
|
||||
Checks if active a database needs to be updated.
|
||||
"""
|
||||
if (
|
||||
self._active_database is None
|
||||
or self._active_database.circuit.state != CBState.CLOSED
|
||||
or (
|
||||
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
and self._next_fallback_attempt <= datetime.now()
|
||||
)
|
||||
):
|
||||
await self.set_active_database(
|
||||
await self._failover_strategy_executor.execute()
|
||||
)
|
||||
self._schedule_next_fallback()
|
||||
|
||||
async def _on_command_fail(self, error, *args):
|
||||
await self._event_dispatcher.dispatch_async(
|
||||
AsyncOnCommandsFailEvent(args, error)
|
||||
)
|
||||
|
||||
async def _register_command_execution(self, cmd: tuple):
|
||||
for detector in self._failure_detectors:
|
||||
await detector.register_command_execution(cmd)
|
||||
|
||||
def _setup_event_dispatcher(self):
|
||||
"""
|
||||
Registers necessary listeners.
|
||||
"""
|
||||
failure_listener = RegisterCommandFailure(self._failure_detectors)
|
||||
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
|
||||
close_connection_listener = CloseConnectionOnActiveDatabaseChanged()
|
||||
self._event_dispatcher.register_listeners(
|
||||
{
|
||||
AsyncOnCommandsFailEvent: [failure_listener],
|
||||
AsyncActiveDatabaseChanged: [
|
||||
close_connection_listener,
|
||||
resubscribe_listener,
|
||||
],
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,210 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
import pybreaker
|
||||
|
||||
from redis.asyncio import ConnectionPool, Redis, RedisCluster
|
||||
from redis.asyncio.multidb.database import Database, Databases
|
||||
from redis.asyncio.multidb.failover import (
|
||||
DEFAULT_FAILOVER_ATTEMPTS,
|
||||
DEFAULT_FAILOVER_DELAY,
|
||||
AsyncFailoverStrategy,
|
||||
WeightBasedFailoverStrategy,
|
||||
)
|
||||
from redis.asyncio.multidb.failure_detector import (
|
||||
AsyncFailureDetector,
|
||||
FailureDetectorAsyncWrapper,
|
||||
)
|
||||
from redis.asyncio.multidb.healthcheck import (
|
||||
DEFAULT_HEALTH_CHECK_DELAY,
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL,
|
||||
DEFAULT_HEALTH_CHECK_POLICY,
|
||||
DEFAULT_HEALTH_CHECK_PROBES,
|
||||
HealthCheck,
|
||||
HealthCheckPolicies,
|
||||
PingHealthCheck,
|
||||
)
|
||||
from redis.asyncio.retry import Retry
|
||||
from redis.backoff import ExponentialWithJitterBackoff, NoBackoff
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.event import EventDispatcher, EventDispatcherInterface
|
||||
from redis.multidb.circuit import (
|
||||
DEFAULT_GRACE_PERIOD,
|
||||
CircuitBreaker,
|
||||
PBCircuitBreakerAdapter,
|
||||
)
|
||||
from redis.multidb.failure_detector import (
|
||||
DEFAULT_FAILURE_RATE_THRESHOLD,
|
||||
DEFAULT_FAILURES_DETECTION_WINDOW,
|
||||
DEFAULT_MIN_NUM_FAILURES,
|
||||
CommandFailureDetector,
|
||||
)
|
||||
|
||||
DEFAULT_AUTO_FALLBACK_INTERVAL = 120
|
||||
|
||||
|
||||
def default_event_dispatcher() -> EventDispatcherInterface:
|
||||
return EventDispatcher()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""
|
||||
Dataclass representing the configuration for a database connection.
|
||||
|
||||
This class is used to store configuration settings for a database connection,
|
||||
including client options, connection sourcing details, circuit breaker settings,
|
||||
and cluster-specific properties. It provides a structure for defining these
|
||||
attributes and allows for the creation of customized configurations for various
|
||||
database setups.
|
||||
|
||||
Attributes:
|
||||
weight (float): Weight of the database to define the active one.
|
||||
client_kwargs (dict): Additional parameters for the database client connection.
|
||||
from_url (Optional[str]): Redis URL way of connecting to the database.
|
||||
from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use.
|
||||
circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation.
|
||||
grace_period (float): Grace period after which we need to check if the circuit could be closed again.
|
||||
health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used
|
||||
on public Redis Enterprise endpoints.
|
||||
|
||||
Methods:
|
||||
default_circuit_breaker:
|
||||
Generates and returns a default CircuitBreaker instance adapted for use.
|
||||
"""
|
||||
|
||||
weight: float = 1.0
|
||||
client_kwargs: dict = field(default_factory=dict)
|
||||
from_url: Optional[str] = None
|
||||
from_pool: Optional[ConnectionPool] = None
|
||||
circuit: Optional[CircuitBreaker] = None
|
||||
grace_period: float = DEFAULT_GRACE_PERIOD
|
||||
health_check_url: Optional[str] = None
|
||||
|
||||
def default_circuit_breaker(self) -> CircuitBreaker:
|
||||
circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period)
|
||||
return PBCircuitBreakerAdapter(circuit_breaker)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDbConfig:
|
||||
"""
|
||||
Configuration class for managing multiple database connections in a resilient and fail-safe manner.
|
||||
|
||||
Attributes:
|
||||
databases_config: A list of database configurations.
|
||||
client_class: The client class used to manage database connections.
|
||||
command_retry: Retry strategy for executing database commands.
|
||||
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
|
||||
min_num_failures: Minimal count of failures required for failover
|
||||
failure_rate_threshold: Percentage of failures required for failover
|
||||
failures_detection_window: Time interval for tracking database failures.
|
||||
health_checks: Optional list of additional health checks performed on databases.
|
||||
health_check_interval: Time interval for executing health checks.
|
||||
health_check_probes: Number of attempts to evaluate the health of a database.
|
||||
health_check_delay: Delay between health check attempts.
|
||||
failover_strategy: Optional strategy for handling database failover scenarios.
|
||||
failover_attempts: Number of retries allowed for failover operations.
|
||||
failover_delay: Delay between failover attempts.
|
||||
auto_fallback_interval: Time interval to trigger automatic fallback.
|
||||
event_dispatcher: Interface for dispatching events related to database operations.
|
||||
|
||||
Methods:
|
||||
databases:
|
||||
Retrieves a collection of database clients managed by weighted configurations.
|
||||
Initializes database clients based on the provided configuration and removes
|
||||
redundant retry objects for lower-level clients to rely on global retry logic.
|
||||
|
||||
default_failure_detectors:
|
||||
Returns the default list of failure detectors used to monitor database failures.
|
||||
|
||||
default_health_checks:
|
||||
Returns the default list of health checks used to monitor database health
|
||||
with specific retry and backoff strategies.
|
||||
|
||||
default_failover_strategy:
|
||||
Provides the default failover strategy used for handling failover scenarios
|
||||
with defined retry and backoff configurations.
|
||||
"""
|
||||
|
||||
databases_config: List[DatabaseConfig]
|
||||
client_class: Type[Union[Redis, RedisCluster]] = Redis
|
||||
command_retry: Retry = Retry(
|
||||
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
|
||||
)
|
||||
failure_detectors: Optional[List[AsyncFailureDetector]] = None
|
||||
min_num_failures: int = DEFAULT_MIN_NUM_FAILURES
|
||||
failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD
|
||||
failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW
|
||||
health_checks: Optional[List[HealthCheck]] = None
|
||||
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
|
||||
health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES
|
||||
health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY
|
||||
health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY
|
||||
failover_strategy: Optional[AsyncFailoverStrategy] = None
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY
|
||||
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL
|
||||
event_dispatcher: EventDispatcherInterface = field(
|
||||
default_factory=default_event_dispatcher
|
||||
)
|
||||
|
||||
def databases(self) -> Databases:
|
||||
databases = WeightedList()
|
||||
|
||||
for database_config in self.databases_config:
|
||||
# The retry object is not used in the lower level clients, so we can safely remove it.
|
||||
# We rely on command_retry in terms of global retries.
|
||||
database_config.client_kwargs.update(
|
||||
{"retry": Retry(retries=0, backoff=NoBackoff())}
|
||||
)
|
||||
|
||||
if database_config.from_url:
|
||||
client = self.client_class.from_url(
|
||||
database_config.from_url, **database_config.client_kwargs
|
||||
)
|
||||
elif database_config.from_pool:
|
||||
database_config.from_pool.set_retry(
|
||||
Retry(retries=0, backoff=NoBackoff())
|
||||
)
|
||||
client = self.client_class.from_pool(
|
||||
connection_pool=database_config.from_pool
|
||||
)
|
||||
else:
|
||||
client = self.client_class(**database_config.client_kwargs)
|
||||
|
||||
circuit = (
|
||||
database_config.default_circuit_breaker()
|
||||
if database_config.circuit is None
|
||||
else database_config.circuit
|
||||
)
|
||||
databases.add(
|
||||
Database(
|
||||
client=client,
|
||||
circuit=circuit,
|
||||
weight=database_config.weight,
|
||||
health_check_url=database_config.health_check_url,
|
||||
),
|
||||
database_config.weight,
|
||||
)
|
||||
|
||||
return databases
|
||||
|
||||
def default_failure_detectors(self) -> List[AsyncFailureDetector]:
|
||||
return [
|
||||
FailureDetectorAsyncWrapper(
|
||||
CommandFailureDetector(
|
||||
min_num_failures=self.min_num_failures,
|
||||
failure_rate_threshold=self.failure_rate_threshold,
|
||||
failure_detection_window=self.failures_detection_window,
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def default_health_checks(self) -> List[HealthCheck]:
|
||||
return [
|
||||
PingHealthCheck(),
|
||||
]
|
||||
|
||||
def default_failover_strategy(self) -> AsyncFailoverStrategy:
|
||||
return WeightBasedFailoverStrategy()
|
||||
@@ -0,0 +1,69 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.multidb.circuit import CircuitBreaker
|
||||
from redis.multidb.database import AbstractDatabase, BaseDatabase
|
||||
from redis.typing import Number
|
||||
|
||||
|
||||
class AsyncDatabase(AbstractDatabase):
|
||||
"""Database with an underlying asynchronous redis client."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def client(self) -> Union[Redis, RedisCluster]:
|
||||
"""The underlying redis client."""
|
||||
pass
|
||||
|
||||
@client.setter
|
||||
@abstractmethod
|
||||
def client(self, client: Union[Redis, RedisCluster]):
|
||||
"""Set the underlying redis client."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def circuit(self) -> CircuitBreaker:
|
||||
"""Circuit breaker for the current database."""
|
||||
pass
|
||||
|
||||
@circuit.setter
|
||||
@abstractmethod
|
||||
def circuit(self, circuit: CircuitBreaker):
|
||||
"""Set the circuit breaker for the current database."""
|
||||
pass
|
||||
|
||||
|
||||
Databases = WeightedList[tuple[AsyncDatabase, Number]]
|
||||
|
||||
|
||||
class Database(BaseDatabase, AsyncDatabase):
|
||||
def __init__(
|
||||
self,
|
||||
client: Union[Redis, RedisCluster],
|
||||
circuit: CircuitBreaker,
|
||||
weight: float,
|
||||
health_check_url: Optional[str] = None,
|
||||
):
|
||||
self._client = client
|
||||
self._cb = circuit
|
||||
self._cb.database = self
|
||||
super().__init__(weight, health_check_url)
|
||||
|
||||
@property
|
||||
def client(self) -> Union[Redis, RedisCluster]:
|
||||
return self._client
|
||||
|
||||
@client.setter
|
||||
def client(self, client: Union[Redis, RedisCluster]):
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def circuit(self) -> CircuitBreaker:
|
||||
return self._cb
|
||||
|
||||
@circuit.setter
|
||||
def circuit(self, circuit: CircuitBreaker):
|
||||
self._cb = circuit
|
||||
@@ -0,0 +1,84 @@
|
||||
from typing import List
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio.multidb.database import AsyncDatabase
|
||||
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
|
||||
from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent
|
||||
|
||||
|
||||
class AsyncActiveDatabaseChanged:
|
||||
"""
|
||||
Event fired when an async active database has been changed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
old_database: AsyncDatabase,
|
||||
new_database: AsyncDatabase,
|
||||
command_executor,
|
||||
**kwargs,
|
||||
):
|
||||
self._old_database = old_database
|
||||
self._new_database = new_database
|
||||
self._command_executor = command_executor
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def old_database(self) -> AsyncDatabase:
|
||||
return self._old_database
|
||||
|
||||
@property
|
||||
def new_database(self) -> AsyncDatabase:
|
||||
return self._new_database
|
||||
|
||||
@property
|
||||
def command_executor(self):
|
||||
return self._command_executor
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
return self._kwargs
|
||||
|
||||
|
||||
class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface):
|
||||
"""
|
||||
Re-subscribe the currently active pub / sub to a new active database.
|
||||
"""
|
||||
|
||||
async def listen(self, event: AsyncActiveDatabaseChanged):
|
||||
old_pubsub = event.command_executor.active_pubsub
|
||||
|
||||
if old_pubsub is not None:
|
||||
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
|
||||
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
|
||||
new_pubsub.channels = old_pubsub.channels
|
||||
new_pubsub.patterns = old_pubsub.patterns
|
||||
await new_pubsub.on_connect(None)
|
||||
event.command_executor.active_pubsub = new_pubsub
|
||||
await old_pubsub.aclose()
|
||||
|
||||
|
||||
class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface):
|
||||
"""
|
||||
Close connection to the old active database.
|
||||
"""
|
||||
|
||||
async def listen(self, event: AsyncActiveDatabaseChanged):
|
||||
await event.old_database.client.aclose()
|
||||
|
||||
if isinstance(event.old_database.client, Redis):
|
||||
await event.old_database.client.connection_pool.update_active_connections_for_reconnect()
|
||||
await event.old_database.client.connection_pool.disconnect()
|
||||
|
||||
|
||||
class RegisterCommandFailure(AsyncEventListenerInterface):
|
||||
"""
|
||||
Event listener that registers command failures and passing it to the failure detectors.
|
||||
"""
|
||||
|
||||
def __init__(self, failure_detectors: List[AsyncFailureDetector]):
|
||||
self._failure_detectors = failure_detectors
|
||||
|
||||
async def listen(self, event: AsyncOnCommandsFailEvent) -> None:
|
||||
for failure_detector in self._failure_detectors:
|
||||
await failure_detector.register_failure(event.exception, event.commands)
|
||||
@@ -0,0 +1,125 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from redis.asyncio.multidb.database import AsyncDatabase, Databases
|
||||
from redis.data_structure import WeightedList
|
||||
from redis.multidb.circuit import State as CBState
|
||||
from redis.multidb.exception import (
|
||||
NoValidDatabaseException,
|
||||
TemporaryUnavailableException,
|
||||
)
|
||||
|
||||
DEFAULT_FAILOVER_ATTEMPTS = 10
|
||||
DEFAULT_FAILOVER_DELAY = 12
|
||||
|
||||
|
||||
class AsyncFailoverStrategy(ABC):
|
||||
@abstractmethod
|
||||
async def database(self) -> AsyncDatabase:
|
||||
"""Select the database according to the strategy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_databases(self, databases: Databases) -> None:
|
||||
"""Set the database strategy operates on."""
|
||||
pass
|
||||
|
||||
|
||||
class FailoverStrategyExecutor(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_attempts(self) -> int:
|
||||
"""The number of failover attempts."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def failover_delay(self) -> float:
|
||||
"""The delay between failover attempts."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def strategy(self) -> AsyncFailoverStrategy:
|
||||
"""The strategy to execute."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> AsyncDatabase:
|
||||
"""Execute the failover strategy."""
|
||||
pass
|
||||
|
||||
|
||||
class WeightBasedFailoverStrategy(AsyncFailoverStrategy):
|
||||
"""
|
||||
Failover strategy based on database weights.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._databases = WeightedList()
|
||||
|
||||
async def database(self) -> AsyncDatabase:
|
||||
for database, _ in self._databases:
|
||||
if database.circuit.state == CBState.CLOSED:
|
||||
return database
|
||||
|
||||
raise NoValidDatabaseException("No valid database available for communication")
|
||||
|
||||
def set_databases(self, databases: Databases) -> None:
|
||||
self._databases = databases
|
||||
|
||||
|
||||
class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor):
|
||||
"""
|
||||
Executes given failover strategy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: AsyncFailoverStrategy,
|
||||
failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS,
|
||||
failover_delay: float = DEFAULT_FAILOVER_DELAY,
|
||||
):
|
||||
self._strategy = strategy
|
||||
self._failover_attempts = failover_attempts
|
||||
self._failover_delay = failover_delay
|
||||
self._next_attempt_ts: int = 0
|
||||
self._failover_counter: int = 0
|
||||
|
||||
@property
|
||||
def failover_attempts(self) -> int:
|
||||
return self._failover_attempts
|
||||
|
||||
@property
|
||||
def failover_delay(self) -> float:
|
||||
return self._failover_delay
|
||||
|
||||
@property
|
||||
def strategy(self) -> AsyncFailoverStrategy:
|
||||
return self._strategy
|
||||
|
||||
async def execute(self) -> AsyncDatabase:
|
||||
try:
|
||||
database = await self._strategy.database()
|
||||
self._reset()
|
||||
return database
|
||||
except NoValidDatabaseException as e:
|
||||
if self._next_attempt_ts == 0:
|
||||
self._next_attempt_ts = time.time() + self._failover_delay
|
||||
self._failover_counter += 1
|
||||
elif time.time() >= self._next_attempt_ts:
|
||||
self._next_attempt_ts += self._failover_delay
|
||||
self._failover_counter += 1
|
||||
|
||||
if self._failover_counter > self._failover_attempts:
|
||||
self._reset()
|
||||
raise e
|
||||
else:
|
||||
raise TemporaryUnavailableException(
|
||||
"No database connections currently available. "
|
||||
"This is a temporary condition - please retry the operation."
|
||||
)
|
||||
|
||||
def _reset(self) -> None:
|
||||
self._next_attempt_ts = 0
|
||||
self._failover_counter = 0
|
||||
@@ -0,0 +1,38 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from redis.multidb.failure_detector import FailureDetector
|
||||
|
||||
|
||||
class AsyncFailureDetector(ABC):
|
||||
@abstractmethod
|
||||
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
|
||||
"""Register a failure that occurred during command execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def register_command_execution(self, cmd: tuple) -> None:
|
||||
"""Register a command execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_command_executor(self, command_executor) -> None:
|
||||
"""Set the command executor for this failure."""
|
||||
pass
|
||||
|
||||
|
||||
class FailureDetectorAsyncWrapper(AsyncFailureDetector):
|
||||
"""
|
||||
Async wrapper for the failure detector.
|
||||
"""
|
||||
|
||||
def __init__(self, failure_detector: FailureDetector) -> None:
|
||||
self._failure_detector = failure_detector
|
||||
|
||||
async def register_failure(self, exception: Exception, cmd: tuple) -> None:
|
||||
self._failure_detector.register_failure(exception, cmd)
|
||||
|
||||
async def register_command_execution(self, cmd: tuple) -> None:
|
||||
self._failure_detector.register_command_execution(cmd)
|
||||
|
||||
def set_command_executor(self, command_executor) -> None:
|
||||
self._failure_detector.set_command_executor(command_executor)
|
||||
@@ -0,0 +1,285 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
|
||||
from redis.backoff import NoBackoff
|
||||
from redis.http.http_client import HttpClient
|
||||
from redis.multidb.exception import UnhealthyDatabaseException
|
||||
from redis.retry import Retry
|
||||
|
||||
DEFAULT_HEALTH_CHECK_PROBES = 3
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL = 5
|
||||
DEFAULT_HEALTH_CHECK_DELAY = 0.5
|
||||
DEFAULT_LAG_AWARE_TOLERANCE = 5000
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthCheck(ABC):
|
||||
@abstractmethod
|
||||
async def check_health(self, database) -> bool:
|
||||
"""Function to determine the health status."""
|
||||
pass
|
||||
|
||||
|
||||
class HealthCheckPolicy(ABC):
|
||||
"""
|
||||
Health checks execution policy.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def health_check_probes(self) -> int:
|
||||
"""Number of probes to execute health checks."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def health_check_delay(self) -> float:
|
||||
"""Delay between health check probes."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
"""Execute health checks and return database health status."""
|
||||
pass
|
||||
|
||||
|
||||
class AbstractHealthCheckPolicy(HealthCheckPolicy):
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
if health_check_probes < 1:
|
||||
raise ValueError("health_check_probes must be greater than 0")
|
||||
self._health_check_probes = health_check_probes
|
||||
self._health_check_delay = health_check_delay
|
||||
|
||||
@property
|
||||
def health_check_probes(self) -> int:
|
||||
return self._health_check_probes
|
||||
|
||||
@property
|
||||
def health_check_delay(self) -> float:
|
||||
return self._health_check_delay
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class HealthyAllPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if all health check probes are successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
for health_check in health_checks:
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if not await health_check.check_health(database):
|
||||
return False
|
||||
except Exception as e:
|
||||
raise UnhealthyDatabaseException("Unhealthy database", database, e)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
await asyncio.sleep(self._health_check_delay)
|
||||
return True
|
||||
|
||||
|
||||
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if a majority of health check probes are successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
for health_check in health_checks:
|
||||
if self.health_check_probes % 2 == 0:
|
||||
allowed_unsuccessful_probes = self.health_check_probes / 2
|
||||
else:
|
||||
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
|
||||
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if not await health_check.check_health(database):
|
||||
allowed_unsuccessful_probes -= 1
|
||||
if allowed_unsuccessful_probes <= 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
allowed_unsuccessful_probes -= 1
|
||||
if allowed_unsuccessful_probes <= 0:
|
||||
raise UnhealthyDatabaseException(
|
||||
"Unhealthy database", database, e
|
||||
)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
await asyncio.sleep(self._health_check_delay)
|
||||
return True
|
||||
|
||||
|
||||
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
|
||||
"""
|
||||
Policy that returns True if at least one health check probe is successful.
|
||||
"""
|
||||
|
||||
def __init__(self, health_check_probes: int, health_check_delay: float):
|
||||
super().__init__(health_check_probes, health_check_delay)
|
||||
|
||||
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
||||
is_healthy = False
|
||||
|
||||
for health_check in health_checks:
|
||||
exception = None
|
||||
|
||||
for attempt in range(self.health_check_probes):
|
||||
try:
|
||||
if await health_check.check_health(database):
|
||||
is_healthy = True
|
||||
break
|
||||
else:
|
||||
is_healthy = False
|
||||
except Exception as e:
|
||||
exception = UnhealthyDatabaseException(
|
||||
"Unhealthy database", database, e
|
||||
)
|
||||
|
||||
if attempt < self.health_check_probes - 1:
|
||||
await asyncio.sleep(self._health_check_delay)
|
||||
|
||||
if not is_healthy and not exception:
|
||||
return is_healthy
|
||||
elif not is_healthy and exception:
|
||||
raise exception
|
||||
|
||||
return is_healthy
|
||||
|
||||
|
||||
class HealthCheckPolicies(Enum):
|
||||
HEALTHY_ALL = HealthyAllPolicy
|
||||
HEALTHY_MAJORITY = HealthyMajorityPolicy
|
||||
HEALTHY_ANY = HealthyAnyPolicy
|
||||
|
||||
|
||||
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
|
||||
|
||||
|
||||
class PingHealthCheck(HealthCheck):
|
||||
"""
|
||||
Health check based on PING command.
|
||||
"""
|
||||
|
||||
async def check_health(self, database) -> bool:
|
||||
if isinstance(database.client, Redis):
|
||||
return await database.client.execute_command("PING")
|
||||
else:
|
||||
# For a cluster checks if all nodes are healthy.
|
||||
all_nodes = database.client.get_nodes()
|
||||
for node in all_nodes:
|
||||
if not await node.redis_connection.execute_command("PING"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class LagAwareHealthCheck(HealthCheck):
|
||||
"""
|
||||
Health check available for Redis Enterprise deployments.
|
||||
Verify via REST API that the database is healthy based on different lags.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rest_api_port: int = 9443,
|
||||
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
auth_basic: Optional[Tuple[str, str]] = None,
|
||||
verify_tls: bool = True,
|
||||
# TLS verification (server) options
|
||||
ca_file: Optional[str] = None,
|
||||
ca_path: Optional[str] = None,
|
||||
ca_data: Optional[Union[str, bytes]] = None,
|
||||
# Mutual TLS (client cert) options
|
||||
client_cert_file: Optional[str] = None,
|
||||
client_key_file: Optional[str] = None,
|
||||
client_key_password: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize LagAwareHealthCheck with the specified parameters.
|
||||
|
||||
Args:
|
||||
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
|
||||
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
|
||||
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
|
||||
auth_basic: Tuple of (username, password) for basic authentication
|
||||
verify_tls: Whether to verify TLS certificates (default: True)
|
||||
ca_file: Path to CA certificate file for TLS verification
|
||||
ca_path: Path to CA certificates directory for TLS verification
|
||||
ca_data: CA certificate data as string or bytes
|
||||
client_cert_file: Path to client certificate file for mutual TLS
|
||||
client_key_file: Path to client private key file for mutual TLS
|
||||
client_key_password: Password for encrypted client private key
|
||||
"""
|
||||
self._http_client = AsyncHTTPClientWrapper(
|
||||
HttpClient(
|
||||
timeout=timeout,
|
||||
auth_basic=auth_basic,
|
||||
retry=Retry(NoBackoff(), retries=0),
|
||||
verify_tls=verify_tls,
|
||||
ca_file=ca_file,
|
||||
ca_path=ca_path,
|
||||
ca_data=ca_data,
|
||||
client_cert_file=client_cert_file,
|
||||
client_key_file=client_key_file,
|
||||
client_key_password=client_key_password,
|
||||
)
|
||||
)
|
||||
self._rest_api_port = rest_api_port
|
||||
self._lag_aware_tolerance = lag_aware_tolerance
|
||||
|
||||
async def check_health(self, database) -> bool:
|
||||
if database.health_check_url is None:
|
||||
raise ValueError(
|
||||
"Database health check url is not set. Please check DatabaseConfig for the current database."
|
||||
)
|
||||
|
||||
if isinstance(database.client, Redis):
|
||||
db_host = database.client.get_connection_kwargs()["host"]
|
||||
else:
|
||||
db_host = database.client.startup_nodes[0].host
|
||||
|
||||
base_url = f"{database.health_check_url}:{self._rest_api_port}"
|
||||
self._http_client.client.base_url = base_url
|
||||
|
||||
# Find bdb matching to the current database host
|
||||
matching_bdb = None
|
||||
for bdb in await self._http_client.get("/v1/bdbs"):
|
||||
for endpoint in bdb["endpoints"]:
|
||||
if endpoint["dns_name"] == db_host:
|
||||
matching_bdb = bdb
|
||||
break
|
||||
|
||||
# In case if the host was set as public IP
|
||||
for addr in endpoint["addr"]:
|
||||
if addr == db_host:
|
||||
matching_bdb = bdb
|
||||
break
|
||||
|
||||
if matching_bdb is None:
|
||||
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
|
||||
raise ValueError("Could not find a matching bdb")
|
||||
|
||||
url = (
|
||||
f"/v1/bdbs/{matching_bdb['uid']}/availability"
|
||||
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
|
||||
)
|
||||
await self._http_client.get(url, expect_json=False)
|
||||
|
||||
# Status checked in an http client, otherwise HttpError will be raised
|
||||
return True
|
||||
Reference in New Issue
Block a user