mirror of
https://github.com/solero/houdini.git
synced 2024-11-13 22:28:21 +00:00
Create generalised base-class for argument deserliazation
This commit is contained in:
parent
ce4603d496
commit
dd6732bfe2
@ -2,6 +2,159 @@ from abc import ABC
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import itertools
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
from Houdini.Cooldown import CooldownError
|
||||||
|
|
||||||
|
|
||||||
|
class ChecklistError(Exception):
|
||||||
|
"""Raised when a checklist fails"""
|
||||||
|
|
||||||
|
|
||||||
|
class _ArgumentDeserializer:
|
||||||
|
__slots__ = ['name', 'components', 'callback', 'parent', 'pass_raw', 'cooldown',
|
||||||
|
'checklist', 'instance', 'alias', 'rest_raw', 'string_delimiter',
|
||||||
|
'string_separator', '_signature', '_arguments', '_exception_callback',
|
||||||
|
'_exception_class']
|
||||||
|
|
||||||
|
def __init__(self, name, callback, **kwargs):
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
self.name = callback.__name__ if name is None else name
|
||||||
|
self.cooldown = kwargs.get('cooldown')
|
||||||
|
self.checklist = kwargs.get('checklist', [])
|
||||||
|
self.rest_raw = kwargs.get('rest_raw', False)
|
||||||
|
self.string_delimiter = kwargs.get('string_delimiter', [])
|
||||||
|
self.string_separator = kwargs.get('string_separator', str())
|
||||||
|
|
||||||
|
self.instance = None
|
||||||
|
|
||||||
|
self._signature = list(inspect.signature(self.callback).parameters.values())
|
||||||
|
self._arguments = inspect.getfullargspec(self.callback)
|
||||||
|
|
||||||
|
self._exception_callback = None
|
||||||
|
self._exception_class = Exception
|
||||||
|
|
||||||
|
if self.rest_raw:
|
||||||
|
self._signature = self._signature[:-1]
|
||||||
|
|
||||||
|
def _can_run(self, p):
|
||||||
|
return True if not self.checklist else all(predicate(self, p) for predicate in self.checklist)
|
||||||
|
|
||||||
|
async def _check_cooldown(self, p):
|
||||||
|
if self.cooldown is not None:
|
||||||
|
bucket = self.cooldown.get_bucket(p)
|
||||||
|
if bucket.is_cooling:
|
||||||
|
if self.cooldown.callback is not None:
|
||||||
|
if self.instance:
|
||||||
|
await self.cooldown.callback(self.instance, p)
|
||||||
|
else:
|
||||||
|
await self.cooldown.callback(p)
|
||||||
|
else:
|
||||||
|
raise CooldownError('{} invoked listener during cooldown'.format(p))
|
||||||
|
|
||||||
|
def _check_list(self, p):
|
||||||
|
if not self._can_run(p):
|
||||||
|
raise ChecklistError('Could not invoke listener due to checklist failure')
|
||||||
|
|
||||||
|
def _consume_separated_string(self, ctx):
|
||||||
|
if ctx.argument[0] in self.string_delimiter:
|
||||||
|
while not ctx.argument.endswith(ctx.argument[0]):
|
||||||
|
ctx.argument += self.string_separator + next(ctx.arguments)
|
||||||
|
ctx.argument = ctx.argument[1:-1]
|
||||||
|
|
||||||
|
def error(self, exception_class=Exception):
|
||||||
|
def decorator(exception_callback):
|
||||||
|
self._exception_callback = exception_callback
|
||||||
|
self._exception_class = exception_class
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
async def _deserialize(self, p, data):
|
||||||
|
handler_call_arguments = [self.instance, p] if self.instance is not None else [p]
|
||||||
|
handler_call_keywords = {}
|
||||||
|
|
||||||
|
arguments = itertools.islice(data, len(data) - len(self._arguments.kwonlyargs))
|
||||||
|
keyword_arguments = itertools.islice(data, len(data) - len(self._arguments.kwonlyargs), len(data))
|
||||||
|
|
||||||
|
ctx = _ConverterContext(None, arguments, None, p)
|
||||||
|
for ctx.component in itertools.islice(self._signature, len(handler_call_arguments), len(self._signature)):
|
||||||
|
if ctx.component.annotation is ctx.component.empty and ctx.component.default is not ctx.component.empty:
|
||||||
|
handler_call_arguments.append(ctx.component.default)
|
||||||
|
elif ctx.component.kind == ctx.component.POSITIONAL_OR_KEYWORD:
|
||||||
|
ctx.argument = next(ctx.arguments)
|
||||||
|
converter = get_converter(ctx.component)
|
||||||
|
|
||||||
|
if converter == str:
|
||||||
|
self._consume_separated_string(ctx)
|
||||||
|
|
||||||
|
handler_call_arguments.append(await do_conversion(converter, ctx))
|
||||||
|
elif ctx.component.kind == ctx.component.VAR_POSITIONAL:
|
||||||
|
for argument in ctx.arguments:
|
||||||
|
ctx.argument = argument
|
||||||
|
converter = get_converter(ctx.component)
|
||||||
|
|
||||||
|
if converter == str:
|
||||||
|
self._consume_separated_string(ctx)
|
||||||
|
|
||||||
|
handler_call_arguments.append(await do_conversion(converter, ctx))
|
||||||
|
elif ctx.component.kind == ctx.component.KEYWORD_ONLY:
|
||||||
|
ctx.arguments = keyword_arguments
|
||||||
|
ctx.argument = next(keyword_arguments)
|
||||||
|
converter = get_converter(ctx.component)
|
||||||
|
|
||||||
|
if converter == str:
|
||||||
|
self._consume_separated_string(ctx)
|
||||||
|
|
||||||
|
handler_call_keywords[ctx.component.name] = await do_conversion(converter, ctx)
|
||||||
|
|
||||||
|
if self.rest_raw:
|
||||||
|
handler_call_arguments.append(list(ctx.arguments))
|
||||||
|
|
||||||
|
return handler_call_arguments, handler_call_keywords
|
||||||
|
|
||||||
|
async def __call__(self, p, data):
|
||||||
|
try:
|
||||||
|
handler_call_arguments, handler_call_keywords = await self._deserialize(p, data)
|
||||||
|
|
||||||
|
return await self.callback(*handler_call_arguments, **handler_call_keywords)
|
||||||
|
except Exception as e:
|
||||||
|
if self._exception_callback and isinstance(e, self._exception_class):
|
||||||
|
if self.instance:
|
||||||
|
await self._exception_callback(self.instance, e)
|
||||||
|
else:
|
||||||
|
await self._exception_callback(e)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.__name__())
|
||||||
|
|
||||||
|
def __name__(self):
|
||||||
|
return "{}.{}".format(self.callback.__module__, self.callback.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _listener(cls, name, **kwargs):
|
||||||
|
def decorator(callback):
|
||||||
|
if not asyncio.iscoroutinefunction(callback):
|
||||||
|
raise TypeError('All listeners must be a coroutine.')
|
||||||
|
|
||||||
|
try:
|
||||||
|
cooldown_object = callback.__cooldown
|
||||||
|
del callback.__cooldown
|
||||||
|
except AttributeError:
|
||||||
|
cooldown_object = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
checklist = callback.__checks
|
||||||
|
del callback.__checks
|
||||||
|
except AttributeError:
|
||||||
|
checklist = []
|
||||||
|
|
||||||
|
listener_object = cls(name, callback, cooldown=cooldown_object, checklist=checklist, **kwargs)
|
||||||
|
return listener_object
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class IConverter(ABC):
|
class IConverter(ABC):
|
||||||
@ -212,7 +365,7 @@ def get_converter(component):
|
|||||||
|
|
||||||
|
|
||||||
async def do_conversion(converter, ctx):
|
async def do_conversion(converter, ctx):
|
||||||
if issubclass(type(converter), IConverter) and not isinstance(converter, IConverter):
|
if issubclass(converter, IConverter) and not isinstance(converter, IConverter):
|
||||||
converter = converter()
|
converter = converter()
|
||||||
if isinstance(converter, IConverter):
|
if isinstance(converter, IConverter):
|
||||||
if asyncio.iscoroutinefunction(converter.convert):
|
if asyncio.iscoroutinefunction(converter.convert):
|
||||||
|
@ -1,26 +1,15 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import itertools
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
|
|
||||||
from Houdini.Converters import get_converter, do_conversion, _ConverterContext
|
from Houdini.Converters import _listener, _ArgumentDeserializer, get_converter, do_conversion, _ConverterContext
|
||||||
|
|
||||||
from Houdini.Cooldown import _Cooldown, _CooldownMapping, BucketType, CooldownError
|
from Houdini.Cooldown import _Cooldown, _CooldownMapping, BucketType
|
||||||
from Houdini import Plugins
|
from Houdini import Plugins
|
||||||
|
|
||||||
|
|
||||||
def get_relative_function_path(function_obj):
|
|
||||||
abs_function_file = inspect.getfile(function_obj)
|
|
||||||
rel_function_file = os.path.relpath(abs_function_file)
|
|
||||||
|
|
||||||
return rel_function_file
|
|
||||||
|
|
||||||
|
|
||||||
class ChecklistError(Exception):
|
|
||||||
"""Raised when a checklist fails"""
|
|
||||||
|
|
||||||
|
|
||||||
class AuthorityError(Exception):
|
class AuthorityError(Exception):
|
||||||
"""Raised when a packet is received but user has not yet authenticated"""
|
"""Raised when a packet is received but user has not yet authenticated"""
|
||||||
|
|
||||||
@ -59,159 +48,76 @@ class Priority(enum.Enum):
|
|||||||
Low = 1
|
Low = 1
|
||||||
|
|
||||||
|
|
||||||
class _Listener:
|
class _Listener(_ArgumentDeserializer):
|
||||||
|
|
||||||
__slots__ = ['packet', 'components', 'handler', 'priority',
|
__slots__ = ['priority', 'packet', 'overrides']
|
||||||
'cooldown', 'pass_packet', 'handler_file',
|
|
||||||
'overrides', 'pre_login', 'checklist', 'plugin']
|
|
||||||
|
|
||||||
def __init__(self, packet, components, handler_function, **kwargs):
|
def __init__(self, packet, callback, **kwargs):
|
||||||
|
super().__init__(packet.id, callback, **kwargs)
|
||||||
self.packet = packet
|
self.packet = packet
|
||||||
self.components = components
|
|
||||||
self.handler = handler_function
|
|
||||||
|
|
||||||
self.priority = kwargs.get('priority', Priority.Low)
|
self.priority = kwargs.get('priority', Priority.Low)
|
||||||
self.overrides = kwargs.get('overrides', [])
|
self.overrides = kwargs.get('overrides', [])
|
||||||
self.cooldown = kwargs.get('cooldown')
|
|
||||||
self.pass_packet = kwargs.get('pass_packet', False)
|
|
||||||
self.checklist = kwargs.get('checklist', [])
|
|
||||||
|
|
||||||
self.plugin = None
|
|
||||||
|
|
||||||
if type(self.overrides) is not list:
|
if type(self.overrides) is not list:
|
||||||
self.overrides = [self.overrides]
|
self.overrides = [self.overrides]
|
||||||
|
|
||||||
self.handler_file = get_relative_function_path(handler_function)
|
|
||||||
|
|
||||||
def _can_run(self, p):
|
|
||||||
return True if not self.checklist else all(predicate(self.packet, p) for predicate in self.checklist)
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(self.__name__())
|
|
||||||
|
|
||||||
def __name__(self):
|
|
||||||
return "{}.{}".format(self.handler.__module__, self.handler.__name__)
|
|
||||||
|
|
||||||
async def __call__(self, p, packet_data):
|
|
||||||
if isinstance(self.packet, XTPacket) and not self.pre_login and not p.joined_world:
|
|
||||||
await p.close()
|
|
||||||
raise AuthorityError('{} tried sending XT packet before authentication!'.format(p))
|
|
||||||
|
|
||||||
if self.cooldown is not None:
|
|
||||||
bucket = self.cooldown.get_bucket(p)
|
|
||||||
if bucket.is_cooling:
|
|
||||||
if self.cooldown.callback is not None:
|
|
||||||
await self.cooldown.callback(*[self.plugin, p] if self.plugin is not None else p)
|
|
||||||
else:
|
|
||||||
raise CooldownError('{} sent packet during cooldown'.format(p))
|
|
||||||
|
|
||||||
if not self._can_run(p):
|
|
||||||
raise ChecklistError('Could not handle packet due to checklist failure')
|
|
||||||
|
|
||||||
|
|
||||||
class _XTListener(_Listener):
|
class _XTListener(_Listener):
|
||||||
|
|
||||||
__slots__ = ['pre_login', 'rest_raw', 'keywords']
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self.pre_login = kwargs.get('pre_login')
|
|
||||||
self.rest_raw = kwargs.get('rest_raw', False)
|
|
||||||
|
|
||||||
self.keywords = len(inspect.getfullargspec(self.handler).kwonlyargs)
|
|
||||||
|
|
||||||
if self.rest_raw:
|
|
||||||
self.components = self.components[:-1]
|
|
||||||
|
|
||||||
async def __call__(self, p, packet_data):
|
|
||||||
await super().__call__(p, packet_data)
|
|
||||||
|
|
||||||
handler_call_arguments = [self.plugin] if self.plugin is not None else []
|
|
||||||
handler_call_arguments += [self.packet, p] if self.pass_packet else [p]
|
|
||||||
handler_call_keywords = {}
|
|
||||||
|
|
||||||
arguments = iter(packet_data[:-self.keywords])
|
|
||||||
ctx = _ConverterContext(None, arguments, None, p)
|
|
||||||
for ctx.component in self.components:
|
|
||||||
if ctx.component.annotation is ctx.component.empty and ctx.component.default is not ctx.component.empty:
|
|
||||||
handler_call_arguments.append(ctx.component.default)
|
|
||||||
next(ctx.arguments)
|
|
||||||
elif ctx.component.kind == ctx.component.POSITIONAL_OR_KEYWORD:
|
|
||||||
ctx.argument = next(ctx.arguments)
|
|
||||||
converter = get_converter(ctx.component)
|
|
||||||
|
|
||||||
handler_call_arguments.append(await do_conversion(converter, ctx))
|
|
||||||
elif ctx.component.kind == ctx.component.VAR_POSITIONAL:
|
|
||||||
for argument in ctx.arguments:
|
|
||||||
ctx.argument = argument
|
|
||||||
converter = get_converter(ctx.component)
|
|
||||||
|
|
||||||
handler_call_arguments.append(await do_conversion(converter, ctx))
|
|
||||||
elif ctx.component.kind == ctx.component.KEYWORD_ONLY:
|
|
||||||
ctx.argument = packet_data[-self.keywords:][len(handler_call_keywords)]
|
|
||||||
converter = get_converter(ctx.component)
|
|
||||||
handler_call_keywords[ctx.component.name] = await do_conversion(converter, ctx)
|
|
||||||
|
|
||||||
if self.rest_raw:
|
|
||||||
handler_call_arguments.append(list(ctx.arguments))
|
|
||||||
return await self.handler(*handler_call_arguments, **handler_call_keywords)
|
|
||||||
elif not len(list(ctx.arguments)):
|
|
||||||
return await self.handler(*handler_call_arguments, **handler_call_keywords)
|
|
||||||
|
|
||||||
|
|
||||||
class _XMLListener(_Listener):
|
|
||||||
__slots__ = ['pre_login']
|
__slots__ = ['pre_login']
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.pre_login = kwargs.get('pre_login')
|
||||||
|
|
||||||
async def __call__(self, p, packet_data):
|
async def __call__(self, p, packet_data):
|
||||||
|
if not self.pre_login and not p.joined_world:
|
||||||
|
await p.close()
|
||||||
|
raise AuthorityError('{} tried sending XT packet before authentication!'.format(p))
|
||||||
|
|
||||||
|
await super()._check_cooldown(p)
|
||||||
|
super()._check_list(p)
|
||||||
|
|
||||||
await super().__call__(p, packet_data)
|
await super().__call__(p, packet_data)
|
||||||
|
|
||||||
handler_call_arguments = [self.plugin] if self.plugin is not None else []
|
|
||||||
handler_call_arguments += [self.packet, p] if self.pass_packet else [p]
|
|
||||||
|
|
||||||
for index, component in enumerate(self.components):
|
class _XMLListener(_Listener):
|
||||||
if component.default is not component.empty:
|
|
||||||
handler_call_arguments.append(component.default)
|
def __init__(self, *args, **kwargs):
|
||||||
elif component.kind == component.POSITIONAL_OR_KEYWORD:
|
super().__init__(*args, **kwargs)
|
||||||
converter = get_converter(component)
|
|
||||||
ctx = _ConverterContext(component, None, packet_data, p)
|
async def __call__(self, p, packet_data):
|
||||||
|
await super()._check_cooldown(p)
|
||||||
|
super()._check_list(p)
|
||||||
|
|
||||||
|
handler_call_arguments = [self.instance, p] if self.instance is not None else [p]
|
||||||
|
|
||||||
|
ctx = _ConverterContext(None, None, packet_data, p)
|
||||||
|
for ctx.component in itertools.islice(self._signature, len(handler_call_arguments), len(self._signature)):
|
||||||
|
if ctx.component.default is not ctx.component.empty:
|
||||||
|
handler_call_arguments.append(ctx.component.default)
|
||||||
|
elif ctx.component.kind == ctx.component.POSITIONAL_OR_KEYWORD:
|
||||||
|
converter = get_converter(ctx.component)
|
||||||
|
|
||||||
handler_call_arguments.append(await do_conversion(converter, ctx))
|
handler_call_arguments.append(await do_conversion(converter, ctx))
|
||||||
return await self.handler(*handler_call_arguments)
|
return await self.callback(*handler_call_arguments)
|
||||||
|
|
||||||
|
|
||||||
|
def get_relative_function_path(function_obj):
|
||||||
|
abs_function_file = inspect.getfile(function_obj)
|
||||||
|
rel_function_file = os.path.relpath(abs_function_file)
|
||||||
|
|
||||||
|
return rel_function_file
|
||||||
|
|
||||||
|
|
||||||
def handler(packet, **kwargs):
|
def handler(packet, **kwargs):
|
||||||
def decorator(handler_function):
|
if not issubclass(type(packet), _Packet):
|
||||||
if not asyncio.iscoroutinefunction(handler_function):
|
raise TypeError('All handlers can only listen for either XMLPacket or XTPacket.')
|
||||||
raise TypeError('All handlers must be a coroutine.')
|
|
||||||
|
|
||||||
components = list(inspect.signature(handler_function).parameters.values())
|
listener_class = _XTListener if isinstance(packet, XTPacket) else _XMLListener
|
||||||
components = components[2:] if str(components[0]) == 'self' else components[1:]
|
return _listener(listener_class, packet, **kwargs)
|
||||||
|
|
||||||
if not issubclass(type(packet), _Packet):
|
|
||||||
raise TypeError('All handlers can only listen for either XMLPacket or XTPacket.')
|
|
||||||
|
|
||||||
listener_class = _XTListener if isinstance(packet, XTPacket) else _XMLListener
|
|
||||||
|
|
||||||
try:
|
|
||||||
cooldown_object = handler_function.__cooldown
|
|
||||||
del handler_function.__cooldown
|
|
||||||
except AttributeError:
|
|
||||||
cooldown_object = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
checklist = handler_function.__checks
|
|
||||||
del handler_function.__checks
|
|
||||||
except AttributeError:
|
|
||||||
checklist = []
|
|
||||||
|
|
||||||
listener_object = listener_class(packet, components, handler_function,
|
|
||||||
cooldown=cooldown_object, checklist=checklist,
|
|
||||||
**kwargs)
|
|
||||||
return listener_object
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def listener_exists(xt_listeners, xml_listeners, packet):
|
def listener_exists(xt_listeners, xml_listeners, packet):
|
||||||
@ -227,7 +133,7 @@ def listeners_from_module(xt_listeners, xml_listeners, module):
|
|||||||
listener_objects = inspect.getmembers(module, is_listener)
|
listener_objects = inspect.getmembers(module, is_listener)
|
||||||
for listener_name, listener_object in listener_objects:
|
for listener_name, listener_object in listener_objects:
|
||||||
if isinstance(module, Plugins.IPlugin):
|
if isinstance(module, Plugins.IPlugin):
|
||||||
listener_object.plugin = module
|
listener_object.instance = module
|
||||||
|
|
||||||
listener_collection = xt_listeners if type(listener_object) == _XTListener else xml_listeners
|
listener_collection = xt_listeners if type(listener_object) == _XTListener else xml_listeners
|
||||||
if listener_object.packet not in listener_collection:
|
if listener_object.packet not in listener_collection:
|
||||||
@ -251,7 +157,8 @@ def remove_handlers_by_module(xt_listeners, xml_listeners, handler_module_path):
|
|||||||
def remove_handlers(remove_handler_items):
|
def remove_handlers(remove_handler_items):
|
||||||
for handler_id, handler_listeners in remove_handler_items:
|
for handler_id, handler_listeners in remove_handler_items:
|
||||||
for handler_listener in handler_listeners:
|
for handler_listener in handler_listeners:
|
||||||
if handler_listener.handler_file == handler_module_path:
|
handler_file = get_relative_function_path(handler_listener.callback)
|
||||||
|
if handler_file == handler_module_path:
|
||||||
handler_listeners.remove(handler_listener)
|
handler_listeners.remove(handler_listener)
|
||||||
remove_handlers(xt_listeners.items())
|
remove_handlers(xt_listeners.items())
|
||||||
remove_handlers(xml_listeners.items())
|
remove_handlers(xml_listeners.items())
|
||||||
@ -278,8 +185,8 @@ def check(predicate):
|
|||||||
|
|
||||||
|
|
||||||
def allow_once():
|
def allow_once():
|
||||||
def check_for_packet(packet, p):
|
def check_for_packet(listener, p):
|
||||||
return packet not in p.received_packets
|
return listener.packet not in p.received_packets
|
||||||
return check(check_for_packet)
|
return check(check_for_packet)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user