Source code for dspinlock.base

""""Module that contains query related utilities."""

import logging
from abc import ABC, abstractmethod
from collections.abc import Hashable
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
from hashlib import sha256
from math import ceil
from time import sleep
from typing import Any

import redis

from .consts import SL_LOG_TAG
from .exceptions import (
    AtomicQueryInvalidStateException,
    FailIfKeyExistsIsEnabled,
    InvalidMutexReleaseEncountered,
    ProvidedObjectIsNotHashable,
    SpinlockTriesExceeded,
)
from .utils import RedisParameters, create_redis_conn

qlog = logging.getLogger(SL_LOG_TAG)
"""Get the logger."""


class QueryState(Enum):
    """The QueryState used to indicate its execution status."""

    BLOCKED: int = 0
    """Indicates that the current key is blocked due someone computing the result."""
    COMPUTED: int = 1
    """Indicates that the result is computed - value will expire after TTL expires."""


# pylint: disable=too-few-public-methods
@dataclass(order=True, frozen=True)
class UnpackedValue:
    """The unpacked value dataclass, used to store the parsed value from redis."""

    value: int
    """The value packed."""
    tag: str
    """The parsed tag."""
    req_id: str | None = None
    """The request id."""
    timestamp: datetime | None = None
    """The parsed timestamp."""


