Source code for almanac.hooks.exception_hook_dispatch_table

import inspect

from typing import Callable, MutableMapping, Optional, Type

from .assertions import assert_async_callback
from .types import AsyncExceptionHookCallback
from ..errors import (

_HookMapping = MutableMapping[Type[Exception], AsyncExceptionHookCallback]

[docs]class ExceptionHookDispatchTable: """A table for storing and dispatching exception hooks.""" def __init__( self ) -> None: self._callback_table: _HookMapping = {} def __call__( self, *exception_types: Type[Exception], allow_overwrite: bool = False ) -> Callable[[AsyncExceptionHookCallback], AsyncExceptionHookCallback]: """A decorator for adding a callback for when exceptions occur.""" if not exception_types: raise MissingRequiredParameterError( 'Must specify at least one exception_type' ) def decorator( hook_coro: AsyncExceptionHookCallback ) -> AsyncExceptionHookCallback: for exc_type in exception_types: self.set_hook_for_exc_type( exc_type, hook_coro, allow_overwrite=allow_overwrite ) return hook_coro return decorator
[docs] def set_hook_for_exc_type( self, exc_type: Type[Exception], hook_coro: AsyncExceptionHookCallback, allow_overwrite: bool = False ) -> None: """Set a hook for an exception type.""" try: assert_async_callback(hook_coro) except InvalidCallbackTypeError as e: raise e if exc_type in self._callback_table.keys() and not allow_overwrite: raise ConflictingExceptionCallbacksError( 'Attempted to overwrite existing exception handler for ' f'{exc_type} without explicit allow_overwrite=True' ) self._callback_table[exc_type] = hook_coro
[docs] def get_hook_for_exc_type( self, exc_type: Type[Exception] ) -> Optional[AsyncExceptionHookCallback]: """Return all matching hooks for the specified exception type.""" matching_hook: Optional[AsyncExceptionHookCallback] = None min_mro_dist = float('inf') # Look for the registered exception type that is "closest" in the class # hierarchy to the exception type we are resolving. for registered_exc_type, hook_coro in self._callback_table.items(): test_min_mro_dist = _mro_distance(exc_type, registered_exc_type) if test_min_mro_dist < min_mro_dist: min_mro_dist = test_min_mro_dist matching_hook = hook_coro return matching_hook
def _mro_distance( sub_cls: Type, super_cls: Type ) -> float: try: sub_cls_mro = inspect.getmro(sub_cls) return sub_cls_mro.index(super_cls) except ValueError: return float('inf')