# pylint: disable=too-few-public-methods
[docs] class DSpinlockBase(ABC): """ Class that ensures the query isolation in case of parallel requests. """ max_spinlock_tries: int = 10 """The spinlock max retries, by default 10 tries.""" spinlock_sleep_thresh: float = 0.5 """The spinlock sleep threshold, by default 0.5 seconds.""" expire_at_timedelta: timedelta = timedelta(hours=1) """The `expire_at` timedelta value, by default 1 hour.""" _key_sep: str = "," """The tag separator.""" _value_sep: str = "," """The value separator.""" max_block_time: timedelta = timedelta(hours=0.5) """The max block time allowed for a query mutex to be held, if not released it is forcefully unblocked.""" # the redis session.""" _sess: redis.Redis | None = None # the key prefix to use within redis _key_prefix: str = "dspinlock" # pylint: disable=too-many-arguments
[docs] def __init__( self, obj: Any, sess: redis.Redis | None = None, redis_params: RedisParameters | None = None, fail_if_key_exists: bool = False, cached_if_computed: bool = False, ): """ The constructor which takes Parameters ---------- obj: Any The object to create the lock for. sess: redis.Redis | None = None The session to redis, which can be `None`. redis_params: RedisParameters | None = None If the `sess` above is `None` then if this flag is raised, we create the connection using sample params. fail_if_key_exists: bool = False Indicates if we fail should the key already exists - i.e. in cases when we want to block computation for a certain period. cached_if_computed: bool = False Indicates if the result is computed can be returned from a cache. Thus, we do not have to wait for its computation; hence, we can return immediately _without_ practically getting the lock. """ self._obj_uid = self._get_uid(obj) self._auto_create = redis_params self._sess = self._get_redis_sess(sess, redis_params) self._tag = self._get_tag() self._fail_if_key_exists = fail_if_key_exists self._cached_if_computed = cached_if_computed
def __enter__(self): self._acquire() # return the context return self def __exit__(self, exc_type, exc_val, exc_tb): self._release() # noinspection InsecureHash @staticmethod def _get_uid(obj: Any) -> str: """ Used to check if the object can be hashed, if not, an exception is raised. Note: In case you want to implement your own `hash` method - you're free to do so either at the object level or by overriding this method when you are subclassing. `Important note: because Python 3 does not support stable hashing for ``str``, ``bytes``, and ``datetime`` we use hashing method from ``hashlib`` that can produce such cases for these cases.` Parameters ---------- obj: Any The object to check if it can be hashed. Returns ------- str The hash of the object returned as a string. Raises -------- ProvidedObjectIsNotHashable Raised when the object provided does not implement `Hashable`. """ if isinstance(obj, str): return sha256(str.encode(obj), usedforsecurity=False).hexdigest() if isinstance(obj, bytes): return sha256(obj, usedforsecurity=False).hexdigest() if isinstance(obj, datetime): return sha256(str.encode(obj.isoformat()), usedforsecurity=False).hexdigest() if isinstance(obj, Hashable): return f"{hash(obj)}" raise ProvidedObjectIsNotHashable(f"Provided object with type: {type(obj)} does not implement `Hashable`.") @abstractmethod def _get_tag(self, obj: Any | None = None): """ Takes the object and produces a tag based on the desired attributes to take in account. Parameters ---------- obj: Any | None = None The object instance to set the tag for, if required - by default it is `None`. """ raise NotImplementedError @abstractmethod def _unpack_value(self, value: bytes | str | None) -> UnpackedValue | None: """ Unpacks the value from redis. The format of the stored value is the following, - req_id: the request id, - datetime ts: the datetime timestamp, - mutex value: the mutex current value Parameters ---------- value: bytes | str | None The value as fetched from redis. Returns ------- UnpackedValue | None The parsed `UnpackedValue` instance, `None` if the value is already `None`. """ raise NotImplementedError def _tag_match(self, val: UnpackedValue) -> bool: """ Function that checks if the unpacked values' tag matches the current instance one. Parameters ---------- val: UnpackedValue The unpacked value instance. Returns ------- bool Returns `True` if the value matches, `False` otherwise. """ return val.tag == self._tag @staticmethod def _mutex_value_match(val: UnpackedValue, match_to: QueryState) -> bool: """ Function that checks if the unpacked values' mutex value matches the target one. Parameters ---------- val: UnpackedValue The unpacked value instance. match_to: QueryState The QueryState value to check against. Returns ------- bool Returns `True` if the value matches, `False` otherwise. """ return val.value == match_to.value def _get_expiry_unix_time(self) -> int: """ Fetches the expiry unix time which is roughly 1 day after the current time. Returns ------- int The timestamp to expire at. """ return ceil((datetime.now(tz=timezone.utc) + self.expire_at_timedelta).timestamp()) def _release(self): """Releases the distributed query mutex.""" key = self.get_key() try: pipe = self._get_redis_sess().pipeline(transaction=True) pipe.watch(key) res = self._unpack_value(self._get_redis_sess().get(key)) if res is None: qlog.debug("Cannot release mutex for key: %s as the key was not existent...", self.get_key()) raise InvalidMutexReleaseEncountered( "Encountered release request before mutex acquisition, this should not happen." ) if self._mutex_value_match(res, QueryState.BLOCKED): if res.tag != self._tag: raise InvalidMutexReleaseEncountered( f"Encountered a blocked query mutex with tag: {res.tag} that is for " "another query, this should not happen." ) payload = self._generate_payload(QueryState.COMPUTED) qlog.debug("Changing the query mutex for key: '%s' to value: '%s'", self.get_key(), payload) pipe.set(key, payload, exat=self._get_expiry_unix_time()) pipe.execute() return if self._mutex_value_match(res, QueryState.COMPUTED): qlog.debug("Query mutex required no release as its already computed with tag: '%s'.", res.tag) return raise InvalidMutexReleaseEncountered("Got an unknown value at release...!") except redis.exceptions.WatchError: if not self._can_break(key, tries=1, is_release=True): raise InvalidMutexReleaseEncountered( "Mutex could not be broken during release when it was acquired - this should not happen, " "as only one instance can have the mutex lock at any given time." ) from redis.exceptions.WatchError def _generate_payload(self, val: QueryState) -> str: """ Generate the payload to store in redis. Parameters ---------- val: QueryState The query state to be stored in redis. Returns ------- str The payload value. """ return f"{self._tag}{self._value_sep}{val.value}" def _block_key(self, pipe: redis.client.Pipeline, key: str, force_unblock: bool): """ Method that blocks the specific key for the current query. Parameters ---------- pipe: redis.client.Pipeline The ``redis`` pipeline to use. key: str The key for the mutex. force_unblock: bool Flag indicating if we can grab te mutex regardless. Raises -------- redis.exceptions.WatchError Raises when the target key has its value changed - this happens by default in this function. """ qlog.debug("Blocking key with tag: '%s', blocking was forced: '%s'", self._tag, force_unblock) pipe.set(key, self._generate_payload(QueryState.BLOCKED), exat=self._get_expiry_unix_time()) pipe.execute() # pylint: disable=too-many-branches def _acquire(self): """ Tries to acquire the mutex for the given query. Raises ------ SpinlockTriesExceeded Is raised when we exceed the number of tries to "own" the query mutex. """ key = self.get_key() # the current spinlock tries to get the Query mutex tries = 1 force_unblock = False computed = False for _ in enumerate(range(self.max_spinlock_tries)): try: pipe: redis.client.Pipeline = self._get_redis_sess().pipeline(transaction=True) pipe.watch(key) # now, check if the key exists and the fail flag if this event happens is raised. if ( res := self._unpack_value(self._get_redis_sess().get(key)) ) is not None and self._fail_if_key_exists: raise FailIfKeyExistsIsEnabled( "Key already exists, cannot continue acquiring when flag to block until expiry is raised." ) # check which situation we're in - either the key exists or not... if res is None or force_unblock or (computed := self._mutex_value_match(res, QueryState.COMPUTED)): if computed and self._cached_if_computed: qlog.debug( "Query has already been computed for key: %s, can be cached flag is enabled not blocking.", self.get_key(), ) else: self._block_key(pipe, key, force_unblock) qlog.debug( f"The key {'existed' if res else 'did not exist before'}, " f"force unblock flag was: '{force_unblock}', and computed flag was: '{computed}', " "now its blocked by request with tag: %s.", self._tag, ) break # we know that `res` is not `None` if we reach here, thus the key definitely exists. # clear everything in multi block, even if its empty from set ops pipe.execute() if self._mutex_value_match(res, QueryState.BLOCKED): if self._tag_match(res): qlog.debug( "Query mutex can be acquired as tags match, tag: %s, spinlock tries: %s", res.tag, tries ) break if not self._has_exhausted_block_time(res): qlog.debug( "Query mutex is blocked by request with tag: %s, spinlock tries: %s", res.tag, tries ) sleep(self.spinlock_sleep_thresh) else: qlog.debug( "Query mutex was blocked for more than the allowed time (which was: %s seconds) " "- force blocking it upon next try.", self.max_block_time.total_seconds(), ) force_unblock = True else: raise AtomicQueryInvalidStateException(f"Encountered an unexpected state value: {res}") except redis.exceptions.WatchError: if self._can_break(key, tries, is_release=False): break tries += 1 if tries > self.max_spinlock_tries: raise SpinlockTriesExceeded( f"Spinlock tries limit of {self.max_spinlock_tries} was exceeded for key {self.get_key()}" ) def _has_exhausted_block_time(self, val: UnpackedValue) -> bool: """ Checks if the mutex block has exceeded the allowed time per query. Parameters ---------- val: UnpackedValue The unpacked value instance. Returns ------- bool Returns `True` if the mutex was blocked for more than the allowed time or was `None`, `False` otherwise. """ return val.timestamp is None or (datetime.now(tz=timezone.utc) - val.timestamp) > self.max_block_time def _can_break(self, key: str, tries: int, is_release: bool): """ Function that checks if we are able to break from the spinlock. Parameters ---------- key: str The key for the entry. tries: int The current spinlock tries. is_release: bool Checks if the break was initialised by a release operation. Returns ------- bool Returns `True` if we are able to break, otherwise `False`. """ res = self._unpack_value(self._get_redis_sess().get(key)) mutex_stage = "release" if is_release else "acquisition" if res is None: raise InvalidMutexReleaseEncountered( f"Cannot have a null key at break check during {mutex_stage}, for key: {self.get_key()}" ) if self._tag_match(res): if (is_release and self._mutex_value_match(res, QueryState.COMPUTED)) or ( not is_release and self._mutex_value_match(res, QueryState.BLOCKED) ): qlog.debug( "Query mutex value changed during its %s, but it was from us as tags match, tag: %s", mutex_stage, self._tag, ) return True qlog.debug( "Query mutex value changed during its %s, attempting to get the mutex until spinlock tries is exhausted " "current: %s out of: %s.", mutex_stage, tries, self.max_spinlock_tries, ) return False
[docs] @abstractmethod def get_key(self) -> str: """ Fetches the base key for the page. Returns ------- str The key to store the value in redis. """ raise NotImplementedError
[docs] def delete_atomic_query_mutex_state(self) -> bool: """ Attempt to delete a specific mutex state for a given key from the global state stored within redis. Returns ------- bool Returns `True` if we managed to successfully delete the mutex state, `False` otherwise. """ key = self.get_key() if (res := self._get_redis_sess().delete(key)) == 0: qlog.debug("Failed to delete mutex state for key: %s, potentially it does not exist", key) else: qlog.debug("Deleted successfully mutex state for key: %s", key) return res != 0
def _get_redis_sess( self, sess: redis.Redis | None = None, redis_params: RedisParameters | None = None ) -> redis.Redis: """ Fetch or create a redis session based on the current parameters. Note: As a general guideline `redis` should be initialised at instance creation. Returns ------- redis.Redis The redis instance to use. """ qlog.debug("Getting redis session with sess: %s and params: %s", sess, redis_params) # check if both parameters have values if redis_params and sess: raise AttributeError("Providing both a redis instance and parameters to create it is not supported") # If we provide a session to use, override the existing one even if we have it. if sess is not None: self._sess = sess # check if we already have an instance and return it - either if we already assigned or it exists already. if self._sess: return self._sess # if we reached here, then sess is `None`, thus create connection using the parameters provided (if any) self._sess = create_redis_conn(redis_params) # if after creation it is still `None`, raise an exception. if self._sess is None: raise AttributeError("No redis session was provided, cannot continue.") qlog.debug("Returning newly created redis session: %s", self._sess) # finally, return the created connection return self._sess